1818#include < optional>
1919
2020#include " absl/log/absl_check.h"
21+ #include " status.h"
2122#include " torch/csrc/lazy/core/helpers.h"
2223#include " torch/csrc/lazy/core/shape_inference.h"
2324#include " torch/csrc/lazy/core/tensor_util.h"
@@ -317,18 +318,27 @@ int64_t GetIntegerUpperLimitForType(torch::ScalarType dtype) {
317318 }
318319}
319320
320- void CheckRangeValues (torch::ScalarType dtype, int64_t from, int64_t to) {
321- XlaHelpers::MinMax min_max;
322- // Bound the min_max by int64_t since types of "from" and "to" are int64.
323- if (IsTypeWithLargerRangeThanLong (dtype)) {
324- min_max = XlaHelpers::MinMaxValues (xla::PrimitiveType::S64);
325- } else {
326- min_max = XlaHelpers::MinMaxValues (XlaTypeFromTorchType (dtype));
321+ absl::Status CheckValueWithinTypeRange (const std::string_view op,
322+ const std::string_view arg,
323+ torch::ScalarType dtype, int64_t value) {
324+ xla::PrimitiveType type = IsTypeWithLargerRangeThanLong (dtype)
325+ ? xla::PrimitiveType::S64
326+ : XlaTypeFromTorchType (dtype);
327+
328+ XlaHelpers::MinMax mm = XlaHelpers::MinMaxValues (type);
329+ int64_t min = mm.min .toLong ();
330+ int64_t max = mm.max .toLong ();
331+
332+ if (value < min || value > max) {
333+ const std::string_view comparison = value < min ? " lower" : " greater" ;
334+ const std::string_view bound = value < min ? " lower bound" : " upper bound" ;
335+ return XLA_ERROR_WITH_LOCATION (absl::InvalidArgumentError (
336+ absl::StrCat (op, " (): expected `" , arg, " ` to be within the range [" ,
337+ min, " , " , max, " ]. However got value " , value,
338+ " , which is " , comparison, " than the " , bound, " ." )));
327339 }
328- XLA_CHECK_GE (from, min_max.min .toLong ());
329- XLA_CHECK_LE (from, min_max.max .toLong ());
330- XLA_CHECK_GE (to, min_max.min .toLong ());
331- XLA_CHECK_LE (to, min_max.max .toLong ());
340+
341+ return absl::OkStatus ();
332342}
333343
334344std::pair<XLATensorPtr, XLATensorPtr> GetBinaryOperands (
@@ -3025,12 +3035,14 @@ at::Tensor& XLANativeFunctions::random_(
30253035 }
30263036 XLATensorPtr self_tensor = GetValueOrThrow (bridge::GetXlaTensor (self));
30273037 at::ScalarType dtype = self_tensor->dtype ();
3038+
30283039 // Prevent "to_val" from overflowing with at::ScalarType::Long.
30293040 int64_t inc = (dtype == at::ScalarType::Long) ? 0 : 1 ;
30303041 int64_t to_val = (to) ? *to : GetIntegerUpperLimitForType (dtype) + inc;
3031- XLA_CHECK_LE (from, to_val);
3032- CheckRangeValues (self_tensor->dtype (), from, to_val - 1 );
3033- tensor_methods::random_ (self_tensor, from, to_val);
3042+
3043+ OkOrThrow (CheckValueWithinTypeRange (" random_" , " from" , dtype, from));
3044+ OkOrThrow (CheckValueWithinTypeRange (" random_" , " to" , dtype, to_val - 1 ));
3045+ OkOrThrow (tensor_methods::random_ (self_tensor, from, to_val));
30343046 return self;
30353047}
30363048
@@ -3043,10 +3055,12 @@ at::Tensor& XLANativeFunctions::random_(
30433055 ATEN_OP2 (random_, to)>::call (self, to,
30443056 generator);
30453057 }
3058+
30463059 XLATensorPtr self_tensor = GetValueOrThrow (bridge::GetXlaTensor (self));
3047- XLA_CHECK_GT (to, 0 );
3048- CheckRangeValues (self_tensor->dtype (), 0 , to - 1 );
3049- tensor_methods::random_ (self_tensor, 0 , to);
3060+ at::ScalarType dtype = self_tensor->dtype ();
3061+
3062+ OkOrThrow (CheckValueWithinTypeRange (" random_" , " to" , dtype, to - 1 ));
3063+ OkOrThrow (tensor_methods::random_ (self_tensor, 0 , to));
30503064 return self;
30513065}
30523066
@@ -3060,10 +3074,12 @@ at::Tensor& XLANativeFunctions::random_(
30603074 }
30613075 XLATensorPtr self_tensor = GetValueOrThrow (bridge::GetXlaTensor (self));
30623076 at::ScalarType dtype = self_tensor->dtype ();
3077+
30633078 // Prevent "to_val" from overflowing with at::ScalarType::Long.
30643079 int64_t inc = (dtype == at::ScalarType::Long) ? 0 : 1 ;
3065- tensor_methods::random_ (self_tensor, 0 ,
3066- GetIntegerUpperLimitForType (dtype) + inc);
3080+ int64_t to_val = GetIntegerUpperLimitForType (dtype) + inc;
3081+
3082+ OkOrThrow (tensor_methods::random_ (self_tensor, 0 , to_val));
30673083 return self;
30683084}
30693085
0 commit comments