Skip to content

Commit 6c3e64f

Browse files
committed
Move test.
1 parent 18676c9 commit 6c3e64f

File tree

2 files changed

+13
-11
lines changed

2 files changed

+13
-11
lines changed

test/test_operations.py

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -2367,17 +2367,6 @@ def test_isneginf_no_fallback(self):
23672367
t = t.to(torch.float16)
23682368
self._test_no_fallback(torch.isneginf, (t,))
23692369

2370-
def test_trace_raises_error_on_non_matrix_input(self):
2371-
device = torch_xla.device()
2372-
a = torch.rand(2, 2, 2, device=device)
2373-
2374-
try:
2375-
torch.trace(a)
2376-
except RuntimeError as e:
2377-
expected_error = ("trace(): expected the input tensor f32[2,2,2] to be a "
2378-
"matrix (i.e. a 2D tensor).")
2379-
self.assertEqual(str(e), expected_error)
2380-
23812370

23822371
class MNISTComparator(nn.Module):
23832372

test/test_ops_error_message.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -179,3 +179,16 @@ def test():
179179
callable=test,
180180
expect="""mm(): cannot matrix-multiply tensors f32[2,5] and f32[8,2]. Expected the size of dimension 1 of the first input tensor (5) to be equal the size of dimension 0 of the second input tensor (8)."""
181181
)
182+
183+
def test_trace_raises_error_on_non_matrix_input(self):
184+
device = torch_xla.device()
185+
a = torch.rand(2, 2, 2, device=device)
186+
187+
def test():
188+
torch.trace(a)
189+
190+
self.assertExpectedRaisesInline(
191+
exc_type=RuntimeError,
192+
callable=test,
193+
expect="""trace(): expected the input tensor f32[2,2,2] to be a matrix (i.e. a 2D tensor)."""
194+
)

0 commit comments

Comments
 (0)