@@ -1702,16 +1702,14 @@ at::Tensor XLANativeFunctions::empty_symint(
17021702 // does not actually end up doing any memory initialization, we use that and
17031703 // avoid going to CPU for it. A common PT pattern is indeed doing empty() plus
17041704 // s_copy_().
1705- XLATensorPtr xla_tensor;
1706- if (all_dims_static) {
1707- xla_tensor = tensor_methods::full (XlaHelpers::I64List (int_sizes.value ()), 0 ,
1708- GetXlaDeviceOrCurrent (device),
1709- at::dtype_or_default (dtype));
1710- } else {
1711- xla_tensor =
1712- tensor_methods::full_symint (sym_size, 0 , GetXlaDeviceOrCurrent (device),
1713- at::dtype_or_default (dtype));
1714- }
1705+ XLATensorPtr xla_tensor = GetValueOrThrow (
1706+ all_dims_static
1707+ ? tensor_methods::full (XlaHelpers::I64List (int_sizes.value ()), 0 ,
1708+ GetXlaDeviceOrCurrent (device),
1709+ at::dtype_or_default (dtype))
1710+ : tensor_methods::full_symint (sym_size, 0 ,
1711+ GetXlaDeviceOrCurrent (device),
1712+ at::dtype_or_default (dtype)));
17151713 // `tensor.to` will trigger an `empty` + `_to_copy`. In the egaer mode, the
17161714 // `full` will be evulated eagerly and got a replicated sharding. We should
17171715 // leave the sharding to be empty.
@@ -1858,9 +1856,9 @@ at::Tensor XLANativeFunctions::full(at::IntArrayRef size,
18581856 } else {
18591857 intend_dtype = fill_value.type ();
18601858 }
1861- return bridge::AtenFromXlaTensor (
1859+ return bridge::AtenFromXlaTensor (GetValueOrThrow (
18621860 tensor_methods::full (absl::Span<const int64_t >(size), fill_value,
1863- GetXlaDeviceOrCurrent (device), intend_dtype));
1861+ GetXlaDeviceOrCurrent (device), intend_dtype))) ;
18641862}
18651863
18661864at::Tensor XLANativeFunctions::gather (const at::Tensor& self, int64_t dim,
@@ -2681,8 +2679,8 @@ std::tuple<at::Tensor, at::Tensor> XLANativeFunctions::nll_loss2d_forward(
26812679 int64_t ignore_index) {
26822680 TORCH_LAZY_FN_COUNTER_TIMED_TRACING (" xla::" );
26832681 XLATensorPtr self_tensor = GetValueOrThrow (bridge::GetXlaTensor (self));
2684- XLATensorPtr total_weight = tensor_methods::full (
2685- {}, 1 , self_tensor->GetDevice (), self_tensor->dtype ());
2682+ XLATensorPtr total_weight = GetValueOrThrow ( tensor_methods::full (
2683+ {}, 1 , self_tensor->GetDevice (), self_tensor->dtype ())) ;
26862684 return std::make_tuple (
26872685 bridge::AtenFromXlaTensor (tensor_methods::nll_loss2d (
26882686 self_tensor, GetValueOrThrow (bridge::GetXlaTensor (target)),
@@ -2716,8 +2714,8 @@ std::tuple<at::Tensor, at::Tensor> XLANativeFunctions::nll_loss_forward(
27162714 int64_t ignore_index) {
27172715 TORCH_LAZY_FN_COUNTER_TIMED_TRACING (" xla::" );
27182716 XLATensorPtr self_tensor = GetValueOrThrow (bridge::GetXlaTensor (self));
2719- XLATensorPtr total_weight = tensor_methods::full (
2720- {}, 1 , self_tensor->GetDevice (), self_tensor->dtype ());
2717+ XLATensorPtr total_weight = GetValueOrThrow ( tensor_methods::full (
2718+ {}, 1 , self_tensor->GetDevice (), self_tensor->dtype ())) ;
27212719 return std::make_tuple (
27222720 bridge::AtenFromXlaTensor (tensor_methods::nll_loss (
27232721 self_tensor, GetValueOrThrow (bridge::GetXlaTensor (target)),
@@ -4038,10 +4036,10 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> XLANativeFunctions::_linalg_svd(
40384036 if (!compute_uv) {
40394037 // When compute_uv is false, torch::_linalg_svd returns an empty tensor for
40404038 // u and vh.
4041- u = tensor_methods::full ({0 }, 0 , self_tensor->GetDevice (),
4042- self_tensor->dtype ());
4043- vh = tensor_methods::full ({0 }, 0 , self_tensor->GetDevice (),
4044- self_tensor->dtype ());
4039+ u = GetValueOrThrow ( tensor_methods::full ({0 }, 0 , self_tensor->GetDevice (),
4040+ self_tensor->dtype () ));
4041+ vh = GetValueOrThrow ( tensor_methods::full ({0 }, 0 , self_tensor->GetDevice (),
4042+ self_tensor->dtype () ));
40454043 }
40464044 return std::make_tuple (bridge::AtenFromXlaTensor (u),
40474045 bridge::AtenFromXlaTensor (s),
0 commit comments