Simon Peyton Jones pushed to branch wip/T23162-part2 at Glasgow Haskell Compiler / GHC

Commits:

8 changed files:

Changes:

  • compiler/GHC/Tc/Solver.hs
    ... ... @@ -1068,7 +1068,7 @@ findInferredDiff annotated_theta inferred_theta
    1068 1068
       | null annotated_theta   -- Short cut the common case when the user didn't
    
    1069 1069
       = return inferred_theta  -- write any constraints in the partial signature
    
    1070 1070
       | otherwise
    
    1071
    -  = pushTcLevelM_ $
    
    1071
    +  = TcM.pushTcLevelM_ $
    
    1072 1072
         do { lcl_env   <- TcM.getLclEnv
    
    1073 1073
            ; given_ids <- mapM TcM.newEvVar annotated_theta
    
    1074 1074
            ; wanteds   <- newWanteds AnnOrigin inferred_theta
    

  • compiler/GHC/Tc/Solver/FunDeps.hs
    ... ... @@ -13,11 +13,11 @@ import {-# SOURCE #-} GHC.Tc.Solver.Solve( solveSimpleWanteds )
    13 13
     
    
    14 14
     import GHC.Tc.Instance.FunDeps
    
    15 15
     import GHC.Tc.Solver.InertSet
    
    16
    -import GHC.Tc.Solver.Monad
    
    17 16
     import GHC.Tc.Solver.Types
    
    17
    +import GHC.Tc.Solver.Monad   as TcS
    
    18
    +import GHC.Tc.Utils.Monad    as TcM
    
    18 19
     import GHC.Tc.Utils.TcType
    
    19 20
     import GHC.Tc.Utils.Unify( UnifyEnv(..) )
    
    20
    -import GHC.Tc.Utils.Monad    as TcM
    
    21 21
     import GHC.Tc.Types.Evidence
    
    22 22
     import GHC.Tc.Types.Constraint
    
    23 23
     
    
    ... ... @@ -39,7 +39,7 @@ import GHC.Utils.Panic
    39 39
     import GHC.Utils.Misc( lengthExceeds )
    
    40 40
     
    
    41 41
     import GHC.Data.Pair
    
    42
    -import Data.Maybe( isNothing, mapMaybe )
    
    42
    +import Data.Maybe( isNothing, isJust, mapMaybe )
    
    43 43
     
    
    44 44
     
    
    45 45
     {- Note [Overview of functional dependencies in type inference]
    
    ... ... @@ -469,8 +469,8 @@ tryEqFunDeps work_item@(EqCt { eq_lhs = work_lhs
    469 469
                                  , eq_eq_rel = eq_rel })
    
    470 470
       | NomEq <- eq_rel
    
    471 471
       , TyFamLHS fam_tc work_args <- work_lhs     -- We have F args ~N# rhs
    
    472
    -  = do { simpleStage $ traceTcS "tryEqFunDeps" (ppr work_item)
    
    473
    -       ; eqs_for_me <- simpleStage $ getInertFamEqsFor fam_tc work_args work_rhs
    
    472
    +  = do { eqs_for_me <- simpleStage $ getInertFamEqsFor fam_tc work_args work_rhs
    
    473
    +       ; simpleStage $ traceTcS "tryEqFunDeps" (ppr work_item $$ ppr eqs_for_me)
    
    474 474
            ; tryFamEqFunDeps eqs_for_me fam_tc work_args work_item }
    
    475 475
       | otherwise
    
    476 476
       = nopStage ()
    
    ... ... @@ -485,10 +485,11 @@ tryFamEqFunDeps eqs_for_me fam_tc work_args
    485 485
         else do { -- Note [Do local fundeps before top-level instances]
    
    486 486
                   tryFDEqns fam_tc work_args work_item $
    
    487 487
                   mkLocalBuiltinFamEqFDs eqs_for_me fam_tc ops work_args work_rhs
    
    488
    -            ; if all (isWanted . eqCtEvidence) eqs_for_me
    
    489
    -              then tryFDEqns fam_tc work_args work_item $
    
    490
    -                   mkTopBuiltinFamEqFDs fam_tc ops work_args work_rhs
    
    491
    -              else nopStage () }
    
    488
    +
    
    489
    +            ; if hasRelevantGiven eqs_for_me work_args work_item
    
    490
    +            ; then nopStage ()
    
    491
    +              else tryFDEqns fam_tc work_args work_item $
    
    492
    +                   mkTopBuiltinFamEqFDs fam_tc ops work_args work_rhs }
    
    492 493
     
    
    493 494
       | isGiven ev    -- See (INJFAM:Given)
    
    494 495
       = nopStage ()
    
    ... ... @@ -502,10 +503,10 @@ tryFamEqFunDeps eqs_for_me fam_tc work_args
    502 503
                Injective inj -> tryFDEqns fam_tc work_args work_item $
    
    503 504
                                 mkLocalFamEqFDs eqs_for_me fam_tc inj work_args work_rhs
    
    504 505
     
    
    505
    -       ; if all (isWanted . eqCtEvidence) eqs_for_me
    
    506
    -         then tryFDEqns fam_tc work_args work_item $
    
    507
    -              mkTopFamEqFDs fam_tc work_args work_rhs
    
    508
    -         else nopStage () }
    
    506
    +       ; if hasRelevantGiven eqs_for_me work_args work_item
    
    507
    +         then nopStage ()
    
    508
    +         else tryFDEqns fam_tc work_args work_item $
    
    509
    +              mkTopFamEqFDs fam_tc work_args work_rhs }
    
    509 510
     
    
    510 511
     mkTopFamEqFDs :: TyCon -> [TcType] -> Xi -> TcS [FunDepEqns]
    
    511 512
     mkTopFamEqFDs fam_tc work_args work_rhs
    
    ... ... @@ -548,10 +549,8 @@ mkTopClosedFamEqFDs ax work_args work_rhs
    548 549
       = do { let branches = fromBranches (coAxiomBranches ax)
    
    549 550
            ; traceTcS "mkTopClosed" (ppr branches $$ ppr work_args $$ ppr work_rhs)
    
    550 551
            ; case getRelevantBranches ax work_args work_rhs of
    
    551
    -           [CoAxBranch { cab_tvs = qtvs, cab_lhs = lhs_tys, cab_rhs = rhs_ty }]
    
    552
    -              -> return [FDEqns { fd_qtvs = qtvs
    
    553
    -                                , fd_eqs = zipWith Pair (rhs_ty:lhs_tys) (work_rhs:work_args) }]
    
    554
    -           _  -> return [] }
    
    552
    +           [eqn] -> return [eqn]
    
    553
    +           _     -> return [] }
    
    555 554
        | otherwise
    
    556 555
        = return []
    
    557 556
     
    
    ... ... @@ -566,7 +565,21 @@ isInformativeType (TyConApp tc tys) = isGenerativeTyCon tc Nominal ||
    566 565
                                                      tys `lengthExceeds` tyConArity tc
    
    567 566
     isInformativeType _ = True  -- AppTy, ForAllTy, FunTy, LitTy
    
    568 567
     
    
    569
    -getRelevantBranches :: CoAxiom Branched -> [TcType] -> Xi -> [CoAxBranch]
    
    568
    +hasRelevantGiven :: [EqCt] -> [TcType] -> EqCt -> Bool
    
    569
    +-- A Given is relevant if it is not apart from the Wanted
    
    570
    +hasRelevantGiven eqs_for_me work_args (EqCt { eq_rhs = work_rhs })
    
    571
    +  = any relevant eqs_for_me
    
    572
    +  where
    
    573
    +    work_tys = work_rhs : work_args
    
    574
    +
    
    575
    +    relevant (EqCt { eq_ev = ev, eq_lhs = lhs, eq_rhs = rhs_ty })
    
    576
    +       | isGiven ev
    
    577
    +       , TyFamLHS _ lhs_tys <- lhs
    
    578
    +       = isJust (tcUnifyTysForInjectivity True work_tys (rhs_ty:lhs_tys))
    
    579
    +       | otherwise
    
    580
    +       = False
    
    581
    +
    
    582
    +getRelevantBranches :: CoAxiom Branched -> [TcType] -> Xi -> [FunDepEqns]
    
    570 583
     getRelevantBranches ax work_args work_rhs
    
    571 584
       = go [] (fromBranches (coAxiomBranches ax))
    
    572 585
       where
    
    ... ... @@ -574,13 +587,21 @@ getRelevantBranches ax work_args work_rhs
    574 587
     
    
    575 588
         go _ [] = []
    
    576 589
         go preceding (branch:branches)
    
    577
    -      | is_relevant branch = branch : go (branch:preceding) branches
    
    578
    -      | otherwise          =          go (branch:preceding) branches
    
    590
    +      = case is_relevant branch of
    
    591
    +          Just eqn -> eqn : go (branch:preceding) branches
    
    592
    +          Nothing  ->       go (branch:preceding) branches
    
    579 593
           where
    
    580
    -         is_relevant (CoAxBranch { cab_lhs = lhs_tys, cab_rhs = rhs_ty })
    
    581
    -            = case tcUnifyTysForInjectivity True work_tys (rhs_ty:lhs_tys) of
    
    582
    -                     Nothing    -> False
    
    583
    -                     Just subst -> all (no_match (substTys subst lhs_tys)) preceding
    
    594
    +         is_relevant (CoAxBranch { cab_tvs = qtvs, cab_lhs = lhs_tys, cab_rhs = rhs_ty })
    
    595
    +            | Just subst <- tcUnifyTysForInjectivity True work_tys (rhs_ty:lhs_tys)
    
    596
    +            , let (subst', qtvs') = trim_qtvs subst qtvs
    
    597
    +                  lhs_tys' = substTys subst' lhs_tys
    
    598
    +                  rhs_ty'  = substTy  subst' rhs_ty
    
    599
    +            , all (no_match lhs_tys') preceding
    
    600
    +            = pprTrace "grb" (ppr qtvs $$ ppr subst $$ ppr qtvs' $$ ppr subst' $$ ppr lhs_tys $$ ppr lhs_tys') $
    
    601
    +              Just (FDEqns { fd_qtvs = qtvs'
    
    602
    +                           , fd_eqs = zipWith Pair (rhs_ty':lhs_tys') work_tys })
    
    603
    +            | otherwise
    
    604
    +            = Nothing
    
    584 605
     
    
    585 606
              no_match lhs_tys (CoAxBranch { cab_lhs = lhs_tys1 })
    
    586 607
                 = isNothing (tcUnifyTysForInjectivity False lhs_tys1 lhs_tys)
    
    ... ... @@ -608,15 +629,15 @@ mkTopOpenFamEqFDs fam_tc inj_flags work_args work_rhs
    608 629
           | otherwise
    
    609 630
           = Nothing
    
    610 631
     
    
    611
    -    trim_qtvs :: Subst -> [TcTyVar] -> (Subst,[TcTyVar])
    
    612
    -    -- Tricky stuff: see (TIF1) in
    
    613
    -    -- Note [Type inference for type families with injectivity]
    
    614
    -    trim_qtvs subst []       = (subst, [])
    
    615
    -    trim_qtvs subst (tv:tvs)
    
    616
    -      | tv `elemSubst` subst = trim_qtvs subst tvs
    
    617
    -      | otherwise            = let !(subst1, tv')  = substTyVarBndr subst tv
    
    618
    -                                   !(subst', tvs') = trim_qtvs subst1 tvs
    
    619
    -                               in (subst', tv':tvs')
    
    632
    +trim_qtvs :: Subst -> [TcTyVar] -> (Subst,[TcTyVar])
    
    633
    +-- Tricky stuff: see (TIF1) in
    
    634
    +-- Note [Type inference for type families with injectivity]
    
    635
    +trim_qtvs subst []       = (subst, [])
    
    636
    +trim_qtvs subst (tv:tvs)
    
    637
    +  | tv `elemSubst` subst = trim_qtvs subst tvs
    
    638
    +  | otherwise            = let !(subst1, tv')  = substTyVarBndr subst tv
    
    639
    +                               !(subst', tvs') = trim_qtvs subst1 tvs
    
    640
    +                           in (subst', tv':tvs')
    
    620 641
     
    
    621 642
     mkLocalFamEqFDs :: [EqCt] -> TyCon -> [Bool] -> [TcType] -> Xi -> TcS [FunDepEqns]
    
    622 643
     mkLocalFamEqFDs eqs_for_me fam_tc inj_flags work_args work_rhs
    
    ... ... @@ -823,7 +844,7 @@ For /built-in/ type families, it's pretty similar, except that
    823 844
         FDEqn { fd_qtvs = [b:kappa], fd_eqs = [ beta ~ Proxy @kappa b ] }
    
    824 845
       Notice that
    
    825 846
         * we must quantify the FunDepEqns over `b`, which is not matched; for this
    
    826
    -      we will generate a fresh unfication variable in `instantiateFunDepEqn`.
    
    847
    +      we will generate a fresh unification variable in `instantiateFunDepEqn`.
    
    827 848
         * we must substitute `k:->kappa` in the kind of `b`.
    
    828 849
       This fancy footwork for `fd_qtvs` is done by `trim_qtvs` in
    
    829 850
       `mkInjWantedFamEqTopEqns`.
    
    ... ... @@ -889,6 +910,10 @@ solveFunDeps work_ev fd_eqns
    889 910
       = do { (unifs, _res)
    
    890 911
                  <- reportFineGrainUnifications $
    
    891 912
                     nestFunDepsTcS              $
    
    913
    +                TcS.pushTcLevelM_           $
    
    914
    +                   -- pushTcLevelTcM: increase the level so that unification variables
    
    915
    +                   -- allocated by the fundep-creation itself don't count as useful unifications
    
    916
    +                   -- See Note [Deeper TcLevel for partial improvement unification variables]
    
    892 917
                     do { (_, eqs) <- wrapUnifier work_ev Nominal do_fundeps
    
    893 918
                        ; solveSimpleWanteds eqs }
    
    894 919
         -- Why solveSimpleWanteds?  Answer
    

  • compiler/GHC/Tc/Solver/Monad.hs
    ... ... @@ -25,7 +25,7 @@ module GHC.Tc.Solver.Monad (
    25 25
         selectNextWorkItem,
    
    26 26
         getWorkList,
    
    27 27
         updWorkListTcS,
    
    28
    -    pushLevelNoWorkList,
    
    28
    +    pushLevelNoWorkList, pushTcLevelM_,
    
    29 29
     
    
    30 30
         runTcPluginTcS, recordUsedGREs,
    
    31 31
         matchGlobalInst, TcM.ClsInstResult(..),
    
    ... ... @@ -1320,11 +1320,6 @@ nestImplicTcS skol_info ev_binds_var inner_tclvl (TcS thing_inside)
    1320 1320
     nestFunDepsTcS :: TcS a -> TcS a
    
    1321 1321
     nestFunDepsTcS (TcS thing_inside)
    
    1322 1322
       = TcS $ \ env@(TcSEnv { tcs_inerts = inerts_var }) ->
    
    1323
    -    TcM.pushTcLevelM_  $
    
    1324
    -         -- pushTcLevelTcM: increase the level so that unification variables
    
    1325
    -         -- allocated by the fundep-creation itself don't count as useful unifications
    
    1326
    -         -- See Note [Deeper TcLevel for partial improvement unification variables]
    
    1327
    -         --     in GHC.Tc.Solver.FunDeps
    
    1328 1323
         do { inerts <- TcM.readTcRef inerts_var
    
    1329 1324
            ; let nest_inerts = resetInertCans inerts
    
    1330 1325
                      -- resetInertCans: like nestImplicTcS
    
    ... ... @@ -1834,6 +1829,10 @@ selectNextWorkItem
    1834 1829
          } }
    
    1835 1830
     
    
    1836 1831
     
    
    1832
    +pushTcLevelM_ :: TcS a -> TcS a
    
    1833
    +pushTcLevelM_ (TcS thing_inside)
    
    1834
    +  = TcS (\env -> TcM.pushTcLevelM_ (thing_inside env))
    
    1835
    +
    
    1837 1836
     pushLevelNoWorkList :: SDoc -> TcS a -> TcS (TcLevel, a)
    
    1838 1837
     -- Push the level and run thing_inside
    
    1839 1838
     -- However, thing_inside should not generate any work items
    

  • compiler/GHC/Tc/Utils/Unify.hs
    ... ... @@ -2440,7 +2440,7 @@ The eager unifier, `uType`, is called by
    2440 2440
         via the wrappers `unifyType`, `unifyKind` etc
    
    2441 2441
     
    
    2442 2442
       * The constraint solver (e.g. in GHC.Tc.Solver.Equality),
    
    2443
    -    via `GHC.Tc.Solver.Monad.wrapUnifie`.
    
    2443
    +    via `GHC.Tc.Solver.Monad.wrapUnifier`.
    
    2444 2444
     
    
    2445 2445
     `uType` runs in the TcM monad, but it carries a UnifyEnv that tells it
    
    2446 2446
     what to do when unifying a variable or deferring a constraint. Specifically,
    

  • testsuite/tests/typecheck/should_fail/T23162b.hs
    1
    +{-# LANGUAGE DataKinds              #-}
    
    2
    +{-# LANGUAGE GADTs                  #-}
    
    3
    +{-# LANGUAGE TypeFamilyDependencies #-}
    
    4
    +{-# LANGUAGE TypeOperators          #-}
    
    5
    +{-# LANGUAGE ViewPatterns           #-}
    
    6
    +{-# LANGUAGE UndecidableInstances   #-}
    
    7
    +
    
    8
    +module T23162b where
    
    9
    +
    
    10
    +import Data.Kind  ( Type )
    
    11
    +import Data.Proxy
    
    12
    +
    
    13
    +type family LV (as :: [Type]) (b :: Type) = (r :: Type) | r -> as b where
    
    14
    +  LV (a ': as) b = a -> LV as b
    
    15
    +
    
    16
    +eq :: a -> a -> ()
    
    17
    +eq x y = ()
    
    18
    +
    
    19
    +foo :: Proxy a -> b -> LV a b
    
    20
    +foo = foo
    
    21
    +
    
    22
    +bar :: (c->()) -> ()
    
    23
    +bar =  bar
    
    24
    +
    
    25
    +f1 :: Int -> ()
    
    26
    +-- LV alpha Bool ~ LV alpha Char
    
    27
    +f1 x = bar (\y -> eq (foo y True) (foo y 'c'))
    
    28
    +
    
    29
    +f2 :: Int -> ()
    
    30
    +-- LV alpha Bool ~ Int -> LV alpha Char
    
    31
    +f2 x = bar (\y -> eq (foo y True) (\(z::Int) -> foo y 'c'))
    
    32
    +
    
    33
    +

  • testsuite/tests/typecheck/should_fail/T23162c.hs
    1
    +{-# LANGUAGE TypeFamilies #-}
    
    2
    +{-# LANGUAGE TypeFamilyDependencies #-}
    
    3
    +
    
    4
    +module T23162c where
    
    5
    +
    
    6
    +type family Bak a = r | r -> a where
    
    7
    +     Bak Int  = Char
    
    8
    +     Bak Char = Int
    
    9
    +     Bak a    = a
    
    10
    +
    
    11
    +eq :: a -> a -> ()
    
    12
    +eq x y = ()
    
    13
    +
    
    14
    +bar :: (c->()) -> ()
    
    15
    +bar =  bar
    
    16
    +
    
    17
    +foo :: a -> Bak a
    
    18
    +foo = foo
    
    19
    +
    
    20
    +-- Bak alpha ~ ()
    
    21
    +f :: ()
    
    22
    +f = bar (\y -> eq (foo y) ())

  • testsuite/tests/typecheck/should_fail/T23162d.hs
    1
    +{-# LANGUAGE DataKinds #-}
    
    2
    +{-# LANGUAGE ScopedTypeVariables #-}
    
    3
    +{-# LANGUAGE TypeFamilies #-}
    
    4
    +{-# LANGUAGE AllowAmbiguousTypes #-}
    
    5
    +
    
    6
    +module T23162d where
    
    7
    +
    
    8
    +import GHC.TypeNats
    
    9
    +import Data.Kind
    
    10
    +
    
    11
    +data T2 a b = MkT2 a b
    
    12
    +
    
    13
    +type TArgKind :: Nat -> Type
    
    14
    +type family TArgKind n where
    
    15
    +   TArgKind 2 = T2 Type Type
    
    16
    +
    
    17
    +eq :: a -> a -> ()
    
    18
    +eq x y = ()
    
    19
    +
    
    20
    +bar :: (c->()) -> ()
    
    21
    +bar =  bar
    
    22
    +
    
    23
    +foo :: forall n k0 k1. (TArgKind n ~ T2 k0 k1) => Int
    
    24
    +foo = foo @n
    
    25
    +
    
    26
    +f :: () -> Int
    
    27
    +f () = foo

  • testsuite/tests/typecheck/should_fail/all.T
    ... ... @@ -746,3 +746,6 @@ test('T26255a', normal, compile_fail, [''])
    746 746
     test('T26255b', normal, compile_fail, [''])
    
    747 747
     test('T26330', normal, compile_fail, [''])
    
    748 748
     test('T23162a', normal, compile_fail, [''])
    
    749
    +test('T23162b', normal, compile_fail, [''])
    
    750
    +test('T23162c', normal, compile, [''])
    
    751
    +test('T23162d', normal, compile, [''])