Skip to content

Conversation

ysiraichi
Copy link
Collaborator

This PR refactors the bmm operation implementation by improving its error message, and returning a status type value.

Key Changes:

  • Make tensor_methods::bmm return StatusOr<XLATensorPtr>
  • Improve error messages and error handling
    • Remove CheckBmmDimension, CheckDimensionSize, and CheckRank
    • Create new CheckBmmInputsAreValid, and CheckInputIs3DTensor functions

Example 1: not a 3D tensor

a = torch.rand(2, 3, device=device)
b = torch.rand(2, 3, 3, device=device)
torch.bmm(a, b)
Comparison

Before:

Traceback (most recent call last):
  File "examples/bmm.py", line 39, in <module>
    torch.bmm(a, b)
RuntimeError: Check failed: actual_rank == expected_rank (2 vs. 3)Expected 3-dimensional tensor, but got 2-dimensional tensor for argument #1 'batch1' (while checking arguments for bmm) (at torch_xla/csrc/tensor_methods.cpp:200)

Exception raised from operator& at torch_xla/csrc/runtime/tf_logging.cpp:26 (most recent call first):

After:

Traceback (most recent call last):
  File "examples/bmm.py", line 39, in <module>
    torch.bmm(a, b)
RuntimeError: bmm(): expected `input` f32[2,3] (a 2D tensor), the 1st input tensor, to be a 3D tensor.

Status Propagation Trace:
    From: CheckInputIs3DTensor at torch_xla/csrc/tensor_methods.cpp:492 (error: bmm(): expected `input` f32[2,3] (a 2D tensor), the 1st input tensor, to be a 3D tensor.)
    From: CheckBMMInputsAreValid at torch_xla/csrc/tensor_methods.cpp:503
    From: bmm at torch_xla/csrc/tensor_methods.cpp:1300
    From: bmm at torch_xla/csrc/aten_xla_type.cpp:1349

Exception raised from ThrowStatusError at torch_xla/csrc/status.cpp:128 (most recent call first):

Example 2: batch dimension size don't match

a = torch.rand(2, 3, 4, device=device)
b = torch.rand(4, 4, 3, device=device)
torch.bmm(a, b)
Comparison

Before:

Traceback (most recent call last):
  File "examples/bmm.py", line 39, in <module>
    torch.bmm(a, b)
RuntimeError: Check failed: t->size(dim) == expected_size (4 vs. 2)Expected tensor to have size 2 at dimension 0, but got size 4 for argument #2 'batch2' (while checking arguments for bmm) (at torch_xla/csrc/tensor_methods.cpp:218)

Exception raised from operator& at torch_xla/csrc/runtime/tf_logging.cpp:26 (most recent call first):

After:

Traceback (most recent call last):
  File "examples/bmm.py", line 39, in <module>
    torch.bmm(a, b)
RuntimeError: bmm(): expected the size of the batch dimension (i.e. dimension 0) of `input` f32[2,3,4] (batch dimension size: 2), the 1st input tensor, to be the same as the size of the batch dimension of `mat2` f32[4,4,3] (batch dimension size: 4), the 2nd input tensor.

Status Propagation Trace:
    From: CheckBMMInputsAreValid at torch_xla/csrc/tensor_methods.cpp:507 (error: bmm(): expected the size of the batch dimension (i.e. dimension 0) of `input` f32[2,3,4] (batch dimension size: 2), the 1st input tensor, to be the same as the size of the batch dimension of `mat2` f32[4,4,3] (batch dimension size: 4), the 2nd input tensor.)
    From: bmm at torch_xla/csrc/tensor_methods.cpp:1300
    From: bmm at torch_xla/csrc/aten_xla_type.cpp:1349

Exception raised from ThrowStatusError at torch_xla/csrc/status.cpp:128 (most recent call first):

Example 3: invalid shapes

a = torch.rand(2, 3, 4, device=device)
b = torch.rand(2, 2, 3, device=device)
torch.bmm(a, b)
Comparison

Before:

Traceback (most recent call last):
  File "examples/bmm.py", line 39, in <module>
    torch.bmm(a, b)
RuntimeError: Check failed: t->size(dim) == expected_size (2 vs. 4)Expected tensor to have size 4 at dimension 1, but got size 2 for argument #2 'batch2' (while checking arguments for bmm) (at torch_xla/csrc/tensor_methods.cpp:218)

Exception raised from operator& at torch_xla/csrc/runtime/tf_logging.cpp:26 (most recent call first):

After:

Traceback (most recent call last):
  File "examples/bmm.py", line 39, in <module>
    torch.bmm(a, b)
RuntimeError: bmm(): cannot apply batch matrix-multiplication to `input` f32[2,3,4], the 1st input tensor, and to `mat2` f32[2,2,3], the 2nd input tensor. Expected the size of dimension 2 of `input` (4) to be equal the size of dimension 1 of `mat2` (2).

Status Propagation Trace:
    From: CheckBMMInputsAreValid at torch_xla/csrc/tensor_methods.cpp:519 (error: bmm(): cannot apply batch matrix-multiplication to `input` f32[2,3,4], the 1st input tensor, and to `mat2` f32[2,2,3], the 2nd input tensor. Expected the size of dimension 2 of `input` (4) to be equal the size of dimension 1 of `mat2` (2).)
    From: bmm at torch_xla/csrc/tensor_methods.cpp:1300
    From: bmm at torch_xla/csrc/aten_xla_type.cpp:1349

Exception raised from ThrowStatusError at torch_xla/csrc/status.cpp:128 (most recent call first):

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants