Skip to content

Commit 5516e90

Browse files
committed
Add preliminary RISC-V vector support (Assembly only)
Signed-off-by: Patrick O'Neill <patrick@rivosinc.com>
1 parent 301e765 commit 5516e90

File tree

11 files changed

+2853
-39
lines changed

11 files changed

+2853
-39
lines changed

src/microprobe/code/ins.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
# Built-in modules
2222
import copy
2323
from itertools import product
24-
from typing import TYPE_CHECKING, Callable, List
24+
from typing import TYPE_CHECKING, Callable, Dict, List
2525

2626
# Third party modules
2727
import six
@@ -1595,7 +1595,8 @@ def __init__(self):
15951595
self._generic_type = None
15961596
self._label = None
15971597
self._mem_operands = []
1598-
self._operands = RejectingOrderedDict()
1598+
self._operands: Dict[str,
1599+
InstructionOperandValue] = RejectingOrderedDict()
15991600

16001601
def set_arch_type(self, instrtype):
16011602
"""
@@ -1604,7 +1605,8 @@ def set_arch_type(self, instrtype):
16041605
16051606
"""
16061607
self._arch_type = instrtype
1607-
self._operands = RejectingOrderedDict()
1608+
self._operands: Dict[str,
1609+
InstructionOperandValue] = RejectingOrderedDict()
16081610
self._mem_operands = []
16091611
self._allowed_regs = []
16101612
self._address = None

src/microprobe/passes/initialization/__init__.py

Lines changed: 36 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,10 @@
1616
"""
1717

1818
# Futures
19-
from __future__ import absolute_import, print_function
19+
from __future__ import absolute_import, print_function, annotations
2020

2121
# Built-in modules
22+
from typing import TYPE_CHECKING
2223

2324
# Third party modules
2425
from six.moves import zip
@@ -36,6 +37,10 @@
3637

3738
# Local modules
3839

40+
# Type hinting
41+
if TYPE_CHECKING:
42+
from microprobe.code.benchmark import Benchmark
43+
from microprobe.target import Target
3944

4045
# Constants
4146
LOG = get_logger(__name__)
@@ -222,6 +227,7 @@ def __init__(self, *args, **kwargs):
222227
skip_unknown = kwargs.get("skip_unknown", False)
223228
warn_unknown = kwargs.get("warn_unknown", False)
224229
self._force_code = kwargs.get("force_code", False)
230+
self.lmul = kwargs.get("lmul", 1)
225231

226232
if len(args) == 1:
227233
self._reg_dict = dict([
@@ -250,7 +256,7 @@ def __init__(self, *args, **kwargs):
250256
self._fp_value,
251257
v_value)
252258

253-
def __call__(self, building_block, target):
259+
def __call__(self, building_block: Benchmark, target: Target):
254260
"""
255261
256262
:param building_block:
@@ -259,26 +265,26 @@ def __call__(self, building_block, target):
259265
"""
260266
if not self._skip_unknown:
261267
for register_name in self._reg_dict:
262-
if register_name not in list(target.registers.keys()):
268+
if register_name not in list(target.isa.registers.keys()):
263269
raise MicroprobeCodeGenerationError(
264270
"Unknown register name: '%s'. Unable to set it" %
265271
register_name)
266272

267273
if self._warn_unknown:
268274
for register_name in self._reg_dict:
269-
if register_name not in list(target.registers.keys()):
275+
if register_name not in list(target.isa.registers.keys()):
270276
print_warning(
271277
"Unknown register name: '%s'. Unable to set it" %
272278
register_name)
273279

274-
regs = sorted(target.registers.values(),
280+
regs = sorted(target.isa.registers.values(),
275281
key=lambda x: self._priolist.index(x.name)
276282
if x.name in self._priolist else 314159)
277283

278284
#
279285
# Make sure scratch registers are set last
280286
#
281-
for reg in target.scratch_registers:
287+
for reg in target.isa.scratch_registers:
282288
if reg in regs:
283289
regs.remove(reg)
284290
regs.append(reg)
@@ -294,25 +300,39 @@ def __call__(self, building_block, target):
294300
self._reg_dict.pop(reg.name)
295301
force_direct = True
296302

297-
if (reg in building_block.context.reserved_registers and
298-
not self._force_reserved):
303+
if reg.name == "LMUL":
304+
building_block.add_init(
305+
target.isa.set_register(reg, self.lmul,
306+
building_block.context))
307+
building_block.context.set_register_value(reg, self.lmul)
308+
continue
309+
310+
all_vec_regs = set([f"V{i}" for i in range(0, 32)])
311+
lmul_allowed_regs = set([f"V{i}" for i in range(0, 32, self.lmul)])
312+
313+
if reg.name in all_vec_regs - lmul_allowed_regs:
314+
# Skip vector registers ignored by lmul
315+
continue
316+
317+
if (reg in building_block.context.reserved_registers
318+
and not self._force_reserved):
299319
LOG.debug("Skip reserved - %s", reg)
300320
continue
301-
elif (reg in target.control_registers and
302-
(value is None or self._skip_control)):
321+
elif (reg in target.isa.control_registers
322+
and (value is None or self._skip_control)):
303323
LOG.debug("Skip control - %s", reg)
304324
continue
305325

306326
if value is None:
307-
if reg.used_for_vector_arithmetic:
327+
if reg.type.used_for_vector_arithmetic:
308328
if self._vect_value is not None:
309329
value = self._vect_value
310330
elemsize = self._vect_elemsize
311331
else:
312332
LOG.debug("Skip no vector default value provided - %s",
313333
reg)
314334
continue
315-
elif reg.used_for_float_arithmetic:
335+
elif reg.type.used_for_float_arithmetic:
316336
if self._fp_value is not None:
317337
value = self._fp_value
318338
else:
@@ -332,10 +352,10 @@ def __call__(self, building_block, target):
332352
if isinstance(value, int):
333353
value = value & ((2**reg.size)-1)
334354

335-
if reg.used_for_float_arithmetic:
355+
if reg.type.used_for_float_arithmetic:
336356
value = ieee_float_to_int64(float(value))
337357

338-
elif reg.used_for_vector_arithmetic:
358+
elif reg.type.used_for_vector_arithmetic:
339359
if isinstance(value, float):
340360
if elemsize != 64:
341361
raise MicroprobeCodeGenerationError(
@@ -360,13 +380,13 @@ def __call__(self, building_block, target):
360380
else:
361381
LOG.debug("Direct set of '%s' to '0x%x'", reg, value)
362382
except MicroprobeCodeGenerationError:
363-
building_block.add_init(target.set_register(
383+
building_block.add_init(target.isa.set_register(
364384
reg, value, building_block.context))
365385
LOG.debug("Set '%s' to '0x%x'", reg, value)
366386
except MicroprobeDuplicatedValueError:
367387
LOG.debug("Skip already set - %s", reg)
368388
else:
369-
building_block.add_init(target.set_register(
389+
building_block.add_init(target.isa.set_register(
370390
reg, value, building_block.context))
371391
building_block.context.set_register_value(reg, value)
372392

src/microprobe/target/isa/instruction.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1763,6 +1763,13 @@ def assembly(self, args, dissabled_fields=None):
17631763
"," + field.name + ")",
17641764
"," + next_operand_value().representation + ")", 1)
17651765

1766+
elif assembly_str.find(" " + field.name + ".t") >= 0:
1767+
assembly_str = assembly_str.replace(
1768+
", " + field.name + ".t",
1769+
", " + next_operand_value().representation + ".t",
1770+
1,
1771+
)
1772+
17661773
else:
17671774
LOG.debug(
17681775
"%s",

src/microprobe/target/isa/operand.py

Lines changed: 97 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -16,12 +16,13 @@
1616
"""
1717

1818
# Futures
19-
from __future__ import absolute_import, print_function
19+
from __future__ import absolute_import, print_function, annotations
2020

2121
# Built-in modules
2222
import abc
2323
import os
2424
import random
25+
from typing import Dict, List, TYPE_CHECKING, cast
2526

2627
# Third party modules
2728
import six
@@ -38,6 +39,10 @@
3839
from microprobe.utils.misc import OrderedDict, natural_sort
3940
from microprobe.utils.yaml import read_yaml
4041

42+
# Type hinting
43+
if TYPE_CHECKING:
44+
from microprobe.code.context import Context
45+
4146
# Constants
4247
SCHEMA = os.path.join(os.path.dirname(os.path.abspath(__file__)), "schemas",
4348
"operand.yaml")
@@ -281,7 +286,7 @@ class OperandDescriptor:
281286
282287
"""
283288

284-
def __init__(self, mtype, is_input, is_output):
289+
def __init__(self, mtype: Operand, is_input, is_output):
285290
"""
286291
287292
:param mtype:
@@ -308,7 +313,7 @@ def is_output(self):
308313
"""Is output flag (:class:`~.bool`) """
309314
return self._is_output
310315

311-
def set_type(self, new_type):
316+
def set_type(self, new_type: Operand):
312317
"""
313318
314319
:param new_type:
@@ -609,7 +614,14 @@ def copy(self):
609614
raise NotImplementedError
610615

611616
@abc.abstractmethod
612-
def values(self):
617+
def values(self) -> List[Register]:
618+
"""Return the possible value of the operand."""
619+
raise NotImplementedError
620+
621+
# TODO: Consider making filtered_values into values.
622+
def filtered_values(
623+
self, context: Context, fieldname: str
624+
) -> List[Register]:
613625
"""Return the possible value of the operand."""
614626
raise NotImplementedError
615627

@@ -759,8 +771,14 @@ class OperandReg(Operand):
759771
760772
"""
761773

762-
def __init__(self, name, descr, regs, address_base, address_index,
763-
floating_point, vector):
774+
def __init__(self,
775+
name: str,
776+
descr: str,
777+
regs: List[Register] | Dict[Register, List[Register]],
778+
address_base,
779+
address_index: int,
780+
floating_point: bool | None,
781+
vector: bool | None):
764782
"""
765783
766784
:param name:
@@ -775,7 +793,7 @@ def __init__(self, name, descr, regs, address_base, address_index,
775793
super(OperandReg, self).__init__(name, descr)
776794

777795
if isinstance(regs, list):
778-
self._regs = OrderedDict()
796+
self._regs: Dict[Register, List[Register]] = OrderedDict()
779797
for reg in regs:
780798
self._regs[reg] = [reg]
781799
else:
@@ -801,6 +819,53 @@ def values(self):
801819
"""
802820
return list(self._regs.keys())
803821

822+
def filtered_values(self, context: Context, fieldname: str):
823+
lmul = cast(int | None, context.get_registername_value("LMUL"))
824+
825+
if lmul is None or not fieldname.startswith("v"):
826+
return self.values()
827+
elif fieldname in ["vd", "vmd", "vrs1", "vrs2", "vmask"]:
828+
lmul *= 1
829+
elif fieldname in ["vdd", "vdmd", "vdrs1", "vdrs2", "vnd", "vnmd"]:
830+
lmul *= 2
831+
elif fieldname in []:
832+
lmul *= 4
833+
elif fieldname in []:
834+
lmul *= 8
835+
else:
836+
raise ValueError(f"Unhandled LMUL operand name: {fieldname}")
837+
838+
regs = list(self._regs.keys())
839+
840+
class LMULRegs:
841+
lmul1 = regs
842+
lmul2 = [
843+
reg
844+
for reg in self._regs.keys()
845+
if reg.name in set([f"V{i}" for i in range(0, 32, 2)])
846+
]
847+
lmul4 = [
848+
reg
849+
for reg in self._regs.keys()
850+
if reg.name in set([f"V{i}" for i in range(0, 32, 4)])
851+
]
852+
lmul8 = [
853+
reg
854+
for reg in self._regs.keys()
855+
if reg.name in set([f"V{i}" for i in range(0, 32, 8)])
856+
]
857+
858+
if lmul == 1:
859+
return LMULRegs.lmul1
860+
elif lmul == 2:
861+
return LMULRegs.lmul2
862+
elif lmul == 4:
863+
return LMULRegs.lmul4
864+
elif lmul == 8:
865+
return LMULRegs.lmul8
866+
else:
867+
raise ValueError(f"Unhandled LMUL value: {lmul}")
868+
804869
def representation(self, value):
805870
"""
806871
@@ -918,6 +983,11 @@ def values(self):
918983
]
919984
return self._computed_values
920985

986+
def filtered_values(
987+
self, context: Context, fieldname: str
988+
):
989+
return super().filtered_values(context, fieldname)
990+
921991
def set_valid_values(self, values):
922992
"""
923993
@@ -1073,6 +1143,11 @@ def values(self):
10731143
"""
10741144
return self._values
10751145

1146+
def filtered_values(
1147+
self, context: Context, fieldname: str
1148+
):
1149+
return super().filtered_values(context, fieldname)
1150+
10761151
def representation(self, value):
10771152
"""
10781153
@@ -1166,6 +1241,11 @@ def values(self):
11661241
"""
11671242
return [self._value]
11681243

1244+
def filtered_values(
1245+
self, context: Context, fieldname: str
1246+
):
1247+
return super().filtered_values(context, fieldname)
1248+
11691249
def representation(self, value):
11701250
"""
11711251
@@ -1273,6 +1353,11 @@ def values(self):
12731353
"""
12741354
return [self._reg]
12751355

1356+
def filtered_values(
1357+
self, context: Context, fieldname: str
1358+
):
1359+
return super().filtered_values(context, fieldname)
1360+
12761361
def random_value(self):
12771362
"""Return a random possible value for the operand.
12781363
@@ -1380,6 +1465,11 @@ def values(self):
13801465
"""
13811466
return [self._mindispl << self._shift]
13821467

1468+
def filtered_values(
1469+
self, context: Context, fieldname: str
1470+
):
1471+
return super().filtered_values(context, fieldname)
1472+
13831473
def random_value(self):
13841474
"""Return a random possible value for the operand.
13851475

0 commit comments

Comments
 (0)