Marge Bot pushed to branch master at Glasgow Haskell Compiler / GHC

Commits:

7 changed files:

Changes:

  • changelog.d/simd_constant_folding
    1
    +section: codegen
    
    2
    +synopsis: Implement Cmm constant folding for some SIMD vector instructions
    
    3
    +issues: #25030 #26915
    
    4
    +mrs: !15512
    
    5
    +
    
    6
    +description: {
    
    7
    +The Cmm constant folding pass now handles the following vector operations:
    
    8
    +
    
    9
    +- insert and extract (broadcast was already supported)
    
    10
    +- integer arithmetic operations: negation, addition, subtraction, multiplication,
    
    11
    +  minimum, maximum
    
    12
    +- logical operations: and, or, xor
    
    13
    +}
    
    14
    +

  • compiler/GHC/Cmm/Opt.hs
    ... ... @@ -24,6 +24,7 @@ import GHC.Platform
    24 24
     import GHC.Types.Literal.Floating
    
    25 25
     
    
    26 26
     import Data.Maybe
    
    27
    +import Control.Monad (zipWithM, guard)
    
    27 28
     import GHC.Float
    
    28 29
     
    
    29 30
     
    
    ... ... @@ -47,7 +48,6 @@ cmmMachOpFold
    47 48
         -> MachOp       -- The operation from an CmmMachOp
    
    48 49
         -> [CmmExpr]    -- The optimized arguments
    
    49 50
         -> CmmExpr
    
    50
    -
    
    51 51
     cmmMachOpFold platform op args = fromMaybe (CmmMachOp op args) (cmmMachOpFoldM platform op args)
    
    52 52
     
    
    53 53
     -- Returns Nothing if no changes, useful for Hoopl, also reduces
    
    ... ... @@ -65,6 +65,30 @@ cmmMachOpFoldM _ (MO_VF_Broadcast lg _w) exprs =
    65 65
       case exprs of
    
    66 66
         [CmmLit l] -> Just $! CmmLit (CmmVec $ replicate lg l)
    
    67 67
         _ -> Nothing
    
    68
    +
    
    69
    +cmmMachOpFoldM plat (MO_V_Extract l _)  [v, (CmmLit (CmmInt idx W32))]
    
    70
    +  | idx >= 0, idx < fromIntegral l
    
    71
    +  = do
    
    72
    +    es <- vectorElements_maybe plat v
    
    73
    +    es !! fromInteger idx
    
    74
    +
    
    75
    +cmmMachOpFoldM plat (MO_VF_Extract l _) [v, (CmmLit (CmmInt idx W32))]
    
    76
    +  | idx >= 0, idx < fromIntegral l
    
    77
    +  = do
    
    78
    +    es <- vectorElements_maybe plat v
    
    79
    +    es !! fromInteger idx
    
    80
    +
    
    81
    +cmmMachOpFoldM plat op [v, newval@(CmmLit _), CmmLit (CmmInt idx W32)]
    
    82
    +  | MO_V_Insert  l _ <- op = foldToVecLit l
    
    83
    +  | MO_VF_Insert l _ <- op = foldToVecLit l
    
    84
    +  where foldToVecLit l = do
    
    85
    +          guard (idx >= 0 && idx < fromIntegral l)
    
    86
    +          ls <- vectorElements_maybe plat v
    
    87
    +          lits <- sequence $ map toLit_maybe (replaceAt (fromIntegral idx) (Just newval) ls)
    
    88
    +          Just $! CmmLit (CmmVec lits)
    
    89
    +        toLit_maybe (Just (CmmLit l)) = Just l
    
    90
    +        toLit_maybe _ = Nothing
    
    91
    +
    
    68 92
     cmmMachOpFoldM _ op [CmmLit (CmmInt x rep)]
    
    69 93
       | MO_WF_Bitcast width <- op = case width of
    
    70 94
           W32 | res <- castWord32ToFloat (fromInteger x)
    
    ... ... @@ -457,6 +481,64 @@ cmmMachOpFoldM platform mop [x, (CmmLit (CmmInt n _w))]
    457 481
             x2 = if p == 1 then x1 else
    
    458 482
                  CmmMachOp (MO_And rep) [x1, CmmLit (CmmInt (n-1) rep)]
    
    459 483
     
    
    484
    +-- Many vector MachOps are simply element-wise scalar MachOps. For these, we reduce
    
    485
    +-- to the scalar case using 'vectorMachOpScalarMachOp_maybe' and 'vectorElements_maybe'.
    
    486
    +
    
    487
    +-- Unary vector MachOps.
    
    488
    +cmmMachOpFoldM plat op [v]
    
    489
    +  | Just scalar_op <- vectorMachOpToScalarMachOp_maybe op
    
    490
    +  = do es <- vectorElements_maybe plat v
    
    491
    +       ls <- mapM (foldToLit plat scalar_op) es
    
    492
    +       Just $! CmmLit $ CmmVec ls
    
    493
    +
    
    494
    +  where foldToLit plat mop (Just a) = do
    
    495
    +          CmmLit l <- cmmMachOpFoldM plat mop [a]
    
    496
    +          return l
    
    497
    +        foldToLit _ _ _ = Nothing
    
    498
    +
    
    499
    +-- Binary vector MachOps.
    
    500
    +cmmMachOpFoldM plat op [v1, v2]
    
    501
    +  | Just scalar_op <- vectorMachOpToScalarMachOp_maybe op
    
    502
    +  = do
    
    503
    +      es1 <- vectorElements_maybe plat v1
    
    504
    +      es2 <- vectorElements_maybe plat v2
    
    505
    +      ls <- zipWithM (foldToLit plat scalar_op) es1 es2
    
    506
    +      Just $! CmmLit $ CmmVec ls
    
    507
    +  -- MIN/MAX don't have scalar equivalents, so handle them manually.
    
    508
    +  | MO_VS_Max _ w <- op = do
    
    509
    +      es1 <- vectorElements_maybe plat v1
    
    510
    +      es2 <- vectorElements_maybe plat v2
    
    511
    +      ls <- zipWithM (foldOp (narrowS w) max) es1 es2
    
    512
    +      Just $! CmmLit $ CmmVec ls
    
    513
    +  | MO_VU_Max _ w <- op = do
    
    514
    +      es1 <- vectorElements_maybe plat v1
    
    515
    +      es2 <- vectorElements_maybe plat v2
    
    516
    +      ls <- zipWithM (foldOp (narrowU w) max) es1 es2
    
    517
    +      Just $! CmmLit $ CmmVec ls
    
    518
    +  | MO_VS_Min _ w <- op = do
    
    519
    +      es1 <- vectorElements_maybe plat v1
    
    520
    +      es2 <- vectorElements_maybe plat v2
    
    521
    +      ls <- zipWithM (foldOp (narrowS w) min) es1 es2
    
    522
    +      Just $! CmmLit $ CmmVec ls
    
    523
    +  | MO_VU_Min _ w <- op = do
    
    524
    +      es1 <- vectorElements_maybe plat v1
    
    525
    +      es2 <- vectorElements_maybe plat v2
    
    526
    +      ls <- zipWithM (foldOp (narrowU w) min) es1 es2
    
    527
    +      Just $! CmmLit $ CmmVec ls
    
    528
    +
    
    529
    +  where
    
    530
    +    foldToLit plat mop (Just a1) (Just a2) = do
    
    531
    +      CmmLit l <- cmmMachOpFoldM plat mop [a1, a2]
    
    532
    +      return l
    
    533
    +    foldToLit _ _ _ _  = Nothing
    
    534
    +
    
    535
    +    foldOp do_narrow op
    
    536
    +      (Just (CmmLit (CmmInt x rep)))
    
    537
    +      (Just (CmmLit (CmmInt y _)))
    
    538
    +        = Just $! CmmInt (do_narrow x `op` do_narrow y) rep
    
    539
    +    foldOp _ _ _ _ = Nothing
    
    540
    +
    
    541
    +
    
    460 542
     -- ToDo (#7116): optimise floating-point multiplication, e.g. x*2.0 -> x+x
    
    461 543
     -- Unfortunately this needs a unique supply because x might not be a
    
    462 544
     -- register.  See #2253 (program 6) for an example.
    
    ... ... @@ -473,6 +555,59 @@ validOffsetRep :: Width -> Bool
    473 555
     validOffsetRep rep = widthInBits rep <= finiteBitSize (undefined :: Int)
    
    474 556
     
    
    475 557
     
    
    558
    +-- Is this a vector 'MachOp' that is an element-wise lift of
    
    559
    +-- a scalar 'MachOp'? If so, returns the corresponding scalar 'MachOp'.
    
    560
    +vectorMachOpToScalarMachOp_maybe :: MachOp -> Maybe MachOp
    
    561
    +vectorMachOpToScalarMachOp_maybe m = case m of
    
    562
    +  MO_VS_Neg _ w -> Just $ MO_S_Neg w
    
    563
    +  MO_VF_Neg _ w -> Just $ MO_F_Neg w
    
    564
    +  MO_V_Add  _ w -> Just $ MO_Add w
    
    565
    +  MO_V_Sub  _ w -> Just $ MO_Sub w
    
    566
    +  MO_V_Mul  _ w -> Just $ MO_Mul w
    
    567
    +  MO_VF_Add _ w -> Just $ MO_F_Add w
    
    568
    +  MO_VF_Sub _ w -> Just $ MO_F_Sub w
    
    569
    +  MO_VF_Mul _ w -> Just $ MO_F_Mul w
    
    570
    +  MO_VF_Min _ w -> Just $ MO_F_Min w
    
    571
    +  MO_VF_Max _ w -> Just $ MO_F_Max w
    
    572
    +  MO_V_And  _ w -> Just $ MO_And w
    
    573
    +  MO_V_Or   _ w -> Just $ MO_Or w
    
    574
    +  MO_V_Xor  _ w -> Just $ MO_Xor w
    
    575
    +  _ -> Nothing
    
    576
    +
    
    577
    +
    
    578
    +-- | Helper function that tells us what we know about the elements of a vector.
    
    579
    +--
    
    580
    +-- Returns 'Nothing' for non-vectors, and @[Nothing, Nothing, ...]@ for vectors
    
    581
    +-- with unknown elements.
    
    582
    +vectorElements_maybe :: Platform -> CmmExpr -> Maybe [Maybe CmmExpr]
    
    583
    +vectorElements_maybe _plat (CmmLit (CmmVec es)) = Just $! map (Just . CmmLit) es
    
    584
    +
    
    585
    +vectorElements_maybe _plat (CmmMachOp (MO_V_Broadcast l _) args)
    
    586
    +  | [CmmLit v] <- args = Just $! replicate l (Just $! CmmLit v)
    
    587
    +vectorElements_maybe _plat (CmmMachOp (MO_VF_Broadcast l _) args)
    
    588
    +  | [CmmLit v] <- args = Just $! replicate l (Just $! CmmLit v)
    
    589
    +
    
    590
    +vectorElements_maybe plat (CmmMachOp (MO_V_Insert _ _) args)
    
    591
    +  | [v, e, (CmmLit (CmmInt i _w))] <- args
    
    592
    +  , Just es <- vectorElements_maybe plat v
    
    593
    +      = Just $! (replaceAt (fromInteger i) (Just $! e) es)
    
    594
    +
    
    595
    +vectorElements_maybe plat (CmmMachOp (MO_VF_Insert _ _) args)
    
    596
    +  | [v, e, (CmmLit (CmmInt i _w))] <- args
    
    597
    +  , Just es <- vectorElements_maybe plat v
    
    598
    +    = Just $! (replaceAt (fromInteger i) (Just $! e) es)
    
    599
    +
    
    600
    +vectorElements_maybe plat (CmmMachOp mop _)
    
    601
    +  | isVecType result_type = Just $! replicate (vecLength result_type) Nothing
    
    602
    +  where result_type = machOpResultType plat mop []
    
    603
    +
    
    604
    +vectorElements_maybe _plat (CmmReg reg)
    
    605
    +  | isVecType reg_type = Just $! replicate (vecLength reg_type) Nothing
    
    606
    +  where reg_type = cmmRegType reg
    
    607
    +
    
    608
    +vectorElements_maybe _ _ = Nothing
    
    609
    +
    
    610
    +
    
    476 611
     {- Note [Comparison operators]
    
    477 612
     ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
    
    478 613
     If we have
    

  • compiler/GHC/Utils/Misc.hs
    ... ... @@ -56,7 +56,7 @@ module GHC.Utils.Misc (
    56 56
     
    
    57 57
             -- * List operations controlled by another list
    
    58 58
             takeList, dropList, splitAtList, split,
    
    59
    -        dropTail, capitalise,
    
    59
    +        replaceAt, dropTail, capitalise,
    
    60 60
     
    
    61 61
             -- * Sorting
    
    62 62
             sortWith, minWith, nubSort, ordNub, ordNubOn,
    
    ... ... @@ -718,6 +718,14 @@ splitAtList xs ys = go 0# xs ys
    718 718
           go n  []     bs     = (take (I# n) ys, bs) -- = splitAt n ys
    
    719 719
           go n  (_:as) (_:bs) = go (n +# 1#) as bs
    
    720 720
     
    
    721
    +-- | given an index n and element y, replace the nth element of list xs with y
    
    722
    +replaceAt :: Int -> a -> [a] -> [a]
    
    723
    +replaceAt n y xs
    
    724
    +  | n >= length xs = xs
    
    725
    +  | n < 0 = xs
    
    726
    +  | otherwise = before ++ (y : drop 1 after)
    
    727
    +      where (before, after) = splitAt n xs
    
    728
    +
    
    721 729
     -- | drop from the end of a list
    
    722 730
     dropTail :: Int -> [a] -> [a]
    
    723 731
     -- Specification: dropTail n = reverse . drop n . reverse
    

  • testsuite/tests/simd/should_run/Makefile
    1
    +TOP=../../..
    
    2
    +include $(TOP)/mk/boilerplate.mk
    
    3
    +include $(TOP)/mk/test.mk
    
    4
    +
    
    5
    +T25030:
    
    6
    +	'$(TEST_HC)' $(TEST_HC_OPTS) T25030.hs -v0 -O1 -fforce-recomp -ddump-cmm > T25030.cmm 2>&1
    
    7
    +
    
    8
    +	# testFoldPlus: 111111+121212=232323, 121212+131313=252525 should be folded
    
    9
    +	grep -m 1 -o "232323" T25030.cmm
    
    10
    +	grep -m 1 -o "252525" T25030.cmm
    
    11
    +	# operands should not appear in the output
    
    12
    +	grep -o "111111" T25030.cmm || echo "Does not appear: 111111"
    
    13
    +	grep -o "121212" T25030.cmm || echo "Does not appear: 121212"
    
    14
    +	grep -o "131313" T25030.cmm || echo "Does not appear: 131313"
    
    15
    +
    
    16
    +	# testFoldMax: max(333333,333332)=333333 should be folded
    
    17
    +	grep -m 1 -o "333333" T25030.cmm
    
    18
    +	# lesser operand should not appear
    
    19
    +	grep -o "333332" T25030.cmm || echo "Does not appear: 333332"
    
    20
    +
    
    21
    +	# testNeg: negate(343434)=-343434 should be folded
    
    22
    +	grep -m 1 -o -- "-343434" T25030.cmm
    
    23
    +
    
    24
    +	# testInserts: insert 363636 into broadcast(353535) and extract it;
    
    25
    +	# should fold to constant 363636
    
    26
    +	grep -m 1 -o "363636" T25030.cmm
    
    27
    +	# broadcast operand should not appear
    
    28
    +	grep -o "353535" T25030.cmm || echo "Does not appear: 353535"
    
    29
    +
    
    30
    +	# testInserts2: 383838+393939=777777 should be folded
    
    31
    +	grep -m 1 -o "777777" T25030.cmm
    
    32
    +	# addends should not appear
    
    33
    +	grep -o "383838" T25030.cmm || echo "Does not appear: 383838"
    
    34
    +
    
    35
    +	# testOverwrite: inserting 404040,404041 into broadcast(414141) should fold to <404040,404041>
    
    36
    +	grep -m 1 -o "404040" T25030.cmm
    
    37
    +	grep -m 1 -o "404041" T25030.cmm
    
    38
    +	# original broadcast value should not appear
    
    39
    +	grep -o "414141" T25030.cmm || echo "Does not appear: 414141"
    
    40
    +
    
    41
    +	# testExtractFromInsert: extract(insert(unknown_v, 454545, 3), 3) should fold to 454545
    
    42
    +	grep -m 1 -o "454545" T25030.cmm

  • testsuite/tests/simd/should_run/T25030.hs
    1
    +{-# LANGUAGE MagicHash, UnboxedTuples, LexicalNegation, ExtendedLiterals #-}
    
    2
    +
    
    3
    +import GHC.Prim
    
    4
    +import GHC.Int
    
    5
    +
    
    6
    +-- Cmm constant folding tests for vector operations
    
    7
    +
    
    8
    +data IntX2 = IX2# Int64X2#
    
    9
    +data IntX4 = IX4# Int32X4#
    
    10
    +
    
    11
    +instance Show IntX2 where
    
    12
    +  show (IX2# d) = case (unpackInt64X2# d) of
    
    13
    +    (# a, b #) -> show ((I64# a), (I64# b))
    
    14
    +
    
    15
    +instance Show IntX4 where
    
    16
    +  show (IX4# v) = case (unpackInt32X4# v) of
    
    17
    +    (# a, b, c, d #) -> show ((I32# a), (I32# b), (I32# c), (I32# d))
    
    18
    +
    
    19
    +testFoldPlus = do
    
    20
    +  let v1    = packInt64X2# (# 111111#Int64,  121212#Int64 #)
    
    21
    +  let v2    = packInt64X2# (# 121212#Int64,  131313#Int64 #)
    
    22
    +  print $ IX2# $ plusInt64X2# v1 v2 -- expect to see 232323 and 252525 here,
    
    23
    +                                    -- and not 111111, 121212, or 131313
    
    24
    +
    
    25
    +testFoldMax = do
    
    26
    +  let v1    = broadcastInt32X4# 333333#Int32
    
    27
    +  let v2    = broadcastInt32X4# 333332#Int32
    
    28
    +  print $ IX4# $ maxInt32X4# v1 v2 -- expect to see 333333 here and not 333332
    
    29
    +
    
    30
    +testFoldMin = do
    
    31
    +  let v1 = broadcastInt32X4# 474747#Int32
    
    32
    +  let v2 = broadcastInt32X4# 474748#Int32
    
    33
    +  print $ IX4# $ minInt32X4# v1 v2 -- expect to see 474747 here and not 474748
    
    34
    +
    
    35
    +testNeg = do
    
    36
    +  let v1 = broadcastInt32X4# 343434#Int32
    
    37
    +  print $ IX4# $ negateInt32X4# v1 -- expect to see -343434 here, not positive 343434
    
    38
    +
    
    39
    +
    
    40
    +testInserts = do
    
    41
    +  let v1 = broadcastInt32X4# 353535#Int32
    
    42
    +  let v2 = insertInt32X4# v1 363636#Int32 0#
    
    43
    +  let (# a, _, _, _ #) = unpackInt32X4# v2
    
    44
    +  print $ (I32# a) -- expect to see 363636 here, not 353535
    
    45
    +
    
    46
    +
    
    47
    +testInserts2 = do
    
    48
    +  let v1 = broadcastInt32X4# 373737#Int32
    
    49
    +  let v2 = insertInt32X4# v1 383838#Int32 0#
    
    50
    +  let v3 = plusInt32X4# v2 (broadcastInt32X4# 393939#Int32)
    
    51
    +  let (# a, _, _, _ #) = unpackInt32X4# v3
    
    52
    +  print $ (I32# a) -- expect to see 777777 == 383838+393939 here, and not 373737, 383838, or 393939
    
    53
    +
    
    54
    +{-# INLINE testOverwrite #-}
    
    55
    +testOverwrite :: Int64X2# -> IO ()
    
    56
    +testOverwrite v = do
    
    57
    +  let v1 = insertInt64X2# v 404040#Int64 0#
    
    58
    +  let v2 = insertInt64X2# v1 404041#Int64 1#
    
    59
    +  print $ IX2# v2 -- expect <404040, 404041> to appear in the cmm as a single assignment,
    
    60
    +                  -- rather than a series of inserts
    
    61
    +
    
    62
    +{-# NOINLINE testExtractFromInsert #-}
    
    63
    +testExtractFromInsert :: Int32X4# -> IO ()
    
    64
    +testExtractFromInsert v = do
    
    65
    +  let v2 = insertInt32X4# v 454545#Int32 3#
    
    66
    +  let (# _, _, _, d #) = unpackInt32X4# v2
    
    67
    +  print (I32# d) -- 454545 should fold as a constant even though v is a runtime value
    
    68
    +
    
    69
    +
    
    70
    +main = do
    
    71
    +  testFoldPlus
    
    72
    +  testFoldMax
    
    73
    +  testFoldMin
    
    74
    +  testNeg
    
    75
    +  testInserts
    
    76
    +  testInserts2
    
    77
    +  testOverwrite (broadcastInt64X2# 414141#Int64)
    
    78
    +  testExtractFromInsert (broadcastInt32X4# 464646#Int32)
    
    79
    +

  • testsuite/tests/simd/should_run/T25030.stdout
    1
    +232323
    
    2
    +252525
    
    3
    +Does not appear: 111111
    
    4
    +Does not appear: 121212
    
    5
    +Does not appear: 131313
    
    6
    +333333
    
    7
    +333333
    
    8
    +333333
    
    9
    +Does not appear: 333332
    
    10
    +-343434
    
    11
    +-343434
    
    12
    +-343434
    
    13
    +363636
    
    14
    +Does not appear: 353535
    
    15
    +777777
    
    16
    +Does not appear: 383838
    
    17
    +404040
    
    18
    +404041
    
    19
    +Does not appear: 414141
    
    20
    +454545

  • testsuite/tests/simd/should_run/all.T
    ... ... @@ -49,6 +49,8 @@ test('int16x8_shuffle_baseline', [], compile_and_run, [''])
    49 49
     test('int32x4_shuffle_baseline', [], compile_and_run, [''])
    
    50 50
     test('int64x2_shuffle_baseline', [], compile_and_run, [''])
    
    51 51
     
    
    52
    +test('T25030', [when(arch('i386'), expect_broken_for(25498, ['optllvm']))], makefile_test, [])
    
    53
    +
    
    52 54
     test('T25658', [], compile_and_run, ['']) # #25658 is a bug with SSE2 code generation
    
    53 55
     test('T25659', [], compile_and_run, [''])
    
    54 56
     
    
    ... ... @@ -83,6 +85,7 @@ test('simd007', [], compile_and_run, [''])
    83 85
     test('simd008', [], compile_and_run, [''])
    
    84 86
     test('simd009', [ req_th
    
    85 87
                     , extra_files(['Simd009b.hs', 'Simd009c.hs'])
    
    88
    +                , when(arch('i386'), expect_broken_for(25498, ['optllvm']))
    
    86 89
                     ]
    
    87 90
                   , multimod_compile_and_run, ['simd009', ''])
    
    88 91
     test('simd010', [], compile_and_run, [''])
    
    ... ... @@ -174,7 +177,7 @@ test('T25062_V64'
    174 177
         , compile_and_run if have_cpu_feature('avx512f') else compile
    
    175 178
         , [''])
    
    176 179
     
    
    177
    -test('T25169', [], compile_and_run, [''])
    
    180
    +test('T25169', [when(arch('i386'), expect_broken_for(25498, ['optllvm']))], compile_and_run, [''])
    
    178 181
     test('T25455', [], compile_and_run, [''])
    
    179 182
     test('T25486', [], compile_and_run, [''])
    
    180 183