diff --git a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py index fe9a01b06c..32f93fe576 100644 --- a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py +++ b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py @@ -126,6 +126,33 @@ def aten_ops_batch_norm_legit_no_training( ) +@dynamo_tensorrt_converter( + torch.ops.aten._native_batch_norm_legit.no_stats, + capability_validator=one_user_validator, + supports_dynamic_shapes=True, +) +def aten_ops_batch_norm_legit_no_stats( + ctx: ConversionContext, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, +) -> Union[TRTTensor, Sequence[TRTTensor]]: + return impl.normalization.batch_norm( + ctx, + target, + SourceIR.ATEN, + name, + input=args[0], + weight=args[1], + bias=args[2], + training=False, + momentum=args[4], + eps=args[5], + return_mean_rstd=True, + ) + + @dynamo_tensorrt_converter( torch.ops.aten.native_layer_norm.default, supports_dynamic_shapes=True, diff --git a/py/torch_tensorrt/dynamo/conversion/converter_utils.py b/py/torch_tensorrt/dynamo/conversion/converter_utils.py index 896bf37b42..1c04814f75 100644 --- a/py/torch_tensorrt/dynamo/conversion/converter_utils.py +++ b/py/torch_tensorrt/dynamo/conversion/converter_utils.py @@ -19,11 +19,10 @@ import numpy as np import tensorrt as trt import torch +import torch_tensorrt.dynamo.conversion.impl as impl from torch.fx.experimental.proxy_tensor import unset_fake_temporarily from torch.fx.node import Argument, Target from torch.fx.passes.shape_prop import TensorMetadata - -import torch_tensorrt.dynamo.conversion.impl as impl from torch_tensorrt import _enums from torch_tensorrt.dynamo._settings import CompilationSettings from torch_tensorrt.dynamo._SourceIR import SourceIR @@ -345,7 +344,7 @@ def to_trt_weights( count: Optional[int] = None, ) -> trt.Weights: """ - Convert a PyTorch tensor or NumPy array to TensorRT weights. + Convert a PyTorch tensor to TensorRT weights. Args: value (Union[torch.Tensor, np.ndarray]): The tensor or array to convert to TRT weights diff --git a/py/torch_tensorrt/dynamo/conversion/impl/normalization/ops.py b/py/torch_tensorrt/dynamo/conversion/impl/normalization/ops.py index f9b47542a8..d2c5e1f840 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/normalization/ops.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/normalization/ops.py @@ -4,6 +4,7 @@ import numpy as np import tensorrt as trt import torch +from torch._subclasses.fake_tensor import unset_fake_temporarily from torch.fx.node import Target from torch_tensorrt.dynamo._SourceIR import SourceIR from torch_tensorrt.dynamo.conversion import impl @@ -32,21 +33,22 @@ def batch_norm( source_ir: Optional[SourceIR], name: str, input: trt.ITensor, - weight: Optional[Union[trt.ITensor, torch.Tensor, np.ndarray]], - bias: Optional[Union[trt.ITensor, torch.Tensor, np.ndarray]], - running_mean: Optional[Union[trt.ITensor, torch.Tensor, np.ndarray]], - running_var: Optional[Union[trt.ITensor, torch.Tensor, np.ndarray]], - training: bool, momentum: float, eps: float, - cudnn_enabled: bool, return_mean_rstd: bool, + weight: Optional[Union[trt.ITensor, torch.Tensor, np.ndarray]] = None, + bias: Optional[Union[trt.ITensor, torch.Tensor, np.ndarray]] = None, + running_mean: Optional[Union[trt.ITensor, torch.Tensor, np.ndarray]] = None, + running_var: Optional[Union[trt.ITensor, torch.Tensor, np.ndarray]] = None, + training: bool = False, + cudnn_enabled: bool = False, ) -> Union[trt.ITensor, Tuple[trt.ITensor, torch.Tensor, torch.Tensor]]: if has_dynamic_shape(input.shape): assert input.shape[1] != -1, "Channel dim can't be dynamic for batch norm." # Save the original output shape for later use output_shape = input.shape + feature_num = output_shape[1] # We perform constant folding for batch norm when the weight, bias, running_mean, and running_var are all tensors. # Batch norm operation can be fused into a single layer, which is more efficient than the original implementation. # In this way, the batch norm layer will be fused with the Convolution layer and get a performance boost. @@ -59,26 +61,41 @@ def batch_norm( ] ): # We name the weight here according to the state_dict name - weight = ( - get_trt_tensor(ctx, 1.0, f"{name}_weight", dtype=input.dtype) - if weight is None - else get_trt_tensor(ctx, weight, f"{name}_weight") - ) - bias = ( - get_trt_tensor(ctx, 0.0, f"{name}_bias", dtype=input.dtype) - if bias is None - else get_trt_tensor(ctx, bias, f"{name}_bias") - ) - running_mean = ( - get_trt_tensor(ctx, 0.0, f"{name}_running_mean", dtype=input.dtype) - if running_mean is None - else get_trt_tensor(ctx, running_mean, f"{name}_running_mean") - ) - running_var = ( - get_trt_tensor(ctx, 1.0, f"{name}_running_var", dtype=input.dtype) - if running_var is None - else get_trt_tensor(ctx, running_var, f"{name}_running_var") - ) + with unset_fake_temporarily(): + weight = ( + get_trt_tensor( + ctx, torch.ones((feature_num,)), f"{name}_weight", dtype=input.dtype + ) + if weight is None + else get_trt_tensor(ctx, weight, f"{name}_weight") + ) + bias = ( + get_trt_tensor( + ctx, torch.zeros((feature_num,)), f"{name}_bias", dtype=input.dtype + ) + if bias is None + else get_trt_tensor(ctx, bias, f"{name}_bias") + ) + running_mean = ( + get_trt_tensor( + ctx, + torch.zeros((feature_num,)), + f"{name}_running_mean", + dtype=input.dtype, + ) + if running_mean is None + else get_trt_tensor(ctx, running_mean, f"{name}_running_mean") + ) + running_var = ( + get_trt_tensor( + ctx, + torch.ones((feature_num,)), + f"{name}_running_var", + dtype=input.dtype, + ) + if running_var is None + else get_trt_tensor(ctx, running_var, f"{name}_running_var") + ) # eps_tensor for numerical stability eps_tensor = get_trt_tensor(ctx, eps, f"{name}_eps", dtype=input.dtype) @@ -110,8 +127,7 @@ def batch_norm( # Reshape scale and bias_adjusted to match input shape for broadcasting expanded_shape = [1] * len(output_shape) - expanded_shape[1] = output_shape[1] # Set channel dimension - + expanded_shape[1] = feature_num # Set channel dimension scale_reshape = impl.shuffle.reshape( ctx, target, @@ -143,21 +159,24 @@ def batch_norm( ) else: - if weight is None: - weight = 1.0 + with unset_fake_temporarily(): + if weight is None: + weight = torch.ones((feature_num,)) - if bias is None: - bias = 0.0 + if bias is None: + bias = torch.zeros((feature_num,)) - if running_mean is None: - running_mean = 0.0 + if running_mean is None: + running_mean = torch.zeros((feature_num,)) - if running_var is None: - running_var = 1.0 - adjusted_scale, adjusted_bias = batch_norm_constant_folding( - weight, bias, running_mean, running_var, eps - ) - power = torch.ones_like(adjusted_scale) + if running_var is None: + running_var = torch.ones((feature_num,)) + + power = torch.ones_like(weight) + + adjusted_scale, adjusted_bias = batch_norm_constant_folding( + weight, bias, running_mean, running_var, eps + ) adjusted_scale = to_trt_weights( ctx, @@ -188,9 +207,7 @@ def batch_norm( source_ir=source_ir, ) - output_shape = input.shape if len(input.shape) < 4: - new_shape = ( (input.shape[0], input.shape[1], 1, 1) if len(input.shape) == 2 diff --git a/py/torch_tensorrt/dynamo/lowering/_decomposition_groups.py b/py/torch_tensorrt/dynamo/lowering/_decomposition_groups.py index 9d28ae70a5..a46e0c9d01 100644 --- a/py/torch_tensorrt/dynamo/lowering/_decomposition_groups.py +++ b/py/torch_tensorrt/dynamo/lowering/_decomposition_groups.py @@ -91,7 +91,6 @@ aten.narrow, # TODO: Disable the below operators once freezing is done aten.native_batch_norm_backward, - aten._native_batch_norm_legit, aten._native_batch_norm_legit_functional, aten.native_dropout_backward, aten.native_group_norm_backward,