| ... |
... |
@@ -12,6 +12,7 @@ module GHC.CmmToAsm.RV64.CodeGen |
|
12
|
12
|
where
|
|
13
|
13
|
|
|
14
|
14
|
import Control.Monad
|
|
|
15
|
+import Data.Bifunctor (bimap)
|
|
15
|
16
|
import Data.Maybe
|
|
16
|
17
|
import Data.Word
|
|
17
|
18
|
import GHC.Cmm
|
| ... |
... |
@@ -55,7 +56,6 @@ import GHC.Types.SrcLoc (srcSpanFile, srcSpanStartCol, srcSpanStartLine) |
|
55
|
56
|
import GHC.Types.Tickish (GenTickish (..))
|
|
56
|
57
|
import GHC.Types.Unique.DSM
|
|
57
|
58
|
import GHC.Utils.Constants (debugIsOn)
|
|
58
|
|
-import GHC.Utils.Misc
|
|
59
|
59
|
import GHC.Utils.Monad
|
|
60
|
60
|
import GHC.Utils.Outputable
|
|
61
|
61
|
import GHC.Utils.Panic
|
| ... |
... |
@@ -266,7 +266,6 @@ annExpr e {- debugIsOn -} = ANN (text . show $ e) |
|
266
|
266
|
-- This seems to be PIC compatible; at least `scanelf` (pax-utils) does not
|
|
267
|
267
|
-- complain.
|
|
268
|
268
|
|
|
269
|
|
-
|
|
270
|
269
|
-- | Generate jump to jump table target
|
|
271
|
270
|
--
|
|
272
|
271
|
-- The index into the jump table is calulated by evaluating @expr@. The
|
| ... |
... |
@@ -423,22 +422,22 @@ getRegisterReg platform (CmmGlobal mid) = |
|
423
|
422
|
-- General things for putting together code sequences
|
|
424
|
423
|
|
|
425
|
424
|
-- | Compute an expression into any register
|
|
426
|
|
-getSomeReg :: HasCallStack => CmmExpr -> NatM (Reg, Format, InstrBlock)
|
|
|
425
|
+getSomeReg :: CmmExpr -> NatM (Reg, Format, InstrBlock)
|
|
427
|
426
|
getSomeReg expr = do
|
|
428
|
427
|
r <- getRegister expr
|
|
429
|
428
|
res@(reg, fmt, _) <- case r of
|
|
430
|
|
- Any rep code -> do
|
|
431
|
|
- newReg <- getNewRegNat rep
|
|
432
|
|
- pure (newReg, rep, code newReg)
|
|
433
|
|
- Fixed rep reg code ->
|
|
434
|
|
- pure (reg, rep, code)
|
|
|
429
|
+ Any rep code -> do
|
|
|
430
|
+ newReg <- getNewRegNat rep
|
|
|
431
|
+ pure (newReg, rep, code newReg)
|
|
|
432
|
+ Fixed rep reg code ->
|
|
|
433
|
+ pure (reg, rep, code)
|
|
435
|
434
|
pure $ assertFmtReg fmt reg res
|
|
436
|
435
|
|
|
437
|
436
|
-- | Compute an expression into any floating-point register
|
|
438
|
437
|
--
|
|
439
|
438
|
-- If the initial expression is not a floating-point expression, finally move
|
|
440
|
439
|
-- the result into a floating-point register.
|
|
441
|
|
-getFloatReg :: (HasCallStack) => CmmExpr -> NatM (Reg, Format, InstrBlock)
|
|
|
440
|
+getFloatReg :: CmmExpr -> NatM (Reg, Format, InstrBlock)
|
|
442
|
441
|
getFloatReg expr = do
|
|
443
|
442
|
r <- getRegister expr
|
|
444
|
443
|
case r of
|
| ... |
... |
@@ -866,13 +865,10 @@ getRegister' config plat expr = |
|
866
|
865
|
MO_AlignmentCheck align wordWidth -> do
|
|
867
|
866
|
reg <- getRegister' config plat e
|
|
868
|
867
|
addAlignmentCheck align wordWidth reg
|
|
869
|
|
-
|
|
870
|
868
|
MO_V_Broadcast length w -> vectorBroadcast (intVecFormat length w) e
|
|
871
|
869
|
MO_VF_Broadcast length w -> vectorBroadcast (floatVecFormat length w) e
|
|
872
|
|
-
|
|
873
|
870
|
MO_VS_Neg length w -> vectorNegation (intVecFormat length w)
|
|
874
|
871
|
MO_VF_Neg length w -> vectorNegation (floatVecFormat length w)
|
|
875
|
|
-
|
|
876
|
872
|
x -> pprPanic ("getRegister' (monadic CmmMachOp): " ++ show x) (pdoc plat expr)
|
|
877
|
873
|
where
|
|
878
|
874
|
-- In the case of 16- or 8-bit values we need to sign-extend to 32-bits
|
| ... |
... |
@@ -1236,9 +1232,14 @@ getRegister' config plat expr = |
|
1236
|
1232
|
genericVectorShuffle :: (Int -> Width -> Format) -> Int -> Width -> [Int] -> NatM Register
|
|
1237
|
1233
|
genericVectorShuffle toDstFormat length w idxs = do
|
|
1238
|
1234
|
-- Our strategy:
|
|
1239
|
|
- -- - Gather elemens of v1 on the right positions
|
|
1240
|
|
- -- - Gather elemenrs of v2 of the right positions
|
|
1241
|
|
- -- - Merge v1 and v2 with an adequate bitmask (v0)
|
|
|
1235
|
+ -- - Gather elements of v1 on the right positions -> v1'
|
|
|
1236
|
+ -- - Gather elements of v2 of the right positions -> v2'
|
|
|
1237
|
+ -- - Merge v1' and v2' with an adequate bitmask (v0)
|
|
|
1238
|
+ --
|
|
|
1239
|
+ -- We create three supporting data section entries / structures:
|
|
|
1240
|
+ -- - A mapping vector that describes the mapping of v1 -> v1'
|
|
|
1241
|
+ -- - The same for v2 -> v2'
|
|
|
1242
|
+ -- - the mask vector used to merge v1' and v2'
|
|
1242
|
1243
|
lbl_selVec_v1 <- getNewLabelNat
|
|
1243
|
1244
|
lbl_selVec_v2 <- getNewLabelNat
|
|
1244
|
1245
|
lbl_mask <- getNewLabelNat
|
| ... |
... |
@@ -1247,7 +1248,7 @@ getRegister' config plat expr = |
|
1247
|
1248
|
(reg_y, format_y, code_y) <- getSomeReg y
|
|
1248
|
1249
|
|
|
1249
|
1250
|
let (idxs_v1, idxs_v2) =
|
|
1250
|
|
- mapTuple reverse
|
|
|
1251
|
+ bimap reverse reverse
|
|
1251
|
1252
|
$ foldl'
|
|
1252
|
1253
|
( \(acc1, acc2) i ->
|
|
1253
|
1254
|
if i < length then (Just i : acc1, Nothing : acc2) else (Nothing : acc1, Just (i - length) : acc2)
|
| ... |
... |
@@ -1259,7 +1260,6 @@ getRegister' config plat expr = |
|
1259
|
1260
|
-- Finally, the mask must be 0 where v1 should be taken and 1 for v2.
|
|
1260
|
1261
|
-- That's why we do an implicit negation here by focussing on v2.
|
|
1261
|
1262
|
maskVecData = maskData idxs_v2
|
|
1262
|
|
- -- Longest vector length is 64, so 8bit (0 - 255) is sufficient
|
|
1263
|
1263
|
selVecFormat = intVecFormat length w
|
|
1264
|
1264
|
dstFormat = toDstFormat length w
|
|
1265
|
1265
|
addrFormat = intFormat W64
|
| ... |
... |
@@ -1301,9 +1301,6 @@ getRegister' config plat expr = |
|
1301
|
1301
|
VMERGE (OpReg dstFormat dst) (OpReg format_x gathered_x) (OpReg format_y gathered_y) (OpReg maskFormat v0Reg)
|
|
1302
|
1302
|
]
|
|
1303
|
1303
|
|
|
1304
|
|
- mapTuple :: (a -> b) -> (a, a) -> (b, b)
|
|
1305
|
|
- mapTuple f (x, y) = (f x, f y)
|
|
1306
|
|
-
|
|
1307
|
1304
|
selVecData :: Width -> [Maybe Int] -> [CmmStatic]
|
|
1308
|
1305
|
selVecData w idxs =
|
|
1309
|
1306
|
-- Using the width `w` here is a bit wasteful. But, it saves
|
| ... |
... |
@@ -1376,7 +1373,6 @@ getRegister' config plat expr = |
|
1376
|
1373
|
MO_Shl w -> intOp False w (\d x y -> unitOL $ annExpr expr (SLL d x y))
|
|
1377
|
1374
|
MO_U_Shr w -> intOp False w (\d x y -> unitOL $ annExpr expr (SRL d x y))
|
|
1378
|
1375
|
MO_S_Shr w -> intOp True w (\d x y -> unitOL $ annExpr expr (SRA d x y))
|
|
1379
|
|
-
|
|
1380
|
1376
|
-- Vector operations
|
|
1381
|
1377
|
MO_VF_Extract _length w -> vecExtract ((scalarFormatFormat . floatScalarFormat) w)
|
|
1382
|
1378
|
MO_V_Extract _length w -> vecExtract ((scalarFormatFormat . intScalarFormat) w)
|
| ... |
... |
@@ -1404,12 +1400,23 @@ getRegister' config plat expr = |
|
1404
|
1400
|
-- Generic ternary case.
|
|
1405
|
1401
|
CmmMachOp op [x, y, z] ->
|
|
1406
|
1402
|
case op of
|
|
1407
|
|
- -- Floating-point fused multiply-add operations
|
|
|
1403
|
+ -- Floating-point fused multiply-add operations:
|
|
|
1404
|
+ --
|
|
|
1405
|
+ -- x86 fmadd x * y + z <=> RISCV64 fmadd : d = r1 * r2 + r3
|
|
|
1406
|
+ -- x86 fmsub x * y - z <=> RISCV64 fmsub: d = r1 * r2 - r3
|
|
|
1407
|
+ -- x86 fnmadd - x * y + z <=> RISCV64 fnmsub: d = - r1 * r2 + r3
|
|
|
1408
|
+ -- x86 fnmsub - x * y - z <=> RISCV64 fnmadd: d = - r1 * r2 - r3
|
|
1408
|
1409
|
--
|
|
1409
|
|
- -- x86 fmadd x * y + z <=> RISCV64 fmadd : d = r1 * r2 + r3
|
|
1410
|
|
- -- x86 fmsub x * y - z <=> RISCV64 fnmsub: d = r1 * r2 - r3
|
|
1411
|
|
- -- x86 fnmadd - x * y + z <=> RISCV64 fmsub : d = - r1 * r2 + r3
|
|
1412
|
|
- -- x86 fnmsub - x * y - z <=> RISCV64 fnmadd: d = - r1 * r2 - r3
|
|
|
1410
|
+ -- Vector fused multiply-add operations (what x86 exactly does doesn't
|
|
|
1411
|
+ -- matter here, we care about the abstract spec):
|
|
|
1412
|
+ --
|
|
|
1413
|
+ -- FMAdd x * y + z <=> RISCV64 vfmadd : d = r1 * r2 + r3
|
|
|
1414
|
+ -- FMSub x * y - z <=> RISCV64 vfmsub: d = r1 * r2 - r3
|
|
|
1415
|
+ -- FNMAdd - x * y + z <=> RISCV64 vfnmsub: d = - r1 * r2 + r3
|
|
|
1416
|
+ -- FNMSub - x * y - z <=> RISCV64 vfnmadd: d = - r1 * r2 - r3
|
|
|
1417
|
+ --
|
|
|
1418
|
+ -- For both formats, the instruction selection happens in the
|
|
|
1419
|
+ -- pretty-printer.
|
|
1413
|
1420
|
MO_FMA var length w
|
|
1414
|
1421
|
| length == 1 ->
|
|
1415
|
1422
|
float3Op w (\d n m a -> unitOL $ FMA var d n m a)
|
| ... |
... |
@@ -1418,12 +1425,10 @@ getRegister' config plat expr = |
|
1418
|
1425
|
(reg_y, format_y, code_y) <- getSomeReg y
|
|
1419
|
1426
|
(reg_z, format_z, code_z) <- getSomeReg z
|
|
1420
|
1427
|
let targetFormat = VecFormat length (floatScalarFormat w)
|
|
1421
|
|
- negate_z = if var `elem` [FNMAdd, FNMSub] then unitOL (VNEG (OpReg format_z reg_z) (OpReg format_z reg_z)) else nilOL
|
|
1422
|
1428
|
pure $ Any targetFormat $ \dst ->
|
|
1423
|
1429
|
code_x
|
|
1424
|
1430
|
`appOL` code_y
|
|
1425
|
1431
|
`appOL` code_z
|
|
1426
|
|
- `appOL` negate_z
|
|
1427
|
1432
|
`snocOL` annExpr
|
|
1428
|
1433
|
expr
|
|
1429
|
1434
|
(VMV (OpReg targetFormat dst) (OpReg format_x reg_x))
|
| ... |
... |
@@ -2138,8 +2143,7 @@ genCCall target@(ForeignTarget expr _cconv) dest_regs arg_regs = do |
|
2138
|
2143
|
-- See Note [RISC-V vector C calling convention]
|
|
2139
|
2144
|
passArguments _gpRegs _fpRegs [] ((_r, format, _hint, _code_r) : _args) _stackSpaceWords _accumRegs _accumCode
|
|
2140
|
2145
|
| isVecFormat format =
|
|
2141
|
|
- panic "C call: no free vector argument registers. We only support 16 vector arguments (registers v8 - v23)."
|
|
2142
|
|
-
|
|
|
2146
|
+ panic "C call: no free vector argument registers. We only support 16 vector arguments (registers v8 - v23)."
|
|
2143
|
2147
|
passArguments _ _ _ _ _ _ _ = pprPanic "passArguments" (text "invalid state")
|
|
2144
|
2148
|
|
|
2145
|
2149
|
readResults :: [Reg] -> [Reg] -> [Reg] -> [LocalReg] -> [Reg] -> InstrBlock -> NatM InstrBlock
|