Simon Peyton Jones pushed to branch wip/T20264 at Glasgow Haskell Compiler / GHC
Commits:
-
e07ba567
by Simon Peyton Jones at 2025-07-24T17:44:04+01:00
6 changed files:
- compiler/GHC/Core/Opt/Arity.hs
- compiler/GHC/Core/Opt/Exitify.hs
- compiler/GHC/Core/Opt/OccurAnal.hs
- compiler/GHC/Core/SimpleOpt.hs
- compiler/GHC/Core/Utils.hs
- compiler/GHC/CoreToStg/Prep.hs
Changes:
| ... | ... | @@ -39,7 +39,7 @@ module GHC.Core.Opt.Arity |
| 39 | 39 | |
| 40 | 40 | |
| 41 | 41 | -- ** Join points
|
| 42 | - , etaExpandToJoinPoint, etaExpandToJoinPointRule
|
|
| 42 | + , etaExpandToJoinPoint, etaExpandToJoinPointRule, mkNewJoinPointBinding
|
|
| 43 | 43 | |
| 44 | 44 | -- ** Coercions and casts
|
| 45 | 45 | , pushCoArg, pushCoArgs, pushCoValArg, pushCoTyArg
|
| ... | ... | @@ -3168,6 +3168,16 @@ more elaborate stuff, but it'd involve substitution etc. |
| 3168 | 3168 | ********************************************************************* -}
|
| 3169 | 3169 | |
| 3170 | 3170 | -------------------
|
| 3171 | +mkNewJoinPointBinding :: Id -> JoinArity -> CoreExpr -> (Id, CoreExpr)
|
|
| 3172 | +mkNewJoinPointBinding bndr join_arity rhs
|
|
| 3173 | + = (join_bndr, mkLams join_lam_bndrs join_body)
|
|
| 3174 | + where
|
|
| 3175 | + (join_lam_bndrs, join_body) = etaExpandToJoinPoint join_arity rhs
|
|
| 3176 | + str_sig = idDmdSig bndr
|
|
| 3177 | + str_arity = count isId join_lam_bndrs -- Strictness demands are for Ids only
|
|
| 3178 | + join_bndr = bndr `asJoinId` join_arity
|
|
| 3179 | + `setIdDmdSig` etaConvertDmdSig str_arity str_sig
|
|
| 3180 | + |
|
| 3171 | 3181 | -- | Split an expression into the given number of binders and a body,
|
| 3172 | 3182 | -- eta-expanding if necessary. Counts value *and* type binders.
|
| 3173 | 3183 | etaExpandToJoinPoint :: JoinArity -> CoreExpr -> ([CoreBndr], CoreExpr)
|
| ... | ... | @@ -38,6 +38,7 @@ Now `t` is no longer in a recursive function, and good things happen! |
| 38 | 38 | import GHC.Prelude
|
| 39 | 39 | import GHC.Builtin.Uniques
|
| 40 | 40 | import GHC.Core
|
| 41 | +import GHC.Core.Opt.Arity( mkNewJoinPointBinding )
|
|
| 41 | 42 | import GHC.Core.Utils
|
| 42 | 43 | import GHC.Core.FVs
|
| 43 | 44 | import GHC.Core.Type
|
| ... | ... | @@ -49,7 +50,7 @@ import GHC.Types.Var.Set |
| 49 | 50 | import GHC.Types.Var.Env
|
| 50 | 51 | import GHC.Types.Basic( JoinPointHood(..) )
|
| 51 | 52 | |
| 52 | -import GHC.Utils.Monad.State.Strict
|
|
| 53 | +import qualified GHC.Utils.Monad.State.Strict as S
|
|
| 53 | 54 | import GHC.Utils.Misc( mapSnd, count )
|
| 54 | 55 | |
| 55 | 56 | import GHC.Data.FastString
|
| ... | ... | @@ -105,7 +106,7 @@ exitifyProgram binds = map goTopLvl binds |
| 105 | 106 | |
| 106 | 107 | |
| 107 | 108 | -- | State Monad used inside `exitify`
|
| 108 | -type ExitifyM = State [(JoinId, CoreExpr)]
|
|
| 109 | +type ExitifyM = S.State [(JoinId, CoreExpr)]
|
|
| 109 | 110 | |
| 110 | 111 | -- | Given a recursive group of a joinrec, identifies “exit paths” and binds them as
|
| 111 | 112 | -- join-points outside the joinrec.
|
| ... | ... | @@ -121,7 +122,7 @@ exitifyRec in_scope pairs |
| 121 | 122 | -- Which are the recursive calls?
|
| 122 | 123 | recursive_calls = mkVarSet $ map fst pairs
|
| 123 | 124 | |
| 124 | - (pairs',exits) = (`runState` []) $
|
|
| 125 | + (pairs',exits) = (`S.runState` []) $
|
|
| 125 | 126 | forM ann_pairs $ \(x,rhs) -> do
|
| 126 | 127 | -- go past the lambdas of the join point
|
| 127 | 128 | let (args, body) = collectNAnnBndrs (idJoinArity x) rhs
|
| ... | ... | @@ -262,28 +263,27 @@ exitifyRec in_scope pairs |
| 262 | 263 | captures_join_points = any isJoinId abs_vars
|
| 263 | 264 | |
| 264 | 265 | |
| 265 | --- Picks a new unique, which is disjoint from
|
|
| 266 | --- * the free variables of the whole joinrec
|
|
| 267 | --- * any bound variables (captured)
|
|
| 268 | --- * any exit join points created so far.
|
|
| 269 | -mkExitJoinId :: InScopeSet -> Type -> JoinArity -> ExitifyM JoinId
|
|
| 270 | -mkExitJoinId in_scope ty join_arity = do
|
|
| 271 | - fs <- get
|
|
| 272 | - let avoid = in_scope `extendInScopeSetList` (map fst fs)
|
|
| 273 | - `extendInScopeSet` exit_id_tmpl -- just cosmetics
|
|
| 274 | - return (uniqAway avoid exit_id_tmpl)
|
|
| 275 | - where
|
|
| 276 | - exit_id_tmpl = mkSysLocal (fsLit "exit") initExitJoinUnique ManyTy ty
|
|
| 277 | - `asJoinId` join_arity
|
|
| 278 | - |
|
| 279 | 266 | addExit :: InScopeSet -> JoinArity -> CoreExpr -> ExitifyM JoinId
|
| 280 | -addExit in_scope join_arity rhs = do
|
|
| 281 | - -- Pick a suitable name
|
|
| 282 | - let ty = exprType rhs
|
|
| 283 | - v <- mkExitJoinId in_scope ty join_arity
|
|
| 284 | - fs <- get
|
|
| 285 | - put ((v,rhs):fs)
|
|
| 286 | - return v
|
|
| 267 | +addExit in_scope join_arity rhs
|
|
| 268 | + = do { fs <- S.get
|
|
| 269 | + ; let ty = exprType rhs
|
|
| 270 | + avoid = in_scope `extendInScopeSetList` (map fst fs)
|
|
| 271 | + `extendInScopeSet` exit_id1 -- just cosmetics
|
|
| 272 | + -- avoid: pick a new unique, that is disjoint from
|
|
| 273 | + -- * the free variables of the whole joinrec
|
|
| 274 | + -- * any bound variables (captured)
|
|
| 275 | + -- * any exit join points created so far (in `fs`)
|
|
| 276 | + |
|
| 277 | + exit_id1 = mkSysLocal (fsLit "exit") initExitJoinUnique ManyTy ty
|
|
| 278 | + exit_id2 = uniqAway avoid exit_id1
|
|
| 279 | + |
|
| 280 | + bind_pr@(exit_id3,_) = mkNewJoinPointBinding exit_id2 join_arity rhs
|
|
| 281 | + -- NB: mkNewJoinPointBinding does eta-expansion if needed,
|
|
| 282 | + -- to make sure that the join-point binding has the
|
|
| 283 | + -- right number of lambdas all lined up at the top
|
|
| 284 | + |
|
| 285 | + ; S.put (bind_pr : fs)
|
|
| 286 | + ; return exit_id3 }
|
|
| 287 | 287 | |
| 288 | 288 | {-
|
| 289 | 289 | Note [Interesting expression]
|
| ... | ... | @@ -742,21 +742,27 @@ Wrinkles (W1) and (W2) are very similar to Note [Binder swap] (BS3). |
| 742 | 742 | Note [Finding join points]
|
| 743 | 743 | ~~~~~~~~~~~~~~~~~~~~~~~~~~
|
| 744 | 744 | It's the occurrence analyser's job to find bindings that we can turn into join
|
| 745 | -points, but it doesn't perform that transformation right away. Rather, it marks
|
|
| 746 | -the eligible bindings as part of their occurrence data, leaving it to the
|
|
| 747 | -simplifier (or to simpleOptPgm) to actually change the binder's 'IdDetails'.
|
|
| 748 | -The simplifier then eta-expands the RHS if needed and then updates the
|
|
| 749 | -occurrence sites. Dividing the work this way means that the occurrence analyser
|
|
| 745 | +points, but it doesn't /perform/ that transformation right away. Rather:
|
|
| 746 | + |
|
| 747 | +* The occurrence analyser marks the eligible bindings as part of their
|
|
| 748 | + occurrence data. To track potential join points, we use the 'occ_tail' field of
|
|
| 749 | + OccInfo. A value of `AlwaysTailCalled n` indicates that every occurrence of
|
|
| 750 | + the variable is a tail call with `n` arguments (counting both value and type
|
|
| 751 | + arguments). Otherwise `occ_tail` will be 'NoTailCallInfo'. The tail call info
|
|
| 752 | + flows bottom-up with the rest of `OccInfo` until it goes on the binder.
|
|
| 753 | + |
|
| 754 | +* The simplifier (or simpleOptPgm) then
|
|
| 755 | + * Spots join points from that AlwaysTailCalled OccInfo
|
|
| 756 | + * Eta-expands the RHS if needed
|
|
| 757 | + * Changes the binder's `IdDetails`
|
|
| 758 | + * Updates the occurrence sites
|
|
| 759 | + The first three steps are done by GHC.Core.Opt.SimpleOpt.joinPointBinding_maybe.
|
|
| 760 | + |
|
| 761 | +Dividing the work this way means that the occurrence analyser
|
|
| 750 | 762 | still only takes one pass, yet one can always tell the difference between a
|
| 751 | 763 | function call and a jump by looking at the occurrence (because the same pass
|
| 752 | 764 | changes the 'IdDetails' and propagates the binders to their occurrence sites).
|
| 753 | 765 | |
| 754 | -To track potential join points, we use the 'occ_tail' field of OccInfo. A value
|
|
| 755 | -of `AlwaysTailCalled n` indicates that every occurrence of the variable is a
|
|
| 756 | -tail call with `n` arguments (counting both value and type arguments). Otherwise
|
|
| 757 | -'occ_tail' will be 'NoTailCallInfo'. The tail call info flows bottom-up with the
|
|
| 758 | -rest of 'OccInfo' until it goes on the binder.
|
|
| 759 | - |
|
| 760 | 766 | Note [Join arity prediction based on joinRhsArity]
|
| 761 | 767 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
| 762 | 768 | In general, the join arity from tail occurrences of a join point (O) may be
|
| ... | ... | @@ -42,7 +42,7 @@ import GHC.Types.Id.Info ( realUnfoldingInfo, setUnfoldingInfo, setRuleInfo, Id |
| 42 | 42 | import GHC.Types.Var ( isNonCoVarId, setTyVarUnfolding, tyVarOccInfo )
|
| 43 | 43 | import GHC.Types.Var.Set
|
| 44 | 44 | import GHC.Types.Var.Env
|
| 45 | -import GHC.Types.Demand( etaConvertDmdSig, topSubDmd )
|
|
| 45 | +import GHC.Types.Demand( topSubDmd )
|
|
| 46 | 46 | import GHC.Types.Tickish
|
| 47 | 47 | import GHC.Types.Basic
|
| 48 | 48 | |
| ... | ... | @@ -998,12 +998,7 @@ joinPointBinding_maybe bndr rhs |
| 998 | 998 | = Just (bndr, rhs)
|
| 999 | 999 | |
| 1000 | 1000 | | AlwaysTailCalled join_arity <- tailCallInfo (idOccInfo bndr)
|
| 1001 | - , (bndrs, body) <- etaExpandToJoinPoint join_arity rhs
|
|
| 1002 | - , let str_sig = idDmdSig bndr
|
|
| 1003 | - str_arity = count isId bndrs -- Strictness demands are for Ids only
|
|
| 1004 | - join_bndr = bndr `asJoinId` join_arity
|
|
| 1005 | - `setIdDmdSig` etaConvertDmdSig str_arity str_sig
|
|
| 1006 | - = Just (join_bndr, mkLams bndrs body)
|
|
| 1001 | + = Just (mkNewJoinPointBinding bndr join_arity rhs)
|
|
| 1007 | 1002 | |
| 1008 | 1003 | | otherwise
|
| 1009 | 1004 | = Nothing
|
| ... | ... | @@ -3121,25 +3121,20 @@ mkPolyAbsLams :: (b -> AbsVar, Var -> b -> b) |
| 3121 | 3121 | -- use it for both CoreExpr and LevelledExpr
|
| 3122 | 3122 | {-# INLINE mkPolyAbsLams #-}
|
| 3123 | 3123 | mkPolyAbsLams (get,set) bndrs body
|
| 3124 | - = go emptyVarSet [] bndrs
|
|
| 3124 | + = go bndrs
|
|
| 3125 | 3125 | where
|
| 3126 | - go _ tv_binds []
|
|
| 3127 | - = mkLets (reverse tv_binds) body
|
|
| 3128 | - go tvs tv_binds (bndr:bndrs)
|
|
| 3126 | + go [] = body
|
|
| 3127 | + go (bndr:bndrs)
|
|
| 3129 | 3128 | | Just ty <- tyVarUnfolding_maybe var
|
| 3130 | - = go (tvs `extendVarSet` var) (NonRec bndr (Type ty) : tv_binds) bndrs
|
|
| 3129 | + = Let (NonRec bndr (Type ty)) $
|
|
| 3130 | + go bndrs
|
|
| 3131 | 3131 | | otherwise
|
| 3132 | - = Lam bndr' (go tvs tv_binds bndrs)
|
|
| 3132 | + = Lam bndr' (go bndrs)
|
|
| 3133 | 3133 | where
|
| 3134 | 3134 | var = get bndr
|
| 3135 | - var' = updateVarType (expandTyVarUnfoldings tvs) $
|
|
| 3136 | - zap_unfolding var
|
|
| 3137 | - bndr' | isEmptyVarSet tvs = bndr
|
|
| 3138 | - | otherwise = set var' bndr
|
|
| 3139 | - |
|
| 3140 | 3135 | -- zap: We are going to lambda-abstract, so nuke any IdInfo
|
| 3141 | - zap_unfolding var | isId var = setIdInfo var vanillaIdInfo
|
|
| 3142 | - | otherwise = var
|
|
| 3136 | + bndr' | isId var = set (setIdInfo var vanillaIdInfo) bndr
|
|
| 3137 | + | otherwise = bndr
|
|
| 3143 | 3138 | |
| 3144 | 3139 | mkCoreAbsLams :: AbsVars -> CoreExpr -> CoreExpr
|
| 3145 | 3140 | -- Specialise for CoreExpr
|
| ... | ... | @@ -819,21 +819,17 @@ cpeRhsE :: CorePrepEnv -> CoreExpr -> UniqSM (Floats, CpeRhs) |
| 819 | 819 | -- For example
|
| 820 | 820 | -- f (g x) ===> ([v = g x], f v)
|
| 821 | 821 | |
| 822 | -cpeRhsE env (Type ty)
|
|
| 823 | - = return (emptyFloats, Type (cpSubstTy env ty))
|
|
| 824 | -cpeRhsE env (Coercion co)
|
|
| 825 | - = return (emptyFloats, Coercion (cpSubstCo env co))
|
|
| 826 | -cpeRhsE env expr@(Lit lit)
|
|
| 827 | - | LitNumber LitNumBigNat i <- lit
|
|
| 828 | - = cpeBigNatLit env i
|
|
| 829 | - | otherwise = return (emptyFloats, expr)
|
|
| 822 | +cpeRhsE env (Type ty) = return (emptyFloats, Type (cpSubstTy env ty))
|
|
| 823 | +cpeRhsE env (Coercion co) = return (emptyFloats, Coercion (cpSubstCo env co))
|
|
| 830 | 824 | cpeRhsE env expr@(Var {}) = cpeApp env expr
|
| 831 | 825 | cpeRhsE env expr@(App {}) = cpeApp env expr
|
| 832 | 826 | |
| 827 | +cpeRhsE env expr@(Lit lit)
|
|
| 828 | + = case lit of
|
|
| 829 | + LitNumber LitNumBigNat i -> cpeBigNatLit env i
|
|
| 830 | + _ -> return (emptyFloats, expr)
|
|
| 831 | + |
|
| 833 | 832 | cpeRhsE env (Let bind body)
|
| 834 | - | isTypeBind bind
|
|
| 835 | - = cpeRhsE env body
|
|
| 836 | - | otherwise
|
|
| 837 | 833 | = do { (env', bind_floats, maybe_bind') <- cpeBind NotTopLevel env bind
|
| 838 | 834 | ; (body_floats, body') <- cpeRhsE env' body
|
| 839 | 835 | ; let expr' = case maybe_bind' of Just bind' -> Let bind' body'
|