1- from django .db .models .aggregates import Aggregate , Count , StdDev , Variance
1+ from django .core .exceptions import EmptyResultSet , FullResultSet
2+ from django .db import NotSupportedError
3+ from django .db .models .aggregates import (
4+ Aggregate ,
5+ Count ,
6+ StdDev ,
7+ StringAgg ,
8+ Variance ,
9+ )
210from django .db .models .expressions import Case , Value , When
311from django .db .models .lookups import IsNull
412
917
1018
1119def aggregate (self , compiler , connection , operator = None , resolve_inner_expression = False ):
12- if self .filter :
13- node = self .copy ()
14- node .filter = None
15- source_expressions = node .get_source_expressions ()
16- condition = When (self .filter , then = source_expressions [0 ])
17- node .set_source_expressions ([Case (condition ), * source_expressions [1 :]])
20+ if self .filter is not None :
21+ # Generate a CASE statement for this aggregate.
22+ try :
23+ lhs_mql = self .filter .as_mql (compiler , connection , as_expr = True )
24+ except NotSupportedError :
25+ source_expressions = self .get_source_expressions ()
26+ condition = Case (When (self .filter .condition , then = source_expressions [0 ]))
27+ lhs_mql = condition .as_mql (compiler , connection , as_expr = True )
28+ except FullResultSet :
29+ lhs_mql = source_expressions [0 ].as_mql (compiler , connection , as_expr = True )
30+ except EmptyResultSet :
31+ lhs_mql = Value (None ).as_mql (compiler , connection , as_expr = True )
1832 else :
19- node = self
20- lhs_mql = process_lhs (node , compiler , connection , as_expr = True )
33+ lhs_mql = process_lhs (self , compiler , connection , as_expr = True )
2134 if resolve_inner_expression :
2235 return lhs_mql
2336 operator = operator or MONGO_AGGREGATIONS .get (self .__class__ , self .function .lower ())
@@ -32,14 +45,19 @@ def count(self, compiler, connection, resolve_inner_expression=False):
3245 """
3346 if not self .distinct or resolve_inner_expression :
3447 if self .filter :
35- node = self .copy ()
36- node .filter = None
37- source_expressions = node .get_source_expressions ()
38- condition = When (
39- self .filter , then = Case (When (IsNull (source_expressions [0 ], False ), then = Value (1 )))
40- )
41- node .set_source_expressions ([Case (condition ), * source_expressions [1 :]])
42- inner_expression = process_lhs (node , compiler , connection , as_expr = True )
48+ try :
49+ inner_expression = self .filter .as_mql (compiler , connection , as_expr = True )
50+ except NotSupportedError :
51+ source_expressions = self .get_source_expressions ()
52+ condition = When (
53+ self .filter .condition ,
54+ then = Case (When (IsNull (source_expressions [0 ], False ), then = Value (1 ))),
55+ )
56+ inner_expression = Case (condition ).as_mql (compiler , connection , as_expr = True )
57+ except FullResultSet :
58+ inner_expression = {"$sum" : 1 }
59+ except EmptyResultSet :
60+ inner_expression = {"$sum" : 0 }
4361 else :
4462 lhs_mql = process_lhs (self , compiler , connection , as_expr = True )
4563 null_cond = {"$in" : [{"$type" : lhs_mql }, ["missing" , "null" ]]}
@@ -65,8 +83,13 @@ def stddev_variance(self, compiler, connection):
6583 return aggregate (self , compiler , connection , operator = operator )
6684
6785
86+ def string_agg (self , compiler , connection ): # noqa: ARG001
87+ raise NotSupportedError ("StringAgg is not supported." )
88+
89+
6890def register_aggregates ():
6991 Aggregate .as_mql_expr = aggregate
7092 Count .as_mql_expr = count
7193 StdDev .as_mql_expr = stddev_variance
94+ StringAgg .as_mql_expr = string_agg
7295 Variance .as_mql_expr = stddev_variance
0 commit comments