Skip to content

Commit 9db64e6

Browse files
tugsbayasgalanfacebook-github-bot
authored andcommitted
Revert "Striding for lists Part 2 (pytorch#49352)" (pytorch#58523)
Summary: Pull Request resolved: pytorch#58523 This reverts commit fee7e8b. Test Plan: Imported from OSS Reviewed By: gmagogsfm Differential Revision: D28528023 Pulled By: tugsbayasgalan fbshipit-source-id: 9fa1d86f0c81fcc6fd3798e0d51a712a3c9b3952
1 parent 9123229 commit 9db64e6

File tree

14 files changed

+53
-252
lines changed

14 files changed

+53
-252
lines changed

aten/src/ATen/native/native_functions.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3642,7 +3642,7 @@
36423642
device_check: NoCheck
36433643
device_guard: False
36443644

3645-
- func: slice.Tensor(Tensor(a) self, int dim=0, int? start=None, int? end=None, int step=1) -> Tensor(a)
3645+
- func: slice.Tensor(Tensor(a) self, int dim=0, int? start=0, int? end=9223372036854775807, int step=1) -> Tensor(a)
36463646
variants: function, method
36473647
device_check: NoCheck
36483648
device_guard: False

aten/src/ATen/templates/RegisterSchema.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ TORCH_LIBRARY(aten, m) {
2323
// Implementations located in torch/csrc/jit/runtime/register_prim_ops.cpp
2424
m.def(TORCH_SELECTIVE_SCHEMA("aten::splitlines(str self, bool keepends=False) -> str[]"));
2525
m.def(TORCH_SELECTIVE_SCHEMA(
26-
"aten::slice.str(str string, int? start=None, int? end=None, int step=1) -> str"));
26+
"aten::slice.str(str string, int? start=0, int? end=9223372036854775807, int step=1) -> str"));
2727
m.def(TORCH_SELECTIVE_SCHEMA("aten::isupper(str self) -> bool"));
2828
m.def(TORCH_SELECTIVE_SCHEMA("aten::islower(str self) -> bool"));
2929
m.def(TORCH_SELECTIVE_SCHEMA("aten::capitalize(str self) -> str"));

test/backward_compatibility/check_backward_compatibility.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@
8383
("aten::cumprod_backward", datetime.date(2021, 5, 1)),
8484
("aten::_triangular_solve_helper", datetime.date(9999, 1, 1)),
8585
("aten::_addmv_impl_", datetime.date(2021, 5, 15)),
86-
("aten::slice", datetime.date(2021, 5, 31)),
86+
("aten::slice", datetime.date(2021, 6, 15)),
8787
("aten::adaptive_avg_pool3d_backward", datetime.date(9999, 1, 1)),
8888
("aten::_embedding_bag_dense_backward", datetime.date(9999, 1, 1)),
8989
("aten::_amp_update_scale", datetime.date(2021, 6, 1)),

test/cpp/jit/test_interpreter.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,7 @@ TEST(InterpreterTest, IgnorableArgsInSchema) {
159159
auto op_to_specified_args = function.op_to_num_specified_args();
160160
ASSERT_TRUE(op_to_specified_args.size() == 2);
161161
ASSERT_TRUE(op_to_specified_args["aten::slice.Tensor"] == 4);
162-
ASSERT_TRUE(op_to_specified_args["aten::slice.str"] == 4);
162+
ASSERT_TRUE(op_to_specified_args["aten::slice.str"] == 1);
163163
auto graph_vararg = build_mobile_export_analysis_graph_with_vararg();
164164
MobileCode function_vararg(graph_vararg, "");
165165
auto op_to_specified_args_vararg = function_vararg.op_to_num_specified_args();
@@ -172,7 +172,7 @@ TEST(InterpreterTest, IgnorableArgsInSchema) {
172172
MobileCode function_nested(graph_nested, "");
173173
auto op_to_specified_args_nested = function_nested.op_to_num_specified_args();
174174
ASSERT_TRUE(op_to_specified_args_nested["aten::slice.Tensor"] == 4);
175-
ASSERT_TRUE(op_to_specified_args_nested["aten::slice.str"] == 4);
175+
ASSERT_TRUE(op_to_specified_args_nested["aten::slice.str"] == 1);
176176

177177
auto graph_non_const = build_mobile_export_analysis_graph_non_const();
178178
MobileCode function_non_const(graph_non_const, "");

test/cpp/jit/test_utils.cpp

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -95,9 +95,9 @@ std::shared_ptr<Graph> build_lstm() {
9595

9696
std::shared_ptr<Graph> build_mobile_export_analysis_graph() {
9797
// We use following two schemas for this graph:
98-
// 1. slice.Tensor(Tensor(a) self, int dim=0, int? start=None,
99-
// int? end=None, int step=1) -> Tensor(a)
100-
// 2. slice.str(str string, int? start=None, int? end=None,
98+
// 1. slice.Tensor(Tensor(a) self, int dim=0, int? start=0,
99+
// int? end=9223372036854775807, int step=1) -> Tensor(a)
100+
// 2. slice.str(str string, int? start=0, int? end=9223372036854775807,
101101
// int step=1) -> str
102102
// %3 and %4 use slice.Tensor while %5 use slice.str.
103103
// Since we can see %3 and %4 have the same last argument that is never used
@@ -114,7 +114,7 @@ std::shared_ptr<Graph> build_mobile_export_analysis_graph() {
114114
%22 : str = prim::Constant[value="value"]()
115115
%3 : Tensor = aten::slice(%0, %1, %20, %2, %1)
116116
%4 : Tensor = aten::slice(%0, %2, %20, %21, %1)
117-
%5 : str = aten::slice(%22, %20, %21, %2)
117+
%5 : str = aten::slice(%22, %20, %21, %1)
118118
return (%3, %4, %5))IR";
119119

120120
auto g = std::make_shared<Graph>();
@@ -139,7 +139,7 @@ std::shared_ptr<Graph> build_mobile_export_analysis_graph_nested() {
139139
%c : Tensor = prim::If(%23)
140140
block0():
141141
%4 : Tensor = aten::slice(%0, %2, %20, %21, %1)
142-
%5 : str = aten::slice(%22, %20, %21, %2)
142+
%5 : str = aten::slice(%22, %20, %21, %1)
143143
%c.1 : Tensor = aten::slice(%0, %1, %20, %2, %1)
144144
-> (%c.1)
145145
block1():

test/jit/test_ignorable_args.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import os
22
import sys
3+
34
from torch._C import parse_ir
45
from torch.testing import FileCheck
56

@@ -17,6 +18,7 @@
1718
class TestIgnorableArgs(JitTestCase):
1819
def test_slice_ignorable_args_for_slice(self):
1920
graph_str = """graph():
21+
%15 : int = prim::Constant[value=9223372036854775807]()
2022
%13 : int = prim::Constant[value=0]()
2123
%10 : bool = prim::Constant[value=0]()
2224
%8 : NoneType = prim::Constant()
@@ -29,8 +31,8 @@ def test_slice_ignorable_args_for_slice(self):
2931
%6 : int[] = prim::ListConstruct(%0, %1, %2, %3, %4, %4)
3032
%7 : int[][] = prim::ListConstruct(%5, %6)
3133
%val.1 : Tensor = aten::tensor(%7, %8, %8, %10)
32-
%16 : Tensor = aten::slice(%val.1, %13, %1, %8, %0)
33-
%20 : Tensor = aten::slice(%16, %0, %8, %0, %0)
34+
%16 : Tensor = aten::slice(%val.1, %13, %1, %15, %0)
35+
%20 : Tensor = aten::slice(%16, %0, %13, %0, %0)
3436
return (%20)"""
3537
graph = parse_ir(graph_str)
3638
function = self.createFunctionFromGraph(graph)
@@ -41,5 +43,5 @@ def test_slice_ignorable_args_for_slice(self):
4143
# We ignore trailing arguments after start=2 for dim 0
4244
# and after end=1 for dim 1
4345
# because in %16, %15 and %0 are default values for the schema.
44-
FileCheck().check("torch.slice(torch.slice(torch.tensor(_0), 0, 2), 1, None, 1)").run(src)
46+
FileCheck().check("torch.slice(torch.tensor(_0), 0, 2), 1, 0, 1)").run(src)
4547
self.assertEqual(function(), function_copy())

test/test_jit.py

Lines changed: 0 additions & 187 deletions
Original file line numberDiff line numberDiff line change
@@ -4626,193 +4626,6 @@ def test(backward=False):
46264626
test(backward=True)
46274627
test(backward=True)
46284628

4629-
def test_index(self):
4630-
def consec(size, start=0):
4631-
numel = torch.tensor(size).prod().item()
4632-
return torch.arange(numel).view(size)
4633-
4634-
def consec_list(size):
4635-
return list(range(size))
4636-
4637-
def random_string(size):
4638-
letters = string.ascii_lowercase
4639-
return "".join(random.choice(letters) for i in range(size))
4640-
4641-
def check_indexing(indexing, tensor):
4642-
template = dedent("""
4643-
def func(x):
4644-
return x{}
4645-
""")
4646-
4647-
self._check_code(template.format(indexing), "func", [tensor])
4648-
4649-
def check_dynamic_indexing(indexing, tensor, value1, value2):
4650-
value1 = torch.tensor(value1)
4651-
value2 = torch.tensor(value2)
4652-
4653-
template = dedent("""
4654-
def func(x, value1, value2):
4655-
i = int(value1)
4656-
j = int(value2)
4657-
return x{}
4658-
""")
4659-
4660-
self._check_code(template.format(indexing), "func", [tensor, value1, value2])
4661-
4662-
# Torchscript assumes type Tensor by default, so we need this explicit
4663-
# declaration.
4664-
def check_indexing_list_int(indexing, list):
4665-
template = dedent("""
4666-
def func(x):
4667-
# type: (List[int]) -> Any
4668-
return x{}
4669-
""")
4670-
4671-
self._check_code(template.format(indexing), "func", [list])
4672-
4673-
def check_indexing_str(indexing, str):
4674-
template = dedent("""
4675-
def func(x):
4676-
# type: (str) -> Any
4677-
return x{}
4678-
""")
4679-
4680-
self._check_code(template.format(indexing), "func", [str])
4681-
4682-
# basic slices
4683-
check_indexing('[0]', consec((3, 3)))
4684-
check_indexing('[1]', consec((3, 3), 10))
4685-
check_indexing('[2]', consec((3, 3), 19))
4686-
check_indexing('[2]', consec((3,)))
4687-
check_indexing('[-1]', consec((3, 3), 19))
4688-
check_indexing('[0:2]', consec((3, 3, 3)))
4689-
check_indexing('[1:-1]', consec((3, 3, 3)))
4690-
check_indexing('[-3:-1]', consec((6, 3)))
4691-
check_indexing('[1:]', consec((3, 3)))
4692-
check_indexing('[:1]', consec((3, 3)))
4693-
check_indexing('[:]', consec((3, 2)))
4694-
4695-
# multi-dim: indexes
4696-
check_indexing('[0, 1]', consec((3, 3)))
4697-
check_indexing('[0, 1]', consec((3, 3, 2)))
4698-
check_indexing('[1, 0, 2]', consec((3, 3, 3)))
4699-
check_indexing('[2, -1]', consec((3, 3)))
4700-
4701-
# multi-dim: mixed slicing and indexing
4702-
check_indexing('[0, 1:2]', consec((3, 3)))
4703-
check_indexing('[0, :1]', consec((3, 3, 2)))
4704-
check_indexing('[1, 2:]', consec((3, 3, 3)))
4705-
check_indexing('[-1, 1:, 0]', consec((3, 3, 3, 3)))
4706-
check_indexing('[1:, -1, 0]', consec((3, 3, 3, 3)))
4707-
check_indexing('[-1, 2:, 1:2]', consec((3, 3, 3, 3)))
4708-
check_indexing('[-1, 1:, 0]', consec((3, 3, 3, 3)))
4709-
check_indexing('[-1, :, 0, 2]', consec((3, 3, 3, 3)))
4710-
4711-
# zero-sized slices
4712-
check_indexing('[0:0]', consec((2, 2)))
4713-
check_indexing('[0:0, 1]', consec((3, 3)))
4714-
4715-
# trivial expression usage
4716-
check_indexing('[1+1]', consec((3, 3)))
4717-
check_indexing('[1:(0 + 2)]', consec((3, 3, 3)))
4718-
4719-
# None for new dimensions
4720-
check_indexing('[None, 0]', consec((3, 3)))
4721-
check_indexing('[1, None]', consec((3, 3), 10))
4722-
check_indexing('[None, None, 2]', consec((3, 3), 19))
4723-
check_indexing('[None, 2, None]', consec((3,)))
4724-
check_indexing('[0:2, None]', consec((3, 3, 3)))
4725-
check_indexing('[None, 1:-1]', consec((3, 3, 3)))
4726-
check_indexing('[None, -3:-1, None]', consec((6, 3)))
4727-
check_indexing('[-1, None, 2:, None, 1:2]', consec((3, 3, 3, 3)))
4728-
check_indexing('[None, -1, None, 2:, None, 1:2, None]', consec((3, 3, 3, 3)))
4729-
4730-
# dynamic expression usage
4731-
check_dynamic_indexing("[i + j]", consec((3, 3)), 0, 1)
4732-
check_dynamic_indexing("[i:j, i]", consec((3, 3, 2)), 0, 2)
4733-
4734-
# positive striding
4735-
check_indexing_list_int('[0]', consec_list(6))
4736-
check_indexing_list_int('[1]', consec_list(7))
4737-
check_indexing_list_int('[2]', consec_list(8))
4738-
check_indexing_list_int('[2]', consec_list(9))
4739-
check_indexing_list_int('[-1]', consec_list(10))
4740-
check_indexing_list_int('[0:2]', consec_list(11))
4741-
check_indexing_list_int('[1:-1]', consec_list(12))
4742-
check_indexing_list_int('[-3:-1]', consec_list(13))
4743-
check_indexing_list_int('[1:]', consec_list(15))
4744-
check_indexing_list_int('[:1]', consec_list(16))
4745-
check_indexing_list_int('[:]', consec_list(17))
4746-
check_indexing_list_int('[::]', consec_list(0))
4747-
check_indexing_list_int('[1000::]', consec_list(0))
4748-
check_indexing_list_int('[:1000:]', consec_list(0))
4749-
4750-
# negative striding
4751-
check_indexing_list_int('[::-1]', consec_list(7))
4752-
check_indexing_list_int('[:3:-1]', consec_list(7))
4753-
check_indexing_list_int('[3::-1]', consec_list(7))
4754-
check_indexing_list_int('[1000::-1]', consec_list(7))
4755-
check_indexing_list_int('[3:0:-1]', consec_list(7))
4756-
check_indexing_list_int('[3:-1000:-1]', consec_list(7))
4757-
check_indexing_list_int('[0:0:-1]', consec_list(7))
4758-
check_indexing_list_int('[0:-1000:-1]', consec_list(7))
4759-
4760-
# only step is specified
4761-
check_indexing_list_int('[::-1]', consec_list(0))
4762-
check_indexing_list_int('[::-1]', consec_list(7))
4763-
check_indexing_list_int('[::-2]', consec_list(7))
4764-
check_indexing_list_int('[::2]', consec_list(7))
4765-
check_indexing_list_int('[::42]', consec_list(7))
4766-
check_indexing_list_int('[::-42]', consec_list(7))
4767-
check_indexing_list_int('[::42]', consec_list(0))
4768-
check_indexing_list_int('[::-42]', consec_list(0))
4769-
check_indexing_list_int('[::9223372036854775807]', consec_list(42))
4770-
check_indexing_list_int('[::-9223372036854775807]', consec_list(42))
4771-
with self.assertRaisesRegex(RuntimeError, "out of bounds"):
4772-
check_indexing_list_int('[::-9223372036854775808]', consec_list(42))
4773-
with self.assertRaisesRegex(RuntimeError, "should have non-zero step"):
4774-
check_indexing_list_int('[::0]', consec_list(42))
4775-
4776-
# striding strings
4777-
check_indexing_str('[0]', random_string(6))
4778-
check_indexing_str('[1]', random_string(7))
4779-
check_indexing_str('[2]', random_string(8))
4780-
check_indexing_str('[2]', random_string(9))
4781-
check_indexing_str('[-1]', random_string(10))
4782-
check_indexing_str('[0:2]', random_string(11))
4783-
check_indexing_str('[1:-1]', random_string(12))
4784-
check_indexing_str('[-3:-1]', random_string(13))
4785-
check_indexing_str('[1:]', random_string(15))
4786-
check_indexing_str('[:1]', random_string(16))
4787-
check_indexing_str('[:]', random_string(17))
4788-
check_indexing_str('[::]', random_string(0))
4789-
check_indexing_str('[1000::]', random_string(0))
4790-
check_indexing_str('[:1000:]', random_string(0))
4791-
4792-
check_indexing_str('[::-1]', random_string(7))
4793-
check_indexing_str('[:3:-1]', random_string(7))
4794-
check_indexing_str('[3::-1]', random_string(7))
4795-
check_indexing_str('[1000::-1]', random_string(7))
4796-
check_indexing_str('[3:0:-1]', random_string(7))
4797-
check_indexing_str('[3:-1000:-1]', random_string(7))
4798-
check_indexing_str('[0:0:-1]', random_string(7))
4799-
check_indexing_str('[0:-1000:-1]', random_string(7))
4800-
4801-
check_indexing_str('[::-1]', random_string(0))
4802-
check_indexing_str('[::-1]', random_string(7))
4803-
check_indexing_str('[::-2]', random_string(7))
4804-
check_indexing_str('[::2]', random_string(7))
4805-
check_indexing_str('[::42]', random_string(7))
4806-
check_indexing_str('[::-42]', random_string(7))
4807-
check_indexing_str('[::42]', random_string(0))
4808-
check_indexing_str('[::-42]', random_string(0))
4809-
check_indexing_str('[::9223372036854775807]', random_string(42))
4810-
check_indexing_str('[::-9223372036854775807]', random_string(42))
4811-
with self.assertRaisesRegex(RuntimeError, "out of bounds"):
4812-
check_indexing_str('[::-9223372036854775808]', random_string(42))
4813-
with self.assertRaisesRegex(RuntimeError, "should have non-zero step"):
4814-
check_indexing_str('[::0]', random_string(42))
4815-
48164629
def test_module_copy_with_attributes(self):
48174630
class Vocabulary(torch.jit.ScriptModule):
48184631
def __init__(self, vocab_list):

tools/autograd/derivatives.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1052,7 +1052,7 @@
10521052
- name: sinh(Tensor self) -> Tensor
10531053
self: grad * self.cosh().conj()
10541054

1055-
- name: slice.Tensor(Tensor(a) self, int dim=0, int? start=None, int? end=None, int step=1) -> Tensor(a)
1055+
- name: slice.Tensor(Tensor(a) self, int dim=0, int? start=0, int? end=9223372036854775807, int step=1) -> Tensor(a)
10561056
self: slice_backward_wrapper(grad, self.sizes(), dim, start, end, step)
10571057
result: auto_linear
10581058

torch/csrc/jit/frontend/ir_emitter.cpp

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3758,7 +3758,7 @@ struct to_ir {
37583758
Value* end,
37593759
Value* step) {
37603760
std::vector<NamedValue> args;
3761-
args.reserve(5);
3761+
args.reserve(4);
37623762
args.emplace_back(loc, "self", sliceable);
37633763

37643764
// XXX: If list slicing becomes more complicated or stops using
@@ -3770,10 +3770,11 @@ struct to_ir {
37703770
} else {
37713771
AT_ASSERT(!sliceable->type()->isSubtypeOf(TensorType::get()));
37723772
}
3773-
3773+
// TODO for now let's deal with TupleType first. Ideally all list, tensor,
3774+
// string, and tuple slicing should be same (tugsbayasgalan)
37743775
if (sliceable->type()->cast<TupleType>()) {
37753776
std::vector<at::optional<NamedValue>> tuple_args;
3776-
// since we are only dealing with tuple slicing, we try to keep
3777+
// since we are only dealing with tuple slicing for now, we try to keep
37773778
// tuple args seperate for now
37783779
tuple_args.reserve(3);
37793780

@@ -3787,15 +3788,22 @@ struct to_ir {
37873788
return emitTupleSlice(loc, args[0], tuple_args);
37883789
}
37893790

3790-
// handling cases like x[0:2]. x[0:2:] is already handled from python
3791+
// TODO this needs to be cleaned for list slicing
3792+
// Default value for start is 0.
3793+
if (!start) {
3794+
start = graph->insertConstant(0, loc);
3795+
}
3796+
args.emplace_back(loc, "start", start);
3797+
3798+
if (end) {
3799+
args.emplace_back(loc, "end", end);
3800+
}
3801+
37913802
if (!step) {
37923803
step = graph->insertConstant(1, loc);
37933804
}
3794-
3795-
args.emplace_back(loc, "start", start);
3796-
args.emplace_back(loc, "end", end);
3797-
args.emplace_back(loc, "step", step);
3798-
return emitBuiltinCall(loc, *graph, aten::slice, args, {});
3805+
NamedValue step_nv = NamedValue(loc, "step", step);
3806+
return emitBuiltinCall(loc, *graph, aten::slice, args, {step_nv});
37993807
}
38003808

38013809
// Desugars slice indexing: tensor[begin:end] -> tensor.slice(dim, begin, end,

torch/csrc/jit/passes/shape_analysis.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -884,7 +884,7 @@ class ShapePropagator {
884884
"aten::trunc(Tensor self) -> Tensor",
885885
"aten::rot90(Tensor self, int k, int[] dims) -> Tensor",
886886
"aten::narrow(Tensor self, int dim, int start, int length) -> Tensor",
887-
"aten::slice(Tensor self, int dim, int? start=None, int? end=None, int step=1) -> Tensor",
887+
"aten::slice(Tensor self, int dim, int? start=0, int? end=9223372036854775807, int step=1) -> Tensor",
888888
"aten::alias(Tensor self) -> Tensor",
889889
},
890890
[](Node* node) -> type_vec_t {

0 commit comments

Comments
 (0)