[Git][ghc/ghc][wip/andreask/ticked_joins] Allow join point Ids to occur below ticks & casts
sheaf pushed to branch wip/andreask/ticked_joins at Glasgow Haskell Compiler / GHC Commits: 190bc495 by sheaf at 2026-01-23T10:41:45+01:00 Allow join point Ids to occur below ticks & casts This commit classifies all join points into two categories: - true join points - quasi join points A quasi join point is a join point in which one of the binders occurs under more profiling ticks or casts than its binding site. The only operational difference is that, for quasi join points, we cannot perform the case-of-case transformation described in Note [Join points and case-of-case] in GHC.Core.Opt.Simplify.Iteration. All of this is explained in detail in Note [Quasi join points]. Fixes #26693 and #26642 Improves on #26157 and #26422, but doesn't entirely fix them because in an ideal world casts & profiling ticks should not inhibit optimisations. ------------------------- Metric Increase: T21839c T9961 ------------------------- - - - - - 10 changed files: - compiler/GHC/Core/Lint.hs - compiler/GHC/Core/Opt/Arity.hs - compiler/GHC/Core/Opt/OccurAnal.hs - compiler/GHC/Core/Opt/Simplify/Env.hs - compiler/GHC/Core/Opt/Simplify/Iteration.hs - compiler/GHC/Core/SimpleOpt.hs - compiler/GHC/Core/Utils.hs - compiler/GHC/CoreToStg/Prep.hs - compiler/GHC/Types/Basic.hs - compiler/GHC/Types/Tickish.hs Changes: ===================================== compiler/GHC/Core/Lint.hs ===================================== @@ -672,7 +672,7 @@ lintRhs :: Id -> CoreExpr -> LintM (OutType, UsageEnv) lintRhs bndr rhs | JoinPoint arity <- idJoinPointHood bndr = lintJoinLams arity (Just bndr) rhs - | AlwaysTailCalled arity <- tailCallInfo (idOccInfo bndr) + | AlwaysTailCalled { tailCallArity = arity } <- tailCallInfo (idOccInfo bndr) = lintJoinLams arity Nothing rhs -- Allow applications of the data constructor @StaticPtr@ at the top @@ -929,9 +929,12 @@ lintCoreExpr (Tick tickish expr) = do { case tickish of Breakpoint _ _ ids -> forM_ ids $ \id -> lintIdOcc id 0 _ -> return () - ; markAllJoinsBadIf block_joins $ lintCoreExpr expr } + ; expr_l <- lintCoreExpr expr + ; r <- markAllJoinsBadIf block_joins $ pure expr_l + -- ; when block_joins + ; pure r} where - block_joins = not (tickish `tickishScopesLike` SoftScope) + block_joins = not (tickishCanScopeJoin tickish) -- TODO Consider whether this is the correct rule. It is consistent with -- the simplifier's behaviour - cost-centre-scoped ticks become part of -- the continuation, and thus they behave like part of an evaluation ===================================== compiler/GHC/Core/Opt/Arity.hs ===================================== @@ -90,7 +90,6 @@ import GHC.Utils.Misc import Data.List.NonEmpty ( nonEmpty ) import qualified Data.List.NonEmpty as NE -import Data.Maybe( isJust ) {- ************************************************************************ @@ -2835,22 +2834,6 @@ tryEtaReduce rec_ids bndrs body eval_sd ok_arg _ _ _ _ = Nothing --- | Can we eta-reduce the given function --- See Note [Eta reduction soundness], criteria (B), (J), and (W). -cantEtaReduceFun :: Id -> Bool -cantEtaReduceFun fun - = hasNoBinding fun -- (B) - -- Don't undersaturate functions with no binding. - - || isJoinId fun -- (J) - -- Don't undersaturate join points. - -- See Note [Invariants on join points] in GHC.Core, and #20599 - - || (isJust (idCbvMarks_maybe fun)) -- (W) - -- Don't undersaturate StrictWorkerIds. - -- See Note [CBV Function Ids: overview] in GHC.Types.Id.Info. - - {- ********************************************************************* * * The "push rules" ===================================== compiler/GHC/Core/Opt/OccurAnal.hs ===================================== @@ -797,10 +797,10 @@ function call and a jump by looking at the occurrence (because the same pass changes the 'IdDetails' and propagates the binders to their occurrence sites). To track potential join points, we use the 'occ_tail' field of OccInfo. A value -of `AlwaysTailCalled n` indicates that every occurrence of the variable is a -tail call with `n` arguments (counting both value and type arguments). Otherwise -'occ_tail' will be 'NoTailCallInfo'. The tail call info flows bottom-up with the -rest of 'OccInfo' until it goes on the binder. +of `AlwaysTailCalled { tailCallArity = n }` indicates that every occurrence of +the variable is a tail call with `n` arguments (counting both value and type +arguments). Otherwise 'occ_tail' will be 'NoTailCallInfo'. The tail call info +flows bottom-up with the rest of 'OccInfo' until it goes on the binder. Note [Join arity prediction based on joinRhsArity] ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ @@ -2585,13 +2585,21 @@ But it is not necessary to gather CoVars from the types of other binders. occAnal env (Tick tickish body) = WUD usage' (Tick tickish body') where - WUD usage body' = occAnal env body + WUD usage body' = occAnal env' body + + env' = case tickish of + -- setInsideProfTick: join points under profiling ticks turn + -- into quasi-join points. See Note [Quasi join points] + ProfNote {} -> setInsideProfTick env + _ -> env usage' - | tickish `tickishScopesLike` SoftScope + | tickishCanScopeJoin tickish = usage -- For soft-scoped ticks (including SourceNotes) we don't want -- to lose join-point-hood, so we don't mess with `usage` (#24078) + -- Similarly for cost centres. (#26157) + -- For a non-soft tick scope, we can inline lambdas only, so we -- abandon tail calls, and do markAllInsideLam too: usage_lam @@ -2613,11 +2621,12 @@ occAnal env (Tick tickish body) -- See #14242. occAnal env (Cast expr co) - = let (WUD usage expr') = occAnal env expr - usage1 = addManyOccs usage (coVarsOfCo co) - -- usage2: see Note [Gather occurrences of coercion variables] - usage2 = markAllNonTail usage1 - -- usage3: calls inside expr aren't tail calls any more + = let (WUD usage expr') = occAnal (setInsideCast env) expr + -- setInsideCast: join points inside casts turn into quasi join points + -- See Note [Quasi join points] + usage1 = addManyOccs usage (coVarsOfCo co) + -- usage2: see Note [Gather occurrences of coercion variables] + usage2 = markAllNonTail usage1 in WUD usage2 (Cast expr' co) occAnal env app@(App _ _) @@ -2927,6 +2936,8 @@ scrutinised y). data OccEnv = OccEnv { occ_encl :: !OccEncl -- Enclosing context information + , occ_prof_ticks :: !Int -- ^ How many profiling ticks are we under? See Note [Quasi join points] + , occ_casts :: !Int -- ^ How many casts are we under? See Note [Quasi join points] , occ_one_shots :: !OneShots -- See Note [OneShots] , occ_unf_act :: Id -> Bool -- Which Id unfoldings are active , occ_rule_act :: ActivationGhc -> Bool -- Which rules are active @@ -2992,6 +3003,8 @@ type OneShots = [OneShotInfo] initOccEnv :: OccEnv initOccEnv = OccEnv { occ_encl = OccVanilla + , occ_prof_ticks = 0 + , occ_casts = 0 , occ_one_shots = [] -- To be conservative, we say that all @@ -3070,6 +3083,12 @@ setTailCtxt !env = env { occ_encl = OccVanilla } -- Preserve occ_one_shots, occ_join points -- Do not use OccRhs for the RHS of a join point (which is a tail ctxt): +setInsideProfTick :: OccEnv -> OccEnv +setInsideProfTick !env = env { occ_prof_ticks = 1 + occ_prof_ticks env } + +setInsideCast :: OccEnv -> OccEnv +setInsideCast !env = env { occ_casts = 1 + occ_casts env } + mkRhsOccEnv :: OccEnv -> RecFlag -> OccEncl -> JoinPointHood -> Id -> CoreExpr -> OccEnv -- See Note [The OccEnv for a right hand side] -- For a join point: @@ -3696,7 +3715,7 @@ type OccInfoEnv = IdEnv LocalOcc -- A finite map from an expression's data LocalOcc -- See Note [LocalOcc] = OneOccL { lo_n_br :: {-# UNPACK #-} !BranchCount -- Number of syntactic occurrences , lo_tail :: !TailCallInfo - -- Combining (AlwaysTailCalled 2) and (AlwaysTailCalled 3) + -- NB: combining 'TailCallInfo's with different arities -- gives NoTailCallInfo , lo_int_cxt :: !InterestingCxt } @@ -3789,9 +3808,20 @@ mkOneOcc !env id int_cxt arity = mkSimpleDetails (unitVarEnv id occ) where - occ = OneOccL { lo_n_br = 1 - , lo_int_cxt = int_cxt - , lo_tail = AlwaysTailCalled arity } + occ = + OneOccL + { lo_n_br = 1 + , lo_int_cxt = int_cxt + , lo_tail = + AlwaysTailCalled + { tailCallArity = arity + + -- See Note [Quasi join points] for justification of these + -- two fields. + , tailCallUnderProfTicks = occ_prof_ticks env + , tailCallUnderCasts = occ_casts env + } + } -- Add several occurrences, assumed not to be tail calls add_many_occ :: Var -> OccInfoEnv -> OccInfoEnv @@ -3844,13 +3874,14 @@ delBndrsFromUDs bndrs (UD { ud_env = env, ud_z_many = z_many , ud_z_tail = z_tail `delVarEnvList` bndrs } markAllMany, markAllInsideLam, markAllNonTail, markAllManyNonTail - :: UsageDetails -> UsageDetails + :: HasDebugCallStack => UsageDetails -> UsageDetails markAllMany ud@(UD { ud_env = env }) = ud { ud_z_many = env } markAllInsideLam ud@(UD { ud_env = env }) = ud { ud_z_in_lam = env } -markAllNonTail ud@(UD { ud_env = env }) = ud { ud_z_tail = env } markAllManyNonTail = markAllMany . markAllNonTail -- effectively sets to noOccInfo -markAllInsideLamIf, markAllNonTailIf :: Bool -> UsageDetails -> UsageDetails +markAllNonTail ud@(UD { ud_env = env }) = ud { ud_z_tail = env } + +markAllInsideLamIf, markAllNonTailIf :: HasDebugCallStack => Bool -> UsageDetails -> UsageDetails markAllInsideLamIf True ud = markAllInsideLam ud markAllInsideLamIf False ud = ud @@ -3947,7 +3978,7 @@ adjustNonRecRhs mb_join_arity (WTUD (TUD rhs_ja uds) rhs) where exact_join = mb_join_arity == JoinPoint rhs_ja -adjustTailUsage :: Bool -- True <=> Exactly-matching join point; don't do markNonTail +adjustTailUsage :: HasDebugCallStack => Bool -- True <=> Exactly-matching join point; don't do markNonTail -> CoreExpr -- Rhs usage, AFTER occAnalLamTail -> UsageDetails -> UsageDetails @@ -3959,7 +3990,7 @@ adjustTailUsage exact_join rhs uds where one_shot = isOneShotFun rhs -adjustTailArity :: JoinPointHood -> TailUsageDetails -> UsageDetails +adjustTailArity :: HasDebugCallStack => JoinPointHood -> TailUsageDetails -> UsageDetails adjustTailArity mb_rhs_ja (TUD ja usage) = markAllNonTailIf (mb_rhs_ja /= JoinPoint ja) usage @@ -3993,7 +4024,7 @@ tagNonRecBinder :: TopLevelFlag -- At top level? -- Precondition: OccInfo is not IAmDead tagNonRecBinder lvl occ bndr | okForJoinPoint lvl bndr tail_call_info - , AlwaysTailCalled ar <- tail_call_info + , AlwaysTailCalled { tailCallArity = ar } <- tail_call_info = (setBinderOcc occ bndr, JoinPoint ar) | otherwise = (setBinderOcc zapped_occ bndr, NotJoinPoint) @@ -4080,7 +4111,7 @@ okForJoinPoint lvl bndr tail_call_info = False where valid_join | NotTopLevel <- lvl - , AlwaysTailCalled arity <- tail_call_info + , AlwaysTailCalled { tailCallArity = arity } <- tail_call_info , -- Invariant 1 as applied to LHSes of rules all (ok_rule arity) (idCoreRules bndr) @@ -4097,9 +4128,9 @@ okForJoinPoint lvl bndr tail_call_info lost_join | JoinPoint ja <- idJoinPointHood bndr = not valid_join || - (case tail_call_info of -- Valid join but arity differs - AlwaysTailCalled ja' -> ja /= ja' - _ -> False) + (case tail_call_info of -- Valid join but arity differs + AlwaysTailCalled { tailCallArity = ja' } -> ja /= ja' + _ -> False) | otherwise = False ok_rule _ BuiltinRule{} = False -- only possible with plugin shenanigans @@ -4121,7 +4152,7 @@ okForJoinPoint lvl bndr tail_call_info , text "tc:" <+> ppr tail_call_info , text "rules:" <+> ppr (idCoreRules bndr) , case tail_call_info of - AlwaysTailCalled arity -> + AlwaysTailCalled { tailCallArity = arity } -> vcat [ text "ok_unf:" <+> ppr (ok_unfolding arity (realIdUnfolding bndr)) , text "ok_type:" <+> ppr (isValidJoinPointType arity (idType bndr)) ] _ -> empty ] @@ -4184,6 +4215,6 @@ orLocalOcc (OneOccL { lo_n_br = nbr1, lo_int_cxt = int_cxt1, lo_tail = tci1 }) orLocalOcc occ1 occ2 = andLocalOcc occ1 occ2 andTailCallInfo :: TailCallInfo -> TailCallInfo -> TailCallInfo -andTailCallInfo info@(AlwaysTailCalled arity1) (AlwaysTailCalled arity2) - | arity1 == arity2 = info +andTailCallInfo (AlwaysTailCalled arity1 p1 c1) (AlwaysTailCalled arity2 p2 c2) + | arity1 == arity2 = AlwaysTailCalled arity1 (max p1 p2) (max c1 c2) andTailCallInfo _ _ = NoTailCallInfo ===================================== compiler/GHC/Core/Opt/Simplify/Env.hs ===================================== @@ -201,6 +201,9 @@ data SimplEnv , seCaseDepth :: !Int -- Depth of multi-branch case alternatives + , seProfTicks :: !Int -- Current depth of profiling ticks; see Note [Quasi join points] + , seCasts :: !Int -- Current depth of casts; see Note [Quasi join points] + , seInlineDepth :: !Int -- 0 initially, 1 when we inline an already-simplified -- unfolding, and simplify again; and so on -- See Note [Inline depth] @@ -590,6 +593,8 @@ mkSimplEnv mode fam_envs , seIdSubst = emptyVarEnv , seRecIds = emptyUnVarSet , seCaseDepth = 0 + , seProfTicks = 0 + , seCasts = 0 , seInlineDepth = 0 } -- The top level "enclosing CC" is "SUBSUMED". ===================================== compiler/GHC/Core/Opt/Simplify/Iteration.hs ===================================== @@ -39,7 +39,7 @@ import GHC.Core.Opt.Arity ( ArityType, exprArity, arityTypeBotSigs_maybe , pushCoTyArg, pushCoValArg, exprIsDeadEnd , typeArity, arityTypeArity, etaExpandAT ) import GHC.Core.SimpleOpt ( exprIsConApp_maybe, joinPointBinding_maybe, joinPointBindings_maybe ) -import GHC.Core.FVs ( mkRuleInfo {- exprsFreeIds -} ) +import GHC.Core.FVs ( mkRuleInfo ) import GHC.Core.Rules ( lookupRule, getRules ) import GHC.Core.Multiplicity @@ -57,6 +57,7 @@ import GHC.Types.Unique ( hasKey ) import GHC.Types.Basic import GHC.Types.Tickish import GHC.Types.Var ( isTyCoVar ) + import GHC.Builtin.Types.Prim( realWorldStatePrimTy ) import GHC.Builtin.Names( runRWKey, seqHashKey ) @@ -1442,7 +1443,10 @@ simplTick env tickish expr cont no_floating_past_tick = do { let (inc,outc) = splitCont cont - ; (floats, expr1) <- simplExprF env expr inc + env' = case tickish of + ProfNote {} -> env { seProfTicks = seProfTicks env + 1 } + _ -> env + ; (floats, expr1) <- simplExprF env' expr inc ; let expr2 = wrapFloats floats expr1 tickish' = simplTickish env tickish ; rebuild env (mkTick tickish' expr2) outc @@ -1680,39 +1684,54 @@ optOutCoercion env co already_optimised empty_subst = mkEmptySubst (seInScope env) opts = seOptCoercionOpts env +-- | Number of casts we are adding around an expression as we process a 'Cast'. +-- +-- We need the cast depth to implement the logic of Note [Quasi join points]. +type NbCastsAdded = Int + simplCast :: SimplEnv -> InExpr -> InCoercion -> SimplCont -> SimplM (SimplFloats, OutExpr) simplCast env body co0 cont0 = do { co1 <- {-#SCC "simplCast-simplCoercion" #-} simplCoercion env co0 - ; cont1 <- {-#SCC "simplCast-addCoerce" #-} - if isReflCo co1 - then return cont0 -- See Note [Optimising reflexivity] - else addCoerce co1 True cont0 - -- True <=> co1 is optimised - ; {-#SCC "simplCast-simplExprF" #-} simplExprF env body cont1 } + ; (cont1, nbAddedCasts) <- {-#SCC "simplCast-addCoerce" #-} + if isReflCo co1 + then return (cont0, 0) -- See Note [Optimising reflexivity] + else addCoerce co1 True cont0 + -- True <=> co1 is optimised + + -- Keep track of how many casts we have added, because we need this + -- information for Note [Quasi join points]. + ; let env' = env { seCasts = seCasts env + nbAddedCasts } + ; {-#SCC "simplCast-simplExprF" #-} simplExprF env' body cont1 } where -- If the first parameter is MRefl, then simplifying revealed a -- reflexive coercion. Omit. - addCoerceM :: MOutCoercion -> Bool -> SimplCont -> SimplM SimplCont - addCoerceM MRefl _ cont = return cont + addCoerceM :: MOutCoercion -> Bool -> SimplCont -> SimplM (SimplCont, NbCastsAdded) + addCoerceM MRefl _ cont = return (cont, 0) addCoerceM (MCo co) opt cont = addCoerce co opt cont - addCoerce :: OutCoercion -> Bool -> SimplCont -> SimplM SimplCont + addCoerce :: OutCoercion -> Bool -> SimplCont -> SimplM (SimplCont, NbCastsAdded) addCoerce co1 _ (CastIt { sc_co = co2, sc_cont = cont }) -- See Note [Optimising reflexivity] - = addCoerce (mkTransCo co1 co2) False cont - -- False: (mkTransCo co1 co2) is not fully optimised - -- See Note [Avoid re-simplifying coercions] + = do { (cont', nbCastsAdded) <- addCoerce (mkTransCo co1 co2) False cont + -- False: (mkTransCo co1 co2) is not fully optimised + -- See Note [Avoid re-simplifying coercions] + ; return (cont', nbCastsAdded - 1) + -- -1: the coercion coalesced with an existing coercion. + } addCoerce co co_is_opt (ApplyToTy { sc_arg_ty = arg_ty, sc_cont = tail }) | Just (arg_ty', m_co') <- pushCoTyArg co arg_ty = {-#SCC "addCoerce-pushCoTyArg" #-} - do { tail' <- addCoerceM m_co' co_is_opt tail - ; return (ApplyToTy { sc_arg_ty = arg_ty' - , sc_cont = tail' - , sc_hole_ty = coercionLKind co }) } - -- NB! As the cast goes past, the - -- type of the hole changes (#16312) + do { (tail', nbCastsAdded) <- addCoerceM m_co' co_is_opt tail + ; return ( ApplyToTy { sc_arg_ty = arg_ty' + , sc_cont = tail' + , sc_hole_ty = coercionLKind co } + -- NB! As the cast goes past, the + -- type of the hole changes (#16312) + , nbCastsAdded ) + } + -- (f |> co) e ===> (f (e |> co1)) |> co2 -- where co :: (s1->s2) ~ (t1->t2) -- co1 :: t1 ~ s1 @@ -1725,10 +1744,12 @@ simplCast env body co0 cont0 | Just (m_co1, m_co2) <- pushCoValArg co = {-#SCC "addCoerce-pushCoValArg" #-} - do { tail' <- addCoerceM m_co2 co_is_opt tail + do { (tail', nbCastsAdded) <- addCoerceM m_co2 co_is_opt tail ; case m_co1 of { - MRefl -> return (cont { sc_cont = tail' - , sc_hole_ty = coercionLKind co }) ; + MRefl -> return + ( cont { sc_cont = tail' + , sc_hole_ty = coercionLKind co } + , nbCastsAdded ) ; -- See Note [Avoiding simplifying repeatedly] MCo co1 -> @@ -1738,17 +1759,23 @@ simplCast env body co0 cont0 -- to make it all consistent. It's a bit messy. -- But it isn't a common case. -- Example of use: #995 - ; return (ApplyToVal { sc_arg = mkCast arg' co1 - , sc_env = arg_se' - , sc_dup = dup' - , sc_cont = tail' - , sc_hole_ty = coercionLKind co }) } } } + ; return + ( ApplyToVal { sc_arg = mkCast arg' co1 + , sc_env = arg_se' + , sc_dup = dup' + , sc_cont = tail' + , sc_hole_ty = coercionLKind co } + , nbCastsAdded ) } } } addCoerce co co_is_opt cont - | isReflCo co = return cont -- Having this at the end makes a huge - -- difference in T12227, for some reason - -- See Note [Optimising reflexivity] - | otherwise = return (CastIt { sc_co = co, sc_opt = co_is_opt, sc_cont = cont }) + | isReflCo co = return (cont, 0 :: NbCastsAdded ) + -- Having this at the end makes a huge + -- difference in T12227, for some reason + -- See Note [Optimising reflexivity] + | otherwise = + return + ( CastIt { sc_co = co, sc_opt = co_is_opt, sc_cont = cont } + , 1 :: NbCastsAdded ) simplLazyArg :: SimplEnvIS -- ^ Used only for its InScopeSet -> DupFlag @@ -2051,8 +2078,8 @@ simplNonRecJoinPoint :: SimplEnv -> InId -> InExpr -> InExpr -> SimplCont -> SimplM (SimplFloats, OutExpr) simplNonRecJoinPoint env bndr rhs body cont - = assert (isJoinId bndr ) $ - wrapJoinCont env cont $ \ env cont -> + = assert (isJoinId bndr) $ + wrapJoinCont do_case_case env cont $ \ env cont -> do { -- We push join_cont into the join RHS and the body; -- and wrap wrap_cont around the whole thing ; let mult = contHoleScaling cont @@ -2062,14 +2089,17 @@ simplNonRecJoinPoint env bndr rhs body cont ; (floats1, env3) <- simplJoinBind NonRecursive cont (bndr,env) (bndr2,env2) (rhs,env) ; (floats2, body') <- simplExprF env3 body cont ; return (floats1 `addFloats` floats2, body') } + where + do_case_case = + if isTrueJoinPoint env bndr + then seCaseCase env + else False - ------------------- simplRecJoinPoint :: SimplEnv -> [(InId, InExpr)] -> InExpr -> SimplCont -> SimplM (SimplFloats, OutExpr) simplRecJoinPoint env pairs body cont - = wrapJoinCont env cont $ \ env cont -> + = wrapJoinCont do_case_case env cont $ \ env cont -> do { let bndrs = map fst pairs mult = contHoleScaling cont res_ty = contResultType cont @@ -2079,30 +2109,53 @@ simplRecJoinPoint env pairs body cont ; (floats1, env2) <- simplRecBind env1 (BC_Join Recursive cont) pairs ; (floats2, body') <- simplExprF env2 body cont ; return (floats1 `addFloats` floats2, body') } + where + do_case_case = + if all (isTrueJoinPoint env . fst) pairs + then seCaseCase env + else False + +-- | Is this a true join point, or only a quasi join point? +-- +-- See Note [Quasi join points] +isTrueJoinPoint :: SimplEnv -> InId -> Bool +isTrueJoinPoint env id + | Just occMaxProfTicks <- occursUnderProfTicks (idOccInfo id) + , occMaxProfTicks > seProfTicks env + -- The join point occurs under more profiling ticks that its binding. + = False + | Just occMaxCasts <- occursUnderCasts (idOccInfo id) + , occMaxCasts > seCasts env + -- The join point occurs under more casts than its binding. + = False + | otherwise + = True -------------------- -wrapJoinCont :: SimplEnv -> SimplCont +wrapJoinCont :: Bool + -> SimplEnv -> SimplCont -> (SimplEnv -> SimplCont -> SimplM (SimplFloats, OutExpr)) -> SimplM (SimplFloats, OutExpr) -- Deal with making the continuation duplicable if necessary, -- and with the no-case-of-case situation. -wrapJoinCont env cont thing_inside +wrapJoinCont do_case_case env cont thing_inside | contIsStop cont -- Common case; no need for fancy footwork = thing_inside env cont - | not (seCaseCase env) - -- See Note [Join points with -fno-case-of-case] - = do { (floats1, expr1) <- thing_inside env (mkBoringStop (contHoleType cont)) - ; let (floats2, expr2) = wrapJoinFloatsX floats1 expr1 - ; (floats3, expr3) <- rebuild (env `setInScopeFromF` floats2) expr2 cont - ; return (floats2 `addFloats` floats3, expr3) } - - | otherwise - -- Normal case; see Note [Join points and case-of-case] + | do_case_case + -- Normal situation: do the "case-of-case" transformation. + -- See Note [Join points and case-of-case]. = do { (floats1, cont') <- mkDupableCont env cont ; (floats2, result) <- thing_inside (env `setInScopeFromF` floats1) cont' ; return (floats1 `addFloats` floats2, result) } + | otherwise + -- No "case-of-case" transformation. + -- See Note [Join points with -fno-case-of-case]. + = do { (floats1, expr1) <- thing_inside env (mkBoringStop (contHoleType cont)) + ; let (floats2, expr2) = wrapJoinFloatsX floats1 expr1 + ; (floats3, expr3) <- rebuild (env `setInScopeFromF` floats2) expr2 cont + ; return (floats2 `addFloats` floats3, expr3) } -------------------- trimJoinCont :: Id -- Used only in error message @@ -2151,15 +2204,18 @@ evaluation context E): As is evident from the example, there are two components to this behavior: - 1. When entering the RHS of a join point, copy the context inside. - 2. When a join point is invoked, discard the outer context. + (wrapJoinCont) When entering the RHS of a join point, copy the context inside. + (trimJoinCont) When a join point is invoked, discard the outer context. We need to be very careful here to remain consistent---neither part is optional! -We need do make the continuation E duplicable (since we are duplicating it) +We need to make the continuation E duplicable (since we are duplicating it) with mkDupableCont. +Note that not all join points support this transformation: +see Note [Quasi join points]. + Note [Join points with -fno-case-of-case] ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ @@ -2184,7 +2240,8 @@ case-of-case we may then end up with this totally bogus result This would be OK in the language of the paper, but not in GHC: j is no longer a join point. We can only do the "push continuation into the RHS of the join point j" if we also push the continuation right down to the /jumps/ to -j, so that it can evaporate there. If we are doing case-of-case, we'll get to +j, so that it can evaporate there (trimJoinCont). Then, if we are doing +case-of-case, we'll get to join x = case <j-rhs> of <outer-alts> in case y of @@ -2199,6 +2256,105 @@ inwards altogether at any join point. Instead simplify the (join ... in ...) with a Stop continuation, and wrap the original continuation around the outside. Surprisingly tricky! +Note [Quasi join points] +~~~~~~~~~~~~~~~~~~~~~~~~ +We currently classify join points into two separate categories + + - true join points + - quasi join points + +Definition: + A join point binding defines a *quasi* join point if any of the join point + binders occur under profiling ticks or casts. + + If a join point binding is not a quasi join point, it is a *true* join point. + +We can push continuations into true join points, as described in +Note [Join points and case-of-case]: + + K[ join j = rhs in body ] --> join j = K[ rhs ] in K[ body ] + +This transformation is not valid if the occurrences of 'j' in 'body' appear: + + 1. under casts, see #26422 + 2. under profiling ticks, see #26693 #26157 #26642 + +For example, consider (a minimisation of) the program in #26693: + + join { j :: Bool -> IO (); j _ = guts } + in case pass of + False -> scctick<foo> jump j True + True -> jump j False + +Let's try to push the application to an argument 'arg' into this expression. +As per Note [Join points and case-of-case], we proceed by first applying the +argument to both the join point RHS and the case alternatives: + + join { j :: Bool -> IO (); j _ = guts arg ] } + in case pass of + False -> (scctick<foo> jump j True) arg + True -> jump j False arg + +Then we rely on 'trimJoinCont' to remove the argument. In this case, this fails +for the first branch, because 'trimJoinCont' doesn't look through profiling +ticks. Were we to address this, it's still not clear what code we would want to +end up with, as we don't want to misattribute profiling costs. +We could plausibly transform to the following: + + join { j :: Bool -> IO (); j scc_or_null _ = (setSCC# scc_or_null guts) arg ] } + in case pass of + False -> jump j <foo> True + True -> jump j null False + +where `setSCC#` is a new primop that would set the current cost centre pointer +(or no-op if the given pointer is null). +However: + - this primop doesn't exist today, + - it requires adding an argument to the join point (hence changing its arity) +So instead, for now, we simply disallow the case-of-case transformation for 'j'. + +Similarly for casts: + + join { j = blah } + in case e of + False -> j True |> co1 + True -> j False |> co2 + +if we want to apply this to an argument 'arg', we would need to perform the +following transformation: + + join { j co = ( blah |> co ) arg } + in case e of + False -> j co1 True + True -> j co2 False + +in which we add a coercion argument to the join point. Again, this is not a +transformation we currently implement, so we instead prevent case-of-case for +such join points. + +To figure out whether a join point is a true join point or a quasi join point, +we proceed as follows: + + 1. In occurrence analysis, we compute how many profiling ticks/casts each + join point Id occurs under. + + This is stored in the 'tailCallUnderProfTicks' and 'tailCallUnderCasts' + fields of 'TailCallInfo', and populated by keeping track of how many + profiling ticks and casts we are under when doing occurrence analysis + (see 'occ_prof_ticks' and 'occ_casts'). + + 2. In the simplifier, we keep track of how many profiling ticks/casts we are + currently inside. See 'seProfTicks' and 'seCasts', which are updated + in 'simplTick' and 'simplCast', respectively. + + 3. In the simplifier, when we come across a join point binding (in either + 'simplNonRecJoinPoint' or 'simplRecJoinPoint'), we compare the current + cast depth/profiling tick depth with the cast depth/profiling tick depth + of the occurrences of the join point binders. + + If a join point binder occurs under more profiling ticks/casts than its + binding site, then it is a quasi join point and we switch off the + case-of-case transformation. ************************************************************************ * * ===================================== compiler/GHC/Core/SimpleOpt.hs ===================================== @@ -1076,7 +1076,7 @@ joinPointBinding_maybe bndr rhs | isJoinId bndr = Just (bndr, rhs) - | AlwaysTailCalled join_arity <- tailCallInfo (idOccInfo bndr) + | AlwaysTailCalled { tailCallArity = join_arity } <- tailCallInfo (idOccInfo bndr) , (bndrs, body) <- etaExpandToJoinPoint join_arity rhs , let str_sig = idDmdSig bndr str_arity = count isId bndrs -- Strictness demands are for Ids only ===================================== compiler/GHC/Core/Utils.hs ===================================== @@ -35,6 +35,7 @@ module GHC.Core.Utils ( exprIsTopLevelBindable, exprIsUnaryClassFun, isUnaryClassId, altsAreExhaustive, etaExpansionTick, + cantEtaReduceFun, -- * Equality cheapEqExpr, cheapEqExpr', diffBinds, @@ -2081,9 +2082,24 @@ altsAreExhaustive (Alt con1 _ _ : alts) -- Takes the function we are applying as argument. etaExpansionTick :: Id -> GenTickish pass -> Bool etaExpansionTick id t - = hasNoBinding id && + = ( cantEtaReduceFun id ) && ( tickishFloatable t || isProfTick t ) +-- | Can we eta-reduce the given function? +-- See Note [Eta reduction soundness], criteria (B), (J), and (W). +cantEtaReduceFun :: Id -> Bool +cantEtaReduceFun fun + = hasNoBinding fun -- (B) + -- Don't undersaturate functions with no binding. + + || isJoinId fun -- (J) + -- Don't undersaturate join points. + -- See Note [Invariants on join points] in GHC.Core, and #20599 + + || isJust (idCbvMarks_maybe fun) -- (W) + -- Don't undersaturate StrictWorkerIds. + -- See Note [CBV Function Ids] in GHC.Types.Id.Info. + {- Note [exprOkForSpeculation and type classes] ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Consider (#22745, #15205) ===================================== compiler/GHC/CoreToStg/Prep.hs ===================================== @@ -1130,7 +1130,10 @@ cpeApp top_env expr hd = getIdFromTrivialExpr_maybe e2 -- Determine number of required arguments. See Note [Ticks and mandatory eta expansion] min_arity = case hd of - Just v_hd -> if hasNoBinding v_hd then Just $! (idArity v_hd) else Nothing + Just v_hd -> + if cantEtaReduceFun v_hd + then Just $! idArity v_hd + else Nothing Nothing -> Nothing -- ; pprTraceM "cpe_app:stricts:" (ppr v <+> ppr args $$ ppr stricts $$ ppr (idCbvMarks_maybe v)) ; (app, floats, unsat_ticks) <- rebuild_app env args e2 emptyFloats stricts min_arity ===================================== compiler/GHC/Types/Basic.hs ===================================== @@ -70,7 +70,7 @@ module GHC.Types.Basic ( BranchCount, oneBranch, InterestingCxt(..), TailCallInfo(..), tailCallInfo, zapOccTailCallInfo, - isAlwaysTailCalled, + isAlwaysTailCalled, occursUnderProfTicks, occursUnderCasts, EP(..), @@ -1149,8 +1149,14 @@ instance Monoid InsideLam where mappend = (Semi.<>) ----------------- + +-- | See Note [TailCallInfo] data TailCallInfo - = AlwaysTailCalled {-# UNPACK #-} !JoinArity -- See Note [TailCallInfo] + = AlwaysTailCalled + { tailCallArity :: {-# UNPACK #-} !JoinArity + , tailCallUnderProfTicks :: !Int -- See Note [Quasi join points] + , tailCallUnderCasts :: !Int -- See Note [Quasi join points] + } | NoTailCallInfo deriving (Eq) @@ -1167,9 +1173,26 @@ isAlwaysTailCalled occ = case tailCallInfo occ of AlwaysTailCalled{} -> True NoTailCallInfo -> False +-- | If this 'Id' is always tail called, how many profiling ticks does +-- it occur under? See Note [Quasi join points]. +occursUnderProfTicks :: OccInfo -> Maybe Int +occursUnderProfTicks occ = + case tailCallInfo occ of + AlwaysTailCalled { tailCallUnderProfTicks = nb } -> Just nb + NoTailCallInfo -> Nothing + +-- | If this 'Id' is always tail called, how many casts does +-- it occur under? See Note [Quasi join points]. +occursUnderCasts :: OccInfo -> Maybe Int +occursUnderCasts occ = + case tailCallInfo occ of + AlwaysTailCalled { tailCallUnderCasts = nb } -> Just nb + NoTailCallInfo -> Nothing + instance Outputable TailCallInfo where - ppr (AlwaysTailCalled ar) = sep [ text "Tail", int ar ] - ppr _ = empty + ppr (AlwaysTailCalled ar p c) = + sep [ text "Tail", brackets (int p <> comma <> int c), int ar ] + ppr NoTailCallInfo = text "NoTailCallInfo" ----------------- strongLoopBreaker, weakLoopBreaker :: OccInfo @@ -1217,7 +1240,10 @@ instance Outputable OccInfo where pp_tail = pprShortTailCallInfo tail_info pprShortTailCallInfo :: TailCallInfo -> SDoc -pprShortTailCallInfo (AlwaysTailCalled ar) = char 'T' <> brackets (int ar) +pprShortTailCallInfo (AlwaysTailCalled ar p c) + = char 'T' <> (brackets (text "P" <+> int p)) + <> (brackets (text "C" <+> int c)) + <> brackets (int ar) pprShortTailCallInfo NoTailCallInfo = empty {- @@ -1251,6 +1277,9 @@ point can also be invoked from other join points, not just from case branches: Here both 'j1' and 'j2' will get marked AlwaysTailCalled, but j1 will get ManyOccs and j2 will get `OneOcc { occ_n_br = 2 }`. +We also store how many profiling ticks and casts the join point occurs under. +The rationale is described in Note [Quasi join points]. + ************************************************************************ * * Default method specification ===================================== compiler/GHC/Types/Tickish.hs ===================================== @@ -11,6 +11,7 @@ module GHC.Types.Tickish ( tickishScopesLike, tickishFloatable, tickishCanSplit, + tickishCanScopeJoin, mkNoCount, mkNoScope, tickishIsCode, @@ -326,6 +327,14 @@ tickishCanSplit ProfNote{profNoteScope = True, profNoteCount = True} = True tickishCanSplit _ = False +-- | Is @join f x in <tick> jump f x@ valid? +tickishCanScopeJoin :: GenTickish pass -> Bool +tickishCanScopeJoin tick = case tick of + ProfNote{} -> True + HpcTick{} -> False + Breakpoint{} -> False + SourceNote{} -> True + mkNoCount :: GenTickish pass -> GenTickish pass mkNoCount n | not (tickishCounts n) = n | not (tickishCanSplit n) = panic "mkNoCount: Cannot split!" View it on GitLab: https://gitlab.haskell.org/ghc/ghc/-/commit/190bc4951c50301c36d1af84343f469b... -- View it on GitLab: https://gitlab.haskell.org/ghc/ghc/-/commit/190bc4951c50301c36d1af84343f469b... You're receiving this email because of your account on gitlab.haskell.org.
participants (1)
-
sheaf (@sheaf)