Skip to content

Commit 3441991

Browse files
committed
Improve error handling and error messages for uniform_.
1 parent caa809f commit 3441991

File tree

4 files changed

+33
-8
lines changed

4 files changed

+33
-8
lines changed

test/test_operations.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2531,6 +2531,19 @@ def test_mm_raises_error_on_incompatible_shapes(self):
25312531
"tensor (8).")
25322532
self.assertEqual(str(e), expected_error)
25332533

2534+
def test_uniform__raises_error_on_invalid_range(self):
2535+
device = torch_xla.device()
2536+
a = torch.empty(5, 5, device=device)
2537+
from_ = 5.
2538+
to_ = 2.
2539+
2540+
try:
2541+
a.uniform_(from_, to_)
2542+
except RuntimeError as e:
2543+
expected_error = (
2544+
"uniform_(): expected `from` (5) to be smaller or equal `to` (2).")
2545+
self.assertEqual(str(e), expected_error)
2546+
25342547

25352548
class MNISTComparator(nn.Module):
25362549

torch_xla/csrc/aten_xla_type.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3905,8 +3905,9 @@ at::Tensor& XLANativeFunctions::uniform_(
39053905
return at::native::call_fallback_fn<&xla_fallback, ATEN_OP(uniform_)>::call(
39063906
self, from, to, generator);
39073907
}
3908-
XLA_ASSIGN_OR_THROW(XLATensorPtr xla_self, bridge::GetXlaTensor(self));
3909-
tensor_methods::uniform_(xla_self, from, to);
3908+
XLA_ASSIGN_OR_THROW(absl_nonnull XLATensorPtr xla_self,
3909+
bridge::GetXlaTensor(self));
3910+
XLA_THROW_IF_ERROR(tensor_methods::uniform_(xla_self, from, to));
39103911
return self;
39113912
}
39123913

torch_xla/csrc/tensor_methods.cpp

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
#include "absl/status/status.h"
1515
#include "absl/strings/str_cat.h"
1616
#include "absl/strings/str_split.h"
17+
#include "status.h"
1718
#include "torch_xla/csrc/LazyIr.h"
1819
#include "torch_xla/csrc/aten_xla_bridge.h"
1920
#include "torch_xla/csrc/data_ops.h"
@@ -506,6 +507,15 @@ absl::Status CheckMMMatrixSizesAreCompatible(const XLATensorPtr& mat1,
506507
return absl::OkStatus();
507508
}
508509

510+
absl::Status CheckUniformRangeIsValid(double from, double to) {
511+
if (from > to) {
512+
return XLA_ERROR_WITH_LOCATION(absl::InvalidArgumentError(
513+
absl::StrCat("uniform_(): expected `from` (", from,
514+
") to be smaller or equal `to` (", to, ").")));
515+
}
516+
return absl::OkStatus();
517+
}
518+
509519
} // namespace
510520

511521
//////////////////////////////////////////////////////////////////////////////
@@ -3653,15 +3663,16 @@ std::vector<XLATensorPtr> unbind(const XLATensorPtr& input, int64_t dim) {
36533663
return slices;
36543664
}
36553665

3656-
void uniform_(XLATensorPtr& input, double from, double to) {
3657-
XLA_CHECK_LE(from, to);
3658-
auto input_shape = input->shape();
3666+
absl::Status uniform_(XLATensorPtr& input, double from, double to) {
3667+
XLA_RETURN_IF_ERROR(CheckUniformRangeIsValid(from, to));
3668+
xla::Shape input_shape = input->shape();
36593669
input->SetInPlaceIrValue(torch_xla::MakeNode<Uniform>(
36603670
XLAGraphExecutor::Get()->GetIrValueForScalar(
3661-
from, input_shape.get().element_type(), input->GetDevice()),
3671+
from, input_shape.element_type(), input->GetDevice()),
36623672
XLAGraphExecutor::Get()->GetIrValueForScalar(
3663-
to, input_shape.get().element_type(), input->GetDevice()),
3673+
to, input_shape.element_type(), input->GetDevice()),
36643674
XLAGraphExecutor::Get()->GetRngSeed(input->GetDevice()), input_shape));
3675+
return absl::OkStatus();
36653676
}
36663677

36673678
XLATensorPtr unsqueeze(const XLATensorPtr& input, int64_t dim) {

torch_xla/csrc/tensor_methods.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -986,7 +986,7 @@ std::tuple<XLATensorPtr, XLATensorPtr> triangular_solve(
986986
// removed.
987987
std::vector<XLATensorPtr> unbind(const XLATensorPtr& input, int64_t dim);
988988

989-
void uniform_(XLATensorPtr& input, double from, double to);
989+
absl::Status uniform_(XLATensorPtr& input, double from, double to);
990990

991991
// Insert a dimension of size one at the specified position.
992992
XLATensorPtr unsqueeze(const XLATensorPtr& input, int64_t dim);

0 commit comments

Comments
 (0)