Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
132 changes: 114 additions & 18 deletions google/cloud/firestore_v1/base_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -1134,12 +1134,8 @@ def _build_pipeline(self, source: "PipelineSource"):
"""
Convert this query into a Pipeline

Queries containing a `cursor` or `limit_to_last` are not currently supported

Args:
source: the PipelineSource to build the pipeline off of
Raises:
- NotImplementedError: raised if the query contains a `cursor` or `limit_to_last`
Returns:
a Pipeline representing the query
"""
Expand All @@ -1162,9 +1158,10 @@ def _build_pipeline(self, source: "PipelineSource"):

# Orders
orders = self._normalize_orders()

exists = []
orderings = []
if orders:
exists = []
orderings = []
for order in orders:
field = pipeline_expressions.Field.of(order.field.field_path)
exists.append(field.exists())
Expand All @@ -1178,23 +1175,59 @@ def _build_pipeline(self, source: "PipelineSource"):
# Add exists filters to match Query's implicit orderby semantics.
if len(exists) == 1:
ppl = ppl.where(exists[0])
else:
elif len(exists) > 1:
ppl = ppl.where(pipeline_expressions.And(*exists))

# Add sort orderings
ppl = ppl.sort(*orderings)
if orderings:
# Normalize cursors to get the raw values corresponding to the orders
start_at_val = None
if self._start_at:
start_at_val = self._normalize_cursor(self._start_at, orders)

end_at_val = None
if self._end_at:
end_at_val = self._normalize_cursor(self._end_at, orders)

# If limit_to_last is set, we need to reverse the orderings to find the
# "last" N documents (which effectively become the "first" N in reverse order).
if self._limit_to_last:
actual_orderings = _reverse_orderings(orderings)
ppl = ppl.sort(*actual_orderings)

# Apply cursor conditions.
# Cursors are translated into filter conditions (e.g., field > value)
# based on the orderings.
if start_at_val:
ppl = ppl.where(
_where_conditions_from_cursor(
start_at_val, orderings, is_start_cursor=True
)
)

# Cursors, Limit and Offset
if self._start_at or self._end_at or self._limit_to_last:
raise NotImplementedError(
"Query to Pipeline conversion: cursors and limit_to_last is not supported yet."
)
else: # Limit & Offset without cursors
if self._offset:
ppl = ppl.offset(self._offset)
if self._limit:
if end_at_val:
ppl = ppl.where(
_where_conditions_from_cursor(
end_at_val, orderings, is_start_cursor=False
)
)

if not self._limit_to_last:
ppl = ppl.sort(*orderings)

if self._limit is not None:
ppl = ppl.limit(self._limit)

# If we reversed the orderings for limit_to_last, we must now re-sort
# using the original orderings to return the results in the user-requested order.
if self._limit_to_last:
ppl = ppl.sort(*orderings)
elif self._limit is not None and not self._limit_to_last:
ppl = ppl.limit(self._limit)

# Offset
if self._offset:
ppl = ppl.offset(self._offset)

return ppl

def _comparator(self, doc1, doc2) -> int:
Expand Down Expand Up @@ -1366,6 +1399,69 @@ def _cursor_pb(cursor_pair: Optional[Tuple[list, bool]]) -> Optional[Cursor]:
return None


def _where_conditions_from_cursor(
cursor: Tuple[List, bool],
orderings: List[pipeline_expressions.Ordering],
is_start_cursor: bool,
) -> pipeline_expressions.BooleanExpression:
"""
Converts a cursor into a filter condition for the pipeline.

Args:
cursor: The cursor values and the 'before' flag.
orderings: The list of ordering expressions used in the query.
is_start_cursor: True if this is a start_at/start_after cursor, False if it is an end_at/end_before cursor.
Returns:
A BooleanExpression representing the cursor condition.
"""
cursor_values, before = cursor
size = len(cursor_values)

if is_start_cursor:
filter_func = pipeline_expressions.Expression.greater_than
else:
filter_func = pipeline_expressions.Expression.less_than

field = orderings[size - 1].expr
value = pipeline_expressions.Constant(cursor_values[size - 1])

# Add condition for last bound
condition = filter_func(field, value)

if (is_start_cursor and before) or (not is_start_cursor and not before):
# When the cursor bound is inclusive, then the last bound
# can be equal to the value, otherwise it's not equal
condition = pipeline_expressions.Or(condition, field.equal(value))

# Iterate backwards over the remaining bounds, adding a condition for each one
for i in range(size - 2, -1, -1):
field = orderings[i].expr
value = pipeline_expressions.Constant(cursor_values[i])

# For each field in the orderings, the condition is either
# a) lessThan|greaterThan the cursor value,
# b) or equal the cursor value and lessThan|greaterThan the cursor values for other fields
condition = pipeline_expressions.Or(
filter_func(field, value),
pipeline_expressions.And(field.equal(value), condition),
)

return condition


def _reverse_orderings(
orderings: List[pipeline_expressions.Ordering],
) -> List[pipeline_expressions.Ordering]:
reversed_orderings = []
for o in orderings:
if o.order_dir == pipeline_expressions.Ordering.Direction.ASCENDING:
new_dir = "descending"
else:
new_dir = "ascending"
reversed_orderings.append(pipeline_expressions.Ordering(o.expr, new_dir))
return reversed_orderings


def _query_response_to_snapshot(
response_pb: RunQueryResponse, collection, expected_prefix: str
) -> Optional[document.DocumentSnapshot]:
Expand Down
16 changes: 12 additions & 4 deletions tests/system/test_system.py
Original file line number Diff line number Diff line change
Expand Up @@ -1445,7 +1445,7 @@ def test_query_stream_w_field_path(query_docs, database):
verify_pipeline(query)


@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True)
@pytest.mark.parametrize("database", TEST_DATABASES_W_ENTERPRISE, indirect=True)
def test_query_stream_w_start_end_cursor(query_docs, database):
collection, stored, allowed_vals = query_docs
num_vals = len(allowed_vals)
Expand All @@ -1459,6 +1459,7 @@ def test_query_stream_w_start_end_cursor(query_docs, database):
for key, value in values:
assert stored[key] == value
assert value["a"] == num_vals - 2
verify_pipeline(query)


@pytest.mark.parametrize("database", TEST_DATABASES_W_ENTERPRISE, indirect=True)
Expand Down Expand Up @@ -1869,7 +1870,7 @@ def test_pipeline_w_read_time(query_docs, cleanup, database):
assert key != new_ref.id


@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True)
@pytest.mark.parametrize("database", TEST_DATABASES_W_ENTERPRISE, indirect=True)
def test_query_with_order_dot_key(client, cleanup, database):
db = client
collection_id = "collek" + UNIQUE_RESOURCE_ID
Expand Down Expand Up @@ -1906,6 +1907,9 @@ def test_query_with_order_dot_key(client, cleanup, database):
)
cursor_with_key_data = list(query4.stream())
assert found_data == [snap.to_dict() for snap in cursor_with_key_data]
verify_pipeline(query)
verify_pipeline(query2)
verify_pipeline(query3)


@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True)
Expand Down Expand Up @@ -1999,7 +2003,7 @@ def test_collection_group_queries(client, cleanup, database):
verify_pipeline(query)


@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True)
@pytest.mark.parametrize("database", TEST_DATABASES_W_ENTERPRISE, indirect=True)
def test_collection_group_queries_startat_endat(client, cleanup, database):
collection_group = "b" + UNIQUE_RESOURCE_ID

Expand Down Expand Up @@ -2030,6 +2034,7 @@ def test_collection_group_queries_startat_endat(client, cleanup, database):
snapshots = list(query.stream())
found = set(snapshot.id for snapshot in snapshots)
assert found == set(["cg-doc2", "cg-doc3", "cg-doc4"])
verify_pipeline(query)

query = (
client.collection_group(collection_group)
Expand All @@ -2040,6 +2045,7 @@ def test_collection_group_queries_startat_endat(client, cleanup, database):
snapshots = list(query.stream())
found = set(snapshot.id for snapshot in snapshots)
assert found == set(["cg-doc2"])
verify_pipeline(query)


@pytest.mark.parametrize("database", TEST_DATABASES_W_ENTERPRISE, indirect=True)
Expand Down Expand Up @@ -2860,6 +2866,7 @@ def test_repro_429(client, cleanup, database):
for snapshot in query2.stream():
print(f"id: {snapshot.id}")
verify_pipeline(query)
verify_pipeline(query2)


@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True)
Expand Down Expand Up @@ -3019,7 +3026,7 @@ def test_count_query_stream_empty_aggregation(query, database):
assert "Aggregations can not be empty" in exc_info.value.message


@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True)
@pytest.mark.parametrize("database", TEST_DATABASES_W_ENTERPRISE, indirect=True)
def test_count_query_with_start_at(query, database):
"""
Ensure that count aggregation queries work when chained with a start_at
Expand All @@ -3036,6 +3043,7 @@ def test_count_query_with_start_at(query, database):
for result in count_query.stream():
for aggregation_result in result:
assert aggregation_result.value == expected_count
verify_pipeline(count_query)


@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True)
Expand Down
Loading