Skip to content

Commit 4381f91

Browse files
authored
fix cat with out args (#1053) (#1074)
1 parent 36e0b30 commit 4381f91

File tree

2 files changed

+120
-9
lines changed

2 files changed

+120
-9
lines changed

intel_extension_for_pytorch/csrc/aten/cpu/TensorShape.cpp

Lines changed: 41 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,42 @@
3838
namespace torch_ipex {
3939
namespace cpu {
4040

41+
using namespace at;
42+
43+
void resize_out(
44+
const Tensor& out,
45+
IntArrayRef sizes,
46+
IntArrayRef strides,
47+
const TensorOptions& options) {
48+
TORCH_CHECK(
49+
options.dtype() == out.dtype(),
50+
"Expected out tensor to have dtype ",
51+
options.dtype(),
52+
", but got ",
53+
out.dtype(),
54+
" instead");
55+
TORCH_CHECK(
56+
options.device() == out.device(),
57+
"Expected out tensor to have device ",
58+
options.device(),
59+
", but got ",
60+
out.device(),
61+
" instead");
62+
const bool resized = at::native::resize_output(out, sizes);
63+
// Only restride if a resize occurred; otherwise we ignore the (advisory)
64+
// strides from the meta function and directly use the output tensor's
65+
// preexisting strides
66+
if (resized) {
67+
if (!strides.empty()) {
68+
TORCH_INTERNAL_ASSERT(!options.memory_format_opt().has_value());
69+
at::native::as_strided_(out, sizes, strides);
70+
} else if (options.memory_format_opt().has_value()) {
71+
out.unsafeGetTensorImpl()->empty_tensor_restride(
72+
*options.memory_format_opt());
73+
}
74+
}
75+
}
76+
4177
DEFINE_DISPATCH(cat_contig_stub);
4278

4379
inline void cat_check_no_zero_dim(
@@ -169,7 +205,11 @@ at::Tensor& cat_out_cpu(
169205
memory_format);
170206
}
171207

172-
result = at::empty(sizes, options);
208+
if (result.defined()) {
209+
resize_out(result, sizes, /*strides=*/{}, options);
210+
} else {
211+
result = at::empty(sizes, options);
212+
}
173213
// Checks for overlaps between the inputs and the output tensor.
174214
if (is_out_defined && found_valid_tensor) {
175215
at::assert_no_internal_overlap(result);

tests/cpu/test_cpu_ops.py

Lines changed: 79 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import random
66
import intel_extension_for_pytorch as ipex
77
from common_utils import TestCase
8+
import itertools
89

910
try:
1011
import torchvision
@@ -776,14 +777,84 @@ def test_index_select(self):
776777
self.assertEqual(y2, y, prec=0.01)
777778

778779
def test_cat(self):
779-
x = x = torch.randn(2, 3)
780-
y = torch.cat((x, x, x), 0)
781-
782-
# test bfloat16
783-
x2 = x.clone().detach().bfloat16()
784-
y2 = torch.cat((x2, x2, x2), 0)
785-
self.assertTrue(y2.dtype == torch.bfloat16)
786-
self.assertEqual(y2, y, prec=0.01)
780+
for datatype in [torch.float32, torch.double, torch.bfloat16]:
781+
for dim, size in itertools.product([0, 1], [[2, 1], [2, 2], [5, 10]]):
782+
x = torch.randn(size, dtype=datatype)
783+
y = torch.cat([x, x], dim)
784+
self.assertTrue(y.dtype == datatype)
785+
786+
# long input tensor list
787+
x1 = torch.randn((2, 2), dtype=datatype)
788+
input1 = []
789+
for i in range(100):
790+
input1.append(x1)
791+
y1 = torch.cat(input1, 0)
792+
self.assertTrue(y1.size() == torch.Size([200, 2]))
793+
self.assertTrue(y1.dtype == datatype)
794+
795+
# input tensors have different shapes and strides
796+
x2 = torch.randn((400, 2), dtype=datatype)
797+
input2 = []
798+
for i in range(10):
799+
input2.append(x1)
800+
for i in range(100):
801+
input2.append(x2)
802+
y2 = torch.cat(input2, 0)
803+
self.assertTrue(y2.size() == torch.Size([40020, 2]))
804+
self.assertTrue(y2.dtype == datatype)
805+
806+
x3 = torch.randn((4000, 2), dtype=datatype)
807+
input3 = []
808+
for i in range(10):
809+
input3.append(x1)
810+
for i in range(10):
811+
input3.append(x3)
812+
y3 = torch.cat(input3, 0)
813+
self.assertTrue(y3.size() == torch.Size([40020, 2]))
814+
self.assertTrue(y3.dtype == datatype)
815+
816+
x4 = torch.randn((4, 2), dtype=datatype)
817+
input4 = []
818+
for i in range(10):
819+
input4.append(x1)
820+
for i in range(10):
821+
input4.append(x4)
822+
y4 = torch.cat(input4, 0)
823+
self.assertTrue(y4.size() == torch.Size([60, 2]))
824+
self.assertTrue(y4.dtype == datatype)
825+
826+
# "out" arg is used but un-defined
827+
y5 = torch.cat([x4, x4], 0, out=torch.empty(0, dtype=datatype))
828+
self.assertEqual(y5, torch.cat([x4, x4], 0))
829+
self.assertTrue(y5.dtype == datatype)
830+
831+
# out is defined with wrong shape
832+
ref = torch.cat([x4, x4], 0)
833+
out = torch.zeros(1)
834+
out_ptr = out.data_ptr()
835+
torch.cat([x4, x4], 0, out=out)
836+
self.assertEqual(ref, out)
837+
self.assertTrue(ref.dtype == datatype)
838+
self.assertTrue(out_ptr != out.data_ptr())
839+
840+
# out is defined with correct shape
841+
ref = torch.cat([x4, x4], 0)
842+
out = torch.zeros_like(ref)
843+
out_ptr = out.data_ptr()
844+
torch.cat([x4, x4], 0, out=out)
845+
self.assertEqual(ref, out)
846+
self.assertTrue(ref.dtype == datatype)
847+
self.assertTrue(out_ptr == out.data_ptr())
848+
849+
y6 = torch.cat([x4, x4], 0, out=torch.empty(0, dtype=torch.float32))
850+
self.assertEqual(y6, torch.cat([x4, x4], 0))
851+
self.assertTrue(y6.dtype == torch.float32)
852+
853+
# one of input tensors is empty
854+
x7 = torch.empty(0, dtype=datatype)
855+
y7 = torch.cat([x4, x4, x7], 0)
856+
self.assertTrue(y7.size() == torch.Size([8, 2]))
857+
self.assertTrue(y7.dtype == datatype)
787858

788859
if __name__ == '__main__':
789860
test = unittest.main()

0 commit comments

Comments
 (0)