[Git][ghc/ghc][wip/andreask/ticked_joins] WIP try harder
sheaf pushed to branch wip/andreask/ticked_joins at Glasgow Haskell Compiler / GHC Commits: a3ef7ffc by sheaf at 2026-01-26T23:12:26+01:00 WIP try harder - - - - - 5 changed files: - compiler/GHC/Core/Lint.hs - compiler/GHC/Core/Opt/OccurAnal.hs - compiler/GHC/Core/Opt/Simplify/Iteration.hs - compiler/GHC/Core/Opt/Simplify/Monad.hs - compiler/GHC/Types/Id/Info.hs Changes: ===================================== compiler/GHC/Core/Lint.hs ===================================== @@ -915,7 +915,7 @@ lintCoreExpr (Lit lit) ; return (literalType lit, zeroUE) } lintCoreExpr (Cast expr co) - = do { (expr_ty, ue) <- markAllJoinsBad (lintCoreExpr expr) + = do { (expr_ty, ue) <- lintCoreExpr expr -- SLD TODO markAllJoinsBad (lintCoreExpr expr) -- markAllJoinsBad: see Note [Join points and casts] ; lintCoercion co @@ -1146,7 +1146,7 @@ checkDeadIdOcc id ------------------ lintJoinBndrType :: OutType -- Type of the body -> OutId -- Possibly a join Id - -> LintM () + -> LintM () -- Checks that the return type of a join Id matches the body -- E.g. join j x = rhs in body -- The type of 'rhs' must be the same as the type of 'body' ===================================== compiler/GHC/Core/Opt/OccurAnal.hs ===================================== @@ -42,7 +42,7 @@ import GHC.Core.Coercion import GHC.Core.Type import GHC.Core.TyCo.FVs ( tyCoVarsOfMCo ) -import GHC.Data.Maybe( orElse ) +import GHC.Data.Maybe( orElse, isNothing ) import GHC.Data.Graph.Directed ( SCC(..), Node(..) , stronglyConnCompFromEdgedVerticesUniq , stronglyConnCompFromEdgedVerticesUniqR ) @@ -68,6 +68,7 @@ import GHC.Builtin.Names( runRWKey ) import GHC.Unit.Module( Module ) import Data.List (mapAccumL) +import qualified Data.List.NonEmpty as NE import qualified Data.Semigroup as Semi {- @@ -2299,7 +2300,7 @@ occ_anal_lam_tail env (Cast expr co) _ -> usage1 -- usage3: see Note [Quasi join points] in GHC.Core.Opt.Simplify.Iteration. - usage3 = markAllNonTail usage2 -- SLD TODO + usage3 = markAllQuasiTail usage2 -- SLD TODO in WUD usage3 (Cast expr' co) @@ -2612,7 +2613,7 @@ occAnal env (Cast expr co) = let (WUD usage expr') = occAnal env expr usage1 = addManyOccs usage (coVarsOfCo co) -- usage1: see Note [Gather occurrences of coercion variables] - usage2 = markAllNonTail usage1 -- SLD TODO + usage2 = markAllQuasiTail usage1 -- SLD TODO -- usage2: see Note [Quasi join points] in WUD usage2 (Cast expr' co) @@ -3985,20 +3986,31 @@ adjustNonRecRhs mb_join_arity (WTUD (TUD rhs_ja uds) rhs) where exact_join = case mb_join_arity of - Nothing -> False - Just ja' -> ja' == rhs_ja - -adjustTailUsage :: Bool -- True <=> Exactly-matching join point; don't do markNonTail + Nothing -> Nothing + Just ja' -> + if ja' == rhs_ja + then Just TrueJoinPoint + else Nothing + +adjustTailUsage :: HasDebugCallStack + => Maybe JoinPointType -> CoreExpr -- Rhs usage, AFTER occAnalLamTail -> UsageDetails -> UsageDetails -adjustTailUsage exact_join rhs uds +adjustTailUsage mb_join rhs uds = -- c.f. occAnal (Lam {}) markAllInsideLamIf (not one_shot) $ - markAllNonTailIf (not exact_join) $ + mb_mark_nontail $ uds where - one_shot = isOneShotFun rhs + one_shot = isOneShotFun rhs + mb_mark_nontail = + case mb_join of + Nothing -> markAllNonTail + Just join_ty -> + case join_ty of + QuasiJoinPoint -> markAllQuasiTail + TrueJoinPoint -> id adjustTailArity :: Maybe JoinArity -> TailUsageDetails -> UsageDetails adjustTailArity mb_rhs_ja (TUD ja usage) = markAllNonTailIf not_same_arity usage @@ -4036,11 +4048,8 @@ tagNonRecBinder :: TopLevelFlag -- At top level? -- No-op on TyVars -- Precondition: OccInfo is not IAmDead tagNonRecBinder lvl occ bndr - | okForJoinPoint lvl bndr tail_call_info - , AlwaysTailCalled - { tailCallArity = ar - , tailCallJoinPointType = join_ty - } <- tail_call_info + | Just join_ty <- okForJoinPoint lvl bndr tail_call_info + , AlwaysTailCalled { tailCallArity = ar } <- tail_call_info = (setBinderOcc occ bndr, JoinPoint join_ty ar) | otherwise = (setBinderOcc zapped_occ bndr, NotJoinPoint) @@ -4070,7 +4079,7 @@ tagRecBinders lvl body_uds details_s = assertPpr (rhs_ja == joinRhsArity rhs) (ppr rhs_ja $$ ppr uds $$ ppr rhs) $ uds - will_be_joins :: Bool + will_be_joins :: Maybe JoinPointType will_be_joins = decideRecJoinPointHood lvl unadj_uds bndrs -- 2. Adjust usage details of each RHS, taking into account the @@ -4108,42 +4117,50 @@ setBinderOcc occ_info bndr -- -- See Note [Invariants on join points] in "GHC.Core". decideRecJoinPointHood :: TopLevelFlag -> UsageDetails - -> [CoreBndr] -> Bool -decideRecJoinPointHood lvl usage bndrs - = all ok bndrs -- Invariant 3: Either all are join points or none are + -> [CoreBndr] -> Maybe JoinPointType +decideRecJoinPointHood lvl usage bndrs = do + bndrsNE <- NE.nonEmpty bndrs + res <- Semi.sconcat <$> traverse ok bndrsNE -- Invariant 3: Either all are join points or none are + pprTraceM "decideRecJoinPointHood" $ + vcat [ text "bndrs:" <+> ppr bndrs + , text "res:" <+> ppr res ] + return res where ok bndr = okForJoinPoint lvl bndr (lookupTailCallInfo usage bndr) -okForJoinPoint :: TopLevelFlag -> Id -> TailCallInfo -> Bool +okForJoinPoint :: TopLevelFlag -> Id -> TailCallInfo -> Maybe JoinPointType -- See Note [Invariants on join points]; invariants cited by number below. -- Invariant 2 is always satisfiable by the simplifier by eta expansion. okForJoinPoint lvl bndr tail_call_info - | isJoinId bndr -- A current join point should still be one! + | Just join_ty <- joinId_maybe bndr + -- A current join point should still be one! = warnPprTrace lost_join "Lost join point" lost_join_doc $ - True - | valid_join - = True + Just join_ty | otherwise - = False + = mb_valid_join where - valid_join | NotTopLevel <- lvl - , AlwaysTailCalled { tailCallArity = arity } <- tail_call_info - - , -- Invariant 1 as applied to LHSes of rules - all (ok_rule arity) (idCoreRules bndr) - - -- Invariant 2a: stable unfoldings - -- See Note [Join points and INLINE pragmas] - , ok_unfolding arity (realIdUnfolding bndr) - - -- Invariant 4: Satisfies polymorphism rule - , isValidJoinPointType arity (idType bndr) - = True - | otherwise - = False + mb_valid_join + | NotTopLevel <- lvl + , AlwaysTailCalled + { tailCallArity = arity + , tailCallJoinPointType = join_ty + } <- tail_call_info + + , -- Invariant 1 as applied to LHSes of rules + all (ok_rule arity) (idCoreRules bndr) + + -- Invariant 2a: stable unfoldings + -- See Note [Join points and INLINE pragmas] + , ok_unfolding arity (realIdUnfolding bndr) + + -- Invariant 4: Satisfies polymorphism rule + , isValidJoinPointType arity (idType bndr) + = Just join_ty + | otherwise + = Nothing lost_join | JoinPoint { joinPointArity = ja } <- idJoinPointHood bndr - = not valid_join || + = isNothing mb_valid_join || (case tail_call_info of -- Valid join but arity differs AlwaysTailCalled { tailCallArity = ja' } -> ja /= ja' _ -> False) ===================================== compiler/GHC/Core/Opt/Simplify/Iteration.hs ===================================== @@ -2056,6 +2056,17 @@ is a join point, and what 'cont' is, in a value of type MaybeJoinCont of a SpecConstr-generated RULE for a join point. -} +-- SLD TODO horrible logic that must be removed +peelJoinResTy :: Int -> Type -> Type +peelJoinResTy 0 ty = ty +peelJoinResTy n ty + | Just (_bndr, inner_ty) <- splitForAllTyCoVar_maybe ty + = peelJoinResTy n inner_ty + | Just (_, _mult, _arg, res_ty) <- splitFunTy_maybe ty + = peelJoinResTy (n-1) res_ty + | otherwise + = ty + simplNonRecJoinPoint :: SimplEnv -> InId -> InExpr -> InExpr -> SimplCont -> SimplM (SimplFloats, OutExpr) @@ -2064,8 +2075,12 @@ simplNonRecJoinPoint env bndr rhs body cont wrapJoinCont do_case_case env cont $ \ env cont -> do { -- We push join_cont into the join RHS and the body; -- and wrap wrap_cont around the whole thing - ; let mult = contHoleScaling cont - res_ty = contResultType cont + ; let (mult, res_ty) + -- SLD TODO + | Just QuasiJoinPoint <- occInfoJoinPointType_maybe (idOccInfo bndr) + = (idMult bndr, peelJoinResTy (idJoinArity bndr) $ substTy env (idType bndr)) + | otherwise + = (contHoleScaling cont, contResultType cont) ; (env1, bndr1) <- simplNonRecJoinBndr env bndr mult res_ty ; (env2, bndr2) <- addBndrRules env1 bndr bndr1 (BC_Join NonRecursive cont) ; (floats1, env3) <- simplJoinBind NonRecursive cont (bndr,env) (bndr2,env2) (rhs,env) @@ -2084,8 +2099,13 @@ simplRecJoinPoint :: SimplEnv -> [(InId, InExpr)] simplRecJoinPoint env pairs body cont = wrapJoinCont do_case_case env cont $ \ env cont -> do { let bndrs = map fst pairs - mult = contHoleScaling cont - res_ty = contResultType cont + (mult, res_ty) + -- SLD TODO + | [b] <- bndrs + , Just QuasiJoinPoint <- occInfoJoinPointType_maybe (idOccInfo b) + = (idMult b, peelJoinResTy (idJoinArity b) $ substTy env (idType b)) + | otherwise + = (contHoleScaling cont, contResultType cont) ; env1 <- simplRecJoinBndrs env bndrs mult res_ty -- NB: bndrs' don't have unfoldings or rules -- We add them as we go down @@ -2135,7 +2155,7 @@ trimJoinCont _ NotJoinPoint cont = cont -- Not a jump trimJoinCont var (JoinPoint { joinPointType = join_ty, joinPointArity = arity }) cont | QuasiJoinPoint <- join_ty - -- As per Note [Quasi join points], don't do any trimming for quasi join points. + -- SLD TODO = cont | otherwise = trim arity cont ===================================== compiler/GHC/Core/Opt/Simplify/Monad.hs ===================================== @@ -214,7 +214,7 @@ newJoinId bndrs body_ty -- arity: See Note [Invariants on join points] invariant 2b, in GHC.Core join_arity = length bndrs details = JoinId - { joinIdType = TrueJoinPoint -- SLD TODO this is very suspicious + { joinIdType = TrueJoinPoint -- SLD TODO this is suspicious , joinIdArity = join_arity , joinIdCbvMarks = Nothing } ===================================== compiler/GHC/Types/Id/Info.hs ===================================== @@ -441,7 +441,7 @@ pprIdDetails other = brackets (pp other) pp CoVarId = text "CoVarId" pp (JoinId ty arity marks) = quasi <> text "JoinId" <> parens (int arity) <> parens (ppr marks) where - quasi = case ty of { QuasiJoinPoint -> text "quasi"; TrueJoinPoint -> empty } + quasi = case ty of { QuasiJoinPoint -> text "Quasi"; TrueJoinPoint -> empty } {- ************************************************************************ View it on GitLab: https://gitlab.haskell.org/ghc/ghc/-/commit/a3ef7ffcc6418ab11ba039915f99513a... -- View it on GitLab: https://gitlab.haskell.org/ghc/ghc/-/commit/a3ef7ffcc6418ab11ba039915f99513a... You're receiving this email because of your account on gitlab.haskell.org.
participants (1)
-
sheaf (@sheaf)