From 57c5c8df30969e0ab411e1a2f80cb75cee22732b Mon Sep 17 00:00:00 2001 From: Yizhuo Zhang Date: Fri, 31 Oct 2025 16:28:38 -0700 Subject: [PATCH 1/2] Add AttentionOp --- mlir-tensorrt/build_tools/docker/Dockerfile | 2 +- .../TensorRT/IR/TensorRTEnums.td | 38 ++++ .../TensorRT/IR/TensorRTOps.td | 168 ++++++++++++++++++ .../IR/TensorRTVersionCompatibility.cpp | 13 ++ 4 files changed, 220 insertions(+), 1 deletion(-) diff --git a/mlir-tensorrt/build_tools/docker/Dockerfile b/mlir-tensorrt/build_tools/docker/Dockerfile index bb2996369..cf8e5df11 100644 --- a/mlir-tensorrt/build_tools/docker/Dockerfile +++ b/mlir-tensorrt/build_tools/docker/Dockerfile @@ -35,7 +35,7 @@ case "${LINUX_DISTRO}" in dnf install -y \ which wget gcc zlib-devel bzip2 bzip2-devel readline-devel sqlite \ sqlite-devel xz xz-devel libffi-devel curl git ncurses-devel \ - openssh-clients libcudnn8-devel zip jq \ + openssh-clients zip jq \ protobuf-compiler autoconf automake libtool dnf-plugins-core cmake dnf config-manager --set-enabled powertools dnf -y install gcc-toolset-11-gcc gcc-toolset-11-gcc-c++ diff --git a/mlir-tensorrt/tensorrt/include/mlir-tensorrt-dialect/TensorRT/IR/TensorRTEnums.td b/mlir-tensorrt/tensorrt/include/mlir-tensorrt-dialect/TensorRT/IR/TensorRTEnums.td index 0bb4e91fd..4d4fd144e 100644 --- a/mlir-tensorrt/tensorrt/include/mlir-tensorrt-dialect/TensorRT/IR/TensorRTEnums.td +++ b/mlir-tensorrt/tensorrt/include/mlir-tensorrt-dialect/TensorRT/IR/TensorRTEnums.td @@ -378,4 +378,42 @@ def TensorRT_ScatterMode : TensorRT_I32EnumAttr< def TensorRT_ScatterModeAttr : TensorRT_EnumAttr{ } +def TensorRT_AttentionNormalizationOp : TensorRT_I32EnumAttr< + "AttentionNormalizationOp", "", + [ + I32EnumAttrCase<"kNONE", 0>, + I32EnumAttrCase<"kSOFTMAX", 1> + ]> +{ + let cppNamespace = "::mlir::tensorrt"; + let genSpecializedAttr = 0; +} + +def TensorRT_AttentionNormalizationOpAttr : TensorRT_EnumAttr{ +} + +def TensorRT_DataType : TensorRT_I32EnumAttr< + "DataType", "", + [ + I32EnumAttrCase<"kFLOAT", 0>, + I32EnumAttrCase<"kHALF", 1>, + I32EnumAttrCase<"kINT8", 2>, + I32EnumAttrCase<"kINT32", 3>, + I32EnumAttrCase<"kBOOL", 4>, + I32EnumAttrCase<"kUINT8", 5>, + I32EnumAttrCase<"kFP8", 6>, + I32EnumAttrCase<"kBF16", 7>, + I32EnumAttrCase<"kINT64", 8>, + I32EnumAttrCase<"kINT4", 9>, + I32EnumAttrCase<"kFP4", 10>, + I32EnumAttrCase<"kE8M0", 11> + ]> +{ + let cppNamespace = "::mlir::tensorrt"; + let genSpecializedAttr = 0; +} + +def TensorRT_DataTypeAttr : TensorRT_EnumAttr{ +} + #endif // MLIR_TENSORRT_DIALECT_TENSORRT_IR_TENSORRTENUMS diff --git a/mlir-tensorrt/tensorrt/include/mlir-tensorrt-dialect/TensorRT/IR/TensorRTOps.td b/mlir-tensorrt/tensorrt/include/mlir-tensorrt-dialect/TensorRT/IR/TensorRTOps.td index e11ef94e6..20b012e1e 100644 --- a/mlir-tensorrt/tensorrt/include/mlir-tensorrt-dialect/TensorRT/IR/TensorRTOps.td +++ b/mlir-tensorrt/tensorrt/include/mlir-tensorrt-dialect/TensorRT/IR/TensorRTOps.td @@ -4432,4 +4432,172 @@ def TensorRT_ScatterElementsOp : TensorRT_Op<"scatter_elements", }]; } +//===----------------------------------------------------------------------===// +// AttentionOp +//===----------------------------------------------------------------------===// + +def TensorRT_AttentionOp : TensorRT_Op<"attention", + [Pure, AttrSizedOperandSegments, TensorRTInferTensorResultTypes, + AllElementTypesMatch<["query", "key", "value"]>, + AllRanksMatch<["query", "key", "value"]>]>{ + let summary = "TensorRT attention (IAttention) operation"; + let description = [{ + The `tensorrt.attention` operation implements a fused attention mechanism + that consumes query, key, and value tensors. The operation implicitly includes + two matrix multiplication layers (BMM1 and BMM2) and a normalization operation + (typically softmax). + + By default, TensorRT will try to use a single fused kernel for better efficiency. + The operation can optionally be decomposed into multiple kernels if no fused + kernel is available by setting `decomposable` to true. + + #### Architecture: + + ``` + Query Key Value Mask (optional) NormalizationQuantizeScale (optional) + | | | | | + | Transpose | | | + | | | | | + ----BMM1---- | | | + | | | | + *--------------------------- | + | | | + Normalization | | + | | | + *------------------------------------------------ + | | + -------BMM2------ + | + Output + ``` + + #### Inputs: + + - Query: tensor of type f32, f16, or bf16 with shape + [batchSize, numHeadsQuery, sequenceLengthQuery, dimHead] + - Key: tensor of type f32, f16, or bf16 with shape + [batchSize, numHeadsKeyValue, sequenceLengthKeyValue, dimHead] + - Value: tensor of type f32, f16, or bf16 with shape + [batchSize, numHeadsKeyValue, sequenceLengthKeyValue, dimHead] + - Mask (optional): tensor of type i1 or same type as BMM1 output with shape + [batchSize, numHeadsQuery, sequenceLengthQuery, sequenceLengthKeyValue] + where batchSize and numHeadsQuery are broadcastable. For i1 mask, true + indicates the position is allowed to attend. For other types, mask values + are added to BMM1 output. + - NormalizationQuantizeScale (optional): tensor of type f32, f16, or bf16 + with rank 0 or 1, used for quantizing the normalization output. + + #### Attributes: + + - normalization_operation: The normalization operation to use (default: kSOFTMAX) + - causal: Whether to use causal masking (default: false). Cannot be used with mask input. + - decomposable: Whether the operation can be decomposed (default: false) + - normalization_quantize_to_type: Optional output type for quantized normalization. + When specified, must be one of kFP8 or kINT8. Requires normalization_quantize_scale input to be provided. + + #### Constraints: + + - All query, key, and value tensors must be rank 4 with shape [batchSize, numHeads, sequenceLength, dimHead] + - Query, key, and value must have the same element type (f32, f16, or bf16) + - If normalization_quantize_to_type is specified: + * It must be kFP8 or kINT8 + * normalization_quantize_scale input must be provided + - Cannot use both mask input and causal=true simultaneously + + #### Examples: + + Basic attention: + ```mlir + %output = tensorrt.attention ins(%query, %key, %value : + tensor<2x8x128x64xf16>, tensor<2x8x128x64xf16>, tensor<2x8x128x64xf16>) + -> tensor<2x8x128x64xf16> + ``` + + Causal attention: + ```mlir + %output_causal = tensorrt.attention {causal = true} ins(%query, %key, %value : + tensor<2x8x128x64xf16>, tensor<2x8x128x64xf16>, tensor<2x8x128x64xf16>) + -> tensor<2x8x128x64xf16> + ``` + + Attention with quantization: + ```mlir + %scale = tensorrt.constant dense<1.0> : tensor + %output_quant = tensorrt.attention { + normalization_quantize_to_type = #tensorrt.data_type + } ins(%query, %key, %value, + normalization_quantize_scale = %scale : + tensor<2x8x128x64xf16>, tensor<2x8x128x64xf16>, + tensor<2x8x128x64xf16>, tensor) + -> tensor<2x8x128x64xf16> + ``` + }]; + + let arguments = (ins + TensorRT_RankedTensorOf<[F16, BF16, F32]>:$query, + TensorRT_RankedTensorOf<[F16, BF16, F32]>:$key, + TensorRT_RankedTensorOf<[F16, BF16, F32]>:$value, + Optional:$mask, + Optional>:$normalization_quantize_scale, + OptionalAttr:$normalization_operation, + DefaultValuedAttr:$causal, + DefaultValuedAttr:$decomposable, + OptionalAttr:$normalization_quantize_to_type + ); + + let results = (outs TensorRT_RankedTensorOf<[F16, BF16, F32]>:$result); + + let assemblyFormat = [{ + attr-dict `ins` `(` $query `,` $key `,` $value + (`,` `mask` `=` $mask^)? + (`,` `normalization_quantize_scale` `=` $normalization_quantize_scale^)? + `:` type($query) `,` type($key) `,` type($value) + (`,` type($mask)^)? + (`,` type($normalization_quantize_scale)^)? + `)` `->` type($result) + }]; + + let hasVerifier = 1; + + let extraClassDeclaration = [{ + /// Returns true if created op is valid for TensorRT major version. + bool isValidForTensorRTVersion(int64_t trtMajorVersion); + }] # baseClassDeclaration; + + let trtLayerAdd = [{ + // Get normalization operation, default to kSOFTMAX + nvinfer1::AttentionNormalizationOp normOp = $normalization_operation + ? *$normalization_operation + : nvinfer1::AttentionNormalizationOp::kSOFTMAX; + + nvinfer1::IAttention *layer = $net->addAttention(*$query, *$key, *$value, normOp, $causal); + if (!layer) + return failure(); + + if ($mask) + layer->setMask(*$mask); + + layer->setDecomposable($decomposable); + + if ($normalization_quantize_scale) { + layer->setNormalizationQuantizeScale(*$normalization_quantize_scale); + } + + if ($normalization_quantize_to_type) { + layer->setNormalizationQuantizeToType(*$normalization_quantize_to_type); + } + + if (!$e.isStronglyTyped()){ + FailureOr outputTrtType = getNvInferDataType($op.getLoc(), + $op.getType().getElementType()); + if (failed(outputTrtType)) + return failure(); + layer->setOutputType(0, *outputTrtType); + } + + $results.push_back(layer->getOutput(0)); + $e.setMetadata(layer, $op); + }]; +} + #endif // MLIR_TENSORRT_DIALECT_TENSORRT_IR_TENSORRTOPS_TD diff --git a/mlir-tensorrt/tensorrt/lib/TensorRT/IR/TensorRTVersionCompatibility.cpp b/mlir-tensorrt/tensorrt/lib/TensorRT/IR/TensorRTVersionCompatibility.cpp index 8b4413952..1b85e4787 100644 --- a/mlir-tensorrt/tensorrt/lib/TensorRT/IR/TensorRTVersionCompatibility.cpp +++ b/mlir-tensorrt/tensorrt/lib/TensorRT/IR/TensorRTVersionCompatibility.cpp @@ -914,3 +914,16 @@ bool tensorrt::ScatterElementsOp::isValidForTensorRTVersion( return isValidForTensorRTVersionScatterOpImpl( trtMajorVersion, dataElementType, indicesElementType); } + +//===----------------------------------------------------------------------===// +// AttentionOp +//===----------------------------------------------------------------------===// + +bool tensorrt::AttentionOp::isValidForTensorRTVersion( + int64_t trtMajorVersion) { + // IAttention layer is only supported in TensorRT >= 10.14.0 + if (trtMajorVersion < 10) + return false; + + return true; +} From dcfc0c74d9f3509a70e104a21d42330795617563 Mon Sep 17 00:00:00 2001 From: Yizhuo Zhang Date: Mon, 10 Nov 2025 17:34:49 -0800 Subject: [PATCH 2/2] Fix Attention addLayer, make cmake to work with TRT 10.14 --- mlir-tensorrt/CMakePresets.json | 14 + .../build_tools/cmake/MTRTDependencies.cmake | 4 +- .../cmake/TensorRTDownloadURL.cmake | 7 + .../dialects/test_tensorrt.py | 4 + mlir-tensorrt/compiler/tools/CMakeLists.txt | 2 +- .../integrations/python/setup_utils.py | 6 +- .../TensorRTAttributes.h | 16 + .../TensorRT/IR/TensorRTOps.td | 333 +++++++++--------- .../lib/Bindings/Python/DialectTensorRT.cpp | 2 + .../tensorrt/lib/CAPI/TensorRTAttributes.cpp | 8 + .../IR/TypeInferenceInterfaceImpls.cpp | 18 + .../tensorrt/lib/TensorRT/IR/Verification.cpp | 56 +++ .../test/Target/TensorRT/TRT10/attention.mlir | 72 ++++ 13 files changed, 368 insertions(+), 174 deletions(-) create mode 100644 mlir-tensorrt/tensorrt/test/Target/TensorRT/TRT10/attention.mlir diff --git a/mlir-tensorrt/CMakePresets.json b/mlir-tensorrt/CMakePresets.json index 9c575f9e5..ebea04a7f 100644 --- a/mlir-tensorrt/CMakePresets.json +++ b/mlir-tensorrt/CMakePresets.json @@ -100,6 +100,20 @@ "MLIR_TRT_ENABLE_NCCL": "OFF", "MLIR_TRT_DOWNLOAD_TENSORRT_VERSION": "$env{DOWNLOAD_TENSORRT_VERSION}" } + }, + { + "name": "python-wheel-build", + "displayName": "Configuration for building the compiler/runtime Python package wheels", + "generator": "Ninja", + "binaryDir": "build", + "inherits": "ninja-llvm", + "cacheVariables": { + "CMAKE_BUILD_TYPE": "Release", + "LLVM_ENABLE_ASSERTIONS": "OFF", + "CMAKE_PLATFORM_NO_VERSIONED_SONAME": "ON", + "MLIR_TRT_ENABLE_NCCL": "OFF", + "MLIR_TRT_DOWNLOAD_TENSORRT_VERSION": "$env{DOWNLOAD_TENSORRT_VERSION}" + } } ] } \ No newline at end of file diff --git a/mlir-tensorrt/build_tools/cmake/MTRTDependencies.cmake b/mlir-tensorrt/build_tools/cmake/MTRTDependencies.cmake index 1eb9efa8b..c16573523 100644 --- a/mlir-tensorrt/build_tools/cmake/MTRTDependencies.cmake +++ b/mlir-tensorrt/build_tools/cmake/MTRTDependencies.cmake @@ -57,8 +57,8 @@ macro(configure_tensorrt_python_plugin_header) find_file( trt_python_plugin_header NAMES NvInferPythonPlugin.h plugin.h - HINTS ${ARG_INSTALL_DIR} ${ARG_INSTALL_DIR}/python/include/impl - PATHS ${ARG_INSTALL_DIR} ${ARG_INSTALL_DIR}/python/include/impl + HINTS ${ARG_INSTALL_DIR} ${ARG_INSTALL_DIR}/python/include/impl ${ARG_INSTALL_DIR}/include/impl + PATHS ${ARG_INSTALL_DIR} ${ARG_INSTALL_DIR}/python/include/impl ${ARG_INSTALL_DIR}/include/impl REQUIRED NO_CMAKE_PATH NO_DEFAULT_PATH NO_CACHE diff --git a/mlir-tensorrt/build_tools/cmake/TensorRTDownloadURL.cmake b/mlir-tensorrt/build_tools/cmake/TensorRTDownloadURL.cmake index 7a1745f1e..394f4ea70 100644 --- a/mlir-tensorrt/build_tools/cmake/TensorRTDownloadURL.cmake +++ b/mlir-tensorrt/build_tools/cmake/TensorRTDownloadURL.cmake @@ -80,6 +80,10 @@ function(mtrt_get_tensorrt_download_url ARG_VERSION OS_NAME TARGET_ARCH ARG_OUT_ set(ARG_VERSION "10.12.0.36") endif() + if(ARG_VERSION VERSION_EQUAL "10.14") + set(ARG_VERSION "10.14.1.48") + endif() + set(downloadable_versions "8.6.1.6" "9.0.1.4" "9.1.0.4" "9.2.0.5" @@ -97,6 +101,7 @@ function(mtrt_get_tensorrt_download_url ARG_VERSION OS_NAME TARGET_ARCH ARG_OUT_ "10.8.0.43" "10.9.0.34" "10.12.0.36" + "10.14.1.48" ) if(NOT ARG_VERSION IN_LIST downloadable_versions) @@ -164,6 +169,8 @@ function(mtrt_get_tensorrt_download_url ARG_VERSION OS_NAME TARGET_ARCH ARG_OUT_ elseif(ARG_VERSION VERSION_GREATER 10.10 AND ARG_VERSION VERSION_LESS 10.13) set(TRT_CUDA_VERSION 12.9) + elseif(ARG_VERSION VERSION_GREATER 10.13) + set(TRT_CUDA_VERSION 13.0) endif() # Handle TRT 8 versions. diff --git a/mlir-tensorrt/compiler/test/python/mlir_tensorrt_compiler/dialects/test_tensorrt.py b/mlir-tensorrt/compiler/test/python/mlir_tensorrt_compiler/dialects/test_tensorrt.py index d57237885..b55f5d334 100644 --- a/mlir-tensorrt/compiler/test/python/mlir_tensorrt_compiler/dialects/test_tensorrt.py +++ b/mlir-tensorrt/compiler/test/python/mlir_tensorrt_compiler/dialects/test_tensorrt.py @@ -49,6 +49,8 @@ def test_attributes(): tensorrt.TripLimitAttr.get("kWHILE"), tensorrt.FillOperationAttr.get("kRANDOM_UNIFORM"), tensorrt.ScatterModeAttr.get("kELEMENT"), + tensorrt.AttentionNormalizationOpAttr.get("kSOFTMAX"), + tensorrt.DataTypeAttr.get("kFLOAT"), ]: print(attr) @@ -74,3 +76,5 @@ def test_attributes(): # CHECK-NEXT: #tensorrt.trip_limit # CHECK-NEXT: #tensorrt.fill_operation # CHECK-NEXT: #tensorrt.scatter_mode +# CHECK-NEXT: #tensorrt.attention_normalization_op +# CHECK-NEXT: #tensorrt.data_type diff --git a/mlir-tensorrt/compiler/tools/CMakeLists.txt b/mlir-tensorrt/compiler/tools/CMakeLists.txt index a8683aa16..75266f072 100644 --- a/mlir-tensorrt/compiler/tools/CMakeLists.txt +++ b/mlir-tensorrt/compiler/tools/CMakeLists.txt @@ -21,5 +21,5 @@ set(LLVM_LINK_COMPONENTS add_subdirectory(mlir-tensorrt-opt) add_subdirectory(mlir-tensorrt-compiler) add_subdirectory(mlir-tensorrt-translate) -add_subdirectory(mlir-tensorrt-lsp-server) +# add_subdirectory(mlir-tensorrt-lsp-server) add_subdirectory(mlir-tensorrt-runner) diff --git a/mlir-tensorrt/integrations/python/setup_utils.py b/mlir-tensorrt/integrations/python/setup_utils.py index 31572b79b..27fec8097 100644 --- a/mlir-tensorrt/integrations/python/setup_utils.py +++ b/mlir-tensorrt/integrations/python/setup_utils.py @@ -13,7 +13,7 @@ import subprocess import atexit -TENSORRT_VERSION = os.getenv("MLIR_TRT_DOWNLOAD_TENSORRT_VERSION", "10.12") +TENSORRT_VERSION = os.getenv("MLIR_TRT_DOWNLOAD_TENSORRT_VERSION", "10.14") def log(*args): @@ -105,8 +105,8 @@ def run_cmake_build(python_package_name: str, python_wheel_staging_dir: Path): # Environment variable overrides cmake_preset = os.environ.get("MLIR_TRT_CMAKE_PRESET", "python-wheel-build") - install_prefix = os.environ.get("MLIR_TRT_INSTALL_DIR", None) - build_dir = os.environ.get("MLIR_TRT_BUILD_DIR", None) + install_prefix = os.environ.get("MLIR_TRT_INSTALL_DIR", "./install") + build_dir = os.environ.get("MLIR_TRT_BUILD_DIR", "./build") parallel_jobs = os.environ.get("MLIR_TRT_PARALLEL_JOBS", str(os.cpu_count() or 1)) # Additional CMake options from environment diff --git a/mlir-tensorrt/tensorrt/include/mlir-tensorrt-dialect-c/TensorRTAttributes.h b/mlir-tensorrt/tensorrt/include/mlir-tensorrt-dialect-c/TensorRTAttributes.h index ed5e9d336..a5fefab71 100644 --- a/mlir-tensorrt/tensorrt/include/mlir-tensorrt-dialect-c/TensorRTAttributes.h +++ b/mlir-tensorrt/tensorrt/include/mlir-tensorrt-dialect-c/TensorRTAttributes.h @@ -188,6 +188,22 @@ DECLARE_ATTR_GETTER_FROM_STRING(ScatterMode) DECLARE_IS_ATTR(ScatterMode) DECLARE_STRING_GETTER_FROM_ATTR(ScatterMode) +//===----------------------------------------------------------------------===// +// AttentionNormalizationOp +//===----------------------------------------------------------------------===// + +DECLARE_ATTR_GETTER_FROM_STRING(AttentionNormalizationOp) +DECLARE_IS_ATTR(AttentionNormalizationOp) +DECLARE_STRING_GETTER_FROM_ATTR(AttentionNormalizationOp) + +//===----------------------------------------------------------------------===// +// DataType +//===----------------------------------------------------------------------===// + +DECLARE_ATTR_GETTER_FROM_STRING(DataType) +DECLARE_IS_ATTR(DataType) +DECLARE_STRING_GETTER_FROM_ATTR(DataType) + #ifdef __cplusplus } #endif diff --git a/mlir-tensorrt/tensorrt/include/mlir-tensorrt-dialect/TensorRT/IR/TensorRTOps.td b/mlir-tensorrt/tensorrt/include/mlir-tensorrt-dialect/TensorRT/IR/TensorRTOps.td index 20b012e1e..ac228d153 100644 --- a/mlir-tensorrt/tensorrt/include/mlir-tensorrt-dialect/TensorRT/IR/TensorRTOps.td +++ b/mlir-tensorrt/tensorrt/include/mlir-tensorrt-dialect/TensorRT/IR/TensorRTOps.td @@ -3504,6 +3504,171 @@ def TensorRT_DequantizeOp : TensorRT_Op<"dequantize", }]; } +//===----------------------------------------------------------------------===// +// AttentionOp +//===----------------------------------------------------------------------===// + +def TensorRT_AttentionOp : TensorRT_Op<"attention", + [Pure, AttrSizedOperandSegments, TensorRTPartiallyInferTensorResultTypes, + AllElementTypesMatch<["query", "key", "value"]>, + AllRanksMatch<["query", "key", "value"]>]>{ + let summary = "TensorRT attention (IAttention) operation"; + let description = [{ + The `tensorrt.attention` operation implements a fused attention mechanism + that consumes query, key, and value tensors. The operation implicitly includes + two matrix multiplication layers (BMM1 and BMM2) and a normalization operation + (typically softmax). + + By default, TensorRT will try to use a single fused kernel for better efficiency. + The operation can optionally be decomposed into multiple kernels if no fused + kernel is available by setting `decomposable` to true. + + #### Architecture: + + ``` + Query Key Value Mask (optional) NormalizationQuantizeScale (optional) + | | | | | + | Transpose | | | + | | | | | + ----BMM1---- | | | + | | | | + *--------------------------- | + | | | + Normalization | | + | | | + *------------------------------------------------ + | | + -------BMM2------ + | + Output + ``` + + #### Inputs: + + - Query: tensor of type f32, f16, or bf16 with shape + [batchSize, numHeadsQuery, sequenceLengthQuery, dimHead] + - Key: tensor of type f32, f16, or bf16 with shape + [batchSize, numHeadsKeyValue, sequenceLengthKeyValue, dimHead] + - Value: tensor of type f32, f16, or bf16 with shape + [batchSize, numHeadsKeyValue, sequenceLengthKeyValue, dimHead] + - Mask (optional): tensor of type i1 or same type as BMM1 output with shape + [batchSize, numHeadsQuery, sequenceLengthQuery, sequenceLengthKeyValue] + where batchSize and numHeadsQuery are broadcastable. For i1 mask, true + indicates the position is allowed to attend. For other types, mask values + are added to BMM1 output. + - NormalizationQuantizeScale (optional): tensor of type f32, f16, or bf16 + with rank 0 (scalar) or 1 (1D tensor), used for quantizing the normalization output. + Required when normalization_quantize_to_type is specified. + + #### Attributes: + + - normalization_operation: The normalization operation to use (default: kSOFTMAX) + - causal: Whether to use causal masking (default: false). Cannot be used with mask input. + - decomposable: Whether the operation can be decomposed (default: false) + - normalization_quantize_to_type: Optional output type for quantized normalization. + When specified, must be one of kFP8 or kINT8. Requires normalization_quantize_scale input to be provided. + + #### Constraints: + + - All query, key, and value tensors must be rank 4 with shape [batchSize, numHeads, sequenceLength, dimHead] + - Query, key, and value must have the same element type (f32, f16, or bf16) + - If normalization_quantize_to_type is specified: + * It must be kFP8 or kINT8 + * normalization_quantize_scale input must be provided + - If normalization_quantize_scale is provided: + * normalization_quantize_to_type must be specified + * Element type must be f32, f16, or bf16 + * Rank must be 0 (scalar) or 1 (1D tensor) + - Cannot use both mask input and causal=true simultaneously + + #### Examples: + + Basic attention: + ```mlir + %output = tensorrt.attention ins(%query, %key, %value : + tensor<2x8x128x64xf16>, tensor<2x8x128x64xf16>, tensor<2x8x128x64xf16>) + -> tensor<2x8x128x64xf16> + ``` + + Causal attention: + ```mlir + %output_causal = tensorrt.attention {causal = true} ins(%query, %key, %value : + tensor<2x8x128x64xf16>, tensor<2x8x128x64xf16>, tensor<2x8x128x64xf16>) + -> tensor<2x8x128x64xf16> + ``` + + Attention with quantization: + ```mlir + %scale = tensorrt.constant dense<1.0> : tensor + %output_quant = tensorrt.attention { + normalization_quantize_to_type = #tensorrt.data_type + } ins(%query, %key, %value, + normalization_quantize_scale = %scale : + tensor<2x8x128x64xf16>, tensor<2x8x128x64xf16>, + tensor<2x8x128x64xf16>, tensor) + -> tensor<2x8x128x64xf16> + ``` + }]; + + let arguments = (ins + TensorRT_RankedTensorOf<[F16, BF16, F32]>:$query, + TensorRT_RankedTensorOf<[F16, BF16, F32]>:$key, + TensorRT_RankedTensorOf<[F16, BF16, F32]>:$value, + Optional:$mask, + Optional>:$normalization_quantize_scale, + DefaultValuedAttr:$normalization_operation, + DefaultValuedAttr:$causal, + DefaultValuedAttr:$decomposable, + OptionalAttr:$normalization_quantize_to_type + ); + + let results = (outs TensorRT_RankedTensorOf<[F16, BF16, F32]>:$result); + + let assemblyFormat = [{ + attr-dict `ins` `(` $query `,` $key `,` $value + (`,` `mask` `=` $mask^)? + (`,` `normalization_quantize_scale` `=` $normalization_quantize_scale^)? + `:` type($query) `,` type($key) `,` type($value) + (`,` type($mask)^)? + (`,` type($normalization_quantize_scale)^)? + `)` `->` type($result) + }]; + + let hasVerifier = 1; + + let extraClassDeclaration = [{ + /// Returns true if created op is valid for TensorRT major version. + bool isValidForTensorRTVersion(int64_t trtMajorVersion); + }] # baseClassDeclaration; + + let trtLayerAdd = [{ + nvinfer1::IAttention *layer = $net->addAttention(*$query, *$key, *$value, *$normalization_operation, $causal); + if (!layer) + return failure(); + + if ($mask) + layer->setMask(*$mask); + + layer->setDecomposable($decomposable); + + if ($normalization_quantize_scale) { + layer->setNormalizationQuantizeScale(*$normalization_quantize_scale); + } + + if ($normalization_quantize_to_type) { + auto convertedDataType = ::mlir::tensorrt::convertDataTypeToNvInferEnum(*$normalization_quantize_to_type); + if (!convertedDataType) + return emitError($op->getLoc()) << "failed to convert DataType to nvinfer enum"; + layer->setNormalizationQuantizeToType(*convertedDataType); + } + + $results.push_back(layer->getOutput(0)); + #if MLIR_TRT_COMPILE_TIME_TENSORRT_VERSION_GTE(10, 15, 0) + layer->setMetadata($op); + #endif + }]; +} + //===----------------------------------------------------------------------===// // TensorRT Dialect Extension Operations // @@ -4432,172 +4597,4 @@ def TensorRT_ScatterElementsOp : TensorRT_Op<"scatter_elements", }]; } -//===----------------------------------------------------------------------===// -// AttentionOp -//===----------------------------------------------------------------------===// - -def TensorRT_AttentionOp : TensorRT_Op<"attention", - [Pure, AttrSizedOperandSegments, TensorRTInferTensorResultTypes, - AllElementTypesMatch<["query", "key", "value"]>, - AllRanksMatch<["query", "key", "value"]>]>{ - let summary = "TensorRT attention (IAttention) operation"; - let description = [{ - The `tensorrt.attention` operation implements a fused attention mechanism - that consumes query, key, and value tensors. The operation implicitly includes - two matrix multiplication layers (BMM1 and BMM2) and a normalization operation - (typically softmax). - - By default, TensorRT will try to use a single fused kernel for better efficiency. - The operation can optionally be decomposed into multiple kernels if no fused - kernel is available by setting `decomposable` to true. - - #### Architecture: - - ``` - Query Key Value Mask (optional) NormalizationQuantizeScale (optional) - | | | | | - | Transpose | | | - | | | | | - ----BMM1---- | | | - | | | | - *--------------------------- | - | | | - Normalization | | - | | | - *------------------------------------------------ - | | - -------BMM2------ - | - Output - ``` - - #### Inputs: - - - Query: tensor of type f32, f16, or bf16 with shape - [batchSize, numHeadsQuery, sequenceLengthQuery, dimHead] - - Key: tensor of type f32, f16, or bf16 with shape - [batchSize, numHeadsKeyValue, sequenceLengthKeyValue, dimHead] - - Value: tensor of type f32, f16, or bf16 with shape - [batchSize, numHeadsKeyValue, sequenceLengthKeyValue, dimHead] - - Mask (optional): tensor of type i1 or same type as BMM1 output with shape - [batchSize, numHeadsQuery, sequenceLengthQuery, sequenceLengthKeyValue] - where batchSize and numHeadsQuery are broadcastable. For i1 mask, true - indicates the position is allowed to attend. For other types, mask values - are added to BMM1 output. - - NormalizationQuantizeScale (optional): tensor of type f32, f16, or bf16 - with rank 0 or 1, used for quantizing the normalization output. - - #### Attributes: - - - normalization_operation: The normalization operation to use (default: kSOFTMAX) - - causal: Whether to use causal masking (default: false). Cannot be used with mask input. - - decomposable: Whether the operation can be decomposed (default: false) - - normalization_quantize_to_type: Optional output type for quantized normalization. - When specified, must be one of kFP8 or kINT8. Requires normalization_quantize_scale input to be provided. - - #### Constraints: - - - All query, key, and value tensors must be rank 4 with shape [batchSize, numHeads, sequenceLength, dimHead] - - Query, key, and value must have the same element type (f32, f16, or bf16) - - If normalization_quantize_to_type is specified: - * It must be kFP8 or kINT8 - * normalization_quantize_scale input must be provided - - Cannot use both mask input and causal=true simultaneously - - #### Examples: - - Basic attention: - ```mlir - %output = tensorrt.attention ins(%query, %key, %value : - tensor<2x8x128x64xf16>, tensor<2x8x128x64xf16>, tensor<2x8x128x64xf16>) - -> tensor<2x8x128x64xf16> - ``` - - Causal attention: - ```mlir - %output_causal = tensorrt.attention {causal = true} ins(%query, %key, %value : - tensor<2x8x128x64xf16>, tensor<2x8x128x64xf16>, tensor<2x8x128x64xf16>) - -> tensor<2x8x128x64xf16> - ``` - - Attention with quantization: - ```mlir - %scale = tensorrt.constant dense<1.0> : tensor - %output_quant = tensorrt.attention { - normalization_quantize_to_type = #tensorrt.data_type - } ins(%query, %key, %value, - normalization_quantize_scale = %scale : - tensor<2x8x128x64xf16>, tensor<2x8x128x64xf16>, - tensor<2x8x128x64xf16>, tensor) - -> tensor<2x8x128x64xf16> - ``` - }]; - - let arguments = (ins - TensorRT_RankedTensorOf<[F16, BF16, F32]>:$query, - TensorRT_RankedTensorOf<[F16, BF16, F32]>:$key, - TensorRT_RankedTensorOf<[F16, BF16, F32]>:$value, - Optional:$mask, - Optional>:$normalization_quantize_scale, - OptionalAttr:$normalization_operation, - DefaultValuedAttr:$causal, - DefaultValuedAttr:$decomposable, - OptionalAttr:$normalization_quantize_to_type - ); - - let results = (outs TensorRT_RankedTensorOf<[F16, BF16, F32]>:$result); - - let assemblyFormat = [{ - attr-dict `ins` `(` $query `,` $key `,` $value - (`,` `mask` `=` $mask^)? - (`,` `normalization_quantize_scale` `=` $normalization_quantize_scale^)? - `:` type($query) `,` type($key) `,` type($value) - (`,` type($mask)^)? - (`,` type($normalization_quantize_scale)^)? - `)` `->` type($result) - }]; - - let hasVerifier = 1; - - let extraClassDeclaration = [{ - /// Returns true if created op is valid for TensorRT major version. - bool isValidForTensorRTVersion(int64_t trtMajorVersion); - }] # baseClassDeclaration; - - let trtLayerAdd = [{ - // Get normalization operation, default to kSOFTMAX - nvinfer1::AttentionNormalizationOp normOp = $normalization_operation - ? *$normalization_operation - : nvinfer1::AttentionNormalizationOp::kSOFTMAX; - - nvinfer1::IAttention *layer = $net->addAttention(*$query, *$key, *$value, normOp, $causal); - if (!layer) - return failure(); - - if ($mask) - layer->setMask(*$mask); - - layer->setDecomposable($decomposable); - - if ($normalization_quantize_scale) { - layer->setNormalizationQuantizeScale(*$normalization_quantize_scale); - } - - if ($normalization_quantize_to_type) { - layer->setNormalizationQuantizeToType(*$normalization_quantize_to_type); - } - - if (!$e.isStronglyTyped()){ - FailureOr outputTrtType = getNvInferDataType($op.getLoc(), - $op.getType().getElementType()); - if (failed(outputTrtType)) - return failure(); - layer->setOutputType(0, *outputTrtType); - } - - $results.push_back(layer->getOutput(0)); - $e.setMetadata(layer, $op); - }]; -} - #endif // MLIR_TENSORRT_DIALECT_TENSORRT_IR_TENSORRTOPS_TD diff --git a/mlir-tensorrt/tensorrt/lib/Bindings/Python/DialectTensorRT.cpp b/mlir-tensorrt/tensorrt/lib/Bindings/Python/DialectTensorRT.cpp index 0e134a405..2dc6d3167 100644 --- a/mlir-tensorrt/tensorrt/lib/Bindings/Python/DialectTensorRT.cpp +++ b/mlir-tensorrt/tensorrt/lib/Bindings/Python/DialectTensorRT.cpp @@ -77,4 +77,6 @@ PYBIND11_MODULE(_tensorrt, m) { ADD_PYTHON_ATTRIBUTE_ADAPTOR(TripLimit) ADD_PYTHON_ATTRIBUTE_ADAPTOR(FillOperation) ADD_PYTHON_ATTRIBUTE_ADAPTOR(ScatterMode) + ADD_PYTHON_ATTRIBUTE_ADAPTOR(AttentionNormalizationOp) + ADD_PYTHON_ATTRIBUTE_ADAPTOR(DataType) } diff --git a/mlir-tensorrt/tensorrt/lib/CAPI/TensorRTAttributes.cpp b/mlir-tensorrt/tensorrt/lib/CAPI/TensorRTAttributes.cpp index 50e87551c..456f3b11c 100644 --- a/mlir-tensorrt/tensorrt/lib/CAPI/TensorRTAttributes.cpp +++ b/mlir-tensorrt/tensorrt/lib/CAPI/TensorRTAttributes.cpp @@ -121,3 +121,11 @@ DEFINE_STRING_GETTER_FROM_ATTR(FillOperation) DEFINE_ATTR_GETTER_FROM_STRING(ScatterMode) DEFINE_IS_ATTR(ScatterMode) DEFINE_STRING_GETTER_FROM_ATTR(ScatterMode) + +DEFINE_ATTR_GETTER_FROM_STRING(AttentionNormalizationOp) +DEFINE_IS_ATTR(AttentionNormalizationOp) +DEFINE_STRING_GETTER_FROM_ATTR(AttentionNormalizationOp) + +DEFINE_ATTR_GETTER_FROM_STRING(DataType) +DEFINE_IS_ATTR(DataType) +DEFINE_STRING_GETTER_FROM_ATTR(DataType) diff --git a/mlir-tensorrt/tensorrt/lib/TensorRT/IR/TypeInferenceInterfaceImpls.cpp b/mlir-tensorrt/tensorrt/lib/TensorRT/IR/TypeInferenceInterfaceImpls.cpp index e0ad4a1fc..96d107305 100644 --- a/mlir-tensorrt/tensorrt/lib/TensorRT/IR/TypeInferenceInterfaceImpls.cpp +++ b/mlir-tensorrt/tensorrt/lib/TensorRT/IR/TypeInferenceInterfaceImpls.cpp @@ -1633,3 +1633,21 @@ LogicalResult tensorrt::DequantizeOp::inferReturnTypeComponents( /*elementType=*/nullptr); return success(); } + +//===----------------------------------------------------------------------===// +// AttentionOp +//===----------------------------------------------------------------------===// + +LogicalResult tensorrt::AttentionOp::inferReturnTypeComponents( + MLIRContext *ctx, std::optional loc, ValueShapeRange operands, + DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions, + SmallVectorImpl &inferredReturnShapes) { + AttentionOp::Adaptor adaptor(operands, attributes, properties, regions); + auto queryType = dyn_cast(adaptor.getQuery().getType()); + if (!queryType) + return emitOptionalError(loc, "expected query to be a ranked tensor"); + inferredReturnShapes.emplace_back( + /*vec=*/queryType.getShape(), + /*elementType=*/queryType.getElementType()); + return success(); +} diff --git a/mlir-tensorrt/tensorrt/lib/TensorRT/IR/Verification.cpp b/mlir-tensorrt/tensorrt/lib/TensorRT/IR/Verification.cpp index 92eca6142..a03b9c997 100644 --- a/mlir-tensorrt/tensorrt/lib/TensorRT/IR/Verification.cpp +++ b/mlir-tensorrt/tensorrt/lib/TensorRT/IR/Verification.cpp @@ -1464,3 +1464,59 @@ static LogicalResult verifyAllowedDataTypes(UnaryOp op) { LogicalResult tensorrt::UnaryOp::verify() { return verifyAllowedDataTypes(*this); } + +//===----------------------------------------------------------------------===// +// AttentionOp +//===----------------------------------------------------------------------===// + +LogicalResult tensorrt::AttentionOp::verify() { + // Check 1: Cannot use both mask input and causal=true simultaneously + if (getMask() && getCausal()) + return emitOpError( + "cannot use both mask input and causal=true simultaneously"); + + // Check 2: If normalization_quantize_to_type is specified, it must be kFP8 + // or kINT8 and normalization_quantize_scale must be provided + std::optional quantizeType = getNormalizationQuantizeToType(); + if (quantizeType.has_value()) { + if (*quantizeType != DataType::kFP8 && *quantizeType != DataType::kINT8) + return emitOpError("normalization_quantize_to_type must be kFP8 or " + "kINT8, but got ") + << stringifyDataType(*quantizeType); + + if (!getNormalizationQuantizeScale()) + return emitOpError( + "normalization_quantize_scale input must be provided when " + "normalization_quantize_to_type is specified"); + } + + // Check 3: If normalization_quantize_scale is provided, + // normalization_quantize_to_type must be specified + if (getNormalizationQuantizeScale() && !quantizeType.has_value()) + return emitOpError( + "normalization_quantize_to_type must be specified when " + "normalization_quantize_scale input is provided"); + + // Check 4: If normalization_quantize_scale is provided, validate its type + if (getNormalizationQuantizeScale()) { + RankedTensorType scaleType = getNormalizationQuantizeScale().getType(); + Type scaleElemType = scaleType.getElementType(); + + // Check that element type is f32, f16, or bf16 + if (!scaleElemType.isF32() && !scaleElemType.isF16() && + !scaleElemType.isBF16()) + return emitOpError( + "normalization_quantize_scale element type must be f32, f16, " + "or bf16, but got ") + << scaleElemType; + + // Check that scale is rank 0 or 1 + if (scaleType.getRank() != 0 && scaleType.getRank() != 1) + return emitOpError( + "normalization_quantize_scale must be rank 0 or 1, but got " + "rank ") + << scaleType.getRank(); + } + + return success(); +} diff --git a/mlir-tensorrt/tensorrt/test/Target/TensorRT/TRT10/attention.mlir b/mlir-tensorrt/tensorrt/test/Target/TensorRT/TRT10/attention.mlir new file mode 100644 index 000000000..c4efe1bf5 --- /dev/null +++ b/mlir-tensorrt/tensorrt/test/Target/TensorRT/TRT10/attention.mlir @@ -0,0 +1,72 @@ +// REQUIRES: tensorrt-version-ge-10.14 +// RUN: %pick-one-gpu tensorrt-opt -split-input-file -pass-pipeline="builtin.module(translate-tensorrt-to-engine)" \ +// RUN: -mlir-elide-elementsattrs-if-larger=32 -tensorrt-builder-opt-level=0 -tensorrt-strongly-typed %s | FileCheck %s +// RUN: %pick-one-gpu tensorrt-opt -split-input-file -pass-pipeline="builtin.module(translate-tensorrt-to-engine)" \ +// RUN: -mlir-elide-elementsattrs-if-larger=32 -tensorrt-builder-opt-level=0 %s | FileCheck %s + +// CHECK-LABEL: @trt_attention_f16 +// CHECK-SAME: tensorrt.engine +func.func @trt_attention_f16(%arg0: tensor<2x8x128x64xf16>, + %arg1: tensor<2x8x128x64xf16>, + %arg2: tensor<2x8x128x64xf16>) + -> tensor<2x8x128x64xf16> { + %0 = tensorrt.attention ins(%arg0, %arg1, %arg2 : + tensor<2x8x128x64xf16>, tensor<2x8x128x64xf16>, tensor<2x8x128x64xf16>) + -> tensor<2x8x128x64xf16> + return %0 : tensor<2x8x128x64xf16> +} + +// CHECK-LABEL: @trt_attention_causal_f16 +// CHECK-SAME: tensorrt.engine +func.func @trt_attention_causal_f16(%arg0: tensor<2x8x128x64xf16>, + %arg1: tensor<2x8x128x64xf16>, + %arg2: tensor<2x8x128x64xf16>) + -> tensor<2x8x128x64xf16> { + %0 = tensorrt.attention {causal = true} ins(%arg0, %arg1, %arg2 : + tensor<2x8x128x64xf16>, tensor<2x8x128x64xf16>, tensor<2x8x128x64xf16>) + -> tensor<2x8x128x64xf16> + return %0 : tensor<2x8x128x64xf16> +} + +// CHECK-LABEL: @trt_attention_with_mask_f16 +// CHECK-SAME: tensorrt.engine +func.func @trt_attention_with_mask_f16(%arg0: tensor<2x8x128x64xf16>, + %arg1: tensor<2x8x128x64xf16>, + %arg2: tensor<2x8x128x64xf16>, + %mask: tensor<2x8x128x128xf16>) + -> tensor<2x8x128x64xf16> { + %0 = tensorrt.attention ins(%arg0, %arg1, %arg2, mask = %mask : + tensor<2x8x128x64xf16>, tensor<2x8x128x64xf16>, tensor<2x8x128x64xf16>, tensor<2x8x128x128xf16>) + -> tensor<2x8x128x64xf16> + return %0 : tensor<2x8x128x64xf16> +} + +// CHECK-LABEL: @trt_attention_with_quantization_f16 +// CHECK-SAME: tensorrt.engine +func.func @trt_attention_with_quantization_f16(%arg0: tensor<2x8x128x64xf16>, + %arg1: tensor<2x8x128x64xf16>, + %arg2: tensor<2x8x128x64xf16>) + -> tensor<2x8x128x64xf16> { + %scale = tensorrt.constant dense<1.0> : tensor + %0 = tensorrt.attention { + normalization_quantize_to_type = #tensorrt.data_type + } ins(%arg0, %arg1, %arg2, + normalization_quantize_scale = %scale : + tensor<2x8x128x64xf16>, tensor<2x8x128x64xf16>, + tensor<2x8x128x64xf16>, tensor) + -> tensor<2x8x128x64xf16> + return %0 : tensor<2x8x128x64xf16> +} + +// CHECK-LABEL: @trt_attention_decomposable_f16 +// CHECK-SAME: tensorrt.engine +func.func @trt_attention_decomposable_f16(%arg0: tensor<2x8x128x64xf16>, + %arg1: tensor<2x8x128x64xf16>, + %arg2: tensor<2x8x128x64xf16>) + -> tensor<2x8x128x64xf16> { + %0 = tensorrt.attention {decomposable = true} ins(%arg0, %arg1, %arg2 : + tensor<2x8x128x64xf16>, tensor<2x8x128x64xf16>, tensor<2x8x128x64xf16>) + -> tensor<2x8x128x64xf16> + return %0 : tensor<2x8x128x64xf16> +} +