Simon Peyton Jones pushed to branch wip/T20264 at Glasgow Haskell Compiler / GHC

Commits:

6 changed files:

Changes:

  • compiler/GHC/Core/Opt/Arity.hs
    ... ... @@ -39,7 +39,7 @@ module GHC.Core.Opt.Arity
    39 39
     
    
    40 40
     
    
    41 41
        -- ** Join points
    
    42
    -   , etaExpandToJoinPoint, etaExpandToJoinPointRule
    
    42
    +   , etaExpandToJoinPoint, etaExpandToJoinPointRule, mkNewJoinPointBinding
    
    43 43
     
    
    44 44
        -- ** Coercions and casts
    
    45 45
        , pushCoArg, pushCoArgs, pushCoValArg, pushCoTyArg
    
    ... ... @@ -3168,6 +3168,16 @@ more elaborate stuff, but it'd involve substitution etc.
    3168 3168
     ********************************************************************* -}
    
    3169 3169
     
    
    3170 3170
     -------------------
    
    3171
    +mkNewJoinPointBinding :: Id -> JoinArity -> CoreExpr -> (Id, CoreExpr)
    
    3172
    +mkNewJoinPointBinding bndr join_arity rhs
    
    3173
    +  = (join_bndr, mkLams join_lam_bndrs join_body)
    
    3174
    +  where
    
    3175
    +    (join_lam_bndrs, join_body) = etaExpandToJoinPoint join_arity rhs
    
    3176
    +    str_sig   = idDmdSig bndr
    
    3177
    +    str_arity = count isId join_lam_bndrs  -- Strictness demands are for Ids only
    
    3178
    +    join_bndr = bndr `asJoinId`    join_arity
    
    3179
    +                     `setIdDmdSig` etaConvertDmdSig str_arity str_sig
    
    3180
    +
    
    3171 3181
     -- | Split an expression into the given number of binders and a body,
    
    3172 3182
     -- eta-expanding if necessary. Counts value *and* type binders.
    
    3173 3183
     etaExpandToJoinPoint :: JoinArity -> CoreExpr -> ([CoreBndr], CoreExpr)
    

  • compiler/GHC/Core/Opt/Exitify.hs
    ... ... @@ -38,6 +38,7 @@ Now `t` is no longer in a recursive function, and good things happen!
    38 38
     import GHC.Prelude
    
    39 39
     import GHC.Builtin.Uniques
    
    40 40
     import GHC.Core
    
    41
    +import GHC.Core.Opt.Arity( mkNewJoinPointBinding )
    
    41 42
     import GHC.Core.Utils
    
    42 43
     import GHC.Core.FVs
    
    43 44
     import GHC.Core.Type
    
    ... ... @@ -49,7 +50,7 @@ import GHC.Types.Var.Set
    49 50
     import GHC.Types.Var.Env
    
    50 51
     import GHC.Types.Basic( JoinPointHood(..) )
    
    51 52
     
    
    52
    -import GHC.Utils.Monad.State.Strict
    
    53
    +import qualified GHC.Utils.Monad.State.Strict as S
    
    53 54
     import GHC.Utils.Misc( mapSnd, count )
    
    54 55
     
    
    55 56
     import GHC.Data.FastString
    
    ... ... @@ -105,7 +106,7 @@ exitifyProgram binds = map goTopLvl binds
    105 106
     
    
    106 107
     
    
    107 108
     -- | State Monad used inside `exitify`
    
    108
    -type ExitifyM =  State [(JoinId, CoreExpr)]
    
    109
    +type ExitifyM =  S.State [(JoinId, CoreExpr)]
    
    109 110
     
    
    110 111
     -- | Given a recursive group of a joinrec, identifies “exit paths” and binds them as
    
    111 112
     --   join-points outside the joinrec.
    
    ... ... @@ -121,7 +122,7 @@ exitifyRec in_scope pairs
    121 122
         -- Which are the recursive calls?
    
    122 123
         recursive_calls = mkVarSet $ map fst pairs
    
    123 124
     
    
    124
    -    (pairs',exits) = (`runState` []) $
    
    125
    +    (pairs',exits) = (`S.runState` []) $
    
    125 126
             forM ann_pairs $ \(x,rhs) -> do
    
    126 127
                 -- go past the lambdas of the join point
    
    127 128
                 let (args, body) = collectNAnnBndrs (idJoinArity x) rhs
    
    ... ... @@ -262,28 +263,27 @@ exitifyRec in_scope pairs
    262 263
             captures_join_points = any isJoinId abs_vars
    
    263 264
     
    
    264 265
     
    
    265
    --- Picks a new unique, which is disjoint from
    
    266
    ---  * the free variables of the whole joinrec
    
    267
    ---  * any bound variables (captured)
    
    268
    ---  * any exit join points created so far.
    
    269
    -mkExitJoinId :: InScopeSet -> Type -> JoinArity -> ExitifyM JoinId
    
    270
    -mkExitJoinId in_scope ty join_arity = do
    
    271
    -    fs <- get
    
    272
    -    let avoid = in_scope `extendInScopeSetList` (map fst fs)
    
    273
    -                         `extendInScopeSet` exit_id_tmpl -- just cosmetics
    
    274
    -    return (uniqAway avoid exit_id_tmpl)
    
    275
    -  where
    
    276
    -    exit_id_tmpl = mkSysLocal (fsLit "exit") initExitJoinUnique ManyTy ty
    
    277
    -                    `asJoinId` join_arity
    
    278
    -
    
    279 266
     addExit :: InScopeSet -> JoinArity -> CoreExpr -> ExitifyM JoinId
    
    280
    -addExit in_scope join_arity rhs = do
    
    281
    -    -- Pick a suitable name
    
    282
    -    let ty = exprType rhs
    
    283
    -    v <- mkExitJoinId in_scope ty join_arity
    
    284
    -    fs <- get
    
    285
    -    put ((v,rhs):fs)
    
    286
    -    return v
    
    267
    +addExit in_scope join_arity rhs
    
    268
    +  = do { fs <- S.get
    
    269
    +       ; let ty = exprType rhs
    
    270
    +             avoid = in_scope `extendInScopeSetList` (map fst fs)
    
    271
    +                              `extendInScopeSet` exit_id1 -- just cosmetics
    
    272
    +               -- avoid: pick a new unique, that is disjoint from
    
    273
    +               --  * the free variables of the whole joinrec
    
    274
    +               --  * any bound variables (captured)
    
    275
    +               --  * any exit join points created so far (in `fs`)
    
    276
    +
    
    277
    +             exit_id1 = mkSysLocal (fsLit "exit") initExitJoinUnique ManyTy ty
    
    278
    +             exit_id2 = uniqAway avoid exit_id1
    
    279
    +
    
    280
    +             bind_pr@(exit_id3,_) = mkNewJoinPointBinding exit_id2 join_arity rhs
    
    281
    +               -- NB: mkNewJoinPointBinding does eta-expansion if needed,
    
    282
    +               --     to make sure that the join-point binding has the
    
    283
    +               --     right number of lambdas all lined up at the top
    
    284
    +
    
    285
    +       ; S.put (bind_pr : fs)
    
    286
    +       ; return exit_id3 }
    
    287 287
     
    
    288 288
     {-
    
    289 289
     Note [Interesting expression]
    

  • compiler/GHC/Core/Opt/OccurAnal.hs
    ... ... @@ -742,21 +742,27 @@ Wrinkles (W1) and (W2) are very similar to Note [Binder swap] (BS3).
    742 742
     Note [Finding join points]
    
    743 743
     ~~~~~~~~~~~~~~~~~~~~~~~~~~
    
    744 744
     It's the occurrence analyser's job to find bindings that we can turn into join
    
    745
    -points, but it doesn't perform that transformation right away. Rather, it marks
    
    746
    -the eligible bindings as part of their occurrence data, leaving it to the
    
    747
    -simplifier (or to simpleOptPgm) to actually change the binder's 'IdDetails'.
    
    748
    -The simplifier then eta-expands the RHS if needed and then updates the
    
    749
    -occurrence sites. Dividing the work this way means that the occurrence analyser
    
    745
    +points, but it doesn't /perform/ that transformation right away. Rather:
    
    746
    +
    
    747
    +* The occurrence analyser marks the eligible bindings as part of their
    
    748
    +  occurrence data. To track potential join points, we use the 'occ_tail' field of
    
    749
    +  OccInfo. A value of `AlwaysTailCalled n` indicates that every occurrence of
    
    750
    +  the variable is a tail call with `n` arguments (counting both value and type
    
    751
    +  arguments). Otherwise `occ_tail` will be 'NoTailCallInfo'. The tail call info
    
    752
    +  flows bottom-up with the rest of `OccInfo` until it goes on the binder.
    
    753
    +
    
    754
    +* The simplifier (or simpleOptPgm) then
    
    755
    +  * Spots join points from that AlwaysTailCalled OccInfo
    
    756
    +  * Eta-expands the RHS if needed
    
    757
    +  * Changes the binder's `IdDetails`
    
    758
    +  * Updates the occurrence sites
    
    759
    +  The first three steps are done by GHC.Core.Opt.SimpleOpt.joinPointBinding_maybe.
    
    760
    +
    
    761
    +Dividing the work this way means that the occurrence analyser
    
    750 762
     still only takes one pass, yet one can always tell the difference between a
    
    751 763
     function call and a jump by looking at the occurrence (because the same pass
    
    752 764
     changes the 'IdDetails' and propagates the binders to their occurrence sites).
    
    753 765
     
    
    754
    -To track potential join points, we use the 'occ_tail' field of OccInfo. A value
    
    755
    -of `AlwaysTailCalled n` indicates that every occurrence of the variable is a
    
    756
    -tail call with `n` arguments (counting both value and type arguments). Otherwise
    
    757
    -'occ_tail' will be 'NoTailCallInfo'. The tail call info flows bottom-up with the
    
    758
    -rest of 'OccInfo' until it goes on the binder.
    
    759
    -
    
    760 766
     Note [Join arity prediction based on joinRhsArity]
    
    761 767
     ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
    
    762 768
     In general, the join arity from tail occurrences of a join point (O) may be
    

  • compiler/GHC/Core/SimpleOpt.hs
    ... ... @@ -42,7 +42,7 @@ import GHC.Types.Id.Info ( realUnfoldingInfo, setUnfoldingInfo, setRuleInfo, Id
    42 42
     import GHC.Types.Var      ( isNonCoVarId, setTyVarUnfolding, tyVarOccInfo )
    
    43 43
     import GHC.Types.Var.Set
    
    44 44
     import GHC.Types.Var.Env
    
    45
    -import GHC.Types.Demand( etaConvertDmdSig, topSubDmd )
    
    45
    +import GHC.Types.Demand( topSubDmd )
    
    46 46
     import GHC.Types.Tickish
    
    47 47
     import GHC.Types.Basic
    
    48 48
     
    
    ... ... @@ -998,12 +998,7 @@ joinPointBinding_maybe bndr rhs
    998 998
       = Just (bndr, rhs)
    
    999 999
     
    
    1000 1000
       | AlwaysTailCalled join_arity <- tailCallInfo (idOccInfo bndr)
    
    1001
    -  , (bndrs, body) <- etaExpandToJoinPoint join_arity rhs
    
    1002
    -  , let str_sig   = idDmdSig bndr
    
    1003
    -        str_arity = count isId bndrs  -- Strictness demands are for Ids only
    
    1004
    -        join_bndr = bndr `asJoinId`        join_arity
    
    1005
    -                         `setIdDmdSig` etaConvertDmdSig str_arity str_sig
    
    1006
    -  = Just (join_bndr, mkLams bndrs body)
    
    1001
    +  = Just (mkNewJoinPointBinding bndr join_arity rhs)
    
    1007 1002
     
    
    1008 1003
       | otherwise
    
    1009 1004
       = Nothing
    

  • compiler/GHC/Core/Utils.hs
    ... ... @@ -3121,25 +3121,20 @@ mkPolyAbsLams :: (b -> AbsVar, Var -> b -> b)
    3121 3121
     -- use it for both CoreExpr and LevelledExpr
    
    3122 3122
     {-# INLINE mkPolyAbsLams #-}
    
    3123 3123
     mkPolyAbsLams (get,set) bndrs body
    
    3124
    -  = go emptyVarSet [] bndrs
    
    3124
    +  = go bndrs
    
    3125 3125
       where
    
    3126
    -    go _ tv_binds []
    
    3127
    -      = mkLets (reverse tv_binds) body
    
    3128
    -    go tvs tv_binds (bndr:bndrs)
    
    3126
    +    go [] = body
    
    3127
    +    go (bndr:bndrs)
    
    3129 3128
           | Just ty <- tyVarUnfolding_maybe var
    
    3130
    -      = go (tvs `extendVarSet` var) (NonRec bndr (Type ty) : tv_binds) bndrs
    
    3129
    +      = Let (NonRec bndr (Type ty)) $
    
    3130
    +        go bndrs
    
    3131 3131
           | otherwise
    
    3132
    -      = Lam bndr' (go tvs tv_binds bndrs)
    
    3132
    +      = Lam bndr' (go bndrs)
    
    3133 3133
           where
    
    3134 3134
             var = get bndr
    
    3135
    -        var' = updateVarType (expandTyVarUnfoldings tvs) $
    
    3136
    -               zap_unfolding var
    
    3137
    -        bndr' | isEmptyVarSet tvs = bndr
    
    3138
    -              | otherwise         = set var' bndr
    
    3139
    -
    
    3140 3135
             -- zap: We are going to lambda-abstract, so nuke any IdInfo
    
    3141
    -        zap_unfolding var | isId var  = setIdInfo var vanillaIdInfo
    
    3142
    -                          | otherwise = var
    
    3136
    +        bndr' | isId var  = set (setIdInfo var vanillaIdInfo) bndr
    
    3137
    +              | otherwise = bndr
    
    3143 3138
     
    
    3144 3139
     mkCoreAbsLams :: AbsVars -> CoreExpr -> CoreExpr
    
    3145 3140
     -- Specialise for CoreExpr
    

  • compiler/GHC/CoreToStg/Prep.hs
    ... ... @@ -819,21 +819,17 @@ cpeRhsE :: CorePrepEnv -> CoreExpr -> UniqSM (Floats, CpeRhs)
    819 819
     -- For example
    
    820 820
     --      f (g x)   ===>   ([v = g x], f v)
    
    821 821
     
    
    822
    -cpeRhsE env (Type ty)
    
    823
    -  = return (emptyFloats, Type (cpSubstTy env ty))
    
    824
    -cpeRhsE env (Coercion co)
    
    825
    -  = return (emptyFloats, Coercion (cpSubstCo env co))
    
    826
    -cpeRhsE env expr@(Lit lit)
    
    827
    -  | LitNumber LitNumBigNat i <- lit
    
    828
    -    = cpeBigNatLit env i
    
    829
    -  | otherwise = return (emptyFloats, expr)
    
    822
    +cpeRhsE env (Type ty)      = return (emptyFloats, Type (cpSubstTy env ty))
    
    823
    +cpeRhsE env (Coercion co)  = return (emptyFloats, Coercion (cpSubstCo env co))
    
    830 824
     cpeRhsE env expr@(Var {})  = cpeApp env expr
    
    831 825
     cpeRhsE env expr@(App {})  = cpeApp env expr
    
    832 826
     
    
    827
    +cpeRhsE env expr@(Lit lit)
    
    828
    +  = case lit of
    
    829
    +      LitNumber LitNumBigNat i -> cpeBigNatLit env i
    
    830
    +      _                        -> return (emptyFloats, expr)
    
    831
    +
    
    833 832
     cpeRhsE env (Let bind body)
    
    834
    -  | isTypeBind bind
    
    835
    -  = cpeRhsE env body
    
    836
    -  | otherwise
    
    837 833
       = do { (env', bind_floats, maybe_bind') <- cpeBind NotTopLevel env bind
    
    838 834
            ; (body_floats, body') <- cpeRhsE env' body
    
    839 835
            ; let expr' = case maybe_bind' of Just bind' -> Let bind' body'