Simon Peyton Jones pushed to branch wip/T26425 at Glasgow Haskell Compiler / GHC

Commits:

2 changed files:

Changes:

  • compiler/GHC/Core/Opt/OccurAnal.hs
    ... ... @@ -759,62 +759,62 @@ rest of 'OccInfo' until it goes on the binder.
    759 759
     
    
    760 760
     Note [Join arity prediction based on joinRhsArity]
    
    761 761
     ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
    
    762
    -In general, the join arity from tail occurrences of a join point (O) may be
    
    763
    -higher or lower than the manifest join arity of the join body (M). E.g.,
    
    762
    +In general, the join arity from tail occurrences of a join point (OAr) may be
    
    763
    +higher or lower than the manifest join arity of the join body (MAr). E.g.,
    
    764 764
     
    
    765
    -  -- M > O:
    
    766
    -  let f x y = x + y              -- M = 2
    
    767
    -  in if b then f 1 else f 2      -- O = 1
    
    765
    +  -- MAr > Oar:
    
    766
    +  let f x y = x + y              -- MAr = 2
    
    767
    +  in if b then f 1 else f 2      -- OAr = 1
    
    768 768
       ==> { Contify for join arity 1 }
    
    769 769
       join f x = \y -> x + y
    
    770 770
       in if b then jump f 1 else jump f 2
    
    771 771
     
    
    772
    -  -- M < O
    
    773
    -  let f = id                     -- M = 0
    
    774
    -  in if ... then f 12 else f 13  -- O = 1
    
    772
    +  -- MAr < Oar
    
    773
    +  let f = id                     -- MAr = 0
    
    774
    +  in if ... then f 12 else f 13  -- OAr = 1
    
    775 775
       ==> { Contify for join arity 1, eta-expand f }
    
    776 776
       join f x = id x
    
    777 777
       in if b then jump f 12 else jump f 13
    
    778 778
     
    
    779
    -But for *recursive* let, it is crucial that both arities match up, consider
    
    779
    +But for *recursive* let, it is crucial MAr=OAr.  Consider:
    
    780 780
     
    
    781 781
       letrec f x y = if ... then f x else True
    
    782 782
       in f 42
    
    783 783
     
    
    784
    -Here, M=2 but O=1. If we settled for a joinrec arity of 1, the recursive jump
    
    784
    +Here, MAr=2 but OAr=1. If we settled for a joinrec arity of 1, the recursive jump
    
    785 785
     would not happen in a tail context! Contification is invalid here.
    
    786
    -So indeed it is crucial to demand that M=O.
    
    786
    +So indeed it is crucial to demand that MAr=OAr.
    
    787 787
     
    
    788
    -(Side note: Actually, we could be more specific: Let O1 be the join arity of
    
    789
    -occurrences from the letrec RHS and O2 the join arity from the let body. Then
    
    790
    -we need M=O1 and M<=O2 and could simply eta-expand the RHS to match O2 later.
    
    791
    -M=O is the specific case where we don't want to eta-expand. Neither the join
    
    788
    +(Side note: Actually, we could be more specific: Let OAr1 be the join arity of
    
    789
    +occurrences from the letrec RHS and OAr2 the join arity from the let body. Then
    
    790
    +we need MAr=OAr1 and MAr<=OAr2 and could simply eta-expand the RHS to match OAr2 later.
    
    791
    +MAr=OAr is the specific case where we don't want to eta-expand. Neither the join
    
    792 792
     points paper nor GHC does this at the moment.)
    
    793 793
     
    
    794 794
     We can capitalise on this observation and conclude that *if* f could become a
    
    795
    -joinrec (without eta-expansion), it will have join arity M.
    
    796
    -Now, M is just the result of 'joinRhsArity', a rather simple, local analysis.
    
    795
    +joinrec (without eta-expansion), it will have join arity MAr.
    
    796
    +Now, MAr is just the result of 'joinRhsArity', a rather simple, local analysis.
    
    797 797
     It is also the join arity inside the 'TailUsageDetails' returned by
    
    798 798
     'occAnalLamTail', so we can predict join arity without doing any fixed-point
    
    799 799
     iteration or really doing any deep traversal of let body or RHS at all.
    
    800
    -We check for M in the 'adjustTailUsage' call inside 'tagRecBinders'.
    
    800
    +We check for MAr in the 'adjustTailUsage' call inside 'tagRecBinders'.
    
    801 801
     
    
    802 802
     All this is quite apparent if you look at the contification transformation in
    
    803 803
     Fig. 5 of "Compiling without Continuations" (which does not account for
    
    804 804
     eta-expansion at all, mind you). The letrec case looks like this
    
    805
    -
    
    805
    +n
    
    806 806
       letrec f = /\as.\xs. L[us] in L'[es]
    
    807 807
         ... and a bunch of conditions establishing that f only occurs
    
    808 808
             in app heads of join arity (len as + len xs) inside us and es ...
    
    809 809
     
    
    810
    -The syntactic form `/\as.\xs. L[us]` forces M=O iff `f` occurs in `us`. However,
    
    810
    +The syntactic form `/\as.\xs. L[us]` forces MAr=OAr iff `f` occurs in `us`. However,
    
    811 811
     for non-recursive functions, this is the definition of contification from the
    
    812 812
     paper:
    
    813 813
     
    
    814 814
       let f = /\as.\xs.u in L[es]     ... conditions ...
    
    815 815
     
    
    816
    -Note that u could be a lambda itself, as we have seen. No relationship between M
    
    817
    -and O to exploit here.
    
    816
    +Note that u could be a lambda itself, as we have seen. No relationship between MAr
    
    817
    +and OAr to exploit here.
    
    818 818
     
    
    819 819
     Note [Join points and unfoldings/rules]
    
    820 820
     ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
    
    ... ... @@ -998,7 +998,8 @@ occAnalBind !env lvl ire (NonRec bndr rhs) thing_inside combine
    998 998
     
    
    999 999
             -- Now analyse the body, adding the join point
    
    1000 1000
             -- into the environment with addJoinPoint
    
    1001
    -        !(WUD body_uds (occ, body)) = occAnalNonRecBody env bndr' $ \env ->
    
    1001
    +        env_body = addLocalLet env lvl bndr
    
    1002
    +        !(WUD body_uds (occ, body)) = occAnalNonRecBody env_body bndr' $ \env ->
    
    1002 1003
                                           thing_inside (addJoinPoint env bndr' rhs_uds)
    
    1003 1004
         in
    
    1004 1005
         if isDeadOcc occ     -- Drop dead code; see Note [Dead code]
    
    ... ... @@ -1012,8 +1013,8 @@ occAnalBind !env lvl ire (NonRec bndr rhs) thing_inside combine
    1012 1013
     
    
    1013 1014
       -- The normal case, including newly-discovered join points
    
    1014 1015
       -- Analyse the body and /then/ the RHS
    
    1015
    -  | WUD body_uds (occ,body) <- occAnalNonRecBody (addLocalLet env lvl bndr)
    
    1016
    -                                                 bndr thing_inside
    
    1016
    +  | let env_body = addLocalLet env lvl bndr
    
    1017
    +  , WUD body_uds (occ,body) <- occAnalNonRecBody env_body bndr thing_inside
    
    1017 1018
       = if isDeadOcc occ   -- Drop dead code; see Note [Dead code]
    
    1018 1019
         then WUD body_uds body
    
    1019 1020
         else let
    
    ... ... @@ -1059,7 +1060,7 @@ occAnalNonRecRhs !env lvl imp_rule_edges mb_join bndr rhs
    1059 1060
         rhs_ctxt = mkNonRecRhsCtxt lvl bndr unf
    
    1060 1061
     
    
    1061 1062
         -- See Note [Join arity prediction based on joinRhsArity]
    
    1062
    -    -- Match join arity O from mb_join_arity with manifest join arity M as
    
    1063
    +    -- Match join arity OAr from mb_join_arity with manifest join arity MAr as
    
    1063 1064
         -- returned by of occAnalLamTail. It's totally OK for them to mismatch;
    
    1064 1065
         -- hence adjust the UDs from the RHS
    
    1065 1066
     
    
    ... ... @@ -1769,7 +1770,7 @@ makeNode !env imp_rule_edges bndr_set (bndr, rhs)
    1769 1770
                                    -- here because that is what we are setting!
    
    1770 1771
         WTUD unf_tuds unf' = occAnalUnfolding rhs_env unf
    
    1771 1772
         adj_unf_uds = adjustTailArity (JoinPoint rhs_ja) unf_tuds
    
    1772
    -      -- `rhs_ja` is `joinRhsArity rhs` and is the prediction for source M
    
    1773
    +      -- `rhs_ja` is `joinRhsArity rhs` and is the prediction for source MAr
    
    1773 1774
           -- of Note [Join arity prediction based on joinRhsArity]
    
    1774 1775
     
    
    1775 1776
         --------- IMP-RULES --------
    
    ... ... @@ -1780,7 +1781,7 @@ makeNode !env imp_rule_edges bndr_set (bndr, rhs)
    1780 1781
     
    
    1781 1782
         --------- All rules --------
    
    1782 1783
         -- See Note [Join points and unfoldings/rules]
    
    1783
    -    -- `rhs_ja` is `joinRhsArity rhs'` and is the prediction for source M
    
    1784
    +    -- `rhs_ja` is `joinRhsArity rhs'` and is the prediction for source MAr
    
    1784 1785
         -- of Note [Join arity prediction based on joinRhsArity]
    
    1785 1786
         rules_w_uds :: [(CoreRule, UsageDetails, UsageDetails)]
    
    1786 1787
         rules_w_uds = [ (r,l,adjustTailArity (JoinPoint rhs_ja) rhs_wuds)
    
    ... ... @@ -2182,7 +2183,9 @@ occAnalLamTail :: OccEnv -> CoreExpr -> WithTailUsageDetails CoreExpr
    2182 2183
     -- See Note [Adjusting right-hand sides]
    
    2183 2184
     occAnalLamTail env expr
    
    2184 2185
       = let !(WUD usage expr') = occ_anal_lam_tail env expr
    
    2185
    -    in WTUD (TUD (joinRhsArity expr) usage) expr'
    
    2186
    +    in WTUD (TUD (joinRhsArity expr') usage) expr'
    
    2187
    +       -- If expr looks like (\x. let dead = e in \y. blah), where `dead` is dead
    
    2188
    +       -- then joinRhsArity expr' might exceed joinRhsArity expr
    
    2186 2189
     
    
    2187 2190
     occ_anal_lam_tail :: OccEnv -> CoreExpr -> WithUsageDetails CoreExpr
    
    2188 2191
     -- Does not markInsideLam etc for the outmost batch of lambdas
    
    ... ... @@ -3643,7 +3646,12 @@ data LocalOcc -- See Note [LocalOcc]
    3643 3646
                        -- Combining (AlwaysTailCalled 2) and (AlwaysTailCalled 3)
    
    3644 3647
                        -- gives NoTailCallInfo
    
    3645 3648
                   , lo_int_cxt :: !InterestingCxt }
    
    3649
    +
    
    3646 3650
         | ManyOccL !TailCallInfo
    
    3651
    +       -- Why do we need TailCallInfo on ManyOccL?
    
    3652
    +       -- Answer: recursive bindings are entered many times:
    
    3653
    +       --    rec { j x = ...j x'... } in j y
    
    3654
    +       -- See the uses of `andUDs` in `tagRecBinders`
    
    3647 3655
     
    
    3648 3656
     instance Outputable LocalOcc where
    
    3649 3657
       ppr (OneOccL { lo_n_br = n, lo_tail = tci })
    
    ... ... @@ -3678,7 +3686,7 @@ instance Outputable UsageDetails where
    3678 3686
     -- | TailUsageDetails captures the result of applying 'occAnalLamTail'
    
    3679 3687
     --   to a function `\xyz.body`. The TailUsageDetails pairs together
    
    3680 3688
     --   * the number of lambdas (including type lambdas: a JoinArity)
    
    3681
    ---   * UsageDetails for the `body` of the lambda, unadjusted by `adjustTailUsage`.
    
    3689
    +--   * UsageDetails for the `body` of the lambda, /unadjusted/ by `adjustTailUsage`.
    
    3682 3690
     -- If the binding turns out to be a join point with the indicated join
    
    3683 3691
     -- arity, this unadjusted usage details is just what we need; otherwise we
    
    3684 3692
     -- need to discard tail calls. That's what `adjustTailUsage` does.
    
    ... ... @@ -3865,8 +3873,6 @@ lookupOccInfoByUnique (UD { ud_env = env
    3865 3873
             | uniq `elemVarEnvByKey` z_tail = NoTailCallInfo
    
    3866 3874
             | otherwise                     = ti
    
    3867 3875
     
    
    3868
    -
    
    3869
    -
    
    3870 3876
     -------------------
    
    3871 3877
     -- See Note [Adjusting right-hand sides]
    
    3872 3878
     
    
    ... ... @@ -3876,21 +3882,22 @@ adjustNonRecRhs :: JoinPointHood
    3876 3882
     -- ^ This function concentrates shared logic between occAnalNonRecBind and the
    
    3877 3883
     -- AcyclicSCC case of occAnalRec.
    
    3878 3884
     -- It returns the adjusted rhs UsageDetails combined with the body usage
    
    3879
    -adjustNonRecRhs mb_join_arity rhs_wuds@(WTUD _ rhs)
    
    3880
    -  = WUD (adjustTailUsage mb_join_arity rhs_wuds) rhs
    
    3881
    -
    
    3885
    +adjustNonRecRhs mb_join_arity (WTUD (TUD rhs_ja uds) rhs)
    
    3886
    +  = WUD (adjustTailUsage exact_join rhs uds) rhs
    
    3887
    +  where
    
    3888
    +    exact_join = mb_join_arity == JoinPoint rhs_ja
    
    3882 3889
     
    
    3883
    -adjustTailUsage :: JoinPointHood
    
    3884
    -                -> WithTailUsageDetails CoreExpr    -- Rhs usage, AFTER occAnalLamTail
    
    3890
    +adjustTailUsage :: Bool        -- True <=> Exactly-matching join point; don't do markNonTail
    
    3891
    +                -> CoreExpr    -- Rhs usage, AFTER occAnalLamTail
    
    3892
    +                -> UsageDetails
    
    3885 3893
                     -> UsageDetails
    
    3886
    -adjustTailUsage mb_join_arity (WTUD (TUD rhs_ja uds) rhs)
    
    3894
    +adjustTailUsage exact_join rhs uds
    
    3887 3895
       = -- c.f. occAnal (Lam {})
    
    3888 3896
         markAllInsideLamIf (not one_shot) $
    
    3889 3897
         markAllNonTailIf (not exact_join) $
    
    3890 3898
         uds
    
    3891 3899
       where
    
    3892 3900
         one_shot   = isOneShotFun rhs
    
    3893
    -    exact_join = mb_join_arity == JoinPoint rhs_ja
    
    3894 3901
     
    
    3895 3902
     adjustTailArity :: JoinPointHood -> TailUsageDetails -> UsageDetails
    
    3896 3903
     adjustTailArity mb_rhs_ja (TUD ja usage)
    
    ... ... @@ -3937,8 +3944,9 @@ tagNonRecBinder lvl occ bndr
    3937 3944
     tagRecBinders :: TopLevelFlag           -- At top level?
    
    3938 3945
                   -> UsageDetails           -- Of body of let ONLY
    
    3939 3946
                   -> [NodeDetails]
    
    3940
    -              -> WithUsageDetails       -- Adjusted details for whole scope,
    
    3941
    -                                        -- with binders removed
    
    3947
    +              -> WithUsageDetails       -- Adjusted details for whole scope
    
    3948
    +                                        -- still including the binders;
    
    3949
    +                                        -- (they are removed by `addInScope`)
    
    3942 3950
                       [IdWithOccInfo]       -- Tagged binders
    
    3943 3951
     -- Substantially more complicated than non-recursive case. Need to adjust RHS
    
    3944 3952
     -- details *before* tagging binders (because the tags depend on the RHSes).
    
    ... ... @@ -3948,32 +3956,21 @@ tagRecBinders lvl body_uds details_s
    3948 3956
     
    
    3949 3957
          -- 1. See Note [Join arity prediction based on joinRhsArity]
    
    3950 3958
          --    Determine possible join-point-hood of whole group, by testing for
    
    3951
    -     --    manifest join arity M.
    
    3952
    -     --    This (re-)asserts that makeNode had made tuds for that same arity M!
    
    3959
    +     --    manifest join arity MAr.
    
    3960
    +     --    This (re-)asserts that makeNode had made tuds for that same arity MAr!
    
    3953 3961
          unadj_uds = foldr (andUDs . test_manifest_arity) body_uds details_s
    
    3954
    -     test_manifest_arity ND{nd_rhs = WTUD tuds rhs}
    
    3955
    -       = adjustTailArity (JoinPoint (joinRhsArity rhs)) tuds
    
    3962
    +     test_manifest_arity ND{nd_rhs = WTUD (TUD rhs_ja uds) rhs}
    
    3963
    +       = assertPpr (rhs_ja == joinRhsArity rhs) (ppr rhs_ja $$ ppr uds $$ ppr rhs) $
    
    3964
    +         uds
    
    3956 3965
     
    
    3966
    +     will_be_joins :: Bool
    
    3957 3967
          will_be_joins = decideRecJoinPointHood lvl unadj_uds bndrs
    
    3958 3968
     
    
    3959
    -     mb_join_arity :: Id -> JoinPointHood
    
    3960
    -     -- mb_join_arity: See Note [Join arity prediction based on joinRhsArity]
    
    3961
    -     -- This is the source O
    
    3962
    -     mb_join_arity bndr
    
    3963
    -         -- Can't use willBeJoinId_maybe here because we haven't tagged
    
    3964
    -         -- the binder yet (the tag depends on these adjustments!)
    
    3965
    -       | will_be_joins
    
    3966
    -       , AlwaysTailCalled arity <- lookupTailCallInfo unadj_uds bndr
    
    3967
    -       = JoinPoint arity
    
    3968
    -       | otherwise
    
    3969
    -       = assert (not will_be_joins) -- Should be AlwaysTailCalled if
    
    3970
    -         NotJoinPoint               -- we are making join points!
    
    3971
    -
    
    3972 3969
          -- 2. Adjust usage details of each RHS, taking into account the
    
    3973 3970
          --    join-point-hood decision
    
    3974
    -     rhs_udss' = [ adjustTailUsage (mb_join_arity bndr) rhs_wuds
    
    3971
    +     rhs_udss' = [ adjustTailUsage will_be_joins rhs rhs_uds
    
    3975 3972
                          -- Matching occAnalLamTail in makeNode
    
    3976
    -                 | ND { nd_bndr = bndr, nd_rhs = rhs_wuds } <- details_s ]
    
    3973
    +                 | ND { nd_rhs = WTUD (TUD _ rhs_uds) rhs } <- details_s ]
    
    3977 3974
     
    
    3978 3975
          -- 3. Compute final usage details from adjusted RHS details
    
    3979 3976
          adj_uds = foldr andUDs body_uds rhs_udss'
    
    ... ... @@ -3992,9 +3989,9 @@ setBinderOcc occ_info bndr
    3992 3989
       | otherwise                  = setIdOccInfo bndr occ_info
    
    3993 3990
     
    
    3994 3991
     -- | Decide whether some bindings should be made into join points or not, based
    
    3995
    --- on its occurrences. This is
    
    3992
    +-- on its occurrences.
    
    3996 3993
     -- Returns `False` if they can't be join points. Note that it's an
    
    3994
    +-- all-or-nothing decision: if multiple binders are given, they are
    
    3997 3995
     -- assumed to be mutually recursive.
    
    3998 3996
     --
    
    3999 3997
     -- It must, however, be a final decision. If we say `True` for 'f',
    

  • compiler/GHC/Core/Opt/Simplify/Iteration.hs
    ... ... @@ -4598,12 +4598,12 @@ mkLetUnfolding :: SimplEnv -> TopLevelFlag -> UnfoldingSource
    4598 4598
                    -> InId -> Bool    -- True <=> this is a join point
    
    4599 4599
                    -> OutExpr -> SimplM Unfolding
    
    4600 4600
     mkLetUnfolding env top_lvl src id is_join new_rhs
    
    4601
    -  | is_join
    
    4602
    -  , UnfNever <- guidance
    
    4603
    -  = -- For large join points, don't keep an unfolding at all if it is large
    
    4604
    -    -- This is just an attempt to keep residency under control in
    
    4605
    -    -- deeply-nested join-point such as those arising in #26425
    
    4606
    -    return NoUnfolding
    
    4601
    +--  | is_join
    
    4602
    +--  , UnfNever <- guidance
    
    4603
    +--  = -- For large join points, don't keep an unfolding at all if it is large
    
    4604
    +--    -- This is just an attempt to keep residency under control in
    
    4605
    +--    -- deeply-nested join-point such as those arising in #26425
    
    4606
    +--    return NoUnfolding
    
    4607 4607
     
    
    4608 4608
       | otherwise
    
    4609 4609
       = return (mkCoreUnfolding src is_top_lvl new_rhs Nothing guidance)