Skip to content

Commit

Permalink
Implement MulRound operation on RISCV
Browse files Browse the repository at this point in the history
Signed-off-by: Zoltan Herczeg [email protected]
  • Loading branch information
zherczeg authored and ksh8281 committed Feb 17, 2025
1 parent 41601dd commit b64a1ce
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 2 deletions.
13 changes: 11 additions & 2 deletions src/jit/ByteCodeParser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -283,6 +283,7 @@ static bool isFloatGlobal(uint32_t globalIndex, Module* module)
#define OTMinMaxV128 OTOp2V128
#define OTPMinMaxV128 OTOp2V128
#define OTSwizzleV128 OTOp2V128
#define OTMulRoundSatV128 OTOp2V128

#elif (defined SLJIT_CONFIG_ARM_64 && SLJIT_CONFIG_ARM_64)

Expand All @@ -302,6 +303,7 @@ static bool isFloatGlobal(uint32_t globalIndex, Module* module)
#define OTMinMaxV128 OTOp2V128
#define OTPopcntV128 OTOp1V128
#define OTSwizzleV128 OTOp2V128
#define OTMulRoundSatV128 OTOp2V128
#define OTShiftV128Tmp OTShiftV128
#define OTOp3DotAddV128 OTOp3V128

Expand All @@ -323,6 +325,7 @@ static bool isFloatGlobal(uint32_t globalIndex, Module* module)
#define OTOp2V128Rev OTOp2V128
#define OTPMinMaxV128 OTOp2V128
#define OTPopcntV128 OTOp1V128
#define OTMulRoundSatV128 OTOp2V128
#define OTShiftV128Tmp OTShiftV128
#define OTOp3DotAddV128 OTOp3V128

Expand All @@ -337,6 +340,7 @@ static bool isFloatGlobal(uint32_t globalIndex, Module* module)
OL2(OTExtractLaneF64, /* SD */ V128 | TMP, F64) \
OL3(OTSwizzleV128, /* SSD */ V128 | TMP, V128 | NOTMP, V128 | TMP | S1) \
OL3(OTShuffleV128, /* SSD */ V128 | TMP, V128 | TMP, V128 | TMP) \
OL4(OTMulRoundSatV128, /* SSDT */ V128 | TMP, V128 | TMP, V128 | TMP | S0 | S1, V128) \
OL3(OTShiftV128, /* SSD */ V128 | NOTMP, I32, V128 | TMP | S0)

// List of aliases.
Expand Down Expand Up @@ -1645,8 +1649,6 @@ static void compileFunction(JITCompiler* compiler)
case ByteCode::I16X8ExtmulHighI8X16UOpcode:
case ByteCode::I16X8NarrowI32X4SOpcode:
case ByteCode::I16X8NarrowI32X4UOpcode:
case ByteCode::I16X8Q15mulrSatSOpcode:
case ByteCode::I16X8RelaxedQ15mulrSOpcode:
case ByteCode::I32X4AddOpcode:
case ByteCode::I32X4SubOpcode:
case ByteCode::I32X4MulOpcode:
Expand Down Expand Up @@ -1708,6 +1710,13 @@ static void compileFunction(JITCompiler* compiler)
requiredInit = OTSwizzleV128;
break;
}
case ByteCode::I16X8Q15mulrSatSOpcode:
case ByteCode::I16X8RelaxedQ15mulrSOpcode: {
group = Instruction::BinarySIMD;
paramType = ParamTypes::ParamSrc2Dst;
requiredInit = OTMulRoundSatV128;
break;
}
case ByteCode::I64X2MulOpcode:
case ByteCode::I64X2LtSOpcode:
case ByteCode::I64X2LeSOpcode:
Expand Down
21 changes: 21 additions & 0 deletions src/jit/SimdRiscvInl.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ enum TypeOpcode : uint32_t {
vaaddu_vv = InstructionType::opmvv | OPCODE(0x8),
vadd_vi = InstructionType::opivi | OPCODE(0x0),
vadd_vv = InstructionType::opivv | OPCODE(0x0),
vadd_vx = InstructionType::opivx | OPCODE(0x0),
vand_vv = InstructionType::opivv | OPCODE(0x9),
vcompress_vm = InstructionType::opmvv | OPCODE(0x17),
#if defined(__riscv_zvbb)
Expand Down Expand Up @@ -818,6 +819,23 @@ static void simdEmitNarrowUnsigned(sljit_compiler* compiler, sljit_s32 type, slj
simdEmitTypedOp(compiler, SLJIT_SIMD_ELEM_8, SimdOp::vslideup_vi, rd, tmp2, 8, SimdOp::rmIsImm);
}

static void simdEmitQ15Mul(sljit_compiler* compiler, sljit_s32 rd, sljit_s32 rn, sljit_s32 rm, sljit_s32 tmp3)
{
sljit_s32 tmp1 = SLJIT_TMP_DEST_VREG;
sljit_s32 tmp2 = SLJIT_VR0;

simdEmitTypedOp(compiler, SLJIT_SIMD_ELEM_16, SimdOp::vslidedown_vi, tmp2, rn, 4, SimdOp::rmIsImm);
simdEmitOp(compiler, SimdOp::vslidedown_vi, tmp3, rm, 4, SimdOp::rmIsImm);
simdEmitTypedOp(compiler, SLJIT_SIMD_ELEM_16, SimdOp::vwmul_vv, tmp1, tmp2, tmp3, 0, SimdOp::vlMulF2);
simdEmitOp(compiler, SimdOp::vwmul_vv, tmp2, rn, rm);

/* The vxrm register is expected to be zero (default). */
simdEmitTypedOp(compiler, SLJIT_SIMD_ELEM_16, SimdOp::vnclip_wi, rd, tmp2, 15, SimdOp::rmIsImm, SimdOp::vlMulF2);
simdEmitOp(compiler, SimdOp::vnclip_wi, tmp2, tmp1, 15, SimdOp::rmIsImm);

simdEmitTypedOp(compiler, SLJIT_SIMD_ELEM_16, SimdOp::vslideup_vi, rd, tmp2, 4, SimdOp::rmIsImm);
}

static void emitBinarySIMD(sljit_compiler* compiler, Instruction* instr)
{
Operand* operands = instr->operands();
Expand Down Expand Up @@ -885,6 +903,7 @@ static void emitBinarySIMD(sljit_compiler* compiler, Instruction* instr)
case ByteCode::I16X8MaxUOpcode:
case ByteCode::I16X8AvgrUOpcode:
case ByteCode::I16X8Q15mulrSatSOpcode:
case ByteCode::I16X8RelaxedQ15mulrSOpcode:
srcType = SLJIT_SIMD_ELEM_16;
dstType = SLJIT_SIMD_ELEM_16;
break;
Expand Down Expand Up @@ -1141,6 +1160,8 @@ static void emitBinarySIMD(sljit_compiler* compiler, Instruction* instr)
simdEmitExtmulHigh(compiler, srcType, SimdOp::vwmulu_vv, dst, args[0].arg, args[1].arg);
break;
case ByteCode::I16X8Q15mulrSatSOpcode:
case ByteCode::I16X8RelaxedQ15mulrSOpcode:
simdEmitQ15Mul(compiler, dst, args[0].arg, args[1].arg, instr->requiredReg(3));
break;
case ByteCode::I32X4DotI16X8SOpcode:
case ByteCode::I16X8DotI8X16I7X16SOpcode:
Expand Down

0 comments on commit b64a1ce

Please sign in to comment.