Skip to content

Commit b711292

Browse files
committed
Add AggregateFilter, StringAgg.as_mql()
django/django@4b977a5
1 parent 598f0dc commit b711292

File tree

1 file changed

+40
-17
lines changed

1 file changed

+40
-17
lines changed

django_mongodb_backend/aggregates.py

Lines changed: 40 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,12 @@
1-
from django.db.models.aggregates import Aggregate, Count, StdDev, Variance
1+
from django.core.exceptions import EmptyResultSet, FullResultSet
2+
from django.db import NotSupportedError
3+
from django.db.models.aggregates import (
4+
Aggregate,
5+
Count,
6+
StdDev,
7+
StringAgg,
8+
Variance,
9+
)
210
from django.db.models.expressions import Case, Value, When
311
from django.db.models.lookups import IsNull
412

@@ -9,15 +17,20 @@
917

1018

1119
def aggregate(self, compiler, connection, operator=None, resolve_inner_expression=False):
12-
if self.filter:
13-
node = self.copy()
14-
node.filter = None
15-
source_expressions = node.get_source_expressions()
16-
condition = When(self.filter, then=source_expressions[0])
17-
node.set_source_expressions([Case(condition), *source_expressions[1:]])
20+
if self.filter is not None:
21+
# Generate a CASE statement for this aggregate.
22+
try:
23+
lhs_mql = self.filter.as_mql(compiler, connection, as_expr=True)
24+
except NotSupportedError:
25+
source_expressions = self.get_source_expressions()
26+
condition = Case(When(self.filter.condition, then=source_expressions[0]))
27+
lhs_mql = condition.as_mql(compiler, connection, as_expr=True)
28+
except FullResultSet:
29+
lhs_mql = source_expressions[0].as_mql(compiler, connection, as_expr=True)
30+
except EmptyResultSet:
31+
lhs_mql = Value(None).as_mql(compiler, connection, as_expr=True)
1832
else:
19-
node = self
20-
lhs_mql = process_lhs(node, compiler, connection, as_expr=True)
33+
lhs_mql = process_lhs(self, compiler, connection, as_expr=True)
2134
if resolve_inner_expression:
2235
return lhs_mql
2336
operator = operator or MONGO_AGGREGATIONS.get(self.__class__, self.function.lower())
@@ -32,14 +45,19 @@ def count(self, compiler, connection, resolve_inner_expression=False):
3245
"""
3346
if not self.distinct or resolve_inner_expression:
3447
if self.filter:
35-
node = self.copy()
36-
node.filter = None
37-
source_expressions = node.get_source_expressions()
38-
condition = When(
39-
self.filter, then=Case(When(IsNull(source_expressions[0], False), then=Value(1)))
40-
)
41-
node.set_source_expressions([Case(condition), *source_expressions[1:]])
42-
inner_expression = process_lhs(node, compiler, connection, as_expr=True)
48+
try:
49+
inner_expression = self.filter.as_mql(compiler, connection, as_expr=True)
50+
except NotSupportedError:
51+
source_expressions = self.get_source_expressions()
52+
condition = When(
53+
self.filter.condition,
54+
then=Case(When(IsNull(source_expressions[0], False), then=Value(1))),
55+
)
56+
inner_expression = Case(condition).as_mql(compiler, connection, as_expr=True)
57+
except FullResultSet:
58+
inner_expression = {"$sum": 1}
59+
except EmptyResultSet:
60+
inner_expression = {"$sum": 0}
4361
else:
4462
lhs_mql = process_lhs(self, compiler, connection, as_expr=True)
4563
null_cond = {"$in": [{"$type": lhs_mql}, ["missing", "null"]]}
@@ -65,8 +83,13 @@ def stddev_variance(self, compiler, connection):
6583
return aggregate(self, compiler, connection, operator=operator)
6684

6785

86+
def string_agg(self, compiler, connection): # noqa: ARG001
87+
raise NotSupportedError("StringAgg is not supported.")
88+
89+
6890
def register_aggregates():
6991
Aggregate.as_mql_expr = aggregate
7092
Count.as_mql_expr = count
7193
StdDev.as_mql_expr = stddev_variance
94+
StringAgg.as_mql_expr = string_agg
7295
Variance.as_mql_expr = stddev_variance

0 commit comments

Comments
 (0)