Skip to content

Commit b7c04f3

Browse files
committed
Add preliminary RISC-V vector support (Assembly only)
Signed-off-by: Patrick O'Neill <[email protected]>
1 parent 7443d60 commit b7c04f3

File tree

14 files changed

+4203
-65
lines changed

14 files changed

+4203
-65
lines changed

dev_tools/parsers/parse_binutils_riscv.py

Lines changed: 51 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,12 @@
99

1010
VECTOR_EXTS = [
1111
"INSN_CLASS_V",
12+
"INSN_CLASS_ZVBB",
13+
"INSN_CLASS_ZVBC",
14+
"INSN_CLASS_ZVKNED",
15+
"INSN_CLASS_ZVKN",
16+
"INSN_CLASS_ZVKS",
17+
"INSN_CLASS_ZVEF",
1218
]
1319

1420

@@ -88,6 +94,25 @@ def get_pattern_info() -> Dict[str, PatternInfo]:
8894
["funct10", "vd", "vrs2", "rs1", "opcode"], "OPC vd, rs1, vrs2"
8995
)
9096

97+
# Vec, Float -> Vec patterns
98+
pattern_info_dict["Vd,Vt,S"] = PatternInfo(
99+
["funct10", "vd", "vrs2", "frs1", "opcode"], "OPC vd, vrs2, frs1"
100+
)
101+
pattern_info_dict["Vd,S,Vt"] = PatternInfo(
102+
["funct10", "vd", "vrs2", "frs1", "opcode"],
103+
"OPC vd, frs1, vrs2",
104+
)
105+
106+
# Float -> Vec patterns
107+
pattern_info_dict["Vd,S"] = PatternInfo(
108+
["funct10", "funct5", "vd", "frs1", "opcode"], "OPC vd, frs1"
109+
)
110+
111+
# Vec -> Float patterns
112+
pattern_info_dict["D,Vt"] = PatternInfo(
113+
["funct10", "funct5", "frd", "vrs2", "opcode"], "OPC frd, vrs2"
114+
)
115+
91116
# Vec -> Vec patterns
92117
pattern_info_dict["Vd,Vs"] = PatternInfo(
93118
["funct10", "funct5", "vd", "vrs1", "opcode"], "OPC vd, vrs1"
@@ -103,10 +128,6 @@ def get_pattern_info() -> Dict[str, PatternInfo]:
103128
["funct10", "vd", "vrs1", "vrs2", "opcode"],
104129
"OPC vd, vrs1, vrs2",
105130
)
106-
pattern_info_dict["Vd,S,Vt"] = PatternInfo(
107-
["funct10", "vd", "rs1", "vrs2", "opcode"],
108-
"OPC vd, rs1, vrs2",
109-
)
110131
pattern_info_dict["Vd,Vt"] = PatternInfo(
111132
["funct10", "funct5", "vd", "vrs2", "opcode"], "OPC vd, vrs2"
112133
)
@@ -150,6 +171,32 @@ def get_pattern_info() -> Dict[str, PatternInfo]:
150171
"OPC vd, vrs2, i_imm6",
151172
)
152173

174+
pattern_info_dict["Vd,Vt,Vs,V0"] = PatternInfo(
175+
["funct5", "vmd", "vrs2", "vrs1", "vmask", "opcode"],
176+
"OPC vmd, vrs1, vrs2, vmask",
177+
)
178+
179+
pattern_info_dict["Vd,Vt,s,V0"] = PatternInfo(
180+
["funct5", "vmd", "vrs2", "rs1", "vmask", "opcode"],
181+
"OPC vmd, vrs2, rs1, vmask",
182+
)
183+
184+
# Float
185+
pattern_info_dict["Vd,Vt,S,V0"] = PatternInfo(
186+
["funct5", "vmd", "vrs2", "frs1", "vmask", "opcode"],
187+
"OPC vmd, vrs2, frs1, vmask",
188+
)
189+
190+
pattern_info_dict["d,Vt"] = PatternInfo(
191+
["funct10", "funct5", "rd", "vrs2", "opcode"],
192+
"OPC rd, vrs2",
193+
)
194+
195+
pattern_info_dict["d,VtVm"] = PatternInfo(
196+
["funct10", "rd", "vrs2", "vmask", "opcode"],
197+
"OPC rd, vrs2, vmask.t",
198+
)
199+
153200
# Add masked variants
154201
masked_patterns: Dict[str, PatternInfo] = {}
155202
for pattern, pattern_info in pattern_info_dict.items():
@@ -210,26 +257,6 @@ def get_pattern_info() -> Dict[str, PatternInfo]:
210257

211258
pattern_info_dict |= widening_patterns
212259

213-
pattern_info_dict["Vd,Vt,Vs,V0"] = PatternInfo(
214-
["funct5", "vmd", "vrs2", "vrs1", "vmask", "opcode"],
215-
"OPC vmd, vrs1, vrs2, vmask",
216-
)
217-
218-
pattern_info_dict["Vd,Vt,s,V0"] = PatternInfo(
219-
["funct5", "vmd", "vrs2", "rs1", "vmask", "opcode"],
220-
"OPC vmd, vrs2, rs1, vmask",
221-
)
222-
223-
pattern_info_dict["d,Vt"] = PatternInfo(
224-
["funct10", "funct5", "rd", "vrs2", "opcode"],
225-
"OPC rd, vrs2",
226-
)
227-
228-
pattern_info_dict["d,VtVm"] = PatternInfo(
229-
["funct10", "rd", "vrs2", "vmask", "opcode"],
230-
"OPC rd, vrs2, vmask.t",
231-
)
232-
233260
return pattern_info_dict
234261

235262

src/microprobe/code/ins.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
# Built-in modules
2222
import copy
2323
from itertools import product
24-
from typing import TYPE_CHECKING, Callable, List
24+
from typing import TYPE_CHECKING, Callable, Dict, List
2525

2626
# Third party modules
2727

@@ -1603,7 +1603,8 @@ def __init__(self):
16031603
self._generic_type = None
16041604
self._label = None
16051605
self._mem_operands = []
1606-
self._operands = RejectingOrderedDict()
1606+
self._operands: Dict[str,
1607+
InstructionOperandValue] = RejectingOrderedDict()
16071608

16081609
def set_arch_type(self, instrtype):
16091610
"""
@@ -1612,7 +1613,8 @@ def set_arch_type(self, instrtype):
16121613
16131614
"""
16141615
self._arch_type = instrtype
1615-
self._operands = RejectingOrderedDict()
1616+
self._operands: Dict[str,
1617+
InstructionOperandValue] = RejectingOrderedDict()
16161618
self._mem_operands = []
16171619
self._allowed_regs = []
16181620
self._address = None

src/microprobe/passes/initialization/__init__.py

Lines changed: 36 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,10 @@
1616
"""
1717

1818
# Futures
19-
from __future__ import absolute_import, print_function
19+
from __future__ import absolute_import, print_function, annotations
20+
21+
# Built-in modules
22+
from typing import TYPE_CHECKING
2023

2124
# Third party modules
2225

@@ -33,6 +36,11 @@
3336
from microprobe.utils.ieee import ieee_float_to_int64
3437
from microprobe.utils.logger import get_logger
3538

39+
# Type hinting
40+
if TYPE_CHECKING:
41+
from microprobe.code.benchmark import Benchmark
42+
from microprobe.target import Target
43+
3644
# Constants
3745
LOG = get_logger(__name__)
3846
__all__ = [
@@ -213,6 +221,8 @@ def __init__(self, *args, **kwargs):
213221
skip_unknown = kwargs.get("skip_unknown", False)
214222
warn_unknown = kwargs.get("warn_unknown", False)
215223
self._force_code = kwargs.get("force_code", False)
224+
self.lmul = kwargs.get("lmul", 1)
225+
self.sew = kwargs.get("sew", 32)
216226

217227
if len(args) == 1:
218228
self._reg_dict = dict(
@@ -242,7 +252,7 @@ def __init__(self, *args, **kwargs):
242252
% (self._value, self._fp_value, v_value)
243253
)
244254

245-
def __call__(self, building_block, target):
255+
def __call__(self, building_block: Benchmark, target: Target):
246256
"""
247257
248258
:param building_block:
@@ -251,15 +261,15 @@ def __call__(self, building_block, target):
251261
"""
252262
if not self._skip_unknown:
253263
for register_name in self._reg_dict:
254-
if register_name not in list(target.registers.keys()):
264+
if register_name not in list(target.isa.registers.keys()):
255265
raise MicroprobeCodeGenerationError(
256266
f"Unknown register name: '{register_name}'. "
257267
"Unable to set it"
258268
)
259269

260270
if self._warn_unknown:
261271
for register_name in self._reg_dict:
262-
if register_name not in list(target.registers.keys()):
272+
if register_name not in list(target.isa.registers.keys()):
263273
print_warning(
264274
f"Unknown register name: '{register_name}'. "
265275
"Unable to set it"
@@ -275,7 +285,7 @@ def __call__(self, building_block, target):
275285
#
276286
# Make sure scratch registers are set last
277287
#
278-
for reg in target.scratch_registers:
288+
for reg in target.isa.scratch_registers:
279289
if reg in regs:
280290
regs.remove(reg)
281291
regs.append(reg)
@@ -290,10 +300,23 @@ def __call__(self, building_block, target):
290300
self._reg_dict.pop(reg.name)
291301
force_direct = True
292302

293-
if (
294-
reg in building_block.context.reserved_registers
295-
and not self._force_reserved
296-
):
303+
if reg.name == "LMUL":
304+
packed_lmul_sew = self.lmul << 9 | self.sew & 127
305+
building_block.add_init(
306+
target.isa.set_register(reg, packed_lmul_sew, building_block.context)
307+
)
308+
building_block.context.set_register_value(reg, packed_lmul_sew)
309+
continue
310+
311+
all_vec_regs = set([f"V{i}" for i in range(0, 32)])
312+
lmul_allowed_regs = set([f"V{i}" for i in range(0, 32, self.lmul)])
313+
314+
if reg.name in all_vec_regs - lmul_allowed_regs:
315+
# Skip vector registers ignored by lmul
316+
continue
317+
318+
if (reg in building_block.context.reserved_registers
319+
and not self._force_reserved):
297320
LOG.debug("Skip reserved - %s", reg)
298321
continue
299322

@@ -304,7 +327,7 @@ def __call__(self, building_block, target):
304327
continue
305328

306329
if value is None:
307-
if reg.used_for_vector_arithmetic:
330+
if reg.type.used_for_vector_arithmetic:
308331
if self._vect_value is not None:
309332
value = self._vect_value
310333
elemsize = self._vect_elemsize
@@ -313,7 +336,7 @@ def __call__(self, building_block, target):
313336
"Skip no vector default value provided - %s", reg
314337
)
315338
continue
316-
elif reg.used_for_float_arithmetic:
339+
elif reg.type.used_for_float_arithmetic:
317340
if self._fp_value is not None:
318341
value = self._fp_value
319342
else:
@@ -334,10 +357,10 @@ def __call__(self, building_block, target):
334357
if isinstance(value, int):
335358
value = value & ((2**reg.size) - 1)
336359

337-
if reg.used_for_float_arithmetic:
360+
if reg.type.used_for_float_arithmetic:
338361
value = ieee_float_to_int64(float(value))
339362

340-
elif reg.used_for_vector_arithmetic:
363+
elif reg.type.used_for_vector_arithmetic:
341364
if isinstance(value, float):
342365
if elemsize != 64:
343366
raise MicroprobeCodeGenerationError(

src/microprobe/passes/register/__init__.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@ def __init__(self,
5555
rand: random.Random,
5656
minimize=False,
5757
value=None,
58+
lmul=1,
5859
dd: Union[int, float] = 0,
5960
relax=False):
6061
"""
@@ -80,6 +81,8 @@ def __init__(self,
8081
else:
8182
raise MicroprobeValueError("Invalid parameter")
8283

84+
self._lmul = lmul
85+
8386
self._relax = relax
8487

8588
if value is not None:
@@ -101,6 +104,17 @@ def __call__(self, building_block, target):
101104
"""
102105

103106
allregs = target.registers
107+
108+
# RISC-V Specific
109+
all_vec_regs = set([f"V{i}" for i in range(0, 32)])
110+
if self._dd != 0:
111+
# Play it safe, we might have a narrowing/widening insn.
112+
# Pretend our lmul is one higher than it is.
113+
lmul_allowed_regs = set([f"V{i}" for i in range(0, 32, self._lmul * 2)])
114+
for lmul_disallowed_reg in all_vec_regs - lmul_allowed_regs:
115+
if lmul_disallowed_reg in allregs:
116+
del allregs[lmul_disallowed_reg]
117+
104118
lastdefined = {}
105119
lastused = {}
106120
rregs = set(building_block.context.reserved_registers)

src/microprobe/target/isa/instruction.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1779,6 +1779,13 @@ def assembly(self, args, dissabled_fields=None):
17791779
"," + field.name + ")",
17801780
"," + next_operand_value().representation + ")", 1)
17811781

1782+
elif assembly_str.find(" " + field.name + ".t") >= 0:
1783+
assembly_str = assembly_str.replace(
1784+
", " + field.name + ".t",
1785+
", " + next_operand_value().representation + ".t",
1786+
1,
1787+
)
1788+
17821789
else:
17831790
LOG.debug(
17841791
"%s",

0 commit comments

Comments
 (0)