@@ -4240,6 +4240,74 @@ diopiError_t diopiGroupNormGB(diopiContextHandle_t ctx, diopiTensorHandle_t out,
42404240 return diopiSuccess;
42414241}
42424242
4243+ diopiError_t diopiGroupNormGBBackward (diopiContextHandle_t ctx, diopiTensorHandle_t grad_input, diopiTensorHandle_t grad_weight, diopiTensorHandle_t grad_bias,
4244+ diopiConstTensorHandle_t grad_output, diopiConstTensorHandle_t input, diopiConstTensorHandle_t weight,
4245+ diopiConstTensorHandle_t mean, diopiConstTensorHandle_t rstd, int64_t num_groups, diopiSize_t reduced_axes, const int64_t channel_axis) {
4246+ impl::aten::setCurStream (ctx);
4247+ auto atGradOutput = impl::aten::buildATen (grad_output);
4248+ auto atInput = impl::aten::buildATen (input);
4249+ auto atWeight = impl::aten::buildATen (weight);
4250+ auto atSaveMean = impl::aten::buildATen (mean);
4251+ auto atSaveVar = impl::aten::buildATen (rstd);
4252+ auto atGradWeight = impl::aten::buildATen (grad_weight);
4253+ auto atGradBias = impl::aten::buildATen (grad_bias);
4254+ auto axisSize = atInput.size (channel_axis);
4255+ auto k = axisSize / num_groups;
4256+ at::IntArrayRef atReducedAxes = impl::aten::buildAtIntArray (reduced_axes);
4257+ std::vector<int64_t > dims;
4258+ int64_t N = 1 ;
4259+ for (int i = 0 ; i < atInput.dim (); i++) {
4260+ if (i == channel_axis) {
4261+ continue ;
4262+ } else {
4263+ bool is_reduced_axis = false ;
4264+ for (int m = 0 ; m < reduced_axes.len ; m++) {
4265+ if (i == reduced_axes.data [m]) {
4266+ is_reduced_axis = true ;
4267+ break ;
4268+ }
4269+ }
4270+ if (is_reduced_axis) {
4271+ continue ;
4272+ } else {
4273+ dims.push_back (i);
4274+ N *= atInput.size (i);
4275+ }
4276+ }
4277+ }
4278+ dims.push_back (channel_axis);
4279+ int64_t HxW = 1 ;
4280+ for (auto i = 0 ; i < reduced_axes.len ; i++) {
4281+ dims.push_back (reduced_axes.data [i]);
4282+ HxW *= atInput.size (reduced_axes.data [i]);
4283+ }
4284+ auto C = atInput.size (channel_axis);
4285+ auto permutedInput = atInput.permute (dims);
4286+ auto permutedShape = permutedInput.sizes ();
4287+ auto reshapedInput = permutedInput.reshape ({N, C, HxW, 1 }).contiguous ();
4288+
4289+ std::vector<int64_t > reverse_order (dims.size ());
4290+ for (auto i = 0 ; i < atInput.dim (); i++) {
4291+ reverse_order[dims[i]] = i;
4292+ }
4293+
4294+ if (grad_weight && grad_bias) {
4295+ auto atGradInput = impl::aten::buildATen (grad_input).permute (dims).reshape ({N, C, HxW, 1 });
4296+
4297+ at::native_group_norm_backward_out (
4298+ atGradInput, atGradWeight, atGradBias, atGradOutput.permute (dims).reshape ({N, C, HxW, 1 }), reshapedInput, atSaveMean, atSaveVar, atWeight, N, C, HxW, num_groups, {true , true , true });
4299+ atGradInput = atGradInput.reshape (permutedShape).permute (reverse_order);
4300+ } else {
4301+ auto atOuts = at::native_group_norm_backward (
4302+ atGradOutput.permute (dims).reshape ({N, C, HxW, 1 }), reshapedInput, atSaveMean, atSaveVar, atWeight, N, C, HxW, num_groups, {true , grad_weight != nullptr , grad_bias != nullptr });
4303+ impl::aten::updateATen2Tensor (ctx, std::get<0 >(atOuts).reshape (permutedShape).permute (reverse_order), grad_input);
4304+ impl::aten::updateATen2Tensor (ctx, std::get<1 >(atOuts), grad_weight);
4305+ impl::aten::updateATen2Tensor (ctx, std::get<2 >(atOuts), grad_bias);
4306+ }
4307+
4308+ return diopiSuccess;
4309+ }
4310+
42434311diopiError_t diopiGroupNorm (diopiContextHandle_t ctx, diopiTensorHandle_t out, diopiTensorHandle_t save_mean, diopiTensorHandle_t save_invstd,
42444312 diopiConstTensorHandle_t input, diopiConstTensorHandle_t weight, diopiConstTensorHandle_t bias, int64_t num_groups,
42454313 double eps) {
0 commit comments