Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
92 changes: 89 additions & 3 deletions tests/test_magma.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,16 @@
from peak import family
from peak import name_outputs

from hwtypes import Bit, SMTBit, SMTBitVector, BitVector, Enum
from hwtypes import Bit, SMTBit, SMTBitVector, BitVector, Enum, Product, Sum
from examples.demo_pes.pe6 import PE_fc
from examples.demo_pes.pe6.sim import Inst

import fault
import magma
import itertools
import random

N_TESTS = 16


@pytest.mark.parametrize('named_outputs', [True, False])
Expand Down Expand Up @@ -71,8 +74,8 @@ def __call__(self, in0: Bit, in1: Bit) -> Bit:
def test_enum():

class Op(Enum):
And=1
Or=2
And=Enum.Auto()
Or=Enum.Auto()

@family_closure
def PE_fc(family):
Expand Down Expand Up @@ -121,6 +124,89 @@ def __call__(self, op: Const(Op), in0: Bit, in1: Bit) -> Bit:
tester.circuit.O.expect(gold)
tester.compile_and_run("verilator", flags=["-Wno-fatal"])


def test_sum():
Data = BitVector[16]
class Op(Enum):
Add = Enum.Auto()
Or = Enum.Auto()

class ImmOp(Product):
opcode = Op
imm = Data


class Inst(Sum[ImmOp, Op]): pass

@family_closure
def PE_fc(family):
@family.assemble(locals(), globals())
class PE_Sum(Peak):
def __call__(self, op: Const(Inst), in0: Data, in1: Data) -> Data:
imm = family.BitVector[16](0)
if op[ImmOp].match:
imm = op[ImmOp].value.imm
op = op[ImmOp].value.opcode
else:
op = op[Op].value

if op == Op.Add:
r = in0 + in1 + imm
else:
r = in0 | in1 | imm

return r
return PE_Sum


# generate golds
PE_bv = PE_fc(family.PyFamily())()

golds = []
for _ in range(N_TESTS):

if random.randint(0, 1):
op = Op.Add
else:
op = Op.Or

if random.randint(0, 1):
imm = BitVector.random(16)
inst = Inst(ImmOp(op, imm))
else:
inst = Inst(op)

in0 = BitVector.random(16)
in1 = BitVector.random(16)
out = PE_bv(inst, in0, in1)
golds.append((inst, in0, in1, out))

# verify smt works
PE_smt = PE_fc(family.SMTFamily())
Inst_aadt = AssembledADT[Inst, Assembler, SMTBitVector]

def to_smt(v):
return SMTBitVector[16](v.value)

for inst, in0, in1, out in golds:
inst = Inst_aadt(inst)
res = PE_smt()(inst, to_smt(in0), to_smt(in1))
assert res == out

# verify magma works
asm = Assembler(Inst)
PE_magma = PE_fc(family.MagmaFamily())
tester = fault.Tester(PE_magma)
for inst, in0, in1, out in golds:
tester.circuit.op = int(asm.assemble(inst))
tester.circuit.in0 = in0
tester.circuit.in1 = in1
tester.eval()
tester.circuit.O.expect(out)
tester.compile_and_run("verilator", flags=["-Wno-fatal"])



def test_wrap_with_disassembler():
class HashableDict(dict):
def __hash__(self):
Expand Down