Skip to content

Conversation

ysiraichi
Copy link
Collaborator

@ysiraichi ysiraichi commented Sep 17, 2025

This PR refactors {avg,max}_pool{2,3}d operation implementations by improving their error message, and returning a status type value.

Key Changes:

  • Make tensor_methods::avg_pool_nd{,_backward} return StatusOr<XLATensorPtr>
  • Make tensor_methods::max_pool_nd{,_backward} return StatusOr<std::tuple<XLATensorPtr, XLATensorPtr>>
  • Improve error messages and error handling
    • Remove CheckIntList
    • Create the following new functions:
      • RepeatIfSingleElement(span, n): if span is a single-element list, create a new one repeating it n times. Otherwise return the elements in span.
      • CheckPoolNdInputHasSize(...): check that the given list has a specific size
      • FillAndCheckPoolNdInputs(...): runs the 2 functions above for *_pool*d common inputs, i.e. kernel_size, stride, and padding

Example

a = torch.rand(1, 1, 4, 4, 4, device=device)
kernel_size = [2, 2]
stride = []
padding = [0]
torch.nn.functional.avg_pool3d(a, kernel_size, stride, padding)

Before:

Traceback (most recent call last):
  File "examples/pool.py", line 26, in <module>
    torch.nn.functional.avg_pool3d(a, kernel_size, stride, padding)
RuntimeError: Check failed: result.size() == length (2 vs. 3)Invalid length for the 'kernel_size' attribute (at torch_xla/csrc/tensor_methods.cpp:267)

Exception raised from operator& at torch_xla/csrc/runtime/tf_logging.cpp:26 (most recent call first):

After:

Traceback (most recent call last):
  File "examples/pool.py", line 26, in <module>
    torch.nn.functional.avg_pool3d(a, kernel_size, stride, padding)
RuntimeError: avg_pool3d(): expected argument kernel_size [2, 2] (size: 2) to have size of 3.

Status Propagation Trace:
    From: CheckPoolNdInputHasSize at torch_xla/csrc/tensor_methods.cpp:275 (error: avg_pool3d(): expected argument kernel_size [2, 2] (size: 2) to have size of 3.)
    From: FillAndCheckPoolNdInputs at torch_xla/csrc/tensor_methods.cpp:292
    From: avg_pool_nd at torch_xla/csrc/tensor_methods.cpp:1257
    From: avg_pool3d at torch_xla/csrc/aten_xla_type.cpp:1228

Exception raised from ThrowStatusError at torch_xla/csrc/status.cpp:128 (most recent call first):

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant