From e04a32f88f96890e0f7a5c88a3740f1361da8c52 Mon Sep 17 00:00:00 2001 From: Yukio Siraichi Date: Wed, 10 Sep 2025 16:14:28 -0300 Subject: [PATCH 1/2] Improve error handling and error messages for uniform_. --- test/test_operations.py | 13 +++++++++++++ torch_xla/csrc/aten_xla_type.cpp | 5 +++-- torch_xla/csrc/tensor_methods.cpp | 20 +++++++++++++++----- torch_xla/csrc/tensor_methods.h | 2 +- 4 files changed, 32 insertions(+), 8 deletions(-) diff --git a/test/test_operations.py b/test/test_operations.py index a02bb746666..a8f61950136 100644 --- a/test/test_operations.py +++ b/test/test_operations.py @@ -2367,6 +2367,19 @@ def test_isneginf_no_fallback(self): t = t.to(torch.float16) self._test_no_fallback(torch.isneginf, (t,)) + def test_uniform__raises_error_on_invalid_range(self): + device = torch_xla.device() + a = torch.empty(5, 5, device=device) + from_ = 5. + to_ = 2. + + try: + a.uniform_(from_, to_) + except RuntimeError as e: + expected_error = ( + "uniform_(): expected `from` (5) to be smaller or equal `to` (2).") + self.assertEqual(str(e), expected_error) + class MNISTComparator(nn.Module): diff --git a/torch_xla/csrc/aten_xla_type.cpp b/torch_xla/csrc/aten_xla_type.cpp index 91101778d7a..e797d8e7c2c 100644 --- a/torch_xla/csrc/aten_xla_type.cpp +++ b/torch_xla/csrc/aten_xla_type.cpp @@ -3916,8 +3916,9 @@ at::Tensor& XLANativeFunctions::uniform_( return at::native::call_fallback_fn<&xla_fallback, ATEN_OP(uniform_)>::call( self, from, to, generator); } - XLA_ASSIGN_OR_THROW(XLATensorPtr xla_self, bridge::GetXlaTensor(self)); - tensor_methods::uniform_(xla_self, from, to); + XLA_ASSIGN_OR_THROW(absl_nonnull XLATensorPtr xla_self, + bridge::GetXlaTensor(self)); + XLA_THROW_IF_ERROR(tensor_methods::uniform_(xla_self, from, to)); return self; } diff --git a/torch_xla/csrc/tensor_methods.cpp b/torch_xla/csrc/tensor_methods.cpp index 0ec204b0fff..c7627be2643 100644 --- a/torch_xla/csrc/tensor_methods.cpp +++ b/torch_xla/csrc/tensor_methods.cpp @@ -578,6 +578,15 @@ absl::Status CheckStackAtLeastOneTensor( return absl::OkStatus(); } +absl::Status CheckUniformRangeIsValid(double from, double to) { + if (from > to) { + return XLA_ERROR_WITH_LOCATION(absl::InvalidArgumentError( + absl::StrCat("uniform_(): expected `from` (", from, + ") to be smaller or equal `to` (", to, ")."))); + } + return absl::OkStatus(); +} + } // namespace ////////////////////////////////////////////////////////////////////////////// @@ -3727,15 +3736,16 @@ std::vector unbind(const XLATensorPtr& input, int64_t dim) { return slices; } -void uniform_(XLATensorPtr& input, double from, double to) { - XLA_CHECK_LE(from, to); - auto input_shape = input->shape(); +absl::Status uniform_(XLATensorPtr& input, double from, double to) { + XLA_RETURN_IF_ERROR(CheckUniformRangeIsValid(from, to)); + xla::Shape input_shape = input->shape(); input->SetInPlaceIrValue(torch_xla::MakeNode( XLAGraphExecutor::Get()->GetIrValueForScalar( - from, input_shape.get().element_type(), input->GetDevice()), + from, input_shape.element_type(), input->GetDevice()), XLAGraphExecutor::Get()->GetIrValueForScalar( - to, input_shape.get().element_type(), input->GetDevice()), + to, input_shape.element_type(), input->GetDevice()), XLAGraphExecutor::Get()->GetRngSeed(input->GetDevice()), input_shape)); + return absl::OkStatus(); } XLATensorPtr unsqueeze(const XLATensorPtr& input, int64_t dim) { diff --git a/torch_xla/csrc/tensor_methods.h b/torch_xla/csrc/tensor_methods.h index e91e92ad96e..47e2ad1e4ef 100644 --- a/torch_xla/csrc/tensor_methods.h +++ b/torch_xla/csrc/tensor_methods.h @@ -989,7 +989,7 @@ std::tuple triangular_solve( // removed. std::vector unbind(const XLATensorPtr& input, int64_t dim); -void uniform_(XLATensorPtr& input, double from, double to); +absl::Status uniform_(XLATensorPtr& input, double from, double to); // Insert a dimension of size one at the specified position. XLATensorPtr unsqueeze(const XLATensorPtr& input, int64_t dim); From d1e8007184508429d2dcff9dafc3f7139a0859a2 Mon Sep 17 00:00:00 2001 From: Yukio Siraichi Date: Fri, 12 Sep 2025 16:24:31 -0300 Subject: [PATCH 2/2] Move test. --- test/test_operations.py | 13 ------------- test/test_ops_error_message.py | 15 +++++++++++++++ 2 files changed, 15 insertions(+), 13 deletions(-) diff --git a/test/test_operations.py b/test/test_operations.py index a8f61950136..a02bb746666 100644 --- a/test/test_operations.py +++ b/test/test_operations.py @@ -2367,19 +2367,6 @@ def test_isneginf_no_fallback(self): t = t.to(torch.float16) self._test_no_fallback(torch.isneginf, (t,)) - def test_uniform__raises_error_on_invalid_range(self): - device = torch_xla.device() - a = torch.empty(5, 5, device=device) - from_ = 5. - to_ = 2. - - try: - a.uniform_(from_, to_) - except RuntimeError as e: - expected_error = ( - "uniform_(): expected `from` (5) to be smaller or equal `to` (2).") - self.assertEqual(str(e), expected_error) - class MNISTComparator(nn.Module): diff --git a/test/test_ops_error_message.py b/test/test_ops_error_message.py index bb23810f234..617c42da717 100644 --- a/test/test_ops_error_message.py +++ b/test/test_ops_error_message.py @@ -222,6 +222,21 @@ def test(): expect="""roll(): expected `dims` [0] (size=1) to match the size of `shifts` [2, 2] (size=2).""" ) + def test_uniform__raises_error_on_invalid_range(self): + device = torch_xla.device() + a = torch.empty(5, 5, device=device) + from_ = 5. + to_ = 2. + + def test(): + return a.uniform_(from_, to_) + + self.assertExpectedRaisesInline( + exc_type=RuntimeError, + callable=test, + expect="""uniform_(): expected `from` (5) to be smaller or equal `to` (2).""" + ) + if __name__ == "__main__": unittest.main()