[Git][ghc/ghc][wip/T26681] Refactor SetLevels [skip ci]
Simon Peyton Jones pushed to branch wip/T26681 at Glasgow Haskell Compiler / GHC Commits: 459bd467 by Simon Peyton Jones at 2025-12-19T17:38:29+00:00 Refactor SetLevels [skip ci] - - - - - 2 changed files: - compiler/GHC/Core.hs - compiler/GHC/Core/Opt/SetLevels.hs Changes: ===================================== compiler/GHC/Core.hs ===================================== @@ -10,7 +10,8 @@ module GHC.Core ( -- * Main data types Expr(..), Alt(..), Bind(..), AltCon(..), Arg, CoreProgram, CoreExpr, CoreAlt, CoreBind, CoreArg, CoreBndr, - TaggedExpr, TaggedAlt, TaggedBind, TaggedArg, TaggedBndr(..), deTagExpr, + TaggedExpr, TaggedAlt, TaggedBind, TaggedArg, TaggedBndr(..), + deTagExpr, taggedBndrBndr, -- * In/Out type synonyms InId, InBind, InExpr, InAlt, InArg, InType, InKind, @@ -1931,6 +1932,9 @@ type TaggedAlt t = Alt (TaggedBndr t) instance Outputable b => Outputable (TaggedBndr b) where ppr (TB b l) = char '<' <> ppr b <> comma <> ppr l <> char '>' +taggedBndrBndr :: TaggedBndr t -> CoreBndr +taggedBndrBndr (TB b _) = b + deTagExpr :: TaggedExpr t -> CoreExpr deTagExpr (Var v) = Var v deTagExpr (Lit l) = Lit l ===================================== compiler/GHC/Core/Opt/SetLevels.hs ===================================== @@ -111,12 +111,13 @@ import GHC.Types.Demand ( DmdSig, prependArgsDmdSig ) import GHC.Types.Cpr ( CprSig, prependArgsCprSig ) import GHC.Types.Name ( getOccName, mkSystemVarName ) import GHC.Types.Name.Occurrence ( occNameFS ) -import GHC.Types.Unique ( hasKey ) +import GHC.Types.Unique ( Unique, hasKey ) import GHC.Types.Tickish ( tickishIsCode ) import GHC.Types.Unique.Supply import GHC.Types.Unique.DFM import GHC.Types.Basic ( Arity, RecFlag(..), isRec ) +import GHC.Data.Maybe ( orElse ) import GHC.Builtin.Types import GHC.Builtin.Names ( runRWKey ) @@ -290,7 +291,7 @@ lvl_top env is_rec bndr rhs = do { rhs' <- lvlRhs env is_rec (isDeadEndId bndr) NotJoinPoint (freeVars rhs) - ; return (stayPut tOP_LEVEL bndr, rhs') } + ; return (TB bndr (StayPut tOP_LEVEL), rhs') } {- ************************************************************************ @@ -363,8 +364,8 @@ lvlExpr env expr@(_, AnnLam {}) ; return (mkLams new_bndrs new_body) } where (bndrs, body) = collectAnnBndrs expr - (env1, bndrs1) = substBndrsSL NonRecursive env bndrs - (new_env, new_bndrs) = lvlLamBndrs env1 (le_ctxt_lvl env) bndrs1 + bndr_lvl = lamBndrLevel (le_ctxt_lvl env) bndrs + (new_env, new_bndrs) = substAndLvlBndrs env NonRecursive bndr_lvl bndrs -- At one time we called a special version of collectBinders, -- which ignored coercions, because we don't want to split -- a lambda like this (\x -> coerce t (\s -> ...)) @@ -455,11 +456,11 @@ lvlCase env scrut_fvs scrut' case_bndr ty alts do { (env1, (case_bndr' : bs')) <- cloneCaseBndrs env dest_lvl (case_bndr : bs) ; let rhs_env = extendCaseBndrEnv env1 case_bndr scrut' ; body' <- lvlMFE rhs_env True body - ; let alt' = Alt con (map (stayPut dest_lvl) bs') body' - ; return (Case scrut' (TB case_bndr' (FloatMe dest_lvl)) ty' [alt']) } + ; let alt' = Alt con bs' body' + ; return (Case scrut' case_bndr' ty' [alt']) } | otherwise -- Stays put - = do { let (alts_env1, [case_bndr']) = substAndLvlBndrs NonRecursive env incd_lvl [case_bndr] + = do { let (alts_env1, [case_bndr']) = substAndLvlBndrs env NonRecursive incd_lvl [case_bndr] alts_env = extendCaseBndrEnv alts_env1 case_bndr scrut' ; alts' <- mapM (lvl_alt alts_env) alts ; return (Case scrut' case_bndr' ty' alts') } @@ -474,7 +475,7 @@ lvlCase env scrut_fvs scrut' case_bndr ty alts = do { rhs' <- lvlMFE new_env True rhs ; return (Alt con bs' rhs') } where - (new_env, bs') = substAndLvlBndrs NonRecursive alts_env incd_lvl bs + (new_env, bs') = substAndLvlBndrs alts_env NonRecursive incd_lvl bs {- Note [Floating single-alternative cases] ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ @@ -631,13 +632,14 @@ lvlMFE env strict_ctxt ann_expr | float_is_new_lam || exprIsTopLevelBindable expr expr_ty -- No wrapping needed if the type is lifted, or is a literal string -- or if we are wrapping it in one or more value lambdas - = do { expr1 <- lvlFloatRhs abs_vars dest_lvl rhs_env NonRecursive - is_bot_lam NotJoinPoint ann_expr + = do { rhs' <- lvlFloatRhs env dest_lvl abs_vars NonRecursive + is_bot_lam NotJoinPoint ann_expr -- Treat the expr just like a right-hand side - ; var <- newLvlVar expr1 NotJoinPoint is_mk_static - ; let var2 = annotateBotStr var float_n_lams mb_bot_str - ; return (Let (NonRec (TB var2 (FloatMe dest_lvl)) expr1) - (mkVarApps (Var var2) abs_vars)) } + ; var <- newLvlVar rhs' NotJoinPoint is_mk_static + ; let lb = TB (FloatMe dest_lvl) var + lb' = annotateBotStr lb float_n_lams mb_bot_str + ; return (Let (NonRec lb' rhs') + (mkVarApps (Var var abs_vars))) } -- OK, so the float has an unlifted type (not top-level bindable) -- and no new value lambdas (float_is_new_lam is False) @@ -649,14 +651,15 @@ lvlMFE env strict_ctxt ann_expr , BI_Box { bi_data_con = box_dc, bi_inst_con = boxing_expr , bi_boxed_type = box_ty } <- boxingDataCon expr_ty , let [bx_bndr, ubx_bndr] = mkTemplateLocals [box_ty, expr_ty] - = do { expr1 <- lvlExpr rhs_env ann_expr - ; let l1r = incMinorLvlFrom rhs_env - float_rhs = mkLams abs_vars_w_lvls $ + = do { let bndr_lvl = lamBndrLevel dest_lvl abs_vars + ; expr1 <- lvlExpr (env `setCtxtLevel` bndr_lvl) ann_expr + ; let l1r = incMinorLvl bndr_lvl + float_rhs = mkLams (stayPut bndr_lvl abs_vars) $ Case expr1 (stayPut l1r ubx_bndr) box_ty [Alt DEFAULT [] (App boxing_expr (Var ubx_bndr))] ; var <- newLvlVar float_rhs NotJoinPoint is_mk_static - ; let l1u = incMinorLvlFrom env + ; let l1u = incMinorLvl (ctxtLevel env) use_expr = Case (mkVarApps (Var var) abs_vars) (stayPut l1u bx_bndr) expr_ty [Alt (DataAlt box_dc) [stayPut l1u ubx_bndr] (Var ubx_bndr)] @@ -690,8 +693,6 @@ lvlMFE env strict_ctxt ann_expr float_is_new_lam = float_n_lams > 0 float_n_lams = count isId abs_vars - (rhs_env, abs_vars_w_lvls) = lvlLamBndrs env dest_lvl abs_vars - is_mk_static = isJust (collectMakeStaticArgs expr) -- Yuk: See Note [Grand plan for static forms] in GHC.Iface.Tidy.StaticPtrTable @@ -706,7 +707,7 @@ lvlMFE env strict_ctxt ann_expr saves_work = escapes_value_lam -- (a) && not is_hnf -- (b) && not float_is_new_lam -- (c) - escapes_value_lam = dest_lvl `ltMajLvl` (le_ctxt_lvl env) + escapes_value_lam = dest_lvl `ltMajLvl` (ctxtLevel env) -- See Note [Saving allocation] and Note [Floating to the top] saves_alloc = isTopLvl dest_lvl @@ -720,8 +721,9 @@ hasFreeJoin :: LevelEnv -> DVarSet -> Bool -- (In the latter case it won't be a join point any more.) -- Not treating top-level ones specially had a massive effect -- on nofib/minimax/Prog.prog -hasFreeJoin env fvs - = not (maxFvLevel isJoinId env fvs == tOP_LEVEL) +hasFreeJoin env fvs = anyDVarSet bad_join fvs + where + bad_join v = isJoinId v && lookupLevel env v == tOP_LEVEL {- Note [Saving work] ~~~~~~~~~~~~~~~~~~~~~ @@ -1117,18 +1119,20 @@ artificial benchmarks (e.g. integer, queens), but there is no perfect answer. -} -annotateBotStr :: Id -> Arity -> Maybe (Arity, DmdSig, CprSig) -> Id +annotateBotStr :: LevelledBndr -> Arity -> Maybe (Arity, DmdSig, CprSig) -> LevelledBndr -- See Note [Bottoming floats] for why we want to add -- bottoming information right now -- -- n_extra are the number of extra value arguments added during floating -annotateBotStr id n_extra mb_bot_str - | Just (arity, str_sig, cpr_sig) <- mb_bot_str - = id `setIdArity` (arity + n_extra) - `setIdDmdSig` prependArgsDmdSig n_extra str_sig - `setIdCprSig` prependArgsCprSig n_extra cpr_sig - | otherwise - = id +annotateBotStr lb@(TB lvl id) n_extra mb_bot_str + = case mb_bot_str of + Nothing -> lb + Just (arity, str_sig, cpr_sig) + -> TB lvl id' + where + id' = id `setIdArity` (arity + n_extra) + `setIdDmdSig` prependArgsDmdSig n_extra str_sig + `setIdCprSig` prependArgsCprSig n_extra cpr_sig notWorthFloating :: CoreExpr -> [Var] -> Bool -- See Note [notWorthFloating] @@ -1269,26 +1273,26 @@ lvlBind env (AnnNonRec bndr rhs) || not (wantToFloat env NonRecursive dest_lvl is_join is_top_bindable) = -- No float do { rhs' <- lvlRhs env NonRecursive is_bot_lam mb_join_arity rhs - ; let bind_lvl = incMinorLvl (le_ctxt_lvl env) - (env', [bndr']) = substAndLvlBndrs NonRecursive env bind_lvl [bndr] + ; let bind_lvl = incMinorLvl (ctxtLevel env) + (env', [bndr']) = substAndLvlBndrs env NonRecursive bind_lvl [bndr] ; return (NonRec bndr' rhs', env') } -- Otherwise we are going to float | null abs_vars = do { -- No type abstraction; clone existing binder - rhs' <- lvlFloatRhs [] dest_lvl env NonRecursive + rhs' <- lvlFloatRhs env dest_lvl [] NonRecursive is_bot_lam NotJoinPoint rhs - ; (env', [bndr']) <- cloneLetVars NonRecursive env dest_lvl [bndr] - ; let bndr2 = annotateBotStr bndr' 0 mb_bot_str - ; return (NonRec (TB bndr2 (FloatMe dest_lvl)) rhs', env') } + ; (env', [lbndr]) <- cloneLetVars NonRecursive env dest_lvl [bndr] + ; let lbndr' = annotateBotStr lbndr 0 mb_bot_str + ; return (NonRec lbndr' rhs', env') } | otherwise = do { -- Yes, type abstraction; create a new binder, extend substitution, etc - rhs' <- lvlFloatRhs abs_vars dest_lvl env NonRecursive + rhs' <- lvlFloatRhs env dest_lvl abs_vars NonRecursive is_bot_lam NotJoinPoint rhs - ; (env', [bndr']) <- newPolyBndrs dest_lvl env abs_vars [bndr] - ; let bndr2 = annotateBotStr bndr' n_extra mb_bot_str - ; return (NonRec (TB bndr2 (FloatMe dest_lvl)) rhs', env') } + ; (env', [lbndr]) <- newPolyBndrs env dest_lvl abs_vars [bndr] + ; let lbndr' = annotateBotStr lbndr n_extra mb_bot_str + ; return (NonRec lbndr' rhs', env') } where bndr_ty = idType bndr @@ -1314,8 +1318,8 @@ lvlBind env (AnnNonRec bndr rhs) lvlBind env (AnnRec pairs) | not (wantToFloat env Recursive dest_lvl is_join is_top_bindable) = -- No float - do { let bind_lvl = incMinorLvl (le_ctxt_lvl env) - (env', bndrs') = substAndLvlBndrs Recursive env bind_lvl bndrs + do { let bind_lvl = incMinorLvl (ctxtLevel env) + (env', bndrs') = substAndLvlBndrs env Recursive bind_lvl bndrs lvl_rhs (b,r) = lvlRhs env' Recursive is_bot (idJoinPointHood b) r ; rhss' <- mapM lvl_rhs pairs ; return (Rec (bndrs' `zip` rhss'), env') } @@ -1331,7 +1335,7 @@ lvlBind env (AnnRec pairs) -- I think we want to stop doing this | [(bndr,rhs)] <- pairs , count isId abs_vars > 1 - = do -- Special case for self recursion where there are + = -- Special case for self recursion where there are -- several variables carried around: build a local loop: -- poly_f = \abs_vars. \lam_vars . letrec f = \lam_vars. rhs in f lam_vars -- This just makes the closures a bit smaller. If we don't do @@ -1341,26 +1345,25 @@ lvlBind env (AnnRec pairs) -- mutually recursive functions, but it's quite a bit more complicated -- -- This all seems a bit ad hoc -- sigh - let (rhs_env, abs_vars_w_lvls) = lvlLamBndrs env dest_lvl abs_vars - rhs_lvl = le_ctxt_lvl rhs_env - - (rhs_env', [new_bndr]) <- cloneLetVars Recursive rhs_env rhs_lvl [bndr] - let - (lam_bndrs, rhs_body) = collectAnnBndrs rhs - (body_env1, lam_bndrs1) = substBndrsSL NonRecursive rhs_env' lam_bndrs - (body_env2, lam_bndrs2) = lvlLamBndrs body_env1 rhs_lvl lam_bndrs1 - new_rhs_body <- lvlRhs body_env2 Recursive is_bot NotJoinPoint rhs_body - (poly_env, [poly_bndr]) <- newPolyBndrs dest_lvl env abs_vars [bndr] - return (Rec [(TB poly_bndr (FloatMe dest_lvl) - , mkLams abs_vars_w_lvls $ - mkLams lam_bndrs2 $ - Let (Rec [( TB new_bndr (StayPut rhs_lvl) - , mkLams lam_bndrs2 new_rhs_body)]) - (mkVarApps (Var new_bndr) lam_bndrs1))] - , poly_env) + do { let (lam_bndrs, body) = collectAnnBndrs rhs + bndr_lvl = lamBndrLevel dest_lvl (abs_vars ++ lam_bndrs) + abs_lbs = stayPut bndr_lvl abs_vars + (body_env1, lam_lbs) = substAndLvlBndrs env NonRecursive bndr_lvl lam_bndrs + + ; (body_env2, [new_bndr]) <- cloneLetVars Recursive body_env1 (ctxtLevel body_env1) [bndr] + ; new_body <- lvlRhs body_env2 Recursive is_bot NotJoinPoint body + ; (poly_env, [poly_bndr]) <- newPolyBndrs dest_lvl env abs_vars [bndr] + + ; let rec_rhs = mkLams lam_lbs new_body + new_rhs = mkLams abs_lbs $ + mkLams lam_lbs $ + Let (Rec [( new_bndr, rec_rhs )]) $ + mkVarApps (Var new_bndr) (map taggedBndrBndr lam_lbs) + ; return ( Rec [(poly_bndr, new_rhs)] + , poly_env) } | otherwise -- Non-null abs_vars - = do { (new_env, new_bndrs) <- newPolyBndrs dest_lvl env abs_vars bndrs + = do { (new_env, new_bndrs) <- newPolyBndrs env dest_lvl abs_vars bndrs ; new_rhss <- mapM (do_rhs new_env) pairs ; return ( Rec ([TB b (FloatMe dest_lvl) | b <- new_bndrs] `zip` new_rhss) , new_env) } @@ -1375,7 +1378,7 @@ lvlBind env (AnnRec pairs) -- function in a Rec, and we don't much care what -- happens to it. False is simple! - do_rhs env (_,rhs) = lvlFloatRhs abs_vars dest_lvl env Recursive + do_rhs env (_,rhs) = lvlFloatRhs env dest_lvl abs_vars Recursive is_bot NotJoinPoint rhs @@ -1425,7 +1428,7 @@ wantToFloat env is_rec dest_lvl is_join is_top_bindable profitableFloat :: LevelEnv -> Level -> Bool profitableFloat env dest_lvl - = (dest_lvl `ltMajLvl` le_ctxt_lvl env) -- Escapes a value lambda + = (dest_lvl `ltMajLvl` ctxtLevel env) -- Escapes a value lambda || (isTopLvl dest_lvl && floatConsts env) -- Going all the way to top level @@ -1439,33 +1442,31 @@ lvlRhs :: LevelEnv -> CoreExprWithFVs -> LvlM LevelledExpr lvlRhs env rec_flag is_bot mb_join_arity expr - = lvlFloatRhs [] (le_ctxt_lvl env) env + = lvlFloatRhs env (ctxtLevel env) [] rec_flag is_bot mb_join_arity expr -lvlFloatRhs :: [OutVar] -> Level -> LevelEnv -> RecFlag +lvlFloatRhs :: LevelEnv -> Level -> [OutVar] -> RecFlag -> Bool -- Binding is for a bottoming function -> JoinPointHood -> CoreExprWithFVs -> LvlM (Expr LevelledBndr) -- Ignores the le_ctxt_lvl in env; treats dest_lvl as the baseline -lvlFloatRhs abs_vars dest_lvl env rec is_bot mb_join_arity rhs +lvlFloatRhs env dest_lvl abs_vars rec_flag is_bot mb_join_arity rhs = do { body' <- if not is_bot -- See Note [Floating from a RHS] && any isId bndrs then lvlMFE body_env True body else lvlExpr body_env body - ; return (mkLams bndrs' body') } + ; return (mkLams (abs_bndrs ++ bndrs') body') } where - (bndrs, body) | JoinPoint join_arity <- mb_join_arity - = collectNAnnBndrs join_arity rhs - | otherwise - = collectAnnBndrs rhs - (env1, bndrs1) = substBndrsSL NonRecursive env bndrs - all_bndrs = abs_vars ++ bndrs1 - (body_env, bndrs') | JoinPoint {} <- mb_join_arity - = lvlJoinBndrs env1 dest_lvl rec all_bndrs - | otherwise - = lvlLamBndrs env1 dest_lvl all_bndrs - -- The important thing here is that we call lvlLamBndrs on + (bndrs, body) = collectAnnBndrs rhs + bndr_lvl = case mb_join_arity of + JoinPoint ja -> assertPpr (null abs_vars) (ppr abs_vars) $ + joinLamBndrLevel dest_lvl rec_flag ja bndrs + NotJoinPoint -> lamBndrLevel dest_lvl (abs_bndrs ++ bndrs) + + abs_bndrs = stayPut bndr_lvl abs_vars + (body_env, bndrs') = substAndLvlBndrs env NonRecursive bndr_lvl bndrs + -- The important thing here is that we call `lamBndrLevel` on -- all these binders at once (abs_vars and bndrs), so they -- all get the same major level. Otherwise we create stupid -- let-bindings inside, joyfully thinking they can float; but @@ -1522,31 +1523,32 @@ Use lvlExpr otherwise. A little subtle, and I got it wrong at least twice ************************************************************************ -} -substAndLvlBndrs :: RecFlag -> LevelEnv -> Level -> [InVar] -> (LevelEnv, [LevelledBndr]) -substAndLvlBndrs is_rec env lvl bndrs - = lvlBndrs subst_env lvl subst_bndrs +setCtxtLevel :: LevelEnv -> Level -> LevelEnv +setCtxtLevel env lvl = env { le_ctxt_lvl = lvl } + +substAndLvlBndrs :: LevelEnv -> RecFlag -> Level -> [InVar] -> (LevelEnv, [LevelledBndr]) +-- New env has +-- * Updated context level +-- * Updated le_lvl_env for the InVars +-- * Updated le_subst and le_env for cloning +substAndLvlBndrs env@(LE { le_subst = subst, le_env = id_env, le_lvl_env = lvl_env }) + is_rec bndr_lvl in_bndrs + = ( env { le_ctxt_lvl = bndr_lvl + , le_lvl_env = addLvls bndr_lvl lvl_env in_bndrs + , le_subst = subst' + , le_env = foldl' add_id id_env (in_bndrs `zip` lvld_bndrs) } + , lvld_bndrs) where - (subst_env, subst_bndrs) = substBndrsSL is_rec env bndrs - -substBndrsSL :: RecFlag -> LevelEnv -> [InVar] -> (LevelEnv, [OutVar]) --- So named only to avoid the name clash with GHC.Core.Subst.substBndrs -substBndrsSL is_rec env@(LE { le_subst = subst, le_env = id_env }) bndrs - = ( env { le_subst = subst' - , le_env = foldl' add_id id_env (bndrs `zip` bndrs') } - , bndrs') + lvld_bndrs = stayPut bndr_lvl out_bndrs + (subst', out_bndrs) = case is_rec of + NonRecursive -> substBndrs subst in_bndrs + Recursive -> substRecBndrs subst in_bndrs + +lamBndrLevel :: Level -> [InVar] -> Level +lamBndrLevel ctxt_lvl bndrs + | any is_major bndrs = incMajorLvl ctxt_lvl + | otherwise = incMinorLvl ctxt_lvl where - (subst', bndrs') = case is_rec of - NonRecursive -> substBndrs subst bndrs - Recursive -> substRecBndrs subst bndrs - -lvlLamBndrs :: LevelEnv -> Level -> [OutVar] -> (LevelEnv, [LevelledBndr]) --- Compute the levels for the binders of a lambda group -lvlLamBndrs env lvl bndrs - = lvlBndrs env new_lvl bndrs - where - new_lvl | any is_major bndrs = incMajorLvl lvl - | otherwise = incMinorLvl lvl - is_major bndr = not (isOneShotBndr bndr) -- Only non-one-shot lambdas bump a major level, which in -- turn triggers floating. NB: isOneShotBndr is always @@ -1554,14 +1556,10 @@ lvlLamBndrs env lvl bndrs -- out of a big lambda. -- See Note [Computing one-shot info] in GHC.Types.Demand -lvlJoinBndrs :: LevelEnv -> Level -> RecFlag -> [OutVar] - -> (LevelEnv, [LevelledBndr]) -lvlJoinBndrs env lvl rec bndrs - = lvlBndrs env new_lvl bndrs - where - new_lvl | isRec rec = incMajorLvl lvl - | otherwise = incMinorLvl lvl - -- Non-recursive join points are one-shot; recursive ones are not +joinLamBndrLevel :: Level -> RecFlag -> JoinArity -> [InVar] -> Level +joinLamBndrLevel ctxt_lvl rec_flag join_arity bndrs + | isRec rec_flag = lamBndrLevel ctxt_lvl bndrs + | otherwise = lamBndrLevel ctxt_lvl (drop join_arity bndrs) lvlBndrs :: LevelEnv -> Level -> [CoreBndr] -> (LevelEnv, [LevelledBndr]) -- The binders returned are exactly the same as the ones passed, @@ -1576,10 +1574,10 @@ lvlBndrs :: LevelEnv -> Level -> [CoreBndr] -> (LevelEnv, [LevelledBndr]) lvlBndrs env@(LE { le_lvl_env = lvl_env }) new_lvl bndrs = ( env { le_ctxt_lvl = new_lvl , le_lvl_env = addLvls new_lvl lvl_env bndrs } - , map (stayPut new_lvl) bndrs) + , stayPut new_lvl bndrs) -stayPut :: Level -> OutVar -> LevelledBndr -stayPut new_lvl bndr = TB bndr (StayPut new_lvl) +stayPut :: Level -> [OutVar] -> [LevelledBndr] +stayPut new_lvl bndrs = [ TB bndr (StayPut new_lvl) | bndr <- bndrs ] -- Destination level is the max Id level of the expression -- (We'll abstract the type variables, if any.) @@ -1677,14 +1675,15 @@ countFreeIds = nonDetStrictFoldUDFM add 0 . getUniqDSet data LevelEnv = LE { le_switches :: FloatOutSwitches , le_ctxt_lvl :: Level -- The current level - , le_lvl_env :: VarEnv Level -- Domain is *post-cloned* TyVars and Ids + , le_lvl_env :: VarEnv Level -- Domain is *pre-cloned* InVars -- See Note [le_subst and le_env] - , le_subst :: Subst -- Domain is pre-cloned TyVars and Ids - -- The Id -> CoreExpr in the Subst is ignored - -- (since we want to substitute a LevelledExpr for - -- an Id via le_env) but we do use the Co/TyVar substs - , le_env :: IdEnv ([OutVar], LevelledExpr) -- Domain is pre-cloned Ids + , le_subst :: Subst -- Domain is pre-cloned TyVars and Ids + -- The Id -> CoreExpr in the Subst is ignored + -- (since we want to substitute a LevelledExpr for + -- an Id via le_env) but we do use the Co/TyVar substs + , le_env :: IdEnv ([LevelledBndr], LevelledExpr) -- Domain is pre-cloned Ids + -- The LevelledBndrs are the free vars of LevelledExpr } {- Note [le_subst and le_env] @@ -1733,10 +1732,13 @@ initialEnv float_lams binds -- to a later one. So here we put all the top-level binders in scope before -- we start, to satisfy the lookupIdSubst invariants (#20200 and #20294) -addLvl :: Level -> VarEnv Level -> OutVar -> VarEnv Level +ctxtLevel :: LevelEnv -> Level +ctxtLevel = le_ctxt_lvl + +addLvl :: Level -> VarEnv Level -> InVar -> VarEnv Level addLvl dest_lvl env v' = extendVarEnv env v' dest_lvl -addLvls :: Level -> VarEnv Level -> [OutVar] -> VarEnv Level +addLvls :: Level -> VarEnv Level -> [InVar] -> VarEnv Level addLvls dest_lvl env vs = foldl' (addLvl dest_lvl) env vs floatLams :: LevelEnv -> Maybe Int @@ -1751,9 +1753,6 @@ floatOverSat le = floatOutOverSatApps (le_switches le) floatTopLvlOnly :: LevelEnv -> Bool floatTopLvlOnly le = floatToTopLevelOnly (le_switches le) -incMinorLvlFrom :: LevelEnv -> Level -incMinorLvlFrom env = incMinorLvl (le_ctxt_lvl env) - -- extendCaseBndrEnv adds the mapping case-bndr->scrut-var if it can -- See Note [Binder-swap during float-out] extendCaseBndrEnv :: LevelEnv @@ -1769,34 +1768,25 @@ extendCaseBndrEnv le@(LE { le_subst = subst, le_env = id_env }) , le_env = add_id id_env (case_bndr, scrut_var) } extendCaseBndrEnv env _ _ = env -maxFvLevel :: (OutVar -> Bool) -> LevelEnv -> DVarSet -> Level -maxFvLevel max_me env var_set - = nonDetStrictFoldDVarSet (maxIn max_me env) tOP_LEVEL var_set +maxFvLevel :: Bool -> LevelEnv -> DVarSet -> Level +-- True <=> include type variables +maxFvLevel include_tyvars (LE { le_lvl_env = env }) var_set + = nonDetStrictFoldDVarSet (maxIn include_tyvars env) tOP_LEVEL var_set -- It's OK to use a non-deterministic fold here because maxIn commutes. -maxFvLevel' :: (OutVar -> Bool) -> LevelEnv -> TyCoVarSet -> Level +maxFvLevel' :: Bool -> LevelEnv -> TyCoVarSet -> Level -- Same but for TyCoVarSet -maxFvLevel' max_me env var_set - = nonDetStrictFoldUniqSet (maxIn max_me env) tOP_LEVEL var_set +maxFvLevel' include_tyvars (LE { le_lvl_env = env }) var_set + = nonDetStrictFoldUniqSet (maxIn include_tyvars env) tOP_LEVEL var_set -- It's OK to use a non-deterministic fold here because maxIn commutes. -maxIn :: (OutVar -> Bool) -> LevelEnv -> InVar -> Level -> Level -maxIn max_me (LE { le_lvl_env = lvl_env, le_env = id_env, le_subst = subst }) in_var lvl - | isId in_var - = case lookupVarEnv id_env in_var of - Just (abs_vars, _) -> foldr max_out lvl abs_vars - Nothing -> max_out in_var lvl - | otherwise -- TyVars - = case lookupTyVar subst in_var of - Just ty -> nonDetStrictFoldVarSet max_out lvl (tyCoVarsOfType ty) - Nothing -> max_out in_var lvl - where - max_out :: OutVar -> Level -> Level - max_out out_var lvl - | max_me out_var = case lookupVarEnv lvl_env out_var of - Just lvl' -> maxLvl lvl' lvl - Nothing -> lvl - | otherwise = lvl -- Ignore some vars depending on max_me +maxIn :: Bool -> VarEnv Level -> InVar -> Level -> Level +maxIn include_tyvars lvl_env var lvl + | not include_tyvars, isTyVar var = lvl + | otherwise = maxLvl lvl (lookupLevel lvl_env var) + +lookupLevel :: VarEnv Level -> InVar -> Level +lookupLevel env v = lookupVarEnv env v `orElse` tOP_LEVEL lookupVar :: LevelEnv -> Id -> LevelledExpr lookupVar le v = case lookupVarEnv (le_env le) v of @@ -1841,18 +1831,17 @@ type LvlM result = UniqSM result initLvl :: UniqSupply -> UniqSM a -> a initLvl = initUs_ -newPolyBndrs :: Level -> LevelEnv -> [OutVar] -> [InId] - -> LvlM (LevelEnv, [OutId]) +newPolyBndrs :: LevelEnv -> Level -> [OutVar] -> [InId] + -> LvlM (LevelEnv, [LevelledBndr]) -- The envt is extended to bind the new bndrs to dest_lvl, but -- the le_ctxt_lvl is unaffected -newPolyBndrs dest_lvl - env@(LE { le_lvl_env = lvl_env, le_subst = subst, le_env = id_env }) - abs_vars bndrs +newPolyBndrs env@(LE { le_lvl_env = lvl_env, le_subst = subst, le_env = id_env }) + dest_lvl abs_vars in_bndrs = assert (all (not . isCoVar) bndrs) $ -- What would we add to the CoSubst in this case. No easy answer. do { uniqs <- getUniquesM ; let new_bndrs = zipWith mk_poly_bndr bndrs uniqs bndr_prs = bndrs `zip` new_bndrs - env' = env { le_lvl_env = addLvls dest_lvl lvl_env new_bndrs + env' = env { le_lvl_env = addLvls dest_lvl lvl_env in_bndrs , le_subst = foldl' add_subst subst bndr_prs , le_env = foldl' add_id id_env bndr_prs } ; return (env', new_bndrs) } @@ -1860,12 +1849,15 @@ newPolyBndrs dest_lvl add_subst env (v, v') = extendIdSubst env v (mkVarApps (Var v') abs_vars) add_id env (v, v') = extendVarEnv env v ((v':abs_vars), mkVarApps (Var v') abs_vars) - mk_poly_bndr bndr uniq = transferPolyIdInfo bndr abs_vars $ -- Note [transferPolyIdInfo] in GHC.Types.Id - transfer_join_info bndr $ - mkSysLocal str uniq (idMult bndr) poly_ty - where - str = fsLit "poly_" `appendFS` occNameFS (getOccName bndr) - poly_ty = mkLamTypes abs_vars (substTyUnchecked subst (idType bndr)) + mk_poly_bndr :: InId -> Unique -> LevelledBndr + mk_poly_bndr bndr uniq + = TB new_bndr (FloatMe dest_lvl) + where + new_bndr = transferPolyIdInfo bndr abs_vars $ -- Note [transferPolyIdInfo] in GHC.Types.Id + transfer_join_info bndr $ + mkSysLocal str uniq (idMult bndr) poly_ty + str = fsLit "poly_" `appendFS` occNameFS (getOccName bndr) + poly_ty = mkLamTypes abs_vars (substTyUnchecked subst (idType bndr)) -- If we are floating a join point to top level, it stops being -- a join point. Otherwise it continues to be a join point, @@ -1900,21 +1892,22 @@ newLvlVar lvld_rhs join_arity_maybe is_mk_static = mkSysLocal (mkFastString "lvl") uniq ManyTy rhs_ty -- | Clone the binders bound by a single-alternative case. -cloneCaseBndrs :: LevelEnv -> Level -> [Var] -> LvlM (LevelEnv, [Var]) +cloneCaseBndrs :: LevelEnv -> Level -> [Var] -> LvlM (LevelEnv, [LevelledBndr]) cloneCaseBndrs env@(LE { le_subst = subst, le_lvl_env = lvl_env, le_env = id_env }) - new_lvl vs + dest_lvl vs = do { (subst', vs') <- cloneBndrsM subst vs -- N.B. We are not moving the body of the case, merely its case -- binders. Consequently we should *not* set le_ctxt_lvl. -- See Note [Setting levels when floating single-alternative cases]. - ; let env' = env { le_lvl_env = addLvls new_lvl lvl_env vs' + ; let lvld_bndrs = stayPut dest_lvl vs' + env' = env { le_lvl_env = addLvls dest_lvl lvl_env vs' , le_subst = subst' - , le_env = foldl' add_id id_env (vs `zip` vs') } + , le_env = foldl' add_id id_env (vs `zip` lvld_bndrs) } - ; return (env', vs') } + ; return (env', lvld_bndrs) } cloneLetVars :: RecFlag -> LevelEnv -> Level -> [InVar] - -> LvlM (LevelEnv, [OutVar]) + -> LvlM (LevelEnv, [LevelledBndr]) -- See Note [Need for cloning during float-out] -- Works for Ids bound by let(rec) -- The dest_lvl is attributed to the binders in the new env, @@ -1927,12 +1920,13 @@ cloneLetVars is_rec NonRecursive -> cloneBndrsM subst vs1 Recursive -> cloneRecIdBndrsM subst vs1 - ; let prs = vs `zip` vs2 - env' = env { le_lvl_env = addLvls dest_lvl lvl_env vs2 + ; let lvld_bndrs = [ TB v2 (FloatMe dest_lvl) | v2 <- vs2 ] + prs = vs `zip` lvld_bndrs + env' = env { le_lvl_env = addLvls dest_lvl lvl_env vs , le_subst = subst' , le_env = foldl' add_id id_env prs } - ; return (env', vs2) } + ; return (env', lvld_bndrs) } where zap :: Var -> Var -- See Note [Floatifying demand info when floating] @@ -1944,10 +1938,11 @@ cloneLetVars is_rec zap_join | isTopLvl dest_lvl = zapJoinId | otherwise = id -add_id :: IdEnv ([Var], LevelledExpr) -> (Var, Var) -> IdEnv ([Var], LevelledExpr) -add_id id_env (v, v1) +add_id :: IdEnv ([LevelledBndr], LevelledExpr) -> (Var, LevelledBndr) + -> IdEnv ([LevelledBndr], LevelledExpr) +add_id id_env (v, lb@(TB v1 _)) | isTyVar v = delVarEnv id_env v - | otherwise = extendVarEnv id_env v ([v1], assert (not (isCoVar v1)) $ Var v1) + | otherwise = extendVarEnv id_env v ([lb], assert (not (isCoVar v1)) $ Var v1) {- Note [Zapping JoinId when floating] ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ View it on GitLab: https://gitlab.haskell.org/ghc/ghc/-/commit/459bd467af024a2570e5d3ac6ca5328f... -- View it on GitLab: https://gitlab.haskell.org/ghc/ghc/-/commit/459bd467af024a2570e5d3ac6ca5328f... You're receiving this email because of your account on gitlab.haskell.org.
participants (1)
-
Simon Peyton Jones (@simonpj)