sheaf pushed to branch wip/andreask/ticked_joins at Glasgow Haskell Compiler / GHC
Commits:
-
77386709
by sheaf at 2026-01-27T18:46:12+01:00
4 changed files:
- compiler/GHC/Core/Opt/Exitify.hs
- compiler/GHC/Core/Opt/OccurAnal.hs
- compiler/GHC/Core/Opt/Simplify/Iteration.hs
- compiler/GHC/Types/Id.hs
Changes:
| ... | ... | @@ -45,12 +45,14 @@ import GHC.Core.Type |
| 45 | 45 | import GHC.Types.Var
|
| 46 | 46 | import GHC.Types.Id
|
| 47 | 47 | import GHC.Types.Id.Info
|
| 48 | +import GHC.Types.Tickish ( GenTickish(..), tickishCanScopeJoin )
|
|
| 49 | + |
|
| 48 | 50 | import GHC.Types.Var.Set
|
| 49 | 51 | import GHC.Types.Var.Env
|
| 50 | -import GHC.Types.Basic( JoinPointHood(..) )
|
|
| 51 | 52 | |
| 52 | 53 | import GHC.Utils.Monad.State.Strict
|
| 53 | 54 | import GHC.Utils.Misc( mapSnd )
|
| 55 | +import GHC.Utils.Outputable
|
|
| 54 | 56 | |
| 55 | 57 | import GHC.Data.FastString
|
| 56 | 58 | |
| ... | ... | @@ -93,23 +95,23 @@ exitifyProgram binds = map goTopLvl binds |
| 93 | 95 | where
|
| 94 | 96 | in_scope' = in_scope `extendInScopeSet` bndr
|
| 95 | 97 | |
| 96 | - go in_scope (Let (Rec pairs) body)
|
|
| 97 | - | is_join_rec = mkLets (exitifyRec in_scope' pairs') body'
|
|
| 98 | - | otherwise = Let (Rec pairs') body'
|
|
| 98 | + go in_scope (Let (Rec pairs) body) =
|
|
| 99 | + case joinPointType_maybe (joinId_maybe . fst) pairs of
|
|
| 100 | + Just join_ty -> mkLets (exitifyRec join_ty in_scope' pairs') body'
|
|
| 101 | + Nothing -> Let (Rec pairs') body'
|
|
| 99 | 102 | where
|
| 100 | - is_join_rec = any (isJoinId . fst) pairs
|
|
| 101 | 103 | in_scope' = in_scope `extendInScopeSetBind` (Rec pairs)
|
| 102 | 104 | pairs' = mapSnd (go in_scope') pairs
|
| 103 | 105 | body' = go in_scope' body
|
| 104 | 106 | |
| 105 | 107 | |
| 106 | 108 | -- | State Monad used inside `exitify`
|
| 107 | -type ExitifyM = State [(JoinId, CoreExpr)]
|
|
| 109 | +type ExitifyM = State [(JoinId, CoreExpr)]
|
|
| 108 | 110 | |
| 109 | 111 | -- | Given a recursive group of a joinrec, identifies “exit paths” and binds them as
|
| 110 | 112 | -- join-points outside the joinrec.
|
| 111 | -exitifyRec :: InScopeSet -> [(Var,CoreExpr)] -> [CoreBind]
|
|
| 112 | -exitifyRec in_scope pairs
|
|
| 113 | +exitifyRec :: JoinPointType -> InScopeSet -> [(Var,CoreExpr)] -> [CoreBind]
|
|
| 114 | +exitifyRec joinrec_join_ty in_scope pairs
|
|
| 113 | 115 | = [ NonRec xid rhs | (xid,rhs) <- exits ] ++ [Rec pairs']
|
| 114 | 116 | where
|
| 115 | 117 | -- We need the set of free variables of many subexpressions here, so
|
| ... | ... | @@ -124,7 +126,7 @@ exitifyRec in_scope pairs |
| 124 | 126 | forM ann_pairs $ \(x,rhs) -> do
|
| 125 | 127 | -- go past the lambdas of the join point
|
| 126 | 128 | let (args, body) = collectNAnnBndrs (idJoinArity x) rhs
|
| 127 | - body' <- go args body
|
|
| 129 | + body' <- go joinrec_join_ty args body -- (ExitJoin2): start with JoinPointType of parent joinrec
|
|
| 128 | 130 | let rhs' = mkLams args body'
|
| 129 | 131 | return (x, rhs')
|
| 130 | 132 | |
| ... | ... | @@ -135,40 +137,41 @@ exitifyRec in_scope pairs |
| 135 | 137 | -- variables bound on the way and lifts it out as a join point.
|
| 136 | 138 | --
|
| 137 | 139 | -- ExitifyM is a state monad to keep track of floated binds
|
| 138 | - go :: [Var] -- Variables that are in-scope here, but
|
|
| 139 | - -- not in scope at the joinrec; that is,
|
|
| 140 | - -- we must potentially abstract over them.
|
|
| 141 | - -- Invariant: they are kept in dependency order
|
|
| 140 | + go :: JoinPointType -- what join point type to create; see Note [Exitification and quasi join points]
|
|
| 141 | + -> [Var] -- Variables that are in-scope here, but
|
|
| 142 | + -- not in scope at the joinrec; that is,
|
|
| 143 | + -- we must potentially abstract over them.
|
|
| 144 | + -- Invariant: they are kept in dependency order
|
|
| 142 | 145 | -> CoreExprWithFVs -- Current expression in tail position
|
| 143 | 146 | -> ExitifyM CoreExpr
|
| 144 | 147 | |
| 145 | 148 | -- We first look at the expression (no matter what it shape is)
|
| 146 | 149 | -- and determine if we can turn it into a exit join point
|
| 147 | - go captured ann_e
|
|
| 150 | + go exit_join_ty captured ann_e
|
|
| 148 | 151 | | -- An exit expression has no recursive calls
|
| 149 | 152 | let fvs = dVarSetToVarSet (freeVarsOf ann_e)
|
| 150 | 153 | , disjointVarSet fvs recursive_calls
|
| 151 | - = go_exit captured (deAnnotate ann_e) fvs
|
|
| 154 | + = go_exit exit_join_ty captured (deAnnotate ann_e) fvs
|
|
| 152 | 155 | |
| 153 | 156 | -- We could not turn it into a exit join point. So now recurse
|
| 154 | 157 | -- into all expression where eligible exit join points might sit,
|
| 155 | 158 | -- i.e. into all tail-call positions:
|
| 156 | 159 | |
| 157 | 160 | -- Case right hand sides are in tail-call position
|
| 158 | - go captured (_, AnnCase scrut bndr ty alts) = do
|
|
| 161 | + go exit_join_ty captured (_, AnnCase scrut bndr ty alts) = do
|
|
| 159 | 162 | alts' <- forM alts $ \(AnnAlt dc pats rhs) -> do
|
| 160 | - rhs' <- go (captured ++ [bndr] ++ pats) rhs
|
|
| 163 | + rhs' <- go exit_join_ty (captured ++ [bndr] ++ pats) rhs
|
|
| 161 | 164 | return (Alt dc pats rhs')
|
| 162 | 165 | return $ Case (deAnnotate scrut) bndr ty alts'
|
| 163 | 166 | |
| 164 | - go captured (_, AnnLet ann_bind body)
|
|
| 167 | + go exit_join_ty captured (_, AnnLet ann_bind body)
|
|
| 165 | 168 | -- join point, RHS and body are in tail-call position
|
| 166 | 169 | | AnnNonRec j rhs <- ann_bind
|
| 167 | 170 | , JoinPoint { joinPointArity = join_arity } <- idJoinPointHood j
|
| 168 | 171 | = do let (params, join_body) = collectNAnnBndrs join_arity rhs
|
| 169 | - join_body' <- go (captured ++ params) join_body
|
|
| 172 | + join_body' <- go exit_join_ty (captured ++ params) join_body
|
|
| 170 | 173 | let rhs' = mkLams params join_body'
|
| 171 | - body' <- go (captured ++ [j]) body
|
|
| 174 | + body' <- go exit_join_ty (captured ++ [j]) body
|
|
| 172 | 175 | return $ Let (NonRec j rhs') body'
|
| 173 | 176 | |
| 174 | 177 | -- rec join point, RHSs and body are in tail-call position
|
| ... | ... | @@ -178,30 +181,41 @@ exitifyRec in_scope pairs |
| 178 | 181 | pairs' <- forM pairs $ \(j,rhs) -> do
|
| 179 | 182 | let join_arity = idJoinArity j
|
| 180 | 183 | (params, join_body) = collectNAnnBndrs join_arity rhs
|
| 181 | - join_body' <- go (captured ++ js ++ params) join_body
|
|
| 184 | + join_body' <- go exit_join_ty (captured ++ js ++ params) join_body
|
|
| 182 | 185 | let rhs' = mkLams params join_body'
|
| 183 | 186 | return (j, rhs')
|
| 184 | - body' <- go (captured ++ js) body
|
|
| 187 | + body' <- go exit_join_ty (captured ++ js) body
|
|
| 185 | 188 | return $ Let (Rec pairs') body'
|
| 186 | 189 | |
| 187 | 190 | -- normal Let, only the body is in tail-call position
|
| 188 | 191 | | otherwise
|
| 189 | - = do body' <- go (captured ++ bindersOf bind ) body
|
|
| 192 | + = do body' <- go exit_join_ty (captured ++ bindersOf bind ) body
|
|
| 190 | 193 | return $ Let bind body'
|
| 191 | 194 | where bind = deAnnBind ann_bind
|
| 192 | 195 | |
| 196 | + -- (ExitJoin1) from Note [Exitification and quasi join points]
|
|
| 197 | + go _ captured (_, AnnCast ann_e (_, co)) = do
|
|
| 198 | + e' <- go QuasiJoinPoint captured ann_e
|
|
| 199 | + return (Cast e' co)
|
|
| 200 | + go exit_join_ty captured (_, AnnTick tickish ann_e)
|
|
| 201 | + | tickishCanScopeJoin tickish
|
|
| 202 | + = Tick tickish <$> go exit_join_ty captured ann_e
|
|
| 203 | + | ProfNote {} <- tickish
|
|
| 204 | + = Tick tickish <$> go QuasiJoinPoint captured ann_e
|
|
| 205 | + |
|
| 193 | 206 | -- Cannot be turned into an exit join point, but also has no
|
| 194 | 207 | -- tail-call subexpression. Nothing to do here.
|
| 195 | - go _ ann_e = return (deAnnotate ann_e)
|
|
| 208 | + go _ _ ann_e = return (deAnnotate ann_e)
|
|
| 196 | 209 | |
| 197 | 210 | ---------------------
|
| 198 | - go_exit :: [Var] -- Variables captured locally
|
|
| 211 | + go_exit :: JoinPointType -- what join point type to create; see Note [Exitification and quasi join points]
|
|
| 212 | + -> [Var] -- Variables captured locally
|
|
| 199 | 213 | -> CoreExpr -- An exit expression
|
| 200 | 214 | -> VarSet -- Free vars of the expression
|
| 201 | 215 | -> ExitifyM CoreExpr
|
| 202 | 216 | -- go_exit deals with a tail expression that is floatable
|
| 203 | 217 | -- out as an exit point; that is, it mentions no recursive calls
|
| 204 | - go_exit captured e fvs
|
|
| 218 | + go_exit exit_join_ty captured e fvs
|
|
| 205 | 219 | -- Do not touch an expression that is already a join jump where all arguments
|
| 206 | 220 | -- are captured variables. See Note [Idempotency]
|
| 207 | 221 | -- But _do_ float join jumps with interesting arguments.
|
| ... | ... | @@ -226,7 +240,7 @@ exitifyRec in_scope pairs |
| 226 | 240 | let rhs = mkLams abs_vars e
|
| 227 | 241 | avoid = in_scope `extendInScopeSetList` captured
|
| 228 | 242 | -- Remember this binding under a suitable name
|
| 229 | - ; v <- addExit avoid (length abs_vars) rhs
|
|
| 243 | + ; v <- addExit avoid exit_join_ty (length abs_vars) rhs
|
|
| 230 | 244 | -- And jump to it from here
|
| 231 | 245 | ; return $ mkVarApps (Var v) abs_vars }
|
| 232 | 246 | |
| ... | ... | @@ -263,7 +277,7 @@ exitifyRec in_scope pairs |
| 263 | 277 | -- * any bound variables (captured)
|
| 264 | 278 | -- * any exit join points created so far.
|
| 265 | 279 | mkExitJoinId :: InScopeSet -> Type -> JoinPointType -> JoinArity -> ExitifyM JoinId
|
| 266 | -mkExitJoinId in_scope ty join_ty join_arity = do
|
|
| 280 | +mkExitJoinId in_scope ty exit_join_ty join_arity = do
|
|
| 267 | 281 | fs <- get
|
| 268 | 282 | let avoid = in_scope `extendInScopeSetList` (map fst fs)
|
| 269 | 283 | `extendInScopeSet` exit_id_tmpl -- just cosmetics
|
| ... | ... | @@ -271,17 +285,65 @@ mkExitJoinId in_scope ty join_ty join_arity = do |
| 271 | 285 | where
|
| 272 | 286 | exit_id_tmpl =
|
| 273 | 287 | asJoinId (mkSysLocal (fsLit "exit") initExitJoinUnique ManyTy ty)
|
| 274 | - join_ty join_arity
|
|
| 288 | + exit_join_ty join_arity
|
|
| 275 | 289 | |
| 276 | -addExit :: InScopeSet -> JoinArity -> CoreExpr -> ExitifyM JoinId
|
|
| 277 | -addExit in_scope join_arity rhs = do
|
|
| 290 | +addExit :: InScopeSet -> JoinPointType -> JoinArity -> CoreExpr -> ExitifyM JoinId
|
|
| 291 | +addExit in_scope exit_join_ty join_arity rhs = do
|
|
| 278 | 292 | -- Pick a suitable name
|
| 279 | 293 | let ty = exprType rhs
|
| 280 | - v <- mkExitJoinId in_scope ty TrueJoinPoint join_arity
|
|
| 294 | + v <- mkExitJoinId in_scope ty exit_join_ty join_arity
|
|
| 281 | 295 | fs <- get
|
| 282 | 296 | put ((v,rhs):fs)
|
| 283 | 297 | return v
|
| 284 | 298 | |
| 299 | +{- Note [Exitification and quasi join points]
|
|
| 300 | +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
|
| 301 | +When we float an exit path, we must determine if the new exit join point
|
|
| 302 | +should be a true join point or a quasi join point, in the sense of
|
|
| 303 | +Note [Quasi join points] in GHC.Core.Opt.Simplify.Iteration.
|
|
| 304 | + |
|
| 305 | +The new exit join point must be a quasi join point if either of the following
|
|
| 306 | +conditions apply:
|
|
| 307 | + |
|
| 308 | + (ExitJoin1) The exit path occurs under a cast or a profiling tick.
|
|
| 309 | + |
|
| 310 | + (ExitJoin2) The original joinrec was a quasi join point.
|
|
| 311 | + |
|
| 312 | +Rationale for (ExitJoin1):
|
|
| 313 | + |
|
| 314 | + Suppose we have:
|
|
| 315 | + |
|
| 316 | + joinrec j x = ... case ... of alts -> e |> co ... in ...
|
|
| 317 | + |
|
| 318 | + After exitifying 'e' to 'exit':
|
|
| 319 | + |
|
| 320 | + join exit y = e in
|
|
| 321 | + joinrec j x = ... case ... of alts -> (exit y) |> co ... in ...
|
|
| 322 | + |
|
| 323 | + Because the jump to 'exit' occurs under a cast, 'exit' must be classified
|
|
| 324 | + as a quasi join point.
|
|
| 325 | + |
|
| 326 | +Rationale for (ExitJoin2):
|
|
| 327 | + |
|
| 328 | + Suppose we have:
|
|
| 329 | + |
|
| 330 | + quasijoinrec j x = case x of { 0 -> 100; _ -> j (x-1) } in j 0 |> co
|
|
| 331 | + |
|
| 332 | + If we float an exit out of 'j', we end up with
|
|
| 333 | + |
|
| 334 | + join exit = 100 in
|
|
| 335 | + quasijoinrec j x = case x of { 0 -> exit ; _ -> j (x-1) } in j 0 |> co
|
|
| 336 | + |
|
| 337 | + Now suppose we inline j and simplify; we end up with:
|
|
| 338 | + |
|
| 339 | + join exit = 100 in exit |> co
|
|
| 340 | + |
|
| 341 | + We see now that 'exit' must be a quasi join point, due to the cast.
|
|
| 342 | + |
|
| 343 | + Hence: exit join points for a parent quasi join point must themselves be
|
|
| 344 | + quasi join points.
|
|
| 345 | +-}
|
|
| 346 | + |
|
| 285 | 347 | {-
|
| 286 | 348 | Note [Interesting expression]
|
| 287 | 349 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
| ... | ... | @@ -68,7 +68,6 @@ import GHC.Builtin.Names( runRWKey ) |
| 68 | 68 | import GHC.Unit.Module( Module )
|
| 69 | 69 | |
| 70 | 70 | import Data.List (mapAccumL)
|
| 71 | -import qualified Data.List.NonEmpty as NE
|
|
| 72 | 71 | import qualified Data.Semigroup as Semi
|
| 73 | 72 | |
| 74 | 73 | {-
|
| ... | ... | @@ -4118,10 +4117,7 @@ setBinderOcc occ_info bndr |
| 4118 | 4117 | -- See Note [Invariants on join points] in "GHC.Core".
|
| 4119 | 4118 | decideRecJoinPointHood :: TopLevelFlag -> UsageDetails
|
| 4120 | 4119 | -> [CoreBndr] -> Maybe JoinPointType
|
| 4121 | -decideRecJoinPointHood lvl usage bndrs = do
|
|
| 4122 | - bndrsNE <- NE.nonEmpty bndrs
|
|
| 4123 | - -- Invariant 3: Either all are join points or none are
|
|
| 4124 | - Semi.sconcat <$> traverse ok bndrsNE
|
|
| 4120 | +decideRecJoinPointHood lvl usage = joinPointType_maybe ok
|
|
| 4125 | 4121 | where
|
| 4126 | 4122 | ok bndr = okForJoinPoint lvl bndr (lookupTailCallInfo usage bndr)
|
| 4127 | 4123 |
| ... | ... | @@ -2056,93 +2056,118 @@ is a join point, and what 'cont' is, in a value of type MaybeJoinCont |
| 2056 | 2056 | of a SpecConstr-generated RULE for a join point.
|
| 2057 | 2057 | -}
|
| 2058 | 2058 | |
| 2059 | --- SLD TODO horrible logic that must be removed
|
|
| 2060 | -peelJoinResTy :: Int -> Type -> Type
|
|
| 2061 | -peelJoinResTy 0 ty = ty
|
|
| 2062 | -peelJoinResTy n ty
|
|
| 2063 | - | Just (_bndr, inner_ty) <- splitForAllTyCoVar_maybe ty
|
|
| 2064 | - = peelJoinResTy n inner_ty
|
|
| 2065 | - | Just (_, _mult, _arg, res_ty) <- splitFunTy_maybe ty
|
|
| 2066 | - = peelJoinResTy (n-1) res_ty
|
|
| 2067 | - | otherwise
|
|
| 2068 | - = ty
|
|
| 2059 | +joinResTy :: HasDebugCallStack => JoinArity -> Type -> Type
|
|
| 2060 | +joinResTy n0 ty0 = go n0 ty0
|
|
| 2061 | + where
|
|
| 2062 | + go 0 ty = ty
|
|
| 2063 | + go n ty
|
|
| 2064 | + | Just (_bndr, res_ty) <- splitPiTy_maybe ty
|
|
| 2065 | + = go (n-1) res_ty
|
|
| 2066 | + | otherwise
|
|
| 2067 | + = pprPanic "joinResTy" $
|
|
| 2068 | + vcat [ text "join arity:" <+> ppr n0
|
|
| 2069 | + , text "join ty:" <+> ppr ty0
|
|
| 2070 | + , text "n:" <+> ppr n
|
|
| 2071 | + , text "ty:" <+> ppr ty
|
|
| 2072 | + ]
|
|
| 2069 | 2073 | |
| 2070 | 2074 | simplNonRecJoinPoint :: SimplEnv -> InId -> InExpr
|
| 2071 | 2075 | -> InExpr -> SimplCont
|
| 2072 | 2076 | -> SimplM (SimplFloats, OutExpr)
|
| 2073 | -simplNonRecJoinPoint env bndr rhs body cont
|
|
| 2077 | +simplNonRecJoinPoint env0 bndr rhs body cont0
|
|
| 2074 | 2078 | = assert (isJoinId bndr) $
|
| 2075 | - wrapJoinCont do_case_case env cont $ \ env cont ->
|
|
| 2079 | + wrapJoinCont do_case_case env0 bndr cont0 $
|
|
| 2080 | + \ WJC { wjc_bind_env = env, wjc_bind_cont = bind_cont, wjc_body_cont = body_cont } ->
|
|
| 2076 | 2081 | do { -- We push join_cont into the join RHS and the body;
|
| 2077 | 2082 | -- and wrap wrap_cont around the whole thing
|
| 2078 | - ; let (mult, res_ty)
|
|
| 2079 | - -- SLD TODO
|
|
| 2080 | - | Just QuasiJoinPoint <- joinId_maybe bndr
|
|
| 2081 | - = (idMult bndr, peelJoinResTy (idJoinArity bndr) $ substTy env (idType bndr))
|
|
| 2082 | - | otherwise
|
|
| 2083 | - = (contHoleScaling cont, contResultType cont)
|
|
| 2083 | + let mult = contHoleScaling bind_cont
|
|
| 2084 | + res_ty = contResultType bind_cont
|
|
| 2084 | 2085 | ; (env1, bndr1) <- simplNonRecJoinBndr env bndr mult res_ty
|
| 2085 | - ; (env2, bndr2) <- addBndrRules env1 bndr bndr1 (BC_Join NonRecursive cont)
|
|
| 2086 | - ; (floats1, env3) <- simplJoinBind NonRecursive cont (bndr,env) (bndr2,env2) (rhs,env)
|
|
| 2087 | - ; (floats2, body') <- simplExprF env3 body cont
|
|
| 2086 | + ; (env2, bndr2) <- addBndrRules env1 bndr bndr1 (BC_Join NonRecursive bind_cont)
|
|
| 2087 | + ; (floats1, env3) <- simplJoinBind NonRecursive bind_cont (bndr,env) (bndr2,env2) (rhs,env)
|
|
| 2088 | + ; (floats2, body') <- simplExprF env3 body body_cont
|
|
| 2088 | 2089 | ; return (floats1 `addFloats` floats2, body') }
|
| 2089 | 2090 | where
|
| 2090 | 2091 | do_case_case
|
| 2091 | 2092 | | Just TrueJoinPoint <- joinId_maybe bndr
|
| 2092 | - = seCaseCase env
|
|
| 2093 | + = seCaseCase env0
|
|
| 2093 | 2094 | | otherwise
|
| 2094 | 2095 | = False
|
| 2095 | 2096 | |
| 2096 | 2097 | simplRecJoinPoint :: SimplEnv -> [(InId, InExpr)]
|
| 2097 | 2098 | -> InExpr -> SimplCont
|
| 2098 | 2099 | -> SimplM (SimplFloats, OutExpr)
|
| 2099 | -simplRecJoinPoint env pairs body cont
|
|
| 2100 | - = wrapJoinCont do_case_case env cont $ \ env cont ->
|
|
| 2101 | - do { let bndrs = map fst pairs
|
|
| 2102 | - (mult, res_ty)
|
|
| 2103 | - -- SLD TODO
|
|
| 2104 | - | [b] <- bndrs
|
|
| 2105 | - , Just QuasiJoinPoint <- joinId_maybe b
|
|
| 2106 | - = (idMult b, peelJoinResTy (idJoinArity b) $ substTy env (idType b))
|
|
| 2107 | - | otherwise
|
|
| 2108 | - = (contHoleScaling cont, contResultType cont)
|
|
| 2100 | +simplRecJoinPoint env0 pairs body cont0
|
|
| 2101 | + = wrapJoinCont do_case_case env0 (head bndrs) cont0 $
|
|
| 2102 | + \ WJC { wjc_bind_env = env, wjc_bind_cont = bind_cont, wjc_body_cont = body_cont } ->
|
|
| 2103 | + do { let mult = contHoleScaling bind_cont
|
|
| 2104 | + res_ty = contResultType bind_cont
|
|
| 2109 | 2105 | ; env1 <- simplRecJoinBndrs env bndrs mult res_ty
|
| 2110 | 2106 | -- NB: bndrs' don't have unfoldings or rules
|
| 2111 | 2107 | -- We add them as we go down
|
| 2112 | - ; (floats1, env2) <- simplRecBind env1 (BC_Join Recursive cont) pairs
|
|
| 2113 | - ; (floats2, body') <- simplExprF env2 body cont
|
|
| 2108 | + ; (floats1, env2) <- simplRecBind env1 (BC_Join Recursive bind_cont) pairs
|
|
| 2109 | + ; (floats2, body') <- simplExprF env2 body body_cont
|
|
| 2114 | 2110 | ; return (floats1 `addFloats` floats2, body') }
|
| 2115 | 2111 | where
|
| 2112 | + bndrs = map fst pairs
|
|
| 2113 | + |
|
| 2116 | 2114 | do_case_case =
|
| 2117 | - if all ((== Just TrueJoinPoint) . joinId_maybe . fst) pairs
|
|
| 2118 | - then seCaseCase env
|
|
| 2115 | + if all ((== Just TrueJoinPoint) . joinId_maybe) bndrs
|
|
| 2116 | + then seCaseCase env0
|
|
| 2119 | 2117 | else False
|
| 2120 | 2118 | |
| 2121 | 2119 | --------------------
|
| 2120 | + |
|
| 2121 | +-- | Information computed by 'wrapJoinCont'.
|
|
| 2122 | +data WrapJoinCont
|
|
| 2123 | + = WJC
|
|
| 2124 | + { wjc_bind_env :: !SimplEnv
|
|
| 2125 | + , wjc_bind_cont :: !SimplCont
|
|
| 2126 | + , wjc_body_cont :: !SimplCont
|
|
| 2127 | + }
|
|
| 2128 | + |
|
| 2122 | 2129 | wrapJoinCont :: Bool
|
| 2123 | - -> SimplEnv -> SimplCont
|
|
| 2124 | - -> (SimplEnv -> SimplCont -> SimplM (SimplFloats, OutExpr))
|
|
| 2130 | + -> SimplEnv -> InId -> SimplCont
|
|
| 2131 | + -> (WrapJoinCont -> SimplM (SimplFloats, OutExpr))
|
|
| 2125 | 2132 | -> SimplM (SimplFloats, OutExpr)
|
| 2126 | 2133 | -- Deal with making the continuation duplicable if necessary,
|
| 2127 | 2134 | -- and with the no-case-of-case situation.
|
| 2128 | -wrapJoinCont do_case_case env cont thing_inside
|
|
| 2135 | +wrapJoinCont do_case_case env join_bndr cont thing_inside
|
|
| 2129 | 2136 | | contIsStop cont -- Common case; no need for fancy footwork
|
| 2130 | - = thing_inside env cont
|
|
| 2137 | + = thing_inside $
|
|
| 2138 | + WJC { wjc_bind_env = env
|
|
| 2139 | + , wjc_bind_cont = if do_case_case then cont else no_case_case_bind_cont
|
|
| 2140 | + , wjc_body_cont = cont
|
|
| 2141 | + }
|
|
| 2131 | 2142 | |
| 2132 | 2143 | | do_case_case
|
| 2133 | 2144 | -- Normal situation: do the "case-of-case" transformation.
|
| 2134 | 2145 | -- See Note [Join points and case-of-case].
|
| 2135 | 2146 | = do { (floats1, cont') <- mkDupableCont env cont
|
| 2136 | - ; (floats2, result) <- thing_inside (env `setInScopeFromF` floats1) cont'
|
|
| 2147 | + ; let wjc = WJC { wjc_bind_env = env `setInScopeFromF` floats1
|
|
| 2148 | + , wjc_bind_cont = cont'
|
|
| 2149 | + , wjc_body_cont = cont'
|
|
| 2150 | + }
|
|
| 2151 | + ; (floats2, result) <- thing_inside wjc
|
|
| 2137 | 2152 | ; return (floats1 `addFloats` floats2, result) }
|
| 2138 | 2153 | |
| 2139 | 2154 | | otherwise
|
| 2140 | 2155 | -- No "case-of-case" transformation.
|
| 2141 | 2156 | -- See Note [Join points with -fno-case-of-case].
|
| 2142 | - = do { (floats1, expr1) <- thing_inside env (mkBoringStop (contHoleType cont))
|
|
| 2157 | + = do { let
|
|
| 2158 | + wjc = WJC { wjc_bind_env = env
|
|
| 2159 | + , wjc_bind_cont = no_case_case_bind_cont
|
|
| 2160 | + , wjc_body_cont = mkBoringStop (contHoleType cont)
|
|
| 2161 | + }
|
|
| 2162 | + ; (floats1, expr1) <- thing_inside wjc
|
|
| 2143 | 2163 | ; let (floats2, expr2) = wrapJoinFloatsX floats1 expr1
|
| 2144 | 2164 | ; (floats3, expr3) <- rebuild (env `setInScopeFromF` floats2) expr2 cont
|
| 2145 | 2165 | ; return (floats2 `addFloats` floats3, expr3) }
|
| 2166 | + where
|
|
| 2167 | + -- See Wrinkle [Casts and join point result types]
|
|
| 2168 | + join_res_ty = joinResTy (idJoinArity join_bndr)
|
|
| 2169 | + $ substTy env (idType join_bndr)
|
|
| 2170 | + no_case_case_bind_cont = mkBoringStop join_res_ty
|
|
| 2146 | 2171 | |
| 2147 | 2172 | --------------------
|
| 2148 | 2173 | trimJoinCont :: Id -- Used only in error message
|
| ... | ... | @@ -2282,9 +2307,9 @@ As per Note [Join points and case-of-case], we proceed by first applying the |
| 2282 | 2307 | argument to both the join point RHS and the case alternatives:
|
| 2283 | 2308 | |
| 2284 | 2309 | join { j :: Bool -> IO (); j _ = guts arg ] }
|
| 2285 | - in case b of
|
|
| 2286 | - False -> (scctick<foo> jump j True) arg
|
|
| 2287 | - True -> jump j False arg
|
|
| 2310 | + in case b of
|
|
| 2311 | + False -> (scctick<foo> jump j True) arg
|
|
| 2312 | + True -> jump j False arg
|
|
| 2288 | 2313 | |
| 2289 | 2314 | Then we rely on 'trimJoinCont' to remove the argument. In this case, this fails
|
| 2290 | 2315 | for the first branch, because 'trimJoinCont' doesn't look through profiling
|
| ... | ... | @@ -2293,9 +2318,9 @@ end up with, as we don't want to misattribute profiling costs. |
| 2293 | 2318 | We could plausibly transform to the following:
|
| 2294 | 2319 | |
| 2295 | 2320 | join { j :: Bool -> IO (); j scc_or_null _ = (setSCC# scc_or_null guts) arg ] }
|
| 2296 | - in case b of
|
|
| 2297 | - False -> jump j <foo> True
|
|
| 2298 | - True -> jump j null False
|
|
| 2321 | + in case b of
|
|
| 2322 | + False -> jump j <foo> True
|
|
| 2323 | + True -> jump j null False
|
|
| 2299 | 2324 | |
| 2300 | 2325 | where `setSCC#` is a new primop that would set the current cost centre pointer
|
| 2301 | 2326 | (or no-op if the given pointer is null).
|
| ... | ... | @@ -2307,17 +2332,17 @@ So instead, for now, we simply disallow the case-of-case transformation for 'j'. |
| 2307 | 2332 | Similarly for casts:
|
| 2308 | 2333 | |
| 2309 | 2334 | join { j = blah }
|
| 2310 | - in case e of
|
|
| 2311 | - False -> j True |> co1
|
|
| 2312 | - True -> j False |> co2
|
|
| 2335 | + in case e of
|
|
| 2336 | + False -> j True |> co1
|
|
| 2337 | + True -> j False |> co2
|
|
| 2313 | 2338 | |
| 2314 | 2339 | if we want to apply this to an argument 'arg', we would need to perform the
|
| 2315 | 2340 | following transformation:
|
| 2316 | 2341 | |
| 2317 | 2342 | join { j co = ( blah |> co ) arg }
|
| 2318 | - in case e of
|
|
| 2319 | - False -> j co1 True
|
|
| 2320 | - True -> j co2 False
|
|
| 2343 | + in case e of
|
|
| 2344 | + False -> j co1 True
|
|
| 2345 | + True -> j co2 False
|
|
| 2321 | 2346 | |
| 2322 | 2347 | in which we add a coercion argument to the join point. Again, this is not a
|
| 2323 | 2348 | transformation we currently implement, so we instead prevent case-of-case for
|
| ... | ... | @@ -2339,6 +2364,33 @@ we proceed as follows: |
| 2339 | 2364 | If we are dealing with a quasi join point, we switch off the case-of-case
|
| 2340 | 2365 | transformation.
|
| 2341 | 2366 | |
| 2367 | +Wrinkle [Casts and join point result types]
|
|
| 2368 | + |
|
| 2369 | + When dealing with a quasi joint-point, we must preserve the original type of
|
|
| 2370 | + the join point instead of transforming the type (as in Core.Opt.Simplify.Env.adjustJoinPointType).
|
|
| 2371 | + This is because we don't trim the continuation like we do in
|
|
| 2372 | + Note [Join points and case-of-case].
|
|
| 2373 | + |
|
| 2374 | + For example, suppose we have:
|
|
| 2375 | + |
|
| 2376 | + type family F a
|
|
| 2377 | + |
|
| 2378 | + join
|
|
| 2379 | + j :: forall a. a -> F a
|
|
| 2380 | + j @a x = ...
|
|
| 2381 | + in case e of
|
|
| 2382 | + False -> j @T1 x1 |> ( co1 :: F T1 ~ Int )
|
|
| 2383 | + True -> j @T2 x2 |> ( co2 :: F T2 ~ Int )
|
|
| 2384 | + |
|
| 2385 | + If we used 'contHoleType cont' to compute the result type of 'j', we would
|
|
| 2386 | + change the result type of 'j' to 'Int', when it needs to remain 'F a'.
|
|
| 2387 | + |
|
| 2388 | + Instead, we avoid doing that and re-compute the result type of 'j' using
|
|
| 2389 | + 'joinResTy' to get 'F a', as required.
|
|
| 2390 | + |
|
| 2391 | +See also Note [Exitification and quasi join points] in GHC.Core.Opt.Exitify
|
|
| 2392 | +for another wrinkle.
|
|
| 2393 | + |
|
| 2342 | 2394 | ************************************************************************
|
| 2343 | 2395 | * *
|
| 2344 | 2396 | Variables
|
| ... | ... | @@ -79,7 +79,8 @@ module GHC.Types.Id ( |
| 79 | 79 | |
| 80 | 80 | -- ** Join variables
|
| 81 | 81 | JoinId, JoinPointHood,
|
| 82 | - isJoinId, joinId_maybe, idJoinPointHood, idJoinArity,
|
|
| 82 | + isJoinId, joinId_maybe, joinPointType_maybe,
|
|
| 83 | + idJoinPointHood, idJoinArity,
|
|
| 83 | 84 | asJoinId, asJoinId_maybe, zapJoinId,
|
| 84 | 85 | |
| 85 | 86 | -- ** Inline pragma stuff
|
| ... | ... | @@ -172,6 +173,8 @@ import GHC.Data.FastString |
| 172 | 173 | import GHC.Utils.Misc
|
| 173 | 174 | import GHC.Utils.Outputable
|
| 174 | 175 | import GHC.Utils.Panic
|
| 176 | +import qualified Data.List.NonEmpty as NE
|
|
| 177 | +import qualified Data.Semigroup as Semi
|
|
| 175 | 178 | |
| 176 | 179 | -- infixl so you can say (id `set` a `set` b)
|
| 177 | 180 | infixl 1 `setIdUnfolding`,
|
| ... | ... | @@ -584,6 +587,13 @@ joinId_maybe id |
| 584 | 587 | _ -> Nothing
|
| 585 | 588 | | otherwise = Nothing
|
| 586 | 589 | |
| 590 | +joinPointType_maybe :: (a -> Maybe JoinPointType) -> [a] -> Maybe JoinPointType
|
|
| 591 | +joinPointType_maybe f xs = do
|
|
| 592 | + xsNE <- NE.nonEmpty xs
|
|
| 593 | + Semi.sconcat <$> traverse f xsNE
|
|
| 594 | + -- traverse: either all are join points or none are
|
|
| 595 | + -- sconcat: only a 'TrueJoinPoint' if all are
|
|
| 596 | + |
|
| 587 | 597 | -- | Doesn't return strictness marks
|
| 588 | 598 | idJoinPointHood :: Var -> JoinPointHood
|
| 589 | 599 | idJoinPointHood id
|