Adam Gundry pushed to branch wip/amg/castz at Glasgow Haskell Compiler / GHC

Commits:

5 changed files:

Changes:

  • compiler/GHC/Core/Coercion.hs
    ... ... @@ -984,12 +984,13 @@ mkForAllCo v visL visR kind_co co
    984 984
       | otherwise
    
    985 985
       = mk_forall_co v visL visR kind_co co
    
    986 986
     
    
    987
    -mkForAllCastCo :: HasDebugCallStack => TyCoVar -> ForAllTyFlag -> ForAllTyFlag
    
    988
    -            -> CastCoercion -> CastCoercion
    
    989
    -mkForAllCastCo v visL visR cco = case cco of
    
    987
    +mkForAllCastCo :: HasDebugCallStack => Role -> TyCoVar -> ForAllTyFlag -> ForAllTyFlag
    
    988
    +            -> Type -> CastCoercion -> CastCoercion
    
    989
    +mkForAllCastCo r v visL visR ty cco = case cco of
    
    990 990
         CCoercion co -> CCoercion (mkForAllCo v visL visR MRefl co)
    
    991
    -    ZCoercion ty cos -> ZCoercion (mkTyCoForAllTy v visL ty) cos
    
    992
    -    ReflCastCo -> ReflCastCo
    
    991
    +    ZCoercion ty cos -> ZCoercion (mkTyCoForAllTy v visR ty) cos
    
    992
    +    ReflCastCo | visL `eqForAllVis` visR -> ReflCastCo
    
    993
    +               | otherwise -> CCoercion (mk_forall_co v visL visR MRefl (mkReflCo r ty))
    
    993 994
     
    
    994 995
     -- mkForAllVisCos [tv{vis}] constructs a cast
    
    995 996
     --   forall tv. res  ~R#   forall tv{vis} res`.
    
    ... ... @@ -1816,7 +1817,7 @@ mkPiCastCo :: Role -> Var -> CastCoercion -> CastCoercion
    1816 1817
     mkPiCastCo _ _ ReflCastCo     = ReflCastCo
    
    1817 1818
     mkPiCastCo r v (CCoercion co) = CCoercion (mkPiCo r v co)
    
    1818 1819
     mkPiCastCo _ v (ZCoercion ty cos)
    
    1819
    -  | isTyVar v = ZCoercion (mkForAllTy (Bndr v vis) ty) cos
    
    1820
    +  | isTyVar v = ZCoercion (mkTyCoForAllTy v vis ty) cos
    
    1820 1821
       | otherwise = ZCoercion (mkFunctionType (idMult v) (varType v) ty) cos
    
    1821 1822
       where
    
    1822 1823
         vis = coreTyLamForAllTyFlag
    

  • compiler/GHC/Core/Opt/Arity.hs
    ... ... @@ -2359,12 +2359,12 @@ mkEtaWW orig_oss ppr_orig_expr in_scope orig_ty
    2359 2359
             -- with an explicit lambda having a non-function type
    
    2360 2360
     
    
    2361 2361
     mkEtaForAllMCo :: ForAllTyBinder -> Type -> CastCoercion -> CastCoercion
    
    2362
    -mkEtaForAllMCo bdnr@(Bndr tcv vis) ty mco
    
    2362
    +mkEtaForAllMCo (Bndr tcv vis) ty mco
    
    2363 2363
       = case mco of
    
    2364 2364
           ReflCastCo | vis == coreTyLamForAllTyFlag -> ReflCastCo
    
    2365 2365
                      | otherwise                    -> mk_fco (mkRepReflCo ty)
    
    2366 2366
           CCoercion co                              -> mk_fco co
    
    2367
    -      ZCoercion tyR cos                         -> ZCoercion (mkForAllTy bdnr tyR) cos
    
    2367
    +      ZCoercion tyR cos                         -> ZCoercion (mkTyCoForAllTy tcv coreTyLamForAllTyFlag tyR) cos
    
    2368 2368
       where
    
    2369 2369
         mk_fco co = CCoercion (mkForAllCo tcv vis coreTyLamForAllTyFlag MRefl co)
    
    2370 2370
         -- coreTyLamForAllTyFlag: See Note [The EtaInfo mechanism], particularly
    
    ... ... @@ -2723,7 +2723,7 @@ tryEtaReduce rec_ids bndrs body eval_sd
    2723 2723
           -- Float app ticks: \x -> Tick t (e x) ==> Tick t e
    
    2724 2724
     
    
    2725 2725
         go (b : bs) (App fun arg) co
    
    2726
    -      | Just (co', ticks) <- ok_arg b arg co (exprType fun)
    
    2726
    +      | Just (co', ticks) <- ok_arg b arg co (exprType fun) (exprType (App fun arg))
    
    2727 2727
           = fmap (flip (foldr mkTick) ticks) $ go bs fun co'
    
    2728 2728
                 -- Float arg ticks: \x -> e (Tick t x) ==> Tick t e
    
    2729 2729
     
    
    ... ... @@ -2798,15 +2798,16 @@ tryEtaReduce rec_ids bndrs body eval_sd
    2798 2798
                -> CastCoercion     -- Of kind (t1~t2)
    
    2799 2799
                -> Type             -- Type (arg_t -> t1) of the function
    
    2800 2800
                                    --      to which the argument is supplied
    
    2801
    +           -> Type             -- Type t1 of the result (AMG TODO: avoid needing to pass this?)
    
    2801 2802
                -> Maybe (CastCoercion  -- Of type (arg_t -> t1 ~  bndr_t -> t2)
    
    2802 2803
                                    --   (and similarly for tyvars, coercion args)
    
    2803 2804
                         , [CoreTickish])
    
    2804 2805
         -- See Note [Eta reduction with casted arguments]
    
    2805
    -    ok_arg bndr (Type arg_ty) co fun_ty
    
    2806
    +    ok_arg bndr (Type arg_ty) co fun_ty res_ty
    
    2806 2807
            | Just tv <- getTyVar_maybe arg_ty
    
    2807 2808
            , bndr == tv  = case splitForAllForAllTyBinder_maybe fun_ty of
    
    2808 2809
                Just (Bndr _ vis, _) -> Just (fco, [])
    
    2809
    -             where !fco = mkForAllCastCo tv vis coreTyLamForAllTyFlag co
    
    2810
    +             where !fco = mkForAllCastCo Representational tv vis coreTyLamForAllTyFlag res_ty co
    
    2810 2811
                        -- The lambda we are eta-reducing always has visibility
    
    2811 2812
                        -- 'coreTyLamForAllTyFlag' which may or may not match
    
    2812 2813
                        -- the visibility on the inner function (#24014)
    
    ... ... @@ -2814,24 +2815,24 @@ tryEtaReduce rec_ids bndrs body eval_sd
    2814 2815
                                    (text "fun:" <+> ppr bndr
    
    2815 2816
                                     $$ text "arg:" <+> ppr arg_ty
    
    2816 2817
                                     $$ text "fun_ty:" <+> ppr fun_ty)
    
    2817
    -    ok_arg bndr (Var v) co fun_ty
    
    2818
    +    ok_arg bndr (Var v) co fun_ty _
    
    2818 2819
            | bndr == v
    
    2819 2820
            , let mult = idMult bndr
    
    2820 2821
            , Just (_af, fun_mult, _, _) <- splitFunTy_maybe fun_ty
    
    2821 2822
            , mult `eqType` fun_mult -- There is no change in multiplicity, otherwise we must abort
    
    2822 2823
            = Just (mkFunResCastCo Representational bndr co, [])
    
    2823
    -    ok_arg bndr (Cast e co_arg) co fun_ty
    
    2824
    +    ok_arg bndr (Cast e co_arg) co fun_ty _
    
    2824 2825
            | (ticks, Var v) <- stripTicksTop tickishFloatable e
    
    2825 2826
            , Just (_, fun_mult, _, res_ty) <- splitFunTy_maybe fun_ty
    
    2826 2827
            , bndr == v
    
    2827 2828
            , fun_mult `eqType` idMult bndr
    
    2828
    -       = Just (mkFunCastCoNoFTF Representational fun_mult (castCoercionRKind (exprType e) co_arg) (mkSymCastCo (exprType e) co_arg) res_ty co, ticks) -- TODO check types
    
    2829
    +       = Just (mkFunCastCoNoFTF Representational fun_mult (castCoercionRKind (exprType e) co_arg) (mkSymCastCo (exprType e) co_arg) res_ty co, ticks)
    
    2829 2830
            -- The simplifier combines multiple casts into one,
    
    2830 2831
            -- so we can have a simple-minded pattern match here
    
    2831
    -    ok_arg bndr (Tick t arg) co fun_ty
    
    2832
    -       | tickishFloatable t, Just (co', ticks) <- ok_arg bndr arg co fun_ty
    
    2832
    +    ok_arg bndr (Tick t arg) co fun_ty res_ty
    
    2833
    +       | tickishFloatable t, Just (co', ticks) <- ok_arg bndr arg co fun_ty res_ty
    
    2833 2834
            = Just (co', t:ticks)
    
    2834
    -    ok_arg _ _ _ _ = Nothing
    
    2835
    +    ok_arg _ _ _ _ _ = Nothing
    
    2835 2836
     
    
    2836 2837
     -- | Can we eta-reduce the given function
    
    2837 2838
     -- See Note [Eta reduction soundness], criteria (B), (J), and (W).
    

  • compiler/GHC/Core/Opt/Simplify/Iteration.hs
    ... ... @@ -28,7 +28,7 @@ import GHC.Core.Make ( FloatBind, mkImpossibleExpr, castBottomExpr )
    28 28
     import qualified GHC.Core.Make
    
    29 29
     import GHC.Core.Coercion hiding ( substCo, substCoVar )
    
    30 30
     import GHC.Core.Reduction
    
    31
    -import GHC.Core.Coercion.Opt    ( optCoercion )
    
    31
    +import GHC.Core.Coercion.Opt    ( optCoercion, optCastCoercion )
    
    32 32
     import GHC.Core.FamInstEnv      ( FamInstEnv, topNormaliseType_maybe )
    
    33 33
     import GHC.Core.DataCon
    
    34 34
     import GHC.Core.Opt.Stats ( Tick(..) )
    
    ... ... @@ -1545,7 +1545,7 @@ rebuild_go env expr cont
    1545 1545
             -> rebuild_go env (mkCastCo expr co') cont
    
    1546 1546
                -- NB: mkCast implements the (Coercion co |> g) optimisation
    
    1547 1547
             where
    
    1548
    -          co' = optOutCastCoercion env co opt
    
    1548
    +          co' = optOutCoercion env (exprType expr) co opt
    
    1549 1549
     
    
    1550 1550
           Select { sc_bndr = bndr, sc_alts = alts, sc_env = se, sc_cont = cont }
    
    1551 1551
             -> rebuildCase (se `setInScopeFromE` env) expr bndr alts cont
    
    ... ... @@ -1674,17 +1674,11 @@ on each successive composition -- that's at least quadratic. So:
    1674 1674
     -}
    
    1675 1675
     
    
    1676 1676
     
    
    1677
    -optOutCastCoercion :: SimplEnvIS -> OutCastCoercion -> Bool -> OutCastCoercion
    
    1678
    -optOutCastCoercion env cco already_optimised = case cco of
    
    1679
    -    ReflCastCo   -> ReflCastCo
    
    1680
    -    CCoercion co -> CCoercion (optOutCoercion env co already_optimised)
    
    1681
    -    ZCoercion{}  -> cco
    
    1682
    -
    
    1683
    -optOutCoercion :: SimplEnvIS -> OutCoercion -> Bool -> OutCoercion
    
    1677
    +optOutCoercion :: SimplEnvIS -> Type -> OutCastCoercion -> Bool -> OutCastCoercion
    
    1684 1678
     -- See Note [Avoid re-simplifying coercions]
    
    1685
    -optOutCoercion env co already_optimised
    
    1679
    +optOutCoercion env ty co already_optimised
    
    1686 1680
       | already_optimised = co  -- See Note [Avoid re-simplifying coercions]
    
    1687
    -  | otherwise         = optCoercion opts empty_subst co
    
    1681
    +  | otherwise         = optCastCoercion opts empty_subst ty co
    
    1688 1682
       where
    
    1689 1683
         empty_subst = mkEmptySubst (seInScope env)
    
    1690 1684
         opts = seOptCoercionOpts env
    
    ... ... @@ -1732,7 +1726,7 @@ simplCast env body co0 cont0
    1732 1726
                                                     , sc_dup = dup, sc_cont = tail
    
    1733 1727
                                                     , sc_hole_ty = fun_ty })
    
    1734 1728
               | not co_is_opt  -- pushCoValArg duplicates the coercion, so optimise first
    
    1735
    -          = addCoerce tyL (optOutCastCoercion (zapSubstEnv env) co co_is_opt) True cont
    
    1729
    +          = addCoerce tyL (optOutCoercion (zapSubstEnv env) tyL co co_is_opt) True cont
    
    1736 1730
     
    
    1737 1731
               | Just (_, m_co1, res_ty, m_co2) <- pushCastCoValArg tyL co
    
    1738 1732
               = {-#SCC "addCoerce-pushCoValArg" #-}
    
    ... ... @@ -3886,7 +3880,7 @@ mkDupableContWithDmds _ _ (Stop {}) = panic "mkDupableCont" -- Handled by pr
    3886 3880
     
    
    3887 3881
     mkDupableContWithDmds env dmds (CastIt { sc_co = co, sc_hole_ty = ty, sc_opt = opt, sc_cont = cont })
    
    3888 3882
       = do  { (floats, cont') <- mkDupableContWithDmds env dmds cont
    
    3889
    -        ; return (floats, CastIt { sc_co = optOutCastCoercion env co opt
    
    3883
    +        ; return (floats, CastIt { sc_co = optOutCoercion env ty co opt
    
    3890 3884
                                      , sc_hole_ty = ty
    
    3891 3885
                                      , sc_opt = True, sc_cont = cont' }) }
    
    3892 3886
                      -- optOutCoercion: see Note [Avoid re-simplifying coercions]
    

  • compiler/GHC/Tc/Zonk/Type.hs
    ... ... @@ -1862,9 +1862,10 @@ zonkEvTerm (EvExpr e)
    1862 1862
       = EvExpr <$> zonkCoreExpr e
    
    1863 1863
     zonkEvTerm (EvCastExpr e (CCoercion co) co_res_ty)
    
    1864 1864
       = do { zap_casts <- hasZapCasts <$> lift getDynFlags
    
    1865
    -       ; co_res_ty' <- zonkTcTypeToTypeX co_res_ty
    
    1866
    -       ; if zap_casts
    
    1867
    -         then EvCastExpr <$> zonkCoreExpr e <*> (ZCoercion co_res_ty' <$> zonkShallowCoVarsOfCo co) <*> pure co_res_ty'
    
    1865
    +       ; if zap_casts && coercionSize co > typeSize co_res_ty -- AMG TODO: experimental heuristic
    
    1866
    +         then do { co_res_ty' <- zonkTcTypeToTypeX co_res_ty
    
    1867
    +                 ; EvCastExpr <$> zonkCoreExpr e <*> (ZCoercion co_res_ty' <$> zonkShallowCoVarsOfCo co) <*> pure co_res_ty'
    
    1868
    +                 }
    
    1868 1869
              else EvExpr <$> zonkCoreExpr (Cast e (CCoercion co))
    
    1869 1870
            }
    
    1870 1871
     zonkEvTerm ev@(EvCastExpr _ (ZCoercion{}) _)
    

  • testsuite/tests/ghci/prog-mhu002/prog-mhu002c.stdout
    ... ... @@ -15,6 +15,7 @@ other dynamic, non-language, flag settings:
    15 15
       -fshow-warning-groups
    
    16 16
       -fprefer-byte-code
    
    17 17
       -fbreak-points
    
    18
    +  -fno-zap-casts
    
    18 19
     warning settings:
    
    19 20
       -Wpattern-namespace-specifier
    
    20 21
     === :set
    
    ... ... @@ -34,6 +35,7 @@ other dynamic, non-language, flag settings:
    34 35
       -fshow-warning-groups
    
    35 36
       -fprefer-byte-code
    
    36 37
       -fbreak-points
    
    38
    +  -fno-zap-casts
    
    37 39
     warning settings:
    
    38 40
       -Wpattern-namespace-specifier
    
    39 41
     Unit ID: b-0.0.0
    
    ... ... @@ -51,6 +53,7 @@ other dynamic, non-language, flag settings:
    51 53
       -fshow-warning-groups
    
    52 54
       -fprefer-byte-code
    
    53 55
       -fbreak-points
    
    56
    +  -fno-zap-casts
    
    54 57
     warning settings:
    
    55 58
       -Wpattern-namespace-specifier
    
    56 59
     Unit ID: c-0.0.0
    
    ... ... @@ -68,6 +71,7 @@ other dynamic, non-language, flag settings:
    68 71
       -fshow-warning-groups
    
    69 72
       -fprefer-byte-code
    
    70 73
       -fbreak-points
    
    74
    +  -fno-zap-casts
    
    71 75
     warning settings:
    
    72 76
       -Wpattern-namespace-specifier
    
    73 77
     Unit ID: d-0.0.0
    
    ... ... @@ -85,5 +89,6 @@ other dynamic, non-language, flag settings:
    85 89
       -fshow-warning-groups
    
    86 90
       -fprefer-byte-code
    
    87 91
       -fbreak-points
    
    92
    +  -fno-zap-casts
    
    88 93
     warning settings:
    
    89 94
       -Wpattern-namespace-specifier