From 4f52f8541c84b2b346e87641cd3d3787c5b52633 Mon Sep 17 00:00:00 2001 From: Yukio Siraichi Date: Sat, 13 Sep 2025 11:14:31 -0300 Subject: [PATCH 1/3] Improve error handling for clamp. --- test/test_ops_error_message.py | 16 +++ test/test_ops_error_message.py.bak | 194 +++++++++++++++++++++++++++++ torch_xla/csrc/aten_xla_type.cpp | 31 +++-- torch_xla/csrc/tensor_methods.cpp | 49 ++++---- torch_xla/csrc/tensor_methods.h | 9 +- 5 files changed, 260 insertions(+), 39 deletions(-) create mode 100644 test/test_ops_error_message.py.bak diff --git a/test/test_ops_error_message.py b/test/test_ops_error_message.py index bbb9f4b95b7..baedcb0f831 100644 --- a/test/test_ops_error_message.py +++ b/test/test_ops_error_message.py @@ -179,3 +179,19 @@ def test(): callable=test, expect="""mm(): cannot matrix-multiply tensors f32[2,5] and f32[8,2]. Expected the size of dimension 1 of the first input tensor (5) to be equal the size of dimension 0 of the second input tensor (8).""" ) + + def test_clamp_scalar_raises_error_on_no_min_and_max(self): + device = torch_xla.device() + a = torch.rand(2, 5, device=device) + + def test(): + # Dispatch to `clamp()` overload explicitly. + # Otherwise, it's dispatched to `clamp.Tensor()`, which doesn't have + # this check. + return torch.ops.aten.clamp.default(a) + + self.assertExpectedRaisesInline( + exc_type=RuntimeError, + callable=test, + expect="""clamp(): expected at least one of `min` or `max` arguments to be specified.""" + ) diff --git a/test/test_ops_error_message.py.bak b/test/test_ops_error_message.py.bak new file mode 100644 index 00000000000..65f933a44b1 --- /dev/null +++ b/test/test_ops_error_message.py.bak @@ -0,0 +1,194 @@ +import expecttest +import os +import torch +import torch_xla +import unittest + + +def onlyOnCPU(fn): + accelerator = os.environ.get("PJRT_DEVICE").lower() + return unittest.skipIf(accelerator != "cpu", "PJRT_DEVICE=CPU required")(fn) + + +class TestOpsErrorMessage(expecttest.TestCase): + + def test_add_broadcast_error(self): + a = torch.rand(2, 2, 4, 4, device="xla") + b = torch.rand(2, 2, device="xla") + + def test(): + return torch.add(a, b) + + self.assertExpectedRaisesInline( + exc_type=RuntimeError, + callable=test, + expect="""Shapes are not compatible for broadcasting: f32[2,2,4,4] vs. f32[2,2]. Expected dimension 2 of shape f32[2,2,4,4] (4) to match dimension 0 of shape f32[2,2] (2). Either that or that any of them is either 1 or unbounded. Try reshaping one of the tensors to match the other.""" + ) + + @onlyOnCPU + def test_construct_large_tensor_raises_error(self): + + def test(): + # When eager-mode is enabled, OOM is triggered here. + a = torch.rand(1024, 1024, 1024, 1024, 1024, device=torch_xla.device()) + b = a.sum() + # OOM is raised when we try to bring data from the device. + return b.cpu() + + self.assertExpectedRaisesInline( + exc_type=RuntimeError, + callable=test, + expect="""Error preparing computation: Out of memory allocating 4503599761588224 bytes.""" + ) + + def test_cat_raises_error_on_incompatible_shapes(self): + a = torch.rand(2, 2, device=torch_xla.device()) + b = torch.rand(5, 1, device=torch_xla.device()) + + def test(): + return torch.cat([a, b]) + + self.assertExpectedRaisesInline( + exc_type=RuntimeError, + callable=test, + expect="""cat(): cannot concatenate tensors of shape f32[2,2] with f32[5,1] at dimension 0. Expected shapes to be equal (except at dimension 0) or that either of them was a 1D empty tensor of size (0,).""" + ) + + def test_div_raises_error_on_invalid_rounding_mode(self): + a = torch.rand(2, 2, device=torch_xla.device()) + + def test(): + return torch.div(a, 2, rounding_mode="bad") + + self.assertExpectedRaisesInline( + exc_type=RuntimeError, + callable=test, + expect="""div(): invalid rounding mode `bad`. Expected it to be either 'trunc', 'floor', or be left unspecified.""" + ) + + def test_flip_raises_error_on_duplicated_dims(self): + a = torch.rand(2, 2, 2, 2, device=torch_xla.device()) + dims = [0, 0, 0, 1, 2, 3, -1] + + def test(): + return torch.flip(a, dims=dims) + + self.assertExpectedRaisesInline( + exc_type=RuntimeError, + callable=test, + expect="""flip(): expected each dimension to appear at most once. Found dimensions: 0 (3 times), 3 (2 times). Consider changing dims from [0, 0, 0, 1, 2, 3, -1] to [0, 1, 2, 3].""" + ) + + def test_full_raises_error_on_negative_size(self): + shape = [2, -2, 2] + + def test(): + return torch.full(shape, 1.5, device="xla") + + self.assertExpectedRaisesInline( + exc_type=RuntimeError, + callable=test, + expect="""full(): expected concrete sizes (i.e. non-symbolic) to be positive values. However found negative ones: [2, -2, 2].""" + ) + + def test_gather_raises_error_on_rank_mismatch(self): + S = 2 + + input = torch.arange(4, device=torch_xla.device()).view(S, S) + index = torch.randint(0, S, (S, S, S), device=torch_xla.device()) + dim = 1 + + def test(): + return torch.gather(input, dim, index) + + self.assertExpectedRaisesInline( + exc_type=RuntimeError, + callable=test, + expect="""gather(): expected rank of input (2) and index (3) tensors to be the same.""" + ) + + def test_gather_raises_error_on_invalid_index_size(self): + S = 2 + X = S + 2 + + input = torch.arange(16, device=torch_xla.device()).view(S, S, S, S) + index = torch.randint(0, S, (X, S, X, S), device=torch_xla.device()) + dim = 1 + + def test(): + return torch.gather(input, dim, index) + + self.assertExpectedRaisesInline( + exc_type=RuntimeError, + callable=test, + expect="""gather(): expected sizes of index [4, 2, 4, 2] to be smaller or equal those of input [2, 2, 2, 2] on all dimensions, except on dimension 1. However, that's not true on dimensions [0, 2].""" + ) + + def test_random__raises_error_on_empty_interval(self): + a = torch.empty(10, device=torch_xla.device()) + from_ = 3 + to_ = 1 + + def test(): + return a.random_(from_, to_) + + self.assertExpectedRaisesInline( + exc_type=RuntimeError, + callable=test, + expect="""random_(): expected `from` (3) to be smaller than `to` (1).""" + ) + + def test_random__raises_error_on_value_out_of_type_value_range(self): + a = torch.empty(10, device=torch_xla.device(), dtype=torch.float16) + from_ = 3 + to_ = 65_504 + 2 + + def test(): + return a.random_(from_, to_) + + self.assertExpectedRaisesInline( + exc_type=RuntimeError, + callable=test, + expect="""random_(): expected `to` to be within the range [-65504, 65504]. However got value 65505, which is greater than the upper bound.""" + ) + + def test_mm_raises_error_on_non_matrix_input(self): + device = torch_xla.device() + a = torch.rand(2, 2, 2, device=device) + b = torch.rand(2, 2, device=device) + + def test(): + torch.mm(a, b) + + self.assertExpectedRaisesInline( + exc_type=RuntimeError, + callable=test, + expect="""mm(): expected the first input tensor f32[2,2,2] to be a matrix (i.e. a 2D tensor).""" + ) + + def test_mm_raises_error_on_incompatible_shapes(self): + device = torch_xla.device() + a = torch.rand(2, 5, device=device) + b = torch.rand(8, 2, device=device) + + def test(): + torch.mm(a, b) + + self.assertExpectedRaisesInline( + exc_type=RuntimeError, + callable=test, + expect="""mm(): cannot matrix-multiply tensors f32[2,5] and f32[8,2]. Expected the size of dimension 1 of the first input tensor (5) to be equal the size of dimension 0 of the second input tensor (8).""" + ) + + def test_clamp_raises_error_on_no_min_and_max(self): + device = torch_xla.device() + a = torch.rand(2, 5, device=device) + + def test(): + return torch.ops.aten.clamp.default(a) + + self.assertExpectedRaisesInline( + exc_type=RuntimeError, + callable=test, + expect="""mm(): cannot matrix-multiply tensors f32[2,5] and f32[8,2]. Expected the size of dimension 1 of the first input tensor (5) to be equal the size of dimension 0 of the second input tensor (8).""" + ) diff --git a/torch_xla/csrc/aten_xla_type.cpp b/torch_xla/csrc/aten_xla_type.cpp index c042d703aa3..6654a808425 100644 --- a/torch_xla/csrc/aten_xla_type.cpp +++ b/torch_xla/csrc/aten_xla_type.cpp @@ -1372,24 +1372,31 @@ at::Tensor XLANativeFunctions::clamp(const at::Tensor& self, const std::optional& min, const std::optional& max) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); - XLA_ASSIGN_OR_THROW(XLATensorPtr xla_self, bridge::GetXlaTensor(self)); - return bridge::AtenFromXlaTensor(tensor_methods::clamp(xla_self, min, max)); + XLA_ASSIGN_OR_THROW(absl_nonnull XLATensorPtr xla_self, + bridge::GetXlaTensor(self)); + XLA_ASSIGN_OR_THROW(absl_nonnull XLATensorPtr output, + tensor_methods::clamp(xla_self, min, max)); + return bridge::AtenFromXlaTensor(std::move(output)); } at::Tensor XLANativeFunctions::clamp_max(const at::Tensor& self, const at::Scalar& max) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); - XLA_ASSIGN_OR_THROW(XLATensorPtr xla_self, bridge::GetXlaTensor(self)); - return bridge::AtenFromXlaTensor( - tensor_methods::clamp(xla_self, std::nullopt, max)); + XLA_ASSIGN_OR_THROW(absl_nonnull XLATensorPtr xla_self, + bridge::GetXlaTensor(self)); + XLA_ASSIGN_OR_THROW(absl_nonnull XLATensorPtr output, + tensor_methods::clamp(xla_self, std::nullopt, max)); + return bridge::AtenFromXlaTensor(std::move(output)); } at::Tensor XLANativeFunctions::clamp_min(const at::Tensor& self, const at::Scalar& min) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); - XLA_ASSIGN_OR_THROW(XLATensorPtr xla_self, bridge::GetXlaTensor(self)); - return bridge::AtenFromXlaTensor( - tensor_methods::clamp(xla_self, min, std::nullopt)); + XLA_ASSIGN_OR_THROW(absl_nonnull XLATensorPtr xla_self, + bridge::GetXlaTensor(self)); + XLA_ASSIGN_OR_THROW(absl_nonnull XLATensorPtr output, + tensor_methods::clamp(xla_self, min, std::nullopt)); + return bridge::AtenFromXlaTensor(std::move(output)); } at::Tensor XLANativeFunctions::clone( @@ -1947,9 +1954,11 @@ at::Tensor XLANativeFunctions::hardtanh(const at::Tensor& self, const at::Scalar& min_val, const at::Scalar& max_val) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); - XLA_ASSIGN_OR_THROW(XLATensorPtr xla_self, bridge::GetXlaTensor(self)); - return bridge::AtenFromXlaTensor( - tensor_methods::clamp(xla_self, min_val, max_val)); + XLA_ASSIGN_OR_THROW(absl_nonnull XLATensorPtr xla_self, + bridge::GetXlaTensor(self)); + XLA_ASSIGN_OR_THROW(absl_nonnull XLATensorPtr output, + tensor_methods::clamp(xla_self, min_val, max_val)); + return bridge::AtenFromXlaTensor(std::move(output)); } at::Tensor XLANativeFunctions::hardtanh_backward(const at::Tensor& grad_output, diff --git a/torch_xla/csrc/tensor_methods.cpp b/torch_xla/csrc/tensor_methods.cpp index 21f4db59713..21fe9ec7556 100644 --- a/torch_xla/csrc/tensor_methods.cpp +++ b/torch_xla/csrc/tensor_methods.cpp @@ -177,22 +177,6 @@ torch::lazy::Value MaybeExpand(const torch::lazy::Value& input, input, torch::lazy::ToVector(target_shape.dimensions())); } -MinMaxValues GetMinMaxValues(const XLATensorPtr& tensor, - const std::optional& min, - const std::optional& max) { - XLA_CHECK(min || max) - << "At least one of \'min\' or \'max\' must not be None"; - xla::PrimitiveType raw_element_type = XlaTypeFromTorchType(tensor->dtype()); - XlaHelpers::MinMax min_max = XlaHelpers::MinMaxValues(raw_element_type); - auto shape = tensor->shape(); - return {XLAGraphExecutor::Get()->GetIrValueForScalar( - min ? *min : min_max.min, shape.get().element_type(), - tensor->GetDevice()), - XLAGraphExecutor::Get()->GetIrValueForScalar( - max ? *max : min_max.max, shape.get().element_type(), - tensor->GetDevice())}; -} - void CheckRank(const XLATensorPtr& t, int64_t expected_rank, const std::string& tag, const std::string& arg_name, int arg_number) { @@ -506,6 +490,16 @@ absl::Status CheckMMMatrixSizesAreCompatible(const XLATensorPtr& mat1, return absl::OkStatus(); } +absl::Status CheckClampMinOrMax(const std::optional& min, + const std::optional& max) { + if (!min.has_value() && !max.has_value()) { + return XLA_ERROR_WITH_LOCATION( + absl::InvalidArgumentError("clamp(): expected at least one of `min` or " + "`max` arguments to be specified.")); + } + return absl::OkStatus(); +} + } // namespace ////////////////////////////////////////////////////////////////////////////// @@ -1357,12 +1351,23 @@ void celu_(XLATensorPtr& input, const at::Scalar& alpha) { input->SetInPlaceIrValue(Celu(input->GetIrValue(), alpha)); } -XLATensorPtr clamp(const XLATensorPtr& input, - const std::optional& min, - const std::optional& max) { - MinMaxValues min_max = GetMinMaxValues(input, min, max); - return input->CreateFrom( - Clamp(input->GetIrValue(), min_max.min, min_max.max)); +absl::StatusOr clamp( + const XLATensorPtr& input, const std::optional& min, + const std::optional& max) { + XLA_RETURN_IF_ERROR(CheckClampMinOrMax(min, max)); + + xla::Shape shape = input->shape(); + const torch::lazy::BackendDevice& device = input->GetDevice(); + + xla::PrimitiveType raw_element_type = XlaTypeFromTorchType(input->dtype()); + XlaHelpers::MinMax min_max = XlaHelpers::MinMaxValues(raw_element_type); + + torch::lazy::Value min_value = XLAGraphExecutor::Get()->GetIrValueForScalar( + min.value_or(min_max.min), shape.element_type(), device); + torch::lazy::Value max_value = XLAGraphExecutor::Get()->GetIrValueForScalar( + max.value_or(min_max.max), shape.element_type(), device); + + return input->CreateFrom(Clamp(input->GetIrValue(), min_value, max_value)); } XLATensorPtr clone(const XLATensorPtr& input) { diff --git a/torch_xla/csrc/tensor_methods.h b/torch_xla/csrc/tensor_methods.h index b25b423d49c..e980d483502 100644 --- a/torch_xla/csrc/tensor_methods.h +++ b/torch_xla/csrc/tensor_methods.h @@ -316,12 +316,9 @@ XLATensorPtr pixel_shuffle(const XLATensorPtr& self, int64_t upscale_factor); XLATensorPtr celu(const XLATensorPtr& input, const at::Scalar& alpha); void celu_(XLATensorPtr& input, const at::Scalar& alpha); -XLATensorPtr clamp(const XLATensorPtr& input, - const std::optional& min, - const std::optional& max); -XLATensorPtr clamp(const XLATensorPtr& input, - const std::optional& min, - const std::optional& max); +absl::StatusOr clamp( + const XLATensorPtr& input, const std::optional& min, + const std::optional& max); XLATensorPtr clone(const XLATensorPtr& input); From 62938780fccd5cce03b34ca16ef97b25c912bcc1 Mon Sep 17 00:00:00 2001 From: Yukio Siraichi Date: Sat, 13 Sep 2025 11:22:59 -0300 Subject: [PATCH 2/3] Remove MinMaxValues struct. --- torch_xla/csrc/tensor_methods.cpp | 5 ----- 1 file changed, 5 deletions(-) diff --git a/torch_xla/csrc/tensor_methods.cpp b/torch_xla/csrc/tensor_methods.cpp index 21fe9ec7556..ee6587f40e8 100644 --- a/torch_xla/csrc/tensor_methods.cpp +++ b/torch_xla/csrc/tensor_methods.cpp @@ -163,11 +163,6 @@ namespace torch_xla { namespace tensor_methods { namespace { -struct MinMaxValues { - torch::lazy::Value min; - torch::lazy::Value max; -}; - torch::lazy::Value MaybeExpand(const torch::lazy::Value& input, const xla::Shape& target_shape) { if (GetXlaShape(input).dimensions() == target_shape.dimensions()) { From 214b375eaa4f8480578b5ae430800198b71c4838 Mon Sep 17 00:00:00 2001 From: Yukio Siraichi Date: Sat, 13 Sep 2025 12:42:33 -0300 Subject: [PATCH 3/3] Remove temp file. --- test/test_ops_error_message.py.bak | 194 ----------------------------- 1 file changed, 194 deletions(-) delete mode 100644 test/test_ops_error_message.py.bak diff --git a/test/test_ops_error_message.py.bak b/test/test_ops_error_message.py.bak deleted file mode 100644 index 65f933a44b1..00000000000 --- a/test/test_ops_error_message.py.bak +++ /dev/null @@ -1,194 +0,0 @@ -import expecttest -import os -import torch -import torch_xla -import unittest - - -def onlyOnCPU(fn): - accelerator = os.environ.get("PJRT_DEVICE").lower() - return unittest.skipIf(accelerator != "cpu", "PJRT_DEVICE=CPU required")(fn) - - -class TestOpsErrorMessage(expecttest.TestCase): - - def test_add_broadcast_error(self): - a = torch.rand(2, 2, 4, 4, device="xla") - b = torch.rand(2, 2, device="xla") - - def test(): - return torch.add(a, b) - - self.assertExpectedRaisesInline( - exc_type=RuntimeError, - callable=test, - expect="""Shapes are not compatible for broadcasting: f32[2,2,4,4] vs. f32[2,2]. Expected dimension 2 of shape f32[2,2,4,4] (4) to match dimension 0 of shape f32[2,2] (2). Either that or that any of them is either 1 or unbounded. Try reshaping one of the tensors to match the other.""" - ) - - @onlyOnCPU - def test_construct_large_tensor_raises_error(self): - - def test(): - # When eager-mode is enabled, OOM is triggered here. - a = torch.rand(1024, 1024, 1024, 1024, 1024, device=torch_xla.device()) - b = a.sum() - # OOM is raised when we try to bring data from the device. - return b.cpu() - - self.assertExpectedRaisesInline( - exc_type=RuntimeError, - callable=test, - expect="""Error preparing computation: Out of memory allocating 4503599761588224 bytes.""" - ) - - def test_cat_raises_error_on_incompatible_shapes(self): - a = torch.rand(2, 2, device=torch_xla.device()) - b = torch.rand(5, 1, device=torch_xla.device()) - - def test(): - return torch.cat([a, b]) - - self.assertExpectedRaisesInline( - exc_type=RuntimeError, - callable=test, - expect="""cat(): cannot concatenate tensors of shape f32[2,2] with f32[5,1] at dimension 0. Expected shapes to be equal (except at dimension 0) or that either of them was a 1D empty tensor of size (0,).""" - ) - - def test_div_raises_error_on_invalid_rounding_mode(self): - a = torch.rand(2, 2, device=torch_xla.device()) - - def test(): - return torch.div(a, 2, rounding_mode="bad") - - self.assertExpectedRaisesInline( - exc_type=RuntimeError, - callable=test, - expect="""div(): invalid rounding mode `bad`. Expected it to be either 'trunc', 'floor', or be left unspecified.""" - ) - - def test_flip_raises_error_on_duplicated_dims(self): - a = torch.rand(2, 2, 2, 2, device=torch_xla.device()) - dims = [0, 0, 0, 1, 2, 3, -1] - - def test(): - return torch.flip(a, dims=dims) - - self.assertExpectedRaisesInline( - exc_type=RuntimeError, - callable=test, - expect="""flip(): expected each dimension to appear at most once. Found dimensions: 0 (3 times), 3 (2 times). Consider changing dims from [0, 0, 0, 1, 2, 3, -1] to [0, 1, 2, 3].""" - ) - - def test_full_raises_error_on_negative_size(self): - shape = [2, -2, 2] - - def test(): - return torch.full(shape, 1.5, device="xla") - - self.assertExpectedRaisesInline( - exc_type=RuntimeError, - callable=test, - expect="""full(): expected concrete sizes (i.e. non-symbolic) to be positive values. However found negative ones: [2, -2, 2].""" - ) - - def test_gather_raises_error_on_rank_mismatch(self): - S = 2 - - input = torch.arange(4, device=torch_xla.device()).view(S, S) - index = torch.randint(0, S, (S, S, S), device=torch_xla.device()) - dim = 1 - - def test(): - return torch.gather(input, dim, index) - - self.assertExpectedRaisesInline( - exc_type=RuntimeError, - callable=test, - expect="""gather(): expected rank of input (2) and index (3) tensors to be the same.""" - ) - - def test_gather_raises_error_on_invalid_index_size(self): - S = 2 - X = S + 2 - - input = torch.arange(16, device=torch_xla.device()).view(S, S, S, S) - index = torch.randint(0, S, (X, S, X, S), device=torch_xla.device()) - dim = 1 - - def test(): - return torch.gather(input, dim, index) - - self.assertExpectedRaisesInline( - exc_type=RuntimeError, - callable=test, - expect="""gather(): expected sizes of index [4, 2, 4, 2] to be smaller or equal those of input [2, 2, 2, 2] on all dimensions, except on dimension 1. However, that's not true on dimensions [0, 2].""" - ) - - def test_random__raises_error_on_empty_interval(self): - a = torch.empty(10, device=torch_xla.device()) - from_ = 3 - to_ = 1 - - def test(): - return a.random_(from_, to_) - - self.assertExpectedRaisesInline( - exc_type=RuntimeError, - callable=test, - expect="""random_(): expected `from` (3) to be smaller than `to` (1).""" - ) - - def test_random__raises_error_on_value_out_of_type_value_range(self): - a = torch.empty(10, device=torch_xla.device(), dtype=torch.float16) - from_ = 3 - to_ = 65_504 + 2 - - def test(): - return a.random_(from_, to_) - - self.assertExpectedRaisesInline( - exc_type=RuntimeError, - callable=test, - expect="""random_(): expected `to` to be within the range [-65504, 65504]. However got value 65505, which is greater than the upper bound.""" - ) - - def test_mm_raises_error_on_non_matrix_input(self): - device = torch_xla.device() - a = torch.rand(2, 2, 2, device=device) - b = torch.rand(2, 2, device=device) - - def test(): - torch.mm(a, b) - - self.assertExpectedRaisesInline( - exc_type=RuntimeError, - callable=test, - expect="""mm(): expected the first input tensor f32[2,2,2] to be a matrix (i.e. a 2D tensor).""" - ) - - def test_mm_raises_error_on_incompatible_shapes(self): - device = torch_xla.device() - a = torch.rand(2, 5, device=device) - b = torch.rand(8, 2, device=device) - - def test(): - torch.mm(a, b) - - self.assertExpectedRaisesInline( - exc_type=RuntimeError, - callable=test, - expect="""mm(): cannot matrix-multiply tensors f32[2,5] and f32[8,2]. Expected the size of dimension 1 of the first input tensor (5) to be equal the size of dimension 0 of the second input tensor (8).""" - ) - - def test_clamp_raises_error_on_no_min_and_max(self): - device = torch_xla.device() - a = torch.rand(2, 5, device=device) - - def test(): - return torch.ops.aten.clamp.default(a) - - self.assertExpectedRaisesInline( - exc_type=RuntimeError, - callable=test, - expect="""mm(): cannot matrix-multiply tensors f32[2,5] and f32[8,2]. Expected the size of dimension 1 of the first input tensor (5) to be equal the size of dimension 0 of the second input tensor (8).""" - )