diff --git a/peak/ir.py b/peak/ir.py index 560f74fa..0491a910 100644 --- a/peak/ir.py +++ b/peak/ir.py @@ -3,6 +3,7 @@ from hwtypes.adt import Product from hwtypes import AbstractBitVector, AbstractBit, BitVector, Bit from .peak import Peak +from . import family as peak_family from .features import name_outputs, family_closure import itertools as it from hwtypes.adt_util import rebind_type @@ -19,7 +20,7 @@ def add_instruction(self, name, peak_fc : tp.Callable): self.instructions[name] = peak_fc #fun should have the form def fun(family, *args) - def add_peak_instruction(self, name : str, input_interface : Product, output_interface : Product, fun : tp.Callable, cls_name=None): + def add_peak_instruction(self, name : str, input_interface : Product, output_interface : Product, fun : tp.Callable, family=peak_family, cls_name=None): if cls_name is None: cls_name = name #Assuming for now that abstract bitvectors are used in the interfaces @@ -32,10 +33,11 @@ def add_peak_instruction(self, name : str, input_interface : Product, output_int if not t in t_to_tname: t_to_tname[t] = f"t{idx}" idx +=1 - class_src = [f"@_family_closure"] + class_src = [f"@_family_closure(_family)"] class_src.append(f"def peak_fc(family):") for t, tname in t_to_tname.items(): - class_src.append(f"{tab*1}_{tname} = _rebind_type({tname}, family)") + #class_src.append(f"{tab*1}_{tname} = _rebind_type({tname}, family)") + class_src.append(f"{tab * 1}_{tname} = {tname}") class_src.append(f"{tab*1}class {cls_name}(Peak):") output_types = ", ".join([f"{field} = _{t_to_tname[t]}" for field, t in outputs.items()]) input_types = ", ".join([f"{field} : _{t_to_tname[t]}" for field, t in inputs.items()]) @@ -52,8 +54,9 @@ def add_peak_instruction(self, name : str, input_interface : Product, output_int Peak=Peak, name_outputs=name_outputs, _fun_=fun, - _rebind_type=rebind_type, - _family_closure=family_closure + #_rebind_type=rebind_type, + _family_closure=family_closure, + _family=family )) exec(class_src, exec_gs, exec_ls) peak_fc = exec_ls["peak_fc"] diff --git a/peak/mapper/mapper.py b/peak/mapper/mapper.py index a02af079..dcc6eeed 100644 --- a/peak/mapper/mapper.py +++ b/peak/mapper/mapper.py @@ -91,7 +91,8 @@ def __init__(self, peak_fc : tp.Callable, family=peak_family): output_forms = [] self.const_valid_conditions = [] const_fields = [field for field, T in input_t.field_dict.items() if issubclass(T, Const)] - + print(const_fields) + print(list(input_t.field_dict.items())) for input_form in input_forms: inputs = aadt_product_to_dict(input_form.value) self.const_valid_conditions.append([is_valid(inputs[field]) for field in const_fields]) @@ -517,8 +518,7 @@ def rr_from_solver(solver, irmapper): bv_ibinding = strip_aadt(bv_ibinding) return RewriteRule(bv_ibinding, obinding, im.peak_fc, am.peak_fc) -def external_loop_solve(y, phi, logic = BV, maxloops = 10, solver_name = "cvc4", irmapper = None): - +def external_loop_solve(y, phi, logic = BV, maxloops = 10, solver_name = "z3", irmapper = None): y = set(y) x = phi.get_free_variables() - y diff --git a/tests/test_ir.py b/tests/test_ir.py index a81e1a2f..1be9c760 100644 --- a/tests/test_ir.py +++ b/tests/test_ir.py @@ -1,8 +1,6 @@ from peak import family_closure from peak.ir import IR from hwtypes import BitVector, Bit, UIntVector -from hwtypes import AbstractBitVector as ABV -from hwtypes import AbstractBit as ABit from hwtypes import SMTBit from hwtypes.adt import Product, Sum, Enum from examples.smallir import gen_SmallIR @@ -20,13 +18,13 @@ def rand_value(width): def test_add_peak_instruction(): class Input(Product): - a = ABV[16] - b = ABV[16] - c = ABit + a = BitVector[16] + b = BitVector[16] + c = Bit class Output(Product): - x = ABV[16] - y = ABit + x = BitVector[16] + y = Bit ir = IR() def fun(family, a, b, c):