Simon Peyton Jones pushed to branch wip/T26709 at Glasgow Haskell Compiler / GHC

Commits:

5 changed files:

Changes:

  • compiler/GHC/Core/Opt/Simplify/Utils.hs
    ... ... @@ -2693,7 +2693,7 @@ mkCase, mkCase1, mkCase2, mkCase3
    2693 2693
     
    
    2694 2694
     mkCase mode scrut outer_bndr alts_ty alts
    
    2695 2695
       | sm_case_merge mode
    
    2696
    -  , Just (joins, alts') <- mergeCaseAlts outer_bndr alts
    
    2696
    +  , Just (joins, alts') <- mergeCaseAlts scrut outer_bndr alts
    
    2697 2697
       = do  { tick (CaseMerge outer_bndr)
    
    2698 2698
             ; case_expr <- mkCase1 mode scrut outer_bndr alts_ty alts'
    
    2699 2699
             ; return (mkLets joins case_expr) }
    

  • compiler/GHC/Core/Utils.hs
    ... ... @@ -73,7 +73,7 @@ import GHC.Platform
    73 73
     
    
    74 74
     import GHC.Core
    
    75 75
     import GHC.Core.Ppr
    
    76
    -import GHC.Core.FVs( bindFreeVars )
    
    76
    +import GHC.Core.FVs( exprFreeVars, bindFreeVars )
    
    77 77
     import GHC.Core.DataCon
    
    78 78
     import GHC.Core.Type as Type
    
    79 79
     import GHC.Core.Predicate( isEqPred )
    
    ... ... @@ -113,11 +113,11 @@ import GHC.Utils.Outputable
    113 113
     import GHC.Utils.Panic
    
    114 114
     import GHC.Utils.Misc
    
    115 115
     
    
    116
    +import Control.Monad       ( guard )
    
    116 117
     import Data.ByteString     ( ByteString )
    
    117 118
     import Data.Function       ( on )
    
    118 119
     import Data.List           ( sort, sortBy, partition, zipWith4, mapAccumL )
    
    119 120
     import Data.Ord            ( comparing )
    
    120
    -import Control.Monad       ( guard )
    
    121 121
     import qualified Data.Set as Set
    
    122 122
     
    
    123 123
     {-
    
    ... ... @@ -674,11 +674,12 @@ filters down the matching alternatives in GHC.Core.Opt.Simplify.rebuildCase.
    674 674
     -}
    
    675 675
     
    
    676 676
     ---------------------------------
    
    677
    -mergeCaseAlts :: Id -> [CoreAlt] -> Maybe ([CoreBind], [CoreAlt])
    
    677
    +mergeCaseAlts :: CoreExpr -> Id -> [CoreAlt] -> Maybe ([CoreBind], [CoreAlt])
    
    678 678
     -- See Note [Merge Nested Cases]
    
    679
    -mergeCaseAlts outer_bndr (Alt DEFAULT _ deflt_rhs : outer_alts)
    
    679
    +mergeCaseAlts scrut outer_bndr (Alt DEFAULT _ deflt_rhs : outer_alts)
    
    680 680
       | Just (joins, inner_alts) <- go deflt_rhs
    
    681
    -  = Just (joins, mergeAlts outer_alts inner_alts)
    
    681
    +  , Just aux_binds <- mk_aux_binds joins
    
    682
    +  = Just ( aux_binds ++ joins, mergeAlts outer_alts inner_alts )
    
    682 683
                     -- NB: mergeAlts gives priority to the left
    
    683 684
                     --      case x of
    
    684 685
                     --        A -> e1
    
    ... ... @@ -688,6 +689,20 @@ mergeCaseAlts outer_bndr (Alt DEFAULT _ deflt_rhs : outer_alts)
    688 689
                     -- When we merge, we must ensure that e1 takes
    
    689 690
                     -- precedence over e2 as the value for A!
    
    690 691
       where
    
    692
    +    scrut_fvs = exprFreeVars scrut
    
    693
    +
    
    694
    +    -- See Note [Floating join points out of DEFAULT alternatives]
    
    695
    +    mk_aux_binds join_binds
    
    696
    +      | not (any mentions_outer_bndr join_binds)
    
    697
    +      = Just []                         -- Good!  No auxiliary bindings needed
    
    698
    +      | exprIsTrivial scrut
    
    699
    +      , not (outer_bndr `elemVarSet` scrut_fvs)
    
    700
    +      = Just [NonRec outer_bndr scrut]  -- Need a fixup binding
    
    701
    +      | otherwise
    
    702
    +      = Nothing                         -- Can't do it
    
    703
    +
    
    704
    +    mentions_outer_bndr bind = outer_bndr `elemVarSet` bindFreeVars bind
    
    705
    +
    
    691 706
         go :: CoreExpr -> Maybe ([CoreBind], [CoreAlt])
    
    692 707
     
    
    693 708
         -- Whizzo: we can merge!
    
    ... ... @@ -725,11 +740,10 @@ mergeCaseAlts outer_bndr (Alt DEFAULT _ deflt_rhs : outer_alts)
    725 740
           = do { (joins, alts) <- go body
    
    726 741
     
    
    727 742
                  -- Check for capture; but only if we could otherwise do a merge
    
    728
    -           ; let capture = outer_bndr `elem` bindersOf bind
    
    729
    -                           || outer_bndr `elemVarSet` bindFreeVars bind
    
    730
    -           ; guard (not capture)
    
    743
    +             --    (i.e. the recursive `go` succeeds)
    
    744
    +           ; guard (okToFloatJoin scrut_fvs outer_bndr bind)
    
    731 745
     
    
    732
    -           ; return (bind:joins, alts ) }
    
    746
    +           ; return (bind : joins, alts ) }
    
    733 747
           | otherwise
    
    734 748
           = Nothing
    
    735 749
     
    
    ... ... @@ -741,7 +755,18 @@ mergeCaseAlts outer_bndr (Alt DEFAULT _ deflt_rhs : outer_alts)
    741 755
     
    
    742 756
         go _ = Nothing
    
    743 757
     
    
    744
    -mergeCaseAlts _ _ = Nothing
    
    758
    +mergeCaseAlts _ _ _ = Nothing
    
    759
    +
    
    760
    +okToFloatJoin :: VarSet -> Id -> CoreBind -> Bool
    
    761
    +-- Check a join-point binding to see if it can be floated out of
    
    762
    +-- the DEFAULT branch of a `case`.
    
    763
    +-- See Note [Floating join points out of DEFAULT alternatives]
    
    764
    +okToFloatJoin scrut_fvs outer_bndr bind
    
    765
    +  = not (any bad_bndr (bindersOf bind))
    
    766
    +  where
    
    767
    +    bad_bndr bndr = bndr == outer_bndr              -- (a)
    
    768
    +                    || bndr `elemVarSet` scrut_fvs  -- (b)
    
    769
    +
    
    745 770
     
    
    746 771
     ---------------------------------
    
    747 772
     mergeAlts :: [Alt a] -> [Alt a] -> [Alt a]
    
    ... ... @@ -950,10 +975,46 @@ Wrinkles
    950 975
           non-join-points unless the /outer/ case has just one alternative; doing
    
    951 976
           so would risk more allocation
    
    952 977
     
    
    978
    +      Floating out join points isn't entirely straightforward.
    
    979
    +      See Note [Floating join points out of DEFAULT alternatives]
    
    980
    +
    
    953 981
     (MC5) See Note [Cascading case merge]
    
    954 982
     
    
    955 983
     See also Note [Example of case-merging and caseRules] in GHC.Core.Opt.Simplify.Utils
    
    956 984
     
    
    985
    +Note [Floating join points out of DEFAULT alternatives]
    
    986
    +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
    
    987
    +Consider this, from (MC4) of Note [Merge Nested Cases]
    
    988
    +   case x of r
    
    989
    +     DEFAULT -> join j = rhs in case r of ...
    
    990
    +     alts
    
    991
    +
    
    992
    +We want to float that join point out to give this
    
    993
    +   join j = rhs
    
    994
    +   case x of r
    
    995
    +     DEFAULT -> case r of ...
    
    996
    +     alts
    
    997
    +
    
    998
    +But doing so is flat-out wrong if the scoping gets messed up:
    
    999
    +    (a) case x of r { DEFAULT -> join r = ... in ...r... }
    
    1000
    +    (b) case j of r { DEFAULT -> join j = ... in ... }
    
    1001
    +    (c) case x of r { DEFAULT -> join j = ...r.. in ... }
    
    1002
    +In all these cases we can't float the join point out because r changes its
    
    1003
    +meaning.  For (a) and (b) the Simplifier removes shadowing, so they'll
    
    1004
    +be solved in the next iteration.  But case (c) will persist.
    
    1005
    +
    
    1006
    +Happily, we can fix up case (c) by adding an auxiliary binding, like this
    
    1007
    +    let r = e in
    
    1008
    +    join j = rhs[r]
    
    1009
    +    case e of r
    
    1010
    +       DEFAULT -> ...r...
    
    1011
    +       ...other alts...
    
    1012
    +
    
    1013
    +We can only do this if
    
    1014
    +  * We don't introduce shadowing: that is `j` and `r` do not appear free in `e`.
    
    1015
    +    (Again the Simplifier will eliminate such shadowing.)
    
    1016
    +  * The scrutinee `e` is trivial so that the transformation doesn't duplicate work.
    
    1017
    +
    
    957 1018
     
    
    958 1019
     Note [Cascading case merge]
    
    959 1020
     ~~~~~~~~~~~~~~~~~~~~~~~~~~~
    

  • testsuite/tests/simplCore/should_compile/T26709.hs
    1
    +module T26709 where
    
    2
    +
    
    3
    +data T = A | B | C
    
    4
    +
    
    5
    +f x = case x of
    
    6
    +        A -> True
    
    7
    +        _ -> let {-# NOINLINE j #-}
    
    8
    +                 j y = y && not (f x)
    
    9
    +             in case x of
    
    10
    +                   B -> j True
    
    11
    +                   C -> j False

  • testsuite/tests/simplCore/should_compile/T26709.stderr
    1
    +[1 of 1] Compiling T26709           ( T26709.hs, T26709.o )
    
    2
    +
    
    3
    +==================== Tidy Core ====================
    
    4
    +Result size of Tidy Core
    
    5
    +  = {terms: 26, types: 9, coercions: 0, joins: 1/1}
    
    6
    +
    
    7
    +Rec {
    
    8
    +-- RHS size: {terms: 25, types: 7, coercions: 0, joins: 1/1}
    
    9
    +f [Occ=LoopBreaker] :: T -> Bool
    
    10
    +[GblId, Arity=1, Str=<SL>, Unf=OtherCon []]
    
    11
    +f = \ (x :: T) ->
    
    12
    +      join {
    
    13
    +        j [InlPrag=NOINLINE, Dmd=MC(1,L)] :: Bool -> Bool
    
    14
    +        [LclId[JoinId(1)(Just [!])], Arity=1, Str=<1L>, Unf=OtherCon []]
    
    15
    +        j (eta [OS=OneShot] :: Bool)
    
    16
    +          = case eta of {
    
    17
    +              False -> GHC.Internal.Types.False;
    
    18
    +              True ->
    
    19
    +                case f x of {
    
    20
    +                  False -> GHC.Internal.Types.True;
    
    21
    +                  True -> GHC.Internal.Types.False
    
    22
    +                }
    
    23
    +            } } in
    
    24
    +      case x of {
    
    25
    +        A -> GHC.Internal.Types.True;
    
    26
    +        B -> jump j GHC.Internal.Types.True;
    
    27
    +        C -> jump j GHC.Internal.Types.False
    
    28
    +      }
    
    29
    +end Rec }
    
    30
    +
    
    31
    +
    
    32
    +

  • testsuite/tests/simplCore/should_compile/all.T
    ... ... @@ -563,3 +563,8 @@ test('T26115', [grep_errmsg(r'DFun')], compile, ['-O -ddump-simpl -dsuppress-uni
    563 563
     test('T26116', normal, compile, ['-O -ddump-rules'])
    
    564 564
     test('T26117', [grep_errmsg(r'==')], compile, ['-O -ddump-simpl -dsuppress-uniques'])
    
    565 565
     test('T26349',  normal, compile, ['-O -ddump-rules'])
    
    566
    +
    
    567
    +# T26709: we expect three `case` expressions not four
    
    568
    +test('T26709', [grep_errmsg(r'case')],
    
    569
    +       multimod_compile,
    
    570
    +       ['T26709', '-O -ddump-simpl -dsuppress-uniques -dno-typeable-binds'])