Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Paddle
Submodule Paddle updated 627 files
6 changes: 3 additions & 3 deletions backends/metax_gpu/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -158,8 +158,8 @@ file(
${PADDLE_SOURCE_DIR}/paddle/phi/kernels/gpu/grid_sample_kernel.cu
${PADDLE_SOURCE_DIR}/paddle/phi/kernels/gpu/instance_norm_grad_kernel.cu
${PADDLE_SOURCE_DIR}/paddle/phi/kernels/gpu/instance_norm_kernel.cu
${PADDLE_SOURCE_DIR}/paddle/phi/kernels/gpu/rnn_grad_kernel.cu.cc
${PADDLE_SOURCE_DIR}/paddle/phi/kernels/gpu/rnn_kernel.cu.cc
${PADDLE_SOURCE_DIR}/paddle/phi/kernels/gpu/rnn_grad_kernel.cu
${PADDLE_SOURCE_DIR}/paddle/phi/kernels/gpu/rnn_kernel.cu
${PADDLE_SOURCE_DIR}/paddle/phi/kernels/gpu/ctc_align_kernel.cu
${PADDLE_SOURCE_DIR}/paddle/phi/kernels/gpu/yolo_box_head_kernel.cu
${PADDLE_SOURCE_DIR}/paddle/phi/kernels/gpu/stft_grad_kernel.cu
Expand Down Expand Up @@ -290,7 +290,7 @@ file(
${PADDLE_SOURCE_DIR}/paddle/phi/kernels/unsqueeze_kernel.cc
${PADDLE_SOURCE_DIR}/paddle/phi/kernels/squeeze_grad_kernel.cc
${PADDLE_SOURCE_DIR}/paddle/phi/kernels/squeeze_kernel.cc
${PADDLE_SOURCE_DIR}/paddle/phi/kernels/gpu/sign_kernel.cu.cc
${PADDLE_SOURCE_DIR}/paddle/phi/kernels/gpu/sign_kernel.cu
${PADDLE_SOURCE_DIR}/paddle/phi/kernels/gpu/split_kernel.cu
${PADDLE_SOURCE_DIR}/paddle/phi/kernels/gpu/scatter_nd_add_kernel.cu
${PADDLE_SOURCE_DIR}/paddle/phi/kernels/gpu/scatter_nd_add_grad_kernel.cu
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ PD_CUSTOM_KERNEL_REGISTER(sign,
int64_t,
float,
double,
phi::dtype::float16,
phi::dtype::bfloat16,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
phi::float16,
phi::bfloat16,
phi::complex64,
phi::complex128) {}
Original file line number Diff line number Diff line change
Expand Up @@ -14,48 +14,99 @@

#include "funcs/layer_norm_util.h"
#include "impl/layer_norm_impl.cu.h"
#include "paddle/common/flags.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/cast_kernel.h"
#include "paddle/phi/kernels/full_kernel.h"
#include "paddle/phi/kernels/layer_norm_grad_kernel.h"
#if defined(PADDLE_WITH_CUDA) && !defined(PADDLE_WITH_HIP) && !defined(_WIN32)
#include "paddle/phi/kernels/funcs/fast_ln_v2.h"
#endif
#ifdef PADDLE_WITH_CUDA
#include "paddle/phi/kernels/gpu/rms_norm_cuda_kernel.h"
#endif

namespace phi {
enum class LayerNormGadKernelVariant { FAST_LN_V2, GENERIC };
static inline LayerNormGadKernelVariant LayerNormGradKernelDispatch(
const paddle::DataType weight_type,
const paddle::DataType input_type,
const paddle::DataType output_type,
const paddle::DataType compute_type,
const uint32_t hidden_size,
const int64_t x_numel,
const DenseTensor* scale,
const DenseTensor* bias) {
#if defined(PADDLE_WITH_CUDA) && !defined(PADDLE_WITH_HIP) && !defined(_WIN32)
if (scale != nullptr && bias != nullptr &&
input_type != paddle::DataType::FLOAT32 && hidden_size != 4096 &&
hidden_size > 1024 && hidden_size <= 10240 &&
x_numel <= std::numeric_limits<uint32_t>::max()) {
// using fast_ln_v2 only sm > 70 and x_numel <= uint32_max
auto prop = funcs::fast_ln_v2::GetDeviceProp();
if (prop->major > 7 &&
funcs::fast_ln_v2::has_fast_ln_v2_bwd_kernel(
weight_type, input_type, output_type, compute_type, hidden_size)) {
return LayerNormGadKernelVariant::FAST_LN_V2;
}
}
#endif
return LayerNormGadKernelVariant::GENERIC;
}

template <typename T, typename Context>
void LayerNormGradKernel(const Context &dev_ctx,
const DenseTensor &x,
const paddle::optional<DenseTensor> &scale_opt,
const paddle::optional<DenseTensor> &bias_opt,
const DenseTensor &mean,
const DenseTensor &variance,
const DenseTensor &out_grad,
float epsilon,
void LayerNormGradKernel(const Context& dev_ctx,
const DenseTensor& x,
const optional<DenseTensor>& scale_opt,
const optional<DenseTensor>& bias_opt,
const DenseTensor& mean,
const DenseTensor& variance,
const DenseTensor& out_grad,
double epsilon,
int begin_norm_axis,
DenseTensor *x_grad,
DenseTensor *scale_grad,
DenseTensor *bias_grad) {
using U = phi::funcs::LayerNormParamType<T>;
DenseTensor* x_grad,
DenseTensor* scale_grad,
DenseTensor* bias_grad) {
if (x.numel() == 0) {
dev_ctx.template Alloc<T>(x_grad);
if (scale_grad) {
Full<T, Context>(dev_ctx, scale_grad->dims(), 0, scale_grad);
if (scale_opt.get_ptr() && x.dtype() != scale_opt.get().dtype()) {
CastKernel<T, Context>(
dev_ctx, *scale_grad, scale_opt.get().dtype(), scale_grad);
}
}
if (bias_grad) {
Full<T, Context>(dev_ctx, bias_grad->dims(), 0, bias_grad);
if (bias_opt.get_ptr() && x.dtype() != bias_opt.get().dtype()) {
CastKernel<T, Context>(
dev_ctx, *bias_grad, bias_opt.get().dtype(), bias_grad);
}
}
return;
}
using U = funcs::LayerNormParamType<T>;
// d_x, d_scale, d_bias may be nullptr
auto *d_x = x_grad;
auto *d_scale = scale_grad;
auto *d_bias = bias_grad;
auto* d_x = x_grad;
auto* d_scale = scale_grad;
auto* d_bias = bias_grad;

auto *scale = scale_opt.get_ptr();
auto *bias = bias_opt.get_ptr();
auto *d_y = &out_grad;
auto* scale = scale_opt.get_ptr();
auto* bias = bias_opt.get_ptr();
auto* d_y = &out_grad;

const auto &x_dims = x.dims();
const auto& x_dims = x.dims();
auto matrix_dim = common::flatten_to_2d(x_dims, begin_norm_axis);
int64_t batch_size = static_cast<int64_t>(matrix_dim[0]);
int64_t feature_size = static_cast<int64_t>(matrix_dim[1]);

auto *x_data = x.data<T>();
auto *d_y_data = d_y->data<T>();
auto* x_data = x.data<T>();
auto* d_y_data = d_y->data<T>();

auto *mean_data = mean.data<U>();
auto *var_data = variance.data<U>();
auto* mean_data = mean.data<U>();
auto* var_data = variance.data<U>();

auto *d_x_data = (d_x == nullptr ? nullptr : dev_ctx.template Alloc<T>(d_x));
auto* d_x_data = (d_x == nullptr ? nullptr : dev_ctx.template Alloc<T>(d_x));

auto x_dtype = x.dtype();

Expand All @@ -74,52 +125,165 @@ void LayerNormGradKernel(const Context &dev_ctx,

#define PADDLE_LAUNCH_LAYERNORM_BWD(ScaleBiasT, IsScaleBiasSameDTypeWithX) \
do { \
auto *scale_data = \
auto* scale_data = \
(scale == nullptr ? nullptr : scale->data<ScaleBiasT>()); \
auto *d_scale_data = \
auto* d_scale_data = \
(d_scale == nullptr ? nullptr \
: dev_ctx.template Alloc<ScaleBiasT>(d_scale)); \
auto *d_bias_data = \
auto* d_bias_data = \
(d_bias == nullptr ? nullptr \
: dev_ctx.template Alloc<ScaleBiasT>(d_bias)); \
auto *d_x_data = \
auto* d_x_data = \
(d_x == nullptr ? nullptr : dev_ctx.template Alloc<T>(d_x)); \
phi::funcs::LayerNormBackward<T, U, IsScaleBiasSameDTypeWithX>( \
x_data, \
d_y_data, \
scale_data, \
mean_data, \
var_data, \
d_x_data, \
d_scale_data, \
d_bias_data, \
epsilon, \
batch_size, \
feature_size, \
dev_ctx); \
funcs::LayerNormBackward<T, U, IsScaleBiasSameDTypeWithX>(x_data, \
d_y_data, \
scale_data, \
mean_data, \
var_data, \
d_x_data, \
d_scale_data, \
d_bias_data, \
epsilon, \
batch_size, \
feature_size, \
dev_ctx); \
} while (0)

if (scale_bias_dtype == x_dtype) {
PADDLE_LAUNCH_LAYERNORM_BWD(T, true);
} else {
PADDLE_LAUNCH_LAYERNORM_BWD(U, false);
#define PADDLE_LAUNCH_FAST_LAYERNORM_V2_BWD(ScaleBiasT) \
do { \
auto stream = dev_ctx.stream(); \
auto place = x.place(); \
auto* scale_data = \
(scale == nullptr ? nullptr : scale->data<ScaleBiasT>()); \
auto* d_scale_data = \
(d_scale == nullptr ? nullptr \
: dev_ctx.template Alloc<ScaleBiasT>(d_scale)); \
auto* d_bias_data = \
(d_bias == nullptr ? nullptr \
: dev_ctx.template Alloc<ScaleBiasT>(d_bias)); \
auto* d_x_data = \
(d_x == nullptr ? nullptr : dev_ctx.template Alloc<T>(d_x)); \
funcs::fast_ln_v2::LaunchNormBwd<T, Context>(dev_ctx, \
stream, \
place, \
x_data, \
scale_data, \
mean_data, \
var_data, \
d_y_data, \
d_x_data, \
d_scale_data, \
d_bias_data, \
scale_bias_dtype, \
x_dtype, \
x_grad->dtype(), \
compute_dtype, \
feature_size, \
batch_size, \
feature_size, \
epsilon); \
} while (0)

auto compute_dtype = phi::CppTypeToDataType<U>::Type();
auto kernel_variant = LayerNormGradKernelDispatch(scale_bias_dtype,
x_dtype,
x_dtype,
compute_dtype,
feature_size,
x.numel(),
scale,
bias);
switch (kernel_variant) {
#if defined(PADDLE_WITH_CUDA) && !defined(PADDLE_WITH_HIP) && !defined(_WIN32)
case LayerNormGadKernelVariant::FAST_LN_V2:
if (scale_bias_dtype == x_dtype) {
PADDLE_LAUNCH_FAST_LAYERNORM_V2_BWD(T);
} else {
PADDLE_LAUNCH_FAST_LAYERNORM_V2_BWD(U);
}
break;
#endif
case LayerNormGadKernelVariant::GENERIC:
default:
#ifdef PADDLE_WITH_CUDA
if ((FLAGS_use_accuracy_compatible_kernel ||
(!isPowerOfTwo(feature_size) && feature_size > 1024)) &&
scale_bias_dtype == x_dtype) {
auto* scale_data = (scale == nullptr ? nullptr : scale->data<T>());
auto* d_scale_data =
(d_scale == nullptr ? nullptr : dev_ctx.template Alloc<T>(d_scale));
auto* d_bias_data =
(d_bias == nullptr ? nullptr : dev_ctx.template Alloc<T>(d_bias));
auto* d_x_data =
(d_x == nullptr ? nullptr : dev_ctx.template Alloc<T>(d_x));
LayerNormBwdCompatKernel<T, Context>(dev_ctx,
d_y_data,
x_data,
scale_data,
mean_data,
var_data,
d_x_data,
d_scale_data,
d_bias_data,
epsilon,
batch_size,
feature_size);
} else {
#endif
if (scale_bias_dtype == x_dtype) {
PADDLE_LAUNCH_LAYERNORM_BWD(T, true);
} else {
PADDLE_LAUNCH_LAYERNORM_BWD(U, false);
}
#ifdef PADDLE_WITH_CUDA
}
#endif
}

#undef PADDLE_LAUNCH_LAYERNORM_BWD
#undef PADDLE_LAUNCH_FAST_LAYERNORM_V2_BWD
}

} // namespace phi

#ifdef PADDLE_WITH_HIP
// MIOPEN do not support double
PD_REGISTER_PLUGIN_KERNEL(layer_norm_grad,
metax_gpu,
ALL_LAYOUT,
phi::LayerNormGradKernel,
float,
phi::float16) {
if (kernel_key.dtype() == phi::DataType::FLOAT16) {
kernel->OutputAt(1).SetDataType(phi::DataType::FLOAT32);
kernel->OutputAt(2).SetDataType(phi::DataType::FLOAT32);
}
}
#elif CUDNN_VERSION_MIN(8, 1, 0)
PD_REGISTER_PLUGIN_KERNEL(layer_norm_grad,
metax_gpu,
ALL_LAYOUT,
phi::LayerNormGradKernel,
float,
double,
phi::float16,
phi::bfloat16) {
if (kernel_key.dtype() == phi::DataType::FLOAT16) {
kernel->OutputAt(1).SetDataType(phi::DataType::FLOAT32);
kernel->OutputAt(2).SetDataType(phi::DataType::FLOAT32);
}
}
#else
PD_REGISTER_PLUGIN_KERNEL(layer_norm_grad,
metax_gpu,
ALL_LAYOUT,
phi::LayerNormGradKernel,
float,
double,
phi::dtype::float16,
phi::dtype::bfloat16) {
phi::float16) {
if (kernel_key.dtype() == phi::DataType::FLOAT16) {
kernel->OutputAt(1).SetDataType(phi::DataType::FLOAT32);
kernel->OutputAt(2).SetDataType(phi::DataType::FLOAT32);
}
}
#endif
Loading