diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp index fbaf8a1f756b..71b433822e32 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp @@ -3165,6 +3165,44 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( supportedScaleFactors = noneVal; supportedSizes = createScalarSublist( loc, proposedSizes, assumedForemostSpatialDim, rewriter); + + // supportedSizes is the spatial list. For resize/interpolate op [0,0] + // is not valid. Replace 0s with the dynamic sentinel + auto supportedSizesTensorType = + cast(supportedSizes.getType()); + + auto sizesOfsupportedSizesTensor = + supportedSizesTensorType.getSizes(); + auto lengthOfFullList = sizesOfsupportedSizesTensor[0]; + SmallVector newSupportedSizes; + + for (int indexOfEachScalar = 0; indexOfEachScalar < lengthOfFullList; + indexOfEachScalar++) { + + Value proposedDim = extractTorchScalar(loc, indexOfEachScalar, + supportedSizes, rewriter); + + Value zero = rewriter.create(loc, 0); + Value dynamic = + rewriter.create(loc, unknownSize); + + // Check if the proposed size is 0 + Value isZero = rewriter.create( + loc, boolType, proposedDim, zero); + + // If proposed spatial dim is 0, change it to use torch dynamic dim + auto corrected = rewriter.create( + loc, proposedDim.getType(), isZero, dynamic, proposedDim); + + newSupportedSizes.push_back(corrected); + } + + auto someTorchScalarType = newSupportedSizes.front().getType(); + Type someTorchScalarListType = + Torch::ListType::get(someTorchScalarType); + + supportedSizes = rewriter.create( + loc, someTorchScalarListType, newSupportedSizes); } else return rewriter.notifyMatchFailure(binder.op, "unknown scaling mode");