Skip to content

Commit

Permalink
Add preliminary RISC-V vector support (Assembly only)
Browse files Browse the repository at this point in the history
Signed-off-by: Patrick O'Neill <[email protected]>
  • Loading branch information
patrick-rivos committed Nov 16, 2024
1 parent 7443d60 commit b7c04f3
Show file tree
Hide file tree
Showing 14 changed files with 4,203 additions and 65 deletions.
75 changes: 51 additions & 24 deletions dev_tools/parsers/parse_binutils_riscv.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,12 @@

VECTOR_EXTS = [
"INSN_CLASS_V",
"INSN_CLASS_ZVBB",
"INSN_CLASS_ZVBC",
"INSN_CLASS_ZVKNED",
"INSN_CLASS_ZVKN",
"INSN_CLASS_ZVKS",
"INSN_CLASS_ZVEF",
]


Expand Down Expand Up @@ -88,6 +94,25 @@ def get_pattern_info() -> Dict[str, PatternInfo]:
["funct10", "vd", "vrs2", "rs1", "opcode"], "OPC vd, rs1, vrs2"
)

# Vec, Float -> Vec patterns
pattern_info_dict["Vd,Vt,S"] = PatternInfo(
["funct10", "vd", "vrs2", "frs1", "opcode"], "OPC vd, vrs2, frs1"
)
pattern_info_dict["Vd,S,Vt"] = PatternInfo(
["funct10", "vd", "vrs2", "frs1", "opcode"],
"OPC vd, frs1, vrs2",
)

# Float -> Vec patterns
pattern_info_dict["Vd,S"] = PatternInfo(
["funct10", "funct5", "vd", "frs1", "opcode"], "OPC vd, frs1"
)

# Vec -> Float patterns
pattern_info_dict["D,Vt"] = PatternInfo(
["funct10", "funct5", "frd", "vrs2", "opcode"], "OPC frd, vrs2"
)

# Vec -> Vec patterns
pattern_info_dict["Vd,Vs"] = PatternInfo(
["funct10", "funct5", "vd", "vrs1", "opcode"], "OPC vd, vrs1"
Expand All @@ -103,10 +128,6 @@ def get_pattern_info() -> Dict[str, PatternInfo]:
["funct10", "vd", "vrs1", "vrs2", "opcode"],
"OPC vd, vrs1, vrs2",
)
pattern_info_dict["Vd,S,Vt"] = PatternInfo(
["funct10", "vd", "rs1", "vrs2", "opcode"],
"OPC vd, rs1, vrs2",
)
pattern_info_dict["Vd,Vt"] = PatternInfo(
["funct10", "funct5", "vd", "vrs2", "opcode"], "OPC vd, vrs2"
)
Expand Down Expand Up @@ -150,6 +171,32 @@ def get_pattern_info() -> Dict[str, PatternInfo]:
"OPC vd, vrs2, i_imm6",
)

pattern_info_dict["Vd,Vt,Vs,V0"] = PatternInfo(
["funct5", "vmd", "vrs2", "vrs1", "vmask", "opcode"],
"OPC vmd, vrs1, vrs2, vmask",
)

pattern_info_dict["Vd,Vt,s,V0"] = PatternInfo(
["funct5", "vmd", "vrs2", "rs1", "vmask", "opcode"],
"OPC vmd, vrs2, rs1, vmask",
)

# Float
pattern_info_dict["Vd,Vt,S,V0"] = PatternInfo(
["funct5", "vmd", "vrs2", "frs1", "vmask", "opcode"],
"OPC vmd, vrs2, frs1, vmask",
)

pattern_info_dict["d,Vt"] = PatternInfo(
["funct10", "funct5", "rd", "vrs2", "opcode"],
"OPC rd, vrs2",
)

pattern_info_dict["d,VtVm"] = PatternInfo(
["funct10", "rd", "vrs2", "vmask", "opcode"],
"OPC rd, vrs2, vmask.t",
)

# Add masked variants
masked_patterns: Dict[str, PatternInfo] = {}
for pattern, pattern_info in pattern_info_dict.items():
Expand Down Expand Up @@ -210,26 +257,6 @@ def get_pattern_info() -> Dict[str, PatternInfo]:

pattern_info_dict |= widening_patterns

pattern_info_dict["Vd,Vt,Vs,V0"] = PatternInfo(
["funct5", "vmd", "vrs2", "vrs1", "vmask", "opcode"],
"OPC vmd, vrs1, vrs2, vmask",
)

pattern_info_dict["Vd,Vt,s,V0"] = PatternInfo(
["funct5", "vmd", "vrs2", "rs1", "vmask", "opcode"],
"OPC vmd, vrs2, rs1, vmask",
)

pattern_info_dict["d,Vt"] = PatternInfo(
["funct10", "funct5", "rd", "vrs2", "opcode"],
"OPC rd, vrs2",
)

pattern_info_dict["d,VtVm"] = PatternInfo(
["funct10", "rd", "vrs2", "vmask", "opcode"],
"OPC rd, vrs2, vmask.t",
)

return pattern_info_dict


Expand Down
8 changes: 5 additions & 3 deletions src/microprobe/code/ins.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
# Built-in modules
import copy
from itertools import product
from typing import TYPE_CHECKING, Callable, List
from typing import TYPE_CHECKING, Callable, Dict, List

# Third party modules

Expand Down Expand Up @@ -1603,7 +1603,8 @@ def __init__(self):
self._generic_type = None
self._label = None
self._mem_operands = []
self._operands = RejectingOrderedDict()
self._operands: Dict[str,
InstructionOperandValue] = RejectingOrderedDict()

def set_arch_type(self, instrtype):
"""
Expand All @@ -1612,7 +1613,8 @@ def set_arch_type(self, instrtype):
"""
self._arch_type = instrtype
self._operands = RejectingOrderedDict()
self._operands: Dict[str,
InstructionOperandValue] = RejectingOrderedDict()
self._mem_operands = []
self._allowed_regs = []
self._address = None
Expand Down
49 changes: 36 additions & 13 deletions src/microprobe/passes/initialization/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,10 @@
"""

# Futures
from __future__ import absolute_import, print_function
from __future__ import absolute_import, print_function, annotations

# Built-in modules
from typing import TYPE_CHECKING

# Third party modules

Expand All @@ -33,6 +36,11 @@
from microprobe.utils.ieee import ieee_float_to_int64
from microprobe.utils.logger import get_logger

# Type hinting
if TYPE_CHECKING:
from microprobe.code.benchmark import Benchmark
from microprobe.target import Target

# Constants
LOG = get_logger(__name__)
__all__ = [
Expand Down Expand Up @@ -213,6 +221,8 @@ def __init__(self, *args, **kwargs):
skip_unknown = kwargs.get("skip_unknown", False)
warn_unknown = kwargs.get("warn_unknown", False)
self._force_code = kwargs.get("force_code", False)
self.lmul = kwargs.get("lmul", 1)
self.sew = kwargs.get("sew", 32)

if len(args) == 1:
self._reg_dict = dict(
Expand Down Expand Up @@ -242,7 +252,7 @@ def __init__(self, *args, **kwargs):
% (self._value, self._fp_value, v_value)
)

def __call__(self, building_block, target):
def __call__(self, building_block: Benchmark, target: Target):
"""
:param building_block:
Expand All @@ -251,15 +261,15 @@ def __call__(self, building_block, target):
"""
if not self._skip_unknown:
for register_name in self._reg_dict:
if register_name not in list(target.registers.keys()):
if register_name not in list(target.isa.registers.keys()):
raise MicroprobeCodeGenerationError(
f"Unknown register name: '{register_name}'. "
"Unable to set it"
)

if self._warn_unknown:
for register_name in self._reg_dict:
if register_name not in list(target.registers.keys()):
if register_name not in list(target.isa.registers.keys()):
print_warning(
f"Unknown register name: '{register_name}'. "
"Unable to set it"
Expand All @@ -275,7 +285,7 @@ def __call__(self, building_block, target):
#
# Make sure scratch registers are set last
#
for reg in target.scratch_registers:
for reg in target.isa.scratch_registers:
if reg in regs:
regs.remove(reg)
regs.append(reg)
Expand All @@ -290,10 +300,23 @@ def __call__(self, building_block, target):
self._reg_dict.pop(reg.name)
force_direct = True

if (
reg in building_block.context.reserved_registers
and not self._force_reserved
):
if reg.name == "LMUL":
packed_lmul_sew = self.lmul << 9 | self.sew & 127
building_block.add_init(
target.isa.set_register(reg, packed_lmul_sew, building_block.context)
)
building_block.context.set_register_value(reg, packed_lmul_sew)
continue

all_vec_regs = set([f"V{i}" for i in range(0, 32)])
lmul_allowed_regs = set([f"V{i}" for i in range(0, 32, self.lmul)])

if reg.name in all_vec_regs - lmul_allowed_regs:
# Skip vector registers ignored by lmul
continue

if (reg in building_block.context.reserved_registers
and not self._force_reserved):
LOG.debug("Skip reserved - %s", reg)
continue

Expand All @@ -304,7 +327,7 @@ def __call__(self, building_block, target):
continue

if value is None:
if reg.used_for_vector_arithmetic:
if reg.type.used_for_vector_arithmetic:
if self._vect_value is not None:
value = self._vect_value
elemsize = self._vect_elemsize
Expand All @@ -313,7 +336,7 @@ def __call__(self, building_block, target):
"Skip no vector default value provided - %s", reg
)
continue
elif reg.used_for_float_arithmetic:
elif reg.type.used_for_float_arithmetic:
if self._fp_value is not None:
value = self._fp_value
else:
Expand All @@ -334,10 +357,10 @@ def __call__(self, building_block, target):
if isinstance(value, int):
value = value & ((2**reg.size) - 1)

if reg.used_for_float_arithmetic:
if reg.type.used_for_float_arithmetic:
value = ieee_float_to_int64(float(value))

elif reg.used_for_vector_arithmetic:
elif reg.type.used_for_vector_arithmetic:
if isinstance(value, float):
if elemsize != 64:
raise MicroprobeCodeGenerationError(
Expand Down
14 changes: 14 additions & 0 deletions src/microprobe/passes/register/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ def __init__(self,
rand: random.Random,
minimize=False,
value=None,
lmul=1,
dd: Union[int, float] = 0,
relax=False):
"""
Expand All @@ -80,6 +81,8 @@ def __init__(self,
else:
raise MicroprobeValueError("Invalid parameter")

self._lmul = lmul

self._relax = relax

if value is not None:
Expand All @@ -101,6 +104,17 @@ def __call__(self, building_block, target):
"""

allregs = target.registers

# RISC-V Specific
all_vec_regs = set([f"V{i}" for i in range(0, 32)])
if self._dd != 0:
# Play it safe, we might have a narrowing/widening insn.
# Pretend our lmul is one higher than it is.
lmul_allowed_regs = set([f"V{i}" for i in range(0, 32, self._lmul * 2)])
for lmul_disallowed_reg in all_vec_regs - lmul_allowed_regs:
if lmul_disallowed_reg in allregs:
del allregs[lmul_disallowed_reg]

lastdefined = {}
lastused = {}
rregs = set(building_block.context.reserved_registers)
Expand Down
7 changes: 7 additions & 0 deletions src/microprobe/target/isa/instruction.py
Original file line number Diff line number Diff line change
Expand Up @@ -1779,6 +1779,13 @@ def assembly(self, args, dissabled_fields=None):
"," + field.name + ")",
"," + next_operand_value().representation + ")", 1)

elif assembly_str.find(" " + field.name + ".t") >= 0:
assembly_str = assembly_str.replace(
", " + field.name + ".t",
", " + next_operand_value().representation + ".t",
1,
)

else:
LOG.debug(
"%s",
Expand Down
Loading

0 comments on commit b7c04f3

Please sign in to comment.