diff --git a/lib/Transforms/FuncPreprocess.cpp b/lib/Transforms/FuncPreprocess.cpp index b92dedb6..ae6d1865 100644 --- a/lib/Transforms/FuncPreprocess.cpp +++ b/lib/Transforms/FuncPreprocess.cpp @@ -98,15 +98,15 @@ struct AddIRaisePattern : public OpRewritePattern { return success(); } - if (auto rhs = add.getRhs().getDefiningOp(); - isValidDim(add.getLhs())) { + auto rhs = add.getRhs().getDefiningOp(); + if (rhs != nullptr && isValidDim(add.getLhs())) { r.replaceOpWithNewOp( add, r.getAffineDimExpr(0) + rhs.value(), add.getLhs()); return success(); } - if (auto lhs = add.getLhs().getDefiningOp(); - isValidDim(add.getRhs())) { + auto lhs = add.getLhs().getDefiningOp(); + if (lhs != nullptr && isValidDim(add.getRhs())) { r.replaceOpWithNewOp( add, lhs.value() + r.getAffineDimExpr(0), add.getRhs()); return success(); @@ -125,17 +125,17 @@ struct MulIRaisePattern : public OpRewritePattern { PatternRewriter &r) const override { r.setInsertionPoint(mul); - if (auto rhs = mul.getRhs().getDefiningOp(); - isValidDim(mul.getLhs())) { + auto rhs = mul.getRhs().getDefiningOp(); + if (rhs != nullptr && isValidDim(mul.getLhs())) { r.replaceOpWithNewOp( - mul, r.getAffineDimExpr(0) * rhs.value(), mul.getLhs()); + mul, r.getAffineDimExpr(0) + rhs.value(), mul.getLhs()); return success(); } - if (auto lhs = mul.getLhs().getDefiningOp(); - isValidDim(mul.getRhs())) { + auto lhs = mul.getLhs().getDefiningOp(); + if (lhs != nullptr && isValidDim(mul.getRhs())) { r.replaceOpWithNewOp( - mul, lhs.value() * r.getAffineDimExpr(0), mul.getRhs()); + mul, lhs.value() + r.getAffineDimExpr(0), mul.getRhs()); return success(); } return failure(); diff --git a/lib/Transforms/Memory/SimplifyCopy.cpp b/lib/Transforms/Memory/SimplifyCopy.cpp index 4e054b2b..2efa63a1 100644 --- a/lib/Transforms/Memory/SimplifyCopy.cpp +++ b/lib/Transforms/Memory/SimplifyCopy.cpp @@ -26,6 +26,10 @@ struct SplitElementwiseGenericOp : public OpRewritePattern { op.getNumOutputs() == 1) { auto &input = op->getOpOperand(0); auto &output = op->getOpOperand(1); + if (input.get().getType() != output.get().getType()) { + LLVM_DEBUG(llvm::dbgs() << "\nCurrent generic: " << op << "\n";); + return failure(); + } if (input.get() == output.get()) return failure(); diff --git a/samples/pytorch/lenet/lenet.py b/samples/pytorch/lenet/lenet.py index 9b612f6c..d005ec3c 100644 --- a/samples/pytorch/lenet/lenet.py +++ b/samples/pytorch/lenet/lenet.py @@ -6,6 +6,7 @@ import torch.nn as nn import torch.nn.functional as F import torch_mlir +import torch_mlir.torchscript class LeNet(nn.Module): @@ -31,7 +32,7 @@ def forward(self, x): return out -module = torch_mlir.compile(LeNet(), torch.ones( - 1, 3, 32, 32), output_type=torch_mlir.OutputType.LINALG_ON_TENSORS) +module = torch_mlir.torchscript.compile(LeNet(), torch.ones( + 1, 3, 32, 32), output_type=torch_mlir.torchscript.OutputType.LINALG_ON_TENSORS) print(module) diff --git a/samples/pytorch/resnet18/resnet18.py b/samples/pytorch/resnet18/resnet18.py index 2b683e5f..fd8fcdf3 100644 --- a/samples/pytorch/resnet18/resnet18.py +++ b/samples/pytorch/resnet18/resnet18.py @@ -6,6 +6,7 @@ import torch.nn as nn import torch.nn.functional as F import torch_mlir +import torch_mlir.torchscript class BasicBlock(nn.Module): @@ -71,7 +72,7 @@ def ResNet18(): return ResNet(BasicBlock, [2, 2, 2, 2]) -module = torch_mlir.compile(ResNet18(), torch.ones( - 1, 3, 32, 32), output_type=torch_mlir.OutputType.LINALG_ON_TENSORS) +module = torch_mlir.torchscript.compile(ResNet18(), torch.ones( + 1, 3, 32, 32), output_type=torch_mlir.torchscript.OutputType.LINALG_ON_TENSORS) print(module) diff --git a/tools/CMakeLists.txt b/tools/CMakeLists.txt index b217d4a7..d521d092 100644 --- a/tools/CMakeLists.txt +++ b/tools/CMakeLists.txt @@ -1,3 +1,4 @@ add_subdirectory(pyscalehls) add_subdirectory(scalehls-opt) +add_subdirectory(scalehls-lsp-server) add_subdirectory(scalehls-translate) diff --git a/tools/scalehls-lsp-server/CMakeLists.txt b/tools/scalehls-lsp-server/CMakeLists.txt new file mode 100644 index 00000000..c14a6ab8 --- /dev/null +++ b/tools/scalehls-lsp-server/CMakeLists.txt @@ -0,0 +1,19 @@ +project(scalehls-lsp-server) + +add_executable(${PROJECT_NAME} + scalehls-lsp-server.cpp +) + +# Link all standard MLIR dialect and conversion libs. +get_property(dialect_libs GLOBAL PROPERTY MLIR_DIALECT_LIBS) +get_property(conversion_libs GLOBAL PROPERTY MLIR_CONVERSION_LIBS) +get_property(circt_dialect_libs GLOBAL PROPERTY CIRCT_DIALECT_LIBS) +target_link_libraries(${PROJECT_NAME} + PRIVATE + MLIRLspServerLib + + ${circt_dialect_libs} + + ${dialect_libs} + ${conversion_libs} +) diff --git a/tools/scalehls-lsp-server/scalehls-lsp-server.cpp b/tools/scalehls-lsp-server/scalehls-lsp-server.cpp new file mode 100644 index 00000000..d67f0cc1 --- /dev/null +++ b/tools/scalehls-lsp-server/scalehls-lsp-server.cpp @@ -0,0 +1,15 @@ +#include "mlir/Tools/mlir-lsp-server/MlirLspServerMain.h" +#include "scalehls/InitAllDialects.h" + +using namespace mlir; + +static int asMainReturnCode(LogicalResult r) { + return r.succeeded() ? EXIT_SUCCESS : EXIT_FAILURE; +} + +int main(int argc, char *argv[]) { + DialectRegistry registry; + scalehls::registerAllDialects(registry); + + return asMainReturnCode(MlirLspServerMain(argc, argv, registry)); +}