sheaf pushed to branch wip/andreask/ticked_joins at Glasgow Haskell Compiler / GHC Commits: 678d8950 by sheaf at 2026-01-27T18:45:58+01:00 deal with exitification - - - - - 4 changed files: - compiler/GHC/Core/Opt/Exitify.hs - compiler/GHC/Core/Opt/OccurAnal.hs - compiler/GHC/Core/Opt/Simplify/Iteration.hs - compiler/GHC/Types/Id.hs Changes: ===================================== compiler/GHC/Core/Opt/Exitify.hs ===================================== @@ -45,12 +45,14 @@ import GHC.Core.Type import GHC.Types.Var import GHC.Types.Id import GHC.Types.Id.Info +import GHC.Types.Tickish ( GenTickish(..), tickishCanScopeJoin ) + import GHC.Types.Var.Set import GHC.Types.Var.Env -import GHC.Types.Basic( JoinPointHood(..) ) import GHC.Utils.Monad.State.Strict import GHC.Utils.Misc( mapSnd ) +import GHC.Utils.Outputable import GHC.Data.FastString @@ -93,23 +95,23 @@ exitifyProgram binds = map goTopLvl binds where in_scope' = in_scope `extendInScopeSet` bndr - go in_scope (Let (Rec pairs) body) - | is_join_rec = mkLets (exitifyRec in_scope' pairs') body' - | otherwise = Let (Rec pairs') body' + go in_scope (Let (Rec pairs) body) = + case joinPointType_maybe (joinId_maybe . fst) pairs of + Just join_ty -> mkLets (exitifyRec join_ty in_scope' pairs') body' + Nothing -> Let (Rec pairs') body' where - is_join_rec = any (isJoinId . fst) pairs in_scope' = in_scope `extendInScopeSetBind` (Rec pairs) pairs' = mapSnd (go in_scope') pairs body' = go in_scope' body -- | State Monad used inside `exitify` -type ExitifyM = State [(JoinId, CoreExpr)] +type ExitifyM = State [(JoinId, CoreExpr)] -- | Given a recursive group of a joinrec, identifies “exit paths” and binds them as -- join-points outside the joinrec. -exitifyRec :: InScopeSet -> [(Var,CoreExpr)] -> [CoreBind] -exitifyRec in_scope pairs +exitifyRec :: JoinPointType -> InScopeSet -> [(Var,CoreExpr)] -> [CoreBind] +exitifyRec joinrec_join_ty in_scope pairs = [ NonRec xid rhs | (xid,rhs) <- exits ] ++ [Rec pairs'] where -- We need the set of free variables of many subexpressions here, so @@ -124,7 +126,7 @@ exitifyRec in_scope pairs forM ann_pairs $ \(x,rhs) -> do -- go past the lambdas of the join point let (args, body) = collectNAnnBndrs (idJoinArity x) rhs - body' <- go args body + body' <- go joinrec_join_ty args body -- (ExitJoin2): start with JoinPointType of parent joinrec let rhs' = mkLams args body' return (x, rhs') @@ -135,40 +137,41 @@ exitifyRec in_scope pairs -- variables bound on the way and lifts it out as a join point. -- -- ExitifyM is a state monad to keep track of floated binds - go :: [Var] -- Variables that are in-scope here, but - -- not in scope at the joinrec; that is, - -- we must potentially abstract over them. - -- Invariant: they are kept in dependency order + go :: JoinPointType -- what join point type to create; see Note [Exitification and quasi join points] + -> [Var] -- Variables that are in-scope here, but + -- not in scope at the joinrec; that is, + -- we must potentially abstract over them. + -- Invariant: they are kept in dependency order -> CoreExprWithFVs -- Current expression in tail position -> ExitifyM CoreExpr -- We first look at the expression (no matter what it shape is) -- and determine if we can turn it into a exit join point - go captured ann_e + go exit_join_ty captured ann_e | -- An exit expression has no recursive calls let fvs = dVarSetToVarSet (freeVarsOf ann_e) , disjointVarSet fvs recursive_calls - = go_exit captured (deAnnotate ann_e) fvs + = go_exit exit_join_ty captured (deAnnotate ann_e) fvs -- We could not turn it into a exit join point. So now recurse -- into all expression where eligible exit join points might sit, -- i.e. into all tail-call positions: -- Case right hand sides are in tail-call position - go captured (_, AnnCase scrut bndr ty alts) = do + go exit_join_ty captured (_, AnnCase scrut bndr ty alts) = do alts' <- forM alts $ \(AnnAlt dc pats rhs) -> do - rhs' <- go (captured ++ [bndr] ++ pats) rhs + rhs' <- go exit_join_ty (captured ++ [bndr] ++ pats) rhs return (Alt dc pats rhs') return $ Case (deAnnotate scrut) bndr ty alts' - go captured (_, AnnLet ann_bind body) + go exit_join_ty captured (_, AnnLet ann_bind body) -- join point, RHS and body are in tail-call position | AnnNonRec j rhs <- ann_bind , JoinPoint { joinPointArity = join_arity } <- idJoinPointHood j = do let (params, join_body) = collectNAnnBndrs join_arity rhs - join_body' <- go (captured ++ params) join_body + join_body' <- go exit_join_ty (captured ++ params) join_body let rhs' = mkLams params join_body' - body' <- go (captured ++ [j]) body + body' <- go exit_join_ty (captured ++ [j]) body return $ Let (NonRec j rhs') body' -- rec join point, RHSs and body are in tail-call position @@ -178,30 +181,41 @@ exitifyRec in_scope pairs pairs' <- forM pairs $ \(j,rhs) -> do let join_arity = idJoinArity j (params, join_body) = collectNAnnBndrs join_arity rhs - join_body' <- go (captured ++ js ++ params) join_body + join_body' <- go exit_join_ty (captured ++ js ++ params) join_body let rhs' = mkLams params join_body' return (j, rhs') - body' <- go (captured ++ js) body + body' <- go exit_join_ty (captured ++ js) body return $ Let (Rec pairs') body' -- normal Let, only the body is in tail-call position | otherwise - = do body' <- go (captured ++ bindersOf bind ) body + = do body' <- go exit_join_ty (captured ++ bindersOf bind ) body return $ Let bind body' where bind = deAnnBind ann_bind + -- (ExitJoin1) from Note [Exitification and quasi join points] + go _ captured (_, AnnCast ann_e (_, co)) = do + e' <- go QuasiJoinPoint captured ann_e + return (Cast e' co) + go exit_join_ty captured (_, AnnTick tickish ann_e) + | tickishCanScopeJoin tickish + = Tick tickish <$> go exit_join_ty captured ann_e + | ProfNote {} <- tickish + = Tick tickish <$> go QuasiJoinPoint captured ann_e + -- Cannot be turned into an exit join point, but also has no -- tail-call subexpression. Nothing to do here. - go _ ann_e = return (deAnnotate ann_e) + go _ _ ann_e = return (deAnnotate ann_e) --------------------- - go_exit :: [Var] -- Variables captured locally + go_exit :: JoinPointType -- what join point type to create; see Note [Exitification and quasi join points] + -> [Var] -- Variables captured locally -> CoreExpr -- An exit expression -> VarSet -- Free vars of the expression -> ExitifyM CoreExpr -- go_exit deals with a tail expression that is floatable -- out as an exit point; that is, it mentions no recursive calls - go_exit captured e fvs + go_exit exit_join_ty captured e fvs -- Do not touch an expression that is already a join jump where all arguments -- are captured variables. See Note [Idempotency] -- But _do_ float join jumps with interesting arguments. @@ -226,7 +240,7 @@ exitifyRec in_scope pairs let rhs = mkLams abs_vars e avoid = in_scope `extendInScopeSetList` captured -- Remember this binding under a suitable name - ; v <- addExit avoid (length abs_vars) rhs + ; v <- addExit avoid exit_join_ty (length abs_vars) rhs -- And jump to it from here ; return $ mkVarApps (Var v) abs_vars } @@ -263,7 +277,7 @@ exitifyRec in_scope pairs -- * any bound variables (captured) -- * any exit join points created so far. mkExitJoinId :: InScopeSet -> Type -> JoinPointType -> JoinArity -> ExitifyM JoinId -mkExitJoinId in_scope ty join_ty join_arity = do +mkExitJoinId in_scope ty exit_join_ty join_arity = do fs <- get let avoid = in_scope `extendInScopeSetList` (map fst fs) `extendInScopeSet` exit_id_tmpl -- just cosmetics @@ -271,17 +285,65 @@ mkExitJoinId in_scope ty join_ty join_arity = do where exit_id_tmpl = asJoinId (mkSysLocal (fsLit "exit") initExitJoinUnique ManyTy ty) - join_ty join_arity + exit_join_ty join_arity -addExit :: InScopeSet -> JoinArity -> CoreExpr -> ExitifyM JoinId -addExit in_scope join_arity rhs = do +addExit :: InScopeSet -> JoinPointType -> JoinArity -> CoreExpr -> ExitifyM JoinId +addExit in_scope exit_join_ty join_arity rhs = do -- Pick a suitable name let ty = exprType rhs - v <- mkExitJoinId in_scope ty TrueJoinPoint join_arity + v <- mkExitJoinId in_scope ty exit_join_ty join_arity fs <- get put ((v,rhs):fs) return v +{- Note [Exitification and quasi join points] +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +When we float an exit path, we must determine if the new exit join point +should be a true join point or a quasi join point, in the sense of +Note [Quasi join points] in GHC.Core.Opt.Simplify.Iteration). + +The new exit join point must be a quasi join point if either of the following +conditions apply: + + (ExitJoin1) The exit path occurs under a cast or a profiling tick. + + (ExitJoin2) The original joinrec was a quasi join point. + +Rationale for (ExitJoin1): + + Suppose we have: + + joinrec j x = ... case ... of alts -> e |> co ... in ... + + After exitifying 'e' to 'exit': + + join exit y = e in + joinrec j x = ... case ... of alts -> (exit y) |> co ... in ... + + Because the jump to 'exit' occurs under a cast, 'exit' must be classified + as a quasi join point. + +Rationale for (ExitJoin2): + + Suppose we have: + + quasijoinrec j x = case x of { 0 -> 100; _ -> j (x-1) } in j 0 |> co + + If we float an exit out of 'j', we end up with + + join exit = 100 in + quasijoinrec j x = case x of { 0 -> exit ; _ -> j (x-1) } in j 0 |> co + + Now suppose we inline j and simplify; we end up with: + + join exit = 100 in exit |> co + + We see now that 'exit' must be a quasi join point, due to the cast. + + Hence: exit join points for a parent quasi join point must themselves be + quasi join points. +-} + {- Note [Interesting expression] ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ ===================================== compiler/GHC/Core/Opt/OccurAnal.hs ===================================== @@ -68,7 +68,6 @@ import GHC.Builtin.Names( runRWKey ) import GHC.Unit.Module( Module ) import Data.List (mapAccumL) -import qualified Data.List.NonEmpty as NE import qualified Data.Semigroup as Semi {- @@ -4118,10 +4117,7 @@ setBinderOcc occ_info bndr -- See Note [Invariants on join points] in "GHC.Core". decideRecJoinPointHood :: TopLevelFlag -> UsageDetails -> [CoreBndr] -> Maybe JoinPointType -decideRecJoinPointHood lvl usage bndrs = do - bndrsNE <- NE.nonEmpty bndrs - -- Invariant 3: Either all are join points or none are - Semi.sconcat <$> traverse ok bndrsNE +decideRecJoinPointHood lvl usage = joinPointType_maybe ok where ok bndr = okForJoinPoint lvl bndr (lookupTailCallInfo usage bndr) ===================================== compiler/GHC/Core/Opt/Simplify/Iteration.hs ===================================== @@ -2056,93 +2056,118 @@ is a join point, and what 'cont' is, in a value of type MaybeJoinCont of a SpecConstr-generated RULE for a join point. -} --- SLD TODO horrible logic that must be removed -peelJoinResTy :: Int -> Type -> Type -peelJoinResTy 0 ty = ty -peelJoinResTy n ty - | Just (_bndr, inner_ty) <- splitForAllTyCoVar_maybe ty - = peelJoinResTy n inner_ty - | Just (_, _mult, _arg, res_ty) <- splitFunTy_maybe ty - = peelJoinResTy (n-1) res_ty - | otherwise - = ty +joinResTy :: HasDebugCallStack => JoinArity -> Type -> Type +joinResTy n0 ty0 = go n0 ty0 + where + go 0 ty = ty + go n ty + | Just (_bndr, res_ty) <- splitPiTy_maybe ty + = go (n-1) res_ty + | otherwise + = pprPanic "joinResTy" $ + vcat [ text "join arity:" <+> ppr n0 + , text "join ty:" <+> ppr ty0 + , text "n:" <+> ppr n + , text "ty:" <+> ppr ty + ] simplNonRecJoinPoint :: SimplEnv -> InId -> InExpr -> InExpr -> SimplCont -> SimplM (SimplFloats, OutExpr) -simplNonRecJoinPoint env bndr rhs body cont +simplNonRecJoinPoint env0 bndr rhs body cont0 = assert (isJoinId bndr) $ - wrapJoinCont do_case_case env cont $ \ env cont -> + wrapJoinCont do_case_case env0 bndr cont0 $ + \ WJC { wjc_bind_env = env, wjc_bind_cont = bind_cont, wjc_body_cont = body_cont } -> do { -- We push join_cont into the join RHS and the body; -- and wrap wrap_cont around the whole thing - ; let (mult, res_ty) - -- SLD TODO - | Just QuasiJoinPoint <- joinId_maybe bndr - = (idMult bndr, peelJoinResTy (idJoinArity bndr) $ substTy env (idType bndr)) - | otherwise - = (contHoleScaling cont, contResultType cont) + let mult = contHoleScaling bind_cont + res_ty = contResultType bind_cont ; (env1, bndr1) <- simplNonRecJoinBndr env bndr mult res_ty - ; (env2, bndr2) <- addBndrRules env1 bndr bndr1 (BC_Join NonRecursive cont) - ; (floats1, env3) <- simplJoinBind NonRecursive cont (bndr,env) (bndr2,env2) (rhs,env) - ; (floats2, body') <- simplExprF env3 body cont + ; (env2, bndr2) <- addBndrRules env1 bndr bndr1 (BC_Join NonRecursive bind_cont) + ; (floats1, env3) <- simplJoinBind NonRecursive bind_cont (bndr,env) (bndr2,env2) (rhs,env) + ; (floats2, body') <- simplExprF env3 body body_cont ; return (floats1 `addFloats` floats2, body') } where do_case_case | Just TrueJoinPoint <- joinId_maybe bndr - = seCaseCase env + = seCaseCase env0 | otherwise = False simplRecJoinPoint :: SimplEnv -> [(InId, InExpr)] -> InExpr -> SimplCont -> SimplM (SimplFloats, OutExpr) -simplRecJoinPoint env pairs body cont - = wrapJoinCont do_case_case env cont $ \ env cont -> - do { let bndrs = map fst pairs - (mult, res_ty) - -- SLD TODO - | [b] <- bndrs - , Just QuasiJoinPoint <- joinId_maybe b - = (idMult b, peelJoinResTy (idJoinArity b) $ substTy env (idType b)) - | otherwise - = (contHoleScaling cont, contResultType cont) +simplRecJoinPoint env0 pairs body cont0 + = wrapJoinCont do_case_case env0 (head bndrs) cont0 $ + \ WJC { wjc_bind_env = env, wjc_bind_cont = bind_cont, wjc_body_cont = body_cont } -> + do { let mult = contHoleScaling bind_cont + res_ty = contResultType bind_cont ; env1 <- simplRecJoinBndrs env bndrs mult res_ty -- NB: bndrs' don't have unfoldings or rules -- We add them as we go down - ; (floats1, env2) <- simplRecBind env1 (BC_Join Recursive cont) pairs - ; (floats2, body') <- simplExprF env2 body cont + ; (floats1, env2) <- simplRecBind env1 (BC_Join Recursive bind_cont) pairs + ; (floats2, body') <- simplExprF env2 body body_cont ; return (floats1 `addFloats` floats2, body') } where + bndrs = map fst pairs + do_case_case = - if all ((== Just TrueJoinPoint) . joinId_maybe . fst) pairs - then seCaseCase env + if all ((== Just TrueJoinPoint) . joinId_maybe) bndrs + then seCaseCase env0 else False -------------------- + +-- | Information computed by 'wrapJoinCont'. +data WrapJoinCont + = WJC + { wjc_bind_env :: !SimplEnv + , wjc_bind_cont :: !SimplCont + , wjc_body_cont :: !SimplCont + } + wrapJoinCont :: Bool - -> SimplEnv -> SimplCont - -> (SimplEnv -> SimplCont -> SimplM (SimplFloats, OutExpr)) + -> SimplEnv -> InId -> SimplCont + -> (WrapJoinCont -> SimplM (SimplFloats, OutExpr)) -> SimplM (SimplFloats, OutExpr) -- Deal with making the continuation duplicable if necessary, -- and with the no-case-of-case situation. -wrapJoinCont do_case_case env cont thing_inside +wrapJoinCont do_case_case env join_bndr cont thing_inside | contIsStop cont -- Common case; no need for fancy footwork - = thing_inside env cont + = thing_inside $ + WJC { wjc_bind_env = env + , wjc_bind_cont = if do_case_case then cont else no_case_case_bind_cont + , wjc_body_cont = cont + } | do_case_case -- Normal situation: do the "case-of-case" transformation. -- See Note [Join points and case-of-case]. = do { (floats1, cont') <- mkDupableCont env cont - ; (floats2, result) <- thing_inside (env `setInScopeFromF` floats1) cont' + ; let wjc = WJC { wjc_bind_env = env `setInScopeFromF` floats1 + , wjc_bind_cont = cont' + , wjc_body_cont = cont' + } + ; (floats2, result) <- thing_inside wjc ; return (floats1 `addFloats` floats2, result) } | otherwise -- No "case-of-case" transformation. -- See Note [Join points with -fno-case-of-case]. - = do { (floats1, expr1) <- thing_inside env (mkBoringStop (contHoleType cont)) + = do { let + wjc = WJC { wjc_bind_env = env + , wjc_bind_cont = no_case_case_bind_cont + , wjc_body_cont = mkBoringStop (contHoleType cont) + } + ; (floats1, expr1) <- thing_inside wjc ; let (floats2, expr2) = wrapJoinFloatsX floats1 expr1 ; (floats3, expr3) <- rebuild (env `setInScopeFromF` floats2) expr2 cont ; return (floats2 `addFloats` floats3, expr3) } + where + -- See Wrinkle [Casts and join point result types] + join_res_ty = joinResTy (idJoinArity join_bndr) + $ substTy env (idType join_bndr) + no_case_case_bind_cont = mkBoringStop join_res_ty -------------------- trimJoinCont :: Id -- Used only in error message @@ -2282,9 +2307,9 @@ As per Note [Join points and case-of-case], we proceed by first applying the argument to both the join point RHS and the case alternatives: join { j :: Bool -> IO (); j _ = guts arg ] } - in case b of - False -> (scctick<foo> jump j True) arg - True -> jump j False arg + in case b of + False -> (scctick<foo> jump j True) arg + True -> jump j False arg Then we rely on 'trimJoinCont' to remove the argument. In this case, this fails for the first branch, because 'trimJoinCont' doesn't look through profiling @@ -2293,9 +2318,9 @@ end up with, as we don't want to misattribute profiling costs. We could plausibly transform to the following: join { j :: Bool -> IO (); j scc_or_null _ = (setSCC# scc_or_null guts) arg ] } - in case b of - False -> jump j <foo> True - True -> jump j null False + in case b of + False -> jump j <foo> True + True -> jump j null False where `setSCC#` is a new primop that would set the current cost centre pointer (or no-op if the given pointer is null). @@ -2307,17 +2332,17 @@ So instead, for now, we simply disallow the case-of-case transformation for 'j'. Similarly for casts: join { j = blah } - in case e of - False -> j True |> co1 - True -> j False |> co2 + in case e of + False -> j True |> co1 + True -> j False |> co2 if we want to apply this to an argument 'arg', we would need to perform the following transformation: join { j co = ( blah |> co ) arg } - in case e of - False -> j co1 True - True -> j co2 False + in case e of + False -> j co1 True + True -> j co2 False in which we add a coercion argument to the join point. Again, this is not a transformation we currently implement, so we instead prevent case-of-case for @@ -2339,6 +2364,33 @@ we proceed as follows: If we are dealing with a quasi join point, we switch off the case-of-case transformation. +Wrinkle [Casts and join point result types] + + When dealing with a quasi joint-point, we must preserve the original type of + the join point instead of transforming the type (as in Core.Opt.Simplify.Env.adjustJoinPointType). + This is because we don't trim the continuation like we do in + Note [Join points and case-of-case]. + + For example, suppose we have: + + type family F a + + join + j :: forall a. a -> F a + j @a x = ... + in case e of + False -> j @T1 x1 |> ( co1 :: F T1 ~ Int ) + True -> j @T2 x2 |> ( co2 :: F T2 ~ Int ) + + If we used 'contHoleType cont' to compute the result type of 'j', we would + change the result type of 'j' to 'Int', when it needs to remain 'F a'. + + Instead, we avoid doing that and re-compute the result type of 'j' using + 'joinResTy' to get 'F a', as required. + +See also Note [Exitification and quasi join points] in GHC.Core.Opt.Exitify +for another wrinkle. + ************************************************************************ * * Variables ===================================== compiler/GHC/Types/Id.hs ===================================== @@ -79,7 +79,8 @@ module GHC.Types.Id ( -- ** Join variables JoinId, JoinPointHood, - isJoinId, joinId_maybe, idJoinPointHood, idJoinArity, + isJoinId, joinId_maybe, joinPointType_maybe, + idJoinPointHood, idJoinArity, asJoinId, asJoinId_maybe, zapJoinId, -- ** Inline pragma stuff @@ -172,6 +173,8 @@ import GHC.Data.FastString import GHC.Utils.Misc import GHC.Utils.Outputable import GHC.Utils.Panic +import qualified Data.List.NonEmpty as NE +import qualified Data.Semigroup as Semi -- infixl so you can say (id `set` a `set` b) infixl 1 `setIdUnfolding`, @@ -584,6 +587,13 @@ joinId_maybe id _ -> Nothing | otherwise = Nothing +joinPointType_maybe :: (a -> Maybe JoinPointType) -> [a] -> Maybe JoinPointType +joinPointType_maybe f xs = do + xsNE <- NE.nonEmpty xs + Semi.sconcat <$> traverse f xsNE + -- traverse: either all are join points or none are + -- sconcat: only a 'TrueJoinPoint' if all are + -- | Doesn't return strictness marks idJoinPointHood :: Var -> JoinPointHood idJoinPointHood id View it on GitLab: https://gitlab.haskell.org/ghc/ghc/-/commit/678d8950191d2669134d4f126b966491... -- View it on GitLab: https://gitlab.haskell.org/ghc/ghc/-/commit/678d8950191d2669134d4f126b966491... You're receiving this email because of your account on gitlab.haskell.org.