|
5 | 5 | import random |
6 | 6 | import intel_extension_for_pytorch as ipex |
7 | 7 | from common_utils import TestCase |
| 8 | +import itertools |
8 | 9 |
|
9 | 10 | try: |
10 | 11 | import torchvision |
@@ -776,14 +777,84 @@ def test_index_select(self): |
776 | 777 | self.assertEqual(y2, y, prec=0.01) |
777 | 778 |
|
778 | 779 | def test_cat(self): |
779 | | - x = x = torch.randn(2, 3) |
780 | | - y = torch.cat((x, x, x), 0) |
781 | | - |
782 | | - # test bfloat16 |
783 | | - x2 = x.clone().detach().bfloat16() |
784 | | - y2 = torch.cat((x2, x2, x2), 0) |
785 | | - self.assertTrue(y2.dtype == torch.bfloat16) |
786 | | - self.assertEqual(y2, y, prec=0.01) |
| 780 | + for datatype in [torch.float32, torch.double, torch.bfloat16]: |
| 781 | + for dim, size in itertools.product([0, 1], [[2, 1], [2, 2], [5, 10]]): |
| 782 | + x = torch.randn(size, dtype=datatype) |
| 783 | + y = torch.cat([x, x], dim) |
| 784 | + self.assertTrue(y.dtype == datatype) |
| 785 | + |
| 786 | + # long input tensor list |
| 787 | + x1 = torch.randn((2, 2), dtype=datatype) |
| 788 | + input1 = [] |
| 789 | + for i in range(100): |
| 790 | + input1.append(x1) |
| 791 | + y1 = torch.cat(input1, 0) |
| 792 | + self.assertTrue(y1.size() == torch.Size([200, 2])) |
| 793 | + self.assertTrue(y1.dtype == datatype) |
| 794 | + |
| 795 | + # input tensors have different shapes and strides |
| 796 | + x2 = torch.randn((400, 2), dtype=datatype) |
| 797 | + input2 = [] |
| 798 | + for i in range(10): |
| 799 | + input2.append(x1) |
| 800 | + for i in range(100): |
| 801 | + input2.append(x2) |
| 802 | + y2 = torch.cat(input2, 0) |
| 803 | + self.assertTrue(y2.size() == torch.Size([40020, 2])) |
| 804 | + self.assertTrue(y2.dtype == datatype) |
| 805 | + |
| 806 | + x3 = torch.randn((4000, 2), dtype=datatype) |
| 807 | + input3 = [] |
| 808 | + for i in range(10): |
| 809 | + input3.append(x1) |
| 810 | + for i in range(10): |
| 811 | + input3.append(x3) |
| 812 | + y3 = torch.cat(input3, 0) |
| 813 | + self.assertTrue(y3.size() == torch.Size([40020, 2])) |
| 814 | + self.assertTrue(y3.dtype == datatype) |
| 815 | + |
| 816 | + x4 = torch.randn((4, 2), dtype=datatype) |
| 817 | + input4 = [] |
| 818 | + for i in range(10): |
| 819 | + input4.append(x1) |
| 820 | + for i in range(10): |
| 821 | + input4.append(x4) |
| 822 | + y4 = torch.cat(input4, 0) |
| 823 | + self.assertTrue(y4.size() == torch.Size([60, 2])) |
| 824 | + self.assertTrue(y4.dtype == datatype) |
| 825 | + |
| 826 | + # "out" arg is used but un-defined |
| 827 | + y5 = torch.cat([x4, x4], 0, out=torch.empty(0, dtype=datatype)) |
| 828 | + self.assertEqual(y5, torch.cat([x4, x4], 0)) |
| 829 | + self.assertTrue(y5.dtype == datatype) |
| 830 | + |
| 831 | + # out is defined with wrong shape |
| 832 | + ref = torch.cat([x4, x4], 0) |
| 833 | + out = torch.zeros(1) |
| 834 | + out_ptr = out.data_ptr() |
| 835 | + torch.cat([x4, x4], 0, out=out) |
| 836 | + self.assertEqual(ref, out) |
| 837 | + self.assertTrue(ref.dtype == datatype) |
| 838 | + self.assertTrue(out_ptr != out.data_ptr()) |
| 839 | + |
| 840 | + # out is defined with correct shape |
| 841 | + ref = torch.cat([x4, x4], 0) |
| 842 | + out = torch.zeros_like(ref) |
| 843 | + out_ptr = out.data_ptr() |
| 844 | + torch.cat([x4, x4], 0, out=out) |
| 845 | + self.assertEqual(ref, out) |
| 846 | + self.assertTrue(ref.dtype == datatype) |
| 847 | + self.assertTrue(out_ptr == out.data_ptr()) |
| 848 | + |
| 849 | + y6 = torch.cat([x4, x4], 0, out=torch.empty(0, dtype=torch.float32)) |
| 850 | + self.assertEqual(y6, torch.cat([x4, x4], 0)) |
| 851 | + self.assertTrue(y6.dtype == torch.float32) |
| 852 | + |
| 853 | + # one of input tensors is empty |
| 854 | + x7 = torch.empty(0, dtype=datatype) |
| 855 | + y7 = torch.cat([x4, x4, x7], 0) |
| 856 | + self.assertTrue(y7.size() == torch.Size([8, 2])) |
| 857 | + self.assertTrue(y7.dtype == datatype) |
787 | 858 |
|
788 | 859 | if __name__ == '__main__': |
789 | 860 | test = unittest.main() |
0 commit comments