From a59052db71f203fd0bdefc46146a7f4e98672cf9 Mon Sep 17 00:00:00 2001 From: Jiahui Xu Date: Thu, 6 Nov 2025 15:27:08 +0100 Subject: [PATCH 01/27] Replace polygeist with llvm-project --- .gitmodules | 11 +++++------ llvm-project | 1 + polygeist | 1 - 3 files changed, 6 insertions(+), 7 deletions(-) create mode 160000 llvm-project delete mode 160000 polygeist diff --git a/.gitmodules b/.gitmodules index 76353fc964..7d97e79a39 100644 --- a/.gitmodules +++ b/.gitmodules @@ -1,9 +1,3 @@ -[submodule "polygeist"] - path = polygeist - url = https://github.com/EPFL-LAP/Polygeist.git - branch = main - shallow = true - [submodule "visual-dataflow/godot-cpp"] path = visual-dataflow/godot-cpp url = https://github.com/godotengine/godot-cpp @@ -12,3 +6,8 @@ [submodule "data/aig"] path = data/aig url = https://github.com/ETHZ-DYNAMO/dataflow-aig-library +[submodule "llvm-project"] + path = llvm-project + url = https://github.com/llvm/llvm-project.git + branch = main + shallow = true diff --git a/llvm-project b/llvm-project new file mode 160000 index 0000000000..9d1b578a22 --- /dev/null +++ b/llvm-project @@ -0,0 +1 @@ +Subproject commit 9d1b578a2237e9c65993d3b9f959e64de184e479 diff --git a/polygeist b/polygeist deleted file mode 160000 index 62f04a7326..0000000000 --- a/polygeist +++ /dev/null @@ -1 +0,0 @@ -Subproject commit 62f04a7326f9d3ef7dd655e1d07b3c4a748d5a99 From 7d1ea23125836f4e8a169dcdcc9e68e893d00c0a Mon Sep 17 00:00:00 2001 From: Jiahui Xu Date: Thu, 6 Nov 2025 18:45:50 +0100 Subject: [PATCH 02/27] Replace all changed calls to MLIR APIs --- build.sh | 81 ++++++----------- .../experimental/Support/CFGAnnotation.h | 1 + .../lib/Support/CreateSmvFormalTestbench.cpp | 14 +-- experimental/lib/Support/SubjectGraph.cpp | 2 +- .../HandshakeCombineSteeringLogic.cpp | 4 +- .../LSQSizing/HandshakeSizeLSQs.cpp | 4 +- .../Speculation/HandshakeSpeculation.cpp | 13 +-- .../elastic-miter/ElasticMiterTestbench.cpp | 4 +- .../tools/elastic-miter/FabricGeneration.cpp | 20 ++-- .../tools/frequency-profiler/Simulator.cpp | 20 ++-- include/dynamatic/Dialect/HW/HWDialect.td | 2 +- include/dynamatic/Dialect/HW/HWOps.h | 1 + include/dynamatic/Dialect/HW/HWStructure.td | 2 +- include/dynamatic/Dialect/HW/HWTypes.h | 10 +- include/dynamatic/Dialect/HW/HWTypes.td | 2 +- .../Dialect/Handshake/HandshakeAttributes.td | 6 +- .../Dialect/Handshake/HandshakeInterfaces.td | 2 +- .../Dialect/Handshake/HandshakeOps.td | 7 +- .../Dialect/Handshake/HandshakeTypes.td | 14 +-- include/dynamatic/Support/LLVM.h | 6 +- .../CfToHandshake/CfToHandshake.cpp | 50 +++++----- .../HandshakeToHW/HandshakeToHW.cpp | 14 +-- lib/Dialect/HW/ConversionPatterns.cpp | 14 +-- lib/Dialect/HW/CustomDirectiveImpl.cpp | 17 ++-- lib/Dialect/HW/HWAttributes.cpp | 86 +++++++++--------- lib/Dialect/HW/HWDialect.cpp | 8 +- lib/Dialect/HW/HWInstanceImplementation.cpp | 20 ++-- lib/Dialect/HW/HWModuleOpInterface.cpp | 2 +- lib/Dialect/HW/HWOpInterfaces.cpp | 2 +- lib/Dialect/HW/HWOps.cpp | 91 +++++++++---------- lib/Dialect/HW/HWTypes.cpp | 44 ++++----- lib/Dialect/HW/ModuleImplementation.cpp | 4 +- lib/Dialect/Handshake/HandshakeOps.cpp | 29 +++--- lib/Dialect/Handshake/MemoryInterfaces.cpp | 6 +- lib/Support/JSON/JSON.cpp | 13 +-- lib/Support/RTL/RTL.cpp | 4 +- lib/Transforms/ArithReduceStrength.cpp | 8 +- lib/Transforms/FlattenMemRefRowMajor.cpp | 9 +- lib/Transforms/HandshakeCanonicalize.cpp | 6 +- lib/Transforms/HandshakeHoistExtInstances.cpp | 4 +- lib/Transforms/HandshakeInferBasicBlocks.cpp | 2 +- lib/Transforms/HandshakeMaterialize.cpp | 10 +- lib/Transforms/HandshakeMinimizeCstWidth.cpp | 4 +- lib/Transforms/HandshakeOptimizeBitwidths.cpp | 14 +-- lib/Transforms/ScfRotateForLoops.cpp | 4 +- lib/Transforms/ScfSimpleIfToSelect.cpp | 4 +- tools/dynamatic/dynamatic.cpp | 2 +- tools/export-rtl/export-rtl.cpp | 11 +-- tools/hls-verifier/hls-verifier.cpp | 4 +- tools/hls-verifier/include/HlsVhdlTb.h | 4 +- tools/translate-llvm-to-std/CMakeLists.txt | 2 + tools/translate-llvm-to-std/InferArgTypes.h | 1 + .../TranslateLLVMToStd.cpp | 14 +-- tools/translate-llvm-to-std/main.cpp | 3 +- .../Transforms/GreedySimplifyMergeLike.cpp | 8 +- 55 files changed, 355 insertions(+), 378 deletions(-) diff --git a/build.sh b/build.sh index 3ec3f2da15..ca7da8803d 100755 --- a/build.sh +++ b/build.sh @@ -16,7 +16,7 @@ print_help_and_exit () { List of options: --release | -r : build in \"Release\" mode (default is \"Debug\") - --skip-polygeist : skip building POLYGEIST + --skip-llvm : skip building LLVM --visual-dataflow | -v : build visual-dataflow's C++ library --export-godot | -e : export the Godot project (requires engine) --force | -f : force cmake reconfiguration in each (sub)project @@ -126,7 +126,6 @@ run_ninja() { CMAKE_COMPILERS="-DCMAKE_C_COMPILER=clang -DCMAKE_CXX_COMPILER=clang++" CMAKE_LLVM_BUILD_OPTIMIZATIONS="-DLLVM_CCACHE_BUILD=ON -DLLVM_USE_LINKER=lld" -CMAKE_POLYGEIST_BUILD_OPTIMIZATIONS="-DPOLYGEIST_USE_LINKER=lld" CMAKE_DYNAMATIC_BUILD_OPTIMIZATIONS="-DDYNAMATIC_CCACHE_BUILD=ON -DLLVM_USE_LINKER=lld" CMAKE_DYNAMATIC_ENABLE_XLS="" CMAKE_DYNAMATIC_ENABLE_LEQ_BINARIES="" @@ -138,8 +137,8 @@ BUILD_TYPE="Debug" BUILD_VISUAL_DATAFLOW=0 GODOT_PATH="" ENABLE_XLS_INTEGRATION=0 -SKIP_POLYGEIST=0 -POLYGEIST_DIR="$PWD/polygeist" +SKIP_LLVM=0 +LLVM_DIR="$PWD/llvm-project" # Loop over command line arguments and update script variables PARSE_ARG="" @@ -156,8 +155,8 @@ do GODOT_PATH="../$GODOT_PATH" fi PARSE_ARG="" - elif [[ $PARSE_ARG == "polygeist-path" ]]; then - POLYGEIST_DIR="$arg" + elif [[ $PARSE_ARG == "llvm-path" ]]; then + LLVM_DIR="$arg" PARSE_ARG="" elif [[ $PARSE_ARG == "llvm-parallel-link-jobs" ]]; then LLVM_PARALLEL_LINK_JOBS="$arg" @@ -167,7 +166,6 @@ do "--disable-build-opt" | "-o") CMAKE_COMPILERS="" CMAKE_LLVM_BUILD_OPTIMIZATIONS="" - CMAKE_POLYGEIST_BUILD_OPTIMIZATIONS="" CMAKE_DYNAMATIC_BUILD_OPTIMIZATIONS="" ;; "--force" | "-f") @@ -191,9 +189,9 @@ do "--export-godot" | "-e") PARSE_ARG="godot-path" ;; - "--skip-polygeist") - SKIP_POLYGEIST=1 - PARSE_ARG="polygeist-path" + "--skip-llvm") + SKIP_LLVM=1 + PARSE_ARG="llvm-path" ;; "--experimental-enable-xls") ENABLE_XLS_INTEGRATION=1 @@ -225,11 +223,11 @@ echo "########################################################################## echo "############# DYNAMATIC - DHLS COMPILER INFRASTRUCTURE - EPFL/LAP ##############" echo "################################################################################" -if [[ $SKIP_POLYGEIST -eq 0 ]]; then +if [[ $SKIP_LLVM -eq 0 ]]; then - #### Polygeist #### + #### LLVM #### - prepare_to_build_project "LLVM" "polygeist/llvm-project/build" + prepare_to_build_project "LLVM" "llvm-project/build" # CMake if should_run_cmake ; then @@ -241,44 +239,21 @@ if [[ $SKIP_POLYGEIST -eq 0 ]]; then -DCMAKE_BUILD_TYPE=$BUILD_TYPE \ -DLLVM_PARALLEL_LINK_JOBS=$LLVM_PARALLEL_LINK_JOBS \ $CMAKE_COMPILERS $CMAKE_LLVM_BUILD_OPTIMIZATIONS - exit_on_fail "Failed to cmake polygeist/llvm-project" + exit_on_fail "Failed to cmake llvm-project" fi # Build run_ninja - exit_on_fail "Failed to build polygeist/llvm-project" + exit_on_fail "Failed to build llvm-project" if [[ ENABLE_TESTS -eq 1 ]]; then ninja check-mlir - exit_on_fail "Tests for polygeist/llvm-project failed" - fi - - prepare_to_build_project "Polygeist" "polygeist/build" - - # CMake - if should_run_cmake ; then - cmake -G Ninja .. \ - -DMLIR_DIR=$PWD/../llvm-project/build/lib/cmake/mlir \ - -DCLANG_DIR=$PWD/../llvm-project/build/lib/cmake/clang \ - -DLLVM_TARGETS_TO_BUILD="host" \ - -DCMAKE_BUILD_TYPE=$BUILD_TYPE \ - $CMAKE_COMPILERS $CMAKE_POLYGEIST_BUILD_OPTIMIZATIONS - exit_on_fail "Failed to cmake polygeist" - fi - - # Build - run_ninja - exit_on_fail "Failed to build polygeist" - if [[ ENABLE_TESTS -eq 1 ]]; then - ninja check-polygeist-opt - exit_on_fail "Tests for polygeist failed" - ninja check-cgeist - exit_on_fail "Tests for polygeist failed" + exit_on_fail "Tests for llvm-project failed" fi else - echo "Skipping POLYGEIST/LLVM build. IMPORTANT: Verify that the path of polygeist in the script tools/dynamatic/scripts/compile.sh is the same" - if [[ ! -d $POLYGEIST_DIR ]]; then - echo "POLYGEIST directory not found: $POLYGEIST_DIR" + echo "Skipping LLVM build. IMPORTANT: Verify that the path of llvm in the script tools/dynamatic/scripts/compile.sh is the same" + if [[ ! -d $LLVM_DIR ]]; then + echo "LLVM directory not found: $LLVM_DIR" exit 1 fi fi @@ -336,10 +311,10 @@ prepare_to_build_project "Dynamatic" "build" # CMake if should_run_cmake ; then cmake -G Ninja .. \ - -DMLIR_DIR="$POLYGEIST_DIR"/llvm-project/build/lib/cmake/mlir \ - -DLLVM_DIR="$POLYGEIST_DIR"/llvm-project/build/lib/cmake/llvm \ - -DCLANG_DIR="$POLYGEIST_DIR"/llvm-project/build/lib/cmake/clang \ - -DPolly_DIR="$POLYGEIST_DIR"/llvm-project/build/tools/polly/lib/cmake/polly \ + -DMLIR_DIR="$LLVM_DIR/build/lib/cmake/mlir" \ + -DLLVM_DIR="$LLVM_DIR/build/lib/cmake/llvm" \ + -DCLANG_DIR="$LLVM_DIR/build/lib/cmake/clang" \ + -DPolly_DIR="$LLVM_DIR/build/tools/polly/lib/cmake/polly" \ -DLLVM_TARGETS_TO_BUILD="host" \ -DCMAKE_BUILD_TYPE=$BUILD_TYPE \ -DCMAKE_EXPORT_COMPILE_COMMANDS="ON" \ @@ -374,8 +349,8 @@ if [[ BUILD_VISUAL_DATAFLOW -ne 0 ]]; then # CMake if should_run_cmake ; then cmake -G Ninja .. \ - -DMLIR_DIR="$POLYGEIST_DIR"/llvm-project/build/lib/cmake/mlir \ - -DLLVM_DIR="$POLYGEIST_DIR"/llvm-project/build/lib/cmake/llvm \ + -DMLIR_DIR="$LLVM_DIR/llvm-project/build/lib/cmake/mlir" \ + -DLLVM_DIR="$LLVM_DIR/llvm-project/build/lib/cmake/llvm" \ -DLLVM_TARGETS_TO_BUILD="host" \ -DCMAKE_BUILD_TYPE=$BUILD_TYPE \ -DCMAKE_EXPORT_COMPILE_COMMANDS="ON" \ @@ -411,11 +386,9 @@ cd "$SCRIPT_CWD" && mkdir -p bin/generators # Create symbolic links to all binaries we use from subfolders -create_symlink "$POLYGEIST_DIR"/build/bin/cgeist -create_symlink "$POLYGEIST_DIR"/build/bin/polygeist-opt -create_symlink "$POLYGEIST_DIR"/llvm-project/build/bin/clang++ -create_symlink "$POLYGEIST_DIR"/llvm-project/build/bin/opt -create_symlink "$POLYGEIST_DIR"/llvm-project/build/bin/clang +create_symlink "$LLVM_DIR/build/bin/opt" +create_symlink "$LLVM_DIR/build/bin/clang++" +create_symlink "$LLVM_DIR/build/bin/clang" create_symlink ../build/bin/dynamatic create_symlink ../build/bin/dynamatic-mlir-lsp-server create_symlink ../build/bin/dynamatic-opt @@ -435,7 +408,7 @@ create_generator_symlink build/bin/exp-sharing-wrapper-generator create_generator_symlink "$LSQ_GEN_PATH/$LSQ_GEN_JAR" # Create symbolic links to polygeist headers (standard c library for clang) -create_include_symlink "$POLYGEIST_DIR"/llvm-project/clang/lib/Headers +create_include_symlink "$LLVM_DIR/clang/lib/Headers" if [[ $GODOT_PATH != "" ]]; then diff --git a/experimental/include/experimental/Support/CFGAnnotation.h b/experimental/include/experimental/Support/CFGAnnotation.h index b7bd5aed5f..014dee618c 100644 --- a/experimental/include/experimental/Support/CFGAnnotation.h +++ b/experimental/include/experimental/Support/CFGAnnotation.h @@ -27,6 +27,7 @@ #include "dynamatic/Analysis/NameAnalysis.h" #include "dynamatic/Dialect/Handshake/HandshakeOps.h" +#include namespace dynamatic { namespace experimental { diff --git a/experimental/lib/Support/CreateSmvFormalTestbench.cpp b/experimental/lib/Support/CreateSmvFormalTestbench.cpp index 8ac8ee1274..e576281444 100644 --- a/experimental/lib/Support/CreateSmvFormalTestbench.cpp +++ b/experimental/lib/Support/CreateSmvFormalTestbench.cpp @@ -84,12 +84,12 @@ static std::string instantiateModuleUnderTest( if (syncOutput) for (const auto &[i, result] : llvm::enumerate(results)) { const auto &[resultName, type] = result; - if (type.isa()) + if (llvm::isa(type)) inputVariables.push_back(llvm::formatv("join_global.ins_{0}_ready", i)); } else for (const auto &[resultName, type] : results) { - if (type.isa()) + if (llvm::isa(type)) inputVariables.push_back( llvm::formatv("sink_{0}.{1}", resultName, SINK_READY_NAME.str())); } @@ -238,7 +238,7 @@ static std::string createSupportEntities( llvm::DenseSet types; for (const auto &[_, type] : arguments) - if (type.isa()) + if (llvm::isa(type)) types.insert(type); std::ostringstream supportEntities; @@ -279,7 +279,7 @@ static std::string instantiateSequenceGenerators( size_t nrOfTokens, bool generateExactNrOfTokens = false) { std::ostringstream sequenceGenerators; for (const auto &[argumentName, type] : arguments) { - if (!type.isa()) + if (!isa(type)) continue; std::string typePrefixName = @@ -328,7 +328,7 @@ instantiateSinks(const std::string &moduleName, std::ostringstream sinks; for (const auto &[resultName, type] : results) { - if (type.isa()) + if (llvm::isa(type)) sinks << llvm::formatv(" VAR sink_{0} : sink_main({1}.{0}_valid);\n", resultName, moduleName) .str(); @@ -345,7 +345,7 @@ instantiateJoin(const std::string &moduleName, str << " VAR join_global : tb_join("; for (const auto &[resultName, type] : results) { - if (type.isa()) + if (llvm::isa(type)) outputValids.push_back( llvm::formatv("{0}.{1}_valid", moduleName, resultName)); } @@ -384,4 +384,4 @@ std::string createSmvFormalTestbench(const SmvTestbenchConfig &config) { return wrapper.str(); } -} // namespace dynamatic::experimental \ No newline at end of file +} // namespace dynamatic::experimental diff --git a/experimental/lib/Support/SubjectGraph.cpp b/experimental/lib/Support/SubjectGraph.cpp index faa11913af..6ca7c7832e 100644 --- a/experimental/lib/Support/SubjectGraph.cpp +++ b/experimental/lib/Support/SubjectGraph.cpp @@ -145,7 +145,7 @@ void BaseSubjectGraph::buildSubjectGraphConnections() { // Store the Result Number of the input operand in the // inputSubjectGraphToResultNumber map. inputSubjectGraphToResultNumber[inputSubjectGraph] = - inputOperand.cast().getResultNumber(); + mlir::cast(inputOperand).getResultNumber(); } } diff --git a/experimental/lib/Transforms/HandshakeCombineSteeringLogic.cpp b/experimental/lib/Transforms/HandshakeCombineSteeringLogic.cpp index d8987238c4..1bfc946b8f 100644 --- a/experimental/lib/Transforms/HandshakeCombineSteeringLogic.cpp +++ b/experimental/lib/Transforms/HandshakeCombineSteeringLogic.cpp @@ -391,8 +391,8 @@ struct HandshakeCombineSteeringLogicPass MLIRContext *ctx = &getContext(); ModuleOp mod = getOperation(); GreedyRewriteConfig config; - config.useTopDownTraversal = true; - config.enableRegionSimplification = false; + config.setUseTopDownTraversal(true); + config.setRegionSimplificationLevel(GreedySimplifyRegionLevel::Disabled); RewritePatternSet patterns(ctx); patterns.add(attr.getValue()); llvm::SetVector cfdfcBBs; for (auto bb : bbList) - cfdfcBBs.insert(bb.cast().getUInt()); + cfdfcBBs.insert(mlir::cast(bb).getUInt()); unsigned index; if (attr.getName().getValue().getAsInteger(10, index)) @@ -666,4 +666,4 @@ dynamatic::experimental::lsqsizing::createHandshakeSizeLSQs( StringRef timingModels, StringRef collisions, double targetCP) { return std::make_unique(timingModels, collisions, targetCP); -} \ No newline at end of file +} diff --git a/experimental/lib/Transforms/Speculation/HandshakeSpeculation.cpp b/experimental/lib/Transforms/Speculation/HandshakeSpeculation.cpp index aa8c609e63..6f58fa8676 100644 --- a/experimental/lib/Transforms/Speculation/HandshakeSpeculation.cpp +++ b/experimental/lib/Transforms/Speculation/HandshakeSpeculation.cpp @@ -522,7 +522,8 @@ std::optional findControlInputToBB(handshake::FuncOp &funcOp, continue; // We are looking for the control branch: data should be of control type - if (branchOp.getDataOperand().getType().isa()) { + if (llvm::isa( + branchOp.getDataOperand().getType())) { // BB should have only one control branch at most if (isControlBranchFound) { branchOp->emitError("Multiple control branches found in the BB #" + @@ -601,8 +602,8 @@ static LogicalResult addSpecTagToValue(Value value) { // The value type must implement ExtraSignalsTypeInterface (e.g., ChannelType // or ControlType). - if (auto valueType = - value.getType().dyn_cast()) { + if (auto valueType = llvm::dyn_cast( + value.getType())) { // Skip if the spec tag was already added during the algorithm. if (!valueType.hasExtraSignal(EXTRA_BIT_SPEC)) { llvm::SmallVector newExtraSignals( @@ -787,8 +788,8 @@ LogicalResult HandshakeSpeculationPass::addNonSpecOp() { OpBuilder builder(&getContext()); for (auto mergeLikeOp : funcOp.getOps()) { - auto dataResultType = - mergeLikeOp.getDataResult().getType().cast(); + auto dataResultType = mlir::cast( + mergeLikeOp.getDataResult().getType()); if (dataResultType.hasExtraSignal(EXTRA_BIT_SPEC)) { // This MuxOp/CMergeOp is within the speculative region. @@ -797,7 +798,7 @@ LogicalResult HandshakeSpeculationPass::addNonSpecOp() { // non-speculative edges. for (auto dataOperand : mergeLikeOp.getDataOperands()) { auto dataOperandType = - dataOperand.getType().cast(); + cast(dataOperand.getType()); if (!dataOperandType.hasExtraSignal(EXTRA_BIT_SPEC)) { // Create a NonSpecOp to add the spec tag to the data operand diff --git a/experimental/tools/elastic-miter/ElasticMiterTestbench.cpp b/experimental/tools/elastic-miter/ElasticMiterTestbench.cpp index dc2e581a88..4c0b22f7a2 100644 --- a/experimental/tools/elastic-miter/ElasticMiterTestbench.cpp +++ b/experimental/tools/elastic-miter/ElasticMiterTestbench.cpp @@ -43,7 +43,7 @@ static std::string createMiterProperties(const std::string &moduleName, // INVARSPEC (model.EQ_A_valid -> model.EQ_A_out) // INVARSPEC (model.EQ_B_valid -> model.EQ_B_out) for (const auto &[resultName, resultType] : config.results) { - if (resultType.isa()) + if (isa(resultType)) properties << llvm::formatv("INVARSPEC ({0}.{1}_valid -> {0}.{1}_out)", moduleName, resultName) .str(); @@ -132,4 +132,4 @@ LogicalResult createSmvSequenceLengthTestbench( mainFile.close(); return success(); } -} // namespace dynamatic::experimental \ No newline at end of file +} // namespace dynamatic::experimental diff --git a/experimental/tools/elastic-miter/FabricGeneration.cpp b/experimental/tools/elastic-miter/FabricGeneration.cpp index 246ae3320e..d83ca9c6be 100644 --- a/experimental/tools/elastic-miter/FabricGeneration.cpp +++ b/experimental/tools/elastic-miter/FabricGeneration.cpp @@ -96,7 +96,7 @@ buildEmptyMiterFuncOp(OpBuilder builder, FuncOp &lhsFuncOp, FuncOp &rhsFuncOp) { // on the LHS resNames but are prefixed with EQ_ SmallVector prefixedResAttr; for (Attribute attr : lhsFuncOp.getResNames()) { - auto strAttr = attr.dyn_cast(); + auto strAttr = dyn_cast(attr); if (strAttr) { prefixedResAttr.push_back( builder.getStringAttr("EQ_" + strAttr.getValue().str())); @@ -121,7 +121,7 @@ buildEmptyMiterFuncOp(OpBuilder builder, FuncOp &lhsFuncOp, FuncOp &rhsFuncOp) { for (Type type : lhsFuncOp.getResultTypes()) { // If the type is a handshake::ControlType, keep it - if (type.isa()) { + if (isa(type)) { outputTypes.push_back(type); } else { // Otherwise replace it with !handshake.channel @@ -170,7 +170,7 @@ FailureOr getModuleFuncOpAndCheck(ModuleOp module) { // Check that arguments are all handshake.channel or handshake.control type for (Type ty : funcOp.getArgumentTypes()) { - if (!ty.isa()) { + if (!isa(ty)) { llvm::errs() << "All arguments need to be of handshake.channel or " "handshake.control type\n"; return failure(); @@ -179,7 +179,7 @@ FailureOr getModuleFuncOpAndCheck(ModuleOp module) { // Check that results are all handshake.channel or handshake.control type for (Type ty : funcOp.getResultTypes()) { - if (!ty.isa()) { + if (!isa(ty)) { llvm::errs() << "All results need to be of handshake.channel or " "handshake.control type\n"; return failure(); @@ -263,7 +263,7 @@ createReachabilityCircuit(MLIRContext &context, } Attribute attr = funcOp.getArgNames()[i]; - auto strAttr = attr.dyn_cast(); + auto strAttr = dyn_cast(attr); config.arguments.push_back( std::make_pair(strAttr.getValue().str(), arg.getType())); } @@ -299,7 +299,7 @@ createReachabilityCircuit(MLIRContext &context, op->replaceUsesOfWith(result, endNDWireOp.getResult()); } Attribute attr = funcOp.getResNames()[i]; - auto strAttr = attr.dyn_cast(); + auto strAttr = dyn_cast(attr); config.results.push_back( std::make_pair(strAttr.getValue().str(), result.getType())); } @@ -429,7 +429,7 @@ createElasticMiter(MLIRContext &context, ModuleOp lhsModule, ModuleOp rhsModule, op->replaceUsesOfWith(rhsArg, rhsNDWireOp.getResult()); Attribute attr = lhsFuncOp.getArgNames()[i]; - auto strAttr = attr.dyn_cast(); + auto strAttr = dyn_cast(attr); config.arguments.push_back( std::make_pair(strAttr.getValue().str(), lhsArg.getType())); } @@ -488,7 +488,7 @@ createElasticMiter(MLIRContext &context, ModuleOp lhsModule, ModuleOp rhsModule, setHandshakeAttributes(builder, lhsEndBufferOp, BB_OUT, lhsBufName); setHandshakeAttributes(builder, rhsEndBufferOp, BB_OUT, rhsBufName); - if (lhsResult.getType().isa()) { + if (isa(lhsResult.getType())) { ValueRange joinInputs = {lhsEndBufferOp.getResult(), rhsEndBufferOp.getResult()}; JoinOp joinOp = @@ -509,7 +509,7 @@ createElasticMiter(MLIRContext &context, ModuleOp lhsModule, ModuleOp rhsModule, config.eq.push_back(eqName); Attribute attr = lhsFuncOp.getResNames()[i]; - auto strAttr = attr.dyn_cast(); + auto strAttr = dyn_cast(attr); // The result name is prefixed with EQ_ config.results.push_back( std::make_pair("EQ_" + strAttr.getValue().str(), lhsResult.getType())); @@ -580,4 +580,4 @@ createMiterFabric(MLIRContext &context, const std::filesystem::path &lhsPath, return std::make_pair(outputDir / mlirFilename, config); } -} // namespace dynamatic::experimental \ No newline at end of file +} // namespace dynamatic::experimental diff --git a/experimental/tools/frequency-profiler/Simulator.cpp b/experimental/tools/frequency-profiler/Simulator.cpp index 8b8ae79d5a..030b0c8ca9 100644 --- a/experimental/tools/frequency-profiler/Simulator.cpp +++ b/experimental/tools/frequency-profiler/Simulator.cpp @@ -88,7 +88,7 @@ static Any readValueWithType(mlir::Type type, std::stringstream &arg) { APInt aparg(width, x); return aparg; } - if (type.isa()) { + if (isa(type)) { int64_t x; arg >> x; int64_t width = type.getIntOrFloatBitWidth(); @@ -107,7 +107,7 @@ static Any readValueWithType(mlir::Type type, std::stringstream &arg) { APFloat aparg(x); return aparg; } - if (auto tupleType = type.dyn_cast()) { + if (auto tupleType = llvm::dyn_cast(type)) { char tmp; arg >> tmp; assert(tmp == '(' && "tuple should start with '('"); @@ -160,9 +160,9 @@ static unsigned allocateMemRef(mlir::MemRefType type, std::vector &in, mlir::Type elementType = type.getElementType(); int64_t width = elementType.getIntOrFloatBitWidth(); for (int i = 0; i < allocationSize; ++i) { - if (elementType.isa()) { + if (isa(elementType)) { store[ptr][i] = APInt(width, 0); - } else if (elementType.isa()) { + } else if (isa(elementType)) { store[ptr][i] = APFloat(0.0); } else { fatalValueError("Unknown result type!\n", elementType); @@ -856,9 +856,9 @@ LogicalResult simulate(func::FuncOp funcOp, ArrayRef inputArgs, for (unsigned i = 0; i < numInputs; ++i) { mlir::Type type = ftype.getInput(i); - if (type.isa()) { + if (isa(type)) { // We require this memref type to be fully specified. - auto memreftype = type.dyn_cast(); + auto memreftype = llvm::dyn_cast(type); // emptyDims: the dynamic dimension type of alloca takes an array of // dimensions. We cannot pass an temporary object "{}" to a reference type @@ -899,7 +899,7 @@ LogicalResult simulate(func::FuncOp funcOp, ArrayRef inputArgs, // } auto modOp = funcOp->getParentOfType(); modOp.walk([&](memref::GlobalOp gblOp) { - auto memreftype = gblOp.getTypeAttr().getValue().dyn_cast(); + auto memreftype = dyn_cast(gblOp.getTypeAttr().getValue()); // emptyDims: the dynamic dimension type of alloca takes an array of // dimensions. We cannot pass an temporary object "{}" to a reference type // "std::vector &". So we construct an empty array here. @@ -920,13 +920,13 @@ LogicalResult simulate(func::FuncOp funcOp, ArrayRef inputArgs, // If the GlobalOp has a dense initializer, use it the initialize the memory // content: mlir::Attribute initValueAttr = gblOp.getInitialValueAttr(); - if (auto denseAttr = initValueAttr.dyn_cast()) { + if (auto denseAttr = dyn_cast(initValueAttr)) { mlir::Type elemType = denseAttr.getElementType(); - if (elemType.isa()) { + if (isa(elemType)) { for (auto [id, val] : llvm::enumerate(denseAttr.getValues())) { programStackMemory[pointer][id] = val; } - } else if (elemType.isa()) { + } else if (isa(elemType)) { for (auto [id, val] : llvm::enumerate(denseAttr.getValues())) { programStackMemory[pointer][id] = val; } diff --git a/include/dynamatic/Dialect/HW/HWDialect.td b/include/dynamatic/Dialect/HW/HWDialect.td index 37a9157da2..5f295ca29e 100644 --- a/include/dynamatic/Dialect/HW/HWDialect.td +++ b/include/dynamatic/Dialect/HW/HWDialect.td @@ -34,7 +34,7 @@ def HWDialect : Dialect { let useDefaultTypePrinterParser = 1; // Opt-out of properties for now, must migrate by LLVM 19. #5273. - let usePropertiesForAttributes = 0; + // let usePropertiesForAttributes = 0; let extraClassDeclaration = [{ /// Register all HW types. diff --git a/include/dynamatic/Dialect/HW/HWOps.h b/include/dynamatic/Dialect/HW/HWOps.h index 1353751fb9..a673af2191 100644 --- a/include/dynamatic/Dialect/HW/HWOps.h +++ b/include/dynamatic/Dialect/HW/HWOps.h @@ -31,6 +31,7 @@ #include "mlir/Interfaces/InferTypeOpInterface.h" #include "mlir/Interfaces/SideEffectInterfaces.h" #include "llvm/ADT/StringExtras.h" +#include namespace dynamatic { namespace hw { diff --git a/include/dynamatic/Dialect/HW/HWStructure.td b/include/dynamatic/Dialect/HW/HWStructure.td index 88471b5767..4afcb8e43d 100644 --- a/include/dynamatic/Dialect/HW/HWStructure.td +++ b/include/dynamatic/Dialect/HW/HWStructure.td @@ -514,7 +514,7 @@ def OutputOp : HWOp<"output", [Terminator, HasParent<"HWModuleOp">, let arguments = (ins Variadic:$outputs); let builders = [ - OpBuilder<(ins), "build($_builder, $_state, std::nullopt);"> + OpBuilder<(ins), "build($_builder, $_state, {});"> ]; let assemblyFormat = "attr-dict ($outputs^ `:` qualified(type($outputs)))?"; diff --git a/include/dynamatic/Dialect/HW/HWTypes.h b/include/dynamatic/Dialect/HW/HWTypes.h index 094821b7d5..5be1c3d0a6 100644 --- a/include/dynamatic/Dialect/HW/HWTypes.h +++ b/include/dynamatic/Dialect/HW/HWTypes.h @@ -102,11 +102,11 @@ bool hasHWInOutType(mlir::Type type); template bool type_isa(Type type) { // First check if the type is the requested type. - if (type.isa()) + if (llvm::isa(type)) return true; // Then check if it is a type alias wrapping the requested type. - if (auto alias = type.dyn_cast()) + if (auto alias = mlir::dyn_cast(type)) return type_isa(alias.getInnerType()); return false; @@ -125,11 +125,11 @@ BaseTy type_cast(Type type) { assert(type_isa(type) && "type must convert to requested type"); // If the type is the requested type, return it. - if (type.isa()) - return type.cast(); + if (llvm::isa(type)) + return mlir::cast(type); // Otherwise, it must be a type alias wrapping the requested type. - return type_cast(type.cast().getInnerType()); + return type_cast(mlir::cast(type).getInnerType()); } template diff --git a/include/dynamatic/Dialect/HW/HWTypes.td b/include/dynamatic/Dialect/HW/HWTypes.td index 4d89098e80..111938d2e9 100644 --- a/include/dynamatic/Dialect/HW/HWTypes.td +++ b/include/dynamatic/Dialect/HW/HWTypes.td @@ -91,7 +91,7 @@ def HWStringType : /// A flat symbol reference or a reference to a name within a module. def NameRefAttr : Attr< - CPred<"$_self.isa<::mlir::FlatSymbolRefAttr, ::dynamatic::hw::InnerRefAttr>()">, + CPred<"isa<::mlir::FlatSymbolRefAttr, ::dynamatic::hw::InnerRefAttr>($_self)">, "name reference attribute">{ let returnType = "::mlir::Attribute"; let convertFromStorage = "$_self"; diff --git a/include/dynamatic/Dialect/Handshake/HandshakeAttributes.td b/include/dynamatic/Dialect/Handshake/HandshakeAttributes.td index 27b998b979..32d65bda68 100644 --- a/include/dynamatic/Dialect/Handshake/HandshakeAttributes.td +++ b/include/dynamatic/Dialect/Handshake/HandshakeAttributes.td @@ -88,7 +88,7 @@ class OperandContainerAttr traits size_t numProps = operandAttrs.size(); for (auto [idx, attr] : llvm::enumerate(operandAttrs)) { odsPrinter << attr.getName() << ": "; - attr.getValue().cast<$cppClass::OperandAttr>().print(odsPrinter); + mlir::cast<$cppClass::OperandAttr>(attr.getValue()).print(odsPrinter); if (idx != numProps - 1) odsPrinter << ", "; } @@ -140,7 +140,7 @@ class OperandContainerAttr traits << attr.getName(); // Value must be channel buffering properties - if (!attr.getValue().isa<$cppClass::OperandAttr>()) + if (!isa<$cppClass::OperandAttr>(attr.getValue())) return emitError() << "map values are not of the correct type"; } return success(); @@ -392,7 +392,7 @@ def CFDFCThroughputAttr : Handshake_Attr< llvm::SmallVector attrs; mlir::Builder builder(context); for (const auto &pair: cfdfcThroughputMap) { - mlir::Type doubleType = mlir::FloatType::getF64(context); + mlir::Type doubleType = builder.getF64Type(); auto tmpCFDFCThroughputAttr = builder.getFloatAttr(doubleType, pair.second); attrs.push_back({builder.getStringAttr(std::to_string(pair.first)), tmpCFDFCThroughputAttr}); } diff --git a/include/dynamatic/Dialect/Handshake/HandshakeInterfaces.td b/include/dynamatic/Dialect/Handshake/HandshakeInterfaces.td index 957ef78f2d..1ce932e9c0 100644 --- a/include/dynamatic/Dialect/Handshake/HandshakeInterfaces.td +++ b/include/dynamatic/Dialect/Handshake/HandshakeInterfaces.td @@ -73,7 +73,7 @@ def MemoryOpInterface : OpInterface<"MemoryOpInterface"> { "::mlir::MemRefType", "getMemRefType", (ins), "", [{ ConcreteOp concreteOp = cast($_op); - return concreteOp.getMemRef().getType().template cast(); + return cast(concreteOp.getMemRef().getType()); }] >, InterfaceMethod<[{ diff --git a/include/dynamatic/Dialect/Handshake/HandshakeOps.td b/include/dynamatic/Dialect/Handshake/HandshakeOps.td index fba390a4c4..7c6ac67670 100644 --- a/include/dynamatic/Dialect/Handshake/HandshakeOps.td +++ b/include/dynamatic/Dialect/Handshake/HandshakeOps.td @@ -95,12 +95,12 @@ def FuncOp : Op(); + return mlir::cast(getArgNames()[idx]); } /// Returns the result name at the given index. StringAttr getResName(unsigned idx) { - return getResNames()[idx].cast(); + return mlir::cast(getResNames()[idx]); } /// Hook for FunctionOpInterface, called after verifying that the 'type' @@ -109,7 +109,7 @@ def FuncOp : Op()) + if (!llvm::isa(type)) return emitOpError( "requires '" + getFunctionTypeAttrName().getValue() + "' attribute of function type"); @@ -136,7 +136,6 @@ def FuncOp : Op ]> { diff --git a/include/dynamatic/Dialect/Handshake/HandshakeTypes.td b/include/dynamatic/Dialect/Handshake/HandshakeTypes.td index 66305ab5e9..787d6c8b94 100644 --- a/include/dynamatic/Dialect/Handshake/HandshakeTypes.td +++ b/include/dynamatic/Dialect/Handshake/HandshakeTypes.td @@ -144,7 +144,7 @@ class valuesToDataTypes names> { names, // Treat ControlType's data type as i0 "isa($" # name # ".getType()) ? IntegerType::get($_ctxt, 0) :" - "$" # name # ".getType().cast().getDataType()" + "mlir::cast($" # name # ".getType()).getDataType()" ), ", ") # "}"; } @@ -158,7 +158,7 @@ class variadicToExtraSignalArrays { for (Type type : variadic.getTypes()) { extraSignalArrays.push_back( - type.cast().getExtraSignals()); + mlir::cast(type).getExtraSignals()); } return extraSignalArrays; @@ -179,7 +179,7 @@ class variadicToDataTypes { // Treat ControlType's data type as i0 dataTypes.push_back(IntegerType::get($_ctxt, 0)); } else { - dataTypes.push_back(type.cast().getDataType()); + dataTypes.push_back(cast(type).getDataType()); } } @@ -287,7 +287,7 @@ class IsSimpleHandshakeVariadic : PredOpTrait< name # " shouldn't have any extra signals", CPred<"llvm::all_of($" # name # ".getTypes(), " "[](Type type) {" - " return type.cast().getNumExtraSignals() == 0;" + " return cast(type).getNumExtraSignals() == 0;" "})"> >; @@ -407,7 +407,7 @@ class HasValidSpecTag : PredOpTrait< return true; } - }] # "($" # operand # ".getType().cast())"> + }] # "(cast($" # operand # ".getType()))"> >; class LacksSpecTag : PredOpTrait< @@ -422,12 +422,12 @@ class HasSpecTagIfPresentIn : PredOpTrait< [](OperandRange::type_range inputTypes, ExtraSignalsTypeInterface outputType) { bool inputHasSpecTag = llvm::any_of(inputTypes, [](Type type) { - return type.cast().hasExtraSignal("spec"); + return cast(type).hasExtraSignal("spec"); }); return outputType.hasExtraSignal("spec") == inputHasSpecTag; } }] # "($" # inputVariadic # ".getTypes(), " - "$" # output # ".getType().cast())"> + "cast($" # output # ".getType()))"> >; #endif // DYNAMATIC_DIALECT_HANDSHAKE_HANDSHAKE_TYPES_TD diff --git a/include/dynamatic/Support/LLVM.h b/include/dynamatic/Support/LLVM.h index 9027f36251..13d84d90c5 100644 --- a/include/dynamatic/Support/LLVM.h +++ b/include/dynamatic/Support/LLVM.h @@ -23,7 +23,7 @@ // Can not forward declare inline functions with default arguments, so we // include the header directly. -#include "mlir/Support/LogicalResult.h" +#include "llvm/Support/LogicalResult.h" // Import classes from the `mlir` namespace into the `dynamatic` namespace. All // of the following classes have been already forward declared and imported from @@ -134,7 +134,6 @@ class OpOperand; class OpResult; template class OwningOpRef; -class ParseResult; class Pass; class PatternRewriter; class Region; @@ -160,13 +159,10 @@ class VectorType; class WalkResult; enum class RegionKind; struct CallInterfaceCallable; -struct LogicalResult; struct MemRefAccess; struct OperationState; class OperationName; -template -class FailureOr; template class OpConversionPattern; template diff --git a/lib/Conversion/CfToHandshake/CfToHandshake.cpp b/lib/Conversion/CfToHandshake/CfToHandshake.cpp index 04864b5312..193be79cfb 100644 --- a/lib/Conversion/CfToHandshake/CfToHandshake.cpp +++ b/lib/Conversion/CfToHandshake/CfToHandshake.cpp @@ -163,12 +163,13 @@ LogicalResult LowerFuncToHandshake::computeLinearDominance( // CfToHandshakeTypeConverter //===-----------------------------------------------------------------------==// -static std::optional oneToOneVoidMaterialization(OpBuilder &builder, - Type /*resultType*/, - ValueRange inputs, - Location /*loc*/) { +static Value oneToOneVoidMaterialization(OpBuilder &builder, + Type /*resultType*/, ValueRange inputs, + Location /*loc*/) { + // NOTE (@Jiahui17): Taking the solution here + // https://github.com/llvm/circt/blob/main/lib/Dialect/ESI/Passes/ESILowerTypes.cpp if (inputs.size() != 1) - return std::nullopt; + return mlir::Value(); return inputs[0]; } @@ -191,7 +192,10 @@ static Type channelifyType(Type type) { CfToHandshakeTypeConverter::CfToHandshakeTypeConverter() { addConversion(channelifyType); - addArgumentMaterialization(oneToOneVoidMaterialization); + // NOTE: addArgumentMaterialization() is replaced by + // addSourceMaterialization() + // addArgumentMaterialization(oneToOneVoidMaterialization); + // https://github.com/llvm/llvm-project/pull/116524#issue-2665213624 addSourceMaterialization(oneToOneVoidMaterialization); addTargetMaterialization(oneToOneVoidMaterialization); } @@ -395,17 +399,18 @@ FailureOr LowerFuncToHandshake::lowerSignature( TypeConverter::SignatureConversion entryConversion( entryBlock->getNumArguments()); setupEntryBlockConversion(entryBlock, numMemories, rewriter, entryConversion); - rewriter.applySignatureConversion(oldBody, entryConversion, typeConv); + rewriter.applySignatureConversion(entryBlock, entryConversion, typeConv); - // Convert the non entry blocks' signatures - SmallVector nonEntryConversions; - for (Block &block : llvm::drop_begin(funcOp)) { - auto &conv = nonEntryConversions.emplace_back(block.getNumArguments()); - setupBlockConversion(&block, rewriter, conv); + for (Block &nonEntryBlock : + llvm::make_early_inc_range(llvm::drop_begin(funcOp.getBody()))) { + + TypeConverter::SignatureConversion nonEntryConversion( + /*numOrigInputs=*/nonEntryBlock.getNumArguments()); + + setupBlockConversion(&nonEntryBlock, rewriter, nonEntryConversion); + rewriter.applySignatureConversion(&nonEntryBlock, nonEntryConversion, + typeConv); } - if (failed(rewriter.convertNonEntryRegionTypes(oldBody, *typeConv, - nonEntryConversions))) - return failure(); // Modify branch-like terminators to forward the new control value through // all blocks @@ -551,8 +556,7 @@ void LowerFuncToHandshake::insertMerge(BlockArgument blockArg, "might have accidentally maximized the SSA of a placeholder op " "like LSQ, MemoryController, or RAMOp."); } - assert(operand.getType() - .cast() + assert(cast(operand.getType()) .getNumExtraSignals() == 0 && "unexpected extra signals"); } @@ -1312,7 +1316,7 @@ ConvertCalls::matchAndRewrite(func::CallOp callOp, OpAdaptor adaptor, Operation *sourceFuncLookup = mlir::SymbolTable::lookupNearestSymbolFrom(callOp, sourceCalleeAttr); auto sourceFunc = dyn_cast(sourceFuncLookup); - assert(sourceFunc && sourceFunc.getSymName().startswith("__init") && + assert(sourceFunc && sourceFunc.getSymName().starts_with("__init") && "All placeholder outputs must be initialized via __init* calls"); if (!llvm::is_contained(initCalls, sourceCallOp)) { @@ -1562,7 +1566,7 @@ struct GetGlobalOpConversion /// The initial value doesn't have any type constraints. Therefore we need /// to check if it is stored as dense elements. mlir::Attribute initValueAttr = global.getInitialValueAttr(); - if (auto denseAttr = initValueAttr.dyn_cast()) { + if (auto denseAttr = dyn_cast(initValueAttr)) { rewriter.replaceOpWithNewOp(op, op.getType(), denseAttr); } else { @@ -1591,7 +1595,7 @@ struct GlobalOpConversion : public DynOpConversionPattern { /// Filters out block arguments of type MemRefType bool FuncSSAStrategy::maximizeArgument(BlockArgument arg) { - return !arg.getType().isa(); + return !isa(arg.getType()); } namespace { @@ -1679,14 +1683,14 @@ struct CfToHandshakePass // addIllegalDialect rule above and must be converted by a pattern. if (auto calledFn = dyn_cast_or_null( SymbolTable::lookupNearestSymbolFrom(op, op.getCalleeAttr()))) { - return calledFn.getSymName().startswith("__init"); + return calledFn.getSymName().starts_with("__init"); } // If symbol lookup fails or it's not a func::FuncOp, treat as default // (illegal) return false; }); target.addDynamicallyLegalOp( - [](func::FuncOp op) { return op.getSymName().startswith("__init"); }); + [](func::FuncOp op) { return op.getSymName().starts_with("__init"); }); if (failed(applyFullConversion(modOp, target, std::move(patterns)))) return signalPassFailure(); @@ -1695,7 +1699,7 @@ struct CfToHandshakePass // has no remaining uses. This is safe because all valid calls to __init* // were tracked and deleted earlier. for (auto func : llvm::make_early_inc_range(modOp.getOps())) { - if (func.getSymName().startswith("__init")) { + if (func.getSymName().starts_with("__init")) { assert(func.use_empty() && "__init function should not have users after transformation"); func.erase(); diff --git a/lib/Conversion/HandshakeToHW/HandshakeToHW.cpp b/lib/Conversion/HandshakeToHW/HandshakeToHW.cpp index 3cb605cca5..47bb14094d 100644 --- a/lib/Conversion/HandshakeToHW/HandshakeToHW.cpp +++ b/lib/Conversion/HandshakeToHW/HandshakeToHW.cpp @@ -1193,18 +1193,20 @@ class ChannelTypeConverter : public TypeConverter { }); addTargetMaterialization([&](OpBuilder &builder, Type resultType, - ValueRange inputs, - Location loc) -> std::optional { + ValueRange inputs, Location loc) -> Value { + // NOTE (@Jiahui17): Taking the solution here + // https://github.com/llvm/circt/blob/main/lib/Dialect/ESI/Passes/ESILowerTypes.cpp if (inputs.size() != 1) - return std::nullopt; + return mlir::Value(); return inputs[0]; }); addSourceMaterialization([&](OpBuilder &builder, Type resultType, - ValueRange inputs, - Location loc) -> std::optional { + ValueRange inputs, Location loc) -> Value { + // NOTE (@Jiahui17): Taking the solution here + // https://github.com/llvm/circt/blob/main/lib/Dialect/ESI/Passes/ESILowerTypes.cpp if (inputs.size() != 1) - return std::nullopt; + return mlir::Value(); return inputs[0]; }); } diff --git a/lib/Dialect/HW/ConversionPatterns.cpp b/lib/Dialect/HW/ConversionPatterns.cpp index 9122888430..e1a6465c51 100644 --- a/lib/Dialect/HW/ConversionPatterns.cpp +++ b/lib/Dialect/HW/ConversionPatterns.cpp @@ -46,13 +46,13 @@ LogicalResult dynamatic::doTypeConversion(Operation *op, ValueRange operands, llvm::SmallVector newAttrs; newAttrs.reserve(op->getAttrs().size()); for (auto attr : op->getAttrs()) { - if (auto typeAttr = attr.getValue().dyn_cast()) { + if (auto typeAttr = dyn_cast(attr.getValue())) { auto innerType = typeAttr.getValue(); // TypeConvert::convertType doesn't handle function types, so we need to // handle them manually. - if (auto funcType = innerType.dyn_cast()) + if (auto funcType = dyn_cast(innerType)) innerType = convertFunctionType(*typeConverter, funcType); - else if (auto modType = innerType.dyn_cast()) + else if (auto modType = dyn_cast(innerType)) innerType = convertModuleType(*typeConverter, modType); else innerType = typeConverter->convertType(innerType); @@ -78,7 +78,7 @@ LogicalResult dynamatic::doTypeConversion(Operation *op, ValueRange operands, Operation *newOp = rewriter.create(state); // Move the regions over, converting the signatures as we go. - rewriter.startRootUpdate(newOp); + rewriter.startOpModification(newOp); for (size_t i = 0, e = op->getNumRegions(); i < e; ++i) { Region ®ion = op->getRegion(i); Region *newRegion = &newOp->getRegion(i); @@ -90,9 +90,11 @@ LogicalResult dynamatic::doTypeConversion(Operation *op, ValueRange operands, newRegion->getArgumentTypes(), result))) return rewriter.notifyMatchFailure(op->getLoc(), "type conversion failed"); - rewriter.applySignatureConversion(newRegion, result, typeConverter); + + Block &entryBlock = newRegion->front(); + rewriter.applySignatureConversion(&entryBlock, result, typeConverter); } - rewriter.finalizeRootUpdate(newOp); + rewriter.finalizeOpModification(newOp); rewriter.replaceOp(op, newOp->getResults()); return success(); diff --git a/lib/Dialect/HW/CustomDirectiveImpl.cpp b/lib/Dialect/HW/CustomDirectiveImpl.cpp index e3b50155af..d2f6be671e 100644 --- a/lib/Dialect/HW/CustomDirectiveImpl.cpp +++ b/lib/Dialect/HW/CustomDirectiveImpl.cpp @@ -52,7 +52,7 @@ void dynamatic::printInputPortList(OpAsmPrinter &p, Operation *op, [&](std::tuple input) { Value val = std::get<0>(input); p.printKeywordOrString( - std::get<1>(input).cast().getValue()); + cast(std::get<1>(input)).getValue()); p << ": " << val << ": " << val.getType(); }); p << ")"; @@ -86,13 +86,12 @@ void dynamatic::printOutputPortList(OpAsmPrinter &p, Operation *op, TypeRange resultTypes, ArrayAttr resultNames) { p << "("; - llvm::interleaveComma( - llvm::zip(resultTypes, resultNames), p, - [&](std::tuple result) { - p.printKeywordOrString( - std::get<1>(result).cast().getValue()); - p << ": " << std::get<0>(result); - }); + llvm::interleaveComma(llvm::zip(resultTypes, resultNames), p, + [&](std::tuple result) { + p.printKeywordOrString( + cast(std::get<1>(result)).getValue()); + p << ": " << std::get<0>(result); + }); p << ")"; } @@ -136,7 +135,7 @@ void dynamatic::printOptionalParameterList(OpAsmPrinter &p, Operation *op, p << '<'; llvm::interleaveComma(parameters, p, [&](Attribute param) { - auto paramAttr = param.cast(); + auto paramAttr = cast(param); p << paramAttr.getName().getValue() << ": " << paramAttr.getType(); if (auto value = paramAttr.getValue()) { p << " = "; diff --git a/lib/Dialect/HW/HWAttributes.cpp b/lib/Dialect/HW/HWAttributes.cpp index 23497a0367..2b99f67376 100644 --- a/lib/Dialect/HW/HWAttributes.cpp +++ b/lib/Dialect/HW/HWAttributes.cpp @@ -94,7 +94,7 @@ static std::string canonicalizeFilename(const Twine &directory, // separator and return it. // e.g. `directory` + `` -> `directory/`. auto separator = llvm::sys::path::get_separator(); - if (nativeFilename.empty() && !nativeDirectory.endswith(separator)) { + if (nativeFilename.empty() && !nativeDirectory.ends_with(separator)) { nativeDirectory += separator; return std::string(nativeDirectory); } @@ -211,7 +211,7 @@ EnumFieldAttr EnumFieldAttr::get(Location loc, StringAttr value, emitError(loc) << "expected enum type"; // Check whether the provided value is a member of the enum type. - EnumType enumType = getCanonicalType(type).cast(); + EnumType enumType = cast(getCanonicalType(type)); if (!enumType.contains(value.getValue())) { emitError(loc) << "enum value '" << value.getValue() << "' is not a member of enum type " << enumType; @@ -357,7 +357,7 @@ void InnerSymAttr::print(AsmPrinter &odsPrinter) const { auto props = getProps(); if (props.size() == 1 && - props[0].getSymVisibility().getValue().equals("public") && + (props[0].getSymVisibility().getValue() == "public") && props[0].getFieldID() == 0) { odsPrinter << "@" << props[0].getName().getValue(); return; @@ -456,8 +456,8 @@ static TypedAttr foldBinaryOp( ArrayRef operands, llvm::function_ref calculate) { assert(operands.size() == 2 && "binary operator always has two operands"); - if (auto lhs = operands[0].dyn_cast()) - if (auto rhs = operands[1].dyn_cast()) + if (auto lhs = dyn_cast(operands[0])) + if (auto rhs = dyn_cast(operands[1])) return IntegerAttr::get(lhs.getType(), calculate(lhs.getValue(), rhs.getValue())); return {}; @@ -469,7 +469,7 @@ static TypedAttr foldUnaryOp(ArrayRef operands, llvm::function_ref calculate) { assert(operands.size() == 1 && "unary operator always has one operand"); - if (auto intAttr = operands[0].dyn_cast()) + if (auto intAttr = dyn_cast(operands[0])) return IntegerAttr::get(intAttr.getType(), calculate(intAttr.getValue())); return {}; } @@ -477,7 +477,7 @@ foldUnaryOp(ArrayRef operands, /// If the specified attribute is a ParamExprAttr with the specified opcode, /// return it. Otherwise return null. static ParamExprAttr dyn_castPE(PEO opcode, Attribute value) { - if (auto expr = value.dyn_cast()) + if (auto expr = dyn_cast(value)) if (expr.getOpcode() == opcode) return expr; return {}; @@ -496,38 +496,38 @@ static bool paramExprOperandSortPredicate(Attribute lhs, Attribute rhs) { return false; // All expressions are "less than" a constant, since they appear on the right. - if (rhs.isa()) { + if (isa(rhs)) { // We don't bother to order constants w.r.t. each other since they will be // folded - they can all compare equal. - return !lhs.isa(); + return !isa(lhs); } - if (lhs.isa()) + if (isa(lhs)) return false; // Next up are named parameters. - if (auto rhsParam = rhs.dyn_cast()) { + if (auto rhsParam = dyn_cast(rhs)) { // Parameters are sorted lexically w.r.t. each other. - if (auto lhsParam = lhs.dyn_cast()) + if (auto lhsParam = dyn_cast(lhs)) return lhsParam.getName().getValue() < rhsParam.getName().getValue(); // They otherwise appear on the right of other things. return true; } - if (lhs.isa()) + if (isa(lhs)) return false; // Next up are verbatim parameters. - if (auto rhsParam = rhs.dyn_cast()) { + if (auto rhsParam = dyn_cast(rhs)) { // Verbatims are sorted lexically w.r.t. each other. - if (auto lhsParam = lhs.dyn_cast()) + if (auto lhsParam = dyn_cast(lhs)) return lhsParam.getValue().getValue() < rhsParam.getValue().getValue(); // They otherwise appear on the right of other things. return true; } - if (lhs.isa()) + if (isa(lhs)) return false; // The only thing left are nested expressions. - auto lhsExpr = lhs.cast(), rhsExpr = rhs.cast(); + auto lhsExpr = cast(lhs), rhsExpr = cast(rhs); // Sort by the string form of the opcode, e.g. add, .. mul,... then xor. if (lhsExpr.getOpcode() != rhsExpr.getOpcode()) return stringifyPEO(lhsExpr.getOpcode()) < @@ -584,16 +584,16 @@ static TypedAttr simplifyAssocOp( llvm::stable_sort(operands, paramExprOperandSortPredicate); // Merge any constants, they will appear at the back of the operand list now. - if (operands.back().isa()) { + if (isa(operands.back())) { while (operands.size() >= 2 && - operands[operands.size() - 2].isa()) { - APInt c1 = operands.pop_back_val().cast().getValue(); - APInt c2 = operands.pop_back_val().cast().getValue(); + isa(operands[operands.size() - 2])) { + APInt c1 = cast(operands.pop_back_val()).getValue(); + APInt c2 = cast(operands.pop_back_val()).getValue(); auto resultConstant = IntegerAttr::get(type, calculateFn(c1, c2)); operands.push_back(resultConstant); } - auto resultCst = operands.back().cast(); + auto resultCst = cast(operands.back()); // If the resulting constant is the destructive constant (e.g. `x*0`), then // return it. @@ -615,7 +615,7 @@ static TypedAttr simplifyAssocOp( /// null as the second (standin for "multiplication by 1"). static std::pair decomposeAddend(TypedAttr operand) { if (auto mul = dyn_castPE(PEO::Mul, operand)) - if (auto cst = mul.getOperands().back().dyn_cast()) { + if (auto cst = dyn_cast(mul.getOperands().back())) { auto nonCst = ParamExprAttr::get(PEO::Mul, mul.getOperands().drop_back()); return {nonCst, cst}; } @@ -728,9 +728,9 @@ static TypedAttr simplifyXor(SmallVector &operands) { static TypedAttr simplifyShl(SmallVector &operands) { assert(isHWIntegerType(operands[0].getType())); - if (auto rhs = operands[1].dyn_cast()) { + if (auto rhs = dyn_cast(operands[1])) { // Constant fold simple integers. - if (auto lhs = operands[0].dyn_cast()) + if (auto lhs = dyn_cast(operands[0])) return IntegerAttr::get(lhs.getType(), lhs.getValue().shl(rhs.getValue())); @@ -747,7 +747,7 @@ static TypedAttr simplifyShl(SmallVector &operands) { static TypedAttr simplifyShrU(SmallVector &operands) { assert(isHWIntegerType(operands[0].getType())); // Implement support for identities like `x >> 0`. - if (auto rhs = operands[1].dyn_cast()) + if (auto rhs = dyn_cast(operands[1])) if (rhs.getValue().isZero()) return operands[0]; @@ -757,7 +757,7 @@ static TypedAttr simplifyShrU(SmallVector &operands) { static TypedAttr simplifyShrS(SmallVector &operands) { assert(isHWIntegerType(operands[0].getType())); // Implement support for identities like `x >> 0`. - if (auto rhs = operands[1].dyn_cast()) + if (auto rhs = dyn_cast(operands[1])) if (rhs.getValue().isZero()) return operands[0]; @@ -767,7 +767,7 @@ static TypedAttr simplifyShrS(SmallVector &operands) { static TypedAttr simplifyDivU(SmallVector &operands) { assert(isHWIntegerType(operands[0].getType())); // Implement support for identities like `x/1`. - if (auto rhs = operands[1].dyn_cast()) + if (auto rhs = dyn_cast(operands[1])) if (rhs.getValue().isOne()) return operands[0]; @@ -777,7 +777,7 @@ static TypedAttr simplifyDivU(SmallVector &operands) { static TypedAttr simplifyDivS(SmallVector &operands) { assert(isHWIntegerType(operands[0].getType())); // Implement support for identities like `x/1`. - if (auto rhs = operands[1].dyn_cast()) + if (auto rhs = dyn_cast(operands[1])) if (rhs.getValue().isOne()) return operands[0]; @@ -787,7 +787,7 @@ static TypedAttr simplifyDivS(SmallVector &operands) { static TypedAttr simplifyModU(SmallVector &operands) { assert(isHWIntegerType(operands[0].getType())); // Implement support for identities like `x%1`. - if (auto rhs = operands[1].dyn_cast()) + if (auto rhs = dyn_cast(operands[1])) if (rhs.getValue().isOne()) return IntegerAttr::get(rhs.getType(), 0); @@ -797,7 +797,7 @@ static TypedAttr simplifyModU(SmallVector &operands) { static TypedAttr simplifyModS(SmallVector &operands) { assert(isHWIntegerType(operands[0].getType())); // Implement support for identities like `x%1`. - if (auto rhs = operands[1].dyn_cast()) + if (auto rhs = dyn_cast(operands[1])) if (rhs.getValue().isOne()) return IntegerAttr::get(rhs.getType(), 0); @@ -829,7 +829,7 @@ static TypedAttr simplifyStrConcat(SmallVector &operands) { }; for (TypedAttr op : operands) { - if (auto strOp = op.dyn_cast()) { + if (auto strOp = dyn_cast(op)) { // Queue up adjacent strings. stringsToCombine.push_back(strOp); } else { @@ -953,11 +953,11 @@ static FailureOr replaceDeclRefInExpr(Location loc, const std::map ¶meters, Attribute paramAttr, bool emitErrors) { - if (paramAttr.dyn_cast()) { + if (dyn_cast(paramAttr)) { // Nothing to do, constant value. return paramAttr; } - if (auto paramRefAttr = paramAttr.dyn_cast()) { + if (auto paramRefAttr = dyn_cast(paramAttr)) { // Get the value from the provided parameters. auto it = parameters.find(paramRefAttr.getName().str()); if (it == parameters.end()) { @@ -969,14 +969,14 @@ replaceDeclRefInExpr(Location loc, } return it->second; } - if (auto paramExprAttr = paramAttr.dyn_cast()) { + if (auto paramExprAttr = dyn_cast(paramAttr)) { // Recurse into all operands of the expression. llvm::SmallVector replacedOperands; for (auto operand : paramExprAttr.getOperands()) { auto res = replaceDeclRefInExpr(loc, parameters, operand, emitErrors); if (failed(res)) return {failure()}; - replacedOperands.push_back(res->cast()); + replacedOperands.push_back(cast(res.value())); } return { hw::ParamExprAttr::get(paramExprAttr.getOpcode(), replacedOperands)}; @@ -992,7 +992,7 @@ FailureOr hw::evaluateParametricAttr(Location loc, // Create a map of the provided parameters for faster lookup. std::map parameterMap; for (auto param : parameters) { - auto paramDecl = param.cast(); + auto paramDecl = cast(param); parameterMap[paramDecl.getName().str()] = paramDecl.getValue(); } @@ -1005,9 +1005,9 @@ FailureOr hw::evaluateParametricAttr(Location loc, paramAttr = *paramAttrRes; // Then, evaluate the parametric attribute. - if (paramAttr.isa()) - return paramAttr.cast(); - if (auto paramExprAttr = paramAttr.dyn_cast()) { + if (isa(paramAttr)) + return cast(paramAttr); + if (auto paramExprAttr = dyn_cast(paramAttr)) { // Since any ParamDeclRefAttr was replaced within the expression, // we re-evaluate the expression through the existing ParamExprAttr // canonicalizer. @@ -1033,7 +1033,7 @@ FailureOr evaluateParametricArrayType(Location loc, ArrayAttr parameters, // If the size was evaluated to a constant, use a 64-bit integer // attribute version of it - if (auto intAttr = size->template dyn_cast()) + if (auto intAttr = dyn_cast(size.value())) return TArray::get( arrayType.getContext(), *elementType, IntegerAttr::get(IntegerType::get(arrayType.getContext(), 64), @@ -1053,12 +1053,12 @@ FailureOr hw::evaluateParametricType(Location loc, ArrayAttr parameters, return {failure()}; // If the width was evaluated to a constant, return an `IntegerType` - if (auto intAttr = evaluatedWidth->dyn_cast()) + if (auto intAttr = dyn_cast(evaluatedWidth.value())) return {IntegerType::get(type.getContext(), intAttr.getValue().getSExtValue())}; // Otherwise parameter references are still involved - return hw::IntType::get(evaluatedWidth->cast()); + return hw::IntType::get(cast(evaluatedWidth.value())); }) .Case( [&](auto arrayType) -> FailureOr { diff --git a/lib/Dialect/HW/HWDialect.cpp b/lib/Dialect/HW/HWDialect.cpp index 2e806e27e9..257bb8998e 100644 --- a/lib/Dialect/HW/HWDialect.cpp +++ b/lib/Dialect/HW/HWDialect.cpp @@ -96,13 +96,13 @@ void HWDialect::initialize() { Operation *HWDialect::materializeConstant(OpBuilder &builder, Attribute value, Type type, Location loc) { // Integer constants can materialize into hw.constant - if (auto intType = type.dyn_cast()) - if (auto attrValue = value.dyn_cast()) + if (auto intType = dyn_cast(type)) + if (auto attrValue = dyn_cast(value)) return builder.create(loc, type, attrValue); // Aggregate constants. - if (auto arrayAttr = value.dyn_cast()) { - if (type.isa()) + if (auto arrayAttr = dyn_cast(value)) { + if (isa(type)) return builder.create(loc, type, arrayAttr); } diff --git a/lib/Dialect/HW/HWInstanceImplementation.cpp b/lib/Dialect/HW/HWInstanceImplementation.cpp index 55d2590810..2f9dbd2b48 100644 --- a/lib/Dialect/HW/HWInstanceImplementation.cpp +++ b/lib/Dialect/HW/HWInstanceImplementation.cpp @@ -179,8 +179,8 @@ instance_like_impl::verifyParameters(ArrayAttr parameters, } for (size_t i = 0; i != numParameters; ++i) { - auto param = parameters[i].cast(); - auto modParam = moduleParameters[i].cast(); + auto param = cast(parameters[i]); + auto modParam = cast(moduleParameters[i]); auto paramName = param.getName(); if (paramName != modParam.getName()) { @@ -290,14 +290,14 @@ instance_like_impl::verifyParameterStructure(ArrayAttr parameters, // Check that all the parameter values specified to the instance are // structurally valid. for (auto param : parameters) { - auto paramAttr = param.cast(); + auto paramAttr = cast(param); auto value = paramAttr.getValue(); // The SymbolUses verifier which checks that this exists may not have been // run yet. Let it issue the error. if (!value) continue; - auto typedValue = value.dyn_cast(); + auto typedValue = dyn_cast(value); if (!typedValue) { emitError([&](auto &diag) { diag << "parameter " << paramAttr @@ -325,7 +325,7 @@ instance_like_impl::verifyParameterStructure(ArrayAttr parameters, StringAttr instance_like_impl::getName(ArrayAttr names, size_t idx) { // Tolerate malformed IR here to enable debug printing etc. if (names && idx < names.size()) - return names[idx].cast(); + return cast(names[idx]); return StringAttr(); } @@ -375,23 +375,23 @@ SmallVector instance_like_impl::getPortList(Operation *instanceOp) { auto type = argTypes[i]; auto direction = ModulePort::Direction::Input; - if (auto inout = type.dyn_cast()) { + if (auto inout = dyn_cast(type)) { type = inout.getElementType(); direction = ModulePort::Direction::InOut; } LocationAttr loc; if (argLocs) - loc = argLocs[i].cast(); + loc = cast(argLocs[i]); ports.push_back( - {{argNames[i].cast(), type, direction}, i, emptyDict, loc}); + {{cast(argNames[i]), type, direction}, i, emptyDict, loc}); } for (unsigned i = 0, e = resultTypes.size(); i < e; ++i) { LocationAttr loc; if (resultLocs) - loc = resultLocs[i].cast(); - ports.push_back({{resultNames[i].cast(), resultTypes[i], + loc = cast(resultLocs[i]); + ports.push_back({{cast(resultNames[i]), resultTypes[i], ModulePort::Direction::Output}, i, emptyDict, diff --git a/lib/Dialect/HW/HWModuleOpInterface.cpp b/lib/Dialect/HW/HWModuleOpInterface.cpp index 51368f0661..16455a0861 100644 --- a/lib/Dialect/HW/HWModuleOpInterface.cpp +++ b/lib/Dialect/HW/HWModuleOpInterface.cpp @@ -64,7 +64,7 @@ static LogicalResult convertModuleOpTypes(HWModuleLike modOp, return failure(); auto newType = ModuleType::get(rewriter.getContext(), newPorts); - rewriter.updateRootInPlace(modOp, [&] { modOp.setHWModuleType(newType); }); + rewriter.modifyOpInPlace(modOp, [&] { modOp.setHWModuleType(newType); }); return success(); } diff --git a/lib/Dialect/HW/HWOpInterfaces.cpp b/lib/Dialect/HW/HWOpInterfaces.cpp index d18cb0e349..9ab01c380f 100644 --- a/lib/Dialect/HW/HWOpInterfaces.cpp +++ b/lib/Dialect/HW/HWOpInterfaces.cpp @@ -51,7 +51,7 @@ void hw::PortInfo::setSym(InnerSymAttr sym, MLIRContext *ctx) { StringRef hw::PortInfo::getVerilogName() const { if (attrs) if (auto updatedName = attrs.get("hw.verilogName")) - return updatedName.cast().getValue(); + return cast(updatedName).getValue(); return name.getValue(); } diff --git a/lib/Dialect/HW/HWOps.cpp b/lib/Dialect/HW/HWOps.cpp index edeb2e74c5..8d6a9df135 100644 --- a/lib/Dialect/HW/HWOps.cpp +++ b/lib/Dialect/HW/HWOps.cpp @@ -49,7 +49,7 @@ ModulePort::Direction hw::flip(ModulePort::Direction direction) { bool hw::isValidIndexBitWidth(Value index, Value array) { hw::ArrayType arrayType = - hw::getCanonicalType(array.getType()).dyn_cast(); + dyn_cast(hw::getCanonicalType(array.getType())); assert(arrayType && "expected array type"); unsigned indexWidth = index.getType().getIntOrFloatBitWidth(); auto requiredWidth = llvm::Log2_64_Ceil(arrayType.getNumElements()); @@ -122,12 +122,12 @@ LogicalResult hw::checkParameterInContext( bool disallowParamRefs) { // Literals are always ok. Their types are already known to match // expectations. - if (value.isa() || value.isa() || - value.isa() || value.isa()) + if (isa(value) || isa(value) || + isa(value) || isa(value)) return success(); // Check both subexpressions of an expression. - if (auto expr = value.dyn_cast()) { + if (auto expr = dyn_cast(value)) { for (auto op : expr.getOperands()) if (failed(checkParameterInContext(op, moduleParameters, instanceError, disallowParamRefs))) @@ -137,7 +137,7 @@ LogicalResult hw::checkParameterInContext( // Parameter references need more analysis to make sure they are valid within // this module. - if (auto parameterRef = value.dyn_cast()) { + if (auto parameterRef = dyn_cast(value)) { auto nameAttr = parameterRef.getName(); // Don't allow references to parameters from the default values of a @@ -153,7 +153,7 @@ LogicalResult hw::checkParameterInContext( // Find the corresponding attribute in the module. for (auto param : moduleParameters) { - auto paramAttr = param.cast(); + auto paramAttr = cast(param); if (paramAttr.getName() != nameAttr) continue; @@ -269,7 +269,7 @@ ParseResult ConstantOp::parse(OpAsmParser &parser, OperationState &result) { LogicalResult ConstantOp::verify() { // If the result type has a bitwidth, then the attribute must match its width. - if (getValue().getBitWidth() != getType().cast().getWidth()) + if (getValue().getBitWidth() != cast(getType()).getWidth()) return emitError( "hw.constant attribute bitwidth doesn't match return type"); @@ -299,7 +299,7 @@ void ConstantOp::build(OpBuilder &builder, OperationState &result, /// an int64_t. Use APInt's instead. void ConstantOp::build(OpBuilder &builder, OperationState &result, Type type, int64_t value) { - auto numBits = type.cast().getWidth(); + auto numBits = cast(type).getWidth(); build(builder, result, APInt(numBits, (uint64_t)value, /*isSigned=*/true)); } @@ -309,7 +309,7 @@ void ConstantOp::getAsmResultNames( auto intCst = getValue(); // Sugar i1 constants with 'true' and 'false'. - if (intTy.cast().getWidth() == 1) + if (cast(intTy).getWidth() == 1) return setNameFn(getResult(), intCst.isZero() ? "false" : "true"); // Otherwise, build a complex name with the value and type. @@ -330,11 +330,11 @@ OpFoldResult ConstantOp::fold(FoldAdaptor adaptor) { static LogicalResult checkAttributes(Operation *op, Attribute attr, Type type) { // If this is a type alias, get the underlying type. - if (auto typeAlias = type.dyn_cast()) + if (auto typeAlias = dyn_cast(type)) type = typeAlias.getCanonicalType(); - if (auto structType = type.dyn_cast()) { - auto arrayAttr = attr.dyn_cast(); + if (auto structType = dyn_cast(type)) { + auto arrayAttr = dyn_cast(attr); if (!arrayAttr) return op->emitOpError("expected array attribute for constant of type ") << type; @@ -348,8 +348,8 @@ static LogicalResult checkAttributes(Operation *op, Attribute attr, Type type) { if (failed(checkAttributes(op, attr, fieldInfo.type))) return failure(); } - } else if (auto arrayType = type.dyn_cast()) { - auto arrayAttr = attr.dyn_cast(); + } else if (auto arrayType = dyn_cast(type)) { + auto arrayAttr = dyn_cast(attr); if (!arrayAttr) return op->emitOpError("expected array attribute for constant of type ") << type; @@ -363,8 +363,8 @@ static LogicalResult checkAttributes(Operation *op, Attribute attr, Type type) { if (failed(checkAttributes(op, attr, elementType))) return failure(); } - } else if (auto arrayType = type.dyn_cast()) { - auto arrayAttr = attr.dyn_cast(); + } else if (auto arrayType = dyn_cast(type)) { + auto arrayAttr = dyn_cast(attr); if (!arrayAttr) return op->emitOpError("expected array attribute for constant of type ") << type; @@ -379,14 +379,14 @@ static LogicalResult checkAttributes(Operation *op, Attribute attr, Type type) { if (failed(checkAttributes(op, attr, elementType))) return failure(); } - } else if (auto enumType = type.dyn_cast()) { - auto stringAttr = attr.dyn_cast(); + } else if (auto enumType = dyn_cast(type)) { + auto stringAttr = dyn_cast(attr); if (!stringAttr) return op->emitOpError("expected string attribute for constant of type ") << type; - } else if (auto intType = type.dyn_cast()) { + } else if (auto intType = dyn_cast(type)) { // Check the attribute kind is correct. - auto intAttr = attr.dyn_cast(); + auto intAttr = dyn_cast(attr); if (!intAttr) return op->emitOpError("expected integer attribute for constant of type ") << type; @@ -456,9 +456,8 @@ FunctionType hw::getModuleType(Operation *moduleOrInstance) { .Case( [](auto mod) { return mod.getHWModuleType().getFuncType(); }) .Default([](Operation *op) { - return cast(op) - .getFunctionType() - .cast(); + return cast( + cast(op).getFunctionType()); }); } @@ -560,7 +559,7 @@ static void modifyModuleArgs( while (!insertArgs.empty() && insertArgs[0].first == argIdx) { auto port = insertArgs[0].second; if (port.dir == ModulePort::Direction::InOut && - !port.type.isa()) + !isa(port.type)) port.type = InOutType::get(port.type); auto sym = port.getSym(); Attribute attr = @@ -676,7 +675,7 @@ void HWModuleOp::build(OpBuilder &builder, OperationState &result, for (auto port : ports.getInputs()) { auto loc = port.loc ? Location(port.loc) : unknownLoc; auto type = port.type; - if (port.isInOut() && !type.isa()) + if (port.isInOut() && !isa(type)) type = InOutType::get(type); body->addArgument(type, loc); } @@ -945,9 +944,8 @@ ParseResult HWModuleGeneratedOp::parse(OpAsmParser &parser, FunctionType getHWModuleOpType(Operation *op) { if (auto mod = dyn_cast(op)) return mod.getHWModuleType().getFuncType(); - return cast(op) - .getFunctionType() - .cast(); + return cast( + cast(op).getFunctionType()); } template @@ -1021,7 +1019,7 @@ static LogicalResult verifyModuleCommon(HWModuleLike module) { // Check parameter default values are sensible. for (auto param : module->getAttrOfType("parameters")) { - auto paramAttr = param.cast(); + auto paramAttr = cast(param); // Check that we don't have any redundant parameter names. These are // resolved by string name: reuse of the same name would cause ambiguities. @@ -1034,7 +1032,7 @@ static LogicalResult verifyModuleCommon(HWModuleLike module) { if (!value) continue; - auto typedValue = value.dyn_cast(); + auto typedValue = dyn_cast(value); if (!typedValue) return module->emitOpError("parameter ") << paramAttr << " should have a typed value; has value " << value; @@ -1072,7 +1070,7 @@ LogicalResult HWModuleOp::verify() { getInputTypes(), getInputLocs())) { if (arg.getType() != type) return emitOpError("block argument types should match signature types"); - if (arg.getLoc() != loc.cast()) + if (arg.getLoc() != cast(loc)) return emitOpError( "block argument locations should match signature locations"); } @@ -1268,7 +1266,7 @@ HWModuleGeneratedOp::verifySymbolUses(SymbolTableCollection &symbolTable) { auto paramRef = referencedKindOp.getRequiredAttrs(); auto dict = (*this)->getAttrDictionary(); for (auto str : paramRef) { - auto strAttr = str.dyn_cast(); + auto strAttr = dyn_cast(str); if (!strAttr) return emitError("Unknown attribute type, expected a string"); if (!dict.get(strAttr.getValue())) @@ -1458,9 +1456,8 @@ LogicalResult InstanceChoiceOp::verifySymbolUses(SymbolTableCollection &symbolTable) { for (Attribute name : getModuleNamesAttr()) { if (failed(instance_like_impl::verifyInstanceOfHWModule( - *this, name.cast(), getInputs(), - getResultTypes(), getArgNames(), getResultNames(), getParameters(), - symbolTable))) { + *this, cast(name), getInputs(), getResultTypes(), + getArgNames(), getResultNames(), getParameters(), symbolTable))) { return failure(); } } @@ -1678,7 +1675,7 @@ void ArrayCreateOp::build(OpBuilder &b, OperationState &state, } LogicalResult ArrayCreateOp::verify() { - unsigned returnSize = getType().cast().getNumElements(); + unsigned returnSize = cast(getType()).getNumElements(); if (getInputs().size() != returnSize) return failure(); return success(); @@ -1953,19 +1950,19 @@ static void printArrayConcatTypes(OpAsmPrinter &p, Operation *, void ArrayConcatOp::build(OpBuilder &b, OperationState &state, ValueRange values) { assert(!values.empty() && "Cannot build array of zero elements"); - ArrayType arrayTy = values[0].getType().cast(); + ArrayType arrayTy = cast(values[0].getType()); Type elemTy = arrayTy.getElementType(); assert(llvm::all_of(values, [elemTy](Value v) -> bool { - return v.getType().isa() && - v.getType().cast().getElementType() == + return isa(v.getType()) && + cast(v.getType()).getElementType() == elemTy; }) && "All values must be of ArrayType with the same element type."); uint64_t resultSize = 0; for (Value val : values) - resultSize += val.getType().cast().getNumElements(); + resultSize += cast(val.getType()).getNumElements(); build(b, state, ArrayType::get(elemTy, resultSize), values); } @@ -1975,7 +1972,7 @@ OpFoldResult ArrayConcatOp::fold(FoldAdaptor adaptor) { for (size_t i = 0, e = getNumOperands(); i < e; ++i) { if (!inputs[i]) return {}; - llvm::copy(inputs[i].cast(), std::back_inserter(array)); + llvm::copy(cast(inputs[i]), std::back_inserter(array)); } return ArrayAttr::get(getContext(), array); } @@ -2270,7 +2267,7 @@ LogicalResult StructExplodeOp::fold(FoldAdaptor adaptor, auto input = adaptor.getInput(); if (!input) return failure(); - llvm::copy(input.cast(), std::back_inserter(results)); + llvm::copy(cast(input), std::back_inserter(results)); return success(); } @@ -2298,7 +2295,7 @@ void StructExplodeOp::getAsmResultNames( void StructExplodeOp::build(OpBuilder &odsBuilder, OperationState &odsState, Value input) { - StructType inputType = input.getType().dyn_cast(); + StructType inputType = dyn_cast(input.getType()); assert(inputType); SmallVector fieldTypes; for (auto field : inputType.getElements()) @@ -2512,7 +2509,7 @@ OpFoldResult StructInjectOp::fold(FoldAdaptor adaptor) { if (!input || !newValue) return {}; SmallVector array; - llvm::copy(input.cast(), std::back_inserter(array)); + llvm::copy(cast(input), std::back_inserter(array)); array[getFieldIndex()] = newValue; return ArrayAttr::get(getContext(), array); } @@ -2675,8 +2672,8 @@ void UnionExtractOp::build(OpBuilder &odsBuilder, OperationState &odsState, // the index. If the array is constructed from a constant by a bitcast // operation, we can fold into a constant. OpFoldResult ArrayGetOp::fold(FoldAdaptor adaptor) { - auto inputCst = adaptor.getInput().dyn_cast_or_null(); - auto indexCst = adaptor.getIndex().dyn_cast_or_null(); + auto inputCst = dyn_cast_or_null(adaptor.getInput()); + auto indexCst = dyn_cast_or_null(adaptor.getIndex()); if (inputCst) { // Constant array index. @@ -2695,7 +2692,7 @@ OpFoldResult ArrayGetOp::fold(FoldAdaptor adaptor) { // array_get(bitcast(c), i) -> c[i*w+w-1:i*w] if (auto bitcast = getInput().getDefiningOp()) { - auto intTy = getType().dyn_cast(); + auto intTy = dyn_cast(getType()); if (!intTy) return {}; auto bitcastInputOp = bitcast.getInput().getDefiningOp(); diff --git a/lib/Dialect/HW/HWTypes.cpp b/lib/Dialect/HW/HWTypes.cpp index f1ff259a7a..13cc858531 100644 --- a/lib/Dialect/HW/HWTypes.cpp +++ b/lib/Dialect/HW/HWTypes.cpp @@ -46,7 +46,7 @@ using namespace dynamatic::hw::detail; mlir::Type dynamatic::hw::getCanonicalType(mlir::Type type) { Type canonicalType; - if (auto typeAlias = type.dyn_cast()) + if (auto typeAlias = dyn_cast(type)) canonicalType = typeAlias.getCanonicalType(); else canonicalType = type; @@ -58,10 +58,10 @@ mlir::Type dynamatic::hw::getCanonicalType(mlir::Type type) { bool dynamatic::hw::isHWIntegerType(mlir::Type type) { Type canonicalType = getCanonicalType(type); - if (canonicalType.isa()) + if (isa(canonicalType)) return true; - auto intType = canonicalType.dyn_cast(); + auto intType = dyn_cast(canonicalType); if (!intType || !intType.isSignless()) return false; @@ -69,7 +69,7 @@ bool dynamatic::hw::isHWIntegerType(mlir::Type type) { } bool dynamatic::hw::isHWEnumType(mlir::Type type) { - return getCanonicalType(type).isa(); + return isa(getCanonicalType(type)); } /// Return true if the specified type can be used as an HW value type, that is @@ -77,24 +77,24 @@ bool dynamatic::hw::isHWEnumType(mlir::Type type) { /// hardware but not marker types like InOutType. bool dynamatic::hw::isHWValueType(Type type) { // Signless and signed integer types are both valid. - if (type.isa()) + if (isa(type)) return true; - if (auto array = type.dyn_cast()) + if (auto array = dyn_cast(type)) return isHWValueType(array.getElementType()); - if (auto array = type.dyn_cast()) + if (auto array = dyn_cast(type)) return isHWValueType(array.getElementType()); - if (auto t = type.dyn_cast()) + if (auto t = dyn_cast(type)) return llvm::all_of(t.getElements(), [](auto f) { return isHWValueType(f.type); }); - if (auto t = type.dyn_cast()) + if (auto t = dyn_cast(type)) return llvm::all_of(t.getElements(), [](auto f) { return isHWValueType(f.type); }); - if (auto t = type.dyn_cast()) + if (auto t = dyn_cast(type)) return isHWValueType(t.getCanonicalType()); return false; @@ -147,21 +147,21 @@ int64_t dynamatic::hw::getBitWidth(mlir::Type type) { /// InOutType. Unlike isHWValueType, this is not conservative, it only returns /// false on known InOut types, rather than any unknown types. bool dynamatic::hw::hasHWInOutType(Type type) { - if (auto array = type.dyn_cast()) + if (auto array = dyn_cast(type)) return hasHWInOutType(array.getElementType()); - if (auto array = type.dyn_cast()) + if (auto array = dyn_cast(type)) return hasHWInOutType(array.getElementType()); - if (auto t = type.dyn_cast()) { + if (auto t = dyn_cast(type)) { return std::any_of(t.getElements().begin(), t.getElements().end(), [](const auto &f) { return hasHWInOutType(f.type); }); } - if (auto t = type.dyn_cast()) + if (auto t = dyn_cast(type)) return hasHWInOutType(t.getCanonicalType()); - return type.isa(); + return isa(type); } /// Parse and print nested HW types nicely. These helper methods allow eliding @@ -199,12 +199,12 @@ static void printHWElementType(Type element, AsmPrinter &p) { Type IntType::get(mlir::TypedAttr width) { // The width expression must always be a 32-bit wide integer type itself. - auto widthWidth = width.getType().dyn_cast(); + auto widthWidth = dyn_cast(width.getType()); assert(widthWidth && widthWidth.getWidth() == 32 && "!hw.int width must be 32-bits"); (void)widthWidth; - if (auto cstWidth = width.dyn_cast()) + if (auto cstWidth = dyn_cast(width)) return IntegerType::get(width.getContext(), cstWidth.getValue().getZExtValue()); @@ -526,7 +526,7 @@ Type EnumType::parse(AsmParser &p) { void EnumType::print(AsmPrinter &p) const { p << '<'; llvm::interleaveComma(getFields(), p, [&](Attribute enumerator) { - p << enumerator.cast().getValue(); + p << cast(enumerator).getValue(); }); p << ">"; } @@ -537,7 +537,7 @@ bool EnumType::contains(mlir::StringRef field) { std::optional EnumType::indexOf(mlir::StringRef field) { for (auto it : llvm::enumerate(getFields())) - if (it.value().cast().getValue() == field) + if (cast(it.value()).getValue() == field) return it.index(); return {}; } @@ -565,7 +565,7 @@ static LogicalResult parseArray(AsmParser &p, Attribute &dim, Type &inner) { else if (!p.parseOptionalAttribute(dim, int64Type).has_value()) return failure(); - if (!dim.isa()) { + if (!isa(dim)) { p.emitError(p.getNameLoc(), "unsupported dimension kind in hw.array"); return failure(); } @@ -600,7 +600,7 @@ void ArrayType::print(AsmPrinter &p) const { } size_t ArrayType::getNumElements() const { - if (auto intAttr = getSizeAttr().dyn_cast()) + if (auto intAttr = dyn_cast(getSizeAttr())) return intAttr.getInt(); return -1; } @@ -685,7 +685,7 @@ UnpackedArrayType::verify(function_ref emitError, } size_t UnpackedArrayType::getNumElements() const { - if (auto intAttr = getSizeAttr().dyn_cast()) + if (auto intAttr = dyn_cast(getSizeAttr())) return intAttr.getInt(); return -1; } diff --git a/lib/Dialect/HW/ModuleImplementation.cpp b/lib/Dialect/HW/ModuleImplementation.cpp index 114b8903e8..397de63279 100644 --- a/lib/Dialect/HW/ModuleImplementation.cpp +++ b/lib/Dialect/HW/ModuleImplementation.cpp @@ -81,7 +81,7 @@ static StringRef getModuleArgumentName(Operation *module, size_t argNo) { auto argNames = module->getAttrOfType("argNames"); // Tolerate malformed IR here to enable debug printing etc. if (argNames && argNo < argNames.size()) - return argNames[argNo].cast().getValue(); + return cast(argNames[argNo]).getValue(); return StringRef(); } @@ -94,7 +94,7 @@ static StringRef getModuleResultName(Operation *module, size_t resultNo) { auto resultNames = module->getAttrOfType("resultNames"); // Tolerate malformed IR here to enable debug printing etc. if (resultNames && resultNo < resultNames.size()) - return resultNames[resultNo].cast().getValue(); + return cast(resultNames[resultNo]).getValue(); return StringRef(); } diff --git a/lib/Dialect/Handshake/HandshakeOps.cpp b/lib/Dialect/Handshake/HandshakeOps.cpp index f357ac52d1..2a4921425e 100644 --- a/lib/Dialect/Handshake/HandshakeOps.cpp +++ b/lib/Dialect/Handshake/HandshakeOps.cpp @@ -34,6 +34,7 @@ #include "mlir/IR/Value.h" #include "mlir/IR/ValueRange.h" #include "mlir/Interfaces/FunctionImplementation.h" +#include "mlir/Interfaces/FunctionInterfaces.h" #include "mlir/Support/LogicalResult.h" #include "mlir/Transforms/InliningUtils.h" #include "llvm/ADT/DenseSet.h" @@ -41,6 +42,7 @@ #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/TypeSwitch.h" #include "llvm/Support/ErrorHandling.h" + #include using namespace mlir; @@ -383,7 +385,7 @@ LogicalResult FuncOp::verify() { << "."; if (llvm::any_of(portNames, - [&](Attribute attr) { return !attr.isa(); })) + [&](Attribute attr) { return !isa(attr); })) return emitOpError() << "expected all entries in attribute '" << attrName << "' to be strings."; @@ -424,7 +426,7 @@ parseFuncOpArgs(OpAsmParser &parser, SmallVectorImpl &resTypes, SmallVectorImpl &resAttrs) { bool isVariadic; - if (mlir::function_interface_impl::parseFunctionSignature( + if (mlir::function_interface_impl::parseFunctionSignatureWithArguments( parser, /*allowVariadic=*/true, entryArgs, isVariadic, resTypes, resAttrs) .failed()) @@ -481,7 +483,7 @@ ParseResult FuncOp::parse(OpAsmParser &parser, OperationState &result) { result.attributes) || parseFuncOpArgs(parser, args, resTypes, resAttributes)) return failure(); - mlir::function_interface_impl::addArgAndResultAttrs( + mlir::call_interface_impl::addArgAndResultAttrs( builder, result, args, resAttributes, handshake::FuncOp::getArgAttrsAttrName(result.name), handshake::FuncOp::getResAttrsAttrName(result.name)); @@ -696,7 +698,7 @@ void MemoryControllerOp::build(OpBuilder &odsBuilder, OperationState &odsState, odsState.addOperands(ctrlEnd); // Data outputs (get their type from memref) - MemRefType memrefType = memRef.getType().cast(); + MemRefType memrefType = mlir::cast(memRef.getType()); MLIRContext *ctx = odsBuilder.getContext(); odsState.types.append(numLoads, wrapChannel(memrefType.getElementType())); odsState.types.push_back(handshake::ControlType::get(ctx)); @@ -920,7 +922,7 @@ void LSQOp::build(OpBuilder &odsBuilder, OperationState &odsState, Value memref, odsState.addOperands(ctrlEnd); // Data outputs (get their type from memref) - MemRefType memrefType = memref.getType().cast(); + MemRefType memrefType = mlir::cast(memref.getType()); MLIRContext *ctx = odsBuilder.getContext(); odsState.types.append(numLoads, wrapChannel(memrefType.getElementType())); odsState.types.push_back(handshake::ControlType::get(ctx)); @@ -1318,7 +1320,7 @@ handshake::LoadOp LoadPort::getLoadOp() const { } StorePort::StorePort(handshake::StoreOp storeOp, unsigned addrInputIdx) - : MemoryPort(storeOp, {addrInputIdx, addrInputIdx + 1}, {}, Kind::STORE){}; + : MemoryPort(storeOp, {addrInputIdx, addrInputIdx + 1}, {}, Kind::STORE) {}; handshake::StoreOp StorePort::getStoreOp() const { return cast(portOp); @@ -1351,7 +1353,8 @@ handshake::MemoryControllerOp MCLoadStorePort::getMCOp() const { // GroupMemoryPorts //===----------------------------------------------------------------------===// -GroupMemoryPorts::GroupMemoryPorts(ControlPort ctrlPort) : ctrlPort(ctrlPort){}; +GroupMemoryPorts::GroupMemoryPorts(ControlPort ctrlPort) + : ctrlPort(ctrlPort) {}; unsigned GroupMemoryPorts::getNumInputs() const { unsigned numInputs = hasControl() ? 1 : 0; @@ -1468,9 +1471,9 @@ ValueRange FuncMemoryPorts::getInterfacesResults() { } MCBlock::MCBlock(GroupMemoryPorts *group, unsigned blockID) - : blockID(blockID), group(group){}; + : blockID(blockID), group(group) {}; -MCPorts::MCPorts(handshake::MemoryControllerOp mcOp) : FuncMemoryPorts(mcOp){}; +MCPorts::MCPorts(handshake::MemoryControllerOp mcOp) : FuncMemoryPorts(mcOp) {}; handshake::MemoryControllerOp MCPorts::getMCOp() const { return cast(memOp); @@ -1506,7 +1509,7 @@ SmallVector LSQPorts::getGroups() { return lsqGroups; } -LSQPorts::LSQPorts(handshake::LSQOp lsqOp) : FuncMemoryPorts(lsqOp){}; +LSQPorts::LSQPorts(handshake::LSQOp lsqOp) : FuncMemoryPorts(lsqOp) {}; handshake::LSQOp LSQPorts::getLSQOp() const { return cast(memOp); @@ -1834,7 +1837,7 @@ CmpFOp::inferReturnTypes(MLIRContext *context, std::optional location, // - operand[0] is a channel type // - all operands have the same extra signals // Note that this cast throws an error if the assumption is not met - operands[0].getType().cast().getExtraSignals())); + mlir::cast(operands[0].getType()).getExtraSignals())); return success(); } @@ -1856,7 +1859,7 @@ CmpIOp::inferReturnTypes(MLIRContext *context, std::optional location, // - operand[0] is a channel type // - all operands have the same extra signals // Note that this cast throws an error if the assumption is not met - operands[0].getType().cast().getExtraSignals())); + mlir::cast(operands[0].getType()).getExtraSignals())); return success(); } @@ -1965,4 +1968,4 @@ LogicalResult TruncIOp::verify() { } #define GET_OP_CLASSES -#include "dynamatic/Dialect/Handshake/Handshake.cpp.inc" \ No newline at end of file +#include "dynamatic/Dialect/Handshake/Handshake.cpp.inc" diff --git a/lib/Dialect/Handshake/MemoryInterfaces.cpp b/lib/Dialect/Handshake/MemoryInterfaces.cpp index 1bdef0ad44..7cd84f9aae 100644 --- a/lib/Dialect/Handshake/MemoryInterfaces.cpp +++ b/lib/Dialect/Handshake/MemoryInterfaces.cpp @@ -73,7 +73,9 @@ LogicalResult MemoryInterfaceBuilder::instantiateInterfaces( handshake::LSQOp &lsqOp) { BackedgeBuilder edgeBuilder(rewriter, memref.getLoc()); FConnectLoad connect = [&](LoadOp loadOp, Value dataIn) { - rewriter.updateRootInPlace(loadOp, [&] { loadOp->setOperand(1, dataIn); }); + // API changed here: https://github.com/llvm/llvm-project/pull/78260 + // updateRootInPlace -> modifyOpInPlace + rewriter.modifyOpInPlace(loadOp, [&] { loadOp->setOperand(1, dataIn); }); }; return instantiateInterfaces(rewriter, edgeBuilder, connect, mcOp, lsqOp); } @@ -111,7 +113,7 @@ LogicalResult MemoryInterfaceBuilder::instantiateInterfaces( // so that the LSQ can forward its loads and stores to the MC. We need // load address, store address, and store data channels from the LSQ to // the MC and a load data channel from the MC to the LSQ - MemRefType memrefType = memref.getType().cast(); + MemRefType memrefType = cast(memref.getType()); // Create 3 backedges (load address, store address, store data) for the MC // inputs that will eventually come from the LSQ. diff --git a/lib/Support/JSON/JSON.cpp b/lib/Support/JSON/JSON.cpp index 609af6c20f..86cf7e1827 100644 --- a/lib/Support/JSON/JSON.cpp +++ b/lib/Support/JSON/JSON.cpp @@ -21,17 +21,6 @@ using namespace mlir; using namespace dynamatic::json; -bool llvm::json::fromJSON(const json::Value &value, unsigned &number, - json::Path path) { - std::optional opt = value.getAsUINT64(); - if (!opt.has_value()) { - path.report("expected unsigned number"); - return false; - } - number = opt.value(); - return true; -} - namespace { /// Serializes MLIR attributes to JSON. Only supports a restricted number of @@ -237,4 +226,4 @@ bool ObjectDeserializer::exhausted(const DenseSet &allowUnmapped) { path.field(key).report("unmapped key in object"); return false; }); -} \ No newline at end of file +} diff --git a/lib/Support/RTL/RTL.cpp b/lib/Support/RTL/RTL.cpp index 01ffe8b3cd..901b811eb7 100644 --- a/lib/Support/RTL/RTL.cpp +++ b/lib/Support/RTL/RTL.cpp @@ -247,11 +247,11 @@ MapVector RTLMatch::getGenericParameterValues() const { } static std::string serializeExtraSignalsInner(const Type &type) { - assert(type.isa() && + assert(isa(type) && "type should be ChannelType or ControlType"); handshake::ExtraSignalsTypeInterface extraSignalsType = - type.cast(); + cast(type); std::string extraSignalsValue; llvm::raw_string_ostream extraSignals(extraSignalsValue); diff --git a/lib/Transforms/ArithReduceStrength.cpp b/lib/Transforms/ArithReduceStrength.cpp index f653d0eddb..e156c63d6e 100644 --- a/lib/Transforms/ArithReduceStrength.cpp +++ b/lib/Transforms/ArithReduceStrength.cpp @@ -439,8 +439,8 @@ struct PromoteSignedCmp : public OpRewritePattern { // Promote the signed comparison to an equivalent unsigned one if possible if (!isPromotionPossible(cmpOp)) return failure(); - rewriter.updateRootInPlace(cmpOp, - [&]() { cmpOp.setPredicate(newPredicate); }); + rewriter.modifyOpInPlace(cmpOp, + [&]() { cmpOp.setPredicate(newPredicate); }); return success(); } @@ -471,8 +471,8 @@ struct ArithReduceStrengthPass MLIRContext *ctx = &getContext(); mlir::GreedyRewriteConfig config; - config.useTopDownTraversal = true; - config.enableRegionSimplification = false; + config.setUseTopDownTraversal(true); + config.setRegionSimplificationLevel(GreedySimplifyRegionLevel::Disabled); RewritePatternSet patterns{ctx}; patterns.add(); + auto memref = dyn_cast(v.getType()); if (!memref) return false; return !isUniDimensional(memref); @@ -296,8 +296,9 @@ struct CondBranchOpConversion matchAndRewrite(mlir::cf::CondBranchOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { rewriter.replaceOpWithNewOp( - op, adaptor.getCondition(), adaptor.getTrueDestOperands(), - adaptor.getFalseDestOperands(), op.getTrueDest(), op.getFalseDest()); + op, adaptor.getCondition(), op.getTrueDest(), + adaptor.getTrueDestOperands(), op.getFalseDest(), + adaptor.getFalseDestOperands()); return success(); } }; @@ -380,7 +381,7 @@ static void populateFlattenMemRefsLegality(ConversionTarget &target) { }); auto resultsConverted = llvm::all_of(op.getResultTypes(), [](Type type) { - if (auto memref = type.dyn_cast()) + if (auto memref = dyn_cast(type)) return isUniDimensional(memref); return true; }); diff --git a/lib/Transforms/HandshakeCanonicalize.cpp b/lib/Transforms/HandshakeCanonicalize.cpp index 2ccf018683..557f8b6a6b 100644 --- a/lib/Transforms/HandshakeCanonicalize.cpp +++ b/lib/Transforms/HandshakeCanonicalize.cpp @@ -162,8 +162,8 @@ struct HandshakeCanonicalizePass mlir::ModuleOp mod = getOperation(); mlir::GreedyRewriteConfig config; - config.useTopDownTraversal = true; - config.enableRegionSimplification = false; + config.setUseTopDownTraversal(true); + config.setRegionSimplificationLevel(GreedySimplifyRegionLevel::Disabled); RewritePatternSet patterns{ctx}; patterns.add dynamatic::createHandshakeCanonicalize() { return std::make_unique(); -} \ No newline at end of file +} diff --git a/lib/Transforms/HandshakeHoistExtInstances.cpp b/lib/Transforms/HandshakeHoistExtInstances.cpp index b86aae69cd..e3669cd632 100644 --- a/lib/Transforms/HandshakeHoistExtInstances.cpp +++ b/lib/Transforms/HandshakeHoistExtInstances.cpp @@ -125,7 +125,7 @@ HandshakeHoistExtInstancesPass::hoistInstances(handshake::FuncOp funcOp, auto namedArguments = llvm::zip_equal(instFuncOp.getArgNames(), instOp.getOperandTypes()); for (auto [argNameAttr, argType] : namedArguments) { - StringRef argName = argNameAttr.cast().strref(); + StringRef argName = cast(argNameAttr).strref(); resTypes.push_back(argType); resNames.push_back(StringAttr::get(ctx, instFuncName + "_" + argName)); } @@ -135,7 +135,7 @@ HandshakeHoistExtInstancesPass::hoistInstances(handshake::FuncOp funcOp, auto namedResults = llvm::zip_equal(instFuncOp.getResNames(), instOp.getResultTypes()); for (auto [argNameAttr, resType] : namedResults) { - StringRef argName = argNameAttr.cast().strref(); + StringRef argName = cast(argNameAttr).strref(); argTypes.push_back(resType); argNames.push_back(StringAttr::get(ctx, instFuncName + "_" + argName)); } diff --git a/lib/Transforms/HandshakeInferBasicBlocks.cpp b/lib/Transforms/HandshakeInferBasicBlocks.cpp index 9619af2ae3..1cdbcc03ca 100644 --- a/lib/Transforms/HandshakeInferBasicBlocks.cpp +++ b/lib/Transforms/HandshakeInferBasicBlocks.cpp @@ -113,7 +113,7 @@ struct FuncOpInferBasicBlocks : public OpConversionPattern { LogicalResult matchAndRewrite(handshake::FuncOp funcOp, OpAdaptor /*adaptor*/, ConversionPatternRewriter &rewriter) const override { - rewriter.updateRootInPlace(funcOp, [&] { + rewriter.modifyOpInPlace(funcOp, [&] { bool progress = false; do { progress = false; diff --git a/lib/Transforms/HandshakeMaterialize.cpp b/lib/Transforms/HandshakeMaterialize.cpp index c1406964db..a7b2db10da 100644 --- a/lib/Transforms/HandshakeMaterialize.cpp +++ b/lib/Transforms/HandshakeMaterialize.cpp @@ -172,8 +172,12 @@ static void promoteEagerToLazyForks(handshake::FuncOp funcOp) { // Replace the original fork's outputs that are part of the memory control // network with the first lazy fork's outputs - for (auto [from, to] : llvm::zip(lazyResults, lazyForkOp->getResults())) + auto results = lazyForkOp->getResults(); // avoid repeated calls + for (size_t i = 0; i < lazyResults.size(); ++i) { + auto from = lazyResults[i]; + auto to = results[i]; from.replaceAllUsesWith(to); + } if (hasValueWithoutLazyConstr) { // If some of the control fork's result go outside the memory control @@ -357,8 +361,8 @@ struct HandshakeMaterializePass // Then, greedily optimize forks mlir::GreedyRewriteConfig config; - config.useTopDownTraversal = true; - config.enableRegionSimplification = false; + config.setUseTopDownTraversal(true); + config.setRegionSimplificationLevel(GreedySimplifyRegionLevel::Disabled); RewritePatternSet patterns{ctx}; patterns .add( diff --git a/lib/Transforms/HandshakeMinimizeCstWidth.cpp b/lib/Transforms/HandshakeMinimizeCstWidth.cpp index cab73eac3d..eb483c0fa9 100644 --- a/lib/Transforms/HandshakeMinimizeCstWidth.cpp +++ b/lib/Transforms/HandshakeMinimizeCstWidth.cpp @@ -188,8 +188,8 @@ struct HandshakeMinimizeCstWidthPass mlir::ModuleOp mod = getOperation(); mlir::GreedyRewriteConfig config; - config.useTopDownTraversal = true; - config.enableRegionSimplification = false; + config.setUseTopDownTraversal(true); + config.setRegionSimplificationLevel(GreedySimplifyRegionLevel::Disabled); RewritePatternSet patterns{ctx}; patterns.add(optNegatives, ctx); if (failed(applyPatternsAndFoldGreedily(mod, std::move(patterns), config))) diff --git a/lib/Transforms/HandshakeOptimizeBitwidths.cpp b/lib/Transforms/HandshakeOptimizeBitwidths.cpp index bf3322df9d..606e1bb414 100644 --- a/lib/Transforms/HandshakeOptimizeBitwidths.cpp +++ b/lib/Transforms/HandshakeOptimizeBitwidths.cpp @@ -389,7 +389,7 @@ class OptDataConfig { public: /// Constructs the configuration from the specific operation being /// transformed. - OptDataConfig(Op op) : op(op){}; + OptDataConfig(Op op) : op(op) {}; /// Returns the list of operands that carry data. The method must return at /// least one operand. If multiple operands are returned, they must all have @@ -461,7 +461,7 @@ class OptDataConfig { /// result which does not carry data. class CMergeDataConfig : public OptDataConfig { public: - CMergeDataConfig(handshake::ControlMergeOp op) : OptDataConfig(op){}; + CMergeDataConfig(handshake::ControlMergeOp op) : OptDataConfig(op) {}; SmallVector getDataResults() override { return SmallVector{op.getResult()}; @@ -487,7 +487,7 @@ class CMergeDataConfig : public OptDataConfig { /// which does not carry data. class MuxDataConfig : public OptDataConfig { public: - MuxDataConfig(handshake::MuxOp op) : OptDataConfig(op){}; + MuxDataConfig(handshake::MuxOp op) : OptDataConfig(op) {}; SmallVector getDataOperands() override { return op.getDataOperands(); } @@ -507,7 +507,7 @@ class MuxDataConfig : public OptDataConfig { /// condition operand which does not carry data. class CBranchDataConfig : public OptDataConfig { public: - CBranchDataConfig(handshake::ConditionalBranchOp op) : OptDataConfig(op){}; + CBranchDataConfig(handshake::ConditionalBranchOp op) : OptDataConfig(op) {}; SmallVector getDataOperands() override { return SmallVector{op.getDataOperand()}; @@ -528,7 +528,7 @@ class CBranchDataConfig : public OptDataConfig { class BufferDataConfig : public OptDataConfig { public: BufferDataConfig(handshake::BufferOp op) - : OptDataConfig(op){}; + : OptDataConfig(op) {}; SmallVector getDataOperands() override { return SmallVector{this->op.getOperand()}; @@ -1537,8 +1537,8 @@ struct HandshakeOptimizeBitwidthsPass // Create greedy config for all optimization passes mlir::GreedyRewriteConfig config; - config.useTopDownTraversal = true; - config.enableRegionSimplification = false; + config.setUseTopDownTraversal(true); + config.setRegionSimplificationLevel(GreedySimplifyRegionLevel::Disabled); // Some optimizations do not need to be applied iteratively. We include // patterns to downgrade control merges and muxes with useless indices into diff --git a/lib/Transforms/ScfRotateForLoops.cpp b/lib/Transforms/ScfRotateForLoops.cpp index 6bf7f84c83..f15b227aa9 100644 --- a/lib/Transforms/ScfRotateForLoops.cpp +++ b/lib/Transforms/ScfRotateForLoops.cpp @@ -129,8 +129,8 @@ struct ScfForLoopRotationPass void runDynamaticPass() override { auto *ctx = &getContext(); mlir::GreedyRewriteConfig config; - config.useTopDownTraversal = true; - config.enableRegionSimplification = false; + config.setUseTopDownTraversal(true); + config.setRegionSimplificationLevel(GreedySimplifyRegionLevel::Disabled); RewritePatternSet patterns{ctx}; patterns.add(ctx); diff --git a/lib/Transforms/ScfSimpleIfToSelect.cpp b/lib/Transforms/ScfSimpleIfToSelect.cpp index d6cf2bb1b8..6ade650f37 100644 --- a/lib/Transforms/ScfSimpleIfToSelect.cpp +++ b/lib/Transforms/ScfSimpleIfToSelect.cpp @@ -239,8 +239,8 @@ struct ScfSimpleIfToSelectPass void runDynamaticPass() override { auto *ctx = &getContext(); mlir::GreedyRewriteConfig config; - config.useTopDownTraversal = true; - config.enableRegionSimplification = false; + config.setUseTopDownTraversal(true); + config.setRegionSimplificationLevel(GreedySimplifyRegionLevel::Disabled); RewritePatternSet patterns{ctx}; patterns.add(ctx); diff --git a/tools/dynamatic/dynamatic.cpp b/tools/dynamatic/dynamatic.cpp index fe294935c1..d742e44ee5 100644 --- a/tools/dynamatic/dynamatic.cpp +++ b/tools/dynamatic/dynamatic.cpp @@ -399,7 +399,7 @@ class FrontendCommands { std::string FrontendState::makeAbsolutePath(StringRef path) { SmallString<128> str; path::append(str, path); - fs::make_absolute(cwd, str); + llvm::sys::path::make_absolute(cwd, str); return str.str().str(); } diff --git a/tools/export-rtl/export-rtl.cpp b/tools/export-rtl/export-rtl.cpp index 55f6cabfb9..c61ec1645c 100644 --- a/tools/export-rtl/export-rtl.cpp +++ b/tools/export-rtl/export-rtl.cpp @@ -135,7 +135,7 @@ struct ExportInfo { /// Creates export information for the given module and RTL configuration. ExportInfo(mlir::ModuleOp modOp, RTLConfiguration &config, StringRef outputPath) - : modOp(modOp), config(config), outputPath(outputPath){}; + : modOp(modOp), config(config), outputPath(outputPath) {}; /// Associates every external hardware module to its match according to the /// RTL configuration and concretizes each of them inside the output @@ -158,7 +158,7 @@ struct FormalPropertyInfo { StringRef outputPath; FormalPropertyInfo(FormalPropertyTable &table, StringRef outputPath) - : table(table), outputPath(outputPath){}; + : table(table), outputPath(outputPath) {}; }; } // namespace @@ -331,7 +331,7 @@ class RTLWriter { /// Creates the RTL writer. RTLWriter(ExportInfo &exportInfo, FormalPropertyInfo &propertyInfo, HDL hdl) - : exportInfo(exportInfo), propertyInfo(propertyInfo), hdl(hdl){}; + : exportInfo(exportInfo), propertyInfo(propertyInfo), hdl(hdl) {}; /// Writes the RTL implementation of the module to the output stream. On /// failure, the RTL implementation should be considered invalid and/or @@ -1162,8 +1162,7 @@ std::optional SMVWriter::getUserSignal(Value val) const { std::string argName; auto argNamesAttr = userInstance->getArgNames(); if (operandIndex < argNamesAttr.size()) { - if (auto strAttr = - argNamesAttr[operandIndex].dyn_cast()) { + if (auto strAttr = dyn_cast(argNamesAttr[operandIndex])) { argName = strAttr.getValue().str(); return instName + "." + argName; } @@ -1292,7 +1291,7 @@ void SMVWriter::constructIOMappings( auto signal = getValueName(oprd); std::string signalName = signal.str(); - if (oprd.isa()) + if (isa(oprd)) std::replace(signalName.begin(), signalName.end(), '.', '_'); llvm::TypeSwitch(portType) diff --git a/tools/hls-verifier/hls-verifier.cpp b/tools/hls-verifier/hls-verifier.cpp index ec38d48995..0495566459 100644 --- a/tools/hls-verifier/hls-verifier.cpp +++ b/tools/hls-verifier/hls-verifier.cpp @@ -57,7 +57,7 @@ mlir::LogicalResult compareCAndVhdlOutputs(const VerificationContext &ctx) { for (auto [arg, portAttr] : llvm::zip_equal( funcOp->getBodyBlock()->getArguments(), funcOp->getArgNames())) { - std::string argName = portAttr.dyn_cast().data(); + std::string argName = dyn_cast(portAttr).data(); if (handshake::ChannelType type = dyn_cast(arg.getType())) { @@ -74,7 +74,7 @@ mlir::LogicalResult compareCAndVhdlOutputs(const VerificationContext &ctx) { // data/control channels (no arrays). for (auto [resType, portAttr] : llvm::zip_equal(funcOp->getResultTypes(), funcOp->getResNames())) { - std::string argName = portAttr.dyn_cast().str(); + std::string argName = dyn_cast(portAttr).str(); if (handshake::ChannelType type = dyn_cast(resType)) { argAndTypeMap.emplace_back(argName, type.getDataType()); diff --git a/tools/hls-verifier/include/HlsVhdlTb.h b/tools/hls-verifier/include/HlsVhdlTb.h index 12faf006de..e3fa269c13 100644 --- a/tools/hls-verifier/include/HlsVhdlTb.h +++ b/tools/hls-verifier/include/HlsVhdlTb.h @@ -163,7 +163,7 @@ getInputArguments(handshake::FuncOp *funcOp) { for (auto [arg, portAttr] : llvm::zip_equal( funcOp->getBodyBlock()->getArguments(), funcOp->getArgNames())) { if (Ty type = dyn_cast(arg.getType())) { - std::string argName = portAttr.dyn_cast().data(); + std::string argName = dyn_cast(portAttr).data(); interfaces.emplace_back(type, argName); } } @@ -177,7 +177,7 @@ getOutputArguments(handshake::FuncOp *funcOp) { for (auto [resType, portAttr] : llvm::zip_equal(funcOp->getResultTypes(), funcOp->getResNames())) { - std::string argName = portAttr.dyn_cast().str(); + std::string argName = dyn_cast(portAttr).str(); if (Ty type = dyn_cast(resType)) { interfaces.emplace_back(type, argName); } diff --git a/tools/translate-llvm-to-std/CMakeLists.txt b/tools/translate-llvm-to-std/CMakeLists.txt index 75f58bc3e6..14d9d993bf 100644 --- a/tools/translate-llvm-to-std/CMakeLists.txt +++ b/tools/translate-llvm-to-std/CMakeLists.txt @@ -9,6 +9,7 @@ add_llvm_tool(translate-llvm-to-std ) llvm_update_compile_flags(translate-llvm-to-std) + target_link_libraries(translate-llvm-to-std PRIVATE DynamaticSupport @@ -33,6 +34,7 @@ target_link_libraries(translate-llvm-to-std LLVMCore LLVMSupport LLVMAnalysis + LLVMIRReader libclang ) diff --git a/tools/translate-llvm-to-std/InferArgTypes.h b/tools/translate-llvm-to-std/InferArgTypes.h index 43eee4e508..005e8e16d6 100644 --- a/tools/translate-llvm-to-std/InferArgTypes.h +++ b/tools/translate-llvm-to-std/InferArgTypes.h @@ -25,6 +25,7 @@ #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/StringRef.h" #include "llvm/IR/DerivedTypes.h" +#include using namespace mlir; diff --git a/tools/translate-llvm-to-std/TranslateLLVMToStd.cpp b/tools/translate-llvm-to-std/TranslateLLVMToStd.cpp index e928644d2e..c8a66fff32 100644 --- a/tools/translate-llvm-to-std/TranslateLLVMToStd.cpp +++ b/tools/translate-llvm-to-std/TranslateLLVMToStd.cpp @@ -36,10 +36,10 @@ static mlir::Type getMLIRType(llvm::Type *llvmType, return mlir::IntegerType::get(context, llvmType->getIntegerBitWidth()); } if (llvmType->isFloatTy()) { - return mlir::FloatType::getF32(context); + return mlir::Float32Type::get(context); } if (llvmType->isDoubleTy()) { - return mlir::FloatType::getF64(context); + return mlir::Float32Type::get(context); } llvm_unreachable("Unhandled scalar type"); @@ -360,12 +360,12 @@ void TranslateLLVMToStd::createConstants(llvm::Function *llvmFunc) { const APFloat &floatVal = floatConst->getValue(); if (&floatVal.getSemantics() == &llvm::APFloat::IEEEsingle()) { auto constOp = builder.create( - loc, floatVal, builder.getF32Type()); + loc, builder.getF32Type(), floatVal); valueMap[val] = constOp->getResult(0); loc = constOp->getLoc(); } else if (&floatVal.getSemantics() == &llvm::APFloat::IEEEdouble()) { auto constOp = builder.create( - loc, floatVal, builder.getF64Type()); + loc, builder.getF64Type(), floatVal); valueMap[val] = constOp->getResult(0); loc = constOp->getLoc(); } @@ -520,7 +520,7 @@ void TranslateLLVMToStd::translateGEPInst(llvm::GetElementPtrInst *gepInst) { mlir::Value baseAddress = valueMap[gepInst->getPointerOperand()]; SmallVector indexOperands; - auto memrefType = baseAddress.getType().dyn_cast(); + auto memrefType = dyn_cast(baseAddress.getType()); if (!memrefType) llvm_unreachable("GEP should take memref as reference"); @@ -641,7 +641,7 @@ void TranslateLLVMToStd::translateLoadInst(llvm::LoadInst *loadInst) { // NOTE: This condition handles a special case where a load only has // constant indices, e.g., tmp = mat[0][0]. memref = this->valueMap[instAddr]; - auto memrefType = memref.getType().dyn_cast(); + auto memrefType = dyn_cast(memref.getType()); int constZerosToAdd = memrefType.getShape().size(); for (int i = 0; i < constZerosToAdd; i++) { auto constZeroOp = this->builder.create( @@ -677,7 +677,7 @@ void TranslateLLVMToStd::translateStoreInst(llvm::StoreInst *storeInst) { llvm_unreachable( "Converting a load but the producer hasn't been converted yet!"); memref = this->valueMap[instAddr]; - auto memrefType = memref.getType().dyn_cast(); + auto memrefType = dyn_cast(memref.getType()); int constZerosToAdd = memrefType.getShape().size(); for (int i = 0; i < constZerosToAdd; i++) { diff --git a/tools/translate-llvm-to-std/main.cpp b/tools/translate-llvm-to-std/main.cpp index ae2a614912..797e48700c 100644 --- a/tools/translate-llvm-to-std/main.cpp +++ b/tools/translate-llvm-to-std/main.cpp @@ -10,6 +10,7 @@ #include "llvm/IR/LLVMContext.h" #include "llvm/IR/Module.h" #include "llvm/IR/Value.h" +#include "llvm/IRReader/IRReader.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/ErrorHandling.h" #include "llvm/Support/InitLLVM.h" @@ -64,7 +65,7 @@ int main(int argc, char **argv) { LLVMContext llvmContext; SMDiagnostic err; std::unique_ptr llvmModule = - parseAssemblyFile(StringRef(inputFilename), err, llvmContext); + parseIRFile(StringRef(inputFilename), err, llvmContext); if (!llvmModule) { errs() << "Failed to read LLVM IR file.\n"; err.print(argv[0], errs()); diff --git a/tutorials/CreatingPasses/lib/CreatingPasses/Transforms/GreedySimplifyMergeLike.cpp b/tutorials/CreatingPasses/lib/CreatingPasses/Transforms/GreedySimplifyMergeLike.cpp index ebe82e4838..a254ce49e1 100644 --- a/tutorials/CreatingPasses/lib/CreatingPasses/Transforms/GreedySimplifyMergeLike.cpp +++ b/tutorials/CreatingPasses/lib/CreatingPasses/Transforms/GreedySimplifyMergeLike.cpp @@ -35,7 +35,7 @@ struct EraseSingleInputMerge : public OpRewritePattern { if (mergeOp->getNumOperands() != 1) return failure(); - rewriter.updateRootInPlace(mergeOp, [&] { + rewriter.modifyOpInPlace(mergeOp, [&] { // Replace all occurences of the merge's single result throughout the IR // with the merge's single operand. This is equivalent to bypassing the // merge @@ -78,7 +78,7 @@ struct DowngradeIndexlessControlMerge cmergeOp.getLoc(), cmergeOp->getOperands()); // We are modifying the operation - rewriter.updateRootInPlace(cmergeOp, [&] { + rewriter.modifyOpInPlace(cmergeOp, [&] { // Then, replace the control merge's first result (the selected input) // with the single result of the newly created merge operation Value mergeRes = newMergeOp.getResult(); @@ -111,8 +111,8 @@ struct GreedySimplifyMergeLikePass // Set up a configuration object to customize the behavior of the rewriter mlir::GreedyRewriteConfig config; - config.useTopDownTraversal = true; - config.enableRegionSimplification = false; + config.setUseTopDownTraversal(true); + config.setRegionSimplificationLevel(GreedySimplifyRegionLevel::Disabled); // Create a rewrite pattern set and add our two patterns to it RewritePatternSet patterns{ctx}; From fce6f40cdd23f3ab3e7a82109e4a4ca8e210e576 Mon Sep 17 00:00:00 2001 From: Jiahui Xu Date: Fri, 7 Nov 2025 12:06:03 +0100 Subject: [PATCH 03/27] CfToHandshake: DynamaticPass -> Pass --- .../Conversion/FtdCfToHandshake.h | 17 +++++----- .../include/experimental/Conversion/Passes.td | 3 +- .../lib/Conversion/FtdCfToHandshake.cpp | 22 +++++++++---- include/dynamatic/Conversion/CfToHandshake.h | 3 -- include/dynamatic/Conversion/Passes.td | 8 +++-- .../CfToHandshake/CfToHandshake.cpp | 33 +++++++++++++------ tools/dynamatic/scripts/compile.sh | 2 +- tools/translate-llvm-to-std/InferArgTypes.cpp | 10 ++++++ 8 files changed, 65 insertions(+), 33 deletions(-) diff --git a/experimental/include/experimental/Conversion/FtdCfToHandshake.h b/experimental/include/experimental/Conversion/FtdCfToHandshake.h index 7a63dfff9b..6e07295b1b 100644 --- a/experimental/include/experimental/Conversion/FtdCfToHandshake.h +++ b/experimental/include/experimental/Conversion/FtdCfToHandshake.h @@ -21,6 +21,13 @@ #include "dynamatic/Support/LLVM.h" #include "experimental/Analysis/GSAAnalysis.h" +namespace dynamatic { +namespace experimental { +#define GEN_PASS_DECL_FTDCFTOHANDSHAKE +#include "experimental/Conversion/Passes.h.inc" +} // namespace experimental +} // namespace dynamatic + namespace dynamatic { namespace experimental { namespace ftd { @@ -35,14 +42,14 @@ class FtdLowerFuncToHandshake : public LowerFuncToHandshake { NameAnalysis &namer, MLIRContext *ctx, mlir::PatternBenefit benefit = 1) : LowerFuncToHandshake(namer, ctx, benefit), cdAnalysis(cda), - gsaAnalysis(gsa){}; + gsaAnalysis(gsa) {}; FtdLowerFuncToHandshake(ControlDependenceAnalysis &cda, gsa::GSAAnalysis &gsa, NameAnalysis &namer, const TypeConverter &typeConverter, MLIRContext *ctx, mlir::PatternBenefit benefit = 1) : LowerFuncToHandshake(namer, typeConverter, ctx, benefit), - cdAnalysis(cda), gsaAnalysis(gsa){}; + cdAnalysis(cda), gsaAnalysis(gsa) {}; LogicalResult matchAndRewrite(mlir::func::FuncOp funcOp, OpAdaptor adaptor, @@ -85,12 +92,6 @@ struct FtdConvertIndexCast : public ConvertIndexCast { ConversionPatternRewriter &rewriter) const override; }; -#define GEN_PASS_DECL_FTDCFTOHANDSHAKE -#define GEN_PASS_DEF_FTDCFTOHANDSHAKE -#include "experimental/Conversion/Passes.h.inc" - -std::unique_ptr createFtdCfToHandshake(); - } // namespace ftd } // namespace experimental } // namespace dynamatic diff --git a/experimental/include/experimental/Conversion/Passes.td b/experimental/include/experimental/Conversion/Passes.td index ac68b127d5..71f71949c8 100644 --- a/experimental/include/experimental/Conversion/Passes.td +++ b/experimental/include/experimental/Conversion/Passes.td @@ -20,7 +20,7 @@ include "dynamatic/Support/Passes.td" // FtdCfToHandshake //===----------------------------------------------------------------------===// -def FtdCfToHandshake : DynamaticPass<"ftd-lower-cf-to-handshake"> { +def FtdCfToHandshake : Pass<"ftd-lower-cf-to-handshake"> { let summary = "Lowers func and cf dialects to handshake with fast token delivery"; let description = [{ The fast token delivery (FTD) methodology was described by @@ -33,7 +33,6 @@ def FtdCfToHandshake : DynamaticPass<"ftd-lower-cf-to-handshake"> { As the algorithm does not require any parameter, no input is necessary. This pass can be used through the flag "--fast-token-delivery" at compile tile in dynamatic. }]; - let constructor = "dynamatic::experimental::ftd::createFtdCfToHandshake()"; } #endif // EXPERIMENTAL_CONVERSION_PASSES_TD diff --git a/experimental/lib/Conversion/FtdCfToHandshake.cpp b/experimental/lib/Conversion/FtdCfToHandshake.cpp index e7836793c3..8f1fec89c8 100644 --- a/experimental/lib/Conversion/FtdCfToHandshake.cpp +++ b/experimental/lib/Conversion/FtdCfToHandshake.cpp @@ -36,15 +36,26 @@ using namespace dynamatic::experimental; using namespace dynamatic::experimental::boolean; using namespace dynamatic::experimental::ftd; +namespace dynamatic { +namespace experimental { +#define GEN_PASS_DEF_FTDCFTOHANDSHAKE +#include "experimental/Conversion/Passes.h.inc" +} // namespace experimental +} // namespace dynamatic + namespace { struct FtdCfToHandshakePass - : public dynamatic::experimental::ftd::impl::FtdCfToHandshakeBase< + : public dynamatic::experimental::impl::FtdCfToHandshakeBase< FtdCfToHandshakePass> { - void runDynamaticPass() override { + void runOnOperation() override { MLIRContext *ctx = &getContext(); - ModuleOp modOp = getOperation(); + mlir::ModuleOp modOp = llvm::dyn_cast(getOperation()); + + NameAnalysis &nameAnalysis = getAnalysis(); + if (!nameAnalysis.isAnalysisValid()) + return signalPassFailure(); CfToHandshakeTypeConverter converter; RewritePatternSet patterns(ctx); @@ -99,6 +110,7 @@ struct FtdCfToHandshakePass if (failed(applyFullConversion(modOp, target, std::move(patterns)))) return signalPassFailure(); + markAnalysesPreserved(); } }; } // namespace @@ -373,7 +385,3 @@ LogicalResult FtdConvertIndexCast::matchAndRewrite( castOp.getResult().replaceAllUsesWith(newOp->getResult(0)); return success(); } - -std::unique_ptr ftd::createFtdCfToHandshake() { - return std::make_unique(); -} diff --git a/include/dynamatic/Conversion/CfToHandshake.h b/include/dynamatic/Conversion/CfToHandshake.h index 49baa43b9c..ab95e9ad84 100644 --- a/include/dynamatic/Conversion/CfToHandshake.h +++ b/include/dynamatic/Conversion/CfToHandshake.h @@ -275,11 +275,8 @@ struct ConvertUndefinedValues }; #define GEN_PASS_DECL_CFTOHANDSHAKE -#define GEN_PASS_DEF_CFTOHANDSHAKE #include "dynamatic/Conversion/Passes.h.inc" -std::unique_ptr createCfToHandshake(); - } // namespace dynamatic #endif // DYNAMATIC_CONVERSION_CF_TO_HANDSHAKE_H diff --git a/include/dynamatic/Conversion/Passes.td b/include/dynamatic/Conversion/Passes.td index c86b98a5e5..0ea1461f16 100644 --- a/include/dynamatic/Conversion/Passes.td +++ b/include/dynamatic/Conversion/Passes.td @@ -57,7 +57,7 @@ def ScfToCf : DynamaticPass<"lower-scf-to-cf", [ // CfToHandshake //===----------------------------------------------------------------------===// -def CfToHandshake : DynamaticPass<"lower-cf-to-handshake"> { +def CfToHandshake : Pass<"lower-cf-to-handshake"> { let summary = "Lowers func and cf dialects to handshake."; let description = [{ Lowers func-level functions whose body have unstructured control flow into @@ -65,7 +65,11 @@ def CfToHandshake : DynamaticPass<"lower-cf-to-handshake"> { which represent dataflow circuits that can ultimately be converted to an RTL design. }]; -let constructor = "dynamatic::createCfToHandshake()"; + + let dependentDialects = [ + "mlir::cf::ControlFlowDialect", "mlir::arith::ArithDialect", + "mlir::func::FuncDialect", "mlir::memref::MemRefDialect", + "handshake::HandshakeDialect"]; } //===----------------------------------------------------------------------===// diff --git a/lib/Conversion/CfToHandshake/CfToHandshake.cpp b/lib/Conversion/CfToHandshake/CfToHandshake.cpp index 193be79cfb..9660375f8f 100644 --- a/lib/Conversion/CfToHandshake/CfToHandshake.cpp +++ b/lib/Conversion/CfToHandshake/CfToHandshake.cpp @@ -60,6 +60,11 @@ using namespace mlir::affine; using namespace mlir::memref; using namespace dynamatic; +namespace dynamatic { +#define GEN_PASS_DEF_CFTOHANDSHAKE +#include "dynamatic/Conversion/Passes.h.inc" +} // namespace dynamatic + //===-----------------------------------------------------------------------==// // Helper functions //===-----------------------------------------------------------------------==// @@ -399,17 +404,17 @@ FailureOr LowerFuncToHandshake::lowerSignature( TypeConverter::SignatureConversion entryConversion( entryBlock->getNumArguments()); setupEntryBlockConversion(entryBlock, numMemories, rewriter, entryConversion); - rewriter.applySignatureConversion(entryBlock, entryConversion, typeConv); + rewriter.applySignatureConversion(entryBlock, entryConversion, + getTypeConverter()); for (Block &nonEntryBlock : llvm::make_early_inc_range(llvm::drop_begin(funcOp.getBody()))) { - TypeConverter::SignatureConversion nonEntryConversion( /*numOrigInputs=*/nonEntryBlock.getNumArguments()); setupBlockConversion(&nonEntryBlock, rewriter, nonEntryConversion); rewriter.applySignatureConversion(&nonEntryBlock, nonEntryConversion, - typeConv); + getTypeConverter()); } // Modify branch-like terminators to forward the new control value through @@ -578,6 +583,9 @@ void LowerFuncToHandshake::addMergeOps(handshake::FuncOp funcOp, // Insert merge-like operations in all non-entry blocks (with backedges // instead as data operands) DenseMap> blockMerges; + + Block *entryBlock = &funcOp.getBody().front(); + for (Block &block : llvm::drop_begin(funcOp)) { rewriter.setInsertionPointToStart(&block); @@ -689,6 +697,7 @@ void LowerFuncToHandshake::addBranchOps( // Connect users of the branch to the appropriate branch result for (const auto &userGroup : branchUsers) { + rewriter.replaceUsesWithIf( branchOprd, getSuccResult(termOp, newOp, userGroup.first), [&](OpOperand &oprd) { @@ -1006,6 +1015,7 @@ void LowerFuncToHandshake::idBasicBlocks( LogicalResult LowerFuncToHandshake::flattenAndTerminate( handshake::FuncOp funcOp, ConversionPatternRewriter &rewriter, const ArgReplacements &argReplacements) const { + // Erase all cf-level terminators, accumulating operands to func-level returns // as we go SmallVector> returnsOperands; @@ -1040,9 +1050,10 @@ LogicalResult LowerFuncToHandshake::flattenAndTerminate( SmallVector replacements; for (BlockArgument blockArg : block.getArguments()) { Value mergeRes = argReplacements.at(blockArg); - replacements.push_back(mergeRes); + // Replacing BA with merge results rewriter.replaceAllUsesWith(blockArg, mergeRes); } + // Replacing the block arguments with merge results rewriter.inlineBlockBefore(&block, lastOp, replacements); } @@ -1606,9 +1617,13 @@ namespace { struct CfToHandshakePass : public dynamatic::impl::CfToHandshakeBase { - void runDynamaticPass() override { + void runOnOperation() override { MLIRContext *ctx = &getContext(); - ModuleOp modOp = getOperation(); + mlir::ModuleOp modOp = llvm::dyn_cast(getOperation()); + + NameAnalysis &nameAnalysis = getAnalysis(); + if (!nameAnalysis.isAnalysisValid()) + return signalPassFailure(); // Put all non-external functions into maximal SSA form for (auto funcOp : modOp.getOps()) { @@ -1705,10 +1720,8 @@ struct CfToHandshakePass func.erase(); } } + // The name analysis is always preserved across passes + markAnalysesPreserved(); } }; } // namespace - -std::unique_ptr dynamatic::createCfToHandshake() { - return std::make_unique(); -} diff --git a/tools/dynamatic/scripts/compile.sh b/tools/dynamatic/scripts/compile.sh index b8383a68dd..ac22c4bce7 100755 --- a/tools/dynamatic/scripts/compile.sh +++ b/tools/dynamatic/scripts/compile.sh @@ -20,7 +20,7 @@ DISABLE_LSQ=${10} FAST_TOKEN_DELIVERY=${11} MILP_SOLVER=${12} -LLVM=$DYNAMATIC_DIR/polygeist/llvm-project +LLVM=$DYNAMATIC_DIR/llvm-project LLVM_BINS=$DYNAMATIC_DIR/bin export PATH=$PATH:$LLVM_BINS diff --git a/tools/translate-llvm-to-std/InferArgTypes.cpp b/tools/translate-llvm-to-std/InferArgTypes.cpp index ef7f577805..5a71c0adcd 100644 --- a/tools/translate-llvm-to-std/InferArgTypes.cpp +++ b/tools/translate-llvm-to-std/InferArgTypes.cpp @@ -120,7 +120,15 @@ static std::optional processScalarType(CXType clangType) { llvm_unreachable("Unhandled CXType_Unexposed type!"); return std::nullopt; } + + case CXType_Typedef: { + CXCursor typedefCursor = clang_getTypeDeclaration(clangType); + return processScalarType(clang_getTypedefDeclUnderlyingType(typedefCursor)); + } default: { + llvm::errs() << "Type ID of unhandled scalar type: " << clangType.kind + << "\n"; + return std::nullopt; } } @@ -161,6 +169,8 @@ static std::optional fromCXType(CXType type) { return ArgType{scalarType.value(), arrayDimSizes, false}; } } + + llvm::errs() << "Unhandled compound type id: " << type.kind << "\n"; // TODO: One important thing to handle in the future is the arguments that // are **passed by reference**. It is probably correct to promote them to // the function return values. From 13164523bce3c01c3e922c8ada476a097e878183 Mon Sep 17 00:00:00 2001 From: Jiahui Xu Date: Fri, 7 Nov 2025 14:34:40 +0100 Subject: [PATCH 04/27] WIP: fixing the inlining block --- .../CfToHandshake/CfToHandshake.cpp | 162 ++++++++++++------ tools/dynamatic/scripts/compile.sh | 2 +- 2 files changed, 108 insertions(+), 56 deletions(-) diff --git a/lib/Conversion/CfToHandshake/CfToHandshake.cpp b/lib/Conversion/CfToHandshake/CfToHandshake.cpp index 9660375f8f..b4303b1a69 100644 --- a/lib/Conversion/CfToHandshake/CfToHandshake.cpp +++ b/lib/Conversion/CfToHandshake/CfToHandshake.cpp @@ -48,6 +48,7 @@ #include "mlir/IR/Value.h" #include "mlir/Support/LogicalResult.h" #include "mlir/Transforms/DialectConversion.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/TypeSwitch.h" @@ -122,7 +123,7 @@ mergeFuncResults(handshake::FuncOp funcOp, ConversionPatternRewriter &rewriter, SmallVector mergeOperands; for (ValueRange operands : returnsOperands) mergeOperands.push_back(operands[i]); - auto mergeOp = rewriter.create(loc, mergeOperands); + auto mergeOp = handshake::MergeOp::create(rewriter, loc, mergeOperands); results.push_back(mergeOp.getResult()); mergeOp->setAttr(BB_ATTR_NAME, rewriter.getUI32IntegerAttr(exitBlockID)); } @@ -388,8 +389,8 @@ FailureOr LowerFuncToHandshake::lowerSignature( rewriter.setInsertionPoint(funcOp); FunctionType funTy = rewriter.getFunctionType(argTypes, resTypes); SmallVector attrs = deriveNewAttributes(funcOp); - auto newFuncOp = rewriter.create( - funcOp.getLoc(), funcOp.getName(), funTy, attrs); + auto newFuncOp = handshake::FuncOp::create(rewriter, funcOp.getLoc(), + funcOp.getName(), funTy, attrs); if (funcOp.isExternal()) { rewriter.eraseOp(funcOp); return newFuncOp; @@ -433,19 +434,20 @@ FailureOr LowerFuncToHandshake::lowerSignature( trueOperands.push_back(blockCtrl); falseOperands.push_back(blockCtrl); - rewriter.replaceOp(termOp, - rewriter.create( - condBrOp->getLoc(), condBrOp.getCondition(), - condBrOp.getTrueDest(), trueOperands, - condBrOp.getFalseDest(), falseOperands)); + rewriter.replaceOp(termOp, cf::CondBranchOp::create( + rewriter, condBrOp->getLoc(), + condBrOp.getCondition(), + condBrOp.getTrueDest(), trueOperands, + condBrOp.getFalseDest(), falseOperands)); } else if (auto brOp = dyn_cast(termOp)) { SmallVector operands; if (failed(rewriter.getRemappedValues(brOp.getDestOperands(), operands))) return failure(); operands.push_back(blockCtrl); - rewriter.replaceOp(termOp, rewriter.create( - brOp->getLoc(), brOp.getDest(), operands)); + rewriter.replaceOp(termOp, + cf::BranchOp::create(rewriter, brOp->getLoc(), + brOp.getDest(), operands)); } } @@ -537,10 +539,10 @@ void LowerFuncToHandshake::insertMerge(BlockArgument blockArg, // Every block needs to feed it's entry control into a control merge if (blockArg == getBlockControl(block)) { addFromAllPredecessors(handshake::ControlType::get(rewriter.getContext())); - iMerge.op = rewriter.create(loc, operands); + iMerge.op = handshake::ControlMergeOp::create(rewriter, loc, operands); } else if (predecessors.size() == 1) { addFromAllPredecessors(blockArg.getType()); - iMerge.op = rewriter.create(loc, operands); + iMerge.op = handshake::MergeOp::create(rewriter, loc, operands); } else { // Create a backedge for the index operand, and another one for each data // operand. The index operand will eventually resolve to the current block's @@ -568,8 +570,8 @@ void LowerFuncToHandshake::insertMerge(BlockArgument blockArg, // Since none of the operands have extra signals, the result type matches // the first operand. - iMerge.op = rewriter.create( - loc, /*resultType=*/operands[0].getType(), index, operands); + iMerge.op = handshake::MuxOp::create( + rewriter, loc, /*result=*/operands[0].getType(), index, operands); } } @@ -675,10 +677,10 @@ void LowerFuncToHandshake::addBranchOps( // Create a branch-like operation for the branch operand Operation *newOp; if (cond) { - newOp = rewriter.create(loc, cond, - branchOprd); + newOp = handshake::ConditionalBranchOp::create(rewriter, loc, cond, + branchOprd); } else { - newOp = rewriter.create(loc, branchOprd); + newOp = handshake::BranchOp::create(rewriter, loc, branchOprd); } // Group users by the block which they belong to, which inform the result @@ -812,7 +814,7 @@ LogicalResult LowerFuncToHandshake::convertMemoryOps( assert(addr && "failed to remap address"); Type dataTy = cast(memref.getType()).getElementType(); Value data = edgeBuilder.get(channelifyType(dataTy)); - auto newOp = rewriter.create(loc, addr, data); + auto newOp = handshake::LoadOp::create(rewriter, loc, addr, data); copyDialectAttr(loadOp, newOp); namer.replaceOp(loadOp, newOp); @@ -838,7 +840,8 @@ LogicalResult LowerFuncToHandshake::convertMemoryOps( Value addr = rewriter.getRemappedValue(indices.front()); Value data = rewriter.getRemappedValue(storeOp.getValueToStore()); assert((addr && data) && "failed to remap address or data"); - auto newOp = rewriter.create(loc, addr, data); + auto newOp = + handshake::StoreOp::create(rewriter, loc, addr, data); copyDialectAttr(storeOp, newOp); @@ -930,7 +933,7 @@ LogicalResult LowerFuncToHandshake::verifyAndCreateMemInterfaces( } rewriter.setInsertionPointToStart(lastRetOp->getBlock()); auto mergeOp = - rewriter.create(lastRetOp.getLoc(), controls); + handshake::MergeOp::create(rewriter, lastRetOp.getLoc(), controls); ctrlEnd = mergeOp.getResult(); // The merge goes into an extra "end block" after all others, this will be @@ -1054,7 +1057,7 @@ LogicalResult LowerFuncToHandshake::flattenAndTerminate( rewriter.replaceAllUsesWith(blockArg, mergeRes); } // Replacing the block arguments with merge results - rewriter.inlineBlockBefore(&block, lastOp, replacements); + // rewriter.inlineBlockBefore(&block, lastOp, replacements); } // The terminator's operands are, in order @@ -1073,7 +1076,7 @@ LogicalResult LowerFuncToHandshake::flattenAndTerminate( if (arg.getUsers().empty()) { // When the memory region is not accessed, just a create a constant source // of valid "memory end" tokens for ir - auto sourceOp = rewriter.create(lastOp->getLoc()); + auto sourceOp = handshake::SourceOp::create(rewriter, lastOp->getLoc()); sourceOp->setAttr(BB_ATTR_NAME, rewriter.getUI32IntegerAttr(exitBlockID)); endOprds.push_back(sourceOp.getResult()); } else { @@ -1088,7 +1091,7 @@ LogicalResult LowerFuncToHandshake::flattenAndTerminate( } endOprds.push_back(getBlockControl(funcOp.getBodyBlock())); - auto endOp = rewriter.create(lastOp->getLoc(), endOprds); + auto endOp = handshake::EndOp::create(rewriter, lastOp->getLoc(), endOprds); endOp->setAttr(BB_ATTR_NAME, rewriter.getUI32IntegerAttr(exitBlockID)); return success(); } @@ -1190,13 +1193,13 @@ ConvertCalls::matchAndRewrite(func::CallOp callOp, OpAdaptor adaptor, TypeRange resultTypes; // Vectors storing indices of classified arguments for placeholder logic // handling - SmallVector InstanceOpInputIndices; - SmallVector InstanceOpOutputIndices; - SmallVector InstanceOpParameterIndices; + SmallVector instanceOpInputIndices; + SmallVector instanceOpOutputIndices; + SmallVector instanceOpParameterIndices; SmallVector initCalls; // OutputConnections: Maps output argument index -> list of operations that // consume its value - llvm::DenseMap> OutputConnections; + llvm::DenseMap> outputConnections; // parameterMap: Maps parameter name to its value llvm::DenseMap parameterMap; // Store and later erase const that are used to define parameters @@ -1219,13 +1222,13 @@ ConvertCalls::matchAndRewrite(func::CallOp callOp, OpAdaptor adaptor, assert(nameAttr && !nameAttr.getValue().empty() && "Argument name attribute is missing or empty"); if (nameAttr.getValue().starts_with("input_")) { - InstanceOpInputIndices.push_back(i); + instanceOpInputIndices.push_back(i); } else if (nameAttr.getValue().starts_with("output_")) { - InstanceOpOutputIndices.push_back(i); + instanceOpOutputIndices.push_back(i); // For each output argument index, find all operations that consume its // value and store the mapping in OutputConnections Value outputArg = callOp.getOperand(i); - auto &fanouts = OutputConnections[i]; + auto &fanouts = outputConnections[i]; for (auto &use : outputArg.getUses()) { Operation *user = use.getOwner(); if (user != callOp) { @@ -1234,7 +1237,7 @@ ConvertCalls::matchAndRewrite(func::CallOp callOp, OpAdaptor adaptor, } } else if (nameAttr.getValue().starts_with("parameter_")) { // Extract and Store parameter name and value inside a Dictionary. - InstanceOpParameterIndices.push_back(i); + instanceOpParameterIndices.push_back(i); StringRef parameterName = nameAttr.getValue().drop_front(strlen("parameter_")); Value operand = callOp.getOperand(i); @@ -1254,15 +1257,15 @@ ConvertCalls::matchAndRewrite(func::CallOp callOp, OpAdaptor adaptor, assert(false && "Invalid argument naming"); } } - assert(!InstanceOpOutputIndices.empty() && + assert(!instanceOpOutputIndices.empty() && "Placeholder functions must at least have one output_ argument!"); // For each operand, check if its index is in the output or parameter index // vector. If it is, remove that operand. This ensures that the operands // list only contains input arguments. We iterate in reverse to avoid // invalidating indices while erasing. for (int i = adaptor.getOperands().size() - 1; i >= 0; --i) { - if (llvm::is_contained(InstanceOpOutputIndices, i) || - llvm::is_contained(InstanceOpParameterIndices, i)) { + if (llvm::is_contained(instanceOpOutputIndices, i) || + llvm::is_contained(instanceOpParameterIndices, i)) { operands.erase(operands.begin() + i); } } @@ -1277,9 +1280,9 @@ ConvertCalls::matchAndRewrite(func::CallOp callOp, OpAdaptor adaptor, SmallVector newInputs; SmallVector newResults; for (unsigned i = 0; i < calledFuncOpType.getNumInputs(); ++i) { - if (llvm::is_contained(InstanceOpInputIndices, i)) { + if (llvm::is_contained(instanceOpInputIndices, i)) { newInputs.push_back(calledFuncOpType.getInput(i)); - } else if (llvm::is_contained(InstanceOpOutputIndices, i)) { + } else if (llvm::is_contained(instanceOpOutputIndices, i)) { newResults.push_back(calledFuncOpType.getInput(i)); } } @@ -1309,12 +1312,12 @@ ConvertCalls::matchAndRewrite(func::CallOp callOp, OpAdaptor adaptor, // our called placeholder. All output arguments must originate from __init // calls (enforced via assertions). All __init calls are collected so they // can be erased after rewiring. - for (unsigned outputArgIdx : InstanceOpOutputIndices) { + for (unsigned outputArgIdx : instanceOpOutputIndices) { assert(outputArgIdx < adaptor.getOperands().size() && "Output index out of bounds"); Value outputVal = adaptor.getOperands()[outputArgIdx]; - auto definingOp = stripCasts(outputVal); + auto *definingOp = stripCasts(outputVal); assert(definingOp && "Expected defining op for output value"); auto sourceCallOp = dyn_cast(definingOp); @@ -1335,8 +1338,8 @@ ConvertCalls::matchAndRewrite(func::CallOp callOp, OpAdaptor adaptor, } } // Erase parameters from callOp - llvm::sort(InstanceOpParameterIndices, std::greater<>()); - for (unsigned id : InstanceOpParameterIndices) { + llvm::sort(instanceOpParameterIndices, std::greater<>()); + for (unsigned id : instanceOpParameterIndices) { if (id < callOp->getNumOperands()) { callOp->eraseOperand(id); } @@ -1352,8 +1355,9 @@ ConvertCalls::matchAndRewrite(func::CallOp callOp, OpAdaptor adaptor, handshake::ControlType::get(rewriter.getContext())); rewriter.setInsertionPoint(callOp); - auto instOp = rewriter.create( - callOp.getLoc(), callOp.getCallee(), handshakeResultTypes, operands); + auto instOp = handshake::InstanceOp::create(rewriter, callOp.getLoc(), + callOp.getCallee(), + handshakeResultTypes, operands); instOp->setDialectAttrs(callOp->getDialectAttrs()); // attach parameters to the new Instance as attributes @@ -1367,23 +1371,23 @@ ConvertCalls::matchAndRewrite(func::CallOp callOp, OpAdaptor adaptor, // by the instance, so we update all uses to reference the correct result. // This step will be skipped for all non-placeholder function since for them // the output index list is empty - if (!InstanceOpOutputIndices.empty()) { - auto InstanceResults = instOp.getResults(); - unsigned ResultId = 0; - for (auto OutputId : InstanceOpOutputIndices) { - for (Operation *user : OutputConnections[OutputId]) { + if (!instanceOpOutputIndices.empty()) { + auto instanceResults = instOp.getResults(); + unsigned resultId = 0; + for (auto outputId : instanceOpOutputIndices) { + for (Operation *user : outputConnections[outputId]) { for (OpOperand &operand : user->getOpOperands()) { // ASSUMPTION: Data dependencies are acyclic, no output from the // instance is used to (directly or indirectly) compute its own input. // All uses of placeholder values must occur after the instance is // inserted. This avoids SSA violations and preserves correct dataflow // semantics. - if (operand.get() == callOp.getOperand(OutputId)) { - operand.set(InstanceResults[ResultId]); + if (operand.get() == callOp.getOperand(outputId)) { + operand.set(instanceResults[resultId]); } } } - ResultId++; + resultId++; } } @@ -1454,7 +1458,7 @@ ConvertConstants::matchAndRewrite(arith::ConstantOp cstOp, // Determine the new constant's control input Value controlVal; if (isCstSourcable(cstOp)) { - auto sourceOp = rewriter.create(cstOp.getLoc()); + auto sourceOp = handshake::SourceOp::create(rewriter, cstOp.getLoc()); inheritBB(cstOp, sourceOp); controlVal = sourceOp.getResult(); } else { @@ -1468,8 +1472,8 @@ ConvertConstants::matchAndRewrite(arith::ConstantOp cstOp, cstAttr = IntegerAttr::get(intType, cast(cstAttr).getValue().trunc(32)); } - auto newCstOp = rewriter.create(cstOp.getLoc(), - cstAttr, controlVal); + auto newCstOp = handshake::ConstantOp::create(rewriter, cstOp.getLoc(), + cstAttr, controlVal); newCstOp->setDialectAttrs(cstOp->getDialectAttrs()); namer.replaceOp(cstOp, newCstOp); rewriter.replaceOp(cstOp, newCstOp->getResults()); @@ -1495,8 +1499,8 @@ LogicalResult ConvertUndefinedValues::matchAndRewrite( // Create a constant with a default value and replace the undefined value rewriter.setInsertionPoint(undefOp); - auto cstOp = rewriter.create(undefOp.getLoc(), cstAttr, - getBlockControl(undefOp)); + auto cstOp = handshake::ConstantOp::create(rewriter, undefOp.getLoc(), + cstAttr, getBlockControl(undefOp)); cstOp->setDialectAttrs(undefOp->getAttrDictionary()); namer.replaceOp(cstOp, cstOp); rewriter.replaceOp(undefOp, cstOp.getResult()); @@ -1600,6 +1604,30 @@ struct GlobalOpConversion : public DynOpConversionPattern { } }; +// Simply Inlining all the blocks other than the first BB into the first BB +struct InlineAllBlocksIntoOneConversion + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(handshake::FuncOp funcOp, OpAdaptor adapter, + ConversionPatternRewriter &rewriter) const override { + llvm::errs() << "Hiii!!!\n"; + Operation *lastOp = &funcOp.front().back(); + llvm::errs() << "Hiii2!!!\n"; + for (Block &block : + llvm::make_early_inc_range(llvm::drop_begin(funcOp, 1))) { + SmallVector replacements; + for (auto _ : block.getArguments()) { + replacements.push_back(mlir::Value()); + } + + llvm::errs() << "Hi!!!\n"; + rewriter.inlineBlockBefore(&block, lastOp, replacements); + } + return success(); + } +}; + //===-----------------------------------------------------------------------==// // Pass driver //===-----------------------------------------------------------------------==// @@ -1710,6 +1738,30 @@ struct CfToHandshakePass if (failed(applyFullConversion(modOp, target, std::move(patterns)))) return signalPassFailure(); + // RewritePatternSet inlinePatterns{ctx}; + // inlinePatterns.add(ctx); + + // modOp.dump(); + + // if (failed(applyPatternsGreedily(modOp, std::move(inlinePatterns)))) + // return signalPassFailure(); + + for (auto funcOp : modOp.getOps()) { + Block *firstBlock = &funcOp.getBlocks().front(); + + auto *endOpIt = firstBlock->getTerminator(); + + for (Block &otherBlock : + llvm::make_early_inc_range(llvm::drop_begin(funcOp))) { + for (Operation &op : llvm::make_early_inc_range(otherBlock)) { + op.moveBefore(endOpIt); + } + otherBlock.erase(); + } + } + + modOp.dump(); + // Clean up: Remove the definition of each __init* function, but only if it // has no remaining uses. This is safe because all valid calls to __init* // were tracked and deleted earlier. diff --git a/tools/dynamatic/scripts/compile.sh b/tools/dynamatic/scripts/compile.sh index ac22c4bce7..493a689e49 100755 --- a/tools/dynamatic/scripts/compile.sh +++ b/tools/dynamatic/scripts/compile.sh @@ -245,7 +245,7 @@ if [[ $FAST_TOKEN_DELIVERY -ne 0 ]]; then exit_on_fail "Failed to compile cf to handshake with FTD" "Compiled cf to handshake with FTD" else "$DYNAMATIC_OPT_BIN" "$F_CF_DYN_TRANSFORMED_MEM_DEP_MARKED" --lower-cf-to-handshake \ - > "$F_HANDSHAKE" + -debug-only=dialect-conversion > "$F_HANDSHAKE" exit_on_fail "Failed to compile cf to handshake" "Compiled cf to handshake" fi From a585fe3d2e1996b14a3a16beee14f7223fae5a9f Mon Sep 17 00:00:00 2001 From: Jiahui Xu Date: Mon, 10 Nov 2025 20:10:28 +0100 Subject: [PATCH 05/27] fix handshake folder pattern --- .../CfToHandshake/CfToHandshake.cpp | 33 +------------------ lib/Dialect/Handshake/HandshakeOps.cpp | 12 ++++++- tools/dynamatic/scripts/compile.sh | 4 +-- tools/dynamatic/scripts/simulate.sh | 4 ++- 4 files changed, 17 insertions(+), 36 deletions(-) diff --git a/lib/Conversion/CfToHandshake/CfToHandshake.cpp b/lib/Conversion/CfToHandshake/CfToHandshake.cpp index b4303b1a69..633b145835 100644 --- a/lib/Conversion/CfToHandshake/CfToHandshake.cpp +++ b/lib/Conversion/CfToHandshake/CfToHandshake.cpp @@ -1604,30 +1604,6 @@ struct GlobalOpConversion : public DynOpConversionPattern { } }; -// Simply Inlining all the blocks other than the first BB into the first BB -struct InlineAllBlocksIntoOneConversion - : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - LogicalResult - matchAndRewrite(handshake::FuncOp funcOp, OpAdaptor adapter, - ConversionPatternRewriter &rewriter) const override { - llvm::errs() << "Hiii!!!\n"; - Operation *lastOp = &funcOp.front().back(); - llvm::errs() << "Hiii2!!!\n"; - for (Block &block : - llvm::make_early_inc_range(llvm::drop_begin(funcOp, 1))) { - SmallVector replacements; - for (auto _ : block.getArguments()) { - replacements.push_back(mlir::Value()); - } - - llvm::errs() << "Hi!!!\n"; - rewriter.inlineBlockBefore(&block, lastOp, replacements); - } - return success(); - } -}; - //===-----------------------------------------------------------------------==// // Pass driver //===-----------------------------------------------------------------------==// @@ -1738,14 +1714,6 @@ struct CfToHandshakePass if (failed(applyFullConversion(modOp, target, std::move(patterns)))) return signalPassFailure(); - // RewritePatternSet inlinePatterns{ctx}; - // inlinePatterns.add(ctx); - - // modOp.dump(); - - // if (failed(applyPatternsGreedily(modOp, std::move(inlinePatterns)))) - // return signalPassFailure(); - for (auto funcOp : modOp.getOps()) { Block *firstBlock = &funcOp.getBlocks().front(); @@ -1773,6 +1741,7 @@ struct CfToHandshakePass } } // The name analysis is always preserved across passes + markAnalysesPreserved(); } }; diff --git a/lib/Dialect/Handshake/HandshakeOps.cpp b/lib/Dialect/Handshake/HandshakeOps.cpp index 2a4921425e..87e8a4be38 100644 --- a/lib/Dialect/Handshake/HandshakeOps.cpp +++ b/lib/Dialect/Handshake/HandshakeOps.cpp @@ -1944,7 +1944,17 @@ OpFoldResult TruncIOp::fold(FoldAdaptor adaptor) { return getResult(); } // Bypass the preceeding extension operation and the truncation - return src; + if (srcWidth == dstWidth) { + return src; + } + + // NOTE (10.11.2025): The current MLIR version does not allow the folder to + // change the result type. + // + // if srcWidth <= dstWidth: we need to use special + // canonicalization rewrite to optimize away (ext -> trunc) instead of + // folding + return nullptr; } // Identical operand and result types mean that the trunc is a no-op diff --git a/tools/dynamatic/scripts/compile.sh b/tools/dynamatic/scripts/compile.sh index 493a689e49..2a2bd0cb70 100755 --- a/tools/dynamatic/scripts/compile.sh +++ b/tools/dynamatic/scripts/compile.sh @@ -25,7 +25,7 @@ LLVM_BINS=$DYNAMATIC_DIR/bin export PATH=$PATH:$LLVM_BINS POLYGEIST_CLANG_BIN="$DYNAMATIC_DIR/bin/cgeist" -CLANGXX_BIN="$DYNAMATIC_DIR/bin/clang++" +CLANGXX_BIN="$LLVM_BINS/clang++" LLVM_OPT="$LLVM_BINS/opt" LLVM_TO_STD_TRANSLATION_BIN="$DYNAMATIC_DIR/build/bin/translate-llvm-to-std" DYNAMATIC_OPT_BIN="$DYNAMATIC_DIR/bin/dynamatic-opt" @@ -245,7 +245,7 @@ if [[ $FAST_TOKEN_DELIVERY -ne 0 ]]; then exit_on_fail "Failed to compile cf to handshake with FTD" "Compiled cf to handshake with FTD" else "$DYNAMATIC_OPT_BIN" "$F_CF_DYN_TRANSFORMED_MEM_DEP_MARKED" --lower-cf-to-handshake \ - -debug-only=dialect-conversion > "$F_HANDSHAKE" + > "$F_HANDSHAKE" exit_on_fail "Failed to compile cf to handshake" "Compiled cf to handshake" fi diff --git a/tools/dynamatic/scripts/simulate.sh b/tools/dynamatic/scripts/simulate.sh index 2de89856c2..f927acb75c 100755 --- a/tools/dynamatic/scripts/simulate.sh +++ b/tools/dynamatic/scripts/simulate.sh @@ -16,6 +16,8 @@ VIVADO_FPU=$6 SIMULATOR_NAME=$7 # Generated directories/files +LLVM=$DYNAMATIC_DIR/llvm-project +LLVM_BINS=$LLVM/build/bin SIM_DIR="$(realpath "$OUTPUT_DIR/sim")" C_SRC_DIR="$SIM_DIR/C_SRC" C_OUT_DIR="$SIM_DIR/C_OUT" @@ -27,7 +29,7 @@ IO_GEN_BIN="$SIM_DIR/C_SRC/$KERNEL_NAME-io-gen" # Shortcuts HDL_DIR="$OUTPUT_DIR/hdl" -CLANGXX_BIN="$DYNAMATIC_DIR/bin/clang++" +CLANGXX_BIN="$LLVM_BINS/clang++" HLS_VERIFIER_BIN="$DYNAMATIC_DIR/bin/hls-verifier" RESOURCE_DIR="$DYNAMATIC_DIR/tools/hls-verifier/resources" From 0111b15b48a248ccd2fb982c17bb09206a12373e Mon Sep 17 00:00:00 2001 From: Jiahui Xu Date: Mon, 10 Nov 2025 21:02:30 +0100 Subject: [PATCH 06/27] fix deprecated create statements --- .../lib/Conversion/FtdCfToHandshake.cpp | 22 +++---- .../lib/Transforms/ResourceSharing/Crush.cpp | 10 ++- .../HandshakeRigidification.cpp | 6 +- .../Speculation/HandshakeSpeculation.cpp | 60 +++++++++-------- .../Dialect/Handshake/HandshakeOps.td | 4 +- .../CfToHandshake/CfToHandshake.cpp | 14 ++-- .../HandshakeToHW/HandshakeToHW.cpp | 30 ++++----- lib/Conversion/ScfToCf/ScfToCf.cpp | 14 ++-- lib/Dialect/HW/HWDialect.cpp | 8 +-- lib/Dialect/HW/HWOps.cpp | 34 +++++----- lib/Dialect/Handshake/MemoryInterfaces.cpp | 26 ++++---- lib/Transforms/ArithReduceStrength.cpp | 39 ++++++----- .../BufferPlacement/HandshakePlaceBuffers.cpp | 8 +-- lib/Transforms/DropUnlistedFunctions.cpp | 19 ------ lib/Transforms/FlattenMemRefRowMajor.cpp | 26 ++++---- lib/Transforms/HandshakeCanonicalize.cpp | 17 ++--- lib/Transforms/HandshakeMaterialize.cpp | 24 +++---- lib/Transforms/HandshakeMinimizeCstWidth.cpp | 11 ++-- lib/Transforms/HandshakeOptimizeBitwidths.cpp | 45 +++++++------ lib/Transforms/PushConstants.cpp | 8 +-- lib/Transforms/ScfRotateForLoops.cpp | 22 ++++--- lib/Transforms/ScfSimpleIfToSelect.cpp | 17 +++-- .../TranslateLLVMToStd.cpp | 64 ++++++++++--------- 23 files changed, 254 insertions(+), 274 deletions(-) diff --git a/experimental/lib/Conversion/FtdCfToHandshake.cpp b/experimental/lib/Conversion/FtdCfToHandshake.cpp index 8f1fec89c8..2b1057f24a 100644 --- a/experimental/lib/Conversion/FtdCfToHandshake.cpp +++ b/experimental/lib/Conversion/FtdCfToHandshake.cpp @@ -163,8 +163,8 @@ static LogicalResult convertUndefinedValues(ConversionPatternRewriter &rewriter, // Create a constant with a default value and replace the undefined value rewriter.setInsertionPoint(undefOp); - auto cstOp = rewriter.create(undefOp.getLoc(), - cstAttr, startValue); + auto cstOp = handshake::ConstantOp::create(rewriter, undefOp.getLoc(), + cstAttr, startValue); cstOp->setDialectAttrs(undefOp->getAttrDictionary()); undefOp.getResult().replaceAllUsesWith(cstOp.getResult()); namer.replaceOp(cstOp, cstOp); @@ -206,8 +206,8 @@ static LogicalResult convertConstants(ConversionPatternRewriter &rewriter, intType, cast(valueAttr).getValue().trunc(32)); } - auto newCstOp = rewriter.create( - cstOp.getLoc(), valueAttr, controlValue); + auto newCstOp = handshake::ConstantOp::create(rewriter, cstOp.getLoc(), + valueAttr, controlValue); newCstOp->setDialectAttrs(cstOp->getDialectAttrs()); @@ -227,8 +227,8 @@ LogicalResult FtdOneToOneConversion::matchAndRewrite( for (Type resType : srcOp->getResultTypes()) newTypes.push_back(channelifyType(resType)); auto newOp = - rewriter.create(srcOp->getLoc(), newTypes, adaptor.getOperands(), - srcOp->getAttrDictionary().getValue()); + DstOp::create(rewriter, srcOp->getLoc(), newTypes, adaptor.getOperands(), + srcOp->getAttrDictionary().getValue()); // /!\ This is the main difference from the base function. Without such // replacement, a "null operand found" error is present at the end of the @@ -369,13 +369,13 @@ LogicalResult FtdConvertIndexCast::matchAndRewrite( if (srcWidth < dstWidth) { // This is an extension newOp = - rewriter.create(castOp.getLoc(), dstType, adaptor.getOperands(), - castOp->getAttrDictionary().getValue()); + ExtOp::create(rewriter, castOp.getLoc(), dstType, adaptor.getOperands(), + castOp->getAttrDictionary().getValue()); } else { // This is a truncation - newOp = rewriter.create( - castOp.getLoc(), dstType, adaptor.getOperands(), - castOp->getAttrDictionary().getValue()); + newOp = handshake::TruncIOp::create(rewriter, castOp.getLoc(), dstType, + adaptor.getOperands(), + castOp->getAttrDictionary().getValue()); } this->namer.replaceOp(castOp, newOp); rewriter.replaceOp(castOp, newOp); diff --git a/experimental/lib/Transforms/ResourceSharing/Crush.cpp b/experimental/lib/Transforms/ResourceSharing/Crush.cpp index 71baf87ea2..5d58bb02f3 100644 --- a/experimental/lib/Transforms/ResourceSharing/Crush.cpp +++ b/experimental/lib/Transforms/ResourceSharing/Crush.cpp @@ -415,12 +415,10 @@ LogicalResult CreditBasedSharingPass::sharingWrapperInsertion( "The sharing wrapper has an incorrect number of output ports."); builder.setInsertionPoint(*group.begin()); - handshake::SharingWrapperOp wrapperOp = - builder.create( - sharedOp->getLoc(), sharingWrapperOutputTypes, dataOperands, - sharedOp->getResult(0), llvm::ArrayRef(credits), - credits.size(), sharedOp->getNumOperands(), - (unsigned)round(latency)); + handshake::SharingWrapperOp wrapperOp = handshake::SharingWrapperOp::create( + builder, sharedOp->getLoc(), sharingWrapperOutputTypes, dataOperands, + sharedOp->getResult(0), llvm::ArrayRef(credits), + credits.size(), sharedOp->getNumOperands(), (unsigned)round(latency)); // Replace original connection from op->successor to // sharingWrapper->successor diff --git a/experimental/lib/Transforms/Rigidification/HandshakeRigidification.cpp b/experimental/lib/Transforms/Rigidification/HandshakeRigidification.cpp index 2858bef5d9..95b82f9a4d 100644 --- a/experimental/lib/Transforms/Rigidification/HandshakeRigidification.cpp +++ b/experimental/lib/Transforms/Rigidification/HandshakeRigidification.cpp @@ -93,7 +93,7 @@ HandshakeRigidificationPass::insertReadyRemover(AbsenceOfBackpressure prop) { builder.setInsertionPointAfter(ownerOp); auto loc = channel.getLoc(); - auto newOp = builder.create(loc, channel); + auto newOp = handshake::ReadyRemoverOp::create(builder, loc, channel); channel.replaceAllUsesExcept(newOp.getResult(), newOp); return success(); @@ -113,8 +113,8 @@ HandshakeRigidificationPass::insertValidMerger(ValidEquivalence prop) { Location loc = FusedLoc::get(ctx, {ownerChannel.getLoc(), targetChannel.getLoc()}); - auto newOp = builder.create(loc, ownerChannel, - targetChannel); + auto newOp = handshake::ValidMergerOp::create(builder, loc, ownerChannel, + targetChannel); ownerChannel.replaceAllUsesExcept(newOp.getLhsOut(), newOp); targetChannel.replaceAllUsesExcept(newOp.getRhsOut(), newOp); diff --git a/experimental/lib/Transforms/Speculation/HandshakeSpeculation.cpp b/experimental/lib/Transforms/Speculation/HandshakeSpeculation.cpp index 6f58fa8676..59cecf2e30 100644 --- a/experimental/lib/Transforms/Speculation/HandshakeSpeculation.cpp +++ b/experimental/lib/Transforms/Speculation/HandshakeSpeculation.cpp @@ -230,8 +230,8 @@ routeCommitControlRecursive(MLIRContext *ctx, SpeculatorOp &specOp, if (!branchDiscardNonSpec.has_value()) { // trueResultType and falseResultType are tentative and will be updated // in the addSpecTag algorithm later. - branchDiscardNonSpec = builder.create( - branchOp.getLoc(), + branchDiscardNonSpec = handshake::SpeculatingBranchOp::create( + builder, branchOp.getLoc(), /*trueResultType=*/conditionOperand.getType(), /*falseResultType=*/conditionOperand.getType(), /*specTag=*/valueForSpecTag, conditionOperand); @@ -244,8 +244,8 @@ routeCommitControlRecursive(MLIRContext *ctx, SpeculatorOp &specOp, if (!branchReplicated.has_value()) { // The replicated branch directs the control token based on the path the // speculative token took - branchReplicated = builder.create( - branchDiscardNonSpec->getLoc(), + branchReplicated = handshake::ConditionalBranchOp::create( + builder, branchDiscardNonSpec->getLoc(), /*condition=*/branchDiscardNonSpec->getTrueResult(), /*data=*/ctrlSignal); inheritBB(specOp, *branchReplicated); @@ -337,9 +337,8 @@ LogicalResult HandshakeSpeculationPass::placeCommits() { // generated by the BackedgeBuilder must be replaced before the builder is // destroyed, which occurs before exiting this method. fakeControlForCommits = - builder - .create( - specOp->getLoc(), commitCtrl.getType(), ValueRange{}) + mlir::UnrealizedConversionCastOp::create( + builder, specOp->getLoc(), commitCtrl.getType(), ValueRange{}) .getResult(0); // Place commits and connect to the fake control signal @@ -351,8 +350,8 @@ LogicalResult HandshakeSpeculationPass::placeCommits() { builder.setInsertionPoint(dstOp); // resultType is tentative and will be updated in the addSpecTag algorithm // later. - SpecCommitOp newOp = builder.create( - dstOp->getLoc(), /*resultType=*/srcOpResult.getType(), + SpecCommitOp newOp = SpecCommitOp::create( + builder, dstOp->getLoc(), /*resultType=*/srcOpResult.getType(), /*dataIn=*/srcOpResult, /*ctrl=*/fakeControlForCommits.value()); inheritBB(dstOp, newOp); @@ -381,8 +380,8 @@ LogicalResult HandshakeSpeculationPass::placeSaveCommits(Value ctrlSignal) { builder.setInsertionPoint(dstOp); // resultType is tentative and will be updated in the addSpecTag algorithm // later. - SpecSaveCommitOp newOp = builder.create( - dstOp->getLoc(), /*resultType=*/srcOpResult.getType(), + SpecSaveCommitOp newOp = SpecSaveCommitOp::create( + builder, dstOp->getLoc(), /*resultType=*/srcOpResult.getType(), /*dataIn=*/srcOpResult, /*ctrl=*/ctrlSignal, /*fifoDepth=*/fifoDepth); inheritBB(dstOp, newOp); @@ -440,27 +439,25 @@ FailureOr HandshakeSpeculationPass::generateSaveCommitCtrl() { auto conditionOperand = controlBranch.getConditionOperand(); // trueResultType and falseResultType are tentative and will be updated in the // addSpecTag algorithm later. - auto branchDiscardCondNonSpec = - builder.create( - controlBranch.getLoc(), - /*trueResultType=*/conditionOperand.getType(), - /*falseResultType=*/conditionOperand.getType(), - /*specTag=*/specOp.getDataOut(), conditionOperand); + auto branchDiscardCondNonSpec = handshake::SpeculatingBranchOp::create( + builder, controlBranch.getLoc(), + /*trueResultType=*/conditionOperand.getType(), + /*falseResultType=*/conditionOperand.getType(), + /*specTag=*/specOp.getDataOut(), conditionOperand); inheritBB(specOp, branchDiscardCondNonSpec); // Second, discard if speculation happened but it was correct // Create a conditional branch driven by SCBranchControl from speculator // SCBranchControl discards the commit-like signal when speculation is correct - auto branchDiscardCondNonMisspec = - builder.create( - branchDiscardCondNonSpec.getLoc(), specOp.getSCIsMisspec(), - branchDiscardCondNonSpec.getTrueResult()); + auto branchDiscardCondNonMisspec = handshake::ConditionalBranchOp::create( + builder, branchDiscardCondNonSpec.getLoc(), specOp.getSCIsMisspec(), + branchDiscardCondNonSpec.getTrueResult()); inheritBB(specOp, branchDiscardCondNonMisspec); // This branch will propagate the signal SCCommitControl according to // the control branch condition, which comes from branchDiscardCondNonMisSpec - auto branchReplicated = builder.create( - branchDiscardCondNonMisspec.getLoc(), + auto branchReplicated = handshake::ConditionalBranchOp::create( + builder, branchDiscardCondNonMisspec.getLoc(), branchDiscardCondNonMisspec.getTrueResult(), specOp.getSCCommitCtrl()); inheritBB(specOp, branchReplicated); @@ -496,16 +493,16 @@ FailureOr HandshakeSpeculationPass::generateSaveCommitCtrl() { } // All the inputs to the merge operation are ready - auto mergeOp = builder.create(branchReplicated.getLoc(), - mergeOperands); + auto mergeOp = handshake::MergeOp::create(builder, branchReplicated.getLoc(), + mergeOperands); inheritBB(specOp, mergeOp); // The control signal is the result of the merge op. return mergeOp.getResult(); } -std::optional findControlInputToBB(handshake::FuncOp &funcOp, - unsigned targetBB) { +static std::optional findControlInputToBB(handshake::FuncOp &funcOp, + unsigned targetBB) { // Here we fork control token to use as trigger signal to speculator. // The presence of a buffer between this fork and the control branch creates // performance issues (see detailed speculation documentation). Therefore we @@ -580,8 +577,8 @@ LogicalResult HandshakeSpeculationPass::placeSpeculator() { // resultType is tentative and will be updated in the addSpecTag algorithm // later. - specOp = builder.create( - dstOp->getLoc(), /*resultType=*/srcOpResult.getType(), + specOp = handshake::SpeculatorOp::create( + builder, dstOp->getLoc(), /*resultType=*/srcOpResult.getType(), /*dataIn=*/srcOpResult, /*specIn=*/specTrigger.value(), fifoDepth); // Replace uses of the original source operation's result with the @@ -803,8 +800,9 @@ LogicalResult HandshakeSpeculationPass::addNonSpecOp() { if (!dataOperandType.hasExtraSignal(EXTRA_BIT_SPEC)) { // Create a NonSpecOp to add the spec tag to the data operand builder.setInsertionPointAfterValue(dataOperand); - auto nonSpecOp = builder.create( - mergeLikeOp.getLoc(), dataOperand.getType(), dataOperand); + auto nonSpecOp = + NonSpecOp::create(builder, mergeLikeOp.getLoc(), + dataOperand.getType(), dataOperand); inheritBB(mergeLikeOp, nonSpecOp); // Add the spec tag to the NonSpecOp's result diff --git a/include/dynamatic/Dialect/Handshake/HandshakeOps.td b/include/dynamatic/Dialect/Handshake/HandshakeOps.td index 7c6ac67670..d207792241 100644 --- a/include/dynamatic/Dialect/Handshake/HandshakeOps.td +++ b/include/dynamatic/Dialect/Handshake/HandshakeOps.td @@ -195,7 +195,7 @@ def InstanceOp : Handshake_Op<"instance", [ /// Set the callee for this operation. void setCalleeFromCallable(mlir::CallInterfaceCallable callee) { - (*this)->setAttr(getModuleAttrName(), callee.get()); + (*this)->setAttr(getModuleAttrName(), cast(callee)); } /// Get the control operand of this instance op @@ -1003,7 +1003,7 @@ def RAMOp : Handshake_Op<"ram", []> { Example: Using handshake.RAMOp in the C++ API. ```c++ // Building a new RAMOp - builder.create(type, elementAttr); + handshake::RAMOp::create(builder, type, elementAttr); // Access its initialValue std::optional initialValue = ramOp.getInitialValue(); diff --git a/lib/Conversion/CfToHandshake/CfToHandshake.cpp b/lib/Conversion/CfToHandshake/CfToHandshake.cpp index 633b145835..a4a6360e38 100644 --- a/lib/Conversion/CfToHandshake/CfToHandshake.cpp +++ b/lib/Conversion/CfToHandshake/CfToHandshake.cpp @@ -1133,8 +1133,8 @@ LogicalResult OneToOneConversion::matchAndRewrite( for (Type resType : srcOp->getResultTypes()) newTypes.push_back(channelifyType(resType)); auto newOp = - rewriter.create(srcOp->getLoc(), newTypes, adaptor.getOperands(), - srcOp->getAttrDictionary().getValue()); + DstOp::create(rewriter, srcOp->getLoc(), newTypes, adaptor.getOperands(), + srcOp->getAttrDictionary().getValue()); namer.replaceOp(srcOp, newOp); rewriter.replaceOp(srcOp, newOp); return success(); @@ -1158,13 +1158,13 @@ LogicalResult ConvertIndexCast::matchAndRewrite( if (srcWidth < dstWidth) { // This is an extension newOp = - rewriter.create(castOp.getLoc(), dstType, adaptor.getOperands(), - castOp->getAttrDictionary().getValue()); + ExtOp::create(rewriter, castOp.getLoc(), dstType, adaptor.getOperands(), + castOp->getAttrDictionary().getValue()); } else { // This is a truncation - newOp = rewriter.create( - castOp.getLoc(), dstType, adaptor.getOperands(), - castOp->getAttrDictionary().getValue()); + newOp = handshake::TruncIOp::create(rewriter, castOp.getLoc(), dstType, + adaptor.getOperands(), + castOp->getAttrDictionary().getValue()); } namer.replaceOp(castOp, newOp); rewriter.replaceOp(castOp, newOp); diff --git a/lib/Conversion/HandshakeToHW/HandshakeToHW.cpp b/lib/Conversion/HandshakeToHW/HandshakeToHW.cpp index 47bb14094d..4a3d014844 100644 --- a/lib/Conversion/HandshakeToHW/HandshakeToHW.cpp +++ b/lib/Conversion/HandshakeToHW/HandshakeToHW.cpp @@ -1106,16 +1106,16 @@ hw::InstanceOp HWBuilder::createInstance(ModuleDiscriminator &discriminator, StringAttr modNameAttr = builder.getStringAttr(extModName); RewriterBase::InsertPoint instInsertPoint = builder.saveInsertionPoint(); builder.setInsertionPointToEnd(topLevelModOp.getBody()); - extModOp = builder.create(loc, modNameAttr, - modBuilder.getPortInfo()); + extModOp = hw::HWModuleExternOp::create(builder, loc, modNameAttr, + modBuilder.getPortInfo()); discriminator.setParameters(extModOp); builder.restoreInsertionPoint(instInsertPoint); } // Now create the instance corresponding to the external module StringAttr instNameAttr = builder.getStringAttr(instName); - return builder.create(loc, extModOp, instNameAttr, - instOperands); + return hw::InstanceOp::create(builder, loc, extModOp, instNameAttr, + instOperands); } hw::ModulePortInfo ModuleBuilder::getPortInfo() { @@ -1151,8 +1151,8 @@ static void addMemIO(ModuleBuilder &modBuilder, handshake::FuncOp funcOp, /// Handshake function. Fills in the lowering state object with information /// that will allow the conversion pass to connect memory interface to their /// top-level IO later on. -hw::ModulePortInfo getFuncPortInfo(handshake::FuncOp funcOp, - ModuleLoweringState &state) { +static hw::ModulePortInfo getFuncPortInfo(handshake::FuncOp funcOp, + ModuleLoweringState &state) { ModuleBuilder modBuilder(funcOp.getContext()); handshake::PortNamer portNames(funcOp); @@ -1266,7 +1266,7 @@ ConvertFunc::matchAndRewrite(handshake::FuncOp funcOp, OpAdaptor adaptor, // Create non-external HW module to replace the function with rewriter.setInsertionPoint(funcOp); - auto modOp = rewriter.create(funcOp.getLoc(), name, modInfo); + auto modOp = hw::HWModuleOp::create(rewriter, funcOp.getLoc(), name, modInfo); // Move the block from the Handshake function to the new HW module, after // which the Handshake function becomes empty and can be deleted @@ -1830,8 +1830,8 @@ hw::InstanceOp ConverterBuilder::createInstance(hw::HWModuleOp wrapperOp, // Create an instance of the converter StringAttr name = builder.getStringAttr("mem_to_bram_converter_" + memName); builder.setInsertionPoint(circuitOp); - hw::InstanceOp converterInstOp = builder.create( - circuitOp.getLoc(), converterModOp, name, instOperands); + hw::InstanceOp converterInstOp = hw::InstanceOp::create( + builder, circuitOp.getLoc(), converterModOp, name, instOperands); // Resolve backedges in the wrapped circuit operands and in the wrapper's // outputs @@ -1894,8 +1894,8 @@ MemToBRAMConverter::buildExternalModule(hw::HWModuleOp circuitMod, builder.setInsertionPointToEnd(topModOp.getBody()); StringAttr modNameAttr = builder.getStringAttr(extModName); - extModOp = builder.create( - circuitMod->getLoc(), modNameAttr, modBuilder.getPortInfo()); + extModOp = hw::HWModuleExternOp::create( + builder, circuitMod->getLoc(), modNameAttr, modBuilder.getPortInfo()); extModOp->setAttr(RTL_NAME_ATTR_NAME, StringAttr::get(ctx, HW_NAME)); SmallVector parameters; @@ -1974,8 +1974,8 @@ static hw::HWModuleOp createEmptyWrapperMod( // Create the wrapper builder.setInsertionPointToEnd(state.modOp.getBody()); - hw::HWModuleOp wrapperOp = builder.create( - circuitOp.getLoc(), + hw::HWModuleOp wrapperOp = hw::HWModuleOp::create( + builder, circuitOp.getLoc(), StringAttr::get(ctx, circuitOp.getSymName() + "_wrapper"), wrapperBuilder.getPortInfo()); builder.setInsertionPointToStart(wrapperOp.getBodyBlock()); @@ -2038,8 +2038,8 @@ static void createWrapper(hw::HWModuleOp circuitOp, LoweringState &state, } // Create the wrapped circuit instance inside the wrapper - hw::InstanceOp circuitInstOp = builder.create( - circuitOp.getLoc(), circuitOp, + hw::InstanceOp circuitInstOp = hw::InstanceOp::create( + builder, circuitOp.getLoc(), circuitOp, builder.getStringAttr(circuitOp.getSymName() + "_wrapped"), circuitOperands); diff --git a/lib/Conversion/ScfToCf/ScfToCf.cpp b/lib/Conversion/ScfToCf/ScfToCf.cpp index 659056dc46..f5240cf00b 100644 --- a/lib/Conversion/ScfToCf/ScfToCf.cpp +++ b/lib/Conversion/ScfToCf/ScfToCf.cpp @@ -93,14 +93,14 @@ struct ForLowering : public OpRewritePattern { Operation *terminator = lastBodyBlock->getTerminator(); rewriter.setInsertionPointToEnd(lastBodyBlock); auto step = forOp.getStep(); - auto stepped = rewriter.create(loc, iv, step).getResult(); + auto stepped = arith::AddIOp::create(rewriter, loc, iv, step).getResult(); if (!stepped) return failure(); SmallVector loopCarried; loopCarried.push_back(stepped); loopCarried.append(terminator->operand_begin(), terminator->operand_end()); - rewriter.create(loc, conditionBlock, loopCarried); + cf::BranchOp::create(rewriter, loc, conditionBlock, loopCarried); rewriter.eraseOp(terminator); // The initial values of loop-carried values is obtained from the operands @@ -109,15 +109,15 @@ struct ForLowering : public OpRewritePattern { destOperands.push_back(lowerBound); llvm::append_range(destOperands, forOp.getInitArgs()); rewriter.setInsertionPointToEnd(initBlock); - rewriter.create(loc, conditionBlock, destOperands); + cf::BranchOp::create(rewriter, loc, conditionBlock, destOperands); // With the body block done, we can fill in the condition block. rewriter.setInsertionPointToEnd(conditionBlock); - auto comparison = rewriter.create(loc, pred, iv, upperBound); + auto comparison = + arith::CmpIOp::create(rewriter, loc, pred, iv, upperBound); - rewriter.create(loc, comparison, firstBodyBlock, - ArrayRef(), endBlock, - ArrayRef()); + cf::CondBranchOp::create(rewriter, loc, comparison, firstBodyBlock, + ArrayRef(), endBlock, ArrayRef()); // The result of the loop operation is the values of the condition block // arguments except the induction variable on the last iteration. rewriter.replaceOp(forOp, conditionBlock->getArguments().drop_front()); diff --git a/lib/Dialect/HW/HWDialect.cpp b/lib/Dialect/HW/HWDialect.cpp index 257bb8998e..a1fedff53d 100644 --- a/lib/Dialect/HW/HWDialect.cpp +++ b/lib/Dialect/HW/HWDialect.cpp @@ -98,21 +98,21 @@ Operation *HWDialect::materializeConstant(OpBuilder &builder, Attribute value, // Integer constants can materialize into hw.constant if (auto intType = dyn_cast(type)) if (auto attrValue = dyn_cast(value)) - return builder.create(loc, type, attrValue); + return ConstantOp::create(builder, loc, type, attrValue); // Aggregate constants. if (auto arrayAttr = dyn_cast(value)) { if (isa(type)) - return builder.create(loc, type, arrayAttr); + return AggregateConstantOp::create(builder, loc, type, arrayAttr); } // Parameter expressions materialize into hw.param.value. - auto parentOp = builder.getBlock()->getParentOp(); + auto *parentOp = builder.getBlock()->getParentOp(); auto curModule = dyn_cast(parentOp); if (!curModule) curModule = parentOp->getParentOfType(); if (curModule && isValidParameterExpression(value, curModule)) - return builder.create(loc, type, value); + return ParamValueOp::create(builder, loc, type, value); return nullptr; } diff --git a/lib/Dialect/HW/HWOps.cpp b/lib/Dialect/HW/HWOps.cpp index 8d6a9df135..34f53e3d0d 100644 --- a/lib/Dialect/HW/HWOps.cpp +++ b/lib/Dialect/HW/HWOps.cpp @@ -707,7 +707,7 @@ void HWModuleOp::build(OpBuilder &builder, OperationState &odsState, modBuilder(builder, accessor); // Create output operands. llvm::SmallVector outputOperands = accessor.getOutputOperands(); - builder.create(odsState.location, outputOperands); + hw::OutputOp::create(builder, odsState.location, outputOperands); } void HWModuleOp::modifyPorts( @@ -941,7 +941,7 @@ ParseResult HWModuleGeneratedOp::parse(OpAsmParser &parser, return parseHWModuleOp(parser, result, GenMod); } -FunctionType getHWModuleOpType(Operation *op) { +static FunctionType getHWModuleOpType(Operation *op) { if (auto mod = dyn_cast(op)) return mod.getHWModuleType().getFuncType(); return cast( @@ -1820,8 +1820,8 @@ LogicalResult ArraySliceOp::canonicalize(ArraySliceOp op, if (sliceSize == 1) { // slice(a, n) -> create(a[n]) - auto get = rewriter.create(op.getLoc(), op.getInput(), - op.getLowIndex()); + auto get = ArrayGetOp::create(rewriter, op.getLoc(), op.getInput(), + op.getLowIndex()); rewriter.replaceOpWithNewOp(op, op.getType(), get.getResult()); return success(); @@ -1831,7 +1831,7 @@ LogicalResult ArraySliceOp::canonicalize(ArraySliceOp op, if (!offsetOpt) return failure(); - auto inputOp = op.getInput().getDefiningOp(); + auto *inputOp = op.getInput().getDefiningOp(); if (auto inputSlice = dyn_cast_or_null(inputOp)) { // slice(slice(a, n), m) -> slice(a, n + m) if (inputSlice == op) @@ -1844,7 +1844,7 @@ LogicalResult ArraySliceOp::canonicalize(ArraySliceOp op, uint64_t offset = *offsetOpt + *inputOffsetOpt; auto lowIndex = - rewriter.create(op.getLoc(), inputIndex.getType(), offset); + ConstantOp::create(rewriter, op.getLoc(), inputIndex.getType(), offset); rewriter.replaceOpWithNewOp(op, op.getType(), inputSlice.getInput(), lowIndex); return success(); @@ -1885,10 +1885,11 @@ LogicalResult ArraySliceOp::canonicalize(ArraySliceOp op, } else { // Slice the required bits from the input. unsigned width = inputSize == 1 ? 1 : llvm::Log2_64_Ceil(inputSize); - auto lowIndex = rewriter.create( - op.getLoc(), rewriter.getIntegerType(width), sliceStart); - chunks.push_back(rewriter.create( - op.getLoc(), hw::ArrayType::get(elemTy, cutSize), input, lowIndex)); + auto lowIndex = ConstantOp::create( + rewriter, op.getLoc(), rewriter.getIntegerType(width), sliceStart); + chunks.push_back(ArraySliceOp::create( + rewriter, op.getLoc(), hw::ArrayType::get(elemTy, cutSize), input, + lowIndex)); } sliceStart = 0; @@ -2422,7 +2423,7 @@ OpFoldResult StructExtractOp::fold(FoldAdaptor adaptor) { LogicalResult StructExtractOp::canonicalize(StructExtractOp op, PatternRewriter &rewriter) { - auto inputOp = op.getInput().getDefiningOp(); + auto *inputOp = op.getInput().getDefiningOp(); // b = extract(inject(x["a"], v0)["b"]) => extract(x, "b") if (auto structInject = dyn_cast_or_null(inputOp)) { @@ -2557,8 +2558,8 @@ LogicalResult StructInjectOp::canonicalize(StructInjectOp op, auto it = fields.find(elements[fieldIndex].name); if (it == fields.end()) continue; - input = rewriter.create(op.getLoc(), ty, input, fieldIndex, - it->second); + input = StructInjectOp::create(rewriter, op.getLoc(), ty, input, fieldIndex, + it->second); } rewriter.replaceOp(op, input); @@ -2743,7 +2744,7 @@ LogicalResult ArrayGetOp::canonicalize(ArrayGetOp op, uint64_t offset = *offsetOpt + *idxOpt; auto newOffset = - rewriter.create(op.getLoc(), offsetOp.getType(), offset); + ConstantOp::create(rewriter, op.getLoc(), offsetOp.getType(), offset); rewriter.replaceOpWithNewOp(op, inputSlice.getInput(), newOffset); return success(); @@ -2760,8 +2761,9 @@ LogicalResult ArrayGetOp::canonicalize(ArrayGetOp op, } unsigned indexWidth = size == 1 ? 1 : llvm::Log2_64_Ceil(size); - auto newIdxOp = rewriter.create( - op.getLoc(), rewriter.getIntegerType(indexWidth), elemIndex); + auto newIdxOp = + ConstantOp::create(rewriter, op.getLoc(), + rewriter.getIntegerType(indexWidth), elemIndex); rewriter.replaceOpWithNewOp(op, input, newIdxOp); return success(); diff --git a/lib/Dialect/Handshake/MemoryInterfaces.cpp b/lib/Dialect/Handshake/MemoryInterfaces.cpp index 7cd84f9aae..e3919386ab 100644 --- a/lib/Dialect/Handshake/MemoryInterfaces.cpp +++ b/lib/Dialect/Handshake/MemoryInterfaces.cpp @@ -100,14 +100,14 @@ LogicalResult MemoryInterfaceBuilder::instantiateInterfaces( if (!inputs.mcInputs.empty() && inputs.lsqInputs.empty()) { // We only need a memory controller - mcOp = builder.create( - loc, memref, memStart, inputs.mcInputs, ctrlEnd, inputs.mcBlocks, - mcNumLoads); + mcOp = handshake::MemoryControllerOp::create(builder, loc, memref, memStart, + inputs.mcInputs, ctrlEnd, + inputs.mcBlocks, mcNumLoads); } else if (inputs.mcInputs.empty() && !inputs.lsqInputs.empty()) { // We only need an LSQ - lsqOp = builder.create(loc, memref, memStart, - inputs.lsqInputs, ctrlEnd, - inputs.lsqGroupSizes, lsqNumLoads); + lsqOp = handshake::LSQOp::create(builder, loc, memref, memStart, + inputs.lsqInputs, ctrlEnd, + inputs.lsqGroupSizes, lsqNumLoads); } else { // We need a MC and an LSQ. They need to be connected with 4 new channels // so that the LSQ can forward its loads and stores to the MC. We need @@ -129,16 +129,16 @@ LogicalResult MemoryInterfaceBuilder::instantiateInterfaces( // Create the memory controller, adding 1 to its load count so that it // generates a load data result for the LSQ - mcOp = builder.create( - loc, memref, memStart, inputs.mcInputs, ctrlEnd, inputs.mcBlocks, - mcNumLoads + 1); + mcOp = handshake::MemoryControllerOp::create( + builder, loc, memref, memStart, inputs.mcInputs, ctrlEnd, + inputs.mcBlocks, mcNumLoads + 1); // Add the MC's load data result to the LSQ's inputs and create the LSQ, // passing a flag to the builder so that it generates the necessary // outputs that will go to the MC inputs.lsqInputs.push_back(mcOp.getOutputs().back()); - lsqOp = builder.create(loc, mcOp, inputs.lsqInputs, - inputs.lsqGroupSizes, lsqNumLoads); + lsqOp = handshake::LSQOp::create(builder, loc, mcOp, inputs.lsqInputs, + inputs.lsqGroupSizes, lsqNumLoads); // Resolve the backedges to fully connect the MC and LSQ ValueRange lsqMemResults = lsqOp.getOutputs().take_back(3); @@ -177,8 +177,8 @@ Value MemoryInterfaceBuilder::getMCControl(Value ctrl, unsigned numStores, builder.setInsertionPointAfter(defOp); else builder.setInsertionPointToStart(ctrl.getParentBlock()); - handshake::ConstantOp cstOp = builder.create( - ctrl.getLoc(), builder.getI32IntegerAttr(numStores), ctrl); + handshake::ConstantOp cstOp = handshake::ConstantOp::create( + builder, ctrl.getLoc(), builder.getI32IntegerAttr(numStores), ctrl); inheritBBFromValue(ctrl, cstOp); return cstOp.getResult(); } diff --git a/lib/Transforms/ArithReduceStrength.cpp b/lib/Transforms/ArithReduceStrength.cpp index e156c63d6e..046a67d375 100644 --- a/lib/Transforms/ArithReduceStrength.cpp +++ b/lib/Transforms/ArithReduceStrength.cpp @@ -14,7 +14,6 @@ #include "dynamatic/Transforms/ArithReduceStrength.h" #include "dynamatic/Analysis/NumericAnalysis.h" -#include "dynamatic/Dialect/Handshake/HandshakeOps.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" @@ -22,7 +21,6 @@ #include "mlir/IR/MLIRContext.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "llvm/ADT/STLExtras.h" -#include "llvm/Support/raw_ostream.h" #include using namespace mlir; @@ -82,10 +80,9 @@ Value OpTree::buildTreeRecursive( return valIt->second; } auto cstResult = - rewriter - .create( - op->getLoc(), - rewriter.getIntegerAttr(op->getResult(0).getType(), *value)) + arith::ConstantOp::create( + rewriter, op->getLoc(), + rewriter.getIntegerAttr(op->getResult(0).getType(), *value)) .getResult(); cstCache[*value] = cstResult; return cstResult; @@ -104,15 +101,15 @@ Value OpTree::buildTreeRecursive( Value result; switch (opType) { case OpType::ADD: - result = rewriter.create(op->getLoc(), leftVal, rightVal) + result = arith::AddIOp::create(rewriter, op->getLoc(), leftVal, rightVal) .getResult(); break; case OpType::SUB: - result = rewriter.create(op->getLoc(), leftVal, rightVal) + result = arith::SubIOp::create(rewriter, op->getLoc(), leftVal, rightVal) .getResult(); break; case OpType::SHIFT_LEFT: - result = rewriter.create(op->getLoc(), leftVal, rightVal) + result = arith::ShLIOp::create(rewriter, op->getLoc(), leftVal, rightVal) .getResult(); break; } @@ -198,19 +195,19 @@ struct ReplaceMulNegOneUsers : public OpRewritePattern { Value newLhs = isLhs ? addOp.getRhs() : addOp.getLhs(); rewriter.replaceOp( user, - rewriter.create(loc, newLhs, oprd)->getResults()); + arith::SubIOp::create(rewriter, loc, newLhs, oprd)->getResults()); anyChange = true; } else if (arith::SubIOp subOp = dyn_cast(user)) { // Substractions are replaced with an equivalemt additiom and, // potentially, a sign flip (when the multiplication provides the RHS) if (mulRes == subOp.getRhs()) { rewriter.replaceOp( - user, rewriter.create(loc, subOp.getLhs(), oprd) + user, arith::AddIOp::create(rewriter, loc, subOp.getLhs(), oprd) ->getResults()); anyChange = true; } else { - arith::AddIOp addOp = rewriter.create( - mulOp->getLoc(), oprd, subOp.getRhs()); + arith::AddIOp addOp = arith::AddIOp::create(rewriter, mulOp->getLoc(), + oprd, subOp.getRhs()); Value addRes = addOp.getResult(); Type dataType = addRes.getType(); @@ -221,19 +218,19 @@ struct ReplaceMulNegOneUsers : public OpRewritePattern { IntegerAttr intAttr = rewriter.getIntegerAttr(dataType, getMaskAllOnes(dataType)); arith::ConstantOp maskOp = - rewriter.create(loc, intAttr); + arith::ConstantOp::create(rewriter, loc, intAttr); // Then create the XOR between the first addition's result and the // mask, inverting the former's bits arith::XOrIOp xorOp = - rewriter.create(loc, addRes, maskOp.getResult()); + arith::XOrIOp::create(rewriter, loc, addRes, maskOp.getResult()); // Finally, add one to the XOR's output to get the negated version of // the first result and replace the initial operation - arith::ConstantOp cstOneOp = rewriter.create( - loc, rewriter.getIntegerAttr(dataType, 1)); - arith::AddIOp negAddOp = rewriter.create( - loc, xorOp.getResult(), cstOneOp.getResult()); + arith::ConstantOp cstOneOp = arith::ConstantOp::create( + rewriter, loc, rewriter.getIntegerAttr(dataType, 1)); + arith::AddIOp negAddOp = arith::AddIOp::create( + rewriter, loc, xorOp.getResult(), cstOneOp.getResult()); rewriter.replaceOp(user, negAddOp->getResults()); anyChange = true; } @@ -486,8 +483,8 @@ struct ArithReduceStrengthPass /// (area, performance, mixed) patterns.add(maxAdderDepthMul, ctx); - if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns), - config))) + if (failed( + applyPatternsGreedily(getOperation(), std::move(patterns), config))) signalPassFailure(); }; }; diff --git a/lib/Transforms/BufferPlacement/HandshakePlaceBuffers.cpp b/lib/Transforms/BufferPlacement/HandshakePlaceBuffers.cpp index e7a25ed7c3..712506effc 100644 --- a/lib/Transforms/BufferPlacement/HandshakePlaceBuffers.cpp +++ b/lib/Transforms/BufferPlacement/HandshakePlaceBuffers.cpp @@ -245,8 +245,8 @@ void HandshakePlaceBuffersPass::runOnOperation() { if (!failed( timingDB.getLatency(op, SignalType::DATA, latency, targetCP))) { - int64_t latency_int = static_cast(latency); - latencyInterface.setLatency(latency_int); + int64_t latencyInt = static_cast(latency); + latencyInterface.setLatency(latencyInt); } else { op->emitError("Failed to get latency from timing model"); return signalPassFailure(); @@ -814,8 +814,8 @@ void HandshakePlaceBuffersPass::instantiateBuffers(BufferPlacement &placement, if (numSlots == 0) return; - auto bufOp = builder.create( - bufferIn.getLoc(), bufferIn, numSlots, bufferType); + auto bufOp = handshake::BufferOp::create(builder, bufferIn.getLoc(), + bufferIn, numSlots, bufferType); placedBuffers.push_back(bufOp); inheritBB(opDst, bufOp); nameAnalysis.setName(bufOp); diff --git a/lib/Transforms/DropUnlistedFunctions.cpp b/lib/Transforms/DropUnlistedFunctions.cpp index 992580b22d..b3aa47b20b 100644 --- a/lib/Transforms/DropUnlistedFunctions.cpp +++ b/lib/Transforms/DropUnlistedFunctions.cpp @@ -1,36 +1,17 @@ #include "dynamatic/Transforms/DropUnlistedFunctions.h" #include "dynamatic/Support/LLVM.h" -#include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h" -#include "mlir/Conversion/LLVMCommon/VectorPattern.h" -#include "mlir/Dialect/Arith/IR/Arith.h" -#include "mlir/Dialect/ControlFlow/IR/ControlFlow.h" -#include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" -#include "mlir/Dialect/LLVMIR/LLVMTypes.h" -#include "mlir/Dialect/Math/IR/Math.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/BuiltinOps.h" -#include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/OwningOpRef.h" -#include "mlir/IR/PatternMatch.h" -#include "mlir/IR/ValueRange.h" #include "mlir/Pass/Pass.h" -#include "mlir/Support/LogicalResult.h" #include "mlir/Transforms/DialectConversion.h" #include "llvm/ADT/STLExtras.h" -#include "llvm/ADT/SmallVector.h" -#include "llvm/ADT/StringRef.h" -#include "llvm/IR/DerivedTypes.h" #include "llvm/Support/Casting.h" -#include "llvm/Support/Process.h" -#include "llvm/Support/raw_ostream.h" -#include -#include - using namespace mlir; using namespace dynamatic; diff --git a/lib/Transforms/FlattenMemRefRowMajor.cpp b/lib/Transforms/FlattenMemRefRowMajor.cpp index 47684e1a4a..c31a755996 100644 --- a/lib/Transforms/FlattenMemRefRowMajor.cpp +++ b/lib/Transforms/FlattenMemRefRowMajor.cpp @@ -49,7 +49,7 @@ static Value flattenIndices(ConversionPatternRewriter &rewriter, Location loc, if (numIndices == 0) { // Singleton memref (e.g. memref) - return 0 - return rewriter.create(loc, rewriter.getIndexAttr(0)) + return arith::ConstantOp::create(rewriter, loc, rewriter.getIndexAttr(0)) .getResult(); } @@ -85,24 +85,22 @@ static Value flattenIndices(ConversionPatternRewriter &rewriter, Location loc, // Multiply product by the current index operand if (llvm::isPowerOf2_64(dimProduct)) { auto constant = - rewriter - .create( - loc, rewriter.getIndexAttr(llvm::Log2_64(dimProduct))) + arith::ConstantOp::create( + rewriter, loc, rewriter.getIndexAttr(llvm::Log2_64(dimProduct))) .getResult(); - partialIdx = - rewriter.create(loc, partialIdx, constant).getResult(); + partialIdx = arith::ShLIOp::create(rewriter, loc, partialIdx, constant) + .getResult(); } else { - auto constant = - rewriter - .create(loc, rewriter.getIndexAttr(dimProduct)) - .getResult(); - partialIdx = - rewriter.create(loc, partialIdx, constant).getResult(); + auto constant = arith::ConstantOp::create( + rewriter, loc, rewriter.getIndexAttr(dimProduct)) + .getResult(); + partialIdx = arith::MulIOp::create(rewriter, loc, partialIdx, constant) + .getResult(); } // Sum up with the prior lower dimension accessors auto sumOp = - rewriter.create(loc, accumulatedArrayIndex, partialIdx); + arith::AddIOp::create(rewriter, loc, accumulatedArrayIndex, partialIdx); accumulatedArrayIndex = sumOp.getResult(); } return accumulatedArrayIndex; @@ -337,7 +335,7 @@ struct CallOpConversion : public OpConversionPattern { calledFunction, op.getCallee(), funcType); else newFuncOp = - rewriter.create(op.getLoc(), op.getCallee(), funcType); + func::FuncOp::create(rewriter, op.getLoc(), op.getCallee(), funcType); newFuncOp.setVisibility(SymbolTable::Visibility::Private); return success(); diff --git a/lib/Transforms/HandshakeCanonicalize.cpp b/lib/Transforms/HandshakeCanonicalize.cpp index 557f8b6a6b..592e7f3ee2 100644 --- a/lib/Transforms/HandshakeCanonicalize.cpp +++ b/lib/Transforms/HandshakeCanonicalize.cpp @@ -67,7 +67,7 @@ struct EraseSingleInputMuxes : public OpRewritePattern { // Insert a sink to consume the mux's select token rewriter.setInsertionPoint(muxOp); Value select = muxOp.getSelectOperand(); - rewriter.create(muxOp->getLoc(), select); + handshake::SinkOp::create(rewriter, muxOp->getLoc(), select); rewriter.replaceOp(muxOp, dataOperands.front()); return success(); @@ -94,8 +94,9 @@ struct EraseSingleInputControlMerges rewriter.setInsertionPoint(cmergeOp); // Create a source operation for the constant - handshake::SourceOp srcOp = rewriter.create( - cmergeOp->getLoc(), handshake::ControlType::get(getContext())); + handshake::SourceOp srcOp = handshake::SourceOp::create( + rewriter, cmergeOp->getLoc(), + handshake::ControlType::get(getContext())); inheritBB(cmergeOp, srcOp); /// NOTE: Sourcing this value may cause problems with very exotic uses of @@ -105,8 +106,8 @@ struct EraseSingleInputControlMerges // Build the attribute for the constant Type indexResType = indexRes.getType().getDataType(); - handshake::ConstantOp cstOp = rewriter.create( - cmergeOp.getLoc(), rewriter.getIntegerAttr(indexResType, 0), + handshake::ConstantOp cstOp = handshake::ConstantOp::create( + rewriter, cmergeOp.getLoc(), rewriter.getIntegerAttr(indexResType, 0), srcOp.getResult()); inheritBB(cmergeOp, cstOp); @@ -138,8 +139,8 @@ struct DowngradeIndexlessControlMerge // Create a merge operation to replace the cmerge rewriter.setInsertionPoint(cmergeOp); - handshake::MergeOp mergeOp = rewriter.create( - cmergeOp.getLoc(), cmergeOp->getOperands()); + handshake::MergeOp mergeOp = handshake::MergeOp::create( + rewriter, cmergeOp.getLoc(), cmergeOp->getOperands()); inheritBB(cmergeOp, mergeOp); // Replace the cmerge's data result with the merge's result, erase any @@ -168,7 +169,7 @@ struct HandshakeCanonicalizePass patterns.add(ctx); - if (failed(applyPatternsAndFoldGreedily(mod, std::move(patterns), config))) + if (failed(applyPatternsGreedily(mod, std::move(patterns), config))) return signalPassFailure(); }; }; diff --git a/lib/Transforms/HandshakeMaterialize.cpp b/lib/Transforms/HandshakeMaterialize.cpp index a7b2db10da..37a84721e6 100644 --- a/lib/Transforms/HandshakeMaterialize.cpp +++ b/lib/Transforms/HandshakeMaterialize.cpp @@ -90,7 +90,7 @@ static void materializeValue(Value val, OpBuilder &builder) { return; if (val.use_empty()) { builder.setInsertionPointAfterValue(val); - builder.create(val.getLoc(), val); + handshake::SinkOp::create(builder, val.getLoc(), val); return; } if (val.hasOneUse()) @@ -104,7 +104,7 @@ static void materializeValue(Value val, OpBuilder &builder) { // Insert a fork with as many results as the value has uses builder.setInsertionPointAfterValue(val); - auto forkOp = builder.create(val.getLoc(), val, numUses); + auto forkOp = handshake::ForkOp::create(builder, val.getLoc(), val, numUses); if (Operation *defOp = val.getDefiningOp()) inheritBB(defOp, forkOp); @@ -166,8 +166,8 @@ static void promoteEagerToLazyForks(handshake::FuncOp funcOp) { } builder.setInsertionPoint(forkOp); - handshake::LazyForkOp lazyForkOp = builder.create( - forkOp->getLoc(), forkOp.getOperand(), numLazyForkOutputs); + handshake::LazyForkOp lazyForkOp = handshake::LazyForkOp::create( + builder, forkOp->getLoc(), forkOp.getOperand(), numLazyForkOutputs); inheritBB(forkOp, lazyForkOp); // Replace the original fork's outputs that are part of the memory control @@ -192,8 +192,8 @@ static void promoteEagerToLazyForks(handshake::FuncOp funcOp) { if (!lazyResults.contains(res)) res.replaceAllUsesWith(lazyForkOp->getResults().back()); } else { - handshake::ForkOp eagerForkOp = builder.create( - forkOp->getLoc(), lazyForkOp->getResults().back(), + handshake::ForkOp eagerForkOp = handshake::ForkOp::create( + builder, forkOp->getLoc(), lazyForkOp->getResults().back(), numValuesWithoutLazyConstr); inheritBB(forkOp, eagerForkOp); @@ -245,8 +245,9 @@ struct MinimizeForkSizes : OpRewritePattern { if (!usedForkResults.empty()) { // Create a new fork operation rewriter.setInsertionPoint(forkOp); - handshake::ForkOp newForkOp = rewriter.create( - forkOp.getLoc(), forkOp.getOperand(), usedForkResults.size()); + handshake::ForkOp newForkOp = handshake::ForkOp::create( + rewriter, forkOp.getLoc(), forkOp.getOperand(), + usedForkResults.size()); inheritBB(forkOp, newForkOp); // Replace results with actual uses of the original fork with results from @@ -285,8 +286,8 @@ struct EliminateForksToForks : OpRewritePattern { if (isForkOprdSingleUse) --totalNumResults; rewriter.setInsertionPoint(defForkOp); - handshake::ForkOp newForkOp = rewriter.create( - defForkOp.getLoc(), defForkOp.getOperand(), totalNumResults); + handshake::ForkOp newForkOp = handshake::ForkOp::create( + rewriter, defForkOp.getLoc(), defForkOp.getOperand(), totalNumResults); inheritBB(defForkOp, newForkOp); // Replace the defining fork's results with the first results of the new @@ -367,8 +368,7 @@ struct HandshakeMaterializePass patterns .add( ctx); - if (failed( - applyPatternsAndFoldGreedily(modOp, std::move(patterns), config))) + if (failed(applyPatternsGreedily(modOp, std::move(patterns), config))) return signalPassFailure(); // Finally, promote forks to lazy wherever necessary diff --git a/lib/Transforms/HandshakeMinimizeCstWidth.cpp b/lib/Transforms/HandshakeMinimizeCstWidth.cpp index eb483c0fa9..b6eedc07c4 100644 --- a/lib/Transforms/HandshakeMinimizeCstWidth.cpp +++ b/lib/Transforms/HandshakeMinimizeCstWidth.cpp @@ -89,8 +89,9 @@ static handshake::ExtSIOp insertExtOp(handshake::ConstantOp toExtend, handshake::ConstantOp toReplace, PatternRewriter &rewriter) { rewriter.setInsertionPointAfter(toExtend); - auto extOp = rewriter.create( - toExtend.getLoc(), toReplace.getResult().getType(), toExtend.getResult()); + auto extOp = handshake::ExtSIOp::create(rewriter, toExtend.getLoc(), + toReplace.getResult().getType(), + toExtend.getResult()); inheritBB(toExtend, extOp); return extOp; } @@ -162,8 +163,8 @@ struct MinimizeConstantBitwidth } // Create a new constant to replace the matched one with - auto newCstOp = rewriter.create( - cstOp->getLoc(), newAttr, cstOp.getCtrl()); + auto newCstOp = handshake::ConstantOp::create(rewriter, cstOp->getLoc(), + newAttr, cstOp.getCtrl()); rewriter.replaceOp(cstOp, insertExtOp(newCstOp, cstOp, rewriter)); return success(); } @@ -192,7 +193,7 @@ struct HandshakeMinimizeCstWidthPass config.setRegionSimplificationLevel(GreedySimplifyRegionLevel::Disabled); RewritePatternSet patterns{ctx}; patterns.add(optNegatives, ctx); - if (failed(applyPatternsAndFoldGreedily(mod, std::move(patterns), config))) + if (failed(applyPatternsGreedily(mod, std::move(patterns), config))) return signalPassFailure(); LLVM_DEBUG(llvm::dbgs() << "Number of saved bits is " << savedBits << "\n"); diff --git a/lib/Transforms/HandshakeOptimizeBitwidths.cpp b/lib/Transforms/HandshakeOptimizeBitwidths.cpp index 606e1bb414..b2f6977333 100644 --- a/lib/Transforms/HandshakeOptimizeBitwidths.cpp +++ b/lib/Transforms/HandshakeOptimizeBitwidths.cpp @@ -226,12 +226,12 @@ static ChannelVal modBitWidth(ExtValue extVal, unsigned targetWidth, if (ext == ExtType::LOGICAL || (ext == ExtType::UNKNOWN && val.getType().getDataType().isUnsignedInteger())) { - newOp = rewriter.create(loc, dstChannelType, val); + newOp = handshake::ExtUIOp::create(rewriter, loc, dstChannelType, val); } else { - newOp = rewriter.create(loc, dstChannelType, val); + newOp = handshake::ExtSIOp::create(rewriter, loc, dstChannelType, val); } } else { - newOp = rewriter.create(loc, dstChannelType, val); + newOp = handshake::TruncIOp::create(rewriter, loc, dstChannelType, val); } inheritBBFromValue(val, newOp); @@ -334,7 +334,7 @@ static void modArithOp(Op op, ExtValue lhs, ExtValue rhs, unsigned optWidth, Value newRhs = modBitWidth(rhs, optWidth, rewriter); rewriter.setInsertionPoint(op); auto newOp = - rewriter.create(op.getLoc(), newLhs.getType(), newLhs, newRhs); + Op::create(rewriter, op.getLoc(), newLhs.getType(), newLhs, newRhs); Value newRes = modBitWidth({newOp.getResult(), extRes}, resWidth, rewriter); namer.replaceOp(op, newOp); inheritBB(op, newOp); @@ -431,7 +431,7 @@ class OptDataConfig { /// default implementation of this function. virtual Op createOp(ArrayRef newResTypes, ArrayRef newOperands, PatternRewriter &rewriter) { - return rewriter.create(op.getLoc(), newResTypes, newOperands); + return Op::create(rewriter, op.getLoc(), newResTypes, newOperands); } /// Determines the list of values that the original operation will be replaced @@ -545,9 +545,9 @@ class BufferDataConfig : public OptDataConfig { handshake::BufferOp createOp(ArrayRef newResTypes, ArrayRef newOperands, PatternRewriter &rewriter) override { - return rewriter.create( - op.getLoc(), newOperands[0].getType(), newOperands[0], - op->getAttrDictionary().getValue()); + return handshake::BufferOp::create(rewriter, op.getLoc(), + newOperands[0].getType(), newOperands[0], + op->getAttrDictionary().getValue()); } }; @@ -676,9 +676,9 @@ struct HandshakeMuxSelect : public OpRewritePattern { modBitWidth({selectOperand, ExtType::LOGICAL}, optWidth, rewriter)); auto dataOprds = muxOp.getDataOperands(); newOperands.append(dataOprds.begin(), dataOprds.end()); - auto newMuxOp = rewriter.create( - muxOp.getLoc(), muxOp->getResultTypes(), newOperands, - muxOp->getAttrs()); + auto newMuxOp = handshake::MuxOp::create(rewriter, muxOp.getLoc(), + muxOp->getResultTypes(), + newOperands, muxOp->getAttrs()); namer.replaceOp(muxOp, newMuxOp); rewriter.replaceOp(muxOp, newMuxOp); return success(); @@ -718,8 +718,8 @@ struct HandshakeCMergeIndex cmergeOp->getOperandTypes().front(), indexType.withDataType(rewriter.getIntegerType(optWidth))}; rewriter.setInsertionPoint(cmergeOp); - auto newCmergeOp = rewriter.create( - cmergeOp.getLoc(), newResultTypes, cmergeOp.getDataOperands(), + auto newCmergeOp = handshake::ControlMergeOp::create( + rewriter, cmergeOp.getLoc(), newResultTypes, cmergeOp.getDataOperands(), cmergeOp->getAttrs()); namer.replaceOp(cmergeOp, newCmergeOp); Value modIndex = modBitWidth({newCmergeOp.getIndex(), ExtType::LOGICAL}, @@ -1111,8 +1111,8 @@ struct ArithSelect : public OpRewritePattern { Value newLhs = modBitWidth({minLhs, extLhs}, optWidth, rewriter); Value newRhs = modBitWidth({minRhs, extRhs}, optWidth, rewriter); rewriter.setInsertionPoint(selectOp); - auto newOp = rewriter.create( - selectOp.getLoc(), selectOp.getCondition(), newLhs, newRhs); + auto newOp = handshake::SelectOp::create( + rewriter, selectOp.getLoc(), selectOp.getCondition(), newLhs, newRhs); Value newRes = modBitWidth({newOp.getResult(), extLhs}, resWidth, rewriter); inheritBB(selectOp, newOp); namer.replaceOp(selectOp, newOp); @@ -1183,8 +1183,8 @@ struct ArithShift : public OpRewritePattern { Value newShifyBy = modBitWidth({minShiftBy, ExtType::LOGICAL}, optWidth, rewriter); rewriter.setInsertionPoint(op); - auto newOp = rewriter.create(op.getLoc(), newToShift.getType(), - newToShift, newShifyBy); + auto newOp = Op::create(rewriter, op.getLoc(), newToShift.getType(), + newToShift, newShifyBy); ChannelVal newRes = newOp.getResult(); if (isRightShift) // In the case of a right shift, we first truncate the result of the @@ -1248,8 +1248,8 @@ struct ArithCmpFW : public OpRewritePattern { Value newLhs = modBitWidth({minLhs, extLhs}, optWidth, rewriter); Value newRhs = modBitWidth({minRhs, extRhs}, optWidth, rewriter); rewriter.setInsertionPoint(cmpOp); - auto newOp = rewriter.create( - cmpOp.getLoc(), cmpOp.getPredicate(), newLhs, newRhs); + auto newOp = handshake::CmpIOp::create( + rewriter, cmpOp.getLoc(), cmpOp.getPredicate(), newLhs, newRhs); namer.replaceOp(cmpOp, newOp); inheritBB(cmpOp, newOp); @@ -1549,8 +1549,7 @@ struct HandshakeOptimizeBitwidthsPass RewritePatternSet patterns(ctx); patterns.add(getAnalysis(), ctx); - if (failed( - applyPatternsAndFoldGreedily(modOp, std::move(patterns), config))) + if (failed(applyPatternsGreedily(modOp, std::move(patterns), config))) return signalPassFailure(); for (auto funcOp : modOp.getOps()) { @@ -1568,8 +1567,8 @@ struct HandshakeOptimizeBitwidthsPass ops.clear(); llvm::transform(funcOp.getOps(), std::back_inserter(ops), [&](Operation &op) { return &op; }); - return applyOpPatternsAndFold(ops, std::move(patterns), config, - &changed); + return applyOpPatternsGreedily(ops, std::move(patterns), config, + &changed); }; // Apply the forward and backward pass continuously until the IR converges diff --git a/lib/Transforms/PushConstants.cpp b/lib/Transforms/PushConstants.cpp index 3d8085f36e..49620bf1e2 100644 --- a/lib/Transforms/PushConstants.cpp +++ b/lib/Transforms/PushConstants.cpp @@ -33,7 +33,7 @@ static LogicalResult pushConstants(func::FuncOp funcOp, MLIRContext *ctx) { // Determine blocks where the constant is used DenseMap> usingBlocks; for (auto *user : constantOp.getResult().getUsers()) - if (auto block = user->getBlock(); block != defBlock) + if (auto *block = user->getBlock(); block != defBlock) usingBlocks[block].push_back(user); else usedByDefiningBlock = true; @@ -41,9 +41,9 @@ static LogicalResult pushConstants(func::FuncOp funcOp, MLIRContext *ctx) { // Create a new constant operation in every block where the constant is used for (auto &[block, users] : usingBlocks) { builder.setInsertionPointToStart(block); - auto newCstOp = builder.create(constantOp->getLoc(), - constantOp.getValue()); - for (auto user : users) + auto newCstOp = arith::ConstantOp::create(builder, constantOp->getLoc(), + constantOp.getValue()); + for (auto *user : users) user->replaceUsesOfWith(constantOp.getResult(), newCstOp.getResult()); } diff --git a/lib/Transforms/ScfRotateForLoops.cpp b/lib/Transforms/ScfRotateForLoops.cpp index f15b227aa9..449594dd22 100644 --- a/lib/Transforms/ScfRotateForLoops.cpp +++ b/lib/Transforms/ScfRotateForLoops.cpp @@ -44,9 +44,9 @@ struct RotateLoop : public OpRewritePattern { // Create a do-while that is equivalent to the loop ValueRange whileArgsRange(whileOpArgs); - auto whileOp = - rewriter.create(forOp.getLoc(), whileArgsRange.getTypes(), - whileOpArgs, nullptr, nullptr); + auto whileOp = scf::WhileOp::create(rewriter, forOp.getLoc(), + whileArgsRange.getTypes(), whileOpArgs, + nullptr, nullptr); // Move all operations from the for loop body to the "before" region of the // while loop @@ -56,10 +56,12 @@ struct RotateLoop : public OpRewritePattern { // Check the for loop condition at the end of the before block rewriter.setInsertionPointToEnd(&beforeBlock); - auto addOp = rewriter.create( - forOp->getLoc(), beforeBlock.getArguments().front(), forOp.getStep()); - auto cmpOp = rewriter.create( - forOp->getLoc(), pred, addOp.getResult(), forOp.getUpperBound()); + auto addOp = arith::AddIOp::create(rewriter, forOp->getLoc(), + beforeBlock.getArguments().front(), + forOp.getStep()); + auto cmpOp = + arith::CmpIOp::create(rewriter, forOp->getLoc(), pred, + addOp.getResult(), forOp.getUpperBound()); // Get the yield operation that was moved from the for loop body to the // before block @@ -78,7 +80,7 @@ struct RotateLoop : public OpRewritePattern { // the before block Block &afterBlock = whileOp.getAfter().front(); rewriter.setInsertionPointToStart(&afterBlock); - rewriter.create(condOp->getLoc(), afterBlock.getArguments()); + scf::YieldOp::create(rewriter, condOp->getLoc(), afterBlock.getArguments()); // Replace for's results with while's results (drop while's first result, // which is the IV) @@ -135,8 +137,8 @@ struct ScfForLoopRotationPass RewritePatternSet patterns{ctx}; patterns.add(ctx); - if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns), - config))) + if (failed( + applyPatternsGreedily(getOperation(), std::move(patterns), config))) signalPassFailure(); }; }; diff --git a/lib/Transforms/ScfSimpleIfToSelect.cpp b/lib/Transforms/ScfSimpleIfToSelect.cpp index 6ade650f37..a6b02162d1 100644 --- a/lib/Transforms/ScfSimpleIfToSelect.cpp +++ b/lib/Transforms/ScfSimpleIfToSelect.cpp @@ -110,9 +110,8 @@ Value ConvertIfToSelect::hoistSingleArithOp(scf::IfOp ifOp, Operation *arithOp, if (!otherValIsFalse) std::swap(trueVal, falseVal); - return rewriter - .create(ifOp->getLoc(), ifOp.getCondition(), trueVal, - falseVal) + return arith::SelectOp::create(rewriter, ifOp->getLoc(), ifOp.getCondition(), + trueVal, falseVal) .getResult(); }; @@ -121,8 +120,8 @@ Value ConvertIfToSelect::createSelectThenArithOp( Value otherArithVal, bool otherValIsRhs, PatternRewriter &rewriter) const { rewriter.setInsertionPoint(ifOp); - arith::SelectOp selectOp = rewriter.create( - ifOp->getLoc(), ifOp.getCondition(), trueVal, falseVal); + arith::SelectOp selectOp = arith::SelectOp::create( + rewriter, ifOp->getLoc(), ifOp.getCondition(), trueVal, falseVal); Value lhs = selectOp.getResult(); Value rhs = otherArithVal; if (!otherValIsRhs) @@ -225,8 +224,8 @@ Value ConvertIfToSelect::tryToConvert(scf::IfOp ifOp, // If the then block is just a yield too, then the entire if is equivalent to // a select - return rewriter.create(ifOp.getLoc(), ifOp.getCondition(), - thenYielded, elseYielded); + return arith::SelectOp::create(rewriter, ifOp.getLoc(), ifOp.getCondition(), + thenYielded, elseYielded); } namespace { @@ -245,8 +244,8 @@ struct ScfSimpleIfToSelectPass RewritePatternSet patterns{ctx}; patterns.add(ctx); - if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns), - config))) + if (failed( + applyPatternsGreedily(getOperation(), std::move(patterns), config))) signalPassFailure(); }; }; diff --git a/tools/translate-llvm-to-std/TranslateLLVMToStd.cpp b/tools/translate-llvm-to-std/TranslateLLVMToStd.cpp index c8a66fff32..7add9c6809 100644 --- a/tools/translate-llvm-to-std/TranslateLLVMToStd.cpp +++ b/tools/translate-llvm-to-std/TranslateLLVMToStd.cpp @@ -172,8 +172,8 @@ void TranslateLLVMToStd::translateFunction(llvm::Function *llvmFunc) { } auto funcType = builder.getFunctionType(argTypes, resTypes); - auto funcOp = builder.create(builder.getUnknownLoc(), - llvmFunc->getName(), funcType); + auto funcOp = func::FuncOp::create(builder, builder.getUnknownLoc(), + llvmFunc->getName(), funcType); initializeBlocksAndBlockMapping(llvmFunc, funcOp); @@ -225,8 +225,9 @@ void TranslateLLVMToStd::translateGlobalVars() { initialValueAttr = convertInitializerToDenseElemAttr(globalVar, ctx); } - auto globalOp = builder.create( + auto globalOp = memref::GlobalOp::create( // clang-format off + builder, UnknownLoc::get(ctx), symNameAttr, visibilityAttr, @@ -275,9 +276,9 @@ void TranslateLLVMToStd::translateInstruction(llvm::Instruction *inst) { } else if (auto *returnOp = dyn_cast(inst)) { if (returnOp->getNumOperands() == 1) { mlir::Value arg = valueMap[inst->getOperand(0)]; - builder.create(loc, arg); + func::ReturnOp::create(builder, loc, arg); } else { - builder.create(loc); + func::ReturnOp::create(builder, loc); } } else if (isa(inst)) { // At this stage, Phi nodes are all converted to the block arguments @@ -350,8 +351,8 @@ void TranslateLLVMToStd::createConstants(llvm::Function *llvmFunc) { if (auto *intConst = dyn_cast(val)) { APInt intVal = intConst->getValue(); - auto constOp = builder.create( - loc, intVal.getSExtValue(), intVal.getBitWidth()); + auto constOp = arith::ConstantIntOp::create( + builder, loc, intVal.getSExtValue(), intVal.getBitWidth()); valueMap[val] = constOp->getResult(0); loc = constOp->getLoc(); } @@ -359,13 +360,13 @@ void TranslateLLVMToStd::createConstants(llvm::Function *llvmFunc) { if (auto *floatConst = dyn_cast(val)) { const APFloat &floatVal = floatConst->getValue(); if (&floatVal.getSemantics() == &llvm::APFloat::IEEEsingle()) { - auto constOp = builder.create( - loc, builder.getF32Type(), floatVal); + auto constOp = arith::ConstantFloatOp::create( + builder, loc, builder.getF32Type(), floatVal); valueMap[val] = constOp->getResult(0); loc = constOp->getLoc(); } else if (&floatVal.getSemantics() == &llvm::APFloat::IEEEdouble()) { - auto constOp = builder.create( - loc, builder.getF64Type(), floatVal); + auto constOp = arith::ConstantFloatOp::create( + builder, loc, builder.getF64Type(), floatVal); valueMap[val] = constOp->getResult(0); loc = constOp->getLoc(); } @@ -389,8 +390,8 @@ void TranslateLLVMToStd::createGetGlobals(llvm::Function *llvmFunc) { auto memrefType = globalOp.getType(); - auto getGlobalOp = builder.create( - loc, memrefType, globalOp.getSymName()); + auto getGlobalOp = memref::GetGlobalOp::create( + builder, loc, memrefType, globalOp.getSymName()); valueMap[val] = getGlobalOp.getResult(); } @@ -477,7 +478,7 @@ void TranslateLLVMToStd::translateICmpInst(llvm::ICmpInst *inst) { } auto op = - builder.create(UnknownLoc::get(ctx), predicate, lhs, rhs); + arith::CmpIOp::create(builder, UnknownLoc::get(ctx), predicate, lhs, rhs); valueMap[inst] = op->getResult(0); } @@ -508,7 +509,7 @@ void TranslateLLVMToStd::translateFCmpInst(llvm::FCmpInst *inst) { // clang-format on } auto op = - builder.create(UnknownLoc::get(ctx), predicate, lhs, rhs); + arith::CmpFOp::create(builder, UnknownLoc::get(ctx), predicate, lhs, rhs); valueMap[inst] = op->getResult(0); } @@ -531,8 +532,8 @@ void TranslateLLVMToStd::translateGEPInst(llvm::GetElementPtrInst *gepInst) { // NOTE: memref::LoadOp and memref::StoreOp expect their indices to be of // IndexType. Therefore, we cast the i32/i64 indices to IndexType. This // pattern will later be folded in the bitwidth optimization pass. - auto idxCastOp = builder.create( - UnknownLoc::get(ctx), builder.getIndexType(), mlirIndexValue); + auto idxCastOp = arith::IndexCastOp::create( + builder, UnknownLoc::get(ctx), builder.getIndexType(), mlirIndexValue); indexOperands.push_back(idxCastOp); } @@ -571,10 +572,10 @@ void TranslateLLVMToStd::translateGEPInst(llvm::GetElementPtrInst *gepInst) { } for (int i = 0; i < remainingConstZeros; i++) { - auto constZeroOp = builder.create( - UnknownLoc::get(ctx), builder.getI64IntegerAttr(0)); - auto idxCastOp = builder.create( - UnknownLoc::get(ctx), builder.getIndexType(), constZeroOp); + auto constZeroOp = arith::ConstantOp::create(builder, UnknownLoc::get(ctx), + builder.getI64IntegerAttr(0)); + auto idxCastOp = arith::IndexCastOp::create( + builder, UnknownLoc::get(ctx), builder.getIndexType(), constZeroOp); indexOperands.push_back(idxCastOp); } @@ -589,8 +590,9 @@ void TranslateLLVMToStd::translateBranchInst(llvm::BranchInst *inst) { BasicBlock *nextLLVMBB = dyn_cast_or_null(inst->getOperand(0)); assert(nextLLVMBB && "The unconditional branch doesn't have a BB as operand!"); - builder.create( + cf::BranchOp::create( // clang-format off + builder, loc, blockMap[nextLLVMBB], getBranchOperandsForCFGEdge(currLLVMBB, nextLLVMBB) @@ -608,8 +610,9 @@ void TranslateLLVMToStd::translateBranchInst(llvm::BranchInst *inst) { SmallVector trueOperands = getBranchOperandsForCFGEdge(currLLVMBB, trueDestBB); mlir::Value condition = valueMap[inst->getCondition()]; - builder.create( + cf::CondBranchOp::create( // clang-format off + builder, loc, condition, blockMap[trueDestBB], @@ -644,13 +647,14 @@ void TranslateLLVMToStd::translateLoadInst(llvm::LoadInst *loadInst) { auto memrefType = dyn_cast(memref.getType()); int constZerosToAdd = memrefType.getShape().size(); for (int i = 0; i < constZerosToAdd; i++) { - auto constZeroOp = this->builder.create( - loc, this->builder.getIndexAttr(0)); + auto constZeroOp = arith::ConstantOp::create( + this->builder, loc, this->builder.getIndexAttr(0)); indices.push_back(constZeroOp); } } mlir::Type resType = getMLIRType(loadInst->getType(), ctx); - auto newOp = builder.create(loc, resType, memref, indices); + auto newOp = + memref::LoadOp::create(this->builder, loc, resType, memref, indices); valueMap[loadInst] = newOp.getResult(); translateMemDepAndNameAttrs(loadInst, newOp, *ctx, builder); } @@ -681,15 +685,15 @@ void TranslateLLVMToStd::translateStoreInst(llvm::StoreInst *storeInst) { int constZerosToAdd = memrefType.getShape().size(); for (int i = 0; i < constZerosToAdd; i++) { - auto constZeroOp = this->builder.create( - loc, this->builder.getIndexAttr(0)); + auto constZeroOp = arith::ConstantOp::create( + this->builder, loc, this->builder.getIndexAttr(0)); indices.push_back(constZeroOp); } } mlir::Value storeValue = valueMap[storeInst->getValueOperand()]; auto newOp = - builder.create(loc, storeValue, memref, indices); + memref::StoreOp::create(this->builder, loc, storeValue, memref, indices); translateMemDepAndNameAttrs(storeInst, newOp, *ctx, builder); } @@ -706,7 +710,7 @@ void TranslateLLVMToStd::translateAllocaInst(llvm::AllocaInst *allocaInst) { auto memrefType = MemRefType::get(shape, getMLIRType(baseElementType, ctx)); - auto allocaOp = builder.create(loc, memrefType); + auto allocaOp = memref::AllocaOp::create(builder, loc, memrefType); valueMap[allocaInst] = allocaOp->getResult(0); } From 1ba0c62fe12b5eab4d7615bf0df82edd0c2c51da Mon Sep 17 00:00:00 2001 From: Jiahui Xu Date: Thu, 20 Nov 2025 16:14:37 +0100 Subject: [PATCH 07/27] Fix calling replaceUsesWith1 multiple times --- .../lib/Conversion/FtdCfToHandshake.cpp | 3 +- include/dynamatic/Conversion/CfToHandshake.h | 28 ++++-- .../CfToHandshake/CfToHandshake.cpp | 99 ++++++++++--------- 3 files changed, 74 insertions(+), 56 deletions(-) diff --git a/experimental/lib/Conversion/FtdCfToHandshake.cpp b/experimental/lib/Conversion/FtdCfToHandshake.cpp index 2b1057f24a..461c05c01c 100644 --- a/experimental/lib/Conversion/FtdCfToHandshake.cpp +++ b/experimental/lib/Conversion/FtdCfToHandshake.cpp @@ -310,7 +310,8 @@ LogicalResult ftd::FtdLowerFuncToHandshake::matchAndRewrite( // Create the memory interface according to the algorithm from FPGA'23. This // functions introduce new data dependencies that are then passed to FTD for // correctly delivering data between them like any real data dependencies - if (failed(verifyAndCreateMemInterfaces(funcOp, rewriter, memInfo))) + if (failed(verifyAndCreateMemInterfaces(funcOp, rewriter, memInfo, + argReplacements))) return failure(); // Convert the constants and undefined values from the `arith` dialect to diff --git a/include/dynamatic/Conversion/CfToHandshake.h b/include/dynamatic/Conversion/CfToHandshake.h index ab95e9ad84..807289d278 100644 --- a/include/dynamatic/Conversion/CfToHandshake.h +++ b/include/dynamatic/Conversion/CfToHandshake.h @@ -35,6 +35,10 @@ class CfToHandshakeTypeConverter : public TypeConverter { CfToHandshakeTypeConverter(); }; +/// CfToHandshake replaces MLIR block diagrams with merge-like ops. +/// This data type records this mapping. +using BlockArgToMergeResult = DenseMap; + /// Converts a func-level function into a handshake-level function. The function /// signature gets an extra control-only argument to represent the starting /// point of the control network. If the function did not return any result, a @@ -110,12 +114,16 @@ class LowerFuncToHandshake : public DynOpConversionPattern { /// use which interface. The backedge builder is used to create temporary /// values for the data input to converted load ports. A flag for FTD is also /// introduced to tweak the rewriting process. - virtual LogicalResult - convertMemoryOps(handshake::FuncOp funcOp, - ConversionPatternRewriter &rewriter, - const DenseMap &memrefToFuncArgIndex, - BackedgeBuilder &edgeBuilder, MemInterfacesInfo &memInfo, - bool isFtd = false) const; + virtual LogicalResult convertMemoryOps( + // clang-format off + handshake::FuncOp funcOp, + ConversionPatternRewriter &rewriter, + const DenseMap &memrefToFuncArgIndex, + BackedgeBuilder &edgeBuilder, + MemInterfacesInfo &memInfo, + bool isFtd = false + // clang-format on + ) const; /// Verifies that LSQ groups derived from input IR annotations make sense /// (check for linear dominance property within each group and cross-group @@ -129,10 +137,10 @@ class LowerFuncToHandshake : public DynOpConversionPattern { /// - Both a `handhsake::MemoryControllerOp` and `handhsake::LSQOp` will be /// instantiated if some but not all of its accesses indicate that they should /// connect to an LSQ. - virtual LogicalResult - verifyAndCreateMemInterfaces(handshake::FuncOp funcOp, - ConversionPatternRewriter &rewriter, - MemInterfacesInfo &memInfo) const; + virtual LogicalResult verifyAndCreateMemInterfaces( + handshake::FuncOp funcOp, ConversionPatternRewriter &rewriter, + MemInterfacesInfo &memInfo, + const BlockArgToMergeResult &argReplacements) const; /// Sets an integer "bb" attribute on each operation to identify the basic /// block from which the operation originates in the std-level IR. diff --git a/lib/Conversion/CfToHandshake/CfToHandshake.cpp b/lib/Conversion/CfToHandshake/CfToHandshake.cpp index a4a6360e38..aca60de441 100644 --- a/lib/Conversion/CfToHandshake/CfToHandshake.cpp +++ b/lib/Conversion/CfToHandshake/CfToHandshake.cpp @@ -48,7 +48,6 @@ #include "mlir/IR/Value.h" #include "mlir/Support/LogicalResult.h" #include "mlir/Transforms/DialectConversion.h" -#include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/TypeSwitch.h" @@ -165,6 +164,27 @@ LogicalResult LowerFuncToHandshake::computeLinearDominance( return success(); } +/// This function finds the control merge result of a given Block. We +/// need to use this function to connect the control network to the memory +/// interface ops (mc, lsq). +/// +/// NOTE: this works only after we called addMergeOps +static Value +getBlockControlMergeOrFuncControl(const BlockArgToMergeResult &argReplacements, + Block *block) { + Operation *parentOp = block->getParentOp(); + + auto funcOp = cast(parentOp); + + assert(funcOp && "The parent Op of this block must be a handshake::FuncOp"); + + // Check if the block is the first block: + if (&(funcOp.getBlocks().front()) == block) { + return block->getArguments().back(); + } + return argReplacements.at(block->getArguments().back()); +} + //===-----------------------------------------------------------------------==// // CfToHandshakeTypeConverter //===-----------------------------------------------------------------------==// @@ -210,8 +230,6 @@ CfToHandshakeTypeConverter::CfToHandshakeTypeConverter() { // LowerFuncToHandshake //===-----------------------------------------------------------------------==// -using ArgReplacements = DenseMap; - LogicalResult LowerFuncToHandshake::matchAndRewrite( func::FuncOp lowerFuncOp, OpAdaptor /*adaptor*/, ConversionPatternRewriter &rewriter) const { @@ -235,7 +253,7 @@ LogicalResult LowerFuncToHandshake::matchAndRewrite( // Stores mapping from each value that passes through a merge-like operation // to the data result of that merge operation - ArgReplacements argReplacements; + BlockArgToMergeResult argReplacements; addMergeOps(funcOp, rewriter, argReplacements); addBranchOps(funcOp, rewriter); @@ -249,7 +267,8 @@ LogicalResult LowerFuncToHandshake::matchAndRewrite( // tagged with the BB they belong to (required by memory interface // instantiation logic) idBasicBlocks(funcOp, rewriter); - if (failed(verifyAndCreateMemInterfaces(funcOp, rewriter, memInfo))) + if (failed(verifyAndCreateMemInterfaces(funcOp, rewriter, memInfo, + argReplacements))) return failure(); idBasicBlocks(funcOp, rewriter); @@ -398,8 +417,6 @@ FailureOr LowerFuncToHandshake::lowerSignature( Region *oldBody = &funcOp.getBody(); - const TypeConverter *typeConv = getTypeConverter(); - // Convert the entry block's signature Block *entryBlock = &funcOp.getBody().front(); TypeConverter::SignatureConversion entryConversion( @@ -414,6 +431,9 @@ FailureOr LowerFuncToHandshake::lowerSignature( /*numOrigInputs=*/nonEntryBlock.getNumArguments()); setupBlockConversion(&nonEntryBlock, rewriter, nonEntryConversion); + + SmallVector origArg; + rewriter.applySignatureConversion(&nonEntryBlock, nonEntryConversion, getTypeConverter()); } @@ -575,9 +595,9 @@ void LowerFuncToHandshake::insertMerge(BlockArgument blockArg, } } -void LowerFuncToHandshake::addMergeOps(handshake::FuncOp funcOp, - ConversionPatternRewriter &rewriter, - ArgReplacements &argReplacements) const { +void LowerFuncToHandshake::addMergeOps( + handshake::FuncOp funcOp, ConversionPatternRewriter &rewriter, + BlockArgToMergeResult &argReplacements) const { // Create backedge builder to manage operands of merge operations between // insertion and reconnection BackedgeBuilder edgeBuilder(rewriter, funcOp.getLoc()); @@ -586,8 +606,6 @@ void LowerFuncToHandshake::addMergeOps(handshake::FuncOp funcOp, // instead as data operands) DenseMap> blockMerges; - Block *entryBlock = &funcOp.getBody().front(); - for (Block &block : llvm::drop_begin(funcOp)) { rewriter.setInsertionPointToStart(&block); @@ -894,7 +912,8 @@ LogicalResult LowerFuncToHandshake::convertMemoryOps( LogicalResult LowerFuncToHandshake::verifyAndCreateMemInterfaces( handshake::FuncOp funcOp, ConversionPatternRewriter &rewriter, - MemInterfacesInfo &memInfo) const { + MemInterfacesInfo &memInfo, + const BlockArgToMergeResult &argReplacements) const { if (memInfo.empty()) return success(); @@ -921,7 +940,10 @@ LogicalResult LowerFuncToHandshake::verifyAndCreateMemInterfaces( auto returns = funcOp.getOps(); assert(!returns.empty() && "no returns in function"); if (std::distance(returns.begin(), returns.end()) == 1) { - ctrlEnd = getBlockControl((*returns.begin())->getBlock()); + Value lastBlockCtrlArg = getBlockControlMergeOrFuncControl( + argReplacements, (*returns.begin())->getBlock()); + + ctrlEnd = lastBlockCtrlArg; } else { // Merge the control signals of all blocks with a return to create a control // representing the final control flow decision @@ -929,7 +951,8 @@ LogicalResult LowerFuncToHandshake::verifyAndCreateMemInterfaces( func::ReturnOp lastRetOp; for (func::ReturnOp retOp : returns) { lastRetOp = retOp; - controls.push_back(getBlockControl(retOp->getBlock())); + controls.push_back(getBlockControlMergeOrFuncControl(argReplacements, + retOp->getBlock())); } rewriter.setInsertionPointToStart(lastRetOp->getBlock()); auto mergeOp = @@ -945,8 +968,11 @@ LogicalResult LowerFuncToHandshake::verifyAndCreateMemInterfaces( // Create a mapping between each block and its control value in the right // format for the memory interface builder DenseMap ctrlVals; - for (auto [blockIdx, block] : llvm::enumerate(funcOp)) - ctrlVals.insert({blockIdx, getBlockControl(&block)}); + for (auto [blockIdx, block] : llvm::enumerate(funcOp)) { + Value blockCtrlVal; + blockCtrlVal = getBlockControlMergeOrFuncControl(argReplacements, &block); + ctrlVals.insert({blockIdx, blockCtrlVal}); + } // Each memory region is independent from the others for (auto &[memref, memAccesses] : memInfo) { @@ -1017,7 +1043,7 @@ void LowerFuncToHandshake::idBasicBlocks( LogicalResult LowerFuncToHandshake::flattenAndTerminate( handshake::FuncOp funcOp, ConversionPatternRewriter &rewriter, - const ArgReplacements &argReplacements) const { + const BlockArgToMergeResult &argReplacements) const { // Erase all cf-level terminators, accumulating operands to func-level returns // as we go @@ -1043,21 +1069,18 @@ LogicalResult LowerFuncToHandshake::flattenAndTerminate( // Inline all non-entry blocks into the entry block, erasing them as we go Operation *lastOp = &funcOp.front().back(); - for (Block &block : llvm::make_early_inc_range(funcOp)) { - if (block.isEntryBlock()) - continue; - + for (Block &block : llvm::make_early_inc_range(llvm::drop_begin(funcOp))) { // Replace all block arguments with the data result of merge-like // operations; this effectively connects all merges to the rest of the // circuit - SmallVector replacements; + SmallVector mergeResults; for (BlockArgument blockArg : block.getArguments()) { - Value mergeRes = argReplacements.at(blockArg); - // Replacing BA with merge results - rewriter.replaceAllUsesWith(blockArg, mergeRes); + mergeResults.push_back(argReplacements.at(blockArg)); } - // Replacing the block arguments with merge results - // rewriter.inlineBlockBefore(&block, lastOp, replacements); + // This call does two things: + // - Move all the ops into the target block of lastOp. + // - Replacing the block arguments with merge results (replacements). + rewriter.inlineBlockBefore(&block, lastOp, mergeResults); } // The terminator's operands are, in order @@ -1089,10 +1112,12 @@ LogicalResult LowerFuncToHandshake::flattenAndTerminate( } } } - endOprds.push_back(getBlockControl(funcOp.getBodyBlock())); + endOprds.push_back(getBlockControlMergeOrFuncControl(argReplacements, + funcOp.getBodyBlock())); auto endOp = handshake::EndOp::create(rewriter, lastOp->getLoc(), endOprds); endOp->setAttr(BB_ATTR_NAME, rewriter.getUI32IntegerAttr(exitBlockID)); + return success(); } @@ -1714,22 +1739,6 @@ struct CfToHandshakePass if (failed(applyFullConversion(modOp, target, std::move(patterns)))) return signalPassFailure(); - for (auto funcOp : modOp.getOps()) { - Block *firstBlock = &funcOp.getBlocks().front(); - - auto *endOpIt = firstBlock->getTerminator(); - - for (Block &otherBlock : - llvm::make_early_inc_range(llvm::drop_begin(funcOp))) { - for (Operation &op : llvm::make_early_inc_range(otherBlock)) { - op.moveBefore(endOpIt); - } - otherBlock.erase(); - } - } - - modOp.dump(); - // Clean up: Remove the definition of each __init* function, but only if it // has no remaining uses. This is safe because all valid calls to __init* // were tracked and deleted earlier. From 78ce39226069282e1f168f4e83ba4e4adf5ea619 Mon Sep 17 00:00:00 2001 From: Jiahui Xu Date: Thu, 20 Nov 2025 16:21:29 +0100 Subject: [PATCH 08/27] Handle a chain of GEPs in translate-llvm-to-std --- .../TranslateLLVMToStd.cpp | 85 ++++++++++++------- 1 file changed, 56 insertions(+), 29 deletions(-) diff --git a/tools/translate-llvm-to-std/TranslateLLVMToStd.cpp b/tools/translate-llvm-to-std/TranslateLLVMToStd.cpp index 7add9c6809..72bbe049b1 100644 --- a/tools/translate-llvm-to-std/TranslateLLVMToStd.cpp +++ b/tools/translate-llvm-to-std/TranslateLLVMToStd.cpp @@ -521,7 +521,19 @@ void TranslateLLVMToStd::translateGEPInst(llvm::GetElementPtrInst *gepInst) { mlir::Value baseAddress = valueMap[gepInst->getPointerOperand()]; SmallVector indexOperands; - auto memrefType = dyn_cast(baseAddress.getType()); + MemRefType memrefType; + + if (!baseAddress) { + // This is called when there is a chain of GEPs: GEP -> GEP -> ... -> many + // loads/stores + auto memrefAndIndices = + this->gepInstToMemRefAndIndicesMap[gepInst->getPointerOperand()]; + baseAddress = std::get<0>(memrefAndIndices); + memrefType = dyn_cast(baseAddress.getType()); + indexOperands = std::get<1>(memrefAndIndices); + } else { + memrefType = dyn_cast(baseAddress.getType()); + } if (!memrefType) llvm_unreachable("GEP should take memref as reference"); @@ -551,34 +563,6 @@ void TranslateLLVMToStd::translateGEPInst(llvm::GetElementPtrInst *gepInst) { indexOperands.erase(indexOperands.begin()); } - // NOTE: GEPOp has the following syntax (some details omitted): - // GEPOp %basePtr, %firstDim, %secondDim, %thirdDim, ... - // When you iterate through the indices, it also returns indices from left - // to right. However, the following two syntaxes are equivalent in LLVM: - // - (1) GEPop %basePtr, %firstDim, 0, 0 - // - (2) GEPop %basePtr, %firstDim - // Notice that, in the second example, the trailing constant 0s are omitted. - // Source: - // https://llvm.org/docs/GetElementPtr.html#why-do-gep-x-1-0-0-and-gep-x-1-alias - // - // However, memref::LoadOp and memref::StoreOp must have their indices - // match the memref. So here we need to fill in the constant zeros. - int remainingConstZeros = memrefType.getShape().size() - indexOperands.size(); - - if (remainingConstZeros < 0) { - llvm_unreachable( - "GEP should only omit indices, but shouldn't have more indices than " - "the original memref type extracted from the function argument!"); - } - - for (int i = 0; i < remainingConstZeros; i++) { - auto constZeroOp = arith::ConstantOp::create(builder, UnknownLoc::get(ctx), - builder.getI64IntegerAttr(0)); - auto idxCastOp = arith::IndexCastOp::create( - builder, UnknownLoc::get(ctx), builder.getIndexType(), constZeroOp); - indexOperands.push_back(idxCastOp); - } - this->gepInstToMemRefAndIndicesMap[gepInst] = MemRefAndIndices(baseAddress, indexOperands); } @@ -624,6 +608,39 @@ void TranslateLLVMToStd::translateBranchInst(llvm::BranchInst *inst) { } } +static void fillConstantZeroes(MemRefType memrefType, + SmallVector &indices, + OpBuilder &builder, MLIRContext *ctx) { + + // NOTE: GEPOp has the following syntax (some details omitted): + // GEPOp %basePtr, %firstDim, %secondDim, %thirdDim, ... + // When you iterate through the indices, it also returns indices from left + // to right. However, the following two syntaxes are equivalent in LLVM: + // - (1) GEPop %basePtr, %firstDim, 0, 0 + // - (2) GEPop %basePtr, %firstDim + // Notice that, in the second example, the trailing constant 0s are omitted. + // Source: + // https://llvm.org/docs/GetElementPtr.html#why-do-gep-x-1-0-0-and-gep-x-1-alias + // + // However, memref::LoadOp and memref::StoreOp must have their indices + // match the memref. So here we need to fill in the constant zeros. + int remainingConstZeros = memrefType.getShape().size() - indices.size(); + + if (remainingConstZeros < 0) { + llvm_unreachable( + "GEP should only omit indices, but shouldn't have more indices than " + "the original memref type extracted from the function argument!"); + } + + for (int i = 0; i < remainingConstZeros; i++) { + auto constZeroOp = arith::ConstantOp::create(builder, UnknownLoc::get(ctx), + builder.getI64IntegerAttr(0)); + auto idxCastOp = arith::IndexCastOp::create( + builder, UnknownLoc::get(ctx), builder.getIndexType(), constZeroOp); + indices.push_back(idxCastOp); + } +} + void TranslateLLVMToStd::translateLoadInst(llvm::LoadInst *loadInst) { Location loc = UnknownLoc::get(ctx); auto *instAddr = loadInst->getPointerOperand(); @@ -637,6 +654,11 @@ void TranslateLLVMToStd::translateLoadInst(llvm::LoadInst *loadInst) { auto memrefAndIndices = this->gepInstToMemRefAndIndicesMap[instAddr]; memref = memrefAndIndices.first; indices = memrefAndIndices.second; + if (auto memrefType = dyn_cast(memref.getType())) { + fillConstantZeroes(memrefType, indices, builder, ctx); + } else { + llvm::report_fatal_error("The pointer operand is not a memref type!"); + } } else { if (isa(instAddr)) llvm_unreachable( @@ -674,6 +696,11 @@ void TranslateLLVMToStd::translateStoreInst(llvm::StoreInst *storeInst) { auto memrefAndIndices = this->gepInstToMemRefAndIndicesMap[instAddr]; memref = memrefAndIndices.first; indices = memrefAndIndices.second; + if (auto memrefType = dyn_cast(memref.getType())) { + fillConstantZeroes(memrefType, indices, builder, ctx); + } else { + llvm::report_fatal_error("The pointer operand is not a memref type!"); + } } else { // NOTE: This condition handles a special case where a load only has // constant indices, e.g., tmp = mat[0][0]. From 1279c074cb579d9d197658b6de01d40423f81aae Mon Sep 17 00:00:00 2001 From: Jiahui Xu Date: Fri, 21 Nov 2025 18:30:17 +0100 Subject: [PATCH 09/27] suppress the no-verify-fixpoint warning --- tools/dynamatic/scripts/compile.sh | 1 + 1 file changed, 1 insertion(+) diff --git a/tools/dynamatic/scripts/compile.sh b/tools/dynamatic/scripts/compile.sh index 2a2bd0cb70..8f722763f7 100755 --- a/tools/dynamatic/scripts/compile.sh +++ b/tools/dynamatic/scripts/compile.sh @@ -203,6 +203,7 @@ $LLVM_TO_STD_TRANSLATION_BIN \ exit_on_fail "Failed to convert to std dialect" \ "Converted to std dialect" + # cf transformations (dynamatic) # - "drop-unlist-functions": Dropping the functions that are not needed in HLS # compilation. From c25a45d40ca9afd5d514b915274cc0c8477ea4e5 Mon Sep 17 00:00:00 2001 From: Jiahui Xu Date: Fri, 21 Nov 2025 18:30:45 +0100 Subject: [PATCH 10/27] flatten the MD array in llvm to std conversion --- tools/translate-llvm-to-std/InferArgTypes.cpp | 15 +- tools/translate-llvm-to-std/InferArgTypes.h | 2 +- .../TranslateLLVMToStd.cpp | 238 ++++++++---------- .../TranslateLLVMToStd.h | 2 + 4 files changed, 115 insertions(+), 142 deletions(-) diff --git a/tools/translate-llvm-to-std/InferArgTypes.cpp b/tools/translate-llvm-to-std/InferArgTypes.cpp index 5a71c0adcd..7aede9b7a8 100644 --- a/tools/translate-llvm-to-std/InferArgTypes.cpp +++ b/tools/translate-llvm-to-std/InferArgTypes.cpp @@ -20,7 +20,7 @@ using namespace mlir; using namespace dynamatic; -Type ArgType::getMlirType(OpBuilder &builder) const { +Type ArgType::getMlirType(OpBuilder &builder, bool flattenArray) const { Type baseMLIRElemType; if (std::holds_alternative(baseElemType)) { @@ -50,6 +50,17 @@ Type ArgType::getMlirType(OpBuilder &builder) const { if (arrayDimensions.empty()) { return baseMLIRElemType; } + + // Instead of returning memref<8 * 8 * i32> for A[8][8], we just return a + // flattened version memref<64 * i32> + if (flattenArray) { + int64_t flattenedSize = 1; + for (auto dim : arrayDimensions) { + flattenedSize *= dim; + } + return MemRefType::get(/* shape = */ {flattenedSize}, baseMLIRElemType); + } + return MemRefType::get(llvm::ArrayRef(arrayDimensions), baseMLIRElemType); } @@ -267,7 +278,7 @@ SmallVector getFuncArgTypes(const std::string &funcName, OpBuilder &builder) { SmallVector mlirArgTypes; for (const ArgType &clangType : map.at(funcName)) { - mlirArgTypes.push_back(clangType.getMlirType(builder)); + mlirArgTypes.push_back(clangType.getMlirType(builder, true)); } return mlirArgTypes; diff --git a/tools/translate-llvm-to-std/InferArgTypes.h b/tools/translate-llvm-to-std/InferArgTypes.h index 005e8e16d6..f5a9341e14 100644 --- a/tools/translate-llvm-to-std/InferArgTypes.h +++ b/tools/translate-llvm-to-std/InferArgTypes.h @@ -71,7 +71,7 @@ struct ArgType { std::vector arrayDimensions; bool isPassedByReference; - mlir::Type getMlirType(OpBuilder &builder) const; + mlir::Type getMlirType(OpBuilder &builder, bool flattenArray) const; }; using CFuncArgs = SmallVector; diff --git a/tools/translate-llvm-to-std/TranslateLLVMToStd.cpp b/tools/translate-llvm-to-std/TranslateLLVMToStd.cpp index 72bbe049b1..6bf75d14e7 100644 --- a/tools/translate-llvm-to-std/TranslateLLVMToStd.cpp +++ b/tools/translate-llvm-to-std/TranslateLLVMToStd.cpp @@ -203,13 +203,14 @@ void TranslateLLVMToStd::translateGlobalVars() { continue; auto *baseElemType = globalVar->getValueType(); - SmallVector shape; + int64_t numElements = 1; while (baseElemType->isArrayTy()) { - shape.push_back(baseElemType->getArrayNumElements()); + numElements *= baseElemType->getArrayNumElements(); baseElemType = baseElemType->getArrayElementType(); } auto baseMLIRElemType = getMLIRType(baseElemType, ctx); - auto memrefType = MemRefType::get(shape, baseMLIRElemType); + auto memrefType = + MemRefType::get(/* shape = */ {numElements}, baseMLIRElemType); StringRef symName = constant.getName(); StringAttr symNameAttr = StringAttr::get(ctx, Twine(symName)); @@ -514,41 +515,20 @@ void TranslateLLVMToStd::translateFCmpInst(llvm::FCmpInst *inst) { } void TranslateLLVMToStd::translateGEPInst(llvm::GetElementPtrInst *gepInst) { - // NOTE: this function does not create any corresponding op in the CF MLIR but - // only computes the indices for the LOAD/STORE ops that the gepInst drives. - - // Check if the GEP is not chained - mlir::Value baseAddress = valueMap[gepInst->getPointerOperand()]; - SmallVector indexOperands; - - MemRefType memrefType; - - if (!baseAddress) { - // This is called when there is a chain of GEPs: GEP -> GEP -> ... -> many - // loads/stores - auto memrefAndIndices = - this->gepInstToMemRefAndIndicesMap[gepInst->getPointerOperand()]; - baseAddress = std::get<0>(memrefAndIndices); - memrefType = dyn_cast(baseAddress.getType()); - indexOperands = std::get<1>(memrefAndIndices); - } else { - memrefType = dyn_cast(baseAddress.getType()); - } - if (!memrefType) - llvm_unreachable("GEP should take memref as reference"); + // Convert the GEP instruction into a series of "idx * dim + idx * dim ..." + + llvm::Type *baseElementType = gepInst->getSourceElementType(); - for (auto &indexUse : gepInst->indices()) { - llvm::Value *indexValue = indexUse; - mlir::Value mlirIndexValue = valueMap[indexValue]; - // NOTE: memref::LoadOp and memref::StoreOp expect their indices to be of - // IndexType. Therefore, we cast the i32/i64 indices to IndexType. This - // pattern will later be folded in the bitwidth optimization pass. - auto idxCastOp = arith::IndexCastOp::create( - builder, UnknownLoc::get(ctx), builder.getIndexType(), mlirIndexValue); - indexOperands.push_back(idxCastOp); + // Get index calculation: + SmallVector multipliers; + while (baseElementType->isArrayTy()) { + multipliers.push_back(baseElementType->getArrayNumElements()); + baseElementType = baseElementType->getArrayElementType(); } + SmallVector gepIndices(gepInst->indices()); + if (auto *defInst = gepInst->getPointerOperand(); isa_and_nonnull(defInst) || isa_and_nonnull(defInst)) { @@ -560,11 +540,69 @@ void TranslateLLVMToStd::translateGEPInst(llvm::GetElementPtrInst *gepInst) { // https://llvm.org/docs/GetElementPtr.html#why-is-the-extra-0-index-required // // Therefore, we drop the first element in this case - indexOperands.erase(indexOperands.begin()); + gepIndices.erase(gepIndices.begin()); } - this->gepInstToMemRefAndIndicesMap[gepInst] = - MemRefAndIndices(baseAddress, indexOperands); + SmallVector multipliedIndices; + + for (size_t i = 0; i < gepIndices.size(); ++i) { + mlir::Value mlirIndexValue = valueMap[gepIndices[i]]; + + // Calculate the partitial index + int64_t coeff = 1; + for (size_t j = i; j < multipliers.size(); j++) + coeff *= multipliers[j]; + + if (coeff == 1) { + multipliedIndices.push_back(mlirIndexValue); + } else if (llvm::isPowerOf2_64(coeff)) { + auto shiftValue = arith::ConstantOp::create( + builder, UnknownLoc::get(ctx), + builder.getIntegerAttr(builder.getI64Type(), llvm::Log2_64(coeff))); + auto idx = arith::ShLIOp::create(builder, UnknownLoc::get(ctx), + mlirIndexValue, shiftValue); + multipliedIndices.push_back(idx); + } else { + auto multipliedValue = arith::ConstantOp::create( + builder, UnknownLoc::get(ctx), + builder.getIntegerAttr(builder.getI64Type(), coeff)); + auto idx = arith::MulIOp::create(builder, UnknownLoc::get(ctx), + mlirIndexValue, multipliedValue); + multipliedIndices.push_back(idx); + } + } + + // If we do not start from a memref type, then it must be from a chain of + // GEPs. Here we accumulate our result onto that. + if (auto pointerOperand = valueMap[gepInst->getPointerOperand()]; + pointerOperand && !isa(pointerOperand.getType())) { + multipliedIndices.push_back(valueMap[gepInst->getPointerOperand()]); + } + + // [START accumulate the array index] + // Build balanced tree + std::function)> build = + [&](ArrayRef vals) -> mlir::Value { + assert(vals.size() > 0); + if (vals.size() == 1) + return vals[0]; + auto mid = vals.size() / 2; + auto lhs = build(vals.take_front(mid)); + auto rhs = build(vals.drop_front(mid)); + return arith::AddIOp::create(builder, UnknownLoc::get(ctx), lhs, rhs); + }; + mlir::Value accumulatedArrayIndex = build(multipliedIndices); + // [END accumulate the array index] + + valueMap[gepInst] = accumulatedArrayIndex; + + if (this->getInstToMemRefMap.count(gepInst->getPointerOperand())) { + this->getInstToMemRefMap[gepInst] = + this->getInstToMemRefMap[gepInst->getPointerOperand()]; + } else { + mlir::Value baseAddress = valueMap[gepInst->getPointerOperand()]; + this->getInstToMemRefMap[gepInst] = baseAddress; + } } void TranslateLLVMToStd::translateBranchInst(llvm::BranchInst *inst) { @@ -608,75 +646,20 @@ void TranslateLLVMToStd::translateBranchInst(llvm::BranchInst *inst) { } } -static void fillConstantZeroes(MemRefType memrefType, - SmallVector &indices, - OpBuilder &builder, MLIRContext *ctx) { - - // NOTE: GEPOp has the following syntax (some details omitted): - // GEPOp %basePtr, %firstDim, %secondDim, %thirdDim, ... - // When you iterate through the indices, it also returns indices from left - // to right. However, the following two syntaxes are equivalent in LLVM: - // - (1) GEPop %basePtr, %firstDim, 0, 0 - // - (2) GEPop %basePtr, %firstDim - // Notice that, in the second example, the trailing constant 0s are omitted. - // Source: - // https://llvm.org/docs/GetElementPtr.html#why-do-gep-x-1-0-0-and-gep-x-1-alias - // - // However, memref::LoadOp and memref::StoreOp must have their indices - // match the memref. So here we need to fill in the constant zeros. - int remainingConstZeros = memrefType.getShape().size() - indices.size(); - - if (remainingConstZeros < 0) { - llvm_unreachable( - "GEP should only omit indices, but shouldn't have more indices than " - "the original memref type extracted from the function argument!"); - } - - for (int i = 0; i < remainingConstZeros; i++) { - auto constZeroOp = arith::ConstantOp::create(builder, UnknownLoc::get(ctx), - builder.getI64IntegerAttr(0)); - auto idxCastOp = arith::IndexCastOp::create( - builder, UnknownLoc::get(ctx), builder.getIndexType(), constZeroOp); - indices.push_back(idxCastOp); - } -} - void TranslateLLVMToStd::translateLoadInst(llvm::LoadInst *loadInst) { Location loc = UnknownLoc::get(ctx); auto *instAddr = loadInst->getPointerOperand(); - mlir::Value memref; - SmallVector indices; - if (this->gepInstToMemRefAndIndicesMap.count(instAddr)) { - // Logic: In LLVM IR, load/store operations take the pointer computed from - // GEP ops, whereas in memref, load operations takes indices (of index - // type). This function uses the index operand collected when processing - // GEPs as operands of LOAD/STOREs. - auto memrefAndIndices = this->gepInstToMemRefAndIndicesMap[instAddr]; - memref = memrefAndIndices.first; - indices = memrefAndIndices.second; - if (auto memrefType = dyn_cast(memref.getType())) { - fillConstantZeroes(memrefType, indices, builder, ctx); - } else { - llvm::report_fatal_error("The pointer operand is not a memref type!"); - } - } else { - if (isa(instAddr)) - llvm_unreachable( - "Converting a load but the producer hasn't been converted yet!"); - // NOTE: This condition handles a special case where a load only has - // constant indices, e.g., tmp = mat[0][0]. - memref = this->valueMap[instAddr]; - auto memrefType = dyn_cast(memref.getType()); - int constZerosToAdd = memrefType.getShape().size(); - for (int i = 0; i < constZerosToAdd; i++) { - auto constZeroOp = arith::ConstantOp::create( - this->builder, loc, this->builder.getIndexAttr(0)); - indices.push_back(constZeroOp); - } - } + mlir::Value memref = getInstToMemRefMap[instAddr]; + + mlir::Value index = valueMap[loadInst->getPointerOperand()]; mlir::Type resType = getMLIRType(loadInst->getType(), ctx); - auto newOp = - memref::LoadOp::create(this->builder, loc, resType, memref, indices); + + // LoadOp needs the index operand to be of index type + auto indexOp = arith::IndexCastOp::create(builder, UnknownLoc::get(ctx), + builder.getIndexType(), index); + + auto newOp = memref::LoadOp::create(this->builder, loc, resType, memref, + /*indices = */ {indexOp}); valueMap[loadInst] = newOp.getResult(); translateMemDepAndNameAttrs(loadInst, newOp, *ctx, builder); } @@ -685,57 +668,34 @@ void TranslateLLVMToStd::translateStoreInst(llvm::StoreInst *storeInst) { Location loc = UnknownLoc::get(ctx); auto *instAddr = storeInst->getPointerOperand(); - mlir::Value memref; - SmallVector indices; - - if (this->gepInstToMemRefAndIndicesMap.count(instAddr)) { - // Logic: In LLVM IR, load/store operations take the pointer computed from - // GEP ops, whereas in memref, load operations takes indices (of index - // type). This function uses the index operand collected when processing - // GEPs as operands of LOAD/STOREs. - auto memrefAndIndices = this->gepInstToMemRefAndIndicesMap[instAddr]; - memref = memrefAndIndices.first; - indices = memrefAndIndices.second; - if (auto memrefType = dyn_cast(memref.getType())) { - fillConstantZeroes(memrefType, indices, builder, ctx); - } else { - llvm::report_fatal_error("The pointer operand is not a memref type!"); - } - } else { - // NOTE: This condition handles a special case where a load only has - // constant indices, e.g., tmp = mat[0][0]. - if (isa(instAddr)) - llvm_unreachable( - "Converting a load but the producer hasn't been converted yet!"); - memref = this->valueMap[instAddr]; - auto memrefType = dyn_cast(memref.getType()); - - int constZerosToAdd = memrefType.getShape().size(); - for (int i = 0; i < constZerosToAdd; i++) { - auto constZeroOp = arith::ConstantOp::create( - this->builder, loc, this->builder.getIndexAttr(0)); - indices.push_back(constZeroOp); - } - } + mlir::Value memref = getInstToMemRefMap[instAddr]; + mlir::Value index = valueMap[storeInst->getPointerOperand()]; + + // StoreOp needs the index operand to be of index type + auto indexOp = arith::IndexCastOp::create(builder, UnknownLoc::get(ctx), + builder.getIndexType(), index); mlir::Value storeValue = valueMap[storeInst->getValueOperand()]; - auto newOp = - memref::StoreOp::create(this->builder, loc, storeValue, memref, indices); + auto newOp = memref::StoreOp::create(this->builder, loc, storeValue, memref, + /*indices = */ {indexOp}); translateMemDepAndNameAttrs(storeInst, newOp, *ctx, builder); } void TranslateLLVMToStd::translateAllocaInst(llvm::AllocaInst *allocaInst) { Location loc = UnknownLoc::get(ctx); - SmallVector shape; + int64_t arraySize = 1; llvm::Type *baseElementType = allocaInst->getAllocatedType(); while (baseElementType->isArrayTy()) { - shape.push_back(baseElementType->getArrayNumElements()); + arraySize *= baseElementType->getArrayNumElements(); baseElementType = baseElementType->getArrayElementType(); } - auto memrefType = MemRefType::get(shape, getMLIRType(baseElementType, ctx)); + assert(arraySize > 0 && "The size of the array must be positive!"); + + auto memrefType = MemRefType::get(/*shape =*/{arraySize}, + getMLIRType(baseElementType, ctx)); auto allocaOp = memref::AllocaOp::create(builder, loc, memrefType); valueMap[allocaInst] = allocaOp->getResult(0); diff --git a/tools/translate-llvm-to-std/TranslateLLVMToStd.h b/tools/translate-llvm-to-std/TranslateLLVMToStd.h index 6d6714573c..d12f095c1d 100644 --- a/tools/translate-llvm-to-std/TranslateLLVMToStd.h +++ b/tools/translate-llvm-to-std/TranslateLLVMToStd.h @@ -74,6 +74,8 @@ class TranslateLLVMToStd { /// to lookup the input base address and indices. mlir::DenseMap gepInstToMemRefAndIndicesMap; + mlir::DenseMap getInstToMemRefMap; + /// The (C-code-level) argument types of the LLVM functions. FuncNameToCFuncArgsMap &argMap; From 963ad1ba2996e504750bcf6148eb17c5500446b9 Mon Sep 17 00:00:00 2001 From: Jiahui Xu Date: Fri, 21 Nov 2025 18:58:45 +0100 Subject: [PATCH 11/27] Fix load and store to constant zero index --- .../TranslateLLVMToStd.cpp | 46 +++++++++++++------ 1 file changed, 32 insertions(+), 14 deletions(-) diff --git a/tools/translate-llvm-to-std/TranslateLLVMToStd.cpp b/tools/translate-llvm-to-std/TranslateLLVMToStd.cpp index 6bf75d14e7..969394e5db 100644 --- a/tools/translate-llvm-to-std/TranslateLLVMToStd.cpp +++ b/tools/translate-llvm-to-std/TranslateLLVMToStd.cpp @@ -647,36 +647,54 @@ void TranslateLLVMToStd::translateBranchInst(llvm::BranchInst *inst) { } void TranslateLLVMToStd::translateLoadInst(llvm::LoadInst *loadInst) { - Location loc = UnknownLoc::get(ctx); auto *instAddr = loadInst->getPointerOperand(); - mlir::Value memref = getInstToMemRefMap[instAddr]; - + mlir::Value memref; mlir::Value index = valueMap[loadInst->getPointerOperand()]; mlir::Type resType = getMLIRType(loadInst->getType(), ctx); - // LoadOp needs the index operand to be of index type - auto indexOp = arith::IndexCastOp::create(builder, UnknownLoc::get(ctx), - builder.getIndexType(), index); + mlir::Value indexOp; + + if (getInstToMemRefMap.count(instAddr)) { + memref = getInstToMemRefMap[instAddr]; + // LoadOp needs the index operand to be of index type + indexOp = arith::IndexCastOp::create(builder, UnknownLoc::get(ctx), + builder.getIndexType(), index); + } else { + assert(isa(index.getType())); + memref = index; + indexOp = arith::ConstantOp::create(builder, UnknownLoc::get(ctx), + builder.getIndexAttr(0)); + } - auto newOp = memref::LoadOp::create(this->builder, loc, resType, memref, + auto newOp = memref::LoadOp::create(this->builder, UnknownLoc::get(ctx), + resType, memref, /*indices = */ {indexOp}); valueMap[loadInst] = newOp.getResult(); translateMemDepAndNameAttrs(loadInst, newOp, *ctx, builder); } void TranslateLLVMToStd::translateStoreInst(llvm::StoreInst *storeInst) { - Location loc = UnknownLoc::get(ctx); auto *instAddr = storeInst->getPointerOperand(); - - mlir::Value memref = getInstToMemRefMap[instAddr]; + mlir::Value memref; mlir::Value index = valueMap[storeInst->getPointerOperand()]; - // StoreOp needs the index operand to be of index type - auto indexOp = arith::IndexCastOp::create(builder, UnknownLoc::get(ctx), - builder.getIndexType(), index); + mlir::Value indexOp; + + if (getInstToMemRefMap.count(instAddr)) { + memref = getInstToMemRefMap[instAddr]; + // LoadOp needs the index operand to be of index type + indexOp = arith::IndexCastOp::create(builder, UnknownLoc::get(ctx), + builder.getIndexType(), index); + } else { + assert(isa(index.getType())); + memref = index; + indexOp = arith::ConstantOp::create(builder, UnknownLoc::get(ctx), + builder.getIndexAttr(0)); + } mlir::Value storeValue = valueMap[storeInst->getValueOperand()]; - auto newOp = memref::StoreOp::create(this->builder, loc, storeValue, memref, + auto newOp = memref::StoreOp::create(this->builder, UnknownLoc::get(ctx), + storeValue, memref, /*indices = */ {indexOp}); translateMemDepAndNameAttrs(storeInst, newOp, *ctx, builder); } From aa123e7492f4ceab5388d5fbf756f184f504b8f2 Mon Sep 17 00:00:00 2001 From: Jiahui Xu Date: Fri, 21 Nov 2025 19:24:35 +0100 Subject: [PATCH 12/27] disable unroll --- tools/dynamatic/scripts/compile.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tools/dynamatic/scripts/compile.sh b/tools/dynamatic/scripts/compile.sh index 8f722763f7..630b940d39 100755 --- a/tools/dynamatic/scripts/compile.sh +++ b/tools/dynamatic/scripts/compile.sh @@ -157,7 +157,7 @@ sed -i "s/^target triple = .*$//g" "$F_CLANG" # ------------------------------------------------------------------------------ $LLVM_BINS/opt -S \ - -passes="inline,mem2reg,consthoist,instcombine,function(loop-mssa(licm)),function(loop(loop-idiom,indvars,loop-deletion)),simplifycfg,loop-rotate,simplifycfg,sink,lowerswitch,simplifycfg" \ + -passes="inline,mem2reg,consthoist,instcombine,function(loop-mssa(licm)),function(loop(loop-idiom,indvars,loop-deletion)),simplifycfg,loop-rotate,simplifycfg,sink,lower-switch,simplifycfg" \ "$F_CLANG" \ > "$F_CLANG_OPTIMIZED" exit_on_fail "Failed to apply optimization to LLVM IR" \ From 70c99f029e81344ada0f025d101674c6a9f54dab Mon Sep 17 00:00:00 2001 From: Jiahui Xu Date: Fri, 21 Nov 2025 19:25:44 +0100 Subject: [PATCH 13/27] Also flatten the constant value attribute into 1d --- tools/translate-llvm-to-std/TranslateLLVMToStd.cpp | 8 +++++++- tools/translate-llvm-to-std/TranslateLLVMToStd.h | 2 +- tools/translate-llvm-to-std/main.cpp | 2 +- 3 files changed, 9 insertions(+), 3 deletions(-) diff --git a/tools/translate-llvm-to-std/TranslateLLVMToStd.cpp b/tools/translate-llvm-to-std/TranslateLLVMToStd.cpp index 969394e5db..05c9793fa6 100644 --- a/tools/translate-llvm-to-std/TranslateLLVMToStd.cpp +++ b/tools/translate-llvm-to-std/TranslateLLVMToStd.cpp @@ -118,8 +118,10 @@ convertInitializerToDenseElemAttr(llvm::GlobalVariable *globVar, values.reserve(numElems); convertInitializerToDenseElemAttrRecursive(globVar->getInitializer(), values, baseMLIRElemType); + return mlir::DenseElementsAttr::get( - mlir::RankedTensorType::get(shape, baseMLIRElemType), values); + mlir::RankedTensorType::get(/*shape = */ {numElems}, baseMLIRElemType), + values); } void convertInitializerToDenseElemAttrRecursive( @@ -529,6 +531,8 @@ void TranslateLLVMToStd::translateGEPInst(llvm::GetElementPtrInst *gepInst) { SmallVector gepIndices(gepInst->indices()); +#if 0 + // NOTE: this is no longer true as of 21.11.2025. We will not have an extra leading zero. if (auto *defInst = gepInst->getPointerOperand(); isa_and_nonnull(defInst) || isa_and_nonnull(defInst)) { @@ -542,6 +546,7 @@ void TranslateLLVMToStd::translateGEPInst(llvm::GetElementPtrInst *gepInst) { // Therefore, we drop the first element in this case gepIndices.erase(gepIndices.begin()); } +#endif SmallVector multipliedIndices; @@ -702,6 +707,7 @@ void TranslateLLVMToStd::translateStoreInst(llvm::StoreInst *storeInst) { void TranslateLLVMToStd::translateAllocaInst(llvm::AllocaInst *allocaInst) { Location loc = UnknownLoc::get(ctx); + // flatten the MD array into 1D int64_t arraySize = 1; llvm::Type *baseElementType = allocaInst->getAllocatedType(); diff --git a/tools/translate-llvm-to-std/TranslateLLVMToStd.h b/tools/translate-llvm-to-std/TranslateLLVMToStd.h index d12f095c1d..fe511f5753 100644 --- a/tools/translate-llvm-to-std/TranslateLLVMToStd.h +++ b/tools/translate-llvm-to-std/TranslateLLVMToStd.h @@ -85,7 +85,7 @@ class TranslateLLVMToStd { void naiveTranslation(mlir::Type returnType, mlir::ValueRange values, Instruction *inst) { MLIRTy op = - builder.create(UnknownLoc::get(ctx), returnType, values); + MLIRTy::create(builder, UnknownLoc::get(ctx), returnType, values); // Register the corresponding MLIR value of the result of the original // instruction. valueMap[inst] = op.getResult(); diff --git a/tools/translate-llvm-to-std/main.cpp b/tools/translate-llvm-to-std/main.cpp index 797e48700c..5e463344b1 100644 --- a/tools/translate-llvm-to-std/main.cpp +++ b/tools/translate-llvm-to-std/main.cpp @@ -96,7 +96,7 @@ int main(int argc, char **argv) { OpBuilder builder(&context); - auto module = builder.create(builder.getUnknownLoc()); + auto module = ModuleOp::create(builder, builder.getUnknownLoc()); // LLVM IR's argument does not indicate high-level types such as array shapes. // We use the original C code to recover this information. From ed571c1e9d3a62b4c986ba3a3977e4b7b98c2911 Mon Sep 17 00:00:00 2001 From: Jiahui Xu Date: Fri, 21 Nov 2025 19:33:21 +0100 Subject: [PATCH 14/27] Implicit trunc when converting profiler input string into APInt --- experimental/tools/frequency-profiler/Simulator.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/experimental/tools/frequency-profiler/Simulator.cpp b/experimental/tools/frequency-profiler/Simulator.cpp index 030b0c8ca9..5a10fe8bb4 100644 --- a/experimental/tools/frequency-profiler/Simulator.cpp +++ b/experimental/tools/frequency-profiler/Simulator.cpp @@ -92,7 +92,7 @@ static Any readValueWithType(mlir::Type type, std::stringstream &arg) { int64_t x; arg >> x; int64_t width = type.getIntOrFloatBitWidth(); - APInt aparg(width, x); + APInt aparg(width, x, /*implicittruc*/ true); return aparg; } if (type.isF32()) { From 6c70c4b33b2fd65f8edb592961469f498601099b Mon Sep 17 00:00:00 2001 From: Jiahui Xu Date: Fri, 21 Nov 2025 19:53:25 +0100 Subject: [PATCH 15/27] typo in the type conversion --- tools/translate-llvm-to-std/TranslateLLVMToStd.cpp | 2 +- tools/translate-llvm-to-std/TranslateLLVMToStd.h | 8 +++----- 2 files changed, 4 insertions(+), 6 deletions(-) diff --git a/tools/translate-llvm-to-std/TranslateLLVMToStd.cpp b/tools/translate-llvm-to-std/TranslateLLVMToStd.cpp index 05c9793fa6..76a7d2e98b 100644 --- a/tools/translate-llvm-to-std/TranslateLLVMToStd.cpp +++ b/tools/translate-llvm-to-std/TranslateLLVMToStd.cpp @@ -39,7 +39,7 @@ static mlir::Type getMLIRType(llvm::Type *llvmType, return mlir::Float32Type::get(context); } if (llvmType->isDoubleTy()) { - return mlir::Float32Type::get(context); + return mlir::Float64Type::get(context); } llvm_unreachable("Unhandled scalar type"); diff --git a/tools/translate-llvm-to-std/TranslateLLVMToStd.h b/tools/translate-llvm-to-std/TranslateLLVMToStd.h index fe511f5753..91cf81a14e 100644 --- a/tools/translate-llvm-to-std/TranslateLLVMToStd.h +++ b/tools/translate-llvm-to-std/TranslateLLVMToStd.h @@ -69,11 +69,9 @@ class TranslateLLVMToStd { /// In LLVM IR to CF, we convert GEP -> LOAD/STORE to LOAD/STORE. /// - In LLVM IR: load and store take pointer operand /// - In MLIR IR: load and store take base address and indices - /// We use this data structure to store the base address and indices provided - /// by GEPs when processing GEPs. This will be used in LOAD/STORE conversions - /// to lookup the input base address and indices. - mlir::DenseMap gepInstToMemRefAndIndicesMap; - + /// + /// We use this data structure to store the mapping between the GEP + /// instruction and the corresponding base address. mlir::DenseMap getInstToMemRefMap; /// The (C-code-level) argument types of the LLVM functions. From decb7a9cc8e60c6dc97865e9ab137973b8626f13 Mon Sep 17 00:00:00 2001 From: Jiahui Xu Date: Fri, 21 Nov 2025 23:36:25 +0100 Subject: [PATCH 16/27] [StdProfiler] Handling signness and truncation when coverting to integer --- experimental/tools/frequency-profiler/Simulator.cpp | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/experimental/tools/frequency-profiler/Simulator.cpp b/experimental/tools/frequency-profiler/Simulator.cpp index 5a10fe8bb4..abbcfbb4c2 100644 --- a/experimental/tools/frequency-profiler/Simulator.cpp +++ b/experimental/tools/frequency-profiler/Simulator.cpp @@ -92,7 +92,7 @@ static Any readValueWithType(mlir::Type type, std::stringstream &arg) { int64_t x; arg >> x; int64_t width = type.getIntOrFloatBitWidth(); - APInt aparg(width, x, /*implicittruc*/ true); + APInt aparg(width, x, /*isSigned = */ true, /*implicitTrunc = */ true); return aparg; } if (type.isF32()) { @@ -236,8 +236,11 @@ LogicalResult StdExecuter::execute(mlir::arith::ShRSIOp, std::vector &in, std::vector &out) { auto toShift = any_cast(in[0]).getSExtValue(); auto shiftAmount = any_cast(in[1]).getZExtValue(); + auto shifted = - APInt(any_cast(in[0]).getBitWidth(), toShift >> shiftAmount); + APInt(any_cast(in[0]).getBitWidth(), toShift >> shiftAmount, + /* isSigned = */ true, /*implicitTrunc = */ true); + out[0] = shifted; return success(); } From d906d4e374c33c8db8c9d4d27dcca18ed023c443 Mon Sep 17 00:00:00 2001 From: Jiahui Xu Date: Fri, 21 Nov 2025 23:44:21 +0100 Subject: [PATCH 17/27] disable test_bitint test (waiting for a bugfix in llvm) --- integration-test/test_bitint/test_bitint.c | 3 +++ tools/integration/TEST_SUITE.cpp | 7 +++++-- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/integration-test/test_bitint/test_bitint.c b/integration-test/test_bitint/test_bitint.c index fee57647c6..8ba012643a 100644 --- a/integration-test/test_bitint/test_bitint.c +++ b/integration-test/test_bitint/test_bitint.c @@ -1,5 +1,8 @@ #include "dynamatic/Integration.h" +// NOTE: This currently doesn't work in Clang! +// waiting for https://github.com/llvm/llvm-project/pull/161796 to be merged + #define AccumType unsigned _BitInt(32) #define DataType unsigned _BitInt(16) #define WeightType unsigned _BitInt(5) diff --git a/tools/integration/TEST_SUITE.cpp b/tools/integration/TEST_SUITE.cpp index 864442d9ed..f109bb274a 100644 --- a/tools/integration/TEST_SUITE.cpp +++ b/tools/integration/TEST_SUITE.cpp @@ -223,6 +223,10 @@ TEST_P(SpecFixture, spec) { INSTANTIATE_TEST_SUITE_P( MiscBenchmarks, BasicFixture, testing::Values( + // NOTE: Disabling "test_bitint": + // Waiting for a fix in Clang in the upstream (https://github.com/llvm/llvm-project/pull/161796) + // + // "test_bitint", "single_loop", "atax", "atax_float", @@ -285,8 +289,7 @@ INSTANTIATE_TEST_SUITE_P( "video_filter", "while_loop_1", "while_loop_3", - "test_loop_free", - "test_bitint" + "test_loop_free" ), [](const auto &info) { return info.param; }); From 32f7973c9c95f2500fb550e8480b0241f658b328 Mon Sep 17 00:00:00 2001 From: Jiahui Xu Date: Sat, 22 Nov 2025 00:31:42 +0100 Subject: [PATCH 18/27] InstCombine is no longer strictly necessary --- docs/DeveloperGuide/CompilerIntrinsics/Frontend.md | 10 +++------- tools/dynamatic/scripts/compile.sh | 2 +- 2 files changed, 4 insertions(+), 8 deletions(-) diff --git a/docs/DeveloperGuide/CompilerIntrinsics/Frontend.md b/docs/DeveloperGuide/CompilerIntrinsics/Frontend.md index 38933efc72..6b5d4ec77f 100644 --- a/docs/DeveloperGuide/CompilerIntrinsics/Frontend.md +++ b/docs/DeveloperGuide/CompilerIntrinsics/Frontend.md @@ -30,7 +30,7 @@ The output of Dynamatic's C frontend is an MLIR IR written in standard MLIR dial Notable optimizations that we need from the LLVM project: - `mem2reg`: Suppressing allocas (allocate memory on the heap) into regs. -- `instcombine`: Performing local DAG-to-DAG rewriting. Notably, this canonicalizes a chain of `getelementptr` instructions (GEPs). +- `instcombine`: Performing local DAG-to-DAG rewriting. - `loop-rotate`: Transforming loops to do-while loops as much as possible. - `simplifycfg`, `loopsimplify`: reducing the number of BBs (fewer branches). - `consthoist`: Moving constants around. @@ -55,7 +55,7 @@ The translation between LLVM IR and the standard dialects (especially the subset > - LLVM uses void ptrs for array inputs (both for fixed-size arrays `int arr[10][20]` and arrays with unbounded length `int * arr`). While in standard dialect, we use MemRef types `memref<10 * 20 * i32>` for referencing an array. > - LLVM does not represent constants as operations, while in MILR, constants must be "materialized" as explicit constant operations. > - LLVM has explicit SSA Phi nodes. MLIR replaces the Phis by block arguments. -> - The MemRef dialect does not have a special GEP operation for the array index calculation (e.g., `a[0][1]`); instead, it has a high-level syntax like `%result = memref.load [%memrefValue] %dim0, %dim1`. Therefore, GEPs are replaced by a direct connection between indices to the loads/stores. +> - The MemRef dialect does not have a special GEP operation for the array index calculation (e.g., `a[0][1]`); instead, it has a high-level syntax like `%result = memref.load [%memrefValue] %dim0, %dim1`. We directly flatten the arrays into 1D and replace GEPs with multiplications and additions. > - In LLVM, global values can be referenced by GEPs, but in MLIR MemRef dialect, global values can only be referenced via `get_global` op via the `sym_name` symbol attached to the global op. ### Type Conversion for Function Arguments @@ -80,15 +80,11 @@ For each LLVM function, Dynamatic performs the following translation: 1. **Constant materialization**. Create a corresponding `arith::ConstantOp` for each constant input of each `llvm::Instruction *` in LLVM IR. 2. **Block conversion**. Create an MLIR block for every basic block in LLVM. Remember the BB mappings (see the list above). For every Phi output in LLVM, it creates the corresponding block argument in MLIR (for each array function argument, the original C code is used to recover the correct MemRef type). Remember the value mappings (see the list above). 3. **Global conversion**. Create a MemRef global operation for each global variable in LLVM. -4. **Instruction translation**. Create an operation in MLIR for each LLVM operation from the input values (retrieved from the value mapping). Exception: GEP are removed and the indices are directly connected to the loads and stores. +4. **Instruction translation**. Create an operation in MLIR for each LLVM operation from the input values (retrieved from the value mapping). > [!NOTE] > The syntax of the GEP instruction in LLVM is often simplified/shortened. This requires a sophisticated conversion rule for GEP. Check out the LLVM documentation on [caveats of GEP syntax](https://llvm.org/docs/GetElementPtr.html) for more details. -> [!IMPORTANT] -> The `instcombine` pass must be applied before the conversion to eliminate a -> chain of GEPs. - ## Memory Dependency Analysis TODO diff --git a/tools/dynamatic/scripts/compile.sh b/tools/dynamatic/scripts/compile.sh index 630b940d39..9f588a0613 100755 --- a/tools/dynamatic/scripts/compile.sh +++ b/tools/dynamatic/scripts/compile.sh @@ -142,7 +142,7 @@ sed -i "s/^target triple = .*$//g" "$F_CLANG" # - inline: Inlines the function calls. # - mem2reg: Promote allocas (allocate memory on the heap) into regs. # - lowerswitch: Convert switch case into branches. -# - instcombine: combine operations. Needed to canonicalize a chain of GEPs. +# - instcombine: combine operations. # - loop-rotate: canonicalize loops to do-while loops # - consthoist: moving constants around # - simplifycfg: merge BBs From 52b5f2bff8e70b59841b0a1a1f6e9d4532b35696 Mon Sep 17 00:00:00 2001 From: Jiahui Xu Date: Sat, 22 Nov 2025 00:32:13 +0100 Subject: [PATCH 19/27] Copy comments from FlattenMemRefRowMajor to GEP conversion --- .../TranslateLLVMToStd.cpp | 85 ++++++++++++++++--- 1 file changed, 75 insertions(+), 10 deletions(-) diff --git a/tools/translate-llvm-to-std/TranslateLLVMToStd.cpp b/tools/translate-llvm-to-std/TranslateLLVMToStd.cpp index 76a7d2e98b..f1a5397c02 100644 --- a/tools/translate-llvm-to-std/TranslateLLVMToStd.cpp +++ b/tools/translate-llvm-to-std/TranslateLLVMToStd.cpp @@ -518,8 +518,24 @@ void TranslateLLVMToStd::translateFCmpInst(llvm::FCmpInst *inst) { void TranslateLLVMToStd::translateGEPInst(llvm::GetElementPtrInst *gepInst) { - // Convert the GEP instruction into a series of "idx * dim + idx * dim ..." + // The GEP instruction calculates the index that the load/store need to use + // the access memory. + // + // For instance A[i][j][k] would be GEP -> ... -> GEP -> load + // + // In TranslateLLVMToStd, we convert it to additions and multiplications. + // + // Also, we flatten the memory into 1D. Since it is very hard to reliably + // reverse engineer the order between the accesses. + // + // An example of flattening a multi-dimension array in a row-major order. + // Given an array: my_array[A][B][C][D]; + // we access it as my_array[i][j][k][l]; + // The flattened index is computed as: + // + // (B * C * D) * i + (C * D) * j + (D) * k + l + // Convert the GEP instruction into a series of "idx * dim + idx * dim ..." llvm::Type *baseElementType = gepInst->getSourceElementType(); // Get index calculation: @@ -531,6 +547,14 @@ void TranslateLLVMToStd::translateGEPInst(llvm::GetElementPtrInst *gepInst) { SmallVector gepIndices(gepInst->indices()); + mlir::Value baseAddress; + if (this->getInstToMemRefMap.count(gepInst->getPointerOperand())) { + baseAddress = this->getInstToMemRefMap[gepInst->getPointerOperand()]; + } else { + baseAddress = valueMap[gepInst->getPointerOperand()]; + } + this->getInstToMemRefMap[gepInst] = baseAddress; + #if 0 // NOTE: this is no longer true as of 21.11.2025. We will not have an extra leading zero. if (auto *defInst = gepInst->getPointerOperand(); @@ -548,6 +572,12 @@ void TranslateLLVMToStd::translateGEPInst(llvm::GetElementPtrInst *gepInst) { } #endif + // A list of value to be accumulated + // + // For the example above: + // + // multipliedIndices = { (B * C * D) * i, (C * D) * j, (D) * k + l } + SmallVector multipliedIndices; for (size_t i = 0; i < gepIndices.size(); ++i) { @@ -559,8 +589,50 @@ void TranslateLLVMToStd::translateGEPInst(llvm::GetElementPtrInst *gepInst) { coeff *= multipliers[j]; if (coeff == 1) { - multipliedIndices.push_back(mlirIndexValue); + if (auto *constVal = dyn_cast(gepIndices[i])) { + // Special case: when the GEP index is a constant number + // + // Here we need to handle tricky examples like this: + // %arrayidx2 = getelementptr inbounds nuw i8, ptr %2, i64 4 + // + // Here we are using GEP to advance 4 * i8 = 32 bits. If the + // element of the original array was 32-bit wide, then here we need to + // increment 1 step (instead of 4). + auto memrefType = dyn_cast(baseAddress.getType()); + assert(memrefType); + + // This is the size of the actual element (i.e., for 32 in the example + // above). + unsigned actualBaseElementWidth = memrefType.getElementTypeBitWidth(); + + // This is the size that the current GEP assumes (i.e., for i8 in the + // example above). + unsigned currBaseElementBitWidth = + baseElementType->getScalarSizeInBits(); + + // This the the number of `currBaseElementBitWidth` that we need to skip + // (i.e., 4 in the example above). + int64_t constInt = *constVal->getUniqueInteger().getRawData(); + + assert(actualBaseElementWidth % (currBaseElementBitWidth * constInt) == + 0 && + "Incorrect alignment!"); + + unsigned actualAdvanceValue = + actualBaseElementWidth / (currBaseElementBitWidth * constInt); + + auto byteAlignedConstantValue = arith::ConstantOp::create( + builder, UnknownLoc::get(ctx), + builder.getIntegerAttr(builder.getI64Type(), actualAdvanceValue)); + multipliedIndices.push_back(byteAlignedConstantValue); + } else { + multipliedIndices.push_back(mlirIndexValue); + } + } else if (llvm::isPowerOf2_64(coeff)) { + // Special case: the array dimension is a power of two. + // Here we can apply the optimization: multiply by power of 2 is the same + // as shifting auto shiftValue = arith::ConstantOp::create( builder, UnknownLoc::get(ctx), builder.getIntegerAttr(builder.getI64Type(), llvm::Log2_64(coeff))); @@ -568,6 +640,7 @@ void TranslateLLVMToStd::translateGEPInst(llvm::GetElementPtrInst *gepInst) { mlirIndexValue, shiftValue); multipliedIndices.push_back(idx); } else { + // Regular case: calculate the (paritial) flattened index auto multipliedValue = arith::ConstantOp::create( builder, UnknownLoc::get(ctx), builder.getIntegerAttr(builder.getI64Type(), coeff)); @@ -600,14 +673,6 @@ void TranslateLLVMToStd::translateGEPInst(llvm::GetElementPtrInst *gepInst) { // [END accumulate the array index] valueMap[gepInst] = accumulatedArrayIndex; - - if (this->getInstToMemRefMap.count(gepInst->getPointerOperand())) { - this->getInstToMemRefMap[gepInst] = - this->getInstToMemRefMap[gepInst->getPointerOperand()]; - } else { - mlir::Value baseAddress = valueMap[gepInst->getPointerOperand()]; - this->getInstToMemRefMap[gepInst] = baseAddress; - } } void TranslateLLVMToStd::translateBranchInst(llvm::BranchInst *inst) { From 03a6dc60a5bd29957709e77485b3787c2814038c Mon Sep 17 00:00:00 2001 From: Jiahui Xu Date: Sat, 22 Nov 2025 00:40:50 +0100 Subject: [PATCH 20/27] Remove deprecated statements --- experimental/lib/Support/CFGAnnotation.cpp | 14 ++--- .../lib/Support/FtdImplementation.cpp | 62 +++++++++---------- .../HandshakeCombineSteeringLogic.cpp | 9 ++- .../HandshakePlaceBuffersCustom.cpp | 9 ++- .../tools/elastic-miter/FabricGeneration.cpp | 58 +++++++++-------- 5 files changed, 74 insertions(+), 78 deletions(-) diff --git a/experimental/lib/Support/CFGAnnotation.cpp b/experimental/lib/Support/CFGAnnotation.cpp index 0502445e54..0dfe6bbfd0 100644 --- a/experimental/lib/Support/CFGAnnotation.cpp +++ b/experimental/lib/Support/CFGAnnotation.cpp @@ -321,7 +321,7 @@ dynamatic::experimental::cfg::restoreCfStructure(handshake::FuncOp &funcOp, rewriter.setInsertionPointToEnd(&bb); if (!edges.contains(blockID)) { - rewriter.create(bb.back().getLoc()); + func::ReturnOp::create(rewriter, bb.back().getLoc()); continue; } @@ -333,14 +333,14 @@ dynamatic::experimental::cfg::restoreCfStructure(handshake::FuncOp &funcOp, Operation *condOp = getOpByName(edge.getCondition(), blockID); if (!condOp) return failure(); - rewriter.create(bb.back().getLoc(), - condOp->getResult(0), - indexToBlock[edge.getTrueSuccessor()], - indexToBlock[edge.getFalseSuccessor()]); + cf::CondBranchOp::create(rewriter, bb.back().getLoc(), + condOp->getResult(0), + indexToBlock[edge.getTrueSuccessor()], + indexToBlock[edge.getFalseSuccessor()]); } else { unsigned successor = edge.getSuccessor(); - rewriter.create(bb.back().getLoc(), - indexToBlock[successor]); + cf::BranchOp::create(rewriter, bb.back().getLoc(), + indexToBlock[successor]); } } diff --git a/experimental/lib/Support/FtdImplementation.cpp b/experimental/lib/Support/FtdImplementation.cpp index 1ac75bd076..baecaea2c8 100644 --- a/experimental/lib/Support/FtdImplementation.cpp +++ b/experimental/lib/Support/FtdImplementation.cpp @@ -443,8 +443,8 @@ LogicalResult experimental::ftd::createPhiNetwork( for (auto *bb : blocksToAddPhi) { rewriter.setInsertionPointToStart(bb); - auto mergeOp = rewriter.create(bb->front().getLoc(), - operandsPerPhi[bb]); + auto mergeOp = handshake::MergeOp::create(rewriter, bb->front().getLoc(), + operandsPerPhi[bb]); mergeOp->setAttr(NEW_PHI, rewriter.getUnitAttr()); newMergePerPhi.insert({bb, mergeOp}); } @@ -542,8 +542,8 @@ LogicalResult ftd::createPhiNetworkDeps( // connected with an SSA network, and then everything is joined. ValueRange operands = dependencies; rewriter.setInsertionPointToStart(operand->getOwner()->getBlock()); - auto joinOp = rewriter.create( - operand->getOwner()->getLoc(), operands); + auto joinOp = handshake::JoinOp::create( + rewriter, operand->getOwner()->getLoc(), operands); joinOp->moveBefore(operandOwner); for (unsigned i = 0; i < dependencies.size(); i++) { @@ -581,8 +581,8 @@ static Value boolVariableToCircuit(PatternRewriter &rewriter, // Add a not if the condition is negated. if (singleCond->isNegated) { rewriter.setInsertionPointToStart(block); - auto notOp = rewriter.create( - block->getOperations().front().getLoc(), + auto notOp = handshake::NotOp::create( + rewriter, block->getOperations().front().getLoc(), ftd::channelifyType(condition.getType()), condition); notOp->setAttr(FTD_OP_TO_SKIP, rewriter.getUnitAttr()); return notOp->getResult(0); @@ -606,16 +606,16 @@ static Value boolExpressionToCircuit(PatternRewriter &rewriter, // Constant case (either 0 or 1) rewriter.setInsertionPointToStart(block); - auto sourceOp = rewriter.create( - block->getOperations().front().getLoc()); + auto sourceOp = handshake::SourceOp::create( + rewriter, block->getOperations().front().getLoc()); Value cnstTrigger = sourceOp.getResult(); auto intType = rewriter.getIntegerType(1); auto cstAttr = rewriter.getIntegerAttr( intType, (expr->type == ExpressionType::One ? 1 : 0)); - auto constOp = rewriter.create( - block->getOperations().front().getLoc(), cstAttr, cnstTrigger); + auto constOp = handshake::ConstantOp::create( + rewriter, block->getOperations().front().getLoc(), cstAttr, cnstTrigger); constOp->setAttr(FTD_OP_TO_SKIP, rewriter.getUnitAttr()); @@ -644,9 +644,9 @@ static Value bddToCircuit(PatternRewriter &rewriter, BDD *bdd, Block *block, bi, needsChannelify); // Create the multiplxer and add it to the rest of the circuit - auto muxOp = rewriter.create( - block->getOperations().front().getLoc(), muxOperands[0].getType(), - muxCond, muxOperands); + auto muxOp = handshake::MuxOp::create( + rewriter, block->getOperations().front().getLoc(), + muxOperands[0].getType(), muxCond, muxOperands); muxOp->setAttr(FTD_OP_TO_SKIP, rewriter.getUnitAttr()); return muxOp.getResult(); @@ -746,14 +746,14 @@ void ftd::addRegenOperandConsumer(PatternRewriter &rewriter, conditionValue = loop->getExitingBlock()->getTerminator()->getOperand(0); // Create the false constant to feed `init` - auto constOp = rewriter.create(consumerOp->getLoc(), - cstAttr, startValue); + auto constOp = handshake::ConstantOp::create(rewriter, consumerOp->getLoc(), + cstAttr, startValue); constOp->setAttr(FTD_INIT_MERGE, rewriter.getUnitAttr()); // Create the `init` operation SmallVector mergeOperands = {constOp.getResult(), conditionValue}; - auto initMergeOp = rewriter.create(consumerOp->getLoc(), - mergeOperands); + auto initMergeOp = handshake::MergeOp::create( + rewriter, consumerOp->getLoc(), mergeOperands); initMergeOp->setAttr(FTD_INIT_MERGE, rewriter.getUnitAttr()); // The multiplexer is to be fed by the init block, and takes as inputs the @@ -762,9 +762,9 @@ void ftd::addRegenOperandConsumer(PatternRewriter &rewriter, selectSignal.setType(channelifyType(selectSignal.getType())); SmallVector muxOperands = {regeneratedValue, regeneratedValue}; - auto muxOp = rewriter.create(regeneratedValue.getLoc(), - regeneratedValue.getType(), - selectSignal, muxOperands); + auto muxOp = handshake::MuxOp::create(rewriter, regeneratedValue.getLoc(), + regeneratedValue.getType(), + selectSignal, muxOperands); muxOp->setOperand(2, muxOp->getResult(0)); muxOp->setAttr(FTD_REGEN, rewriter.getUnitAttr()); @@ -813,7 +813,7 @@ using PairOperandConsumer = std::pair; // outside the loop. static Block *findClosestLoopExit(Operation *consumer, Value connection, const ftd::BlockIndexing &bi, - SmallVector exitBlocks) { + const SmallVector &exitBlocks) { // Find all the paths from the producer to the consumer using DFS std::vector> allPaths = findAllPaths(connection.getParentBlock(), consumer->getBlock(), bi); @@ -882,8 +882,8 @@ static Value addSuppressionInLoop(PatternRewriter &rewriter, CFGLoop *loop, rewriter.setInsertionPointToStart(loopExit); - branchOp = rewriter.create( - loopExit->getOperations().front().getLoc(), + branchOp = handshake::ConditionalBranchOp::create( + rewriter, loopExit->getOperations().front().getLoc(), ftd::getListTypes(connection.getType()), branchCond, connection); Value newConnection = btlt == MoreProducerThanConsumers @@ -981,9 +981,9 @@ static void insertDirectSuppression( Value branchCond = bddToCircuit(rewriter, bdd, consumer->getBlock(), bi); rewriter.setInsertionPointToStart(consumer->getBlock()); - auto branchOp = rewriter.create( - consumer->getLoc(), ftd::getListTypes(connection.getType()), branchCond, - connection); + auto branchOp = handshake::ConditionalBranchOp::create( + rewriter, consumer->getLoc(), ftd::getListTypes(connection.getType()), + branchCond, connection); // Take into account the possibility of a mux to get the condition input // also as data input. In this case, a branch needs to be created, but only @@ -1274,7 +1274,7 @@ LogicalResult experimental::ftd::addGsaGates(Region ®ion, mergeOperands.push_back(conditionValue); auto initMergeOp = - rewriter.create(loc, mergeOperands); + handshake::MergeOp::create(rewriter, loc, mergeOperands); initMergeOp->setAttr(FTD_INIT_MERGE, rewriter.getUnitAttr()); @@ -1287,8 +1287,8 @@ LogicalResult experimental::ftd::addGsaGates(Region ®ion, auto cstType = rewriter.getIntegerType(1); auto cstAttr = IntegerAttr::get(cstType, 0); rewriter.setInsertionPointToStart(initMergeOp->getBlock()); - auto constOp = rewriter.create( - initMergeOp->getLoc(), cstAttr, startValue); + auto constOp = handshake::ConstantOp::create( + rewriter, initMergeOp->getLoc(), cstAttr, startValue); constOp->setAttr(FTD_INIT_MERGE, rewriter.getUnitAttr()); initMergeOp->setOperand(0, constOp.getResult()); } @@ -1302,8 +1302,8 @@ LogicalResult experimental::ftd::addGsaGates(Region ®ion, } // Create the multiplexer - auto mux = rewriter.create(loc, gate->result.getType(), - conditionValue, operands); + auto mux = handshake::MuxOp::create(rewriter, loc, gate->result.getType(), + conditionValue, operands); // The one input gamma is marked at an operation to skip in the IR and // later removed diff --git a/experimental/lib/Transforms/HandshakeCombineSteeringLogic.cpp b/experimental/lib/Transforms/HandshakeCombineSteeringLogic.cpp index 1bfc946b8f..a5d87a2da6 100644 --- a/experimental/lib/Transforms/HandshakeCombineSteeringLogic.cpp +++ b/experimental/lib/Transforms/HandshakeCombineSteeringLogic.cpp @@ -22,7 +22,6 @@ #include "mlir/Support/LLVM.h" #include "mlir/Support/LogicalResult.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" -#include "llvm/Support/Debug.h" #include using namespace mlir; @@ -138,7 +137,7 @@ Operation *returnMuxAtSameDepth(Operation *op, // Otherwise, explore all users in DFS-like traversal until you hit a match Operation *finalOp = nullptr; - for (auto cons : cast(op).getResult().getUsers()) { + for (auto *cons : cast(op).getResult().getUsers()) { Operation *potentialOp = returnMuxAtSameDepth(cons, referenceMuxOp); if (potentialOp != nullptr) { finalOp = potentialOp; @@ -338,8 +337,8 @@ struct RemoveNotCondition rewriter.setInsertionPointAfter(condBranchOp); - auto newBranch = rewriter.create( - condOp->getLoc(), drivingNot.getOperand(), + auto newBranch = handshake::ConditionalBranchOp::create( + rewriter, condOp->getLoc(), drivingNot.getOperand(), condBranchOp.getDataOperand()); rewriter.replaceAllUsesWith(condBranchOp.getTrueResult(), @@ -397,7 +396,7 @@ struct HandshakeCombineSteeringLogicPass patterns.add(ctx); - if (failed(applyPatternsAndFoldGreedily(mod, std::move(patterns), config))) + if (failed(applyPatternsGreedily(mod, std::move(patterns), config))) return signalPassFailure(); }; }; diff --git a/experimental/lib/Transforms/HandshakePlaceBuffersCustom.cpp b/experimental/lib/Transforms/HandshakePlaceBuffersCustom.cpp index 991a4a8ead..f3309c3b72 100644 --- a/experimental/lib/Transforms/HandshakePlaceBuffersCustom.cpp +++ b/experimental/lib/Transforms/HandshakePlaceBuffersCustom.cpp @@ -7,7 +7,7 @@ //===----------------------------------------------------------------------===// // // Buffer placement pass in Handshake functions, it takes the location (i.e., -// the predecessor, and which output channel of it), type, and slots of the +// the predecessor, and which output channel of it), type, and slots of the // buffer that should be placed. // // This pass facilitates externally prototyping a custom buffer placement @@ -23,7 +23,6 @@ #include "experimental/Transforms/HandshakePlaceBuffersCustom.h" #include "dynamatic/Analysis/NameAnalysis.h" -#include "dynamatic/Dialect/Handshake/HandshakeAttributes.h" #include "dynamatic/Dialect/Handshake/HandshakeOps.h" #include "dynamatic/Support/CFG.h" #include "dynamatic/Transforms/HandshakeMaterialize.h" @@ -89,8 +88,8 @@ struct HandshakePlaceBuffersCustomPass // pull the enum itself from the optional auto bufferType = bufferTypeOpt.value(); - auto bufOp = builder.create(channel.getLoc(), channel, - slots, bufferType); + auto bufOp = handshake::BufferOp::create(builder, channel.getLoc(), channel, + slots, bufferType); inheritBB(succ, bufOp); Value bufferRes = bufOp->getResult(0); succ->replaceUsesOfWith(channel, bufferRes); @@ -104,4 +103,4 @@ dynamatic::experimental::buffer::createHandshakePlaceBuffersCustom( const std::string &type) { return std::make_unique(pred, outid, slots, type); -} \ No newline at end of file +} diff --git a/experimental/tools/elastic-miter/FabricGeneration.cpp b/experimental/tools/elastic-miter/FabricGeneration.cpp index d83ca9c6be..290446d9d4 100644 --- a/experimental/tools/elastic-miter/FabricGeneration.cpp +++ b/experimental/tools/elastic-miter/FabricGeneration.cpp @@ -14,8 +14,6 @@ #include #include "dynamatic/Analysis/NameAnalysis.h" -#include "dynamatic/Dialect/Handshake/HandshakeAttributes.h" -#include "dynamatic/Dialect/Handshake/HandshakeDialect.h" #include "dynamatic/Dialect/Handshake/HandshakeOps.h" #include "dynamatic/Dialect/Handshake/HandshakeTypes.h" #include "dynamatic/Support/CFG.h" @@ -28,8 +26,8 @@ using namespace dynamatic::handshake; namespace dynamatic::experimental { -void setHandshakeName(OpBuilder &builder, Operation *op, - const std::string &name) { +static void setHandshakeName(OpBuilder &builder, Operation *op, + const std::string &name) { StringAttr nameAttr = builder.getStringAttr(name); op->setAttr(dynamatic::NameAnalysis::ATTR_NAME, nameAttr); } @@ -63,13 +61,13 @@ buildNewFuncWithBlock(OpBuilder builder, const std::string &name, NamedAttribute argNamedAttr, NamedAttribute resNamedAttr) { - ArrayRef funcAttr({argNamedAttr, resNamedAttr}); + SmallVector funcAttr({argNamedAttr, resNamedAttr}); FunctionType funcType = builder.getFunctionType(inputTypes, outputTypes); // Create the new function - FuncOp newFuncOp = - builder.create(builder.getUnknownLoc(), name, funcType, funcAttr); + FuncOp newFuncOp = FuncOp::create(builder, builder.getUnknownLoc(), name, + funcType, funcAttr); // Add an entry block to the function Block *newEntryBlock = newFuncOp.addEntryBlock(); @@ -252,7 +250,7 @@ createReachabilityCircuit(MLIRContext &context, std::string ndwName = "ndw_in_" + funcOp.getArgName(i).str(); - NDWireOp ndWireOp = builder.create(funcOp.getLoc(), arg); + NDWireOp ndWireOp = NDWireOp::create(builder, funcOp.getLoc(), arg); setHandshakeAttributes(builder, ndWireOp, 0, ndwName); // Use the newly created NDwire's output instead of the original argument in @@ -289,7 +287,7 @@ createReachabilityCircuit(MLIRContext &context, std::string ndwName = "ndw_out_" + funcOp.getResName(i).str(); - NDWireOp endNDWireOp = builder.create(endOp->getLoc(), result); + NDWireOp endNDWireOp = NDWireOp::create(builder, endOp->getLoc(), result); setHandshakeAttributes(builder, endNDWireOp, 3, ndwName); // Use the newly created NDwire's output instead of the original argument in @@ -394,22 +392,22 @@ createElasticMiter(MLIRContext &context, ModuleOp lhsModule, ModuleOp rhsModule, std::string rhsNdwName = "rhs_in_ndw_" + lhsFuncOp.getArgName(i).str(); LazyForkOp forkOp = - builder.create(newFuncOp.getLoc(), miterArg, 2); + LazyForkOp::create(builder, newFuncOp.getLoc(), miterArg, 2); setHandshakeAttributes(builder, forkOp, BB_IN, forkName); - BufferOp lhsBufferOp = builder.create( - forkOp.getLoc(), forkOp.getResults()[BB_IN], bufferSlots, + BufferOp lhsBufferOp = BufferOp::create( + builder, forkOp.getLoc(), forkOp.getResults()[BB_IN], bufferSlots, dynamatic::handshake::BufferType::FIFO_BREAK_DV); - BufferOp rhsBufferOp = builder.create( - forkOp.getLoc(), forkOp.getResults()[1], bufferSlots, + BufferOp rhsBufferOp = BufferOp::create( + builder, forkOp.getLoc(), forkOp.getResults()[1], bufferSlots, dynamatic::handshake::BufferType::FIFO_BREAK_DV); setHandshakeAttributes(builder, lhsBufferOp, BB_IN, lhsBufName); setHandshakeAttributes(builder, rhsBufferOp, BB_IN, rhsBufName); NDWireOp lhsNDWireOp = - builder.create(forkOp.getLoc(), lhsBufferOp.getResult()); + NDWireOp::create(builder, forkOp.getLoc(), lhsBufferOp.getResult()); NDWireOp rhsNDWireOp = - builder.create(forkOp.getLoc(), rhsBufferOp.getResult()); + NDWireOp::create(builder, forkOp.getLoc(), rhsBufferOp.getResult()); setHandshakeAttributes(builder, lhsNDWireOp, BB_IN, lhsNdwName); setHandshakeAttributes(builder, rhsNDWireOp, BB_IN, rhsNdwName); @@ -471,33 +469,33 @@ createElasticMiter(MLIRContext &context, ModuleOp lhsModule, ModuleOp rhsModule, NDWireOp rhsEndNDWireOp; lhsEndNDWireOp = - builder.create(nextLocation->getLoc(), lhsResult); + NDWireOp::create(builder, nextLocation->getLoc(), lhsResult); rhsEndNDWireOp = - builder.create(nextLocation->getLoc(), rhsResult); + NDWireOp::create(builder, nextLocation->getLoc(), rhsResult); setHandshakeAttributes(builder, lhsEndNDWireOp, BB_OUT, lhsNDwName); setHandshakeAttributes(builder, rhsEndNDWireOp, BB_OUT, rhsNDwName); - BufferOp lhsEndBufferOp = builder.create( - nextLocation->getLoc(), lhsEndNDWireOp.getResult(), bufferSlots, - dynamatic::handshake::BufferType::FIFO_BREAK_DV); - BufferOp rhsEndBufferOp = builder.create( - nextLocation->getLoc(), rhsEndNDWireOp.getResult(), bufferSlots, - dynamatic::handshake::BufferType::FIFO_BREAK_DV); + BufferOp lhsEndBufferOp = BufferOp::create( + builder, nextLocation->getLoc(), lhsEndNDWireOp.getResult(), + bufferSlots, dynamatic::handshake::BufferType::FIFO_BREAK_DV); + BufferOp rhsEndBufferOp = BufferOp::create( + builder, nextLocation->getLoc(), rhsEndNDWireOp.getResult(), + bufferSlots, dynamatic::handshake::BufferType::FIFO_BREAK_DV); setHandshakeAttributes(builder, lhsEndBufferOp, BB_OUT, lhsBufName); setHandshakeAttributes(builder, rhsEndBufferOp, BB_OUT, rhsBufName); if (isa(lhsResult.getType())) { - ValueRange joinInputs = {lhsEndBufferOp.getResult(), - rhsEndBufferOp.getResult()}; + SmallVector joinInputs = {lhsEndBufferOp.getResult(), + rhsEndBufferOp.getResult()}; JoinOp joinOp = - builder.create(builder.getUnknownLoc(), joinInputs); + JoinOp::create(builder, builder.getUnknownLoc(), joinInputs); setHandshakeAttributes(builder, joinOp, BB_OUT, eqName); miterResultValues.push_back(joinOp.getResult()); } else { - CmpIOp compOp = builder.create( - builder.getUnknownLoc(), CmpIPredicate::eq, + CmpIOp compOp = CmpIOp::create( + builder, builder.getUnknownLoc(), CmpIPredicate::eq, lhsEndBufferOp.getResult(), rhsEndBufferOp.getResult()); setHandshakeAttributes(builder, compOp, BB_OUT, eqName); miterResultValues.push_back(compOp.getResult()); @@ -516,7 +514,7 @@ createElasticMiter(MLIRContext &context, ModuleOp lhsModule, ModuleOp rhsModule, } EndOp newEndOp = - builder.create(builder.getUnknownLoc(), miterResultValues); + EndOp::create(builder, builder.getUnknownLoc(), miterResultValues); setHandshakeAttributes(builder, newEndOp, BB_OUT, "end"); // Delete old end operation, we can only have one end operation in a From 41f7effb4abc59582049e48b6d0e48409bcc28a0 Mon Sep 17 00:00:00 2001 From: Jiahui Xu Date: Sat, 22 Nov 2025 01:02:27 +0100 Subject: [PATCH 21/27] fix some compiler warnings --- experimental/include/experimental/Support/SubjectGraph.h | 2 +- experimental/lib/Support/SubjectGraph.cpp | 6 +++--- experimental/tools/elastic-miter/SmvUtils.cpp | 4 ++-- lib/Support/Backedge.cpp | 8 ++++---- .../CreatingPasses/Transforms/GreedySimplifyMergeLike.cpp | 6 +++--- .../lib/CreatingPasses/Transforms/SimplifyMergeLike.cpp | 4 ++-- unittests/Support/ConstraintProgramming/CPTest.cpp | 2 -- 7 files changed, 15 insertions(+), 17 deletions(-) diff --git a/experimental/include/experimental/Support/SubjectGraph.h b/experimental/include/experimental/Support/SubjectGraph.h index d7798a1d0e..75fdc2a643 100644 --- a/experimental/include/experimental/Support/SubjectGraph.h +++ b/experimental/include/experimental/Support/SubjectGraph.h @@ -57,7 +57,7 @@ class BaseSubjectGraph { bool isBlackbox = false; void loadBlifFile(std::initializer_list inputs, - std::string toAppend = ""); + const std::string &toAppend = ""); // Helper function to connect the input nodes of the current module // to the output nodes of the preceding module in the subject graph diff --git a/experimental/lib/Support/SubjectGraph.cpp b/experimental/lib/Support/SubjectGraph.cpp index 6ca7c7832e..6bbac91744 100644 --- a/experimental/lib/Support/SubjectGraph.cpp +++ b/experimental/lib/Support/SubjectGraph.cpp @@ -95,7 +95,7 @@ void BaseSubjectGraph::connectInputNodesHelper(ChannelSignals ¤tSignals, // Constructs the file path based on Operation name and parameters, calls the // Blif parser to load the Blif file void BaseSubjectGraph::loadBlifFile(std::initializer_list inputs, - std::string toAppend) { + const std::string &toAppend) { std::string moduleType; std::string fullPath; moduleType = op->getName().getStringRef(); @@ -119,8 +119,8 @@ void BaseSubjectGraph::loadBlifFile(std::initializer_list inputs, } // Assigns signals to the variables in ChannelSignals struct -void assignSignals(ChannelSignals &signals, Node *node, - const std::string &nodeName) { +static void assignSignals(ChannelSignals &signals, Node *node, + const std::string &nodeName) { // If nodeName includes "valid" or "ready", assign it to the respective // signal. If it does not, assign it to the data signals. if (nodeName.find("valid") != std::string::npos) { diff --git a/experimental/tools/elastic-miter/SmvUtils.cpp b/experimental/tools/elastic-miter/SmvUtils.cpp index 9634e7fc28..30bf153cf3 100644 --- a/experimental/tools/elastic-miter/SmvUtils.cpp +++ b/experimental/tools/elastic-miter/SmvUtils.cpp @@ -80,7 +80,7 @@ static int executeWithRedirect(const std::string &command, // Redirect stdout, keep default of stdin and stderr std::string stdoutFileString = stdoutFile.string(); - ArrayRef> redirects = { + SmallVector> redirects = { std::nullopt, stdoutFileString, std::nullopt}; std::string errMsg; @@ -181,4 +181,4 @@ handshake2smv(const std::filesystem::path &mlirPath, return std::make_pair(smvFile, moduleName); } -} // namespace dynamatic::experimental \ No newline at end of file +} // namespace dynamatic::experimental diff --git a/lib/Support/Backedge.cpp b/lib/Support/Backedge.cpp index df258a9fcf..6919a6e71e 100644 --- a/lib/Support/Backedge.cpp +++ b/lib/Support/Backedge.cpp @@ -57,11 +57,11 @@ Backedge BackedgeBuilder::get(Type resultType, mlir::LocationAttr optionalLoc) { // Create the opearion using either a builder or a rewriter Operation *op; if (rewriter) - op = rewriter->create( - optionalLoc, resultType, ValueRange{}); + op = mlir::UnrealizedConversionCastOp::create(*rewriter, optionalLoc, + resultType, ValueRange{}); else - op = builder->create( - optionalLoc, resultType, ValueRange{}); + op = mlir::UnrealizedConversionCastOp::create(*rewriter, optionalLoc, + resultType, ValueRange{}); edges.push_back(op); return Backedge(op); } diff --git a/tutorials/CreatingPasses/lib/CreatingPasses/Transforms/GreedySimplifyMergeLike.cpp b/tutorials/CreatingPasses/lib/CreatingPasses/Transforms/GreedySimplifyMergeLike.cpp index a254ce49e1..2bf16cf475 100644 --- a/tutorials/CreatingPasses/lib/CreatingPasses/Transforms/GreedySimplifyMergeLike.cpp +++ b/tutorials/CreatingPasses/lib/CreatingPasses/Transforms/GreedySimplifyMergeLike.cpp @@ -74,8 +74,8 @@ struct DowngradeIndexlessControlMerge // the control merge we are replacing. The merge has the exact same inputs // as the control merge rewriter.setInsertionPoint(cmergeOp); - handshake::MergeOp newMergeOp = rewriter.create( - cmergeOp.getLoc(), cmergeOp->getOperands()); + handshake::MergeOp newMergeOp = handshake::MergeOp::create( + rewriter, cmergeOp.getLoc(), cmergeOp->getOperands()); // We are modifying the operation rewriter.modifyOpInPlace(cmergeOp, [&] { @@ -119,7 +119,7 @@ struct GreedySimplifyMergeLikePass patterns.add(ctx); // Apply our two patterns recursively on all operations in the input module - if (failed(applyPatternsAndFoldGreedily(mod, std::move(patterns), config))) + if (failed(applyPatternsGreedily(mod, std::move(patterns), config))) signalPassFailure(); } }; diff --git a/tutorials/CreatingPasses/lib/CreatingPasses/Transforms/SimplifyMergeLike.cpp b/tutorials/CreatingPasses/lib/CreatingPasses/Transforms/SimplifyMergeLike.cpp index a1fec6f7c1..77ddf5cd8d 100644 --- a/tutorials/CreatingPasses/lib/CreatingPasses/Transforms/SimplifyMergeLike.cpp +++ b/tutorials/CreatingPasses/lib/CreatingPasses/Transforms/SimplifyMergeLike.cpp @@ -60,8 +60,8 @@ static LogicalResult performSimplification(handshake::FuncOp funcOp, // the control merge we are replacing. The merge has the exact same inputs // as the control merge builder.setInsertionPoint(cmergeOp); - handshake::MergeOp newMergeOp = builder.create( - cmergeOp.getLoc(), cmergeOp->getOperands()); + handshake::MergeOp newMergeOp = handshake::MergeOp::create( + builder, cmergeOp.getLoc(), cmergeOp->getOperands()); // Then, replace the control merge's first result (the selected input) with // the single result of the newly created merge operation diff --git a/unittests/Support/ConstraintProgramming/CPTest.cpp b/unittests/Support/ConstraintProgramming/CPTest.cpp index 7c0f4f93b6..b09a57a45b 100644 --- a/unittests/Support/ConstraintProgramming/CPTest.cpp +++ b/unittests/Support/ConstraintProgramming/CPTest.cpp @@ -110,7 +110,6 @@ TEST_P(ParamSolverTest, SimpleMaxLP) { auto xVal = solver->getValue(x); auto yVal = solver->getValue(y); - auto obj = solver->getObjective(); EXPECT_LE(xVal + yVal, 10 + 1e-6); // Constraint check } @@ -369,7 +368,6 @@ TEST(LinExprOpTest, ChainedAddSub) { // [END AI-generated test cases] // Factories for both solvers -std::unique_ptr makeCbc() { return std::make_unique(); } std::unique_ptr makeGurobi() { return std::make_unique(); From 99a94263cfa394bd27e385bac8b6db1f344ab074 Mon Sep 17 00:00:00 2001 From: Jiahui Xu Date: Sat, 22 Nov 2025 01:04:55 +0100 Subject: [PATCH 22/27] Revert "fix some compiler warnings" This reverts commit f44acf20eeedba5ea6b02f366ed89d918eadca73. --- experimental/include/experimental/Support/SubjectGraph.h | 2 +- experimental/lib/Support/SubjectGraph.cpp | 6 +++--- experimental/tools/elastic-miter/SmvUtils.cpp | 4 ++-- lib/Support/Backedge.cpp | 8 ++++---- .../CreatingPasses/Transforms/GreedySimplifyMergeLike.cpp | 6 +++--- .../lib/CreatingPasses/Transforms/SimplifyMergeLike.cpp | 4 ++-- unittests/Support/ConstraintProgramming/CPTest.cpp | 2 ++ 7 files changed, 17 insertions(+), 15 deletions(-) diff --git a/experimental/include/experimental/Support/SubjectGraph.h b/experimental/include/experimental/Support/SubjectGraph.h index 75fdc2a643..d7798a1d0e 100644 --- a/experimental/include/experimental/Support/SubjectGraph.h +++ b/experimental/include/experimental/Support/SubjectGraph.h @@ -57,7 +57,7 @@ class BaseSubjectGraph { bool isBlackbox = false; void loadBlifFile(std::initializer_list inputs, - const std::string &toAppend = ""); + std::string toAppend = ""); // Helper function to connect the input nodes of the current module // to the output nodes of the preceding module in the subject graph diff --git a/experimental/lib/Support/SubjectGraph.cpp b/experimental/lib/Support/SubjectGraph.cpp index 6bbac91744..6ca7c7832e 100644 --- a/experimental/lib/Support/SubjectGraph.cpp +++ b/experimental/lib/Support/SubjectGraph.cpp @@ -95,7 +95,7 @@ void BaseSubjectGraph::connectInputNodesHelper(ChannelSignals ¤tSignals, // Constructs the file path based on Operation name and parameters, calls the // Blif parser to load the Blif file void BaseSubjectGraph::loadBlifFile(std::initializer_list inputs, - const std::string &toAppend) { + std::string toAppend) { std::string moduleType; std::string fullPath; moduleType = op->getName().getStringRef(); @@ -119,8 +119,8 @@ void BaseSubjectGraph::loadBlifFile(std::initializer_list inputs, } // Assigns signals to the variables in ChannelSignals struct -static void assignSignals(ChannelSignals &signals, Node *node, - const std::string &nodeName) { +void assignSignals(ChannelSignals &signals, Node *node, + const std::string &nodeName) { // If nodeName includes "valid" or "ready", assign it to the respective // signal. If it does not, assign it to the data signals. if (nodeName.find("valid") != std::string::npos) { diff --git a/experimental/tools/elastic-miter/SmvUtils.cpp b/experimental/tools/elastic-miter/SmvUtils.cpp index 30bf153cf3..9634e7fc28 100644 --- a/experimental/tools/elastic-miter/SmvUtils.cpp +++ b/experimental/tools/elastic-miter/SmvUtils.cpp @@ -80,7 +80,7 @@ static int executeWithRedirect(const std::string &command, // Redirect stdout, keep default of stdin and stderr std::string stdoutFileString = stdoutFile.string(); - SmallVector> redirects = { + ArrayRef> redirects = { std::nullopt, stdoutFileString, std::nullopt}; std::string errMsg; @@ -181,4 +181,4 @@ handshake2smv(const std::filesystem::path &mlirPath, return std::make_pair(smvFile, moduleName); } -} // namespace dynamatic::experimental +} // namespace dynamatic::experimental \ No newline at end of file diff --git a/lib/Support/Backedge.cpp b/lib/Support/Backedge.cpp index 6919a6e71e..df258a9fcf 100644 --- a/lib/Support/Backedge.cpp +++ b/lib/Support/Backedge.cpp @@ -57,11 +57,11 @@ Backedge BackedgeBuilder::get(Type resultType, mlir::LocationAttr optionalLoc) { // Create the opearion using either a builder or a rewriter Operation *op; if (rewriter) - op = mlir::UnrealizedConversionCastOp::create(*rewriter, optionalLoc, - resultType, ValueRange{}); + op = rewriter->create( + optionalLoc, resultType, ValueRange{}); else - op = mlir::UnrealizedConversionCastOp::create(*rewriter, optionalLoc, - resultType, ValueRange{}); + op = builder->create( + optionalLoc, resultType, ValueRange{}); edges.push_back(op); return Backedge(op); } diff --git a/tutorials/CreatingPasses/lib/CreatingPasses/Transforms/GreedySimplifyMergeLike.cpp b/tutorials/CreatingPasses/lib/CreatingPasses/Transforms/GreedySimplifyMergeLike.cpp index 2bf16cf475..a254ce49e1 100644 --- a/tutorials/CreatingPasses/lib/CreatingPasses/Transforms/GreedySimplifyMergeLike.cpp +++ b/tutorials/CreatingPasses/lib/CreatingPasses/Transforms/GreedySimplifyMergeLike.cpp @@ -74,8 +74,8 @@ struct DowngradeIndexlessControlMerge // the control merge we are replacing. The merge has the exact same inputs // as the control merge rewriter.setInsertionPoint(cmergeOp); - handshake::MergeOp newMergeOp = handshake::MergeOp::create( - rewriter, cmergeOp.getLoc(), cmergeOp->getOperands()); + handshake::MergeOp newMergeOp = rewriter.create( + cmergeOp.getLoc(), cmergeOp->getOperands()); // We are modifying the operation rewriter.modifyOpInPlace(cmergeOp, [&] { @@ -119,7 +119,7 @@ struct GreedySimplifyMergeLikePass patterns.add(ctx); // Apply our two patterns recursively on all operations in the input module - if (failed(applyPatternsGreedily(mod, std::move(patterns), config))) + if (failed(applyPatternsAndFoldGreedily(mod, std::move(patterns), config))) signalPassFailure(); } }; diff --git a/tutorials/CreatingPasses/lib/CreatingPasses/Transforms/SimplifyMergeLike.cpp b/tutorials/CreatingPasses/lib/CreatingPasses/Transforms/SimplifyMergeLike.cpp index 77ddf5cd8d..a1fec6f7c1 100644 --- a/tutorials/CreatingPasses/lib/CreatingPasses/Transforms/SimplifyMergeLike.cpp +++ b/tutorials/CreatingPasses/lib/CreatingPasses/Transforms/SimplifyMergeLike.cpp @@ -60,8 +60,8 @@ static LogicalResult performSimplification(handshake::FuncOp funcOp, // the control merge we are replacing. The merge has the exact same inputs // as the control merge builder.setInsertionPoint(cmergeOp); - handshake::MergeOp newMergeOp = handshake::MergeOp::create( - builder, cmergeOp.getLoc(), cmergeOp->getOperands()); + handshake::MergeOp newMergeOp = builder.create( + cmergeOp.getLoc(), cmergeOp->getOperands()); // Then, replace the control merge's first result (the selected input) with // the single result of the newly created merge operation diff --git a/unittests/Support/ConstraintProgramming/CPTest.cpp b/unittests/Support/ConstraintProgramming/CPTest.cpp index b09a57a45b..7c0f4f93b6 100644 --- a/unittests/Support/ConstraintProgramming/CPTest.cpp +++ b/unittests/Support/ConstraintProgramming/CPTest.cpp @@ -110,6 +110,7 @@ TEST_P(ParamSolverTest, SimpleMaxLP) { auto xVal = solver->getValue(x); auto yVal = solver->getValue(y); + auto obj = solver->getObjective(); EXPECT_LE(xVal + yVal, 10 + 1e-6); // Constraint check } @@ -368,6 +369,7 @@ TEST(LinExprOpTest, ChainedAddSub) { // [END AI-generated test cases] // Factories for both solvers +std::unique_ptr makeCbc() { return std::make_unique(); } std::unique_ptr makeGurobi() { return std::make_unique(); From 33fc5517bf1cd310995655c90a254ad70736115a Mon Sep 17 00:00:00 2001 From: Jiahui Xu Date: Sat, 22 Nov 2025 01:26:15 +0100 Subject: [PATCH 23/27] No more lit test --- .github/workflows/ci.yml | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 3ba46bc4fa..900eb575b8 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -82,13 +82,13 @@ jobs: - name: build run: ./build.sh --release --force - - name: check-dynamatic - if: steps.build.outputs.exit_code == 0 - run: ninja -C build check-dynamatic - - - name: check-dynamatic-experimental - if: steps.build.outputs.exit_code == 0 - run: ninja -C build check-dynamatic-experimental +# - name: check-dynamatic +# if: steps.build.outputs.exit_code == 0 +# run: ninja -C build check-dynamatic +# +# - name: check-dynamatic-experimental +# if: steps.build.outputs.exit_code == 0 +# run: ninja -C build check-dynamatic-experimental - name: integration-test if: steps.build.outputs.exit_code == 0 From 7b931168b03cce5ed1c3914be4305fdb65c261f5 Mon Sep 17 00:00:00 2001 From: Jiahui Xu Date: Sat, 22 Nov 2025 01:35:26 +0100 Subject: [PATCH 24/27] Disable verbose debug output --- tools/translate-llvm-to-std/InferArgTypes.cpp | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/tools/translate-llvm-to-std/InferArgTypes.cpp b/tools/translate-llvm-to-std/InferArgTypes.cpp index 7aede9b7a8..74f67d9c89 100644 --- a/tools/translate-llvm-to-std/InferArgTypes.cpp +++ b/tools/translate-llvm-to-std/InferArgTypes.cpp @@ -137,8 +137,8 @@ static std::optional processScalarType(CXType clangType) { return processScalarType(clang_getTypedefDeclUnderlyingType(typedefCursor)); } default: { - llvm::errs() << "Type ID of unhandled scalar type: " << clangType.kind - << "\n"; + LLVM_DEBUG(llvm::errs() << "Type ID of unhandled scalar type: " + << clangType.kind << "\n"); return std::nullopt; } @@ -181,7 +181,8 @@ static std::optional fromCXType(CXType type) { } } - llvm::errs() << "Unhandled compound type id: " << type.kind << "\n"; + LLVM_DEBUG(llvm::errs() << "Unhandled compound type id: " << type.kind + << "\n"); // TODO: One important thing to handle in the future is the arguments that // are **passed by reference**. It is probably correct to promote them to // the function return values. @@ -214,9 +215,10 @@ static CXChildVisitResult visitParamDecl(CXCursor cursor, CXCursor parent, if (argType.has_value()) { args->push_back(argType.value()); } else { - llvm::errs() << "Warning - unable to parse " << getCursorSpelling(cursor) - << " with type " - << clang_getCString(clang_getTypeSpelling(type)) << "!\n"; + LLVM_DEBUG(llvm::errs() + << "Warning - unable to parse " << getCursorSpelling(cursor) + << " with type " + << clang_getCString(clang_getTypeSpelling(type)) << "!\n"); } // else: Maybe instead of push nothing here, we should have a ArgType that // is specifically for "I don't know what it is?" From 2ab4d43114b495e17c30b2212d7d4fe5b2e32511 Mon Sep 17 00:00:00 2001 From: Jiahui Xu Date: Sat, 22 Nov 2025 11:29:14 +0100 Subject: [PATCH 25/27] Upload also the err message --- .github/workflows/ci.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 900eb575b8..b449fbae67 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -126,4 +126,5 @@ jobs: integration-test/*/out*/comp integration-test/*/out*/sim/report.txt integration-test/*/out/dynamatic_out.txt + integration-test/*/out/dynamatic_err.txt From ead24a47ba7e4c4f4cfca78759c80fc79435d731 Mon Sep 17 00:00:00 2001 From: Jiahui Xu Date: Sat, 22 Nov 2025 12:17:22 +0100 Subject: [PATCH 26/27] Fix reversed --- tools/translate-llvm-to-std/TranslateLLVMToStd.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tools/translate-llvm-to-std/TranslateLLVMToStd.cpp b/tools/translate-llvm-to-std/TranslateLLVMToStd.cpp index f1a5397c02..7f72aa2e01 100644 --- a/tools/translate-llvm-to-std/TranslateLLVMToStd.cpp +++ b/tools/translate-llvm-to-std/TranslateLLVMToStd.cpp @@ -614,12 +614,12 @@ void TranslateLLVMToStd::translateGEPInst(llvm::GetElementPtrInst *gepInst) { // (i.e., 4 in the example above). int64_t constInt = *constVal->getUniqueInteger().getRawData(); - assert(actualBaseElementWidth % (currBaseElementBitWidth * constInt) == + assert((currBaseElementBitWidth * constInt) % actualBaseElementWidth == 0 && "Incorrect alignment!"); unsigned actualAdvanceValue = - actualBaseElementWidth / (currBaseElementBitWidth * constInt); + (currBaseElementBitWidth * constInt) / actualBaseElementWidth; auto byteAlignedConstantValue = arith::ConstantOp::create( builder, UnknownLoc::get(ctx), From f4b41d4d4723c0ae87a8adccbfac9abe6993429c Mon Sep 17 00:00:00 2001 From: Jiahui Xu Date: Thu, 27 Nov 2025 12:23:55 +0100 Subject: [PATCH 27/27] Fix APInt cast of SExt's input --- experimental/tools/frequency-profiler/Simulator.cpp | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/experimental/tools/frequency-profiler/Simulator.cpp b/experimental/tools/frequency-profiler/Simulator.cpp index abbcfbb4c2..921871c3d1 100644 --- a/experimental/tools/frequency-profiler/Simulator.cpp +++ b/experimental/tools/frequency-profiler/Simulator.cpp @@ -227,7 +227,8 @@ LogicalResult StdExecuter::execute(mlir::arith::ShLIOp, std::vector &in, auto toShift = any_cast(in[0]).getSExtValue(); auto shiftAmount = any_cast(in[1]).getZExtValue(); auto shifted = - APInt(any_cast(in[0]).getBitWidth(), toShift << shiftAmount); + APInt(any_cast(in[0]).getBitWidth(), toShift << shiftAmount, + /* isSigned = */ true, /*implicitTrunc = */ true); out[0] = shifted; return success(); } @@ -236,11 +237,9 @@ LogicalResult StdExecuter::execute(mlir::arith::ShRSIOp, std::vector &in, std::vector &out) { auto toShift = any_cast(in[0]).getSExtValue(); auto shiftAmount = any_cast(in[1]).getZExtValue(); - auto shifted = APInt(any_cast(in[0]).getBitWidth(), toShift >> shiftAmount, /* isSigned = */ true, /*implicitTrunc = */ true); - out[0] = shifted; return success(); } @@ -249,6 +248,7 @@ LogicalResult StdExecuter::execute(mlir::arith::ShRUIOp, std::vector &in, std::vector &out) { auto toShift = any_cast(in[0]).getZExtValue(); auto shiftAmount = any_cast(in[1]).getZExtValue(); + auto shifted = APInt(any_cast(in[0]).getBitWidth(), toShift >> shiftAmount); out[0] = shifted;