Simon Peyton Jones pushed to branch wip/T26115 at Glasgow Haskell Compiler / GHC

Commits:

1 changed file:

Changes:

  • compiler/GHC/HsToCore/Binds.hs
    ... ... @@ -57,6 +57,7 @@ import GHC.Core.TyCon
    57 57
     import GHC.Core.Type
    
    58 58
     import GHC.Core.Coercion
    
    59 59
     import GHC.Core.Rules
    
    60
    +import GHC.Core.Ppr( pprCoreBinders )
    
    60 61
     import GHC.Core.TyCo.Compare( eqType )
    
    61 62
     
    
    62 63
     import GHC.Builtin.Names
    
    ... ... @@ -1002,6 +1003,88 @@ when we would rather avoid passing both dictionaries, and instead generate:
    1002 1003
       $sg @c d = let { d1 = $p1Ord d; d2 = d } in <g-rhs> @c @c d1 d2
    
    1003 1004
     
    
    1004 1005
     For now, we accept this infelicity.
    
    1006
    +
    
    1007
    +Note [Desugaring new-form SPECIALISE pragmas] -- Take 2
    
    1008
    +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
    
    1009
    +Suppose we have
    
    1010
    +  f :: forall a b c d. (Ord a, Ord b, Eq c, Ix d) => ...
    
    1011
    +  f = rhs
    
    1012
    +  {-# SPECIALISE f @p @[p] @[Int] @(q,r) #-}
    
    1013
    +
    
    1014
    +The type-checker generates `the_call` which looks like
    
    1015
    +
    
    1016
    +  spe_bndrs = (dx1 :: Ord p) (dx2::Ix q) (dx3::Ix r)
    
    1017
    +  the_call = let d6 = dx1
    
    1018
    +                 d2 = $fOrdList d6
    
    1019
    +                 d3 = $fEqList $fEqInt
    
    1020
    +                 d7 = dx1   -- Solver may introduce
    
    1021
    +                 d1 = d7    -- these indirections
    
    1022
    +                 d4 = $fIxPair dx2 dx3
    
    1023
    +             in f @p @p @[Int] @(q,r)
    
    1024
    +                  (d1::Ord p) (d2::Ord [p]) (d3::Eq [Int]) (d4::Ix (q,r)
    
    1025
    +
    
    1026
    +We /could/ generate
    
    1027
    +   RULE  f d1 d2 d3 d4 e1..en = $sf d1 d2 d3 d4
    
    1028
    +   $sf d1 d2 d3 d4 = <rhs> d1 d2 d3 d4
    
    1029
    +
    
    1030
    +But that would do no specialisation! What we want is this:
    
    1031
    +   RULE  f d1 _d2 _d3 d4 e1..en = $sf d1 d4
    
    1032
    +   $sf d1 d4 =  let d7 = d1   -- Renaming
    
    1033
    +                    dx1 = d7  -- Renaming
    
    1034
    +                    d6 = dx1
    
    1035
    +                    d2 = $fOrdList d6
    
    1036
    +                    d3 = $fEqList $fEqInt
    
    1037
    +                in rhs d1 d2 d3 d4
    
    1038
    +
    
    1039
    +Notice that:
    
    1040
    +  * We pass some, but not all, of the matched dictionaries to $sf
    
    1041
    +
    
    1042
    +  * We get specialisations for d2 and d3, but not for d1, nor d4.
    
    1043
    +
    
    1044
    +  * We had to introduce some renaming bindings at the top
    
    1045
    +    to line things up
    
    1046
    +
    
    1047
    +The transformation goes in these steps
    
    1048
    +(S1) decomposeCall: decomopose `the_call` into
    
    1049
    +     - `rev_binds`: the enclosing let-bindings (actually reversed)
    
    1050
    +     - `rule_lhs_args`: the arguments of the call itself
    
    1051
    +    We carefully arrange that the dictionary arguments of the actual
    
    1052
    +    call, `rule_lhs_args` are all distinct dictionary variables,
    
    1053
    +    not expressions. How? We use `simpleOptExprNoInline` to avoid
    
    1054
    +    inlining the let-bindings.
    
    1055
    +
    
    1056
    +(S2) Compute `rule_bndrs`: the free vars of `rule_lhs_args`, which
    
    1057
    +   will be the forall'd template variables of the RULE.  In the example,
    
    1058
    +       rule_bndrs = d1,d2,d3,d4
    
    1059
    +
    
    1060
    +(S3) grabSpecBinds: transform `rev_binds` into `spec_binds`: the
    
    1061
    +   bindings we will wrap around the call in the RHS of `$sf`
    
    1062
    +
    
    1063
    +(S4) Find `spec_bndrs`, the subset of `rule_bndrs` that we actually
    
    1064
    +   need to pass to `$sf`, simply by filtering out those that are
    
    1065
    +   bound by `spec_binds`.  In the example
    
    1066
    +      spec_bndrs = d1,d4
    
    1067
    +
    
    1068
    +
    
    1069
    +     Working inner
    
    1070
    +* Grab any bindings we can that will "shadow" the forall'd
    
    1071
    +  rule-bndrs, giving specialised bindings for them.
    
    1072
    +  * We keep a set of known_bndrs starting with {d1,..,dn}
    
    1073
    +  * We keep a binding iff no free var is
    
    1074
    +      (a) in orig_bndrs (i.e. not totally free)
    
    1075
    +      (b) not in known_bndrs
    
    1076
    +  * If we keep it, add its binder to known_bndrs; if not, don't
    
    1077
    +
    
    1078
    +To maximise what we can "grab", start by extracting /renamings/ of the
    
    1079
    +forall'd rule_bndrs, and bringing them to the top.  A renaming is
    
    1080
    +    rule_bndr = d
    
    1081
    +If we see this:
    
    1082
    +  * Bring d=rule_bndr to the top
    
    1083
    +  * Add d to the set of variables to look for on the right.
    
    1084
    +    e.g.    rule_bndrs = d1, d2
    
    1085
    +            Bindings   { d7=d9; d1=d7 }
    
    1086
    +        Bring to the top  { d7=d1; d9=d7 }
    
    1087
    +
    
    1005 1088
     -}
    
    1006 1089
     
    
    1007 1090
     ------------------------
    
    ... ... @@ -1083,13 +1166,18 @@ dsSpec_help poly_nm poly_id poly_rhs inl orig_bndrs ds_call
    1083 1166
                 Nothing -> do { diagnosticDs (DsRuleLhsTooComplicated ds_call core_call)
    
    1084 1167
                                ; return Nothing } ;
    
    1085 1168
     
    
    1086
    -            Just (rev_binds, rule_lhs_args) ->
    
    1169
    +            Just (binds, rule_lhs_args) ->
    
    1170
    +
    
    1171
    +    do { let locals = mkVarSet orig_bndrs `extendVarSetList` bindersOfBinds binds
    
    1172
    +             is_local :: Var -> Bool
    
    1173
    +             is_local v = v `elemVarSet` locals
    
    1174
    +
    
    1175
    +             rule_bndrs = scopedSort (exprsSomeFreeVarsList is_local rule_lhs_args)
    
    1176
    +             rn_binds = getRenamings orig_bndrs binds rule_bndrs
    
    1177
    +
    
    1178
    +             spec_binds = pickSpecBinds is_local (mkVarSet rule_bndrs)
    
    1179
    +                                        (rn_binds ++ binds)
    
    1087 1180
     
    
    1088
    -    do { let orig_bndr_set = mkVarSet orig_bndrs
    
    1089
    -             locally_bound = orig_bndr_set `extendVarSetList` bindersOfBinds rev_binds
    
    1090
    -             rule_bndrs = scopedSort (exprsSomeFreeVarsList (`elemVarSet` locally_bound)
    
    1091
    -                                                            rule_lhs_args)
    
    1092
    -             spec_binds = grabSpecBinds orig_bndr_set (mkVarSet rule_bndrs) rev_binds
    
    1093 1181
                  spec_binds_bndr_set = mkVarSet (bindersOfBinds spec_binds)
    
    1094 1182
                  spec_bndrs = filterOut (`elemVarSet` spec_binds_bndr_set) rule_bndrs
    
    1095 1183
     
    
    ... ... @@ -1101,13 +1189,14 @@ dsSpec_help poly_nm poly_id poly_rhs inl orig_bndrs ds_call
    1101 1189
            ; tracePm "dsSpec(new route)" $
    
    1102 1190
              vcat [ text "poly_id" <+> ppr poly_id
    
    1103 1191
                   , text "unfolding" <+> ppr (realIdUnfolding poly_id)
    
    1104
    -              , text "orig_bndrs"   <+> ppr orig_bndrs
    
    1192
    +              , text "orig_bndrs"   <+> pprCoreBinders orig_bndrs
    
    1105 1193
                   , text "ds_call" <+> ppr ds_call
    
    1106 1194
                   , text "core_call" <+> ppr core_call
    
    1107
    -              , text "rev_binds" <+> ppr rev_binds
    
    1195
    +              , text "binds" <+> ppr binds
    
    1108 1196
                   , text "rule_bndrs" <+> ppr rule_bndrs
    
    1109 1197
                   , text "rule_lhs_args" <+> ppr rule_lhs_args
    
    1110 1198
                   , text "spec_bndrs" <+> ppr spec_bndrs
    
    1199
    +              , text "rn_binds" <+> ppr rn_binds
    
    1111 1200
                   , text "spec_binds" <+> ppr spec_binds ]
    
    1112 1201
     
    
    1113 1202
            ; finishSpecPrag poly_nm poly_rhs
    
    ... ... @@ -1115,7 +1204,7 @@ dsSpec_help poly_nm poly_id poly_rhs inl orig_bndrs ds_call
    1115 1204
                             spec_bndrs mk_spec_body inl } } }
    
    1116 1205
     
    
    1117 1206
     decomposeCall :: Id -> CoreExpr
    
    1118
    -               -> Maybe ( [CoreBind]    -- Reversed bindings
    
    1207
    +               -> Maybe ( [CoreBind]
    
    1119 1208
                             , [CoreExpr] )  -- Args of the call
    
    1120 1209
     decomposeCall poly_id binds
    
    1121 1210
       = go [] binds
    
    ... ... @@ -1125,42 +1214,78 @@ decomposeCall poly_id binds
    1125 1214
         go acc e
    
    1126 1215
           | (Var fun, args) <- collectArgs e
    
    1127 1216
           = assertPpr (fun == poly_id) (ppr fun $$ ppr poly_id) $
    
    1128
    -        Just (acc, args)
    
    1217
    +        Just (reverse acc, args)
    
    1129 1218
           | otherwise
    
    1130 1219
           = Nothing
    
    1131 1220
     
    
    1221
    +getRenamings :: [Var] -> [CoreBind]  -- orig_bndrs and bindings
    
    1222
    +             -> [Var]                -- rule_bndrs
    
    1223
    +             -> [CoreBind]           -- Binds some of the orig_bndrs to a rule_bndr
    
    1224
    +getRenamings orig_bndrs binds rule_bndrs
    
    1225
    +  = [ NonRec b e | b <- orig_bndrs
    
    1226
    +                 , not (b `elem` rule_bndrs)
    
    1227
    +                 , Just e <- [lookupVarEnv final_renamings b] ]
    
    1228
    +  where
    
    1229
    +    init_renamings, final_renamings :: IdEnv CoreExpr
    
    1230
    +    -- In this function, IdEnv maps a local variable to (v |> co),
    
    1231
    +    -- where `v` is a rule_bndr
    
    1232
    +
    
    1233
    +    init_renamings = mkVarEnv [ (v, Var v) | v <- rule_bndrs, isId v ]
    
    1234
    +    final_renamings = go binds
    
    1235
    +
    
    1236
    +    go :: [CoreBind] -> IdEnv CoreExpr
    
    1237
    +    go [] = init_renamings
    
    1238
    +    go (bind : binds)
    
    1239
    +       | NonRec b rhs <- bind
    
    1240
    +       , Just (v, mco) <- getCastedVar rhs
    
    1241
    +       , Just e <- lookupVarEnv renamings v
    
    1242
    +       = extendVarEnv renamings b (mkCastMCo e (mkSymMCo mco))
    
    1243
    +       | otherwise
    
    1244
    +       = renamings
    
    1245
    +       where
    
    1246
    +         renamings = go binds
    
    1132 1247
     
    
    1133
    -grabSpecBinds :: VarSet -> VarSet -> [CoreBind] -> [CoreBind]
    
    1134
    -grabSpecBinds orig_bndrs rule_bndrs rev_binds
    
    1135
    -   = reverse rename_binds ++ spec_binds
    
    1248
    +pickSpecBinds :: (Var -> Bool) -> VarSet -> [CoreBind] -> [CoreBind]
    
    1249
    +pickSpecBinds _ _ [] = []
    
    1250
    +pickSpecBinds is_local known_bndrs (bind:binds)
    
    1251
    +      | all keep_me (rhssOfBind bind)
    
    1252
    +      , let known_bndrs' = known_bndrs `extendVarSetList` bindersOf bind
    
    1253
    +      = bind : pickSpecBinds is_local known_bndrs' binds
    
    1254
    +      | otherwise
    
    1255
    +      = pickSpecBinds is_local known_bndrs binds
    
    1256
    +      where
    
    1257
    +        keep_me rhs = isEmptyVarSet (exprSomeFreeVars bad_var rhs)
    
    1258
    +        bad_var v = is_local v && not (v `elemVarSet` known_bndrs)
    
    1259
    +{-
    
    1260
    +grabSpecBinds :: (Var -> Bool) -> VarSet -> [CoreBind]
    
    1261
    +              -> ([CoreBind], [CoreBind])
    
    1262
    +grabSpecBinds is_local rule_bndrs rev_binds
    
    1263
    +   = (reverse rename_binds, spec_binds)
    
    1136 1264
       where
    
    1137 1265
         (known_bndrs, (rename_binds, other_binds))
    
    1138
    -        = get_renamings orig_bndrs rule_bndrs ([],[]) rev_binds
    
    1266
    +        = get_renamings rule_bndrs ([],[]) rev_binds
    
    1139 1267
         spec_binds = pick_spec_binds known_bndrs other_binds
    
    1140 1268
     
    
    1141 1269
         ------------------------
    
    1142
    -    get_renamings :: VarSet  -- Locally bound variables
    
    1143
    -                  -> VarSet  -- Variables bound by a successful match on the call
    
    1270
    +    get_renamings :: VarSet  -- Variables bound by a successful match on the call
    
    1144 1271
                       -> ([CoreBind],[CoreBind])   -- Accumulating parameter, in order
    
    1145 1272
                       -> [CoreBind]     -- Reversed, innermost first
    
    1146 1273
                       -> ( VarSet
    
    1147 1274
                          , ([CoreBind]    -- Renamings, in order
    
    1148 1275
                          ,  [CoreBind]))  -- Other bindings, in order
    
    1149
    -    get_renamings _ bndrs acc [] = (bndrs, acc)
    
    1276
    +    get_renamings bndrs acc [] = (bndrs, acc)
    
    1150 1277
     
    
    1151
    -    get_renamings locals bndrs (rn_binds, other_binds) (bind : binds)
    
    1278
    +    get_renamings bndrs (rn_binds, other_binds) (bind : binds)
    
    1152 1279
           | NonRec d r <- bind
    
    1153 1280
           , d `elemVarSet` bndrs
    
    1154 1281
           , Just (v, mco) <- getCastedVar r
    
    1155
    -      , v `elemVarSet` locals
    
    1282
    +      , is_local v
    
    1156 1283
           , let flipped_bind = NonRec v (mkCastMCo (Var d) (mkSymMCo mco))
    
    1157 1284
           = get_renamings (bndrs `extendVarSet` v)
    
    1158
    -                      (locals `extendVarSet` d)
    
    1159 1285
                           (flipped_bind:rn_binds, other_binds)
    
    1160 1286
                           binds
    
    1161 1287
           | otherwise
    
    1162 1288
           = get_renamings bndrs
    
    1163
    -                     (locals `extendVarSetList` bindersOf bind)
    
    1164 1289
                          (rn_binds, bind:other_binds)
    
    1165 1290
                          binds
    
    1166 1291
     
    
    ... ... @@ -1175,7 +1300,8 @@ grabSpecBinds orig_bndrs rule_bndrs rev_binds
    1175 1300
           = pick_spec_binds known_bndrs binds
    
    1176 1301
           where
    
    1177 1302
             keep_me rhs = isEmptyVarSet (exprSomeFreeVars bad_var rhs)
    
    1178
    -        bad_var v = v `elemVarSet` orig_bndrs && not (v `elemVarSet` known_bndrs)
    
    1303
    +        bad_var v = is_local v && not (v `elemVarSet` known_bndrs)
    
    1304
    +-}
    
    1179 1305
     
    
    1180 1306
     getCastedVar :: CoreExpr -> Maybe (Var, MCoercionR)
    
    1181 1307
     getCastedVar (Var v)           = Just (v, MRefl)