Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 10 additions & 10 deletions lib/Transforms/FuncPreprocess.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -98,15 +98,15 @@ struct AddIRaisePattern : public OpRewritePattern<arith::AddIOp> {
return success();
}

if (auto rhs = add.getRhs().getDefiningOp<arith::ConstantIndexOp>();
isValidDim(add.getLhs())) {
auto rhs = add.getRhs().getDefiningOp<arith::ConstantIndexOp>();
if (rhs != nullptr && isValidDim(add.getLhs())) {
r.replaceOpWithNewOp<mlir::AffineApplyOp>(
add, r.getAffineDimExpr(0) + rhs.value(), add.getLhs());
return success();
}

if (auto lhs = add.getLhs().getDefiningOp<arith::ConstantIndexOp>();
isValidDim(add.getRhs())) {
auto lhs = add.getLhs().getDefiningOp<arith::ConstantIndexOp>();
if (lhs != nullptr && isValidDim(add.getRhs())) {
r.replaceOpWithNewOp<mlir::AffineApplyOp>(
add, lhs.value() + r.getAffineDimExpr(0), add.getRhs());
return success();
Expand All @@ -125,17 +125,17 @@ struct MulIRaisePattern : public OpRewritePattern<arith::MulIOp> {
PatternRewriter &r) const override {
r.setInsertionPoint(mul);

if (auto rhs = mul.getRhs().getDefiningOp<arith::ConstantIndexOp>();
isValidDim(mul.getLhs())) {
auto rhs = mul.getRhs().getDefiningOp<arith::ConstantIndexOp>();
if (rhs != nullptr && isValidDim(mul.getLhs())) {
r.replaceOpWithNewOp<mlir::AffineApplyOp>(
mul, r.getAffineDimExpr(0) * rhs.value(), mul.getLhs());
mul, r.getAffineDimExpr(0) + rhs.value(), mul.getLhs());
return success();
}

if (auto lhs = mul.getLhs().getDefiningOp<arith::ConstantIndexOp>();
isValidDim(mul.getRhs())) {
auto lhs = mul.getLhs().getDefiningOp<arith::ConstantIndexOp>();
if (lhs != nullptr && isValidDim(mul.getRhs())) {
r.replaceOpWithNewOp<mlir::AffineApplyOp>(
mul, lhs.value() * r.getAffineDimExpr(0), mul.getRhs());
mul, lhs.value() + r.getAffineDimExpr(0), mul.getRhs());
return success();
}
return failure();
Expand Down
4 changes: 4 additions & 0 deletions lib/Transforms/Memory/SimplifyCopy.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,10 @@ struct SplitElementwiseGenericOp : public OpRewritePattern<linalg::GenericOp> {
op.getNumOutputs() == 1) {
auto &input = op->getOpOperand(0);
auto &output = op->getOpOperand(1);
if (input.get().getType() != output.get().getType()) {
LLVM_DEBUG(llvm::dbgs() << "\nCurrent generic: " << op << "\n";);
return failure();
}
if (input.get() == output.get())
return failure();

Expand Down
5 changes: 3 additions & 2 deletions samples/pytorch/lenet/lenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import torch.nn as nn
import torch.nn.functional as F
import torch_mlir
import torch_mlir.torchscript


class LeNet(nn.Module):
Expand All @@ -31,7 +32,7 @@ def forward(self, x):
return out


module = torch_mlir.compile(LeNet(), torch.ones(
1, 3, 32, 32), output_type=torch_mlir.OutputType.LINALG_ON_TENSORS)
module = torch_mlir.torchscript.compile(LeNet(), torch.ones(
1, 3, 32, 32), output_type=torch_mlir.torchscript.OutputType.LINALG_ON_TENSORS)

print(module)
5 changes: 3 additions & 2 deletions samples/pytorch/resnet18/resnet18.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import torch.nn as nn
import torch.nn.functional as F
import torch_mlir
import torch_mlir.torchscript


class BasicBlock(nn.Module):
Expand Down Expand Up @@ -71,7 +72,7 @@ def ResNet18():
return ResNet(BasicBlock, [2, 2, 2, 2])


module = torch_mlir.compile(ResNet18(), torch.ones(
1, 3, 32, 32), output_type=torch_mlir.OutputType.LINALG_ON_TENSORS)
module = torch_mlir.torchscript.compile(ResNet18(), torch.ones(
1, 3, 32, 32), output_type=torch_mlir.torchscript.OutputType.LINALG_ON_TENSORS)

print(module)
1 change: 1 addition & 0 deletions tools/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
add_subdirectory(pyscalehls)
add_subdirectory(scalehls-opt)
add_subdirectory(scalehls-lsp-server)
add_subdirectory(scalehls-translate)
19 changes: 19 additions & 0 deletions tools/scalehls-lsp-server/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
project(scalehls-lsp-server)

add_executable(${PROJECT_NAME}
scalehls-lsp-server.cpp
)

# Link all standard MLIR dialect and conversion libs.
get_property(dialect_libs GLOBAL PROPERTY MLIR_DIALECT_LIBS)
get_property(conversion_libs GLOBAL PROPERTY MLIR_CONVERSION_LIBS)
get_property(circt_dialect_libs GLOBAL PROPERTY CIRCT_DIALECT_LIBS)
target_link_libraries(${PROJECT_NAME}
PRIVATE
MLIRLspServerLib

${circt_dialect_libs}

${dialect_libs}
${conversion_libs}
)
15 changes: 15 additions & 0 deletions tools/scalehls-lsp-server/scalehls-lsp-server.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
#include "mlir/Tools/mlir-lsp-server/MlirLspServerMain.h"
#include "scalehls/InitAllDialects.h"

using namespace mlir;

static int asMainReturnCode(LogicalResult r) {
return r.succeeded() ? EXIT_SUCCESS : EXIT_FAILURE;
}

int main(int argc, char *argv[]) {
DialectRegistry registry;
scalehls::registerAllDialects(registry);

return asMainReturnCode(MlirLspServerMain(argc, argv, registry));
}