1- from typing import List , Union , cast
1+ from typing import Dict , List , Optional , Union , cast
22
33from forestadmin .agent_toolkit .utils .context import User
44from forestadmin .datasource_toolkit .decorators .collection_decorator import CollectionDecorator
55from forestadmin .datasource_toolkit .interfaces .fields import ManyToOne , is_many_to_one
6+ from forestadmin .datasource_toolkit .interfaces .query .aggregation import AggregateResult , Aggregation
67from forestadmin .datasource_toolkit .interfaces .query .condition_tree .nodes .leaf import ConditionTreeLeaf
78from forestadmin .datasource_toolkit .interfaces .query .filter .paginated import PaginatedFilter
89from forestadmin .datasource_toolkit .interfaces .query .filter .unpaginated import Filter
@@ -17,7 +18,7 @@ async def list(self, caller: User, filter_: PaginatedFilter, projection: Project
1718 refined_filter = cast (PaginatedFilter , await self ._refine_filter (caller , filter_ ))
1819 ret = await self .child_collection .list (caller , refined_filter , simplified_projection )
1920
20- return self ._apply_joins_on_records (projection , simplified_projection , ret )
21+ return self ._apply_joins_on_simplified_records (projection , simplified_projection , ret )
2122
2223 async def _refine_filter (
2324 self , caller : User , _filter : Union [Filter , PaginatedFilter , None ]
@@ -28,18 +29,39 @@ async def _refine_filter(
2829 _filter .condition_tree = _filter .condition_tree .replace (
2930 lambda leaf : (
3031 ConditionTreeLeaf (
31- self ._get_fk_field_for_projection (leaf .field ),
32+ self ._get_fk_field_for_many_to_one_projection (leaf .field ),
3233 leaf .operator ,
3334 leaf .value ,
3435 )
35- if self ._is_useless_join (leaf .field .split (":" )[0 ], _filter .condition_tree .projection )
36+ if self ._is_useless_join_for_projection (leaf .field .split (":" )[0 ], _filter .condition_tree .projection )
3637 else leaf
3738 )
3839 )
3940
4041 return _filter
4142
42- def _is_useless_join (self , relation : str , projection : Projection ) -> bool :
43+ async def aggregate (
44+ self , caller : User , filter_ : Union [Filter , None ], aggregation : Aggregation , limit : Optional [int ] = None
45+ ) -> List [AggregateResult ]:
46+ replaced = {} # new_name -> old_name; for a simpler reconciliation
47+
48+ def replacer (field_name : str ) -> str :
49+ if self ._is_useless_join_for_projection (field_name .split (":" )[0 ], aggregation .projection ):
50+ new_field_name = self ._get_fk_field_for_many_to_one_projection (field_name )
51+ replaced [new_field_name ] = field_name
52+ return new_field_name
53+ return field_name
54+
55+ new_aggregation = aggregation .replace_fields (replacer )
56+
57+ aggregate_results = await self .child_collection .aggregate (
58+ caller , cast (Filter , await self ._refine_filter (caller , filter_ )), new_aggregation , limit
59+ )
60+ if aggregation == new_aggregation :
61+ return aggregate_results
62+ return self ._replace_fields_in_aggregate_group (aggregate_results , replaced )
63+
64+ def _is_useless_join_for_projection (self , relation : str , projection : Projection ) -> bool :
4365 relation_schema = self .schema ["fields" ][relation ]
4466 sub_projections = projection .relations [relation ]
4567
@@ -49,7 +71,7 @@ def _is_useless_join(self, relation: str, projection: Projection) -> bool:
4971 and sub_projections [0 ] == relation_schema ["foreign_key_target" ]
5072 )
5173
52- def _get_fk_field_for_projection (self , projection : str ) -> str :
74+ def _get_fk_field_for_many_to_one_projection (self , projection : str ) -> str :
5375 relation_name = projection .split (":" )[0 ]
5476 relation_schema = cast (ManyToOne , self .schema ["fields" ][relation_name ])
5577
@@ -58,18 +80,18 @@ def _get_fk_field_for_projection(self, projection: str) -> str:
5880 def _get_projection_without_useless_joins (self , projection : Projection ) -> Projection :
5981 returned_projection = Projection (* projection )
6082 for relation , relation_projections in projection .relations .items ():
61- if self ._is_useless_join (relation , projection ):
83+ if self ._is_useless_join_for_projection (relation , projection ):
6284 # remove foreign key target from projection
6385 returned_projection .remove (f"{ relation } :{ relation_projections [0 ]} " )
6486
6587 # add foreign keys to projection
66- fk_field = self ._get_fk_field_for_projection ( relation )
88+ fk_field = self ._get_fk_field_for_many_to_one_projection ( f" { relation } : { relation_projections [ 0 ] } " )
6789 if fk_field not in returned_projection :
6890 returned_projection .append (fk_field )
6991
7092 return returned_projection
7193
72- def _apply_joins_on_records (
94+ def _apply_joins_on_simplified_records (
7395 self , initial_projection : Projection , requested_projection : Projection , records : List [RecordsDataAlias ]
7496 ) -> List [RecordsDataAlias ]:
7597 if requested_projection == initial_projection :
@@ -84,11 +106,27 @@ def _apply_joins_on_records(
84106 relation_schema = self .schema ["fields" ][relation ]
85107
86108 if is_many_to_one (relation_schema ):
87- fk_value = record [self ._get_fk_field_for_projection (f"{ relation } :{ relation_projections [0 ]} " )]
109+ fk_value = record [
110+ self ._get_fk_field_for_many_to_one_projection (f"{ relation } :{ relation_projections [0 ]} " )
111+ ]
88112 record [relation ] = {relation_projections [0 ]: fk_value } if fk_value else None
89113
90114 # remove foreign keys
91115 for projection in projections_to_rm :
92116 del record [projection ]
93117
94118 return records
119+
120+ def _replace_fields_in_aggregate_group (
121+ self , aggregate_results : List [AggregateResult ], field_to_replace : Dict [str , str ]
122+ ) -> List [AggregateResult ]:
123+ for aggregate_result in aggregate_results :
124+ group = {}
125+ for field , value in aggregate_result ["group" ].items ():
126+ if field in field_to_replace :
127+ group [field_to_replace [field ]] = value
128+ else :
129+ group [field ] = value
130+ aggregate_result ["group" ] = group
131+
132+ return aggregate_results
0 commit comments