[Git][ghc/ghc][wip/andreask/ticked_joins] 2 commits: fix fix
sheaf pushed to branch wip/andreask/ticked_joins at Glasgow Haskell Compiler / GHC Commits: e4500d48 by sheaf at 2026-01-26T23:46:57+01:00 fix fix - - - - - 5239e258 by sheaf at 2026-01-27T00:01:20+01:00 working? - - - - - 4 changed files: - compiler/GHC/Core/Lint.hs - compiler/GHC/Core/Opt/OccurAnal.hs - compiler/GHC/Core/Opt/SetLevels.hs - compiler/GHC/Core/Opt/Simplify/Iteration.hs Changes: ===================================== compiler/GHC/Core/Lint.hs ===================================== @@ -1151,7 +1151,11 @@ lintJoinBndrType :: OutType -- Type of the body -- E.g. join j x = rhs in body -- The type of 'rhs' must be the same as the type of 'body' lintJoinBndrType body_ty bndr - | JoinPoint { joinPointArity = arity } <- idJoinPointHood bndr + | JoinPoint + { joinPointArity = arity + , joinPointType = TrueJoinPoint + -- SLD TODO: quasi join points can have intervening casts + } <- idJoinPointHood bndr , let bndr_ty = idType bndr , (bndrs, res) <- splitPiTys bndr_ty = do let msg = ===================================== compiler/GHC/Core/Opt/OccurAnal.hs ===================================== @@ -1127,7 +1127,7 @@ occAnalNonRecRhs !env lvl imp_rule_edges mb_join bndr rhs -- returned by of occAnalLamTail. It's totally OK for them to mismatch; -- hence adjust the UDs from the RHS - WUD adj_rhs_uds final_rhs = adjustNonRecRhs (joinPointHoodArity mb_join) $ + WUD adj_rhs_uds final_rhs = adjustNonRecRhs mb_join $ occAnalLamTail rhs_env rhs final_bndr_with_rules | noBinderSwaps env = bndr -- See Note [Unfoldings and rules] @@ -1217,7 +1217,7 @@ occAnalRec !_ lvl = WUD body_uds binds | otherwise = let (bndr', mb_join) = tagNonRecBinder lvl occ bndr - !(WUD rhs_uds' rhs') = adjustNonRecRhs (joinPointHoodArity mb_join) wtuds + !(WUD rhs_uds' rhs') = adjustNonRecRhs mb_join wtuds in WUD (body_uds `andUDs` rhs_uds') (NonRec bndr' rhs' : binds) where @@ -2621,7 +2621,7 @@ occAnal env app@(App _ _) = occAnalApp env (collectArgsTicks tickishFloatable app) occAnal env expr@(Lam {}) - = adjustNonRecRhs Nothing $ -- Nothing <=> markAllManyNonTail + = adjustNonRecRhs NotJoinPoint $ -- NotJoinPoint <=> markAllManyNonTail occAnalLamTail env expr occAnal env (Case scrut bndr ty alts) @@ -2749,7 +2749,7 @@ occAnalApp env (Var fun, args, ticks) -- This caused #18296 | fun `hasKey` runRWKey , [t1, t2, arg] <- args - , WUD usage arg' <- adjustNonRecRhs (Just 1) $ occAnalLamTail env arg + , WUD usage arg' <- adjustNonRecRhs (JoinPoint TrueJoinPoint 1) $ occAnalLamTail env arg = let app_out = mkTicks ticks $ mkApps (Var fun) [t1, t2, arg'] in WUD usage app_out @@ -3975,21 +3975,21 @@ lookupOccInfoByUnique (UD { ud_env = env ------------------- -- See Note [Adjusting right-hand sides] -adjustNonRecRhs :: Maybe JoinArity +adjustNonRecRhs :: JoinPointHood -> WithTailUsageDetails CoreExpr -> WithUsageDetails CoreExpr -- ^ This function concentrates shared logic between occAnalNonRecBind and the -- AcyclicSCC case of occAnalRec. -- It returns the adjusted rhs UsageDetails combined with the body usage -adjustNonRecRhs mb_join_arity (WTUD (TUD rhs_ja uds) rhs) +adjustNonRecRhs mb_join (WTUD (TUD rhs_ja uds) rhs) = WUD (adjustTailUsage exact_join rhs uds) rhs where exact_join = - case mb_join_arity of - Nothing -> Nothing - Just ja' -> + case mb_join of + NotJoinPoint -> Nothing + JoinPoint { joinPointArity = ja', joinPointType = ty } -> if ja' == rhs_ja - then Just TrueJoinPoint + then Just ty else Nothing adjustTailUsage :: HasDebugCallStack @@ -4120,11 +4120,8 @@ decideRecJoinPointHood :: TopLevelFlag -> UsageDetails -> [CoreBndr] -> Maybe JoinPointType decideRecJoinPointHood lvl usage bndrs = do bndrsNE <- NE.nonEmpty bndrs - res <- Semi.sconcat <$> traverse ok bndrsNE -- Invariant 3: Either all are join points or none are - pprTraceM "decideRecJoinPointHood" $ - vcat [ text "bndrs:" <+> ppr bndrs - , text "res:" <+> ppr res ] - return res + -- Invariant 3: Either all are join points or none are + Semi.sconcat <$> traverse ok bndrsNE where ok bndr = okForJoinPoint lvl bndr (lookupTailCallInfo usage bndr) @@ -4132,10 +4129,11 @@ okForJoinPoint :: TopLevelFlag -> Id -> TailCallInfo -> Maybe JoinPointType -- See Note [Invariants on join points]; invariants cited by number below. -- Invariant 2 is always satisfiable by the simplifier by eta expansion. okForJoinPoint lvl bndr tail_call_info - | Just join_ty <- joinId_maybe bndr + | isJoinId bndr -- A current join point should still be one! = warnPprTrace lost_join "Lost join point" lost_join_doc $ - Just join_ty + mb_valid_join + -- NB: we might downgrade 'TrueJoinPoint' to 'QuasiJoinPoint'. | otherwise = mb_valid_join where ===================================== compiler/GHC/Core/Opt/SetLevels.hs ===================================== @@ -1895,7 +1895,7 @@ newPolyBndrs dest_lvl , not dest_is_top = asJoinId new_bndr join_ty - ( join_arity + length abs_vars ) + (join_arity + length abs_vars) | otherwise = new_bndr ===================================== compiler/GHC/Core/Opt/Simplify/Iteration.hs ===================================== @@ -2088,7 +2088,7 @@ simplNonRecJoinPoint env bndr rhs body cont ; return (floats1 `addFloats` floats2, body') } where do_case_case - | Just TrueJoinPoint <- occInfoJoinPointType_maybe (idOccInfo bndr) + | Just TrueJoinPoint <- joinId_maybe bndr = seCaseCase env | otherwise = False @@ -2114,7 +2114,7 @@ simplRecJoinPoint env pairs body cont ; return (floats1 `addFloats` floats2, body') } where do_case_case = - if all ((== Just TrueJoinPoint) . occInfoJoinPointType_maybe . idOccInfo . fst) pairs + if all ((== Just TrueJoinPoint) . joinId_maybe . fst) pairs then seCaseCase env else False @@ -2154,15 +2154,15 @@ trimJoinCont :: Id -- Used only in error message trimJoinCont _ NotJoinPoint cont = cont -- Not a jump trimJoinCont var (JoinPoint { joinPointType = join_ty, joinPointArity = arity }) cont - | QuasiJoinPoint <- join_ty - -- SLD TODO - = cont - | otherwise = trim arity cont where trim 0 cont@(Stop {}) = cont trim 0 cont + | QuasiJoinPoint <- join_ty + -- SLD TODO explain + = cont + | otherwise = mkBoringStop (contResultType cont) trim n cont@(ApplyToVal { sc_cont = k }) = cont { sc_cont = trim (n-1) k } View it on GitLab: https://gitlab.haskell.org/ghc/ghc/-/compare/a3ef7ffcc6418ab11ba039915f99513... -- View it on GitLab: https://gitlab.haskell.org/ghc/ghc/-/compare/a3ef7ffcc6418ab11ba039915f99513... You're receiving this email because of your account on gitlab.haskell.org.
participants (1)
-
sheaf (@sheaf)