diff --git a/test/test_ops_error_message.py b/test/test_ops_error_message.py index bbb9f4b95b7..baedcb0f831 100644 --- a/test/test_ops_error_message.py +++ b/test/test_ops_error_message.py @@ -179,3 +179,19 @@ def test(): callable=test, expect="""mm(): cannot matrix-multiply tensors f32[2,5] and f32[8,2]. Expected the size of dimension 1 of the first input tensor (5) to be equal the size of dimension 0 of the second input tensor (8).""" ) + + def test_clamp_scalar_raises_error_on_no_min_and_max(self): + device = torch_xla.device() + a = torch.rand(2, 5, device=device) + + def test(): + # Dispatch to `clamp()` overload explicitly. + # Otherwise, it's dispatched to `clamp.Tensor()`, which doesn't have + # this check. + return torch.ops.aten.clamp.default(a) + + self.assertExpectedRaisesInline( + exc_type=RuntimeError, + callable=test, + expect="""clamp(): expected at least one of `min` or `max` arguments to be specified.""" + ) diff --git a/torch_xla/csrc/aten_xla_type.cpp b/torch_xla/csrc/aten_xla_type.cpp index c042d703aa3..6654a808425 100644 --- a/torch_xla/csrc/aten_xla_type.cpp +++ b/torch_xla/csrc/aten_xla_type.cpp @@ -1372,24 +1372,31 @@ at::Tensor XLANativeFunctions::clamp(const at::Tensor& self, const std::optional& min, const std::optional& max) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); - XLA_ASSIGN_OR_THROW(XLATensorPtr xla_self, bridge::GetXlaTensor(self)); - return bridge::AtenFromXlaTensor(tensor_methods::clamp(xla_self, min, max)); + XLA_ASSIGN_OR_THROW(absl_nonnull XLATensorPtr xla_self, + bridge::GetXlaTensor(self)); + XLA_ASSIGN_OR_THROW(absl_nonnull XLATensorPtr output, + tensor_methods::clamp(xla_self, min, max)); + return bridge::AtenFromXlaTensor(std::move(output)); } at::Tensor XLANativeFunctions::clamp_max(const at::Tensor& self, const at::Scalar& max) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); - XLA_ASSIGN_OR_THROW(XLATensorPtr xla_self, bridge::GetXlaTensor(self)); - return bridge::AtenFromXlaTensor( - tensor_methods::clamp(xla_self, std::nullopt, max)); + XLA_ASSIGN_OR_THROW(absl_nonnull XLATensorPtr xla_self, + bridge::GetXlaTensor(self)); + XLA_ASSIGN_OR_THROW(absl_nonnull XLATensorPtr output, + tensor_methods::clamp(xla_self, std::nullopt, max)); + return bridge::AtenFromXlaTensor(std::move(output)); } at::Tensor XLANativeFunctions::clamp_min(const at::Tensor& self, const at::Scalar& min) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); - XLA_ASSIGN_OR_THROW(XLATensorPtr xla_self, bridge::GetXlaTensor(self)); - return bridge::AtenFromXlaTensor( - tensor_methods::clamp(xla_self, min, std::nullopt)); + XLA_ASSIGN_OR_THROW(absl_nonnull XLATensorPtr xla_self, + bridge::GetXlaTensor(self)); + XLA_ASSIGN_OR_THROW(absl_nonnull XLATensorPtr output, + tensor_methods::clamp(xla_self, min, std::nullopt)); + return bridge::AtenFromXlaTensor(std::move(output)); } at::Tensor XLANativeFunctions::clone( @@ -1947,9 +1954,11 @@ at::Tensor XLANativeFunctions::hardtanh(const at::Tensor& self, const at::Scalar& min_val, const at::Scalar& max_val) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); - XLA_ASSIGN_OR_THROW(XLATensorPtr xla_self, bridge::GetXlaTensor(self)); - return bridge::AtenFromXlaTensor( - tensor_methods::clamp(xla_self, min_val, max_val)); + XLA_ASSIGN_OR_THROW(absl_nonnull XLATensorPtr xla_self, + bridge::GetXlaTensor(self)); + XLA_ASSIGN_OR_THROW(absl_nonnull XLATensorPtr output, + tensor_methods::clamp(xla_self, min_val, max_val)); + return bridge::AtenFromXlaTensor(std::move(output)); } at::Tensor XLANativeFunctions::hardtanh_backward(const at::Tensor& grad_output, diff --git a/torch_xla/csrc/tensor_methods.cpp b/torch_xla/csrc/tensor_methods.cpp index 21f4db59713..ee6587f40e8 100644 --- a/torch_xla/csrc/tensor_methods.cpp +++ b/torch_xla/csrc/tensor_methods.cpp @@ -163,11 +163,6 @@ namespace torch_xla { namespace tensor_methods { namespace { -struct MinMaxValues { - torch::lazy::Value min; - torch::lazy::Value max; -}; - torch::lazy::Value MaybeExpand(const torch::lazy::Value& input, const xla::Shape& target_shape) { if (GetXlaShape(input).dimensions() == target_shape.dimensions()) { @@ -177,22 +172,6 @@ torch::lazy::Value MaybeExpand(const torch::lazy::Value& input, input, torch::lazy::ToVector(target_shape.dimensions())); } -MinMaxValues GetMinMaxValues(const XLATensorPtr& tensor, - const std::optional& min, - const std::optional& max) { - XLA_CHECK(min || max) - << "At least one of \'min\' or \'max\' must not be None"; - xla::PrimitiveType raw_element_type = XlaTypeFromTorchType(tensor->dtype()); - XlaHelpers::MinMax min_max = XlaHelpers::MinMaxValues(raw_element_type); - auto shape = tensor->shape(); - return {XLAGraphExecutor::Get()->GetIrValueForScalar( - min ? *min : min_max.min, shape.get().element_type(), - tensor->GetDevice()), - XLAGraphExecutor::Get()->GetIrValueForScalar( - max ? *max : min_max.max, shape.get().element_type(), - tensor->GetDevice())}; -} - void CheckRank(const XLATensorPtr& t, int64_t expected_rank, const std::string& tag, const std::string& arg_name, int arg_number) { @@ -506,6 +485,16 @@ absl::Status CheckMMMatrixSizesAreCompatible(const XLATensorPtr& mat1, return absl::OkStatus(); } +absl::Status CheckClampMinOrMax(const std::optional& min, + const std::optional& max) { + if (!min.has_value() && !max.has_value()) { + return XLA_ERROR_WITH_LOCATION( + absl::InvalidArgumentError("clamp(): expected at least one of `min` or " + "`max` arguments to be specified.")); + } + return absl::OkStatus(); +} + } // namespace ////////////////////////////////////////////////////////////////////////////// @@ -1357,12 +1346,23 @@ void celu_(XLATensorPtr& input, const at::Scalar& alpha) { input->SetInPlaceIrValue(Celu(input->GetIrValue(), alpha)); } -XLATensorPtr clamp(const XLATensorPtr& input, - const std::optional& min, - const std::optional& max) { - MinMaxValues min_max = GetMinMaxValues(input, min, max); - return input->CreateFrom( - Clamp(input->GetIrValue(), min_max.min, min_max.max)); +absl::StatusOr clamp( + const XLATensorPtr& input, const std::optional& min, + const std::optional& max) { + XLA_RETURN_IF_ERROR(CheckClampMinOrMax(min, max)); + + xla::Shape shape = input->shape(); + const torch::lazy::BackendDevice& device = input->GetDevice(); + + xla::PrimitiveType raw_element_type = XlaTypeFromTorchType(input->dtype()); + XlaHelpers::MinMax min_max = XlaHelpers::MinMaxValues(raw_element_type); + + torch::lazy::Value min_value = XLAGraphExecutor::Get()->GetIrValueForScalar( + min.value_or(min_max.min), shape.element_type(), device); + torch::lazy::Value max_value = XLAGraphExecutor::Get()->GetIrValueForScalar( + max.value_or(min_max.max), shape.element_type(), device); + + return input->CreateFrom(Clamp(input->GetIrValue(), min_value, max_value)); } XLATensorPtr clone(const XLATensorPtr& input) { diff --git a/torch_xla/csrc/tensor_methods.h b/torch_xla/csrc/tensor_methods.h index b25b423d49c..e980d483502 100644 --- a/torch_xla/csrc/tensor_methods.h +++ b/torch_xla/csrc/tensor_methods.h @@ -316,12 +316,9 @@ XLATensorPtr pixel_shuffle(const XLATensorPtr& self, int64_t upscale_factor); XLATensorPtr celu(const XLATensorPtr& input, const at::Scalar& alpha); void celu_(XLATensorPtr& input, const at::Scalar& alpha); -XLATensorPtr clamp(const XLATensorPtr& input, - const std::optional& min, - const std::optional& max); -XLATensorPtr clamp(const XLATensorPtr& input, - const std::optional& min, - const std::optional& max); +absl::StatusOr clamp( + const XLATensorPtr& input, const std::optional& min, + const std::optional& max); XLATensorPtr clone(const XLATensorPtr& input);