diff --git a/README.md b/README.md index 1539c65..ae270c0 100644 --- a/README.md +++ b/README.md @@ -21,7 +21,11 @@ This README only contains a brief overview of the library's current contents. Al Utilizing Chisel and ChiselSim, `approx` requires a suitable installation of Scala. For this purpose, we use the Scala Build Tool (`sbt`) for which we provide a suitable build script. The provided tests require a recent version of Verilator. -This library is tested in Ubuntu 24.04 with Verilator 5.032. Note that the default Verilator version (5.020) available through `apt` in Ubunty 24.04 is _not_ new enough. +This library is tested in Ubuntu 24.04 with Verilator 5.032. Note that the default Verilator version (5.020) available through `apt` in Ubunty 24.04 is _not_ new enough. If you wish to have VCD dumps from the simulations, pass the `emitVcd` flag to `testOnly`, for example: + +```bash +sbt "testOnly approx.addition.RCASpec -- -DemitVcd=1" +``` *** # Adders diff --git a/src/main/scala/approx/accumulation/Exact.scala b/src/main/scala/approx/accumulation/Exact.scala index 11f38a8..550c937 100644 --- a/src/main/scala/approx/accumulation/Exact.scala +++ b/src/main/scala/approx/accumulation/Exact.scala @@ -1,8 +1,10 @@ package approx.accumulation import chisel3._ +import chisel3.util.RegEnable import chisel3.util.experimental.FlattenInstance +import approx.util.PRShiftReg import approx.multiplication.comptree.{Approximation, Signature, CompressorTree} /** Simple accumulator @@ -10,39 +12,70 @@ import approx.multiplication.comptree.{Approximation, Signature, CompressorTree} * @param inW the width of the input operand * @param accW the width of the accumulator * @param signed whether the input operands are signed (defaults to false) + * @param pipes the number of pipeline stages (defaults to 0) + * + * Pipelining relies on retiming! */ -class SimpleAccumulator(inW: Int, accW: Int, signed: Boolean = false) extends SA(inW, accW, signed) { +class SimpleAccumulator(inW: Int, accW: Int, signed: Boolean = false, pipes: Int = 0) + extends SA(inW, accW, signed, pipes) { // Extend the input to the width of the accumulator if needed val inExt = if (inW < accW) { val sext = if (signed) VecInit(Seq.fill(accW - inW)(io.in(inW-1))).asUInt else 0.U((accW - inW).W) sext ## io.in } else io.in(accW-1, 0) - val acc = RegInit(0.U(accW.W)) - acc := inExt + Mux(io.zero, 0.U, acc) + // Pass the extended input through a series of registers + val dataShReg = Module(new PRShiftReg(UInt(accW.W), pipes)) + dataShReg.io.in := inExt + + // Pass enable and zero through a shift register as needed + val enShReg = Module(new PRShiftReg(Bool(), pipes)) + enShReg.io.in := io.en + val zeroShReg = Module(new PRShiftReg(Bool(), pipes)) + zeroShReg.io.in := io.zero + + // Compute the sum and register the accumulator + val sum = Wire(UInt(accW.W)) + val acc = RegEnable(sum, 0.U(accW.W), enShReg.io.out.last) + sum := dataShReg.io.out.last + Mux(zeroShReg.io.out.last, 0.U, acc) io.acc := acc } /** Multiply accumulator * - * @param inW the width of the input operands + * @param inAW the width of the first input operand + * @param inBW the width of the second input operand * @param accW the width of the accumulator * @param signed whether the input operands are signed (defaults to false) + * @param pipes the number of pipeline stages (defaults to 0) * - * @todo Extend with different signs and operand bit-widths. + * Pipelining relies on retiming! */ -class MultiplyAccumulator(inW: Int, accW: Int, signed: Boolean = false) extends MAC(inW, accW, signed) { +class MultiplyAccumulator(inAW: Int, inBW: Int, accW: Int, signed: Boolean = false, pipes: Int = 0) + extends MAC(inAW, inBW, accW, signed, pipes) { // Compute and extend the product to the width of the accumulator if needed - val prodExt = if (2 * inW < accW) { + val prodExt = if ((inAW + inBW) < accW) { val prod = if (signed) (io.a.asSInt * io.b.asSInt).asUInt else io.a * io.b - val sext = if (signed) VecInit(Seq.fill(accW - 2 * inW)(prod(2*inW-1))).asUInt else 0.U((accW - 2 * inW).W) + val sext = if (signed) VecInit(Seq.fill(accW - inAW - inBW)(prod(inAW + inBW - 1))).asUInt else 0.U((accW - inAW - inBW).W) sext ## prod } else { (if (signed) (io.a.asSInt * io.b.asSInt) else (io.a * io.b))(accW-1, 0) } - val acc = RegInit(0.U(accW.W)) - acc := prodExt + Mux(io.zero, 0.U, acc) + // Pass the extended product through a series of registers + val dataShReg = Module(new PRShiftReg(UInt(accW.W), pipes)) + dataShReg.io.in := prodExt + + // Pass enable and zero through a shift register as needed + val enShReg = Module(new PRShiftReg(Bool(), pipes)) + enShReg.io.in := io.en + val zeroShReg = Module(new PRShiftReg(Bool(), pipes)) + zeroShReg.io.in := io.zero + + // Compute the sum and register the accumulator + val sum = Wire(UInt(accW.W)) + val acc = RegEnable(sum, 0.U(accW.W), enShReg.io.out.last) + sum := dataShReg.io.out.last + Mux(zeroShReg.io.out.last, 0.U, acc) io.acc := acc } @@ -50,16 +83,19 @@ class MultiplyAccumulator(inW: Int, accW: Int, signed: Boolean = false) extends * * @param sig the input bit matrix' signature * @param accW the width of the accumulator + * @param pipes the number of pipeline stages (defaults to 0) * @param targetDevice a string indicating the target device * (defaults to "", meaning ASIC) * @param mtrc which metric to use for selecting counters (defaults to efficiency) * @param approx the targeted approximation styles (defaults to no approximation) + * + * Pipelining relies on retiming! + * + * @todo Consider building pipelining into the compressor tree generator. */ -class BitMatrixAccumulator(sig: Signature, accW: Int, targetDevice: String = "", +class BitMatrixAccumulator(sig: Signature, accW: Int, pipes: Int = 0, targetDevice: String = "", mtrc: Char = 'e', approx: Seq[Approximation] = Seq.empty[Approximation]) - extends MxAC(sig, accW) with FlattenInstance { - val acc = RegInit(0.U(accW.W)) - + extends MxAC(sig, accW, pipes) with FlattenInstance { // Add the accumulator to the input signature val sigExt = new Signature((0 until scala.math.max(accW, sig.length)).map { c => val sigCnt = if (c < sig.length) sig.signature(c) else 0 @@ -68,7 +104,8 @@ class BitMatrixAccumulator(sig: Signature, accW: Int, targetDevice: String = "", }.toArray) // Build a compressor tree and assign its inputs and outputs - val comp = Module(CompressorTree(sigExt, targetDevice=targetDevice, mtrc=mtrc, approx=approx)) + val acc = Wire(UInt(accW.W)) + val comp = Module(CompressorTree(sigExt, targetDevice=targetDevice, mtrc=mtrc, approx=approx)) val compIns = Wire(Vec(sigExt.count, Bool())) var (inOffset, compOffset) = (0, 0) (0 until scala.math.max(accW, sig.length)).foreach { c => @@ -88,9 +125,19 @@ class BitMatrixAccumulator(sig: Signature, accW: Int, targetDevice: String = "", } } + // Pass the compressor output through a series of registers + val dataShReg = Module(new PRShiftReg(UInt(accW.W), pipes)) comp.io.in := compIns.asUInt - acc := comp.io.out - io.acc := acc + dataShReg.io.in := comp.io.out + + // Pass enable through a shift register as needed + val enShReg = Module(new PRShiftReg(Bool(), pipes)) + enShReg.io.in := io.en + + // Compute the sum and register the accumulator + val accReg = RegEnable(comp.io.out, 0.U(accW.W), enShReg.io.out.last) + acc := accReg + io.acc := accReg } /** Parallel simple accumulator @@ -99,15 +146,19 @@ class BitMatrixAccumulator(sig: Signature, accW: Int, targetDevice: String = "", * @param inW the width of the input operands * @param accW the width of the accumulator * @param signed whether the input operands are signed (defaults to false) + * @param pipes the number of pipeline stages (defaults to 0) * @param comp whether to use the compressor tree generator (defaults to false) * @param targetDevice a string indicating the target device * (defaults to "", meaning ASIC) * @param mtrc which metric to use for selecting counters (defaults to efficiency) * @param approx the targeted approximation styles (defaults to no approximation) + * + * Pipelining relies on retiming! */ class ParallelSimpleAccumulator(nIn: Int, inW: Int, accW: Int, signed: Boolean = false, - comp: Boolean = false, targetDevice: String = "", mtrc: Char = 'e', approx: Seq[Approximation] = Seq.empty[Approximation]) - extends PSA(nIn, inW, accW, signed) with FlattenInstance { + pipes: Int = 0, comp: Boolean = false, targetDevice: String = "", mtrc: Char = 'e', + approx: Seq[Approximation] = Seq.empty[Approximation]) + extends PSA(nIn, inW, accW, signed, pipes) with FlattenInstance { // Extend the inputs to the width of the accumulator if needed val insExt = if (inW < accW) { if (signed) { @@ -123,18 +174,28 @@ class ParallelSimpleAccumulator(nIn: Int, inW: Int, accW: Int, signed: Boolean = val sig = new Signature(Array.fill(extW)(nIn)) // Build a bit matrix accumulator and assign its inputs and outputs - val mxAcc = Module(new BitMatrixAccumulator(sig, accW, targetDevice, mtrc, approx)) + val mxAcc = Module(new BitMatrixAccumulator(sig, accW, pipes, targetDevice, mtrc, approx)) val accIns = VecInit((0 until extW).flatMap { c => (0 until nIn).map(i => insExt(i)(c)) }).asUInt + mxAcc.io.en := io.en mxAcc.io.zero := io.zero mxAcc.io.in := accIns io.acc := mxAcc.io.acc } else { - // Instantiate an accumulator register - val acc = RegInit(0.U(accW.W)) + // Pass the parallel sum through a series of registers + val dataShReg = Module(new PRShiftReg(UInt(accW.W), pipes)) + dataShReg.io.in := insExt.reduceTree(_ +& _) + + // Pass enable and zero through a shift register as needed + val enShReg = Module(new PRShiftReg(Bool(), pipes)) + enShReg.io.in := io.en + val zeroShReg = Module(new PRShiftReg(Bool(), pipes)) + zeroShReg.io.in := io.zero - // Connect and sum the extended inputs - acc := insExt.reduceTree(_ +& _) + Mux(io.zero, 0.U, acc) + // Compute the sum and register the accumulator + val sum = Wire(UInt(accW.W)) + val acc = RegEnable(sum, 0.U(accW.W), enShReg.io.out.last) + sum := dataShReg.io.out.last + Mux(zeroShReg.io.out.last, 0.U, acc) io.acc := acc } } @@ -142,30 +203,31 @@ class ParallelSimpleAccumulator(nIn: Int, inW: Int, accW: Int, signed: Boolean = /** Parallel multiply accumulator * * @param nIn the number of parallel input operands - * @param inW the width of the input operands + * @param inAW the width of the first input operands + * @param inBW the width of the second input operands * @param accW the width of the accumulator * @param signed whether the input operands are signed (defaults to false) + * @param pipes the number of pipeline stages (defaults to 0) * @param comp whether to use the compressor tree generator (defaults to false) * @param targetDevice a string indicating the target device * (defaults to "", meaning ASIC) * @param mtrc which metric to use for selecting counters (defaults to efficiency) * @param approx the targeted approximation styles (defaults to no approximation) * - * @todo Extend with different signs and operand bit-widths. + * Pipelining relies on retiming! */ -class ParallelMultiplyAccumulator(nIn: Int, inW: Int, accW: Int, signed: Boolean = false, - comp: Boolean = false, targetDevice: String = "", mtrc: Char = 'e', approx: Seq[Approximation] = Seq.empty[Approximation]) - extends PMAC(nIn, inW, accW, signed) with FlattenInstance { - val aW = io.as.head.getWidth - val bW = io.bs.head.getWidth +class ParallelMultiplyAccumulator(nIn: Int, inAW: Int, inBW: Int, accW: Int, signed: Boolean = false, + pipes: Int = 0, comp: Boolean = false, targetDevice: String = "", mtrc: Char = 'e', + approx: Seq[Approximation] = Seq.empty[Approximation]) + extends PMAC(nIn, inAW, inBW, accW, signed, pipes) with FlattenInstance { // Depending on the parameters passed, generate a naive accumulator or use // the custom compressor tree generator if (comp) { // Compute some constants and generate the sign-extension constant - val midLo = scala.math.min(aW, bW) - 1 - val midHi = scala.math.max(aW, bW) - 1 - val upper = aW + bW - 1 + val midLo = scala.math.min(inAW, inBW) - 1 + val midHi = scala.math.max(inAW, inBW) - 1 + val upper = inAW + inBW - 1 val extConst = if (signed) Seq.fill(nIn) { (BigInt(-1) << upper) + (BigInt(1) << midLo) + (BigInt(1) << midHi) }.sum else BigInt(0) @@ -177,7 +239,7 @@ class ParallelMultiplyAccumulator(nIn: Int, inW: Int, accW: Int, signed: Boolean */ def dotCount(col: Int): Int = { if (col < midLo) col + 1 - else if (midLo <= col && col <= midHi) scala.math.min(aW, bW) + else if (midLo <= col && col <= midHi) scala.math.min(inAW, inBW) else if (col < upper) upper - col else 0 } @@ -188,7 +250,7 @@ class ParallelMultiplyAccumulator(nIn: Int, inW: Int, accW: Int, signed: Boolean * @param col the index of the column * @return the index of the least significant row */ - def lsRow(col: Int): Int = if (col < inW) 0 else (col - inW + 1) + def lsRow(col: Int): Int = if (col < inBW) 0 else (col - inBW + 1) // Generate the signature of the needed compressor tree val sig = new Signature((0 until scala.math.max(upper + 1, accW)).map { c => @@ -200,22 +262,22 @@ class ParallelMultiplyAccumulator(nIn: Int, inW: Int, accW: Int, signed: Boolean // Compute the partial products val prods = if (signed) { (0 until nIn).map { i => - (0 until aW).map { r => - val pprod = VecInit((0 until bW).map { c => + (0 until inAW).map { r => + val pprod = VecInit((0 until inBW).map { c => val dot = io.as(i)(r) & io.bs(i)(c) - if (c == (bW - 1)) !dot else dot + if (c == (inBW - 1)) !dot else dot }).asUInt - if (r == (aW - 1)) ~pprod else pprod + if (r == (inAW - 1)) ~pprod else pprod } } } else { (0 until nIn).map { i => - (0 until aW).map { r => VecInit(Seq.fill(bW)(io.as(i)(r))).asUInt & io.bs(i) } + (0 until inAW).map { r => VecInit(Seq.fill(inBW)(io.as(i)(r))).asUInt & io.bs(i) } } } // Build a bit matrix accumulator and assign its inputs and outputs - val mxAcc = Module(new BitMatrixAccumulator(sig, accW, targetDevice, mtrc, approx)) + val mxAcc = Module(new BitMatrixAccumulator(sig, accW, pipes, targetDevice, mtrc, approx)) val accIns = Wire(Vec(sig.count, Bool())) var compOffset = 0 (0 until sig.length).foreach { c => @@ -236,15 +298,16 @@ class ParallelMultiplyAccumulator(nIn: Int, inW: Int, accW: Int, signed: Boolean } } + mxAcc.io.en := io.en mxAcc.io.zero := io.zero mxAcc.io.in := accIns.asUInt io.acc := mxAcc.io.acc } else { // Compute and sign-extend the incoming products as needed - val prodsExt = if (aW + bW < accW) { + val prodsExt = if ((inAW + inBW) < accW) { VecInit(io.as.zip(io.bs).map { case (a, b) => val prod = if (signed) (a.asSInt * b.asSInt).asUInt else (a * b) - val sext = if (signed) VecInit(Seq.fill(accW - aW - bW)(prod(aW+bW-1))).asUInt else 0.U((accW - aW - bW).W) + val sext = if (signed) VecInit(Seq.fill(accW - inAW - inBW)(prod(inAW - inBW - 1))).asUInt else 0.U((accW - inAW - inBW).W) sext ## prod }) } else { @@ -253,11 +316,20 @@ class ParallelMultiplyAccumulator(nIn: Int, inW: Int, accW: Int, signed: Boolean }) } - // Instantiate an accumulator register - val acc = RegInit(0.U(accW.W)) + // Pass the parallel sum through a series of registers + val dataShReg = Module(new PRShiftReg(UInt(accW.W), pipes)) + dataShReg.io.in := prodsExt.reduceTree(_ +& _) + + // Pass enable and zero through a shift register as needed + val enShReg = Module(new PRShiftReg(Bool(), pipes)) + enShReg.io.in := io.en + val zeroShReg = Module(new PRShiftReg(Bool(), pipes)) + zeroShReg.io.in := io.zero // Connect and sum the extended products - acc := prodsExt.reduceTree(_ +& _) + Mux(io.zero, 0.U, acc) + val sum = Wire(UInt(accW.W)) + val acc = RegEnable(sum, 0.U(accW.W), enShReg.io.out.last) + sum := dataShReg.io.out.last + Mux(zeroShReg.io.out.last, 0.U, acc) io.acc := acc } } diff --git a/src/main/scala/approx/accumulation/package.scala b/src/main/scala/approx/accumulation/package.scala index a4081a9..510bc7b 100644 --- a/src/main/scala/approx/accumulation/package.scala +++ b/src/main/scala/approx/accumulation/package.scala @@ -5,14 +5,16 @@ import chisel3._ import approx.multiplication.comptree.Signature package object accumulation { - - /** @todo Extend all these with support for pipelining! */ /** Accumulator IO bundle * * @param accW the width of the accumulator + * + * Asserting `en` enables accumulation of the input operand(s) into the + * accumulator. Asserting `zero` resets the accumulator to zero. */ private[accumulation] abstract class AccumulatorIO(accW: Int) extends Bundle { + val en = Input(Bool()) val zero = Input(Bool()) val acc = Output(UInt(accW.W)) } @@ -28,12 +30,13 @@ package object accumulation { /** Multiply accumulator IO bundle * - * @param inW the width of the input operands + * @param inAW the width of the first input operand + * @param inBW the width of the second input operand * @param accW the width of the accumulator */ - class MultiplyAccumulatorIO(inW: Int, accW: Int) extends AccumulatorIO(accW) { - val a = Input(UInt(inW.W)) - val b = Input(UInt(inW.W)) + class MultiplyAccumulatorIO(inAW: Int, inBW: Int, accW: Int) extends AccumulatorIO(accW) { + val a = Input(UInt(inAW.W)) + val b = Input(UInt(inBW.W)) } /** Bit matrix accumulator IO bundle @@ -58,12 +61,13 @@ package object accumulation { /** Parallel multiply accumulator IO bundle * * @param nIn the number of parallel input operands - * @param inW the width of the input operands + * @param inAW the width of the first input operands + * @param inBW the width of the second input operands * @param accW the width of the accumulator */ - class ParallelMultiplyAccumulatorIO(nIn: Int, inW: Int, accW: Int) extends AccumulatorIO(accW) { - val as = Input(Vec(nIn, UInt(inW.W))) - val bs = Input(Vec(nIn, UInt(inW.W))) + class ParallelMultiplyAccumulatorIO(nIn: Int, inAW: Int, inBW: Int, accW: Int) extends AccumulatorIO(accW) { + val as = Input(Vec(nIn, UInt(inAW.W))) + val bs = Input(Vec(nIn, UInt(inBW.W))) } /** Abstract simple accumulator module class @@ -71,29 +75,31 @@ package object accumulation { * @param inW the width of the input operand * @param accW the width of the accumulator * @param signed whether input operands are signed + * @param pipes the number of pipeline stages */ - abstract class SA(val inW: Int, val accW: Int, val signed: Boolean) extends Module { + abstract class SA(val inW: Int, val accW: Int, val signed: Boolean, val pipes: Int) extends Module { val io = IO(new SimpleAccumulatorIO(inW, accW)) } /** Abstract multiply accumulator module class * - * @param inW the width of the input operands + * @param inAW the width of the first input operand + * @param inBW the width of the second input operand * @param accW the width of the accumulator * @param signed whether input operands are signed - * - * @todo Extend with different operand widths and different signs. + * @param pipes the number of pipeline stages */ - abstract class MAC(val inW: Int, val accW: Int, val signed: Boolean) extends Module { - val io = IO(new MultiplyAccumulatorIO(inW, accW)) + abstract class MAC(val inAW: Int, val inBW: Int, val accW: Int, val signed: Boolean, val pipes: Int) extends Module { + val io = IO(new MultiplyAccumulatorIO(inAW, inBW, accW)) } /** Abstract bit matrix accumulator module class * * @param sig the input bit matrix' signature * @param accW the width of the accumulator + * @param pipes the number of pipeline stages */ - abstract class MxAC(val sig: Signature, val accW: Int) extends Module { + abstract class MxAC(val sig: Signature, val accW: Int, val pipes: Int) extends Module { val io = IO(new MatrixAccumulatorIO(sig, accW)) } @@ -103,21 +109,22 @@ package object accumulation { * @param inW the width of the input operands * @param accW the width of the accumulator * @param signed whether the input operands are signed + * @param pipes the number of pipeline stages */ - abstract class PSA(val nIn: Int, val inW: Int, val accW: Int, val signed: Boolean) extends Module { + abstract class PSA(val nIn: Int, val inW: Int, val accW: Int, val signed: Boolean, val pipes: Int) extends Module { val io = IO(new ParallelSimpleAccumulatorIO(nIn, inW, accW)) } /** Parallel multiply accumulator module class * * @param nIn the number of parallel input operands - * @param inW the width of the input operands + * @param inAW the width of the first input operands + * @param inBW the width of the second input operands * @param accW the width of the accumulator * @param signed whether the input operands are signed - * - * @todo Extend with different operand widths and different signs. + * @param pipes the number of pipeline stages */ - abstract class PMAC(val nIn: Int, val inW: Int, val accW: Int, val signed: Boolean) extends Module { - val io = IO(new ParallelMultiplyAccumulatorIO(nIn, inW, accW)) + abstract class PMAC(val nIn: Int, val inAW: Int, val inBW: Int, val accW: Int, val signed: Boolean, val pipes: Int) extends Module { + val io = IO(new ParallelMultiplyAccumulatorIO(nIn, inAW, inBW, accW)) } } diff --git a/src/main/scala/approx/addition/Exact.scala b/src/main/scala/approx/addition/Exact.scala index c495634..8128b05 100644 --- a/src/main/scala/approx/addition/Exact.scala +++ b/src/main/scala/approx/addition/Exact.scala @@ -217,8 +217,6 @@ class CSA(width: Int, val stages: Int) extends Adder(width) { /** Exact parallel prefix adder base class * * @param width the width of the adder - * - * @todo Remove redundant extension bits in this design. */ abstract class PPA(width: Int) extends Adder(width) { /** Bundle of generate and propagate bits */ diff --git a/src/main/scala/approx/multiplication/comptree/CompressorTree.scala b/src/main/scala/approx/multiplication/comptree/CompressorTree.scala index b0ac820..71e507c 100644 --- a/src/main/scala/approx/multiplication/comptree/CompressorTree.scala +++ b/src/main/scala/approx/multiplication/comptree/CompressorTree.scala @@ -6,6 +6,7 @@ import chisel3.util.experimental.FlattenInstance import scala.collection.mutable import Counters._ +import TerminalAdders._ /** Compressor tree generator object */ object CompressorTree { @@ -31,8 +32,6 @@ object CompressorTree { * * Not expected to be manually instantiated by users. Instead, one may rely * on the companion object for a simplified interface to this generator. - * - * @todo Update to make use of the state in LUT placement and pipelining. */ private[comptree] class CompressorTree(val sig: Signature, context: Context) extends Module with FlattenInstance { private val state = new State() @@ -44,21 +43,29 @@ private[comptree] class CompressorTree(val sig: Signature, context: Context) ext val out = Output(UInt(outW.W)) }) - // Select the appropriate set of counters and sort them by the desired - // fitness metric + // Select the appropriate sets of regular and variable-length counters + // and sort them by the desired fitness metric. For the latter, only + // evaluate these metrics at a length of three, as this seems a good + // balancing point in Hossfeld et al. [2024] val isApprox = context.approximations.exists(_.isInstanceOf[Miscounting]) private val counters = (if (isApprox) context.counters.approxCounters else context.counters.exactCounters) .sortBy { cntr => context.metric match { case FitnessMetric.Efficiency => cntr.efficiency case FitnessMetric.Strength => cntr.strength }}.reverse + private val vlcounters = (if (isApprox) context.counters.approxVarLenCounters else context.counters.exactVarLenCounters) + .sortBy { cntr => context.metric match { + case FitnessMetric.Efficiency => cntr.efficiency(3) + case FitnessMetric.Strength => cntr.strength(3) + }}.reverse // Generate and connect the inputs to a bit matrix private val inMtrx = buildMatrix(sig, io.in) // Iteratively compress the bit matrix till the compression goal is reached private val mtrcs = mutable.ArrayBuffer(inMtrx) - while (!mtrcs.last.meetsGoal(context.goal)) mtrcs += compress(mtrcs.last, counters) + while (!mtrcs.last.meetsGoal(context.goal)) + mtrcs += compress(mtrcs.last, counters, vlcounters) // Perform a final summation and output the result io.out := finalSummation(mtrcs.last) @@ -124,14 +131,16 @@ private[comptree] class CompressorTree(val sig: Signature, context: Context) ext * * @param bits the input bit matrix * @param cntrs the appropriate available counters + * @param vlcntrs the appropriate available variable-length counters * @return the bit matrix after compression, before final summation */ - private[CompressorTree] def compress(bits: BitMatrix, cntrs: Seq[Counter]): BitMatrix = { + private[CompressorTree] def compress(bits: BitMatrix, cntrs: Seq[Counter], vlcntrs: Seq[VarLenCounter]): BitMatrix = { state.addStage() // Place a new counter in the bit matrix while possible val res = new BitMatrix() - while (!bits.meetsGoal(context.goal)) placeLargestCntr(cntrs, bits, res) + while (!bits.meetsGoal(context.goal)) + placeLargestCntr(cntrs, vlcntrs, bits, res) // Transfer any remaining bits to the next stage transferBits(bits, res) @@ -142,6 +151,7 @@ private[comptree] class CompressorTree(val sig: Signature, context: Context) ext * or efficiency) in a compression stage * * @param cntrs the selection of counters to pick from + * @param vlcntrs the selection of variable-length counters to pick from * @param inBits the current bit matrix (will be updated) * @param outBits the output bit matrix (will be updated) * @@ -153,7 +163,7 @@ private[comptree] class CompressorTree(val sig: Signature, context: Context) ext * process of a column (i.e., if its placement brings a column's bit count * below the compression goal) */ - private[CompressorTree] def placeLargestCntr(cntrs: Seq[Counter], inBits: BitMatrix, outBits: BitMatrix): Unit = { + private[CompressorTree] def placeLargestCntr(cntrs: Seq[Counter], vlcntrs: Seq[VarLenCounter], inBits: BitMatrix, outBits: BitMatrix): Unit = { // Find the least significant column that still needs compression val lsColOpt = inBits.bits .zipWithIndex @@ -161,36 +171,70 @@ private[comptree] class CompressorTree(val sig: Signature, context: Context) ext lsColOpt match { case Some(lsCol) => - // Pick the best counter that fits in the bit matrix starting from - // the found, least significant column - val bestCntrOpt = cntrs.collectFirst { - case cntr if canPlaceCntr(cntr, inBits, outBits, lsCol) => cntr + // Pick the best regular and variable-length counters that fit in + // the bit matrix starting from the least significant column found + val bestRegCntrOpt = cntrs.collectFirst { + case rcntr if canPlaceCntr(rcntr, inBits, outBits, lsCol) => + val mtrc = context.metric match { + case FitnessMetric.Efficiency => rcntr.efficiency + case FitnessMetric.Strength => rcntr.strength + } + (rcntr, mtrc) + } + val bestVLCntrOpt = vlcntrs.collectFirst { + case vlcntr if canPlaceVLCntr(vlcntr, inBits, lsCol) => // search for length 1 + val len = maxVLCntrLen(vlcntr, inBits, lsCol) + val mtrc = context.metric match { + case FitnessMetric.Efficiency => vlcntr.efficiency(len) + case FitnessMetric.Strength => vlcntr.strength(len) + } + (vlcntr, len, mtrc) + } + + // Select the best of the two counters + val bestCntrOpt = (bestRegCntrOpt, bestVLCntrOpt) match { + case (Some((rcntr, rMtrc)), Some((vlcntr, len, vlMtrc))) => + if (rMtrc >= vlMtrc) Some((rcntr, 0)) else Some((vlcntr, len)) + case (Some((rcntr, _)), _) => Some((rcntr, 0)) + case (_, Some((vlcntr, len, _))) => Some((vlcntr, len)) + case _ => None } // Construct the counter and connect it accordingly - bestCntrOpt match { - case Some(cntr) => - state.addCounter(cntr) - - val hwCntr = context.counters.construct(cntr, state) - - // ... inputs first - hwCntr.io.in := VecInit((0 until cntr.sig._1.length).flatMap { col => - (0 until cntr.sig._1(col)).map(_ => inBits.popBit(col + lsCol)) - }.reverse).asUInt - - // ... outputs second - var index = 0 - (0 until cntr.sig._2.length).foreach { col => - (0 until cntr.sig._2(col)).foreach { _ => - outBits.insertBit(hwCntr.io.out(index), col + lsCol) - index += 1 - } - } + val (inSig, outSig, hwCntr) = bestCntrOpt match { + case Some((rcntr: Counter, _)) => // regular counter + state.addCounter(rcntr) + + val hwCntr = context.counters.construct(rcntr, state) + val inSig = rcntr.sig._1 + val outSig = rcntr.sig._2 + (inSig, outSig, hwCntr) + + case Some((vlcntr: VarLenCounter, len: Int)) => // variable-length counter + state.addVLCounter(vlcntr, len) + + val hwCntr = context.counters.construct(vlcntr, len, state) + val inSig = vlcntr.inSigFn(len) + val outSig = vlcntr.outSigFn(len) + (inSig, outSig, hwCntr) case _ => throw new Exception("cannot place a counter in input bit matrix") } + // ... inputs first + hwCntr.io.in := VecInit((0 until inSig.length).flatMap { col => + (0 until inSig(col)).map(_ => inBits.popBit(col + lsCol)) + }.reverse).asUInt + + // ... outputs second + var index = 0 + (0 until outSig.length).foreach { col => + (0 until outSig(col)).foreach { _ => + outBits.insertBit(hwCntr.io.out(index), col + lsCol) + index += 1 + } + } + case _ => throw new Exception("input bit matrix does not need compression") } } @@ -228,6 +272,53 @@ private[comptree] class CompressorTree(val sig: Signature, context: Context) ext } } + /** Check whether a particular variable-length counter fits in an input + * bit matrix starting from the given column + * + * @param vlcntr the variable-length counter + * @param inBits the input bit matrix + * @param lsCol the least significant (starting) column + * @return true if the counter fits, false otherwise + */ + private[CompressorTree] def canPlaceVLCntr(vlcntr: VarLenCounter, inBits: BitMatrix, lsCol: Int): Boolean = { + val inSig = vlcntr.inSigFn(1) // search for length 1 + inSig.zipWithIndex.forall { case (cnt, col) => inBits.colCount(lsCol + col) >= cnt } + } + + /** Determine the maximum length of a variable-length counter that + * can be placed in an input bit matrix starting from the given column + * + * @param vlcntr the variable-length counter + * @param inBits the input bit matrix + * @param lsCol the least significant (starting) column + * @return the maximum length of the variable-length counter that can be placed + */ + private[CompressorTree] def maxVLCntrLen(vlcntr: VarLenCounter, inBits: BitMatrix, lsCol: Int): Int = { + /** Binary search to find the maximum permissible length */ + def search(low: Int, high: Int, maxLen: Int): Int = { + if (low >= high) { + maxLen + } else { + val mid = (low + high) / 2 + val inSig = vlcntr.inSigFn(mid) + val fits = inSig.zipWithIndex.forall { case (cnt, col) => + inBits.colCount(lsCol + col) >= cnt + } + + if (fits) search(mid + 1, high, mid) + else search(low, mid, maxLen) + } + } + + /** Default to length 1 as this function is called after `canPlaceVLCntr`. + * Maximum length is constrained either by the width of the input matrix + * or the height of its tallest column + */ + val lowInit = 2 + val highInit = scala.math.max(inBits.length - lsCol, inBits.bits.map(_.size).max) + search(lowInit, highInit, 1) + } + /** Check if the stacked pair of an input and an output bit matrix * meet the compression goal * @@ -262,8 +353,6 @@ private[comptree] class CompressorTree(val sig: Signature, context: Context) ext * * For now, this implementation relies on the synthesis tools to remove * gates from constant-low bit positions. - * - * @todo Implement ternary and quaternary adders for FPGAs. */ private[CompressorTree] def finalSummation(bits: BitMatrix): UInt = { require((0 until outW).forall(i => bits.colCount(i) <= context.goal)) @@ -273,12 +362,26 @@ private[comptree] class CompressorTree(val sig: Signature, context: Context) ext (bits.colCount(i) until context.goal).foreach(_ => bits.insertBit(false.B, i)) } - // Collect the operands as integers and sum them + // Collect the operands as integers val oprs = WireDefault(VecInit((0 until context.goal).map { _ => val op = Wire(Vec(outW, Bool())) (0 until outW).foreach(i => op(i) := bits.popBit(i)) op.asUInt })) - oprs.reduceTree(_ + _) + + // Depending on the target device, instantiate a different adder + context.terminal match { + case "ternary" => + val adder = Module(new TernaryAdder(outW)) + adder.io.in := oprs + adder.io.out + case "quaternary" => + val adder = Module(new QuaternaryAdder(outW)) + adder.io.in := oprs + adder.io.out + case _ => + // Default to a simple binary adder tree + oprs.reduceTree(_ + _) + } } } diff --git a/src/main/scala/approx/multiplication/comptree/Counters.scala b/src/main/scala/approx/multiplication/comptree/Counters.scala index 7e55eac..6dbad3d 100644 --- a/src/main/scala/approx/multiplication/comptree/Counters.scala +++ b/src/main/scala/approx/multiplication/comptree/Counters.scala @@ -21,6 +21,8 @@ import approx.util.Xilinx.Versal.{genLUT6CYInitString, LOOKAHEAD8, LUT6CY} * Atoms can (and should) be composed arbitrarily to form larger counters. * For simplicity, we only consider compositions of two counters. All counter * libraries per default include half adders and full adders. + * + * @todo Consider entirely reworking how composed counters are constructed. */ private[comptree] object Counters { /** Abstract atom class @@ -71,12 +73,55 @@ private[comptree] object Counters { val efficiency: Double = if (cost == -1) Double.MinValue else (sig._1.sum - sig._2.sum).toDouble / cost } + /** Abstract variable-length counter class + * + * @param inSigFn function to compute the input signature for a given length + * @param outSigFn function to compute the output signature for a given length + * @param costFn function to compute the hardware cost of the counter for a + * given length (for FPGAs: no. of LUTs, for ASICs: ~number of XORs) + * + * Variable-length counters are counters that can be cascaded to form + * arbitrarily long chains. They are defined by a triple of functions + * that define the input and output signatures and cost for a given length. + */ + abstract class VarLenCounter( + val inSigFn : Int => Array[Int], + val outSigFn: Int => Array[Int], + val costFn : Int => Int = (n: Int) => -1 + ) { + this: CounterType => + + /** Returns the strength of the counter for length `n` + * + * @param n the length of the counter + */ + def strength(n: Int): Double = inSigFn(n).sum.toDouble / outSigFn(n).sum + + /** Returns the efficiency of the counter for length `n` + * + * @param n the length of the counter + * + * We use the definition by Preusser [2017] and adapt it somewhat for + * Intel FPGAs and ASICs. + */ + def efficiency(n: Int): Double = { + if (costFn(n) == -1) Double.MinValue + else (inSigFn(n).sum - outSigFn(n).sum).toDouble / costFn(n) + } + } + /** Returns true iff the counter mixes in `Approximate` */ def isApproximate(ctr: Counter): Boolean = ctr match { case _: Counter with Approximate => true case _ => false } + /** Returns true iff the variable-length counter mixes in `Approximate` */ + def isApproximate(ctr: VarLenCounter): Boolean = ctr match { + case _: VarLenCounter with Approximate => true + case _ => false + } + /** Abstract hardware counter class * * @param sig the signature of the counter @@ -106,6 +151,12 @@ private[comptree] object Counters { // Collection of approximate and exact counters val approxCounters: Seq[Counter] + // Collection of exact variable-length counters + val exactVarLenCounters: Seq[VarLenCounter] + + // Collection of approximate and exact variable-length counters + val approxVarLenCounters: Seq[VarLenCounter] + /** Function to construct a counter * * @param cntr the counter to construct @@ -113,6 +164,15 @@ private[comptree] object Counters { * @return a module representing the counter */ def construct(cntr: Counter, state: State): Instance[HardwareCounter] + + /** Function to construct a variable-length counter + * + * @param cntr the variable-length counter to construct + * @param len the length of the counter to construct + * @param state the present compressor generator state + * @return a module representing the counter + */ + def construct(cntr: VarLenCounter, len: Int, state: State): Instance[HardwareCounter] } /** Collection of counters for ASIC @@ -185,6 +245,12 @@ private[comptree] object Counters { (new Counter8_111) ) + /** Collection of exact variable-length counters */ + lazy val exactVarLenCounters: Seq[VarLenCounter] = Seq() + + /** Collection of approximate and exact variable-length counters */ + lazy val approxVarLenCounters: Seq[VarLenCounter] = exactVarLenCounters ++ Seq() + /** Function to construct a counter */ def construct(cntr: Counter, state: State): Instance[HardwareCounter] = { /** Generic extension of the hardware counter for returning */ @@ -284,6 +350,23 @@ private[comptree] object Counters { state.cntrDefs(cntrName) = Definition(new ASICCounter(cntr)) Instance(state.cntrDefs(cntrName)) } + + /** Function to construct a variable-length counter */ + def construct(cntr: VarLenCounter, len: Int, state: State): Instance[HardwareCounter] = { + /** Generic extension of the hardware counter for returning */ + class ASICCounter(counter: VarLenCounter, length: Int) extends HardwareCounter((counter.inSigFn(length), counter.outSigFn(length))) { + /** Different counters require different amounts of logic here */ + counter match { + case _ => throw new IllegalArgumentException(s"cannot generate hardware for unsupported variable-length counter ${counter}") + } + } + + // Store a definition of this hardware counter for future reference + val cntrName = s"${cntr.getClass().getName()}_$len" + if (!state.cntrDefs.contains(cntrName)) + state.cntrDefs(cntrName) = Definition(new ASICCounter(cntr, len)) + Instance(state.cntrDefs(cntrName)) + } } /** Collection of counters for Xilinx 7 Series and UltraScale FPGAs @@ -297,8 +380,6 @@ private[comptree] object Counters { * - (3 : 1,1] * - (2,5 : 1,2,1] * - (8 : 1,1,1] (approximate) - * - * @todo Extend to support approximate compound counters. */ object SevenSeries extends Library { /** Use an Atom class specific to this library to simplify type checking */ @@ -330,7 +411,7 @@ private[comptree] object Counters { * * Implementation of the counter from Boroumand and Brisk [2019] */ - private[SevenSeries] class Counter8_111 extends Counter((Array(8), Array(1,1,1)), 4) with Approximate + private[SevenSeries] class Counter8_111 extends Counter((Array(8), Array(1, 1, 1)), 4) with Approximate /** Function to compute the signature for a composed counter * @@ -377,6 +458,12 @@ private[comptree] object Counters { (new Counter8_111) ) + /** Collection of exact variable-length counters */ + lazy val exactVarLenCounters: Seq[VarLenCounter] = Seq() + + /** Collection of approximate and exact variable-length counters */ + lazy val approxVarLenCounters: Seq[VarLenCounter] = exactVarLenCounters ++ Seq() + /** Function to construct a counter */ def construct(cntr: Counter, state: State): Instance[HardwareCounter] = { /** Generic extension of the hardware counter for returning */ @@ -678,21 +765,37 @@ private[comptree] object Counters { state.cntrDefs(cntrName) = Definition(new SevenSeriesCounter(cntr)) Instance(state.cntrDefs(cntrName)) } + + /** Function to construct a variable-length counter */ + def construct(cntr: VarLenCounter, len: Int, state: State): Instance[HardwareCounter] = { + /** Generic extension of the hardware counter for returning */ + class SevenSeriesCounter(counter: VarLenCounter, length: Int) extends HardwareCounter((counter.inSigFn(length), counter.outSigFn(length))) { + /** Different counters require different amounts of logic here */ + counter match { + case _ => throw new IllegalArgumentException(s"cannot generate hardware for unsupported variable-length counter ${counter}") + } + } + + // Store a definition of this hardware counter for future reference + val cntrName = s"${cntr.getClass().getName()}_$len" + if (!state.cntrDefs.contains(cntrName)) + state.cntrDefs(cntrName) = Definition(new SevenSeriesCounter(cntr, len)) + Instance(state.cntrDefs(cntrName)) + } } /** Collection of counters for Xilinx Versal FPGAs * * The current library includes the following atoms: - * - @todo Add atoms and compound counters here! + * - (2,2) + * - (1,4) * And the following standalone counters: * - (2 : 1,1] * - (3 : 1,1] + * - (2,5 : 1,2,1] * - (7 : 1,1,1] * - (8 : 1,1,1] (approximate) - * - * @todo Extend with the (10 : 4,2] counter from Hossfeld et al. [2024] - * - * @todo Extend to support approximate compound counters. + * - (10 : 4,2] */ object Versal extends Library { /** Use an Atom class specific to this library to simplify type checking */ @@ -701,57 +804,101 @@ private[comptree] object Counters { def inSig: Array[Int] } + /** Atom (2,2) */ + private[Versal] class Atom22(val inSig: Array[Int] = Array(2, 2)) extends VersalAtom(2) + /** Atom (1,4) */ + private[Versal] class Atom14(val inSig: Array[Int] = Array(4, 1)) extends VersalAtom(2) + /** Counter (2 : 1,1] (half adder) */ private[Versal] class Counter2_11 extends Counter((Array(2), Array(1, 1)), 1) with Exact /** Counter (3 : 1,1] (full adder) */ private[Versal] class Counter3_11 extends Counter((Array(3), Array(1, 1)), 1) with Exact + /** Counter (2,5 : 1,2,1] + * + * Adaptation of the counter from Preusser [2017] + */ + private[Versal] class Counter25_121 extends Counter((Array(5, 2), Array(1, 2, 1)), 2) with Exact + /** Counter (7 : 1,1,1] */ private[Versal] class Counter7_111 extends Counter((Array(7), Array(1, 1, 1)), 3) with Exact + /** Counter (10 : 4,2] + * + * Adaptation of the counter from Hossfeld et al. [2024] + */ + private[Versal] class Counter10_42 extends Counter((Array(10), Array(4, 2)), 3) with Exact + /** Approximate counter (8 : 1,1,1] * * Adaptation of the counter from Boroumand and Brisk [2019] */ - private[Versal] class Counter8_111 extends Counter((Array(8), Array(1,1,1)), 3) with Approximate + private[Versal] class Counter8_111 extends Counter((Array(8), Array(1, 1, 1)), 3) with Approximate /** Function to compute the signature for a composed counter * - * @param atom1 the first atom - * @param atom2 the second atom - * @return the signature arising from the two atoms' composition + * @param atoms a list of atoms to compose + * @return the signature arising from the atoms' composition */ - private[Versal] def compose(atom1: VersalAtom, atom2: VersalAtom): (Array[Int], Array[Int]) = { - val ins = { - val comb = atom1.inSig ++ atom2.inSig - comb(0) += 1 - comb - } - val outs = Array(1, 1, 1, 1, 1) + private[Versal] def compose(atoms: Seq[VersalAtom]): (Array[Int], Array[Int]) = { + require(!atoms.isEmpty && atoms.length <= 4, "can only compose between 1 and 4 atoms") + val ins = atoms.flatMap(_.inSig).toArray + ins(0) += 1 + val outs = Array.fill(2*atoms.length + 1)(1) (ins, outs) } /** Counter composed from two atoms * - * @param atom1 the first atom - * @param atom2 the second atom + * @param atoms a list of atoms to compose */ - private[Versal] class ComposedCounter(val atom1: VersalAtom, val atom2: VersalAtom) - extends Counter(compose(atom1, atom2), atom1.luts + atom2.luts) with Exact + private[Versal] class ComposedCounter(val atoms: Seq[VersalAtom]) + extends Counter(compose(atoms), atoms.map(_.luts).sum) with Exact + + /** Ripple-sum counter with signature (2n+1 : n,1] + * + * Adaptation of the counter from Hossfeld et al. [2024] + */ + private[Versal] class RippleSum extends VarLenCounter( + (n: Int) => Array(1, n), + (n: Int) => Array(2*n + 1), + (n: Int) => n + ) with Exact + + /** Dual-rail ripple-sum counter with signature (n+1,4n+1 : n,n+1,1] + * + * Adaptation of the counter from Hossfeld et al. [2024] + */ + private[Versal] class DualRailRippleSum extends VarLenCounter( + (n: Int) => Array(4*n + 1, n + 1), + (n: Int) => Array(1, n + 1, n), + (n: Int) => 2*n + ) with Exact /** Collection of exact counters */ lazy val exactCounters: Seq[Counter] = Seq( (new Counter2_11), (new Counter3_11), - (new Counter7_111) - ) + (new Counter25_121), + (new Counter7_111), + (new Counter10_42) + ) ++ (Seq.fill(4) { new Atom22 } ++ Seq.fill(4) { new Atom14 }).combinations(4).map(atoms => new ComposedCounter(atoms)) /** Collection of approximate and exact counters */ lazy val approxCounters: Seq[Counter] = exactCounters ++ Seq( (new Counter8_111) ) + /** Collection of exact variable-length counters */ + lazy val exactVarLenCounters: Seq[VarLenCounter] = Seq( + (new RippleSum), + (new DualRailRippleSum) + ) + + /** Collection of approximate and exact variable-length counters */ + lazy val approxVarLenCounters: Seq[VarLenCounter] = exactVarLenCounters ++ Seq() + /** Function to construct a counter */ def construct(cntr: Counter, state: State): Instance[HardwareCounter] = { /** Generic extension of the hardware counter for returning */ @@ -794,6 +941,43 @@ private[comptree] object Counters { // Outputs: [c0 = out(1), s0 = out(0)] io.out := lut.io.O51 ## lut.io.O52 + case _: Counter25_121 => + // Boolean functions for the two LUTs + val lutLOFO51 = (ins: Seq[Boolean]) => ins.take(5).reduceLeft(_ ^ _) + val lutLOFO52 = (ins: Seq[Boolean]) => { + val s2 = ins(2) ^ ins(3) ^ ins(4) + (ins(0) && s2) || (ins(1) && s2) || (ins(0) && ins(1)) + } + val lutHIFO51 = (ins: Seq[Boolean]) => { + val c0 = (ins(2) && ins(3)) || (ins(2) && ins(4)) || (ins(3) && ins(4)) + ins(0) ^ ins(1) ^ c0 + } + val lutHIFO52 = (ins: Seq[Boolean]) => { + val c0 = (ins(2) && ins(3)) || (ins(2) && ins(4)) || (ins(3) && ins(4)) + (ins(0) && ins(1)) || (ins(0) && c0) || (ins(1) && c0) + } + + // LUTHI computes s0 as O51 and c0 as O52 + // Inputs: (x0 = in(0), x1 = in(1), x2 = in(2), x3 = in(3), x4 = in(4)) + val lutLO = Module(new LUT6CY(genLUT6CYInitString(lutLOFO51, lutLOFO52))) + lutLO.io.I0 := io.in(0) + lutLO.io.I1 := io.in(1) + lutLO.io.I2 := io.in(2) + lutLO.io.I3 := io.in(3) + lutLO.io.I4 := io.in(4) + + // LUTLO computes c1 as O51 and t1' as O52 + // Inputs: (x5 = in(0), x6 = in(1), x2 = in(2), x3 = in(3), x4 = in(4)) + val lutHI = Module(new LUT6CY(genLUT6CYInitString(lutHIFO51, lutHIFO52))) + lutHI.io.I0 := io.in(5) + lutHI.io.I1 := io.in(6) + lutHI.io.I2 := io.in(2) + lutHI.io.I3 := io.in(3) + lutHI.io.I4 := io.in(4) + + // Outputs: [t1' = out(3), c1 = out(2), c0 = out(1), s0 = out(0)] + io.out := lutHI.io.O52 ## lutHI.io.O51 ## lutLO.io.O52 ## lutLO.io.O51 + case _: Counter7_111 => // Boolean functions for the three LUTs val lutS1FO51 = (ins: Seq[Boolean]) => (ins(0) && ins(1)) || (ins(0) && ins(2)) || (ins(1) && ins(2)) @@ -853,6 +1037,49 @@ private[comptree] object Counters { // Outputs: [z2 = out(2), z1 = out(1), z0 = out(0)] io.out := lutZ.io.O52 ## lutZ.io.O51 ## lutC2.io.O51 + case _: Counter10_42 => + // Boolean functions for the three LUTs + val lut0FO51 = (ins: Seq[Boolean]) => ins.take(5).reduceLeft(_ ^ _) + val lut0FO52 = (ins: Seq[Boolean]) => { + val s2 = ins(2) ^ ins(3) ^ ins(4) + (ins(0) && s2) || (ins(1) && s2) || (ins(0) && ins(1)) + } + val lut1FO51 = lut0FO51 + val lut1FO52 = lut0FO52 + val lut2FO5 = (ins: Seq[Boolean]) => (ins(0) && ins(1)) || (ins(0) && ins(2)) || (ins(1) && ins(2)) + val lut2FO6 = (ins: Seq[Boolean]) => (ins(3) && ins(4)) || (ins(3) && ins(5)) || (ins(4) && ins(5)) + + // LUT0 computes S0 as O51 and C0 as O52 + // Inputs: (x0 = in(0), x1 = in(1), x2 = in(2), x3 = in(3), x4 = in(4)) + val lut0 = Module(new LUT6CY(genLUT6CYInitString(lut0FO51, lut0FO52))) + lut0.io.I0 := io.in(0) + lut0.io.I1 := io.in(1) + lut0.io.I2 := io.in(2) + lut0.io.I3 := io.in(3) + lut0.io.I4 := io.in(4) + + // LUT1 computes S1 as O51 and C1 as O52 + // Inputs: (x5 = in(5), x6 = in(6), x7 = in(7), x8 = in(8), x9 = in(9)) + val lut1 = Module(new LUT6CY(genLUT6CYInitString(lut1FO51, lut1FO52))) + lut1.io.I0 := io.in(5) + lut1.io.I1 := io.in(6) + lut1.io.I2 := io.in(7) + lut1.io.I3 := io.in(8) + lut1.io.I4 := io.in(9) + + // LUT2 computes C2 and C3 + // Inputs: (x2 = in(2), x3 = in(3), x4 = in(4), x7 = in(7), x8 = in(8), x9 = in(9)) + val lut2 = Module(new LUT6_2(genLUT6_2InitString(lut2FO5, lut2FO6))) + lut2.io.I0 := io.in(2) + lut2.io.I1 := io.in(3) + lut2.io.I2 := io.in(4) + lut2.io.I3 := io.in(7) + lut2.io.I4 := io.in(8) + lut2.io.I5 := io.in(9) + + // Outputs: [t3 = out(5), t2 = out(4), t1 = out(3), t0 = out(2), s1 = out(1), s0 = out(0)] + io.out := lut2.io.O6 ## lut2.io.O5 ## lut1.io.O52 ## lut0.io.O52 ## lut1.io.O51 ## lut0.io.O51 + case _: Counter8_111 => // Boolean functions for the three LUTs val lutS1FO51 = (ins: Seq[Boolean]) => (ins(0) && ins(1)) || (ins(0) && (ins(2) || ins(3))) || (ins(1) && (ins(2) || ins(3))) @@ -912,6 +1139,120 @@ private[comptree] object Counters { // Outputs: [z2 = out(2), z1 = out(1), z0 = out(0)] io.out := lutZ.io.O52 ## lutZ.io.O51 ## lutC2.io.O51 + case comp: ComposedCounter => + /** Instantiate a LOOKAHEAD8 block and connect the atoms to it */ + + /** Build the LUT structure of an atom (2,2) + * + * @param inputs the input bits to the structure + * @param cins the input carries to the structure + * @return a triple of (O51s, O52s, PROPs) output bits + * + * Implementation of the atom from Hossfeld et al. [2024] + */ + def buildAtom22(inputs: UInt, cins: UInt): (UInt, UInt, UInt) = { + // Boolean functions for the LUT + val lutFO51 = (ins: Seq[Boolean]) => ins(0) ^ ins(1) ^ ins(4) + val lutFO52 = (ins: Seq[Boolean]) => { + (ins(0) && ins(1)) || (ins(0) && ins(4)) || (ins(1) && ins(4)) + } + + // Inputs: (a0 = inputs(0), a1 = inputs(1), false, false, cins(0), false) + val lut0 = Module(new LUT6CY(genLUT6CYInitString(lutFO51, lutFO52))) + lut0.io.I0 := inputs(0) + lut0.io.I1 := inputs(1) + lut0.io.I2 := false.B + lut0.io.I3 := false.B + lut0.io.I4 := cins(0) + + // Inputs: (b0 = inputs(2), b1 = inputs(3), false, false, cins(1), false) + val lut1 = Module(new LUT6CY(genLUT6CYInitString(lutFO51, lutFO52))) + lut1.io.I0 := inputs(2) + lut1.io.I1 := inputs(3) + lut1.io.I2 := false.B + lut1.io.I3 := false.B + lut1.io.I4 := cins(1) + + // Outputs combined + (lut1.io.O51 ## lut0.io.O51, lut1.io.O52 ## lut0.io.O52, lut1.io.PROP ## lut0.io.PROP) + } + + /** Build the LUT structure of an atom (1,4) + * + * @param inputs the input bits to the structure + * @param cins the input carries to the structure + * @return a triple of (O51s, O52s, PROPs) output bits + * + * Implementation of the atom from Hossfeld et al. [2024] + */ + def buildAtom14(inputs: UInt, cins: UInt): (UInt, UInt, UInt) = { + // Boolean functions for the two LUTs + val lut0FO51 = (ins: Seq[Boolean]) => ins.take(5).reduce(_ ^ _) + val lut0FO52 = (ins: Seq[Boolean]) => { + val s = ins.take(3).reduce(_ ^ _) + (s && ins(3)) || (s && ins(4)) || (ins(3) && ins(4)) + } + val lut1FO51 = (ins: Seq[Boolean]) => { + val c = (ins(0) && ins(1)) || (ins(0) && ins(2)) || (ins(1) && ins(2)) + c ^ ins(3) ^ ins(4) + } + val lut1FO52 = (ins: Seq[Boolean]) => { + val c = (ins(0) && ins(1)) || (ins(0) && ins(2)) || (ins(1) && ins(2)) + (c && ins(3)) || (c && ins(4)) || (ins(3) && ins(4)) + } + + // Inputs: (a0 = inputs(0), a1 = inputs(1), a2 = inputs(2), a3 = inputs(3), a4 = cins(0), false) + val lut0 = Module(new LUT6CY(genLUT6CYInitString(lut0FO51, lut0FO52))) + lut0.io.I0 := inputs(0) + lut0.io.I1 := inputs(1) + lut0.io.I2 := inputs(2) + lut0.io.I3 := inputs(3) + lut0.io.I4 := cins(0) + + // Inputs: (a0 = inputs(0), a1 = inputs(1), a2 = inputs(2), a3 = inputs(4), a4 = cins(1), false) + val lut1 = Module(new LUT6CY(genLUT6CYInitString(lut1FO51, lut1FO52))) + lut1.io.I0 := inputs(0) + lut1.io.I1 := inputs(1) + lut1.io.I2 := inputs(2) + lut1.io.I3 := inputs(4) + lut1.io.I4 := cins(1) + + // Outputs combined + (lut1.io.O51 ## lut0.io.O51, lut1.io.O52 ## lut0.io.O52, lut1.io.PROP ## lut0.io.PROP) + } + + // Split the inputs between the atoms and the LOOKAHEAD8 block + val ins = Wire(MixedVec(Bool(), comp.atoms.map { atom => UInt(atom.inSig.sum.W) } :_*)) + ins := io.in.asTypeOf(ins) + + // Instantiate a LOOKAHEAD8 block and get its carry and propagate in-/outputs + val look8 = Module(new LOOKAHEAD8("TRUE", "TRUE", "TRUE", "TRUE")) + val look8CYs = look8.allCYs + val look8COs = look8.allCOs + val look8Props = look8.allProps + // Default assignments to avoid unconnected wires + look8CYs .foreach(_ := false.B) + look8Props.foreach(_ := false.B) + look8.io.CIN := ins(0) + + // Generate the LUT structure of the atoms and connect them to the LOOKAHEAD8 + val luts = comp.atoms.zipWithIndex.map { + case (_: Atom22, i) => buildAtom22(ins(i+1), look8COs(2*i+1) ## look8COs(2*i)) + case (_: Atom14, i) => buildAtom14(ins(i+1), look8COs(2*i+1) ## look8COs(2*i)) + case (atom, _) => + throw new IllegalArgumentException(s"cannot generate hardware for unsupported atom ${atom}") + } + luts.zipWithIndex.foreach { case ((_, cos, props), i) => + look8CYs(2*i) := cos(0) + look8CYs(2*i + 1) := cos(1) + + look8Props(2*i) := props(0) + look8Props(2*i + 1) := props(1) + } + + // Outputs: [carry out, sum bits] + io.out := look8COs(2*comp.atoms.length - 1) ## VecInit(luts.reverse.map(_._1)).asUInt + case _ => throw new IllegalArgumentException(s"cannot generate hardware for unsupported counter ${counter}") } } @@ -922,6 +1263,109 @@ private[comptree] object Counters { state.cntrDefs(cntrName) = Definition(new VersalCounter(cntr)) Instance(state.cntrDefs(cntrName)) } + + /** Function to construct a variable-length counter */ + def construct(cntr: VarLenCounter, len: Int, state: State): Instance[HardwareCounter] = { + /** Generic extension of the hardware counter for returning */ + class VersalCounter(counter: VarLenCounter, length: Int) extends HardwareCounter((counter.inSigFn(length), counter.outSigFn(length))) { + /** Different counters require different amounts of logic here */ + counter match { + case _: RippleSum => + val carries = Wire(Vec(length, Bool())) + val ripples = Wire(Vec(length + 1, Bool())) + ripples(0) := io.in(0) // transfer in + + // Construct and connect LUTs + (0 until length).foreach { i => + // Boolean functions for the LUT + val lutFO5 = (ins: Seq[Boolean]) => (ins(0) && ins(1)) || (ins(0) && ins(4)) || (ins(1) && ins(4)) + val lutFO6 = (ins: Seq[Boolean]) => ins(0) ^ ins(1) ^ ins(4) + + // LUT computes S as O6 and C as O5 + // Inputs: (a0 = in(2*i+1), a1 = in(2*i+2), false, false, cascade, false) + val lut = Module(new LUT6_2(genLUT6_2InitString(lutFO5, lutFO6))) + lut.io.I0 := io.in(2*i + 1) + lut.io.I1 := io.in(2*i + 2) + lut.io.I2 := false.B + lut.io.I3 := false.B + lut.io.I4 := ripples(i) // cascade + lut.io.I5 := false.B + carries(i) := lut.io.O5 + ripples(i+1) := lut.io.O6 + } + + // Outputs + io.out := carries.asUInt ## ripples(length) + + case _: DualRailRippleSum => + // Construct the top half first + val topCarries = Wire(Vec(length, Bool())) + val topRipples = Wire(Vec(length + 1, Bool())) + topRipples(0) := io.in(0) // transfer 0 in + + (0 until length).foreach { i => + // Boolean functions for the LUT + val lutFO5 = (ins: Seq[Boolean]) => { + val s = (ins(1) ^ ins(2) ^ ins(3)) + (ins(0) && s) || (ins(0) && ins(4)) || (s && ins(4)) + } + val lutFO6 = (ins: Seq[Boolean]) => ins.take(5).reduce(_ ^ _) + + // LUT computes S as O6 and C as O5 + // Inputs: (a0 = in(4*i+1), a1 = in(4*i+2), a2 = in(4*i+3), a3 = in(4*i+4), cascade, false) + val lut = Module(new LUT6_2(genLUT6_2InitString(lutFO5, lutFO6))) + lut.io.I0 := io.in(4*i + 1) + lut.io.I1 := io.in(4*i + 2) + lut.io.I2 := io.in(4*i + 3) + lut.io.I3 := io.in(4*i + 4) + lut.io.I4 := topRipples(i) // cascade + lut.io.I5 := false.B + topCarries(i) := lut.io.O5 + topRipples(i+1) := lut.io.O6 + } + + // Construct the bottom half second + val botCarries = Wire(Vec(length, Bool())) + val botRipples = Wire(Vec(length + 1, Bool())) + botRipples(0) := io.in(length + 1) // transfer 1 in + + (0 until length).foreach { i => + // Boolean functions for the LUT + val lutFO5 = (ins: Seq[Boolean]) => { + val c = (ins(1) && ins(2)) || (ins(1) && ins(3)) || (ins(2) && ins(3)) + (ins(0) && c) || (ins(0) && ins(4)) || (c && ins(4)) + } + val lutFO6 = (ins: Seq[Boolean]) => { + val c = (ins(1) && ins(2)) || (ins(1) && ins(3)) || (ins(2) && ins(3)) + ins(0) ^ c ^ ins(4) + } + + // LUT computes S as O6 and C as O5 + // Inputs: (a0 = in(i+n+2), a1 = in(4*i+2), a2 = in(4*i+3), a3 = in(4*i+4), cascade, false) + val lut = Module(new LUT6_2(genLUT6_2InitString(lutFO5, lutFO6))) + lut.io.I0 := io.in(i + length + 2) + lut.io.I1 := io.in(4*i + 2) + lut.io.I2 := io.in(4*i + 3) + lut.io.I3 := io.in(4*i + 4) + lut.io.I4 := botRipples(i) // cascade + lut.io.I5 := false.B + botCarries(i) := lut.io.O5 + botRipples(i+1) := lut.io.O6 + } + + // Outputs + io.out := botCarries.asUInt ## botRipples(length) ## topCarries.asUInt ## topRipples(length) + + case _ => throw new IllegalArgumentException(s"cannot generate hardware for unsupported variable-length counter ${counter}") + } + } + + // Store a definition of this hardware counter for future reference + val cntrName = s"${cntr.getClass().getName()}_$len" + if (!state.cntrDefs.contains(cntrName)) + state.cntrDefs(cntrName) = Definition(new VersalCounter(cntr, len)) + Instance(state.cntrDefs(cntrName)) + } } /** Collection of counters for Intel FPGAs @@ -955,6 +1399,12 @@ private[comptree] object Counters { (new Counter8_111) ) + /** Collection of exact variable-length counters */ + lazy val exactVarLenCounters: Seq[VarLenCounter] = Seq() + + /** Collection of approximate and exact variable-length counters */ + lazy val approxVarLenCounters: Seq[VarLenCounter] = exactVarLenCounters ++ Seq() + /** Function to construct a counter */ def construct(cntr: Counter, state: State): Instance[HardwareCounter] = { /** Generic extension of the hardware counter for returning */ @@ -998,5 +1448,22 @@ private[comptree] object Counters { state.cntrDefs(cntrName) = Definition(new IntelCounter(cntr)) Instance(state.cntrDefs(cntrName)) } + + /** Function to construct a variable-length counter */ + def construct(cntr: VarLenCounter, len: Int, state: State): Instance[HardwareCounter] = { + /** Generic extension of the hardware counter for returning */ + class IntelCounter(counter: VarLenCounter, length: Int) extends HardwareCounter((counter.inSigFn(length), counter.outSigFn(length))) { + /** Different counters require different amounts of logic here */ + counter match { + case _ => throw new IllegalArgumentException(s"cannot generate hardware for unsupported variable-length counter ${counter}") + } + } + + // Store a definition of this hardware counter for future reference + val cntrName = s"${cntr.getClass().getName()}_$len" + if (!state.cntrDefs.contains(cntrName)) + state.cntrDefs(cntrName) = Definition(new IntelCounter(cntr, len)) + Instance(state.cntrDefs(cntrName)) + } } } diff --git a/src/main/scala/approx/multiplication/comptree/TerminalAdders.scala b/src/main/scala/approx/multiplication/comptree/TerminalAdders.scala new file mode 100644 index 0000000..bf33206 --- /dev/null +++ b/src/main/scala/approx/multiplication/comptree/TerminalAdders.scala @@ -0,0 +1,174 @@ +package approx.multiplication.comptree + +import chisel3._ +import chisel3.util._ + +import approx.util.Xilinx.Common.{genLUT6_2InitString, LUT6_2} +import approx.util.Xilinx.SevenSeries.CARRY4 +import approx.util.Xilinx.Versal.{genLUT6CYInitString, LOOKAHEAD8, LUT6CY} + +/** Collection of terminal adders useful for compressor tree + * generation for different devices. We currently include the + * following device types: + * - Xilinx FPGAs (denoted by `Xilinx.{SevenSeries, Versal}`) + */ +private[comptree] object TerminalAdders { + /** Abstract terminal adder class + * + * @param inOps number of input operands + * @param outW in-/output bit width + */ + private[TerminalAdders] abstract class TerminalAdder(inOps: Int, outW: Int) extends Module { + val io = IO(new Bundle { + val in = Input(Vec(inOps, UInt(outW.W))) + val out = Output(UInt(outW.W)) + }) + } + + /** Ternary adder for Xilinx 7-Series and UltraScale FPGAs + * + * @param outW in-/output bit width + * + * Implements a ternary adder using LUT6_2 and CARRY4 primitives. + * The adder takes three input operands of equal bit width and + * produces a single output sum of the same bit width. + * + * Note, does not implement any form of carry-in. + */ + class TernaryAdder(outW: Int) extends TerminalAdder(3, outW) { + // Boolean functions for the LUTs + val lutFO5 = (ins: Seq[Boolean]) => (ins(1) && ins(2)) || (ins(1) && ins(3)) || (ins(2) && ins(3)) + val lutFO6 = (ins: Seq[Boolean]) => ins.take(4).reduce(_ ^ _) + + // Generate CARRY4 elements + val nCarry4 = (outW + 3) / 4 + val carries = Seq.fill(nCarry4) { Module(new CARRY4) } + carries.head.io.CI := false.B + carries.head.io.CYINIT := false.B + (1 until nCarry4).foreach { i => + carries(i).io.CI := carries(i-1).io.CO(3) + carries(i).io.CYINIT := false.B + } + + // Generate LUT6_2 elements and connect to CARRY4s + val tPrimes = Wire(Vec(outW+1, Bool())) + val luts = Seq.fill(outW) { Module(new LUT6_2(genLUT6_2InitString(lutFO5, lutFO6))) } + tPrimes(0) := false.B + (0 until outW).foreach { i => + luts(i).io.I0 := io.in(0)(i) + luts(i).io.I1 := io.in(1)(i) + luts(i).io.I2 := io.in(2)(i) + luts(i).io.I3 := tPrimes(i) + luts(i).io.I4 := false.B + luts(i).io.I5 := false.B + + // Connect to CARRY4 + tPrimes(i+1) := luts(i).io.O5 + } + carries.zipWithIndex.foreach { + case (carry4, idx) if idx == nCarry4 - 1 && outW % 4 != 0 => // last CARRY4, not fully used + val usedBits = outW % 4 + carry4.io.S := VecInit(luts.drop(idx * 4).map(_.io.O6) ++ Seq.fill(4 - usedBits)(false.B)).asUInt + carry4.io.DI := VecInit(tPrimes.drop(idx * 4) ++ Seq.fill(4 - usedBits)(false.B)).asUInt + case (carry4, idx) => // fully used CARRY4 + carry4.io.S := VecInit(luts.drop(idx * 4).take(4).map(_.io.O6)).asUInt + carry4.io.DI := VecInit(tPrimes.drop(idx * 4).take(4)).asUInt + } + + // Connect sum outputs + io.out := VecInit(carries.map(_.io.O)).asUInt(outW-1, 0) + } + + /** Quaternary adder for Xilinx Versal FPGAs + * + * @param outW in-/output bit width + * + * Implements a quaternary adder using LUT6CY and LOOKAHEAD8 primitives. + * The adder takes four input operands of equal bit width and + * produces a single output sum of the same bit width. + * + * Note, does not implement any form of carry-in. + */ + class QuaternaryAdder(outW: Int) extends TerminalAdder(4, outW) { + // Generate top and bottom LOOKAHEAD8s + val nLook8 = (outW + 7) / 8 + val topLook8 = Seq.fill(nLook8) { Module(new LOOKAHEAD8("TRUE", "TRUE", "TRUE", "TRUE")) } + val topL8CYs = topLook8.flatMap(_.allCYs) + val topL8COs = topLook8.flatMap(_.allCOs) + val topL8Props = topLook8.flatMap(_.allProps) + topLook8.head.io.CIN := false.B + topLook8.sliding(2).foreach { + case Seq(prev, next) => next.io.CIN := prev.io.COUTH + case _ => + } + + val botLook8 = Seq.fill(nLook8) { Module(new LOOKAHEAD8("TRUE", "TRUE", "TRUE", "TRUE")) } + val botL8CYs = botLook8.flatMap(_.allCYs) + val botL8COs = botLook8.flatMap(_.allCOs) + val botL8Props = botLook8.flatMap(_.allProps) + botLook8.head.io.CIN := false.B + botLook8.sliding(2).foreach { + case Seq(prev, next) => next.io.CIN := prev.io.COUTH + case _ => + } + + // ... default assignments to avoid unconnected wires + topL8CYs .foreach(_ := false.B) + topL8Props.foreach(_ := false.B) + + botL8CYs .foreach(_ := false.B) + botL8Props.foreach(_ := false.B) + + // Boolean functions for the LUTs + val topLutFO51 = (ins: Seq[Boolean]) => ins.take(5).reduce(_ ^ _) + val topLutFO52 = (ins: Seq[Boolean]) => { + val s = ins.take(3).reduce(_ ^ _) + (s && ins(3)) || (s && ins(4)) || (ins(3) && ins(4)) + } + + val botLutFO51 = (ins: Seq[Boolean]) => { + val c = (ins(0) && ins(1)) || (ins(0) && ins(2)) || (ins(1) && ins(2)) + c ^ ins(3) ^ ins(4) + } + val botLutFO52 = (ins: Seq[Boolean]) => { + val c = (ins(0) && ins(1)) || (ins(0) && ins(2)) || (ins(1) && ins(2)) + (c && ins(3)) || (c && ins(4)) || (ins(3) && ins(4)) + } + + // Generate LUT6CY elements and connect to LOOKAHEAD8s + val topLuts = Seq.fill(outW) { Module(new LUT6CY(genLUT6CYInitString(topLutFO51, topLutFO52))) } + val botLuts = Seq.fill(outW) { Module(new LUT6CY(genLUT6CYInitString(botLutFO51, botLutFO52))) } + (0 until outW).foreach { i => + topL8Props(i) := topLuts(i).io.PROP + botL8Props(i) := botLuts(i).io.PROP + + topL8CYs(i) := topLuts(i).io.O52 + botL8CYs(i) := botLuts(i).io.O52 + + if (i > 0) { // carries only from bit 1 onwards + topLuts(i).io.I4 := topL8COs(i-1) + botLuts(i).io.I4 := botL8COs(i-1) + } else { + topLuts(i).io.I4 := false.B + botLuts(i).io.I4 := false.B + } + } + topLuts.tail.zip(botLuts).foreach { case (top, bot) => bot.io.I3 := top.io.O51 } + botLuts.last.io.I3 := topLuts.last.io.O52 // final top-bottom carry + + // Connect data inputs + (0 until outW).foreach { i => + topLuts(i).io.I0 := io.in(0)(i) + topLuts(i).io.I1 := io.in(1)(i) + topLuts(i).io.I2 := io.in(2)(i) + topLuts(i).io.I3 := io.in(3)(i) + + botLuts(i).io.I0 := io.in(0)(i) + botLuts(i).io.I1 := io.in(1)(i) + botLuts(i).io.I2 := io.in(2)(i) + } + + // Connect sum outputs + io.out := VecInit(botLuts.map(_.io.O51)).asUInt ## topLuts.head.io.O51 + } +} diff --git a/src/main/scala/approx/multiplication/comptree/package.scala b/src/main/scala/approx/multiplication/comptree/package.scala index 06443e1..ec23d47 100644 --- a/src/main/scala/approx/multiplication/comptree/package.scala +++ b/src/main/scala/approx/multiplication/comptree/package.scala @@ -7,7 +7,7 @@ import scala.collection.mutable package object comptree { - import Counters.{Counter, HardwareCounter} + import Counters.{Counter, HardwareCounter, VarLenCounter} /** Compressor approximation styles * @@ -143,11 +143,15 @@ package object comptree { * when generating compressors for FPGAs. */ private[comptree] class State { - private val _counters = mutable.ArrayBuffer.empty[mutable.HashMap[Counter, Int]] + private val _counters = mutable.ArrayBuffer.empty[mutable.HashMap[Counter, Int]] + private val _vlcounters = mutable.ArrayBuffer.empty[mutable.HashMap[(VarLenCounter, Int), Int]] /** Return the counters */ def counters = _counters.map(_.toMap).toSeq + /** Return the variable-length counters */ + def vlcounters = _vlcounters.map(_.toMap).toSeq + /** Add a stage */ def addStage(): Unit = _counters.append(mutable.HashMap.empty[Counter, Int]) @@ -157,6 +161,13 @@ package object comptree { _counters.last(cntr) = _counters.last.getOrElse(cntr, 0) + 1 } + /** Add a variable-length counter to this stage */ + def addVLCounter(cntr: VarLenCounter, len: Int): Unit = { + require(_vlcounters.nonEmpty) + val key = (cntr, len) + _vlcounters.last(key) = _vlcounters.last.getOrElse(key, 0) + 1 + } + /** For hierarchy building */ val cntrDefs = mutable.HashMap.empty[String, Definition[HardwareCounter]] } @@ -175,8 +186,9 @@ package object comptree { private[comptree] class Context(val outW: Int, targetDevice: String, mtrc: Char, approx: Seq[Approximation]) { private val _device: String = targetDevice.toLowerCase() private val _mtrc: Char = mtrc.toLower - val goal: Int = _device match { + val goal = _device match { case "7series" | "ultrascale" => 3 + case "versal" => 4 case _ => 2 } val metric = _mtrc match { @@ -191,6 +203,11 @@ package object comptree { case "intel" => Counters.Intel case _ => Counters.ASIC } + val terminal = _device match { + case "7series" | "ultrascale" => "ternary" + case "versal" => "quaternary" + case _ => "" + } val approximations = { // Remove dominated column-wise approximations val colTruncOpt = approx.collect { case ct: ColumnTruncation => ct }.sortBy(_.width).lastOption diff --git a/src/main/scala/approx/util/Xilinx.scala b/src/main/scala/approx/util/Xilinx.scala index f54366c..427edaa 100644 --- a/src/main/scala/approx/util/Xilinx.scala +++ b/src/main/scala/approx/util/Xilinx.scala @@ -500,6 +500,15 @@ object Xilinx { val PROPG = Input(Bool()) val PROPH = Input(Bool()) }) + + /** Return a sequence of carry inputs */ + def allCYs = Seq(io.CYA, io.CYB, io.CYC, io.CYD, io.CYE, io.CYF, io.CYG, io.CYH) + + /** Return a sequence of carry outputs */ + def allCOs = Seq(io.CYA, io.COUTB, io.CYC, io.COUTD, io.CYE, io.COUTF, io.CYG, io.COUTH) + + /** Return a sequence of propagate inputs */ + def allProps = Seq(io.PROPA, io.PROPB, io.PROPC, io.PROPD, io.PROPE, io.PROPF, io.PROPG, io.PROPH) } class DSP58(aInput: String = "DIRECT", aMultSel: String = "A", bInput: String = "DIRECT", diff --git a/src/main/scala/approx/util/package.scala b/src/main/scala/approx/util/package.scala index b497f75..34dc231 100644 --- a/src/main/scala/approx/util/package.scala +++ b/src/main/scala/approx/util/package.scala @@ -47,4 +47,26 @@ package object util { VecInit(Seq.fill(log2Up(width))(oneHot(i))).asUInt & i.U(log2Up(width).W) }).reduceTree(_ | _) } + + /** Parallel-read shift register + * + * @param depth the depth of the shift register + */ + class PRShiftReg[T <: Data](gen: T, depth: Int) extends Module { + require(depth >= 0, "depth of shift register must be non-negative") + val io = IO(new Bundle { + val in = Input(gen) + val out = Output(Vec(depth + 1, gen)) + }) + if (depth == 0) { + io.out(0) := io.in + } else { + val regs = Seq.fill(depth) { RegInit(gen, 0.U.asTypeOf(gen)) } + regs.head := io.in + (1 until depth).foreach { s => + regs(s) := regs(s-1) + } + io.out := VecInit(io.in +: regs) + } + } } diff --git a/src/test/scala/approx/accumulation/ExactAccumulatorSpec.scala b/src/test/scala/approx/accumulation/ExactAccumulatorSpec.scala index 02651d2..bb10488 100644 --- a/src/test/scala/approx/accumulation/ExactAccumulatorSpec.scala +++ b/src/test/scala/approx/accumulation/ExactAccumulatorSpec.scala @@ -6,10 +6,7 @@ import org.scalatest.flatspec.AnyFlatSpec import approx.multiplication.comptree.Signature -/** Common test patterns for exact accumulators - * - * @todo Extend all these with support for pipelining! - */ +/** Common test patterns for exact accumulators */ abstract trait ExactAccumulatorSpec extends AnyFlatSpec with ChiselSim { val CommonWidths = List(4, 8, 16, 32) val OddWidths = List(5, 13, 29) @@ -29,9 +26,12 @@ trait SASpec extends ExactAccumulatorSpec { * @param acc the expected sum */ def pokeAndExpect[T <: SA](in: UInt, zero: Bool)(acc: UInt)(implicit dut: T) = { + dut.io.en.poke(true.B) dut.io.in.poke(in) dut.io.zero.poke(zero) dut.clock.step() + dut.io.en.poke(false.B) + if (dut.pipes > 0) dut.clock.step(dut.pipes) dut.io.acc.expect(acc) } @@ -46,10 +46,13 @@ trait SASpec extends ExactAccumulatorSpec { * inputs. Ignores extra inputs in the opposite case. */ def pokeAndExpect[T <: PSA](ins: Seq[UInt], zero: Bool)(acc: UInt)(implicit dut: T) = { + dut.io.en.poke(true.B) val insExt = if (ins.size < dut.io.ins.size) ins ++ Seq.fill(dut.io.ins.size - ins.size)(0.U) else ins dut.io.ins.zip(insExt).foreach { case (port, inv) => port.poke(inv) } dut.io.zero.poke(zero) dut.clock.step() + dut.io.en.poke(false.B) + if (dut.pipes > 0) dut.clock.step(dut.pipes) dut.io.acc.expect(acc) } @@ -146,16 +149,19 @@ trait SASpec extends ExactAccumulatorSpec { class SimpleAccumulatorSpec extends SASpec { behavior of "Simple Accumulator" + val rng = new scala.util.Random(1337) for (width <- CommonWidths ++ OddWidths) { - it should s"do random $width-bit unsigned accumulation" in { - simulate(new SimpleAccumulator(width-3, width)) { dut => + val uPipes = rng.nextInt(5) + it should s"do random $width-bit unsigned accumulation with $uPipes pipeline stages" in { + simulate(new SimpleAccumulator(width-3, width, pipes = uPipes)) { dut => randomUnsignedTest(dut) } } - it should s"do random $width-bit signed accumulation" in { - simulate(new SimpleAccumulator(width-3, width, true)) { dut => + val sPipes = rng.nextInt(5) + it should s"do random $width-bit signed accumulation with $sPipes pipeline stages" in { + simulate(new SimpleAccumulator(width-3, width, true, pipes = sPipes)) { dut => randomSignedTest(dut) } } @@ -164,16 +170,19 @@ class SimpleAccumulatorSpec extends SASpec { class ParallelSimpleAccumulatorSpec extends SASpec { behavior of "Parallel Simple Accumulator" + val rng = new scala.util.Random(27) for (width <- CommonWidths ++ OddWidths) { - it should s"do random $width-bit unsigned accumulation" in { - simulate(new ParallelSimpleAccumulator(width / 2, width-3, width)) { dut => + val uPipes = rng.nextInt(5) + it should s"do random $width-bit unsigned accumulation with $uPipes pipeline stages" in { + simulate(new ParallelSimpleAccumulator(width / 2, width-3, width, pipes = uPipes)) { dut => randomUnsignedTest(dut) } } - it should s"do random $width-bit signed accumulation" in { - simulate(new ParallelSimpleAccumulator(width / 2, width-3, width, true)) { dut => + val sPipes = rng.nextInt(5) + it should s"do random $width-bit signed accumulation with $sPipes pipeline stages" in { + simulate(new ParallelSimpleAccumulator(width / 2, width-3, width, true, pipes = sPipes)) { dut => randomSignedTest(dut) } } @@ -189,10 +198,13 @@ trait MACSpec extends ExactAccumulatorSpec { * @param acc the expected sum */ def pokeAndExpect[T <: MAC](a: UInt, b: UInt, zero: Bool)(acc: UInt)(implicit dut: T) = { + dut.io.en.poke(true.B) dut.io.a.poke(a) dut.io.b.poke(b) dut.io.zero.poke(zero) dut.clock.step() + dut.io.en.poke(false.B) + if (dut.pipes > 0) dut.clock.step(dut.pipes) dut.io.acc.expect(acc) } @@ -208,12 +220,13 @@ trait MACSpec extends ExactAccumulatorSpec { * inputs. Ignores extra inputs in the opposite case. */ def pokeAndExpect[T <: PMAC](as: Seq[UInt], bs: Seq[UInt], zero: Bool)(acc: UInt)(implicit dut: T) = { - if (as.size != bs.size) - println(s"Warning: ignoring values on MAC input") + dut.io.en.poke(true.B) dut.io.as.zip(as).foreach { case (port, inv) => port.poke(inv) } dut.io.bs.zip(bs).foreach { case (port, inv) => port.poke(inv) } dut.io.zero.poke(zero) dut.clock.step() + dut.io.en.poke(false.B) + if (dut.pipes > 0) dut.clock.step(dut.pipes) dut.io.acc.expect(acc) } @@ -234,8 +247,8 @@ trait MACSpec extends ExactAccumulatorSpec { pokeAndExpect(0.U, 0.U, false.B)(0.U)(mac) // Accumulate some numbers - val as = Array.fill(n) { BigInt(mac.inW, rng) } - val bs = Array.fill(n) { BigInt(mac.inW, rng) } + val as = Array.fill(n) { BigInt(mac.inAW, rng) } + val bs = Array.fill(n) { BigInt(mac.inBW, rng) } val zeros = Array.fill(n) { rng.nextBoolean() } (0 until n).foreach { i => acc = if (zeros(i)) (as(i) * bs(i)) & mask else (acc + as(i) * bs(i)) & mask @@ -250,8 +263,8 @@ trait MACSpec extends ExactAccumulatorSpec { pokeAndExpect(Seq.fill(pmac.nIn)(0.U), Seq.fill(pmac.nIn)(0.U), false.B)(0.U)(pmac) // Accumulate some numbers - val as = Array.fill(n) { Seq.fill(pmac.nIn) { BigInt(pmac.inW, rng) } } - val bs = Array.fill(n) { Seq.fill(pmac.nIn) { BigInt(pmac.inW, rng) } } + val as = Array.fill(n) { Seq.fill(pmac.nIn) { BigInt(pmac.inAW, rng) } } + val bs = Array.fill(n) { Seq.fill(pmac.nIn) { BigInt(pmac.inBW, rng) } } val zeros = Array.fill(n) { rng.nextBoolean() } (0 until n).foreach { i => acc = if (zeros(i)) as(i).zip(bs(i)).map{ case (a, b) => a * b }.sum & mask @@ -259,7 +272,7 @@ trait MACSpec extends ExactAccumulatorSpec { pokeAndExpect(as(i).map(_.U), bs(i).map(_.U), zeros(i).B)(acc.U)(pmac) } - case _ => throw new IllegalArgumentException("can only verify SAs and PSAs") + case _ => throw new IllegalArgumentException("can only verify MACs and PMACs") } } @@ -280,13 +293,14 @@ trait MACSpec extends ExactAccumulatorSpec { pokeAndExpect(0.U, 0.U, false.B)(0.U)(mac) // Accumulate some numbers - val ext = ((BigInt(1) << mac.accW) - 1) & ~((BigInt(1) << mac.inW) - 1) - val as = Array.fill(n) { BigInt(mac.inW, rng) } - val bs = Array.fill(n) { BigInt(mac.inW, rng) } + val extA = ((BigInt(1) << mac.accW) - 1) & ~((BigInt(1) << mac.inAW) - 1) + val extB = ((BigInt(1) << mac.accW) - 1) & ~((BigInt(1) << mac.inBW) - 1) + val as = Array.fill(n) { BigInt(mac.inAW, rng) } + val bs = Array.fill(n) { BigInt(mac.inBW, rng) } val zeros = Array.fill(n) { rng.nextBoolean() } (0 until n).foreach { i => - val aExt = if (as(i).testBit(mac.inW-1)) ext | as(i) else as(i) - val bExt = if (bs(i).testBit(mac.inW-1)) ext | bs(i) else bs(i) + val aExt = if (as(i).testBit(mac.inAW-1)) extA | as(i) else as(i) + val bExt = if (bs(i).testBit(mac.inBW-1)) extB | bs(i) else bs(i) acc = if (zeros(i)) (aExt * bExt) & mask else (acc + aExt * bExt) & mask pokeAndExpect(as(i).U, bs(i).U, zeros(i).B)(acc.U)(mac) } @@ -299,13 +313,14 @@ trait MACSpec extends ExactAccumulatorSpec { pokeAndExpect(Seq.fill(pmac.nIn)(0.U), Seq.fill(pmac.nIn)(0.U), false.B)(0.U)(pmac) // Accumulate some numbers - val ext = ((BigInt(1) << pmac.accW) - 1) & ~((BigInt(1) << pmac.inW) - 1) - val as = Array.fill(n) { Seq.fill(pmac.nIn) { BigInt(pmac.inW, rng) }} - val bs = Array.fill(n) { Seq.fill(pmac.nIn) { BigInt(pmac.inW, rng) }} + val extA = ((BigInt(1) << pmac.accW) - 1) & ~((BigInt(1) << pmac.inAW) - 1) + val extB = ((BigInt(1) << pmac.accW) - 1) & ~((BigInt(1) << pmac.inBW) - 1) + val as = Array.fill(n) { Seq.fill(pmac.nIn) { BigInt(pmac.inAW, rng) }} + val bs = Array.fill(n) { Seq.fill(pmac.nIn) { BigInt(pmac.inBW, rng) }} val zeros = Array.fill(n) { rng.nextBoolean() } (0 until n).foreach { i => - val asExt = as(i).map { a => if (a.testBit(pmac.inW-1)) ext | a else a } - val bsExt = bs(i).map { b => if (b.testBit(pmac.inW-1)) ext | b else b } + val asExt = as(i).map { a => if (a.testBit(pmac.inAW-1)) extA | a else a } + val bsExt = bs(i).map { b => if (b.testBit(pmac.inBW-1)) extB | b else b } acc = if (zeros(i)) asExt.zip(bsExt).map{ case (a, b) => a * b }.sum & mask else (acc + asExt.zip(bsExt).map{ case (a, b) => a * b }.sum) & mask pokeAndExpect(as(i).map(_.U), bs(i).map(_.U), zeros(i).B)(acc.U)(pmac) @@ -318,17 +333,22 @@ trait MACSpec extends ExactAccumulatorSpec { class MultiplyAccumulatorSpec extends MACSpec { behavior of "Multiply Accumulator" - - for (width <- CommonWidths ++ OddWidths) { - it should s"do random $width-bit unsigned accumulation" in { - simulate(new MultiplyAccumulator(width-3, width)) { dut => - randomUnsignedTest(dut) + val rng = new scala.util.Random(1337) + + for (aW <- CommonWidths ++ OddWidths) { + for (bW <- CommonWidths) { + val uPipes = rng.nextInt(5) + it should s"do random $aW-by-$bW-bit unsigned accumulation" in { + simulate(new MultiplyAccumulator(aW, bW, aW+bW-3, pipes = uPipes)) { dut => + randomUnsignedTest(dut) + } } - } - it should s"do random $width-bit signed accumulation" in { - simulate(new MultiplyAccumulator(width-3, width, true)) { dut => - randomSignedTest(dut) + val sPipes = rng.nextInt(5) + it should s"do random $aW-by-$bW-bit signed accumulation" in { + simulate(new MultiplyAccumulator(aW, bW, aW+bW+3, true, pipes = sPipes)) { dut => + randomSignedTest(dut) + } } } } @@ -336,19 +356,26 @@ class MultiplyAccumulatorSpec extends MACSpec { class ParallelMultiplyAccumulatorSpec extends MACSpec { behavior of "Parallel Multiply Accumulator" + val rng = new scala.util.Random(1997) // These tests are only run for relatively low bit-widths due to // long execution times otherwise - for (width <- CommonWidths ++ OddWidths) { - it should s"do random $width-bit unsigned accumulation" in { - simulate(new ParallelMultiplyAccumulator(width / 2, width-3, width)) { dut => - randomUnsignedTest(dut) + for (aW <- CommonWidths ++ OddWidths) { + for (bW <- CommonWidths) { + val uPipes = rng.nextInt(5) + it should s"do random $aW-by-$bW-bit unsigned accumulation with $uPipes pipeline stages" in { + simulate(new ParallelMultiplyAccumulator( + scala.math.min(aW, bW) / 2, aW, bW, aW + bW + 3, pipes = uPipes)) { dut => + randomUnsignedTest(dut) + } } - } - it should s"do random $width-bit signed accumulation" in { - simulate(new ParallelMultiplyAccumulator(width / 2, width-3, width, true)) { dut => - randomSignedTest(dut) + val sPipes = rng.nextInt(5) + it should s"do random $aW-by-$bW-bit signed accumulation with $sPipes pipeline stages" in { + simulate(new ParallelMultiplyAccumulator( + scala.math.min(aW, bW) / 2, aW, bW, aW + bW - 3, true, pipes = sPipes)) { dut => + randomSignedTest(dut) + } } } } @@ -362,9 +389,12 @@ trait MxACSpec extends ExactAccumulatorSpec { * @param acc the expected sum */ def pokeAndExpect[T <: MxAC](in: UInt, zero: Bool)(acc: UInt)(implicit dut: T) = { + dut.io.en.poke(true.B) dut.io.in.poke(in) dut.io.zero.poke(zero) dut.clock.step() + dut.io.en.poke(false.B) + if (dut.pipes > 0) dut.clock.step(dut.pipes) dut.io.acc.expect(acc) } @@ -421,9 +451,10 @@ class BitMatrixAccumulatorSpec extends MxACSpec { val rng = new scala.util.Random(0) for (width <- CommonWidths ++ OddWidths) { - it should s"do random $width-bit accumulation" in { - val sig = new Signature(Array.fill(width)(rng.nextInt(2 * width))) - simulate(new BitMatrixAccumulator(sig, width)) { dut => + val pipes = rng.nextInt(5) + it should s"do random $width-bit accumulation with $pipes pipeline stages" in { + val sig = new Signature(Array.fill(width)(rng.nextInt(2 * width))) + simulate(new BitMatrixAccumulator(sig, width, pipes = pipes)) { dut => randomTest(dut) } }