@@ -442,6 +442,43 @@ absl::Status CheckFullConcreteSizesArePositive(at::SymIntArrayRef sym_sizes) {
442442 });
443443}
444444
445+ absl::Status CheckGatherRanksAreEqual (const XLATensorPtr& input,
446+ const XLATensorPtr& index) {
447+ int64_t input_rank = input->shape ().get ().dimensions_size ();
448+ int64_t index_rank = index->shape ().get ().dimensions_size ();
449+ if (input_rank != index_rank) {
450+ return XLA_ERROR_WITH_LOCATION (absl::InvalidArgumentError (absl::StrCat (
451+ " gather(): expected rank of input (" , input_rank, " ) and index (" ,
452+ index_rank, " ) tensors to be the same." )));
453+ }
454+ return absl::OkStatus ();
455+ }
456+
457+ // Checks that all index dimensions are smaller or equal to those of input,
458+ // except on dimension canonical_dim.
459+ absl::Status CheckGatherDimensionsAreCompatible (const XLATensorPtr& input,
460+ const XLATensorPtr& index,
461+ int64_t canonical_dim) {
462+ // Dimensions that fail the "smaller or equal" condition.
463+ std::vector<int64_t > bad_dims;
464+ for (int64_t dim = 0 ; dim < input->shape ().get ().dimensions_size (); dim++) {
465+ if (dim != canonical_dim && input->size (dim) < index->size (dim)) {
466+ bad_dims.push_back (dim);
467+ }
468+ }
469+ if (!bad_dims.empty ()) {
470+ return XLA_ERROR_WITH_LOCATION (absl::InvalidArgumentError (absl::StrCat (
471+ " gather(): expected sizes of index [" ,
472+ absl::StrJoin (index->shape ().get ().dimensions (), /* sep= */ " , " ),
473+ " ] to be smaller or equal those of input [" ,
474+ absl::StrJoin (input->shape ().get ().dimensions (), /* sep= */ " , " ),
475+ " ] on all dimensions, except on dimension " , canonical_dim,
476+ " . However, that's not true on dimensions [" ,
477+ absl::StrJoin (bad_dims, /* sep= */ " , " ), " ]." )));
478+ }
479+ return absl::OkStatus ();
480+ }
481+
445482} // namespace
446483
447484// ////////////////////////////////////////////////////////////////////////////
@@ -1838,18 +1875,14 @@ absl::StatusOr<absl_nonnull XLATensorPtr> full_symint(
18381875 device, scalar_type);
18391876}
18401877
1841- XLATensorPtr gather (const XLATensorPtr& input, int64_t dim,
1842- const XLATensorPtr& index) {
1843- xla::Shape input_shape = input->shape ();
1844- xla::Shape index_shape = index->shape ();
1845- XLA_CHECK_EQ (input_shape.dimensions_size (), index_shape.dimensions_size ());
1878+ absl::StatusOr<absl_nonnull XLATensorPtr> gather (const XLATensorPtr& input,
1879+ int64_t dim,
1880+ const XLATensorPtr& index) {
18461881 int64_t canonical_dim = torch::lazy::GetCanonicalDimensionIndex (
1847- dim, input_shape.dimensions_size ());
1848- for (size_t dim = 0 ; dim < input_shape.dimensions_size (); dim++) {
1849- if (dim != canonical_dim) {
1850- XLA_CHECK_LE (index->size (dim), input->size (dim));
1851- }
1852- }
1882+ dim, input->shape ().get ().dimensions_size ());
1883+ XLA_RETURN_IF_ERROR (CheckGatherRanksAreEqual (input, index));
1884+ XLA_RETURN_IF_ERROR (
1885+ CheckGatherDimensionsAreCompatible (input, index, canonical_dim));
18531886 return input->CreateFrom (torch_xla::MakeNode<Gather>(
18541887 input->GetIrValue (), canonical_dim, index->GetIrValue ()));
18551888}
0 commit comments