diff --git a/p-isa_tools/kerngen/pisa_generators/basic.py b/p-isa_tools/kerngen/pisa_generators/basic.py index 442d8b0..0835561 100644 --- a/p-isa_tools/kerngen/pisa_generators/basic.py +++ b/p-isa_tools/kerngen/pisa_generators/basic.py @@ -69,14 +69,18 @@ def to_pisa(self) -> list[PIsaOp]: # Not the same number of parts first, second = (self.input0, self.input1) if self.input0.parts < self.input1.parts else (self.input1, self.input0) + # Preserve the ordering of the inputs (output = input0 op input1) to avoid issues with non-commutative ops (e.g. sub) + # For example input0 - input1 != input1 - input0 + order_preserved = first == self.input0 + ls: list[PIsaOp] = [] for unit, q in it.product(range(self.context.units), range(self.input0.start_rns, self.input0.rns)): ls.extend( self.op( self.context.label, self.output(part, q, unit), - first(part, q, unit), - second(0, q, unit), + first(part, q, unit) if order_preserved else second(0, q, unit), + second(0, q, unit) if order_preserved else first(part, q, unit), q, ) for part in range(first.parts)