Skip to content

Commit e04a32f

Browse files
committed
Improve error handling and error messages for uniform_.
1 parent 0c0ae2d commit e04a32f

File tree

4 files changed

+32
-8
lines changed

4 files changed

+32
-8
lines changed

test/test_operations.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2367,6 +2367,19 @@ def test_isneginf_no_fallback(self):
23672367
t = t.to(torch.float16)
23682368
self._test_no_fallback(torch.isneginf, (t,))
23692369

2370+
def test_uniform__raises_error_on_invalid_range(self):
2371+
device = torch_xla.device()
2372+
a = torch.empty(5, 5, device=device)
2373+
from_ = 5.
2374+
to_ = 2.
2375+
2376+
try:
2377+
a.uniform_(from_, to_)
2378+
except RuntimeError as e:
2379+
expected_error = (
2380+
"uniform_(): expected `from` (5) to be smaller or equal `to` (2).")
2381+
self.assertEqual(str(e), expected_error)
2382+
23702383

23712384
class MNISTComparator(nn.Module):
23722385

torch_xla/csrc/aten_xla_type.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3916,8 +3916,9 @@ at::Tensor& XLANativeFunctions::uniform_(
39163916
return at::native::call_fallback_fn<&xla_fallback, ATEN_OP(uniform_)>::call(
39173917
self, from, to, generator);
39183918
}
3919-
XLA_ASSIGN_OR_THROW(XLATensorPtr xla_self, bridge::GetXlaTensor(self));
3920-
tensor_methods::uniform_(xla_self, from, to);
3919+
XLA_ASSIGN_OR_THROW(absl_nonnull XLATensorPtr xla_self,
3920+
bridge::GetXlaTensor(self));
3921+
XLA_THROW_IF_ERROR(tensor_methods::uniform_(xla_self, from, to));
39213922
return self;
39223923
}
39233924

torch_xla/csrc/tensor_methods.cpp

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -578,6 +578,15 @@ absl::Status CheckStackAtLeastOneTensor(
578578
return absl::OkStatus();
579579
}
580580

581+
absl::Status CheckUniformRangeIsValid(double from, double to) {
582+
if (from > to) {
583+
return XLA_ERROR_WITH_LOCATION(absl::InvalidArgumentError(
584+
absl::StrCat("uniform_(): expected `from` (", from,
585+
") to be smaller or equal `to` (", to, ").")));
586+
}
587+
return absl::OkStatus();
588+
}
589+
581590
} // namespace
582591

583592
//////////////////////////////////////////////////////////////////////////////
@@ -3727,15 +3736,16 @@ std::vector<XLATensorPtr> unbind(const XLATensorPtr& input, int64_t dim) {
37273736
return slices;
37283737
}
37293738

3730-
void uniform_(XLATensorPtr& input, double from, double to) {
3731-
XLA_CHECK_LE(from, to);
3732-
auto input_shape = input->shape();
3739+
absl::Status uniform_(XLATensorPtr& input, double from, double to) {
3740+
XLA_RETURN_IF_ERROR(CheckUniformRangeIsValid(from, to));
3741+
xla::Shape input_shape = input->shape();
37333742
input->SetInPlaceIrValue(torch_xla::MakeNode<Uniform>(
37343743
XLAGraphExecutor::Get()->GetIrValueForScalar(
3735-
from, input_shape.get().element_type(), input->GetDevice()),
3744+
from, input_shape.element_type(), input->GetDevice()),
37363745
XLAGraphExecutor::Get()->GetIrValueForScalar(
3737-
to, input_shape.get().element_type(), input->GetDevice()),
3746+
to, input_shape.element_type(), input->GetDevice()),
37383747
XLAGraphExecutor::Get()->GetRngSeed(input->GetDevice()), input_shape));
3748+
return absl::OkStatus();
37393749
}
37403750

37413751
XLATensorPtr unsqueeze(const XLATensorPtr& input, int64_t dim) {

torch_xla/csrc/tensor_methods.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -989,7 +989,7 @@ std::tuple<XLATensorPtr, XLATensorPtr> triangular_solve(
989989
// removed.
990990
std::vector<XLATensorPtr> unbind(const XLATensorPtr& input, int64_t dim);
991991

992-
void uniform_(XLATensorPtr& input, double from, double to);
992+
absl::Status uniform_(XLATensorPtr& input, double from, double to);
993993

994994
// Insert a dimension of size one at the specified position.
995995
XLATensorPtr unsqueeze(const XLATensorPtr& input, int64_t dim);

0 commit comments

Comments
 (0)