sheaf pushed to branch wip/andreask/ticked_joins at Glasgow Haskell Compiler / GHC

Commits:

4 changed files:

Changes:

  • compiler/GHC/Core/Opt/Exitify.hs
    ... ... @@ -45,12 +45,14 @@ import GHC.Core.Type
    45 45
     import GHC.Types.Var
    
    46 46
     import GHC.Types.Id
    
    47 47
     import GHC.Types.Id.Info
    
    48
    +import GHC.Types.Tickish ( GenTickish(..), tickishCanScopeJoin )
    
    49
    +
    
    48 50
     import GHC.Types.Var.Set
    
    49 51
     import GHC.Types.Var.Env
    
    50
    -import GHC.Types.Basic( JoinPointHood(..) )
    
    51 52
     
    
    52 53
     import GHC.Utils.Monad.State.Strict
    
    53 54
     import GHC.Utils.Misc( mapSnd )
    
    55
    +import GHC.Utils.Outputable
    
    54 56
     
    
    55 57
     import GHC.Data.FastString
    
    56 58
     
    
    ... ... @@ -93,23 +95,23 @@ exitifyProgram binds = map goTopLvl binds
    93 95
           where
    
    94 96
             in_scope' = in_scope `extendInScopeSet` bndr
    
    95 97
     
    
    96
    -    go in_scope (Let (Rec pairs) body)
    
    97
    -      | is_join_rec = mkLets (exitifyRec in_scope' pairs') body'
    
    98
    -      | otherwise   = Let (Rec pairs') body'
    
    98
    +    go in_scope (Let (Rec pairs) body) =
    
    99
    +      case joinPointType_maybe (joinId_maybe . fst) pairs of
    
    100
    +        Just join_ty -> mkLets (exitifyRec join_ty in_scope' pairs') body'
    
    101
    +        Nothing -> Let (Rec pairs') body'
    
    99 102
           where
    
    100
    -        is_join_rec = any (isJoinId . fst) pairs
    
    101 103
             in_scope'   = in_scope `extendInScopeSetBind` (Rec pairs)
    
    102 104
             pairs'      = mapSnd (go in_scope') pairs
    
    103 105
             body'       = go in_scope' body
    
    104 106
     
    
    105 107
     
    
    106 108
     -- | State Monad used inside `exitify`
    
    107
    -type ExitifyM =  State [(JoinId, CoreExpr)]
    
    109
    +type ExitifyM = State [(JoinId, CoreExpr)]
    
    108 110
     
    
    109 111
     -- | Given a recursive group of a joinrec, identifies “exit paths” and binds them as
    
    110 112
     --   join-points outside the joinrec.
    
    111
    -exitifyRec :: InScopeSet -> [(Var,CoreExpr)] -> [CoreBind]
    
    112
    -exitifyRec in_scope pairs
    
    113
    +exitifyRec :: JoinPointType -> InScopeSet -> [(Var,CoreExpr)] -> [CoreBind]
    
    114
    +exitifyRec joinrec_join_ty in_scope pairs
    
    113 115
       = [ NonRec xid rhs | (xid,rhs) <- exits ] ++ [Rec pairs']
    
    114 116
       where
    
    115 117
         -- We need the set of free variables of many subexpressions here, so
    
    ... ... @@ -124,7 +126,7 @@ exitifyRec in_scope pairs
    124 126
             forM ann_pairs $ \(x,rhs) -> do
    
    125 127
                 -- go past the lambdas of the join point
    
    126 128
                 let (args, body) = collectNAnnBndrs (idJoinArity x) rhs
    
    127
    -            body' <- go args body
    
    129
    +            body' <- go joinrec_join_ty args body -- (ExitJoin2): start with JoinPointType of parent joinrec
    
    128 130
                 let rhs' = mkLams args body'
    
    129 131
                 return (x, rhs')
    
    130 132
     
    
    ... ... @@ -135,40 +137,41 @@ exitifyRec in_scope pairs
    135 137
         -- variables bound on the way and lifts it out as a join point.
    
    136 138
         --
    
    137 139
         -- ExitifyM is a state monad to keep track of floated binds
    
    138
    -    go :: [Var]           -- Variables that are in-scope here, but
    
    139
    -                          -- not in scope at the joinrec; that is,
    
    140
    -                          -- we must potentially abstract over them.
    
    141
    -                          -- Invariant: they are kept in dependency order
    
    140
    +    go :: JoinPointType -- what join point type to create; see Note [Exitification and quasi join points]
    
    141
    +       -> [Var] -- Variables that are in-scope here, but
    
    142
    +                -- not in scope at the joinrec; that is,
    
    143
    +                -- we must potentially abstract over them.
    
    144
    +                -- Invariant: they are kept in dependency order
    
    142 145
            -> CoreExprWithFVs -- Current expression in tail position
    
    143 146
            -> ExitifyM CoreExpr
    
    144 147
     
    
    145 148
         -- We first look at the expression (no matter what it shape is)
    
    146 149
         -- and determine if we can turn it into a exit join point
    
    147
    -    go captured ann_e
    
    150
    +    go exit_join_ty captured ann_e
    
    148 151
             | -- An exit expression has no recursive calls
    
    149 152
               let fvs = dVarSetToVarSet (freeVarsOf ann_e)
    
    150 153
             , disjointVarSet fvs recursive_calls
    
    151
    -        = go_exit captured (deAnnotate ann_e) fvs
    
    154
    +        = go_exit exit_join_ty captured (deAnnotate ann_e) fvs
    
    152 155
     
    
    153 156
         -- We could not turn it into a exit join point. So now recurse
    
    154 157
         -- into all expression where eligible exit join points might sit,
    
    155 158
         -- i.e. into all tail-call positions:
    
    156 159
     
    
    157 160
         -- Case right hand sides are in tail-call position
    
    158
    -    go captured (_, AnnCase scrut bndr ty alts) = do
    
    161
    +    go exit_join_ty captured (_, AnnCase scrut bndr ty alts) = do
    
    159 162
             alts' <- forM alts $ \(AnnAlt dc pats rhs) -> do
    
    160
    -            rhs' <- go (captured ++ [bndr] ++ pats) rhs
    
    163
    +            rhs' <- go exit_join_ty (captured ++ [bndr] ++ pats) rhs
    
    161 164
                 return (Alt dc pats rhs')
    
    162 165
             return $ Case (deAnnotate scrut) bndr ty alts'
    
    163 166
     
    
    164
    -    go captured (_, AnnLet ann_bind body)
    
    167
    +    go exit_join_ty captured (_, AnnLet ann_bind body)
    
    165 168
             -- join point, RHS and body are in tail-call position
    
    166 169
             | AnnNonRec j rhs <- ann_bind
    
    167 170
             , JoinPoint { joinPointArity = join_arity } <- idJoinPointHood j
    
    168 171
             = do let (params, join_body) = collectNAnnBndrs join_arity rhs
    
    169
    -             join_body' <- go (captured ++ params) join_body
    
    172
    +             join_body' <- go exit_join_ty (captured ++ params) join_body
    
    170 173
                  let rhs' = mkLams params join_body'
    
    171
    -             body' <- go (captured ++ [j]) body
    
    174
    +             body' <- go exit_join_ty (captured ++ [j]) body
    
    172 175
                  return $ Let (NonRec j rhs') body'
    
    173 176
     
    
    174 177
             -- rec join point, RHSs and body are in tail-call position
    
    ... ... @@ -178,30 +181,41 @@ exitifyRec in_scope pairs
    178 181
                  pairs' <- forM pairs $ \(j,rhs) -> do
    
    179 182
                      let join_arity = idJoinArity j
    
    180 183
                          (params, join_body) = collectNAnnBndrs join_arity rhs
    
    181
    -                 join_body' <- go (captured ++ js ++ params) join_body
    
    184
    +                 join_body' <- go exit_join_ty (captured ++ js ++ params) join_body
    
    182 185
                      let rhs' = mkLams params join_body'
    
    183 186
                      return (j, rhs')
    
    184
    -             body' <- go (captured ++ js) body
    
    187
    +             body' <- go exit_join_ty (captured ++ js) body
    
    185 188
                  return $ Let (Rec pairs') body'
    
    186 189
     
    
    187 190
             -- normal Let, only the body is in tail-call position
    
    188 191
             | otherwise
    
    189
    -        = do body' <- go (captured ++ bindersOf bind ) body
    
    192
    +        = do body' <- go exit_join_ty (captured ++ bindersOf bind ) body
    
    190 193
                  return $ Let bind body'
    
    191 194
           where bind = deAnnBind ann_bind
    
    192 195
     
    
    196
    +    -- (ExitJoin1) from Note [Exitification and quasi join points]
    
    197
    +    go _ captured (_, AnnCast ann_e (_, co)) = do
    
    198
    +        e' <- go QuasiJoinPoint captured ann_e
    
    199
    +        return (Cast e' co)
    
    200
    +    go exit_join_ty captured (_, AnnTick tickish ann_e)
    
    201
    +      | tickishCanScopeJoin tickish
    
    202
    +      = Tick tickish <$> go exit_join_ty captured ann_e
    
    203
    +      | ProfNote {} <- tickish
    
    204
    +      = Tick tickish <$> go QuasiJoinPoint captured ann_e
    
    205
    +
    
    193 206
         -- Cannot be turned into an exit join point, but also has no
    
    194 207
         -- tail-call subexpression. Nothing to do here.
    
    195
    -    go _ ann_e = return (deAnnotate ann_e)
    
    208
    +    go _ _ ann_e = return (deAnnotate ann_e)
    
    196 209
     
    
    197 210
         ---------------------
    
    198
    -    go_exit :: [Var]      -- Variables captured locally
    
    211
    +    go_exit :: JoinPointType -- what join point type to create; see Note [Exitification and quasi join points]
    
    212
    +            -> [Var]      -- Variables captured locally
    
    199 213
                 -> CoreExpr   -- An exit expression
    
    200 214
                 -> VarSet     -- Free vars of the expression
    
    201 215
                 -> ExitifyM CoreExpr
    
    202 216
         -- go_exit deals with a tail expression that is floatable
    
    203 217
         -- out as an exit point; that is, it mentions no recursive calls
    
    204
    -    go_exit captured e fvs
    
    218
    +    go_exit exit_join_ty captured e fvs
    
    205 219
           -- Do not touch an expression that is already a join jump where all arguments
    
    206 220
           -- are captured variables. See Note [Idempotency]
    
    207 221
           -- But _do_ float join jumps with interesting arguments.
    
    ... ... @@ -226,7 +240,7 @@ exitifyRec in_scope pairs
    226 240
                  let rhs   = mkLams abs_vars e
    
    227 241
                      avoid = in_scope `extendInScopeSetList` captured
    
    228 242
                  -- Remember this binding under a suitable name
    
    229
    -           ; v <- addExit avoid (length abs_vars) rhs
    
    243
    +           ; v <- addExit avoid exit_join_ty (length abs_vars) rhs
    
    230 244
                  -- And jump to it from here
    
    231 245
                ; return $ mkVarApps (Var v) abs_vars }
    
    232 246
     
    
    ... ... @@ -263,7 +277,7 @@ exitifyRec in_scope pairs
    263 277
     --  * any bound variables (captured)
    
    264 278
     --  * any exit join points created so far.
    
    265 279
     mkExitJoinId :: InScopeSet -> Type -> JoinPointType -> JoinArity -> ExitifyM JoinId
    
    266
    -mkExitJoinId in_scope ty join_ty join_arity = do
    
    280
    +mkExitJoinId in_scope ty exit_join_ty join_arity = do
    
    267 281
         fs <- get
    
    268 282
         let avoid = in_scope `extendInScopeSetList` (map fst fs)
    
    269 283
                              `extendInScopeSet` exit_id_tmpl -- just cosmetics
    
    ... ... @@ -271,17 +285,65 @@ mkExitJoinId in_scope ty join_ty join_arity = do
    271 285
       where
    
    272 286
         exit_id_tmpl =
    
    273 287
           asJoinId (mkSysLocal (fsLit "exit") initExitJoinUnique ManyTy ty)
    
    274
    -        join_ty join_arity
    
    288
    +        exit_join_ty join_arity
    
    275 289
     
    
    276
    -addExit :: InScopeSet -> JoinArity -> CoreExpr -> ExitifyM JoinId
    
    277
    -addExit in_scope join_arity rhs = do
    
    290
    +addExit :: InScopeSet -> JoinPointType -> JoinArity -> CoreExpr -> ExitifyM JoinId
    
    291
    +addExit in_scope exit_join_ty join_arity rhs = do
    
    278 292
         -- Pick a suitable name
    
    279 293
         let ty = exprType rhs
    
    280
    -    v <- mkExitJoinId in_scope ty TrueJoinPoint join_arity
    
    294
    +    v <- mkExitJoinId in_scope ty exit_join_ty join_arity
    
    281 295
         fs <- get
    
    282 296
         put ((v,rhs):fs)
    
    283 297
         return v
    
    284 298
     
    
    299
    +{- Note [Exitification and quasi join points]
    
    300
    +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
    
    301
    +When we float an exit path, we must determine if the new exit join point
    
    302
    +should be a true join point or a quasi join point, in the sense of
    
    303
    +Note [Quasi join points] in GHC.Core.Opt.Simplify.Iteration).
    
    304
    +
    
    305
    +The new exit join point must be a quasi join point if either of the following
    
    306
    +conditions apply:
    
    307
    +
    
    308
    +  (ExitJoin1) The exit path occurs under a cast or a profiling tick.
    
    309
    +
    
    310
    +  (ExitJoin2) The original joinrec was a quasi join point.
    
    311
    +
    
    312
    +Rationale for (ExitJoin1):
    
    313
    +
    
    314
    +  Suppose we have:
    
    315
    +
    
    316
    +    joinrec j x = ... case ... of alts -> e |> co ... in ...
    
    317
    +
    
    318
    +  After exitifying 'e' to 'exit':
    
    319
    +
    
    320
    +    join exit y = e in
    
    321
    +    joinrec j x = ... case ... of alts -> (exit y) |> co ... in ...
    
    322
    +
    
    323
    +  Because the jump to 'exit' occurs under a cast, 'exit' must be classified
    
    324
    +  as a quasi join point.
    
    325
    +
    
    326
    +Rationale for (ExitJoin2):
    
    327
    +
    
    328
    +  Suppose we have:
    
    329
    +
    
    330
    +  quasijoinrec j x = case x of { 0 -> 100; _ -> j (x-1) } in j 0 |> co
    
    331
    +
    
    332
    +  If we float an exit out of 'j', we end up with
    
    333
    +
    
    334
    +  join exit = 100 in
    
    335
    +  quasijoinrec j x = case x of { 0 -> exit ; _ -> j (x-1) } in j 0 |> co
    
    336
    +
    
    337
    +  Now suppose we inline j and simplify; we end up with:
    
    338
    +
    
    339
    +  join exit = 100 in exit |> co
    
    340
    +
    
    341
    +  We see now that 'exit' must be a quasi join point, due to the cast.
    
    342
    +
    
    343
    +  Hence: exit join points for a parent quasi join point must themselves be
    
    344
    +  quasi join points.
    
    345
    +-}
    
    346
    +
    
    285 347
     {-
    
    286 348
     Note [Interesting expression]
    
    287 349
     ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
    

  • compiler/GHC/Core/Opt/OccurAnal.hs
    ... ... @@ -68,7 +68,6 @@ import GHC.Builtin.Names( runRWKey )
    68 68
     import GHC.Unit.Module( Module )
    
    69 69
     
    
    70 70
     import Data.List (mapAccumL)
    
    71
    -import qualified Data.List.NonEmpty as NE
    
    72 71
     import qualified Data.Semigroup as Semi
    
    73 72
     
    
    74 73
     {-
    
    ... ... @@ -4118,10 +4117,7 @@ setBinderOcc occ_info bndr
    4118 4117
     -- See Note [Invariants on join points] in "GHC.Core".
    
    4119 4118
     decideRecJoinPointHood :: TopLevelFlag -> UsageDetails
    
    4120 4119
                            -> [CoreBndr] -> Maybe JoinPointType
    
    4121
    -decideRecJoinPointHood lvl usage bndrs = do
    
    4122
    -  bndrsNE <- NE.nonEmpty bndrs
    
    4123
    -  -- Invariant 3: Either all are join points or none are
    
    4124
    -  Semi.sconcat <$> traverse ok bndrsNE
    
    4120
    +decideRecJoinPointHood lvl usage = joinPointType_maybe ok
    
    4125 4121
       where
    
    4126 4122
         ok bndr = okForJoinPoint lvl bndr (lookupTailCallInfo usage bndr)
    
    4127 4123
     
    

  • compiler/GHC/Core/Opt/Simplify/Iteration.hs
    ... ... @@ -2056,93 +2056,118 @@ is a join point, and what 'cont' is, in a value of type MaybeJoinCont
    2056 2056
     of a SpecConstr-generated RULE for a join point.
    
    2057 2057
     -}
    
    2058 2058
     
    
    2059
    --- SLD TODO horrible logic that must be removed
    
    2060
    -peelJoinResTy :: Int -> Type -> Type
    
    2061
    -peelJoinResTy 0 ty = ty
    
    2062
    -peelJoinResTy n ty
    
    2063
    -  | Just (_bndr, inner_ty) <- splitForAllTyCoVar_maybe ty
    
    2064
    -  = peelJoinResTy n inner_ty
    
    2065
    -  | Just (_, _mult, _arg, res_ty) <- splitFunTy_maybe ty
    
    2066
    -  = peelJoinResTy (n-1) res_ty
    
    2067
    -  | otherwise
    
    2068
    -  = ty
    
    2059
    +joinResTy :: HasDebugCallStack => JoinArity -> Type -> Type
    
    2060
    +joinResTy n0 ty0 = go n0 ty0
    
    2061
    +   where
    
    2062
    +    go 0 ty = ty
    
    2063
    +    go n ty
    
    2064
    +      | Just (_bndr, res_ty) <- splitPiTy_maybe ty
    
    2065
    +      = go (n-1) res_ty
    
    2066
    +      | otherwise
    
    2067
    +      = pprPanic "joinResTy" $
    
    2068
    +         vcat [ text "join arity:" <+> ppr n0
    
    2069
    +              , text "join ty:" <+> ppr ty0
    
    2070
    +              , text "n:" <+> ppr n
    
    2071
    +              , text "ty:" <+> ppr ty
    
    2072
    +              ]
    
    2069 2073
     
    
    2070 2074
     simplNonRecJoinPoint :: SimplEnv -> InId -> InExpr
    
    2071 2075
                          -> InExpr -> SimplCont
    
    2072 2076
                          -> SimplM (SimplFloats, OutExpr)
    
    2073
    -simplNonRecJoinPoint env bndr rhs body cont
    
    2077
    +simplNonRecJoinPoint env0 bndr rhs body cont0
    
    2074 2078
        = assert (isJoinId bndr) $
    
    2075
    -     wrapJoinCont do_case_case env cont $ \ env cont ->
    
    2079
    +     wrapJoinCont do_case_case env0 bndr cont0 $
    
    2080
    +     \ WJC { wjc_bind_env = env, wjc_bind_cont = bind_cont, wjc_body_cont = body_cont } ->
    
    2076 2081
          do { -- We push join_cont into the join RHS and the body;
    
    2077 2082
               -- and wrap wrap_cont around the whole thing
    
    2078
    -        ; let (mult, res_ty)
    
    2079
    -                -- SLD TODO
    
    2080
    -                | Just QuasiJoinPoint <- joinId_maybe bndr
    
    2081
    -                = (idMult bndr, peelJoinResTy (idJoinArity bndr) $ substTy env (idType bndr))
    
    2082
    -                | otherwise
    
    2083
    -                = (contHoleScaling cont, contResultType cont)
    
    2083
    +          let mult   = contHoleScaling bind_cont
    
    2084
    +              res_ty = contResultType  bind_cont
    
    2084 2085
             ; (env1, bndr1)    <- simplNonRecJoinBndr env bndr mult res_ty
    
    2085
    -        ; (env2, bndr2)    <- addBndrRules env1 bndr bndr1 (BC_Join NonRecursive cont)
    
    2086
    -        ; (floats1, env3)  <- simplJoinBind NonRecursive cont (bndr,env) (bndr2,env2) (rhs,env)
    
    2087
    -        ; (floats2, body') <- simplExprF env3 body cont
    
    2086
    +        ; (env2, bndr2)    <- addBndrRules env1 bndr bndr1 (BC_Join NonRecursive bind_cont)
    
    2087
    +        ; (floats1, env3)  <- simplJoinBind NonRecursive bind_cont (bndr,env) (bndr2,env2) (rhs,env)
    
    2088
    +        ; (floats2, body') <- simplExprF env3 body body_cont
    
    2088 2089
             ; return (floats1 `addFloats` floats2, body') }
    
    2089 2090
       where
    
    2090 2091
         do_case_case
    
    2091 2092
           | Just TrueJoinPoint <- joinId_maybe bndr
    
    2092
    -      = seCaseCase env
    
    2093
    +      = seCaseCase env0
    
    2093 2094
           | otherwise
    
    2094 2095
           = False
    
    2095 2096
     
    
    2096 2097
     simplRecJoinPoint :: SimplEnv -> [(InId, InExpr)]
    
    2097 2098
                       -> InExpr -> SimplCont
    
    2098 2099
                       -> SimplM (SimplFloats, OutExpr)
    
    2099
    -simplRecJoinPoint env pairs body cont
    
    2100
    -  = wrapJoinCont do_case_case env cont $ \ env cont ->
    
    2101
    -    do { let bndrs  = map fst pairs
    
    2102
    -             (mult, res_ty)
    
    2103
    -                -- SLD TODO
    
    2104
    -                | [b] <- bndrs
    
    2105
    -                , Just QuasiJoinPoint <- joinId_maybe b
    
    2106
    -                = (idMult b, peelJoinResTy (idJoinArity b) $ substTy env (idType b))
    
    2107
    -                | otherwise
    
    2108
    -                = (contHoleScaling cont, contResultType cont)
    
    2100
    +simplRecJoinPoint env0 pairs body cont0
    
    2101
    +  = wrapJoinCont do_case_case env0 (head bndrs) cont0 $
    
    2102
    +    \ WJC { wjc_bind_env = env, wjc_bind_cont = bind_cont, wjc_body_cont = body_cont } ->
    
    2103
    +    do { let mult   = contHoleScaling bind_cont
    
    2104
    +             res_ty = contResultType  bind_cont
    
    2109 2105
            ; env1 <- simplRecJoinBndrs env bndrs mult res_ty
    
    2110 2106
                    -- NB: bndrs' don't have unfoldings or rules
    
    2111 2107
                    -- We add them as we go down
    
    2112
    -       ; (floats1, env2)  <- simplRecBind env1 (BC_Join Recursive cont) pairs
    
    2113
    -       ; (floats2, body') <- simplExprF env2 body cont
    
    2108
    +       ; (floats1, env2)  <- simplRecBind env1 (BC_Join Recursive bind_cont) pairs
    
    2109
    +       ; (floats2, body') <- simplExprF env2 body body_cont
    
    2114 2110
            ; return (floats1 `addFloats` floats2, body') }
    
    2115 2111
       where
    
    2112
    +    bndrs = map fst pairs
    
    2113
    +
    
    2116 2114
         do_case_case =
    
    2117
    -      if all ((== Just TrueJoinPoint) . joinId_maybe . fst) pairs
    
    2118
    -      then seCaseCase env
    
    2115
    +      if all ((== Just TrueJoinPoint) . joinId_maybe) bndrs
    
    2116
    +      then seCaseCase env0
    
    2119 2117
           else False
    
    2120 2118
     
    
    2121 2119
     --------------------
    
    2120
    +
    
    2121
    +-- | Information computed by 'wrapJoinCont'.
    
    2122
    +data WrapJoinCont
    
    2123
    +  = WJC
    
    2124
    +    { wjc_bind_env :: !SimplEnv
    
    2125
    +    , wjc_bind_cont :: !SimplCont
    
    2126
    +    , wjc_body_cont :: !SimplCont
    
    2127
    +    }
    
    2128
    +
    
    2122 2129
     wrapJoinCont :: Bool
    
    2123
    -             -> SimplEnv -> SimplCont
    
    2124
    -             -> (SimplEnv -> SimplCont -> SimplM (SimplFloats, OutExpr))
    
    2130
    +             -> SimplEnv -> InId -> SimplCont
    
    2131
    +             -> (WrapJoinCont -> SimplM (SimplFloats, OutExpr))
    
    2125 2132
                  -> SimplM (SimplFloats, OutExpr)
    
    2126 2133
     -- Deal with making the continuation duplicable if necessary,
    
    2127 2134
     -- and with the no-case-of-case situation.
    
    2128
    -wrapJoinCont do_case_case env cont thing_inside
    
    2135
    +wrapJoinCont do_case_case env join_bndr cont thing_inside
    
    2129 2136
       | contIsStop cont        -- Common case; no need for fancy footwork
    
    2130
    -  = thing_inside env cont
    
    2137
    +  = thing_inside $
    
    2138
    +    WJC { wjc_bind_env  = env
    
    2139
    +        , wjc_bind_cont = if do_case_case then cont else no_case_case_bind_cont
    
    2140
    +        , wjc_body_cont = cont
    
    2141
    +        }
    
    2131 2142
     
    
    2132 2143
       | do_case_case
    
    2133 2144
         -- Normal situation: do the "case-of-case" transformation.
    
    2134 2145
         -- See Note [Join points and case-of-case].
    
    2135 2146
       = do { (floats1, cont')  <- mkDupableCont env cont
    
    2136
    -       ; (floats2, result) <- thing_inside (env `setInScopeFromF` floats1) cont'
    
    2147
    +       ; let wjc = WJC { wjc_bind_env  = env `setInScopeFromF` floats1
    
    2148
    +                       , wjc_bind_cont = cont'
    
    2149
    +                       , wjc_body_cont = cont'
    
    2150
    +                       }
    
    2151
    +       ; (floats2, result) <- thing_inside wjc
    
    2137 2152
            ; return (floats1 `addFloats` floats2, result) }
    
    2138 2153
     
    
    2139 2154
       | otherwise
    
    2140 2155
         -- No "case-of-case" transformation.
    
    2141 2156
         -- See Note [Join points with -fno-case-of-case].
    
    2142
    -  = do { (floats1, expr1) <- thing_inside env (mkBoringStop (contHoleType cont))
    
    2157
    +  = do { let
    
    2158
    +           wjc = WJC { wjc_bind_env  = env
    
    2159
    +                     , wjc_bind_cont = no_case_case_bind_cont
    
    2160
    +                     , wjc_body_cont = mkBoringStop (contHoleType cont)
    
    2161
    +                     }
    
    2162
    +       ; (floats1, expr1) <- thing_inside wjc
    
    2143 2163
            ; let (floats2, expr2) = wrapJoinFloatsX floats1 expr1
    
    2144 2164
            ; (floats3, expr3) <- rebuild (env `setInScopeFromF` floats2) expr2 cont
    
    2145 2165
            ; return (floats2 `addFloats` floats3, expr3) }
    
    2166
    +  where
    
    2167
    +    -- See Wrinkle [Casts and join point result types]
    
    2168
    +    join_res_ty = joinResTy (idJoinArity join_bndr)
    
    2169
    +                $ substTy env (idType join_bndr)
    
    2170
    +    no_case_case_bind_cont = mkBoringStop join_res_ty
    
    2146 2171
     
    
    2147 2172
     --------------------
    
    2148 2173
     trimJoinCont :: Id         -- Used only in error message
    
    ... ... @@ -2282,9 +2307,9 @@ As per Note [Join points and case-of-case], we proceed by first applying the
    2282 2307
     argument to both the join point RHS and the case alternatives:
    
    2283 2308
     
    
    2284 2309
       join { j :: Bool -> IO (); j _ = guts arg ] }
    
    2285
    -    in case b of
    
    2286
    -      False -> (scctick<foo> jump j True) arg
    
    2287
    -      True  ->               jump j False arg
    
    2310
    +  in case b of
    
    2311
    +    False -> (scctick<foo> jump j True) arg
    
    2312
    +    True  ->               jump j False arg
    
    2288 2313
     
    
    2289 2314
     Then we rely on 'trimJoinCont' to remove the argument. In this case, this fails
    
    2290 2315
     for the first branch, because 'trimJoinCont' doesn't look through profiling
    
    ... ... @@ -2293,9 +2318,9 @@ end up with, as we don't want to misattribute profiling costs.
    2293 2318
     We could plausibly transform to the following:
    
    2294 2319
     
    
    2295 2320
       join { j :: Bool -> IO (); j scc_or_null _ = (setSCC# scc_or_null guts) arg ] }
    
    2296
    -    in case b of
    
    2297
    -      False -> jump j <foo> True
    
    2298
    -      True  -> jump j null  False
    
    2321
    +  in case b of
    
    2322
    +    False -> jump j <foo> True
    
    2323
    +    True  -> jump j null  False
    
    2299 2324
     
    
    2300 2325
     where `setSCC#` is a new primop that would set the current cost centre pointer
    
    2301 2326
     (or no-op if the given pointer is null).
    
    ... ... @@ -2307,17 +2332,17 @@ So instead, for now, we simply disallow the case-of-case transformation for 'j'.
    2307 2332
     Similarly for casts:
    
    2308 2333
     
    
    2309 2334
         join { j = blah }
    
    2310
    -      in case e of
    
    2311
    -        False -> j True  |> co1
    
    2312
    -        True  -> j False |> co2
    
    2335
    +    in case e of
    
    2336
    +      False -> j True  |> co1
    
    2337
    +      True  -> j False |> co2
    
    2313 2338
     
    
    2314 2339
     if we want to apply this to an argument 'arg', we would need to perform the
    
    2315 2340
     following transformation:
    
    2316 2341
     
    
    2317 2342
         join { j co = ( blah |> co ) arg }
    
    2318
    -      in case e of
    
    2319
    -        False -> j co1 True
    
    2320
    -        True  -> j co2 False
    
    2343
    +    in case e of
    
    2344
    +      False -> j co1 True
    
    2345
    +      True  -> j co2 False
    
    2321 2346
     
    
    2322 2347
     in which we add a coercion argument to the join point. Again, this is not a
    
    2323 2348
     transformation we currently implement, so we instead prevent case-of-case for
    
    ... ... @@ -2339,6 +2364,33 @@ we proceed as follows:
    2339 2364
          If we are dealing with a quasi join point, we switch off the case-of-case
    
    2340 2365
          transformation.
    
    2341 2366
     
    
    2367
    +Wrinkle [Casts and join point result types]
    
    2368
    +
    
    2369
    +  When dealing with a quasi joint-point, we must preserve the original type of
    
    2370
    +  the join point instead of transforming the type (as in Core.Opt.Simplify.Env.adjustJoinPointType).
    
    2371
    +  This is because we don't trim the continuation like we do in
    
    2372
    +  Note [Join points and case-of-case].
    
    2373
    +
    
    2374
    +  For example, suppose we have:
    
    2375
    +
    
    2376
    +    type family F a
    
    2377
    +
    
    2378
    +    join
    
    2379
    +      j :: forall a. a -> F a
    
    2380
    +      j @a x = ...
    
    2381
    +    in case e of
    
    2382
    +      False -> j @T1 x1 |> ( co1 :: F T1 ~ Int )
    
    2383
    +      True  -> j @T2 x2 |> ( co2 :: F T2 ~ Int )
    
    2384
    +
    
    2385
    +  If we used 'contHoleType cont' to compute the result type of 'j', we would
    
    2386
    +  change the result type of 'j' to 'Int', when it needs to remain 'F a'.
    
    2387
    +
    
    2388
    +  Instead, we avoid doing that and re-compute the result type of 'j' using
    
    2389
    +  'joinResTy' to get 'F a', as required.
    
    2390
    +
    
    2391
    +See also Note [Exitification and quasi join points] in GHC.Core.Opt.Exitify
    
    2392
    +for another wrinkle.
    
    2393
    +
    
    2342 2394
     ************************************************************************
    
    2343 2395
     *                                                                      *
    
    2344 2396
                          Variables
    

  • compiler/GHC/Types/Id.hs
    ... ... @@ -79,7 +79,8 @@ module GHC.Types.Id (
    79 79
     
    
    80 80
             -- ** Join variables
    
    81 81
             JoinId, JoinPointHood,
    
    82
    -        isJoinId, joinId_maybe, idJoinPointHood, idJoinArity,
    
    82
    +        isJoinId, joinId_maybe, joinPointType_maybe,
    
    83
    +        idJoinPointHood, idJoinArity,
    
    83 84
             asJoinId, asJoinId_maybe, zapJoinId,
    
    84 85
     
    
    85 86
             -- ** Inline pragma stuff
    
    ... ... @@ -172,6 +173,8 @@ import GHC.Data.FastString
    172 173
     import GHC.Utils.Misc
    
    173 174
     import GHC.Utils.Outputable
    
    174 175
     import GHC.Utils.Panic
    
    176
    +import qualified Data.List.NonEmpty as NE
    
    177
    +import qualified Data.Semigroup as Semi
    
    175 178
     
    
    176 179
     -- infixl so you can say (id `set` a `set` b)
    
    177 180
     infixl  1 `setIdUnfolding`,
    
    ... ... @@ -584,6 +587,13 @@ joinId_maybe id
    584 587
                     _             -> Nothing
    
    585 588
       | otherwise = Nothing
    
    586 589
     
    
    590
    +joinPointType_maybe :: (a -> Maybe JoinPointType) -> [a] -> Maybe JoinPointType
    
    591
    +joinPointType_maybe f xs = do
    
    592
    +  xsNE <- NE.nonEmpty xs
    
    593
    +  Semi.sconcat <$> traverse f xsNE
    
    594
    +    -- traverse: either all are join points or none are
    
    595
    +    -- sconcat: only a 'TrueJoinPoint' if all are
    
    596
    +
    
    587 597
     -- | Doesn't return strictness marks
    
    588 598
     idJoinPointHood :: Var -> JoinPointHood
    
    589 599
     idJoinPointHood id