
Sven Tennie pushed to branch wip/supersven/riscv-vectors at Glasgow Haskell Compiler / GHC Commits: b22c53e1 by Sven Tennie at 2025-04-20T12:08:16+02:00 Implement MO_VS_Quot and MO_VU_Quot - - - - - d8441ea3 by Sven Tennie at 2025-04-20T13:12:41+02:00 Implement MO_X64 and MO_W64 CallishOps - - - - - 4b485b8c by Sven Tennie at 2025-04-21T19:27:09+02:00 WIP: Vector shuffle - - - - - 3 changed files: - compiler/GHC/CmmToAsm/RV64/CodeGen.hs - compiler/GHC/CmmToAsm/RV64/Instr.hs - compiler/GHC/CmmToAsm/RV64/Ppr.hs Changes: ===================================== compiler/GHC/CmmToAsm/RV64/CodeGen.hs ===================================== @@ -1294,7 +1294,7 @@ getRegister' config plat expr = MO_V_Sub length w -> vecOp (intVecFormat length w) VSUB MO_VF_Mul length w -> vecOp (floatVecFormat length w) VMUL MO_V_Mul length w -> vecOp (intVecFormat length w) VMUL - MO_VF_Quot length w -> vecOp (floatVecFormat length w) VQUOT + MO_VF_Quot length w -> vecOp (floatVecFormat length w) (VQUOT Nothing) -- See https://godbolt.org/z/PvcWKMKoW MO_VS_Min length w -> vecOp (intVecFormat length w) VSMIN MO_VS_Max length w -> vecOp (intVecFormat length w) VSMAX @@ -1302,6 +1302,66 @@ getRegister' config plat expr = MO_VU_Max length w -> vecOp (intVecFormat length w) VUMAX MO_VF_Min length w -> vecOp (floatVecFormat length w) VFMIN MO_VF_Max length w -> vecOp (floatVecFormat length w) VFMAX + MO_V_Shuffle length w idxs -> do + -- Our strategy: + -- - Gather elemens of v1 on the right positions + -- - Gather elemenrs of v2 of the right positions + -- - Merge v1 and v2 with an adequate bitmask (v0) + lbl_selVec_v1 <- getNewLabelNat + lbl_selVec_v2 <- getNewLabelNat + + (reg_x, format_x, code_x) <- getSomeReg x + (reg_y, format_y, code_y) <- getSomeReg y + + let (idxs_v1, idxs_v2) = + mapTuple reverse + $ foldl' + ( \(acc1, acc2) i -> + if i < length then (Just i : acc1, Nothing : acc2) else (Nothing : acc1, Just (i - length) : acc2) + ) + ([], []) + idxs + selVecData_v1 = selVecData idxs_v1 + selVecData_v2 = selVecData idxs_v2 + selVecFormat = intVecFormat length W16 + dstFormat = intVecFormat length w + addrFormat = intFormat W64 + sel_v1 <- getNewRegNat selVecFormat + sel_v2 <- getNewRegNat selVecFormat + sel_v1_addr <- getNewRegNat addrFormat + sel_v2_addr <- getNewRegNat addrFormat + gathered_x <- getNewRegNat format_x + gathered_y <- getNewRegNat format_y + pure $ Any dstFormat $ \dst -> + toOL + [ LDATA (Section ReadOnlyData lbl_selVec_v1) (CmmStaticsRaw lbl_selVec_v1 selVecData_v1), + LDATA (Section ReadOnlyData lbl_selVec_v2) (CmmStaticsRaw lbl_selVec_v2 selVecData_v2) + ] + `appOL` code_x + `appOL` code_y + `appOL` toOL + [ LDR addrFormat (OpReg addrFormat sel_v1_addr) (OpImm (ImmCLbl lbl_selVec_v1)), + LDR addrFormat (OpReg addrFormat sel_v2_addr) (OpImm (ImmCLbl lbl_selVec_v2)), + LDRU selVecFormat (OpReg selVecFormat sel_v1) (OpAddr (AddrReg sel_v1_addr)), + LDRU selVecFormat (OpReg selVecFormat sel_v2) (OpAddr (AddrReg sel_v2_addr)), + VRGATHER (OpReg format_x gathered_x) (OpReg format_x reg_x) (OpReg selVecFormat sel_v1), + VRGATHER (OpReg format_y gathered_y) (OpReg format_y reg_y) (OpReg selVecFormat sel_v2), + VMV (OpReg selVecFormat v0Reg) (OpReg selVecFormat sel_v1), + VMERGE (OpReg dstFormat dst)(OpReg format_x gathered_x)(OpReg format_y gathered_y) (OpReg selVecFormat v0Reg) + ] + where + mapTuple :: (a -> b) -> (a, a) -> (b, b) + mapTuple f (x, y) = (f x, f y) + selVecData :: [Maybe Int] -> [CmmStatic] + selVecData idxs = + (CmmStaticLit . (flip CmmInt) W16 . fromIntegral) + `map` ( map + ( \i -> case i of + Just i' -> i' + Nothing -> 0 + ) + idxs + ) _e -> panic $ "Missing operation " ++ show expr -- Generic ternary case. @@ -1331,7 +1391,6 @@ getRegister' config plat expr = expr (VMV (OpReg targetFormat dst) (OpReg format_x reg_x)) `snocOL` VFMA var (OpReg targetFormat dst) (OpReg format_y reg_y) (OpReg format_z reg_z) - MO_VF_Insert length width -> vecInsert floatVecFormat length width MO_V_Insert length width -> vecInsert intVecFormat length width _ -> @@ -1348,7 +1407,7 @@ getRegister' config plat expr = (reg_idx, format_idx, code_idx) <- getSomeReg z let format = toFormat length width format_mask = intVecFormat length W8 -- Actually, W1 (one bit) would be correct, but that does not exist. - format_vid = intVecFormat length vidWidth + format_vid = intVecFormat length (vidWidth length) vidReg <- getNewRegNat format_vid tmp <- getNewRegNat format pure $ Any format $ \dst -> @@ -1373,18 +1432,20 @@ getRegister' config plat expr = `snocOL` -- 4. Merge with mask -> set element at index VMERGE (OpReg format dst) (OpReg format_v reg_v) (OpReg format tmp) (OpReg format_mask v0Reg) + + -- Which element width do I need in my vector to store indexes in it? + vidWidth :: Int -> Width + vidWidth length = case bitWidthFixed (fromIntegral length :: Word) of + x + | x <= widthInBits W8 -> W8 + | x <= widthInBits W16 -> W16 + | x <= widthInBits W32 -> W32 + | x <= widthInBits W64 -> W64 + | x <= widthInBits W128 -> W128 + | x <= widthInBits W256 -> W256 + | x <= widthInBits W512 -> W512 + e -> panic $ "length " ++ show length ++ "not representable in a single element's Width (" ++ show e ++ ")" where - -- Which element width do I need in my vector to store indexes in it? - vidWidth = case bitWidthFixed (fromIntegral length :: Word) of - x - | x <= widthInBits W8 -> W8 - | x <= widthInBits W16 -> W16 - | x <= widthInBits W32 -> W32 - | x <= widthInBits W64 -> W64 - | x <= widthInBits W128 -> W128 - | x <= widthInBits W256 -> W256 - | x <= widthInBits W512 -> W512 - e -> panic $ "length " ++ show length ++ "not representable in a single element's Width (" ++ show e ++ ")" bitWidthFixed :: Word -> Int bitWidthFixed 0 = 1 bitWidthFixed n = finiteBitSize n - countLeadingZeros n @@ -1489,14 +1550,6 @@ getRegister' config plat expr = ) -- TODO: Missing MachOps: --- - MO_V_Add --- - MO_V_Sub --- - MO_V_Mul --- - MO_VS_Quot --- - MO_VS_Rem --- - MO_VS_Neg --- - MO_VU_Quot --- - MO_VU_Rem -- - MO_V_Shuffle -- - MO_VF_Shuffle @@ -2142,19 +2195,45 @@ genCCall (PrimTarget mop) dest_regs arg_regs = do MO_AddIntC _w -> unsupported mop MO_SubIntC _w -> unsupported mop MO_U_Mul2 _w -> unsupported mop + MO_VS_Quot length w + | [x, y] <- arg_regs, + [dst_reg] <- dest_regs -> + v3op mop (intVecFormat length w) dst_reg x y (VQUOT (Just Signed)) MO_VS_Quot {} -> unsupported mop + MO_VU_Quot length w + | [x, y] <- arg_regs, + [dst_reg] <- dest_regs -> + v3op mop (intVecFormat length w) dst_reg x y (VQUOT (Just Unsigned)) MO_VU_Quot {} -> unsupported mop MO_VS_Rem length w | [x, y] <- arg_regs, - [dst_reg] <- dest_regs -> vrem mop length w dst_reg x y Signed + [dst_reg] <- dest_regs -> + v3op mop (intVecFormat length w) dst_reg x y (VREM Signed) MO_VS_Rem {} -> unsupported mop MO_VU_Rem length w | [x, y] <- arg_regs, - [dst_reg] <- dest_regs -> vrem mop length w dst_reg x y Unsigned + [dst_reg] <- dest_regs -> + v3op mop (intVecFormat length w) dst_reg x y (VREM Unsigned) MO_VU_Rem {} -> unsupported mop + MO_I64X2_Min + | [x, y] <- arg_regs, + [dst_reg] <- dest_regs -> + v3op mop (intVecFormat 2 W64) dst_reg x y VSMIN MO_I64X2_Min -> unsupported mop + MO_I64X2_Max + | [x, y] <- arg_regs, + [dst_reg] <- dest_regs -> + v3op mop (intVecFormat 2 W64) dst_reg x y VSMAX MO_I64X2_Max -> unsupported mop + MO_W64X2_Min + | [x, y] <- arg_regs, + [dst_reg] <- dest_regs -> + v3op mop (intVecFormat 2 W64) dst_reg x y VUMIN MO_W64X2_Min -> unsupported mop + MO_W64X2_Max + | [x, y] <- arg_regs, + [dst_reg] <- dest_regs -> + v3op mop (intVecFormat 2 W64) dst_reg x y VUMAX MO_W64X2_Max -> unsupported mop -- Memory Ordering -- The related C functions are: @@ -2275,24 +2354,23 @@ genCCall (PrimTarget mop) dest_regs arg_regs = do let code = code_fx `appOL` op (OpReg fmt dst) (OpReg format_x reg_fx) pure code - vrem :: CallishMachOp -> Int -> Width -> LocalReg -> CmmExpr -> CmmExpr -> Signage -> NatM InstrBlock - vrem mop length w dst_reg x y s = do - platform <- getPlatform - let dst = getRegisterReg platform (CmmLocal dst_reg) - format = intVecFormat length w - moDescr = pprCallishMachOp mop - (reg_x, format_x, code_x) <- getSomeReg x - (reg_y, format_y, code_y) <- getSomeReg y - massertPpr (isVecFormat format_x && isVecFormat format_y) - $ text "vecOp: non-vector operand. operands: " - <+> ppr format_x - <+> ppr format_y - pure - $ code_x - `appOL` code_y - `snocOL` - ann moDescr - (VREM s (OpReg format dst) (OpReg format_x reg_x) (OpReg format_y reg_y)) + v3op :: CallishMachOp -> Format -> LocalReg -> CmmExpr -> CmmExpr -> (Operand -> Operand -> Operand -> Instr) -> NatM InstrBlock + v3op mop dst_format dst_reg x y op = do + platform <- getPlatform + let dst = getRegisterReg platform (CmmLocal dst_reg) + moDescr = pprCallishMachOp mop + (reg_x, format_x, code_x) <- getSomeReg x + (reg_y, format_y, code_y) <- getSomeReg y + massertPpr (isVecFormat format_x && isVecFormat format_y) + $ text "vecOp: non-vector operand. operands: " + <+> ppr format_x + <+> ppr format_y + pure + $ code_x + `appOL` code_y + `snocOL` ann + moDescr + (op (OpReg dst_format dst) (OpReg format_x reg_x) (OpReg format_y reg_y)) {- Note [RISCV64 far jumps] ~~~~~~~~~~~~~~~~~~~~~~~~~~~ @@ -2540,6 +2618,7 @@ makeFarBranches {- only used when debugging -} _platform statics basic_blocks = VUMAX {} -> 2 VFMIN {} -> 2 VFMAX {} -> 2 + VRGATHER {} -> 2 VFMA {} -> 3 -- estimate the subsituted size for jumps to lables -- jumps to registers have size 1 ===================================== compiler/GHC/CmmToAsm/RV64/Instr.hs ===================================== @@ -119,14 +119,15 @@ regUsageOfInstr platform instr = case instr of VADD dst src1 src2 -> usage (regOp src1 ++ regOp src2, regOp dst) VSUB dst src1 src2 -> usage (regOp src1 ++ regOp src2, regOp dst) VMUL dst src1 src2 -> usage (regOp src1 ++ regOp src2, regOp dst) - VQUOT dst src1 src2 -> usage (regOp src1 ++ regOp src2, regOp dst) - VREM s dst src1 src2 -> usage (regOp src1 ++ regOp src2, regOp dst) + VQUOT _mbS dst src1 src2 -> usage (regOp src1 ++ regOp src2, regOp dst) + VREM _s dst src1 src2 -> usage (regOp src1 ++ regOp src2, regOp dst) VSMIN dst src1 src2 -> usage (regOp src1 ++ regOp src2, regOp dst) VSMAX dst src1 src2 -> usage (regOp src1 ++ regOp src2, regOp dst) VUMIN dst src1 src2 -> usage (regOp src1 ++ regOp src2, regOp dst) VUMAX dst src1 src2 -> usage (regOp src1 ++ regOp src2, regOp dst) VFMIN dst src1 src2 -> usage (regOp src1 ++ regOp src2, regOp dst) VFMAX dst src1 src2 -> usage (regOp src1 ++ regOp src2, regOp dst) + VRGATHER dst src1 src2 -> usage (regOp src1 ++ regOp src2, regOp dst) FMA _ dst src1 src2 src3 -> usage (regOp src1 ++ regOp src2 ++ regOp src3, regOp dst) VFMA _ op1 op2 op3 -> @@ -233,7 +234,7 @@ patchRegsOfInstr instr env = case instr of VADD o1 o2 o3 -> VADD (patchOp o1) (patchOp o2) (patchOp o3) VSUB o1 o2 o3 -> VSUB (patchOp o1) (patchOp o2) (patchOp o3) VMUL o1 o2 o3 -> VMUL (patchOp o1) (patchOp o2) (patchOp o3) - VQUOT o1 o2 o3 -> VQUOT (patchOp o1) (patchOp o2) (patchOp o3) + VQUOT mbS o1 o2 o3 -> VQUOT mbS (patchOp o1) (patchOp o2) (patchOp o3) VREM s o1 o2 o3 -> VREM s (patchOp o1) (patchOp o2) (patchOp o3) VSMIN o1 o2 o3 -> VSMIN (patchOp o1) (patchOp o2) (patchOp o3) VSMAX o1 o2 o3 -> VSMAX (patchOp o1) (patchOp o2) (patchOp o3) @@ -241,6 +242,7 @@ patchRegsOfInstr instr env = case instr of VUMAX o1 o2 o3 -> VUMAX (patchOp o1) (patchOp o2) (patchOp o3) VFMIN o1 o2 o3 -> VFMIN (patchOp o1) (patchOp o2) (patchOp o3) VFMAX o1 o2 o3 -> VFMAX (patchOp o1) (patchOp o2) (patchOp o3) + VRGATHER o1 o2 o3 -> VRGATHER (patchOp o1) (patchOp o2) (patchOp o3) FMA s o1 o2 o3 o4 -> FMA s (patchOp o1) (patchOp o2) (patchOp o3) (patchOp o4) VFMA s o1 o2 o3 -> @@ -676,7 +678,7 @@ data Instr | VADD Operand Operand Operand | VSUB Operand Operand Operand | VMUL Operand Operand Operand - | VQUOT Operand Operand Operand + | VQUOT (Maybe Signage) Operand Operand Operand | VREM Signage Operand Operand Operand | VSMIN Operand Operand Operand | VSMAX Operand Operand Operand @@ -685,6 +687,7 @@ data Instr | VFMIN Operand Operand Operand | VFMAX Operand Operand Operand | VFMA FMASign Operand Operand Operand + | VRGATHER Operand Operand Operand data Signage = Signed | Unsigned deriving (Eq, Show) @@ -770,6 +773,7 @@ instrCon i = VUMAX {} -> "VUMAX" VFMIN {} -> "VFMIN" VFMAX {} -> "VFMAX" + VRGATHER {} -> "VRGATHER" FMA variant _ _ _ _ -> case variant of FMAdd -> "FMADD" ===================================== compiler/GHC/CmmToAsm/RV64/Ppr.hs ===================================== @@ -853,8 +853,10 @@ pprInstr platform instr = case instr of VMUL o1 o2 o3 | allIntVectorRegOps [o1, o2, o3] -> op3 (text "\tvmul.vv") o1 o2 o3 VMUL o1 o2 o3 | allFloatVectorRegOps [o1, o2, o3] -> op3 (text "\tvfmul.vv") o1 o2 o3 VMUL o1 o2 o3 -> pprPanic "RV64.pprInstr - VMUL wrong operands." (pprOps platform [o1, o2, o3]) - VQUOT o1 o2 o3 | allVectorRegOps [o1, o2, o3] -> op3 (text "\tvfdiv.vv") o1 o2 o3 - VQUOT o1 o2 o3 -> pprPanic "RV64.pprInstr - VQUOT wrong operands." (pprOps platform [o1, o2, o3]) + VQUOT (Just Signed) o1 o2 o3 | allIntVectorRegOps [o1, o2, o3] -> op3 (text "\tvdiv.vv") o1 o2 o3 + VQUOT (Just Unsigned) o1 o2 o3 | allIntVectorRegOps [o1, o2, o3] -> op3 (text "\tvdivu.vv") o1 o2 o3 + VQUOT Nothing o1 o2 o3 | allFloatVectorRegOps [o1, o2, o3] -> op3 (text "\tvfdiv.vv") o1 o2 o3 + VQUOT mbS o1 o2 o3 -> pprPanic ("RV64.pprInstr - VQUOT wrong operands. " ++ show mbS) (pprOps platform [o1, o2, o3]) VREM Signed o1 o2 o3 | allIntVectorRegOps [o1, o2, o3] -> op3 (text "\tvrem.vv") o1 o2 o3 VREM Unsigned o1 o2 o3 | allIntVectorRegOps [o1, o2, o3] -> op3 (text "\tvremu.vv") o1 o2 o3 VREM s o1 o2 o3 -> pprPanic ("RV64.pprInstr - VREM wrong operands. " ++ show s) (pprOps platform [o1, o2, o3]) @@ -870,6 +872,8 @@ pprInstr platform instr = case instr of VFMIN o1 o2 o3 -> pprPanic "RV64.pprInstr - VFMIN wrong operands." (pprOps platform [o1, o2, o3]) VFMAX o1 o2 o3 | allVectorRegOps [o1, o2, o3] -> op3 (text "\tvfmax.vv") o1 o2 o3 VFMAX o1 o2 o3 -> pprPanic "RV64.pprInstr - VFMAX wrong operands." (pprOps platform [o1, o2, o3]) + VRGATHER o1 o2 o3 | allVectorRegOps [o1, o2, o3] -> op3 (text "\tvrgatherei16.vv") o1 o2 o3 + VRGATHER o1 o2 o3 -> pprPanic "RV64.pprInstr - VRGATHER wrong operands." (pprOps platform [o1, o2, o3]) instr -> panic $ "RV64.pprInstr - Unknown instruction: " ++ instrCon instr where op1 op o1 = line $ op <+> pprOp platform o1 @@ -984,9 +988,9 @@ instrVecFormat platform instr = case instr of VMUL (OpReg fmt _reg) _o2 _o3 | isVecFormat fmt -> checkedJustFmt fmt VMUL _o1 _o2 _o3 -> pprPanic "Did not match" (pprInstr platform instr) - VQUOT (OpReg fmt _reg) _o2 _o3 + VQUOT _mbS (OpReg fmt _reg) _o2 _o3 | isVecFormat fmt -> checkedJustFmt fmt - VQUOT _o1 _o2 _o3 -> pprPanic "Did not match" (pprInstr platform instr) + VQUOT _mbS _o1 _o2 _o3 -> pprPanic "Did not match" (pprInstr platform instr) VSMIN (OpReg fmt _reg) _o2 _o3 | isVecFormat fmt -> checkedJustFmt fmt VSMIN _o1 _o2 _o3 -> pprPanic "Did not match" (pprInstr platform instr) @@ -1004,6 +1008,8 @@ instrVecFormat platform instr = case instr of VFMIN _o1 _o2 _o3 -> pprPanic "Did not match" (pprInstr platform instr) VFMAX (OpReg fmt _reg) _o2 _o3 -> checkedJustFmt fmt VFMAX _o1 _o2 _o3 -> pprPanic "Did not match" (pprInstr platform instr) + VRGATHER (OpReg fmt _reg) _o2 _o3 -> checkedJustFmt fmt + VRGATHER _o1 _o2 _o3 -> pprPanic "Did not match" (pprInstr platform instr) _ -> Nothing where checkedJustFmt :: Format -> Maybe Format View it on GitLab: https://gitlab.haskell.org/ghc/ghc/-/compare/5a31e90c45ea723391dfd1e331ba0ff... -- View it on GitLab: https://gitlab.haskell.org/ghc/ghc/-/compare/5a31e90c45ea723391dfd1e331ba0ff... You're receiving this email because of your account on gitlab.haskell.org.