diff --git a/django_mongodb_backend/aggregates.py b/django_mongodb_backend/aggregates.py index fb41ce4fc..ad7c5f661 100644 --- a/django_mongodb_backend/aggregates.py +++ b/django_mongodb_backend/aggregates.py @@ -52,6 +52,8 @@ def count(self, compiler, connection, resolve_inner_expression=False): # If distinct=True or resolve_inner_expression=False, sum the size of the # set. lhs_mql = process_lhs(self, compiler, connection, as_expr=True) + # Wrap null results as an empty array. + lhs_mql = {"$ifNull": [lhs_mql, []]} # None shouldn't be counted, so subtract 1 if it's present. exits_null = {"$cond": {"if": {"$in": [{"$literal": None}, lhs_mql]}, "then": -1, "else": 0}} return {"$add": [{"$size": lhs_mql}, exits_null]} diff --git a/django_mongodb_backend/compiler.py b/django_mongodb_backend/compiler.py index 2145fbf68..409032675 100644 --- a/django_mongodb_backend/compiler.py +++ b/django_mongodb_backend/compiler.py @@ -38,6 +38,10 @@ def __init__(self, *args, **kwargs): self.subqueries = [] # Atlas search stage. self.search_pipeline = [] + # The aggregation has no group-by fields and needs wrapping. + self.wrap_for_global_aggregation = False + # HAVING stage match (MongoDB equivalent) + self.having_match_mql = None def _get_group_alias_column(self, expr, annotation_group_idx): """Generate a dummy field for use in the ids fields in $group.""" @@ -234,21 +238,9 @@ def _build_aggregation_pipeline(self, ids, group): """Build the aggregation pipeline for grouping.""" pipeline = [] if not ids: - group["_id"] = None - pipeline.append({"$facet": {"group": [{"$group": group}]}}) - pipeline.append( - { - "$addFields": { - key: { - "$getField": { - "input": {"$arrayElemAt": ["$group", 0]}, - "field": key, - } - } - for key in group - } - } - ) + pipeline.append({"$group": {"_id": None, **group}}) + # If there are no ids and no having clause, apply a global aggregation + self.wrap_for_global_aggregation = not bool(self.having) else: group["_id"] = ids pipeline.append({"$group": group}) diff --git a/django_mongodb_backend/fields/array.py b/django_mongodb_backend/fields/array.py index 84164c4d1..7645119c5 100644 --- a/django_mongodb_backend/fields/array.py +++ b/django_mongodb_backend/fields/array.py @@ -310,37 +310,24 @@ class ArrayOverlap(ArrayRHSMixin, FieldGetDbPrepValueMixin, Lookup): def get_subquery_wrapping_pipeline(self, compiler, connection, field_name, expr): return [ + {"$project": {"tmp_name": expr.as_mql(compiler, connection, as_expr=True)}}, { - "$facet": { - "group": [ - {"$project": {"tmp_name": expr.as_mql(compiler, connection, as_expr=True)}}, - { - "$unwind": "$tmp_name", - }, - { - "$group": { - "_id": None, - "tmp_name": {"$addToSet": "$tmp_name"}, - } - }, - ] - } + "$unwind": "$tmp_name", }, { - "$project": { - field_name: { - "$ifNull": [ - { - "$getField": { - "input": {"$arrayElemAt": ["$group", 0]}, - "field": "tmp_name", - } - }, - [], - ] - } + "$group": { + "_id": None, + "tmp_name": {"$addToSet": "$tmp_name"}, } }, + # Workaround for https://jira.mongodb.org/browse/SERVER-114196: + # $$NOW becomes unavailable after $unionWith, so it must be stored + # beforehand to ensure it remains accessible later in the pipeline. + {"$addFields": {"__now": "$$NOW"}}, + # Add an empty extra document to handle default values on empty results. + {"$unionWith": {"pipeline": [{"$documents": [{"tmp_name": []}]}]}}, + {"$limit": 1}, + {"$project": {field_name: "$tmp_name"}}, ] def as_mql_expr(self, compiler, connection): diff --git a/django_mongodb_backend/fields/embedded_model_array.py b/django_mongodb_backend/fields/embedded_model_array.py index 501b78428..4a06bcaf0 100644 --- a/django_mongodb_backend/fields/embedded_model_array.py +++ b/django_mongodb_backend/fields/embedded_model_array.py @@ -150,44 +150,31 @@ def get_subquery_wrapping_pipeline(self, compiler, connection, field_name, expr) # structure of EmbeddedModelArrayField on the RHS behaves similar to # ArrayField. return [ + {"$project": {"tmp_name": expr.as_mql(compiler, connection, as_expr=True)}}, + # To concatenate all the values from the RHS subquery, + # use an $unwind followed by a $group. { - "$facet": { - "gathered_data": [ - {"$project": {"tmp_name": expr.as_mql(compiler, connection, as_expr=True)}}, - # To concatenate all the values from the RHS subquery, - # use an $unwind followed by a $group. - { - "$unwind": "$tmp_name", - }, - # The $group stage collects values into an array using - # $addToSet. The use of {_id: null} results in a - # single grouped array. However, because arrays from - # multiple documents are aggregated, the result is a - # list of lists. - { - "$group": { - "_id": None, - "tmp_name": {"$addToSet": "$tmp_name"}, - } - }, - ] - } + "$unwind": "$tmp_name", }, + # The $group stage collects values into an array using + # $addToSet. The use of {_id: null} results in a + # single grouped array. However, because arrays from + # multiple documents are aggregated, the result is a + # list of lists. { - "$project": { - field_name: { - "$ifNull": [ - { - "$getField": { - "input": {"$arrayElemAt": ["$gathered_data", 0]}, - "field": "tmp_name", - } - }, - [], - ] - } + "$group": { + "_id": None, + "tmp_name": {"$addToSet": "$tmp_name"}, } }, + # Workaround for https://jira.mongodb.org/browse/SERVER-114196: + # $$NOW becomes unavailable after $unionWith, so it must be stored + # beforehand to ensure it remains accessible later in the pipeline. + {"$addFields": {"__now": "$$NOW"}}, + # Add a dummy document in case of empty result. + {"$unionWith": {"pipeline": [{"$documents": [{"tmp_name": []}]}]}}, + {"$limit": 1}, + {"$project": {field_name: "$tmp_name"}}, ] diff --git a/django_mongodb_backend/lookups.py b/django_mongodb_backend/lookups.py index 6b59fb961..7df4eb547 100644 --- a/django_mongodb_backend/lookups.py +++ b/django_mongodb_backend/lookups.py @@ -56,34 +56,20 @@ def inner(self, compiler, connection): def get_subquery_wrapping_pipeline(self, compiler, connection, field_name, expr): # noqa: ARG001 return [ { - "$facet": { - "group": [ - { - "$group": { - "_id": None, - "tmp_name": { - "$addToSet": expr.as_mql(compiler, connection, as_expr=True) - }, - } - } - ] - } - }, - { - "$project": { - field_name: { - "$ifNull": [ - { - "$getField": { - "input": {"$arrayElemAt": ["$group", 0]}, - "field": "tmp_name", - } - }, - [], - ] - } + "$group": { + "_id": None, + # use a temporal name in order to support field_name="_id" + "tmp_name": {"$addToSet": expr.as_mql(compiler, connection, as_expr=True)}, } }, + # Workaround for https://jira.mongodb.org/browse/SERVER-114196: + # $$NOW becomes unavailable after $unionWith, so it must be stored + # beforehand to ensure it remains accessible later in the pipeline. + {"$addFields": {"__now": "$$NOW"}}, + # Add an empty extra document to handle default values on empty results. + {"$unionWith": {"pipeline": [{"$documents": [{"tmp_name": []}]}]}}, + {"$limit": 1}, + {"$project": {field_name: "$tmp_name"}}, ] diff --git a/django_mongodb_backend/query.py b/django_mongodb_backend/query.py index 5b4f0ec51..4692b082d 100644 --- a/django_mongodb_backend/query.py +++ b/django_mongodb_backend/query.py @@ -56,6 +56,7 @@ def __init__(self, compiler): # $lookup stage that encapsulates the pipeline for performing a nested # subquery. self.subquery_lookup = None + self.wrap_for_global_aggregation = compiler.wrap_for_global_aggregation def __repr__(self): return f"" @@ -91,6 +92,17 @@ def get_pipeline(self): pipeline.append({"$match": self.match_mql}) if self.aggregation_pipeline: pipeline.extend(self.aggregation_pipeline) + if self.wrap_for_global_aggregation: + pipeline.extend( + [ + # Workaround for https://jira.mongodb.org/browse/SERVER-114196: + # $$NOW becomes unavailable after $unionWith, so it must be stored + # beforehand to ensure it remains accessible later in the pipeline. + {"$addFields": {"__now": "$$NOW"}}, + # Add an empty extra document to handle default values on empty results. + {"$unionWith": {"pipeline": [{"$documents": [{}]}]}}, + ] + ) if self.project_fields: pipeline.append({"$project": self.project_fields}) if self.combinator_pipeline: diff --git a/tests/lookup_/tests.py b/tests/lookup_/tests.py index b6ac8a322..b8f53dbb1 100644 --- a/tests/lookup_/tests.py +++ b/tests/lookup_/tests.py @@ -137,28 +137,11 @@ def test_subquery_filter_constant(self): "let": {}, "pipeline": [ {"$match": {"num": {"$gt": 2}}}, - { - "$facet": { - "group": [ - {"$group": {"_id": None, "tmp_name": {"$addToSet": "$num"}}} - ] - } - }, - { - "$project": { - "num": { - "$ifNull": [ - { - "$getField": { - "input": {"$arrayElemAt": ["$group", 0]}, - "field": "tmp_name", - } - }, - [], - ] - } - } - }, + {"$group": {"_id": None, "tmp_name": {"$addToSet": "$num"}}}, + {"$addFields": {"__now": "$$NOW"}}, + {"$unionWith": {"pipeline": [{"$documents": [{"tmp_name": []}]}]}}, + {"$limit": 1}, + {"$project": {"num": "$tmp_name"}}, ], } }, diff --git a/tests/model_fields_/test_arrayfield.py b/tests/model_fields_/test_arrayfield.py index e334b21dc..8d40214da 100644 --- a/tests/model_fields_/test_arrayfield.py +++ b/tests/model_fields_/test_arrayfield.py @@ -634,6 +634,21 @@ def test_overlap_values(self): self.objs[:3], ) + def test_overlap_empty_values(self): + qs = NullableIntegerArrayModel.objects.filter(order__lt=-30) + self.assertCountEqual( + NullableIntegerArrayModel.objects.filter( + field__overlap=qs.values_list("field"), + ), + [], + ) + self.assertCountEqual( + NullableIntegerArrayModel.objects.filter( + field__overlap=qs.values("field"), + ), + [], + ) + def test_index(self): self.assertSequenceEqual( NullableIntegerArrayModel.objects.filter(field__0=2), self.objs[1:3] diff --git a/tests/model_fields_/test_embedded_model_array.py b/tests/model_fields_/test_embedded_model_array.py index 8453f6379..499205e9b 100644 --- a/tests/model_fields_/test_embedded_model_array.py +++ b/tests/model_fields_/test_embedded_model_array.py @@ -520,6 +520,11 @@ def test_subquery_in_lookup(self): result = Exhibit.objects.filter(sections__number__in=subquery) self.assertCountEqual(result, [self.wonders, self.new_discoveries, self.egypt]) + def test_subquery_empty_in_lookup(self): + subquery = Audit.objects.filter(section_number=10).values_list("section_number", flat=True) + result = Exhibit.objects.filter(sections__number__in=subquery) + self.assertCountEqual(result, []) + def test_array_as_rhs(self): result = Exhibit.objects.filter(main_section__number__in=models.F("sections__number")) self.assertCountEqual(result, [self.new_discoveries])