From 57a77effa0089eea4445af463fca7d381d0be448 Mon Sep 17 00:00:00 2001 From: guozixu2001 Date: Fri, 9 Aug 2024 10:50:14 +0800 Subject: [PATCH] Fix: Correct backward convolution function call with output mask --- impl/torch/functions/functions.cpp | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/impl/torch/functions/functions.cpp b/impl/torch/functions/functions.cpp index 94b6705bd..bc0acae1a 100644 --- a/impl/torch/functions/functions.cpp +++ b/impl/torch/functions/functions.cpp @@ -2254,8 +2254,11 @@ diopiError_t diopiConvolution2dBackward(diopiContextHandle_t ctx, diopiTensorHan at::native::copy_(atGradWeight, std::get<1>(tempOut), true); at::native::copy_(atGradBias, std::get<2>(tempOut), true); } else { - auto results = at::convolution_backward( - atGrad, atInput, atWeight, c10::nullopt, atStride, atPadding, atDilation, false, outputPadding, groups, {true, true, false}); + std::array output_mask{true, true, false}; + if (!grad_input) output_mask[0] = false; + if (!grad_weight) output_mask[1] = false; + auto results = at::native::convolution_backward( + atGrad, atInput, atWeight, c10::nullopt, atStride, atPadding, atDilation, false, outputPadding, groups, output_mask); impl::aten::updateATen2Tensor(ctx, std::get<0>(results), grad_input); impl::aten::updateATen2Tensor(ctx, std::get<1>(results), grad_weight); if (bias_sizes && grad_bias) {