[Git][ghc/ghc][wip/andreask/ticked_joins] WIP: turn off case-of-case when join point occurs under prof ticks
sheaf pushed to branch wip/andreask/ticked_joins at Glasgow Haskell Compiler / GHC Commits: 8b8e2154 by sheaf at 2026-01-10T11:31:45+01:00 WIP: turn off case-of-case when join point occurs under prof ticks - - - - - 8 changed files: - compiler/GHC/Core/Lint.hs - compiler/GHC/Core/Opt/OccurAnal.hs - compiler/GHC/Core/Opt/Simplify/Env.hs - compiler/GHC/Core/Opt/Simplify/Iteration.hs - compiler/GHC/Core/SimpleOpt.hs - compiler/GHC/Core/Utils.hs - compiler/GHC/CoreToStg/Prep.hs - compiler/GHC/Types/Basic.hs Changes: ===================================== compiler/GHC/Core/Lint.hs ===================================== @@ -672,7 +672,7 @@ lintRhs :: Id -> CoreExpr -> LintM (OutType, UsageEnv) lintRhs bndr rhs | JoinPoint arity <- idJoinPointHood bndr = lintJoinLams arity (Just bndr) rhs - | AlwaysTailCalled arity <- tailCallInfo (idOccInfo bndr) + | AlwaysTailCalled arity _ <- tailCallInfo (idOccInfo bndr) = lintJoinLams arity Nothing rhs -- Allow applications of the data constructor @StaticPtr@ at the top ===================================== compiler/GHC/Core/Opt/OccurAnal.hs ===================================== @@ -2585,7 +2585,13 @@ But it is not necessary to gather CoVars from the types of other binders. occAnal env (Tick tickish body) = WUD usage' (Tick tickish body') where - WUD usage body' = occAnal env body + WUD usage body' = occAnal env' body + + env' = case tickish of + -- Set that we are inside a profiling tick + -- SLD TODO: explain why we need this info + ProfNote {} -> setInProfTick env + _ -> env usage' | tickishCanScopeJoin tickish @@ -2809,6 +2815,13 @@ occAnalApp env (fun, args, ticks) in WUD (markAllNonTail (fun_uds `andUDs` args_uds)) app_out where + -- SLD TODO + -- !_ = pprTrace "occAnalApp fallback: marking all non-tail" + -- ( vcat [ text "fun:" <+> ppr fun + -- , text "args:" <+> ppr args + -- , text "ticks:" <+> ppr ticks + -- ]) + -- () !(WUD args_uds app') = occAnalArgs env fun' args [] !(WUD fun_uds fun') = occAnal (addAppCtxt env args) fun -- The addAppCtxt is a bit cunning. One iteration of the simplifier @@ -2929,6 +2942,7 @@ scrutinised y). data OccEnv = OccEnv { occ_encl :: !OccEncl -- Enclosing context information + , occ_prof_ticks :: !Int , occ_one_shots :: !OneShots -- See Note [OneShots] , occ_unf_act :: Id -> Bool -- Which Id unfoldings are active , occ_rule_act :: ActivationGhc -> Bool -- Which rules are active @@ -2994,6 +3008,7 @@ type OneShots = [OneShotInfo] initOccEnv :: OccEnv initOccEnv = OccEnv { occ_encl = OccVanilla + , occ_prof_ticks = 0 , occ_one_shots = [] -- To be conservative, we say that all @@ -3072,6 +3087,9 @@ setTailCtxt !env = env { occ_encl = OccVanilla } -- Preserve occ_one_shots, occ_join points -- Do not use OccRhs for the RHS of a join point (which is a tail ctxt): +setInProfTick :: OccEnv -> OccEnv +setInProfTick !env = env { occ_prof_ticks = 1 + occ_prof_ticks env } + mkRhsOccEnv :: OccEnv -> RecFlag -> OccEncl -> JoinPointHood -> Id -> CoreExpr -> OccEnv -- See Note [The OccEnv for a right hand side] -- For a join point: @@ -3813,7 +3831,7 @@ mkOneOcc !env id int_cxt arity where occ = OneOccL { lo_n_br = 1 , lo_int_cxt = int_cxt - , lo_tail = AlwaysTailCalled arity } + , lo_tail = AlwaysTailCalled arity (occ_prof_ticks env) } -- Add several occurrences, assumed not to be tail calls add_many_occ :: Var -> OccInfoEnv -> OccInfoEnv @@ -3866,13 +3884,20 @@ delBndrsFromUDs bndrs (UD { ud_env = env, ud_z_many = z_many , ud_z_tail = z_tail `delVarEnvList` bndrs } markAllMany, markAllInsideLam, markAllNonTail, markAllManyNonTail - :: UsageDetails -> UsageDetails + :: HasDebugCallStack => UsageDetails -> UsageDetails markAllMany ud@(UD { ud_env = env }) = ud { ud_z_many = env } markAllInsideLam ud@(UD { ud_env = env }) = ud { ud_z_in_lam = env } -markAllNonTail ud@(UD { ud_env = env }) = ud { ud_z_tail = env } markAllManyNonTail = markAllMany . markAllNonTail -- effectively sets to noOccInfo -markAllInsideLamIf, markAllNonTailIf :: Bool -> UsageDetails -> UsageDetails +markAllNonTail ud@(UD { ud_env = env }) = + if isNullUFM env + then + ud { ud_z_tail = env } + else + -- SLD TODO pprTrace "markAllNonTail" ( text "zapping:" <+> ppr env $$ callStackDoc ) $ + ud { ud_z_tail = env } + +markAllInsideLamIf, markAllNonTailIf :: HasDebugCallStack => Bool -> UsageDetails -> UsageDetails markAllInsideLamIf True ud = markAllInsideLam ud markAllInsideLamIf False ud = ud @@ -3969,7 +3994,7 @@ adjustNonRecRhs mb_join_arity (WTUD (TUD rhs_ja uds) rhs) where exact_join = mb_join_arity == JoinPoint rhs_ja -adjustTailUsage :: Bool -- True <=> Exactly-matching join point; don't do markNonTail +adjustTailUsage :: HasDebugCallStack => Bool -- True <=> Exactly-matching join point; don't do markNonTail -> CoreExpr -- Rhs usage, AFTER occAnalLamTail -> UsageDetails -> UsageDetails @@ -3981,7 +4006,7 @@ adjustTailUsage exact_join rhs uds where one_shot = isOneShotFun rhs -adjustTailArity :: JoinPointHood -> TailUsageDetails -> UsageDetails +adjustTailArity :: HasDebugCallStack => JoinPointHood -> TailUsageDetails -> UsageDetails adjustTailArity mb_rhs_ja (TUD ja usage) = markAllNonTailIf (mb_rhs_ja /= JoinPoint ja) usage @@ -4015,7 +4040,7 @@ tagNonRecBinder :: TopLevelFlag -- At top level? -- Precondition: OccInfo is not IAmDead tagNonRecBinder lvl occ bndr | okForJoinPoint lvl bndr tail_call_info - , AlwaysTailCalled ar <- tail_call_info + , AlwaysTailCalled ar _ <- tail_call_info = (setBinderOcc occ bndr, JoinPoint ar) | otherwise = (setBinderOcc zapped_occ bndr, NotJoinPoint) @@ -4102,7 +4127,7 @@ okForJoinPoint lvl bndr tail_call_info = False where valid_join | NotTopLevel <- lvl - , AlwaysTailCalled arity <- tail_call_info + , AlwaysTailCalled arity _ <- tail_call_info , -- Invariant 1 as applied to LHSes of rules all (ok_rule arity) (idCoreRules bndr) @@ -4120,8 +4145,8 @@ okForJoinPoint lvl bndr tail_call_info lost_join | JoinPoint ja <- idJoinPointHood bndr = not valid_join || (case tail_call_info of -- Valid join but arity differs - AlwaysTailCalled ja' -> ja /= ja' - _ -> False) + AlwaysTailCalled ja' _ -> ja /= ja' + _ -> False) | otherwise = False ok_rule _ BuiltinRule{} = False -- only possible with plugin shenanigans @@ -4143,7 +4168,7 @@ okForJoinPoint lvl bndr tail_call_info , text "tc:" <+> ppr tail_call_info , text "rules:" <+> ppr (idCoreRules bndr) , case tail_call_info of - AlwaysTailCalled arity -> + AlwaysTailCalled arity _ -> vcat [ text "ok_unf:" <+> ppr (ok_unfolding arity (realIdUnfolding bndr)) , text "ok_type:" <+> ppr (isValidJoinPointType arity (idType bndr)) ] _ -> empty ] @@ -4206,6 +4231,6 @@ orLocalOcc (OneOccL { lo_n_br = nbr1, lo_int_cxt = int_cxt1, lo_tail = tci1 }) orLocalOcc occ1 occ2 = andLocalOcc occ1 occ2 andTailCallInfo :: TailCallInfo -> TailCallInfo -> TailCallInfo -andTailCallInfo info@(AlwaysTailCalled arity1) (AlwaysTailCalled arity2) - | arity1 == arity2 = info +andTailCallInfo (AlwaysTailCalled arity1 p1) (AlwaysTailCalled arity2 p2) + | arity1 == arity2 = AlwaysTailCalled arity1 (max p1 p2) andTailCallInfo _ _ = NoTailCallInfo ===================================== compiler/GHC/Core/Opt/Simplify/Env.hs ===================================== @@ -201,6 +201,8 @@ data SimplEnv , seCaseDepth :: !Int -- Depth of multi-branch case alternatives + , seProfTicks :: !Int -- SLD TODO + , seInlineDepth :: !Int -- 0 initially, 1 when we inline an already-simplified -- unfolding, and simplify again; and so on -- See Note [Inline depth] @@ -588,6 +590,7 @@ mkSimplEnv mode fam_envs , seIdSubst = emptyVarEnv , seRecIds = emptyUnVarSet , seCaseDepth = 0 + , seProfTicks = 0 , seInlineDepth = 0 } -- The top level "enclosing CC" is "SUBSUMED". ===================================== compiler/GHC/Core/Opt/Simplify/Iteration.hs ===================================== @@ -39,7 +39,7 @@ import GHC.Core.Opt.Arity ( ArityType, exprArity, arityTypeBotSigs_maybe , pushCoTyArg, pushCoValArg, exprIsDeadEnd , typeArity, arityTypeArity, etaExpandAT ) import GHC.Core.SimpleOpt ( exprIsConApp_maybe, joinPointBinding_maybe, joinPointBindings_maybe ) -import GHC.Core.FVs ( mkRuleInfo {- exprsFreeIds -} ) +import GHC.Core.FVs ( mkRuleInfo ) import GHC.Core.Rules ( lookupRule, getRules ) import GHC.Core.Multiplicity @@ -57,10 +57,11 @@ import GHC.Types.Unique ( hasKey ) import GHC.Types.Basic import GHC.Types.Tickish import GHC.Types.Var ( isTyCoVar ) + import GHC.Builtin.Types.Prim( realWorldStatePrimTy ) import GHC.Builtin.Names( runRWKey, seqHashKey ) -import GHC.Data.Maybe ( isNothing, orElse, mapMaybe ) +import GHC.Data.Maybe ( isNothing, orElse, fromMaybe, mapMaybe ) import GHC.Data.FastString import GHC.Unit.Module ( moduleName ) import GHC.Utils.Outputable @@ -1442,7 +1443,10 @@ simplTick env tickish expr cont no_floating_past_tick = do { let (inc,outc) = splitCont cont - ; (floats, expr1) <- simplExprF env expr inc + env' = case tickish of + ProfNote {} -> env { seProfTicks = seProfTicks env + 1 } + _ -> env + ; (floats, expr1) <- simplExprF env' expr inc ; let expr2 = wrapFloats floats expr1 tickish' = simplTickish env tickish ; rebuild env (mkTick tickish' expr2) outc @@ -2051,8 +2055,8 @@ simplNonRecJoinPoint :: SimplEnv -> InId -> InExpr -> InExpr -> SimplCont -> SimplM (SimplFloats, OutExpr) simplNonRecJoinPoint env bndr rhs body cont - = assert (isJoinId bndr ) $ - wrapJoinCont env cont $ \ env cont -> + = assert (isJoinId bndr) $ + wrapJoinCont do_case_case env cont $ \ env cont -> do { -- We push join_cont into the join RHS and the body; -- and wrap wrap_cont around the whole thing ; let mult = contHoleScaling cont @@ -2062,14 +2066,19 @@ simplNonRecJoinPoint env bndr rhs body cont ; (floats1, env3) <- simplJoinBind NonRecursive cont (bndr,env) (bndr2,env2) (rhs,env) ; (floats2, body') <- simplExprF env3 body cont ; return (floats1 `addFloats` floats2, body') } + where + do_case_case + | Just occMaxProfTicks <- occursUnderProfTick (idOccInfo bndr) + , occMaxProfTicks > seProfTicks env + = False + | otherwise + = seCaseCase env - ------------------- simplRecJoinPoint :: SimplEnv -> [(InId, InExpr)] -> InExpr -> SimplCont -> SimplM (SimplFloats, OutExpr) simplRecJoinPoint env pairs body cont - = wrapJoinCont env cont $ \ env cont -> + = wrapJoinCont do_case_case env cont $ \ env cont -> do { let bndrs = map fst pairs mult = contHoleScaling cont res_ty = contResultType cont @@ -2079,30 +2088,38 @@ simplRecJoinPoint env pairs body cont ; (floats1, env2) <- simplRecBind env1 (BC_Join Recursive cont) pairs ; (floats2, body') <- simplExprF env2 body cont ; return (floats1 `addFloats` floats2, body') } + where + do_case_case + | any ((seProfTicks env <) . fromMaybe 0 . occursUnderProfTick . idOccInfo . fst) pairs + = False + | otherwise + = seCaseCase env -------------------- -wrapJoinCont :: SimplEnv -> SimplCont +wrapJoinCont :: Bool + -> SimplEnv -> SimplCont -> (SimplEnv -> SimplCont -> SimplM (SimplFloats, OutExpr)) -> SimplM (SimplFloats, OutExpr) -- Deal with making the continuation duplicable if necessary, -- and with the no-case-of-case situation. -wrapJoinCont env cont thing_inside +wrapJoinCont do_case_case env cont thing_inside | contIsStop cont -- Common case; no need for fancy footwork = thing_inside env cont - | not (seCaseCase env) - -- See Note [Join points with -fno-case-of-case] - = do { (floats1, expr1) <- thing_inside env (mkBoringStop (contHoleType cont)) - ; let (floats2, expr2) = wrapJoinFloatsX floats1 expr1 - ; (floats3, expr3) <- rebuild (env `setInScopeFromF` floats2) expr2 cont - ; return (floats2 `addFloats` floats3, expr3) } - - | otherwise - -- Normal case; see Note [Join points and case-of-case] + | do_case_case + -- Normal situation: do the "case-of-case" transformation. + -- See Note [Join points and case-of-case]. = do { (floats1, cont') <- mkDupableCont env cont ; (floats2, result) <- thing_inside (env `setInScopeFromF` floats1) cont' ; return (floats1 `addFloats` floats2, result) } + | otherwise + -- No "case-of-case" transformation. + -- See Note [Join points with -fno-case-of-case]. + = do { (floats1, expr1) <- thing_inside env (mkBoringStop (contHoleType cont)) + ; let (floats2, expr2) = wrapJoinFloatsX floats1 expr1 + ; (floats3, expr3) <- rebuild (env `setInScopeFromF` floats2) expr2 cont + ; return (floats2 `addFloats` floats3, expr3) } -------------------- trimJoinCont :: Id -- Used only in error message @@ -2151,13 +2168,13 @@ evaluation context E): As is evident from the example, there are two components to this behavior: - 1. When entering the RHS of a join point, copy the context inside. - 2. When a join point is invoked, discard the outer context. + (wrapJoinCont) When entering the RHS of a join point, copy the context inside. + (trimJoinCont) When a join point is invoked, discard the outer context. We need to be very careful here to remain consistent---neither part is optional! -We need do make the continuation E duplicable (since we are duplicating it) +We need to make the continuation E duplicable (since we are duplicating it) with mkDupableCont. @@ -2184,7 +2201,8 @@ case-of-case we may then end up with this totally bogus result This would be OK in the language of the paper, but not in GHC: j is no longer a join point. We can only do the "push continuation into the RHS of the join point j" if we also push the continuation right down to the /jumps/ to -j, so that it can evaporate there. If we are doing case-of-case, we'll get to +j, so that it can evaporate there (trimJoinCont). Then, if we are doing +case-of-case, we'll get to join x = case <j-rhs> of <outer-alts> in case y of ===================================== compiler/GHC/Core/SimpleOpt.hs ===================================== @@ -1076,7 +1076,7 @@ joinPointBinding_maybe bndr rhs | isJoinId bndr = Just (bndr, rhs) - | AlwaysTailCalled join_arity <- tailCallInfo (idOccInfo bndr) + | AlwaysTailCalled join_arity _ <- tailCallInfo (idOccInfo bndr) , (bndrs, body) <- etaExpandToJoinPoint join_arity rhs , let str_sig = idDmdSig bndr str_arity = count isId bndrs -- Strictness demands are for Ids only ===================================== compiler/GHC/Core/Utils.hs ===================================== @@ -1977,7 +1977,7 @@ altsAreExhaustive (Alt con1 _ _ : alts) -- Takes the function we are applying as argument. etaExpansionTick :: Id -> GenTickish pass -> Bool etaExpansionTick id t - = hasNoBinding id && + = ( hasNoBinding id || isJoinId id ) && -- SLD TODO ( tickishFloatable t || isProfTick t ) {- Note [exprOkForSpeculation and type classes] ===================================== compiler/GHC/CoreToStg/Prep.hs ===================================== @@ -1130,7 +1130,10 @@ cpeApp top_env expr hd = getIdFromTrivialExpr_maybe e2 -- Determine number of required arguments. See Note [Ticks and mandatory eta expansion] min_arity = case hd of - Just v_hd -> if hasNoBinding v_hd then Just $! (idArity v_hd) else Nothing + Just v_hd -> + if hasNoBinding v_hd || isJoinId v_hd -- SLD TODO (re-use cantEtaReduceFun?) + then Just $! idArity v_hd + else Nothing Nothing -> Nothing -- ; pprTraceM "cpe_app:stricts:" (ppr v <+> ppr args $$ ppr stricts $$ ppr (idCbvMarks_maybe v)) ; (app, floats, unsat_ticks) <- rebuild_app env args e2 emptyFloats stricts min_arity ===================================== compiler/GHC/Types/Basic.hs ===================================== @@ -70,7 +70,7 @@ module GHC.Types.Basic ( BranchCount, oneBranch, InterestingCxt(..), TailCallInfo(..), tailCallInfo, zapOccTailCallInfo, - isAlwaysTailCalled, + isAlwaysTailCalled, occursUnderProfTick, EP(..), @@ -1150,7 +1150,7 @@ instance Monoid InsideLam where ----------------- data TailCallInfo - = AlwaysTailCalled {-# UNPACK #-} !JoinArity -- See Note [TailCallInfo] + = AlwaysTailCalled {-# UNPACK #-} !JoinArity !Int-- See Note [TailCallInfo] | NoTailCallInfo deriving (Eq) @@ -1167,9 +1167,15 @@ isAlwaysTailCalled occ = case tailCallInfo occ of AlwaysTailCalled{} -> True NoTailCallInfo -> False +occursUnderProfTick :: OccInfo -> Maybe Int +occursUnderProfTick occ = + case tailCallInfo occ of + AlwaysTailCalled _ b -> Just b + NoTailCallInfo -> Nothing + instance Outputable TailCallInfo where - ppr (AlwaysTailCalled ar) = sep [ text "Tail", int ar ] - ppr _ = empty + ppr (AlwaysTailCalled ar b) = sep [ text "Tail", brackets (int b), int ar ] + ppr _ = text "NoTailCallInfo" --empty ----------------- strongLoopBreaker, weakLoopBreaker :: OccInfo @@ -1217,7 +1223,8 @@ instance Outputable OccInfo where pp_tail = pprShortTailCallInfo tail_info pprShortTailCallInfo :: TailCallInfo -> SDoc -pprShortTailCallInfo (AlwaysTailCalled ar) = char 'T' <> brackets (int ar) +pprShortTailCallInfo (AlwaysTailCalled ar p) + = char 'T' <> (brackets (text "P" <+> int p)) <> brackets (int ar) pprShortTailCallInfo NoTailCallInfo = empty {- View it on GitLab: https://gitlab.haskell.org/ghc/ghc/-/commit/8b8e2154801f619b716b9e46e69c65b5... -- View it on GitLab: https://gitlab.haskell.org/ghc/ghc/-/commit/8b8e2154801f619b716b9e46e69c65b5... You're receiving this email because of your account on gitlab.haskell.org.
participants (1)
-
sheaf (@sheaf)