Sven Tennie pushed to branch wip/supersven/riscv-vectors at Glasgow Haskell Compiler / GHC

Commits:

8 changed files:

Changes:

  • compiler/GHC/CmmToAsm/RV64/CodeGen.hs
    ... ... @@ -1400,12 +1400,23 @@ getRegister' config plat expr =
    1400 1400
         -- Generic ternary case.
    
    1401 1401
         CmmMachOp op [x, y, z] ->
    
    1402 1402
           case op of
    
    1403
    -        -- Floating-point fused multiply-add operations
    
    1403
    +        -- Floating-point fused multiply-add operations:
    
    1404 1404
             --
    
    1405
    -        -- x86 fmadd    x * y + z <=> RISCV64 fmadd : d =   r1 * r2 + r3
    
    1406
    -        -- x86 fmsub    x * y - z <=> RISCV64 fnmsub: d =   r1 * r2 - r3
    
    1407
    -        -- x86 fnmadd - x * y + z <=> RISCV64 fmsub : d = - r1 * r2 + r3
    
    1408
    -        -- x86 fnmsub - x * y - z <=> RISCV64 fnmadd: d = - r1 * r2 - r3
    
    1405
    +        -- x86 fmadd    x * y + z <=> RISCV64 fmadd :  d =   r1 * r2 + r3
    
    1406
    +        -- x86 fmsub    x * y - z <=> RISCV64 fmsub:   d =   r1 * r2 - r3
    
    1407
    +        -- x86 fnmadd - x * y + z <=> RISCV64 fnmsub:  d = - r1 * r2 + r3
    
    1408
    +        -- x86 fnmsub - x * y - z <=> RISCV64 fnmadd:  d = - r1 * r2 - r3
    
    1409
    +        --
    
    1410
    +        -- Vector fused multiply-add operations (what x86 exactly does doesn't
    
    1411
    +        -- matter here, we care about the abstract spec):
    
    1412
    +        --
    
    1413
    +        -- FMAdd    x * y + z <=> RISCV64 vfmadd :  d =   r1 * r2 + r3
    
    1414
    +        -- FMSub    x * y - z <=> RISCV64 vfmsub:   d =   r1 * r2 - r3
    
    1415
    +        -- FNMAdd - x * y + z <=> RISCV64 vfnmsub:  d = - r1 * r2 + r3
    
    1416
    +        -- FNMSub - x * y - z <=> RISCV64 vfnmadd:  d = - r1 * r2 - r3
    
    1417
    +        --
    
    1418
    +        -- For both formats, the instruction selection happens in the
    
    1419
    +        -- pretty-printer.
    
    1409 1420
             MO_FMA var length w
    
    1410 1421
               | length == 1 ->
    
    1411 1422
                   float3Op w (\d n m a -> unitOL $ FMA var d n m a)
    
    ... ... @@ -1414,12 +1425,10 @@ getRegister' config plat expr =
    1414 1425
                   (reg_y, format_y, code_y) <- getSomeReg y
    
    1415 1426
                   (reg_z, format_z, code_z) <- getSomeReg z
    
    1416 1427
                   let targetFormat = VecFormat length (floatScalarFormat w)
    
    1417
    -                  negate_z = if var `elem` [FNMAdd, FNMSub] then unitOL (VNEG (OpReg format_z reg_z) (OpReg format_z reg_z)) else nilOL
    
    1418 1428
                   pure $ Any targetFormat $ \dst ->
    
    1419 1429
                     code_x
    
    1420 1430
                       `appOL` code_y
    
    1421 1431
                       `appOL` code_z
    
    1422
    -                  `appOL` negate_z
    
    1423 1432
                       `snocOL` annExpr
    
    1424 1433
                         expr
    
    1425 1434
                         (VMV (OpReg targetFormat dst) (OpReg format_x reg_x))
    

  • compiler/GHC/CmmToAsm/RV64/Ppr.hs
    ... ... @@ -804,8 +804,8 @@ pprInstr platform instr = case instr of
    804 804
             let fma = case variant of
    
    805 805
                   FMAdd -> text "\tfmadd" <> dot <> floatPrecission d
    
    806 806
                   FMSub -> text "\tfmsub" <> dot <> floatPrecission d
    
    807
    -              FNMAdd -> text "\tfnmadd" <> dot <> floatPrecission d
    
    808
    -              FNMSub -> text "\tfnmsub" <> dot <> floatPrecission d
    
    807
    +              FNMAdd -> text "\tfnmsub" <> dot <> floatPrecission d
    
    808
    +              FNMSub -> text "\tfnmadd" <> dot <> floatPrecission d
    
    809 809
              in op4 fma d r1 r2 r3
    
    810 810
       VFMA variant o1@(OpReg fmt _reg) o2 o3
    
    811 811
         | VecFormat _l fmt' <- fmt ->
    
    ... ... @@ -815,8 +815,8 @@ pprInstr platform instr = case instr of
    815 815
                 fma = case variant of
    
    816 816
                   FMAdd -> text "madd"
    
    817 817
                   FMSub -> text "msub" -- TODO: Works only for floats!
    
    818
    -              FNMAdd -> text "nmadd" -- TODO: Works only for floats!
    
    819
    -              FNMSub -> text "nmsub"
    
    818
    +              FNMAdd -> text "nmsub" -- TODO: Works only for floats!
    
    819
    +              FNMSub -> text "nmadd"
    
    820 820
              in op3 (tab <> prefix <> fma <> dot <> suffix) o1 o2 o3
    
    821 821
       VFMA _variant o1 _o2 _o3 -> pprPanic "RV64.pprInstr - VFMA can only target registers." (pprOp platform o1)
    
    822 822
       VMV o1@(OpReg fmt _reg) o2
    

  • configure.ac
    ... ... @@ -612,9 +612,10 @@ AC_SYS_INTERPRETER()
    612 612
     
    
    613 613
     dnl ** look for GCC and find out which version
    
    614 614
     dnl     Figure out which C compiler to use.  Gcc is preferred.
    
    615
    -dnl     If gcc, make sure it's at least 4.7
    
    615
    +dnl     If gcc, make sure it's at least 4.7 (14 for RISC-V 64bit)
    
    616 616
     dnl
    
    617 617
     FP_GCC_VERSION
    
    618
    +FP_RISCV_CHECK_GCC_VERSION
    
    618 619
     
    
    619 620
     
    
    620 621
     dnl ** Check support for the extra flags passed by GHC when compiling via C
    

  • distrib/configure.ac.in
    ... ... @@ -225,6 +225,7 @@ dnl ** Check gcc version and flags we need to pass it **
    225 225
     FP_GCC_VERSION
    
    226 226
     FP_GCC_SUPPORTS_NO_PIE
    
    227 227
     FP_GCC_SUPPORTS_VIA_C_FLAGS
    
    228
    +FP_RISCV_CHECK_GCC_VERSION
    
    228 229
     
    
    229 230
     FPTOOLS_SET_C_LD_FLAGS([target],[CFLAGS],[LDFLAGS],[IGNORE_LINKER_LD_FLAGS],[CPPFLAGS])
    
    230 231
     FPTOOLS_SET_C_LD_FLAGS([build],[CONF_CC_OPTS_STAGE0],[CONF_GCC_LINKER_OPTS_STAGE0],[CONF_LD_LINKER_OPTS_STAGE0],[CONF_CPP_OPTS_STAGE0])
    

  • m4/fp_riscv_check_gcc_version.m4
    1
    +# FP_RISCV_CHECK_GCC_VERSION
    
    2
    +#
    
    3
    +# We cannot use all GCC versions that are generally supported: Up to
    
    4
    +# (including) GCC 13, GCC does not support the expected C calling convention
    
    5
    +# for vectors. Thus, we require at least GCC 14.
    
    6
    +#
    
    7
    +# Details: GCC 13 expects vector arguments to be passed on stack / by
    
    8
    +# reference, though the "Standard Vector Calling Convention Variant"
    
    9
    +# (https://github.com/riscv-non-isa/riscv-elf-psabi-doc/blob/master/riscv-cc.adoc#standard-vector-calling-convention-variant)
    
    10
    +# - which is the new default (e.g. for GCC 14) - expects vector arguments in
    
    11
    +# registers v8 to v23. I guess, this is due to the "Standard Vector Calling
    
    12
    +# Convention Variant" being pretty new. And, the GCC implementors had to make
    
    13
    +# up design decissions before this part of the standard has been ratified.
    
    14
    +# As long as the calling convention is consistently used for all code, this
    
    15
    +# isn't an issue. But, we have to be able to call C functions compiled by GCC
    
    16
    +# with code emitted by GHC.
    
    17
    +
    
    18
    +AC_DEFUN([FP_RISCV_CHECK_GCC_VERSION], [
    
    19
    +  AC_REQUIRE([FP_GCC_VERSION])
    
    20
    +  AC_REQUIRE([AC_CANONICAL_TARGET])
    
    21
    +  
    
    22
    +  # Check if target is RISC-V
    
    23
    +  case "$target" in
    
    24
    +    riscv64*-*-*)
    
    25
    +      AC_MSG_NOTICE([Assert GCC version for RISC-V. Detected version is $GccVersion])
    
    26
    +      if test -n "$GccVersion"; then
    
    27
    +        AC_CACHE_CHECK([risc-v version of gcc], [fp_cv_riscv_check_gcc_version], [
    
    28
    +            FP_COMPARE_VERSIONS([$GccVersion], [-lt], [14.0],
    
    29
    +                                [AC_MSG_ERROR([Need at least GCC version 14 for RISC-V])],
    
    30
    +                                [AC_MSG_RESULT([good])]
    
    31
    +                                )
    
    32
    +        ])
    
    33
    +      fi
    
    34
    +      ;;
    
    35
    +    # Ignore riscv32*-*-* as we don't have a NCG for RISC-V 32bit targets
    
    36
    +  esac
    
    37
    +])

  • testsuite/driver/testlib.py
    ... ... @@ -426,7 +426,8 @@ def req_fma_cpu( name, opts ):
    426 426
     
    
    427 427
         # RISC-V: We imply float and double extensions (rv64g), so we only have to
    
    428 428
         # check for vector support.
    
    429
    -    if not(have_cpu_feature('avx') or have_cpu_feature('zvl128b')):
    
    429
    +    # AArch64: Always expect FMA support.
    
    430
    +    if not (have_cpu_feature('avx') or arch('aarch64') or have_cpu_feature('zvl128b')):
    
    430 431
             opts.skip = True
    
    431 432
     
    
    432 433
     def ignore_stdout(name, opts):
    

  • testsuite/tests/primops/should_run/all.T
    ... ... @@ -63,16 +63,12 @@ test('UnliftedTVar2', normal, compile_and_run, [''])
    63 63
     test('UnliftedWeakPtr', normal, compile_and_run, [''])
    
    64 64
     
    
    65 65
     test('FMA_Primops'
    
    66
    -    , [ when(have_cpu_feature('fma'), extra_hc_opts('-mfma'))
    
    67
    -      , js_skip # JS backend doesn't have an FMA implementation
    
    68
    -      , when(arch('wasm32'), skip)
    
    66
    +    , [ req_fma_cpu, extra_hc_opts('-mfma')
    
    69 67
           , when(have_llvm(), extra_ways(["optllvm"]))
    
    70 68
           ]
    
    71 69
          , compile_and_run, [''])
    
    72 70
     test('FMA_ConstantFold'
    
    73
    -    , [ when(have_cpu_feature('fma'), extra_hc_opts('-mfma'))
    
    74
    -      , js_skip # JS backend doesn't have an FMA implementation
    
    75
    -      , when(arch('wasm32'), skip)
    
    71
    +    , [ req_fma_cpu, extra_hc_opts('-mfma')
    
    76 72
           , expect_broken(21227)
    
    77 73
           , when(have_llvm(), extra_ways(["optllvm"]))
    
    78 74
           ]
    
    ... ... @@ -85,9 +81,7 @@ test('T23071',
    85 81
          [''])
    
    86 82
     test('T22710', normal, compile_and_run, [''])
    
    87 83
     test('T24496'
    
    88
    -    , [ when(have_cpu_feature('fma'), extra_hc_opts('-mfma'))
    
    89
    -      , js_skip # JS backend doesn't have an FMA implementation
    
    90
    -      , when(arch('wasm32'), skip)
    
    84
    +    , [ req_fma_cpu, extra_hc_opts('-mfma')
    
    91 85
           , when(have_llvm(), extra_ways(["optllvm"]))
    
    92 86
           ]
    
    93 87
         , compile_and_run, ['-O'])
    

  • testsuite/tests/simd/should_run/all.T
    ... ... @@ -127,10 +127,10 @@ test('floatx4_arith', [], compile_and_run, [''])
    127 127
     test('doublex2_arith', [], compile_and_run, [''])
    
    128 128
     test('floatx4_shuffle', [], compile_and_run, [''])
    
    129 129
     test('doublex2_shuffle', [], compile_and_run, [''])
    
    130
    -test('floatx4_fma', [ unless(have_cpu_feature('fma'), skip)
    
    130
    +test('floatx4_fma', [ req_fma_cpu
    
    131 131
                         , extra_hc_opts('-mfma')
    
    132 132
                         ], compile_and_run, [''])
    
    133
    -test('doublex2_fma', [ unless(have_cpu_feature('fma'), skip)
    
    133
    +test('doublex2_fma', [ req_fma_cpu
    
    134 134
                          , extra_hc_opts('-mfma')
    
    135 135
                          ], compile_and_run, [''])
    
    136 136
     test('int8x16_arith', [], compile_and_run, [''])