Simon Peyton Jones pushed to branch wip/T26682 at Glasgow Haskell Compiler / GHC

Commits:

4 changed files:

Changes:

  • compiler/GHC/Core/Opt/Specialise.hs
    ... ... @@ -654,9 +654,7 @@ specProgram guts@(ModGuts { mg_module = this_mod
    654 654
                   -- Easiest thing is to do it all at once, as if all the top-level
    
    655 655
                   -- decls were mutually recursive
    
    656 656
            ; let top_env = SE { se_subst = Core.mkEmptySubst $
    
    657
    -                                        mkInScopeSetBndrs binds
    
    658
    -                                      --    mkInScopeSetList $
    
    659
    -                                      --  bindersOfBinds binds
    
    657
    +                                       mkInScopeSetBndrs binds
    
    660 658
                               , se_module = this_mod
    
    661 659
                               , se_rules  = rule_env
    
    662 660
                               , se_dflags = dflags }
    
    ... ... @@ -816,9 +814,12 @@ spec_imports env callers dict_binds calls
    816 814
         go :: SpecEnv -> [CallInfoSet] -> CoreM (SpecEnv, [CoreRule], [CoreBind])
    
    817 815
         go env [] = return (env, [], [])
    
    818 816
         go env (cis : other_calls)
    
    819
    -      = do { -- debugTraceMsg (text "specImport {" <+> ppr cis)
    
    817
    +      = do {
    
    818
    +--             debugTraceMsg (text "specImport {" <+> vcat [ ppr cis
    
    819
    +--                                                         , text "callers" <+> ppr callers
    
    820
    +--                                                         , text "dict_binds" <+> ppr dict_binds ])
    
    820 821
                ; (env, rules1, spec_binds1) <- spec_import env callers dict_binds cis
    
    821
    -           ; -- debugTraceMsg (text "specImport }" <+> ppr cis)
    
    822
    +--           ; debugTraceMsg (text "specImport }" <+> ppr cis)
    
    822 823
     
    
    823 824
                ; (env, rules2, spec_binds2) <- go env other_calls
    
    824 825
                ; return (env, rules1 ++ rules2, spec_binds1 ++ spec_binds2) }
    
    ... ... @@ -835,13 +836,18 @@ spec_import :: SpecEnv -- Passed in so that all top-level Ids are
    835 836
                          , [CoreBind] )  -- Specialised bindings
    
    836 837
     spec_import env callers dict_binds cis@(CIS fn _)
    
    837 838
       | isIn "specImport" fn callers
    
    838
    -  = return (env, [], [])  -- No warning.  This actually happens all the time
    
    839
    -                          -- when specialising a recursive function, because
    
    840
    -                          -- the RHS of the specialised function contains a recursive
    
    841
    -                          -- call to the original function
    
    839
    +  = do {
    
    840
    +--         debugTraceMsg (text "specImport1-bad" <+> (ppr fn $$ text "callers" <+> ppr callers))
    
    841
    +       ; return (env, [], []) }
    
    842
    +    -- No warning.  This actually happens all the time
    
    843
    +    -- when specialising a recursive function, because
    
    844
    +    -- the RHS of the specialised function contains a recursive
    
    845
    +    -- call to the original function
    
    842 846
     
    
    843 847
       | null good_calls
    
    844
    -  = return (env, [], [])
    
    848
    +  = do {
    
    849
    +--        debugTraceMsg (text "specImport1-no-good" <+> (ppr cis $$ text "dict_binds" <+> ppr dict_binds))
    
    850
    +       ; return (env, [], []) }
    
    845 851
     
    
    846 852
       | Just rhs <- canSpecImport dflags fn
    
    847 853
       = do {     -- Get rules from the external package state
    
    ... ... @@ -890,7 +896,10 @@ spec_import env callers dict_binds cis@(CIS fn _)
    890 896
            ; return (env, rules2 ++ rules1, final_binds) }
    
    891 897
     
    
    892 898
       | otherwise
    
    893
    -  = do { tryWarnMissingSpecs dflags callers fn good_calls
    
    899
    +  = do {
    
    900
    +--         debugTraceMsg (hang (text "specImport1-missed")
    
    901
    +--                          2 (vcat [ppr cis, text "can-spec" <+> ppr (canSpecImport dflags fn)]))
    
    902
    +       ; tryWarnMissingSpecs dflags callers fn good_calls
    
    894 903
            ; return (env, [], [])}
    
    895 904
     
    
    896 905
       where
    
    ... ... @@ -1455,7 +1464,9 @@ specBind top_lvl env (NonRec fn rhs) do_body
    1455 1464
     
    
    1456 1465
            ; (fn4, spec_defns, body_uds1) <- specDefn env body_uds fn3 rhs
    
    1457 1466
     
    
    1458
    -       ; let (free_uds, dump_dbs, float_all) = dumpBindUDs [fn4] body_uds1
    
    1467
    +       ; let can_float_this_one = exprIsTopLevelBindable rhs (idType fn)
    
    1468
    +                 -- exprIsTopLevelBindable: see Note [Care with unlifted bindings]
    
    1469
    +             (free_uds, dump_dbs, float_all) = dumpBindUDs can_float_this_one [fn4] body_uds1
    
    1459 1470
                  all_free_uds                    = free_uds `thenUDs` rhs_uds
    
    1460 1471
     
    
    1461 1472
                  pairs = spec_defns ++ [(fn4, rhs')]
    
    ... ... @@ -1471,10 +1482,8 @@ specBind top_lvl env (NonRec fn rhs) do_body
    1471 1482
                              = [mkDB $ NonRec b r | (b,r) <- pairs]
    
    1472 1483
                                ++ fromOL dump_dbs
    
    1473 1484
     
    
    1474
    -             can_float_this_one = exprIsTopLevelBindable rhs (idType fn)
    
    1475
    -             -- exprIsTopLevelBindable: see Note [Care with unlifted bindings]
    
    1476 1485
     
    
    1477
    -       ; if float_all && can_float_this_one then
    
    1486
    +       ; if float_all then
    
    1478 1487
                  -- Rather than discard the calls mentioning the bound variables
    
    1479 1488
                  -- we float this (dictionary) binding along with the others
    
    1480 1489
                   return ([], body', all_free_uds `snocDictBinds` final_binds)
    
    ... ... @@ -1509,7 +1518,7 @@ specBind top_lvl env (Rec pairs) do_body
    1509 1518
                                   <- specDefns rec_env uds2 (bndrs2 `zip` rhss)
    
    1510 1519
                             ; return (bndrs3, spec_defns3 ++ spec_defns2, uds3) }
    
    1511 1520
     
    
    1512
    -       ; let (final_uds, dumped_dbs, float_all) = dumpBindUDs bndrs1 uds3
    
    1521
    +       ; let (final_uds, dumped_dbs, float_all) = dumpBindUDs True bndrs1 uds3
    
    1513 1522
                  final_bind = recWithDumpedDicts (spec_defns3 ++ zip bndrs3 rhss')
    
    1514 1523
                                                  dumped_dbs
    
    1515 1524
     
    
    ... ... @@ -1630,7 +1639,6 @@ specCalls spec_imp env existing_rules calls_for_me fn rhs
    1630 1639
         dflags    = se_dflags env
    
    1631 1640
         this_mod  = se_module env
    
    1632 1641
         subst     = se_subst env
    
    1633
    -    in_scope  = Core.substInScopeSet subst
    
    1634 1642
             -- Figure out whether the function has an INLINE pragma
    
    1635 1643
             -- See Note [Inline specialisations]
    
    1636 1644
     
    
    ... ... @@ -1646,9 +1654,6 @@ specCalls spec_imp env existing_rules calls_for_me fn rhs
    1646 1654
           | otherwise
    
    1647 1655
           = inl_prag
    
    1648 1656
     
    
    1649
    -    not_in_scope :: InterestingVarFun
    
    1650
    -    not_in_scope v = isLocalVar v && not (v `elemInScopeSet` in_scope)
    
    1651
    -
    
    1652 1657
         ----------------------------------------------------------
    
    1653 1658
             -- Specialise to one particular call pattern
    
    1654 1659
         spec_call :: SpecInfo                         -- Accumulating parameter
    
    ... ... @@ -1662,47 +1667,34 @@ specCalls spec_imp env existing_rules calls_for_me fn rhs
    1662 1667
                      mk_extra_dfun_arg bndr | isTyVar bndr = UnspecType
    
    1663 1668
                                             | otherwise    = UnspecArg
    
    1664 1669
     
    
    1665
    -             -- Find qvars, the type variables to add to the binders for the rule
    
    1666
    -             -- Namely those free in `ty` that aren't in scope
    
    1667
    -             -- See (MP2) in Note [Specialising polymorphic dictionaries]
    
    1668
    -           ; let poly_qvars = scopedSort $ fvVarList $ specArgsFVs not_in_scope call_args
    
    1669
    -                 subst'     = subst `Core.extendSubstInScopeList` poly_qvars
    
    1670
    -                              -- Maybe we should clone the poly_qvars telescope?
    
    1671
    -
    
    1672
    -             -- Any free Ids will have caused the call to be dropped
    
    1673
    -           ; massertPpr (all isTyCoVar poly_qvars)
    
    1674
    -                        (ppr fn $$ ppr all_call_args $$ ppr poly_qvars)
    
    1675
    -
    
    1676
    -           ; (useful, subst'', rule_bndrs, rule_lhs_args, spec_bndrs, dx_binds, spec_args)
    
    1677
    -                 <- specHeader subst' rhs_bndrs all_call_args
    
    1678
    -           ; let all_rule_bndrs = poly_qvars ++ rule_bndrs
    
    1679
    -                 env' = env { se_subst = subst'' }
    
    1670
    +           ; (useful, subst', rule_bndrs, rule_lhs_args, spec_bndrs, dx_binds, spec_args)
    
    1671
    +                 <- specHeader subst rhs_bndrs all_call_args
    
    1672
    +           ; let env' = env { se_subst = subst' }
    
    1680 1673
     
    
    1681 1674
                -- Check for (a) usefulness and (b) not already covered
    
    1682 1675
                -- See (SC1) in Note [Specialisations already covered]
    
    1683 1676
                ; let all_rules = rules_acc ++ existing_rules
    
    1684 1677
                      -- all_rules: we look both in the rules_acc (generated by this invocation
    
    1685 1678
                      --   of specCalls), and in existing_rules (passed in to specCalls)
    
    1686
    -                 already_covered = alreadyCovered env' all_rule_bndrs fn
    
    1679
    +                 already_covered = alreadyCovered env' rule_bndrs fn
    
    1687 1680
                                                       rule_lhs_args is_active all_rules
    
    1688 1681
     
    
    1689
    -{-         ; pprTrace "spec_call" (vcat
    
    1690
    -                [ text "fun:       "  <+> ppr fn
    
    1691
    -                , text "call info: "  <+> ppr _ci
    
    1692
    -                , text "useful:    "  <+> ppr useful
    
    1693
    -                , text "already_covered:"  <+> ppr already_covered
    
    1694
    -                , text "poly_qvars: " <+> ppr poly_qvars
    
    1695
    -                , text "useful:    "  <+> ppr useful
    
    1696
    -                , text "all_rule_bndrs:"  <+> ppr all_rule_bndrs
    
    1697
    -                , text "rule_lhs_args:"  <+> ppr rule_lhs_args
    
    1698
    -                , text "spec_bndrs:" <+> ppr spec_bndrs
    
    1699
    -                , text "dx_binds:"   <+> ppr dx_binds
    
    1700
    -                , text "spec_args: "  <+> ppr spec_args
    
    1701
    -                , text "rhs_bndrs"    <+> ppr rhs_bndrs
    
    1702
    -                , text "rhs_body"     <+> ppr rhs_body
    
    1703
    -                , text "subst''" <+> ppr subst'' ]) $
    
    1704
    -             return ()
    
    1705
    --}
    
    1682
    +--         ; pprTrace "spec_call" (vcat
    
    1683
    +--                [ text "fun:       "  <+> ppr fn
    
    1684
    +--                , text "call info: "  <+> ppr _ci
    
    1685
    +--                , text "useful:    "  <+> ppr useful
    
    1686
    +--                , text "already_covered:"  <+> ppr already_covered
    
    1687
    +--                , text "useful:    "  <+> ppr useful
    
    1688
    +--                , text "rule_bndrs:"  <+> ppr (sep (map (pprBndr LambdaBind) rule_bndrs))
    
    1689
    +--                , text "rule_lhs_args:"  <+> ppr rule_lhs_args
    
    1690
    +--                , text "spec_bndrs:" <+> ppr (sep (map (pprBndr LambdaBind) spec_bndrs))
    
    1691
    +--                , text "dx_binds:"   <+> ppr dx_binds
    
    1692
    +--                , text "spec_args: "  <+> ppr spec_args
    
    1693
    +--                , text "rhs_bndrs"    <+> ppr (sep (map (pprBndr LambdaBind) rhs_bndrs))
    
    1694
    +--                , text "rhs_body"     <+> ppr rhs_body
    
    1695
    +--                , text "subst'" <+> ppr subst'
    
    1696
    +--                ]) $ return ()
    
    1697
    +
    
    1706 1698
     
    
    1707 1699
                ; if not useful          -- No useful specialisation
    
    1708 1700
                     || already_covered  -- Useful, but done already
    
    ... ... @@ -1716,23 +1708,15 @@ specCalls spec_imp env existing_rules calls_for_me fn rhs
    1716 1708
                  -- Run the specialiser on the specialised RHS
    
    1717 1709
                ; (rhs_body', rhs_uds) <- specExpr env'' rhs_body
    
    1718 1710
     
    
    1719
    -{-         ; pprTrace "spec_call2" (vcat
    
    1720
    -                 [ text "fun:" <+> ppr fn
    
    1721
    -                 , text "rhs_body':" <+> ppr rhs_body' ]) $
    
    1722
    -             return ()
    
    1723
    --}
    
    1724
    -
    
    1725 1711
                -- Make the RHS of the specialised function
    
    1726 1712
                ; let spec_rhs_bndrs = spec_bndrs ++ inner_rhs_bndrs'
    
    1727
    -                 (rhs_uds1, inner_dumped_dbs) = dumpUDs spec_rhs_bndrs rhs_uds
    
    1728
    -                 (rhs_uds2, outer_dumped_dbs) = dumpUDs poly_qvars (dx_binds `consDictBinds` rhs_uds1)
    
    1729
    -                 -- dx_binds comes from the arguments to the call, and so can mention
    
    1730
    -                 -- poly_qvars but no other local binders
    
    1731
    -                 spec_rhs = mkLams poly_qvars               $
    
    1732
    -                            wrapDictBindsE outer_dumped_dbs $
    
    1733
    -                            mkLams spec_rhs_bndrs           $
    
    1713
    +                 (rhs_uds2, inner_dumped_dbs) = dumpUDs spec_rhs_bndrs $
    
    1714
    +                                                dx_binds `consDictBinds` rhs_uds
    
    1715
    +                 -- dx_binds comes from the arguments to the call,
    
    1716
    +                 -- and so can mention poly_qvars but no other local binders
    
    1717
    +                 spec_rhs = mkLams spec_rhs_bndrs           $
    
    1734 1718
                                 wrapDictBindsE inner_dumped_dbs rhs_body'
    
    1735
    -                 rule_rhs_args = poly_qvars ++ spec_bndrs
    
    1719
    +                 rule_rhs_args = spec_bndrs
    
    1736 1720
     
    
    1737 1721
                      -- Maybe add a void arg to the specialised function,
    
    1738 1722
                      -- to avoid unlifted bindings
    
    ... ... @@ -1787,7 +1771,7 @@ specCalls spec_imp env existing_rules calls_for_me fn rhs
    1787 1771
                                          text "SPEC"
    
    1788 1772
     
    
    1789 1773
                     spec_rule = mkSpecRule dflags this_mod True inl_act
    
    1790
    -                                    herald fn all_rule_bndrs rule_lhs_args
    
    1774
    +                                    herald fn rule_bndrs rule_lhs_args
    
    1791 1775
                                         (mkVarApps (Var spec_fn) rule_rhs_args1)
    
    1792 1776
     
    
    1793 1777
                     _rule_trace_doc = vcat [ ppr fn <+> dcolon <+> ppr fn_type
    
    ... ... @@ -1798,8 +1782,12 @@ specCalls spec_imp env existing_rules calls_for_me fn rhs
    1798 1782
                                            , text "existing" <+> ppr existing_rules
    
    1799 1783
                                            ]
    
    1800 1784
     
    
    1801
    -           ; -- pprTrace "spec_call: rule" _rule_trace_doc
    
    1802
    -             return ( spec_rule            : rules_acc
    
    1785
    +--           ; pprTrace "spec_call: rule" (vcat [ -- text "poly_qvars" <+> ppr poly_qvars
    
    1786
    +--                                                text "rule_bndrs" <+> ppr rule_bndrs
    
    1787
    +--                                              , text "rule_lhs_args" <+> ppr rule_lhs_args
    
    1788
    +--                                              , text "all_call_args" <+> ppr all_call_args
    
    1789
    +--                                              , ppr spec_rule ]) $
    
    1790
    +           ; return ( spec_rule            : rules_acc
    
    1803 1791
                         , (spec_fn, spec_rhs1) : pairs_acc
    
    1804 1792
                         , rhs_uds2 `thenUDs` uds_acc
    
    1805 1793
                         ) } }
    
    ... ... @@ -1946,6 +1934,16 @@ floating to top level anyway; but that is hard to spot (since we don't know what
    1946 1934
     the non-top-level in-scope binders are) and rare (since the binding must satisfy
    
    1947 1935
     Note [Core let-can-float invariant] in GHC.Core).
    
    1948 1936
     
    
    1937
    +Arguably we'd be better off if we had left that `x` in the RHS of `n`, thus
    
    1938
    +    f x = let n::Natural = let x::ByteArray# = <some literal> in
    
    1939
    +                           NB x
    
    1940
    +          in wombat @192827 (n |> co)
    
    1941
    +Now we could float `n` happily.  But that's in conflict with exposing the `NB`
    
    1942
    +data constructor in the body of the `let`, so I'm leaving this unresolved.
    
    1943
    +
    
    1944
    +Another case came up in #26682, where the binding had an unlifted sum type
    
    1945
    +(# Word# | ByteArray# #), itself arising from an UNPACK pragma.  Test case
    
    1946
    +T26682.
    
    1949 1947
     
    
    1950 1948
     Note [Specialising Calls]
    
    1951 1949
     ~~~~~~~~~~~~~~~~~~~~~~~~~
    
    ... ... @@ -2593,12 +2591,22 @@ specHeader subst _ [] = pure (False, subst, [], [], [], [], [])
    2593 2591
     -- 'a->T1', as well as a LHS argument for the resulting RULE and unfolding
    
    2594 2592
     -- details.
    
    2595 2593
     specHeader subst (bndr:bndrs) (SpecType ty : args)
    
    2596
    -  = do { let subst1 = Core.extendTvSubst subst bndr ty
    
    2597
    -       ; (useful, subst2, rule_bs, rule_args, spec_bs, dx, spec_args)
    
    2598
    -             <- specHeader subst1 bndrs args
    
    2599
    -       ; pure ( useful, subst2
    
    2600
    -              , rule_bs,     Type ty : rule_args
    
    2601
    -              , spec_bs, dx, Type ty : spec_args ) }
    
    2594
    +  = do { -- Find free_tvs, the type variables to add to the binders for the rule
    
    2595
    +         -- Namely those free in `ty` that aren't in scope
    
    2596
    +         -- See (MP2) in Note [Specialising polymorphic dictionaries]
    
    2597
    +         let in_scope = Core.substInScopeSet subst
    
    2598
    +             not_in_scope tv = not (tv `elemInScopeSet` in_scope)
    
    2599
    +             free_tvs = scopedSort $ fvVarList $
    
    2600
    +                        filterFV not_in_scope  $
    
    2601
    +                        tyCoFVsOfType ty
    
    2602
    +             subst1 = subst `Core.extendSubstInScopeList` free_tvs
    
    2603
    +
    
    2604
    +       ; let subst2 = Core.extendTvSubst subst1 bndr ty
    
    2605
    +       ; (useful, subst3, rule_bs, rule_args, spec_bs, dx, spec_args)
    
    2606
    +             <- specHeader subst2 bndrs args
    
    2607
    +       ; pure ( useful, subst3
    
    2608
    +              , free_tvs ++ rule_bs,     Type ty : rule_args
    
    2609
    +              , free_tvs ++ spec_bs, dx, Type ty : spec_args ) }
    
    2602 2610
     
    
    2603 2611
     -- Next we have a type that we don't want to specialise. We need to perform
    
    2604 2612
     -- a substitution on it (in case the type refers to 'a'). Additionally, we need
    
    ... ... @@ -2682,7 +2690,7 @@ bindAuxiliaryDict subst orig_dict_id fresh_dict_id dict_arg
    2682 2690
       -- don’t bother creating a new dict binding; just substitute
    
    2683 2691
       | exprIsTrivial dict_arg
    
    2684 2692
       , let subst' = Core.extendSubst subst orig_dict_id dict_arg
    
    2685
    -  = -- pprTrace "bindAuxiliaryDict:trivial" (ppr orig_dict_id <+> ppr dict_id) $
    
    2693
    +  = -- pprTrace "bindAuxiliaryDict:trivial" (ppr orig_dict_id <+> ppr dict_arg) $
    
    2686 2694
         (subst', Nothing, dict_arg)
    
    2687 2695
     
    
    2688 2696
       | otherwise  -- Non-trivial dictionary arg; make an auxiliary binding
    
    ... ... @@ -2978,7 +2986,8 @@ pprCallInfo fn (CI { ci_key = key })
    2978 2986
     
    
    2979 2987
     instance Outputable CallInfo where
    
    2980 2988
       ppr (CI { ci_key = key, ci_fvs = _fvs })
    
    2981
    -    = text "CI" <> braces (sep (map ppr key))
    
    2989
    +    = text "CI" <> braces (text "fvs" <+> ppr _fvs
    
    2990
    +                           $$ sep (map ppr key))
    
    2982 2991
     
    
    2983 2992
     unionCalls :: CallDetails -> CallDetails -> CallDetails
    
    2984 2993
     unionCalls c1 c2 = plusDVarEnv_C unionCallInfoSet c1 c2
    
    ... ... @@ -3394,38 +3403,49 @@ wrapDictBindsE dbs expr
    3394 3403
     
    
    3395 3404
     ----------------------
    
    3396 3405
     dumpUDs :: [CoreBndr] -> UsageDetails -> (UsageDetails, OrdList DictBind)
    
    3397
    --- Used at a lambda or case binder; just dump anything mentioning the binder
    
    3406
    +-- Used at binder; just dump anything mentioning the binder
    
    3398 3407
     dumpUDs bndrs uds@(MkUD { ud_binds = orig_dbs, ud_calls = orig_calls })
    
    3399 3408
       | null bndrs = (uds, nilOL)  -- Common in case alternatives
    
    3400 3409
       | otherwise  = -- pprTrace "dumpUDs" (vcat
    
    3401
    -                 --    [ text "bndrs" <+> ppr bndrs
    
    3402
    -                 --    , text "uds" <+> ppr uds
    
    3403
    -                 --    , text "free_uds" <+> ppr free_uds
    
    3404
    -                 --    , text "dump-dbs" <+> ppr dump_dbs ]) $
    
    3410
    +                 --   [ text "bndrs" <+> ppr bndrs
    
    3411
    +                 --   , text "uds" <+> ppr uds
    
    3412
    +                 --   , text "free_uds" <+> ppr free_uds
    
    3413
    +                 --   , text "dump_dbs" <+> ppr dump_dbs ]) $
    
    3405 3414
                      (free_uds, dump_dbs)
    
    3406 3415
       where
    
    3407 3416
         free_uds = uds { ud_binds = free_dbs, ud_calls = free_calls }
    
    3408 3417
         bndr_set = mkVarSet bndrs
    
    3409 3418
         (free_dbs, dump_dbs, dump_set) = splitDictBinds orig_dbs bndr_set
    
    3410
    -    free_calls = deleteCallsMentioning dump_set $   -- Drop calls mentioning bndr_set on the floor
    
    3411
    -                 deleteCallsFor bndrs orig_calls    -- Discard calls for bndr_set; there should be
    
    3412
    -                                                    -- no calls for any of the dicts in dump_dbs
    
    3413 3419
     
    
    3414
    -dumpBindUDs :: [CoreBndr] -> UsageDetails -> (UsageDetails, OrdList DictBind, Bool)
    
    3420
    +    -- Delete calls:
    
    3421
    +    --   * For any binder in `bndrs`
    
    3422
    +    --   * That mention a dictionary bound in `dump_set`
    
    3423
    +    -- These variables aren't in scope "above" the binding and the `dump_dbs`,
    
    3424
    +    -- so no call should mention them.  (See #26682.)
    
    3425
    +    free_calls = deleteCallsMentioning dump_set $
    
    3426
    +                 deleteCallsFor bndrs orig_calls
    
    3427
    +
    
    3428
    +dumpBindUDs :: Bool   -- Main binding can float to top
    
    3429
    +            -> [CoreBndr] -> UsageDetails
    
    3430
    +            -> (UsageDetails, OrdList DictBind, Bool)
    
    3415 3431
     -- Used at a let(rec) binding.
    
    3416
    --- We return a boolean indicating whether the binding itself is mentioned,
    
    3417
    --- directly or indirectly, by any of the ud_calls; in that case we want to
    
    3418
    --- float the binding itself;
    
    3419
    --- See Note [Floated dictionary bindings]
    
    3420
    -dumpBindUDs bndrs (MkUD { ud_binds = orig_dbs, ud_calls = orig_calls })
    
    3421
    -  = -- pprTrace "dumpBindUDs" (ppr bndrs $$ ppr free_uds $$ ppr dump_dbs $$ ppr float_all) $
    
    3422
    -    (free_uds, dump_dbs, float_all)
    
    3432
    +-- We return a boolean indicating whether the binding itself
    
    3433
    +--    is mentioned, directly or indirectly, by any of the ud_calls;
    
    3434
    +--    in that case we want to float the binding itself.
    
    3435
    +--    See Note [Floated dictionary bindings]
    
    3436
    +-- If the boolean is True, then the returned ud_calls can mention `bndrs`;
    
    3437
    +-- if False, then returned ud_calls must not mention `bndrs`
    
    3438
    +dumpBindUDs can_float_bind bndrs (MkUD { ud_binds = orig_dbs, ud_calls = orig_calls })
    
    3439
    +  = ( MkUD { ud_binds = free_dbs, ud_calls = free_calls2 }
    
    3440
    +    , dump_dbs
    
    3441
    +    , can_float_bind && calls_mention_bndrs )
    
    3423 3442
       where
    
    3424
    -    free_uds = MkUD { ud_binds = free_dbs, ud_calls = free_calls }
    
    3425 3443
         bndr_set = mkVarSet bndrs
    
    3426 3444
         (free_dbs, dump_dbs, dump_set) = splitDictBinds orig_dbs bndr_set
    
    3427
    -    free_calls = deleteCallsFor bndrs orig_calls
    
    3428
    -    float_all = dump_set `intersectsVarSet` callDetailsFVs free_calls
    
    3445
    +    free_calls1 = deleteCallsFor bndrs orig_calls
    
    3446
    +    calls_mention_bndrs = dump_set `intersectsVarSet` callDetailsFVs free_calls1
    
    3447
    +    free_calls2 | can_float_bind = free_calls1
    
    3448
    +                | otherwise      = deleteCallsMentioning dump_set free_calls1
    
    3429 3449
     
    
    3430 3450
     callsForMe :: Id -> UsageDetails -> (UsageDetails, [CallInfo])
    
    3431 3451
     callsForMe fn uds@MkUD { ud_binds = orig_dbs, ud_calls = orig_calls }
    

  • testsuite/tests/simplCore/should_compile/T26682.hs
    1
    +{-# LANGUAGE Haskell2010 #-}
    
    2
    +
    
    3
    +{-# LANGUAGE AllowAmbiguousTypes #-}
    
    4
    +{-# LANGUAGE BangPatterns #-}
    
    5
    +{-# LANGUAGE DataKinds #-}
    
    6
    +{-# LANGUAGE PolyKinds #-}
    
    7
    +{-# LANGUAGE StandaloneKindSignatures #-}
    
    8
    +{-# LANGUAGE TypeApplications #-}
    
    9
    +{-# LANGUAGE TypeFamilies #-}
    
    10
    +
    
    11
    +{-# OPTIONS_GHC -fspecialise-aggressively #-}
    
    12
    +
    
    13
    +-- This is the result of @sheaf's work in minimising
    
    14
    +-- @mikolaj's original bug report for #26682
    
    15
    +
    
    16
    +module T26682 ( tensorADOnceMnistTests2 ) where
    
    17
    +
    
    18
    +import Prelude
    
    19
    +
    
    20
    +import Data.Proxy
    
    21
    +  ( Proxy (Proxy) )
    
    22
    +
    
    23
    +import GHC.TypeNats
    
    24
    +import Data.Kind
    
    25
    +
    
    26
    +import T26682a
    
    27
    +
    
    28
    +
    
    29
    +data Concrete2 x = Concrete2
    
    30
    +
    
    31
    +instance Eq ( Concrete2 a ) where
    
    32
    +  _ == _ = error "no"
    
    33
    +  {-# OPAQUE (==) #-}
    
    34
    +
    
    35
    +type X :: Type -> TK
    
    36
    +type family X a
    
    37
    +
    
    38
    +type instance X (target y) = y
    
    39
    +type instance X (a, b) = TKProduct (X a) (X b)
    
    40
    +type instance X (a, b, c) = TKProduct (TKProduct (X a) (X b)) (X c)
    
    41
    +
    
    42
    +tensorADOnceMnistTests2 :: Int -> Bool
    
    43
    +tensorADOnceMnistTests2 seed0 =
    
    44
    +  withSomeSNat 999 $ \ _ ->
    
    45
    +    let seed1 =
    
    46
    +          randomValue2
    
    47
    +            @(Concrete2 (X (ADFcnnMnist2ParametersShaped Concrete2 101 101 Double Double)))
    
    48
    +            seed0
    
    49
    +        art = mnistTrainBench2VTOGradient3 seed1
    
    50
    +
    
    51
    +        gg :: Concrete2
    
    52
    +                (TKProduct
    
    53
    +                   (TKProduct
    
    54
    +                      (TKProduct
    
    55
    +                         (TKProduct (TKR2 2 (TKScalar Double)) (TKR2 1 (TKScalar Double)))
    
    56
    +                         (TKProduct (TKR2 2 (TKScalar Double)) (TKR2 1 (TKScalar Double))))
    
    57
    +                      (TKProduct (TKR2 2 (TKScalar Double)) (TKR2 1 (TKScalar Double))))
    
    58
    +                   (TKProduct (TKR 1 Double) (TKR 1 Double)))
    
    59
    +        gg = undefined
    
    60
    +        value1 = revInterpretArtifact2 art gg
    
    61
    +    in
    
    62
    +      value1 == value1
    
    63
    +
    
    64
    +mnistTrainBench2VTOGradient3
    
    65
    +  :: Int
    
    66
    +  -> AstArtifactRev2
    
    67
    +        (TKProduct
    
    68
    +           (XParams2 Double Double)
    
    69
    +           (TKProduct (TKR2 1 (TKScalar Double))
    
    70
    +                      (TKR2 1 (TKScalar Double))))
    
    71
    +        (TKScalar Double)
    
    72
    +mnistTrainBench2VTOGradient3 !_
    
    73
    +  | Dict0 <- lemTKScalarAllNumAD2 (Proxy @Double)
    
    74
    +  = undefined
    
    75
    +
    
    76
    +type ADFcnnMnist2ParametersShaped
    
    77
    +       (target :: TK -> Type) (widthHidden :: Nat) (widthHidden2 :: Nat) r q =
    
    78
    +  ( ( target (TKS '[widthHidden, 784] r)
    
    79
    +    , target (TKS '[widthHidden] r) )
    
    80
    +  , ( target (TKS '[widthHidden2, widthHidden] q)
    
    81
    +    , target (TKS '[widthHidden2] r) )
    
    82
    +  , ( target (TKS '[10, widthHidden2] r)
    
    83
    +    , target (TKS '[10] r) )
    
    84
    +  )
    
    85
    +
    
    86
    +-- | The differentiable type of all trainable parameters of this nn.
    
    87
    +type ADFcnnMnist2Parameters (target :: TK -> Type) r q =
    
    88
    +  ( ( target (TKR 2 r)
    
    89
    +    , target (TKR 1 r) )
    
    90
    +  , ( target (TKR 2 q)
    
    91
    +    , target (TKR 1 r) )
    
    92
    +  , ( target (TKR 2 r)
    
    93
    +    , target (TKR 1 r) )
    
    94
    +  )
    
    95
    +
    
    96
    +type XParams2 r q = X (ADFcnnMnist2Parameters Concrete2 r q)
    
    97
    +
    
    98
    +data AstArtifactRev2 x z = AstArtifactRev2
    
    99
    +
    
    100
    +revInterpretArtifact2
    
    101
    +  :: AstArtifactRev2 x z
    
    102
    +  -> Concrete2 x
    
    103
    +  -> Concrete2 z
    
    104
    +{-# OPAQUE revInterpretArtifact2 #-}
    
    105
    +revInterpretArtifact2 _ _ = error "no"

  • testsuite/tests/simplCore/should_compile/T26682a.hs
    1
    +{-# LANGUAGE Haskell2010 #-}
    
    2
    +
    
    3
    +{-# LANGUAGE RankNTypes #-}
    
    4
    +{-# LANGUAGE AllowAmbiguousTypes #-}
    
    5
    +{-# LANGUAGE BangPatterns #-}
    
    6
    +{-# LANGUAGE DataKinds #-}
    
    7
    +{-# LANGUAGE FlexibleInstances #-}
    
    8
    +{-# LANGUAGE GADTs #-}
    
    9
    +{-# LANGUAGE PolyKinds #-}
    
    10
    +{-# LANGUAGE ScopedTypeVariables #-}
    
    11
    +{-# LANGUAGE StandaloneKindSignatures #-}
    
    12
    +{-# LANGUAGE TypeApplications #-}
    
    13
    +{-# LANGUAGE TypeData #-}
    
    14
    +{-# LANGUAGE TypeFamilies #-}
    
    15
    +{-# LANGUAGE TypeOperators #-}
    
    16
    +{-# LANGUAGE UndecidableSuperClasses #-}
    
    17
    +{-# LANGUAGE UndecidableInstances #-}
    
    18
    +
    
    19
    +module T26682a
    
    20
    +  ( TK(..), TKR, TKS, TKX
    
    21
    +  , Dict0(..)
    
    22
    +  , randomValue2
    
    23
    +  , lemTKScalarAllNumAD2
    
    24
    +  ) where
    
    25
    +
    
    26
    +import Prelude
    
    27
    +
    
    28
    +
    
    29
    +import GHC.TypeLits ( KnownNat(..), Nat, SNat )
    
    30
    +import Data.Kind ( Type, Constraint )
    
    31
    +import Data.Typeable ( Typeable )
    
    32
    +import Data.Proxy ( Proxy )
    
    33
    +
    
    34
    +import Type.Reflection
    
    35
    +import Data.Type.Equality
    
    36
    +
    
    37
    +ifDifferentiable2 :: forall r a. Typeable r
    
    38
    +                 => (Num r => a) -> a -> a
    
    39
    +{-# INLINE ifDifferentiable2 #-}
    
    40
    +ifDifferentiable2 ra _
    
    41
    +  | Just Refl <- testEquality (typeRep @r) (typeRep @Double) = ra
    
    42
    +ifDifferentiable2 ra _
    
    43
    +  | Just Refl <- testEquality (typeRep @r) (typeRep @Float) = ra
    
    44
    +ifDifferentiable2 _ a = a
    
    45
    +
    
    46
    +data Dict0 c where
    
    47
    +  Dict0 :: c => Dict0 c
    
    48
    +
    
    49
    +type ShS2 :: [Nat] -> Type
    
    50
    +data ShS2 ns where
    
    51
    +  Z :: ShS2 '[]
    
    52
    +  S :: {-# UNPACK #-} !( SNat n ) -> !( ShS2 ns ) -> ShS2 (n ': ns)
    
    53
    +
    
    54
    +type KnownShS2 :: [Nat] -> Constraint
    
    55
    +class KnownShS2 ns where
    
    56
    +  knownShS2 :: ShS2 ns
    
    57
    +
    
    58
    +instance KnownShS2 '[] where
    
    59
    +  knownShS2 = Z
    
    60
    +instance ( KnownNat n, KnownShS2 ns ) => KnownShS2 ( n ': ns ) where
    
    61
    +  knownShS2 =
    
    62
    +    case natSing @n of
    
    63
    +      !i ->
    
    64
    +        case knownShS2 @ns of
    
    65
    +          !j ->
    
    66
    +            S i j
    
    67
    +
    
    68
    +type RandomValue2 :: Type -> Constraint
    
    69
    +class RandomValue2 vals where
    
    70
    +  randomValue2 :: Int -> Int
    
    71
    +
    
    72
    +
    
    73
    +type IsDouble :: Type -> Constraint
    
    74
    +type family IsDouble a where
    
    75
    +  IsDouble Double = ( () :: Constraint )
    
    76
    +
    
    77
    +class ( Typeable r, IsDouble r ) => NumScalar2 r
    
    78
    +instance ( Typeable r, IsDouble r ) => NumScalar2 r
    
    79
    +
    
    80
    +instance forall sh r target. (KnownShS2 sh, NumScalar2 r)
    
    81
    +         => RandomValue2 (target (TKS sh r)) where
    
    82
    +  randomValue2 g =
    
    83
    +    ifDifferentiable2 @r
    
    84
    +      ( case knownShS2 @sh of
    
    85
    +          !_ -> g )
    
    86
    +      g
    
    87
    +
    
    88
    +instance (RandomValue2 (target a), RandomValue2 (target b))
    
    89
    +         => RandomValue2 (target (TKProduct a b)) where
    
    90
    +  randomValue2 g =
    
    91
    +    let g1 = randomValue2 @(target a) g
    
    92
    +        g2 = randomValue2 @(target b) g1
    
    93
    +    in g2
    
    94
    +
    
    95
    +lemTKScalarAllNumAD2 :: Proxy r -> Dict0 ( IsDouble r )
    
    96
    +lemTKScalarAllNumAD2 _ = undefined
    
    97
    +{-# OPAQUE lemTKScalarAllNumAD2 #-}
    
    98
    +
    
    99
    +
    
    100
    +type data TK =
    
    101
    +    TKScalar Type
    
    102
    +  | TKR2 Nat TK
    
    103
    +  | TKS2 [Nat] TK
    
    104
    +  | TKX2 [Maybe Nat] TK
    
    105
    +  | TKProduct TK TK
    
    106
    +
    
    107
    +type TKR n r = TKR2 n (TKScalar r)
    
    108
    +type TKS sh r = TKS2 sh (TKScalar r)
    
    109
    +type TKX sh r = TKX2 sh (TKScalar r)

  • testsuite/tests/simplCore/should_compile/all.T
    ... ... @@ -563,3 +563,4 @@ test('T26115', [grep_errmsg(r'DFun')], compile, ['-O -ddump-simpl -dsuppress-uni
    563 563
     test('T26116', normal, compile, ['-O -ddump-rules'])
    
    564 564
     test('T26117', [grep_errmsg(r'==')], compile, ['-O -ddump-simpl -dsuppress-uniques'])
    
    565 565
     test('T26349',  normal, compile, ['-O -ddump-rules'])
    
    566
    +test('T26682',  normal, multimod_compile, ['T26682', '-O'])