diff --git a/experimental/lib/Support/FormalProperty.cpp b/experimental/lib/Support/FormalProperty.cpp index 388d4e3b35..160bb41b35 100644 --- a/experimental/lib/Support/FormalProperty.cpp +++ b/experimental/lib/Support/FormalProperty.cpp @@ -11,7 +11,7 @@ //===----------------------------------------------------------------------===// #include "experimental/Support/FormalProperty.h" #include "dynamatic/Analysis/NameAnalysis.h" -#include "dynamatic/Dialect/Handshake/HandshakeInterfaces.h" +#include "dynamatic/Dialect/Handshake/HandshakeOps.h" #include "dynamatic/Support/JSON/JSON.h" #include "llvm/Support/JSON.h" #include @@ -128,9 +128,6 @@ AbsenceOfBackpressure::AbsenceOfBackpressure(unsigned long id, TAG tag, Operation *ownerOp = res.getOwner(); Operation *userOp = *res.getUsers().begin(); - handshake::PortNamer ownerNamer(ownerOp); - handshake::PortNamer userNamer(userOp); - unsigned long operandIndex = userOp->getNumOperands(); for (auto [j, arg] : llvm::enumerate(userOp->getOperands())) { if (arg == res) { @@ -144,9 +141,11 @@ AbsenceOfBackpressure::AbsenceOfBackpressure(unsigned long id, TAG tag, userChannel.operationName = getUniqueName(userOp).str(); ownerChannel.channelIndex = res.getResultNumber(); userChannel.channelIndex = operandIndex; + auto handshakeOwnerOp = handshake::getHandshakeBase(ownerOp); ownerChannel.channelName = - ownerNamer.getOutputName(res.getResultNumber()).str(); - userChannel.channelName = userNamer.getInputName(operandIndex).str(); + handshakeOwnerOp.getResultName(res.getResultNumber()); + auto handshakeUserOp = handshake::getHandshakeBase(userOp); + userChannel.channelName = handshakeUserOp.getOperandName(operandIndex); } llvm::json::Value AbsenceOfBackpressure::extraInfoToJSON() const { @@ -184,18 +183,18 @@ ValidEquivalence::ValidEquivalence(unsigned long id, TAG tag, : FormalProperty(id, tag, TYPE::VEQ) { Operation *op1 = res1.getOwner(); unsigned int i = res1.getResultNumber(); - handshake::PortNamer namer1(op1); Operation *op2 = res2.getOwner(); unsigned int j = res2.getResultNumber(); - handshake::PortNamer namer2(op2); ownerChannel.operationName = getUniqueName(op1).str(); targetChannel.operationName = getUniqueName(op2).str(); ownerChannel.channelIndex = i; targetChannel.channelIndex = j; - ownerChannel.channelName = namer1.getOutputName(i).str(); - targetChannel.channelName = namer2.getOutputName(j).str(); + auto handshakeOp1 = handshake::getHandshakeBase(op1); + auto handshakeOp2 = handshake::getHandshakeBase(op2); + ownerChannel.channelName = handshakeOp1.getResultName(i); + targetChannel.channelName = handshakeOp2.getResultName(j); } llvm::json::Value ValidEquivalence::extraInfoToJSON() const { diff --git a/include/dynamatic/Dialect/Handshake/HandshakeArithOps.td b/include/dynamatic/Dialect/Handshake/HandshakeArithOps.td index e253dc9212..1fdb372833 100644 --- a/include/dynamatic/Dialect/Handshake/HandshakeArithOps.td +++ b/include/dynamatic/Dialect/Handshake/HandshakeArithOps.td @@ -24,25 +24,12 @@ class Handshake_Arith_Op traits = []> : class Handshake_Arith_BinaryOp traits = []> : Handshake_Arith_Op, + SameOperandsAndResultType ]> { let arguments = (ins ChannelType:$lhs, ChannelType:$rhs); let results = (outs ChannelType:$result); let assemblyFormat = "$lhs `,` $rhs attr-dict `:` type($result)"; - - let extraClassDefinition = [{ - std::string $cppClass::getOperandName(unsigned idx) { - assert(idx < 2 && "index too high"); - return (idx == 0) ? "lhs" : "rhs"; - } - - std::string $cppClass::getResultName(unsigned idx) { - assert(idx < 1 && "index too high"); - return "result"; - } - }]; } class Handshake_Arith_IntBinaryOp traits = []> : @@ -76,24 +63,11 @@ class Handshake_Arith_CompareOp traits = []> : AllTypesMatch<["lhs", "rhs"]>, AllExtraSignalsMatch<["lhs", "rhs", "result"]>, IsIntSizedChannel<1, "result">, - DeclareOpInterfaceMethods, - DeclareOpInterfaceMethods + DeclareOpInterfaceMethods ]> { let results = (outs ChannelType:$result); let assemblyFormat = "$predicate `,` $lhs `,` $rhs attr-dict `:` type($lhs)"; - - let extraClassDefinition = [{ - std::string $cppClass::getOperandName(unsigned idx) { - assert(idx < getNumOperands() && "index too high"); - return (idx == 0) ? "lhs" : "rhs"; - } - - std::string $cppClass::getResultName(unsigned idx) { - assert(idx < 1 && "index too high"); - return "result"; - } - }]; } class Handshake_Arith_IToICastOp traits = []> : @@ -155,17 +129,24 @@ class Handshake_Arith_FToICastOp traits = []> : def Handshake_AddFOp : Handshake_Arith_FloatBinaryOp<"addf", [ Commutative, + BinaryArithNamedIOInterface, FPUImplInterface, LatencyInterface ]> { let summary = "Floating-point addition."; } -def Handshake_AddIOp : Handshake_Arith_IntBinaryOp<"addi", [Commutative]> { +def Handshake_AddIOp : Handshake_Arith_IntBinaryOp<"addi", [ + Commutative, + BinaryArithNamedIOInterface +]> { let summary = "Integer addition."; } -def Handshake_AndIOp : Handshake_Arith_IntBinaryOp<"andi", [Commutative]> { +def Handshake_AndIOp : Handshake_Arith_IntBinaryOp<"andi", [ + Commutative, + BinaryArithNamedIOInterface +]> { let summary = "Bitwise conjunction."; } @@ -196,6 +177,7 @@ def Handshake_CmpFPredicateAttr : I64EnumAttr< def Handshake_CmpFOp : Handshake_Arith_CompareOp<"cmpf", [ IsFloatChannel<"lhs">, IsFloatChannel<"rhs">, + BinaryArithNamedIOInterface, FPUImplInterface, LatencyInterface ]> { @@ -225,7 +207,8 @@ def Handshake_CmpIPredicateAttr : I64EnumAttr< def Handshake_CmpIOp : Handshake_Arith_CompareOp<"cmpi", [ IsIntChannel<"lhs">, - IsIntChannel<"rhs"> + IsIntChannel<"rhs">, + BinaryArithNamedIOInterface ]> { let summary = "Integer comparison."; @@ -235,8 +218,11 @@ def Handshake_CmpIOp : Handshake_Arith_CompareOp<"cmpi", [ def Handshake_ConstantOp : Handshake_Arith_Op<"constant", [ AllExtraSignalsMatch<["ctrl", "result"]>, - DeclareOpInterfaceMethods, - DeclareOpInterfaceMethods + DeclareOpInterfaceMethods, + // The IOs of a Handshake_ConstantOp have custom names, + // defined here in the Tablegen + // as part of the Handshake_ConstantOp declaration + CustomNamedIOInterface ]> { let summary = "constant operation"; let description = [{ @@ -256,6 +242,25 @@ def Handshake_ConstantOp : Handshake_Arith_Op<"constant", [ let arguments = (ins TypedAttrInterface:$value, ControlType:$ctrl); let results = (outs ChannelType:$result); + + let extraClassDeclaration = [{ + //===------------------------------------------------------------------===// + // CustomNamedIOInterface Methods + //===------------------------------------------------------------------===// + + // Operand 0 is named ctrl + std::string getOperandNameImpl(unsigned idx) { + assert(idx < 1 && "index too high"); + return "ctrl"; + } + + // Result 0 is named outs + // detail::simpleResultName handles idx validation + std::string getResultNameImpl(unsigned idx) { + return detail::simpleResultName(idx, 1); + } + }]; + // The type of the control also needs to be specified in the IR. // It may have extra bits, which could affect the result's type and token. let assemblyFormat = "$ctrl attr-dict `:` type($ctrl) `,` type($result)"; @@ -263,13 +268,16 @@ def Handshake_ConstantOp : Handshake_Arith_Op<"constant", [ } def Handshake_DivFOp : Handshake_Arith_FloatBinaryOp<"divf", [ + BinaryArithNamedIOInterface, FPUImplInterface, LatencyInterface ]> { let summary = "Floating-point division."; } + def Handshake_DivSIOp : Handshake_Arith_IntBinaryOp<"divsi", [ + BinaryArithNamedIOInterface, IsIntSizedChannel<32, "lhs">, IsIntSizedChannel<32, "rhs">, IsIntSizedChannel<32, "result">, @@ -279,6 +287,7 @@ def Handshake_DivSIOp : Handshake_Arith_IntBinaryOp<"divsi", [ } def Handshake_RemSIOp : Handshake_Arith_IntBinaryOp<"remsi", [ + BinaryArithNamedIOInterface, IsIntSizedChannel<32, "lhs">, IsIntSizedChannel<32, "rhs">, IsIntSizedChannel<32, "result">, @@ -288,6 +297,7 @@ def Handshake_RemSIOp : Handshake_Arith_IntBinaryOp<"remsi", [ } def Handshake_DivUIOp : Handshake_Arith_IntBinaryOp<"divui", [ + BinaryArithNamedIOInterface, IsIntSizedChannel<32, "lhs">, IsIntSizedChannel<32, "rhs">, IsIntSizedChannel<32, "result">, @@ -297,25 +307,34 @@ def Handshake_DivUIOp : Handshake_Arith_IntBinaryOp<"divui", [ let summary = "Unsigned integer division."; } -def Handshake_MaxSIOp : Handshake_Arith_IntBinaryOp<"maxsi"> { +def Handshake_MaxSIOp : Handshake_Arith_IntBinaryOp<"maxsi", [ + BinaryArithNamedIOInterface +]> { let summary = "Outputs the maximum between two input signed integers"; } -def Handshake_MaxUIOp : Handshake_Arith_IntBinaryOp<"maxui"> { +def Handshake_MaxUIOp : Handshake_Arith_IntBinaryOp<"maxui", [ + BinaryArithNamedIOInterface +]> { let summary = "Outputs the maximum between two input unsigned integers"; } -def Handshake_ExtSIOp : Handshake_Arith_IToICastOp<"extsi"> { +def Handshake_ExtSIOp : Handshake_Arith_IToICastOp<"extsi", [ + SimpleNamedIOInterface +]> { let summary = "Integer unsigned width extension."; let hasCanonicalizer = 1; } -def Handshake_ExtUIOp : Handshake_Arith_IToICastOp<"extui"> { +def Handshake_ExtUIOp : Handshake_Arith_IToICastOp<"extui", [ + SimpleNamedIOInterface +]> { let summary = "Integer signed width extension."; } def Handshake_MaximumFOp : Handshake_Arith_FloatBinaryOp<"maximumf", [ Commutative, + BinaryArithNamedIOInterface, IsFloatSizedChannel<32, "lhs">, IsFloatSizedChannel<32, "rhs">, IsFloatSizedChannel<32, "result">, @@ -326,6 +345,7 @@ def Handshake_MaximumFOp : Handshake_Arith_FloatBinaryOp<"maximumf", [ def Handshake_MinimumFOp : Handshake_Arith_FloatBinaryOp< "minimumf", [ Commutative, + BinaryArithNamedIOInterface, IsFloatSizedChannel<32, "lhs">, IsFloatSizedChannel<32, "rhs">, IsFloatSizedChannel<32, "result">, @@ -336,6 +356,7 @@ def Handshake_MinimumFOp : Handshake_Arith_FloatBinaryOp< "minimumf", [ def Handshake_MulFOp : Handshake_Arith_FloatBinaryOp<"mulf", [ Commutative, + BinaryArithNamedIOInterface, FPUImplInterface, LatencyInterface ]> { @@ -344,16 +365,22 @@ def Handshake_MulFOp : Handshake_Arith_FloatBinaryOp<"mulf", [ def Handshake_MulIOp : Handshake_Arith_IntBinaryOp<"muli", [ Commutative, + BinaryArithNamedIOInterface, LatencyInterface ]> { let summary = "Integer multiplication."; } -def Handshake_NegFOp : Handshake_Arith_FloatUnaryOp<"negf"> { +def Handshake_NegFOp : Handshake_Arith_FloatUnaryOp<"negf", [ + SimpleNamedIOInterface +]> { let summary = "Floating-point sign negation."; } -def Handshake_OrIOp : Handshake_Arith_IntBinaryOp<"ori", [Commutative]> { +def Handshake_OrIOp : Handshake_Arith_IntBinaryOp<"ori", [ + Commutative, + BinaryArithNamedIOInterface +]> { let summary = "Bitwise union."; } @@ -361,7 +388,10 @@ def Handshake_SelectOp : Handshake_Arith_Op<"select", [ AllTypesMatch<["trueValue", "falseValue", "result"]>, AllExtraSignalsMatch<["condition", "trueValue", "falseValue", "result"]>, IsIntSizedChannel<1, "condition">, - DeclareOpInterfaceMethods, + // The IOs of a Handshake_SelectOp have custom names, + // defined here in the Tablegen + // as part of the Handshake_SelectOp declaration + CustomNamedIOInterface ]> { let summary = "Select a value based on a 1-bit predicate."; @@ -369,41 +399,75 @@ def Handshake_SelectOp : Handshake_Arith_Op<"select", [ ChannelType:$falseValue); let results = (outs ChannelType:$result); + let extraClassDeclaration = [{ + //===------------------------------------------------------------------===// + // CustomNamedIOInterface Methods + //===------------------------------------------------------------------===// + + // Operand 0 is named condition + // Operand 1 is named trueValue + // Operand 2 is named falseValue + std::string getOperandNameImpl(unsigned idx) { + assert(idx < getNumOperands() && "index too high"); + if (idx == 0) + return "condition"; + return (idx == 1) ? "trueValue" : "falseValue"; + } + + // Result 0 is named result + std::string getResultNameImpl(unsigned idx) { + assert(idx < 1 && "index too high"); + return "result"; + } + }]; + let assemblyFormat = [{ $condition `[` $trueValue `,` $falseValue `]` attr-dict `:` type($condition) `,` type($result) }]; } -def Handshake_ShLIOp : Handshake_Arith_IntBinaryOp<"shli"> { +def Handshake_ShLIOp : Handshake_Arith_IntBinaryOp<"shli", [ + BinaryArithNamedIOInterface +]> { let summary = "Logical left shift."; } -def Handshake_ShRSIOp : Handshake_Arith_IntBinaryOp<"shrsi"> { +def Handshake_ShRSIOp : Handshake_Arith_IntBinaryOp<"shrsi", [ + BinaryArithNamedIOInterface +]> { let summary = "Arithmetic right shift."; } -def Handshake_ShRUIOp : Handshake_Arith_IntBinaryOp<"shrui"> { +def Handshake_ShRUIOp : Handshake_Arith_IntBinaryOp<"shrui", [ + BinaryArithNamedIOInterface +]> { let summary = "Logical right shift."; } def Handshake_SubFOp : Handshake_Arith_FloatBinaryOp<"subf", [ + BinaryArithNamedIOInterface, FPUImplInterface, LatencyInterface ]> { let summary = "Floating-point subtraction."; } -def Handshake_SubIOp : Handshake_Arith_IntBinaryOp<"subi"> { +def Handshake_SubIOp : Handshake_Arith_IntBinaryOp<"subi", [ + BinaryArithNamedIOInterface +]> { let summary = "Integer subtraction."; } -def Handshake_TruncIOp : Handshake_Arith_IToICastOp<"trunci"> { +def Handshake_TruncIOp : Handshake_Arith_IToICastOp<"trunci", [ + SimpleNamedIOInterface +]> { let summary = "Integer truncation."; let hasCanonicalizer = 1; } def Handshake_TruncFOp : Handshake_Arith_FToFCastOp<"truncf", [ + SimpleNamedIOInterface, IsFloatSizedChannel<64, "in">, IsFloatSizedChannel<32, "out"> ]>{ @@ -411,11 +475,15 @@ def Handshake_TruncFOp : Handshake_Arith_FToFCastOp<"truncf", [ // TODO: add canonicalizer } -def Handshake_XOrIOp : Handshake_Arith_IntBinaryOp<"xori", [Commutative]> { +def Handshake_XOrIOp : Handshake_Arith_IntBinaryOp<"xori", [ + Commutative, + BinaryArithNamedIOInterface +]> { let summary = "Bitwise exclusive union."; } def Handshake_UIToFPOp : Handshake_Arith_IToFCastOp<"uitofp", [ + SimpleNamedIOInterface, IsFloatSizedChannel<32, "out"> ]> { // NOTE: @@ -432,7 +500,8 @@ def Handshake_UIToFPOp : Handshake_Arith_IToFCastOp<"uitofp", [ let summary = "Converts a unsigned integer to float."; } -def Handshake_SIToFPOp : Handshake_Arith_IToFCastOp<"sitofp",[ +def Handshake_SIToFPOp : Handshake_Arith_IToFCastOp<"sitofp", [ + SimpleNamedIOInterface, IsIntSizedChannel<32, "in">, IsFloatSizedChannel<32, "out"> ]> { @@ -440,6 +509,7 @@ def Handshake_SIToFPOp : Handshake_Arith_IToFCastOp<"sitofp",[ } def Handshake_FPToSIOp : Handshake_Arith_FToICastOp<"fptosi", [ + SimpleNamedIOInterface, IsFloatSizedChannel<32, "in">, IsIntSizedChannel<32, "out">, LatencyInterface @@ -448,6 +518,7 @@ def Handshake_FPToSIOp : Handshake_Arith_FToICastOp<"fptosi", [ } def Handshake_ExtFOp : Handshake_Arith_FToFCastOp<"extf", [ + SimpleNamedIOInterface, IsFloatSizedChannel<32, "in">, IsFloatSizedChannel<64, "out"> ]>{ @@ -455,7 +526,9 @@ def Handshake_ExtFOp : Handshake_Arith_FToFCastOp<"extf", [ // TODO: add canonicalizer } -def Handshake_AbsFOp : Handshake_Arith_FToFCastOp<"absf"> { +def Handshake_AbsFOp : Handshake_Arith_FToFCastOp<"absf", [ + SimpleNamedIOInterface +]> { let summary = "floating point absolute-value operation"; // TODO: add folder } diff --git a/include/dynamatic/Dialect/Handshake/HandshakeInterfaces.h b/include/dynamatic/Dialect/Handshake/HandshakeInterfaces.h index c8d488100c..5cddc48b84 100644 --- a/include/dynamatic/Dialect/Handshake/HandshakeInterfaces.h +++ b/include/dynamatic/Dialect/Handshake/HandshakeInterfaces.h @@ -32,51 +32,37 @@ namespace dynamatic { namespace handshake { -class NamedIOInterface; class FuncOp; -/// Provides an opaque interface for generating the port names of an operation; -/// handshake operations generate names by the `handshake::NamedIOInterface`; -/// other operations, such as arithmetic ones, are assigned default names. -class PortNamer { -public: - /// Does nothing; no port name will be generated. - PortNamer() = default; - - /// Derives port names for the operation on object creation. - PortNamer(Operation *op); - - /// Returs the port name of the input at the specified index. - StringRef getInputName(unsigned idx) const { return inputs[idx]; } - - /// Returs the port name of the output at the specified index. - StringRef getOutputName(unsigned idx) const { return outputs[idx]; } +class ControlType; -private: - /// Maps the index of an input or output to its port name. - using IdxToStrF = const std::function &; +namespace detail { - /// Infers port names for the operation using the provided callbacks. - void infer(Operation *op, IdxToStrF &inF, IdxToStrF &outF); +inline std::string simpleOperandName(unsigned idx, unsigned numOperands) { + assert(idx < numOperands && "index too high"); - /// Infers default port names when nothing better can be achieved. - void inferDefault(Operation *op); + // TODO: Remove 2D I/O packing + // but for now this is needed + if (numOperands == 1) { + return "ins"; + } - /// Infers port names for an operation implementing the - /// `handshake::NamedIOInterface` interface. - void inferFromNamedOpInterface(NamedIOInterface namedIO); + return "ins_" + std::to_string(idx); +} - /// Infers port names for a Handshake function. - void inferFromFuncOp(FuncOp funcOp); +inline std::string simpleResultName(unsigned idx, unsigned numResults) { + assert(idx < numResults && "index too high"); - /// List of input port names. - SmallVector inputs; - /// List of output port names. - SmallVector outputs; -}; + // TODO: Remove 2D I/O packing + // but for now this is needed + if (numResults == 1) { + return "outs"; + } -class ControlType; + return "outs_" + std::to_string(idx); +} +} // end namespace detail } // end namespace handshake } // end namespace dynamatic diff --git a/include/dynamatic/Dialect/Handshake/HandshakeInterfaces.td b/include/dynamatic/Dialect/Handshake/HandshakeInterfaces.td index 957ef78f2d..131ca38856 100644 --- a/include/dynamatic/Dialect/Handshake/HandshakeInterfaces.td +++ b/include/dynamatic/Dialect/Handshake/HandshakeInterfaces.td @@ -143,65 +143,172 @@ def MemPortOpInterface : OpInterface<"MemPortOpInterface"> { ]; } -def NamedIOInterface : OpInterface<"NamedIOInterface"> { + +//===----------------------------------------------------------------------===// +// IO Naming Interfaces +//===----------------------------------------------------------------------===// +// +// To generate the netlist, we need to know the names of +// the operands (inputs) and results (outputs) +// for each operation. +// +// These are already named inside of the Tablegen, +// but this does not generate any way to access the names as strings +// only using the names to generate functions. +// +// There are 3 interfaces: simple, binary arith, custom +// For an operation which uses the simple interface: +// The operand names are 'ins_n', where n is the index of the operand +// the result names are 'outs_n' where n is the idnex of the result +// +// If there is only 1 operand or 1 result, the name is 'ins' or 'outs', +// respectively. +// +// To use binary arith, the operation must have 2 operands and 1 result +// They will be named 'lhs', 'rhs', and 'result'. +// +// If the custom interface is used, the name uses a fully custom function +// If the naming is easy to do, +// the function definition can be added to the operation declaration in tablegen +// +// If it is hard, +// the function declaration can be added to the operation declaration in tablegen +// and the definition can be written in HandshakeOps.cpp +// +//===----------------------------------------------------------------------===// + +def SimpleNamedIOInterface : OpInterface<"SimpleNamedIOInterface"> { + let cppNamespace = "::dynamatic::handshake::detail::io"; + let description = + [{"Used by operations whose operands and results are named 'ins" and 'outs'."}]; + + let methods = [ + InterfaceMethod< + "Returns the name of a specific operand.", + "std::string", "getOperandNameImpl", (ins "unsigned" : $idx), + "", + [{ + ConcreteOp concreteOp = mlir::cast($_op); + + return simpleOperandName(idx, concreteOp->getNumOperands()); + }]>, + InterfaceMethod< + "Returns the name of a specific result.", + "std::string", "getResultNameImpl", (ins "unsigned" : $idx), + "", + [{ + ConcreteOp concreteOp = mlir::cast($_op); + + return simpleResultName(idx, concreteOp->getNumResults()); + }]> + ]; +} + +def BinaryArithNamedIOInterface : OpInterface<"BinaryArithNamedIOInterface"> { + let cppNamespace = "::dynamatic::handshake::detail::io"; + let description = + [{"Used by operations whose operands are named 'lhs' and 'rhs', and result is named 'result'."}]; + + let methods = [ + InterfaceMethod< + "Returns the name of a specific operand.", + "std::string", "getOperandNameImpl", (ins "unsigned" : $idx), + "", + [{ + assert(idx < 2 && "index too high"); + return (idx == 0) ? "lhs" : "rhs"; + }]>, + InterfaceMethod< + "Returns the name of a specific result.", + "std::string", "getResultNameImpl", (ins "unsigned" : $idx), + "", + [{ + assert(idx < 1 && "index too high"); + return "result"; + }]> + ]; +} + +def CustomNamedIOInterface : OpInterface<"CustomNamedIOInterface"> { + let cppNamespace = "::dynamatic::handshake::detail::io"; + let description = + [{"Used by operations with unique names for their operands and results."}]; + + let methods = [ + InterfaceMethod< + "Returns the name of a specific operand.", + "std::string", "getOperandNameImpl", (ins "unsigned" : $idx)>, + InterfaceMethod< + "Returns the name of a specific result.", + "std::string", "getResultNameImpl", (ins "unsigned" : $idx) + > + ]; +} + +//===----------------------------------------------------------------------===// +// Handshake Base Interface +//===----------------------------------------------------------------------===// +// +// The MLIR Operation inheritance hierarchy does not allow a shared base +// class to place shared functionality into. +// +// Instead we use this interface, present on every handshake operation, +// to implement shared functionality +// +//===----------------------------------------------------------------------===// + +def HandshakeBaseInterface : OpInterface<"HandshakeBaseInterface"> { let cppNamespace = "::dynamatic::handshake"; let description = - [{"Provides detailed names for the operands and results of an operation."}]; + [{"Base interface for shared functionality of all handshake operations."}]; let methods = [ - StaticInterfaceMethod< - "Returns the default name of a specific operand.", - "std::string", "getDefaultOperandName", (ins "unsigned" : $idx), - "", - [{ - return "ins_" + std::to_string(idx); - }]>, - StaticInterfaceMethod< - "Returns the default name of a specific result.", - "std::string", "getDefaultResultName", (ins "unsigned" : $idx), - "", - [{ - return "outs_" + std::to_string(idx); - }]>, InterfaceMethod< - "Returns the name of a specific operand.", - "std::string", "getOperandName", (ins "unsigned" : $idx), - "", - [{ - ConcreteOp concreteOp = mlir::cast($_op); - - // Operations which always have a single operand get a specific port - // name for it - if (concreteOp.template hasTrait()) { - assert(idx == 0 && "index too high"); - return "ins"; - } - - // Generic input name - assert(idx < concreteOp->getNumOperands() && "index too high"); - return getDefaultOperandName(idx); - }]>, + "Returns the name of a specific operand.", + "std::string", "getOperandName", (ins "unsigned" : $idx), + "", + [{ + auto *op = const_cast($_op.getOperation()); + + if (auto nameInterface = dyn_cast(op)) { + return nameInterface.getOperandNameImpl(idx); + } else if (auto nameInterface = + dyn_cast(op)) { + return nameInterface.getOperandNameImpl(idx); + } else if (auto nameInterface = + dyn_cast(op)) { + return nameInterface.getOperandNameImpl(idx); + } + + op->emitError() << "must specify operand names, op: " << *op; + llvm::report_fatal_error("All operation must specify IO names"); + }]>, InterfaceMethod< - "Returns the name of a specific result.", - "std::string", "getResultName", (ins "unsigned" : $idx), - "", - [{ - ConcreteOp concreteOp = mlir::cast($_op); - - // Operations which always have a single result get a specific port - // name for it - if (concreteOp.template hasTrait()) { - assert(idx == 0 && "index too high"); - return "outs"; - } - - // Generic output name - assert(idx < concreteOp->getNumResults() && "index too high"); - return getDefaultResultName(idx); - }]> + "Returns the name of a specific result.", + "std::string", "getResultName", (ins "unsigned" : $idx), + "", + [{ + auto *op = const_cast($_op.getOperation()); + + if (auto nameInterface = dyn_cast(op)) { + return nameInterface.getResultNameImpl(idx); + } else if (auto nameInterface = + dyn_cast(op)) { + return nameInterface.getResultNameImpl(idx); + } else if (auto nameInterface = + dyn_cast(op)) { + return nameInterface.getResultNameImpl(idx); + } + + op->emitError() << "must specify result names, op: " << *op; + llvm::report_fatal_error("All operation must specify IO names"); + }]> ]; } + +//===----------------------------------------------------------------------===// + def ControlInterface : OpInterface<"ControlInterface"> { let cppNamespace = "::dynamatic::handshake"; let description = diff --git a/include/dynamatic/Dialect/Handshake/HandshakeOps.h b/include/dynamatic/Dialect/Handshake/HandshakeOps.h index 849b63a89b..e1d97431ef 100644 --- a/include/dynamatic/Dialect/Handshake/HandshakeOps.h +++ b/include/dynamatic/Dialect/Handshake/HandshakeOps.h @@ -42,6 +42,8 @@ class StoreOp; class MemoryControllerOp; class LSQOp; +HandshakeBaseInterface getHandshakeBase(Operation *op); + } // end namespace handshake } // end namespace dynamatic diff --git a/include/dynamatic/Dialect/Handshake/HandshakeOps.td b/include/dynamatic/Dialect/Handshake/HandshakeOps.td index fba390a4c4..fc4a98444e 100644 --- a/include/dynamatic/Dialect/Handshake/HandshakeOps.td +++ b/include/dynamatic/Dialect/Handshake/HandshakeOps.td @@ -31,8 +31,26 @@ include "dynamatic/Dialect/Handshake/HandshakeTypes.td" class Handshake_Op traits = []> : Op, - DeclareOpInterfaceMethods, - DeclareOpInterfaceMethods]> { + HandshakeBaseInterface, + DeclareOpInterfaceMethods, +]> { + // Shared functions to include in child classes + string validateIO = [{ + //===------------------------------------------------------------------===// + // Helper Methods for CustomNamedIOInterface + //===------------------------------------------------------------------===// + + void validateOperandIdx(unsigned idx){ + if (idx >= getOperation()->getNumOperands()) + llvm::report_fatal_error("operand index too high"); + } + + void validateResultIdx(unsigned idx){ + if (idx >= getOperation()->getNumResults()) + llvm::report_fatal_error("result index too high"); + } + }]; + } // This is almost exactly like a standard FuncOp, except that it has some @@ -44,7 +62,12 @@ def FuncOp : Op { let summary = "Handshake dialect function."; let description = [{ @@ -95,12 +118,18 @@ def FuncOp : Op(); + auto names = getArgNames(); + if (idx >= names.size()) + llvm::report_fatal_error("argument index too high"); + return names[idx].cast(); } /// Returns the result name at the given index. StringAttr getResName(unsigned idx) { - return getResNames()[idx].cast(); + auto names = getResNames(); + if (idx >= names.size()) + llvm::report_fatal_error("result index too high"); + return names[idx].cast(); } /// Hook for FunctionOpInterface, called after verifying that the 'type' @@ -128,6 +157,18 @@ def FuncOp : Op { let summary = "buffer operation"; let description = [{ @@ -287,7 +330,9 @@ def BufferOp : Handshake_Op<"buffer", [ } def InitOp : Handshake_Op<"init", [ - HasClock, SameOperandsAndResultType + HasClock, + SameOperandsAndResultType, + SimpleNamedIOInterface ]> { let summary = "init operation"; let description = [{ @@ -309,7 +354,9 @@ def InitOp : Handshake_Op<"init", [ } def NDWireOp : Handshake_Op<"ndwire", [ - HasClock, SameOperandsAndResultType + HasClock, + SameOperandsAndResultType, + SimpleNamedIOInterface ]> { let summary = "non-deterministic wire operation"; let description = [{ @@ -326,7 +373,10 @@ def NDWireOp : Handshake_Op<"ndwire", [ class Handshake_ForkOp traits = []> : Handshake_Op { let arguments = (ins HandshakeType:$operand); let results = (outs Variadic:$result); @@ -382,7 +432,8 @@ def MergeOp : Handshake_Op<"merge", [ // require that the data types match. SameOperandsAndResultType, VariadicHasElement<"dataOperands">, - DeclareOpInterfaceMethods + DeclareOpInterfaceMethods, + SimpleNamedIOInterface ]> { let summary = "merge operation"; let description = [{ @@ -413,7 +464,10 @@ def MuxOp : Handshake_Op<"mux", [ // Interface declarations DeclareOpInterfaceMethods, DeclareOpInterfaceMethods, - DeclareOpInterfaceMethods + // The IOs of a MuxOp have custom names, + // defined here in the Tablegen + // as part of the MuxOp declaration + CustomNamedIOInterface ]> { let summary = "mux operation"; let description = [{ @@ -439,6 +493,27 @@ def MuxOp : Handshake_Op<"mux", [ Variadic:$dataOperands); let results = (outs HandshakeType:$result); + let extraClassDeclaration = [{ + //===------------------------------------------------------------------===// + // CustomNamedIOInterface Methods + //===------------------------------------------------------------------===// + + /// Operand 0 is named "index" + /// Operands 1 to N are named "ins_0" to "ins_" + /// simpleOperandName handles validation + std::string getOperandNameImpl(unsigned idx) { + return idx == 0 ? "index" : detail::simpleOperandName(idx - 1, getOperation()->getNumOperands() - 1); + } + + /// Result 0 is named "outs" + /// simpleResultName handles validation + std::string getResultNameImpl(unsigned idx) { + return detail::simpleResultName(idx, 1); + } + }]; + + + let hasCustomAssemblyFormat = 1; let hasVerifier = 1; } @@ -450,7 +525,10 @@ def ControlMergeOp : Handshake_Op<"control_merge", [ AllExtraSignalsMatchWithVariadic<"dataOperands", ["result", "index"]>, // Interface declarations DeclareOpInterfaceMethods, - DeclareOpInterfaceMethods + // The IOs of a CMergeOP have custom names, + // defined here in the Tablegen + // as part of the CMergeOP declaration + CustomNamedIOInterface ]> { let summary = "control merge operation"; let description = [{ @@ -482,12 +560,34 @@ def ControlMergeOp : Handshake_Op<"control_merge", [ $_state.addTypes({operands[0].getType(), idxType}); }]>]; + // include validateIO functions from handshake op + let extraClassDeclaration = validateIO # [{ + //===------------------------------------------------------------------===// + // CustomNamedIOInterface Methods + //===------------------------------------------------------------------===// + + /// Operands 0 to N are named "ins_0" to "ins_" + /// simpleOperandName handles validation + std::string getOperandNameImpl(unsigned idx) { + return detail::simpleOperandName(idx, getOperation()->getNumOperands()); + } + + /// Result 0 is named outs + /// Result 1 is named index + std::string getResultNameImpl(unsigned idx) { + validateResultIdx(idx); + return idx == 0 ? "outs" : "index"; + } + }]; + let hasCustomAssemblyFormat = 1; let hasVerifier = 1; } def BranchOp : Handshake_Op<"br", [ - Pure, SameOperandsAndResultType + Pure, + SameOperandsAndResultType, + SimpleNamedIOInterface ]> { let summary = "branch operation"; let description = [{ @@ -516,7 +616,10 @@ def ConditionalBranchOp : Handshake_Op<"cond_br", [ IsIntSizedChannel<1, "conditionOperand">, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods, - DeclareOpInterfaceMethods + // The IOs of a ConditionalBranch have custom names, + // defined here in the Tablegen + // as part of the ConditionalBranch declaration + CustomNamedIOInterface ]> { let summary = "conditional branch operation"; let description = [{ @@ -541,13 +644,35 @@ def ConditionalBranchOp : Handshake_Op<"cond_br", [ $conditionOperand `,` $dataOperand attr-dict `:` type($conditionOperand) `,` custom(type($dataOperand)) }]; - let extraClassDeclaration = [{ + + // include validateIO functions from handshake op + let extraClassDeclaration = validateIO # [{ + //===------------------------------------------------------------------===// + // CustomNamedIOInterface Methods + //===------------------------------------------------------------------===// + // These are the indices into the dests list. enum { trueIndex = 0, falseIndex = 1 }; + + /// Operand 0 is named condition + /// Operand 1 is named data + std::string getOperandNameImpl(unsigned idx) { + validateOperandIdx(idx); + return idx == 0 ? "condition" : "data"; + } + + /// Operand 0 is named trueOut + /// Operand 1 is named falseOut + std::string getResultNameImpl(unsigned idx) { + validateResultIdx(idx); + return idx == trueIndex ? "trueOut" : "falseOut"; + } }]; } -def SinkOp : Handshake_Op<"sink"> { +def SinkOp : Handshake_Op<"sink", [ + SimpleNamedIOInterface +]> { let summary = "sink operation"; let description = [{ The sink operation discards any data that arrives at its @@ -566,7 +691,10 @@ def SinkOp : Handshake_Op<"sink"> { }]; } -def SourceOp : Handshake_Op<"source", [Pure]> { +def SourceOp : Handshake_Op<"source", [ + Pure, + SimpleNamedIOInterface +]> { let summary = "source operation"; let description = [{ The source operation represents a continuous control-only-token source. The @@ -594,7 +722,8 @@ def SourceOp : Handshake_Op<"source", [Pure]> { def JoinOp : Handshake_Op<"join", [ SameOperandsAndResultType, - DeclareOpInterfaceMethods + DeclareOpInterfaceMethods, + SimpleNamedIOInterface ]> { let summary = "join operation"; let description = [{ @@ -616,7 +745,8 @@ def JoinOp : Handshake_Op<"join", [ // TODO: Split the input into two parts and explicitly mark the forwarded input, // once the backend supports variadic channels with zero or one item correctly. def BlockerOp : Handshake_Op<"blocker", [ - SameOperandsAndResultType + SameOperandsAndResultType, + SimpleNamedIOInterface ]> { let summary = "blocker operation"; let description = [{ @@ -635,7 +765,11 @@ def BlockerOp : Handshake_Op<"blocker", [ let assemblyFormat = "$data attr-dict `:` type($result)"; } -def NotOp : Handshake_Op<"not", [Pure, SameOperandsAndResultType]> { +def NotOp : Handshake_Op<"not", [ + Pure, + SameOperandsAndResultType, + SimpleNamedIOInterface +]> { let summary = "Logical negation"; let description = [{ Bitwise logical negation. @@ -665,8 +799,10 @@ def MemoryControllerOp : Handshake_Op<"mem_controller", [ IsSimpleHandshakeVariadic<"inputs">, IsSimpleHandshakeVariadic<"outputs">, DeclareOpInterfaceMethods, - DeclareOpInterfaceMethods, - DeclareOpInterfaceMethods + // The IOs of a MemoryControllerOp have custom names, + // As they are complex, they are declared here, + // but defined in HandshakeOps.cpp + CustomNamedIOInterface ]> { let summary = "memory controller (dynamatic)"; let description = [{ @@ -720,6 +856,7 @@ def MemoryControllerOp : Handshake_Op<"mem_controller", [ "ValueRange":$inputs, "Value":$ctrlEnd, "ArrayRef":$blocks, "unsigned":$numLoads)>]; + let hasVerifier = 1; // Dispatch SimpleControl signals to custom print and parse @@ -740,13 +877,23 @@ def MemoryControllerOp : Handshake_Op<"mem_controller", [ /// Returns a convenient data-structure to go over the controls and memory /// accesses that are connected to the memory controller. dynamatic::MCPorts getPorts(); + + + //===------------------------------------------------------------------===// + // CustomNamedIOInterface Methods + //===------------------------------------------------------------------===// + + std::string getOperandNameImpl(unsigned idx); + std::string getResultNameImpl(unsigned idx); }]; } def LSQOp : Handshake_Op<"lsq", [ DeclareOpInterfaceMethods, - DeclareOpInterfaceMethods, - DeclareOpInterfaceMethods + // The IOs of a MemoryControllerOp have custom names, + // As they are complex, they are declared here, + // but defined in HandshakeOps.cpp + CustomNamedIOInterface ]> { let summary = "load-store queue (dynamatic)"; let description = [{ @@ -815,7 +962,7 @@ def LSQOp : Handshake_Op<"lsq", [ /// Returns a convenient data-structure to go over the controls and memory /// accesses that are connected to the LSQ. - dynamatic::LSQPorts getPorts(); + dynamatic::LSQPorts getPorts(); /// Determines whether the LSQ is connected to an MC. bool isConnectedToMC() { @@ -835,6 +982,13 @@ def LSQOp : Handshake_Op<"lsq", [ /// materialized when called. The returned values are guaranteed to be in /// the same order as the control operation's results. SmallVector getControlPaths(Operation *ctrlOp); + + //===------------------------------------------------------------------===// + // CustomNamedIOInterface Methods + //===------------------------------------------------------------------===// + + std::string getOperandNameImpl(unsigned idx); + std::string getResultNameImpl(unsigned idx); }]; } @@ -848,7 +1002,6 @@ class Handshake_MemPortOp< MemPortOpInterface, AllDataTypesMatch<["address", "addressResult"]>, AllDataTypesMatch<["data", "dataResult"]>, - DeclareOpInterfaceMethods, IsIntChannel<"address">, IsIntChannel<"addressResult"> ] @@ -876,7 +1029,11 @@ def LoadOp : Handshake_MemPortOp<"load", [ AllExtraSignalsMatch<["address", "dataResult"]>, // In LoadOp, addressResult and data are connected to a memory controller. IsSimpleHandshake<"addressResult">, - IsSimpleHandshake<"data"> + IsSimpleHandshake<"data">, + // The IOs of a LoadOp have custom names, + // defined here in the Tablegen + // as part of the LoadOp declaration + CustomNamedIOInterface ], [ // TODO: Please add comments on why this builder is necessary OpBuilder<(ins "MemRefType":$memrefType, "Value":$address), [{ @@ -913,14 +1070,23 @@ def LoadOp : Handshake_MemPortOp<"load", [ ``` }]; - let extraClassDefinition = [{ - std::string $cppClass::getOperandName(unsigned idx) { - assert(idx < getNumOperands() && "index too high"); + // include validateIO functions from handshake op + let extraClassDeclaration = validateIO # [{ + //===------------------------------------------------------------------===// + // CustomNamedIOInterface Methods + //===------------------------------------------------------------------===// + + /// Operand 0 is named addrIn + /// Operand 1 is named dataFromMem + std::string getOperandNameImpl(unsigned idx) { + validateOperandIdx(idx); return (idx == 0) ? "addrIn" : "dataFromMem"; } - std::string $cppClass::getResultName(unsigned idx) { - assert(idx < getNumResults() && "index too high"); + /// Result 0 is named addrOut + /// Result 1 is named dataOut + std::string getResultNameImpl(unsigned idx) { + validateResultIdx(idx); return (idx == 0) ? "addrOut" : "dataOut"; } }]; @@ -930,7 +1096,11 @@ def StoreOp : Handshake_MemPortOp<"store", [ AllExtraSignalsMatch<["address", "data"]>, // In StoreOp, addressResult and dataResult are connected to a memory controller. IsSimpleHandshake<"addressResult">, - IsSimpleHandshake<"dataResult"> + IsSimpleHandshake<"dataResult">, + // The IOs of a StoreOp have custom names, + // defined here in the Tablegen + // as part of the StoreOp declaration + CustomNamedIOInterface ], []> { let summary = "store operation for memory controller (MC)"; let description = [{ @@ -954,12 +1124,23 @@ def StoreOp : Handshake_MemPortOp<"store", [ ``` }]; - let extraClassDefinition = [{ - std::string $cppClass::getOperandName(unsigned int idx) { + // include validateIO functions from handshake op + let extraClassDeclaration = validateIO # [{ + //===------------------------------------------------------------------===// + // CustomNamedIOInterface Methods + //===------------------------------------------------------------------===// + + /// Operand 0 is named addrIn + /// Operand 1 is named dataIn + std::string getOperandNameImpl(unsigned idx) { + validateOperandIdx(idx); return (idx == 0) ? "addrIn" : "dataIn"; } - std::string $cppClass::getResultName(unsigned int idx) { + /// Result 0 is named addrOut + /// Result 1 is named dataToMem + std::string getResultNameImpl(unsigned idx) { + validateResultIdx(idx); return (idx == 0) ? "addrOut" : "dataToMem"; } }]; @@ -1019,8 +1200,8 @@ def RAMOp : Handshake_Op<"ram", []> { //===----------------------------------------------------------------------===// def EndOp : Handshake_Op<"end", [ - DeclareOpInterfaceMethods, - Terminator + Terminator, + SimpleNamedIOInterface ]> { let summary = "function endpoint (dynamatic)"; let description = [{ @@ -1062,7 +1243,10 @@ def SpeculatorOp : Handshake_Op<"speculator", [ IsIntSizedChannel<3, "SCCommitCtrl">, IsSimpleHandshake<"SCIsMisspec">, IsIntSizedChannel<1, "SCIsMisspec">, - DeclareOpInterfaceMethods + // The IOs of the SpeculatorOp have custom names + // defined here in the Tablegen + // as part of the SpeculatorOp declaration + CustomNamedIOInterface ]> { let summary = "Central control unit of the speculative circuit."; let description = [{ @@ -1113,14 +1297,27 @@ def SpeculatorOp : Handshake_Op<"speculator", [ wideControlType, ctrlType}); }]>]; - let extraClassDefinition = [{ - std::string $cppClass::getOperandName(unsigned idx) { - assert(idx < getNumOperands() && "index too high"); + // include validateIO functions from handshake op + let extraClassDeclaration = validateIO # [{ + //===------------------------------------------------------------------===// + // CustomNamedIOInterface Methods + //===------------------------------------------------------------------===// + + /// Operand 0 is named ins + /// Operand 1 is named trigger + std::string getOperandNameImpl(unsigned idx) { + validateOperandIdx(idx); return idx == 0 ? "ins" : "trigger"; } - std::string $cppClass::getResultName(unsigned idx) { - assert(idx < getNumResults() && "index too high"); + /// Result 0 is named outs + /// Result 1 is named ctrl_save + /// Result 2 is named ctrl_commit + /// Result 3 is named ctrl_sc_save + /// Result 4 is named ctrl_sc_commit + /// Result 5 is named ctrl_sc_branch + std::string getResultNameImpl(unsigned idx) { + validateResultIdx(idx); switch (idx) { case 0: return "outs"; @@ -1146,7 +1343,10 @@ def SpecSaveOp : Handshake_Op<"spec_save", [ HasValidSpecTag<"dataOut">, IsSimpleHandshake<"ctrl">, IsIntSizedChannel<1, "ctrl">, - DeclareOpInterfaceMethods + // The IOs of a SpecSaveOp have custom names, + // defined here in the Tablegen + // as part of the SpecSaveOp declaration + CustomNamedIOInterface ]> { let summary = "Saves data tokens that interact in the speculative region."; let description = [{ @@ -1172,14 +1372,22 @@ def SpecSaveOp : Handshake_Op<"spec_save", [ `[` $ctrl `]` $dataIn attr-dict `:` type($dataIn) `,` type($dataOut) `,` type($ctrl) }]; - let extraClassDefinition = [{ - std::string $cppClass::getOperandName(unsigned idx) { - assert(idx < getOperation()->getNumOperands() && "index too high"); + // include validateIO functions from handshake op + let extraClassDeclaration = validateIO # [{ + //===------------------------------------------------------------------===// + // CustomNamedIOInterface Methods + //===------------------------------------------------------------------===// + + /// Operand 0 is named ins + /// Operand 1 is named ctrl + std::string getOperandNameImpl(unsigned idx) { + validateOperandIdx(idx); return idx == 0 ? "ins" : "ctrl"; } - std::string $cppClass::getResultName(unsigned idx) { - assert(idx < getOperation()->getNumResults() && "index too high"); + /// Result 0 is named outs + std::string getResultNameImpl(unsigned idx) { + validateResultIdx(idx); return "outs"; } }]; @@ -1192,7 +1400,10 @@ def SpecCommitOp : Handshake_Op<"spec_commit", [ LacksSpecTag<"dataOut">, IsSimpleHandshake<"ctrl">, IsIntSizedChannel<1, "ctrl">, - DeclareOpInterfaceMethods + // The IOs of a SpecCommitOp have custom names, + // defined here in the Tablegen + // as part of the SpecCommitOp declaration + CustomNamedIOInterface ]> { let summary = "Stall speculative data tokens until they are resolved."; let description = [{ @@ -1220,14 +1431,21 @@ def SpecCommitOp : Handshake_Op<"spec_commit", [ `[` $ctrl `]` $dataIn attr-dict `:` type($dataIn) `,` type($dataOut) `,` type($ctrl) }]; - let extraClassDefinition = [{ - std::string $cppClass::getOperandName(unsigned idx) { - assert(idx < getOperation()->getNumOperands() && "index too high"); + let extraClassDeclaration = validateIO # [{ + //===------------------------------------------------------------------===// + // CustomNamedIOInterface Methods + //===------------------------------------------------------------------===// + + /// Operand 0 is named ins + /// Operand 1 is named ctrl + std::string getOperandNameImpl(unsigned idx) { + validateOperandIdx(idx); return idx == 0 ? "ins" : "ctrl"; } - std::string $cppClass::getResultName(unsigned idx) { - assert(idx < getOperation()->getNumResults() && "index too high"); + /// Result 0 is named result + std::string getResultNameImpl(unsigned idx) { + validateResultIdx(idx); return "outs"; } }]; @@ -1239,7 +1457,10 @@ def SpecSaveCommitOp : Handshake_Op<"spec_save_commit", [ AllTypesMatch<["dataIn", "dataOut"]>, IsSimpleHandshake<"ctrl">, IsIntSizedChannel<3, "ctrl">, - DeclareOpInterfaceMethods + // The IOs of a SpecSaveCommitOp have custom names, + // defined here in the Tablegen + // as part of the SpecCommitOp declaration + CustomNamedIOInterface ]> { let summary = "Lets all tokens pass and saves a copy of them."; let description = [{ @@ -1262,14 +1483,22 @@ def SpecSaveCommitOp : Handshake_Op<"spec_save_commit", [ `[` $ctrl `]` $dataIn attr-dict `:` type($dataIn) `,` type($ctrl) }]; - let extraClassDefinition = [{ - std::string $cppClass::getOperandName(unsigned idx) { - assert(idx < getOperation()->getNumOperands() && "index too high"); + // include validateIO functions from handshake op + let extraClassDeclaration = validateIO # [{ + //===------------------------------------------------------------------===// + // CustomNamedIOInterface Methods + //===------------------------------------------------------------------===// + + /// Operand 0 is named ins + /// Operand 1 is named ctrl + std::string getOperandNameImpl(unsigned idx) { + validateOperandIdx(idx); return idx == 0 ? "ins" : "ctrl"; } - std::string $cppClass::getResultName(unsigned idx) { - assert(idx < getOperation()->getNumResults() && "index too high"); + /// Result 0 is named result + std::string getResultNameImpl(unsigned idx) { + validateResultIdx(idx); return "outs"; } }]; @@ -1282,7 +1511,10 @@ def SpeculatingBranchOp : Handshake_Op<"speculating_branch", [ HasValidSpecTag<"dataOperand">, LacksSpecTag<"trueResult">, LacksSpecTag<"falseResult">, - DeclareOpInterfaceMethods + // The IOs of a SpeculatingBranchOp have custom names, + // defined here in the Tablegen + // as part of the SpeculatingBranchOp declaration + CustomNamedIOInterface ]> { let summary = "speculating branch operation"; let description = [{ @@ -1310,14 +1542,24 @@ def SpeculatingBranchOp : Handshake_Op<"speculating_branch", [ type($trueResult) `,` type($falseResult) }]; - let extraClassDefinition = [{ - std::string $cppClass::getOperandName(unsigned idx) { - assert(idx < getNumOperands() && "index too high"); + + // include validateIO functions from handshake op + let extraClassDeclaration = validateIO # [{ + //===------------------------------------------------------------------===// + // CustomNamedIOInterface Methods + //===------------------------------------------------------------------===// + + /// Operand 0 is named spec_tag_data + /// Operand 1 is named data + std::string getOperandNameImpl(unsigned idx) { + validateOperandIdx(idx); return idx == 0 ? "spec_tag_data" : "data"; } - std::string $cppClass::getResultName(unsigned idx) { - assert(idx < getNumResults() && "index too high"); + /// Result 0 is named trueOut + /// Result 1 is named falseOut + std::string getResultNameImpl(unsigned idx) { + validateResultIdx(idx); return idx == 0 ? "trueOut" : "falseOut"; } }]; @@ -1328,7 +1570,10 @@ def NonSpecOp : Handshake_Op<"non_spec", [ AllExtraSignalsMatchExcept<"spec", ["dataIn", "dataOut"]>, LacksSpecTag<"dataIn">, HasValidSpecTag<"dataOut">, - DeclareOpInterfaceMethods + // The IOs of a NonSpecOp have custom names, + // defined here in the Tablegen + // as part of the NonSpecOp declaration + CustomNamedIOInterface ]> { let summary = "Adds a non-speculative spec bit."; let description = [{ @@ -1350,14 +1595,22 @@ def NonSpecOp : Handshake_Op<"non_spec", [ $dataIn attr-dict `:` type($dataIn) `to` type($dataOut) }]; - let extraClassDefinition = [{ - std::string $cppClass::getOperandName(unsigned idx) { - assert(idx < getOperation()->getNumOperands() && "index too high"); + + // include validateIO functions from handshake op + let extraClassDeclaration = validateIO # [{ + //===------------------------------------------------------------------===// + // CustomNamedIOInterface Methods + //===------------------------------------------------------------------===// + + /// Operand 0 is named dataIn + std::string getOperandNameImpl(unsigned idx) { + validateOperandIdx(idx); return "dataIn"; } - std::string $cppClass::getResultName(unsigned idx) { - assert(idx < getOperation()->getNumResults() && "index too high"); + /// Result 0 is named dataOut + std::string getResultNameImpl(unsigned idx) { + validateResultIdx(idx); return "dataOut"; } }]; @@ -1369,7 +1622,10 @@ def NonSpecOp : Handshake_Op<"non_spec", [ def SharingWrapperOp : Handshake_Op<"sharing_wrapper", [ SameOperandsAndResultType, - DeclareOpInterfaceMethods, + // The IOs of a SharingWrapperOp have custom names, + // defined here in the Tablegen + // as part of the SharingWrapperOp declaration + CustomNamedIOInterface ]> { let summary = "sharing wrapper operation"; @@ -1414,6 +1670,33 @@ def SharingWrapperOp : Handshake_Op<"sharing_wrapper", [ ConfinedAttr]>:$latency); let results = (outs Variadic : $dataOut); + // include validateIO functions from handshake op + let extraClassDeclaration = validateIO # [{ + //===------------------------------------------------------------------===// + // CustomNamedIOInterface Methods + //===------------------------------------------------------------------===// + + /// The first set of operands are named op__in_ + /// The last operand is named fromSharedUnitOut0 + std::string getOperandNameImpl(unsigned idx) { + validateOperandIdx(idx); + if (idx < getNumSharedOperands() * getNumSharedOperations()) { + return "op" + std::to_string(idx / getNumSharedOperands()) + "in" + + std::to_string(idx % getNumSharedOperands()); + } + return "fromSharedUnitOut0"; + } + + /// The first set of results are named op__out_0 + /// The last set of results are named toSharedUnitIn_ + std::string getResultNameImpl(unsigned idx) { + validateResultIdx(idx); + if (idx < getNumSharedOperations()) + return "op" + std::to_string(idx) + "out0"; + return "toSharedUnitIn" + std::to_string(idx - getNumSharedOperations()); + } + }]; + let assemblyFormat = [{ `[` $dataOperands `]` `,` `[` $sharedOpResult `]` attr-dict `:` functional-type(operands, results) @@ -1550,7 +1833,7 @@ def UnbundleOp : Handshake_Op<"unbundle", [ def ReadyRemoverOp : Handshake_Op<"ready_remover",[ SameOperandsAndResultType, - DeclareOpInterfaceMethods + SimpleNamedIOInterface ]> { let summary = [{ Rigidifies a channel to simplify the handshake logic. @@ -1567,22 +1850,15 @@ def ReadyRemoverOp : Handshake_Op<"ready_remover",[ $channelIn attr-dict `:` custom(type($channelIn)) }]; - let extraClassDefinition = [{ - std::string $cppClass::getOperandName(unsigned idx = 0) { - return "ins"; - } - - std::string $cppClass::getResultName(unsigned idx = 0) { - return "outs"; - } - }]; - } def ValidMergerOp : Handshake_Op<"valid_merger",[ AllTypesMatch<["lhsIn", "lhsOut"]>, AllTypesMatch<["rhsIn", "rhsOut"]>, - DeclareOpInterfaceMethods + // The IOs of a ValidMergerOp have custom names, + // defined here in the Tablegen + // as part of the ValidMergerOp declaration + CustomNamedIOInterface ]> { let summary = [{ Merges the valid signals of two channels. @@ -1600,15 +1876,24 @@ def ValidMergerOp : Handshake_Op<"valid_merger",[ $lhsIn `,` $rhsIn attr-dict `:` custom(type($lhsIn)) `,` custom(type($rhsIn)) }]; - let extraClassDefinition = [{ - std::string $cppClass::getOperandName(unsigned idx) { - assert(idx < getNumOperands() && "index too high"); - return (idx == 0) ? "lhs_ins" : "rhs_ins"; + // include validateIO functions from handshake op + let extraClassDeclaration = validateIO # [{ + //===------------------------------------------------------------------===// + // CustomNamedIOInterface Methods + //===------------------------------------------------------------------===// + + /// Operand 0 is named lhs_ins + /// Operand 1 is named rhs_ins + std::string getOperandNameImpl(unsigned idx) { + validateOperandIdx(idx); + return (idx == 0) ? "lhs_ins" : "rhs_ins"; } - std::string $cppClass::getResultName(unsigned idx) { - assert(idx < getNumOperands() && "index too high"); - return (idx == 0) ? "lhs_outs" : "rhs_outs"; + /// Result 0 is named lhs_outs + /// Result 1 is named rhs_outs + std::string getResultNameImpl(unsigned idx) { + validateResultIdx(idx); + return (idx == 0) ? "lhs_outs" : "rhs_outs"; } }]; diff --git a/lib/Analysis/NameAnalysis.cpp b/lib/Analysis/NameAnalysis.cpp index f014bac6f9..38b25872ab 100644 --- a/lib/Analysis/NameAnalysis.cpp +++ b/lib/Analysis/NameAnalysis.cpp @@ -11,6 +11,7 @@ //===----------------------------------------------------------------------===// #include "dynamatic/Analysis/NameAnalysis.h" +#include "dynamatic/Dialect/Handshake/HandshakeInterfaces.h" #include "dynamatic/Dialect/Handshake/HandshakeOps.h" #include "dynamatic/Support/LLVM.h" #include "mlir/Dialect/Func/IR/FuncOps.h" @@ -26,6 +27,7 @@ using namespace mlir; using namespace dynamatic; +using namespace handshake; /// Shortcut to get the name attribute of an operation. inline static mlir::StringAttr getNameAttr(Operation *op) { @@ -92,24 +94,6 @@ static bool tryToGetBlockArgName(BlockArgument arg, StringRef parentOpName, }); } -/// Returns the name of a result which is either provided by the -/// handshake::NamedIOInterface interface or, failing that, is its index. -static std::string getResultName(Operation *op, size_t resIdx) { - std::string oprName; - if (auto namedIO = dyn_cast(op)) - return namedIO.getResultName(resIdx); - return std::to_string(resIdx); -} - -/// Returns the name of an operand which is either provided by the -/// handshake::NamedIOInterface interface or, failing that, is its index. -static std::string getOperandName(Operation *op, size_t oprdIdx) { - std::string oprName; - if (auto namedIO = dyn_cast(op)) - return namedIO.getOperandName(oprdIdx); - return std::to_string(oprdIdx); -} - StringRef NameAnalysis::getName(Operation *op) { assert(namesValid && "analysis invariant is broken"); // If the operation already has a name or is intrinsically named , do nothing @@ -135,7 +119,8 @@ std::string NameAnalysis::getName(OpOperand &oprd) { Value val = oprd.get(); if (Operation *defOp = val.getDefiningOp()) { defName = getName(defOp); - resName = getResultName(defOp, cast(val).getResultNumber()); + auto handshakeOp = getHandshakeBase(defOp); + resName = handshakeOp.getResultName(cast(val).getResultNumber()); } else { getBlockArgName(cast(val), defName, resName); } @@ -144,7 +129,8 @@ std::string NameAnalysis::getName(OpOperand &oprd) { std::string userName, oprName; Operation *userOp = oprd.getOwner(); userName = getName(userOp); - oprName = getOperandName(userOp, oprd.getOperandNumber()); + auto handshakeOp = getHandshakeBase(userOp); + oprName = handshakeOp.getOperandName(oprd.getOperandNumber()); return defName + "_" + resName + "_" + oprName + "_" + userName; } @@ -310,7 +296,8 @@ std::string dynamatic::getUniqueName(OpOperand &oprd) { if (Operation *defOp = val.getDefiningOp()) { if (mlir::StringAttr attr = getNameAttr(defOp)) { defName = attr.str(); - resName = getResultName(defOp, cast(val).getResultNumber()); + auto handshakeOp = getHandshakeBase(defOp); + resName = handshakeOp.getResultName(cast(val).getResultNumber()); } else { return ""; } @@ -326,7 +313,8 @@ std::string dynamatic::getUniqueName(OpOperand &oprd) { Operation *userOp = oprd.getOwner(); if (mlir::StringAttr attr = getNameAttr(userOp)) { userName = attr.str(); - oprName = getOperandName(userOp, oprd.getOperandNumber()); + auto handshakeOp = getHandshakeBase(userOp); + oprName = handshakeOp.getOperandName(oprd.getOperandNumber()); } else { return ""; } diff --git a/lib/Conversion/HandshakeToHW/HandshakeToHW.cpp b/lib/Conversion/HandshakeToHW/HandshakeToHW.cpp index 3cb605cca5..de468aac07 100644 --- a/lib/Conversion/HandshakeToHW/HandshakeToHW.cpp +++ b/lib/Conversion/HandshakeToHW/HandshakeToHW.cpp @@ -167,9 +167,7 @@ struct MemLoweringState { /// Cache memory port information before modifying the interface, which can /// make them impossible to query. FuncMemoryPorts ports; - /// Generates and stores the interface's port names before starting the - /// conversion, when those are still queryable. - handshake::PortNamer portNames; + /// Backedges to the containing module's `hw::OutputOp` operation, which /// must be set, in order, with the memory interface's results that connect /// to the top-level module IO. @@ -186,7 +184,7 @@ struct MemLoweringState { /// Needed because we use the class as a value type in a map, which needs to /// be default-constructible. - MemLoweringState() : ports(nullptr), portNames(nullptr) { + MemLoweringState() : ports(nullptr) { llvm_unreachable("object should never be default-constructed"); } @@ -194,7 +192,7 @@ struct MemLoweringState { MemLoweringState(handshake::MemoryOpInterface memOp, const Twine &name) : name(name.str()), dataType(lowerType(memOp.getMemRefType().getElementType())), - ports(getMemoryPorts(memOp)), portNames(memOp) { + ports(getMemoryPorts(memOp)) { assert(dataType && "unsupported memory element type"); }; @@ -221,20 +219,17 @@ struct InternalMemLoweringState { handshake::MemoryOpInterface memInterface; FuncMemoryPorts ports; - handshake::PortNamer portNames; - /// Needed because we use the class as a value type in a map, which needs to /// be default-constructible. InternalMemLoweringState() - : ramOp(nullptr), memInterface(nullptr), ports(nullptr), - portNames(nullptr) { + : ramOp(nullptr), memInterface(nullptr), ports(nullptr) { llvm_unreachable("object should never be default-constructed"); } InternalMemLoweringState(handshake::RAMOp ramOp, handshake::MemoryOpInterface memInterface) : ramOp(ramOp), memInterface(memInterface), - ports(getMemoryPorts(memInterface)), portNames(memInterface) {}; + ports(getMemoryPorts(memInterface)) {}; }; /// Summarizes information to convert a Handshake function into a @@ -1154,11 +1149,10 @@ static void addMemIO(ModuleBuilder &modBuilder, handshake::FuncOp funcOp, hw::ModulePortInfo getFuncPortInfo(handshake::FuncOp funcOp, ModuleLoweringState &state) { ModuleBuilder modBuilder(funcOp.getContext()); - handshake::PortNamer portNames(funcOp); // Add all function outputs to the module for (auto [idx, res] : llvm::enumerate(funcOp.getResultTypes())) - modBuilder.addOutput(portNames.getOutputName(idx), lowerType(res)); + modBuilder.addOutput(funcOp.getResName(idx).getValue(), lowerType(res)); // Add all function inputs to the module, expanding memory references into a // set of individual ports for loads and stores @@ -1168,7 +1162,7 @@ hw::ModulePortInfo getFuncPortInfo(handshake::FuncOp funcOp, if (TypedValue memref = dyn_cast>(arg)) addMemIO(modBuilder, funcOp, memref, argName, state); else - modBuilder.addInput(portNames.getInputName(idx), lowerType(type)); + modBuilder.addInput(argName.getValue(), lowerType(type)); } modBuilder.addClkAndRst(); @@ -1320,11 +1314,10 @@ LogicalResult ConvertExternalFunc::matchAndRewrite( StringAttr name = rewriter.getStringAttr(funcOp.getName()); ModuleBuilder modBuilder(funcOp.getContext()); - handshake::PortNamer portNames(funcOp); // Add all function outputs to the module for (auto [idx, res] : llvm::enumerate(funcOp.getResultTypes())) - modBuilder.addOutput(portNames.getOutputName(idx), lowerType(res)); + modBuilder.addOutput(funcOp.getResName(idx).getValue(), lowerType(res)); // Add all function inputs to the module for (auto [idx, type] : llvm::enumerate(funcOp.getArgumentTypes())) { @@ -1333,7 +1326,7 @@ LogicalResult ConvertExternalFunc::matchAndRewrite( << "Memory interfaces are not supported for external " "functions"; } - modBuilder.addInput(portNames.getInputName(idx), lowerType(type)); + modBuilder.addInput(funcOp.getArgName(idx).getValue(), lowerType(type)); } modBuilder.addClkAndRst(); @@ -1409,15 +1402,19 @@ LogicalResult ConvertMemInterface::matchAndRewrite( for (auto [port, arg] : llvm::zip_equal(inputModPorts, memArgs)) converter.addInput(removePortNamePrefix(port), arg); for (auto [idx, oprd] : llvm::enumerate(operands)) { - if (!isa(oprd.getType())) - converter.addInput(memState.portNames.getInputName(idx), oprd); + if (!isa(oprd.getType())) { + auto handshakeOp = handshake::getHandshakeBase(memOp); + converter.addInput(handshakeOp.getOperandName(idx), oprd); + } } converter.addClkAndRst(parentModOp); // The HW instance will be connected to the top-level module through a // number of output ports, add those last after the regular interface ports for (auto [idx, res] : llvm::enumerate(memOp->getResults())) { - converter.addOutput(memState.portNames.getOutputName(idx), + auto handshakeOp = handshake::getHandshakeBase(memOp); + + converter.addOutput(handshakeOp.getResultName(idx), lowerType(res.getType())); } auto outputModPorts = memState.getMemOutputPorts(parentModOp); @@ -1534,17 +1531,18 @@ LogicalResult ConvertMemInterfaceForInternalArray::matchAndRewrite( memInterfaceConverter.addInput("loadData", bramInstanceOp.getResult(0)); } - // Add the ports from handshake op (here we use the port namer to name the - // ports that are directly converted from handshake op), except for the memref - // type. + // Add the ports from handshake op for (auto [i, oprd] : llvm::enumerate(operands)) { - if (!isa(oprd.getType())) - memInterfaceConverter.addInput(memState.portNames.getInputName(i), oprd); + if (!isa(oprd.getType())){ + auto handshakeOp = handshake::getHandshakeBase(memOp); + memInterfaceConverter.addInput(handshakeOp.getOperandName(i), oprd); + } } memInterfaceConverter.addClkAndRst(parentModOp); for (auto [idx, res] : llvm::enumerate(memOp->getResults())) { - memInterfaceConverter.addOutput(memState.portNames.getOutputName(idx), + auto handshakeOp = handshake::getHandshakeBase(memOp); + memInterfaceConverter.addOutput(handshakeOp.getResultName(idx), lowerType(res.getType())); } @@ -1592,17 +1590,20 @@ template LogicalResult ConvertToHWInstance::matchAndRewrite( T op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { HWConverter converter(this->getContext()); - handshake::PortNamer portNames(op); // Add all operation operands to the inputs - for (auto [idx, oprd] : llvm::enumerate(adaptor.getOperands())) - converter.addInput(portNames.getInputName(idx), oprd); + for (auto [idx, oprd] : llvm::enumerate(adaptor.getOperands())){ + auto handshakeOp = handshake::getHandshakeBase(op); + converter.addInput(handshakeOp.getOperandName(idx), oprd); + } converter.addClkAndRst(((Operation *)op)->getParentOfType()); // Add all operation results to the outputs - for (auto [idx, type] : llvm::enumerate(op->getResultTypes())) - converter.addOutput(portNames.getOutputName(idx), lowerType(type)); + for (auto [idx, type] : llvm::enumerate(op->getResultTypes())){ + auto handshakeOp = handshake::getHandshakeBase(op); + converter.addOutput(handshakeOp.getResultName(idx), lowerType(type)); + } hw::InstanceOp instOp = converter.convertToInstance(op, rewriter); return instOp ? success() : failure(); } diff --git a/lib/Dialect/Handshake/HandshakeInterfaces.cpp b/lib/Dialect/Handshake/HandshakeInterfaces.cpp index 8bd260b861..b6615dbda2 100644 --- a/lib/Dialect/Handshake/HandshakeInterfaces.cpp +++ b/lib/Dialect/Handshake/HandshakeInterfaces.cpp @@ -28,309 +28,6 @@ using namespace mlir; using namespace dynamatic; using namespace dynamatic::handshake; -//===----------------------------------------------------------------------===// -// PortNameGenerator (uses NamedIOInterface) -//===----------------------------------------------------------------------===// - -PortNamer::PortNamer(Operation *op) { - assert(op && "cannot generate port names for null operation"); - if (auto namedOpInterface = dyn_cast(op)) - inferFromNamedOpInterface(namedOpInterface); - else if (auto funcOp = dyn_cast(op)) - inferFromFuncOp(funcOp); - else - inferDefault(op); -} - -void PortNamer::infer(Operation *op, IdxToStrF &inF, IdxToStrF &outF) { - for (size_t idx = 0, e = op->getNumOperands(); idx < e; ++idx) - inputs.push_back(inF(idx)); - for (size_t idx = 0, e = op->getNumResults(); idx < e; ++idx) - outputs.push_back(outF(idx)); - - // The Handshake terminator forwards its non-memory inputs to its outputs, so - // it needs port names for them - if (handshake::EndOp endOp = dyn_cast(op)) { - handshake::FuncOp funcOp = endOp->getParentOfType(); - assert(funcOp && "end must be child of handshake function"); - size_t numResults = funcOp.getFunctionType().getNumResults(); - for (size_t idx = 0, e = numResults; idx < e; ++idx) - outputs.push_back(endOp.getDefaultResultName(idx)); - } -} - -void PortNamer::inferDefault(Operation *op) { - llvm::TypeSwitch(op) - .Case([&](auto) { - infer( - op, [](unsigned idx) { return idx == 0 ? "lhs" : "rhs"; }, - [](unsigned idx) { return "result"; }); - }) - .Case( - [&](auto) { - infer( - op, [](unsigned idx) { return "ins"; }, - [](unsigned idx) { return "outs"; }); - }) - .Case([&](auto) { - infer( - op, - [](unsigned idx) { - if (idx == 0) - return "condition"; - if (idx == 1) - return "trueValue"; - return "falseValue"; - }, - [](unsigned idx) { return "result"; }); - }) - .Default([&](auto) { - infer( - op, [](unsigned idx) { return "in" + std::to_string(idx); }, - [](unsigned idx) { return "out" + std::to_string(idx); }); - }); -} - -void PortNamer::inferFromNamedOpInterface(handshake::NamedIOInterface namedIO) { - auto inF = [&](unsigned idx) { return namedIO.getOperandName(idx); }; - auto outF = [&](unsigned idx) { return namedIO.getResultName(idx); }; - infer(namedIO, inF, outF); -} - -void PortNamer::inferFromFuncOp(handshake::FuncOp funcOp) { - llvm::transform(funcOp.getArgNames(), std::back_inserter(inputs), - [](Attribute arg) { return cast(arg).str(); }); - llvm::transform(funcOp.getResNames(), std::back_inserter(outputs), - [](Attribute res) { return cast(res).str(); }); -} - -//===----------------------------------------------------------------------===// -// NamedIOInterface (getOperandName/getResultName) -//===----------------------------------------------------------------------===// - -static inline std::string getArrayElemName(const Twine &name, unsigned idx) { - return name.str() + "_" + std::to_string(idx); -} - -std::string handshake::MuxOp::getOperandName(unsigned idx) { - assert(idx < getNumOperands() && "index too high"); - return idx == 0 ? "index" : getDefaultOperandName(idx - 1); -} - -std::string handshake::ControlMergeOp::getResultName(unsigned idx) { - assert(idx < getNumResults() && "index too high"); - return idx == 0 ? "outs" : "index"; -} - -std::string handshake::ConditionalBranchOp::getOperandName(unsigned idx) { - assert(idx < getNumOperands() && "index too high"); - return idx == 0 ? "condition" : "data"; -} - -std::string handshake::ConditionalBranchOp::getResultName(unsigned idx) { - assert(idx < getNumResults() && "index too high"); - return idx == ConditionalBranchOp::trueIndex ? "trueOut" : "falseOut"; -} - -std::string handshake::ConstantOp::getOperandName(unsigned idx) { - assert(idx == 0 && "index too high"); - return "ctrl"; -} - -std::string handshake::EndOp::getOperandName(unsigned idx) { - assert(idx < getNumOperands() && "index too high"); - handshake::FuncOp funcOp = (*this)->getParentOfType(); - assert(funcOp && "end must be child of handshake function"); - - unsigned numResults = funcOp.getFunctionType().getNumResults(); - if (idx < numResults) - return getDefaultOperandName(idx); - return "memDone_" + std::to_string(idx - numResults); -} - -std::string handshake::SelectOp::getOperandName(unsigned idx) { - assert(idx < getNumOperands() && "index too high"); - if (idx == 0) - return "condition"; - return (idx == 1) ? "trueValue" : "falseValue"; -} - -std::string handshake::SelectOp::getResultName(unsigned idx) { - assert(idx == 0 && "index too high"); - return "result"; -} - -/// Load/Store base signal names common to all memory interfaces -static constexpr llvm::StringLiteral MEMREF("memref"), MEM_START("memStart"), - MEM_END("memEnd"), CTRL_END("ctrlEnd"), CTRL("ctrl"), LD_ADDR("ldAddr"), - LD_DATA("ldData"), ST_ADDR("stAddr"), ST_DATA("stData"); - -static StringRef getIfControlOprd(MemoryOpInterface memOp, unsigned idx) { - if (!memOp.isMasterInterface()) - return ""; - switch (idx) { - case 0: - return MEMREF; - case 1: - return MEM_START; - default: - return idx == memOp->getNumOperands() - 1 ? CTRL_END : ""; - } -} - -static StringRef getIfControlRes(MemoryOpInterface memOp, unsigned idx) { - if (memOp.isMasterInterface() && idx == memOp->getNumResults() - 1) - return MEM_END; - return ""; -} - -/// Common operand naming logic for memory controllers and LSQs. -static std::string getMemOperandName(const FuncMemoryPorts &ports, - unsigned idx) { - // Iterate through all memory ports to find out the type of the operand - unsigned ctrlIdx = 0, loadIdx = 0, storeIdx = 0; - for (const GroupMemoryPorts &blockPorts : ports.groups) { - if (blockPorts.hasControl()) { - if (idx == blockPorts.ctrlPort->getCtrlInputIndex()) - return getArrayElemName(CTRL, ctrlIdx); - ++ctrlIdx; - } - for (const MemoryPort &accessPort : blockPorts.accessPorts) { - if (std::optional loadPort = dyn_cast(accessPort)) { - if (loadPort->getAddrInputIndex() == idx) - return getArrayElemName(LD_ADDR, loadIdx); - ++loadIdx; - } else { - std::optional storePort = cast(accessPort); - if (storePort->getAddrInputIndex() == idx) - return getArrayElemName(ST_ADDR, storeIdx); - if (storePort->getDataInputIndex() == idx) - return getArrayElemName(ST_DATA, storeIdx); - ++storeIdx; - } - } - } - - return ""; -} - -/// Common result naming logic for memory controllers and LSQs. -static std::string getMemResultName(FuncMemoryPorts &ports, unsigned idx) { - // Iterate through all memory ports to find out the type of the - // operand - unsigned loadIdx = 0; - for (const GroupMemoryPorts &blockPorts : ports.groups) { - for (const MemoryPort &accessPort : blockPorts.accessPorts) { - if (std::optional loadPort = dyn_cast(accessPort)) { - if (loadPort->getDataOutputIndex() == idx) - return getArrayElemName(LD_DATA, loadIdx); - ++loadIdx; - } - } - } - return ""; -} - -std::string handshake::MemoryControllerOp::getOperandName(unsigned idx) { - assert(idx < getNumOperands() && "index too high"); - - if (StringRef name = getIfControlOprd(*this, idx); !name.empty()) - return name.str(); - - // Try to get the operand name from the regular ports - MCPorts mcPorts = getPorts(); - if (std::string name = getMemOperandName(mcPorts, idx); !name.empty()) - return name; - - // Get the operand name from a port to an LSQ - assert(mcPorts.connectsToLSQ() && "expected MC to connect to LSQ"); - LSQLoadStorePort lsqPort = mcPorts.getLSQPort(); - if (lsqPort.getLoadAddrInputIndex() == idx) - return getArrayElemName(LD_ADDR, mcPorts.getNumPorts()); - if (lsqPort.getStoreAddrInputIndex() == idx) - return getArrayElemName(ST_ADDR, mcPorts.getNumPorts()); - assert(lsqPort.getStoreDataInputIndex() == idx && "unknown MC/LSQ operand"); - return getArrayElemName(ST_DATA, mcPorts.getNumPorts()); -} - -std::string handshake::MemoryControllerOp::getResultName(unsigned idx) { - assert(idx < getNumResults() && "index too high"); - - if (StringRef name = getIfControlRes(*this, idx); !name.empty()) - return name.str(); - - // Try to get the operand name from the regular ports - MCPorts mcPorts = getPorts(); - if (std::string name = getMemResultName(mcPorts, idx); !name.empty()) - return name; - - // Get the operand name from a port to an LSQ - assert(mcPorts.connectsToLSQ() && "expected MC to connect to LSQ"); - LSQLoadStorePort lsqPort = mcPorts.getLSQPort(); - assert(lsqPort.getLoadDataOutputIndex() == idx && "unknown MC/LSQ result"); - return getArrayElemName(LD_DATA, mcPorts.getNumPorts()); -} - -std::string handshake::LSQOp::getOperandName(unsigned idx) { - assert(idx < getNumOperands() && "index too high"); - - if (StringRef name = getIfControlOprd(*this, idx); !name.empty()) - return name.str(); - - // Try to get the operand name from the regular ports - LSQPorts lsqPorts = getPorts(); - if (std::string name = getMemOperandName(lsqPorts, idx); !name.empty()) - return name; - - // Get the operand name from a port to a memory controller - assert(lsqPorts.connectsToMC() && "expected LSQ to connect to MC"); - assert(lsqPorts.getMCPort().getLoadDataInputIndex() == idx && - "unknown LSQ/MC operand"); - return "ldDataFromMC"; -} - -std::string handshake::LSQOp::getResultName(unsigned idx) { - assert(idx < getNumResults() && "index too high"); - - if (StringRef name = getIfControlRes(*this, idx); !name.empty()) - return name.str(); - - // Try to get the operand name from the regular ports - LSQPorts lsqPorts = getPorts(); - if (std::string name = getMemResultName(lsqPorts, idx); !name.empty()) - return name; - - // Get the operand name from a port to a memory controller - assert(lsqPorts.connectsToMC() && "expected LSQ to connect to MC"); - MCLoadStorePort mcPort = lsqPorts.getMCPort(); - if (mcPort.getLoadAddrOutputIndex() == idx) - return "ldAddrToMC"; - if (mcPort.getStoreAddrOutputIndex() == idx) - return "stAddrToMC"; - assert(mcPort.getStoreDataOutputIndex() == idx && "unknown LSQ/MC result"); - return "stDataToMC"; -} - -std::string handshake::SharingWrapperOp::getOperandName(unsigned idx) { - assert(idx < getNumOperands() && "index too high"); - if (idx < getNumSharedOperands() * getNumSharedOperations()) { - return "op" + std::to_string(idx / getNumSharedOperands()) + "in" + - std::to_string(idx % getNumSharedOperands()); - } - return "fromSharedUnitOut0"; -} - -std::string handshake::SharingWrapperOp::getResultName(unsigned idx) { - assert(idx < getNumResults() && "index too high"); - if (idx < getNumSharedOperations()) - return "op" + std::to_string(idx) + "out0"; - return "toSharedUnitIn" + std::to_string(idx - getNumSharedOperations()); -} - //===----------------------------------------------------------------------===// // MemoryOpInterface //===----------------------------------------------------------------------===// diff --git a/lib/Dialect/Handshake/HandshakeOps.cpp b/lib/Dialect/Handshake/HandshakeOps.cpp index f357ac52d1..b0f4a0bd3c 100644 --- a/lib/Dialect/Handshake/HandshakeOps.cpp +++ b/lib/Dialect/Handshake/HandshakeOps.cpp @@ -1964,5 +1964,185 @@ LogicalResult TruncIOp::verify() { return success(); } +//===----------------------------------------------------------------------===// +// Memory Controller GetOperandName and GetResultName Utilities +//===----------------------------------------------------------------------===// + +/// Load/Store base signal names common to all memory interfaces +static constexpr llvm::StringLiteral MEMREF("memref"), MEM_START("memStart"), + MEM_END("memEnd"), CTRL_END("ctrlEnd"), CTRL("ctrl"), LD_ADDR("ldAddr"), + LD_DATA("ldData"), ST_ADDR("stAddr"), ST_DATA("stData"); + +static inline std::string getArrayElemName(const Twine &name, unsigned idx) { + return name.str() + "_" + std::to_string(idx); +} + +inline static StringRef getIfControlOprd(MemoryOpInterface memOp, + unsigned idx) { + if (!memOp.isMasterInterface()) + return ""; + switch (idx) { + case 0: + return MEMREF; + case 1: + return MEM_START; + default: + return idx == memOp->getNumOperands() - 1 ? CTRL_END : ""; + } +} + +static StringRef getIfControlRes(MemoryOpInterface memOp, unsigned idx) { + if (memOp.isMasterInterface() && idx == memOp->getNumResults() - 1) + return MEM_END; + return ""; +} + +/// Common operand naming logic for memory controllers and LSQs. +inline static std::string getMemOperandName(const FuncMemoryPorts &ports, + unsigned idx) { + // Iterate through all memory ports to find out the type of the operand + unsigned ctrlIdx = 0, loadIdx = 0, storeIdx = 0; + for (const GroupMemoryPorts &blockPorts : ports.groups) { + if (blockPorts.hasControl()) { + if (idx == blockPorts.ctrlPort->getCtrlInputIndex()) + return getArrayElemName(CTRL, ctrlIdx); + ++ctrlIdx; + } + for (const MemoryPort &accessPort : blockPorts.accessPorts) { + if (std::optional loadPort = dyn_cast(accessPort)) { + if (loadPort->getAddrInputIndex() == idx) + return getArrayElemName(LD_ADDR, loadIdx); + ++loadIdx; + } else { + std::optional storePort = cast(accessPort); + if (storePort->getAddrInputIndex() == idx) + return getArrayElemName(ST_ADDR, storeIdx); + if (storePort->getDataInputIndex() == idx) + return getArrayElemName(ST_DATA, storeIdx); + ++storeIdx; + } + } + } + + return ""; +} + +/// Common result naming logic for memory controllers and LSQs. +static std::string getMemResultName(FuncMemoryPorts &ports, unsigned idx) { + // Iterate through all memory ports to find out the type of the + // operand + unsigned loadIdx = 0; + for (const GroupMemoryPorts &blockPorts : ports.groups) { + for (const MemoryPort &accessPort : blockPorts.accessPorts) { + if (std::optional loadPort = dyn_cast(accessPort)) { + if (loadPort->getDataOutputIndex() == idx) + return getArrayElemName(LD_DATA, loadIdx); + ++loadIdx; + } + } + } + return ""; +} + +std::string LSQOp::getOperandNameImpl(unsigned idx) { + + assert(idx < getOperation()->getNumOperands() && "index too high"); + + if (StringRef name = getIfControlOprd(*this, idx); !name.empty()) + return name.str(); + + // Try to get the operand name from the regular ports + LSQPorts lsqPorts = getPorts(); + if (std::string name = getMemOperandName(lsqPorts, idx); !name.empty()) + return name; + + // Get the operand name from a port to a memory controller + assert(lsqPorts.connectsToMC() && "expected LSQ to connect to MC"); + assert(lsqPorts.getMCPort().getLoadDataInputIndex() == idx && + "unknown LSQ/MC operand"); + return "ldDataFromMC"; +} + +std::string LSQOp::getResultNameImpl(unsigned idx) { + assert(idx < getOperation()->getNumResults() && "index too high"); + + if (StringRef name = getIfControlRes(*this, idx); !name.empty()) + return name.str(); + + // Try to get the operand name from the regular ports + LSQPorts lsqPorts = getPorts(); + if (std::string name = getMemResultName(lsqPorts, idx); !name.empty()) + return name; + + // Get the operand name from a port to a memory controller + assert(lsqPorts.connectsToMC() && "expected LSQ to connect to MC"); + MCLoadStorePort mcPort = lsqPorts.getMCPort(); + if (mcPort.getLoadAddrOutputIndex() == idx) + return "ldAddrToMC"; + if (mcPort.getStoreAddrOutputIndex() == idx) + return "stAddrToMC"; + assert(mcPort.getStoreDataOutputIndex() == idx && "unknown LSQ/MC result"); + return "stDataToMC"; +} + +std::string MemoryControllerOp::getOperandNameImpl(unsigned idx) { + assert(idx < getOperation()->getNumOperands() && "index too high"); + + if (StringRef name = getIfControlOprd(*this, idx); !name.empty()) + return name.str(); + + // Try to get the operand name from the regular ports + MCPorts mcPorts = getPorts(); + if (std::string name = getMemOperandName(mcPorts, idx); !name.empty()) + return name; + + // Get the operand name from a port to an LSQ + assert(mcPorts.connectsToLSQ() && "expected MC to connect to LSQ"); + LSQLoadStorePort lsqPort = mcPorts.getLSQPort(); + if (lsqPort.getLoadAddrInputIndex() == idx) + return getArrayElemName(LD_ADDR, mcPorts.getNumPorts()); + if (lsqPort.getStoreAddrInputIndex() == idx) + return getArrayElemName(ST_ADDR, mcPorts.getNumPorts()); + assert(lsqPort.getStoreDataInputIndex() == idx && "unknown MC/LSQ operand"); + return getArrayElemName(ST_DATA, mcPorts.getNumPorts()); +} + +std::string MemoryControllerOp::getResultNameImpl(unsigned idx) { + assert(idx < getOperation()->getNumResults() && "index too high"); + + if (StringRef name = getIfControlRes(*this, idx); !name.empty()) + return name.str(); + + // Try to get the operand name from the regular ports + MCPorts mcPorts = getPorts(); + if (std::string name = getMemResultName(mcPorts, idx); !name.empty()) + return name; + + // Get the operand name from a port to an LSQ + assert(mcPorts.connectsToLSQ() && "expected MC to connect to LSQ"); + LSQLoadStorePort lsqPort = mcPorts.getLSQPort(); + assert(lsqPort.getLoadDataOutputIndex() == idx && "unknown MC/LSQ result"); + return getArrayElemName(LD_DATA, mcPorts.getNumPorts()); +} + +namespace dynamatic { +namespace handshake { + +//===----------------------------------------------------------------------===// +// Operand and Result Names +//===----------------------------------------------------------------------===// + +handshake::HandshakeBaseInterface getHandshakeBase(Operation *op) { + if (auto handshakeBase = + llvm::dyn_cast(op)) { + return handshakeBase; + } + op->emitError() << "must implement HandshakeBaseInterface, op: " << *op; + llvm::report_fatal_error("Missing HandshakeBaseInterface"); +} + +} // end namespace handshake +} // end namespace dynamatic + #define GET_OP_CLASSES -#include "dynamatic/Dialect/Handshake/Handshake.cpp.inc" \ No newline at end of file +#include "dynamatic/Dialect/Handshake/Handshake.cpp.inc" diff --git a/tools/backend/log2csv/log2csv.cpp b/tools/backend/log2csv/log2csv.cpp index 07f5bd2c2a..8183f7d18e 100644 --- a/tools/backend/log2csv/log2csv.cpp +++ b/tools/backend/log2csv/log2csv.cpp @@ -250,16 +250,14 @@ static LogicalResult mapSignalsToValues(mlir::ModuleOp modOp, } // First associate names to all function arguments - handshake::PortNamer argNameGen(funcOp); for (auto [idx, arg] : llvm::enumerate(funcOp.getArguments())) - ports.insert({argNameGen.getInputName(idx), arg}); + ports.insert({funcOp.getArgName(idx), arg}); // Then associate names to each operation's results for (Operation &op : funcOp.getOps()) { - handshake::PortNamer resNameGen(&op); for (auto [idx, res] : llvm::enumerate(op.getResults())) { std::string signalName = - getUniqueName(&op).str() + "_" + resNameGen.getOutputName(idx).str(); + getUniqueName(&op).str() + "_" + funcOp.getResName(idx).str(); ports.insert({signalName, res}); } } diff --git a/tools/export-dot/export-dot.cpp b/tools/export-dot/export-dot.cpp index c70fb6bc4e..939739b7b9 100644 --- a/tools/export-dot/export-dot.cpp +++ b/tools/export-dot/export-dot.cpp @@ -276,13 +276,6 @@ static LogicalResult getDOTGraph(handshake::FuncOp funcOp, DOTGraph &graph) { mlir::DenseMap bbSubgraphs; DOTGraph::Subgraph *root = &builder.getRoot(); - // Collect port names for all operations and the top-level function - using PortNames = mlir::DenseMap; - PortNames portNames; - portNames.try_emplace(funcOp, funcOp); - for (Operation &op : funcOp.getOps()) - portNames.try_emplace(&op, &op); - auto addNode = [&](Operation *op, DOTGraph::Subgraph &subgraph) -> LogicalResult { // The node's DOT "mlir_op" attribute @@ -335,11 +328,13 @@ static LogicalResult getDOTGraph(handshake::FuncOp funcOp, DOTGraph &graph) { Operation *srcOp = res.getDefiningOp(); srcNodeName = getUniqueName(srcOp).str(); srcIdx = res.getResultNumber(); - srcPortName = portNames.at(srcOp).getOutputName(srcIdx); + auto handshakeOp = getHandshakeBase(srcOp); + srcPortName = handshakeOp.getResultName(srcIdx); } else { Operation *parentOp = val.getParentBlock()->getParentOp(); srcIdx = cast(val).getArgNumber(); - srcNodeName = srcPortName = portNames.at(parentOp).getInputName(srcIdx); + auto handshakeOp = getHandshakeBase(parentOp); + srcNodeName = srcPortName = handshakeOp.getOperandName(srcIdx); } // Determine the edge's destination @@ -348,11 +343,13 @@ static LogicalResult getDOTGraph(handshake::FuncOp funcOp, DOTGraph &graph) { if (isa(dstOp)) { Operation *parentOp = dstOp->getParentOp(); dstIdx = oprd.getOperandNumber(); - dstNodeName = dstPortName = portNames.at(parentOp).getOutputName(dstIdx); + auto handshakeOp = getHandshakeBase(parentOp); + dstNodeName = dstPortName = handshakeOp.getResultName(dstIdx); } else { dstNodeName = getUniqueName(dstOp).str(); dstIdx = oprd.getOperandNumber(); - dstPortName = portNames.at(dstOp).getInputName(dstIdx); + auto handshakeOp = getHandshakeBase(dstOp); + dstPortName = handshakeOp.getOperandName(dstIdx); } DOTGraph::Edge &edge = builder.addEdge(srcNodeName, dstNodeName, subgraph); @@ -377,7 +374,7 @@ static LogicalResult getDOTGraph(handshake::FuncOp funcOp, DOTGraph &graph) { continue; // Create a node for the argument - StringRef argName = portNames.at(funcOp).getInputName(idx); + StringRef argName = funcOp.getArgName(idx); DOTGraph::Node *node = builder.addNode(argName, *root); if (!node) return funcOp.emitError() << "failed to create node for argument " << idx; @@ -394,7 +391,7 @@ static LogicalResult getDOTGraph(handshake::FuncOp funcOp, DOTGraph &graph) { // Create nodes for all function results ValueRange results = funcOp.getBodyBlock()->getTerminator()->getOperands(); for (const auto &[idx, res] : llvm::enumerate(results)) { - StringRef resName = portNames.at(funcOp).getOutputName(idx); + StringRef resName = funcOp.getResName(idx); DOTGraph::Node *node = builder.addNode(resName, *root); if (!node) return funcOp.emitError() << "failed to create node for argument " << idx;