From 8fc2a552ef4b2e53ccd8030326243959582c096d Mon Sep 17 00:00:00 2001 From: Caleb Donovick Date: Fri, 16 Jul 2021 15:18:06 -0700 Subject: [PATCH] Add test --- tests/test_magma.py | 92 +++++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 89 insertions(+), 3 deletions(-) diff --git a/tests/test_magma.py b/tests/test_magma.py index d2a1bb6..17be6a2 100644 --- a/tests/test_magma.py +++ b/tests/test_magma.py @@ -5,13 +5,16 @@ from peak import family from peak import name_outputs -from hwtypes import Bit, SMTBit, SMTBitVector, BitVector, Enum +from hwtypes import Bit, SMTBit, SMTBitVector, BitVector, Enum, Product, Sum from examples.demo_pes.pe6 import PE_fc from examples.demo_pes.pe6.sim import Inst import fault import magma import itertools +import random + +N_TESTS = 16 @pytest.mark.parametrize('named_outputs', [True, False]) @@ -71,8 +74,8 @@ def __call__(self, in0: Bit, in1: Bit) -> Bit: def test_enum(): class Op(Enum): - And=1 - Or=2 + And=Enum.Auto() + Or=Enum.Auto() @family_closure def PE_fc(family): @@ -121,6 +124,89 @@ def __call__(self, op: Const(Op), in0: Bit, in1: Bit) -> Bit: tester.circuit.O.expect(gold) tester.compile_and_run("verilator", flags=["-Wno-fatal"]) + +def test_sum(): + Data = BitVector[16] + class Op(Enum): + Add = Enum.Auto() + Or = Enum.Auto() + + class ImmOp(Product): + opcode = Op + imm = Data + + + class Inst(Sum[ImmOp, Op]): pass + + @family_closure + def PE_fc(family): + @family.assemble(locals(), globals()) + class PE_Sum(Peak): + def __call__(self, op: Const(Inst), in0: Data, in1: Data) -> Data: + imm = family.BitVector[16](0) + if op[ImmOp].match: + imm = op[ImmOp].value.imm + op = op[ImmOp].value.opcode + else: + op = op[Op].value + + if op == Op.Add: + r = in0 + in1 + imm + else: + r = in0 | in1 | imm + + return r + return PE_Sum + + + # generate golds + PE_bv = PE_fc(family.PyFamily())() + + golds = [] + for _ in range(N_TESTS): + + if random.randint(0, 1): + op = Op.Add + else: + op = Op.Or + + if random.randint(0, 1): + imm = BitVector.random(16) + inst = Inst(ImmOp(op, imm)) + else: + inst = Inst(op) + + in0 = BitVector.random(16) + in1 = BitVector.random(16) + out = PE_bv(inst, in0, in1) + golds.append((inst, in0, in1, out)) + + # verify smt works + PE_smt = PE_fc(family.SMTFamily()) + Inst_aadt = AssembledADT[Inst, Assembler, SMTBitVector] + + def to_smt(v): + return SMTBitVector[16](v.value) + + for inst, in0, in1, out in golds: + inst = Inst_aadt(inst) + res = PE_smt()(inst, to_smt(in0), to_smt(in1)) + assert res == out + + # verify magma works + asm = Assembler(Inst) + PE_magma = PE_fc(family.MagmaFamily()) + tester = fault.Tester(PE_magma) + for inst, in0, in1, out in golds: + tester.circuit.op = int(asm.assemble(inst)) + tester.circuit.in0 = in0 + tester.circuit.in1 = in1 + tester.eval() + tester.circuit.O.expect(out) + tester.compile_and_run("verilator", flags=["-Wno-fatal"]) + + + def test_wrap_with_disassembler(): class HashableDict(dict): def __hash__(self):