1- from django .db .models .aggregates import Aggregate , Count , StdDev , Variance
1+ from django .db import NotSupportedError
2+ from django .db .models .aggregates import (
3+ Aggregate ,
4+ Count ,
5+ StdDev ,
6+ StringAgg ,
7+ Variance ,
8+ )
29from django .db .models .expressions import Case , Value , When
310from django .db .models .lookups import IsNull
411from django .db .models .sql .where import WhereNode
1118
1219def aggregate (self , compiler , connection , operator = None , resolve_inner_expression = False ):
1320 agg_expression , * _ = self .get_source_expressions ()
14- if self .filter :
15- agg_expression = Case (
16- When (self .filter , then = agg_expression ),
17- # Skip rows that don't meet the criteria.
18- default = Remove (),
19- )
20- lhs_mql = agg_expression .as_mql (compiler , connection , as_expr = True )
21+ lhs_mql = None
22+ if self .filter is not None :
23+ try :
24+ lhs_mql = self .filter .as_mql (compiler , connection , as_expr = True )
25+ except NotSupportedError :
26+ # Generate a CASE statement for this AggregateFilter.
27+ agg_expression = Case (
28+ When (self .filter .condition , then = agg_expression ),
29+ # Skip rows that don't meet the criteria.
30+ default = Remove (),
31+ )
32+ if lhs_mql is None :
33+ lhs_mql = agg_expression .as_mql (compiler , connection , as_expr = True )
2134 if resolve_inner_expression :
2235 return lhs_mql
2336 operator = operator or MONGO_AGGREGATIONS .get (self .__class__ , self .function .lower ())
@@ -32,18 +45,30 @@ def count(self, compiler, connection, resolve_inner_expression=False):
3245 """
3346 agg_expression , * _ = self .get_source_expressions ()
3447 if not self .distinct or resolve_inner_expression :
48+ lhs_mql = None
3549 conditions = [IsNull (agg_expression , False )]
3650 if self .filter :
37- conditions .append (self .filter )
38- inner_expression = Case (
39- When (WhereNode (conditions ), then = agg_expression if self .distinct else Value (1 )),
40- # Skip rows that don't meet the criteria.
41- default = Remove (),
42- )
43- inner_expression = inner_expression .as_mql (compiler , connection , as_expr = True )
51+ try :
52+ lhs_mql = self .filter .as_mql (compiler , connection , as_expr = True )
53+ except NotSupportedError :
54+ # Generate a CASE statement for this AggregateFilter.
55+ conditions .append (self .filter .condition )
56+ condition = When (
57+ WhereNode (conditions ),
58+ then = agg_expression if self .distinct else Value (1 ),
59+ )
60+ inner_expression = Case (condition , default = Remove ())
61+ else :
62+ inner_expression = Case (
63+ When (WhereNode (conditions ), then = agg_expression if self .distinct else Value (1 )),
64+ # Skip rows that don't meet the criteria.
65+ default = Remove (),
66+ )
67+ if lhs_mql is None :
68+ lhs_mql = inner_expression .as_mql (compiler , connection , as_expr = True )
4469 if resolve_inner_expression :
45- return inner_expression
46- return {"$sum" : inner_expression }
70+ return lhs_mql
71+ return {"$sum" : lhs_mql }
4772 # If distinct=True or resolve_inner_expression=False, sum the size of the
4873 # set.
4974 return {"$size" : agg_expression .as_mql (compiler , connection , as_expr = True )}
@@ -57,8 +82,13 @@ def stddev_variance(self, compiler, connection):
5782 return aggregate (self , compiler , connection , operator = operator )
5883
5984
85+ def string_agg (self , compiler , connection ): # noqa: ARG001
86+ raise NotSupportedError ("StringAgg is not supported." )
87+
88+
6089def register_aggregates ():
6190 Aggregate .as_mql_expr = aggregate
6291 Count .as_mql_expr = count
6392 StdDev .as_mql_expr = stddev_variance
93+ StringAgg .as_mql_expr = string_agg
6494 Variance .as_mql_expr = stddev_variance
0 commit comments