@@ -4626,193 +4626,6 @@ def test(backward=False):
46264626 test(backward=True)
46274627 test(backward=True)
46284628
4629- def test_index(self):
4630- def consec(size, start=0):
4631- numel = torch.tensor(size).prod().item()
4632- return torch.arange(numel).view(size)
4633-
4634- def consec_list(size):
4635- return list(range(size))
4636-
4637- def random_string(size):
4638- letters = string.ascii_lowercase
4639- return "".join(random.choice(letters) for i in range(size))
4640-
4641- def check_indexing(indexing, tensor):
4642- template = dedent("""
4643- def func(x):
4644- return x{}
4645- """)
4646-
4647- self._check_code(template.format(indexing), "func", [tensor])
4648-
4649- def check_dynamic_indexing(indexing, tensor, value1, value2):
4650- value1 = torch.tensor(value1)
4651- value2 = torch.tensor(value2)
4652-
4653- template = dedent("""
4654- def func(x, value1, value2):
4655- i = int(value1)
4656- j = int(value2)
4657- return x{}
4658- """)
4659-
4660- self._check_code(template.format(indexing), "func", [tensor, value1, value2])
4661-
4662- # Torchscript assumes type Tensor by default, so we need this explicit
4663- # declaration.
4664- def check_indexing_list_int(indexing, list):
4665- template = dedent("""
4666- def func(x):
4667- # type: (List[int]) -> Any
4668- return x{}
4669- """)
4670-
4671- self._check_code(template.format(indexing), "func", [list])
4672-
4673- def check_indexing_str(indexing, str):
4674- template = dedent("""
4675- def func(x):
4676- # type: (str) -> Any
4677- return x{}
4678- """)
4679-
4680- self._check_code(template.format(indexing), "func", [str])
4681-
4682- # basic slices
4683- check_indexing('[0]', consec((3, 3)))
4684- check_indexing('[1]', consec((3, 3), 10))
4685- check_indexing('[2]', consec((3, 3), 19))
4686- check_indexing('[2]', consec((3,)))
4687- check_indexing('[-1]', consec((3, 3), 19))
4688- check_indexing('[0:2]', consec((3, 3, 3)))
4689- check_indexing('[1:-1]', consec((3, 3, 3)))
4690- check_indexing('[-3:-1]', consec((6, 3)))
4691- check_indexing('[1:]', consec((3, 3)))
4692- check_indexing('[:1]', consec((3, 3)))
4693- check_indexing('[:]', consec((3, 2)))
4694-
4695- # multi-dim: indexes
4696- check_indexing('[0, 1]', consec((3, 3)))
4697- check_indexing('[0, 1]', consec((3, 3, 2)))
4698- check_indexing('[1, 0, 2]', consec((3, 3, 3)))
4699- check_indexing('[2, -1]', consec((3, 3)))
4700-
4701- # multi-dim: mixed slicing and indexing
4702- check_indexing('[0, 1:2]', consec((3, 3)))
4703- check_indexing('[0, :1]', consec((3, 3, 2)))
4704- check_indexing('[1, 2:]', consec((3, 3, 3)))
4705- check_indexing('[-1, 1:, 0]', consec((3, 3, 3, 3)))
4706- check_indexing('[1:, -1, 0]', consec((3, 3, 3, 3)))
4707- check_indexing('[-1, 2:, 1:2]', consec((3, 3, 3, 3)))
4708- check_indexing('[-1, 1:, 0]', consec((3, 3, 3, 3)))
4709- check_indexing('[-1, :, 0, 2]', consec((3, 3, 3, 3)))
4710-
4711- # zero-sized slices
4712- check_indexing('[0:0]', consec((2, 2)))
4713- check_indexing('[0:0, 1]', consec((3, 3)))
4714-
4715- # trivial expression usage
4716- check_indexing('[1+1]', consec((3, 3)))
4717- check_indexing('[1:(0 + 2)]', consec((3, 3, 3)))
4718-
4719- # None for new dimensions
4720- check_indexing('[None, 0]', consec((3, 3)))
4721- check_indexing('[1, None]', consec((3, 3), 10))
4722- check_indexing('[None, None, 2]', consec((3, 3), 19))
4723- check_indexing('[None, 2, None]', consec((3,)))
4724- check_indexing('[0:2, None]', consec((3, 3, 3)))
4725- check_indexing('[None, 1:-1]', consec((3, 3, 3)))
4726- check_indexing('[None, -3:-1, None]', consec((6, 3)))
4727- check_indexing('[-1, None, 2:, None, 1:2]', consec((3, 3, 3, 3)))
4728- check_indexing('[None, -1, None, 2:, None, 1:2, None]', consec((3, 3, 3, 3)))
4729-
4730- # dynamic expression usage
4731- check_dynamic_indexing("[i + j]", consec((3, 3)), 0, 1)
4732- check_dynamic_indexing("[i:j, i]", consec((3, 3, 2)), 0, 2)
4733-
4734- # positive striding
4735- check_indexing_list_int('[0]', consec_list(6))
4736- check_indexing_list_int('[1]', consec_list(7))
4737- check_indexing_list_int('[2]', consec_list(8))
4738- check_indexing_list_int('[2]', consec_list(9))
4739- check_indexing_list_int('[-1]', consec_list(10))
4740- check_indexing_list_int('[0:2]', consec_list(11))
4741- check_indexing_list_int('[1:-1]', consec_list(12))
4742- check_indexing_list_int('[-3:-1]', consec_list(13))
4743- check_indexing_list_int('[1:]', consec_list(15))
4744- check_indexing_list_int('[:1]', consec_list(16))
4745- check_indexing_list_int('[:]', consec_list(17))
4746- check_indexing_list_int('[::]', consec_list(0))
4747- check_indexing_list_int('[1000::]', consec_list(0))
4748- check_indexing_list_int('[:1000:]', consec_list(0))
4749-
4750- # negative striding
4751- check_indexing_list_int('[::-1]', consec_list(7))
4752- check_indexing_list_int('[:3:-1]', consec_list(7))
4753- check_indexing_list_int('[3::-1]', consec_list(7))
4754- check_indexing_list_int('[1000::-1]', consec_list(7))
4755- check_indexing_list_int('[3:0:-1]', consec_list(7))
4756- check_indexing_list_int('[3:-1000:-1]', consec_list(7))
4757- check_indexing_list_int('[0:0:-1]', consec_list(7))
4758- check_indexing_list_int('[0:-1000:-1]', consec_list(7))
4759-
4760- # only step is specified
4761- check_indexing_list_int('[::-1]', consec_list(0))
4762- check_indexing_list_int('[::-1]', consec_list(7))
4763- check_indexing_list_int('[::-2]', consec_list(7))
4764- check_indexing_list_int('[::2]', consec_list(7))
4765- check_indexing_list_int('[::42]', consec_list(7))
4766- check_indexing_list_int('[::-42]', consec_list(7))
4767- check_indexing_list_int('[::42]', consec_list(0))
4768- check_indexing_list_int('[::-42]', consec_list(0))
4769- check_indexing_list_int('[::9223372036854775807]', consec_list(42))
4770- check_indexing_list_int('[::-9223372036854775807]', consec_list(42))
4771- with self.assertRaisesRegex(RuntimeError, "out of bounds"):
4772- check_indexing_list_int('[::-9223372036854775808]', consec_list(42))
4773- with self.assertRaisesRegex(RuntimeError, "should have non-zero step"):
4774- check_indexing_list_int('[::0]', consec_list(42))
4775-
4776- # striding strings
4777- check_indexing_str('[0]', random_string(6))
4778- check_indexing_str('[1]', random_string(7))
4779- check_indexing_str('[2]', random_string(8))
4780- check_indexing_str('[2]', random_string(9))
4781- check_indexing_str('[-1]', random_string(10))
4782- check_indexing_str('[0:2]', random_string(11))
4783- check_indexing_str('[1:-1]', random_string(12))
4784- check_indexing_str('[-3:-1]', random_string(13))
4785- check_indexing_str('[1:]', random_string(15))
4786- check_indexing_str('[:1]', random_string(16))
4787- check_indexing_str('[:]', random_string(17))
4788- check_indexing_str('[::]', random_string(0))
4789- check_indexing_str('[1000::]', random_string(0))
4790- check_indexing_str('[:1000:]', random_string(0))
4791-
4792- check_indexing_str('[::-1]', random_string(7))
4793- check_indexing_str('[:3:-1]', random_string(7))
4794- check_indexing_str('[3::-1]', random_string(7))
4795- check_indexing_str('[1000::-1]', random_string(7))
4796- check_indexing_str('[3:0:-1]', random_string(7))
4797- check_indexing_str('[3:-1000:-1]', random_string(7))
4798- check_indexing_str('[0:0:-1]', random_string(7))
4799- check_indexing_str('[0:-1000:-1]', random_string(7))
4800-
4801- check_indexing_str('[::-1]', random_string(0))
4802- check_indexing_str('[::-1]', random_string(7))
4803- check_indexing_str('[::-2]', random_string(7))
4804- check_indexing_str('[::2]', random_string(7))
4805- check_indexing_str('[::42]', random_string(7))
4806- check_indexing_str('[::-42]', random_string(7))
4807- check_indexing_str('[::42]', random_string(0))
4808- check_indexing_str('[::-42]', random_string(0))
4809- check_indexing_str('[::9223372036854775807]', random_string(42))
4810- check_indexing_str('[::-9223372036854775807]', random_string(42))
4811- with self.assertRaisesRegex(RuntimeError, "out of bounds"):
4812- check_indexing_str('[::-9223372036854775808]', random_string(42))
4813- with self.assertRaisesRegex(RuntimeError, "should have non-zero step"):
4814- check_indexing_str('[::0]', random_string(42))
4815-
48164629 def test_module_copy_with_attributes(self):
48174630 class Vocabulary(torch.jit.ScriptModule):
48184631 def __init__(self, vocab_list):
0 commit comments