Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
144 changes: 110 additions & 34 deletions src/spaceone/core/model/mongo_model/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -590,49 +590,60 @@ def _make_unwind_project_stage(only: list):
}

@classmethod
def _stat_with_unwind(
def _stat_with_pipeline(
cls,
unwind: list,
lookup: list = None,
unwind: dict = None,
add_fields: dict = None,
only: list = None,
filter: list = None,
filter_or: list = None,
sort: list = None,
page: dict = None,
target: str = None,
):
if only is None:
raise ERROR_DB_QUERY(reason="unwind option requires only option.")
if unwind:
if only is None:
raise ERROR_DB_QUERY(reason="unwind option requires only option.")

if not isinstance(unwind, dict):
raise ERROR_DB_QUERY(reason="unwind option should be dict type.")
if not isinstance(unwind, dict):
raise ERROR_DB_QUERY(reason="unwind option should be dict type.")

if "path" not in unwind:
raise ERROR_DB_QUERY(reason="unwind option should have path key.")
if "path" not in unwind:
raise ERROR_DB_QUERY(reason="unwind option should have path key.")

unwind_path = unwind["path"]
aggregate = [{"unwind": unwind}]
aggregate = []

# Add project stage
project_fields = []
for key in only:
project_fields.append(
if lookup:
for lu in lookup:
aggregate.append({"lookup": lu})

if unwind:
aggregate.append({"unwind": unwind})

if add_fields:
aggregate.append({"add_fields": add_fields})

if only:
project_fields = []
for key in only:
project_fields.append(
{
"key": key,
"name": key,
}
)

aggregate.append(
{
"key": key,
"name": key,
"project": {
"exclude_keys": True,
"only_keys": True,
"fields": project_fields,
}
}
)

aggregate.append(
{
"project": {
"exclude_keys": True,
"only_keys": True,
"fields": project_fields,
}
}
)

# Add sort stage
if sort:
aggregate.append({"sort": sort})

Expand All @@ -641,21 +652,23 @@ def _stat_with_unwind(
filter=filter,
filter_or=filter_or,
page=page,
tageet=target,
target=target,
allow_disk_use=True,
)

try:
vos = []
total_count = response.get("total_count", 0)
for result in response.get("results", []):
unwind_data = utils.get_dict_value(result, unwind_path)
result = utils.change_dict_value(result, unwind_path, [unwind_data])
if unwind:
unwind_path = unwind["path"]
unwind_data = utils.get_dict_value(result, unwind_path)
result = utils.change_dict_value(result, unwind_path, [unwind_data])

vo = cls(**result)
vos.append(vo)
except Exception as e:
raise ERROR_DB_QUERY(reason=f"Failed to convert unwind result: {e}")
raise ERROR_DB_QUERY(reason=f"Failed to convert pipeline result: {e}")

return vos, total_count

Expand All @@ -672,7 +685,9 @@ def query(
minimal=False,
include_count=True,
count_only=False,
lookup=None,
unwind=None,
add_fields=None,
reference_filter=None,
target=None,
hint=None,
Expand All @@ -683,9 +698,17 @@ def query(
sort = sort or []
page = page or {}

if unwind:
return cls._stat_with_unwind(
unwind, only, filter, filter_or, sort, page, target
if unwind or lookup or add_fields:
return cls._stat_with_pipeline(
lookup=lookup,
unwind=unwind,
add_fields=add_fields,
only=only,
filter=filter,
filter_or=filter_or,
sort=sort,
page=page,
target=target,
)

else:
Expand Down Expand Up @@ -1075,6 +1098,44 @@ def _make_match_rule(cls, options):

return {"$match": match_options}

@classmethod
def _make_lookup_rule(cls, options):
return {"$lookup": options}

@classmethod
def _make_add_fields_rule(cls, options):
add_fields_options = {}

for field, conditional in options.items():
add_fields_options.update(
{field: cls._process_conditional_expression(conditional)}
)

return {"$addFields": add_fields_options}

@classmethod
def _process_conditional_expression(cls, expression):
if isinstance(expression, dict):
if_expression = expression["if"]

if isinstance(if_expression, dict):
replaced = {}
for k, v in if_expression.items():
new_k = k.replace("__", "$")
replaced[new_k] = v

if_expression = replaced

return {
"$cond": {
"if": if_expression,
"then": cls._process_conditional_expression(expression["then"]),
"else": cls._process_conditional_expression(expression["else"]),
}
}

return expression

@classmethod
def _make_aggregate_rules(cls, aggregate):
_aggregate_rules = []
Expand Down Expand Up @@ -1116,6 +1177,12 @@ def _make_aggregate_rules(cls, aggregate):
elif "match" in stage:
rule = cls._make_match_rule(stage["match"])
_aggregate_rules.append(rule)
elif "lookup" in stage:
rule = cls._make_lookup_rule(stage["lookup"])
_aggregate_rules.append(rule)
elif "add_fields" in stage:
rule = cls._make_add_fields_rule(stage["add_fields"])
_aggregate_rules.append(rule)
else:
raise ERROR_REQUIRED_PARAMETER(
key="aggregate.unwind or aggregate.group or "
Expand Down Expand Up @@ -1514,7 +1581,9 @@ def analyze(
sort=None,
start=None,
end=None,
lookup=None,
unwind=None,
add_fields=None,
date_field="date",
date_field_format="%Y-%m-%d",
reference_filter=None,
Expand Down Expand Up @@ -1552,9 +1621,16 @@ def analyze(

aggregate = []

if lookup:
for lu in lookup:
aggregate.append({"lookup": lu})

if unwind:
aggregate.append({"unwind": unwind})

if add_fields:
aggregate.append({"add_fields": add_fields})

aggregate.append({"group": {"keys": group_keys, "fields": group_fields}})

query = {
Expand Down
Loading