[Git][ghc/ghc][wip/supersven/riscv-vectors] 7 commits: Use BiFunctor
Sven Tennie pushed to branch wip/supersven/riscv-vectors at Glasgow Haskell Compiler / GHC Commits: de8ee118 by Sven Tennie at 2025-07-24T11:28:23+02:00 Use BiFunctor - - - - - 1f5173f7 by Sven Tennie at 2025-07-24T11:28:41+02:00 Cleanup HasCallstack annotations - - - - - 01b50257 by Sven Tennie at 2025-07-24T11:28:49+02:00 Better comment - - - - - 77a15fb0 by Sven Tennie at 2025-07-25T18:09:13+02:00 Formatting - - - - - 995b22a4 by Sven Tennie at 2025-07-25T18:09:34+02:00 Overhaul comments for shuffle - - - - - fe090b7b by Sven Tennie at 2025-07-26T17:50:08+02:00 Test and fix (V)FMA - - - - - 9e4bd9d4 by Sven Tennie at 2025-07-26T17:50:18+02:00 Formatting - - - - - 5 changed files: - compiler/GHC/CmmToAsm/RV64/CodeGen.hs - compiler/GHC/CmmToAsm/RV64/Ppr.hs - testsuite/driver/testlib.py - testsuite/tests/primops/should_run/all.T - testsuite/tests/simd/should_run/all.T Changes: ===================================== compiler/GHC/CmmToAsm/RV64/CodeGen.hs ===================================== @@ -12,6 +12,7 @@ module GHC.CmmToAsm.RV64.CodeGen where import Control.Monad +import Data.Bifunctor (bimap) import Data.Maybe import Data.Word import GHC.Cmm @@ -55,7 +56,6 @@ import GHC.Types.SrcLoc (srcSpanFile, srcSpanStartCol, srcSpanStartLine) import GHC.Types.Tickish (GenTickish (..)) import GHC.Types.Unique.DSM import GHC.Utils.Constants (debugIsOn) -import GHC.Utils.Misc import GHC.Utils.Monad import GHC.Utils.Outputable import GHC.Utils.Panic @@ -266,7 +266,6 @@ annExpr e {- debugIsOn -} = ANN (text . show $ e) -- This seems to be PIC compatible; at least `scanelf` (pax-utils) does not -- complain. - -- | Generate jump to jump table target -- -- The index into the jump table is calulated by evaluating @expr@. The @@ -423,22 +422,22 @@ getRegisterReg platform (CmmGlobal mid) = -- General things for putting together code sequences -- | Compute an expression into any register -getSomeReg :: HasCallStack => CmmExpr -> NatM (Reg, Format, InstrBlock) +getSomeReg :: CmmExpr -> NatM (Reg, Format, InstrBlock) getSomeReg expr = do r <- getRegister expr res@(reg, fmt, _) <- case r of - Any rep code -> do - newReg <- getNewRegNat rep - pure (newReg, rep, code newReg) - Fixed rep reg code -> - pure (reg, rep, code) + Any rep code -> do + newReg <- getNewRegNat rep + pure (newReg, rep, code newReg) + Fixed rep reg code -> + pure (reg, rep, code) pure $ assertFmtReg fmt reg res -- | Compute an expression into any floating-point register -- -- If the initial expression is not a floating-point expression, finally move -- the result into a floating-point register. -getFloatReg :: (HasCallStack) => CmmExpr -> NatM (Reg, Format, InstrBlock) +getFloatReg :: CmmExpr -> NatM (Reg, Format, InstrBlock) getFloatReg expr = do r <- getRegister expr case r of @@ -866,13 +865,10 @@ getRegister' config plat expr = MO_AlignmentCheck align wordWidth -> do reg <- getRegister' config plat e addAlignmentCheck align wordWidth reg - MO_V_Broadcast length w -> vectorBroadcast (intVecFormat length w) e MO_VF_Broadcast length w -> vectorBroadcast (floatVecFormat length w) e - MO_VS_Neg length w -> vectorNegation (intVecFormat length w) MO_VF_Neg length w -> vectorNegation (floatVecFormat length w) - x -> pprPanic ("getRegister' (monadic CmmMachOp): " ++ show x) (pdoc plat expr) where -- In the case of 16- or 8-bit values we need to sign-extend to 32-bits @@ -1236,9 +1232,14 @@ getRegister' config plat expr = genericVectorShuffle :: (Int -> Width -> Format) -> Int -> Width -> [Int] -> NatM Register genericVectorShuffle toDstFormat length w idxs = do -- Our strategy: - -- - Gather elemens of v1 on the right positions - -- - Gather elemenrs of v2 of the right positions - -- - Merge v1 and v2 with an adequate bitmask (v0) + -- - Gather elements of v1 on the right positions -> v1' + -- - Gather elements of v2 of the right positions -> v2' + -- - Merge v1' and v2' with an adequate bitmask (v0) + -- + -- We create three supporting data section entries / structures: + -- - A mapping vector that describes the mapping of v1 -> v1' + -- - The same for v2 -> v2' + -- - the mask vector used to merge v1' and v2' lbl_selVec_v1 <- getNewLabelNat lbl_selVec_v2 <- getNewLabelNat lbl_mask <- getNewLabelNat @@ -1247,7 +1248,7 @@ getRegister' config plat expr = (reg_y, format_y, code_y) <- getSomeReg y let (idxs_v1, idxs_v2) = - mapTuple reverse + bimap reverse reverse $ foldl' ( \(acc1, acc2) i -> if i < length then (Just i : acc1, Nothing : acc2) else (Nothing : acc1, Just (i - length) : acc2) @@ -1259,7 +1260,6 @@ getRegister' config plat expr = -- Finally, the mask must be 0 where v1 should be taken and 1 for v2. -- That's why we do an implicit negation here by focussing on v2. maskVecData = maskData idxs_v2 - -- Longest vector length is 64, so 8bit (0 - 255) is sufficient selVecFormat = intVecFormat length w dstFormat = toDstFormat length w addrFormat = intFormat W64 @@ -1301,9 +1301,6 @@ getRegister' config plat expr = VMERGE (OpReg dstFormat dst) (OpReg format_x gathered_x) (OpReg format_y gathered_y) (OpReg maskFormat v0Reg) ] - mapTuple :: (a -> b) -> (a, a) -> (b, b) - mapTuple f (x, y) = (f x, f y) - selVecData :: Width -> [Maybe Int] -> [CmmStatic] selVecData w idxs = -- Using the width `w` here is a bit wasteful. But, it saves @@ -1376,7 +1373,6 @@ getRegister' config plat expr = MO_Shl w -> intOp False w (\d x y -> unitOL $ annExpr expr (SLL d x y)) MO_U_Shr w -> intOp False w (\d x y -> unitOL $ annExpr expr (SRL d x y)) MO_S_Shr w -> intOp True w (\d x y -> unitOL $ annExpr expr (SRA d x y)) - -- Vector operations MO_VF_Extract _length w -> vecExtract ((scalarFormatFormat . floatScalarFormat) w) MO_V_Extract _length w -> vecExtract ((scalarFormatFormat . intScalarFormat) w) @@ -1404,12 +1400,23 @@ getRegister' config plat expr = -- Generic ternary case. CmmMachOp op [x, y, z] -> case op of - -- Floating-point fused multiply-add operations + -- Floating-point fused multiply-add operations: + -- + -- x86 fmadd x * y + z <=> RISCV64 fmadd : d = r1 * r2 + r3 + -- x86 fmsub x * y - z <=> RISCV64 fmsub: d = r1 * r2 - r3 + -- x86 fnmadd - x * y + z <=> RISCV64 fnmsub: d = - r1 * r2 + r3 + -- x86 fnmsub - x * y - z <=> RISCV64 fnmadd: d = - r1 * r2 - r3 -- - -- x86 fmadd x * y + z <=> RISCV64 fmadd : d = r1 * r2 + r3 - -- x86 fmsub x * y - z <=> RISCV64 fnmsub: d = r1 * r2 - r3 - -- x86 fnmadd - x * y + z <=> RISCV64 fmsub : d = - r1 * r2 + r3 - -- x86 fnmsub - x * y - z <=> RISCV64 fnmadd: d = - r1 * r2 - r3 + -- Vector fused multiply-add operations (what x86 exactly does doesn't + -- matter here, we care about the abstract spec): + -- + -- FMAdd x * y + z <=> RISCV64 vfmadd : d = r1 * r2 + r3 + -- FMSub x * y - z <=> RISCV64 vfmsub: d = r1 * r2 - r3 + -- FNMAdd - x * y + z <=> RISCV64 vfnmsub: d = - r1 * r2 + r3 + -- FNMSub - x * y - z <=> RISCV64 vfnmadd: d = - r1 * r2 - r3 + -- + -- For both formats, the instruction selection happens in the + -- pretty-printer. MO_FMA var length w | length == 1 -> float3Op w (\d n m a -> unitOL $ FMA var d n m a) @@ -1418,12 +1425,10 @@ getRegister' config plat expr = (reg_y, format_y, code_y) <- getSomeReg y (reg_z, format_z, code_z) <- getSomeReg z let targetFormat = VecFormat length (floatScalarFormat w) - negate_z = if var `elem` [FNMAdd, FNMSub] then unitOL (VNEG (OpReg format_z reg_z) (OpReg format_z reg_z)) else nilOL pure $ Any targetFormat $ \dst -> code_x `appOL` code_y `appOL` code_z - `appOL` negate_z `snocOL` annExpr expr (VMV (OpReg targetFormat dst) (OpReg format_x reg_x)) @@ -2138,8 +2143,7 @@ genCCall target@(ForeignTarget expr _cconv) dest_regs arg_regs = do -- See Note [RISC-V vector C calling convention] passArguments _gpRegs _fpRegs [] ((_r, format, _hint, _code_r) : _args) _stackSpaceWords _accumRegs _accumCode | isVecFormat format = - panic "C call: no free vector argument registers. We only support 16 vector arguments (registers v8 - v23)." - + panic "C call: no free vector argument registers. We only support 16 vector arguments (registers v8 - v23)." passArguments _ _ _ _ _ _ _ = pprPanic "passArguments" (text "invalid state") readResults :: [Reg] -> [Reg] -> [Reg] -> [LocalReg] -> [Reg] -> InstrBlock -> NatM InstrBlock ===================================== compiler/GHC/CmmToAsm/RV64/Ppr.hs ===================================== @@ -15,6 +15,7 @@ import GHC.CmmToAsm.RV64.Instr import GHC.CmmToAsm.RV64.Regs import GHC.CmmToAsm.Types import GHC.CmmToAsm.Utils +import GHC.Data.OrdList import GHC.Platform import GHC.Platform.Reg import GHC.Prelude hiding (EQ) @@ -23,7 +24,6 @@ import GHC.Types.Basic (Alignment, alignmentBytes, mkAlignment) import GHC.Types.Unique (getUnique, pprUniqueAlways) import GHC.Utils.Outputable import GHC.Utils.Panic -import GHC.Data.OrdList pprNatCmmDecl :: forall doc. (IsDoc doc) => NCGConfig -> NatCmmDecl RawCmmStatics Instr -> doc pprNatCmmDecl config (CmmData section dats) = @@ -804,8 +804,8 @@ pprInstr platform instr = case instr of let fma = case variant of FMAdd -> text "\tfmadd" <> dot <> floatPrecission d FMSub -> text "\tfmsub" <> dot <> floatPrecission d - FNMAdd -> text "\tfnmadd" <> dot <> floatPrecission d - FNMSub -> text "\tfnmsub" <> dot <> floatPrecission d + FNMAdd -> text "\tfnmsub" <> dot <> floatPrecission d + FNMSub -> text "\tfnmadd" <> dot <> floatPrecission d in op4 fma d r1 r2 r3 VFMA variant o1@(OpReg fmt _reg) o2 o3 | VecFormat _l fmt' <- fmt -> @@ -815,8 +815,8 @@ pprInstr platform instr = case instr of fma = case variant of FMAdd -> text "madd" FMSub -> text "msub" -- TODO: Works only for floats! - FNMAdd -> text "nmadd" -- TODO: Works only for floats! - FNMSub -> text "nmsub" + FNMAdd -> text "nmsub" -- TODO: Works only for floats! + FNMSub -> text "nmadd" in op3 (tab <> prefix <> fma <> dot <> suffix) o1 o2 o3 VFMA _variant o1 _o2 _o3 -> pprPanic "RV64.pprInstr - VFMA can only target registers." (pprOp platform o1) VMV o1@(OpReg fmt _reg) o2 @@ -830,7 +830,7 @@ pprInstr platform instr = case instr of VMV o1 o2 -> pprPanic "RV64.pprInstr - invalid VMV instruction" (text "VMV" <+> pprOp platform o1 <> comma <+> pprOp platform o2) VID op | isVectorRegOp op -> op1 (text "\tvid.v") op VID op -> pprPanic "RV64.pprInstr - VID can only target registers." (pprOp platform op) - VMSEQ o1 o2 o3 | allVectorRegOps [o1, o2] && isIntOp o3 && isImmOp o3-> op3 (text "\tvmseq.vi") o1 o2 o3 + VMSEQ o1 o2 o3 | allVectorRegOps [o1, o2] && isIntOp o3 && isImmOp o3 -> op3 (text "\tvmseq.vi") o1 o2 o3 VMSEQ o1 o2 o3 | allVectorRegOps [o1, o2] && isIntOp o3 -> op3 (text "\tvmseq.vx") o1 o2 o3 VMSEQ o1 o2 o3 -> pprPanic "RV64.pprInstr - VMSEQ wrong operands." (pprOps platform [o1, o2, o3]) VMERGE o1 o2 o3 o4 | allVectorRegOps [o1, o2, o3, o4] -> op4 (text "\tvmerge.vvm") o1 o2 o3 o4 ===================================== testsuite/driver/testlib.py ===================================== @@ -426,7 +426,8 @@ def req_fma_cpu( name, opts ): # RISC-V: We imply float and double extensions (rv64g), so we only have to # check for vector support. - if not(have_cpu_feature('avx') or have_cpu_feature('zvl128b')): + # AArch64: Always expect FMA support. + if not (have_cpu_feature('avx') or arch('aarch64') or have_cpu_feature('zvl128b')): opts.skip = True def ignore_stdout(name, opts): ===================================== testsuite/tests/primops/should_run/all.T ===================================== @@ -63,16 +63,12 @@ test('UnliftedTVar2', normal, compile_and_run, ['']) test('UnliftedWeakPtr', normal, compile_and_run, ['']) test('FMA_Primops' - , [ when(have_cpu_feature('fma'), extra_hc_opts('-mfma')) - , js_skip # JS backend doesn't have an FMA implementation - , when(arch('wasm32'), skip) + , [ req_fma_cpu, extra_hc_opts('-mfma') , when(have_llvm(), extra_ways(["optllvm"])) ] , compile_and_run, ['']) test('FMA_ConstantFold' - , [ when(have_cpu_feature('fma'), extra_hc_opts('-mfma')) - , js_skip # JS backend doesn't have an FMA implementation - , when(arch('wasm32'), skip) + , [ req_fma_cpu, extra_hc_opts('-mfma') , expect_broken(21227) , when(have_llvm(), extra_ways(["optllvm"])) ] @@ -85,9 +81,7 @@ test('T23071', ['']) test('T22710', normal, compile_and_run, ['']) test('T24496' - , [ when(have_cpu_feature('fma'), extra_hc_opts('-mfma')) - , js_skip # JS backend doesn't have an FMA implementation - , when(arch('wasm32'), skip) + , [ req_fma_cpu, extra_hc_opts('-mfma') , when(have_llvm(), extra_ways(["optllvm"])) ] , compile_and_run, ['-O']) ===================================== testsuite/tests/simd/should_run/all.T ===================================== @@ -123,10 +123,10 @@ test('word32x4_basic', [], compile_and_run, ['']) test('word64x2_basic', [], compile_and_run, ['']) test('floatx4_arith', [], compile_and_run, ['']) test('doublex2_arith', [], compile_and_run, ['']) -test('floatx4_fma', [ unless(have_cpu_feature('fma'), skip) +test('floatx4_fma', [ req_fma_cpu , extra_hc_opts('-mfma') ], compile_and_run, ['']) -test('doublex2_fma', [ unless(have_cpu_feature('fma'), skip) +test('doublex2_fma', [ req_fma_cpu , extra_hc_opts('-mfma') ], compile_and_run, ['']) test('int8x16_arith', [], compile_and_run, ['']) View it on GitLab: https://gitlab.haskell.org/ghc/ghc/-/compare/be50e5e89425932d4553cbf660e66e6... -- View it on GitLab: https://gitlab.haskell.org/ghc/ghc/-/compare/be50e5e89425932d4553cbf660e66e6... You're receiving this email because of your account on gitlab.haskell.org.
participants (1)
-
Sven Tennie (@supersven)