Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions test/test_ops_error_message.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,3 +179,19 @@ def test():
callable=test,
expect="""mm(): cannot matrix-multiply tensors f32[2,5] and f32[8,2]. Expected the size of dimension 1 of the first input tensor (5) to be equal the size of dimension 0 of the second input tensor (8)."""
)

def test_clamp_scalar_raises_error_on_no_min_and_max(self):
device = torch_xla.device()
a = torch.rand(2, 5, device=device)

def test():
# Dispatch to `clamp()` overload explicitly.
# Otherwise, it's dispatched to `clamp.Tensor()`, which doesn't have
# this check.
return torch.ops.aten.clamp.default(a)

self.assertExpectedRaisesInline(
exc_type=RuntimeError,
callable=test,
expect="""clamp(): expected at least one of `min` or `max` arguments to be specified."""
)
31 changes: 20 additions & 11 deletions torch_xla/csrc/aten_xla_type.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1372,24 +1372,31 @@ at::Tensor XLANativeFunctions::clamp(const at::Tensor& self,
const std::optional<at::Scalar>& min,
const std::optional<at::Scalar>& max) {
TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::");
XLA_ASSIGN_OR_THROW(XLATensorPtr xla_self, bridge::GetXlaTensor(self));
return bridge::AtenFromXlaTensor(tensor_methods::clamp(xla_self, min, max));
XLA_ASSIGN_OR_THROW(absl_nonnull XLATensorPtr xla_self,
bridge::GetXlaTensor(self));
XLA_ASSIGN_OR_THROW(absl_nonnull XLATensorPtr output,
tensor_methods::clamp(xla_self, min, max));
return bridge::AtenFromXlaTensor(std::move(output));
}

at::Tensor XLANativeFunctions::clamp_max(const at::Tensor& self,
const at::Scalar& max) {
TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::");
XLA_ASSIGN_OR_THROW(XLATensorPtr xla_self, bridge::GetXlaTensor(self));
return bridge::AtenFromXlaTensor(
tensor_methods::clamp(xla_self, std::nullopt, max));
XLA_ASSIGN_OR_THROW(absl_nonnull XLATensorPtr xla_self,
bridge::GetXlaTensor(self));
XLA_ASSIGN_OR_THROW(absl_nonnull XLATensorPtr output,
tensor_methods::clamp(xla_self, std::nullopt, max));
return bridge::AtenFromXlaTensor(std::move(output));
}

at::Tensor XLANativeFunctions::clamp_min(const at::Tensor& self,
const at::Scalar& min) {
TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::");
XLA_ASSIGN_OR_THROW(XLATensorPtr xla_self, bridge::GetXlaTensor(self));
return bridge::AtenFromXlaTensor(
tensor_methods::clamp(xla_self, min, std::nullopt));
XLA_ASSIGN_OR_THROW(absl_nonnull XLATensorPtr xla_self,
bridge::GetXlaTensor(self));
XLA_ASSIGN_OR_THROW(absl_nonnull XLATensorPtr output,
tensor_methods::clamp(xla_self, min, std::nullopt));
return bridge::AtenFromXlaTensor(std::move(output));
}

at::Tensor XLANativeFunctions::clone(
Expand Down Expand Up @@ -1947,9 +1954,11 @@ at::Tensor XLANativeFunctions::hardtanh(const at::Tensor& self,
const at::Scalar& min_val,
const at::Scalar& max_val) {
TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::");
XLA_ASSIGN_OR_THROW(XLATensorPtr xla_self, bridge::GetXlaTensor(self));
return bridge::AtenFromXlaTensor(
tensor_methods::clamp(xla_self, min_val, max_val));
XLA_ASSIGN_OR_THROW(absl_nonnull XLATensorPtr xla_self,
bridge::GetXlaTensor(self));
XLA_ASSIGN_OR_THROW(absl_nonnull XLATensorPtr output,
tensor_methods::clamp(xla_self, min_val, max_val));
return bridge::AtenFromXlaTensor(std::move(output));
}

at::Tensor XLANativeFunctions::hardtanh_backward(const at::Tensor& grad_output,
Expand Down
54 changes: 27 additions & 27 deletions torch_xla/csrc/tensor_methods.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -163,11 +163,6 @@ namespace torch_xla {
namespace tensor_methods {
namespace {

struct MinMaxValues {
torch::lazy::Value min;
torch::lazy::Value max;
};

torch::lazy::Value MaybeExpand(const torch::lazy::Value& input,
const xla::Shape& target_shape) {
if (GetXlaShape(input).dimensions() == target_shape.dimensions()) {
Expand All @@ -177,22 +172,6 @@ torch::lazy::Value MaybeExpand(const torch::lazy::Value& input,
input, torch::lazy::ToVector<int64_t>(target_shape.dimensions()));
}

MinMaxValues GetMinMaxValues(const XLATensorPtr& tensor,
const std::optional<at::Scalar>& min,
const std::optional<at::Scalar>& max) {
XLA_CHECK(min || max)
<< "At least one of \'min\' or \'max\' must not be None";
xla::PrimitiveType raw_element_type = XlaTypeFromTorchType(tensor->dtype());
XlaHelpers::MinMax min_max = XlaHelpers::MinMaxValues(raw_element_type);
auto shape = tensor->shape();
return {XLAGraphExecutor::Get()->GetIrValueForScalar(
min ? *min : min_max.min, shape.get().element_type(),
tensor->GetDevice()),
XLAGraphExecutor::Get()->GetIrValueForScalar(
max ? *max : min_max.max, shape.get().element_type(),
tensor->GetDevice())};
}

void CheckRank(const XLATensorPtr& t, int64_t expected_rank,
const std::string& tag, const std::string& arg_name,
int arg_number) {
Expand Down Expand Up @@ -506,6 +485,16 @@ absl::Status CheckMMMatrixSizesAreCompatible(const XLATensorPtr& mat1,
return absl::OkStatus();
}

absl::Status CheckClampMinOrMax(const std::optional<at::Scalar>& min,
const std::optional<at::Scalar>& max) {
if (!min.has_value() && !max.has_value()) {
return XLA_ERROR_WITH_LOCATION(
absl::InvalidArgumentError("clamp(): expected at least one of `min` or "
"`max` arguments to be specified."));
}
return absl::OkStatus();
}

} // namespace

//////////////////////////////////////////////////////////////////////////////
Expand Down Expand Up @@ -1357,12 +1346,23 @@ void celu_(XLATensorPtr& input, const at::Scalar& alpha) {
input->SetInPlaceIrValue(Celu(input->GetIrValue(), alpha));
}

XLATensorPtr clamp(const XLATensorPtr& input,
const std::optional<at::Scalar>& min,
const std::optional<at::Scalar>& max) {
MinMaxValues min_max = GetMinMaxValues(input, min, max);
return input->CreateFrom(
Clamp(input->GetIrValue(), min_max.min, min_max.max));
absl::StatusOr<absl_nonnull XLATensorPtr> clamp(
const XLATensorPtr& input, const std::optional<at::Scalar>& min,
const std::optional<at::Scalar>& max) {
XLA_RETURN_IF_ERROR(CheckClampMinOrMax(min, max));

xla::Shape shape = input->shape();
const torch::lazy::BackendDevice& device = input->GetDevice();

xla::PrimitiveType raw_element_type = XlaTypeFromTorchType(input->dtype());
XlaHelpers::MinMax min_max = XlaHelpers::MinMaxValues(raw_element_type);

torch::lazy::Value min_value = XLAGraphExecutor::Get()->GetIrValueForScalar(
min.value_or(min_max.min), shape.element_type(), device);
torch::lazy::Value max_value = XLAGraphExecutor::Get()->GetIrValueForScalar(
max.value_or(min_max.max), shape.element_type(), device);

return input->CreateFrom(Clamp(input->GetIrValue(), min_value, max_value));
}

XLATensorPtr clone(const XLATensorPtr& input) {
Expand Down
9 changes: 3 additions & 6 deletions torch_xla/csrc/tensor_methods.h
Original file line number Diff line number Diff line change
Expand Up @@ -316,12 +316,9 @@ XLATensorPtr pixel_shuffle(const XLATensorPtr& self, int64_t upscale_factor);
XLATensorPtr celu(const XLATensorPtr& input, const at::Scalar& alpha);
void celu_(XLATensorPtr& input, const at::Scalar& alpha);

XLATensorPtr clamp(const XLATensorPtr& input,
const std::optional<at::Scalar>& min,
const std::optional<at::Scalar>& max);
XLATensorPtr clamp(const XLATensorPtr& input,
const std::optional<at::Tensor>& min,
const std::optional<at::Tensor>& max);
absl::StatusOr<absl_nonnull XLATensorPtr> clamp(
const XLATensorPtr& input, const std::optional<at::Scalar>& min,
const std::optional<at::Scalar>& max);

XLATensorPtr clone(const XLATensorPtr& input);

Expand Down