@@ -590,49 +590,60 @@ def _make_unwind_project_stage(only: list):
590
590
}
591
591
592
592
@classmethod
593
- def _stat_with_unwind (
593
+ def _stat_with_pipeline (
594
594
cls ,
595
- unwind : list ,
595
+ lookup : list = None ,
596
+ unwind : dict = None ,
597
+ add_fields : dict = None ,
596
598
only : list = None ,
597
599
filter : list = None ,
598
600
filter_or : list = None ,
599
601
sort : list = None ,
600
602
page : dict = None ,
601
603
target : str = None ,
602
604
):
603
- if only is None :
604
- raise ERROR_DB_QUERY (reason = "unwind option requires only option." )
605
+ if unwind :
606
+ if only is None :
607
+ raise ERROR_DB_QUERY (reason = "unwind option requires only option." )
605
608
606
- if not isinstance (unwind , dict ):
607
- raise ERROR_DB_QUERY (reason = "unwind option should be dict type." )
609
+ if not isinstance (unwind , dict ):
610
+ raise ERROR_DB_QUERY (reason = "unwind option should be dict type." )
608
611
609
- if "path" not in unwind :
610
- raise ERROR_DB_QUERY (reason = "unwind option should have path key." )
612
+ if "path" not in unwind :
613
+ raise ERROR_DB_QUERY (reason = "unwind option should have path key." )
611
614
612
- unwind_path = unwind ["path" ]
613
- aggregate = [{"unwind" : unwind }]
615
+ aggregate = []
614
616
615
- # Add project stage
616
- project_fields = []
617
- for key in only :
618
- project_fields .append (
617
+ if lookup :
618
+ for lu in lookup :
619
+ aggregate .append ({"lookup" : lu })
620
+
621
+ if unwind :
622
+ aggregate .append ({"unwind" : unwind })
623
+
624
+ if add_fields :
625
+ aggregate .append ({"add_fields" : add_fields })
626
+
627
+ if only :
628
+ project_fields = []
629
+ for key in only :
630
+ project_fields .append (
631
+ {
632
+ "key" : key ,
633
+ "name" : key ,
634
+ }
635
+ )
636
+
637
+ aggregate .append (
619
638
{
620
- "key" : key ,
621
- "name" : key ,
639
+ "project" : {
640
+ "exclude_keys" : True ,
641
+ "only_keys" : True ,
642
+ "fields" : project_fields ,
643
+ }
622
644
}
623
645
)
624
646
625
- aggregate .append (
626
- {
627
- "project" : {
628
- "exclude_keys" : True ,
629
- "only_keys" : True ,
630
- "fields" : project_fields ,
631
- }
632
- }
633
- )
634
-
635
- # Add sort stage
636
647
if sort :
637
648
aggregate .append ({"sort" : sort })
638
649
@@ -641,21 +652,23 @@ def _stat_with_unwind(
641
652
filter = filter ,
642
653
filter_or = filter_or ,
643
654
page = page ,
644
- tageet = target ,
655
+ target = target ,
645
656
allow_disk_use = True ,
646
657
)
647
658
648
659
try :
649
660
vos = []
650
661
total_count = response .get ("total_count" , 0 )
651
662
for result in response .get ("results" , []):
652
- unwind_data = utils .get_dict_value (result , unwind_path )
653
- result = utils .change_dict_value (result , unwind_path , [unwind_data ])
663
+ if unwind :
664
+ unwind_path = unwind ["path" ]
665
+ unwind_data = utils .get_dict_value (result , unwind_path )
666
+ result = utils .change_dict_value (result , unwind_path , [unwind_data ])
654
667
655
668
vo = cls (** result )
656
669
vos .append (vo )
657
670
except Exception as e :
658
- raise ERROR_DB_QUERY (reason = f"Failed to convert unwind result: { e } " )
671
+ raise ERROR_DB_QUERY (reason = f"Failed to convert pipeline result: { e } " )
659
672
660
673
return vos , total_count
661
674
@@ -672,7 +685,9 @@ def query(
672
685
minimal = False ,
673
686
include_count = True ,
674
687
count_only = False ,
688
+ lookup = None ,
675
689
unwind = None ,
690
+ add_fields = None ,
676
691
reference_filter = None ,
677
692
target = None ,
678
693
hint = None ,
@@ -683,9 +698,17 @@ def query(
683
698
sort = sort or []
684
699
page = page or {}
685
700
686
- if unwind :
687
- return cls ._stat_with_unwind (
688
- unwind , only , filter , filter_or , sort , page , target
701
+ if unwind or lookup or add_fields :
702
+ return cls ._stat_with_pipeline (
703
+ lookup = lookup ,
704
+ unwind = unwind ,
705
+ add_fields = add_fields ,
706
+ only = only ,
707
+ filter = filter ,
708
+ filter_or = filter_or ,
709
+ sort = sort ,
710
+ page = page ,
711
+ target = target ,
689
712
)
690
713
691
714
else :
@@ -1075,6 +1098,44 @@ def _make_match_rule(cls, options):
1075
1098
1076
1099
return {"$match" : match_options }
1077
1100
1101
+ @classmethod
1102
+ def _make_lookup_rule (cls , options ):
1103
+ return {"$lookup" : options }
1104
+
1105
+ @classmethod
1106
+ def _make_add_fields_rule (cls , options ):
1107
+ add_fields_options = {}
1108
+
1109
+ for field , conditional in options .items ():
1110
+ add_fields_options .update (
1111
+ {field : cls ._process_conditional_expression (conditional )}
1112
+ )
1113
+
1114
+ return {"$addFields" : add_fields_options }
1115
+
1116
+ @classmethod
1117
+ def _process_conditional_expression (cls , expression ):
1118
+ if isinstance (expression , dict ):
1119
+ if_expression = expression ["if" ]
1120
+
1121
+ if isinstance (if_expression , dict ):
1122
+ replaced = {}
1123
+ for k , v in if_expression .items ():
1124
+ new_k = k .replace ("__" , "$" )
1125
+ replaced [new_k ] = v
1126
+
1127
+ if_expression = replaced
1128
+
1129
+ return {
1130
+ "$cond" : {
1131
+ "if" : if_expression ,
1132
+ "then" : cls ._process_conditional_expression (expression ["then" ]),
1133
+ "else" : cls ._process_conditional_expression (expression ["else" ]),
1134
+ }
1135
+ }
1136
+
1137
+ return expression
1138
+
1078
1139
@classmethod
1079
1140
def _make_aggregate_rules (cls , aggregate ):
1080
1141
_aggregate_rules = []
@@ -1116,6 +1177,12 @@ def _make_aggregate_rules(cls, aggregate):
1116
1177
elif "match" in stage :
1117
1178
rule = cls ._make_match_rule (stage ["match" ])
1118
1179
_aggregate_rules .append (rule )
1180
+ elif "lookup" in stage :
1181
+ rule = cls ._make_lookup_rule (stage ["lookup" ])
1182
+ _aggregate_rules .append (rule )
1183
+ elif "add_fields" in stage :
1184
+ rule = cls ._make_add_fields_rule (stage ["add_fields" ])
1185
+ _aggregate_rules .append (rule )
1119
1186
else :
1120
1187
raise ERROR_REQUIRED_PARAMETER (
1121
1188
key = "aggregate.unwind or aggregate.group or "
@@ -1514,7 +1581,9 @@ def analyze(
1514
1581
sort = None ,
1515
1582
start = None ,
1516
1583
end = None ,
1584
+ lookup = None ,
1517
1585
unwind = None ,
1586
+ add_fields = None ,
1518
1587
date_field = "date" ,
1519
1588
date_field_format = "%Y-%m-%d" ,
1520
1589
reference_filter = None ,
@@ -1552,9 +1621,16 @@ def analyze(
1552
1621
1553
1622
aggregate = []
1554
1623
1624
+ if lookup :
1625
+ for lu in lookup :
1626
+ aggregate .append ({"lookup" : lu })
1627
+
1555
1628
if unwind :
1556
1629
aggregate .append ({"unwind" : unwind })
1557
1630
1631
+ if add_fields :
1632
+ aggregate .append ({"add_fields" : add_fields })
1633
+
1558
1634
aggregate .append ({"group" : {"keys" : group_keys , "fields" : group_fields }})
1559
1635
1560
1636
query = {
0 commit comments