Skip to content

Commit e013a25

Browse files
committed
Move test.
1 parent 0e4c662 commit e013a25

File tree

2 files changed

+13
-11
lines changed

2 files changed

+13
-11
lines changed

test/test_operations.py

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -2367,17 +2367,6 @@ def test_isneginf_no_fallback(self):
23672367
t = t.to(torch.float16)
23682368
self._test_no_fallback(torch.isneginf, (t,))
23692369

2370-
def test_trace_raises_error_on_non_matrix_input(self):
2371-
device = torch_xla.device()
2372-
a = torch.rand(2, 2, 2, device=device)
2373-
2374-
try:
2375-
torch.trace(a)
2376-
except RuntimeError as e:
2377-
expected_error = ("trace(): expected the input tensor f32[2,2,2] to be a "
2378-
"matrix (i.e. a 2D tensor).")
2379-
self.assertEqual(str(e), expected_error)
2380-
23812370

23822371
class MNISTComparator(nn.Module):
23832372

test/test_ops_error_message.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -222,6 +222,19 @@ def test():
222222
expect="""roll(): expected `dims` [0] (size=1) to match the size of `shifts` [2, 2] (size=2)."""
223223
)
224224

225+
def test_trace_raises_error_on_non_matrix_input(self):
226+
device = torch_xla.device()
227+
a = torch.rand(2, 2, 2, device=device)
228+
229+
def test():
230+
torch.trace(a)
231+
232+
self.assertExpectedRaisesInline(
233+
exc_type=RuntimeError,
234+
callable=test,
235+
expect="""trace(): expected the input tensor f32[2,2,2] to be a matrix (i.e. a 2D tensor)."""
236+
)
237+
225238

226239
if __name__ == "__main__":
227240
unittest.main()

0 commit comments

Comments
 (0)