From 59fa92aca693179c7559715a9a4390c0a0618fae Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Tue, 6 Jan 2026 21:23:53 -0800 Subject: [PATCH 1/6] implemented cursors --- google/cloud/firestore_v1/base_query.py | 127 ++++++++++++++++++++---- 1 file changed, 109 insertions(+), 18 deletions(-) diff --git a/google/cloud/firestore_v1/base_query.py b/google/cloud/firestore_v1/base_query.py index b1b74fcf1..cfdee2a8e 100644 --- a/google/cloud/firestore_v1/base_query.py +++ b/google/cloud/firestore_v1/base_query.py @@ -1134,12 +1134,8 @@ def _build_pipeline(self, source: "PipelineSource"): """ Convert this query into a Pipeline - Queries containing a `cursor` or `limit_to_last` are not currently supported - Args: source: the PipelineSource to build the pipeline off of - Raises: - - NotImplementedError: raised if the query contains a `cursor` or `limit_to_last` Returns: a Pipeline representing the query """ @@ -1162,9 +1158,10 @@ def _build_pipeline(self, source: "PipelineSource"): # Orders orders = self._normalize_orders() + + exists = [] + orderings = [] if orders: - exists = [] - orderings = [] for order in orders: field = pipeline_expressions.Field.of(order.field.field_path) exists.append(field.exists()) @@ -1178,23 +1175,58 @@ def _build_pipeline(self, source: "PipelineSource"): # Add exists filters to match Query's implicit orderby semantics. if len(exists) == 1: ppl = ppl.where(exists[0]) - else: + elif len(exists) > 1: ppl = ppl.where(pipeline_expressions.And(*exists)) - # Add sort orderings - ppl = ppl.sort(*orderings) + if orderings: + # Normalize cursors to get the raw values corresponding to the orders + start_at_val = None + if self._start_at: + start_at_val = self._normalize_cursor(self._start_at, orders) + + end_at_val = None + if self._end_at: + end_at_val = self._normalize_cursor(self._end_at, orders) + + # If limit_to_last is set, we need to reverse the orderings to find the + # "last" N documents (which effectively become the "first" N in reverse order). + if self._limit_to_last: + actual_orderings = _reverse_orderings(orderings) + ppl = ppl.sort(*actual_orderings) + else: + ppl = ppl.sort(*orderings) + + # Apply cursor conditions. + # Cursors are translated into filter conditions (e.g., field > value) + # based on the orderings. + if start_at_val: + ppl = ppl.where( + _where_conditions_from_cursor( + start_at_val, orderings, is_start_cursor=True + ) + ) - # Cursors, Limit and Offset - if self._start_at or self._end_at or self._limit_to_last: - raise NotImplementedError( - "Query to Pipeline conversion: cursors and limit_to_last is not supported yet." - ) - else: # Limit & Offset without cursors - if self._offset: - ppl = ppl.offset(self._offset) - if self._limit: + if end_at_val: + ppl = ppl.where( + _where_conditions_from_cursor( + end_at_val, orderings, is_start_cursor=False + ) + ) + + if self._limit is not None: ppl = ppl.limit(self._limit) + # If we reversed the orderings for limit_to_last, we must now re-sort + # using the original orderings to return the results in the user-requested order. + if self._limit_to_last: + ppl = ppl.sort(*orderings) + elif self._limit is not None and not self._limit_to_last: + ppl = ppl.limit(self._limit) + + # Offset + if self._offset: + ppl = ppl.offset(self._offset) + return ppl def _comparator(self, doc1, doc2) -> int: @@ -1366,6 +1398,65 @@ def _cursor_pb(cursor_pair: Optional[Tuple[list, bool]]) -> Optional[Cursor]: return None +def _where_conditions_from_cursor( + cursor: Tuple[List, bool], + orderings: List[pipeline_expressions.Ordering], + is_start_cursor: bool, +) -> pipeline_expressions.BooleanExpression: + """ + Converts a cursor into a filter condition for the pipeline. + + Args: + cursor: The cursor values and the 'before' flag. + orderings: The list of ordering expressions used in the query. + is_start_cursor: True if this is a start_at/start_after cursor, False if it is an end_at/end_before cursor. + Returns: + A BooleanExpression representing the cursor condition. + """ + cursor_values, before = cursor + size = len(cursor_values) + + field = orderings[size - 1].expr + value = pipeline_expressions.Constant(cursor_values[size - 1]) + + if not is_start_cursor: + condition = field.less_than(value) + else: + condition = field.greater_than(value) + + if (is_start_cursor and before) or (not is_start_cursor and not before): + condition = pipeline_expressions.Or(condition, field.equal(value)) + + for i in range(size - 2, -1, -1): + field = orderings[i].expr + value = pipeline_expressions.Constant(cursor_values[i]) + + if not is_start_cursor: + current_filter = field.less_than(value) + else: + current_filter = field.greater_than(value) + + condition = pipeline_expressions.Or( + current_filter, + pipeline_expressions.And(field.equal(value), condition), + ) + + return condition + + +def _reverse_orderings( + orderings: List[pipeline_expressions.Ordering], +) -> List[pipeline_expressions.Ordering]: + reversed_orderings = [] + for o in orderings: + if o.order_dir == pipeline_expressions.Ordering.Direction.ASCENDING: + new_dir = "descending" + else: + new_dir = "ascending" + reversed_orderings.append(pipeline_expressions.Ordering(o.expr, new_dir)) + return reversed_orderings + + def _query_response_to_snapshot( response_pb: RunQueryResponse, collection, expected_prefix: str ) -> Optional[document.DocumentSnapshot]: From f9687a0aabe80252af1f8e679a59f0fba765b826 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Tue, 6 Jan 2026 21:31:20 -0800 Subject: [PATCH 2/6] add unit tests --- tests/unit/v1/test_base_query.py | 59 ++++++++++++++++++++++++++------ 1 file changed, 49 insertions(+), 10 deletions(-) diff --git a/tests/unit/v1/test_base_query.py b/tests/unit/v1/test_base_query.py index 4a4dac727..b9ce1fde9 100644 --- a/tests/unit/v1/test_base_query.py +++ b/tests/unit/v1/test_base_query.py @@ -2116,19 +2116,58 @@ def test__query_pipeline_order_sorts(): assert sort_stage.orders[1].order_dir == expr.Ordering.Direction.DESCENDING -def test__query_pipeline_unsupported(): +def test__query_pipeline_cursors(): client = make_client() - query_start = client.collection("my_col").start_at({"field_a": "value"}) - with pytest.raises(NotImplementedError, match="cursors"): - query_start._build_pipeline(client.pipeline()) + query_start = ( + client.collection("my_col").order_by("field_a").start_at({"field_a": "value"}) + ) + pipeline = query_start._build_pipeline(client.pipeline()) + assert len(pipeline.stages) >= 2 + + query_end = ( + client.collection("my_col").order_by("field_a").end_at({"field_a": "value"}) + ) + pipeline = query_end._build_pipeline(client.pipeline()) + assert len(pipeline.stages) >= 2 - query_end = client.collection("my_col").end_at({"field_a": "value"}) - with pytest.raises(NotImplementedError, match="cursors"): - query_end._build_pipeline(client.pipeline()) - query_limit_last = client.collection("my_col").limit_to_last(10) - with pytest.raises(NotImplementedError, match="limit_to_last"): - query_limit_last._build_pipeline(client.pipeline()) +def test__query_pipeline_limit_to_last(): + from google.cloud.firestore_v1 import pipeline_expressions as expr + + client = make_client() + query_limit_last = ( + client.collection("my_col").order_by("field_a").limit_to_last(10) + ) + pipeline = query_limit_last._build_pipeline(client.pipeline()) + # stages: collection, exists, sort(desc), limit, sort(asc) + + assert len(pipeline.stages) == 5 + + # 0. Collection + assert pipeline.stages[0].path == "/my_col" + + # 1. Exists + exists_stage = pipeline.stages[1] + assert isinstance(exists_stage, stages.Where) + + # 2. Sort DESCENDING (reversed) + sort_desc = pipeline.stages[2] + assert isinstance(sort_desc, stages.Sort) + assert len(sort_desc.orders) == 1 + assert sort_desc.orders[0].expr.path == "field_a" + assert sort_desc.orders[0].order_dir == expr.Ordering.Direction.DESCENDING + + # 3. Limit + limit_stage = pipeline.stages[3] + assert isinstance(limit_stage, stages.Limit) + assert limit_stage.limit == 10 + + # 4. Sort ASCENDING (original) + sort_asc = pipeline.stages[4] + assert isinstance(sort_asc, stages.Sort) + assert len(sort_asc.orders) == 1 + assert sort_asc.orders[0].expr.path == "field_a" + assert sort_asc.orders[0].order_dir == expr.Ordering.Direction.ASCENDING def test__query_pipeline_limit(): From c8d3cec871ad78a3752aa07ebdc0fc8e57462501 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Tue, 6 Jan 2026 21:36:19 -0800 Subject: [PATCH 3/6] added more verify pipeline lines to system tests --- tests/system/test_system.py | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/tests/system/test_system.py b/tests/system/test_system.py index 61b1a983c..1b6910bb7 100644 --- a/tests/system/test_system.py +++ b/tests/system/test_system.py @@ -1444,7 +1444,7 @@ def test_query_stream_w_field_path(query_docs, database): verify_pipeline(query) -@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES_W_ENTERPRISE, indirect=True) def test_query_stream_w_start_end_cursor(query_docs, database): collection, stored, allowed_vals = query_docs num_vals = len(allowed_vals) @@ -1458,6 +1458,7 @@ def test_query_stream_w_start_end_cursor(query_docs, database): for key, value in values: assert stored[key] == value assert value["a"] == num_vals - 2 + verify_pipeline(query) @pytest.mark.parametrize("database", TEST_DATABASES_W_ENTERPRISE, indirect=True) @@ -1733,7 +1734,7 @@ def test_pipeline_w_read_time(query_docs, cleanup, database): assert key != new_ref.id -@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES_W_ENTERPRISE, indirect=True) def test_query_with_order_dot_key(client, cleanup, database): db = client collection_id = "collek" + UNIQUE_RESOURCE_ID @@ -1770,6 +1771,9 @@ def test_query_with_order_dot_key(client, cleanup, database): ) cursor_with_key_data = list(query4.stream()) assert found_data == [snap.to_dict() for snap in cursor_with_key_data] + verify_pipeline(query) + verify_pipeline(query2) + verify_pipeline(query3) @pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) @@ -1863,7 +1867,7 @@ def test_collection_group_queries(client, cleanup, database): verify_pipeline(query) -@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES_W_ENTERPRISE, indirect=True) def test_collection_group_queries_startat_endat(client, cleanup, database): collection_group = "b" + UNIQUE_RESOURCE_ID @@ -1894,6 +1898,7 @@ def test_collection_group_queries_startat_endat(client, cleanup, database): snapshots = list(query.stream()) found = set(snapshot.id for snapshot in snapshots) assert found == set(["cg-doc2", "cg-doc3", "cg-doc4"]) + verify_pipeline(query) query = ( client.collection_group(collection_group) @@ -1904,6 +1909,7 @@ def test_collection_group_queries_startat_endat(client, cleanup, database): snapshots = list(query.stream()) found = set(snapshot.id for snapshot in snapshots) assert found == set(["cg-doc2"]) + verify_pipeline(query) @pytest.mark.parametrize("database", TEST_DATABASES_W_ENTERPRISE, indirect=True) @@ -2724,6 +2730,7 @@ def test_repro_429(client, cleanup, database): for snapshot in query2.stream(): print(f"id: {snapshot.id}") verify_pipeline(query) + verify_pipeline(query2) @pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) @@ -2883,7 +2890,7 @@ def test_count_query_stream_empty_aggregation(query, database): assert "Aggregations can not be empty" in exc_info.value.message -@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES_W_ENTERPRISE, indirect=True) def test_count_query_with_start_at(query, database): """ Ensure that count aggregation queries work when chained with a start_at @@ -2900,6 +2907,7 @@ def test_count_query_with_start_at(query, database): for result in count_query.stream(): for aggregation_result in result: assert aggregation_result.value == expected_count + verify_pipeline(count_query) @pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) From 326315ab40a6b86e2ee1df958f4fe887a2d8f67e Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Tue, 6 Jan 2026 22:03:43 -0800 Subject: [PATCH 4/6] re-ordered sort --- google/cloud/firestore_v1/base_query.py | 5 +- tests/unit/v1/test_base_query.py | 96 +++++++++++++++++++++++++ 2 files changed, 99 insertions(+), 2 deletions(-) diff --git a/google/cloud/firestore_v1/base_query.py b/google/cloud/firestore_v1/base_query.py index cfdee2a8e..56ce4eb25 100644 --- a/google/cloud/firestore_v1/base_query.py +++ b/google/cloud/firestore_v1/base_query.py @@ -1193,8 +1193,6 @@ def _build_pipeline(self, source: "PipelineSource"): if self._limit_to_last: actual_orderings = _reverse_orderings(orderings) ppl = ppl.sort(*actual_orderings) - else: - ppl = ppl.sort(*orderings) # Apply cursor conditions. # Cursors are translated into filter conditions (e.g., field > value) @@ -1213,6 +1211,9 @@ def _build_pipeline(self, source: "PipelineSource"): ) ) + if not self._limit_to_last: + ppl = ppl.sort(*orderings) + if self._limit is not None: ppl = ppl.limit(self._limit) diff --git a/tests/unit/v1/test_base_query.py b/tests/unit/v1/test_base_query.py index b9ce1fde9..1ed88bf6c 100644 --- a/tests/unit/v1/test_base_query.py +++ b/tests/unit/v1/test_base_query.py @@ -2337,3 +2337,99 @@ def _make_snapshot(docref, values): from google.cloud.firestore_v1 import document return document.DocumentSnapshot(docref, values, True, None, None, None) + + +def test__build_pipeline_limit_to_last_ordering(): + from google.cloud.firestore_v1 import pipeline_expressions as expr + + # Verify that for limit_to_last=True: + # 1. Sort (reversed) + # 2. Where (cursor condition) + + client = make_client() + # Query: Order by 'a' ASC, StartAt(10), LimitToLast(5) + query = ( + client.collection("my_col").order_by("a").start_at({"a": 10}).limit_to_last(5) + ) + + pipeline = query._build_pipeline(client.pipeline()) + + # Expected stages: + # 0. Collection + # 1. Exists (for 'a') + # 2. Sort (DESCENDING) -> This must come BEFORE the cursor filter + # 3. Where (a > 10 condition or similar) + # 4. Limit (5) + # 5. Sort (ASCENDING) + + assert len(pipeline.stages) >= 4 + + # Find indices + sort_reversed_idx = -1 + cursor_where_idx = -1 + + for i, stage in enumerate(pipeline.stages): + if isinstance(stage, stages.Sort): + # Check if it is the reversed sort (DESCENDING) + if ( + len(stage.orders) > 0 + and stage.orders[0].order_dir == expr.Ordering.Direction.DESCENDING + ): + if sort_reversed_idx == -1: + sort_reversed_idx = i + + if isinstance(stage, stages.Where): + # Check if this is the cursor condition. + # Cursor condition for start_at({"a": 10}) should be related to 'a' and 10. + # usually an OR or Comparison. + # The Exists filter is also a Where, but it's usually `exists(a)`. + + # Simple check: The condition is not just an 'exists' function call. + cond = stage.condition + if not (hasattr(cond, "name") and cond.name == "exists"): + # Assume this is the cursor filter + cursor_where_idx = i + + assert sort_reversed_idx != -1, "Reversed sort stage not found" + assert cursor_where_idx != -1, "Cursor filter stage not found" + + # Reversed Sort must happen BEFORE Cursor Filter + assert sort_reversed_idx < cursor_where_idx + + +def test__build_pipeline_normal_ordering(): + from google.cloud.firestore_v1 import pipeline_expressions as expr + + # Verify that for limit_to_last=False (Normal): + # 1. Where (cursor condition) + # 2. Sort + + client = make_client() + # Query: Order by 'a' ASC, StartAt(10) + query = client.collection("my_col").order_by("a").start_at({"a": 10}) + + pipeline = query._build_pipeline(client.pipeline()) + + # Expected stages: + # 0. Collection + # 1. Exists (for 'a') + # 2. Where (cursor condition) + # 3. Sort (ASCENDING) + + sort_idx = -1 + cursor_where_idx = -1 + + for i, stage in enumerate(pipeline.stages): + if isinstance(stage, stages.Sort): + sort_idx = i + + if isinstance(stage, stages.Where): + cond = stage.condition + if not (hasattr(cond, "name") and cond.name == "exists"): + cursor_where_idx = i + + assert sort_idx != -1, "Sort stage not found" + assert cursor_where_idx != -1, "Cursor filter stage not found" + + # Cursor Filter must happen BEFORE Sort + assert cursor_where_idx < sort_idx From 1b518a885b453d0d32335912bd682241386001c3 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Tue, 6 Jan 2026 22:17:18 -0800 Subject: [PATCH 5/6] improved test --- tests/unit/v1/test_base_query.py | 42 ++++++++++++++++++++++++++++++-- 1 file changed, 40 insertions(+), 2 deletions(-) diff --git a/tests/unit/v1/test_base_query.py b/tests/unit/v1/test_base_query.py index 1ed88bf6c..13756f972 100644 --- a/tests/unit/v1/test_base_query.py +++ b/tests/unit/v1/test_base_query.py @@ -2117,18 +2117,56 @@ def test__query_pipeline_order_sorts(): def test__query_pipeline_cursors(): + from google.cloud.firestore_v1 import pipeline_expressions as expr + client = make_client() query_start = ( client.collection("my_col").order_by("field_a").start_at({"field_a": "value"}) ) pipeline = query_start._build_pipeline(client.pipeline()) - assert len(pipeline.stages) >= 2 + + # Expected stages: Collection, Exists, Where(Cursor), Sort + assert len(pipeline.stages) == 4 + assert pipeline.stages[0].path == "/my_col" + assert isinstance(pipeline.stages[1], stages.Where) # Exists + + cursor_stage = pipeline.stages[2] + assert isinstance(cursor_stage, stages.Where) + condition = cursor_stage.condition + # start_at({"field_a": "value"}) -> field_a >= "value" + # Implemented as Or(field_a > "value", field_a == "value") + assert isinstance(condition, expr.Or) + assert len(condition.params) == 2 + assert condition.params[0].name == "greater_than" + assert isinstance(condition.params[0].params[0], expr.Field) + assert condition.params[0].params[0].path == "field_a" + assert condition.params[1].name == "equal" + + assert isinstance(pipeline.stages[3], stages.Sort) query_end = ( client.collection("my_col").order_by("field_a").end_at({"field_a": "value"}) ) pipeline = query_end._build_pipeline(client.pipeline()) - assert len(pipeline.stages) >= 2 + + # Expected stages: Collection, Exists, Where(Cursor), Sort + assert len(pipeline.stages) == 4 + assert pipeline.stages[0].path == "/my_col" + assert isinstance(pipeline.stages[1], stages.Where) # Exists + + cursor_stage = pipeline.stages[2] + assert isinstance(cursor_stage, stages.Where) + condition = cursor_stage.condition + # end_at({"field_a": "value"}) -> field_a <= "value" + # Implemented as Or(field_a < "value", field_a == "value") + assert isinstance(condition, expr.Or) + assert len(condition.params) == 2 + assert condition.params[0].name == "less_than" + assert isinstance(condition.params[0].params[0], expr.Field) + assert condition.params[0].params[0].path == "field_a" + assert condition.params[1].name == "equal" + + assert isinstance(pipeline.stages[3], stages.Sort) def test__query_pipeline_limit_to_last(): From edc857bce9e688472655f22a17fa53bfe17225f9 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Tue, 6 Jan 2026 22:49:31 -0800 Subject: [PATCH 6/6] match node implementation --- google/cloud/firestore_v1/base_query.py | 24 ++++++++++++++---------- 1 file changed, 14 insertions(+), 10 deletions(-) diff --git a/google/cloud/firestore_v1/base_query.py b/google/cloud/firestore_v1/base_query.py index 56ce4eb25..30e1c8fb7 100644 --- a/google/cloud/firestore_v1/base_query.py +++ b/google/cloud/firestore_v1/base_query.py @@ -1417,28 +1417,32 @@ def _where_conditions_from_cursor( cursor_values, before = cursor size = len(cursor_values) + if is_start_cursor: + filter_func = pipeline_expressions.Expression.greater_than + else: + filter_func = pipeline_expressions.Expression.less_than + field = orderings[size - 1].expr value = pipeline_expressions.Constant(cursor_values[size - 1]) - if not is_start_cursor: - condition = field.less_than(value) - else: - condition = field.greater_than(value) + # Add condition for last bound + condition = filter_func(field, value) if (is_start_cursor and before) or (not is_start_cursor and not before): + # When the cursor bound is inclusive, then the last bound + # can be equal to the value, otherwise it's not equal condition = pipeline_expressions.Or(condition, field.equal(value)) + # Iterate backwards over the remaining bounds, adding a condition for each one for i in range(size - 2, -1, -1): field = orderings[i].expr value = pipeline_expressions.Constant(cursor_values[i]) - if not is_start_cursor: - current_filter = field.less_than(value) - else: - current_filter = field.greater_than(value) - + # For each field in the orderings, the condition is either + # a) lessThan|greaterThan the cursor value, + # b) or equal the cursor value and lessThan|greaterThan the cursor values for other fields condition = pipeline_expressions.Or( - current_filter, + filter_func(field, value), pipeline_expressions.And(field.equal(value), condition), )