Sven Tennie pushed to branch wip/supersven/riscv-vectors at Glasgow Haskell Compiler / GHC

Commits:

5 changed files:

Changes:

  • compiler/GHC/CmmToAsm/RV64/CodeGen.hs
    ... ... @@ -12,6 +12,7 @@ module GHC.CmmToAsm.RV64.CodeGen
    12 12
     where
    
    13 13
     
    
    14 14
     import Control.Monad
    
    15
    +import Data.Bifunctor (bimap)
    
    15 16
     import Data.Maybe
    
    16 17
     import Data.Word
    
    17 18
     import GHC.Cmm
    
    ... ... @@ -55,7 +56,6 @@ import GHC.Types.SrcLoc (srcSpanFile, srcSpanStartCol, srcSpanStartLine)
    55 56
     import GHC.Types.Tickish (GenTickish (..))
    
    56 57
     import GHC.Types.Unique.DSM
    
    57 58
     import GHC.Utils.Constants (debugIsOn)
    
    58
    -import GHC.Utils.Misc
    
    59 59
     import GHC.Utils.Monad
    
    60 60
     import GHC.Utils.Outputable
    
    61 61
     import GHC.Utils.Panic
    
    ... ... @@ -266,7 +266,6 @@ annExpr e {- debugIsOn -} = ANN (text . show $ e)
    266 266
     -- This seems to be PIC compatible; at least `scanelf` (pax-utils) does not
    
    267 267
     -- complain.
    
    268 268
     
    
    269
    -
    
    270 269
     -- | Generate jump to jump table target
    
    271 270
     --
    
    272 271
     -- The index into the jump table is calulated by evaluating @expr@. The
    
    ... ... @@ -423,22 +422,22 @@ getRegisterReg platform (CmmGlobal mid) =
    423 422
     -- General things for putting together code sequences
    
    424 423
     
    
    425 424
     -- | Compute an expression into any register
    
    426
    -getSomeReg :: HasCallStack => CmmExpr -> NatM (Reg, Format, InstrBlock)
    
    425
    +getSomeReg :: CmmExpr -> NatM (Reg, Format, InstrBlock)
    
    427 426
     getSomeReg expr = do
    
    428 427
       r <- getRegister expr
    
    429 428
       res@(reg, fmt, _) <- case r of
    
    430
    -        Any rep code -> do
    
    431
    -          newReg <- getNewRegNat rep
    
    432
    -          pure (newReg, rep, code newReg)
    
    433
    -        Fixed rep reg code ->
    
    434
    -          pure (reg, rep, code)
    
    429
    +    Any rep code -> do
    
    430
    +      newReg <- getNewRegNat rep
    
    431
    +      pure (newReg, rep, code newReg)
    
    432
    +    Fixed rep reg code ->
    
    433
    +      pure (reg, rep, code)
    
    435 434
       pure $ assertFmtReg fmt reg res
    
    436 435
     
    
    437 436
     -- | Compute an expression into any floating-point register
    
    438 437
     --
    
    439 438
     -- If the initial expression is not a floating-point expression, finally move
    
    440 439
     -- the result into a floating-point register.
    
    441
    -getFloatReg :: (HasCallStack) => CmmExpr -> NatM (Reg, Format, InstrBlock)
    
    440
    +getFloatReg :: CmmExpr -> NatM (Reg, Format, InstrBlock)
    
    442 441
     getFloatReg expr = do
    
    443 442
       r <- getRegister expr
    
    444 443
       case r of
    
    ... ... @@ -866,13 +865,10 @@ getRegister' config plat expr =
    866 865
             MO_AlignmentCheck align wordWidth -> do
    
    867 866
               reg <- getRegister' config plat e
    
    868 867
               addAlignmentCheck align wordWidth reg
    
    869
    -
    
    870 868
             MO_V_Broadcast length w -> vectorBroadcast (intVecFormat length w) e
    
    871 869
             MO_VF_Broadcast length w -> vectorBroadcast (floatVecFormat length w) e
    
    872
    -
    
    873 870
             MO_VS_Neg length w -> vectorNegation (intVecFormat length w)
    
    874 871
             MO_VF_Neg length w -> vectorNegation (floatVecFormat length w)
    
    875
    -
    
    876 872
             x -> pprPanic ("getRegister' (monadic CmmMachOp): " ++ show x) (pdoc plat expr)
    
    877 873
           where
    
    878 874
             -- In the case of 16- or 8-bit values we need to sign-extend to 32-bits
    
    ... ... @@ -1236,9 +1232,14 @@ getRegister' config plat expr =
    1236 1232
               genericVectorShuffle :: (Int -> Width -> Format) -> Int -> Width -> [Int] -> NatM Register
    
    1237 1233
               genericVectorShuffle toDstFormat length w idxs = do
    
    1238 1234
                 -- Our strategy:
    
    1239
    -            --   - Gather elemens of v1 on the right positions
    
    1240
    -            --   - Gather elemenrs of v2 of the right positions
    
    1241
    -            --   - Merge v1 and v2 with an adequate bitmask (v0)
    
    1235
    +            --   - Gather elements of v1 on the right positions -> v1'
    
    1236
    +            --   - Gather elements of v2 of the right positions -> v2'
    
    1237
    +            --   - Merge v1' and v2' with an adequate bitmask (v0)
    
    1238
    +            --
    
    1239
    +            -- We create three supporting data section entries / structures:
    
    1240
    +            --   - A mapping vector that describes the mapping of v1 -> v1'
    
    1241
    +            --   - The same for v2 -> v2'
    
    1242
    +            --   - the mask vector used to merge v1' and v2'
    
    1242 1243
                 lbl_selVec_v1 <- getNewLabelNat
    
    1243 1244
                 lbl_selVec_v2 <- getNewLabelNat
    
    1244 1245
                 lbl_mask <- getNewLabelNat
    
    ... ... @@ -1247,7 +1248,7 @@ getRegister' config plat expr =
    1247 1248
                 (reg_y, format_y, code_y) <- getSomeReg y
    
    1248 1249
     
    
    1249 1250
                 let (idxs_v1, idxs_v2) =
    
    1250
    -                  mapTuple reverse
    
    1251
    +                  bimap reverse reverse
    
    1251 1252
                         $ foldl'
    
    1252 1253
                           ( \(acc1, acc2) i ->
    
    1253 1254
                               if i < length then (Just i : acc1, Nothing : acc2) else (Nothing : acc1, Just (i - length) : acc2)
    
    ... ... @@ -1259,7 +1260,6 @@ getRegister' config plat expr =
    1259 1260
                     -- Finally, the mask must be 0 where v1 should be taken and 1 for v2.
    
    1260 1261
                     -- That's why we do an implicit negation here by focussing on v2.
    
    1261 1262
                     maskVecData = maskData idxs_v2
    
    1262
    -                -- Longest vector length is 64, so 8bit (0 - 255) is sufficient
    
    1263 1263
                     selVecFormat = intVecFormat length w
    
    1264 1264
                     dstFormat = toDstFormat length w
    
    1265 1265
                     addrFormat = intFormat W64
    
    ... ... @@ -1301,9 +1301,6 @@ getRegister' config plat expr =
    1301 1301
                         VMERGE (OpReg dstFormat dst) (OpReg format_x gathered_x) (OpReg format_y gathered_y) (OpReg maskFormat v0Reg)
    
    1302 1302
                       ]
    
    1303 1303
     
    
    1304
    -          mapTuple :: (a -> b) -> (a, a) -> (b, b)
    
    1305
    -          mapTuple f (x, y) = (f x, f y)
    
    1306
    -
    
    1307 1304
               selVecData :: Width -> [Maybe Int] -> [CmmStatic]
    
    1308 1305
               selVecData w idxs =
    
    1309 1306
                 -- Using the width `w` here is a bit wasteful. But, it saves
    
    ... ... @@ -1376,7 +1373,6 @@ getRegister' config plat expr =
    1376 1373
             MO_Shl w -> intOp False w (\d x y -> unitOL $ annExpr expr (SLL d x y))
    
    1377 1374
             MO_U_Shr w -> intOp False w (\d x y -> unitOL $ annExpr expr (SRL d x y))
    
    1378 1375
             MO_S_Shr w -> intOp True w (\d x y -> unitOL $ annExpr expr (SRA d x y))
    
    1379
    -
    
    1380 1376
             -- Vector operations
    
    1381 1377
             MO_VF_Extract _length w -> vecExtract ((scalarFormatFormat . floatScalarFormat) w)
    
    1382 1378
             MO_V_Extract _length w -> vecExtract ((scalarFormatFormat . intScalarFormat) w)
    
    ... ... @@ -1404,12 +1400,23 @@ getRegister' config plat expr =
    1404 1400
         -- Generic ternary case.
    
    1405 1401
         CmmMachOp op [x, y, z] ->
    
    1406 1402
           case op of
    
    1407
    -        -- Floating-point fused multiply-add operations
    
    1403
    +        -- Floating-point fused multiply-add operations:
    
    1404
    +        --
    
    1405
    +        -- x86 fmadd    x * y + z <=> RISCV64 fmadd :  d =   r1 * r2 + r3
    
    1406
    +        -- x86 fmsub    x * y - z <=> RISCV64 fmsub:   d =   r1 * r2 - r3
    
    1407
    +        -- x86 fnmadd - x * y + z <=> RISCV64 fnmsub:  d = - r1 * r2 + r3
    
    1408
    +        -- x86 fnmsub - x * y - z <=> RISCV64 fnmadd:  d = - r1 * r2 - r3
    
    1408 1409
             --
    
    1409
    -        -- x86 fmadd    x * y + z <=> RISCV64 fmadd : d =   r1 * r2 + r3
    
    1410
    -        -- x86 fmsub    x * y - z <=> RISCV64 fnmsub: d =   r1 * r2 - r3
    
    1411
    -        -- x86 fnmadd - x * y + z <=> RISCV64 fmsub : d = - r1 * r2 + r3
    
    1412
    -        -- x86 fnmsub - x * y - z <=> RISCV64 fnmadd: d = - r1 * r2 - r3
    
    1410
    +        -- Vector fused multiply-add operations (what x86 exactly does doesn't
    
    1411
    +        -- matter here, we care about the abstract spec):
    
    1412
    +        --
    
    1413
    +        -- FMAdd    x * y + z <=> RISCV64 vfmadd :  d =   r1 * r2 + r3
    
    1414
    +        -- FMSub    x * y - z <=> RISCV64 vfmsub:   d =   r1 * r2 - r3
    
    1415
    +        -- FNMAdd - x * y + z <=> RISCV64 vfnmsub:  d = - r1 * r2 + r3
    
    1416
    +        -- FNMSub - x * y - z <=> RISCV64 vfnmadd:  d = - r1 * r2 - r3
    
    1417
    +        --
    
    1418
    +        -- For both formats, the instruction selection happens in the
    
    1419
    +        -- pretty-printer.
    
    1413 1420
             MO_FMA var length w
    
    1414 1421
               | length == 1 ->
    
    1415 1422
                   float3Op w (\d n m a -> unitOL $ FMA var d n m a)
    
    ... ... @@ -1418,12 +1425,10 @@ getRegister' config plat expr =
    1418 1425
                   (reg_y, format_y, code_y) <- getSomeReg y
    
    1419 1426
                   (reg_z, format_z, code_z) <- getSomeReg z
    
    1420 1427
                   let targetFormat = VecFormat length (floatScalarFormat w)
    
    1421
    -                  negate_z = if var `elem` [FNMAdd, FNMSub] then unitOL (VNEG (OpReg format_z reg_z) (OpReg format_z reg_z)) else nilOL
    
    1422 1428
                   pure $ Any targetFormat $ \dst ->
    
    1423 1429
                     code_x
    
    1424 1430
                       `appOL` code_y
    
    1425 1431
                       `appOL` code_z
    
    1426
    -                  `appOL` negate_z
    
    1427 1432
                       `snocOL` annExpr
    
    1428 1433
                         expr
    
    1429 1434
                         (VMV (OpReg targetFormat dst) (OpReg format_x reg_x))
    
    ... ... @@ -2138,8 +2143,7 @@ genCCall target@(ForeignTarget expr _cconv) dest_regs arg_regs = do
    2138 2143
         -- See Note [RISC-V vector C calling convention]
    
    2139 2144
         passArguments _gpRegs _fpRegs [] ((_r, format, _hint, _code_r) : _args) _stackSpaceWords _accumRegs _accumCode
    
    2140 2145
           | isVecFormat format =
    
    2141
    -      panic "C call: no free vector argument registers. We only support 16 vector arguments (registers v8 - v23)."
    
    2142
    -
    
    2146
    +          panic "C call: no free vector argument registers. We only support 16 vector arguments (registers v8 - v23)."
    
    2143 2147
         passArguments _ _ _ _ _ _ _ = pprPanic "passArguments" (text "invalid state")
    
    2144 2148
     
    
    2145 2149
         readResults :: [Reg] -> [Reg] -> [Reg] -> [LocalReg] -> [Reg] -> InstrBlock -> NatM InstrBlock
    

  • compiler/GHC/CmmToAsm/RV64/Ppr.hs
    ... ... @@ -15,6 +15,7 @@ import GHC.CmmToAsm.RV64.Instr
    15 15
     import GHC.CmmToAsm.RV64.Regs
    
    16 16
     import GHC.CmmToAsm.Types
    
    17 17
     import GHC.CmmToAsm.Utils
    
    18
    +import GHC.Data.OrdList
    
    18 19
     import GHC.Platform
    
    19 20
     import GHC.Platform.Reg
    
    20 21
     import GHC.Prelude hiding (EQ)
    
    ... ... @@ -23,7 +24,6 @@ import GHC.Types.Basic (Alignment, alignmentBytes, mkAlignment)
    23 24
     import GHC.Types.Unique (getUnique, pprUniqueAlways)
    
    24 25
     import GHC.Utils.Outputable
    
    25 26
     import GHC.Utils.Panic
    
    26
    -import GHC.Data.OrdList
    
    27 27
     
    
    28 28
     pprNatCmmDecl :: forall doc. (IsDoc doc) => NCGConfig -> NatCmmDecl RawCmmStatics Instr -> doc
    
    29 29
     pprNatCmmDecl config (CmmData section dats) =
    
    ... ... @@ -804,8 +804,8 @@ pprInstr platform instr = case instr of
    804 804
             let fma = case variant of
    
    805 805
                   FMAdd -> text "\tfmadd" <> dot <> floatPrecission d
    
    806 806
                   FMSub -> text "\tfmsub" <> dot <> floatPrecission d
    
    807
    -              FNMAdd -> text "\tfnmadd" <> dot <> floatPrecission d
    
    808
    -              FNMSub -> text "\tfnmsub" <> dot <> floatPrecission d
    
    807
    +              FNMAdd -> text "\tfnmsub" <> dot <> floatPrecission d
    
    808
    +              FNMSub -> text "\tfnmadd" <> dot <> floatPrecission d
    
    809 809
              in op4 fma d r1 r2 r3
    
    810 810
       VFMA variant o1@(OpReg fmt _reg) o2 o3
    
    811 811
         | VecFormat _l fmt' <- fmt ->
    
    ... ... @@ -815,8 +815,8 @@ pprInstr platform instr = case instr of
    815 815
                 fma = case variant of
    
    816 816
                   FMAdd -> text "madd"
    
    817 817
                   FMSub -> text "msub" -- TODO: Works only for floats!
    
    818
    -              FNMAdd -> text "nmadd" -- TODO: Works only for floats!
    
    819
    -              FNMSub -> text "nmsub"
    
    818
    +              FNMAdd -> text "nmsub" -- TODO: Works only for floats!
    
    819
    +              FNMSub -> text "nmadd"
    
    820 820
              in op3 (tab <> prefix <> fma <> dot <> suffix) o1 o2 o3
    
    821 821
       VFMA _variant o1 _o2 _o3 -> pprPanic "RV64.pprInstr - VFMA can only target registers." (pprOp platform o1)
    
    822 822
       VMV o1@(OpReg fmt _reg) o2
    
    ... ... @@ -830,7 +830,7 @@ pprInstr platform instr = case instr of
    830 830
       VMV o1 o2 -> pprPanic "RV64.pprInstr - invalid VMV instruction" (text "VMV" <+> pprOp platform o1 <> comma <+> pprOp platform o2)
    
    831 831
       VID op | isVectorRegOp op -> op1 (text "\tvid.v") op
    
    832 832
       VID op -> pprPanic "RV64.pprInstr - VID can only target registers." (pprOp platform op)
    
    833
    -  VMSEQ o1 o2 o3 | allVectorRegOps [o1, o2] && isIntOp o3 && isImmOp o3-> op3 (text "\tvmseq.vi") o1 o2 o3
    
    833
    +  VMSEQ o1 o2 o3 | allVectorRegOps [o1, o2] && isIntOp o3 && isImmOp o3 -> op3 (text "\tvmseq.vi") o1 o2 o3
    
    834 834
       VMSEQ o1 o2 o3 | allVectorRegOps [o1, o2] && isIntOp o3 -> op3 (text "\tvmseq.vx") o1 o2 o3
    
    835 835
       VMSEQ o1 o2 o3 -> pprPanic "RV64.pprInstr - VMSEQ wrong operands." (pprOps platform [o1, o2, o3])
    
    836 836
       VMERGE o1 o2 o3 o4 | allVectorRegOps [o1, o2, o3, o4] -> op4 (text "\tvmerge.vvm") o1 o2 o3 o4
    

  • testsuite/driver/testlib.py
    ... ... @@ -426,7 +426,8 @@ def req_fma_cpu( name, opts ):
    426 426
     
    
    427 427
         # RISC-V: We imply float and double extensions (rv64g), so we only have to
    
    428 428
         # check for vector support.
    
    429
    -    if not(have_cpu_feature('avx') or have_cpu_feature('zvl128b')):
    
    429
    +    # AArch64: Always expect FMA support.
    
    430
    +    if not (have_cpu_feature('avx') or arch('aarch64') or have_cpu_feature('zvl128b')):
    
    430 431
             opts.skip = True
    
    431 432
     
    
    432 433
     def ignore_stdout(name, opts):
    

  • testsuite/tests/primops/should_run/all.T
    ... ... @@ -63,16 +63,12 @@ test('UnliftedTVar2', normal, compile_and_run, [''])
    63 63
     test('UnliftedWeakPtr', normal, compile_and_run, [''])
    
    64 64
     
    
    65 65
     test('FMA_Primops'
    
    66
    -    , [ when(have_cpu_feature('fma'), extra_hc_opts('-mfma'))
    
    67
    -      , js_skip # JS backend doesn't have an FMA implementation
    
    68
    -      , when(arch('wasm32'), skip)
    
    66
    +    , [ req_fma_cpu, extra_hc_opts('-mfma')
    
    69 67
           , when(have_llvm(), extra_ways(["optllvm"]))
    
    70 68
           ]
    
    71 69
          , compile_and_run, [''])
    
    72 70
     test('FMA_ConstantFold'
    
    73
    -    , [ when(have_cpu_feature('fma'), extra_hc_opts('-mfma'))
    
    74
    -      , js_skip # JS backend doesn't have an FMA implementation
    
    75
    -      , when(arch('wasm32'), skip)
    
    71
    +    , [ req_fma_cpu, extra_hc_opts('-mfma')
    
    76 72
           , expect_broken(21227)
    
    77 73
           , when(have_llvm(), extra_ways(["optllvm"]))
    
    78 74
           ]
    
    ... ... @@ -85,9 +81,7 @@ test('T23071',
    85 81
          [''])
    
    86 82
     test('T22710', normal, compile_and_run, [''])
    
    87 83
     test('T24496'
    
    88
    -    , [ when(have_cpu_feature('fma'), extra_hc_opts('-mfma'))
    
    89
    -      , js_skip # JS backend doesn't have an FMA implementation
    
    90
    -      , when(arch('wasm32'), skip)
    
    84
    +    , [ req_fma_cpu, extra_hc_opts('-mfma')
    
    91 85
           , when(have_llvm(), extra_ways(["optllvm"]))
    
    92 86
           ]
    
    93 87
         , compile_and_run, ['-O'])
    

  • testsuite/tests/simd/should_run/all.T
    ... ... @@ -123,10 +123,10 @@ test('word32x4_basic', [], compile_and_run, [''])
    123 123
     test('word64x2_basic', [], compile_and_run, [''])
    
    124 124
     test('floatx4_arith', [], compile_and_run, [''])
    
    125 125
     test('doublex2_arith', [], compile_and_run, [''])
    
    126
    -test('floatx4_fma', [ unless(have_cpu_feature('fma'), skip)
    
    126
    +test('floatx4_fma', [ req_fma_cpu
    
    127 127
                         , extra_hc_opts('-mfma')
    
    128 128
                         ], compile_and_run, [''])
    
    129
    -test('doublex2_fma', [ unless(have_cpu_feature('fma'), skip)
    
    129
    +test('doublex2_fma', [ req_fma_cpu
    
    130 130
                          , extra_hc_opts('-mfma')
    
    131 131
                          ], compile_and_run, [''])
    
    132 132
     test('int8x16_arith', [], compile_and_run, [''])