Simon Peyton Jones pushed to branch wip/T26425 at Glasgow Haskell Compiler / GHC Commits: 4e7dd636 by Simon Peyton Jones at 2025-10-31T17:36:56+00:00 Undo the "suppress unfoldings of big join points" - - - - - 3916cdd3 by Simon Peyton Jones at 2025-10-31T17:37:28+00:00 Tidying up on the main point of this MR ..needs better docs! - - - - - 2 changed files: - compiler/GHC/Core/Opt/OccurAnal.hs - compiler/GHC/Core/Opt/Simplify/Iteration.hs Changes: ===================================== compiler/GHC/Core/Opt/OccurAnal.hs ===================================== @@ -759,62 +759,62 @@ rest of 'OccInfo' until it goes on the binder. Note [Join arity prediction based on joinRhsArity] ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -In general, the join arity from tail occurrences of a join point (O) may be -higher or lower than the manifest join arity of the join body (M). E.g., +In general, the join arity from tail occurrences of a join point (OAr) may be +higher or lower than the manifest join arity of the join body (MAr). E.g., - -- M > O: - let f x y = x + y -- M = 2 - in if b then f 1 else f 2 -- O = 1 + -- MAr > Oar: + let f x y = x + y -- MAr = 2 + in if b then f 1 else f 2 -- OAr = 1 ==> { Contify for join arity 1 } join f x = \y -> x + y in if b then jump f 1 else jump f 2 - -- M < O - let f = id -- M = 0 - in if ... then f 12 else f 13 -- O = 1 + -- MAr < Oar + let f = id -- MAr = 0 + in if ... then f 12 else f 13 -- OAr = 1 ==> { Contify for join arity 1, eta-expand f } join f x = id x in if b then jump f 12 else jump f 13 -But for *recursive* let, it is crucial that both arities match up, consider +But for *recursive* let, it is crucial MAr=OAr. Consider: letrec f x y = if ... then f x else True in f 42 -Here, M=2 but O=1. If we settled for a joinrec arity of 1, the recursive jump +Here, MAr=2 but OAr=1. If we settled for a joinrec arity of 1, the recursive jump would not happen in a tail context! Contification is invalid here. -So indeed it is crucial to demand that M=O. +So indeed it is crucial to demand that MAr=OAr. -(Side note: Actually, we could be more specific: Let O1 be the join arity of -occurrences from the letrec RHS and O2 the join arity from the let body. Then -we need M=O1 and M<=O2 and could simply eta-expand the RHS to match O2 later. -M=O is the specific case where we don't want to eta-expand. Neither the join +(Side note: Actually, we could be more specific: Let OAr1 be the join arity of +occurrences from the letrec RHS and OAr2 the join arity from the let body. Then +we need MAr=OAr1 and MAr<=OAr2 and could simply eta-expand the RHS to match OAr2 later. +MAr=OAr is the specific case where we don't want to eta-expand. Neither the join points paper nor GHC does this at the moment.) We can capitalise on this observation and conclude that *if* f could become a -joinrec (without eta-expansion), it will have join arity M. -Now, M is just the result of 'joinRhsArity', a rather simple, local analysis. +joinrec (without eta-expansion), it will have join arity MAr. +Now, MAr is just the result of 'joinRhsArity', a rather simple, local analysis. It is also the join arity inside the 'TailUsageDetails' returned by 'occAnalLamTail', so we can predict join arity without doing any fixed-point iteration or really doing any deep traversal of let body or RHS at all. -We check for M in the 'adjustTailUsage' call inside 'tagRecBinders'. +We check for MAr in the 'adjustTailUsage' call inside 'tagRecBinders'. All this is quite apparent if you look at the contification transformation in Fig. 5 of "Compiling without Continuations" (which does not account for eta-expansion at all, mind you). The letrec case looks like this - +n letrec f = /\as.\xs. L[us] in L'[es] ... and a bunch of conditions establishing that f only occurs in app heads of join arity (len as + len xs) inside us and es ... -The syntactic form `/\as.\xs. L[us]` forces M=O iff `f` occurs in `us`. However, +The syntactic form `/\as.\xs. L[us]` forces MAr=OAr iff `f` occurs in `us`. However, for non-recursive functions, this is the definition of contification from the paper: let f = /\as.\xs.u in L[es] ... conditions ... -Note that u could be a lambda itself, as we have seen. No relationship between M -and O to exploit here. +Note that u could be a lambda itself, as we have seen. No relationship between MAr +and OAr to exploit here. Note [Join points and unfoldings/rules] ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ @@ -998,7 +998,8 @@ occAnalBind !env lvl ire (NonRec bndr rhs) thing_inside combine -- Now analyse the body, adding the join point -- into the environment with addJoinPoint - !(WUD body_uds (occ, body)) = occAnalNonRecBody env bndr' $ \env -> + env_body = addLocalLet env lvl bndr + !(WUD body_uds (occ, body)) = occAnalNonRecBody env_body bndr' $ \env -> thing_inside (addJoinPoint env bndr' rhs_uds) in if isDeadOcc occ -- Drop dead code; see Note [Dead code] @@ -1012,8 +1013,8 @@ occAnalBind !env lvl ire (NonRec bndr rhs) thing_inside combine -- The normal case, including newly-discovered join points -- Analyse the body and /then/ the RHS - | WUD body_uds (occ,body) <- occAnalNonRecBody (addLocalLet env lvl bndr) - bndr thing_inside + | let env_body = addLocalLet env lvl bndr + , WUD body_uds (occ,body) <- occAnalNonRecBody env_body bndr thing_inside = if isDeadOcc occ -- Drop dead code; see Note [Dead code] then WUD body_uds body else let @@ -1059,7 +1060,7 @@ occAnalNonRecRhs !env lvl imp_rule_edges mb_join bndr rhs rhs_ctxt = mkNonRecRhsCtxt lvl bndr unf -- See Note [Join arity prediction based on joinRhsArity] - -- Match join arity O from mb_join_arity with manifest join arity M as + -- Match join arity OAr from mb_join_arity with manifest join arity MAr as -- returned by of occAnalLamTail. It's totally OK for them to mismatch; -- hence adjust the UDs from the RHS @@ -1769,7 +1770,7 @@ makeNode !env imp_rule_edges bndr_set (bndr, rhs) -- here because that is what we are setting! WTUD unf_tuds unf' = occAnalUnfolding rhs_env unf adj_unf_uds = adjustTailArity (JoinPoint rhs_ja) unf_tuds - -- `rhs_ja` is `joinRhsArity rhs` and is the prediction for source M + -- `rhs_ja` is `joinRhsArity rhs` and is the prediction for source MAr -- of Note [Join arity prediction based on joinRhsArity] --------- IMP-RULES -------- @@ -1780,7 +1781,7 @@ makeNode !env imp_rule_edges bndr_set (bndr, rhs) --------- All rules -------- -- See Note [Join points and unfoldings/rules] - -- `rhs_ja` is `joinRhsArity rhs'` and is the prediction for source M + -- `rhs_ja` is `joinRhsArity rhs'` and is the prediction for source MAr -- of Note [Join arity prediction based on joinRhsArity] rules_w_uds :: [(CoreRule, UsageDetails, UsageDetails)] rules_w_uds = [ (r,l,adjustTailArity (JoinPoint rhs_ja) rhs_wuds) @@ -2182,7 +2183,9 @@ occAnalLamTail :: OccEnv -> CoreExpr -> WithTailUsageDetails CoreExpr -- See Note [Adjusting right-hand sides] occAnalLamTail env expr = let !(WUD usage expr') = occ_anal_lam_tail env expr - in WTUD (TUD (joinRhsArity expr) usage) expr' + in WTUD (TUD (joinRhsArity expr') usage) expr' + -- If expr looks like (\x. let dead = e in \y. blah), where `dead` is dead + -- then joinRhsArity expr' might exceed joinRhsArity expr occ_anal_lam_tail :: OccEnv -> CoreExpr -> WithUsageDetails CoreExpr -- Does not markInsideLam etc for the outmost batch of lambdas @@ -3643,7 +3646,12 @@ data LocalOcc -- See Note [LocalOcc] -- Combining (AlwaysTailCalled 2) and (AlwaysTailCalled 3) -- gives NoTailCallInfo , lo_int_cxt :: !InterestingCxt } + | ManyOccL !TailCallInfo + -- Why do we need TailCallInfo on ManyOccL? + -- Answer: recursive bindings are entered many times: + -- rec { j x = ...j x'... } in j y + -- See the uses of `andUDs` in `tagRecBinders` instance Outputable LocalOcc where ppr (OneOccL { lo_n_br = n, lo_tail = tci }) @@ -3678,7 +3686,7 @@ instance Outputable UsageDetails where -- | TailUsageDetails captures the result of applying 'occAnalLamTail' -- to a function `\xyz.body`. The TailUsageDetails pairs together -- * the number of lambdas (including type lambdas: a JoinArity) --- * UsageDetails for the `body` of the lambda, unadjusted by `adjustTailUsage`. +-- * UsageDetails for the `body` of the lambda, /unadjusted/ by `adjustTailUsage`. -- If the binding turns out to be a join point with the indicated join -- arity, this unadjusted usage details is just what we need; otherwise we -- need to discard tail calls. That's what `adjustTailUsage` does. @@ -3865,8 +3873,6 @@ lookupOccInfoByUnique (UD { ud_env = env | uniq `elemVarEnvByKey` z_tail = NoTailCallInfo | otherwise = ti - - ------------------- -- See Note [Adjusting right-hand sides] @@ -3876,21 +3882,22 @@ adjustNonRecRhs :: JoinPointHood -- ^ This function concentrates shared logic between occAnalNonRecBind and the -- AcyclicSCC case of occAnalRec. -- It returns the adjusted rhs UsageDetails combined with the body usage -adjustNonRecRhs mb_join_arity rhs_wuds@(WTUD _ rhs) - = WUD (adjustTailUsage mb_join_arity rhs_wuds) rhs - +adjustNonRecRhs mb_join_arity (WTUD (TUD rhs_ja uds) rhs) + = WUD (adjustTailUsage exact_join rhs uds) rhs + where + exact_join = mb_join_arity == JoinPoint rhs_ja -adjustTailUsage :: JoinPointHood - -> WithTailUsageDetails CoreExpr -- Rhs usage, AFTER occAnalLamTail +adjustTailUsage :: Bool -- True <=> Exactly-matching join point; don't do markNonTail + -> CoreExpr -- Rhs usage, AFTER occAnalLamTail + -> UsageDetails -> UsageDetails -adjustTailUsage mb_join_arity (WTUD (TUD rhs_ja uds) rhs) +adjustTailUsage exact_join rhs uds = -- c.f. occAnal (Lam {}) markAllInsideLamIf (not one_shot) $ markAllNonTailIf (not exact_join) $ uds where one_shot = isOneShotFun rhs - exact_join = mb_join_arity == JoinPoint rhs_ja adjustTailArity :: JoinPointHood -> TailUsageDetails -> UsageDetails adjustTailArity mb_rhs_ja (TUD ja usage) @@ -3937,8 +3944,9 @@ tagNonRecBinder lvl occ bndr tagRecBinders :: TopLevelFlag -- At top level? -> UsageDetails -- Of body of let ONLY -> [NodeDetails] - -> WithUsageDetails -- Adjusted details for whole scope, - -- with binders removed + -> WithUsageDetails -- Adjusted details for whole scope + -- still including the binders; + -- (they are removed by `addInScope`) [IdWithOccInfo] -- Tagged binders -- Substantially more complicated than non-recursive case. Need to adjust RHS -- details *before* tagging binders (because the tags depend on the RHSes). @@ -3948,32 +3956,21 @@ tagRecBinders lvl body_uds details_s -- 1. See Note [Join arity prediction based on joinRhsArity] -- Determine possible join-point-hood of whole group, by testing for - -- manifest join arity M. - -- This (re-)asserts that makeNode had made tuds for that same arity M! + -- manifest join arity MAr. + -- This (re-)asserts that makeNode had made tuds for that same arity MAr! unadj_uds = foldr (andUDs . test_manifest_arity) body_uds details_s - test_manifest_arity ND{nd_rhs = WTUD tuds rhs} - = adjustTailArity (JoinPoint (joinRhsArity rhs)) tuds + test_manifest_arity ND{nd_rhs = WTUD (TUD rhs_ja uds) rhs} + = assertPpr (rhs_ja == joinRhsArity rhs) (ppr rhs_ja $$ ppr uds $$ ppr rhs) $ + uds + will_be_joins :: Bool will_be_joins = decideRecJoinPointHood lvl unadj_uds bndrs - mb_join_arity :: Id -> JoinPointHood - -- mb_join_arity: See Note [Join arity prediction based on joinRhsArity] - -- This is the source O - mb_join_arity bndr - -- Can't use willBeJoinId_maybe here because we haven't tagged - -- the binder yet (the tag depends on these adjustments!) - | will_be_joins - , AlwaysTailCalled arity <- lookupTailCallInfo unadj_uds bndr - = JoinPoint arity - | otherwise - = assert (not will_be_joins) -- Should be AlwaysTailCalled if - NotJoinPoint -- we are making join points! - -- 2. Adjust usage details of each RHS, taking into account the -- join-point-hood decision - rhs_udss' = [ adjustTailUsage (mb_join_arity bndr) rhs_wuds + rhs_udss' = [ adjustTailUsage will_be_joins rhs rhs_uds -- Matching occAnalLamTail in makeNode - | ND { nd_bndr = bndr, nd_rhs = rhs_wuds } <- details_s ] + | ND { nd_rhs = WTUD (TUD _ rhs_uds) rhs } <- details_s ] -- 3. Compute final usage details from adjusted RHS details adj_uds = foldr andUDs body_uds rhs_udss' @@ -3992,9 +3989,9 @@ setBinderOcc occ_info bndr | otherwise = setIdOccInfo bndr occ_info -- | Decide whether some bindings should be made into join points or not, based --- on its occurrences. This is +-- on its occurrences. -- Returns `False` if they can't be join points. Note that it's an --- all-or-nothing decision, as if multiple binders are given, they're +-- all-or-nothing decision: if multiple binders are given, they are -- assumed to be mutually recursive. -- -- It must, however, be a final decision. If we say `True` for 'f', ===================================== compiler/GHC/Core/Opt/Simplify/Iteration.hs ===================================== @@ -4598,12 +4598,12 @@ mkLetUnfolding :: SimplEnv -> TopLevelFlag -> UnfoldingSource -> InId -> Bool -- True <=> this is a join point -> OutExpr -> SimplM Unfolding mkLetUnfolding env top_lvl src id is_join new_rhs - | is_join - , UnfNever <- guidance - = -- For large join points, don't keep an unfolding at all if it is large - -- This is just an attempt to keep residency under control in - -- deeply-nested join-point such as those arising in #26425 - return NoUnfolding +-- | is_join +-- , UnfNever <- guidance +-- = -- For large join points, don't keep an unfolding at all if it is large +-- -- This is just an attempt to keep residency under control in +-- -- deeply-nested join-point such as those arising in #26425 +-- return NoUnfolding | otherwise = return (mkCoreUnfolding src is_top_lvl new_rhs Nothing guidance) View it on GitLab: https://gitlab.haskell.org/ghc/ghc/-/compare/193bf31249a81c6c5bcddb956dbca7d... -- View it on GitLab: https://gitlab.haskell.org/ghc/ghc/-/compare/193bf31249a81c6c5bcddb956dbca7d... You're receiving this email because of your account on gitlab.haskell.org.