Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions django_mongodb_backend/aggregates.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ def count(self, compiler, connection, resolve_inner_expression=False):
# If distinct=True or resolve_inner_expression=False, sum the size of the
# set.
lhs_mql = process_lhs(self, compiler, connection, as_expr=True)
lhs_mql = {"$ifNull": [lhs_mql, []]}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add a comment about what fails without this?

# None shouldn't be counted, so subtract 1 if it's present.
exits_null = {"$cond": {"if": {"$in": [{"$literal": None}, lhs_mql]}, "then": -1, "else": 0}}
return {"$add": [{"$size": lhs_mql}, exits_null]}
Expand Down
22 changes: 7 additions & 15 deletions django_mongodb_backend/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,10 @@ def __init__(self, *args, **kwargs):
self.subqueries = []
# Atlas search stage.
self.search_pipeline = []
# The aggregation has no group-by fields and needs wrapping.
self.wrap_for_global_aggregation = False
# HAVING stage match (MongoDB equivalent)
self.having_match_mql = None

def _get_group_alias_column(self, expr, annotation_group_idx):
"""Generate a dummy field for use in the ids fields in $group."""
Expand Down Expand Up @@ -234,21 +238,9 @@ def _build_aggregation_pipeline(self, ids, group):
"""Build the aggregation pipeline for grouping."""
pipeline = []
if not ids:
group["_id"] = None
pipeline.append({"$facet": {"group": [{"$group": group}]}})
pipeline.append(
{
"$addFields": {
key: {
"$getField": {
"input": {"$arrayElemAt": ["$group", 0]},
"field": key,
}
}
for key in group
}
}
)
pipeline.append({"$group": {"_id": None, **group}})
# If there are no ids and no having clause, apply a global aggregation
self.wrap_for_global_aggregation = not bool(self.having)
else:
group["_id"] = ids
pipeline.append({"$group": group})
Expand Down
34 changes: 8 additions & 26 deletions django_mongodb_backend/fields/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,37 +310,19 @@ class ArrayOverlap(ArrayRHSMixin, FieldGetDbPrepValueMixin, Lookup):

def get_subquery_wrapping_pipeline(self, compiler, connection, field_name, expr):
return [
{"$project": {"tmp_name": expr.as_mql(compiler, connection, as_expr=True)}},
{
"$facet": {
"group": [
{"$project": {"tmp_name": expr.as_mql(compiler, connection, as_expr=True)}},
{
"$unwind": "$tmp_name",
},
{
"$group": {
"_id": None,
"tmp_name": {"$addToSet": "$tmp_name"},
}
},
]
}
"$unwind": "$tmp_name",
},
{
"$project": {
field_name: {
"$ifNull": [
{
"$getField": {
"input": {"$arrayElemAt": ["$group", 0]},
"field": "tmp_name",
}
},
[],
]
}
"$group": {
"_id": None,
"tmp_name": {"$addToSet": "$tmp_name"},
}
},
{"$unionWith": {"pipeline": [{"$documents": [{"tmp_name": []}]}]}},
{"$limit": 1},
{"$project": {field_name: "$tmp_name"}},
]

def as_mql_expr(self, compiler, connection):
Expand Down
49 changes: 16 additions & 33 deletions django_mongodb_backend/fields/embedded_model_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,44 +150,27 @@ def get_subquery_wrapping_pipeline(self, compiler, connection, field_name, expr)
# structure of EmbeddedModelArrayField on the RHS behaves similar to
# ArrayField.
return [
{"$project": {"tmp_name": expr.as_mql(compiler, connection, as_expr=True)}},
# To concatenate all the values from the RHS subquery,
# use an $unwind followed by a $group.
{
"$facet": {
"gathered_data": [
{"$project": {"tmp_name": expr.as_mql(compiler, connection, as_expr=True)}},
# To concatenate all the values from the RHS subquery,
# use an $unwind followed by a $group.
{
"$unwind": "$tmp_name",
},
# The $group stage collects values into an array using
# $addToSet. The use of {_id: null} results in a
# single grouped array. However, because arrays from
# multiple documents are aggregated, the result is a
# list of lists.
{
"$group": {
"_id": None,
"tmp_name": {"$addToSet": "$tmp_name"},
}
},
]
}
"$unwind": "$tmp_name",
},
# The $group stage collects values into an array using
# $addToSet. The use of {_id: null} results in a
# single grouped array. However, because arrays from
# multiple documents are aggregated, the result is a
# list of lists.
{
"$project": {
field_name: {
"$ifNull": [
{
"$getField": {
"input": {"$arrayElemAt": ["$gathered_data", 0]},
"field": "tmp_name",
}
},
[],
]
}
"$group": {
"_id": None,
"tmp_name": {"$addToSet": "$tmp_name"},
}
},
# Add a dummy document in case of empty result.
{"$unionWith": {"pipeline": [{"$documents": [{"tmp_name": []}]}]}},
{"$limit": 1},
{"$project": {field_name: "$tmp_name"}},
]


Expand Down
33 changes: 7 additions & 26 deletions django_mongodb_backend/lookups.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,34 +56,15 @@ def inner(self, compiler, connection):
def get_subquery_wrapping_pipeline(self, compiler, connection, field_name, expr): # noqa: ARG001
return [
{
"$facet": {
"group": [
{
"$group": {
"_id": None,
"tmp_name": {
"$addToSet": expr.as_mql(compiler, connection, as_expr=True)
},
}
}
]
}
},
{
"$project": {
field_name: {
"$ifNull": [
{
"$getField": {
"input": {"$arrayElemAt": ["$group", 0]},
"field": "tmp_name",
}
},
[],
]
}
"$group": {
"_id": None,
# use a temporal name in order to support field_name="_id"
"tmp_name": {"$addToSet": expr.as_mql(compiler, connection, as_expr=True)},
}
},
{"$unionWith": {"pipeline": [{"$documents": [{"tmp_name": []}]}]}},
{"$limit": 1},
{"$project": {field_name: "$tmp_name"}},
]


Expand Down
4 changes: 4 additions & 0 deletions django_mongodb_backend/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ def __init__(self, compiler):
# $lookup stage that encapsulates the pipeline for performing a nested
# subquery.
self.subquery_lookup = None
self.wrap_for_global_aggregation = compiler.wrap_for_global_aggregation

def __repr__(self):
return f"<MongoQuery: {self.match_mql!r} ORDER {self.ordering!r}>"
Expand Down Expand Up @@ -91,6 +92,9 @@ def get_pipeline(self):
pipeline.append({"$match": self.match_mql})
if self.aggregation_pipeline:
pipeline.extend(self.aggregation_pipeline)
if self.wrap_for_global_aggregation:
# Add an empty extra document to handle default values on empty results
pipeline.append({"$unionWith": {"pipeline": [{"$documents": [{}]}]}})
if self.project_fields:
pipeline.append({"$project": self.project_fields})
if self.combinator_pipeline:
Expand Down
26 changes: 4 additions & 22 deletions tests/lookup_/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,28 +137,10 @@ def test_subquery_filter_constant(self):
"let": {},
"pipeline": [
{"$match": {"num": {"$gt": 2}}},
{
"$facet": {
"group": [
{"$group": {"_id": None, "tmp_name": {"$addToSet": "$num"}}}
]
}
},
{
"$project": {
"num": {
"$ifNull": [
{
"$getField": {
"input": {"$arrayElemAt": ["$group", 0]},
"field": "tmp_name",
}
},
[],
]
}
}
},
{"$group": {"_id": None, "tmp_name": {"$addToSet": "$num"}}},
{"$unionWith": {"pipeline": [{"$documents": [{"tmp_name": []}]}]}},
{"$limit": 1},
{"$project": {"num": "$tmp_name"}},
],
}
},
Expand Down
15 changes: 15 additions & 0 deletions tests/model_fields_/test_arrayfield.py
Original file line number Diff line number Diff line change
Expand Up @@ -634,6 +634,21 @@ def test_overlap_values(self):
self.objs[:3],
)

def test_overlap_empty_values(self):
qs = NullableIntegerArrayModel.objects.filter(order__lt=-30)
self.assertCountEqual(
NullableIntegerArrayModel.objects.filter(
field__overlap=qs.values_list("field"),
),
[],
)
self.assertCountEqual(
NullableIntegerArrayModel.objects.filter(
field__overlap=qs.values("field"),
),
[],
)

def test_index(self):
self.assertSequenceEqual(
NullableIntegerArrayModel.objects.filter(field__0=2), self.objs[1:3]
Expand Down
5 changes: 5 additions & 0 deletions tests/model_fields_/test_embedded_model_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -520,6 +520,11 @@ def test_subquery_in_lookup(self):
result = Exhibit.objects.filter(sections__number__in=subquery)
self.assertCountEqual(result, [self.wonders, self.new_discoveries, self.egypt])

def test_subquery_empty_in_lookup(self):
subquery = Audit.objects.filter(section_number=10).values_list("section_number", flat=True)
result = Exhibit.objects.filter(sections__number__in=subquery)
self.assertCountEqual(result, [])

def test_array_as_rhs(self):
result = Exhibit.objects.filter(main_section__number__in=models.F("sections__number"))
self.assertCountEqual(result, [self.new_discoveries])
Expand Down