... |
... |
@@ -1079,14 +1079,15 @@ dsSpec_help poly_nm poly_id poly_rhs inl orig_bndrs ds_call |
1079
|
1079
|
; let simpl_opts = initSimpleOpts dflags
|
1080
|
1080
|
core_call = simpleOptExprNoInline simpl_opts ds_call
|
1081
|
1081
|
|
1082
|
|
- ; case decomposeCall poly_id [] core_call of {
|
|
1082
|
+ ; case decomposeCall poly_id core_call of {
|
1083
|
1083
|
Nothing -> do { diagnosticDs (DsRuleLhsTooComplicated ds_call core_call)
|
1084
|
1084
|
; return Nothing } ;
|
1085
|
1085
|
|
1086
|
1086
|
Just (rev_binds, rule_lhs_args) ->
|
1087
|
1087
|
|
1088
|
1088
|
do { let orig_bndr_set = mkVarSet orig_bndrs
|
1089
|
|
- rule_bndrs = scopedSort (exprsSomeFreeVarsList (`elemVarSet` orig_bndr_set)
|
|
1089
|
+ locally_bound = orig_bndr_set `extendVarSetList` bindersOfBinds rev_binds
|
|
1090
|
+ rule_bndrs = scopedSort (exprsSomeFreeVarsList (`elemVarSet` locally_bound)
|
1090
|
1091
|
rule_lhs_args)
|
1091
|
1092
|
spec_binds = grabSpecBinds orig_bndr_set (mkVarSet rule_bndrs) rev_binds
|
1092
|
1093
|
spec_binds_bndr_set = mkVarSet (bindersOfBinds spec_binds)
|
... |
... |
@@ -1100,17 +1101,14 @@ dsSpec_help poly_nm poly_id poly_rhs inl orig_bndrs ds_call |
1100
|
1101
|
; tracePm "dsSpec(new route)" $
|
1101
|
1102
|
vcat [ text "poly_id" <+> ppr poly_id
|
1102
|
1103
|
, text "unfolding" <+> ppr (realIdUnfolding poly_id)
|
1103
|
|
- , text "bndrs" <+> ppr bndrs
|
|
1104
|
+ , text "orig_bndrs" <+> ppr orig_bndrs
|
1104
|
1105
|
, text "ds_call" <+> ppr ds_call
|
1105
|
1106
|
, text "core_call" <+> ppr core_call
|
1106
|
|
- , text "bndr_set" <+> ppr bndr_set
|
1107
|
|
- , text "all_bndrs" <+> ppr all_bndrs
|
|
1107
|
+ , text "rev_binds" <+> ppr rev_binds
|
1108
|
1108
|
, text "rule_bndrs" <+> ppr rule_bndrs
|
1109
|
1109
|
, text "rule_lhs_args" <+> ppr rule_lhs_args
|
1110
|
|
- , text "const_bndrs" <+> ppr const_bndrs
|
1111
|
1110
|
, text "spec_bndrs" <+> ppr spec_bndrs
|
1112
|
|
- , text "core_call fvs" <+> ppr (exprFreeVars core_call)
|
1113
|
|
- , text "spec_const_binds" <+> ppr spec_const_binds ]
|
|
1111
|
+ , text "spec_binds" <+> ppr spec_binds ]
|
1114
|
1112
|
|
1115
|
1113
|
; finishSpecPrag poly_nm poly_rhs
|
1116
|
1114
|
rule_bndrs poly_id rule_lhs_args
|
... |
... |
@@ -1124,8 +1122,8 @@ decomposeCall poly_id binds |
1124
|
1122
|
where
|
1125
|
1123
|
go acc (Let bind body)
|
1126
|
1124
|
= go (bind:acc) body
|
1127
|
|
- go add e
|
1128
|
|
- | Just (Var fun, args) <- collectArgs e
|
|
1125
|
+ go acc e
|
|
1126
|
+ | (Var fun, args) <- collectArgs e
|
1129
|
1127
|
= assertPpr (fun == poly_id) (ppr fun $$ ppr poly_id) $
|
1130
|
1128
|
Just (acc, args)
|
1131
|
1129
|
| otherwise
|
... |
... |
@@ -1134,44 +1132,55 @@ decomposeCall poly_id binds |
1134
|
1132
|
|
1135
|
1133
|
grabSpecBinds :: VarSet -> VarSet -> [CoreBind] -> [CoreBind]
|
1136
|
1134
|
grabSpecBinds orig_bndrs rule_bndrs rev_binds
|
1137
|
|
- = rename_binds ++ spec_binds
|
|
1135
|
+ = reverse rename_binds ++ spec_binds
|
1138
|
1136
|
where
|
1139
|
|
- (known_bndrs, rename_binds, other_binds)
|
1140
|
|
- = get_renamings rule_bndrs ([],[]) rev_binds
|
|
1137
|
+ (known_bndrs, (rename_binds, other_binds))
|
|
1138
|
+ = get_renamings orig_bndrs rule_bndrs ([],[]) rev_binds
|
1141
|
1139
|
spec_binds = pick_spec_binds known_bndrs other_binds
|
1142
|
1140
|
|
1143
|
1141
|
------------------------
|
1144
|
|
- get_renamings :: VarSet -- Variables bound by a successful match on the call
|
|
1142
|
+ get_renamings :: VarSet -- Locally bound variables
|
|
1143
|
+ -> VarSet -- Variables bound by a successful match on the call
|
1145
|
1144
|
-> ([CoreBind],[CoreBind]) -- Accumulating parameter, in order
|
1146
|
1145
|
-> [CoreBind] -- Reversed, innermost first
|
1147
|
1146
|
-> ( VarSet
|
1148
|
|
- , [CoreBind] -- Renamings, in order
|
1149
|
|
- , [CoreBind]) -- Other bindings, in order
|
1150
|
|
- get_renamings _ acc [] acc
|
|
1147
|
+ , ([CoreBind] -- Renamings, in order
|
|
1148
|
+ , [CoreBind])) -- Other bindings, in order
|
|
1149
|
+ get_renamings _ bndrs acc [] = (bndrs, acc)
|
1151
|
1150
|
|
1152
|
|
- get_renamings bndrs (rn_binds, other_binds) (bind : binds)
|
|
1151
|
+ get_renamings locals bndrs (rn_binds, other_binds) (bind : binds)
|
1153
|
1152
|
| NonRec d r <- bind
|
1154
|
1153
|
, d `elemVarSet` bndrs
|
1155
|
1154
|
, Just (v, mco) <- getCastedVar r
|
|
1155
|
+ , v `elemVarSet` locals
|
1156
|
1156
|
, let flipped_bind = NonRec v (mkCastMCo (Var d) (mkSymMCo mco))
|
1157
|
1157
|
= get_renamings (bndrs `extendVarSet` v)
|
|
1158
|
+ (locals `extendVarSet` d)
|
1158
|
1159
|
(flipped_bind:rn_binds, other_binds)
|
1159
|
1160
|
binds
|
1160
|
1161
|
| otherwise
|
1161
|
|
- = get_renamings bndrs (rn_binds, bind:other_binds) binds
|
|
1162
|
+ = get_renamings bndrs
|
|
1163
|
+ (locals `extendVarSetList` bindersOf bind)
|
|
1164
|
+ (rn_binds, bind:other_binds)
|
|
1165
|
+ binds
|
1162
|
1166
|
|
1163
|
1167
|
------------------------
|
1164
|
1168
|
pick_spec_binds :: VarSet -> [CoreBind] -> [CoreBind]
|
1165
|
|
- pick_spec_binds known_bndrs [] = []
|
|
1169
|
+ pick_spec_binds _ [] = []
|
1166
|
1170
|
pick_spec_binds known_bndrs (bind:binds)
|
1167
|
1171
|
| all keep_me (rhssOfBind bind)
|
1168
|
|
- , let known_bndrs' = known_bndrs `extendVarSetList` bindersOfBind bind
|
|
1172
|
+ , let known_bndrs' = known_bndrs `extendVarSetList` bindersOf bind
|
1169
|
1173
|
= bind : pick_spec_binds known_bndrs' binds
|
1170
|
1174
|
| otherwise
|
1171
|
1175
|
= pick_spec_binds known_bndrs binds
|
1172
|
1176
|
where
|
1173
|
|
- keep_me rhs = isEmptyVarSet (exprSomFreeVars bad_var rhs)
|
1174
|
|
- bad_var v = v `elemVarSet` orig_bndrs && not (bndr `elemVarSet` known_bndrs)
|
|
1177
|
+ keep_me rhs = isEmptyVarSet (exprSomeFreeVars bad_var rhs)
|
|
1178
|
+ bad_var v = v `elemVarSet` orig_bndrs && not (v `elemVarSet` known_bndrs)
|
|
1179
|
+
|
|
1180
|
+getCastedVar :: CoreExpr -> Maybe (Var, MCoercionR)
|
|
1181
|
+getCastedVar (Var v) = Just (v, MRefl)
|
|
1182
|
+getCastedVar (Cast (Var v) co) = Just (v, MCo co)
|
|
1183
|
+getCastedVar _ = Nothing
|
1175
|
1184
|
|
1176
|
1185
|
{-
|
1177
|
1186
|
where
|