diff --git a/Paddle b/Paddle index 0025cdc9537..56be4659242 160000 --- a/Paddle +++ b/Paddle @@ -1 +1 @@ -Subproject commit 0025cdc95377be8a334ec943d35a33ef42edcfdf +Subproject commit 56be465924264e1251cf127dbff56d17a7554d01 diff --git a/backends/metax_gpu/CMakeLists.txt b/backends/metax_gpu/CMakeLists.txt index d09d0128c4d..60b43d63637 100755 --- a/backends/metax_gpu/CMakeLists.txt +++ b/backends/metax_gpu/CMakeLists.txt @@ -158,8 +158,8 @@ file( ${PADDLE_SOURCE_DIR}/paddle/phi/kernels/gpu/grid_sample_kernel.cu ${PADDLE_SOURCE_DIR}/paddle/phi/kernels/gpu/instance_norm_grad_kernel.cu ${PADDLE_SOURCE_DIR}/paddle/phi/kernels/gpu/instance_norm_kernel.cu - ${PADDLE_SOURCE_DIR}/paddle/phi/kernels/gpu/rnn_grad_kernel.cu.cc - ${PADDLE_SOURCE_DIR}/paddle/phi/kernels/gpu/rnn_kernel.cu.cc + ${PADDLE_SOURCE_DIR}/paddle/phi/kernels/gpu/rnn_grad_kernel.cu + ${PADDLE_SOURCE_DIR}/paddle/phi/kernels/gpu/rnn_kernel.cu ${PADDLE_SOURCE_DIR}/paddle/phi/kernels/gpu/ctc_align_kernel.cu ${PADDLE_SOURCE_DIR}/paddle/phi/kernels/gpu/yolo_box_head_kernel.cu ${PADDLE_SOURCE_DIR}/paddle/phi/kernels/gpu/stft_grad_kernel.cu @@ -290,7 +290,7 @@ file( ${PADDLE_SOURCE_DIR}/paddle/phi/kernels/unsqueeze_kernel.cc ${PADDLE_SOURCE_DIR}/paddle/phi/kernels/squeeze_grad_kernel.cc ${PADDLE_SOURCE_DIR}/paddle/phi/kernels/squeeze_kernel.cc - ${PADDLE_SOURCE_DIR}/paddle/phi/kernels/gpu/sign_kernel.cu.cc + ${PADDLE_SOURCE_DIR}/paddle/phi/kernels/gpu/sign_kernel.cu ${PADDLE_SOURCE_DIR}/paddle/phi/kernels/gpu/split_kernel.cu ${PADDLE_SOURCE_DIR}/paddle/phi/kernels/gpu/scatter_nd_add_kernel.cu ${PADDLE_SOURCE_DIR}/paddle/phi/kernels/gpu/scatter_nd_add_grad_kernel.cu diff --git a/backends/metax_gpu/kernels/cuda_kernels/sign_kernel_register.cu b/backends/metax_gpu/kernels/cuda_kernels/sign_kernel_register.cu index 52e0f9d256a..8e37408da9c 100644 --- a/backends/metax_gpu/kernels/cuda_kernels/sign_kernel_register.cu +++ b/backends/metax_gpu/kernels/cuda_kernels/sign_kernel_register.cu @@ -27,7 +27,7 @@ PD_CUSTOM_KERNEL_REGISTER(sign, int64_t, float, double, - phi::dtype::float16, - phi::dtype::bfloat16, - phi::dtype::complex, - phi::dtype::complex) {} + phi::float16, + phi::bfloat16, + phi::complex64, + phi::complex128) {} diff --git a/backends/metax_gpu/kernels/custom_kernel/layer_norm_grad_kernel_register.cu b/backends/metax_gpu/kernels/custom_kernel/layer_norm_grad_kernel_register.cu index 857dcb6d522..dd727cd21ae 100644 --- a/backends/metax_gpu/kernels/custom_kernel/layer_norm_grad_kernel_register.cu +++ b/backends/metax_gpu/kernels/custom_kernel/layer_norm_grad_kernel_register.cu @@ -14,48 +14,99 @@ #include "funcs/layer_norm_util.h" #include "impl/layer_norm_impl.cu.h" -#include "paddle/common/flags.h" #include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/cast_kernel.h" +#include "paddle/phi/kernels/full_kernel.h" #include "paddle/phi/kernels/layer_norm_grad_kernel.h" +#if defined(PADDLE_WITH_CUDA) && !defined(PADDLE_WITH_HIP) && !defined(_WIN32) +#include "paddle/phi/kernels/funcs/fast_ln_v2.h" +#endif +#ifdef PADDLE_WITH_CUDA +#include "paddle/phi/kernels/gpu/rms_norm_cuda_kernel.h" +#endif namespace phi { +enum class LayerNormGadKernelVariant { FAST_LN_V2, GENERIC }; +static inline LayerNormGadKernelVariant LayerNormGradKernelDispatch( + const paddle::DataType weight_type, + const paddle::DataType input_type, + const paddle::DataType output_type, + const paddle::DataType compute_type, + const uint32_t hidden_size, + const int64_t x_numel, + const DenseTensor* scale, + const DenseTensor* bias) { +#if defined(PADDLE_WITH_CUDA) && !defined(PADDLE_WITH_HIP) && !defined(_WIN32) + if (scale != nullptr && bias != nullptr && + input_type != paddle::DataType::FLOAT32 && hidden_size != 4096 && + hidden_size > 1024 && hidden_size <= 10240 && + x_numel <= std::numeric_limits::max()) { + // using fast_ln_v2 only sm > 70 and x_numel <= uint32_max + auto prop = funcs::fast_ln_v2::GetDeviceProp(); + if (prop->major > 7 && + funcs::fast_ln_v2::has_fast_ln_v2_bwd_kernel( + weight_type, input_type, output_type, compute_type, hidden_size)) { + return LayerNormGadKernelVariant::FAST_LN_V2; + } + } +#endif + return LayerNormGadKernelVariant::GENERIC; +} template -void LayerNormGradKernel(const Context &dev_ctx, - const DenseTensor &x, - const paddle::optional &scale_opt, - const paddle::optional &bias_opt, - const DenseTensor &mean, - const DenseTensor &variance, - const DenseTensor &out_grad, - float epsilon, +void LayerNormGradKernel(const Context& dev_ctx, + const DenseTensor& x, + const optional& scale_opt, + const optional& bias_opt, + const DenseTensor& mean, + const DenseTensor& variance, + const DenseTensor& out_grad, + double epsilon, int begin_norm_axis, - DenseTensor *x_grad, - DenseTensor *scale_grad, - DenseTensor *bias_grad) { - using U = phi::funcs::LayerNormParamType; + DenseTensor* x_grad, + DenseTensor* scale_grad, + DenseTensor* bias_grad) { + if (x.numel() == 0) { + dev_ctx.template Alloc(x_grad); + if (scale_grad) { + Full(dev_ctx, scale_grad->dims(), 0, scale_grad); + if (scale_opt.get_ptr() && x.dtype() != scale_opt.get().dtype()) { + CastKernel( + dev_ctx, *scale_grad, scale_opt.get().dtype(), scale_grad); + } + } + if (bias_grad) { + Full(dev_ctx, bias_grad->dims(), 0, bias_grad); + if (bias_opt.get_ptr() && x.dtype() != bias_opt.get().dtype()) { + CastKernel( + dev_ctx, *bias_grad, bias_opt.get().dtype(), bias_grad); + } + } + return; + } + using U = funcs::LayerNormParamType; // d_x, d_scale, d_bias may be nullptr - auto *d_x = x_grad; - auto *d_scale = scale_grad; - auto *d_bias = bias_grad; + auto* d_x = x_grad; + auto* d_scale = scale_grad; + auto* d_bias = bias_grad; - auto *scale = scale_opt.get_ptr(); - auto *bias = bias_opt.get_ptr(); - auto *d_y = &out_grad; + auto* scale = scale_opt.get_ptr(); + auto* bias = bias_opt.get_ptr(); + auto* d_y = &out_grad; - const auto &x_dims = x.dims(); + const auto& x_dims = x.dims(); auto matrix_dim = common::flatten_to_2d(x_dims, begin_norm_axis); int64_t batch_size = static_cast(matrix_dim[0]); int64_t feature_size = static_cast(matrix_dim[1]); - auto *x_data = x.data(); - auto *d_y_data = d_y->data(); + auto* x_data = x.data(); + auto* d_y_data = d_y->data(); - auto *mean_data = mean.data(); - auto *var_data = variance.data(); + auto* mean_data = mean.data(); + auto* var_data = variance.data(); - auto *d_x_data = (d_x == nullptr ? nullptr : dev_ctx.template Alloc(d_x)); + auto* d_x_data = (d_x == nullptr ? nullptr : dev_ctx.template Alloc(d_x)); auto x_dtype = x.dtype(); @@ -74,52 +125,165 @@ void LayerNormGradKernel(const Context &dev_ctx, #define PADDLE_LAUNCH_LAYERNORM_BWD(ScaleBiasT, IsScaleBiasSameDTypeWithX) \ do { \ - auto *scale_data = \ + auto* scale_data = \ (scale == nullptr ? nullptr : scale->data()); \ - auto *d_scale_data = \ + auto* d_scale_data = \ (d_scale == nullptr ? nullptr \ : dev_ctx.template Alloc(d_scale)); \ - auto *d_bias_data = \ + auto* d_bias_data = \ (d_bias == nullptr ? nullptr \ : dev_ctx.template Alloc(d_bias)); \ - auto *d_x_data = \ + auto* d_x_data = \ (d_x == nullptr ? nullptr : dev_ctx.template Alloc(d_x)); \ - phi::funcs::LayerNormBackward( \ - x_data, \ - d_y_data, \ - scale_data, \ - mean_data, \ - var_data, \ - d_x_data, \ - d_scale_data, \ - d_bias_data, \ - epsilon, \ - batch_size, \ - feature_size, \ - dev_ctx); \ + funcs::LayerNormBackward(x_data, \ + d_y_data, \ + scale_data, \ + mean_data, \ + var_data, \ + d_x_data, \ + d_scale_data, \ + d_bias_data, \ + epsilon, \ + batch_size, \ + feature_size, \ + dev_ctx); \ } while (0) - if (scale_bias_dtype == x_dtype) { - PADDLE_LAUNCH_LAYERNORM_BWD(T, true); - } else { - PADDLE_LAUNCH_LAYERNORM_BWD(U, false); +#define PADDLE_LAUNCH_FAST_LAYERNORM_V2_BWD(ScaleBiasT) \ + do { \ + auto stream = dev_ctx.stream(); \ + auto place = x.place(); \ + auto* scale_data = \ + (scale == nullptr ? nullptr : scale->data()); \ + auto* d_scale_data = \ + (d_scale == nullptr ? nullptr \ + : dev_ctx.template Alloc(d_scale)); \ + auto* d_bias_data = \ + (d_bias == nullptr ? nullptr \ + : dev_ctx.template Alloc(d_bias)); \ + auto* d_x_data = \ + (d_x == nullptr ? nullptr : dev_ctx.template Alloc(d_x)); \ + funcs::fast_ln_v2::LaunchNormBwd(dev_ctx, \ + stream, \ + place, \ + x_data, \ + scale_data, \ + mean_data, \ + var_data, \ + d_y_data, \ + d_x_data, \ + d_scale_data, \ + d_bias_data, \ + scale_bias_dtype, \ + x_dtype, \ + x_grad->dtype(), \ + compute_dtype, \ + feature_size, \ + batch_size, \ + feature_size, \ + epsilon); \ + } while (0) + + auto compute_dtype = phi::CppTypeToDataType::Type(); + auto kernel_variant = LayerNormGradKernelDispatch(scale_bias_dtype, + x_dtype, + x_dtype, + compute_dtype, + feature_size, + x.numel(), + scale, + bias); + switch (kernel_variant) { +#if defined(PADDLE_WITH_CUDA) && !defined(PADDLE_WITH_HIP) && !defined(_WIN32) + case LayerNormGadKernelVariant::FAST_LN_V2: + if (scale_bias_dtype == x_dtype) { + PADDLE_LAUNCH_FAST_LAYERNORM_V2_BWD(T); + } else { + PADDLE_LAUNCH_FAST_LAYERNORM_V2_BWD(U); + } + break; +#endif + case LayerNormGadKernelVariant::GENERIC: + default: +#ifdef PADDLE_WITH_CUDA + if ((FLAGS_use_accuracy_compatible_kernel || + (!isPowerOfTwo(feature_size) && feature_size > 1024)) && + scale_bias_dtype == x_dtype) { + auto* scale_data = (scale == nullptr ? nullptr : scale->data()); + auto* d_scale_data = + (d_scale == nullptr ? nullptr : dev_ctx.template Alloc(d_scale)); + auto* d_bias_data = + (d_bias == nullptr ? nullptr : dev_ctx.template Alloc(d_bias)); + auto* d_x_data = + (d_x == nullptr ? nullptr : dev_ctx.template Alloc(d_x)); + LayerNormBwdCompatKernel(dev_ctx, + d_y_data, + x_data, + scale_data, + mean_data, + var_data, + d_x_data, + d_scale_data, + d_bias_data, + epsilon, + batch_size, + feature_size); + } else { +#endif + if (scale_bias_dtype == x_dtype) { + PADDLE_LAUNCH_LAYERNORM_BWD(T, true); + } else { + PADDLE_LAUNCH_LAYERNORM_BWD(U, false); + } +#ifdef PADDLE_WITH_CUDA + } +#endif } #undef PADDLE_LAUNCH_LAYERNORM_BWD +#undef PADDLE_LAUNCH_FAST_LAYERNORM_V2_BWD } } // namespace phi +#ifdef PADDLE_WITH_HIP +// MIOPEN do not support double +PD_REGISTER_PLUGIN_KERNEL(layer_norm_grad, + metax_gpu, + ALL_LAYOUT, + phi::LayerNormGradKernel, + float, + phi::float16) { + if (kernel_key.dtype() == phi::DataType::FLOAT16) { + kernel->OutputAt(1).SetDataType(phi::DataType::FLOAT32); + kernel->OutputAt(2).SetDataType(phi::DataType::FLOAT32); + } +} +#elif CUDNN_VERSION_MIN(8, 1, 0) +PD_REGISTER_PLUGIN_KERNEL(layer_norm_grad, + metax_gpu, + ALL_LAYOUT, + phi::LayerNormGradKernel, + float, + double, + phi::float16, + phi::bfloat16) { + if (kernel_key.dtype() == phi::DataType::FLOAT16) { + kernel->OutputAt(1).SetDataType(phi::DataType::FLOAT32); + kernel->OutputAt(2).SetDataType(phi::DataType::FLOAT32); + } +} +#else PD_REGISTER_PLUGIN_KERNEL(layer_norm_grad, metax_gpu, ALL_LAYOUT, phi::LayerNormGradKernel, float, double, - phi::dtype::float16, - phi::dtype::bfloat16) { + phi::float16) { if (kernel_key.dtype() == phi::DataType::FLOAT16) { kernel->OutputAt(1).SetDataType(phi::DataType::FLOAT32); kernel->OutputAt(2).SetDataType(phi::DataType::FLOAT32); } } +#endif diff --git a/backends/metax_gpu/kernels/custom_kernel/layer_norm_kernel_register.cu b/backends/metax_gpu/kernels/custom_kernel/layer_norm_kernel_register.cu index d5e5901784c..7fc6bb50c56 100644 --- a/backends/metax_gpu/kernels/custom_kernel/layer_norm_kernel_register.cu +++ b/backends/metax_gpu/kernels/custom_kernel/layer_norm_kernel_register.cu @@ -12,20 +12,30 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "funcs/layer_norm_util.h" -#include "impl/layer_norm_impl.cu.h" +// #include "kernels/funcs/layer_norm_util.h" #include "paddle/common/flags.h" #include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/kernels/layer_norm_kernel.h" - +#if defined(PADDLE_WITH_CUDA) && !defined(PADDLE_WITH_HIP) && !defined(_WIN32) +#include "paddle/phi/kernels/funcs/fast_ln_v2.h" +#endif +// #include "paddle/phi/kernels/funcs/layer_norm_impl.cu.h" +// #include "paddle/phi/kernels/funcs/layer_norm_util.h" +#include "funcs/layer_norm_util.h" +#include "impl/layer_norm_impl.cu.h" +#ifdef PADDLE_WITH_CUDA +#include "paddle/phi/kernels/gpu/rms_norm_cuda_kernel.h" +#endif COMMON_DECLARE_bool(use_fast_math); namespace phi { +enum class LayerNormKernelVariant { FAST_LN_V1, FAST_LN_V2, GENERIC }; + #ifdef PADDLE_WITH_CUDA template -__device__ inline void WelfordOnline(U val, U *mean, U *square, U *count) { +__device__ inline void WelfordOnline(U val, U* mean, U* square, U* count) { *count += 1; U delta1 = val - *mean; *mean += delta1 / (*count); @@ -35,7 +45,7 @@ __device__ inline void WelfordOnline(U val, U *mean, U *square, U *count) { template __device__ inline void WelfordOnline( - U b_mean, U b_square, U b_cnt, U *mean, U *square, U *count) { + U b_mean, U b_square, U b_cnt, U* mean, U* square, U* count) { if (b_cnt == 0) { return; } @@ -49,7 +59,7 @@ __device__ inline void WelfordOnline( } template -__device__ inline void WelfordWarpAllReduce(U *mean, U *square, U *count) { +__device__ inline void WelfordWarpAllReduce(U* mean, U* square, U* count) { constexpr int kWarpSize = 32; #pragma unroll for (int mask = 1; mask < kWarpSize; mask *= 2) { @@ -68,7 +78,7 @@ template struct ThreadAssigner { __device__ __forceinline__ int operator()(const int cols, const int cols_per_thread, - int32_t *last_tid_idx) { + int32_t* last_tid_idx) { return cols_per_thread; } }; @@ -77,7 +87,7 @@ template <> struct ThreadAssigner<1> { __device__ inline int operator()(const int cols, const int cols_per_thread, - int *last_tid_idx) { + int* last_tid_idx) { int cols_this_thread = cols_per_thread; int last_tid = (cols / cols_per_thread); *last_tid_idx = last_tid; @@ -92,14 +102,14 @@ struct ThreadAssigner<1> { template struct LayerNormDataReader { - __device__ inline void operator()(const T *__restrict__ row_src, - U *buffer, + __device__ inline void operator()(const T* __restrict__ row_src, + U* buffer, const int last_tid_idx, const int read_times, const int cols_this_thread) { - using VecT = phi::AlignedVector; - const VecT *__restrict__ v_src = - reinterpret_cast(row_src); + using VecT = AlignedVector; + const VecT* __restrict__ v_src = + reinterpret_cast(row_src); for (int i = 0; i < read_times; ++i) { VecT temp_src = v_src[threadIdx.x + i * blockDim.x]; @@ -113,8 +123,8 @@ struct LayerNormDataReader { template struct LayerNormDataReader { - __device__ inline void operator()(const T *__restrict__ row_src, - U *buffer, + __device__ inline void operator()(const T* __restrict__ row_src, + U* buffer, const int last_tid_idx, const int read_times, const int cols_this_thread) { @@ -134,10 +144,10 @@ struct LayerNormDataReader { template struct LayerNormDataWriter { __device__ inline void operator()( - T *__restrict__ row_dst, - const U *__restrict__ buffer, - const funcs::LayerNormScaleBiasT *__restrict__ scale, - const funcs::LayerNormScaleBiasT *__restrict__ bias, + T* __restrict__ row_dst, + const U* __restrict__ buffer, + const funcs::LayerNormScaleBiasT* __restrict__ scale, + const funcs::LayerNormScaleBiasT* __restrict__ bias, const U row_mean, const U row_inv_var, const int write_times, @@ -145,10 +155,10 @@ struct LayerNormDataWriter { const int last_tid_idx, const bool valid_scale, const bool valid_bias) { - using VecT = phi::AlignedVector; + using VecT = AlignedVector; using ScaleT = funcs::LayerNormScaleBiasT; - using VecScaleT = phi::AlignedVector; - VecT *v_dst = reinterpret_cast(row_dst); + using VecScaleT = AlignedVector; + VecT* v_dst = reinterpret_cast(row_dst); // cols_this_thread is just cols_per_thread if ((!valid_scale) && (!valid_bias)) { @@ -159,16 +169,16 @@ struct LayerNormDataWriter { temp_dst[j] = static_cast((buffer[i * VecSize + j] - row_mean) * row_inv_var); } - v_dst[threadIdx.x + blockDim.x * i] = temp_dst; + v_dst[threadIdx.x + static_cast(blockDim.x) * i] = temp_dst; } } else { - const VecScaleT *__restrict__ v_scale = - reinterpret_cast(scale); - const VecScaleT *__restrict__ v_bias = - reinterpret_cast(bias); + const VecScaleT* __restrict__ v_scale = + reinterpret_cast(scale); + const VecScaleT* __restrict__ v_bias = + reinterpret_cast(bias); if (valid_scale && valid_bias) { for (int i = 0; i < write_times; ++i) { - int idx = threadIdx.x + blockDim.x * i; + int64_t idx = threadIdx.x + static_cast(blockDim.x) * i; VecT temp_dst; VecScaleT temp_v_scale = v_scale[idx]; VecScaleT temp_v_bias = v_bias[idx]; @@ -184,7 +194,7 @@ struct LayerNormDataWriter { } else { if (valid_scale) { for (int i = 0; i < write_times; ++i) { - int idx = threadIdx.x + blockDim.x * i; + int64_t idx = threadIdx.x + static_cast(blockDim.x) * i; VecT temp_dst; VecScaleT temp_v_scale = v_scale[idx]; #pragma unroll @@ -217,10 +227,10 @@ struct LayerNormDataWriter { template struct LayerNormDataWriter { __device__ __forceinline__ void operator()( - T *__restrict__ row_dst, - U *__restrict__ buffer, - const funcs::LayerNormScaleBiasT *__restrict__ scale, - const funcs::LayerNormScaleBiasT *__restrict__ bias, + T* __restrict__ row_dst, + U* __restrict__ buffer, + const funcs::LayerNormScaleBiasT* __restrict__ scale, + const funcs::LayerNormScaleBiasT* __restrict__ bias, const U row_mean, const U row_inv_var, const int write_times, @@ -232,19 +242,19 @@ struct LayerNormDataWriter { if ((!valid_scale) && (!valid_bias)) { if (threadIdx.x < last_tid_idx) { for (int i = 0; i < cols_this_thread; ++i) { - row_dst[threadIdx.x + last_tid_idx * i] = + row_dst[threadIdx.x + static_cast(last_tid_idx) * i] = (buffer[i] - row_mean) * row_inv_var; } } else { for (int i = 0; i < cols_this_thread; ++i) { - row_dst[last_tid_idx * write_times + i] = + row_dst[static_cast(last_tid_idx) * write_times + i] = (buffer[i] - row_mean) * row_inv_var; } } } else if (valid_scale && valid_bias) { if (threadIdx.x < last_tid_idx) { for (int i = 0; i < cols_this_thread; ++i) { - int idx = threadIdx.x + last_tid_idx * i; + int64_t idx = threadIdx.x + static_cast(last_tid_idx) * i; row_dst[idx] = static_cast(static_cast(scale[idx]) * (buffer[i] - row_mean) * row_inv_var + @@ -252,7 +262,7 @@ struct LayerNormDataWriter { } } else { for (int i = 0; i < cols_this_thread; ++i) { - int idx = last_tid_idx * write_times + i; + int64_t idx = static_cast(last_tid_idx) * write_times + i; row_dst[idx] = static_cast(static_cast(scale[idx]) * (buffer[i] - row_mean) * row_inv_var + @@ -263,13 +273,13 @@ struct LayerNormDataWriter { if (valid_scale) { if (threadIdx.x < last_tid_idx) { for (int i = 0; i < cols_this_thread; ++i) { - int idx = threadIdx.x + last_tid_idx * i; + int64_t idx = threadIdx.x + static_cast(last_tid_idx) * i; row_dst[idx] = static_cast(static_cast(scale[idx]) * (buffer[i] - row_mean) * row_inv_var); } } else { for (int i = 0; i < cols_this_thread; ++i) { - int idx = last_tid_idx * write_times + i; + int64_t idx = static_cast(last_tid_idx) * write_times + i; row_dst[idx] = static_cast(static_cast(scale[idx]) * (buffer[i] - row_mean) * row_inv_var); } @@ -277,13 +287,13 @@ struct LayerNormDataWriter { } else { if (threadIdx.x < last_tid_idx) { for (int i = 0; i < cols_this_thread; ++i) { - int idx = threadIdx.x + last_tid_idx * i; + int64_t idx = threadIdx.x + static_cast(last_tid_idx) * i; row_dst[idx] = static_cast((buffer[i] - row_mean) * row_inv_var + static_cast(bias[idx])); } } else { for (int i = 0; i < cols_this_thread; ++i) { - int idx = last_tid_idx * write_times + i; + int64_t idx = static_cast(last_tid_idx) * write_times + i; row_dst[idx] = static_cast((buffer[i] - row_mean) * row_inv_var + static_cast(bias[idx])); } @@ -295,12 +305,12 @@ struct LayerNormDataWriter { template __global__ void LayerNormFwdWithWelford( - const T *__restrict__ src_data, - T *dst_data, - const funcs::LayerNormScaleBiasT *__restrict__ scale, - const funcs::LayerNormScaleBiasT *__restrict__ bias, - U *mean, - U *var, + const T* __restrict__ src_data, + T* dst_data, + const funcs::LayerNormScaleBiasT* __restrict__ scale, + const funcs::LayerNormScaleBiasT* __restrict__ bias, + U* mean, + U* var, const U epsilon, const IndexT rows, const int32_t cols, @@ -320,8 +330,8 @@ __global__ void LayerNormFwdWithWelford( U tid_mean = static_cast(0); U tid_square = static_cast(0); - const T *__restrict__ row_src = src_data + row_offset * cols; - T *row_dst = dst_data + row_offset * cols; + const T* __restrict__ row_src = src_data + row_offset * cols; + T* row_dst = dst_data + row_offset * cols; LayerNormDataReader()( row_src, buffer, last_tid_idx, read_times, cols_this_thread); @@ -358,14 +368,14 @@ __global__ void LayerNormFwdWithWelford( } template -void LaunchLayerNormKernel(const Context &dev_ctx, - const T *x_data, - T *y_data, - const void *void_scale_data, - const void *void_bias_data, - U *mean_data, - U *var_data, - float epsilon, +void LaunchLayerNormKernel(const Context& dev_ctx, + const T* x_data, + T* y_data, + const void* void_scale_data, + const void* void_bias_data, + U* mean_data, + U* var_data, + double epsilon, const int64_t rows, const int cols, const bool valid_scale, @@ -389,7 +399,8 @@ void LaunchLayerNormKernel(const Context &dev_ctx, : addr; addr = valid_bias ? (addr | reinterpret_cast(void_bias_data)) : addr; - data_vec_size = phi::GetVectorizedSize(reinterpret_cast(addr)); + data_vec_size = + std::min(4, phi::GetVectorizedSize(reinterpret_cast(addr))); } else { uint64_t bias_addr = reinterpret_cast(void_bias_data); uint64_t attr_addr = valid_scale @@ -399,8 +410,9 @@ void LaunchLayerNormKernel(const Context &dev_ctx, ? (valid_scale ? (attr_addr | bias_addr) : attr_addr) : attr_addr; data_vec_size = std::min( - phi::GetVectorizedSize(reinterpret_cast(addr)), - phi::GetVectorizedSize(reinterpret_cast(attr_addr))); + phi::GetVectorizedSize(reinterpret_cast(addr)), + phi::GetVectorizedSize(reinterpret_cast(attr_addr))); + data_vec_size = std::min(4, data_vec_size); } } for (int size = data_vec_size; size > 0; size /= 2) { @@ -417,8 +429,8 @@ void LaunchLayerNormKernel(const Context &dev_ctx, <<>>( \ x_data, \ y_data, \ - static_cast(void_scale_data), \ - static_cast(void_bias_data), \ + static_cast(void_scale_data), \ + static_cast(void_bias_data), \ mean_data, \ var_data, \ static_cast(epsilon), \ @@ -455,22 +467,24 @@ void LaunchLayerNormKernel(const Context &dev_ctx, template void LayerNormDirectCUDAFunctor::operator()( gpuStream_t stream, - const T *input, + const T* input, std::vector input_shape, - const U *bias, - const U *scale, - T *output, - U *mean, - U *variance, + const U* bias, + const U* scale, + T* output, + U* mean, + U* variance, int begin_norm_axis, float eps) { - const auto x_dims = common::make_ddim(input_shape); + const auto x_dims = make_ddim(input_shape); auto matrix_dim = common::flatten_to_2d(x_dims, begin_norm_axis); - int64_t batch_size = static_cast(matrix_dim[0]); - int64_t feature_size = static_cast(matrix_dim[1]); - switch (phi::funcs::GetDesiredBlockDim(feature_size)) { + int64_t batch_size = matrix_dim[0]; + int64_t feature_size = matrix_dim[1]; + // TODO(large-tensor): generic kernel launch uses int32 grid dim + PADDLE_ENFORCE_LE_INT_MAX(batch_size, "batch_size"); + switch (funcs::GetDesiredBlockDim(feature_size)) { FIXED_BLOCK_DIM_CASE( - phi::funcs::LayerNormForward + funcs::LayerNormForward <<>>( input, scale, bias, output, mean, variance, eps, feature_size)); default: @@ -481,38 +495,79 @@ void LayerNormDirectCUDAFunctor::operator()( } } -template class LayerNormDirectCUDAFunctor; -template class LayerNormDirectCUDAFunctor; +template class PADDLE_API LayerNormDirectCUDAFunctor; +template class PADDLE_API LayerNormDirectCUDAFunctor; #if defined(PADDLE_WITH_CUDA) && !defined(PADDLE_WITH_HIP) -template class LayerNormDirectCUDAFunctor; +template class PADDLE_API LayerNormDirectCUDAFunctor; +#endif +static inline LayerNormKernelVariant LayerNormKernelDispatch( + const paddle::DataType weight_type, + const paddle::DataType input_type, + const paddle::DataType output_type, + const paddle::DataType compute_type, + const uint32_t hidden_size, + const int64_t x_numel, + const DenseTensor* scale, + const DenseTensor* bias) { + if (scale == nullptr || bias == nullptr) { + return LayerNormKernelVariant::GENERIC; + } +#ifdef PADDLE_WITH_CUDA + if (FLAGS_use_accuracy_compatible_kernel) { + return LayerNormKernelVariant::GENERIC; + } +#endif +#if defined(PADDLE_WITH_CUDA) && !defined(PADDLE_WITH_HIP) && !defined(_WIN32) + if (input_type != paddle::DataType::FLOAT32 && hidden_size != 4096 && + hidden_size > 1024 && hidden_size <= 10240 && + x_numel <= std::numeric_limits::max()) { + // using fast_ln_v2 only sm > 70 and x_numel <= uint32_max + auto prop = funcs::fast_ln_v2::GetDeviceProp(); + if (prop->major > 7 && + funcs::fast_ln_v2::has_fast_ln_v2_fwd_kernel( + weight_type, input_type, output_type, compute_type, hidden_size)) { + return LayerNormKernelVariant::FAST_LN_V2; + } + } #endif + if ((hidden_size >= 768 && hidden_size <= 2048 && hidden_size % 256 == 0 || + hidden_size == 4096) && + x_numel <= std::numeric_limits::max() && scale != nullptr && + bias != nullptr) { + return LayerNormKernelVariant::FAST_LN_V1; + } + + return LayerNormKernelVariant::GENERIC; +} template -void LayerNormKernel(const Context &dev_ctx, - const DenseTensor &x, - const paddle::optional &scale_opt, - const paddle::optional &bias_opt, - float epsilon, +void LayerNormKernel(const Context& dev_ctx, + const DenseTensor& x, + const optional& scale_opt, + const optional& bias_opt, + double epsilon, int begin_norm_axis, - DenseTensor *y, - DenseTensor *mean, - DenseTensor *var) { - using U = phi::funcs::LayerNormParamType; - auto *scale = scale_opt.get_ptr(); - auto *bias = bias_opt.get_ptr(); + DenseTensor* y, + DenseTensor* mean, + DenseTensor* var) { + using U = funcs::LayerNormParamType; + auto* scale = scale_opt.get_ptr(); + auto* bias = bias_opt.get_ptr(); const auto x_dims = x.dims(); - auto *x_data = x.data(); - auto *y_data = dev_ctx.template Alloc(y); - auto *mean_data = dev_ctx.template Alloc(mean); - auto *var_data = dev_ctx.template Alloc(var); + auto* x_data = x.data(); + auto* y_data = dev_ctx.template Alloc(y); + auto* mean_data = dev_ctx.template Alloc(mean); + auto* var_data = dev_ctx.template Alloc(var); + if (x.numel() == 0) return; bool valid_scale = (scale != nullptr); bool valid_bias = (bias != nullptr); - auto *void_scale_data = valid_scale ? scale->data() : nullptr; - auto *void_bias_data = valid_bias ? bias->data() : nullptr; + auto* void_scale_data = valid_scale ? scale->data() : nullptr; + auto* void_bias_data = valid_bias ? bias->data() : nullptr; auto x_dtype = x.dtype(); + auto y_dtype = y->dtype(); phi::DataType scale_bias_dtype; if (valid_scale) { scale_bias_dtype = scale->dtype(); @@ -536,33 +591,35 @@ void LayerNormKernel(const Context &dev_ctx, } auto matrix_dim = common::flatten_to_2d(x_dims, begin_norm_axis); - int64_t batch_size = static_cast(matrix_dim[0]); - int64_t feature_size = static_cast(matrix_dim[1]); + int64_t batch_size = matrix_dim[0]; + // TODO(large-tensor): generic kernel launch uses int32 grid dim + PADDLE_ENFORCE_LE_INT_MAX(batch_size, "batch_size"); + int64_t feature_size = matrix_dim[1]; auto stream = dev_ctx.stream(); - -#define PADDLE_LAUNCH_LAYERNORM_FWD(ScaleBiasT, IsScaleBiasSameDTypeWithX) \ - do { \ - switch (phi::funcs::GetDesiredBlockDim(feature_size)) { \ - FIXED_BLOCK_DIM_CASE( \ - phi::funcs:: \ - LayerNormForward \ - <<>>( \ - x_data, \ - static_cast(void_scale_data), \ - static_cast(void_bias_data), \ - y_data, \ - mean_data, \ - var_data, \ - epsilon, \ - feature_size)); \ - default: \ - PADDLE_THROW(common::errors::InvalidArgument( \ - "Product from begin_norm_axis to end must be larger than 1")); \ - break; \ - } \ + auto place = x.place(); + +#define PADDLE_LAUNCH_LAYERNORM_FWD(ScaleBiasT, IsScaleBiasSameDTypeWithX) \ + do { \ + switch (funcs::GetDesiredBlockDim(feature_size)) { \ + FIXED_BLOCK_DIM_CASE( \ + funcs::LayerNormForward \ + <<>>( \ + x_data, \ + static_cast(void_scale_data), \ + static_cast(void_bias_data), \ + y_data, \ + mean_data, \ + var_data, \ + epsilon, \ + feature_size)); \ + default: \ + PADDLE_THROW(common::errors::InvalidArgument( \ + "Product from begin_norm_axis to end must be larger than 1")); \ + break; \ + } \ } while (0) -#define PADDLE_LAUNCH_FAST_LAYERNORM_FWD_BASE(ScaleT, feature_size) \ +#define PADDLE_LAUNCH_FAST_LAYERNORM_V1_FWD_BASE(ScaleT, feature_size) \ case (feature_size): { \ constexpr int WARPS_N = feature_size < 1024 ? 1 : (feature_size / 1024); \ constexpr int WARPS_M = 4 / WARPS_N; \ @@ -573,98 +630,161 @@ void LayerNormKernel(const Context &dev_ctx, const int ROWS_PER_CTA = WARPS_M; \ const int grid = static_cast( \ std::ceil(batch_size / static_cast(ROWS_PER_CTA))); \ - phi::funcs::fast_ln_fwd_kernel \ + funcs::fast_ln_v1::fast_ln_v1_fwd_kernel \ <<>>( \ batch_size, \ feature_size, \ epsilon, \ x_data, \ - static_cast(void_scale_data), \ - static_cast(void_bias_data), \ + static_cast(void_scale_data), \ + static_cast(void_bias_data), \ mean_data, \ var_data, \ y_data); \ } break -#define PADDLE_LAUNCH_FAST_LAYERNORM_FWD(ScaleT) \ - PADDLE_LAUNCH_FAST_LAYERNORM_FWD_BASE(ScaleT, 768); \ - PADDLE_LAUNCH_FAST_LAYERNORM_FWD_BASE(ScaleT, 1024); \ - PADDLE_LAUNCH_FAST_LAYERNORM_FWD_BASE(ScaleT, 1280); \ - PADDLE_LAUNCH_FAST_LAYERNORM_FWD_BASE(ScaleT, 1536); \ - PADDLE_LAUNCH_FAST_LAYERNORM_FWD_BASE(ScaleT, 1792); \ - PADDLE_LAUNCH_FAST_LAYERNORM_FWD_BASE(ScaleT, 2048); \ - PADDLE_LAUNCH_FAST_LAYERNORM_FWD_BASE(ScaleT, 4096) - -#ifdef PADDLE_WITH_CUDA - bool can_call_fast_kernel = false; - if ((feature_size >= 768 && feature_size <= 2048 && feature_size % 256 == 0 || - feature_size == 4096) && - scale != nullptr && bias != nullptr) { - can_call_fast_kernel = true; - } - - if (can_call_fast_kernel) { - if (is_scale_bias_same_dtype_with_x) { - switch (feature_size) { - PADDLE_LAUNCH_FAST_LAYERNORM_FWD(T); - default: - PADDLE_THROW(common::errors::InvalidArgument( - "Only when feature_size is from 256 to 4096 and is diviaible by " - "256 is supported " - "now")); - break; - } - } else { - switch (feature_size) { - PADDLE_LAUNCH_FAST_LAYERNORM_FWD(U); - default: - PADDLE_THROW(common::errors::InvalidArgument( - "Only when feature_size is from 256 to 4096 and is diviaible by " - "is supported " - "now")); - break; - } - } - } else { - // WarpShuffle intrinsics is involved in LaunchLayerNormKernel. - if (FLAGS_use_fast_math && feature_size <= 1024 && - (!std::is_same::value)) { - LaunchLayerNormKernel(dev_ctx, - x_data, - y_data, - void_scale_data, - void_bias_data, - mean_data, - var_data, - epsilon, - batch_size, - feature_size, - valid_scale, - valid_bias, - is_scale_bias_same_dtype_with_x); - } else { +#define PADDLE_LAUNCH_FAST_LAYERNORM_V1_FWD(ScaleT) \ + PADDLE_LAUNCH_FAST_LAYERNORM_V1_FWD_BASE(ScaleT, 768); \ + PADDLE_LAUNCH_FAST_LAYERNORM_V1_FWD_BASE(ScaleT, 1024); \ + PADDLE_LAUNCH_FAST_LAYERNORM_V1_FWD_BASE(ScaleT, 1280); \ + PADDLE_LAUNCH_FAST_LAYERNORM_V1_FWD_BASE(ScaleT, 1536); \ + PADDLE_LAUNCH_FAST_LAYERNORM_V1_FWD_BASE(ScaleT, 1792); \ + PADDLE_LAUNCH_FAST_LAYERNORM_V1_FWD_BASE(ScaleT, 2048); \ + PADDLE_LAUNCH_FAST_LAYERNORM_V1_FWD_BASE(ScaleT, 4096) + auto compute_dtype = phi::CppTypeToDataType::Type(); + auto kernel_variant = LayerNormKernelDispatch(scale_bias_dtype, + x_dtype, + y_dtype, + compute_dtype, + feature_size, + x.numel(), + scale, + bias); + + switch (kernel_variant) { +#if defined(PADDLE_WITH_CUDA) && !defined(PADDLE_WITH_HIP) && !defined(_WIN32) + case LayerNormKernelVariant::FAST_LN_V2: + funcs::fast_ln_v2::LaunchNormFwd(dev_ctx, + stream, + place, + x_data, + void_scale_data, + void_bias_data, + y_data, + mean_data, + var_data, + scale_bias_dtype, + x_dtype, + y_dtype, + compute_dtype, + feature_size, + batch_size, + feature_size, + epsilon); + break; #endif + case LayerNormKernelVariant::FAST_LN_V1: if (is_scale_bias_same_dtype_with_x) { - PADDLE_LAUNCH_LAYERNORM_FWD(T, true); + switch (feature_size) { + PADDLE_LAUNCH_FAST_LAYERNORM_V1_FWD(T); + default: + break; + } } else { - PADDLE_LAUNCH_LAYERNORM_FWD(U, false); + switch (feature_size) { + PADDLE_LAUNCH_FAST_LAYERNORM_V1_FWD(U); + default: + break; + } } + break; + case LayerNormKernelVariant::GENERIC: + default: #ifdef PADDLE_WITH_CUDA - } - } + if (FLAGS_use_accuracy_compatible_kernel || + (!isPowerOfTwo(feature_size) && feature_size > 1024)) { + LayerNormFwdCompatKernel( + dev_ctx, + x_data, + valid_scale ? static_cast(void_scale_data) : nullptr, + valid_bias ? static_cast(void_bias_data) : nullptr, + epsilon, + batch_size, + feature_size, + y_data, + mean_data, + var_data); + } else if (FLAGS_use_fast_math && feature_size <= 1024 && + (!std::is_same::value)) { + // WarpShuffle intrinsics is involved in LaunchLayerNormKernel. + LaunchLayerNormKernel(dev_ctx, + x_data, + y_data, + void_scale_data, + void_bias_data, + mean_data, + var_data, + epsilon, + batch_size, + feature_size, + valid_scale, + valid_bias, + is_scale_bias_same_dtype_with_x); + } else { #endif + if (is_scale_bias_same_dtype_with_x) { + PADDLE_LAUNCH_LAYERNORM_FWD(T, true); + } else { + PADDLE_LAUNCH_LAYERNORM_FWD(U, false); + } +#ifdef PADDLE_WITH_CUDA + } +#endif + break; + } #undef PADDLE_LAUNCH_LAYERNORM_FWD -#undef PADDLE_LAUNCH_FAST_LAYERNORM_FWD +#undef PADDLE_LAUNCH_FAST_LAYERNORM_V1_FWD } - +#ifdef _WIN32 +template PADDLE_API void LayerNormKernel( + const GPUContext& dev_ctx, + const DenseTensor& x, + const optional& scale_opt, + const optional& bias_opt, + double epsilon, + int begin_norm_axis, + DenseTensor* y, + DenseTensor* mean, + DenseTensor* var); +template PADDLE_API void LayerNormKernel( + const GPUContext& dev_ctx, + const DenseTensor& x, + const optional& scale_opt, + const optional& bias_opt, + double epsilon, + int begin_norm_axis, + DenseTensor* y, + DenseTensor* mean, + DenseTensor* var); +template PADDLE_API void LayerNormKernel( + const GPUContext& dev_ctx, + const DenseTensor& x, + const optional& scale_opt, + const optional& bias_opt, + double epsilon, + int begin_norm_axis, + DenseTensor* y, + DenseTensor* mean, + DenseTensor* var); +#endif } // namespace phi #ifdef PADDLE_WITH_HIP @@ -674,7 +794,7 @@ PD_REGISTER_PLUGIN_KERNEL(layer_norm, ALL_LAYOUT, phi::LayerNormKernel, float, - phi::dtype::float16) { + phi::float16) { kernel->OutputAt(1).SetDataType(phi::DataType::UNDEFINED); kernel->OutputAt(2).SetDataType(phi::DataType::UNDEFINED); } @@ -685,8 +805,8 @@ PD_REGISTER_PLUGIN_KERNEL(layer_norm, phi::LayerNormKernel, float, double, - phi::dtype::float16, - phi::dtype::bfloat16) { + phi::float16, + phi::bfloat16) { kernel->OutputAt(1).SetDataType(phi::DataType::UNDEFINED); kernel->OutputAt(2).SetDataType(phi::DataType::UNDEFINED); } @@ -697,7 +817,7 @@ PD_REGISTER_PLUGIN_KERNEL(layer_norm, phi::LayerNormKernel, float, double, - phi::dtype::float16) { + phi::float16) { kernel->OutputAt(1).SetDataType(phi::DataType::UNDEFINED); kernel->OutputAt(2).SetDataType(phi::DataType::UNDEFINED); } diff --git a/backends/metax_gpu/kernels/impl/addmm_kernel_impl.h b/backends/metax_gpu/kernels/impl/addmm_kernel_impl.h index a2c69b6adf0..9de0b3aadd4 100644 --- a/backends/metax_gpu/kernels/impl/addmm_kernel_impl.h +++ b/backends/metax_gpu/kernels/impl/addmm_kernel_impl.h @@ -26,14 +26,11 @@ limitations under the License. */ // clang-format on namespace phi { -template -using PhiEigenTensor = EigenTensor; +template +using PhiEigenTensor = EigenTensor; -using Array1 = Eigen::DSizes; -using Array2 = Eigen::DSizes; +using Array1 = Eigen::DSizes; +using Array2 = Eigen::DSizes; template void AddmmKernel(const Context& dev_ctx, diff --git a/backends/metax_gpu/kernels/impl/layer_norm_impl.cu.h b/backends/metax_gpu/kernels/impl/layer_norm_impl.cu.h index 18e3401bcdd..53c5aff8875 100644 --- a/backends/metax_gpu/kernels/impl/layer_norm_impl.cu.h +++ b/backends/metax_gpu/kernels/impl/layer_norm_impl.cu.h @@ -33,6 +33,7 @@ namespace cub = hipcub; #include "paddle/phi/backends/gpu/gpu_dnn.h" #include "paddle/phi/common/memory_utils.h" #include "paddle/phi/kernels/funcs/aligned_vector.h" +#include "paddle/phi/kernels/funcs/fast_ln_v1.h" namespace phi { namespace funcs { diff --git a/backends/metax_gpu/patch/paddle.patch b/backends/metax_gpu/patch/paddle.patch index 0d06e6aad2f..c6f5c331c9d 100755 --- a/backends/metax_gpu/patch/paddle.patch +++ b/backends/metax_gpu/patch/paddle.patch @@ -1,15 +1,19 @@ diff --git a/cmake/external/eigen.cmake b/cmake/external/eigen.cmake -index cfada544d4..a690e97d74 100644 +index 6790d961f8..37b841bfc4 100644 --- a/cmake/external/eigen.cmake +++ b/cmake/external/eigen.cmake -@@ -42,12 +42,12 @@ endif() +@@ -42,14 +42,14 @@ endif() file(TO_NATIVE_PATH "${PADDLE_SOURCE_DIR}/patches/eigen/TensorRandom.h.patch" tensor_random_header) # See: [Why calling some `git` commands before `patch`?] --set(EIGEN_PATCH_COMMAND git checkout -- . && git checkout ${EIGEN_TAG} && git -- apply ${tensor_random_header}) -+# set(EIGEN_PATCH_COMMAND git checkout -- . && git checkout ${EIGEN_TAG} && git -+# apply ${tensor_random_header}) +-set(EIGEN_PATCH_COMMAND +- git checkout -- Eigen/src/Core/arch/SSE/Complex.h +- unsupported/Eigen/CXX11/src/Tensor/TensorRandom.h && git checkout +- ${EIGEN_TAG} && git apply ${tensor_random_header}) ++# set(EIGEN_PATCH_COMMAND ++# git checkout -- Eigen/src/Core/arch/SSE/Complex.h ++# unsupported/Eigen/CXX11/src/Tensor/TensorRandom.h && git checkout ++# ${EIGEN_TAG} && git apply ${tensor_random_header}) if(CMAKE_CXX_COMPILER_ID MATCHES "GNU|Clang") file(TO_NATIVE_PATH ${PADDLE_SOURCE_DIR}/patches/eigen/Complex.h.patch complex_header) @@ -35,18 +39,19 @@ index 8d445b39ae..504e7b6293 100755 op_library(fused_gemm_epilogue_op) endif() diff --git a/paddle/phi/backends/dynload/dynamic_loader.cc b/paddle/phi/backends/dynload/dynamic_loader.cc -index 37ee00b591..f497472ad9 100644 +index 99d3733da0..d0c8783afd 100644 --- a/paddle/phi/backends/dynload/dynamic_loader.cc +++ b/paddle/phi/backends/dynload/dynamic_loader.cc -@@ -18,7 +18,6 @@ limitations under the License. */ - #include +@@ -19,7 +19,7 @@ limitations under the License. */ + #include #include #include -#include "paddle/phi/backends/dynload/cupti_lib_path.h" ++// #include "paddle/phi/backends/dynload/cupti_lib_path.h" #include "paddle/phi/common/port.h" #include "paddle/phi/core/enforce.h" -@@ -112,6 +111,10 @@ COMMON_DECLARE_string(magma_dir); +@@ -113,6 +113,10 @@ COMMON_DECLARE_string(magma_dir); #define SPARSELT_LIB_NAME "libcusparseLt.so" #endif @@ -1096,3 +1101,4 @@ index d8bc15926b..6071baf340 100644 PADDLE_ENFORCE_EQ( status, + diff --git a/backends/metax_gpu/tests/ignore.txt b/backends/metax_gpu/tests/ignore.txt index ce7ec2e3621..3ec7440dad5 100644 --- a/backends/metax_gpu/tests/ignore.txt +++ b/backends/metax_gpu/tests/ignore.txt @@ -71,6 +71,7 @@ test_householder_product test_paddle_device test_conv2d_op test_rnn_op +test_interp_antialias_op [internet] test_hapi_amp