Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,14 @@
// https://github.com/openxla/xla/blob/main/LICENSE for the license
// information.
//
// Changes are copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
// Changes are copyright (c) 2024-2025, NVIDIA CORPORATION. All rights reserved.
//
//===----------------------------------------------------------------------===//
///
/// Implementation of a pass to convert stablehlo control flow ops to scf ops.
///
//===----------------------------------------------------------------------===//

#include "mlir-tensorrt/Conversion/Passes.h"
#include "mlir-tensorrt/Transforms/Transforms.h"
#include "mlir/Analysis/SliceAnalysis.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
Expand Down Expand Up @@ -538,4 +538,4 @@ struct StablehloToScfPass
}
}
};
} // namespace
} // namespace
Original file line number Diff line number Diff line change
Expand Up @@ -297,7 +297,7 @@ convertFuncRegionTypes(RewriterBase &rewriter, FunctionOpInterface funcOp,
/// change are RankedTensorTypes where the encoding has been updated. Therefore,
/// we only insert `tensor.cast` operations to cast the values back to their
/// original types.
struct LogicalResult convertFuncUsers(RewriterBase &rewriter,
static LogicalResult convertFuncUsers(RewriterBase &rewriter,
FunctionOpInterface func,
const SymbolUserMap &userMap) {
OpBuilder::InsertionGuard g(rewriter);
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
//===- PostClusteringValidation.cpp ---------------------------------------===//
//
// SPDX-FileCopyrightText: Copyright 2024 NVIDIA CORPORATION & AFFILIATES.
// SPDX-FileCopyrightText: Copyright 2024-2025 NVIDIA CORPORATION & AFFILIATES.
// All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
Expand All @@ -27,6 +27,7 @@
#include "mlir-tensorrt-dialect/TensorRT/IR/TensorRTDialect.h"
#include "mlir-tensorrt/Dialect/CUDA/IR/CUDADialect.h"
#include "mlir-tensorrt/Dialect/Plan/IR/Plan.h"
#include "mlir-tensorrt/Dialect/Plan/Transforms/Passes.h"
#include "mlir-tensorrt/Dialect/TensorRTRuntime/IR/TensorRTRuntime.h"
#include "mlir-tensorrt/Transforms/Passes.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ constexpr std::pair<int64_t, int64_t> OpToHostParametersOffsetAndSize() {
// Simple macro for creating the appearance of a table.
#define CASE(OpType, start, size) \
if constexpr (std::is_same_v<T, OpType>) \
return {start, size}
return { start, size }

CASE(stablehlo::DynamicIotaOp, 0, 1);
// Note that `stablehlo.dynamic_slice` and `stablehlo.dynamic_update_slice`
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
//===- MemRefCastElimination.cpp ------------------------------------------===//
//
// SPDX-FileCopyrightText: Copyright 2024 NVIDIA CORPORATION & AFFILIATES.
// SPDX-FileCopyrightText: Copyright 2024-2025 NVIDIA CORPORATION & AFFILIATES.
// All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
Expand Down Expand Up @@ -41,7 +41,7 @@ using namespace mlir;
/// "then" and "else" terminator operands are produced by "compatible" cast
/// operations that can be moved to act on the result of the if. The types of
/// the new `scf.if` operation are returned as well.
FailureOr<std::pair<SmallVector<int64_t>, SmallVector<Type>>>
static FailureOr<std::pair<SmallVector<int64_t>, SmallVector<Type>>>
isCastEliminationCandidate(scf::IfOp op) {
SmallVector<int64_t> resultIndices;
SmallVector<Type> newResultTypes;
Expand Down Expand Up @@ -139,4 +139,4 @@ class MemRefCastEliminationPass
}
}
};
} // namespace
} // namespace
Original file line number Diff line number Diff line change
Expand Up @@ -87,9 +87,9 @@ static std::string printWithoutRegions(Operation *op) {
/// Unrolls `op` if its trip count is static and less than `unrollThreshold`.
/// Returns `success()` if the loop is unrolled or ignored, `failure()` if the
/// transformation fails.
LogicalResult unrollForLoopWithStaticTripCount(IRRewriter &rewriter,
scf::ForOp op,
uint64_t unrollThreshold) {
static LogicalResult
unrollForLoopWithStaticTripCount(IRRewriter &rewriter, scf::ForOp op,
uint64_t unrollThreshold) {
std::optional<int64_t> tripCount = getConstantTripCount(op);
if (!tripCount)
return success();
Expand Down
4 changes: 1 addition & 3 deletions mlir-tensorrt/compiler/test/lib/Analysis/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
add_mlir_tensorrt_test_library(MLIRTensorRTTestTensorKindAnalysis
TestTensorKindAnalysis.cpp
add_mlir_tensorrt_test_library(MLIRTensorRTCompilerAnalysisTestPasses
TestBoundsAnalysis.cpp

LINK_LIBS PUBLIC
MLIRTensorRTAnalysis
MLIRTensorRTPlanAnalysis

MLIR_LIBS PUBLIC
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
//===- TestBoundsAnalysis.cpp ---------------------------------------------===//
//
// Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
// Copyright (c) 2024-2025, NVIDIA CORPORATION. All rights reserved.
//
//===----------------------------------------------------------------------===//
///
Expand All @@ -20,6 +20,10 @@
using namespace mlir;
using namespace mlir::dataflow;

namespace mlir {
void registerTestBoundsAnalysisPass();
}

/// Print out the lattice information for the given value `v`.
template <typename T>
static void printLatticeInfo(llvm::raw_ostream &os, Value v,
Expand Down Expand Up @@ -135,9 +139,7 @@ struct TestTensorValueBoundsAnalysisPass
};
} // namespace

namespace mlir {
void registerTestBoundsAnalysisPass() {
void mlir::registerTestBoundsAnalysisPass() {
PassRegistration<TestBoundsAnalysisPass>();
PassRegistration<TestTensorValueBoundsAnalysisPass>();
}
} // namespace mlir
Original file line number Diff line number Diff line change
Expand Up @@ -31,17 +31,19 @@ using namespace llvm;

#ifdef MLIR_TRT_ENABLE_TESTING
namespace mlir {
namespace tensorrt {
void registerTestTensorKindAnalysisPass();
void registerTestTensorRTShapeInferencePass();
} // namespace tensorrt

#ifdef MLIR_TRT_ENABLE_HLO
void registerTestBoundsAnalysisPass();
#endif // MLIR_TRT_ENABLE_HLO
} // namespace mlir

static void registerTestPasses() {
mlir::registerTestTensorKindAnalysisPass();
mlir::registerTestTensorRTShapeInferencePass();
mlir::tensorrt::registerTestTensorKindAnalysisPass();
mlir::tensorrt::registerTestTensorRTShapeInferencePass();
IF_MLIR_TRT_ENABLE_HLO({ mlir::registerTestBoundsAnalysisPass(); });
}
#endif // MLIR_TRT_ENABLE_TESTING
Expand Down
10 changes: 5 additions & 5 deletions mlir-tensorrt/executor/lib/CAPI/Common/Common.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
//===- Common.cpp -----0---------------------------------------------------===//
//===- Common.cpp ---------------------------------------------------------===//
//
// SPDX-FileCopyrightText: Copyright 2024 NVIDIA CORPORATION & AFFILIATES.
// SPDX-FileCopyrightText: Copyright 2024-2025 NVIDIA CORPORATION & AFFILIATES.
// All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
Expand Down Expand Up @@ -398,7 +398,7 @@ mtrtFunctionSignatureGetNumOutputArgs(MTRT_FunctionSignature signature,
return mtrtStatusGetOk();
}

MTRT_Status getTypeHelper(TypeUnionView typeUnionView, MTRT_Type *type) {
static MTRT_Status getTypeHelper(TypeUnionView typeUnionView, MTRT_Type *type) {
// Allocate the TypeUnion object, populate it by moving in the
// concrete object, and release it to be owned by the CAPI object.
auto typeUnion = std::make_unique<mtrt::flat::TypeUnion>();
Expand Down Expand Up @@ -466,8 +466,8 @@ mtrtFunctionSignatureGetShapeFuncName(MTRT_FunctionSignature signature,
return mtrtStatusGetOk();
}

MTRT_Status getBoundsHelper(BoundsUnionView boundsUnionView,
MTRT_Bounds *bounds) {
static MTRT_Status getBoundsHelper(BoundsUnionView boundsUnionView,
MTRT_Bounds *bounds) {
// Allocate the BoundsUnion object, populate it by moving in the
// concrete object, and release it to be owned by the CAPI object.
auto boundsUnion = std::make_unique<mtrt::flat::BoundsUnion>();
Expand Down
4 changes: 2 additions & 2 deletions mlir-tensorrt/executor/lib/Runtime/API/API.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
//===- Executable.cpp ------ ----------------------------------------------===//
//
// SPDX-FileCopyrightText: Copyright 2024 NVIDIA CORPORATION & AFFILIATES.
// SPDX-FileCopyrightText: Copyright 2024-2025 NVIDIA CORPORATION & AFFILIATES.
// All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
Expand Down Expand Up @@ -549,7 +549,7 @@ struct CompletionToken
// completes. userData is a heap-allocated pointer to
// std::shared_ptr<CompletionToken<T>>. Implementation posts a small job to IO
// thread pool and returns.
void cuda_event_host_callback(void *userData) {
static void cuda_event_host_callback(void *userData) {
// // userData is pointer-to-heap-allocated Ref<CompletionToken>.
Ref<CompletionToken> *tokenPtr =
static_cast<Ref<CompletionToken> *>(userData);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ int32_t mtrt_cuda_set_active_device(int32_t device) {
return device;
}

int32_t mtrt_get_device(int32_t device) { return device; }
int32_t mtrt_cuda_get_device(int32_t device) { return device; }

CUstream mtrt_cuda_stream_create() {
CUstream stream;
Expand Down Expand Up @@ -108,12 +108,6 @@ void mtrt_cuda_free(CUstream stream, void *ptr, int8_t isHostPinned,
HANDLE_CUDART_ERROR(cudaFreeAsync(ptr, stream), );
}

CUfunction mtrt_cumodule_load_func(CUmodule module, const char *funcName) {
CUfunction func;
cuModuleGetFunction(&func, module, funcName);
return func;
}

static StatusOr<std::string> getDeviceArch(int32_t deviceNumber) {
CUdevice deviceID;
RETURN_ERROR_WITH_MSG_IF_CUDADRV_ERROR(cuDeviceGet(&deviceID, deviceNumber),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -530,6 +530,11 @@ registerCudaMemoryManagementOps(sol::state_view &lua,
};
}

#if defined(__GNUC__) || defined(__clang__)
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wmissing-prototypes"
#endif

namespace mtrt {
void registerLuaCudaRuntimeExtension() {
registerLuaRuntimeExtension(
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
//===- CoreModule.cpp -----------------------------------------------------===//
//
// SPDX-FileCopyrightText: Copyright 2024 NVIDIA CORPORATION & AFFILIATES.
// SPDX-FileCopyrightText: Copyright 2024-2025 NVIDIA CORPORATION & AFFILIATES.
// All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
Expand Down Expand Up @@ -43,6 +43,10 @@
using namespace mtrt;
using namespace mtrt;

namespace mtrt {
void registerLuaCoreRuntimeExtension();
}

//===----------------------------------------------------------------------===//
// Templated helpers
//===----------------------------------------------------------------------===//
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -578,6 +578,11 @@ static void registerExecutorCuBLASModuleLuaRuntimeMethods(
};
}

#if defined(__GNUC__) || defined(__clang__)
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wmissing-prototypes"
#endif

namespace mtrt {
void registerLuaCublasRuntimeExtension() {
registerLuaRuntimeExtension(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -792,6 +792,11 @@ static void registerExecutorTensorRTModuleLuaRuntimeMethods(
};
}

#if defined(__GNUC__) || defined(__clang__)
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wmissing-prototypes"
#endif

namespace mtrt {
void registerLuaTensorRTRuntimeExtension() {
registerLuaRuntimeExtension(
Expand Down
2 changes: 1 addition & 1 deletion mlir-tensorrt/executor/lib/Support/DeviceInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ static Status makeCudaStringError(cudaError_t errCode,
}
#endif // MLIR_TRT_ENABLE_CUDA

static StatusOr<DeviceInfo>
[[maybe_unused]] static StatusOr<DeviceInfo>
getDeviceInformationFromHostImpl(int cudaDeviceOridinal) {
#ifdef MLIR_TRT_ENABLE_CUDA
cudaDeviceProp properties;
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
//===- TranslateToRuntimeExecutable.cpp -----------------------------------===//
//
// SPDX-FileCopyrightText: Copyright 2024 NVIDIA CORPORATION & AFFILIATES.
// SPDX-FileCopyrightText: Copyright 2024-2025 NVIDIA CORPORATION & AFFILIATES.
// All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
Expand Down Expand Up @@ -479,7 +479,7 @@ static mtrt::flat::FunctionSignatureT generateSignature() {
return signature;
}

LogicalResult
static LogicalResult
translateBoundsIfPresent(FunctionOpInterface func, unsigned argIndex,
mtrt::flat::FunctionSignatureT &signature,
bool isInput) {
Expand Down
10 changes: 6 additions & 4 deletions mlir-tensorrt/executor/test/lib/BufferizationTestPass.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
//===- BufferizationTestPass.cpp ------------------------------------------===//
//
// Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
// Copyright (c) 2024-2025, NVIDIA CORPORATION. All rights reserved.
//
//===----------------------------------------------------------------------===//
///
Expand All @@ -22,6 +22,10 @@
using namespace mlir;
using namespace mlir::executor;

namespace mlir::executor {
void registerTestExecutorBufferizePass();
}

namespace {
class ExecutorBufferizationTestPass
: public PassWrapper<ExecutorBufferizationTestPass,
Expand Down Expand Up @@ -55,8 +59,7 @@ class ExecutorBufferizationTestPass
};
} // namespace

namespace mlir::executor {
void registerTestExecutorBufferizePass() {
void executor::registerTestExecutorBufferizePass() {
PassRegistration<ExecutorBufferizationTestPass>();

PassPipelineRegistration<> executorBufferizationPipeline(
Expand All @@ -71,4 +74,3 @@ void registerTestExecutorBufferizePass() {
pm.addPass(createCanonicalizerPass());
});
}
} // namespace mlir::executor
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
//===- TestClustering.cpp ------------------------------------------------===//
//
// Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
// Copyright (c) 2023-2025, NVIDIA CORPORATION. All rights reserved.
//
//===----------------------------------------------------------------------===//
///
Expand All @@ -23,8 +23,12 @@

using namespace mlir;

scf::ExecuteRegionOp createScfRegionOpFromCluster(const Cluster &cluster,
RewriterBase &rewriter) {
namespace mlir::executor {
void registerTestClusteringTransformPass();
}

static scf::ExecuteRegionOp
createScfRegionOpFromCluster(const Cluster &cluster, RewriterBase &rewriter) {
return cast<scf::ExecuteRegionOp>(mlir::createRegionOpFromCluster(
cluster, rewriter,
[](OpBuilder &b, Location loc, TypeRange types, Attribute target) {
Expand Down Expand Up @@ -215,8 +219,6 @@ class TestClusteringPass
};
} // namespace

namespace mlir::executor {
void registerTestClusteringTransformPass() {
void executor::registerTestClusteringTransformPass() {
PassRegistration<TestClusteringPass>();
}
} // namespace mlir::executor
Original file line number Diff line number Diff line change
Expand Up @@ -376,7 +376,7 @@ nvinfer1::ILayer *NvInferNetworkEncoder::addDequantizeLayer(
#endif
}

nvinfer1::IFillLayer *populateFillLayerParameters(
static nvinfer1::IFillLayer *populateFillLayerParameters(
nvinfer1::IFillLayer *layer, const nvinfer1::Dims &staticShape,
nvinfer1::ITensor *dynamicShape, std::optional<double> alpha,
std::optional<double> beta, nvinfer1::ITensor *dynamicAlpha,
Expand Down
3 changes: 2 additions & 1 deletion mlir-tensorrt/tensorrt/lib/TensorRT/IR/TensorRT.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2011,7 +2011,8 @@ void tensorrt::IfOp::build(
/// https://docs.nvidia.com/deeplearning/tensorrt/operators/index.html#layers-flow-control-constructs
/// However, it is missing some checks on convolution/activation/fill/unary ops
/// and therefore may give false positives.
bool isOperationSupportedInControlFlowBranchRegion(TensorRTOpInterface op) {
static bool
isOperationSupportedInControlFlowBranchRegion(TensorRTOpInterface op) {
return !isa<PaddingOp, DeconvolutionOp, ParametricReLUOp, PoolingOp,
RaggedSoftMaxOp, ResizeNearestOp, ResizeLinearOp, ResizeCubicOp>(
op);
Expand Down
Loading