From ea360045c9e2c1e13d218136c1c6492568cc422e Mon Sep 17 00:00:00 2001 From: David Siklosi Date: Sat, 30 Sep 2023 15:45:44 +0200 Subject: [PATCH 001/189] Turning follow_external_dependency into bool or dict property --- dagger/pipeline/io.py | 22 +++++++++++++++++++--- 1 file changed, 19 insertions(+), 3 deletions(-) diff --git a/dagger/pipeline/io.py b/dagger/pipeline/io.py index 452798f..32ae303 100644 --- a/dagger/pipeline/io.py +++ b/dagger/pipeline/io.py @@ -19,9 +19,15 @@ def init_attributes(cls, orig_cls): Attribute( attribute_name="follow_external_dependency", required=False, - comment="Weather an external task sensor should be created if this dataset" - "is created in another pipeline. Default is False", + format_help="dictionary or boolean", + comment="External Task Sensor parameters in key value format: https://airflow.apache.org/docs/apache-airflow/stable/_api/airflow/sensors/base/index.html" ), + # Attribute( + # attribute_name="follow_external_dependency", + # required=False, + # comment="Weather an external task sensor should be created if this dataset" + # "is created in another pipeline. Default is False", + # ), ] ) @@ -34,7 +40,17 @@ def __init__(self, io_config, config_location): self._has_dependency = self.parse_attribute("has_dependency") if self._has_dependency is None: self._has_dependency = True - self._follow_external_dependency = self.parse_attribute("follow_external_dependency") or False + + follow_external_dependency = self.parse_attribute("follow_external_dependency") + if follow_external_dependency is not None: + if isinstance(follow_external_dependency, bool): + if follow_external_dependency: + follow_external_dependency = dict() + else: + follow_external_dependency = None + else: + follow_external_dependency = dict(follow_external_dependency) + self._follow_external_dependency = follow_external_dependency def __eq__(self, other): return self.alias() == other.alias() From 5113752cee666d61d1e52901d9f1b0121cf4de52 Mon Sep 17 00:00:00 2001 From: David Siklosi Date: Sat, 30 Sep 2023 15:47:07 +0200 Subject: [PATCH 002/189] Handling the new dict format of follow_external_dependency --- dagger/dag_creator/airflow/dag_creator.py | 18 ++++++++++++------ dagger/graph/task_graph.py | 2 +- 2 files changed, 13 insertions(+), 7 deletions(-) diff --git a/dagger/dag_creator/airflow/dag_creator.py b/dagger/dag_creator/airflow/dag_creator.py index 70358e8..5db416f 100644 --- a/dagger/dag_creator/airflow/dag_creator.py +++ b/dagger/dag_creator/airflow/dag_creator.py @@ -58,7 +58,7 @@ def _get_external_task_sensor_name_dict(self, from_task_id: str) -> dict: "external_sensor_name": f"{from_pipeline_name}-{from_task_name}-sensor", } - def _get_external_task_sensor(self, from_task_id: str, to_task_id: str) -> ExternalTaskSensor: + def _get_external_task_sensor(self, from_task_id: str, to_task_id: str, follow_external_dependency: dict) -> ExternalTaskSensor: """ create an object of external task sensor for a specific from_task_id and to_task_id """ @@ -72,6 +72,14 @@ def _get_external_task_sensor(self, from_task_id: str, to_task_id: str) -> Exter to_pipe_id = self._task_graph.get_node(to_task_id).obj.pipeline.name + + extra_args = { + 'mode': conf.EXTERNAL_SENSOR_MODE, + 'poke_interval': conf.EXTERNAL_SENSOR_POKE_INTERVAL, + 'timeout': conf.EXTERNAL_SENSOR_TIMEOUT, + } + extra_args.update(follow_external_dependency) + return ExternalTaskSensor( dag=self._dags[to_pipe_id], task_id=external_sensor_name, @@ -80,9 +88,7 @@ def _get_external_task_sensor(self, from_task_id: str, to_task_id: str) -> Exter execution_date_fn=self._get_execution_date_fn( from_pipeline_schedule, to_pipeline_schedule ), - mode=conf.EXTERNAL_SENSOR_MODE, - poke_interval=conf.EXTERNAL_SENSOR_POKE_INTERVAL, - timeout=conf.EXTERNAL_SENSOR_TIMEOUT, + **extra_args ) def _create_control_flow_task(self, pipe_id, dag): @@ -143,7 +149,7 @@ def _create_edge_without_data(self, from_task_id: str, to_task_ids: list, node: to_pipe = self._task_graph.get_node(to_task_id).obj.pipeline_name if from_pipe and from_pipe == to_pipe: self._tasks[from_task_id] >> self._tasks[to_task_id] - elif from_pipe and from_pipe != to_pipe and edge_properties.follow_external_dependency: + elif from_pipe and from_pipe != to_pipe and edge_properties.follow_external_dependency is not None: from_schedule = self._task_graph.get_node(from_task_id).obj.pipeline.schedule to_schedule = self._task_graph.get_node(to_task_id).obj.pipeline.schedule if not from_schedule.startswith("@") and not to_schedule.startswith("@"): @@ -155,7 +161,7 @@ def _create_edge_without_data(self, from_task_id: str, to_task_ids: list, node: not in self._sensor_dict.get(to_pipe, dict()).keys() ): external_task_sensor = self._get_external_task_sensor( - from_task_id, to_task_id + from_task_id, to_task_id, edge_properties.follow_external_dependency ) self._sensor_dict[to_pipe] = { external_task_sensor_name: external_task_sensor diff --git a/dagger/graph/task_graph.py b/dagger/graph/task_graph.py index c7898f6..0f14a56 100644 --- a/dagger/graph/task_graph.py +++ b/dagger/graph/task_graph.py @@ -55,7 +55,7 @@ def add_child(self, child_id): class Edge: - def __init__(self, follow_external_dependency=False): + def __init__(self, follow_external_dependency=None): self._follow_external_dependency = follow_external_dependency @property From fc2a2e2eb530de94e64681ca2f62915186faac7a Mon Sep 17 00:00:00 2001 From: David Siklosi Date: Sat, 30 Sep 2023 15:48:01 +0200 Subject: [PATCH 003/189] Changing test case to see if it handles sensor parameters properly --- .../root/dags/test_external_sensor/dummy_first.yaml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/fixtures/config_finder/root/dags/test_external_sensor/dummy_first.yaml b/tests/fixtures/config_finder/root/dags/test_external_sensor/dummy_first.yaml index 279dd70..e97563f 100644 --- a/tests/fixtures/config_finder/root/dags/test_external_sensor/dummy_first.yaml +++ b/tests/fixtures/config_finder/root/dags/test_external_sensor/dummy_first.yaml @@ -5,7 +5,8 @@ inputs: # format: list | Use dagger init-io cli name: redshift_input schema: dwh table: batch_table - follow_external_dependency: True + follow_external_dependency: + poke_interval: 60 outputs: # format: list | Use dagger init-io cli - type: dummy name: first_dummy_output From db6dee2281f029093b72d6fe6e19847e01825d63 Mon Sep 17 00:00:00 2001 From: claudiazi Date: Mon, 23 Oct 2023 14:54:59 +0200 Subject: [PATCH 004/189] upgrade version of tenacity --- reqs/base.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/reqs/base.txt b/reqs/base.txt index b6c8400..877d102 100644 --- a/reqs/base.txt +++ b/reqs/base.txt @@ -4,4 +4,4 @@ envyaml==1.10.211231 mergedeep==1.3.4 slack==0.0.2 slackclient==2.9.4 -tenacity==8.1.0 +tenacity==8.2.0 From d4634383f84f55492621a41697a94c8b43b6ed0a Mon Sep 17 00:00:00 2001 From: Kiran Vasudev Date: Wed, 15 Nov 2023 14:49:45 +0100 Subject: [PATCH 005/189] added new dbt config parser module --- dagger/utilities/dbt_config_parser.py | 0 1 file changed, 0 insertions(+), 0 deletions(-) create mode 100644 dagger/utilities/dbt_config_parser.py diff --git a/dagger/utilities/dbt_config_parser.py b/dagger/utilities/dbt_config_parser.py new file mode 100644 index 0000000..e69de29 From a91a58fc8b9351f3bb3aa9f5918eb7ba5cb311f1 Mon Sep 17 00:00:00 2001 From: Kiran Vasudev Date: Wed, 15 Nov 2023 14:50:48 +0100 Subject: [PATCH 006/189] added class with constructor --- dagger/utilities/dbt_config_parser.py | 24 ++++++++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/dagger/utilities/dbt_config_parser.py b/dagger/utilities/dbt_config_parser.py index e69de29..983e150 100644 --- a/dagger/utilities/dbt_config_parser.py +++ b/dagger/utilities/dbt_config_parser.py @@ -0,0 +1,24 @@ +from os import path +from os.path import join +from typing import Union +import json +import yaml + +ATHENA_IO_BASE = {"type": "athena"} +S3_IO_BASE = {"type": "s3"} + +class DBTConfigParser: + + def __init__(self, default_config_parameters:dict): + self._default_data_bucket = default_config_parameters["data_bucket"] + self._dbt_project_dir = default_config_parameters.get("project_dir", None) + dbt_manifest_path = path.join(self._dbt_project_dir, "target","manifest.json") + self._dbt_profile_dir = default_config_parameters.get("profile_dir", None) + dbt_profile_path = path.join(self._dbt_profile_dir, "profiles.yml") + + with open(dbt_manifest_path, "r") as f: + data = f.read() + self._manifest_data = json.loads(data) + profile_yaml = yaml.safe_load(open(dbt_profile_path, "r")) + prod_dbt_profile = profile_yaml[self._dbt_project_dir]['outputs']['data'] + self._default_data_dir = prod_dbt_profile.get('s3_data_dir') or prod_dbt_profile.get('s3_staging_dir') \ No newline at end of file From 102620e30a7bda042f3cfadb46a64a1b13258c53 Mon Sep 17 00:00:00 2001 From: Kiran Vasudev Date: Wed, 15 Nov 2023 14:53:16 +0100 Subject: [PATCH 007/189] added method to parse dbt model inputs --- dagger/utilities/dbt_config_parser.py | 42 ++++++++++++++++++++++++++- 1 file changed, 41 insertions(+), 1 deletion(-) diff --git a/dagger/utilities/dbt_config_parser.py b/dagger/utilities/dbt_config_parser.py index 983e150..a3ef0af 100644 --- a/dagger/utilities/dbt_config_parser.py +++ b/dagger/utilities/dbt_config_parser.py @@ -21,4 +21,44 @@ def __init__(self, default_config_parameters:dict): self._manifest_data = json.loads(data) profile_yaml = yaml.safe_load(open(dbt_profile_path, "r")) prod_dbt_profile = profile_yaml[self._dbt_project_dir]['outputs']['data'] - self._default_data_dir = prod_dbt_profile.get('s3_data_dir') or prod_dbt_profile.get('s3_staging_dir') \ No newline at end of file + self._default_data_dir = prod_dbt_profile.get('s3_data_dir') or prod_dbt_profile.get('s3_staging_dir') + def _get_model_data_location(self, node: dict, schema: str, dbt_model_name: str) -> str: + location = node.get("unrendered_config", {}).get("external_location") + if not location: + location = join(self._default_data_dir, schema, dbt_model_name) + + return location.split("data-lake/")[1] + + def _parse_dbt_model_inputs(self, model_name: str) -> dict: + inputs_dict = {} + inputs_list = [] + dbt_ref_to_model = f'model.{self._dbt_project_dir}.{model_name}' + + nodes = self._manifest_data['nodes'] + model_info = nodes[f'model.main.{model_name}'] + + parents_as_full_selectors = model_info.get('depends_on', {}).get('nodes', []) + inputs = [x.split('.')[-1] for x in parents_as_full_selectors] + + for index, node_name in enumerate(parents_as_full_selectors): + if not (".int_" in node_name): + dbt_parent_model_name = node_name.split('.')[-1] + parent_model_node = nodes.get(node_name) + parent_schema = parent_model_node.get('schema') + + model_data_location = self._get_model_data_location(parent_model_node, parent_schema, + dbt_parent_model_name) + + inputs_list.append({ + "schema": parent_schema, + "model_name": inputs[index], + "relative_s3_path": model_data_location + }) + + inputs_dict['model_name'] = model_name + inputs_dict['node_name'] = dbt_ref_to_model + inputs_dict['inputs'] = inputs_list + inputs_dict['schema'] = model_info['schema'] + inputs_dict['relative_s3_path'] = self._get_model_data_location(model_info, model_info['schema'], model_name) + + return inputs_dict \ No newline at end of file From c61e656d364aedcf966ae93b2cb52fbddaf89413 Mon Sep 17 00:00:00 2001 From: Kiran Vasudev Date: Wed, 15 Nov 2023 14:54:03 +0100 Subject: [PATCH 008/189] added functions to generate dagger input and outputs for dbt models --- dagger/utilities/dbt_config_parser.py | 45 +++++++++++++++++++++++++++ 1 file changed, 45 insertions(+) diff --git a/dagger/utilities/dbt_config_parser.py b/dagger/utilities/dbt_config_parser.py index a3ef0af..5acb3b7 100644 --- a/dagger/utilities/dbt_config_parser.py +++ b/dagger/utilities/dbt_config_parser.py @@ -22,6 +22,51 @@ def __init__(self, default_config_parameters:dict): profile_yaml = yaml.safe_load(open(dbt_profile_path, "r")) prod_dbt_profile = profile_yaml[self._dbt_project_dir]['outputs']['data'] self._default_data_dir = prod_dbt_profile.get('s3_data_dir') or prod_dbt_profile.get('s3_staging_dir') + + def parse_dbt_staging_model(self, dbt_staging_model: str) -> Union[str, str]: + _model_split, core_table = dbt_staging_model.split('__') + core_schema = _model_split.split('_')[-1] + + return core_schema, core_table + + def generate_dagger_inputs(self, dbt_inputs: dict) -> Union[list[dict], None]: + dagger_inputs = [] + for dbt_input in dbt_inputs['inputs']: + model_name = dbt_input['model_name'] + athena_input = ATHENA_IO_BASE.copy() + s3_input = S3_IO_BASE.copy() + + if (model_name.startswith("stg_")): + athena_input['name'] = model_name + athena_input['schema'], athena_input['table'] = self.parse_dbt_staging_model(model_name) + + dagger_inputs.append(athena_input) + else: + athena_input['name'] = athena_input['table'] = model_name + athena_input['schema'] = dbt_input['schema'] + + s3_input['name'] = model_name + s3_input['bucket'] = self._default_data_bucket + s3_input['path'] = dbt_input['relative_s3_path'] + + dagger_inputs.append(athena_input) + dagger_inputs.append(s3_input) + + return dagger_inputs or None + + def generate_dagger_outputs(self, dbt_inputs: dict) -> list[dict]: + athena_input = ATHENA_IO_BASE.copy() + s3_input = S3_IO_BASE.copy() + + athena_input['name'] = athena_input['table'] = dbt_inputs['model_name'] + athena_input['schema'] = dbt_inputs['schema'] + + s3_input['name'] = dbt_inputs['model_name'] + s3_input['bucket'] = self._default_data_bucket + s3_input['relative_s3_path'] = dbt_inputs['relative_s3_path'] + + return [athena_input, s3_input] + def _get_model_data_location(self, node: dict, schema: str, dbt_model_name: str) -> str: location = node.get("unrendered_config", {}).get("external_location") if not location: From 29858a975c795b240555637de0d2fe079b63ea56 Mon Sep 17 00:00:00 2001 From: Kiran Vasudev Date: Wed, 15 Nov 2023 14:54:20 +0100 Subject: [PATCH 009/189] added fn to generate io for dbt task --- dagger/utilities/dbt_config_parser.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/dagger/utilities/dbt_config_parser.py b/dagger/utilities/dbt_config_parser.py index 5acb3b7..1b2fce1 100644 --- a/dagger/utilities/dbt_config_parser.py +++ b/dagger/utilities/dbt_config_parser.py @@ -23,6 +23,12 @@ def __init__(self, default_config_parameters:dict): prod_dbt_profile = profile_yaml[self._dbt_project_dir]['outputs']['data'] self._default_data_dir = prod_dbt_profile.get('s3_data_dir') or prod_dbt_profile.get('s3_staging_dir') + def generate_io(self, model_name: str) -> tuple[list[dict], list[dict]]: + model_inputs = self._parse_dbt_model_inputs(model_name) + model_dagger_inputs = self.generate_dagger_inputs(model_inputs) + model_dagger_outputs = self.generate_dagger_outputs(model_inputs) + return model_dagger_inputs, model_dagger_outputs + def parse_dbt_staging_model(self, dbt_staging_model: str) -> Union[str, str]: _model_split, core_table = dbt_staging_model.split('__') core_schema = _model_split.split('_')[-1] From 1a93f056d7dd4d87c02b1c761d6cfb919b6f5ab4 Mon Sep 17 00:00:00 2001 From: Kiran Vasudev Date: Wed, 15 Nov 2023 14:54:30 +0100 Subject: [PATCH 010/189] black format --- dagger/utilities/dbt_config_parser.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/dagger/utilities/dbt_config_parser.py b/dagger/utilities/dbt_config_parser.py index 1b2fce1..d30949b 100644 --- a/dagger/utilities/dbt_config_parser.py +++ b/dagger/utilities/dbt_config_parser.py @@ -7,12 +7,13 @@ ATHENA_IO_BASE = {"type": "athena"} S3_IO_BASE = {"type": "s3"} + class DBTConfigParser: - def __init__(self, default_config_parameters:dict): + def __init__(self, default_config_parameters: dict): self._default_data_bucket = default_config_parameters["data_bucket"] self._dbt_project_dir = default_config_parameters.get("project_dir", None) - dbt_manifest_path = path.join(self._dbt_project_dir, "target","manifest.json") + dbt_manifest_path = path.join(self._dbt_project_dir, "target", "manifest.json") self._dbt_profile_dir = default_config_parameters.get("profile_dir", None) dbt_profile_path = path.join(self._dbt_profile_dir, "profiles.yml") From 847f6229bba0f0a40d0fc294cc0f944330f06617 Mon Sep 17 00:00:00 2001 From: Kiran Vasudev Date: Wed, 15 Nov 2023 14:55:02 +0100 Subject: [PATCH 011/189] use new module in fn to generate configs --- dagger/utilities/module.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/dagger/utilities/module.py b/dagger/utilities/module.py index 6f3b395..c697ef8 100644 --- a/dagger/utilities/module.py +++ b/dagger/utilities/module.py @@ -1,6 +1,7 @@ import logging from os import path from mergedeep import merge +from dbt_config_parser import DBTConfigParser import yaml @@ -21,6 +22,7 @@ def __init__(self, path_to_config, target_dir): self._branches_to_generate = config["branches_to_generate"] self._override_parameters = config.get("override_parameters", {}) self._default_parameters = config.get("default_parameters", {}) + self._dbt_module = DBTConfigParser(self._default_parameters) @staticmethod def read_yaml(yaml_str): @@ -82,6 +84,11 @@ def generate_task_configs(self): ) task_dict = yaml.safe_load(task_str) + if task == 'dbt': + inputs, outputs = self._dbt_module.generate_io(branch_name) + task_dict['inputs'] = inputs + task_dict['outputs'] = outputs + task_dict["autogenerated_by_dagger"] = self._path_to_config override_parameters = self._override_parameters or {} merge(task_dict, override_parameters.get(branch_name, {}).get(task, {})) From 104013af971b2aaa0cb71451f56921472714df02 Mon Sep 17 00:00:00 2001 From: Kiran Vasudev Date: Wed, 15 Nov 2023 14:59:13 +0100 Subject: [PATCH 012/189] changed output bucket --- dagger/utilities/dbt_config_parser.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dagger/utilities/dbt_config_parser.py b/dagger/utilities/dbt_config_parser.py index d30949b..d15b02b 100644 --- a/dagger/utilities/dbt_config_parser.py +++ b/dagger/utilities/dbt_config_parser.py @@ -69,7 +69,7 @@ def generate_dagger_outputs(self, dbt_inputs: dict) -> list[dict]: athena_input['schema'] = dbt_inputs['schema'] s3_input['name'] = dbt_inputs['model_name'] - s3_input['bucket'] = self._default_data_bucket + s3_input['bucket'] = "cho${ENV}-data-lake" s3_input['relative_s3_path'] = dbt_inputs['relative_s3_path'] return [athena_input, s3_input] From de82c14c60011b15e0743bf102a884a675591a6f Mon Sep 17 00:00:00 2001 From: Kiran Vasudev Date: Wed, 15 Nov 2023 16:22:16 +0100 Subject: [PATCH 013/189] fixed import --- dagger/utilities/module.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dagger/utilities/module.py b/dagger/utilities/module.py index c697ef8..bba316c 100644 --- a/dagger/utilities/module.py +++ b/dagger/utilities/module.py @@ -1,7 +1,7 @@ import logging from os import path from mergedeep import merge -from dbt_config_parser import DBTConfigParser +from dagger.utilities.dbt_config_parser import DBTConfigParser import yaml From dcb6854d6f1dfa80c626dc8d6283fa3c7c3a5149 Mon Sep 17 00:00:00 2001 From: Kiran Vasudev Date: Wed, 15 Nov 2023 16:24:10 +0100 Subject: [PATCH 014/189] renamed functions and variables --- dagger/utilities/dbt_config_parser.py | 35 +++++++++++++++------------ 1 file changed, 19 insertions(+), 16 deletions(-) diff --git a/dagger/utilities/dbt_config_parser.py b/dagger/utilities/dbt_config_parser.py index d15b02b..9912353 100644 --- a/dagger/utilities/dbt_config_parser.py +++ b/dagger/utilities/dbt_config_parser.py @@ -1,6 +1,6 @@ from os import path from os.path import join -from typing import Union +from typing import Union, Tuple import json import yaml @@ -25,21 +25,25 @@ def __init__(self, default_config_parameters: dict): self._default_data_dir = prod_dbt_profile.get('s3_data_dir') or prod_dbt_profile.get('s3_staging_dir') def generate_io(self, model_name: str) -> tuple[list[dict], list[dict]]: - model_inputs = self._parse_dbt_model_inputs(model_name) - model_dagger_inputs = self.generate_dagger_inputs(model_inputs) - model_dagger_outputs = self.generate_dagger_outputs(model_inputs) + """ + Generates the dagger inputs and outputs for the respective dbt model + Args: + model_parents = self._get_dbt_model_parents(model_name) + model_dagger_inputs = self.generate_dagger_inputs(model_parents) + model_dagger_outputs = self.generate_dagger_outputs(model_parents['model_name'], model_parents['schema'], model_parents['relative_s3_path']) + return model_dagger_inputs, model_dagger_outputs - def parse_dbt_staging_model(self, dbt_staging_model: str) -> Union[str, str]: + def parse_dbt_staging_model(self, dbt_staging_model: str) -> Tuple[str, str]: _model_split, core_table = dbt_staging_model.split('__') core_schema = _model_split.split('_')[-1] return core_schema, core_table - def generate_dagger_inputs(self, dbt_inputs: dict) -> Union[list[dict], None]: + def generate_dagger_inputs(self, dbt_model_parents: dict) -> Union[list[dict], None]: dagger_inputs = [] - for dbt_input in dbt_inputs['inputs']: - model_name = dbt_input['model_name'] + for parent in dbt_model_parents['inputs']: + model_name = parent['model_name'] athena_input = ATHENA_IO_BASE.copy() s3_input = S3_IO_BASE.copy() @@ -50,27 +54,26 @@ def generate_dagger_inputs(self, dbt_inputs: dict) -> Union[list[dict], None]: dagger_inputs.append(athena_input) else: athena_input['name'] = athena_input['table'] = model_name - athena_input['schema'] = dbt_input['schema'] + athena_input['schema'] = parent['schema'] s3_input['name'] = model_name s3_input['bucket'] = self._default_data_bucket - s3_input['path'] = dbt_input['relative_s3_path'] + s3_input['path'] = parent['relative_s3_path'] dagger_inputs.append(athena_input) dagger_inputs.append(s3_input) return dagger_inputs or None - def generate_dagger_outputs(self, dbt_inputs: dict) -> list[dict]: + def generate_dagger_outputs(self, model_name: str, schema: str, relative_s3_path: str) -> list[dict]: athena_input = ATHENA_IO_BASE.copy() s3_input = S3_IO_BASE.copy() - athena_input['name'] = athena_input['table'] = dbt_inputs['model_name'] - athena_input['schema'] = dbt_inputs['schema'] + athena_input['name'] = athena_input['table'] = s3_input['name'] = model_name + athena_input['schema'] = schema - s3_input['name'] = dbt_inputs['model_name'] s3_input['bucket'] = "cho${ENV}-data-lake" - s3_input['relative_s3_path'] = dbt_inputs['relative_s3_path'] + s3_input['relative_s3_path'] = relative_s3_path return [athena_input, s3_input] @@ -81,7 +84,7 @@ def _get_model_data_location(self, node: dict, schema: str, dbt_model_name: str) return location.split("data-lake/")[1] - def _parse_dbt_model_inputs(self, model_name: str) -> dict: + def _get_dbt_model_parents(self, model_name: str) -> dict: inputs_dict = {} inputs_list = [] dbt_ref_to_model = f'model.{self._dbt_project_dir}.{model_name}' From fb6f8f6be543e6bbea91f227126af0dc1fac0b8e Mon Sep 17 00:00:00 2001 From: Kiran Vasudev Date: Wed, 15 Nov 2023 16:24:26 +0100 Subject: [PATCH 015/189] added type hints and docstrings --- dagger/utilities/dbt_config_parser.py | 63 ++++++++++++++++++++++++++- 1 file changed, 62 insertions(+), 1 deletion(-) diff --git a/dagger/utilities/dbt_config_parser.py b/dagger/utilities/dbt_config_parser.py index 9912353..8262933 100644 --- a/dagger/utilities/dbt_config_parser.py +++ b/dagger/utilities/dbt_config_parser.py @@ -9,7 +9,9 @@ class DBTConfigParser: - + """ + Module that parses the manifest.json file generated by dbt and generates the dagger inputs and outputs for the respective dbt model + """ def __init__(self, default_config_parameters: dict): self._default_data_bucket = default_config_parameters["data_bucket"] self._dbt_project_dir = default_config_parameters.get("project_dir", None) @@ -28,6 +30,12 @@ def generate_io(self, model_name: str) -> tuple[list[dict], list[dict]]: """ Generates the dagger inputs and outputs for the respective dbt model Args: + model_name: name of the dbt model + + Returns: + tuple[list[dict], list[dict]]: dagger inputs and outputs for the respective dbt model + + """ model_parents = self._get_dbt_model_parents(model_name) model_dagger_inputs = self.generate_dagger_inputs(model_parents) model_dagger_outputs = self.generate_dagger_outputs(model_parents['model_name'], model_parents['schema'], model_parents['relative_s3_path']) @@ -35,12 +43,31 @@ def generate_io(self, model_name: str) -> tuple[list[dict], list[dict]]: return model_dagger_inputs, model_dagger_outputs def parse_dbt_staging_model(self, dbt_staging_model: str) -> Tuple[str, str]: + """ + Parses the dbt staging model to get the core schema and table name + Args: + dbt_staging_model: name of the DBT staging model + + Returns: + Tuple[str, str]: core schema and table name + """ _model_split, core_table = dbt_staging_model.split('__') core_schema = _model_split.split('_')[-1] return core_schema, core_table def generate_dagger_inputs(self, dbt_model_parents: dict) -> Union[list[dict], None]: + """ + Generates the dagger inputs for the respective dbt model. This means that all parents of the dbt model are added as dagger inputs. + Staging models are added as Athena inputs and core models are added as Athena and S3 inputs. + Intermediate models are not added as an input. + Args: + dbt_model_parents: All parents of the dbt model + + Returns: + Union[list[dict], None]: dagger inputs for the respective dbt model. If there are no parents, returns None + + """ dagger_inputs = [] for parent in dbt_model_parents['inputs']: model_name = parent['model_name'] @@ -66,6 +93,18 @@ def generate_dagger_inputs(self, dbt_model_parents: dict) -> Union[list[dict], N return dagger_inputs or None def generate_dagger_outputs(self, model_name: str, schema: str, relative_s3_path: str) -> list[dict]: + """ + Generates the dagger outputs for the respective dbt model. + This means that an Athena and S3 output is added for the dbt model. + Args: + model_name: The name of the dbt model + schema: The schema of the dbt model + relative_s3_path: The S3 path of the dbt model relative to the data bucket + + Returns: + list[dict]: dagger S3 and Athena outputs for the respective dbt model + + """ athena_input = ATHENA_IO_BASE.copy() s3_input = S3_IO_BASE.copy() @@ -78,6 +117,19 @@ def generate_dagger_outputs(self, model_name: str, schema: str, relative_s3_path return [athena_input, s3_input] def _get_model_data_location(self, node: dict, schema: str, dbt_model_name: str) -> str: + """ + Gets the S3 path of the dbt model relative to the data bucket. + If external location is not specified in the DBT model config, then the default data directory from the + DBT profiles configuration is used. + Args: + node: The extracted node from the manifest.json file + schema: The schema of the dbt model + dbt_model_name: The name of the dbt model + + Returns: + str: The S3 path of the dbt model relative to the data bucket + + """ location = node.get("unrendered_config", {}).get("external_location") if not location: location = join(self._default_data_dir, schema, dbt_model_name) @@ -85,6 +137,15 @@ def _get_model_data_location(self, node: dict, schema: str, dbt_model_name: str) return location.split("data-lake/")[1] def _get_dbt_model_parents(self, model_name: str) -> dict: + """ + Gets all parents of a single dbt model from the manifest.json file + Args: + model_name: The name of the DBT model + + Returns: + dict: All parents of the dbt model along with the name, schema and S3 path of the dbt model itself + + """ inputs_dict = {} inputs_list = [] dbt_ref_to_model = f'model.{self._dbt_project_dir}.{model_name}' From 06dc86bf47584712eaef1267124278d28ff26ebe Mon Sep 17 00:00:00 2001 From: Kiran Vasudev Date: Wed, 15 Nov 2023 16:24:55 +0100 Subject: [PATCH 016/189] changed how models are selected when config is generated --- dagger/utilities/module.py | 1 + 1 file changed, 1 insertion(+) diff --git a/dagger/utilities/module.py b/dagger/utilities/module.py index bba316c..ee7540d 100644 --- a/dagger/utilities/module.py +++ b/dagger/utilities/module.py @@ -88,6 +88,7 @@ def generate_task_configs(self): inputs, outputs = self._dbt_module.generate_io(branch_name) task_dict['inputs'] = inputs task_dict['outputs'] = outputs + task_dict['task_parameters']['select'] = branch_name task_dict["autogenerated_by_dagger"] = self._path_to_config override_parameters = self._override_parameters or {} From 19a13bce8cbeb7ac9d9d444cc8dd623e19ce7564 Mon Sep 17 00:00:00 2001 From: Kiran Vasudev Date: Wed, 15 Nov 2023 16:28:25 +0100 Subject: [PATCH 017/189] black --- dagger/utilities/dbt_config_parser.py | 113 +++++++++++++++----------- 1 file changed, 67 insertions(+), 46 deletions(-) diff --git a/dagger/utilities/dbt_config_parser.py b/dagger/utilities/dbt_config_parser.py index 8262933..03d3d66 100644 --- a/dagger/utilities/dbt_config_parser.py +++ b/dagger/utilities/dbt_config_parser.py @@ -12,6 +12,7 @@ class DBTConfigParser: """ Module that parses the manifest.json file generated by dbt and generates the dagger inputs and outputs for the respective dbt model """ + def __init__(self, default_config_parameters: dict): self._default_data_bucket = default_config_parameters["data_bucket"] self._dbt_project_dir = default_config_parameters.get("project_dir", None) @@ -23,8 +24,10 @@ def __init__(self, default_config_parameters: dict): data = f.read() self._manifest_data = json.loads(data) profile_yaml = yaml.safe_load(open(dbt_profile_path, "r")) - prod_dbt_profile = profile_yaml[self._dbt_project_dir]['outputs']['data'] - self._default_data_dir = prod_dbt_profile.get('s3_data_dir') or prod_dbt_profile.get('s3_staging_dir') + prod_dbt_profile = profile_yaml[self._dbt_project_dir]["outputs"]["data"] + self._default_data_dir = prod_dbt_profile.get( + "s3_data_dir" + ) or prod_dbt_profile.get("s3_staging_dir") def generate_io(self, model_name: str) -> tuple[list[dict], list[dict]]: """ @@ -38,7 +41,11 @@ def generate_io(self, model_name: str) -> tuple[list[dict], list[dict]]: """ model_parents = self._get_dbt_model_parents(model_name) model_dagger_inputs = self.generate_dagger_inputs(model_parents) - model_dagger_outputs = self.generate_dagger_outputs(model_parents['model_name'], model_parents['schema'], model_parents['relative_s3_path']) + model_dagger_outputs = self.generate_dagger_outputs( + model_parents["model_name"], + model_parents["schema"], + model_parents["relative_s3_path"], + ) return model_dagger_inputs, model_dagger_outputs @@ -51,12 +58,14 @@ def parse_dbt_staging_model(self, dbt_staging_model: str) -> Tuple[str, str]: Returns: Tuple[str, str]: core schema and table name """ - _model_split, core_table = dbt_staging_model.split('__') - core_schema = _model_split.split('_')[-1] + _model_split, core_table = dbt_staging_model.split("__") + core_schema = _model_split.split("_")[-1] return core_schema, core_table - def generate_dagger_inputs(self, dbt_model_parents: dict) -> Union[list[dict], None]: + def generate_dagger_inputs( + self, dbt_model_parents: dict + ) -> Union[list[dict], None]: """ Generates the dagger inputs for the respective dbt model. This means that all parents of the dbt model are added as dagger inputs. Staging models are added as Athena inputs and core models are added as Athena and S3 inputs. @@ -69,30 +78,35 @@ def generate_dagger_inputs(self, dbt_model_parents: dict) -> Union[list[dict], N """ dagger_inputs = [] - for parent in dbt_model_parents['inputs']: - model_name = parent['model_name'] + for parent in dbt_model_parents["inputs"]: + model_name = parent["model_name"] athena_input = ATHENA_IO_BASE.copy() s3_input = S3_IO_BASE.copy() - if (model_name.startswith("stg_")): - athena_input['name'] = model_name - athena_input['schema'], athena_input['table'] = self.parse_dbt_staging_model(model_name) + if model_name.startswith("stg_"): + athena_input["name"] = model_name + ( + athena_input["schema"], + athena_input["table"], + ) = self.parse_dbt_staging_model(model_name) dagger_inputs.append(athena_input) else: - athena_input['name'] = athena_input['table'] = model_name - athena_input['schema'] = parent['schema'] + athena_input["name"] = athena_input["table"] = model_name + athena_input["schema"] = parent["schema"] - s3_input['name'] = model_name - s3_input['bucket'] = self._default_data_bucket - s3_input['path'] = parent['relative_s3_path'] + s3_input["name"] = model_name + s3_input["bucket"] = self._default_data_bucket + s3_input["path"] = parent["relative_s3_path"] dagger_inputs.append(athena_input) dagger_inputs.append(s3_input) return dagger_inputs or None - def generate_dagger_outputs(self, model_name: str, schema: str, relative_s3_path: str) -> list[dict]: + def generate_dagger_outputs( + self, model_name: str, schema: str, relative_s3_path: str + ) -> list[dict]: """ Generates the dagger outputs for the respective dbt model. This means that an Athena and S3 output is added for the dbt model. @@ -108,15 +122,17 @@ def generate_dagger_outputs(self, model_name: str, schema: str, relative_s3_path athena_input = ATHENA_IO_BASE.copy() s3_input = S3_IO_BASE.copy() - athena_input['name'] = athena_input['table'] = s3_input['name'] = model_name - athena_input['schema'] = schema + athena_input["name"] = athena_input["table"] = s3_input["name"] = model_name + athena_input["schema"] = schema - s3_input['bucket'] = "cho${ENV}-data-lake" - s3_input['relative_s3_path'] = relative_s3_path + s3_input["bucket"] = "cho${ENV}-data-lake" + s3_input["relative_s3_path"] = relative_s3_path return [athena_input, s3_input] - def _get_model_data_location(self, node: dict, schema: str, dbt_model_name: str) -> str: + def _get_model_data_location( + self, node: dict, schema: str, dbt_model_name: str + ) -> str: """ Gets the S3 path of the dbt model relative to the data bucket. If external location is not specified in the DBT model config, then the default data directory from the @@ -148,33 +164,38 @@ def _get_dbt_model_parents(self, model_name: str) -> dict: """ inputs_dict = {} inputs_list = [] - dbt_ref_to_model = f'model.{self._dbt_project_dir}.{model_name}' + dbt_ref_to_model = f"model.{self._dbt_project_dir}.{model_name}" - nodes = self._manifest_data['nodes'] - model_info = nodes[f'model.main.{model_name}'] + nodes = self._manifest_data["nodes"] + model_info = nodes[f"model.main.{model_name}"] - parents_as_full_selectors = model_info.get('depends_on', {}).get('nodes', []) - inputs = [x.split('.')[-1] for x in parents_as_full_selectors] + parents_as_full_selectors = model_info.get("depends_on", {}).get("nodes", []) + inputs = [x.split(".")[-1] for x in parents_as_full_selectors] for index, node_name in enumerate(parents_as_full_selectors): if not (".int_" in node_name): - dbt_parent_model_name = node_name.split('.')[-1] + dbt_parent_model_name = node_name.split(".")[-1] parent_model_node = nodes.get(node_name) - parent_schema = parent_model_node.get('schema') - - model_data_location = self._get_model_data_location(parent_model_node, parent_schema, - dbt_parent_model_name) - - inputs_list.append({ - "schema": parent_schema, - "model_name": inputs[index], - "relative_s3_path": model_data_location - }) - - inputs_dict['model_name'] = model_name - inputs_dict['node_name'] = dbt_ref_to_model - inputs_dict['inputs'] = inputs_list - inputs_dict['schema'] = model_info['schema'] - inputs_dict['relative_s3_path'] = self._get_model_data_location(model_info, model_info['schema'], model_name) - - return inputs_dict \ No newline at end of file + parent_schema = parent_model_node.get("schema") + + model_data_location = self._get_model_data_location( + parent_model_node, parent_schema, dbt_parent_model_name + ) + + inputs_list.append( + { + "schema": parent_schema, + "model_name": inputs[index], + "relative_s3_path": model_data_location, + } + ) + + inputs_dict["model_name"] = model_name + inputs_dict["node_name"] = dbt_ref_to_model + inputs_dict["inputs"] = inputs_list + inputs_dict["schema"] = model_info["schema"] + inputs_dict["relative_s3_path"] = self._get_model_data_location( + model_info, model_info["schema"], model_name + ) + + return inputs_dict From 5f83797d424b14dfe373b5efa660cba10e13725d Mon Sep 17 00:00:00 2001 From: Kiran Vasudev Date: Thu, 16 Nov 2023 11:45:30 +0100 Subject: [PATCH 018/189] fix project dir to load the correct profile --- dagger/utilities/dbt_config_parser.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dagger/utilities/dbt_config_parser.py b/dagger/utilities/dbt_config_parser.py index 03d3d66..5f81dbf 100644 --- a/dagger/utilities/dbt_config_parser.py +++ b/dagger/utilities/dbt_config_parser.py @@ -24,7 +24,7 @@ def __init__(self, default_config_parameters: dict): data = f.read() self._manifest_data = json.loads(data) profile_yaml = yaml.safe_load(open(dbt_profile_path, "r")) - prod_dbt_profile = profile_yaml[self._dbt_project_dir]["outputs"]["data"] + prod_dbt_profile = profile_yaml[self._dbt_project_dir.split("/")[-1]]["outputs"]["data"] self._default_data_dir = prod_dbt_profile.get( "s3_data_dir" ) or prod_dbt_profile.get("s3_staging_dir") From 209efd605d1947e6ff40eaef80647fa54227da84 Mon Sep 17 00:00:00 2001 From: Kiran Vasudev Date: Thu, 16 Nov 2023 11:46:11 +0100 Subject: [PATCH 019/189] get external_location from config instead of unrendered config --- dagger/utilities/dbt_config_parser.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dagger/utilities/dbt_config_parser.py b/dagger/utilities/dbt_config_parser.py index 5f81dbf..cfa3c81 100644 --- a/dagger/utilities/dbt_config_parser.py +++ b/dagger/utilities/dbt_config_parser.py @@ -146,7 +146,7 @@ def _get_model_data_location( str: The S3 path of the dbt model relative to the data bucket """ - location = node.get("unrendered_config", {}).get("external_location") + location = node.get("config", {}).get("external_location") if not location: location = join(self._default_data_dir, schema, dbt_model_name) From 3c4e4c369d36a2e5b619d2aeb32a621e862e4344 Mon Sep 17 00:00:00 2001 From: Kiran Vasudev Date: Thu, 16 Nov 2023 11:46:49 +0100 Subject: [PATCH 020/189] renamed variables for better understanding --- dagger/utilities/dbt_config_parser.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/dagger/utilities/dbt_config_parser.py b/dagger/utilities/dbt_config_parser.py index cfa3c81..801bb3d 100644 --- a/dagger/utilities/dbt_config_parser.py +++ b/dagger/utilities/dbt_config_parser.py @@ -169,23 +169,23 @@ def _get_dbt_model_parents(self, model_name: str) -> dict: nodes = self._manifest_data["nodes"] model_info = nodes[f"model.main.{model_name}"] - parents_as_full_selectors = model_info.get("depends_on", {}).get("nodes", []) - inputs = [x.split(".")[-1] for x in parents_as_full_selectors] + parent_node_names = model_info.get("depends_on", {}).get("nodes", []) + parent_model_names = [x.split(".")[-1] for x in parent_node_names] - for index, node_name in enumerate(parents_as_full_selectors): - if not (".int_" in node_name): - dbt_parent_model_name = node_name.split(".")[-1] - parent_model_node = nodes.get(node_name) + for index, parent_node_name in enumerate(parent_node_names): + if not (".int_" in parent_node_name): + parent_model_name = parent_node_name.split(".")[-1] + parent_model_node = nodes.get(parent_node_name) parent_schema = parent_model_node.get("schema") model_data_location = self._get_model_data_location( - parent_model_node, parent_schema, dbt_parent_model_name + parent_model_node, parent_schema, parent_model_name ) inputs_list.append( { "schema": parent_schema, - "model_name": inputs[index], + "model_name": parent_model_names[index], "relative_s3_path": model_data_location, } ) From 1bd525cdc0761080d0f8d42e2cb51b09d17d55c3 Mon Sep 17 00:00:00 2001 From: Kiran Vasudev Date: Thu, 16 Nov 2023 11:47:15 +0100 Subject: [PATCH 021/189] added doctest for fn --- dagger/utilities/dbt_config_parser.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/dagger/utilities/dbt_config_parser.py b/dagger/utilities/dbt_config_parser.py index 801bb3d..12f5e92 100644 --- a/dagger/utilities/dbt_config_parser.py +++ b/dagger/utilities/dbt_config_parser.py @@ -57,6 +57,11 @@ def parse_dbt_staging_model(self, dbt_staging_model: str) -> Tuple[str, str]: Returns: Tuple[str, str]: core schema and table name + + >>> parse_dbt_staging_model("schema_name__table") + ('schema_name', 'table') + >>> parse_dbt_staging_model("another_schema__another_table") + ('another_schema', 'another_table') """ _model_split, core_table = dbt_staging_model.split("__") core_schema = _model_split.split("_")[-1] From 4b651ad24ba3598f2e342ce070367ccd4adf757a Mon Sep 17 00:00:00 2001 From: Kiran Vasudev Date: Thu, 16 Nov 2023 11:48:09 +0100 Subject: [PATCH 022/189] added files for fixtures and tests --- tests/fixtures/modules/dbt_config_parser_fixtures.py | 0 tests/utilities/test_dbt_config_parser.py | 0 2 files changed, 0 insertions(+), 0 deletions(-) create mode 100644 tests/fixtures/modules/dbt_config_parser_fixtures.py create mode 100644 tests/utilities/test_dbt_config_parser.py diff --git a/tests/fixtures/modules/dbt_config_parser_fixtures.py b/tests/fixtures/modules/dbt_config_parser_fixtures.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/utilities/test_dbt_config_parser.py b/tests/utilities/test_dbt_config_parser.py new file mode 100644 index 0000000..e69de29 From 7123cbcddbc0830d5440ec263d067acc9d9cfa47 Mon Sep 17 00:00:00 2001 From: Kiran Vasudev Date: Thu, 16 Nov 2023 11:49:08 +0100 Subject: [PATCH 023/189] added fixture for manifest and profiles --- .../modules/dbt_config_parser_fixtures.py | 67 +++++++++++++++++++ 1 file changed, 67 insertions(+) diff --git a/tests/fixtures/modules/dbt_config_parser_fixtures.py b/tests/fixtures/modules/dbt_config_parser_fixtures.py index e69de29..96ac1fb 100644 --- a/tests/fixtures/modules/dbt_config_parser_fixtures.py +++ b/tests/fixtures/modules/dbt_config_parser_fixtures.py @@ -0,0 +1,67 @@ +DBT_MANIFEST_FILE_FIXTURE = { + "nodes": { + "model.main.model1": { + "database": "awsdatacatalog", + "schema": "analytics_engineering", + "name": "fct_supplier_revenue", + "config": { + "external_location": "s3://bucket1-data-lake/path1/model1", + "materialized": "incremental", + "incremental_strategy": "insert_overwrite", + }, + "description": "Details of revenue calculation at supplier level for each observation day", + "tags": ["daily"], + "unrendered_config": { + "materialized": "incremental", + "external_location": "s3://bucket1-data-lake/path1/model1", + "incremental_strategy": "insert_overwrite", + "partitioned_by": ["year", "month", "day", "dt"], + "tags": ["daily"], + "on_schema_change": "fail", + }, + "depends_on": { + "macros": [ + "macro.main.macro1", + "macro.main.macro2", + ], + "nodes": [ + "model.main.stg_core_schema1__table1", + "model.main.model2", + "model.main.int_model3", + ], + }, + }, + "model.main.stg_core_schema1__table1": { + "schema": "analytics_engineering", + }, + "model.main.model2": { + "schema": "analytics_engineering", + "config": { + "external_location": "s3://bucket1-data-lake/path2/model2", + }, + }, + "model.main.int_model3": { + "schema": "analytics_engineering", + }, + } +} + +DBT_PROFILE_FIXTURE = { + "main": { + "outputs": { + "data": { + "aws_profile_name": "data", + "database": "awsdatacatalog", + "num_retries": 10, + "region_name": "eu-west-1", + "s3_data_dir": "s3://bucket1-data-lake/path1/tmp", + "s3_data_naming": "schema_table", + "s3_staging_dir": "s3://bucket1-data-lake/path1/", + "schema": "analytics_engineering", + "threads": 4, + "type": "athena", + "work_group": "primary", + }, + } + } +} \ No newline at end of file From 77fb2b24b4d83bb2b8c742daf4ae5b6c5be4a7d3 Mon Sep 17 00:00:00 2001 From: Kiran Vasudev Date: Thu, 16 Nov 2023 11:51:09 +0100 Subject: [PATCH 024/189] setUp test class --- tests/utilities/test_dbt_config_parser.py | 30 +++++++++++++++++++++++ 1 file changed, 30 insertions(+) diff --git a/tests/utilities/test_dbt_config_parser.py b/tests/utilities/test_dbt_config_parser.py index e69de29..d383c1b 100644 --- a/tests/utilities/test_dbt_config_parser.py +++ b/tests/utilities/test_dbt_config_parser.py @@ -0,0 +1,30 @@ +import logging +import unittest +from unittest import skip +from unittest.mock import patch, MagicMock + +from dagger.utilities.dbt_config_parser import DBTConfigParser +from dagger.utilities.module import Module +from tests.fixtures.modules.dbt_config_parser_fixtures import ( + EXPECTED_DBT_MODEL_PARENTS, + EXPECTED_DAGGER_INPUTS, + DBT_MANIFEST_FILE_FIXTURE, + DBT_PROFILE_FIXTURE, + EXPECTED_DAGGER_OUTPUTS, +) + +_logger = logging.getLogger("root") + +DEFAULT_CONFIG_PARAMS = { + "data_bucket": "bucket1-data-lake", + "project_dir": "main", + "profile_dir": ".dbt", +} + + +class TestDBTConfigParser(unittest.TestCase): + @patch("builtins.open", new_callable=MagicMock, read_data=DBT_MANIFEST_FILE_FIXTURE) + @patch("json.loads", return_value=DBT_MANIFEST_FILE_FIXTURE) + @patch("yaml.safe_load", return_value=DBT_PROFILE_FIXTURE) + def setUp(self, mock_open, mock_json_load, mock_safe_load): + self._dbt_config_parser = DBTConfigParser(DEFAULT_CONFIG_PARAMS) From 47c063856b7aa8addf8d3f0fdfd7779eaf510441 Mon Sep 17 00:00:00 2001 From: Kiran Vasudev Date: Thu, 16 Nov 2023 11:53:08 +0100 Subject: [PATCH 025/189] added test for get_dbt_model_parents --- .../modules/dbt_config_parser_fixtures.py | 19 +++++++++++++++++++ tests/utilities/test_dbt_config_parser.py | 7 +++++++ 2 files changed, 26 insertions(+) diff --git a/tests/fixtures/modules/dbt_config_parser_fixtures.py b/tests/fixtures/modules/dbt_config_parser_fixtures.py index 96ac1fb..372153a 100644 --- a/tests/fixtures/modules/dbt_config_parser_fixtures.py +++ b/tests/fixtures/modules/dbt_config_parser_fixtures.py @@ -64,4 +64,23 @@ }, } } +} + +EXPECTED_DBT_MODEL_PARENTS = { + "inputs": [ + { + "model_name": "stg_core_schema1__table1", + "relative_s3_path": "path1/tmp/analytics_engineering/stg_core_schema1__table1", + "schema": "analytics_engineering", + }, + { + "model_name": "model2", + "relative_s3_path": "path2/model2", + "schema": "analytics_engineering", + }, + ], + "model_name": "model1", + "node_name": "model.main.model1", + "relative_s3_path": "path1/model1", + "schema": "analytics_engineering", } \ No newline at end of file diff --git a/tests/utilities/test_dbt_config_parser.py b/tests/utilities/test_dbt_config_parser.py index d383c1b..4835b2b 100644 --- a/tests/utilities/test_dbt_config_parser.py +++ b/tests/utilities/test_dbt_config_parser.py @@ -20,6 +20,7 @@ "project_dir": "main", "profile_dir": ".dbt", } +MODEL_NAME = "model1" class TestDBTConfigParser(unittest.TestCase): @@ -28,3 +29,9 @@ class TestDBTConfigParser(unittest.TestCase): @patch("yaml.safe_load", return_value=DBT_PROFILE_FIXTURE) def setUp(self, mock_open, mock_json_load, mock_safe_load): self._dbt_config_parser = DBTConfigParser(DEFAULT_CONFIG_PARAMS) + + + def test_get_dbt_model_parents(self): + result = self._dbt_config_parser._get_dbt_model_parents(MODEL_NAME) + + self.assertDictEqual(result, EXPECTED_DBT_MODEL_PARENTS) \ No newline at end of file From 4eced6b6685758bb9e7bc3080c930afe9d813ff6 Mon Sep 17 00:00:00 2001 From: Kiran Vasudev Date: Thu, 16 Nov 2023 11:53:40 +0100 Subject: [PATCH 026/189] added tests for generate_dagger_inputs --- .../modules/dbt_config_parser_fixtures.py | 23 ++++++++++++++++++- tests/utilities/test_dbt_config_parser.py | 9 +++++++- 2 files changed, 30 insertions(+), 2 deletions(-) diff --git a/tests/fixtures/modules/dbt_config_parser_fixtures.py b/tests/fixtures/modules/dbt_config_parser_fixtures.py index 372153a..ea6acd3 100644 --- a/tests/fixtures/modules/dbt_config_parser_fixtures.py +++ b/tests/fixtures/modules/dbt_config_parser_fixtures.py @@ -83,4 +83,25 @@ "node_name": "model.main.model1", "relative_s3_path": "path1/model1", "schema": "analytics_engineering", -} \ No newline at end of file +} + +EXPECTED_DAGGER_INPUTS = [ + { + "name": "stg_core_schema1__table1", + "schema": "schema1", + "table": "table1", + "type": "athena", + }, + { + "name": "model2", + "schema": "analytics_engineering", + "table": "model2", + "type": "athena", + }, + { + "bucket": "bucket1-data-lake", + "name": "model2", + "path": "path2/model2", + "type": "s3", + }, +] \ No newline at end of file diff --git a/tests/utilities/test_dbt_config_parser.py b/tests/utilities/test_dbt_config_parser.py index 4835b2b..a882760 100644 --- a/tests/utilities/test_dbt_config_parser.py +++ b/tests/utilities/test_dbt_config_parser.py @@ -34,4 +34,11 @@ def setUp(self, mock_open, mock_json_load, mock_safe_load): def test_get_dbt_model_parents(self): result = self._dbt_config_parser._get_dbt_model_parents(MODEL_NAME) - self.assertDictEqual(result, EXPECTED_DBT_MODEL_PARENTS) \ No newline at end of file + self.assertDictEqual(result, EXPECTED_DBT_MODEL_PARENTS) + + def test_generate_dagger_inputs(self): + result_inputs = self._dbt_config_parser.generate_dagger_inputs( + EXPECTED_DBT_MODEL_PARENTS + ) + + self.assertListEqual(result_inputs, EXPECTED_DAGGER_INPUTS) \ No newline at end of file From aa767e8d670b99deac0b6568ccae5fa0a1a4701a Mon Sep 17 00:00:00 2001 From: Kiran Vasudev Date: Thu, 16 Nov 2023 11:53:56 +0100 Subject: [PATCH 027/189] added test for generate_dagger_outputs --- .../modules/dbt_config_parser_fixtures.py | 17 ++++++++++++++++- tests/utilities/test_dbt_config_parser.py | 11 ++++++++++- 2 files changed, 26 insertions(+), 2 deletions(-) diff --git a/tests/fixtures/modules/dbt_config_parser_fixtures.py b/tests/fixtures/modules/dbt_config_parser_fixtures.py index ea6acd3..821c5b8 100644 --- a/tests/fixtures/modules/dbt_config_parser_fixtures.py +++ b/tests/fixtures/modules/dbt_config_parser_fixtures.py @@ -104,4 +104,19 @@ "path": "path2/model2", "type": "s3", }, -] \ No newline at end of file +] + +EXPECTED_DAGGER_OUTPUTS = [ + { + "name": "model1", + "schema": "analytics_engineering", + "table": "model1", + "type": "athena", + }, + { + "bucket": "cho${ENV}-data-lake", + "name": "model1", + "relative_s3_path": "path1/model1", + "type": "s3", + }, +] diff --git a/tests/utilities/test_dbt_config_parser.py b/tests/utilities/test_dbt_config_parser.py index a882760..a09bbc4 100644 --- a/tests/utilities/test_dbt_config_parser.py +++ b/tests/utilities/test_dbt_config_parser.py @@ -41,4 +41,13 @@ def test_generate_dagger_inputs(self): EXPECTED_DBT_MODEL_PARENTS ) - self.assertListEqual(result_inputs, EXPECTED_DAGGER_INPUTS) \ No newline at end of file + self.assertListEqual(result_inputs, EXPECTED_DAGGER_INPUTS) + + def test_generate_dagger_outputs(self): + result_outputs = self._dbt_config_parser.generate_dagger_outputs( + EXPECTED_DBT_MODEL_PARENTS["model_name"], + EXPECTED_DBT_MODEL_PARENTS["schema"], + EXPECTED_DBT_MODEL_PARENTS["relative_s3_path"], + ) + + self.assertListEqual(result_outputs, EXPECTED_DAGGER_OUTPUTS) From 8cecbe67f0584758e2369250be74726580526145 Mon Sep 17 00:00:00 2001 From: Kiran Vasudev Date: Thu, 16 Nov 2023 12:07:27 +0100 Subject: [PATCH 028/189] fixed type hint --- dagger/utilities/dbt_config_parser.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dagger/utilities/dbt_config_parser.py b/dagger/utilities/dbt_config_parser.py index 12f5e92..afc011a 100644 --- a/dagger/utilities/dbt_config_parser.py +++ b/dagger/utilities/dbt_config_parser.py @@ -29,7 +29,7 @@ def __init__(self, default_config_parameters: dict): "s3_data_dir" ) or prod_dbt_profile.get("s3_staging_dir") - def generate_io(self, model_name: str) -> tuple[list[dict], list[dict]]: + def generate_io(self, model_name: str) -> Tuple[list[dict], list[dict]]: """ Generates the dagger inputs and outputs for the respective dbt model Args: From 04eafe6aa68154d8afe80d2e8c8d6f782e3382f0 Mon Sep 17 00:00:00 2001 From: Kiran Vasudev Date: Thu, 16 Nov 2023 12:20:13 +0100 Subject: [PATCH 029/189] fixed type hints --- dagger/utilities/dbt_config_parser.py | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/dagger/utilities/dbt_config_parser.py b/dagger/utilities/dbt_config_parser.py index afc011a..4d444b8 100644 --- a/dagger/utilities/dbt_config_parser.py +++ b/dagger/utilities/dbt_config_parser.py @@ -1,6 +1,7 @@ from os import path from os.path import join -from typing import Union, Tuple +from typing import Union, Tuple, List, Dict + import json import yaml @@ -8,6 +9,7 @@ S3_IO_BASE = {"type": "s3"} + class DBTConfigParser: """ Module that parses the manifest.json file generated by dbt and generates the dagger inputs and outputs for the respective dbt model @@ -29,14 +31,14 @@ def __init__(self, default_config_parameters: dict): "s3_data_dir" ) or prod_dbt_profile.get("s3_staging_dir") - def generate_io(self, model_name: str) -> Tuple[list[dict], list[dict]]: + def generate_io(self, model_name: str) -> Tuple[List[Dict], List[Dict]]: """ Generates the dagger inputs and outputs for the respective dbt model Args: model_name: name of the dbt model Returns: - tuple[list[dict], list[dict]]: dagger inputs and outputs for the respective dbt model + tuple[List[Dict], List[Dict]]: dagger inputs and outputs for the respective dbt model """ model_parents = self._get_dbt_model_parents(model_name) @@ -70,7 +72,7 @@ def parse_dbt_staging_model(self, dbt_staging_model: str) -> Tuple[str, str]: def generate_dagger_inputs( self, dbt_model_parents: dict - ) -> Union[list[dict], None]: + ) -> Union[List[Dict], None]: """ Generates the dagger inputs for the respective dbt model. This means that all parents of the dbt model are added as dagger inputs. Staging models are added as Athena inputs and core models are added as Athena and S3 inputs. @@ -79,7 +81,7 @@ def generate_dagger_inputs( dbt_model_parents: All parents of the dbt model Returns: - Union[list[dict], None]: dagger inputs for the respective dbt model. If there are no parents, returns None + Union[List[Dict], None]: dagger inputs for the respective dbt model. If there are no parents, returns None """ dagger_inputs = [] @@ -111,7 +113,7 @@ def generate_dagger_inputs( def generate_dagger_outputs( self, model_name: str, schema: str, relative_s3_path: str - ) -> list[dict]: + ) -> List[Dict]: """ Generates the dagger outputs for the respective dbt model. This means that an Athena and S3 output is added for the dbt model. @@ -121,7 +123,7 @@ def generate_dagger_outputs( relative_s3_path: The S3 path of the dbt model relative to the data bucket Returns: - list[dict]: dagger S3 and Athena outputs for the respective dbt model + List[Dict]: dagger S3 and Athena outputs for the respective dbt model """ athena_input = ATHENA_IO_BASE.copy() From 745311dd57d85de980eea6d643816d6a2d627dcf Mon Sep 17 00:00:00 2001 From: Kiran Vasudev Date: Fri, 17 Nov 2023 21:18:40 +0100 Subject: [PATCH 030/189] refactored for simplicity --- dagger/utilities/dbt_config_parser.py | 186 ++++++++------------------ dagger/utilities/module.py | 2 +- 2 files changed, 56 insertions(+), 132 deletions(-) diff --git a/dagger/utilities/dbt_config_parser.py b/dagger/utilities/dbt_config_parser.py index afc011a..d811650 100644 --- a/dagger/utilities/dbt_config_parser.py +++ b/dagger/utilities/dbt_config_parser.py @@ -1,11 +1,13 @@ +import json from os import path from os.path import join -from typing import Union, Tuple -import json +from pprint import pprint +from typing import Tuple, List, Dict + import yaml -ATHENA_IO_BASE = {"type": "athena"} -S3_IO_BASE = {"type": "s3"} +ATHENA_TASK_BASE = {"type": "athena"} +S3_TASK_BASE = {"type": "s3"} class DBTConfigParser: @@ -24,116 +26,58 @@ def __init__(self, default_config_parameters: dict): data = f.read() self._manifest_data = json.loads(data) profile_yaml = yaml.safe_load(open(dbt_profile_path, "r")) - prod_dbt_profile = profile_yaml[self._dbt_project_dir.split("/")[-1]]["outputs"]["data"] + prod_dbt_profile = profile_yaml[self._dbt_project_dir.split("/")[-1]][ + "outputs" + ]["data"] self._default_data_dir = prod_dbt_profile.get( "s3_data_dir" ) or prod_dbt_profile.get("s3_staging_dir") - def generate_io(self, model_name: str) -> Tuple[list[dict], list[dict]]: + def _generate_dagger_dependency(self, node: dict) -> List[Dict]: """ - Generates the dagger inputs and outputs for the respective dbt model + Generates the dagger task based on whether the DBT model node is a staging model or not. + If the DBT model node represents a staging model, then a dagger athena task is generated for each source of the DBT model. + If the DBT model node is not a staging model, then a dagger athena task and an s3 task is generated for the DBT model node itself. Args: - model_name: name of the dbt model - - Returns: - tuple[list[dict], list[dict]]: dagger inputs and outputs for the respective dbt model - - """ - model_parents = self._get_dbt_model_parents(model_name) - model_dagger_inputs = self.generate_dagger_inputs(model_parents) - model_dagger_outputs = self.generate_dagger_outputs( - model_parents["model_name"], - model_parents["schema"], - model_parents["relative_s3_path"], - ) - - return model_dagger_inputs, model_dagger_outputs - - def parse_dbt_staging_model(self, dbt_staging_model: str) -> Tuple[str, str]: - """ - Parses the dbt staging model to get the core schema and table name - Args: - dbt_staging_model: name of the DBT staging model + node: The extracted node from the manifest.json file Returns: - Tuple[str, str]: core schema and table name - - >>> parse_dbt_staging_model("schema_name__table") - ('schema_name', 'table') - >>> parse_dbt_staging_model("another_schema__another_table") - ('another_schema', 'another_table') - """ - _model_split, core_table = dbt_staging_model.split("__") - core_schema = _model_split.split("_")[-1] - - return core_schema, core_table + List[Dict]: The respective dagger tasks for the DBT model node - def generate_dagger_inputs( - self, dbt_model_parents: dict - ) -> Union[list[dict], None]: """ - Generates the dagger inputs for the respective dbt model. This means that all parents of the dbt model are added as dagger inputs. - Staging models are added as Athena inputs and core models are added as Athena and S3 inputs. - Intermediate models are not added as an input. - Args: - dbt_model_parents: All parents of the dbt model + model_name = node["name"] - Returns: - Union[list[dict], None]: dagger inputs for the respective dbt model. If there are no parents, returns None + s3_task = S3_TASK_BASE.copy() + dagger_tasks = [] - """ - dagger_inputs = [] - for parent in dbt_model_parents["inputs"]: - model_name = parent["model_name"] - athena_input = ATHENA_IO_BASE.copy() - s3_input = S3_IO_BASE.copy() - - if model_name.startswith("stg_"): - athena_input["name"] = model_name - ( - athena_input["schema"], - athena_input["table"], - ) = self.parse_dbt_staging_model(model_name) - - dagger_inputs.append(athena_input) - else: - athena_input["name"] = athena_input["table"] = model_name - athena_input["schema"] = parent["schema"] - - s3_input["name"] = model_name - s3_input["bucket"] = self._default_data_bucket - s3_input["path"] = parent["relative_s3_path"] - - dagger_inputs.append(athena_input) - dagger_inputs.append(s3_input) - - return dagger_inputs or None - - def generate_dagger_outputs( - self, model_name: str, schema: str, relative_s3_path: str - ) -> list[dict]: - """ - Generates the dagger outputs for the respective dbt model. - This means that an Athena and S3 output is added for the dbt model. - Args: - model_name: The name of the dbt model - schema: The schema of the dbt model - relative_s3_path: The S3 path of the dbt model relative to the data bucket + if model_name.startswith("stg_"): + source_nodes = node.get("depends_on", {}).get("nodes", []) + for source_node in source_nodes: + _, project_name, schema_name, table_name = source_node.split(".") + athena_task = ATHENA_TASK_BASE.copy() - Returns: - list[dict]: dagger S3 and Athena outputs for the respective dbt model + athena_task["name"] = f"stg_{schema_name}__{table_name}" + athena_task["schema"] = schema_name + athena_task["table"] = table_name - """ - athena_input = ATHENA_IO_BASE.copy() - s3_input = S3_IO_BASE.copy() + dagger_tasks.append(athena_task) + else: + athena_task = ATHENA_TASK_BASE.copy() + model_schema = node["schema"] + athena_task["name"] = f"{model_schema}_{model_name}_athena" + athena_task["table"] = model_name + athena_task["schema"] = node["schema"] - athena_input["name"] = athena_input["table"] = s3_input["name"] = model_name - athena_input["schema"] = schema + s3_task["name"] = f"{model_schema}_{model_name}_s3" + s3_task["bucket"] = self._default_data_bucket + s3_task["path"] = self._get_model_data_location( + node, model_schema, model_name + ) - s3_input["bucket"] = "cho${ENV}-data-lake" - s3_input["relative_s3_path"] = relative_s3_path + dagger_tasks.append(athena_task) + dagger_tasks.append(s3_task) - return [athena_input, s3_input] + return dagger_tasks def _get_model_data_location( self, node: dict, schema: str, dbt_model_name: str @@ -148,59 +92,39 @@ def _get_model_data_location( dbt_model_name: The name of the dbt model Returns: - str: The S3 path of the dbt model relative to the data bucket + str: The relative S3 path of the dbt model relative to the data bucket """ location = node.get("config", {}).get("external_location") if not location: location = join(self._default_data_dir, schema, dbt_model_name) - return location.split("data-lake/")[1] + return location.split(self._default_data_bucket)[1].lstrip("/") - def _get_dbt_model_parents(self, model_name: str) -> dict: + def generate_dagger_io(self, model_name: str) -> Tuple[list, list]: """ - Gets all parents of a single dbt model from the manifest.json file + Parse through all the parents of the DBT model and return the dagger inputs and outputs for the DBT model Args: model_name: The name of the DBT model Returns: - dict: All parents of the dbt model along with the name, schema and S3 path of the dbt model itself + Tuple[list, list]: The dagger inputs and outputs for the DBT model """ - inputs_dict = {} inputs_list = [] - dbt_ref_to_model = f"model.{self._dbt_project_dir}.{model_name}" nodes = self._manifest_data["nodes"] - model_info = nodes[f"model.main.{model_name}"] + model_node = nodes[f"model.main.{model_name}"] - parent_node_names = model_info.get("depends_on", {}).get("nodes", []) - parent_model_names = [x.split(".")[-1] for x in parent_node_names] + parent_node_names = model_node.get("depends_on", {}).get("nodes", []) for index, parent_node_name in enumerate(parent_node_names): if not (".int_" in parent_node_name): - parent_model_name = parent_node_name.split(".")[-1] parent_model_node = nodes.get(parent_node_name) - parent_schema = parent_model_node.get("schema") - - model_data_location = self._get_model_data_location( - parent_model_node, parent_schema, parent_model_name - ) - - inputs_list.append( - { - "schema": parent_schema, - "model_name": parent_model_names[index], - "relative_s3_path": model_data_location, - } - ) - - inputs_dict["model_name"] = model_name - inputs_dict["node_name"] = dbt_ref_to_model - inputs_dict["inputs"] = inputs_list - inputs_dict["schema"] = model_info["schema"] - inputs_dict["relative_s3_path"] = self._get_model_data_location( - model_info, model_info["schema"], model_name - ) - - return inputs_dict + dagger_input = self._generate_dagger_dependency(parent_model_node) + + inputs_list += dagger_input + + output_list = self._generate_dagger_dependency(model_node) + + return inputs_list, output_list diff --git a/dagger/utilities/module.py b/dagger/utilities/module.py index ee7540d..7b954e3 100644 --- a/dagger/utilities/module.py +++ b/dagger/utilities/module.py @@ -85,7 +85,7 @@ def generate_task_configs(self): task_dict = yaml.safe_load(task_str) if task == 'dbt': - inputs, outputs = self._dbt_module.generate_io(branch_name) + inputs, outputs = self._dbt_module.generate_dagger_io(branch_name) task_dict['inputs'] = inputs task_dict['outputs'] = outputs task_dict['task_parameters']['select'] = branch_name From b52e0b540dc023cd64574c43ad8590629a757e1a Mon Sep 17 00:00:00 2001 From: Kiran Vasudev Date: Fri, 17 Nov 2023 21:19:56 +0100 Subject: [PATCH 031/189] updated tests --- .../modules/dbt_config_parser_fixtures.py | 86 ++++++++++++------- tests/utilities/test_dbt_config_parser.py | 54 ++++++++---- 2 files changed, 94 insertions(+), 46 deletions(-) diff --git a/tests/fixtures/modules/dbt_config_parser_fixtures.py b/tests/fixtures/modules/dbt_config_parser_fixtures.py index 821c5b8..d3ec583 100644 --- a/tests/fixtures/modules/dbt_config_parser_fixtures.py +++ b/tests/fixtures/modules/dbt_config_parser_fixtures.py @@ -25,7 +25,7 @@ "macro.main.macro2", ], "nodes": [ - "model.main.stg_core_schema1__table1", + "model.main.stg_core_schema2__table2", "model.main.model2", "model.main.int_model3", ], @@ -33,14 +33,32 @@ }, "model.main.stg_core_schema1__table1": { "schema": "analytics_engineering", + "name": "stg_core_schema1__table1", + "depends_on": { + "macros": [], + "nodes": ["source.main.core_schema1.table1"], + }, + }, + "model.main.stg_core_schema2__table2": { + "schema": "analytics_engineering", + "name": "stg_core_schema2__table2", + "depends_on": { + "macros": [], + "nodes": [ + "source.main.core_schema2.table2", + "source.main.core_schema2.table3", + ], + }, }, "model.main.model2": { + "name": "model2", "schema": "analytics_engineering", "config": { "external_location": "s3://bucket1-data-lake/path2/model2", }, }, "model.main.int_model3": { + "name": "int_model3", "schema": "analytics_engineering", }, } @@ -66,41 +84,51 @@ } } -EXPECTED_DBT_MODEL_PARENTS = { - "inputs": [ - { - "model_name": "stg_core_schema1__table1", - "relative_s3_path": "path1/tmp/analytics_engineering/stg_core_schema1__table1", - "schema": "analytics_engineering", - }, - { - "model_name": "model2", - "relative_s3_path": "path2/model2", - "schema": "analytics_engineering", - }, - ], - "model_name": "model1", - "node_name": "model.main.model1", - "relative_s3_path": "path1/model1", - "schema": "analytics_engineering", -} - -EXPECTED_DAGGER_INPUTS = [ +EXPECTED_STAGING_NODE = [ { + "type": "athena", "name": "stg_core_schema1__table1", - "schema": "schema1", + "schema": "core_schema1", "table": "table1", + } +] +EXPECTED_STAGING_NODE_MULTIPLE_DEPENDENCIES = [ + { + "type": "athena", + "name": "stg_core_schema2__table2", + "schema": "core_schema2", + "table": "table2", + }, + { + "type": "athena", + "name": "stg_core_schema2__table3", + "schema": "core_schema2", + "table": "table3", + }, +] + +EXPECTED_DAGGER_INPUTS = [ + { + "name": "stg_core_schema2__table2", + "schema": "core_schema2", + "table": "table2", "type": "athena", }, { - "name": "model2", + "name": "stg_core_schema2__table3", + "schema": "core_schema2", + "table": "table3", + "type": "athena", + }, + { + "name": "analytics_engineering_model2_athena", "schema": "analytics_engineering", "table": "model2", "type": "athena", }, { "bucket": "bucket1-data-lake", - "name": "model2", + "name": "analytics_engineering_model2_s3", "path": "path2/model2", "type": "s3", }, @@ -108,15 +136,15 @@ EXPECTED_DAGGER_OUTPUTS = [ { - "name": "model1", + "name": "analytics_engineering_fct_supplier_revenue_athena", "schema": "analytics_engineering", - "table": "model1", + "table": "fct_supplier_revenue", "type": "athena", }, { - "bucket": "cho${ENV}-data-lake", - "name": "model1", - "relative_s3_path": "path1/model1", + "bucket": "bucket1-data-lake", + "name": "analytics_engineering_fct_supplier_revenue_s3", + "path": "path1/model1", "type": "s3", }, ] diff --git a/tests/utilities/test_dbt_config_parser.py b/tests/utilities/test_dbt_config_parser.py index a09bbc4..e05fe4e 100644 --- a/tests/utilities/test_dbt_config_parser.py +++ b/tests/utilities/test_dbt_config_parser.py @@ -6,11 +6,12 @@ from dagger.utilities.dbt_config_parser import DBTConfigParser from dagger.utilities.module import Module from tests.fixtures.modules.dbt_config_parser_fixtures import ( - EXPECTED_DBT_MODEL_PARENTS, + EXPECTED_DAGGER_OUTPUTS, EXPECTED_DAGGER_INPUTS, DBT_MANIFEST_FILE_FIXTURE, DBT_PROFILE_FIXTURE, - EXPECTED_DAGGER_OUTPUTS, + EXPECTED_STAGING_NODE, + EXPECTED_STAGING_NODE_MULTIPLE_DEPENDENCIES, ) _logger = logging.getLogger("root") @@ -29,25 +30,44 @@ class TestDBTConfigParser(unittest.TestCase): @patch("yaml.safe_load", return_value=DBT_PROFILE_FIXTURE) def setUp(self, mock_open, mock_json_load, mock_safe_load): self._dbt_config_parser = DBTConfigParser(DEFAULT_CONFIG_PARAMS) + self._sample_dbt_node = DBT_MANIFEST_FILE_FIXTURE["nodes"][ + "model.main.stg_core_schema1__table1" + ] + @skip("Run only locally") + def test_generate_task_configs(self): + module = Module( + path_to_config="./tests/fixtures/modules/dbt_test_config.yaml", + target_dir="./tests/fixtures/modules/", + ) - def test_get_dbt_model_parents(self): - result = self._dbt_config_parser._get_dbt_model_parents(MODEL_NAME) + module.generate_task_configs() - self.assertDictEqual(result, EXPECTED_DBT_MODEL_PARENTS) + def test_generate_dagger_dependency(self): + test_inputs = [ + ( + DBT_MANIFEST_FILE_FIXTURE["nodes"][ + "model.main.stg_core_schema1__table1" + ], + EXPECTED_STAGING_NODE, + ), + ( + DBT_MANIFEST_FILE_FIXTURE["nodes"][ + "model.main.stg_core_schema2__table2" + ], + EXPECTED_STAGING_NODE_MULTIPLE_DEPENDENCIES, + ), + ] + for mock_input, expected_output in test_inputs: + result = self._dbt_config_parser._generate_dagger_dependency(mock_input) + self.assertListEqual(result, expected_output) - def test_generate_dagger_inputs(self): - result_inputs = self._dbt_config_parser.generate_dagger_inputs( - EXPECTED_DBT_MODEL_PARENTS - ) + def test_generate_io_inputs(self): + result, _ = self._dbt_config_parser.generate_dagger_io(MODEL_NAME) - self.assertListEqual(result_inputs, EXPECTED_DAGGER_INPUTS) + self.assertListEqual(result, EXPECTED_DAGGER_INPUTS) - def test_generate_dagger_outputs(self): - result_outputs = self._dbt_config_parser.generate_dagger_outputs( - EXPECTED_DBT_MODEL_PARENTS["model_name"], - EXPECTED_DBT_MODEL_PARENTS["schema"], - EXPECTED_DBT_MODEL_PARENTS["relative_s3_path"], - ) + def test_generate_io_outputs(self): + _, result = self._dbt_config_parser.generate_dagger_io(MODEL_NAME) - self.assertListEqual(result_outputs, EXPECTED_DAGGER_OUTPUTS) + self.assertListEqual(result, EXPECTED_DAGGER_OUTPUTS) From 6e7ced264e0ae3e65472b56fc810e59947966b7e Mon Sep 17 00:00:00 2001 From: Kiran Vasudev Date: Mon, 20 Nov 2023 09:40:26 +0100 Subject: [PATCH 032/189] added dbt profile to default parameters for dbt task --- dagger/utilities/dbt_config_parser.py | 3 ++- tests/utilities/test_dbt_config_parser.py | 1 + 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/dagger/utilities/dbt_config_parser.py b/dagger/utilities/dbt_config_parser.py index bb0d332..4725afd 100644 --- a/dagger/utilities/dbt_config_parser.py +++ b/dagger/utilities/dbt_config_parser.py @@ -17,6 +17,7 @@ class DBTConfigParser: """ def __init__(self, default_config_parameters: dict): + self._dbt_profile = default_config_parameters.get("dbt_profile", "data") self._default_data_bucket = default_config_parameters["data_bucket"] self._dbt_project_dir = default_config_parameters.get("project_dir", None) dbt_manifest_path = path.join(self._dbt_project_dir, "target", "manifest.json") @@ -29,7 +30,7 @@ def __init__(self, default_config_parameters: dict): profile_yaml = yaml.safe_load(open(dbt_profile_path, "r")) prod_dbt_profile = profile_yaml[self._dbt_project_dir.split("/")[-1]][ "outputs" - ]["data"] + ][self._dbt_profile] self._default_data_dir = prod_dbt_profile.get( "s3_data_dir" ) or prod_dbt_profile.get("s3_staging_dir") diff --git a/tests/utilities/test_dbt_config_parser.py b/tests/utilities/test_dbt_config_parser.py index e05fe4e..34d37a1 100644 --- a/tests/utilities/test_dbt_config_parser.py +++ b/tests/utilities/test_dbt_config_parser.py @@ -20,6 +20,7 @@ "data_bucket": "bucket1-data-lake", "project_dir": "main", "profile_dir": ".dbt", + "dbt_profile": "data", } MODEL_NAME = "model1" From e8cbc6f90a9654d283893876a9c0f1537f4b573a Mon Sep 17 00:00:00 2001 From: Kiran Vasudev Date: Mon, 20 Nov 2023 20:01:43 +0100 Subject: [PATCH 033/189] added follow external dependency as true as default for athena task --- dagger/utilities/dbt_config_parser.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/dagger/utilities/dbt_config_parser.py b/dagger/utilities/dbt_config_parser.py index 4725afd..bb06f3a 100644 --- a/dagger/utilities/dbt_config_parser.py +++ b/dagger/utilities/dbt_config_parser.py @@ -6,11 +6,10 @@ import yaml -ATHENA_TASK_BASE = {"type": "athena"} +ATHENA_TASK_BASE = {"type": "athena", "follow_external_dependency": True} S3_TASK_BASE = {"type": "s3"} - class DBTConfigParser: """ Module that parses the manifest.json file generated by dbt and generates the dagger inputs and outputs for the respective dbt model From e62d35cd054e00b2b7f132da347a58718677da39 Mon Sep 17 00:00:00 2001 From: Kiran Vasudev Date: Mon, 20 Nov 2023 20:03:04 +0100 Subject: [PATCH 034/189] add fn to process seed input --- dagger/utilities/dbt_config_parser.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/dagger/utilities/dbt_config_parser.py b/dagger/utilities/dbt_config_parser.py index bb06f3a..73fac6b 100644 --- a/dagger/utilities/dbt_config_parser.py +++ b/dagger/utilities/dbt_config_parser.py @@ -34,6 +34,22 @@ def __init__(self, default_config_parameters: dict): "s3_data_dir" ) or prod_dbt_profile.get("s3_staging_dir") + def _process_seed_input(self, seed_node: dict) -> dict: + """ + Generates a dummy dagger task for the DBT seed node + Args: + seed_node: The extracted seed node from the manifest.json file + + Returns: + dict: The dummy dagger task for the DBT seed node + + """ + task = {} + task["name"] = seed_node.get("name", "") + task["type"] = "dummy" + + return task + def _generate_dagger_dependency(self, node: dict) -> List[Dict]: """ Generates the dagger task based on whether the DBT model node is a staging model or not. From 8de68213c7506ff4ad3ec4e040c2743efc7f28c3 Mon Sep 17 00:00:00 2001 From: Kiran Vasudev Date: Mon, 20 Nov 2023 20:03:34 +0100 Subject: [PATCH 035/189] refactor code to incorporate seeds --- dagger/utilities/dbt_config_parser.py | 40 +++++++++++++++++---------- 1 file changed, 25 insertions(+), 15 deletions(-) diff --git a/dagger/utilities/dbt_config_parser.py b/dagger/utilities/dbt_config_parser.py index 73fac6b..caa3f19 100644 --- a/dagger/utilities/dbt_config_parser.py +++ b/dagger/utilities/dbt_config_parser.py @@ -33,6 +33,10 @@ def __init__(self, default_config_parameters: dict): self._default_data_dir = prod_dbt_profile.get( "s3_data_dir" ) or prod_dbt_profile.get("s3_staging_dir") + self._default_schema = prod_dbt_profile.get("schema") + + self._nodes_in_manifest = self._manifest_data["nodes"] + self._sources_in_manifest = self._manifest_data["sources"] def _process_seed_input(self, seed_node: dict) -> dict: """ @@ -67,23 +71,30 @@ def _generate_dagger_dependency(self, node: dict) -> List[Dict]: s3_task = S3_TASK_BASE.copy() dagger_tasks = [] - if model_name.startswith("stg_"): - source_nodes = node.get("depends_on", {}).get("nodes", []) - for source_node in source_nodes: - _, project_name, schema_name, table_name = source_node.split(".") - athena_task = ATHENA_TASK_BASE.copy() - - athena_task["name"] = f"stg_{schema_name}__{table_name}" - athena_task["schema"] = schema_name - athena_task["table"] = table_name - - dagger_tasks.append(athena_task) + if node.get("resource_type") == "seed": + task = self._process_seed_input(node) + dagger_tasks.append(task) + elif model_name.startswith("stg_"): + source_node_names = node.get("depends_on", {}).get("nodes", []) + for source_node_name in source_node_names: + if source_node_name.startswith("seed"): + source_node = self._nodes_in_manifest[source_node_name] + task = self._process_seed_input(source_node) + else: + source_node = self._sources_in_manifest[source_node_name] + task = ATHENA_TASK_BASE.copy() + + task["schema"] = source_node.get("schema", self._default_schema) + task["table"] = source_node.get("name", "") + task["name"] = f"stg_{task['schema']}__{task['table']}" + + dagger_tasks.append(task) else: athena_task = ATHENA_TASK_BASE.copy() model_schema = node["schema"] athena_task["name"] = f"{model_schema}_{model_name}_athena" athena_task["table"] = model_name - athena_task["schema"] = node["schema"] + athena_task["schema"] = node.get("schema", self._default_schema) s3_task["name"] = f"{model_schema}_{model_name}_s3" s3_task["bucket"] = self._default_data_bucket @@ -130,14 +141,13 @@ def generate_dagger_io(self, model_name: str) -> Tuple[list, list]: """ inputs_list = [] - nodes = self._manifest_data["nodes"] - model_node = nodes[f"model.main.{model_name}"] + model_node = self._nodes_in_manifest[f"model.main.{model_name}"] parent_node_names = model_node.get("depends_on", {}).get("nodes", []) for index, parent_node_name in enumerate(parent_node_names): if not (".int_" in parent_node_name): - parent_model_node = nodes.get(parent_node_name) + parent_model_node = self._nodes_in_manifest.get(parent_node_name) dagger_input = self._generate_dagger_dependency(parent_model_node) inputs_list += dagger_input From 5f26723d449b606ceaf09781c5f8ec418deac3c5 Mon Sep 17 00:00:00 2001 From: Kiran Vasudev Date: Mon, 20 Nov 2023 20:03:47 +0100 Subject: [PATCH 036/189] updates fixtures --- .../modules/dbt_config_parser_fixtures.py | 46 ++++++++++++++++++- 1 file changed, 45 insertions(+), 1 deletion(-) diff --git a/tests/fixtures/modules/dbt_config_parser_fixtures.py b/tests/fixtures/modules/dbt_config_parser_fixtures.py index d3ec583..ff17e07 100644 --- a/tests/fixtures/modules/dbt_config_parser_fixtures.py +++ b/tests/fixtures/modules/dbt_config_parser_fixtures.py @@ -61,7 +61,44 @@ "name": "int_model3", "schema": "analytics_engineering", }, - } + "seed.main.seed_buyer_country_overwrite": { + "database": "awsdatacatalog", + "schema": "analytics_engineering", + "name": "seed_buyer_country_overwrite", + "resource_type": "seed", + "alias": "seed_buyer_country_overwrite", + "tags": ["analytics"], + "description": "", + "created_at": 1700216177.105391, + "depends_on": {"macros": []}, + }, + }, + "sources": { + "source.main.core_schema1.table1": { + "source_name": "table1", + "database": "awsdatacatalog", + "schema": "core_schema1", + "name": "table1", + "tags": ["analytics"], + "description": "", + }, + "source.main.core_schema2.table2": { + "source_name": "table2", + "database": "awsdatacatalog", + "schema": "core_schema2", + "name": "table2", + "tags": ["analytics"], + "description": "", + }, + "source.main.core_schema2.table3": { + "source_name": "table3", + "database": "awsdatacatalog", + "schema": "core_schema2", + "name": "table3", + "tags": ["analytics"], + "description": "", + }, + }, } DBT_PROFILE_FIXTURE = { @@ -90,6 +127,7 @@ "name": "stg_core_schema1__table1", "schema": "core_schema1", "table": "table1", + "follow_external_dependency": True, } ] EXPECTED_STAGING_NODE_MULTIPLE_DEPENDENCIES = [ @@ -98,12 +136,14 @@ "name": "stg_core_schema2__table2", "schema": "core_schema2", "table": "table2", + "follow_external_dependency": True, }, { "type": "athena", "name": "stg_core_schema2__table3", "schema": "core_schema2", "table": "table3", + "follow_external_dependency": True, }, ] @@ -113,18 +153,21 @@ "schema": "core_schema2", "table": "table2", "type": "athena", + "follow_external_dependency": True, }, { "name": "stg_core_schema2__table3", "schema": "core_schema2", "table": "table3", "type": "athena", + "follow_external_dependency": True, }, { "name": "analytics_engineering_model2_athena", "schema": "analytics_engineering", "table": "model2", "type": "athena", + "follow_external_dependency": True, }, { "bucket": "bucket1-data-lake", @@ -140,6 +183,7 @@ "schema": "analytics_engineering", "table": "fct_supplier_revenue", "type": "athena", + "follow_external_dependency": True, }, { "bucket": "bucket1-data-lake", From 9acaff5de7b454d1a6487b7d6cd1a6661342c310 Mon Sep 17 00:00:00 2001 From: Kiran Vasudev Date: Mon, 20 Nov 2023 20:13:47 +0100 Subject: [PATCH 037/189] added test for dbt seed --- tests/utilities/test_dbt_config_parser.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/tests/utilities/test_dbt_config_parser.py b/tests/utilities/test_dbt_config_parser.py index 34d37a1..d9a85d3 100644 --- a/tests/utilities/test_dbt_config_parser.py +++ b/tests/utilities/test_dbt_config_parser.py @@ -12,6 +12,7 @@ DBT_PROFILE_FIXTURE, EXPECTED_STAGING_NODE, EXPECTED_STAGING_NODE_MULTIPLE_DEPENDENCIES, + EXPECTED_SEED_NODE, ) _logger = logging.getLogger("root") @@ -58,6 +59,12 @@ def test_generate_dagger_dependency(self): ], EXPECTED_STAGING_NODE_MULTIPLE_DEPENDENCIES, ), + ( + DBT_MANIFEST_FILE_FIXTURE["nodes"][ + "seed.main.seed_buyer_country_overwrite" + ], + EXPECTED_SEED_NODE, + ), ] for mock_input, expected_output in test_inputs: result = self._dbt_config_parser._generate_dagger_dependency(mock_input) From 876010c9d49910f1c9a263f5e112e60d1b15922e Mon Sep 17 00:00:00 2001 From: Kiran Vasudev Date: Mon, 20 Nov 2023 20:14:09 +0100 Subject: [PATCH 038/189] modified tests for model containing dbt seed as a dependency --- .../fixtures/modules/dbt_config_parser_fixtures.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/tests/fixtures/modules/dbt_config_parser_fixtures.py b/tests/fixtures/modules/dbt_config_parser_fixtures.py index ff17e07..3033def 100644 --- a/tests/fixtures/modules/dbt_config_parser_fixtures.py +++ b/tests/fixtures/modules/dbt_config_parser_fixtures.py @@ -47,6 +47,7 @@ "nodes": [ "source.main.core_schema2.table2", "source.main.core_schema2.table3", + "seed.main.seed_buyer_country_overwrite", ], }, }, @@ -145,6 +146,17 @@ "table": "table3", "follow_external_dependency": True, }, + { + "type": "dummy", + "name": "seed_buyer_country_overwrite", + }, +] + +EXPECTED_SEED_NODE = [ + { + "type": "dummy", + "name": "seed_buyer_country_overwrite", + } ] EXPECTED_DAGGER_INPUTS = [ @@ -162,6 +174,7 @@ "type": "athena", "follow_external_dependency": True, }, + {"name": "seed_buyer_country_overwrite", "type": "dummy"}, { "name": "analytics_engineering_model2_athena", "schema": "analytics_engineering", From 9cd0a51b615809975e29942e03ca1234f0bb2dc7 Mon Sep 17 00:00:00 2001 From: Kiran Vasudev Date: Fri, 24 Nov 2023 12:52:44 +0100 Subject: [PATCH 039/189] refactor * deduplicate list of input dictionaries * created functions that generate the seed input and athena and s3 tasks * removed the follow_external_dependency as true as default for all athena inputs --- dagger/utilities/dbt_config_parser.py | 79 ++++++++++++++++++--------- 1 file changed, 52 insertions(+), 27 deletions(-) diff --git a/dagger/utilities/dbt_config_parser.py b/dagger/utilities/dbt_config_parser.py index caa3f19..d3ce6c7 100644 --- a/dagger/utilities/dbt_config_parser.py +++ b/dagger/utilities/dbt_config_parser.py @@ -1,12 +1,12 @@ import json +from collections import OrderedDict from os import path from os.path import join -from pprint import pprint from typing import Tuple, List, Dict import yaml -ATHENA_TASK_BASE = {"type": "athena", "follow_external_dependency": True} +ATHENA_TASK_BASE = {"type": "athena"} S3_TASK_BASE = {"type": "s3"} @@ -38,7 +38,7 @@ def __init__(self, default_config_parameters: dict): self._nodes_in_manifest = self._manifest_data["nodes"] self._sources_in_manifest = self._manifest_data["sources"] - def _process_seed_input(self, seed_node: dict) -> dict: + def _generate_seed_input(self, seed_node: dict) -> dict: """ Generates a dummy dagger task for the DBT seed node Args: @@ -54,7 +54,39 @@ def _process_seed_input(self, seed_node: dict) -> dict: return task - def _generate_dagger_dependency(self, node: dict) -> List[Dict]: + def _get_athena_task( + self, node: dict, follow_external_dependency: bool = False + ) -> dict: + node_name = node.get("unique_id", "") + + task = ATHENA_TASK_BASE.copy() + if follow_external_dependency: + task["follow_external_dependency"] = True + + task["schema"] = node.get("schema", self._default_schema) + task["table"] = node.get("name", "") + task["name"] = f"{task['schema']}__{task['table']}_athena" + + return task + + def _get_s3_task(self, node: dict) -> dict: + task = S3_TASK_BASE.copy() + + schema = node.get("schema", self._default_schema) + table = node.get("name", "") + task["name"] = f"{schema}__{table}_s3" + task["bucket"] = self._default_data_bucket + task["path"] = self._get_model_data_location(node, schema, table) + + return task + + def _generate_dagger_output(self, node: dict): + return [self._get_athena_task(node), self._get_s3_task(node)] + + def _generate_dagger_inputs( + self, + node: dict, + ) -> List[Dict]: """ Generates the dagger task based on whether the DBT model node is a staging model or not. If the DBT model node represents a staging model, then a dagger athena task is generated for each source of the DBT model. @@ -67,40 +99,27 @@ def _generate_dagger_dependency(self, node: dict) -> List[Dict]: """ model_name = node["name"] - - s3_task = S3_TASK_BASE.copy() dagger_tasks = [] if node.get("resource_type") == "seed": - task = self._process_seed_input(node) + task = self._generate_seed_input(node) dagger_tasks.append(task) elif model_name.startswith("stg_"): source_node_names = node.get("depends_on", {}).get("nodes", []) for source_node_name in source_node_names: if source_node_name.startswith("seed"): source_node = self._nodes_in_manifest[source_node_name] - task = self._process_seed_input(source_node) + task = self._generate_seed_input(source_node) else: source_node = self._sources_in_manifest[source_node_name] - task = ATHENA_TASK_BASE.copy() - - task["schema"] = source_node.get("schema", self._default_schema) - task["table"] = source_node.get("name", "") - task["name"] = f"stg_{task['schema']}__{task['table']}" + task = self._get_athena_task( + source_node, follow_external_dependency=True + ) dagger_tasks.append(task) else: - athena_task = ATHENA_TASK_BASE.copy() - model_schema = node["schema"] - athena_task["name"] = f"{model_schema}_{model_name}_athena" - athena_task["table"] = model_name - athena_task["schema"] = node.get("schema", self._default_schema) - - s3_task["name"] = f"{model_schema}_{model_name}_s3" - s3_task["bucket"] = self._default_data_bucket - s3_task["path"] = self._get_model_data_location( - node, model_schema, model_name - ) + athena_task = self._get_athena_task(node, follow_external_dependency=True) + s3_task = self._get_s3_task(node) dagger_tasks.append(athena_task) dagger_tasks.append(s3_task) @@ -148,10 +167,16 @@ def generate_dagger_io(self, model_name: str) -> Tuple[list, list]: for index, parent_node_name in enumerate(parent_node_names): if not (".int_" in parent_node_name): parent_model_node = self._nodes_in_manifest.get(parent_node_name) - dagger_input = self._generate_dagger_dependency(parent_model_node) + dagger_input = self._generate_dagger_inputs(parent_model_node) inputs_list += dagger_input - output_list = self._generate_dagger_dependency(model_node) + output_list = self._generate_dagger_output(model_node) + + unique_inputs = list( + OrderedDict( + (frozenset(item.items()), item) for item in inputs_list + ).values() + ) - return inputs_list, output_list + return unique_inputs, output_list From 2e6abf6f3bd3ebee59ffd6ad29f1aff00c13c3c1 Mon Sep 17 00:00:00 2001 From: Kiran Vasudev Date: Fri, 24 Nov 2023 12:52:59 +0100 Subject: [PATCH 040/189] updated tests and fixtures --- .../modules/dbt_config_parser_fixtures.py | 34 ++++++++++++------- tests/utilities/test_dbt_config_parser.py | 21 +++++++----- 2 files changed, 35 insertions(+), 20 deletions(-) diff --git a/tests/fixtures/modules/dbt_config_parser_fixtures.py b/tests/fixtures/modules/dbt_config_parser_fixtures.py index 3033def..3e73c0a 100644 --- a/tests/fixtures/modules/dbt_config_parser_fixtures.py +++ b/tests/fixtures/modules/dbt_config_parser_fixtures.py @@ -3,7 +3,8 @@ "model.main.model1": { "database": "awsdatacatalog", "schema": "analytics_engineering", - "name": "fct_supplier_revenue", + "unique_id": "model.main.model1", + "name": "model1", "config": { "external_location": "s3://bucket1-data-lake/path1/model1", "materialized": "incremental", @@ -28,11 +29,13 @@ "model.main.stg_core_schema2__table2", "model.main.model2", "model.main.int_model3", + "seed.main.seed_buyer_country_overwrite", ], }, }, "model.main.stg_core_schema1__table1": { "schema": "analytics_engineering", + "unique_id": "model.main.stg_core_schema1__table1", "name": "stg_core_schema1__table1", "depends_on": { "macros": [], @@ -42,6 +45,7 @@ "model.main.stg_core_schema2__table2": { "schema": "analytics_engineering", "name": "stg_core_schema2__table2", + "unique_id": "model.main.stg_core_schema2__table2", "depends_on": { "macros": [], "nodes": [ @@ -54,17 +58,21 @@ "model.main.model2": { "name": "model2", "schema": "analytics_engineering", + "unique_id": "model.main.model2", "config": { "external_location": "s3://bucket1-data-lake/path2/model2", }, + "depends_on": {"macros": [], "nodes": []}, }, "model.main.int_model3": { "name": "int_model3", + "unique_id": "model.main.int_model3", "schema": "analytics_engineering", }, "seed.main.seed_buyer_country_overwrite": { "database": "awsdatacatalog", "schema": "analytics_engineering", + "unique_id": "seed.main.seed_buyer_country_overwrite", "name": "seed_buyer_country_overwrite", "resource_type": "seed", "alias": "seed_buyer_country_overwrite", @@ -79,6 +87,7 @@ "source_name": "table1", "database": "awsdatacatalog", "schema": "core_schema1", + "unique_id": "source.main.core_schema1.table1", "name": "table1", "tags": ["analytics"], "description": "", @@ -87,6 +96,7 @@ "source_name": "table2", "database": "awsdatacatalog", "schema": "core_schema2", + "unique_id": "source.main.core_schema2.table2", "name": "table2", "tags": ["analytics"], "description": "", @@ -95,6 +105,7 @@ "source_name": "table3", "database": "awsdatacatalog", "schema": "core_schema2", + "unique_id": "source.main.core_schema2.table3", "name": "table3", "tags": ["analytics"], "description": "", @@ -125,7 +136,7 @@ EXPECTED_STAGING_NODE = [ { "type": "athena", - "name": "stg_core_schema1__table1", + "name": "core_schema1__table1_athena", "schema": "core_schema1", "table": "table1", "follow_external_dependency": True, @@ -134,14 +145,14 @@ EXPECTED_STAGING_NODE_MULTIPLE_DEPENDENCIES = [ { "type": "athena", - "name": "stg_core_schema2__table2", + "name": "core_schema2__table2_athena", "schema": "core_schema2", "table": "table2", "follow_external_dependency": True, }, { "type": "athena", - "name": "stg_core_schema2__table3", + "name": "core_schema2__table3_athena", "schema": "core_schema2", "table": "table3", "follow_external_dependency": True, @@ -161,14 +172,14 @@ EXPECTED_DAGGER_INPUTS = [ { - "name": "stg_core_schema2__table2", + "name": "core_schema2__table2_athena", "schema": "core_schema2", "table": "table2", "type": "athena", "follow_external_dependency": True, }, { - "name": "stg_core_schema2__table3", + "name": "core_schema2__table3_athena", "schema": "core_schema2", "table": "table3", "type": "athena", @@ -176,7 +187,7 @@ }, {"name": "seed_buyer_country_overwrite", "type": "dummy"}, { - "name": "analytics_engineering_model2_athena", + "name": "analytics_engineering__model2_athena", "schema": "analytics_engineering", "table": "model2", "type": "athena", @@ -184,7 +195,7 @@ }, { "bucket": "bucket1-data-lake", - "name": "analytics_engineering_model2_s3", + "name": "analytics_engineering__model2_s3", "path": "path2/model2", "type": "s3", }, @@ -192,15 +203,14 @@ EXPECTED_DAGGER_OUTPUTS = [ { - "name": "analytics_engineering_fct_supplier_revenue_athena", + "name": "analytics_engineering__model1_athena", "schema": "analytics_engineering", - "table": "fct_supplier_revenue", + "table": "model1", "type": "athena", - "follow_external_dependency": True, }, { "bucket": "bucket1-data-lake", - "name": "analytics_engineering_fct_supplier_revenue_s3", + "name": "analytics_engineering__model1_s3", "path": "path1/model1", "type": "s3", }, diff --git a/tests/utilities/test_dbt_config_parser.py b/tests/utilities/test_dbt_config_parser.py index d9a85d3..752eaaa 100644 --- a/tests/utilities/test_dbt_config_parser.py +++ b/tests/utilities/test_dbt_config_parser.py @@ -32,9 +32,7 @@ class TestDBTConfigParser(unittest.TestCase): @patch("yaml.safe_load", return_value=DBT_PROFILE_FIXTURE) def setUp(self, mock_open, mock_json_load, mock_safe_load): self._dbt_config_parser = DBTConfigParser(DEFAULT_CONFIG_PARAMS) - self._sample_dbt_node = DBT_MANIFEST_FILE_FIXTURE["nodes"][ - "model.main.stg_core_schema1__table1" - ] + self._sample_dbt_node = DBT_MANIFEST_FILE_FIXTURE["nodes"]["model.main.model1"] @skip("Run only locally") def test_generate_task_configs(self): @@ -45,37 +43,44 @@ def test_generate_task_configs(self): module.generate_task_configs() - def test_generate_dagger_dependency(self): + def test_generate_dagger_inputs(self): test_inputs = [ ( DBT_MANIFEST_FILE_FIXTURE["nodes"][ "model.main.stg_core_schema1__table1" ], EXPECTED_STAGING_NODE, + True, ), ( DBT_MANIFEST_FILE_FIXTURE["nodes"][ "model.main.stg_core_schema2__table2" ], EXPECTED_STAGING_NODE_MULTIPLE_DEPENDENCIES, + True, ), ( DBT_MANIFEST_FILE_FIXTURE["nodes"][ "seed.main.seed_buyer_country_overwrite" ], EXPECTED_SEED_NODE, + False, ), ] - for mock_input, expected_output in test_inputs: - result = self._dbt_config_parser._generate_dagger_dependency(mock_input) + for mock_input, expected_output, follow_external_dependency in test_inputs: + result = self._dbt_config_parser._generate_dagger_inputs(mock_input) self.assertListEqual(result, expected_output) def test_generate_io_inputs(self): - result, _ = self._dbt_config_parser.generate_dagger_io(MODEL_NAME) + result, _ = self._dbt_config_parser.generate_dagger_io( + self._sample_dbt_node.get("name") + ) self.assertListEqual(result, EXPECTED_DAGGER_INPUTS) def test_generate_io_outputs(self): - _, result = self._dbt_config_parser.generate_dagger_io(MODEL_NAME) + _, result = self._dbt_config_parser.generate_dagger_io( + self._sample_dbt_node.get("name") + ) self.assertListEqual(result, EXPECTED_DAGGER_OUTPUTS) From c652d56873fecf94051a4345f8386ec65d77f348 Mon Sep 17 00:00:00 2001 From: Kiran Vasudev Date: Fri, 24 Nov 2023 13:02:58 +0100 Subject: [PATCH 041/189] changed name of seed task generating fn --- dagger/utilities/dbt_config_parser.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/dagger/utilities/dbt_config_parser.py b/dagger/utilities/dbt_config_parser.py index d3ce6c7..dfdcf81 100644 --- a/dagger/utilities/dbt_config_parser.py +++ b/dagger/utilities/dbt_config_parser.py @@ -38,7 +38,8 @@ def __init__(self, default_config_parameters: dict): self._nodes_in_manifest = self._manifest_data["nodes"] self._sources_in_manifest = self._manifest_data["sources"] - def _generate_seed_input(self, seed_node: dict) -> dict: + @staticmethod + def _generate_seed_task(seed_node: dict) -> dict: """ Generates a dummy dagger task for the DBT seed node Args: @@ -102,14 +103,14 @@ def _generate_dagger_inputs( dagger_tasks = [] if node.get("resource_type") == "seed": - task = self._generate_seed_input(node) + task = self._generate_seed_task(node) dagger_tasks.append(task) elif model_name.startswith("stg_"): source_node_names = node.get("depends_on", {}).get("nodes", []) for source_node_name in source_node_names: if source_node_name.startswith("seed"): source_node = self._nodes_in_manifest[source_node_name] - task = self._generate_seed_input(source_node) + task = self._generate_seed_task(source_node) else: source_node = self._sources_in_manifest[source_node_name] task = self._get_athena_task( From ac5f48f1e0615e3310602128dfc231a86cd0b24a Mon Sep 17 00:00:00 2001 From: Kiran Vasudev Date: Fri, 24 Nov 2023 13:03:05 +0100 Subject: [PATCH 042/189] removed unused line --- dagger/utilities/dbt_config_parser.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/dagger/utilities/dbt_config_parser.py b/dagger/utilities/dbt_config_parser.py index dfdcf81..9f1a2a0 100644 --- a/dagger/utilities/dbt_config_parser.py +++ b/dagger/utilities/dbt_config_parser.py @@ -58,8 +58,6 @@ def _generate_seed_task(seed_node: dict) -> dict: def _get_athena_task( self, node: dict, follow_external_dependency: bool = False ) -> dict: - node_name = node.get("unique_id", "") - task = ATHENA_TASK_BASE.copy() if follow_external_dependency: task["follow_external_dependency"] = True From 46a833ded562622b1f0768482606ceb09001066e Mon Sep 17 00:00:00 2001 From: Kiran Vasudev Date: Mon, 27 Nov 2023 13:16:06 +0100 Subject: [PATCH 043/189] added docstrings --- dagger/utilities/dbt_config_parser.py | 28 +++++++++++++++++++++++++++ 1 file changed, 28 insertions(+) diff --git a/dagger/utilities/dbt_config_parser.py b/dagger/utilities/dbt_config_parser.py index 9f1a2a0..1aa6b4e 100644 --- a/dagger/utilities/dbt_config_parser.py +++ b/dagger/utilities/dbt_config_parser.py @@ -58,6 +58,16 @@ def _generate_seed_task(seed_node: dict) -> dict: def _get_athena_task( self, node: dict, follow_external_dependency: bool = False ) -> dict: + """ + Generates the dagger athena task for the DBT model node + Args: + node: The extracted node from the manifest.json file + follow_external_dependency: Whether to follow external airflow dependencies or not + + Returns: + dict: The dagger athena task for the DBT model node + + """ task = ATHENA_TASK_BASE.copy() if follow_external_dependency: task["follow_external_dependency"] = True @@ -69,6 +79,15 @@ def _get_athena_task( return task def _get_s3_task(self, node: dict) -> dict: + """ + Generates the dagger s3 task for the DBT model node + Args: + node: The extracted node from the manifest.json file + + Returns: + dict: The dagger s3 task for the DBT model node + + """ task = S3_TASK_BASE.copy() schema = node.get("schema", self._default_schema) @@ -80,6 +99,15 @@ def _get_s3_task(self, node: dict) -> dict: return task def _generate_dagger_output(self, node: dict): + """ + Generates the dagger output for the DBT model node + Args: + node: The extracted node from the manifest.json file + + Returns: + dict: The dagger output, which is a combination of an athena and s3 task for the DBT model node + + """ return [self._get_athena_task(node), self._get_s3_task(node)] def _generate_dagger_inputs( From cdd0fea29c0d6154232d4ae52963e0f25a037848 Mon Sep 17 00:00:00 2001 From: Kiran Vasudev Date: Mon, 27 Nov 2023 13:17:09 +0100 Subject: [PATCH 044/189] removed unused test parameter --- tests/utilities/test_dbt_config_parser.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/tests/utilities/test_dbt_config_parser.py b/tests/utilities/test_dbt_config_parser.py index 752eaaa..76ecd30 100644 --- a/tests/utilities/test_dbt_config_parser.py +++ b/tests/utilities/test_dbt_config_parser.py @@ -50,24 +50,21 @@ def test_generate_dagger_inputs(self): "model.main.stg_core_schema1__table1" ], EXPECTED_STAGING_NODE, - True, ), ( DBT_MANIFEST_FILE_FIXTURE["nodes"][ "model.main.stg_core_schema2__table2" ], EXPECTED_STAGING_NODE_MULTIPLE_DEPENDENCIES, - True, ), ( DBT_MANIFEST_FILE_FIXTURE["nodes"][ "seed.main.seed_buyer_country_overwrite" ], EXPECTED_SEED_NODE, - False, ), ] - for mock_input, expected_output, follow_external_dependency in test_inputs: + for mock_input, expected_output in test_inputs: result = self._dbt_config_parser._generate_dagger_inputs(mock_input) self.assertListEqual(result, expected_output) From 829ac94eb03bddd45d94bff9cf010eeb2bd16eb4 Mon Sep 17 00:00:00 2001 From: Kiran Vasudev Date: Mon, 27 Nov 2023 14:30:17 +0100 Subject: [PATCH 045/189] changed name of function for better understanding --- dagger/utilities/dbt_config_parser.py | 4 ++-- tests/utilities/test_dbt_config_parser.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/dagger/utilities/dbt_config_parser.py b/dagger/utilities/dbt_config_parser.py index 1aa6b4e..28d1761 100644 --- a/dagger/utilities/dbt_config_parser.py +++ b/dagger/utilities/dbt_config_parser.py @@ -110,7 +110,7 @@ def _generate_dagger_output(self, node: dict): """ return [self._get_athena_task(node), self._get_s3_task(node)] - def _generate_dagger_inputs( + def _generate_dagger_tasks( self, node: dict, ) -> List[Dict]: @@ -194,7 +194,7 @@ def generate_dagger_io(self, model_name: str) -> Tuple[list, list]: for index, parent_node_name in enumerate(parent_node_names): if not (".int_" in parent_node_name): parent_model_node = self._nodes_in_manifest.get(parent_node_name) - dagger_input = self._generate_dagger_inputs(parent_model_node) + dagger_input = self._generate_dagger_tasks(parent_model_node) inputs_list += dagger_input diff --git a/tests/utilities/test_dbt_config_parser.py b/tests/utilities/test_dbt_config_parser.py index 76ecd30..7df0a66 100644 --- a/tests/utilities/test_dbt_config_parser.py +++ b/tests/utilities/test_dbt_config_parser.py @@ -65,7 +65,7 @@ def test_generate_dagger_inputs(self): ), ] for mock_input, expected_output in test_inputs: - result = self._dbt_config_parser._generate_dagger_inputs(mock_input) + result = self._dbt_config_parser._generate_dagger_tasks(mock_input) self.assertListEqual(result, expected_output) def test_generate_io_inputs(self): From eaf42865b48c6ad2bd264eefd77e5655fbd0d8ea Mon Sep 17 00:00:00 2001 From: Kiran Vasudev Date: Mon, 27 Nov 2023 14:30:47 +0100 Subject: [PATCH 046/189] added test to check for de-duplication of inputs --- .../modules/dbt_config_parser_fixtures.py | 51 +++++++++++++++++++ tests/utilities/test_dbt_config_parser.py | 17 +++++-- 2 files changed, 64 insertions(+), 4 deletions(-) diff --git a/tests/fixtures/modules/dbt_config_parser_fixtures.py b/tests/fixtures/modules/dbt_config_parser_fixtures.py index 3e73c0a..84e8a81 100644 --- a/tests/fixtures/modules/dbt_config_parser_fixtures.py +++ b/tests/fixtures/modules/dbt_config_parser_fixtures.py @@ -81,6 +81,23 @@ "created_at": 1700216177.105391, "depends_on": {"macros": []}, }, + "model.main.model3": { + "name": "model3", + "schema": "analytics_engineering", + "unique_id": "model.main.model3", + "config": { + "external_location": "s3://bucket1-data-lake/path2/model3", + }, + "depends_on": { + "macros": [], + "nodes": [ + "model.main.int_model3", + "model.main.model2", + "seed.main.seed_buyer_country_overwrite", + "model.main.stg_core_schema2__table2", + ], + }, + }, }, "sources": { "source.main.core_schema1.table1": { @@ -170,6 +187,40 @@ } ] +EXPECTED_MODEL_MULTIPLE_DEPENDENCIES = [ + { + "type": "athena", + "name": "analytics_engineering__model2_athena", + "schema": "analytics_engineering", + "table": "model2", + "follow_external_dependency": True, + }, + { + "bucket": "bucket1-data-lake", + "name": "analytics_engineering__model2_s3", + "path": "path2/model2", + "type": "s3", + }, + { + "type": "dummy", + "name": "seed_buyer_country_overwrite", + }, + { + "type": "athena", + "name": "core_schema2__table2_athena", + "schema": "core_schema2", + "table": "table2", + "follow_external_dependency": True, + }, + { + "type": "athena", + "name": "core_schema2__table3_athena", + "schema": "core_schema2", + "table": "table3", + "follow_external_dependency": True, + }, +] + EXPECTED_DAGGER_INPUTS = [ { "name": "core_schema2__table2_athena", diff --git a/tests/utilities/test_dbt_config_parser.py b/tests/utilities/test_dbt_config_parser.py index 7df0a66..ecc7fc7 100644 --- a/tests/utilities/test_dbt_config_parser.py +++ b/tests/utilities/test_dbt_config_parser.py @@ -13,6 +13,7 @@ EXPECTED_STAGING_NODE, EXPECTED_STAGING_NODE_MULTIPLE_DEPENDENCIES, EXPECTED_SEED_NODE, + EXPECTED_MODEL_MULTIPLE_DEPENDENCIES, ) _logger = logging.getLogger("root") @@ -69,11 +70,19 @@ def test_generate_dagger_inputs(self): self.assertListEqual(result, expected_output) def test_generate_io_inputs(self): - result, _ = self._dbt_config_parser.generate_dagger_io( - self._sample_dbt_node.get("name") - ) + fixtures = [ + ("model1", EXPECTED_DAGGER_INPUTS), + ( + "model3", + EXPECTED_MODEL_MULTIPLE_DEPENDENCIES, + ), + ] + for mock_input, expected_output in fixtures: + result, _ = self._dbt_config_parser.generate_dagger_io( + mock_input + ) - self.assertListEqual(result, EXPECTED_DAGGER_INPUTS) + self.assertListEqual(result, expected_output) def test_generate_io_outputs(self): _, result = self._dbt_config_parser.generate_dagger_io( From ab9a2c892a227ec918a32162a6224e2fc5705adc Mon Sep 17 00:00:00 2001 From: Kiran Vasudev Date: Wed, 29 Nov 2023 11:28:59 +0100 Subject: [PATCH 047/189] refactored getting model location function this was done because the bucket name in the main module config can be different for how the manifest file is compiled --- dagger/utilities/dbt_config_parser.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/dagger/utilities/dbt_config_parser.py b/dagger/utilities/dbt_config_parser.py index 28d1761..4cd71d7 100644 --- a/dagger/utilities/dbt_config_parser.py +++ b/dagger/utilities/dbt_config_parser.py @@ -94,7 +94,7 @@ def _get_s3_task(self, node: dict) -> dict: table = node.get("name", "") task["name"] = f"{schema}__{table}_s3" task["bucket"] = self._default_data_bucket - task["path"] = self._get_model_data_location(node, schema, table) + task["path"] = self._get_model_data_location(node, schema, table)[1] return task @@ -155,7 +155,7 @@ def _generate_dagger_tasks( def _get_model_data_location( self, node: dict, schema: str, dbt_model_name: str - ) -> str: + ) -> Tuple[str, str]: """ Gets the S3 path of the dbt model relative to the data bucket. If external location is not specified in the DBT model config, then the default data directory from the @@ -173,7 +173,10 @@ def _get_model_data_location( if not location: location = join(self._default_data_dir, schema, dbt_model_name) - return location.split(self._default_data_bucket)[1].lstrip("/") + split = location.split("//")[1].split("/") + bucket_name, data_path = split[0], "/".join(split[1:]) + + return bucket_name, data_path def generate_dagger_io(self, model_name: str) -> Tuple[list, list]: """ From 8ef70503dd7302d7abe4c3dacfea0297efa06c59 Mon Sep 17 00:00:00 2001 From: Kiran Vasudev Date: Wed, 29 Nov 2023 11:30:14 +0100 Subject: [PATCH 048/189] refactor dummy task generation --- dagger/utilities/dbt_config_parser.py | 21 ++++++++++++++------- 1 file changed, 14 insertions(+), 7 deletions(-) diff --git a/dagger/utilities/dbt_config_parser.py b/dagger/utilities/dbt_config_parser.py index 4cd71d7..d9fe1dd 100644 --- a/dagger/utilities/dbt_config_parser.py +++ b/dagger/utilities/dbt_config_parser.py @@ -39,20 +39,23 @@ def __init__(self, default_config_parameters: dict): self._sources_in_manifest = self._manifest_data["sources"] @staticmethod - def _generate_seed_task(seed_node: dict) -> dict: + def _get_dummy_task(node: dict, follow_external_dependency: bool = False) -> dict: """ - Generates a dummy dagger task for the DBT seed node + Generates a dummy dagger task Args: - seed_node: The extracted seed node from the manifest.json file + node: The extracted node from the manifest.json file Returns: - dict: The dummy dagger task for the DBT seed node + dict: The dummy dagger task for the DBT node """ task = {} - task["name"] = seed_node.get("name", "") + task["name"] = node.get("name", "") task["type"] = "dummy" + if follow_external_dependency: + task["follow_external_dependency"] = True + return task def _get_athena_task( @@ -116,6 +119,7 @@ def _generate_dagger_tasks( ) -> List[Dict]: """ Generates the dagger task based on whether the DBT model node is a staging model or not. + If the DBT model node represents a DBT seed or an ephemeral model, then a dagger dummy task is generated. If the DBT model node represents a staging model, then a dagger athena task is generated for each source of the DBT model. If the DBT model node is not a staging model, then a dagger athena task and an s3 task is generated for the DBT model node itself. Args: @@ -129,14 +133,17 @@ def _generate_dagger_tasks( dagger_tasks = [] if node.get("resource_type") == "seed": - task = self._generate_seed_task(node) + task = self._get_dummy_task(node) + dagger_tasks.append(task) + elif node.get("config",{}).get("materialized") == "ephemeral": + task = self._get_dummy_task(node, follow_external_dependency=True) dagger_tasks.append(task) elif model_name.startswith("stg_"): source_node_names = node.get("depends_on", {}).get("nodes", []) for source_node_name in source_node_names: if source_node_name.startswith("seed"): source_node = self._nodes_in_manifest[source_node_name] - task = self._generate_seed_task(source_node) + task = self._get_dummy_task(source_node) else: source_node = self._sources_in_manifest[source_node_name] task = self._get_athena_task( From 65a07d76877734177805ced9fffc1ba0e9965471 Mon Sep 17 00:00:00 2001 From: Kiran Vasudev Date: Wed, 29 Nov 2023 11:31:12 +0100 Subject: [PATCH 049/189] generate inputs for intermediate models and updated tests --- dagger/utilities/dbt_config_parser.py | 7 ++--- .../modules/dbt_config_parser_fixtures.py | 28 +++++++++++++++++++ tests/utilities/test_dbt_config_parser.py | 11 +++++--- 3 files changed, 38 insertions(+), 8 deletions(-) diff --git a/dagger/utilities/dbt_config_parser.py b/dagger/utilities/dbt_config_parser.py index d9fe1dd..a4390fb 100644 --- a/dagger/utilities/dbt_config_parser.py +++ b/dagger/utilities/dbt_config_parser.py @@ -202,11 +202,10 @@ def generate_dagger_io(self, model_name: str) -> Tuple[list, list]: parent_node_names = model_node.get("depends_on", {}).get("nodes", []) for index, parent_node_name in enumerate(parent_node_names): - if not (".int_" in parent_node_name): - parent_model_node = self._nodes_in_manifest.get(parent_node_name) - dagger_input = self._generate_dagger_tasks(parent_model_node) + parent_model_node = self._nodes_in_manifest.get(parent_node_name) + dagger_input = self._generate_dagger_tasks(parent_model_node) - inputs_list += dagger_input + inputs_list += dagger_input output_list = self._generate_dagger_output(model_node) diff --git a/tests/fixtures/modules/dbt_config_parser_fixtures.py b/tests/fixtures/modules/dbt_config_parser_fixtures.py index 84e8a81..ab887d4 100644 --- a/tests/fixtures/modules/dbt_config_parser_fixtures.py +++ b/tests/fixtures/modules/dbt_config_parser_fixtures.py @@ -37,6 +37,9 @@ "schema": "analytics_engineering", "unique_id": "model.main.stg_core_schema1__table1", "name": "stg_core_schema1__table1", + "config": { + "materialized": "view", + }, "depends_on": { "macros": [], "nodes": ["source.main.core_schema1.table1"], @@ -46,6 +49,9 @@ "schema": "analytics_engineering", "name": "stg_core_schema2__table2", "unique_id": "model.main.stg_core_schema2__table2", + "config": { + "materialized": "view", + }, "depends_on": { "macros": [], "nodes": [ @@ -61,6 +67,7 @@ "unique_id": "model.main.model2", "config": { "external_location": "s3://bucket1-data-lake/path2/model2", + "materialized": "table", }, "depends_on": {"macros": [], "nodes": []}, }, @@ -68,6 +75,9 @@ "name": "int_model3", "unique_id": "model.main.int_model3", "schema": "analytics_engineering", + "config": { + "materialized": "ephemeral", + }, }, "seed.main.seed_buyer_country_overwrite": { "database": "awsdatacatalog", @@ -188,6 +198,11 @@ ] EXPECTED_MODEL_MULTIPLE_DEPENDENCIES = [ + { + "type": "dummy", + "name": "int_model3", + "follow_external_dependency": True, + }, { "type": "athena", "name": "analytics_engineering__model2_athena", @@ -221,6 +236,14 @@ }, ] +EXPECTED_EPHEMERAL_NODE = [ + { + "type": "dummy", + "name": "int_model3", + "follow_external_dependency": True, + } +] + EXPECTED_DAGGER_INPUTS = [ { "name": "core_schema2__table2_athena", @@ -250,6 +273,11 @@ "path": "path2/model2", "type": "s3", }, + { + "type": "dummy", + "name": "int_model3", + "follow_external_dependency": True, + }, ] EXPECTED_DAGGER_OUTPUTS = [ diff --git a/tests/utilities/test_dbt_config_parser.py b/tests/utilities/test_dbt_config_parser.py index ecc7fc7..549e41c 100644 --- a/tests/utilities/test_dbt_config_parser.py +++ b/tests/utilities/test_dbt_config_parser.py @@ -14,6 +14,7 @@ EXPECTED_STAGING_NODE_MULTIPLE_DEPENDENCIES, EXPECTED_SEED_NODE, EXPECTED_MODEL_MULTIPLE_DEPENDENCIES, + EXPECTED_EPHEMERAL_NODE, ) _logger = logging.getLogger("root") @@ -35,7 +36,7 @@ def setUp(self, mock_open, mock_json_load, mock_safe_load): self._dbt_config_parser = DBTConfigParser(DEFAULT_CONFIG_PARAMS) self._sample_dbt_node = DBT_MANIFEST_FILE_FIXTURE["nodes"]["model.main.model1"] - @skip("Run only locally") + # @skip("Run only locally") def test_generate_task_configs(self): module = Module( path_to_config="./tests/fixtures/modules/dbt_test_config.yaml", @@ -64,6 +65,10 @@ def test_generate_dagger_inputs(self): ], EXPECTED_SEED_NODE, ), + ( + DBT_MANIFEST_FILE_FIXTURE["nodes"]["model.main.int_model3"], + EXPECTED_EPHEMERAL_NODE, + ), ] for mock_input, expected_output in test_inputs: result = self._dbt_config_parser._generate_dagger_tasks(mock_input) @@ -78,9 +83,7 @@ def test_generate_io_inputs(self): ), ] for mock_input, expected_output in fixtures: - result, _ = self._dbt_config_parser.generate_dagger_io( - mock_input - ) + result, _ = self._dbt_config_parser.generate_dagger_io(mock_input) self.assertListEqual(result, expected_output) From e5afcfa980d93129e20638c4778b5b9c39220408 Mon Sep 17 00:00:00 2001 From: Kiran Vasudev Date: Wed, 29 Nov 2023 11:32:54 +0100 Subject: [PATCH 050/189] uncomment skipping local test --- tests/utilities/test_dbt_config_parser.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/utilities/test_dbt_config_parser.py b/tests/utilities/test_dbt_config_parser.py index 549e41c..7c3557c 100644 --- a/tests/utilities/test_dbt_config_parser.py +++ b/tests/utilities/test_dbt_config_parser.py @@ -36,7 +36,7 @@ def setUp(self, mock_open, mock_json_load, mock_safe_load): self._dbt_config_parser = DBTConfigParser(DEFAULT_CONFIG_PARAMS) self._sample_dbt_node = DBT_MANIFEST_FILE_FIXTURE["nodes"]["model.main.model1"] - # @skip("Run only locally") + @skip("Run only locally") def test_generate_task_configs(self): module = Module( path_to_config="./tests/fixtures/modules/dbt_test_config.yaml", From e28a078a573339aed405187adfc67497cfbf3020 Mon Sep 17 00:00:00 2001 From: Kiran Vasudev Date: Wed, 29 Nov 2023 15:31:51 +0100 Subject: [PATCH 051/189] refactor generate_dagger_tasks fn to make recursive --- dagger/utilities/dbt_config_parser.py | 45 ++++++++++++++------------- 1 file changed, 24 insertions(+), 21 deletions(-) diff --git a/dagger/utilities/dbt_config_parser.py b/dagger/utilities/dbt_config_parser.py index a4390fb..6c80cb4 100644 --- a/dagger/utilities/dbt_config_parser.py +++ b/dagger/utilities/dbt_config_parser.py @@ -103,7 +103,8 @@ def _get_s3_task(self, node: dict) -> dict: def _generate_dagger_output(self, node: dict): """ - Generates the dagger output for the DBT model node + Generates the dagger output for the DBT model node. If the model is materialized as a view or ephemeral, then a dummy task is created. + Otherwise, an athena and s3 task is created for the DBT model node. Args: node: The extracted node from the manifest.json file @@ -111,16 +112,19 @@ def _generate_dagger_output(self, node: dict): dict: The dagger output, which is a combination of an athena and s3 task for the DBT model node """ - return [self._get_athena_task(node), self._get_s3_task(node)] + if node.get("config", {}).get("materialized") in ("view", "ephemeral"): + return [self._get_dummy_task(node)] + else: + return [self._get_athena_task(node), self._get_s3_task(node)] def _generate_dagger_tasks( self, - node: dict, + node_name: str, ) -> List[Dict]: """ Generates the dagger task based on whether the DBT model node is a staging model or not. If the DBT model node represents a DBT seed or an ephemeral model, then a dagger dummy task is generated. - If the DBT model node represents a staging model, then a dagger athena task is generated for each source of the DBT model. + If the DBT model node represents a staging model, then a dagger athena task is generated for each source of the DBT model. Apart from this, a dummy task is also generated for the staging model itself. If the DBT model node is not a staging model, then a dagger athena task and an s3 task is generated for the DBT model node itself. Args: node: The extracted node from the manifest.json file @@ -129,28 +133,28 @@ def _generate_dagger_tasks( List[Dict]: The respective dagger tasks for the DBT model node """ - model_name = node["name"] dagger_tasks = [] + if node_name.startswith("source"): + node = self._sources_in_manifest[node_name] + else: + node = self._nodes_in_manifest[node_name] + if node.get("resource_type") == "seed": task = self._get_dummy_task(node) dagger_tasks.append(task) - elif node.get("config",{}).get("materialized") == "ephemeral": + elif node.get("resource_type") == 'source': + athena_task = self._get_athena_task(node, follow_external_dependency=True) + dagger_tasks.append(athena_task) + elif node.get("config", {}).get("materialized") == "ephemeral": task = self._get_dummy_task(node, follow_external_dependency=True) dagger_tasks.append(task) - elif model_name.startswith("stg_"): + elif node.get("name").startswith("stg_"): source_node_names = node.get("depends_on", {}).get("nodes", []) + dagger_tasks.append(self._get_dummy_task(node)) for source_node_name in source_node_names: - if source_node_name.startswith("seed"): - source_node = self._nodes_in_manifest[source_node_name] - task = self._get_dummy_task(source_node) - else: - source_node = self._sources_in_manifest[source_node_name] - task = self._get_athena_task( - source_node, follow_external_dependency=True - ) - - dagger_tasks.append(task) + task = self._generate_dagger_tasks(source_node_name) + dagger_tasks.extend(task) else: athena_task = self._get_athena_task(node, follow_external_dependency=True) s3_task = self._get_s3_task(node) @@ -185,7 +189,7 @@ def _get_model_data_location( return bucket_name, data_path - def generate_dagger_io(self, model_name: str) -> Tuple[list, list]: + def generate_dagger_io(self, model_name: str) -> Tuple[List[dict], List[dict]]: """ Parse through all the parents of the DBT model and return the dagger inputs and outputs for the DBT model Args: @@ -201,9 +205,8 @@ def generate_dagger_io(self, model_name: str) -> Tuple[list, list]: parent_node_names = model_node.get("depends_on", {}).get("nodes", []) - for index, parent_node_name in enumerate(parent_node_names): - parent_model_node = self._nodes_in_manifest.get(parent_node_name) - dagger_input = self._generate_dagger_tasks(parent_model_node) + for parent_node_name in parent_node_names: + dagger_input = self._generate_dagger_tasks(parent_node_name) inputs_list += dagger_input From ab482292545cfb1b3b57912c56714a01235e90d1 Mon Sep 17 00:00:00 2001 From: Kiran Vasudev Date: Wed, 29 Nov 2023 15:32:49 +0100 Subject: [PATCH 052/189] updated fixtures and tests --- .../modules/dbt_config_parser_fixtures.py | 34 ++++++++++++++++++- tests/utilities/test_dbt_config_parser.py | 28 +++++++-------- 2 files changed, 47 insertions(+), 15 deletions(-) diff --git a/tests/fixtures/modules/dbt_config_parser_fixtures.py b/tests/fixtures/modules/dbt_config_parser_fixtures.py index ab887d4..432b2a3 100644 --- a/tests/fixtures/modules/dbt_config_parser_fixtures.py +++ b/tests/fixtures/modules/dbt_config_parser_fixtures.py @@ -114,6 +114,7 @@ "source_name": "table1", "database": "awsdatacatalog", "schema": "core_schema1", + "resource_type": "source", "unique_id": "source.main.core_schema1.table1", "name": "table1", "tags": ["analytics"], @@ -123,6 +124,7 @@ "source_name": "table2", "database": "awsdatacatalog", "schema": "core_schema2", + "resource_type": "source", "unique_id": "source.main.core_schema2.table2", "name": "table2", "tags": ["analytics"], @@ -132,6 +134,7 @@ "source_name": "table3", "database": "awsdatacatalog", "schema": "core_schema2", + "resource_type": "source", "unique_id": "source.main.core_schema2.table3", "name": "table3", "tags": ["analytics"], @@ -161,15 +164,17 @@ } EXPECTED_STAGING_NODE = [ + {"name": "stg_core_schema1__table1", "type": "dummy"}, { "type": "athena", "name": "core_schema1__table1_athena", "schema": "core_schema1", "table": "table1", "follow_external_dependency": True, - } + }, ] EXPECTED_STAGING_NODE_MULTIPLE_DEPENDENCIES = [ + {"name": "stg_core_schema2__table2", "type": "dummy"}, { "type": "athena", "name": "core_schema2__table2_athena", @@ -220,6 +225,7 @@ "type": "dummy", "name": "seed_buyer_country_overwrite", }, + {"name": "stg_core_schema2__table2", "type": "dummy"}, { "type": "athena", "name": "core_schema2__table2_athena", @@ -245,6 +251,7 @@ ] EXPECTED_DAGGER_INPUTS = [ + {"name": "stg_core_schema2__table2", "type": "dummy"}, { "name": "core_schema2__table2_athena", "schema": "core_schema2", @@ -280,6 +287,24 @@ }, ] +EXPECTED_DBT_STAGING_MODEL_DAGGER_INPUTS = [ + { + "follow_external_dependency": True, + "name": "core_schema2__table2_athena", + "schema": "core_schema2", + "table": "table2", + "type": "athena", + }, + { + "follow_external_dependency": True, + "name": "core_schema2__table3_athena", + "schema": "core_schema2", + "table": "table3", + "type": "athena", + }, + {"name": "seed_buyer_country_overwrite", "type": "dummy"}, +] + EXPECTED_DAGGER_OUTPUTS = [ { "name": "analytics_engineering__model1_athena", @@ -294,3 +319,10 @@ "type": "s3", }, ] + +EXPECTED_DBT_STAGING_MODEL_DAGGER_OUTPUTS = [ + { + "type": "dummy", + "name": "stg_core_schema2__table2", + }, +] diff --git a/tests/utilities/test_dbt_config_parser.py b/tests/utilities/test_dbt_config_parser.py index 7c3557c..be9b3dc 100644 --- a/tests/utilities/test_dbt_config_parser.py +++ b/tests/utilities/test_dbt_config_parser.py @@ -15,6 +15,8 @@ EXPECTED_SEED_NODE, EXPECTED_MODEL_MULTIPLE_DEPENDENCIES, EXPECTED_EPHEMERAL_NODE, + EXPECTED_DBT_STAGING_MODEL_DAGGER_OUTPUTS, + EXPECTED_DBT_STAGING_MODEL_DAGGER_INPUTS, ) _logger = logging.getLogger("root") @@ -48,25 +50,19 @@ def test_generate_task_configs(self): def test_generate_dagger_inputs(self): test_inputs = [ ( - DBT_MANIFEST_FILE_FIXTURE["nodes"][ - "model.main.stg_core_schema1__table1" - ], + "model.main.stg_core_schema1__table1", EXPECTED_STAGING_NODE, ), ( - DBT_MANIFEST_FILE_FIXTURE["nodes"][ - "model.main.stg_core_schema2__table2" - ], + "model.main.stg_core_schema2__table2", EXPECTED_STAGING_NODE_MULTIPLE_DEPENDENCIES, ), ( - DBT_MANIFEST_FILE_FIXTURE["nodes"][ - "seed.main.seed_buyer_country_overwrite" - ], + "seed.main.seed_buyer_country_overwrite", EXPECTED_SEED_NODE, ), ( - DBT_MANIFEST_FILE_FIXTURE["nodes"]["model.main.int_model3"], + "model.main.int_model3", EXPECTED_EPHEMERAL_NODE, ), ] @@ -81,6 +77,7 @@ def test_generate_io_inputs(self): "model3", EXPECTED_MODEL_MULTIPLE_DEPENDENCIES, ), + ("stg_core_schema2__table2", EXPECTED_DBT_STAGING_MODEL_DAGGER_INPUTS), ] for mock_input, expected_output in fixtures: result, _ = self._dbt_config_parser.generate_dagger_io(mock_input) @@ -88,8 +85,11 @@ def test_generate_io_inputs(self): self.assertListEqual(result, expected_output) def test_generate_io_outputs(self): - _, result = self._dbt_config_parser.generate_dagger_io( - self._sample_dbt_node.get("name") - ) + fixtures = [ + ("model1", EXPECTED_DAGGER_OUTPUTS), + ("stg_core_schema2__table2", EXPECTED_DBT_STAGING_MODEL_DAGGER_OUTPUTS), + ] + for mock_input, expected_output in fixtures: + _, result = self._dbt_config_parser.generate_dagger_io(mock_input) - self.assertListEqual(result, EXPECTED_DAGGER_OUTPUTS) + self.assertListEqual(result, expected_output) From d653977097c11342f015ab5c0b39e202da55250b Mon Sep 17 00:00:00 2001 From: Kiran Vasudev Date: Mon, 4 Dec 2023 18:47:01 +0100 Subject: [PATCH 053/189] bugfix --- dagger/dag_creator/airflow/dag_creator.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/dagger/dag_creator/airflow/dag_creator.py b/dagger/dag_creator/airflow/dag_creator.py index 5db416f..2b208b9 100644 --- a/dagger/dag_creator/airflow/dag_creator.py +++ b/dagger/dag_creator/airflow/dag_creator.py @@ -163,13 +163,15 @@ def _create_edge_without_data(self, from_task_id: str, to_task_ids: list, node: external_task_sensor = self._get_external_task_sensor( from_task_id, to_task_id, edge_properties.follow_external_dependency ) - self._sensor_dict[to_pipe] = { + + if self._sensor_dict.get(to_pipe) is None: + self._sensor_dict[to_pipe] = {} + + self._sensor_dict[to_pipe].update({ external_task_sensor_name: external_task_sensor - } - ( - self._tasks[self._get_control_flow_task_id(to_pipe)] - >> external_task_sensor - ) + }) + + self._tasks[self._get_control_flow_task_id(to_pipe)] >> external_task_sensor self._sensor_dict[to_pipe][external_task_sensor_name] >> self._tasks[to_task_id] else: self._tasks[self._get_control_flow_task_id(to_pipe)] >> self._tasks[to_task_id] From 48a102fcecfe3dda9c5e280e33ab081f4319f4f1 Mon Sep 17 00:00:00 2001 From: Kiran Vasudev Date: Mon, 4 Dec 2023 19:16:58 +0100 Subject: [PATCH 054/189] only return dummy when stg model --- dagger/utilities/dbt_config_parser.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/dagger/utilities/dbt_config_parser.py b/dagger/utilities/dbt_config_parser.py index 6c80cb4..dd16008 100644 --- a/dagger/utilities/dbt_config_parser.py +++ b/dagger/utilities/dbt_config_parser.py @@ -143,18 +143,14 @@ def _generate_dagger_tasks( if node.get("resource_type") == "seed": task = self._get_dummy_task(node) dagger_tasks.append(task) - elif node.get("resource_type") == 'source': + elif node.get("resource_type") == "source": athena_task = self._get_athena_task(node, follow_external_dependency=True) dagger_tasks.append(athena_task) elif node.get("config", {}).get("materialized") == "ephemeral": task = self._get_dummy_task(node, follow_external_dependency=True) dagger_tasks.append(task) elif node.get("name").startswith("stg_"): - source_node_names = node.get("depends_on", {}).get("nodes", []) dagger_tasks.append(self._get_dummy_task(node)) - for source_node_name in source_node_names: - task = self._generate_dagger_tasks(source_node_name) - dagger_tasks.extend(task) else: athena_task = self._get_athena_task(node, follow_external_dependency=True) s3_task = self._get_s3_task(node) From f6c7226606956e74e91cb19e58207f8897fe97a1 Mon Sep 17 00:00:00 2001 From: Kiran Vasudev Date: Mon, 4 Dec 2023 19:19:37 +0100 Subject: [PATCH 055/189] adapted tests --- .../modules/dbt_config_parser_fixtures.py | 68 ++++--------------- tests/utilities/test_dbt_config_parser.py | 12 ++-- 2 files changed, 20 insertions(+), 60 deletions(-) diff --git a/tests/fixtures/modules/dbt_config_parser_fixtures.py b/tests/fixtures/modules/dbt_config_parser_fixtures.py index 432b2a3..2f42778 100644 --- a/tests/fixtures/modules/dbt_config_parser_fixtures.py +++ b/tests/fixtures/modules/dbt_config_parser_fixtures.py @@ -165,34 +165,6 @@ EXPECTED_STAGING_NODE = [ {"name": "stg_core_schema1__table1", "type": "dummy"}, - { - "type": "athena", - "name": "core_schema1__table1_athena", - "schema": "core_schema1", - "table": "table1", - "follow_external_dependency": True, - }, -] -EXPECTED_STAGING_NODE_MULTIPLE_DEPENDENCIES = [ - {"name": "stg_core_schema2__table2", "type": "dummy"}, - { - "type": "athena", - "name": "core_schema2__table2_athena", - "schema": "core_schema2", - "table": "table2", - "follow_external_dependency": True, - }, - { - "type": "athena", - "name": "core_schema2__table3_athena", - "schema": "core_schema2", - "table": "table3", - "follow_external_dependency": True, - }, - { - "type": "dummy", - "name": "seed_buyer_country_overwrite", - }, ] EXPECTED_SEED_NODE = [ @@ -225,21 +197,7 @@ "type": "dummy", "name": "seed_buyer_country_overwrite", }, - {"name": "stg_core_schema2__table2", "type": "dummy"}, - { - "type": "athena", - "name": "core_schema2__table2_athena", - "schema": "core_schema2", - "table": "table2", - "follow_external_dependency": True, - }, - { - "type": "athena", - "name": "core_schema2__table3_athena", - "schema": "core_schema2", - "table": "table3", - "follow_external_dependency": True, - }, + {"name": "stg_core_schema2__table2", "type": "dummy"} ] EXPECTED_EPHEMERAL_NODE = [ @@ -250,23 +208,24 @@ } ] -EXPECTED_DAGGER_INPUTS = [ - {"name": "stg_core_schema2__table2", "type": "dummy"}, +EXPECTED_MODEL_NODE = [ { - "name": "core_schema2__table2_athena", - "schema": "core_schema2", - "table": "table2", "type": "athena", + "name": "analytics_engineering__model1_athena", + "schema": "analytics_engineering", + "table": "model1", "follow_external_dependency": True, }, { - "name": "core_schema2__table3_athena", - "schema": "core_schema2", - "table": "table3", - "type": "athena", - "follow_external_dependency": True, + "bucket": "bucket1-data-lake", + "name": "analytics_engineering__model1_s3", + "path": "path1/model1", + "type": "s3", }, - {"name": "seed_buyer_country_overwrite", "type": "dummy"}, +] + +EXPECTED_DAGGER_INPUTS = [ + {"name": "stg_core_schema2__table2", "type": "dummy"}, { "name": "analytics_engineering__model2_athena", "schema": "analytics_engineering", @@ -285,6 +244,7 @@ "name": "int_model3", "follow_external_dependency": True, }, + {"name": "seed_buyer_country_overwrite", "type": "dummy"}, ] EXPECTED_DBT_STAGING_MODEL_DAGGER_INPUTS = [ diff --git a/tests/utilities/test_dbt_config_parser.py b/tests/utilities/test_dbt_config_parser.py index be9b3dc..c03976d 100644 --- a/tests/utilities/test_dbt_config_parser.py +++ b/tests/utilities/test_dbt_config_parser.py @@ -11,12 +11,12 @@ DBT_MANIFEST_FILE_FIXTURE, DBT_PROFILE_FIXTURE, EXPECTED_STAGING_NODE, - EXPECTED_STAGING_NODE_MULTIPLE_DEPENDENCIES, EXPECTED_SEED_NODE, EXPECTED_MODEL_MULTIPLE_DEPENDENCIES, EXPECTED_EPHEMERAL_NODE, EXPECTED_DBT_STAGING_MODEL_DAGGER_OUTPUTS, EXPECTED_DBT_STAGING_MODEL_DAGGER_INPUTS, + EXPECTED_MODEL_NODE, ) _logger = logging.getLogger("root") @@ -47,16 +47,12 @@ def test_generate_task_configs(self): module.generate_task_configs() - def test_generate_dagger_inputs(self): + def test_generate_dagger_tasks(self): test_inputs = [ ( "model.main.stg_core_schema1__table1", EXPECTED_STAGING_NODE, ), - ( - "model.main.stg_core_schema2__table2", - EXPECTED_STAGING_NODE_MULTIPLE_DEPENDENCIES, - ), ( "seed.main.seed_buyer_country_overwrite", EXPECTED_SEED_NODE, @@ -65,6 +61,10 @@ def test_generate_dagger_inputs(self): "model.main.int_model3", EXPECTED_EPHEMERAL_NODE, ), + ( + "model.main.model1", + EXPECTED_MODEL_NODE, + ), ] for mock_input, expected_output in test_inputs: result = self._dbt_config_parser._generate_dagger_tasks(mock_input) From 146846e7953db1999265510cbc74dc3344f57de2 Mon Sep 17 00:00:00 2001 From: Kiran Vasudev Date: Wed, 6 Dec 2023 12:57:57 +0100 Subject: [PATCH 056/189] initialize dbt module only when its a dbt pipeline config --- dagger/utilities/module.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/dagger/utilities/module.py b/dagger/utilities/module.py index 7b954e3..5c81e04 100644 --- a/dagger/utilities/module.py +++ b/dagger/utilities/module.py @@ -22,7 +22,13 @@ def __init__(self, path_to_config, target_dir): self._branches_to_generate = config["branches_to_generate"] self._override_parameters = config.get("override_parameters", {}) self._default_parameters = config.get("default_parameters", {}) - self._dbt_module = DBTConfigParser(self._default_parameters) + + if ( + "dbt_profile" in self._default_parameters.keys() + and "project_dir" in self._default_parameters.keys() + and "profile_dir" in self._default_parameters.keys() + ): + self._dbt_module = DBTConfigParser(self._default_parameters) @staticmethod def read_yaml(yaml_str): From fd2bd0f365ad94b2dc837b0ac39a0419be7bcf1f Mon Sep 17 00:00:00 2001 From: Kiran Vasudev Date: Wed, 6 Dec 2023 12:58:16 +0100 Subject: [PATCH 057/189] format --- dagger/utilities/module.py | 18 +++++++++++------- 1 file changed, 11 insertions(+), 7 deletions(-) diff --git a/dagger/utilities/module.py b/dagger/utilities/module.py index 5c81e04..ade94e3 100644 --- a/dagger/utilities/module.py +++ b/dagger/utilities/module.py @@ -17,7 +17,9 @@ def __init__(self, path_to_config, target_dir): self._tasks = {} for task in config["tasks"]: - self._tasks[task] = self.read_task_config(f"{path.join(self._directory, task)}.yaml") + self._tasks[task] = self.read_task_config( + f"{path.join(self._directory, task)}.yaml" + ) self._branches_to_generate = config["branches_to_generate"] self._override_parameters = config.get("override_parameters", {}) @@ -54,7 +56,7 @@ def replace_template_parameters(_task_str, _template_parameters): if type(_value) == str: try: int_value = int(_value) - _value = f"\"{_value}\"" + _value = f'"{_value}"' except: pass locals()[_key] = _value @@ -90,17 +92,19 @@ def generate_task_configs(self): ) task_dict = yaml.safe_load(task_str) - if task == 'dbt': + if task == "dbt": inputs, outputs = self._dbt_module.generate_dagger_io(branch_name) - task_dict['inputs'] = inputs - task_dict['outputs'] = outputs - task_dict['task_parameters']['select'] = branch_name + task_dict["inputs"] = inputs + task_dict["outputs"] = outputs + task_dict["task_parameters"]["select"] = branch_name task_dict["autogenerated_by_dagger"] = self._path_to_config override_parameters = self._override_parameters or {} merge(task_dict, override_parameters.get(branch_name, {}).get(task, {})) - self.dump_yaml(task_dict, f"{path.join(self._target_dir, task_name)}.yaml") + self.dump_yaml( + task_dict, f"{path.join(self._target_dir, task_name)}.yaml" + ) @staticmethod def module_config_template(): From fb0c2e9cadc6fe6a1e3c6a933101a088bc943ec6 Mon Sep 17 00:00:00 2001 From: Kiran Vasudev Date: Wed, 6 Dec 2023 16:14:59 +0100 Subject: [PATCH 058/189] made logic to check for dbt task easier --- dagger/utilities/module.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/dagger/utilities/module.py b/dagger/utilities/module.py index ade94e3..d565ffe 100644 --- a/dagger/utilities/module.py +++ b/dagger/utilities/module.py @@ -25,11 +25,7 @@ def __init__(self, path_to_config, target_dir): self._override_parameters = config.get("override_parameters", {}) self._default_parameters = config.get("default_parameters", {}) - if ( - "dbt_profile" in self._default_parameters.keys() - and "project_dir" in self._default_parameters.keys() - and "profile_dir" in self._default_parameters.keys() - ): + if 'dbt' in self._tasks.keys(): self._dbt_module = DBTConfigParser(self._default_parameters) @staticmethod From a2b3fdda30196e09c84392de83c1f304191df950 Mon Sep 17 00:00:00 2001 From: Kiran Vasudev Date: Thu, 7 Dec 2023 19:54:18 +0100 Subject: [PATCH 059/189] fix: follow external dependency for staging models --- dagger/utilities/dbt_config_parser.py | 9 +++++++-- .../modules/dbt_config_parser_fixtures.py | 18 +++++++++++++++--- 2 files changed, 22 insertions(+), 5 deletions(-) diff --git a/dagger/utilities/dbt_config_parser.py b/dagger/utilities/dbt_config_parser.py index dd16008..15f31df 100644 --- a/dagger/utilities/dbt_config_parser.py +++ b/dagger/utilities/dbt_config_parser.py @@ -112,7 +112,10 @@ def _generate_dagger_output(self, node: dict): dict: The dagger output, which is a combination of an athena and s3 task for the DBT model node """ - if node.get("config", {}).get("materialized") in ("view", "ephemeral"): + if node.get("config", {}).get("materialized") in ( + "view", + "ephemeral", + ) or node.get("name").startswith("stg_"): return [self._get_dummy_task(node)] else: return [self._get_athena_task(node), self._get_s3_task(node)] @@ -150,7 +153,9 @@ def _generate_dagger_tasks( task = self._get_dummy_task(node, follow_external_dependency=True) dagger_tasks.append(task) elif node.get("name").startswith("stg_"): - dagger_tasks.append(self._get_dummy_task(node)) + dagger_tasks.append( + self._get_dummy_task(node, follow_external_dependency=True) + ) else: athena_task = self._get_athena_task(node, follow_external_dependency=True) s3_task = self._get_s3_task(node) diff --git a/tests/fixtures/modules/dbt_config_parser_fixtures.py b/tests/fixtures/modules/dbt_config_parser_fixtures.py index 2f42778..90ebf03 100644 --- a/tests/fixtures/modules/dbt_config_parser_fixtures.py +++ b/tests/fixtures/modules/dbt_config_parser_fixtures.py @@ -164,7 +164,11 @@ } EXPECTED_STAGING_NODE = [ - {"name": "stg_core_schema1__table1", "type": "dummy"}, + { + "name": "stg_core_schema1__table1", + "type": "dummy", + "follow_external_dependency": True, + }, ] EXPECTED_SEED_NODE = [ @@ -197,7 +201,11 @@ "type": "dummy", "name": "seed_buyer_country_overwrite", }, - {"name": "stg_core_schema2__table2", "type": "dummy"} + { + "name": "stg_core_schema2__table2", + "type": "dummy", + "follow_external_dependency": True, + }, ] EXPECTED_EPHEMERAL_NODE = [ @@ -225,7 +233,11 @@ ] EXPECTED_DAGGER_INPUTS = [ - {"name": "stg_core_schema2__table2", "type": "dummy"}, + { + "name": "stg_core_schema2__table2", + "type": "dummy", + "follow_external_dependency": True, + }, { "name": "analytics_engineering__model2_athena", "schema": "analytics_engineering", From 365364687a7baa0b6e6326be1c8a541a10dafbf6 Mon Sep 17 00:00:00 2001 From: Kiran Vasudev Date: Fri, 23 Feb 2024 11:36:26 +0530 Subject: [PATCH 060/189] added logic to get parents of int models this is done to keep track of dependencies of int models that are ephemeral --- dagger/utilities/dbt_config_parser.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/dagger/utilities/dbt_config_parser.py b/dagger/utilities/dbt_config_parser.py index 15f31df..f12cccf 100644 --- a/dagger/utilities/dbt_config_parser.py +++ b/dagger/utilities/dbt_config_parser.py @@ -152,6 +152,10 @@ def _generate_dagger_tasks( elif node.get("config", {}).get("materialized") == "ephemeral": task = self._get_dummy_task(node, follow_external_dependency=True) dagger_tasks.append(task) + + ephemeral_parent_node_names = node.get("depends_on", {}).get("nodes", []) + for node_name in ephemeral_parent_node_names: + dagger_tasks += self._generate_dagger_tasks(node_name) elif node.get("name").startswith("stg_"): dagger_tasks.append( self._get_dummy_task(node, follow_external_dependency=True) From 7dec65aa51fa310cf8d1dfe210464efd15e48610 Mon Sep 17 00:00:00 2001 From: Kiran Vasudev Date: Fri, 23 Feb 2024 11:37:26 +0530 Subject: [PATCH 061/189] updated tests and fixtures --- .../modules/dbt_config_parser_fixtures.py | 70 +++++++++++++++++-- tests/utilities/test_dbt_config_parser.py | 2 + 2 files changed, 68 insertions(+), 4 deletions(-) diff --git a/tests/fixtures/modules/dbt_config_parser_fixtures.py b/tests/fixtures/modules/dbt_config_parser_fixtures.py index 90ebf03..a28d871 100644 --- a/tests/fixtures/modules/dbt_config_parser_fixtures.py +++ b/tests/fixtures/modules/dbt_config_parser_fixtures.py @@ -78,6 +78,10 @@ "config": { "materialized": "ephemeral", }, + "depends_on": { + "macros": [], + "nodes": ["model.main.int_model2"], + }, }, "seed.main.seed_buyer_country_overwrite": { "database": "awsdatacatalog", @@ -108,6 +112,21 @@ ], }, }, + "model.main.int_model2": { + "name": "int_model2", + "unique_id": "model.main.int_model2", + "schema": "analytics_engineering", + "config": { + "materialized": "ephemeral", + }, + "depends_on": { + "macros": [], + "nodes": [ + "seed.main.seed_buyer_country_overwrite", + "model.main.stg_core_schema1__table1", + ], + }, + }, }, "sources": { "source.main.core_schema1.table1": { @@ -184,6 +203,20 @@ "name": "int_model3", "follow_external_dependency": True, }, + { + "type": "dummy", + "name": "int_model2", + "follow_external_dependency": True, + }, + { + "type": "dummy", + "name": "seed_buyer_country_overwrite", + }, + { + "name": "stg_core_schema1__table1", + "type": "dummy", + "follow_external_dependency": True, + }, { "type": "athena", "name": "analytics_engineering__model2_athena", @@ -197,10 +230,6 @@ "path": "path2/model2", "type": "s3", }, - { - "type": "dummy", - "name": "seed_buyer_country_overwrite", - }, { "name": "stg_core_schema2__table2", "type": "dummy", @@ -213,6 +242,20 @@ "type": "dummy", "name": "int_model3", "follow_external_dependency": True, + }, + { + "type": "dummy", + "name": "int_model2", + "follow_external_dependency": True, + }, + { + "type": "dummy", + "name": "seed_buyer_country_overwrite", + }, + { + "name": "stg_core_schema1__table1", + "type": "dummy", + "follow_external_dependency": True, } ] @@ -256,7 +299,17 @@ "name": "int_model3", "follow_external_dependency": True, }, + { + "type": "dummy", + "name": "int_model2", + "follow_external_dependency": True, + }, {"name": "seed_buyer_country_overwrite", "type": "dummy"}, + { + "name": "stg_core_schema1__table1", + "type": "dummy", + "follow_external_dependency": True, + }, ] EXPECTED_DBT_STAGING_MODEL_DAGGER_INPUTS = [ @@ -298,3 +351,12 @@ "name": "stg_core_schema2__table2", }, ] + +EXPECTED_DBT_INT_MODEL_DAGGER_INPUTS = [ + {"name": "seed_buyer_country_overwrite", "type": "dummy"}, + { + "name": "stg_core_schema1__table1", + "type": "dummy", + "follow_external_dependency": True, + }, +] diff --git a/tests/utilities/test_dbt_config_parser.py b/tests/utilities/test_dbt_config_parser.py index c03976d..3fd6394 100644 --- a/tests/utilities/test_dbt_config_parser.py +++ b/tests/utilities/test_dbt_config_parser.py @@ -17,6 +17,7 @@ EXPECTED_DBT_STAGING_MODEL_DAGGER_OUTPUTS, EXPECTED_DBT_STAGING_MODEL_DAGGER_INPUTS, EXPECTED_MODEL_NODE, + EXPECTED_DBT_INT_MODEL_DAGGER_INPUTS ) _logger = logging.getLogger("root") @@ -78,6 +79,7 @@ def test_generate_io_inputs(self): EXPECTED_MODEL_MULTIPLE_DEPENDENCIES, ), ("stg_core_schema2__table2", EXPECTED_DBT_STAGING_MODEL_DAGGER_INPUTS), + ("int_model2", EXPECTED_DBT_INT_MODEL_DAGGER_INPUTS), ] for mock_input, expected_output in fixtures: result, _ = self._dbt_config_parser.generate_dagger_io(mock_input) From bdc3e57babf7a74caa8bc3d99b3515f0c72a9b30 Mon Sep 17 00:00:00 2001 From: David Siklosi Date: Mon, 25 Mar 2024 20:33:40 +0100 Subject: [PATCH 062/189] Turing split_statements on by default --- dagger/dag_creator/airflow/operators/postgres_operator.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dagger/dag_creator/airflow/operators/postgres_operator.py b/dagger/dag_creator/airflow/operators/postgres_operator.py index b833516..c01b255 100644 --- a/dagger/dag_creator/airflow/operators/postgres_operator.py +++ b/dagger/dag_creator/airflow/operators/postgres_operator.py @@ -51,6 +51,6 @@ def execute(self, context): self.hook = PostgresHook( postgres_conn_id=self.postgres_conn_id, schema=self.database ) - self.hook.run(self.sql, self.autocommit, parameters=self.parameters) + self.hook.run(self.sql, self.autocommit, parameters=self.parameters, split_statements=True) for output in self.hook.conn.notices: self.log.info(output) From 96ad27b584ba27acbb9dbe54f418426fd47764f7 Mon Sep 17 00:00:00 2001 From: David Siklosi Date: Sat, 13 Apr 2024 20:59:23 +0200 Subject: [PATCH 063/189] Moving to python3.9; Upgrading airflow version; removing legacy postgres operator --- Makefile | 6 ++-- .../redshift_load_creator.py | 5 ++- .../redshift_transform_creator.py | 6 ++-- .../redshift_unload_creator.py | 7 +++-- .../airflow/operators/postgres_operator.py | 2 +- reqs/dev.txt | 31 +++++++++---------- reqs/test.txt | 2 +- setup.py | 3 +- 8 files changed, 30 insertions(+), 32 deletions(-) diff --git a/Makefile b/Makefile index 02daa9e..8872dba 100644 --- a/Makefile +++ b/Makefile @@ -96,15 +96,15 @@ install: clean ## install the package to the active Python's site-packages install-dev: clean ## install the package to the active Python's site-packages - virtualenv -p python3 venv; \ + virtualenv -p python3.9 venv; \ source venv/bin/activate; \ python -m pip install --upgrade pip; \ python setup.py install; \ pip install -e . ; \ - pip install -r reqs/dev.txt -r reqs/test.txt + SYSTEM_VERSION_COMPAT=0 CFLAGS='-std=c++20' pip install -r reqs/dev.txt -r reqs/test.txt install-test: clean ## install the package to the active Python's site-packages - virtualenv -p python3 venv; \ + virtualenv -p python3.9 venv; \ source venv/bin/activate; \ python -m pip install --upgrade pip; \ pip install -r reqs/test.txt -r reqs/base.txt diff --git a/dagger/dag_creator/airflow/operator_creators/redshift_load_creator.py b/dagger/dag_creator/airflow/operator_creators/redshift_load_creator.py index 8d14182..f1576f4 100644 --- a/dagger/dag_creator/airflow/operator_creators/redshift_load_creator.py +++ b/dagger/dag_creator/airflow/operator_creators/redshift_load_creator.py @@ -2,9 +2,8 @@ from typing import Optional from dagger.dag_creator.airflow.operator_creator import OperatorCreator -from dagger.dag_creator.airflow.operators.redshift_sql_operator import ( - RedshiftSQLOperator, -) +from dagger.dag_creator.airflow.operators.redshift_sql_operator import RedshiftSQLOperator + class RedshiftLoadCreator(OperatorCreator): diff --git a/dagger/dag_creator/airflow/operator_creators/redshift_transform_creator.py b/dagger/dag_creator/airflow/operator_creators/redshift_transform_creator.py index 0218a6f..c8eb8dd 100644 --- a/dagger/dag_creator/airflow/operator_creators/redshift_transform_creator.py +++ b/dagger/dag_creator/airflow/operator_creators/redshift_transform_creator.py @@ -1,7 +1,7 @@ from os.path import join from dagger.dag_creator.airflow.operator_creator import OperatorCreator -from dagger.dag_creator.airflow.operators.postgres_operator import PostgresOperator +from dagger.dag_creator.airflow.operators.redshift_sql_operator import RedshiftSQLOperator class RedshiftTransformCreator(OperatorCreator): @@ -22,11 +22,11 @@ def _read_sql(directory, file_path): def _create_operator(self, **kwargs): sql_string = self._read_sql(self._task.pipeline.directory, self._task.sql_file) - redshift_op = PostgresOperator( + redshift_op = RedshiftSQLOperator( dag=self._dag, task_id=self._task.name, sql=sql_string, - postgres_conn_id=self._task.postgres_conn_id, + redshift_conn_id=self._task.postgres_conn_id, params=self._template_parameters, **kwargs, ) diff --git a/dagger/dag_creator/airflow/operator_creators/redshift_unload_creator.py b/dagger/dag_creator/airflow/operator_creators/redshift_unload_creator.py index 7fd74d7..cb7be04 100644 --- a/dagger/dag_creator/airflow/operator_creators/redshift_unload_creator.py +++ b/dagger/dag_creator/airflow/operator_creators/redshift_unload_creator.py @@ -1,7 +1,7 @@ from os.path import join from dagger.dag_creator.airflow.operator_creator import OperatorCreator -from dagger.dag_creator.airflow.operators.postgres_operator import PostgresOperator +from dagger.dag_creator.airflow.operators.redshift_sql_operator import RedshiftSQLOperator REDSHIFT_UNLOAD_CMD = """ unload ('{sql_string}') @@ -58,12 +58,13 @@ def _create_operator(self, **kwargs): unload_cmd = self._get_unload_command(sql_string) - redshift_op = PostgresOperator( + redshift_op = RedshiftSQLOperator( dag=self._dag, task_id=self._task.name, sql=unload_cmd, - postgres_conn_id=self._task.postgres_conn_id, + redshift_conn_id=self._task.postgres_conn_id, params=self._template_parameters, + autocommit=True, **kwargs, ) diff --git a/dagger/dag_creator/airflow/operators/postgres_operator.py b/dagger/dag_creator/airflow/operators/postgres_operator.py index c01b255..ce90250 100644 --- a/dagger/dag_creator/airflow/operators/postgres_operator.py +++ b/dagger/dag_creator/airflow/operators/postgres_operator.py @@ -1,6 +1,6 @@ from typing import Iterable, Mapping, Optional, Union -from airflow.hooks.postgres_hook import PostgresHook +from airflow.providers.postgres.hooks.postgres import PostgresHook from airflow.utils.decorators import apply_defaults from dagger.dag_creator.airflow.operators.dagger_base_operator import DaggerBaseOperator diff --git a/reqs/dev.txt b/reqs/dev.txt index 806d6c5..c52136a 100644 --- a/reqs/dev.txt +++ b/reqs/dev.txt @@ -1,19 +1,18 @@ -apache-airflow[amazon,postgres,s3,statsd]==2.3.4 +pip==24.0 +apache-airflow[amazon,postgres,s3,statsd]==2.9.0 --constraint "https://raw.githubusercontent.com/apache/airflow/constraints-2.9.0/constraints-3.9.txt" black==22.10.0 -boto3==1.26.16 +boto3==1.34.82 bumpversion==0.6.0 -coverage==6.5.0 -elasticsearch==7.17.7 -flake8==5.0.4 -neo4j==5.2.1 -numpydoc==1.5.0 -pip==22.3.1 +coverage==7.4.4 +#elasticsearch==7.17.7 +flake8==7.0.0 +#neo4j==5.19.0 +numpydoc==1.7.0 pre-commit==2.20.0 -sphinx-rtd-theme==1.1.1 -Sphinx==4.3.2 -SQLAlchemy==1.4.44 -tox==3.27.1 -twine==4.0.1 -watchdog==2.1.9 -Werkzeug==2.2.2 -wheel==0.38.4 +sphinx-rtd-theme==2.0.0 +Sphinx==7.2.6 +SQLAlchemy +tox==4.14.2 +twine==5.0.0 +watchdog==4.0.0 +Werkzeug diff --git a/reqs/test.txt b/reqs/test.txt index c568f77..195932d 100644 --- a/reqs/test.txt +++ b/reqs/test.txt @@ -1,3 +1,3 @@ -apache-airflow[amazon,postgres,s3,statsd]==2.3.4 +apache-airflow[amazon,postgres,s3,statsd]==2.9.0 pytest-cov==4.0.0 pytest==7.2.0 diff --git a/setup.py b/setup.py index 3f80fe3..080a5bb 100644 --- a/setup.py +++ b/setup.py @@ -45,8 +45,7 @@ def reqs(*f): classifiers=[ "Development Status :: 2 - Pre-Alpha", "Intended Audience :: Developers", - "Programming Language :: Python :: 3.6", - "Programming Language :: Python :: 3.7", + "Programming Language :: Python :: 3.9", ], description="Config Driven ETL", entry_points={"console_scripts": ["dagger=dagger.main:cli"]}, From 3e62d78068b396779e89b2ddedd32c87eff3cf57 Mon Sep 17 00:00:00 2001 From: David Siklosi Date: Mon, 15 Apr 2024 11:00:22 +0200 Subject: [PATCH 064/189] Making sensor default args more flexible --- dagger/conf.py | 4 +--- dagger/dag_creator/airflow/dag_creator.py | 8 ++------ dagger/dagger_config.yaml | 8 +++++--- 3 files changed, 8 insertions(+), 12 deletions(-) diff --git a/dagger/conf.py b/dagger/conf.py index 667c207..6b5488f 100644 --- a/dagger/conf.py +++ b/dagger/conf.py @@ -21,9 +21,7 @@ # Airflow parameters airflow_config = config.get('airflow', None) or {} WITH_DATA_NODES = airflow_config.get('with_data_nodes', False) -EXTERNAL_SENSOR_POKE_INTERVAL = airflow_config.get('external_sensor_poke_interval', 600) -EXTERNAL_SENSOR_TIMEOUT = airflow_config.get('external_sensor_timeout', 28800) -EXTERNAL_SENSOR_MODE = airflow_config.get('external_sensor_mode', 'reschedule') +EXTERNAL_SENSOR_DEFAULT_ARGS = airflow_config.get('external_sensor_default_args', {}) IS_DUMMY_OPERATOR_SHORT_CIRCUIT = airflow_config.get('is_dummy_operator_short_circuit', False) # Neo4j parameters diff --git a/dagger/dag_creator/airflow/dag_creator.py b/dagger/dag_creator/airflow/dag_creator.py index 2b208b9..031a3a4 100644 --- a/dagger/dag_creator/airflow/dag_creator.py +++ b/dagger/dag_creator/airflow/dag_creator.py @@ -72,12 +72,7 @@ def _get_external_task_sensor(self, from_task_id: str, to_task_id: str, follow_e to_pipe_id = self._task_graph.get_node(to_task_id).obj.pipeline.name - - extra_args = { - 'mode': conf.EXTERNAL_SENSOR_MODE, - 'poke_interval': conf.EXTERNAL_SENSOR_POKE_INTERVAL, - 'timeout': conf.EXTERNAL_SENSOR_TIMEOUT, - } + extra_args = conf.EXTERNAL_SENSOR_DEFAULT_ARGS.copy() extra_args.update(follow_external_dependency) return ExternalTaskSensor( @@ -141,6 +136,7 @@ def _create_edge_without_data(self, from_task_id: str, to_task_ids: list, node: to_task_ids: The IDs of the tasks to which the edge connects. node: The current node in a task graph. """ + from_pipe = ( self._task_graph.get_node(from_task_id).obj.pipeline_name if from_task_id else None ) diff --git a/dagger/dagger_config.yaml b/dagger/dagger_config.yaml index 9eac6ff..3366828 100644 --- a/dagger/dagger_config.yaml +++ b/dagger/dagger_config.yaml @@ -1,8 +1,10 @@ airflow: + external_sensor_default_args: + poll_interval: 30 + timeout: 28800 + mode: reschedule + deferrable: true with_data_node: false - external_sensor_poke_interval: 600 - external_sensor_timeout: 28800 - external_sensor_mode: reschedule is_dummy_operator_short_circuit: false From 41f15544787651da6b9a2b3e085a766656f93226 Mon Sep 17 00:00:00 2001 From: David Siklosi Date: Mon, 15 Apr 2024 14:15:44 +0200 Subject: [PATCH 065/189] Upgrading python in CI --- .github/workflows/ci-data.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/ci-data.yml b/.github/workflows/ci-data.yml index 599d325..bef5bed 100644 --- a/.github/workflows/ci-data.yml +++ b/.github/workflows/ci-data.yml @@ -17,10 +17,10 @@ jobs: with: persist-credentials: false - - name: Set up Python 3.7 + - name: Set up Python 3.9 uses: actions/setup-python@v2 with: - python-version: 3.7 + python-version: 3.9 - name: Install dependencies run: | From d8145ab6b0ccda834b58f107ef3fd7ed5a1a83a0 Mon Sep 17 00:00:00 2001 From: David Siklosi Date: Mon, 15 Apr 2024 14:25:25 +0200 Subject: [PATCH 066/189] Adding graphviz dependency to test --- reqs/dev.txt | 4 ++-- reqs/test.txt | 1 + 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/reqs/dev.txt b/reqs/dev.txt index c52136a..238b2e2 100644 --- a/reqs/dev.txt +++ b/reqs/dev.txt @@ -4,9 +4,9 @@ black==22.10.0 boto3==1.34.82 bumpversion==0.6.0 coverage==7.4.4 -#elasticsearch==7.17.7 +elasticsearch==7.17.7 flake8==7.0.0 -#neo4j==5.19.0 +neo4j==5.19.0 numpydoc==1.7.0 pre-commit==2.20.0 sphinx-rtd-theme==2.0.0 diff --git a/reqs/test.txt b/reqs/test.txt index 195932d..7bdc89f 100644 --- a/reqs/test.txt +++ b/reqs/test.txt @@ -1,3 +1,4 @@ apache-airflow[amazon,postgres,s3,statsd]==2.9.0 pytest-cov==4.0.0 pytest==7.2.0 +graphviz From 8cc2ec2640220ec2f5e605a71dc050a68cc3df9c Mon Sep 17 00:00:00 2001 From: David Siklosi Date: Tue, 16 Apr 2024 12:59:04 +0200 Subject: [PATCH 067/189] Upgrading some package versions to remove warnings --- reqs/base.txt | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/reqs/base.txt b/reqs/base.txt index 877d102..d9cc38a 100644 --- a/reqs/base.txt +++ b/reqs/base.txt @@ -1,7 +1,7 @@ click==8.1.3 -croniter==1.3.8 +croniter==2.0.2 envyaml==1.10.211231 mergedeep==1.3.4 slack==0.0.2 slackclient==2.9.4 -tenacity==8.2.0 +tenacity==8.2.3 From f56e6b62890bb41fabb9cae394da7b777d43c16b Mon Sep 17 00:00:00 2001 From: claudiazi Date: Wed, 17 Apr 2024 10:31:57 +0200 Subject: [PATCH 068/189] feat: rename profile_name to target_name --- .../dag_creator/airflow/operator_creators/dbt_creator.py | 4 ++-- dagger/pipeline/tasks/dbt_task.py | 8 ++++---- dockers/airflow/airflow.cfg | 2 +- 3 files changed, 7 insertions(+), 7 deletions(-) diff --git a/dagger/dag_creator/airflow/operator_creators/dbt_creator.py b/dagger/dag_creator/airflow/operator_creators/dbt_creator.py index 1c16835..38b9c34 100644 --- a/dagger/dag_creator/airflow/operator_creators/dbt_creator.py +++ b/dagger/dag_creator/airflow/operator_creators/dbt_creator.py @@ -12,7 +12,7 @@ def __init__(self, task, dag): self._project_dir = task.project_dir self._profile_dir = task.profile_dir - self._profile_name = task.profile_name + self._target_name = task.target_name self._select = task.select self._dbt_command = task.dbt_command @@ -20,7 +20,7 @@ def _generate_command(self): command = [self._task.executable_prefix, self._task.executable] command.append(f"--project_dir={self._project_dir}") command.append(f"--profiles_dir={self._profile_dir}") - command.append(f"--profile_name={self._profile_name}") + command.append(f"--target_name={self._target_name}") command.append(f"--dbt_command={self._dbt_command}") if self._select: command.append(f"--select={self._select}") diff --git a/dagger/pipeline/tasks/dbt_task.py b/dagger/pipeline/tasks/dbt_task.py index 33b9c1a..c59cdd6 100644 --- a/dagger/pipeline/tasks/dbt_task.py +++ b/dagger/pipeline/tasks/dbt_task.py @@ -20,7 +20,7 @@ def init_attributes(cls, orig_cls): comment="Which directory to look in for the profiles.yml file", ), Attribute( - attribute_name="profile_name", + attribute_name="target_name", required=False, parent_fields=["task_parameters"], comment="Which target to load for the given profile " @@ -45,7 +45,7 @@ def __init__(self, name, pipeline_name, pipeline, job_config): self._project_dir = self.parse_attribute("project_dir") self._profile_dir = self.parse_attribute("profile_dir") - self._profile_name = self.parse_attribute("profile_name") or "default" + self._target_name = self.parse_attribute("target_name") or "default" self._select = self.parse_attribute("select") self._dbt_command = self.parse_attribute("dbt_command") @@ -58,8 +58,8 @@ def profile_dir(self): return self._profile_dir @property - def profile_name(self): - return self._profile_name + def target_name(self): + return self._target_name @property def select(self): diff --git a/dockers/airflow/airflow.cfg b/dockers/airflow/airflow.cfg index a5ace87..0b19fbd 100644 --- a/dockers/airflow/airflow.cfg +++ b/dockers/airflow/airflow.cfg @@ -434,7 +434,7 @@ backend = # The backend_kwargs param is loaded into a dictionary and passed to __init__ of secrets backend class. # See documentation for the secrets backend you are using. JSON is expected. # Example for AWS Systems Manager ParameterStore: -# ``{{"connections_prefix": "/airflow/connections", "profile_name": "default"}}`` +# ``{{"connections_prefix": "/airflow/connections", "target_name": "default"}}`` backend_kwargs = [cli] From 1b99357a25766888d334b9a2858130070c1b7f1c Mon Sep 17 00:00:00 2001 From: claudiazi Date: Fri, 26 Apr 2024 10:38:35 +0200 Subject: [PATCH 069/189] feat: register new databricks_io --- dagger/pipeline/io_factory.py | 3 +- dagger/pipeline/ios/databricks_io.py | 48 +++++++++++++++++++ .../fixtures/pipeline/ios/databricks_io.yaml | 11 +++++ tests/pipeline/ios/test_databricks_io.py | 17 +++++++ 4 files changed, 78 insertions(+), 1 deletion(-) create mode 100644 dagger/pipeline/ios/databricks_io.py create mode 100644 tests/fixtures/pipeline/ios/databricks_io.yaml create mode 100644 tests/pipeline/ios/test_databricks_io.py diff --git a/dagger/pipeline/io_factory.py b/dagger/pipeline/io_factory.py index 782fd14..5454f31 100644 --- a/dagger/pipeline/io_factory.py +++ b/dagger/pipeline/io_factory.py @@ -6,7 +6,8 @@ dummy_io, gdrive_io, redshift_io, - s3_io + s3_io, + databricks_io ) from dagger.utilities.classes import get_deep_obj_subclasses diff --git a/dagger/pipeline/ios/databricks_io.py b/dagger/pipeline/ios/databricks_io.py new file mode 100644 index 0000000..dd9041b --- /dev/null +++ b/dagger/pipeline/ios/databricks_io.py @@ -0,0 +1,48 @@ +from dagger.pipeline.io import IO +from dagger.utilities.config_validator import Attribute + + +class DatabricksIO(IO): + ref_name = "databricks" + + @classmethod + def init_attributes(cls, orig_cls): + cls.add_config_attributes( + [ + Attribute(attribute_name="catalog"), + Attribute( + attribute_name="schema" + ), + Attribute(attribute_name="table"), + ] + ) + + def __init__(self, io_config, config_location): + super().__init__(io_config, config_location) + + self._catalog = self.parse_attribute("catalog") + self._schema = self.parse_attribute("schema") + self._table = self.parse_attribute("table") + + def alias(self): + return f"databricks://{self._catalog}/{self._schema}/{self._table}" + + @property + def rendered_name(self): + return f"{self._catalog}.{self._schema}.{self._table}" + + @property + def airflow_name(self): + return f"databricks-{self._catalog}-{self._schema}-{self._table}" + + @property + def catalog(self): + return self._catalog + + @property + def schema(self): + return self._schema + + @property + def table(self): + return self._table diff --git a/tests/fixtures/pipeline/ios/databricks_io.yaml b/tests/fixtures/pipeline/ios/databricks_io.yaml new file mode 100644 index 0000000..a8d5914 --- /dev/null +++ b/tests/fixtures/pipeline/ios/databricks_io.yaml @@ -0,0 +1,11 @@ +type: databricks +name: test +catalog: test_catalog +schema: test_schema +table: test_table + + + +# Other attributes: + +# has_dependency: # Weather this i/o should be added to the dependency graph or not. Default is True \ No newline at end of file diff --git a/tests/pipeline/ios/test_databricks_io.py b/tests/pipeline/ios/test_databricks_io.py new file mode 100644 index 0000000..b1d0c45 --- /dev/null +++ b/tests/pipeline/ios/test_databricks_io.py @@ -0,0 +1,17 @@ +import unittest +from dagger.pipeline.io_factory import databricks_io + +import yaml + + +class DbIOTest(unittest.TestCase): + def setUp(self) -> None: + with open('tests/fixtures/pipeline/ios/databricks_io.yaml', "r") as stream: + config = yaml.safe_load(stream) + + self.db_io = databricks_io.DatabricksIO(config, "/") + + def test_properties(self): + self.assertEqual(self.db_io.alias(), "databricks://test_catalog/test_schema/test_table") + self.assertEqual(self.db_io.rendered_name, "test_catalog.test_schema.test_table") + self.assertEqual(self.db_io.airflow_name, "databricks-test_catalog-test_schema-test_table") From e00f555d138f5d1bdfe661a62037d66583a98b00 Mon Sep 17 00:00:00 2001 From: claudiazi Date: Fri, 26 Apr 2024 10:48:06 +0200 Subject: [PATCH 070/189] feat: refactor the DBTParseConfig to parse databricks-dbt manifest --- .../airflow/operator_creators/dbt_creator.py | 19 +- dagger/pipeline/tasks/dbt_task.py | 35 +- dagger/utilities/dbt_config_parser.py | 372 ++++++++++++------ dagger/utilities/module.py | 7 +- 4 files changed, 299 insertions(+), 134 deletions(-) diff --git a/dagger/dag_creator/airflow/operator_creators/dbt_creator.py b/dagger/dag_creator/airflow/operator_creators/dbt_creator.py index 38b9c34..4b88fe3 100644 --- a/dagger/dag_creator/airflow/operator_creators/dbt_creator.py +++ b/dagger/dag_creator/airflow/operator_creators/dbt_creator.py @@ -12,29 +12,26 @@ def __init__(self, task, dag): self._project_dir = task.project_dir self._profile_dir = task.profile_dir + self._profile_name = task.profile_name self._target_name = task.target_name self._select = task.select self._dbt_command = task.dbt_command + self._vars = task.vars + # self._create_external_athena_table = task.create_external_athena_table def _generate_command(self): command = [self._task.executable_prefix, self._task.executable] command.append(f"--project_dir={self._project_dir}") command.append(f"--profiles_dir={self._profile_dir}") + command.append(f"--profile_name={self._profile_name}") command.append(f"--target_name={self._target_name}") command.append(f"--dbt_command={self._dbt_command}") if self._select: command.append(f"--select={self._select}") - - if len(self._template_parameters) > 0: - dbt_vars = json.dumps(self._template_parameters) + if self._vars: + dbt_vars = json.dumps(self._vars) command.append(f"--vars='{dbt_vars}'") + # if self._create_external_athena_table: + # command.append(f"--create_external_athena_table={self._create_external_athena_table}") return command - - # Overwriting function because for dbt we don't want to add inputs/outputs to the - # template parameters. - def create_operator(self): - self._template_parameters.update(self._task.template_parameters) - self._update_airflow_parameters() - - return self._create_operator(**self._airflow_parameters) diff --git a/dagger/pipeline/tasks/dbt_task.py b/dagger/pipeline/tasks/dbt_task.py index c59cdd6..aea0945 100644 --- a/dagger/pipeline/tasks/dbt_task.py +++ b/dagger/pipeline/tasks/dbt_task.py @@ -19,9 +19,13 @@ def init_attributes(cls, orig_cls): parent_fields=["task_parameters"], comment="Which directory to look in for the profiles.yml file", ), + Attribute( + attribute_name="profile_name", + parent_fields=["task_parameters"], + comment="Which profile to load from the profiles.yml file", + ), Attribute( attribute_name="target_name", - required=False, parent_fields=["task_parameters"], comment="Which target to load for the given profile " "(--target dbt option). Default is 'default'", @@ -37,6 +41,18 @@ def init_attributes(cls, orig_cls): parent_fields=["task_parameters"], comment="Specify the name of the DBT command to run", ), + Attribute( + attribute_name="vars", + required=False, + parent_fields=["task_parameters"], + comment="Specify the variables to pass to dbt", + ), + Attribute( + attribute_name="create_external_athena_table", + required=False, + parent_fields=["task_parameters"], + comment="Specify whether to create an external Athena table for the model", + ) ] ) @@ -45,9 +61,12 @@ def __init__(self, name, pipeline_name, pipeline, job_config): self._project_dir = self.parse_attribute("project_dir") self._profile_dir = self.parse_attribute("profile_dir") - self._target_name = self.parse_attribute("target_name") or "default" + self._profile_name = self.parse_attribute("profile_name") + self._target_name = self.parse_attribute("target_name") self._select = self.parse_attribute("select") self._dbt_command = self.parse_attribute("dbt_command") + self._vars = self.parse_attribute("vars") + self._create_external_athena_table = self.parse_attribute("create_external_athena_table") @property def project_dir(self): @@ -57,6 +76,10 @@ def project_dir(self): def profile_dir(self): return self._profile_dir + @property + def profile_name(self): + return self._profile_name + @property def target_name(self): return self._target_name @@ -68,3 +91,11 @@ def select(self): @property def dbt_command(self): return self._dbt_command + + @property + def vars(self): + return self._vars + + @property + def create_external_athena_table(self): + return self._create_external_athena_table \ No newline at end of file diff --git a/dagger/utilities/dbt_config_parser.py b/dagger/utilities/dbt_config_parser.py index f12cccf..6c2ae5d 100644 --- a/dagger/utilities/dbt_config_parser.py +++ b/dagger/utilities/dbt_config_parser.py @@ -1,42 +1,89 @@ import json +import yaml +from abc import ABC, abstractmethod from collections import OrderedDict from os import path from os.path import join from typing import Tuple, List, Dict +import logging -import yaml - +# Task base configurations ATHENA_TASK_BASE = {"type": "athena"} +DATABRICKS_TASK_BASE = {"type": "databricks"} S3_TASK_BASE = {"type": "s3"} +_logger = logging.getLogger("root") -class DBTConfigParser: - """ - Module that parses the manifest.json file generated by dbt and generates the dagger inputs and outputs for the respective dbt model - """ +class DBTConfigParser(ABC): + """Abstract base class for parsing dbt manifest.json files and generating task configurations.""" - def __init__(self, default_config_parameters: dict): - self._dbt_profile = default_config_parameters.get("dbt_profile", "data") - self._default_data_bucket = default_config_parameters["data_bucket"] - self._dbt_project_dir = default_config_parameters.get("project_dir", None) - dbt_manifest_path = path.join(self._dbt_project_dir, "target", "manifest.json") - self._dbt_profile_dir = default_config_parameters.get("profile_dir", None) - dbt_profile_path = path.join(self._dbt_profile_dir, "profiles.yml") - - with open(dbt_manifest_path, "r") as f: - data = f.read() - self._manifest_data = json.loads(data) - profile_yaml = yaml.safe_load(open(dbt_profile_path, "r")) - prod_dbt_profile = profile_yaml[self._dbt_project_dir.split("/")[-1]][ - "outputs" - ][self._dbt_profile] - self._default_data_dir = prod_dbt_profile.get( - "s3_data_dir" - ) or prod_dbt_profile.get("s3_staging_dir") - self._default_schema = prod_dbt_profile.get("schema") + def __init__(self, config_parameters: dict): + self._dbt_project_dir = config_parameters.get("project_dir") + self._profile_name = config_parameters.get("profile_name", "") + self._target_name = config_parameters.get("target_name", "") + self._dbt_profile_dir = config_parameters.get("profile_dir", None) + self._manifest_data = self._load_file( + self._get_manifest_path(), file_type="json" + ) + profile_data = self._load_file(self._get_profile_path(), file_type="yaml") + self._target_config = profile_data[self._profile_name]["outputs"][ + self._target_name + ] + self._default_schema = self._target_config.get("schema", "") + self._nodes_in_manifest = self._manifest_data.get("nodes", {}) + self._sources_in_manifest = self._manifest_data.get("sources", {}) + + def _get_manifest_path(self) -> str: + """ + Construct path for manifest.json file based on configuration parameters. + """ + target_path = f"{self._profile_name}_target" + return path.join(self._dbt_project_dir, target_path, "manifest.json") - self._nodes_in_manifest = self._manifest_data["nodes"] - self._sources_in_manifest = self._manifest_data["sources"] + def _get_profile_path(self) -> str: + """ + Construct path for profiles.yml file based on configuration parameters. + """ + return path.join(self._dbt_profile_dir, "profiles.yml") + + @staticmethod + def _load_file(file_path: str, file_type: str) -> dict: + """Load a file (JSON or YAML) based on the specified type and return its contents.""" + try: + with open(file_path, "r") as file: + if file_type == "json": + return json.load(file) + elif file_type == "yaml": + return yaml.safe_load(file) + except FileNotFoundError: + _logger.error(f"File not found: {file_path}") + exit(1) + + @abstractmethod + def _get_athena_table_task( + self, node: dict, follow_external_dependency: bool = False + ) -> dict: + """Generate an athena table task for a DBT node. Must be implemented by subclasses. This function should be deprecated after the source connects with databricks directly""" + pass + + @abstractmethod + def _get_table_task( + self, node: dict, follow_external_dependency: bool = False + ) -> dict: + """Generate a table task for a DBT node for the specific dbt-adapter. Must be implemented by subclasses.""" + pass + + @abstractmethod + def _get_model_data_location( + self, node: dict, schema: str, model_name: str + ) -> Tuple[str, str]: + """Get the S3 path of the DBT model relative to the data bucket. Must be implemented by subclasses.""" + pass + + @abstractmethod + def _get_s3_task(self, node: dict) -> dict: + """Generate an S3 task configuration based on a DBT node. Must be implemented by subclasses.""" + pass @staticmethod def _get_dummy_task(node: dict, follow_external_dependency: bool = False) -> dict: @@ -58,18 +105,107 @@ def _get_dummy_task(node: dict, follow_external_dependency: bool = False) -> dic return task - def _get_athena_task( - self, node: dict, follow_external_dependency: bool = False - ) -> dict: + @abstractmethod + def _generate_dagger_output(self, node: dict): + """Generate the dagger output for a DBT node. Must be implemented by subclasses.""" + pass + + def _generate_dagger_tasks(self, node_name: str) -> List[Dict]: """ - Generates the dagger athena task for the DBT model node + Generates the dagger task based on whether the DBT model node is a staging model or not. + If the DBT model node represents a DBT seed or an ephemeral model, then a dagger dummy task is generated. + If the DBT model node represents a staging model, then a dagger athena task is generated for each source of the DBT model. Apart from this, a dummy task is also generated for the staging model itself. + If the DBT model node is not a staging model, then a dagger athena task and an s3 task is generated for the DBT model node itself. Args: - node: The extracted node from the manifest.json file - follow_external_dependency: Whether to follow external airflow dependencies or not + node_name: The name of the DBT model node Returns: - dict: The dagger athena task for the DBT model node + List[Dict]: The respective dagger tasks for the DBT model node + """ + dagger_tasks = [] + + if node_name.startswith("source"): + node = self._sources_in_manifest[node_name] + else: + node = self._nodes_in_manifest[node_name] + + if node.get("resource_type") == "seed": + task = self._get_dummy_task(node) + dagger_tasks.append(task) + elif node.get("resource_type") == "source": + table_task = self._get_athena_table_task(node, follow_external_dependency=True) + dagger_tasks.append(table_task) + elif node.get("config", {}).get("materialized") == "ephemeral": + task = self._get_dummy_task(node, follow_external_dependency=True) + dagger_tasks.append(task) + + ephemeral_parent_node_names = node.get("depends_on", {}).get("nodes", []) + for node_name in ephemeral_parent_node_names: + dagger_tasks += self._generate_dagger_tasks(node_name) + elif node.get("name").startswith("stg_") or "preparation" in node.get( + "schema", "" + ): + dagger_tasks.append( + self._get_dummy_task(node, follow_external_dependency=True) + ) + else: + table_task = self._get_table_task(node, follow_external_dependency=True) + s3_task = self._get_s3_task(node) + + dagger_tasks.append(table_task) + dagger_tasks.append(s3_task) + + return dagger_tasks + + def generate_dagger_io(self, model_name: str) -> Tuple[List[dict], List[dict]]: + """ + Parse through all the parents of the DBT model and return the dagger inputs and outputs for the DBT model + Args: + model_name: The name of the DBT model + + Returns: + Tuple[list, list]: The dagger inputs and outputs for the DBT model + + """ + inputs_list = [] + model_node = self._nodes_in_manifest[f"model.main.{model_name}"] + parent_node_names = model_node.get("depends_on", {}).get("nodes", []) + print(f"parent node name: {parent_node_names}") + + for parent_node_name in parent_node_names: + dagger_input = self._generate_dagger_tasks(parent_node_name) + inputs_list += dagger_input + + output_list = self._generate_dagger_output(model_node) + + unique_inputs = list( + OrderedDict( + (frozenset(item.items()), item) for item in inputs_list + ).values() + ) + + print(unique_inputs) + + return unique_inputs, output_list + + +class AthenaDBTConfigParser(DBTConfigParser): + """Implementation for Athena configurations.""" + def __init__(self, default_config_parameters: dict): + super().__init__(default_config_parameters) + self._profile_name = "athena" + self._default_data_bucket = default_config_parameters.get("data_bucket") + self._default_data_dir = self._target_config.get( + "s3_data_dir" + ) or self._target_config.get("s3_staging_dir") + + + def _get_table_task( + self, node: dict, follow_external_dependency: bool = False + ) -> dict: + """ + Generates the dagger athena task for the DBT model node """ task = ATHENA_TASK_BASE.copy() if follow_external_dependency: @@ -81,6 +217,24 @@ def _get_athena_task( return task + def _get_athena_table_task(self, node: dict, follow_external_dependency: bool = False) -> dict: + return self._get_table_task(node, follow_external_dependency) + + def _get_model_data_location( + self, node: dict, schema: str, model_name: str + ) -> Tuple[str, str]: + """ + Gets the S3 path of the dbt model relative to the data bucket. + """ + location = node.get("config", {}).get("external_location") + if not location: + location = join(self._default_data_dir, schema, model_name) + + split = location.split("//")[1].split("/") + bucket_name, data_path = split[0], "/".join(split[1:]) + + return bucket_name, data_path + def _get_s3_task(self, node: dict) -> dict: """ Generates the dagger s3 task for the DBT model node @@ -93,17 +247,16 @@ def _get_s3_task(self, node: dict) -> dict: """ task = S3_TASK_BASE.copy() + schema = node.get("schema", self._default_schema) table = node.get("name", "") task["name"] = f"{schema}__{table}_s3" - task["bucket"] = self._default_data_bucket - task["path"] = self._get_model_data_location(node, schema, table)[1] - + task["bucket"], task["path"] = self._get_model_data_location(node, schema, table) return task def _generate_dagger_output(self, node: dict): """ - Generates the dagger output for the DBT model node. If the model is materialized as a view or ephemeral, then a dummy task is created. + Generates the dagger output for the DBT model node with athena-dbt adapter. If the model is materialized as a view or ephemeral, then a dummy task is created. Otherwise, an athena and s3 task is created for the DBT model node. Args: node: The extracted node from the manifest.json file @@ -118,109 +271,90 @@ def _generate_dagger_output(self, node: dict): ) or node.get("name").startswith("stg_"): return [self._get_dummy_task(node)] else: - return [self._get_athena_task(node), self._get_s3_task(node)] - - def _generate_dagger_tasks( - self, - node_name: str, - ) -> List[Dict]: - """ - Generates the dagger task based on whether the DBT model node is a staging model or not. - If the DBT model node represents a DBT seed or an ephemeral model, then a dagger dummy task is generated. - If the DBT model node represents a staging model, then a dagger athena task is generated for each source of the DBT model. Apart from this, a dummy task is also generated for the staging model itself. - If the DBT model node is not a staging model, then a dagger athena task and an s3 task is generated for the DBT model node itself. - Args: - node: The extracted node from the manifest.json file + return [self._get_table_task(node), self._get_s3_task(node)] - Returns: - List[Dict]: The respective dagger tasks for the DBT model node - """ - dagger_tasks = [] +class DatabricksDBTConfigParser(DBTConfigParser): + """Implementation for Databricks configurations.""" - if node_name.startswith("source"): - node = self._sources_in_manifest[node_name] - else: - node = self._nodes_in_manifest[node_name] + def __init__(self, default_config_parameters: dict): + super().__init__(default_config_parameters) + self._profile_name = "databricks" + self._default_catalog = self._target_config.get("catalog") + self._athena_dbt_parser = AthenaDBTConfigParser(default_config_parameters) + self._create_external_athena_table = default_config_parameters.get("create_external_athena_table", False) - if node.get("resource_type") == "seed": - task = self._get_dummy_task(node) - dagger_tasks.append(task) - elif node.get("resource_type") == "source": - athena_task = self._get_athena_task(node, follow_external_dependency=True) - dagger_tasks.append(athena_task) - elif node.get("config", {}).get("materialized") == "ephemeral": - task = self._get_dummy_task(node, follow_external_dependency=True) - dagger_tasks.append(task) + def _get_table_task( + self, node: dict, follow_external_dependency: bool = False + ) -> dict: + """ + Generates the dagger databricks task for the DBT model node + """ + task = DATABRICKS_TASK_BASE.copy() + if follow_external_dependency: + task["follow_external_dependency"] = True - ephemeral_parent_node_names = node.get("depends_on", {}).get("nodes", []) - for node_name in ephemeral_parent_node_names: - dagger_tasks += self._generate_dagger_tasks(node_name) - elif node.get("name").startswith("stg_"): - dagger_tasks.append( - self._get_dummy_task(node, follow_external_dependency=True) - ) - else: - athena_task = self._get_athena_task(node, follow_external_dependency=True) - s3_task = self._get_s3_task(node) + task["catalog"] = node.get("database", self._default_catalog) + task["schema"] = node.get("schema", self._default_schema) + task["table"] = node.get("name", "") + task[ + "name" + ] = f"{task['catalog']}__{task['schema']}__{task['table']}_databricks" - dagger_tasks.append(athena_task) - dagger_tasks.append(s3_task) + return task - return dagger_tasks + def _get_athena_table_task(self, node: dict, follow_external_dependency: bool = False) -> dict: + return self._athena_dbt_parser._get_table_task(node, follow_external_dependency) def _get_model_data_location( - self, node: dict, schema: str, dbt_model_name: str + self, node: dict, schema: str, model_name: str ) -> Tuple[str, str]: """ Gets the S3 path of the dbt model relative to the data bucket. - If external location is not specified in the DBT model config, then the default data directory from the - DBT profiles configuration is used. - Args: - node: The extracted node from the manifest.json file - schema: The schema of the dbt model - dbt_model_name: The name of the dbt model - - Returns: - str: The relative S3 path of the dbt model relative to the data bucket - """ - location = node.get("config", {}).get("external_location") - if not location: - location = join(self._default_data_dir, schema, dbt_model_name) - + location_root = node.get("config", {}).get("location_root") + location = join(location_root, schema, model_name) split = location.split("//")[1].split("/") bucket_name, data_path = split[0], "/".join(split[1:]) return bucket_name, data_path - def generate_dagger_io(self, model_name: str) -> Tuple[List[dict], List[dict]]: + def _get_s3_task(self, node: dict) -> dict: """ - Parse through all the parents of the DBT model and return the dagger inputs and outputs for the DBT model - Args: - model_name: The name of the DBT model - - Returns: - Tuple[list, list]: The dagger inputs and outputs for the DBT model - + Generates the dagger s3 task for the databricks-dbt model node """ - inputs_list = [] - - model_node = self._nodes_in_manifest[f"model.main.{model_name}"] - - parent_node_names = model_node.get("depends_on", {}).get("nodes", []) + task = S3_TASK_BASE.copy() - for parent_node_name in parent_node_names: - dagger_input = self._generate_dagger_tasks(parent_node_name) + catalog = node.get("database", self._default_catalog) + schema = node.get("schema", self._default_schema) + table = node.get("name", "") + task["name"] = f"{catalog}__{schema}__{table}_s3" + task["bucket"], task["path"] = self._get_model_data_location( + node, schema, table + ) - inputs_list += dagger_input + return task - output_list = self._generate_dagger_output(model_node) + def _generate_dagger_output(self, node: dict): + """ + Generates the dagger output for the DBT model node with the databricks-dbt adapter. + If the model is materialized as a view or ephemeral, then a dummy task is created. + Otherwise, and databricks and s3 task is created for the DBT model node. + And if create_external_athena_table is True te an extra athena task is created. + Args: + node: The extracted node from the manifest.json file - unique_inputs = list( - OrderedDict( - (frozenset(item.items()), item) for item in inputs_list - ).values() - ) + Returns: + dict: The dagger output, which is a combination of an athena and s3 task for the DBT model node - return unique_inputs, output_list + """ + if node.get("config", {}).get("materialized") in ( + "view", + "ephemeral", + ) or node.get("name").startswith("stg_"): + return [self._get_dummy_task(node)] + else: + output_tasks = [self._get_table_task(node), self._get_s3_task(node)] + if self._create_external_athena_table: + output_tasks.append(self._get_athena_table_task(node)) + return output_tasks diff --git a/dagger/utilities/module.py b/dagger/utilities/module.py index d565ffe..3cb261d 100644 --- a/dagger/utilities/module.py +++ b/dagger/utilities/module.py @@ -1,7 +1,7 @@ import logging from os import path from mergedeep import merge -from dagger.utilities.dbt_config_parser import DBTConfigParser +from dagger.utilities.dbt_config_parser import AthenaDBTConfigParser, DatabricksDBTConfigParser import yaml @@ -26,7 +26,10 @@ def __init__(self, path_to_config, target_dir): self._default_parameters = config.get("default_parameters", {}) if 'dbt' in self._tasks.keys(): - self._dbt_module = DBTConfigParser(self._default_parameters) + if self._default_parameters.get('profile_name') == 'athena': + self._dbt_module = AthenaDBTConfigParser(self._default_parameters) + if self._default_parameters.get('profile_name') == 'databricks': + self._dbt_module = DatabricksDBTConfigParser(self._default_parameters) @staticmethod def read_yaml(yaml_str): From 9274fe222618c7dabb620aac3b468d41d9f5179d Mon Sep 17 00:00:00 2001 From: claudiazi Date: Fri, 26 Apr 2024 10:48:31 +0200 Subject: [PATCH 071/189] feat: add unit test for databricks config parser --- ...y => dbt_config_parser_fixtures_athena.py} | 40 +- .../dbt_config_parser_fixtures_databricks.py | 385 ++++++++++++++++++ tests/utilities/test_dbt_config_parser.py | 98 ++++- 3 files changed, 485 insertions(+), 38 deletions(-) rename tests/fixtures/modules/{dbt_config_parser_fixtures.py => dbt_config_parser_fixtures_athena.py} (99%) create mode 100644 tests/fixtures/modules/dbt_config_parser_fixtures_databricks.py diff --git a/tests/fixtures/modules/dbt_config_parser_fixtures.py b/tests/fixtures/modules/dbt_config_parser_fixtures_athena.py similarity index 99% rename from tests/fixtures/modules/dbt_config_parser_fixtures.py rename to tests/fixtures/modules/dbt_config_parser_fixtures_athena.py index a28d871..66005e6 100644 --- a/tests/fixtures/modules/dbt_config_parser_fixtures.py +++ b/tests/fixtures/modules/dbt_config_parser_fixtures_athena.py @@ -1,3 +1,23 @@ +DBT_PROFILE_FIXTURE = { + "athena": { + "outputs": { + "data": { + "aws_profile_name": "data", + "database": "awsdatacatalog", + "num_retries": 10, + "region_name": "eu-west-1", + "s3_data_dir": "s3://bucket1-data-lake/path1/tmp", + "s3_data_naming": "schema_table", + "s3_staging_dir": "s3://bucket1-data-lake/path1/", + "schema": "analytics_engineering", + "threads": 4, + "type": "athena", + "work_group": "primary", + }, + } + } +} + DBT_MANIFEST_FILE_FIXTURE = { "nodes": { "model.main.model1": { @@ -162,26 +182,6 @@ }, } -DBT_PROFILE_FIXTURE = { - "main": { - "outputs": { - "data": { - "aws_profile_name": "data", - "database": "awsdatacatalog", - "num_retries": 10, - "region_name": "eu-west-1", - "s3_data_dir": "s3://bucket1-data-lake/path1/tmp", - "s3_data_naming": "schema_table", - "s3_staging_dir": "s3://bucket1-data-lake/path1/", - "schema": "analytics_engineering", - "threads": 4, - "type": "athena", - "work_group": "primary", - }, - } - } -} - EXPECTED_STAGING_NODE = [ { "name": "stg_core_schema1__table1", diff --git a/tests/fixtures/modules/dbt_config_parser_fixtures_databricks.py b/tests/fixtures/modules/dbt_config_parser_fixtures_databricks.py new file mode 100644 index 0000000..5c6e0a4 --- /dev/null +++ b/tests/fixtures/modules/dbt_config_parser_fixtures_databricks.py @@ -0,0 +1,385 @@ +DATABRICKS_DBT_PROFILE_FIXTURE = { + "databricks": { + "outputs": { + "data": { + "type": "databricks", + "catalog": "hive_metastore", + "schema": "analytics_engineering", + "host": "xxx.databricks.com", + "http_path": "/sql/1.0/warehouses/xxx", + "token": "{{ env_var('SECRETDATABRICKS') }}" + }, + } + + } +} + +DATABRICKS_DBT_MANIFEST_FILE_FIXTURE = { + "nodes": { + "model.main.model1": { + "database": "marts", + "schema": "analytics_engineering", + "name": "model1", + "unique_id": "model.main.model1", + "resource_type": "model", + "config": { + "location_root": "s3://chodata-data-lake/analytics_warehouse/data/marts", + "materialized": "incremental", + "incremental_strategy": "insert_overwrite", + }, + "description": "Details of revenue calculation at supplier level for each observation day", + "tags": ["daily"], + "unrendered_config": { + "materialized": "incremental", + "location_root": "s3://chodata-data-lake/analytics_warehouse/data/marts", + "incremental_strategy": "insert_overwrite", + "partitioned_by": ["year", "month", "day", "dt"], + "tags": ["daily"], + "on_schema_change": "fail", + }, + "depends_on": { + "macros": [ + "macro.main.macro1", + "macro.main.macro2", + ], + "nodes": [ + "model.main.stg_core_schema2__table2", + "model.main.model2", + "model.main.int_model3", + "seed.main.seed_buyer_country_overwrite", + ], + }, + }, + "model.main.stg_core_schema1__table1": { + "database": "hive_metastore", + "schema": "data_preparation", + "name": "stg_core_schema1__table1", + "unique_id": "model.main.stg_core_schema1__table1", + "resource_type": "model", + "config": { + "location_root": "s3://chodata-data-lake/analytics_warehouse/data/preparation", + "materialized": "view", + }, + "depends_on": { + "macros": [], + "nodes": ["source.main.core_schema1.table1"], + }, + }, + "model.main.stg_core_schema2__table2": { + "database": "hive_metastore", + "schema": "data_preparation", + "name": "stg_core_schema2__table2", + "unique_id": "model.main.stg_core_schema2__table2", + "resource_type": "model", + "config": { + "location_root": "s3://chodata-data-lake/analytics_warehouse/data/preparation", + "materialized": "view", + }, + "depends_on": { + "macros": [], + "nodes": [ + "source.main.core_schema2.table2", + "source.main.core_schema2.table3", + "seed.main.seed_buyer_country_overwrite", + ], + }, + }, + "model.main.model2": { + "database": "marts", + "schema": "analytics_engineering", + "name": "model2", + "unique_id": "model.main.model2", + "resource_type": "model", + "config": { + "location_root": "s3://chodata-data-lake/analytics_warehouse/data/marts", + "materialized": "table", + }, + "depends_on": {"macros": [], "nodes": []}, + }, + "model.main.int_model3": { + "name": "int_model3", + "unique_id": "model.main.int_model3", + "database": "intermediate", + "schema": "analytics_engineering", + "resource_type": "model", + "config": { + "materialized": "ephemeral", + "location_root": "s3://chodata-data-lake/analytics_warehouse/data/intermediate", + }, + "depends_on": { + "macros": [], + "nodes": ["model.main.int_model2"], + }, + }, + "seed.main.seed_buyer_country_overwrite": { + "database": "hive_metastore", + "schema": "datastg_preparation", + "name": "seed_buyer_country_overwrite", + "unique_id": "seed.main.seed_buyer_country_overwrite", + "resource_type": "seed", + "alias": "seed_buyer_country_overwrite", + "tags": ["analytics"], + "description": "", + "created_at": 1700216177.105391, + "depends_on": {"macros": []}, + }, + "model.main.model3": { + "name": "model3", + "database": "marts", + "schema": "analytics_engineering", + "unique_id": "model.main.model3", + "config": { + "location_root": "s3://chodata-data-lake/analytics_warehouse/data/marts", + }, + "depends_on": { + "macros": [], + "nodes": [ + "model.main.int_model3", + "model.main.model2", + "seed.main.seed_buyer_country_overwrite", + "model.main.stg_core_schema2__table2", + ], + }, + }, + "model.main.int_model2": { + "name": "int_model2", + "unique_id": "model.main.int_model2", + "database": "intermediate", + "schema": "analytics_engineering", + "config": { + "materialized": "ephemeral", + "location_root": "s3://chodata-data-lake/analytics_warehouse/data/intermediate", + }, + "depends_on": { + "macros": [], + "nodes": [ + "seed.main.seed_buyer_country_overwrite", + "model.main.stg_core_schema1__table1", + ], + }, + }, + }, + "sources": { + "source.main.core_schema1.table1": { + "source_name": "table1", + "database": "hive_metastore", + "schema": "core_schema1", + "resource_type": "source", + "unique_id": "source.main.core_schema1.table1", + "name": "table1", + "tags": ["analytics"], + "description": "", + }, + "source.main.core_schema2.table2": { + "source_name": "table2", + "database": "hive_metastore", + "schema": "core_schema2", + "resource_type": "source", + "unique_id": "source.main.core_schema2.table2", + "name": "table2", + "tags": ["analytics"], + "description": "", + }, + "source.main.core_schema2.table3": { + "source_name": "table3", + "database": "hive_metastore", + "schema": "core_schema2", + "resource_type": "source", + "unique_id": "source.main.core_schema2.table3", + "name": "table3", + "tags": ["analytics"], + "description": "", + }, + }, +} + +DATABRICKS_EXPECTED_STAGING_NODE = [ + { + "name": "stg_core_schema1__table1", + "type": "dummy", + "follow_external_dependency": True, + }, +] + +DATABRICKS_EXPECTED_SEED_NODE = [ + { + "type": "dummy", + "name": "seed_buyer_country_overwrite", + } +] + +DATABRICKS_EXPECTED_MODEL_MULTIPLE_DEPENDENCIES = [ + { + "type": "dummy", + "name": "int_model3", + "follow_external_dependency": True, + }, + { + "type": "dummy", + "name": "int_model2", + "follow_external_dependency": True, + }, + { + "type": "dummy", + "name": "seed_buyer_country_overwrite", + }, + { + "name": "stg_core_schema1__table1", + "type": "dummy", + "follow_external_dependency": True, + }, + { + "type": "databricks", + "name": "marts__analytics_engineering__model2_databricks", + "catalog": "marts", + "schema": "analytics_engineering", + "table": "model2", + "follow_external_dependency": True, + }, + { + "bucket": "chodata-data-lake", + "name": "marts__analytics_engineering__model2_s3", + "path": "analytics_warehouse/data/marts/analytics_engineering/model2", + "type": "s3", + }, + { + "name": "stg_core_schema2__table2", + "type": "dummy", + "follow_external_dependency": True, + }, +] + +DATABRICKS_EXPECTED_EPHEMERAL_NODE = [ + { + "type": "dummy", + "name": "int_model3", + "follow_external_dependency": True, + }, + { + "type": "dummy", + "name": "int_model2", + "follow_external_dependency": True, + }, + { + "type": "dummy", + "name": "seed_buyer_country_overwrite", + }, + { + "name": "stg_core_schema1__table1", + "type": "dummy", + "follow_external_dependency": True, + } +] + +DATABRICKS_EXPECTED_MODEL_NODE = [ + { + "type": "databricks", + "name": "marts__analytics_engineering__model1_databricks", + "catalog": "marts", + "schema": "analytics_engineering", + "table": "model1", + "follow_external_dependency": True, + }, + { + "bucket": "chodata-data-lake", + "name": "marts__analytics_engineering__model1_s3", + "path": "analytics_warehouse/data/marts/analytics_engineering/model1", + "type": "s3", + }, +] + +DATABRICKS_EXPECTED_DAGGER_INPUTS = [ + { + "name": "stg_core_schema2__table2", + "type": "dummy", + "follow_external_dependency": True, + }, + { + "name": "marts__analytics_engineering__model2_databricks", + "catalog": "marts", + "schema": "analytics_engineering", + "table": "model2", + "type": "databricks", + "follow_external_dependency": True, + }, + { + "bucket": "chodata-data-lake", + "name": "marts__analytics_engineering__model2_s3", + "path": "analytics_warehouse/data/marts/analytics_engineering/model2", + "type": "s3", + }, + { + "type": "dummy", + "name": "int_model3", + "follow_external_dependency": True, + }, + { + "type": "dummy", + "name": "int_model2", + "follow_external_dependency": True, + }, + {"name": "seed_buyer_country_overwrite", "type": "dummy"}, + { + "name": "stg_core_schema1__table1", + "type": "dummy", + "follow_external_dependency": True, + }, +] + +DATABRICKS_EXPECTED_DBT_STAGING_MODEL_DAGGER_INPUTS = [ + { + "follow_external_dependency": True, + "name": "core_schema2__table2_athena", + "schema": "core_schema2", + "table": "table2", + "type": "athena", + }, + { + "follow_external_dependency": True, + "name": "core_schema2__table3_athena", + "schema": "core_schema2", + "table": "table3", + "type": "athena", + }, + {"name": "seed_buyer_country_overwrite", "type": "dummy"}, +] + +DATABRICKS_EXPECTED_DBT_INT_MODEL_DAGGER_INPUTS = [ + {"name": "seed_buyer_country_overwrite", "type": "dummy"}, + { + "name": "stg_core_schema1__table1", + "type": "dummy", + "follow_external_dependency": True, + }, +] + +DATABRICKS_EXPECTED_DAGGER_OUTPUTS = [ + { + "name": "marts__analytics_engineering__model1_databricks", + "catalog": "marts", + "schema": "analytics_engineering", + "table": "model1", + "type": "databricks", + }, + { + "bucket": "chodata-data-lake", + "name": "marts__analytics_engineering__model1_s3", + "path": "analytics_warehouse/data/marts/analytics_engineering/model1", + "type": "s3", + }, + { + "name": "analytics_engineering__model1_athena", + "schema": "analytics_engineering", + "table": "model1", + "type": "athena", + } +] + +DATABRICKS_EXPECTED_DBT_STAGING_MODEL_DAGGER_OUTPUTS = [ + { + "type": "dummy", + "name": "stg_core_schema2__table2", + }, +] + + diff --git a/tests/utilities/test_dbt_config_parser.py b/tests/utilities/test_dbt_config_parser.py index 3fd6394..8c188d3 100644 --- a/tests/utilities/test_dbt_config_parser.py +++ b/tests/utilities/test_dbt_config_parser.py @@ -3,22 +3,10 @@ from unittest import skip from unittest.mock import patch, MagicMock -from dagger.utilities.dbt_config_parser import DBTConfigParser +from dagger.utilities.dbt_config_parser import AthenaDBTConfigParser, DatabricksDBTConfigParser from dagger.utilities.module import Module -from tests.fixtures.modules.dbt_config_parser_fixtures import ( - EXPECTED_DAGGER_OUTPUTS, - EXPECTED_DAGGER_INPUTS, - DBT_MANIFEST_FILE_FIXTURE, - DBT_PROFILE_FIXTURE, - EXPECTED_STAGING_NODE, - EXPECTED_SEED_NODE, - EXPECTED_MODEL_MULTIPLE_DEPENDENCIES, - EXPECTED_EPHEMERAL_NODE, - EXPECTED_DBT_STAGING_MODEL_DAGGER_OUTPUTS, - EXPECTED_DBT_STAGING_MODEL_DAGGER_INPUTS, - EXPECTED_MODEL_NODE, - EXPECTED_DBT_INT_MODEL_DAGGER_INPUTS -) +from tests.fixtures.modules.dbt_config_parser_fixtures_athena import * +from tests.fixtures.modules.dbt_config_parser_fixtures_databricks import * _logger = logging.getLogger("root") @@ -26,17 +14,25 @@ "data_bucket": "bucket1-data-lake", "project_dir": "main", "profile_dir": ".dbt", - "dbt_profile": "data", + "profile_name": "athena", + "target_name": "data", +} +DATABRICKS_DEFAULT_CONFIG_PARAMS = { + "project_dir": "main", + "profile_dir": ".dbt", + "profile_name": "databricks", + "target_name": "data", + "create_external_athena_table": True, } MODEL_NAME = "model1" -class TestDBTConfigParser(unittest.TestCase): +class TestAthenaDBTConfigParser(unittest.TestCase): @patch("builtins.open", new_callable=MagicMock, read_data=DBT_MANIFEST_FILE_FIXTURE) @patch("json.loads", return_value=DBT_MANIFEST_FILE_FIXTURE) @patch("yaml.safe_load", return_value=DBT_PROFILE_FIXTURE) def setUp(self, mock_open, mock_json_load, mock_safe_load): - self._dbt_config_parser = DBTConfigParser(DEFAULT_CONFIG_PARAMS) + self._dbt_config_parser = AthenaDBTConfigParser(DEFAULT_CONFIG_PARAMS) self._sample_dbt_node = DBT_MANIFEST_FILE_FIXTURE["nodes"]["model.main.model1"] @skip("Run only locally") @@ -95,3 +91,69 @@ def test_generate_io_outputs(self): _, result = self._dbt_config_parser.generate_dagger_io(mock_input) self.assertListEqual(result, expected_output) + + +class TestDatabricksDBTConfigParser(unittest.TestCase): + @patch("builtins.open", new_callable=MagicMock, read_data=DATABRICKS_DBT_MANIFEST_FILE_FIXTURE) + @patch("json.loads", return_value=DATABRICKS_DBT_MANIFEST_FILE_FIXTURE) + @patch("yaml.safe_load", return_value=DATABRICKS_DBT_PROFILE_FIXTURE) + def setUp(self, mock_open, mock_json_load, mock_safe_load): + self._dbt_config_parser = DatabricksDBTConfigParser(DATABRICKS_DEFAULT_CONFIG_PARAMS) + self._sample_dbt_node = DATABRICKS_DBT_MANIFEST_FILE_FIXTURE["nodes"]["model.main.model1"] + + @skip("Run only locally") + def test_generate_task_configs(self): + module = Module( + path_to_config="./tests/fixtures/modules/dbt_test_config.yaml", + target_dir="./tests/fixtures/modules/", + ) + + module.generate_task_configs() + + def test_generate_dagger_tasks(self): + test_inputs = [ + ( + "model.main.stg_core_schema1__table1", + DATABRICKS_EXPECTED_STAGING_NODE, + ), + ( + "seed.main.seed_buyer_country_overwrite", + DATABRICKS_EXPECTED_SEED_NODE, + ), + ( + "model.main.int_model3", + DATABRICKS_EXPECTED_EPHEMERAL_NODE, + ), + ( + "model.main.model1", + DATABRICKS_EXPECTED_MODEL_NODE, + ), + ] + for mock_input, expected_output in test_inputs: + result = self._dbt_config_parser._generate_dagger_tasks(mock_input) + self.assertListEqual(result, expected_output) + + def test_generate_io_inputs(self): + fixtures = [ + ("model1", DATABRICKS_EXPECTED_DAGGER_INPUTS), + ( + "model3", + DATABRICKS_EXPECTED_MODEL_MULTIPLE_DEPENDENCIES, + ), + ("stg_core_schema2__table2", DATABRICKS_EXPECTED_DBT_STAGING_MODEL_DAGGER_INPUTS), + ("int_model2", DATABRICKS_EXPECTED_DBT_INT_MODEL_DAGGER_INPUTS), + ] + for mock_input, expected_output in fixtures: + result, _ = self._dbt_config_parser.generate_dagger_io(mock_input) + + self.assertListEqual(result, expected_output) + + def test_generate_io_outputs(self): + fixtures = [ + ("model1", DATABRICKS_EXPECTED_DAGGER_OUTPUTS), + ("stg_core_schema2__table2", DATABRICKS_EXPECTED_DBT_STAGING_MODEL_DAGGER_OUTPUTS), + ] + for mock_input, expected_output in fixtures: + _, result = self._dbt_config_parser.generate_dagger_io(mock_input) + + self.assertListEqual(result, expected_output) From e951c474b82975ba6a40bb3039ff300719a9441d Mon Sep 17 00:00:00 2001 From: claudiazi Date: Fri, 26 Apr 2024 10:53:11 +0200 Subject: [PATCH 072/189] chore: black --- dagger/pipeline/ios/databricks_io.py | 4 +-- dagger/pipeline/tasks/dbt_task.py | 10 ++++--- dagger/utilities/dbt_config_parser.py | 23 ++++++++++----- dagger/utilities/module.py | 11 ++++--- .../dbt_config_parser_fixtures_athena.py | 2 +- .../dbt_config_parser_fixtures_databricks.py | 9 ++---- tests/utilities/test_dbt_config_parser.py | 29 +++++++++++++++---- 7 files changed, 57 insertions(+), 31 deletions(-) diff --git a/dagger/pipeline/ios/databricks_io.py b/dagger/pipeline/ios/databricks_io.py index dd9041b..15be2c1 100644 --- a/dagger/pipeline/ios/databricks_io.py +++ b/dagger/pipeline/ios/databricks_io.py @@ -10,9 +10,7 @@ def init_attributes(cls, orig_cls): cls.add_config_attributes( [ Attribute(attribute_name="catalog"), - Attribute( - attribute_name="schema" - ), + Attribute(attribute_name="schema"), Attribute(attribute_name="table"), ] ) diff --git a/dagger/pipeline/tasks/dbt_task.py b/dagger/pipeline/tasks/dbt_task.py index aea0945..e59ea5a 100644 --- a/dagger/pipeline/tasks/dbt_task.py +++ b/dagger/pipeline/tasks/dbt_task.py @@ -28,7 +28,7 @@ def init_attributes(cls, orig_cls): attribute_name="target_name", parent_fields=["task_parameters"], comment="Which target to load for the given profile " - "(--target dbt option). Default is 'default'", + "(--target dbt option). Default is 'default'", ), Attribute( attribute_name="select", @@ -52,7 +52,7 @@ def init_attributes(cls, orig_cls): required=False, parent_fields=["task_parameters"], comment="Specify whether to create an external Athena table for the model", - ) + ), ] ) @@ -66,7 +66,9 @@ def __init__(self, name, pipeline_name, pipeline, job_config): self._select = self.parse_attribute("select") self._dbt_command = self.parse_attribute("dbt_command") self._vars = self.parse_attribute("vars") - self._create_external_athena_table = self.parse_attribute("create_external_athena_table") + self._create_external_athena_table = self.parse_attribute( + "create_external_athena_table" + ) @property def project_dir(self): @@ -98,4 +100,4 @@ def vars(self): @property def create_external_athena_table(self): - return self._create_external_athena_table \ No newline at end of file + return self._create_external_athena_table diff --git a/dagger/utilities/dbt_config_parser.py b/dagger/utilities/dbt_config_parser.py index 6c2ae5d..11c4325 100644 --- a/dagger/utilities/dbt_config_parser.py +++ b/dagger/utilities/dbt_config_parser.py @@ -134,7 +134,9 @@ def _generate_dagger_tasks(self, node_name: str) -> List[Dict]: task = self._get_dummy_task(node) dagger_tasks.append(task) elif node.get("resource_type") == "source": - table_task = self._get_athena_table_task(node, follow_external_dependency=True) + table_task = self._get_athena_table_task( + node, follow_external_dependency=True + ) dagger_tasks.append(table_task) elif node.get("config", {}).get("materialized") == "ephemeral": task = self._get_dummy_task(node, follow_external_dependency=True) @@ -192,6 +194,7 @@ def generate_dagger_io(self, model_name: str) -> Tuple[List[dict], List[dict]]: class AthenaDBTConfigParser(DBTConfigParser): """Implementation for Athena configurations.""" + def __init__(self, default_config_parameters: dict): super().__init__(default_config_parameters) self._profile_name = "athena" @@ -200,7 +203,6 @@ def __init__(self, default_config_parameters: dict): "s3_data_dir" ) or self._target_config.get("s3_staging_dir") - def _get_table_task( self, node: dict, follow_external_dependency: bool = False ) -> dict: @@ -217,7 +219,9 @@ def _get_table_task( return task - def _get_athena_table_task(self, node: dict, follow_external_dependency: bool = False) -> dict: + def _get_athena_table_task( + self, node: dict, follow_external_dependency: bool = False + ) -> dict: return self._get_table_task(node, follow_external_dependency) def _get_model_data_location( @@ -247,11 +251,12 @@ def _get_s3_task(self, node: dict) -> dict: """ task = S3_TASK_BASE.copy() - schema = node.get("schema", self._default_schema) table = node.get("name", "") task["name"] = f"{schema}__{table}_s3" - task["bucket"], task["path"] = self._get_model_data_location(node, schema, table) + task["bucket"], task["path"] = self._get_model_data_location( + node, schema, table + ) return task def _generate_dagger_output(self, node: dict): @@ -282,7 +287,9 @@ def __init__(self, default_config_parameters: dict): self._profile_name = "databricks" self._default_catalog = self._target_config.get("catalog") self._athena_dbt_parser = AthenaDBTConfigParser(default_config_parameters) - self._create_external_athena_table = default_config_parameters.get("create_external_athena_table", False) + self._create_external_athena_table = default_config_parameters.get( + "create_external_athena_table", False + ) def _get_table_task( self, node: dict, follow_external_dependency: bool = False @@ -303,7 +310,9 @@ def _get_table_task( return task - def _get_athena_table_task(self, node: dict, follow_external_dependency: bool = False) -> dict: + def _get_athena_table_task( + self, node: dict, follow_external_dependency: bool = False + ) -> dict: return self._athena_dbt_parser._get_table_task(node, follow_external_dependency) def _get_model_data_location( diff --git a/dagger/utilities/module.py b/dagger/utilities/module.py index 3cb261d..6b6aa86 100644 --- a/dagger/utilities/module.py +++ b/dagger/utilities/module.py @@ -1,7 +1,10 @@ import logging from os import path from mergedeep import merge -from dagger.utilities.dbt_config_parser import AthenaDBTConfigParser, DatabricksDBTConfigParser +from dagger.utilities.dbt_config_parser import ( + AthenaDBTConfigParser, + DatabricksDBTConfigParser, +) import yaml @@ -25,10 +28,10 @@ def __init__(self, path_to_config, target_dir): self._override_parameters = config.get("override_parameters", {}) self._default_parameters = config.get("default_parameters", {}) - if 'dbt' in self._tasks.keys(): - if self._default_parameters.get('profile_name') == 'athena': + if "dbt" in self._tasks.keys(): + if self._default_parameters.get("profile_name") == "athena": self._dbt_module = AthenaDBTConfigParser(self._default_parameters) - if self._default_parameters.get('profile_name') == 'databricks': + if self._default_parameters.get("profile_name") == "databricks": self._dbt_module = DatabricksDBTConfigParser(self._default_parameters) @staticmethod diff --git a/tests/fixtures/modules/dbt_config_parser_fixtures_athena.py b/tests/fixtures/modules/dbt_config_parser_fixtures_athena.py index 66005e6..5f44af4 100644 --- a/tests/fixtures/modules/dbt_config_parser_fixtures_athena.py +++ b/tests/fixtures/modules/dbt_config_parser_fixtures_athena.py @@ -256,7 +256,7 @@ "name": "stg_core_schema1__table1", "type": "dummy", "follow_external_dependency": True, - } + }, ] EXPECTED_MODEL_NODE = [ diff --git a/tests/fixtures/modules/dbt_config_parser_fixtures_databricks.py b/tests/fixtures/modules/dbt_config_parser_fixtures_databricks.py index 5c6e0a4..94a387f 100644 --- a/tests/fixtures/modules/dbt_config_parser_fixtures_databricks.py +++ b/tests/fixtures/modules/dbt_config_parser_fixtures_databricks.py @@ -7,10 +7,9 @@ "schema": "analytics_engineering", "host": "xxx.databricks.com", "http_path": "/sql/1.0/warehouses/xxx", - "token": "{{ env_var('SECRETDATABRICKS') }}" + "token": "{{ env_var('SECRETDATABRICKS') }}", }, } - } } @@ -268,7 +267,7 @@ "name": "stg_core_schema1__table1", "type": "dummy", "follow_external_dependency": True, - } + }, ] DATABRICKS_EXPECTED_MODEL_NODE = [ @@ -372,7 +371,7 @@ "schema": "analytics_engineering", "table": "model1", "type": "athena", - } + }, ] DATABRICKS_EXPECTED_DBT_STAGING_MODEL_DAGGER_OUTPUTS = [ @@ -381,5 +380,3 @@ "name": "stg_core_schema2__table2", }, ] - - diff --git a/tests/utilities/test_dbt_config_parser.py b/tests/utilities/test_dbt_config_parser.py index 8c188d3..9e4d18f 100644 --- a/tests/utilities/test_dbt_config_parser.py +++ b/tests/utilities/test_dbt_config_parser.py @@ -3,7 +3,10 @@ from unittest import skip from unittest.mock import patch, MagicMock -from dagger.utilities.dbt_config_parser import AthenaDBTConfigParser, DatabricksDBTConfigParser +from dagger.utilities.dbt_config_parser import ( + AthenaDBTConfigParser, + DatabricksDBTConfigParser, +) from dagger.utilities.module import Module from tests.fixtures.modules.dbt_config_parser_fixtures_athena import * from tests.fixtures.modules.dbt_config_parser_fixtures_databricks import * @@ -94,12 +97,20 @@ def test_generate_io_outputs(self): class TestDatabricksDBTConfigParser(unittest.TestCase): - @patch("builtins.open", new_callable=MagicMock, read_data=DATABRICKS_DBT_MANIFEST_FILE_FIXTURE) + @patch( + "builtins.open", + new_callable=MagicMock, + read_data=DATABRICKS_DBT_MANIFEST_FILE_FIXTURE, + ) @patch("json.loads", return_value=DATABRICKS_DBT_MANIFEST_FILE_FIXTURE) @patch("yaml.safe_load", return_value=DATABRICKS_DBT_PROFILE_FIXTURE) def setUp(self, mock_open, mock_json_load, mock_safe_load): - self._dbt_config_parser = DatabricksDBTConfigParser(DATABRICKS_DEFAULT_CONFIG_PARAMS) - self._sample_dbt_node = DATABRICKS_DBT_MANIFEST_FILE_FIXTURE["nodes"]["model.main.model1"] + self._dbt_config_parser = DatabricksDBTConfigParser( + DATABRICKS_DEFAULT_CONFIG_PARAMS + ) + self._sample_dbt_node = DATABRICKS_DBT_MANIFEST_FILE_FIXTURE["nodes"][ + "model.main.model1" + ] @skip("Run only locally") def test_generate_task_configs(self): @@ -140,7 +151,10 @@ def test_generate_io_inputs(self): "model3", DATABRICKS_EXPECTED_MODEL_MULTIPLE_DEPENDENCIES, ), - ("stg_core_schema2__table2", DATABRICKS_EXPECTED_DBT_STAGING_MODEL_DAGGER_INPUTS), + ( + "stg_core_schema2__table2", + DATABRICKS_EXPECTED_DBT_STAGING_MODEL_DAGGER_INPUTS, + ), ("int_model2", DATABRICKS_EXPECTED_DBT_INT_MODEL_DAGGER_INPUTS), ] for mock_input, expected_output in fixtures: @@ -151,7 +165,10 @@ def test_generate_io_inputs(self): def test_generate_io_outputs(self): fixtures = [ ("model1", DATABRICKS_EXPECTED_DAGGER_OUTPUTS), - ("stg_core_schema2__table2", DATABRICKS_EXPECTED_DBT_STAGING_MODEL_DAGGER_OUTPUTS), + ( + "stg_core_schema2__table2", + DATABRICKS_EXPECTED_DBT_STAGING_MODEL_DAGGER_OUTPUTS, + ), ] for mock_input, expected_output in fixtures: _, result = self._dbt_config_parser.generate_dagger_io(mock_input) From 4567b3e02266ab897c1a26c2edb1a85342065b5a Mon Sep 17 00:00:00 2001 From: David Siklosi Date: Fri, 26 Apr 2024 12:48:23 +0200 Subject: [PATCH 073/189] Switching to official batch operator --- .../operator_creators/batch_creator.py | 28 +- .../operator_creators/spark_creator.py | 2 +- .../airflow/operators/awsbatch_operator.py | 275 ++++++------------ 3 files changed, 105 insertions(+), 200 deletions(-) diff --git a/dagger/dag_creator/airflow/operator_creators/batch_creator.py b/dagger/dag_creator/airflow/operator_creators/batch_creator.py index a3d2534..0cfe9fb 100644 --- a/dagger/dag_creator/airflow/operator_creators/batch_creator.py +++ b/dagger/dag_creator/airflow/operator_creators/batch_creator.py @@ -1,5 +1,9 @@ +from pathlib import Path +from datetime import timedelta + from dagger.dag_creator.airflow.operator_creator import OperatorCreator from dagger.dag_creator.airflow.operators.awsbatch_operator import AWSBatchOperator +from dagger import conf class BatchCreator(OperatorCreator): @@ -8,6 +12,20 @@ class BatchCreator(OperatorCreator): def __init__(self, task, dag): super().__init__(task, dag) + @staticmethod + def _validate_job_name(job_name, absolute_job_name): + if not absolute_job_name and not job_name: + raise Exception("Both job_name and absolute_job_name cannot be null") + + if absolute_job_name is not None: + return absolute_job_name + + job_path = Path(conf.DAGS_DIR) / job_name.replace("-", "/") + assert ( + job_path.is_dir() + ), f"Job name `{job_name}`, points to a non-existing folder `{job_path}`" + return job_name + def _generate_command(self): command = [self._task.executable_prefix, self._task.executable] for param_name, param_value in self._template_parameters.items(): @@ -21,16 +39,16 @@ def _create_operator(self, **kwargs): overrides = self._task.overrides overrides.update({"command": self._generate_command()}) + job_name = self._validate_job_name(self._task.job_name, self._task.absolute_job_name) batch_op = AWSBatchOperator( dag=self._dag, task_id=self._task.name, - job_name=self._task.job_name, - absolute_job_name=self._task.absolute_job_name, + job_name=self._task.name, + job_definition=job_name, region_name=self._task.region_name, - cluster_name=self._task.cluster_name, job_queue=self._task.job_queue, - overrides=overrides, + container_overrides=overrides, + awslogs_enabled=True, **kwargs, ) - return batch_op diff --git a/dagger/dag_creator/airflow/operator_creators/spark_creator.py b/dagger/dag_creator/airflow/operator_creators/spark_creator.py index 2bb41e9..c48ebda 100644 --- a/dagger/dag_creator/airflow/operator_creators/spark_creator.py +++ b/dagger/dag_creator/airflow/operator_creators/spark_creator.py @@ -113,7 +113,7 @@ def _create_operator(self, **kwargs): job_name=job_name, region_name=self._task.region_name, job_queue=self._task.job_queue, - overrides=overrides, + container_overrides=overrides, **kwargs, ) elif self._task.spark_engine == "glue": diff --git a/dagger/dag_creator/airflow/operators/awsbatch_operator.py b/dagger/dag_creator/airflow/operators/awsbatch_operator.py index a267ba7..b2f4bb3 100644 --- a/dagger/dag_creator/airflow/operators/awsbatch_operator.py +++ b/dagger/dag_creator/airflow/operators/awsbatch_operator.py @@ -1,203 +1,90 @@ -from pathlib import Path -from time import sleep - -from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook +from airflow.providers.amazon.aws.operators.batch import BatchOperator +from airflow.utils.context import Context from airflow.exceptions import AirflowException -from airflow.utils.decorators import apply_defaults - -from dagger.dag_creator.airflow.operators.dagger_base_operator import DaggerBaseOperator -from dagger.dag_creator.airflow.utils.decorators import lazy_property -from dagger import conf - - -class AWSBatchOperator(DaggerBaseOperator): - """ - Execute a job on AWS Batch Service - - .. warning: the queue parameter was renamed to job_queue to segregate the - internal CeleryExecutor queue from the AWS Batch internal queue. - - :param job_name: the name for the job that will run on AWS Batch - :type job_name: str - :param job_definition: the job definition name on AWS Batch - :type job_definition: str - :param job_queue: the queue name on AWS Batch - :type job_queue: str - :param overrides: the same parameter that boto3 will receive on - containerOverrides (templated): - http://boto3.readthedocs.io/en/latest/reference/services/batch.html#submit_job - :type overrides: dict - :param max_retries: exponential backoff retries while waiter is not - merged, 4200 = 48 hours - :type max_retries: int - :param aws_conn_id: connection id of AWS credentials / region name. If None, - credential boto3 strategy will be used - (http://boto3.readthedocs.io/en/latest/guide/configuration.html). - :type aws_conn_id: str - :param region_name: region name to use in AWS Hook. - Override the region_name in connection (if provided) - :type region_name: str - :param cluster_name: Batch cluster short name or arn - :type region_name: str - - """ - - ui_color = "#c3dae0" - client = None - arn = None - template_fields = ("overrides",) +from airflow.providers.amazon.aws.links.batch import ( + BatchJobDefinitionLink, + BatchJobQueueLink, +) +from airflow.providers.amazon.aws.links.logs import CloudWatchEventsLink - @apply_defaults - def __init__( - self, - job_queue, - job_name=None, - absolute_job_name=None, - overrides=None, - job_definition=None, - aws_conn_id=None, - region_name=None, - cluster_name=None, - *args, - **kwargs, - ): - super().__init__(*args, **kwargs) - self.job_name = self._validate_job_name(job_name, absolute_job_name) - self.aws_conn_id = aws_conn_id - self.region_name = region_name - self.cluster_name = cluster_name - self.job_definition = job_definition or self.job_name - self.job_queue = job_queue - self.overrides = overrides or {} - self.job_id = None - - @lazy_property - def batch_client(self): - return AwsBaseHook(aws_conn_id=self.aws_conn_id, client_type="batch").get_client_type( - region_name=self.region_name) - - @lazy_property - def logs_client(self): - return AwsBaseHook(aws_conn_id=self.aws_conn_id, client_type="logs").get_client_type( - region_name=self.region_name) - - @lazy_property - def ecs_client(self): - return AwsBaseHook(aws_conn_id=self.aws_conn_id, client_type="ecs").get_client_type( - region_name=self.region_name) +class AWSBatchOperator(AWSBatchOperator): @staticmethod - def _validate_job_name(job_name, absolute_job_name): - if absolute_job_name is None and job_name is None: - raise Exception("Both job_name and absolute_job_name cannot be null") - - if absolute_job_name is not None: - return absolute_job_name - - job_path = Path(conf.DAGS_DIR) / job_name.replace("-", "/") - assert ( - job_path.is_dir() - ), f"Job name `{job_name}`, points to a non-existing folder `{job_path}`" - return job_name - - def execute(self, context): - self.task_instance = context["ti"] - self.log.info( - "\n" - f"\n\tJob name: {self.job_name}" - f"\n\tJob queue: {self.job_queue}" - f"\n\tJob definition: {self.job_definition}" - "\n" - ) - - res = self.batch_client.submit_job( - jobName=self.job_name, - jobQueue=self.job_queue, - jobDefinition=self.job_definition, - containerOverrides=self.overrides, - ) - self.job_id = res["jobId"] - self.log.info( - "\n" - f"\n\tJob ID: {self.job_id}" - "\n" - ) - self.poll_task() - - def poll_task(self): - log_offset = 0 - print_logs_url = True - - while True: - res = self.batch_client.describe_jobs(jobs=[self.job_id]) - - if len(res["jobs"]) == 0: - sleep(3) - continue - - job = res["jobs"][0] - job_status = job["status"] - log_stream_name = job["container"].get("logStreamName") - - if print_logs_url and log_stream_name: - print_logs_url = False - self.log.info( - "\n" - f"\n\tLogs at: https://{self.region_name}.console.aws.amazon.com/cloudwatch/home?" - f"region={self.region_name}#logEventViewer:group=/aws/batch/job;stream={log_stream_name}" - "\n" - ) + def _format_cloudwatch_link(awslogs_region: str, awslogs_group: str, awslogs_stream_name: str): + return f"https://{awslogs_region}.console.aws.amazon.com/cloudwatch/home?region={awslogs_region}#logEventViewer:group={awslogs_group};stream={awslogs_stream_name}" + + def monitor_job(self, context: Context): + """Monitor an AWS Batch job. + + This can raise an exception or an AirflowTaskTimeout if the task was + created with ``execution_timeout``. + """ + if not self.job_id: + raise AirflowException("AWS Batch job - job_id was not found") + + try: + job_desc = self.hook.get_job_description(self.job_id) + job_definition_arn = job_desc["jobDefinition"] + job_queue_arn = job_desc["jobQueue"] + self.log.info( + "AWS Batch job (%s) Job Definition ARN: %r, Job Queue ARN: %r", + self.job_id, + job_definition_arn, + job_queue_arn, + ) + except KeyError: + self.log.warning("AWS Batch job (%s) can't get Job Definition ARN and Job Queue ARN", self.job_id) + else: + BatchJobDefinitionLink.persist( + context=context, + operator=self, + region_name=self.hook.conn_region_name, + aws_partition=self.hook.conn_partition, + job_definition_arn=job_definition_arn, + ) + BatchJobQueueLink.persist( + context=context, + operator=self, + region_name=self.hook.conn_region_name, + aws_partition=self.hook.conn_partition, + job_queue_arn=job_queue_arn, + ) - if job_status in ("RUNNING", "FAILED", "SUCCEEDED") and log_stream_name: - try: - log_offset = self.print_logs(log_stream_name, log_offset) - except self.logs_client.exceptions.ResourceNotFoundException: - pass + if self.awslogs_enabled: + if self.waiters: + self.waiters.wait_for_job(self.job_id, get_batch_log_fetcher=self._get_batch_log_fetcher) else: - self.log.info(f"Job status: {job_status}") - - if job_status == "FAILED": - status_reason = res["jobs"][0]["statusReason"] - exit_code = res["jobs"][0]["container"].get("exitCode") - reason = res["jobs"][0]["container"].get("reason", "") - failure_msg = f"Status: {status_reason} | Exit code: {exit_code} | Reason: {reason}" - container_instance_arn = job["container"]["containerInstanceArn"] - self.retry_check(container_instance_arn) - raise AirflowException(failure_msg) - - if job_status == "SUCCEEDED": - self.log.info("AWS Batch Job has been successfully executed") - return - - sleep(7.5) + self.hook.wait_for_job(self.job_id, get_batch_log_fetcher=self._get_batch_log_fetcher) + else: + if self.waiters: + self.waiters.wait_for_job(self.job_id) + else: + self.hook.wait_for_job(self.job_id) + + awslogs = [] + try: + awslogs = self.hook.get_job_all_awslogs_info(self.job_id) + except AirflowException as ae: + self.log.warning("Cannot determine where to find the AWS logs for this Batch job: %s", ae) + + if awslogs: + self.log.info("AWS Batch job (%s) CloudWatch Events details found. Links to logs:", self.job_id) + link_builder = CloudWatchEventsLink() + for log in awslogs: + self.log.info(self._format_cloudwatch_link(**log)) + if len(awslogs) > 1: + # there can be several log streams on multi-node jobs + self.log.warning( + "out of all those logs, we can only link to one in the UI. Using the first one." + ) - def retry_check(self, container_instance_arn): - res = self.ecs_client.describe_container_instances( - cluster=self.cluster_name, containerInstances=[container_instance_arn] - ) - instance_status = res["containerInstances"][0]["status"] - if instance_status != "ACTIVE": - self.log.warning( - f"Instance in {instance_status} state: setting the task up for retry..." + CloudWatchEventsLink.persist( + context=context, + operator=self, + region_name=self.hook.conn_region_name, + aws_partition=self.hook.conn_partition, + **awslogs[0], ) - self.retries += self.task_instance.try_number + 1 - self.task_instance.max_tries = self.retries - - def print_logs(self, log_stream_name, log_offset): - logs = self.logs_client.get_log_events( - logGroupName="/aws/batch/job", - logStreamName=log_stream_name, - startFromHead=True, - ) - - for event in logs["events"][log_offset:]: - self.log.info(event["message"]) - - log_offset = len(logs["events"]) - return log_offset - def on_kill(self): - res = self.batch_client.terminate_job( - jobId=self.job_id, reason="Task killed by the user" - ) - self.log.info(res) + self.hook.check_job_success(self.job_id) + self.log.info("AWS Batch job (%s) succeeded", self.job_id) From 1c2ad82672d0837dfbaf3a63b95c13e866fde8d3 Mon Sep 17 00:00:00 2001 From: claudiazi Date: Fri, 26 Apr 2024 12:54:26 +0200 Subject: [PATCH 074/189] feat: add another param in dbt task --- dagger/dag_creator/airflow/operator_creators/dbt_creator.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/dagger/dag_creator/airflow/operator_creators/dbt_creator.py b/dagger/dag_creator/airflow/operator_creators/dbt_creator.py index 4b88fe3..60866c8 100644 --- a/dagger/dag_creator/airflow/operator_creators/dbt_creator.py +++ b/dagger/dag_creator/airflow/operator_creators/dbt_creator.py @@ -17,7 +17,7 @@ def __init__(self, task, dag): self._select = task.select self._dbt_command = task.dbt_command self._vars = task.vars - # self._create_external_athena_table = task.create_external_athena_table + self._create_external_athena_table = task.create_external_athena_table def _generate_command(self): command = [self._task.executable_prefix, self._task.executable] @@ -31,7 +31,7 @@ def _generate_command(self): if self._vars: dbt_vars = json.dumps(self._vars) command.append(f"--vars='{dbt_vars}'") - # if self._create_external_athena_table: - # command.append(f"--create_external_athena_table={self._create_external_athena_table}") + if self._create_external_athena_table: + command.append(f"--create_external_athena_table={self._create_external_athena_table}") return command From a8e647184a76c5dc3be96b4f6938c9a038356282 Mon Sep 17 00:00:00 2001 From: David Siklosi Date: Fri, 26 Apr 2024 12:55:35 +0200 Subject: [PATCH 075/189] Complete renaming of classes --- dagger/dag_creator/airflow/operators/awsbatch_operator.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dagger/dag_creator/airflow/operators/awsbatch_operator.py b/dagger/dag_creator/airflow/operators/awsbatch_operator.py index b2f4bb3..23b3596 100644 --- a/dagger/dag_creator/airflow/operators/awsbatch_operator.py +++ b/dagger/dag_creator/airflow/operators/awsbatch_operator.py @@ -8,7 +8,7 @@ from airflow.providers.amazon.aws.links.logs import CloudWatchEventsLink -class AWSBatchOperator(AWSBatchOperator): +class AWSBatchOperator(BatchOperator): @staticmethod def _format_cloudwatch_link(awslogs_region: str, awslogs_group: str, awslogs_stream_name: str): return f"https://{awslogs_region}.console.aws.amazon.com/cloudwatch/home?region={awslogs_region}#logEventViewer:group={awslogs_group};stream={awslogs_stream_name}" From 7e29420c3739f16185dbde1674016baed5e63fcf Mon Sep 17 00:00:00 2001 From: claudiazi Date: Fri, 26 Apr 2024 13:35:59 +0200 Subject: [PATCH 076/189] feat: refactor --- dagger/utilities/dbt_config_parser.py | 37 ++++++++------------------- 1 file changed, 11 insertions(+), 26 deletions(-) diff --git a/dagger/utilities/dbt_config_parser.py b/dagger/utilities/dbt_config_parser.py index 11c4325..8f761e3 100644 --- a/dagger/utilities/dbt_config_parser.py +++ b/dagger/utilities/dbt_config_parser.py @@ -59,12 +59,19 @@ def _load_file(file_path: str, file_type: str) -> dict: _logger.error(f"File not found: {file_path}") exit(1) - @abstractmethod def _get_athena_table_task( self, node: dict, follow_external_dependency: bool = False ) -> dict: - """Generate an athena table task for a DBT node. Must be implemented by subclasses. This function should be deprecated after the source connects with databricks directly""" - pass + """Generate an athena table task for a DBT node.""" + task = ATHENA_TASK_BASE.copy() + if follow_external_dependency: + task["follow_external_dependency"] = True + + task["schema"] = node.get("schema", self._default_schema) + task["table"] = node.get("name", "") + task["name"] = f"{task['schema']}__{task['table']}_athena" + + return task @abstractmethod def _get_table_task( @@ -173,7 +180,6 @@ def generate_dagger_io(self, model_name: str) -> Tuple[List[dict], List[dict]]: inputs_list = [] model_node = self._nodes_in_manifest[f"model.main.{model_name}"] parent_node_names = model_node.get("depends_on", {}).get("nodes", []) - print(f"parent node name: {parent_node_names}") for parent_node_name in parent_node_names: dagger_input = self._generate_dagger_tasks(parent_node_name) @@ -187,8 +193,6 @@ def generate_dagger_io(self, model_name: str) -> Tuple[List[dict], List[dict]]: ).values() ) - print(unique_inputs) - return unique_inputs, output_list @@ -209,20 +213,7 @@ def _get_table_task( """ Generates the dagger athena task for the DBT model node """ - task = ATHENA_TASK_BASE.copy() - if follow_external_dependency: - task["follow_external_dependency"] = True - - task["schema"] = node.get("schema", self._default_schema) - task["table"] = node.get("name", "") - task["name"] = f"{task['schema']}__{task['table']}_athena" - - return task - - def _get_athena_table_task( - self, node: dict, follow_external_dependency: bool = False - ) -> dict: - return self._get_table_task(node, follow_external_dependency) + return self._get_athena_table_task(node, follow_external_dependency) def _get_model_data_location( self, node: dict, schema: str, model_name: str @@ -286,7 +277,6 @@ def __init__(self, default_config_parameters: dict): super().__init__(default_config_parameters) self._profile_name = "databricks" self._default_catalog = self._target_config.get("catalog") - self._athena_dbt_parser = AthenaDBTConfigParser(default_config_parameters) self._create_external_athena_table = default_config_parameters.get( "create_external_athena_table", False ) @@ -310,11 +300,6 @@ def _get_table_task( return task - def _get_athena_table_task( - self, node: dict, follow_external_dependency: bool = False - ) -> dict: - return self._athena_dbt_parser._get_table_task(node, follow_external_dependency) - def _get_model_data_location( self, node: dict, schema: str, model_name: str ) -> Tuple[str, str]: From 5659c689bf76ce1082079da075532e5dd7ecbf51 Mon Sep 17 00:00:00 2001 From: claudiazi Date: Fri, 26 Apr 2024 17:56:33 +0200 Subject: [PATCH 077/189] feat: adjust the s3 tasks --- dagger/utilities/dbt_config_parser.py | 61 ++++++------------- .../dbt_config_parser_fixtures_athena.py | 8 +-- .../dbt_config_parser_fixtures_databricks.py | 8 +-- 3 files changed, 27 insertions(+), 50 deletions(-) diff --git a/dagger/utilities/dbt_config_parser.py b/dagger/utilities/dbt_config_parser.py index 8f761e3..0973444 100644 --- a/dagger/utilities/dbt_config_parser.py +++ b/dagger/utilities/dbt_config_parser.py @@ -87,10 +87,20 @@ def _get_model_data_location( """Get the S3 path of the DBT model relative to the data bucket. Must be implemented by subclasses.""" pass - @abstractmethod - def _get_s3_task(self, node: dict) -> dict: - """Generate an S3 task configuration based on a DBT node. Must be implemented by subclasses.""" - pass + def _get_s3_task(self, node: dict, is_output: bool = False) -> dict: + """ + Generates the dagger s3 task for the databricks-dbt model node + """ + task = S3_TASK_BASE.copy() + + schema = node.get("schema", self._default_schema) + table = node.get("name", "") + task["name"] = f"output_s3_path" if is_output else f"s3_{table}" + task["bucket"], task["path"] = self._get_model_data_location( + node, schema, table + ) + + return task @staticmethod def _get_dummy_task(node: dict, follow_external_dependency: bool = False) -> dict: @@ -230,25 +240,6 @@ def _get_model_data_location( return bucket_name, data_path - def _get_s3_task(self, node: dict) -> dict: - """ - Generates the dagger s3 task for the DBT model node - Args: - node: The extracted node from the manifest.json file - - Returns: - dict: The dagger s3 task for the DBT model node - - """ - task = S3_TASK_BASE.copy() - - schema = node.get("schema", self._default_schema) - table = node.get("name", "") - task["name"] = f"{schema}__{table}_s3" - task["bucket"], task["path"] = self._get_model_data_location( - node, schema, table - ) - return task def _generate_dagger_output(self, node: dict): """ @@ -267,7 +258,7 @@ def _generate_dagger_output(self, node: dict): ) or node.get("name").startswith("stg_"): return [self._get_dummy_task(node)] else: - return [self._get_table_task(node), self._get_s3_task(node)] + return [self._get_table_task(node), self._get_s3_task(node, is_output=True)] class DatabricksDBTConfigParser(DBTConfigParser): @@ -313,22 +304,6 @@ def _get_model_data_location( return bucket_name, data_path - def _get_s3_task(self, node: dict) -> dict: - """ - Generates the dagger s3 task for the databricks-dbt model node - """ - task = S3_TASK_BASE.copy() - - catalog = node.get("database", self._default_catalog) - schema = node.get("schema", self._default_schema) - table = node.get("name", "") - task["name"] = f"{catalog}__{schema}__{table}_s3" - task["bucket"], task["path"] = self._get_model_data_location( - node, schema, table - ) - - return task - def _generate_dagger_output(self, node: dict): """ Generates the dagger output for the DBT model node with the databricks-dbt adapter. @@ -345,10 +320,12 @@ def _generate_dagger_output(self, node: dict): if node.get("config", {}).get("materialized") in ( "view", "ephemeral", - ) or node.get("name").startswith("stg_"): + ) or node.get("name").startswith("stg_") or "preparation" in "preparation" in node.get( + "schema", "" + ): return [self._get_dummy_task(node)] else: - output_tasks = [self._get_table_task(node), self._get_s3_task(node)] + output_tasks = [self._get_table_task(node), self._get_s3_task(node, is_output=True)] if self._create_external_athena_table: output_tasks.append(self._get_athena_table_task(node)) return output_tasks diff --git a/tests/fixtures/modules/dbt_config_parser_fixtures_athena.py b/tests/fixtures/modules/dbt_config_parser_fixtures_athena.py index 5f44af4..072fb41 100644 --- a/tests/fixtures/modules/dbt_config_parser_fixtures_athena.py +++ b/tests/fixtures/modules/dbt_config_parser_fixtures_athena.py @@ -226,7 +226,7 @@ }, { "bucket": "bucket1-data-lake", - "name": "analytics_engineering__model2_s3", + "name": "s3_model2", "path": "path2/model2", "type": "s3", }, @@ -269,7 +269,7 @@ }, { "bucket": "bucket1-data-lake", - "name": "analytics_engineering__model1_s3", + "name": "s3_model1", "path": "path1/model1", "type": "s3", }, @@ -290,7 +290,7 @@ }, { "bucket": "bucket1-data-lake", - "name": "analytics_engineering__model2_s3", + "name": "s3_model2", "path": "path2/model2", "type": "s3", }, @@ -339,7 +339,7 @@ }, { "bucket": "bucket1-data-lake", - "name": "analytics_engineering__model1_s3", + "name": "output_s3_path", "path": "path1/model1", "type": "s3", }, diff --git a/tests/fixtures/modules/dbt_config_parser_fixtures_databricks.py b/tests/fixtures/modules/dbt_config_parser_fixtures_databricks.py index 94a387f..b415c60 100644 --- a/tests/fixtures/modules/dbt_config_parser_fixtures_databricks.py +++ b/tests/fixtures/modules/dbt_config_parser_fixtures_databricks.py @@ -237,7 +237,7 @@ }, { "bucket": "chodata-data-lake", - "name": "marts__analytics_engineering__model2_s3", + "name": "s3_model2", "path": "analytics_warehouse/data/marts/analytics_engineering/model2", "type": "s3", }, @@ -281,7 +281,7 @@ }, { "bucket": "chodata-data-lake", - "name": "marts__analytics_engineering__model1_s3", + "name": "s3_model1", "path": "analytics_warehouse/data/marts/analytics_engineering/model1", "type": "s3", }, @@ -303,7 +303,7 @@ }, { "bucket": "chodata-data-lake", - "name": "marts__analytics_engineering__model2_s3", + "name": "s3_model2", "path": "analytics_warehouse/data/marts/analytics_engineering/model2", "type": "s3", }, @@ -362,7 +362,7 @@ }, { "bucket": "chodata-data-lake", - "name": "marts__analytics_engineering__model1_s3", + "name": "output_s3_path", "path": "analytics_warehouse/data/marts/analytics_engineering/model1", "type": "s3", }, From 81697d48d6030695f1e8a29f410ff778928965b4 Mon Sep 17 00:00:00 2001 From: claudiazi Date: Mon, 29 Apr 2024 11:55:53 +0200 Subject: [PATCH 078/189] feat: adjust the _get_s3_task for different dbt adapters --- dagger/utilities/dbt_config_parser.py | 44 ++++++++++++++++++++------- 1 file changed, 33 insertions(+), 11 deletions(-) diff --git a/dagger/utilities/dbt_config_parser.py b/dagger/utilities/dbt_config_parser.py index 0973444..00f58c1 100644 --- a/dagger/utilities/dbt_config_parser.py +++ b/dagger/utilities/dbt_config_parser.py @@ -87,20 +87,12 @@ def _get_model_data_location( """Get the S3 path of the DBT model relative to the data bucket. Must be implemented by subclasses.""" pass + @abstractmethod def _get_s3_task(self, node: dict, is_output: bool = False) -> dict: """ - Generates the dagger s3 task for the databricks-dbt model node + Generate an S3 task for a DBT node for the specific dbt-adapter. Must be implemented by subclasses. """ - task = S3_TASK_BASE.copy() - - schema = node.get("schema", self._default_schema) - table = node.get("name", "") - task["name"] = f"output_s3_path" if is_output else f"s3_{table}" - task["bucket"], task["path"] = self._get_model_data_location( - node, schema, table - ) - - return task + pass @staticmethod def _get_dummy_task(node: dict, follow_external_dependency: bool = False) -> dict: @@ -240,6 +232,21 @@ def _get_model_data_location( return bucket_name, data_path + def _get_s3_task(self, node: dict, is_output: bool = False) -> dict: + """ + Generates the dagger s3 task for the athena-dbt model node + """ + task = S3_TASK_BASE.copy() + + schema = node.get("schema", self._default_schema) + table = node.get("name", "") + task["name"] = f"output_s3_path" if is_output else f"s3_{table}" + task["bucket"] = self._default_data_bucket + _, task["path"] = self._get_model_data_location( + node, schema, table + ) + + return task def _generate_dagger_output(self, node: dict): """ @@ -304,6 +311,21 @@ def _get_model_data_location( return bucket_name, data_path + def _get_s3_task(self, node: dict, is_output: bool = False) -> dict: + """ + Generates the dagger s3 task for the databricks-dbt model node + """ + task = S3_TASK_BASE.copy() + + schema = node.get("schema", self._default_schema) + table = node.get("name", "") + task["name"] = f"output_s3_path" if is_output else f"s3_{table}" + task["bucket"], task["path"] = self._get_model_data_location( + node, schema, table + ) + + return task + def _generate_dagger_output(self, node: dict): """ Generates the dagger output for the DBT model node with the databricks-dbt adapter. From 7c1aa22917786068a85eec953e97078ca656501a Mon Sep 17 00:00:00 2001 From: claudiazi Date: Mon, 29 Apr 2024 14:09:12 +0200 Subject: [PATCH 079/189] fix: define the correct target_config for databricks adapter --- dagger/utilities/dbt_config_parser.py | 30 ++++++++++++++++----------- 1 file changed, 18 insertions(+), 12 deletions(-) diff --git a/dagger/utilities/dbt_config_parser.py b/dagger/utilities/dbt_config_parser.py index 00f58c1..757ad1a 100644 --- a/dagger/utilities/dbt_config_parser.py +++ b/dagger/utilities/dbt_config_parser.py @@ -26,9 +26,11 @@ def __init__(self, config_parameters: dict): self._get_manifest_path(), file_type="json" ) profile_data = self._load_file(self._get_profile_path(), file_type="yaml") - self._target_config = profile_data[self._profile_name]["outputs"][ - self._target_name - ] + self._target_config = ( + profile_data[self._profile_name]["outputs"].get(self._target_name) + if self._profile_name == "athena" + else profile_data[self._profile_name]["outputs"]["data"] + ) # if databricks, get the default catalog and schema from the data output self._default_schema = self._target_config.get("schema", "") self._nodes_in_manifest = self._manifest_data.get("nodes", {}) self._sources_in_manifest = self._manifest_data.get("sources", {}) @@ -242,9 +244,7 @@ def _get_s3_task(self, node: dict, is_output: bool = False) -> dict: table = node.get("name", "") task["name"] = f"output_s3_path" if is_output else f"s3_{table}" task["bucket"] = self._default_data_bucket - _, task["path"] = self._get_model_data_location( - node, schema, table - ) + _, task["path"] = self._get_model_data_location(node, schema, table) return task @@ -339,15 +339,21 @@ def _generate_dagger_output(self, node: dict): dict: The dagger output, which is a combination of an athena and s3 task for the DBT model node """ - if node.get("config", {}).get("materialized") in ( - "view", - "ephemeral", - ) or node.get("name").startswith("stg_") or "preparation" in "preparation" in node.get( - "schema", "" + if ( + node.get("config", {}).get("materialized") + in ( + "view", + "ephemeral", + ) + or node.get("name").startswith("stg_") + or "preparation" in "preparation" in node.get("schema", "") ): return [self._get_dummy_task(node)] else: - output_tasks = [self._get_table_task(node), self._get_s3_task(node, is_output=True)] + output_tasks = [ + self._get_table_task(node), + self._get_s3_task(node, is_output=True), + ] if self._create_external_athena_table: output_tasks.append(self._get_athena_table_task(node)) return output_tasks From 40b28ef9298d5151fd73e894606a71bc18409493 Mon Sep 17 00:00:00 2001 From: claudiazi Date: Mon, 29 Apr 2024 14:18:12 +0200 Subject: [PATCH 080/189] fix: _generate_dagger_output --- dagger/utilities/dbt_config_parser.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dagger/utilities/dbt_config_parser.py b/dagger/utilities/dbt_config_parser.py index 757ad1a..b801379 100644 --- a/dagger/utilities/dbt_config_parser.py +++ b/dagger/utilities/dbt_config_parser.py @@ -346,7 +346,7 @@ def _generate_dagger_output(self, node: dict): "ephemeral", ) or node.get("name").startswith("stg_") - or "preparation" in "preparation" in node.get("schema", "") + or "preparation" in node.get("schema", "") ): return [self._get_dummy_task(node)] else: From 9c0311f2d1daee5d210c26ea5dadd226921feb62 Mon Sep 17 00:00:00 2001 From: claudiazi Date: Mon, 29 Apr 2024 23:12:31 +0200 Subject: [PATCH 081/189] extend: command --- dagger/dag_creator/airflow/operator_creators/dbt_creator.py | 1 + 1 file changed, 1 insertion(+) diff --git a/dagger/dag_creator/airflow/operator_creators/dbt_creator.py b/dagger/dag_creator/airflow/operator_creators/dbt_creator.py index 60866c8..c4e250a 100644 --- a/dagger/dag_creator/airflow/operator_creators/dbt_creator.py +++ b/dagger/dag_creator/airflow/operator_creators/dbt_creator.py @@ -33,5 +33,6 @@ def _generate_command(self): command.append(f"--vars='{dbt_vars}'") if self._create_external_athena_table: command.append(f"--create_external_athena_table={self._create_external_athena_table}") + command.append(super()._generate_command()) return command From 14f9d1f0dcf81328cb0aee47b1644f5f4e6643b8 Mon Sep 17 00:00:00 2001 From: claudiazi Date: Tue, 30 Apr 2024 09:47:12 +0200 Subject: [PATCH 082/189] fix: _generate_command --- dagger/dag_creator/airflow/operator_creators/dbt_creator.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/dagger/dag_creator/airflow/operator_creators/dbt_creator.py b/dagger/dag_creator/airflow/operator_creators/dbt_creator.py index c4e250a..9be9ee8 100644 --- a/dagger/dag_creator/airflow/operator_creators/dbt_creator.py +++ b/dagger/dag_creator/airflow/operator_creators/dbt_creator.py @@ -33,6 +33,8 @@ def _generate_command(self): command.append(f"--vars='{dbt_vars}'") if self._create_external_athena_table: command.append(f"--create_external_athena_table={self._create_external_athena_table}") - command.append(super()._generate_command()) - + for param_name, param_value in self._template_parameters.items(): + command.append( + f"--{param_name}={param_value}" + ) return command From f7cc9830f97d210a21b74f87c87957b8f16799a1 Mon Sep 17 00:00:00 2001 From: claudiazi Date: Tue, 30 Apr 2024 17:12:09 +0200 Subject: [PATCH 083/189] fix: _get_model_data_location in DatabricksDBTConfigParser --- dagger/utilities/dbt_config_parser.py | 2 +- .../modules/dbt_config_parser_fixtures_databricks.py | 12 ++++++------ 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/dagger/utilities/dbt_config_parser.py b/dagger/utilities/dbt_config_parser.py index b801379..1b64132 100644 --- a/dagger/utilities/dbt_config_parser.py +++ b/dagger/utilities/dbt_config_parser.py @@ -305,7 +305,7 @@ def _get_model_data_location( Gets the S3 path of the dbt model relative to the data bucket. """ location_root = node.get("config", {}).get("location_root") - location = join(location_root, schema, model_name) + location = join(location_root, model_name) split = location.split("//")[1].split("/") bucket_name, data_path = split[0], "/".join(split[1:]) diff --git a/tests/fixtures/modules/dbt_config_parser_fixtures_databricks.py b/tests/fixtures/modules/dbt_config_parser_fixtures_databricks.py index b415c60..232fe8c 100644 --- a/tests/fixtures/modules/dbt_config_parser_fixtures_databricks.py +++ b/tests/fixtures/modules/dbt_config_parser_fixtures_databricks.py @@ -22,7 +22,7 @@ "unique_id": "model.main.model1", "resource_type": "model", "config": { - "location_root": "s3://chodata-data-lake/analytics_warehouse/data/marts", + "location_root": "s3://chodata-data-lake/analytics_warehouse/data/marts/analytics_engineering", "materialized": "incremental", "incremental_strategy": "insert_overwrite", }, @@ -30,7 +30,7 @@ "tags": ["daily"], "unrendered_config": { "materialized": "incremental", - "location_root": "s3://chodata-data-lake/analytics_warehouse/data/marts", + "location_root": "s3://chodata-data-lake/analytics_warehouse/data/marts/analytics_engineering", "incremental_strategy": "insert_overwrite", "partitioned_by": ["year", "month", "day", "dt"], "tags": ["daily"], @@ -90,7 +90,7 @@ "unique_id": "model.main.model2", "resource_type": "model", "config": { - "location_root": "s3://chodata-data-lake/analytics_warehouse/data/marts", + "location_root": "s3://chodata-data-lake/analytics_warehouse/data/marts/analytics_engineering", "materialized": "table", }, "depends_on": {"macros": [], "nodes": []}, @@ -103,7 +103,7 @@ "resource_type": "model", "config": { "materialized": "ephemeral", - "location_root": "s3://chodata-data-lake/analytics_warehouse/data/intermediate", + "location_root": "s3://chodata-data-lake/analytics_warehouse/data/intermediate/analytics_engineering", }, "depends_on": { "macros": [], @@ -128,7 +128,7 @@ "schema": "analytics_engineering", "unique_id": "model.main.model3", "config": { - "location_root": "s3://chodata-data-lake/analytics_warehouse/data/marts", + "location_root": "s3://chodata-data-lake/analytics_warehouse/data/marts/analytics_engineering", }, "depends_on": { "macros": [], @@ -147,7 +147,7 @@ "schema": "analytics_engineering", "config": { "materialized": "ephemeral", - "location_root": "s3://chodata-data-lake/analytics_warehouse/data/intermediate", + "location_root": "s3://chodata-data-lake/analytics_warehouse/data/intermediate/analytics_engineering", }, "depends_on": { "macros": [], From 59cf8473a738c7ea57e72702f80a1997d83b909d Mon Sep 17 00:00:00 2001 From: claudiazi Date: Fri, 3 May 2024 15:42:18 +0200 Subject: [PATCH 084/189] fix: generate_task for dbt tasks --- dagger/utilities/module.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/dagger/utilities/module.py b/dagger/utilities/module.py index 6b6aa86..968e196 100644 --- a/dagger/utilities/module.py +++ b/dagger/utilities/module.py @@ -28,12 +28,6 @@ def __init__(self, path_to_config, target_dir): self._override_parameters = config.get("override_parameters", {}) self._default_parameters = config.get("default_parameters", {}) - if "dbt" in self._tasks.keys(): - if self._default_parameters.get("profile_name") == "athena": - self._dbt_module = AthenaDBTConfigParser(self._default_parameters) - if self._default_parameters.get("profile_name") == "databricks": - self._dbt_module = DatabricksDBTConfigParser(self._default_parameters) - @staticmethod def read_yaml(yaml_str): try: @@ -85,6 +79,11 @@ def generate_task_configs(self): template_parameters = {} template_parameters.update(self._default_parameters or {}) template_parameters.update(attrs) + if "dbt" in self._tasks.keys(): + if template_parameters.get("profile_name") == "athena": + self._dbt_module = AthenaDBTConfigParser(template_parameters) + if template_parameters.get("profile_name") == "databricks": + self._dbt_module = DatabricksDBTConfigParser(template_parameters) for task, task_yaml in self._tasks.items(): task_name = f"{branch_name}_{task}" From c44a7b5ccb83e7e53810766d1a0b57c5f50d78c4 Mon Sep 17 00:00:00 2001 From: David Siklosi Date: Wed, 12 Jun 2024 11:37:16 +0200 Subject: [PATCH 085/189] Replacing string replacement with jinja in module processor --- dagger/utilities/dbt_config_parser.py | 12 ++++++++++++ dagger/utilities/module.py | 28 ++++++++++++++++----------- 2 files changed, 29 insertions(+), 11 deletions(-) diff --git a/dagger/utilities/dbt_config_parser.py b/dagger/utilities/dbt_config_parser.py index 1b64132..3be57fe 100644 --- a/dagger/utilities/dbt_config_parser.py +++ b/dagger/utilities/dbt_config_parser.py @@ -35,6 +35,18 @@ def __init__(self, config_parameters: dict): self._nodes_in_manifest = self._manifest_data.get("nodes", {}) self._sources_in_manifest = self._manifest_data.get("sources", {}) + @property + def nodes_in_manifest(self): + return self._nodes_in_manifest + + @property + def sources_in_manifest(self): + return self._sources_in_manifest + + @property + def dbt_default_schema(self): + return self._default_schema + def _get_manifest_path(self) -> str: """ Construct path for manifest.json file based on configuration parameters. diff --git a/dagger/utilities/module.py b/dagger/utilities/module.py index 968e196..ff1329f 100644 --- a/dagger/utilities/module.py +++ b/dagger/utilities/module.py @@ -6,6 +6,8 @@ DatabricksDBTConfigParser, ) +import jinja2 + import yaml _logger = logging.getLogger("root") @@ -48,19 +50,13 @@ def read_task_config(self, task): @staticmethod def replace_template_parameters(_task_str, _template_parameters): - for _key, _value in _template_parameters.items(): - if type(_value) == str: - try: - int_value = int(_value) - _value = f'"{_value}"' - except: - pass - locals()[_key] = _value + environment = jinja2.Environment() + template = environment.from_string(_task_str) + rendered_task = template.render(_template_parameters) return ( - _task_str.format(**locals()) - .replace("{", "{{") - .replace("}", "}}") + rendered_task + # TODO Remove this hack and use Jinja escaping instead of special expression in template files .replace("__CBS__", "{") .replace("__CBE__", "}") ) @@ -79,12 +75,22 @@ def generate_task_configs(self): template_parameters = {} template_parameters.update(self._default_parameters or {}) template_parameters.update(attrs) + template_parameters['branch_name'] = branch_name + + dbt_manifest = None if "dbt" in self._tasks.keys(): if template_parameters.get("profile_name") == "athena": self._dbt_module = AthenaDBTConfigParser(template_parameters) if template_parameters.get("profile_name") == "databricks": self._dbt_module = DatabricksDBTConfigParser(template_parameters) + dbt_manifest = {} + dbt_manifest['nodes'] = self._dbt_module.nodes_in_manifest + dbt_manifest['sources'] = self._dbt_module.sources_in_manifest + + template_parameters["dbt_manifest"] = dbt_manifest + template_parameters["dbt_default_schema"] = self._dbt_module.dbt_default_schema + for task, task_yaml in self._tasks.items(): task_name = f"{branch_name}_{task}" _logger.info(f"Generating task {task_name}") From 1bc6bf28e2286f78e3b30902a7e60a63f24a182f Mon Sep 17 00:00:00 2001 From: claudiazi Date: Mon, 17 Jun 2024 12:30:37 +0200 Subject: [PATCH 086/189] feat: adjust the dbt config parser so that view/ephemeral staging layer doesnt need a task --- dagger/utilities/dbt_config_parser.py | 8 +++-- .../dbt_config_parser_fixtures_athena.py | 36 ++++++++++++++++--- .../dbt_config_parser_fixtures_databricks.py | 32 +++++++++++++++-- tests/utilities/test_dbt_config_parser.py | 5 --- 4 files changed, 67 insertions(+), 14 deletions(-) diff --git a/dagger/utilities/dbt_config_parser.py b/dagger/utilities/dbt_config_parser.py index 3be57fe..bb79b39 100644 --- a/dagger/utilities/dbt_config_parser.py +++ b/dagger/utilities/dbt_config_parser.py @@ -161,16 +161,18 @@ def _generate_dagger_tasks(self, node_name: str) -> List[Dict]: node, follow_external_dependency=True ) dagger_tasks.append(table_task) - elif node.get("config", {}).get("materialized") == "ephemeral": + elif node.get("config", {}).get("materialized") == "ephemeral" or ((node.get("name").startswith("stg_") or "preparation" in node.get( + "schema", "" + )) and node.get("config", {}).get("materialized") != "table"): task = self._get_dummy_task(node, follow_external_dependency=True) dagger_tasks.append(task) ephemeral_parent_node_names = node.get("depends_on", {}).get("nodes", []) for node_name in ephemeral_parent_node_names: dagger_tasks += self._generate_dagger_tasks(node_name) - elif node.get("name").startswith("stg_") or "preparation" in node.get( + elif (node.get("name").startswith("stg_") or "preparation" in node.get( "schema", "" - ): + ) and node.get("config", {}).get("materialized") == "table"): dagger_tasks.append( self._get_dummy_task(node, follow_external_dependency=True) ) diff --git a/tests/fixtures/modules/dbt_config_parser_fixtures_athena.py b/tests/fixtures/modules/dbt_config_parser_fixtures_athena.py index 072fb41..2f1c6ee 100644 --- a/tests/fixtures/modules/dbt_config_parser_fixtures_athena.py +++ b/tests/fixtures/modules/dbt_config_parser_fixtures_athena.py @@ -58,7 +58,7 @@ "unique_id": "model.main.stg_core_schema1__table1", "name": "stg_core_schema1__table1", "config": { - "materialized": "view", + "materialized": "table", }, "depends_on": { "macros": [], @@ -235,6 +235,20 @@ "type": "dummy", "follow_external_dependency": True, }, + { + "type": "athena", + "name": "core_schema2__table2_athena", + "schema": "core_schema2", + "table": "table2", + "follow_external_dependency": True, + }, + { + "type": "athena", + "name": "core_schema2__table3_athena", + "schema": "core_schema2", + "table": "table3", + "follow_external_dependency": True, + }, ] EXPECTED_EPHEMERAL_NODE = [ @@ -256,7 +270,7 @@ "name": "stg_core_schema1__table1", "type": "dummy", "follow_external_dependency": True, - }, + } ] EXPECTED_MODEL_NODE = [ @@ -281,6 +295,21 @@ "type": "dummy", "follow_external_dependency": True, }, + { + "type": "athena", + "name": "core_schema2__table2_athena", + "schema": "core_schema2", + "table": "table2", + "follow_external_dependency": True, + }, + { + "type": "athena", + "name": "core_schema2__table3_athena", + "schema": "core_schema2", + "table": "table3", + "follow_external_dependency": True, + }, + {"name": "seed_buyer_country_overwrite", "type": "dummy"}, { "name": "analytics_engineering__model2_athena", "schema": "analytics_engineering", @@ -304,12 +333,11 @@ "name": "int_model2", "follow_external_dependency": True, }, - {"name": "seed_buyer_country_overwrite", "type": "dummy"}, { "name": "stg_core_schema1__table1", "type": "dummy", "follow_external_dependency": True, - }, + } ] EXPECTED_DBT_STAGING_MODEL_DAGGER_INPUTS = [ diff --git a/tests/fixtures/modules/dbt_config_parser_fixtures_databricks.py b/tests/fixtures/modules/dbt_config_parser_fixtures_databricks.py index 232fe8c..342c32a 100644 --- a/tests/fixtures/modules/dbt_config_parser_fixtures_databricks.py +++ b/tests/fixtures/modules/dbt_config_parser_fixtures_databricks.py @@ -57,7 +57,7 @@ "resource_type": "model", "config": { "location_root": "s3://chodata-data-lake/analytics_warehouse/data/preparation", - "materialized": "view", + "materialized": "table", }, "depends_on": { "macros": [], @@ -246,6 +246,20 @@ "type": "dummy", "follow_external_dependency": True, }, + { + "type": "athena", + "name": "core_schema2__table2_athena", + "schema": "core_schema2", + "table": "table2", + "follow_external_dependency": True, + }, + { + "type": "athena", + "name": "core_schema2__table3_athena", + "schema": "core_schema2", + "table": "table3", + "follow_external_dependency": True, + }, ] DATABRICKS_EXPECTED_EPHEMERAL_NODE = [ @@ -293,6 +307,21 @@ "type": "dummy", "follow_external_dependency": True, }, + { + "type": "athena", + "name": "core_schema2__table2_athena", + "schema": "core_schema2", + "table": "table2", + "follow_external_dependency": True, + }, + { + "type": "athena", + "name": "core_schema2__table3_athena", + "schema": "core_schema2", + "table": "table3", + "follow_external_dependency": True, + }, + {"name": "seed_buyer_country_overwrite", "type": "dummy"}, { "name": "marts__analytics_engineering__model2_databricks", "catalog": "marts", @@ -317,7 +346,6 @@ "name": "int_model2", "follow_external_dependency": True, }, - {"name": "seed_buyer_country_overwrite", "type": "dummy"}, { "name": "stg_core_schema1__table1", "type": "dummy", diff --git a/tests/utilities/test_dbt_config_parser.py b/tests/utilities/test_dbt_config_parser.py index 9e4d18f..d401e4b 100644 --- a/tests/utilities/test_dbt_config_parser.py +++ b/tests/utilities/test_dbt_config_parser.py @@ -49,10 +49,6 @@ def test_generate_task_configs(self): def test_generate_dagger_tasks(self): test_inputs = [ - ( - "model.main.stg_core_schema1__table1", - EXPECTED_STAGING_NODE, - ), ( "seed.main.seed_buyer_country_overwrite", EXPECTED_SEED_NODE, @@ -159,7 +155,6 @@ def test_generate_io_inputs(self): ] for mock_input, expected_output in fixtures: result, _ = self._dbt_config_parser.generate_dagger_io(mock_input) - self.assertListEqual(result, expected_output) def test_generate_io_outputs(self): From 22ca09b45ab35072897ba2023420806b0f43f387 Mon Sep 17 00:00:00 2001 From: claudiazi Date: Tue, 18 Jun 2024 17:10:49 +0200 Subject: [PATCH 087/189] feat: adjust io for the materalised staging model --- dagger/utilities/dbt_config_parser.py | 16 ++--- .../dbt_config_parser_fixtures_athena.py | 65 +++++++++++++++---- .../dbt_config_parser_fixtures_databricks.py | 65 ++++++++++++++++--- 3 files changed, 115 insertions(+), 31 deletions(-) diff --git a/dagger/utilities/dbt_config_parser.py b/dagger/utilities/dbt_config_parser.py index bb79b39..8a62fb4 100644 --- a/dagger/utilities/dbt_config_parser.py +++ b/dagger/utilities/dbt_config_parser.py @@ -161,21 +161,19 @@ def _generate_dagger_tasks(self, node_name: str) -> List[Dict]: node, follow_external_dependency=True ) dagger_tasks.append(table_task) - elif node.get("config", {}).get("materialized") == "ephemeral" or ((node.get("name").startswith("stg_") or "preparation" in node.get( - "schema", "" - )) and node.get("config", {}).get("materialized") != "table"): + elif node.get("config", {}).get("materialized") == "ephemeral" or ( + ( + node.get("name").startswith("stg_") + or "preparation" in node.get("schema", "") + ) + and node.get("config", {}).get("materialized") != "table" + ): task = self._get_dummy_task(node, follow_external_dependency=True) dagger_tasks.append(task) ephemeral_parent_node_names = node.get("depends_on", {}).get("nodes", []) for node_name in ephemeral_parent_node_names: dagger_tasks += self._generate_dagger_tasks(node_name) - elif (node.get("name").startswith("stg_") or "preparation" in node.get( - "schema", "" - ) and node.get("config", {}).get("materialized") == "table"): - dagger_tasks.append( - self._get_dummy_task(node, follow_external_dependency=True) - ) else: table_task = self._get_table_task(node, follow_external_dependency=True) s3_task = self._get_s3_task(node) diff --git a/tests/fixtures/modules/dbt_config_parser_fixtures_athena.py b/tests/fixtures/modules/dbt_config_parser_fixtures_athena.py index 2f1c6ee..f1afd52 100644 --- a/tests/fixtures/modules/dbt_config_parser_fixtures_athena.py +++ b/tests/fixtures/modules/dbt_config_parser_fixtures_athena.py @@ -59,6 +59,7 @@ "name": "stg_core_schema1__table1", "config": { "materialized": "table", + "external_location": "s3://bucket1-data-lake/path2/stg_core_schema1__table1", }, "depends_on": { "macros": [], @@ -184,10 +185,18 @@ EXPECTED_STAGING_NODE = [ { - "name": "stg_core_schema1__table1", - "type": "dummy", + "name": "analytics_engineering__stg_core_schema1__table1_athena", + "type": "athena", + "table": "stg_core_schema1__table1", + "schema": "analytics_engineering", "follow_external_dependency": True, }, + { + "type": "s3", + "name": "s3_stg_core_schema1__table1", + "bucket": "bucket1-data-lake", + "path": "path2/stg_core_schema1__table1", + }, ] EXPECTED_SEED_NODE = [ @@ -213,10 +222,18 @@ "name": "seed_buyer_country_overwrite", }, { - "name": "stg_core_schema1__table1", - "type": "dummy", + "name": "analytics_engineering__stg_core_schema1__table1_athena", + "type": "athena", + "table": "stg_core_schema1__table1", + "schema": "analytics_engineering", "follow_external_dependency": True, }, + { + "type": "s3", + "name": "s3_stg_core_schema1__table1", + "bucket": "bucket1-data-lake", + "path": "path2/stg_core_schema1__table1", + }, { "type": "athena", "name": "analytics_engineering__model2_athena", @@ -267,10 +284,18 @@ "name": "seed_buyer_country_overwrite", }, { - "name": "stg_core_schema1__table1", - "type": "dummy", + "name": "analytics_engineering__stg_core_schema1__table1_athena", + "type": "athena", + "table": "stg_core_schema1__table1", + "schema": "analytics_engineering", "follow_external_dependency": True, - } + }, + { + "type": "s3", + "name": "s3_stg_core_schema1__table1", + "bucket": "bucket1-data-lake", + "path": "path2/stg_core_schema1__table1", + }, ] EXPECTED_MODEL_NODE = [ @@ -334,10 +359,18 @@ "follow_external_dependency": True, }, { - "name": "stg_core_schema1__table1", - "type": "dummy", + "name": "analytics_engineering__stg_core_schema1__table1_athena", + "type": "athena", + "table": "stg_core_schema1__table1", + "schema": "analytics_engineering", "follow_external_dependency": True, - } + }, + { + "type": "s3", + "name": "s3_stg_core_schema1__table1", + "bucket": "bucket1-data-lake", + "path": "path2/stg_core_schema1__table1", + }, ] EXPECTED_DBT_STAGING_MODEL_DAGGER_INPUTS = [ @@ -383,8 +416,16 @@ EXPECTED_DBT_INT_MODEL_DAGGER_INPUTS = [ {"name": "seed_buyer_country_overwrite", "type": "dummy"}, { - "name": "stg_core_schema1__table1", - "type": "dummy", + "name": "analytics_engineering__stg_core_schema1__table1_athena", + "type": "athena", + "table": "stg_core_schema1__table1", + "schema": "analytics_engineering", "follow_external_dependency": True, }, + { + "type": "s3", + "name": "s3_stg_core_schema1__table1", + "bucket": "bucket1-data-lake", + "path": "path2/stg_core_schema1__table1", + }, ] diff --git a/tests/fixtures/modules/dbt_config_parser_fixtures_databricks.py b/tests/fixtures/modules/dbt_config_parser_fixtures_databricks.py index 342c32a..2538e25 100644 --- a/tests/fixtures/modules/dbt_config_parser_fixtures_databricks.py +++ b/tests/fixtures/modules/dbt_config_parser_fixtures_databricks.py @@ -194,9 +194,18 @@ DATABRICKS_EXPECTED_STAGING_NODE = [ { - "name": "stg_core_schema1__table1", - "type": "dummy", + "type": "databricks", "follow_external_dependency": True, + "catalog": "hive_metastore", + "schema": "data_preparation", + "table": "stg_core_schema1__table1", + "name": "hive_metastore__data_preparation__stg_core_schema1__table1_databricks", + }, + { + "type": "s3", + "name": "s3_stg_core_schema1__table1", + "bucket": "chodata-data-lake", + "path": "analytics_warehouse/data/preparation/stg_core_schema1__table1", }, ] @@ -223,9 +232,18 @@ "name": "seed_buyer_country_overwrite", }, { - "name": "stg_core_schema1__table1", - "type": "dummy", + "type": "databricks", "follow_external_dependency": True, + "catalog": "hive_metastore", + "schema": "data_preparation", + "table": "stg_core_schema1__table1", + "name": "hive_metastore__data_preparation__stg_core_schema1__table1_databricks", + }, + { + "type": "s3", + "name": "s3_stg_core_schema1__table1", + "bucket": "chodata-data-lake", + "path": "analytics_warehouse/data/preparation/stg_core_schema1__table1", }, { "type": "databricks", @@ -278,9 +296,18 @@ "name": "seed_buyer_country_overwrite", }, { - "name": "stg_core_schema1__table1", - "type": "dummy", + "type": "databricks", "follow_external_dependency": True, + "catalog": "hive_metastore", + "schema": "data_preparation", + "table": "stg_core_schema1__table1", + "name": "hive_metastore__data_preparation__stg_core_schema1__table1_databricks", + }, + { + "type": "s3", + "name": "s3_stg_core_schema1__table1", + "bucket": "chodata-data-lake", + "path": "analytics_warehouse/data/preparation/stg_core_schema1__table1", }, ] @@ -347,9 +374,18 @@ "follow_external_dependency": True, }, { - "name": "stg_core_schema1__table1", - "type": "dummy", + "type": "databricks", "follow_external_dependency": True, + "catalog": "hive_metastore", + "schema": "data_preparation", + "table": "stg_core_schema1__table1", + "name": "hive_metastore__data_preparation__stg_core_schema1__table1_databricks", + }, + { + "type": "s3", + "name": "s3_stg_core_schema1__table1", + "bucket": "chodata-data-lake", + "path": "analytics_warehouse/data/preparation/stg_core_schema1__table1", }, ] @@ -374,9 +410,18 @@ DATABRICKS_EXPECTED_DBT_INT_MODEL_DAGGER_INPUTS = [ {"name": "seed_buyer_country_overwrite", "type": "dummy"}, { - "name": "stg_core_schema1__table1", - "type": "dummy", + "type": "databricks", "follow_external_dependency": True, + "catalog": "hive_metastore", + "schema": "data_preparation", + "table": "stg_core_schema1__table1", + "name": "hive_metastore__data_preparation__stg_core_schema1__table1_databricks", + }, + { + "type": "s3", + "name": "s3_stg_core_schema1__table1", + "bucket": "chodata-data-lake", + "path": "analytics_warehouse/data/preparation/stg_core_schema1__table1", }, ] From f2f8015d94bbda0681a6bbb07baafaf8078d8824 Mon Sep 17 00:00:00 2001 From: claudiazi Date: Wed, 19 Jun 2024 14:45:55 +0200 Subject: [PATCH 088/189] feat: restructure the dbt dagger task input & output --- dagger/utilities/dbt_config_parser.py | 99 +++++++++++-------- .../dbt_config_parser_fixtures_athena.py | 71 ++++++++----- .../dbt_config_parser_fixtures_databricks.py | 90 ++++++++++++----- tests/utilities/test_dbt_config_parser.py | 8 +- 4 files changed, 173 insertions(+), 95 deletions(-) diff --git a/dagger/utilities/dbt_config_parser.py b/dagger/utilities/dbt_config_parser.py index 8a62fb4..9a341f6 100644 --- a/dagger/utilities/dbt_config_parser.py +++ b/dagger/utilities/dbt_config_parser.py @@ -133,18 +133,25 @@ def _generate_dagger_output(self, node: dict): """Generate the dagger output for a DBT node. Must be implemented by subclasses.""" pass + @abstractmethod + def _is_node_preparation_model(self, node: dict): + """Define whether it is a preparation model. Must be implemented by subclasses.""" + pass + def _generate_dagger_tasks(self, node_name: str) -> List[Dict]: """ - Generates the dagger task based on whether the DBT model node is a staging model or not. - If the DBT model node represents a DBT seed or an ephemeral model, then a dagger dummy task is generated. - If the DBT model node represents a staging model, then a dagger athena task is generated for each source of the DBT model. Apart from this, a dummy task is also generated for the staging model itself. - If the DBT model node is not a staging model, then a dagger athena task and an s3 task is generated for the DBT model node itself. + Generates dagger tasks based on the type and materialization of the DBT model node. + + - If the node is a DBT source, an Athena table task is generated. + - If the node is an ephemeral model, a dummy task is generated, and tasks for its dependent nodes are recursively generated. + - If the node is a staging model (preparation model) and not materialized as a table, a table task is generated along with tasks for its dependent nodes. + - For other nodes, a table task is generated. If the node is materialized as a table, an additional S3 task is also generated. + Args: node_name: The name of the DBT model node Returns: List[Dict]: The respective dagger tasks for the DBT model node - """ dagger_tasks = [] @@ -153,33 +160,36 @@ def _generate_dagger_tasks(self, node_name: str) -> List[Dict]: else: node = self._nodes_in_manifest[node_name] - if node.get("resource_type") == "seed": - task = self._get_dummy_task(node) - dagger_tasks.append(task) - elif node.get("resource_type") == "source": + resource_type = node.get("resource_type") + materialized_type = node.get("config", {}).get("materialized") + + follow_external_dependency = True + if resource_type == "seed" or (self._is_node_preparation_model(node) and materialized_type != "table"): + follow_external_dependency = False + + if resource_type == "source": table_task = self._get_athena_table_task( - node, follow_external_dependency=True + node, follow_external_dependency=follow_external_dependency ) dagger_tasks.append(table_task) - elif node.get("config", {}).get("materialized") == "ephemeral" or ( - ( - node.get("name").startswith("stg_") - or "preparation" in node.get("schema", "") - ) - and node.get("config", {}).get("materialized") != "table" - ): - task = self._get_dummy_task(node, follow_external_dependency=True) - dagger_tasks.append(task) - ephemeral_parent_node_names = node.get("depends_on", {}).get("nodes", []) - for node_name in ephemeral_parent_node_names: + elif materialized_type == "ephemeral": + task = self._get_dummy_task(node) + dagger_tasks.append(task) + for node_name in node.get("depends_on", {}).get("nodes", []): dagger_tasks += self._generate_dagger_tasks(node_name) - else: - table_task = self._get_table_task(node, follow_external_dependency=True) - s3_task = self._get_s3_task(node) + else: + table_task = self._get_table_task(node, follow_external_dependency=follow_external_dependency) dagger_tasks.append(table_task) - dagger_tasks.append(s3_task) + + if materialized_type in ("table", "incremental"): + dagger_tasks.append(self._get_s3_task(node)) + elif self._is_node_preparation_model(node): + for dependent_node_name in node.get("depends_on", {}).get("nodes", []): + dagger_tasks.extend( + self._generate_dagger_tasks(dependent_node_name) + ) return dagger_tasks @@ -223,6 +233,10 @@ def __init__(self, default_config_parameters: dict): "s3_data_dir" ) or self._target_config.get("s3_staging_dir") + def _is_node_preparation_model(self, node: dict): + """Define whether it is a preparation model.""" + return node.get("name").startswith("stg_") + def _get_table_task( self, node: dict, follow_external_dependency: bool = False ) -> dict: @@ -271,13 +285,14 @@ def _generate_dagger_output(self, node: dict): dict: The dagger output, which is a combination of an athena and s3 task for the DBT model node """ - if node.get("config", {}).get("materialized") in ( - "view", - "ephemeral", - ) or node.get("name").startswith("stg_"): + materialized_type = node.get("config", {}).get("materialized") + if materialized_type == "ephemeral": return [self._get_dummy_task(node)] else: - return [self._get_table_task(node), self._get_s3_task(node, is_output=True)] + output_tasks = [self._get_table_task(node)] + if materialized_type in ("table", "incremental"): + output_tasks.append(self._get_s3_task(node, is_output=True)) + return output_tasks class DatabricksDBTConfigParser(DBTConfigParser): @@ -291,6 +306,12 @@ def __init__(self, default_config_parameters: dict): "create_external_athena_table", False ) + def _is_node_preparation_model(self, node: dict): + """ + Define whether it is a preparation model. + """ + return "preparation" in node.get("schema", "") + def _get_table_task( self, node: dict, follow_external_dependency: bool = False ) -> dict: @@ -351,21 +372,13 @@ def _generate_dagger_output(self, node: dict): dict: The dagger output, which is a combination of an athena and s3 task for the DBT model node """ - if ( - node.get("config", {}).get("materialized") - in ( - "view", - "ephemeral", - ) - or node.get("name").startswith("stg_") - or "preparation" in node.get("schema", "") - ): + materialized_type = node.get("config", {}).get("materialized") + if materialized_type == "ephemeral": return [self._get_dummy_task(node)] else: - output_tasks = [ - self._get_table_task(node), - self._get_s3_task(node, is_output=True), - ] + output_tasks = [self._get_table_task(node)] + if materialized_type in ("table", "incremental"): + output_tasks.append(self._get_s3_task(node, is_output=True)) if self._create_external_athena_table: output_tasks.append(self._get_athena_table_task(node)) return output_tasks diff --git a/tests/fixtures/modules/dbt_config_parser_fixtures_athena.py b/tests/fixtures/modules/dbt_config_parser_fixtures_athena.py index f1afd52..64fffce 100644 --- a/tests/fixtures/modules/dbt_config_parser_fixtures_athena.py +++ b/tests/fixtures/modules/dbt_config_parser_fixtures_athena.py @@ -201,8 +201,10 @@ EXPECTED_SEED_NODE = [ { - "type": "dummy", - "name": "seed_buyer_country_overwrite", + "type": "athena", + "schema": "analytics_engineering", + "table": "seed_buyer_country_overwrite", + "name": "analytics_engineering__seed_buyer_country_overwrite_athena", } ] @@ -210,16 +212,16 @@ { "type": "dummy", "name": "int_model3", - "follow_external_dependency": True, }, { "type": "dummy", "name": "int_model2", - "follow_external_dependency": True, }, { - "type": "dummy", - "name": "seed_buyer_country_overwrite", + "type": "athena", + "schema": "analytics_engineering", + "table": "seed_buyer_country_overwrite", + "name": "analytics_engineering__seed_buyer_country_overwrite_athena", }, { "name": "analytics_engineering__stg_core_schema1__table1_athena", @@ -248,9 +250,10 @@ "type": "s3", }, { - "name": "stg_core_schema2__table2", - "type": "dummy", - "follow_external_dependency": True, + "type": "athena", + "schema": "analytics_engineering", + "table": "stg_core_schema2__table2", + "name": "analytics_engineering__stg_core_schema2__table2_athena", }, { "type": "athena", @@ -272,16 +275,16 @@ { "type": "dummy", "name": "int_model3", - "follow_external_dependency": True, }, { "type": "dummy", "name": "int_model2", - "follow_external_dependency": True, }, { - "type": "dummy", - "name": "seed_buyer_country_overwrite", + "type": "athena", + "schema": "analytics_engineering", + "table": "seed_buyer_country_overwrite", + "name": "analytics_engineering__seed_buyer_country_overwrite_athena", }, { "name": "analytics_engineering__stg_core_schema1__table1_athena", @@ -316,9 +319,10 @@ EXPECTED_DAGGER_INPUTS = [ { - "name": "stg_core_schema2__table2", - "type": "dummy", - "follow_external_dependency": True, + "type": "athena", + "schema": "analytics_engineering", + "table": "stg_core_schema2__table2", + "name": "analytics_engineering__stg_core_schema2__table2_athena", }, { "type": "athena", @@ -334,7 +338,12 @@ "table": "table3", "follow_external_dependency": True, }, - {"name": "seed_buyer_country_overwrite", "type": "dummy"}, + { + "type": "athena", + "schema": "analytics_engineering", + "table": "seed_buyer_country_overwrite", + "name": "analytics_engineering__seed_buyer_country_overwrite_athena", + }, { "name": "analytics_engineering__model2_athena", "schema": "analytics_engineering", @@ -351,12 +360,10 @@ { "type": "dummy", "name": "int_model3", - "follow_external_dependency": True, }, { "type": "dummy", "name": "int_model2", - "follow_external_dependency": True, }, { "name": "analytics_engineering__stg_core_schema1__table1_athena", @@ -388,7 +395,12 @@ "table": "table3", "type": "athena", }, - {"name": "seed_buyer_country_overwrite", "type": "dummy"}, + { + "type": "athena", + "schema": "analytics_engineering", + "table": "seed_buyer_country_overwrite", + "name": "analytics_engineering__seed_buyer_country_overwrite_athena", + } ] EXPECTED_DAGGER_OUTPUTS = [ @@ -408,13 +420,26 @@ EXPECTED_DBT_STAGING_MODEL_DAGGER_OUTPUTS = [ { - "type": "dummy", - "name": "stg_core_schema2__table2", + "name": "analytics_engineering__stg_core_schema1__table1_athena", + "type": "athena", + "table": "stg_core_schema1__table1", + "schema": "analytics_engineering", + }, + { + "type": "s3", + "name": "output_s3_path", + "bucket": "bucket1-data-lake", + "path": "path2/stg_core_schema1__table1", }, ] EXPECTED_DBT_INT_MODEL_DAGGER_INPUTS = [ - {"name": "seed_buyer_country_overwrite", "type": "dummy"}, + { + "type": "athena", + "schema": "analytics_engineering", + "table": "seed_buyer_country_overwrite", + "name": "analytics_engineering__seed_buyer_country_overwrite_athena", + }, { "name": "analytics_engineering__stg_core_schema1__table1_athena", "type": "athena", diff --git a/tests/fixtures/modules/dbt_config_parser_fixtures_databricks.py b/tests/fixtures/modules/dbt_config_parser_fixtures_databricks.py index 2538e25..ad6b912 100644 --- a/tests/fixtures/modules/dbt_config_parser_fixtures_databricks.py +++ b/tests/fixtures/modules/dbt_config_parser_fixtures_databricks.py @@ -112,7 +112,7 @@ }, "seed.main.seed_buyer_country_overwrite": { "database": "hive_metastore", - "schema": "datastg_preparation", + "schema": "data_preparation", "name": "seed_buyer_country_overwrite", "unique_id": "seed.main.seed_buyer_country_overwrite", "resource_type": "seed", @@ -211,8 +211,11 @@ DATABRICKS_EXPECTED_SEED_NODE = [ { - "type": "dummy", - "name": "seed_buyer_country_overwrite", + "type": "databricks", + "catalog": "hive_metastore", + "schema": "data_preparation", + "table": "seed_buyer_country_overwrite", + "name": "hive_metastore__data_preparation__seed_buyer_country_overwrite_databricks", } ] @@ -220,24 +223,25 @@ { "type": "dummy", "name": "int_model3", - "follow_external_dependency": True, }, { "type": "dummy", "name": "int_model2", - "follow_external_dependency": True, }, { - "type": "dummy", - "name": "seed_buyer_country_overwrite", + "type": "databricks", + "catalog": "hive_metastore", + "schema": "data_preparation", + "table": "seed_buyer_country_overwrite", + "name": "hive_metastore__data_preparation__seed_buyer_country_overwrite_databricks", }, { "type": "databricks", - "follow_external_dependency": True, "catalog": "hive_metastore", "schema": "data_preparation", "table": "stg_core_schema1__table1", "name": "hive_metastore__data_preparation__stg_core_schema1__table1_databricks", + "follow_external_dependency": True, }, { "type": "s3", @@ -260,9 +264,11 @@ "type": "s3", }, { - "name": "stg_core_schema2__table2", - "type": "dummy", - "follow_external_dependency": True, + "type": "databricks", + "catalog": "hive_metastore", + "schema": "data_preparation", + "table": "stg_core_schema2__table2", + "name": "hive_metastore__data_preparation__stg_core_schema2__table2_databricks", }, { "type": "athena", @@ -284,16 +290,17 @@ { "type": "dummy", "name": "int_model3", - "follow_external_dependency": True, }, { "type": "dummy", "name": "int_model2", - "follow_external_dependency": True, }, { - "type": "dummy", - "name": "seed_buyer_country_overwrite", + "type": "databricks", + "catalog": "hive_metastore", + "schema": "data_preparation", + "table": "seed_buyer_country_overwrite", + "name": "hive_metastore__data_preparation__seed_buyer_country_overwrite_databricks", }, { "type": "databricks", @@ -330,9 +337,11 @@ DATABRICKS_EXPECTED_DAGGER_INPUTS = [ { - "name": "stg_core_schema2__table2", - "type": "dummy", - "follow_external_dependency": True, + "type": "databricks", + "catalog": "hive_metastore", + "schema": "data_preparation", + "table": "stg_core_schema2__table2", + "name": "hive_metastore__data_preparation__stg_core_schema2__table2_databricks", }, { "type": "athena", @@ -348,7 +357,13 @@ "table": "table3", "follow_external_dependency": True, }, - {"name": "seed_buyer_country_overwrite", "type": "dummy"}, + { + "type": "databricks", + "catalog": "hive_metastore", + "schema": "data_preparation", + "table": "seed_buyer_country_overwrite", + "name": "hive_metastore__data_preparation__seed_buyer_country_overwrite_databricks", + }, { "name": "marts__analytics_engineering__model2_databricks", "catalog": "marts", @@ -366,12 +381,10 @@ { "type": "dummy", "name": "int_model3", - "follow_external_dependency": True, }, { "type": "dummy", "name": "int_model2", - "follow_external_dependency": True, }, { "type": "databricks", @@ -404,11 +417,23 @@ "table": "table3", "type": "athena", }, - {"name": "seed_buyer_country_overwrite", "type": "dummy"}, + { + "type": "databricks", + "catalog": "hive_metastore", + "schema": "data_preparation", + "table": "seed_buyer_country_overwrite", + "name": "hive_metastore__data_preparation__seed_buyer_country_overwrite_databricks", + }, ] DATABRICKS_EXPECTED_DBT_INT_MODEL_DAGGER_INPUTS = [ - {"name": "seed_buyer_country_overwrite", "type": "dummy"}, + { + "type": "databricks", + "catalog": "hive_metastore", + "schema": "data_preparation", + "table": "seed_buyer_country_overwrite", + "name": "hive_metastore__data_preparation__seed_buyer_country_overwrite_databricks", + }, { "type": "databricks", "follow_external_dependency": True, @@ -449,7 +474,22 @@ DATABRICKS_EXPECTED_DBT_STAGING_MODEL_DAGGER_OUTPUTS = [ { - "type": "dummy", - "name": "stg_core_schema2__table2", + "type": "databricks", + "catalog": "hive_metastore", + "schema": "data_preparation", + "table": "stg_core_schema1__table1", + "name": "hive_metastore__data_preparation__stg_core_schema1__table1_databricks", }, + { + "type": "s3", + "name": "output_s3_path", + "bucket": "chodata-data-lake", + "path": "analytics_warehouse/data/preparation/stg_core_schema1__table1", + }, + { + 'type': 'athena', + 'schema': 'data_preparation', + 'table': 'stg_core_schema1__table1', + 'name': 'data_preparation__stg_core_schema1__table1_athena' + } ] diff --git a/tests/utilities/test_dbt_config_parser.py b/tests/utilities/test_dbt_config_parser.py index d401e4b..d4d9028 100644 --- a/tests/utilities/test_dbt_config_parser.py +++ b/tests/utilities/test_dbt_config_parser.py @@ -78,13 +78,14 @@ def test_generate_io_inputs(self): ] for mock_input, expected_output in fixtures: result, _ = self._dbt_config_parser.generate_dagger_io(mock_input) - + print(f"result: {result}") + print(f"expected_output: {expected_output}") self.assertListEqual(result, expected_output) def test_generate_io_outputs(self): fixtures = [ ("model1", EXPECTED_DAGGER_OUTPUTS), - ("stg_core_schema2__table2", EXPECTED_DBT_STAGING_MODEL_DAGGER_OUTPUTS), + ("stg_core_schema1__table1", EXPECTED_DBT_STAGING_MODEL_DAGGER_OUTPUTS), ] for mock_input, expected_output in fixtures: _, result = self._dbt_config_parser.generate_dagger_io(mock_input) @@ -161,11 +162,10 @@ def test_generate_io_outputs(self): fixtures = [ ("model1", DATABRICKS_EXPECTED_DAGGER_OUTPUTS), ( - "stg_core_schema2__table2", + "stg_core_schema1__table1", DATABRICKS_EXPECTED_DBT_STAGING_MODEL_DAGGER_OUTPUTS, ), ] for mock_input, expected_output in fixtures: _, result = self._dbt_config_parser.generate_dagger_io(mock_input) - self.assertListEqual(result, expected_output) From 698358dfcd9231f6097d44736c738ab4279e9b29 Mon Sep 17 00:00:00 2001 From: David Siklosi Date: Thu, 4 Jul 2024 14:23:36 +0200 Subject: [PATCH 089/189] Module generation with generalised jinja parameters --- dagger/cli/module.py | 21 +++++++++++++++++++-- dagger/utilities/module.py | 4 +++- 2 files changed, 22 insertions(+), 3 deletions(-) diff --git a/dagger/cli/module.py b/dagger/cli/module.py index 67fca87..931e809 100644 --- a/dagger/cli/module.py +++ b/dagger/cli/module.py @@ -1,17 +1,34 @@ import click from dagger.utilities.module import Module from dagger.utils import Printer +import json +def parse_key_value(ctx, param, value): + #print('YYY', value) + if not value: + return {} + key_value_dict = {} + for pair in value: + try: + key, val_file_path = pair.split('=', 1) + #print('YYY', key, val_file_path, pair) + val = json.load(open(val_file_path)) + key_value_dict[key] = val + except ValueError: + raise click.BadParameter(f"Key-value pair '{pair}' is not in the format key=value") + return key_value_dict + @click.command() @click.option("--config_file", "-c", help="Path to module config file") @click.option("--target_dir", "-t", help="Path to directory to generate the task configs to") -def generate_tasks(config_file: str, target_dir: str) -> None: +@click.option("--jinja_parameters", "-j", callback=parse_key_value, multiple=True, default=None, help="Path to jinja parameters json file in the format: =") +def generate_tasks(config_file: str, target_dir: str, jinja_parameters: dict) -> None: """ Generating tasks for a module based on config """ - module = Module(config_file, target_dir) + module = Module(config_file, target_dir, jinja_parameters) module.generate_task_configs() Printer.print_success("Tasks are successfully generated") diff --git a/dagger/utilities/module.py b/dagger/utilities/module.py index ff1329f..242e2f5 100644 --- a/dagger/utilities/module.py +++ b/dagger/utilities/module.py @@ -14,7 +14,7 @@ class Module: - def __init__(self, path_to_config, target_dir): + def __init__(self, path_to_config, target_dir, jinja_parameters): self._directory = path.dirname(path_to_config) self._target_dir = target_dir or "./" self._path_to_config = path_to_config @@ -29,6 +29,7 @@ def __init__(self, path_to_config, target_dir): self._branches_to_generate = config["branches_to_generate"] self._override_parameters = config.get("override_parameters", {}) self._default_parameters = config.get("default_parameters", {}) + self._jinja_parameters = jinja_parameters @staticmethod def read_yaml(yaml_str): @@ -76,6 +77,7 @@ def generate_task_configs(self): template_parameters.update(self._default_parameters or {}) template_parameters.update(attrs) template_parameters['branch_name'] = branch_name + template_parameters.update(self._jinja_parameters) dbt_manifest = None if "dbt" in self._tasks.keys(): From cd12ddabc2a6991a54cab61b274335b5b67a63af Mon Sep 17 00:00:00 2001 From: David Siklosi Date: Thu, 4 Jul 2024 16:19:44 +0200 Subject: [PATCH 090/189] Now it's possible to assign task to task groups --- dagger/dag_creator/airflow/operator_creator.py | 13 +++++++++++++ dagger/pipeline/task.py | 11 +++++++++++ 2 files changed, 24 insertions(+) diff --git a/dagger/dag_creator/airflow/operator_creator.py b/dagger/dag_creator/airflow/operator_creator.py index fc46234..b6aa036 100644 --- a/dagger/dag_creator/airflow/operator_creator.py +++ b/dagger/dag_creator/airflow/operator_creator.py @@ -1,5 +1,6 @@ from abc import ABC, abstractmethod from datetime import timedelta +from airflow.utils.task_group import TaskGroup TIMEDELTA_PARAMETERS = ['execution_timeout'] @@ -11,6 +12,15 @@ def __init__(self, task, dag): self._template_parameters = {} self._airflow_parameters = {} + def _get_existing_task_group_or_create_new(self): + group_id = self._task.task_group + if self._dag.task_group: + for group in self._dag.task_group.children.values(): + if isinstance(group, TaskGroup) and group.group_id == group_id: + return group + + return TaskGroup(group_id=group_id, dag=self._dag) + @abstractmethod def _create_operator(self, kwargs): raise NotImplementedError @@ -34,6 +44,9 @@ def _update_airflow_parameters(self): if self._task.timeout_in_seconds: self._airflow_parameters["execution_timeout"] = self._task.timeout_in_seconds + if self._task.task_group: + self._airflow_parameters["task_group"] = self._get_existing_task_group_or_create_new() + self._fix_timedelta_parameters() def create_operator(self): diff --git a/dagger/pipeline/task.py b/dagger/pipeline/task.py index 26235bd..ce07aec 100644 --- a/dagger/pipeline/task.py +++ b/dagger/pipeline/task.py @@ -36,6 +36,12 @@ def init_attributes(cls, orig_cls): comment="Use dagger init-io cli", ), Attribute(attribute_name="pool", required=False), + Attribute( + attribute_name="task_group", + required=False, + format_help=str, + comment="Task group name", + ), Attribute( attribute_name="timeout_in_seconds", required=False, @@ -73,6 +79,7 @@ def __init__(self, name: str, pipeline_name, pipeline, config: dict): self._outputs = [] self._pool = self.parse_attribute("pool") or self.default_pool self._timeout_in_seconds = self.parse_attribute("timeout_in_seconds") + self._task_group = self.parse_attribute("task_group") self.process_inputs(config["inputs"]) self.process_outputs(config["outputs"]) @@ -137,6 +144,10 @@ def pool(self): def timeout_in_seconds(self): return self._timeout_in_seconds + @property + def task_group(self): + return self._task_group + def add_input(self, task_input: IO): _logger.info("Adding input: %s to task: %s", task_input.name, self._name) self._inputs.append(task_input) From 0e6a6a638838ea14ff1607ee26a0ef9cb6778fbe Mon Sep 17 00:00:00 2001 From: David Siklosi Date: Thu, 4 Jul 2024 17:54:32 +0200 Subject: [PATCH 091/189] Adding default value to the parameter --- dagger/utilities/module.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/dagger/utilities/module.py b/dagger/utilities/module.py index 242e2f5..8697efa 100644 --- a/dagger/utilities/module.py +++ b/dagger/utilities/module.py @@ -14,7 +14,7 @@ class Module: - def __init__(self, path_to_config, target_dir, jinja_parameters): + def __init__(self, path_to_config, target_dir, jinja_parameters=None): self._directory = path.dirname(path_to_config) self._target_dir = target_dir or "./" self._path_to_config = path_to_config @@ -29,7 +29,7 @@ def __init__(self, path_to_config, target_dir, jinja_parameters): self._branches_to_generate = config["branches_to_generate"] self._override_parameters = config.get("override_parameters", {}) self._default_parameters = config.get("default_parameters", {}) - self._jinja_parameters = jinja_parameters + self._jinja_parameters = jinja_parameters or {} @staticmethod def read_yaml(yaml_str): From 8d0e23feaa6f324901440d76b68574904c039c81 Mon Sep 17 00:00:00 2001 From: Kiran Vasudev Date: Wed, 6 Nov 2024 13:32:02 +0100 Subject: [PATCH 092/189] removed custom dbt task generation logic from Module --- dagger/utilities/module.py | 20 -------------------- 1 file changed, 20 deletions(-) diff --git a/dagger/utilities/module.py b/dagger/utilities/module.py index 8697efa..4172fce 100644 --- a/dagger/utilities/module.py +++ b/dagger/utilities/module.py @@ -79,20 +79,6 @@ def generate_task_configs(self): template_parameters['branch_name'] = branch_name template_parameters.update(self._jinja_parameters) - dbt_manifest = None - if "dbt" in self._tasks.keys(): - if template_parameters.get("profile_name") == "athena": - self._dbt_module = AthenaDBTConfigParser(template_parameters) - if template_parameters.get("profile_name") == "databricks": - self._dbt_module = DatabricksDBTConfigParser(template_parameters) - - dbt_manifest = {} - dbt_manifest['nodes'] = self._dbt_module.nodes_in_manifest - dbt_manifest['sources'] = self._dbt_module.sources_in_manifest - - template_parameters["dbt_manifest"] = dbt_manifest - template_parameters["dbt_default_schema"] = self._dbt_module.dbt_default_schema - for task, task_yaml in self._tasks.items(): task_name = f"{branch_name}_{task}" _logger.info(f"Generating task {task_name}") @@ -101,12 +87,6 @@ def generate_task_configs(self): ) task_dict = yaml.safe_load(task_str) - if task == "dbt": - inputs, outputs = self._dbt_module.generate_dagger_io(branch_name) - task_dict["inputs"] = inputs - task_dict["outputs"] = outputs - task_dict["task_parameters"]["select"] = branch_name - task_dict["autogenerated_by_dagger"] = self._path_to_config override_parameters = self._override_parameters or {} merge(task_dict, override_parameters.get(branch_name, {}).get(task, {})) From 77f6026a030ef1ee1821aa9feec0d8a1e19f1b25 Mon Sep 17 00:00:00 2001 From: Kiran Vasudev Date: Thu, 7 Nov 2024 10:23:46 +0100 Subject: [PATCH 093/189] add plugins path to dagger config --- dagger/conf.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/dagger/conf.py b/dagger/conf.py index 6b5488f..5322ce3 100644 --- a/dagger/conf.py +++ b/dagger/conf.py @@ -98,4 +98,8 @@ # Alert parameters alert_config = config.get('alert', None) or {} SLACK_TOKEN = alert_config.get('slack_token', None) -DEFAULT_ALERT = alert_config.get('default_alert', {"type": "slack", "channel": "#airflow-jobs", "mentions": None}) \ No newline at end of file +DEFAULT_ALERT = alert_config.get('default_alert', {"type": "slack", "channel": "#airflow-jobs", "mentions": None}) + +# Plugin parameters +plugin_config = config.get('plugin', None) or {} +PLUGIN_DIRS = [os.path.join(AIRFLOW_HOME, path) for path in plugin_config.get('paths', [])] From fbb648b034974c6b4cb9683aea80667fa1f60918 Mon Sep 17 00:00:00 2001 From: Kiran Vasudev Date: Thu, 7 Nov 2024 10:26:00 +0100 Subject: [PATCH 094/189] added function to load plugins --- dagger/utilities/module.py | 36 ++++++++++++++++++++++++++++++------ 1 file changed, 30 insertions(+), 6 deletions(-) diff --git a/dagger/utilities/module.py b/dagger/utilities/module.py index 4172fce..7aa38ba 100644 --- a/dagger/utilities/module.py +++ b/dagger/utilities/module.py @@ -1,15 +1,16 @@ +import importlib +import inspect import logging +import os +import pkgutil from os import path -from mergedeep import merge -from dagger.utilities.dbt_config_parser import ( - AthenaDBTConfigParser, - DatabricksDBTConfigParser, -) import jinja2 - import yaml +from dagger import conf +from mergedeep import merge + _logger = logging.getLogger("root") @@ -49,6 +50,29 @@ def read_task_config(self, task): exit(1) return content + @staticmethod + def load_plugins() -> dict: + """ + Dynamically load all classes(plugins) from the folders defined in the conf.PLUGIN_DIRS variable. + The folder contains all plugins that are part of the project. + Returns: + dict: A dictionary with the class name as key and the class object as value + """ + classes = {} + + for module_info in pkgutil.iter_modules(conf.PLUGIN_DIRS): + module_name = module_info.name + module_path = os.path.join(module_info.module_finder.path, f"{module_name}.py") + spec = importlib.util.spec_from_file_location(module_name, module_path) + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + + for name, obj in inspect.getmembers(module, inspect.isclass): + classes[f"{name}"] = obj + + return classes + + @staticmethod def replace_template_parameters(_task_str, _template_parameters): environment = jinja2.Environment() From 90bb16bfc0777bf7f0ca6fea2879298bba0d8bf3 Mon Sep 17 00:00:00 2001 From: Kiran Vasudev Date: Thu, 7 Nov 2024 10:26:12 +0100 Subject: [PATCH 095/189] load plugins and render jinja --- dagger/utilities/module.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/dagger/utilities/module.py b/dagger/utilities/module.py index 7aa38ba..41da171 100644 --- a/dagger/utilities/module.py +++ b/dagger/utilities/module.py @@ -76,6 +76,10 @@ def load_plugins() -> dict: @staticmethod def replace_template_parameters(_task_str, _template_parameters): environment = jinja2.Environment() + loaded_classes = Module.load_plugins() + for class_name, class_obj in loaded_classes.items(): + environment.globals[class_name] = class_obj + template = environment.from_string(_task_str) rendered_task = template.render(_template_parameters) From 1d09c57ef9e5a570e24052f0f4fe8c4f106c7262 Mon Sep 17 00:00:00 2001 From: Kiran Vasudev Date: Thu, 7 Nov 2024 11:49:48 +0100 Subject: [PATCH 096/189] iterate over multiple folders and their subfolders --- dagger/utilities/module.py | 21 ++++++++++++--------- 1 file changed, 12 insertions(+), 9 deletions(-) diff --git a/dagger/utilities/module.py b/dagger/utilities/module.py index 41da171..4a85d37 100644 --- a/dagger/utilities/module.py +++ b/dagger/utilities/module.py @@ -60,15 +60,18 @@ def load_plugins() -> dict: """ classes = {} - for module_info in pkgutil.iter_modules(conf.PLUGIN_DIRS): - module_name = module_info.name - module_path = os.path.join(module_info.module_finder.path, f"{module_name}.py") - spec = importlib.util.spec_from_file_location(module_name, module_path) - module = importlib.util.module_from_spec(spec) - spec.loader.exec_module(module) - - for name, obj in inspect.getmembers(module, inspect.isclass): - classes[f"{name}"] = obj + for plugin_path in conf.PLUGIN_DIRS: + for root, dirs, files in os.walk(plugin_path): + for plugin_file in files: + if plugin_file.endswith(".py") and not plugin_file.startswith("__init__"): + module_name = plugin_file.replace(".py", "") + module_path = os.path.join(root, plugin_file) + spec = importlib.util.spec_from_file_location(module_name, module_path) + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + + for name, obj in inspect.getmembers(module, inspect.isclass): + classes[f"{name}"] = obj return classes From da0528ada6d08159470ab67289ff9ce5bb235d51 Mon Sep 17 00:00:00 2001 From: Kiran Vasudev Date: Thu, 7 Nov 2024 12:50:55 +0100 Subject: [PATCH 097/189] exclude all files starting with __ --- dagger/utilities/module.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dagger/utilities/module.py b/dagger/utilities/module.py index 4a85d37..6bf2c35 100644 --- a/dagger/utilities/module.py +++ b/dagger/utilities/module.py @@ -63,7 +63,7 @@ def load_plugins() -> dict: for plugin_path in conf.PLUGIN_DIRS: for root, dirs, files in os.walk(plugin_path): for plugin_file in files: - if plugin_file.endswith(".py") and not plugin_file.startswith("__init__"): + if plugin_file.endswith(".py") and not plugin_file.startswith("__"): module_name = plugin_file.replace(".py", "") module_path = os.path.join(root, plugin_file) spec = importlib.util.spec_from_file_location(module_name, module_path) From a8aeff7612e6c67453e163bb175288bb400fb31f Mon Sep 17 00:00:00 2001 From: Kiran Vasudev Date: Thu, 7 Nov 2024 12:51:03 +0100 Subject: [PATCH 098/189] added plugin to dagger config --- dagger/dagger_config.yaml | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/dagger/dagger_config.yaml b/dagger/dagger_config.yaml index 3366828..69c3d54 100644 --- a/dagger/dagger_config.yaml +++ b/dagger/dagger_config.yaml @@ -58,3 +58,7 @@ alert: # type: slack # channel: "#airflow-jobs" # mentions: + +plugin: +# paths: +# - plugins From 6addf9857cd8f8cc6a50a71532b9f0aa201f52f1 Mon Sep 17 00:00:00 2001 From: Kiran Vasudev Date: Thu, 7 Nov 2024 12:52:05 +0100 Subject: [PATCH 099/189] added logging for plugins --- dagger/conf.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/dagger/conf.py b/dagger/conf.py index 5322ce3..5036da3 100644 --- a/dagger/conf.py +++ b/dagger/conf.py @@ -101,5 +101,6 @@ DEFAULT_ALERT = alert_config.get('default_alert', {"type": "slack", "channel": "#airflow-jobs", "mentions": None}) # Plugin parameters -plugin_config = config.get('plugin', None) or {} +plugin_config = config.get('plugin', {}) PLUGIN_DIRS = [os.path.join(AIRFLOW_HOME, path) for path in plugin_config.get('paths', [])] +logging.info(f"All Python classes will be loaded as plugins from the following directories: {PLUGIN_DIRS}") From b8feafb4506ee35dc176dd0c2d530b35a236c69e Mon Sep 17 00:00:00 2001 From: Kiran Vasudev Date: Thu, 7 Nov 2024 15:44:58 +0100 Subject: [PATCH 100/189] fix --- dagger/conf.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dagger/conf.py b/dagger/conf.py index 5036da3..cbb075d 100644 --- a/dagger/conf.py +++ b/dagger/conf.py @@ -101,6 +101,6 @@ DEFAULT_ALERT = alert_config.get('default_alert', {"type": "slack", "channel": "#airflow-jobs", "mentions": None}) # Plugin parameters -plugin_config = config.get('plugin', {}) +plugin_config = config.get('plugin', None) or {} PLUGIN_DIRS = [os.path.join(AIRFLOW_HOME, path) for path in plugin_config.get('paths', [])] logging.info(f"All Python classes will be loaded as plugins from the following directories: {PLUGIN_DIRS}") From d8c6384257d3efd8f464c72fee9692621c01dd79 Mon Sep 17 00:00:00 2001 From: Kiran Vasudev Date: Thu, 7 Nov 2024 15:45:45 +0100 Subject: [PATCH 101/189] added tests for plugins --- .../sample_folder/sample_folder_plugin.py | 5 ++ tests/fixtures/plugins/sample_plugin.py | 0 tests/utilities/test_plugins.py | 60 +++++++++++++++++++ 3 files changed, 65 insertions(+) create mode 100644 tests/fixtures/plugins/sample_folder/sample_folder_plugin.py create mode 100644 tests/fixtures/plugins/sample_plugin.py create mode 100644 tests/utilities/test_plugins.py diff --git a/tests/fixtures/plugins/sample_folder/sample_folder_plugin.py b/tests/fixtures/plugins/sample_folder/sample_folder_plugin.py new file mode 100644 index 0000000..c8b931e --- /dev/null +++ b/tests/fixtures/plugins/sample_folder/sample_folder_plugin.py @@ -0,0 +1,5 @@ +class SampleFolderPlugin: + @staticmethod + def get_inputs(): + return [{"name": "sample_folder_plugin_task", "type": "dummy"}] + diff --git a/tests/fixtures/plugins/sample_plugin.py b/tests/fixtures/plugins/sample_plugin.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/utilities/test_plugins.py b/tests/utilities/test_plugins.py new file mode 100644 index 0000000..1591317 --- /dev/null +++ b/tests/utilities/test_plugins.py @@ -0,0 +1,60 @@ +import inspect +import shutil +import unittest +from pathlib import Path +from unittest.mock import patch +import os +import importlib.util + +import jinja2 + +from dagger.utilities.module import Module # Adjust this import according to your actual module structure +from dagger import conf + +TESTS_ROOT = Path(__file__).parent.parent + +class TestLoadPlugins(unittest.TestCase): + + def setUp(self): + self._jinja_environment = jinja2.Environment() + loaded_classes = Module.load_plugins() + for class_name, class_obj in loaded_classes.items(): + self._jinja_environment.globals[class_name] = class_obj + + self._template = self._jinja_environment.from_string("inputs: {{ SampleFolderPlugin.get_inputs() }}") + + @patch("dagger.conf.PLUGIN_DIRS", new=[]) + @patch("os.walk") + def test_load_plugins_no_plugin_dir(self, mock_os_walk): + # Simulate os.walk returning no Python files + mock_os_walk.return_value = [("/fake/plugin/dir", [], [])] + + result = Module.load_plugins() + + # Expecting an empty dictionary since no plugins were found + self.assertEqual(result, {}) + + @patch("dagger.conf.PLUGIN_DIRS", new=[str(TESTS_ROOT.joinpath("fixtures/plugins"))]) + def test_load_plugins(self): + + result = Module.load_plugins() + for name, plugin_class in result.items(): + result[name] = str(plugin_class) + + expected_classes = {"SampleFolderPlugin": ""} + + self.assertEqual(result, expected_classes) + + @patch("dagger.conf.PLUGIN_DIRS", new=[str(TESTS_ROOT.joinpath("fixtures/plugins"))]) + def test_load_plugins_in_jinja(self): + result = Module.load_plugins() + for class_name, class_obj in result.items(): + self._jinja_environment.globals[class_name] = class_obj + + rendered_task = self._template.render() + expected_task = "inputs: [{'name': 'sample_folder_plugin_task', 'type': 'dummy'}]" + + self.assertEqual(rendered_task, expected_task) + +if __name__ == '__main__': + unittest.main() \ No newline at end of file From a3d7f47b780f7d7a5ddcb021028b7976692dc292 Mon Sep 17 00:00:00 2001 From: Kiran Vasudev Date: Thu, 7 Nov 2024 17:49:17 +0100 Subject: [PATCH 102/189] refactor code * add the plugins into jinja env directly --- dagger/utilities/module.py | 14 +++++--------- 1 file changed, 5 insertions(+), 9 deletions(-) diff --git a/dagger/utilities/module.py b/dagger/utilities/module.py index 6bf2c35..169123e 100644 --- a/dagger/utilities/module.py +++ b/dagger/utilities/module.py @@ -3,7 +3,7 @@ import logging import os import pkgutil -from os import path +from os import path, environ import jinja2 import yaml @@ -51,7 +51,7 @@ def read_task_config(self, task): return content @staticmethod - def load_plugins() -> dict: + def load_plugins_to_jinja_environment(environment: jinja2.Environment) -> jinja2.Environment: """ Dynamically load all classes(plugins) from the folders defined in the conf.PLUGIN_DIRS variable. The folder contains all plugins that are part of the project. @@ -71,18 +71,14 @@ def load_plugins() -> dict: spec.loader.exec_module(module) for name, obj in inspect.getmembers(module, inspect.isclass): - classes[f"{name}"] = obj - - return classes + environment.globals[f"{name}"] = obj + return environment @staticmethod def replace_template_parameters(_task_str, _template_parameters): environment = jinja2.Environment() - loaded_classes = Module.load_plugins() - for class_name, class_obj in loaded_classes.items(): - environment.globals[class_name] = class_obj - + environment = Module.load_plugins_to_jinja_environment(environment) template = environment.from_string(_task_str) rendered_task = template.render(_template_parameters) From 450b54491fd9d8f7d827cd72f8aabd8073c9229b Mon Sep 17 00:00:00 2001 From: Kiran Vasudev Date: Thu, 7 Nov 2024 17:55:34 +0100 Subject: [PATCH 103/189] refactor tests --- tests/utilities/test_plugins.py | 26 +++++++------------------- 1 file changed, 7 insertions(+), 19 deletions(-) diff --git a/tests/utilities/test_plugins.py b/tests/utilities/test_plugins.py index 1591317..6a4b343 100644 --- a/tests/utilities/test_plugins.py +++ b/tests/utilities/test_plugins.py @@ -8,7 +8,7 @@ import jinja2 -from dagger.utilities.module import Module # Adjust this import according to your actual module structure +from dagger.utilities.module import Module from dagger import conf TESTS_ROOT = Path(__file__).parent.parent @@ -17,10 +17,6 @@ class TestLoadPlugins(unittest.TestCase): def setUp(self): self._jinja_environment = jinja2.Environment() - loaded_classes = Module.load_plugins() - for class_name, class_obj in loaded_classes.items(): - self._jinja_environment.globals[class_name] = class_obj - self._template = self._jinja_environment.from_string("inputs: {{ SampleFolderPlugin.get_inputs() }}") @patch("dagger.conf.PLUGIN_DIRS", new=[]) @@ -29,27 +25,19 @@ def test_load_plugins_no_plugin_dir(self, mock_os_walk): # Simulate os.walk returning no Python files mock_os_walk.return_value = [("/fake/plugin/dir", [], [])] - result = Module.load_plugins() + result_environment = Module.load_plugins_to_jinja_environment(self._jinja_environment) - # Expecting an empty dictionary since no plugins were found - self.assertEqual(result, {}) + self.assertNotIn("SampleFolderPlugin", result_environment.globals) @patch("dagger.conf.PLUGIN_DIRS", new=[str(TESTS_ROOT.joinpath("fixtures/plugins"))]) def test_load_plugins(self): + result_environment = Module.load_plugins_to_jinja_environment(self._jinja_environment) - result = Module.load_plugins() - for name, plugin_class in result.items(): - result[name] = str(plugin_class) - - expected_classes = {"SampleFolderPlugin": ""} - - self.assertEqual(result, expected_classes) + self.assertIn("SampleFolderPlugin", result_environment.globals.keys()) @patch("dagger.conf.PLUGIN_DIRS", new=[str(TESTS_ROOT.joinpath("fixtures/plugins"))]) - def test_load_plugins_in_jinja(self): - result = Module.load_plugins() - for class_name, class_obj in result.items(): - self._jinja_environment.globals[class_name] = class_obj + def test_load_plugins_render_jinja(self): + result_environment = Module.load_plugins_to_jinja_environment(self._jinja_environment) rendered_task = self._template.render() expected_task = "inputs: [{'name': 'sample_folder_plugin_task', 'type': 'dummy'}]" From 98a6b55db23a92379289e362c95be54660eefb02 Mon Sep 17 00:00:00 2001 From: Kiran Vasudev Date: Fri, 8 Nov 2024 13:32:39 +0100 Subject: [PATCH 104/189] exclude test folders from directory walk --- dagger/utilities/module.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/dagger/utilities/module.py b/dagger/utilities/module.py index 169123e..7f33690 100644 --- a/dagger/utilities/module.py +++ b/dagger/utilities/module.py @@ -58,12 +58,11 @@ def load_plugins_to_jinja_environment(environment: jinja2.Environment) -> jinja2 Returns: dict: A dictionary with the class name as key and the class object as value """ - classes = {} - for plugin_path in conf.PLUGIN_DIRS: for root, dirs, files in os.walk(plugin_path): + dirs[:] = [directory for directory in dirs if not directory.lower().startswith("test")] for plugin_file in files: - if plugin_file.endswith(".py") and not plugin_file.startswith("__"): + if plugin_file.endswith(".py") and not (plugin_file.startswith("__") or plugin_file.startswith("test")): module_name = plugin_file.replace(".py", "") module_path = os.path.join(root, plugin_file) spec = importlib.util.spec_from_file_location(module_name, module_path) From ab6ff156f784b137f115187a736ce70aba4af39e Mon Sep 17 00:00:00 2001 From: Kiran Vasudev Date: Fri, 8 Nov 2024 13:44:12 +0100 Subject: [PATCH 105/189] remove unused imports --- tests/utilities/test_plugins.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/tests/utilities/test_plugins.py b/tests/utilities/test_plugins.py index 6a4b343..353a407 100644 --- a/tests/utilities/test_plugins.py +++ b/tests/utilities/test_plugins.py @@ -1,15 +1,10 @@ -import inspect -import shutil import unittest from pathlib import Path from unittest.mock import patch -import os -import importlib.util import jinja2 from dagger.utilities.module import Module -from dagger import conf TESTS_ROOT = Path(__file__).parent.parent From 0fa5389968dd8342f985132ddae8984248833f41 Mon Sep 17 00:00:00 2001 From: Kiran Vasudev Date: Mon, 11 Nov 2024 15:56:35 +0100 Subject: [PATCH 106/189] update readme --- README.md | 34 ++++++++++++++++++++++++++++++++++ 1 file changed, 34 insertions(+) diff --git a/README.md b/README.md index 2a161be..d70a203 100644 --- a/README.md +++ b/README.md @@ -87,6 +87,40 @@ flowchart TD; ``` +Plugins for dagger +------- + +### Overview +Dagger now supports a plugin system that allows users to extend its functionality by adding custom Python classes. These plugins are integrated into the Jinja2 templating engine, enabling dynamic rendering of task configuration templates. +### Purpose +The plugin system allows users to define Python classes that can be loaded into the Jinja2 environment. When functions from these classes are invoked within a task configuration template, they are rendered dynamically using Jinja2. This feature enhances the flexibility of task configurations by allowing custom logic to be embedded directly in the templates. + +### Usage +1. **Creating a Plugin:** To create a new plugin, define a Python class in a folder(for example `plugins/sample_plugin/sample_plugin.py`) with the desired methods. For example: +```python +class MyCustomPlugin: + def generate_input(self, branch_name): + return [{"name": f"{branch_name}", "type": "dummy"}] +``` +This class defines a `generate_input` method that takes the branch_name from the module config and returns a dummy dagger task. +2. **Loading the Plugin into Dagger:** To load this plugin into Dagger's Jinja2 environment, you need to register it in your `dagger_config.yaml`: +```yaml +# pipeline.yaml +plugin: + paths: + - plugins # all Python classes within this path will be loaded into the Jinja environment +``` + +3. **Using Plugin Methods in Templates:** Once the plugin is loaded, you can call its methods from within any Jinja2 template in your task configurations: +```yaml +# task_configuration.yaml +type: batch +description: sample task +inputs: # format: list | Use dagger init-io cli + {{ MyCustomPlugin.generate_input("dummy_input") }} +``` + + Credits ------- From e7f08cd0fd841b62d755878325c8934bf3fb2317 Mon Sep 17 00:00:00 2001 From: Kiran Vasudev Date: Tue, 12 Nov 2024 15:43:21 +0100 Subject: [PATCH 107/189] fix readme --- README.md | 1 + 1 file changed, 1 insertion(+) diff --git a/README.md b/README.md index d70a203..8e59257 100644 --- a/README.md +++ b/README.md @@ -103,6 +103,7 @@ class MyCustomPlugin: return [{"name": f"{branch_name}", "type": "dummy"}] ``` This class defines a `generate_input` method that takes the branch_name from the module config and returns a dummy dagger task. + 2. **Loading the Plugin into Dagger:** To load this plugin into Dagger's Jinja2 environment, you need to register it in your `dagger_config.yaml`: ```yaml # pipeline.yaml From 7d871b6ce7deb6f016e5e5299510325d17db9142 Mon Sep 17 00:00:00 2001 From: David Siklosi Date: Thu, 2 Jan 2025 13:15:13 +0100 Subject: [PATCH 108/189] Adding new reverse etl operator to dagger inherited from batch operator --- .../operator_creators/reverse_etl_creator.py | 55 +++++ .../airflow/operators/reverse_etl_batch.py | 8 + dagger/pipeline/tasks/reverse_etl_task.py | 203 ++++++++++++++++++ 3 files changed, 266 insertions(+) create mode 100644 dagger/dag_creator/airflow/operator_creators/reverse_etl_creator.py create mode 100644 dagger/dag_creator/airflow/operators/reverse_etl_batch.py create mode 100644 dagger/pipeline/tasks/reverse_etl_task.py diff --git a/dagger/dag_creator/airflow/operator_creators/reverse_etl_creator.py b/dagger/dag_creator/airflow/operator_creators/reverse_etl_creator.py new file mode 100644 index 0000000..f6f9095 --- /dev/null +++ b/dagger/dag_creator/airflow/operator_creators/reverse_etl_creator.py @@ -0,0 +1,55 @@ +import base64 + +from dagger.dag_creator.airflow.operator_creators.batch_creator import BatchCreator +import json + + +class ReverseEtlCreator(BatchCreator): + ref_name = "reverse_etl" + + def __init__(self, task, dag): + super().__init__(task, dag) + + self._assume_role_arn = task.assume_role_arn + self._num_threads = task.num_threads + self._batch_size = task.batch_size + self._absolute_job_name = task.absolute_job_name + self._primary_id_column = task.primary_id_column + self._secondary_id_column = task.secondary_id_column + self._custom_id_column = task.custom_id_column + self._model_name = task.model_name + self._project_name = task.project_name + self._is_deleted_column = task.is_deleted_column + self._hash_column = task.hash_column + self._updated_at_column = task.updated_at_column + self._from_time = task.from_time + self._days_to_live = task.days_to_live + + def _generate_command(self): + command = [self._task.executable_prefix, self._task.executable] + + + command.append(f"--num_threads={self._num_threads}") + command.append(f"--batch_size={self._batch_size}") + command.append(f"--primary_id_column={self._primary_id_column}") + command.append(f"--model_name={self._model_name}") + command.append(f"--project_name={self._project_name}") + + if self._assume_role_arn: + command.append(f"--assume_role_arn={self._assume_role_arn}") + if self._secondary_id_column: + command.append(f"--secondary_id_column={self._secondary_id_column}") + if self._custom_id_column: + command.append(f"--custom_id_column={self._custom_id_column}") + if self._is_deleted_column: + command.append(f"--is_deleted_column={self._is_deleted_column}") + if self._hash_column: + command.append(f"--hash_column={self._hash_column}") + if self._updated_at_column: + command.append(f"--updated_at_column={self._updated_at_column}") + if self._from_time: + command.append(f"--from_time={self._from_time}") + if self._days_to_live: + command.append(f"--days_to_live={self._days_to_live}") + + return command diff --git a/dagger/dag_creator/airflow/operators/reverse_etl_batch.py b/dagger/dag_creator/airflow/operators/reverse_etl_batch.py new file mode 100644 index 0000000..abb775b --- /dev/null +++ b/dagger/dag_creator/airflow/operators/reverse_etl_batch.py @@ -0,0 +1,8 @@ +from dagger.dag_creator.airflow.operators.awsbatch_operator import AWSBatchOperator + +class ReverseEtlBatchOperator(AWSBatchOperator): + custom_operator_name = 'ReverseETL' + ui_color = "#f0ede4" + + def __init__(self, *args, **kwargs): + super().__init__(args, kwargs) diff --git a/dagger/pipeline/tasks/reverse_etl_task.py b/dagger/pipeline/tasks/reverse_etl_task.py new file mode 100644 index 0000000..1e21d49 --- /dev/null +++ b/dagger/pipeline/tasks/reverse_etl_task.py @@ -0,0 +1,203 @@ +from dagger.pipeline.tasks.batch_task import BatchTask +from dagger.utilities.config_validator import Attribute + +class ReverseEtlTask(BatchTask): + ref_name = "reverse_etl" + + @classmethod + def init_attributes(cls, orig_cls): + cls.add_config_attributes( + [ + Attribute( + attribute_name="executable_prefix", + required=False, + parent_fields=["task_parameters"], + comment="E.g.: python", + ), + Attribute( + attribute_name="executable", + required=False, + parent_fields=["task_parameters"], + comment="E.g.: my_code.py", + ), + Attribute( + attribute_name="assume_role_arn", + parent_fields=["task_parameters"], + required = False, + validator=str, + comment="The ARN of the role to assume before running the job", + ), + Attribute( + attribute_name="num_threads", + parent_fields=["task_parameters"], + required=False, + validator=int, + comment="The number of threads to use for the job", + ), + Attribute( + attribute_name="batch_size", + parent_fields=["task_parameters"], + required=False, + validator=int, + comment="The number of rows to fetch in each batch", + ), + Attribute( + attribute_name="primary_id_column", + parent_fields=["task_parameters"], + validator=str, + comment="The primary key column to use for the job", + ), + Attribute( + attribute_name="secondary_id_column", + parent_fields=["task_parameters"], + validator=str, + required=False, + comment="The secondary key column to use for the job", + ), + Attribute( + attribute_name="custom_id_column", + parent_fields=["task_parameters"], + validator=str, + required=False, + comment="The custom key column to use for the job", + ), + Attribute( + attribute_name="model_name", + parent_fields=["task_parameters"], + validator=str, + required=False, + comment="The name of the model. This is going to be a column on the target table. By default it is" + " set to the name of the input .", + ), + Attribute( + attribute_name="project_name", + parent_fields=["task_parameters"], + validator=str, + required=False, + comment="The name of the project. This is going to be a column on the target table. By default it is" + " set to feature_store", + ), + Attribute( + attribute_name="is_deleted_column", + parent_fields=["task_parameters"], + validator=str, + required=False, + comment="The column that has the boolean flag to indicate if the row is deleted", + ), + Attribute( + attribute_name="hash_column", + parent_fields=["task_parameters"], + validator=str, + required=False, + comment="The column that has the the hash value of the row to be used to get the diff since " + "the last export. If provided, the from_time is required. It's mutually exclusive with " + "updated_at_column", + ), + Attribute( + attribute_name="updated_at_column", + parent_fields=["task_parameters"], + validator=str, + required=False, + comment="The column that has the last updated timestamp of the row to be used to get the diff " + "since the last export. If provided, the from_time is required. It's mutually exclusive " + "with hash_column", + ), + Attribute( + attribute_name="from_time", + parent_fields=["task_parameters"], + validator=str, + required=False, + comment="Timestamp in YYYY-mm-ddTHH:MM format. It is used for incremental loads." + "It's required when hash_column or updated_at_column is provided", + ), + Attribute( + attribute_name="days_to_live", + parent_fields=["task_parameters"], + validator=str, + required=False, + comment="The number of days to keep the data in the table. If provided, the time_to_live attribute " + "will be set in dynamodb", + ), + + ] + ) + + def __init__(self, name, pipeline_name, pipeline, job_config): + super().__init__(name, pipeline_name, pipeline, job_config) + + self.executable = self.executable or "reverse_etl.py" + self.executable_prefix = self.executable_prefix or "python" + + self._assume_role_arn = self.parse_attribute("assume_role_arn") + self._num_threads = self.parse_attribute("num_threads") or 4 + self._batch_size = self.parse_attribute("batch_size") or 10000 + self._absolute_job_name = self._absolute_job_name or "common_batch_jobs/reverse_etl" + self._primary_id_column = self.parse_attribute("primary_id_column") + self._secondary_id_column = self.parse_attribute("secondary_id_column") + self._custom_id_column = self.parse_attribute("custom_id_column") + self._model_name = self.parse_attribute("model_name") + self._project_name = self.parse_attribute("project_name") or "feature_store" + self._is_deleted_column = self.parse_attribute("is_deleted_column") + self._hash_column = self.parse_attribute("hash_column") + self._updated_at_column = self.parse_attribute("updated_at_column") + self._from_time = self.parse_attribute("from_time") + self._days_to_live = self.parse_attribute("days_to_live") + + if self._hash_column and self._updated_at_column: + raise ValueError("hash_column and updated_at_column are mutually exclusive") + + if self._hash_column or self._updated_at_column: + if not self._from_time: + raise ValueError("from_time is required when hash_column or updated_at_column is provided") + + @property + def assume_role_arn(self): + return self._assume_role_arn + + @property + def num_threads(self): + return self._num_threads + + @@property + def batch_size(self): + return self._batch_size + + @property + def primary_id_column(self): + return self._primary_id_column + + @property + def secondary_id_column(self): + return self._secondary_id_column + + @property + def custom_id_column(self): + return self._custom_id_column + + @property + def model_name(self): + return self._model_name + + @property + def project_name(self): + return self._project_name + + @property + def is_deleted_column(self): + return self._is_deleted_column + + @property + def hash_column(self): + return self._hash_column + + @property + def updated_at_column(self): + return self._updated_at_column + + @property + def from_time(self): + return self._from_time + + @property + def days_to_live(self): + return self._days_to_live From bfceb811de2bc837ff42c8c080205069ce4f60fb Mon Sep 17 00:00:00 2001 From: David Siklosi Date: Thu, 2 Jan 2025 13:15:45 +0100 Subject: [PATCH 109/189] Registering the new operator with dagger --- dagger/dag_creator/airflow/operator_factory.py | 1 + dagger/pipeline/task_factory.py | 1 + 2 files changed, 2 insertions(+) diff --git a/dagger/dag_creator/airflow/operator_factory.py b/dagger/dag_creator/airflow/operator_factory.py index 706a737..f610f1e 100644 --- a/dagger/dag_creator/airflow/operator_factory.py +++ b/dagger/dag_creator/airflow/operator_factory.py @@ -10,6 +10,7 @@ redshift_load_creator, redshift_transform_creator, redshift_unload_creator, + reverse_etl_creator, spark_creator, sqoop_creator, ) diff --git a/dagger/pipeline/task_factory.py b/dagger/pipeline/task_factory.py index a9c5eef..d8a1e53 100644 --- a/dagger/pipeline/task_factory.py +++ b/dagger/pipeline/task_factory.py @@ -9,6 +9,7 @@ redshift_load_task, redshift_transform_task, redshift_unload_task, + reverse_etl_task, spark_task, sqoop_task, ) From 0b2ac50ac6b0d97628e715bba01e6ac8b0fbcead Mon Sep 17 00:00:00 2001 From: David Siklosi Date: Thu, 2 Jan 2025 13:16:19 +0100 Subject: [PATCH 110/189] Small type fix to resolve broken cli help command --- dagger/pipeline/task.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dagger/pipeline/task.py b/dagger/pipeline/task.py index ce07aec..d484e49 100644 --- a/dagger/pipeline/task.py +++ b/dagger/pipeline/task.py @@ -39,7 +39,7 @@ def init_attributes(cls, orig_cls): Attribute( attribute_name="task_group", required=False, - format_help=str, + format_help="str", comment="Task group name", ), Attribute( From e414b84cb9781a5e850c9e9d6cbb15416dbac375 Mon Sep 17 00:00:00 2001 From: David Siklosi Date: Thu, 2 Jan 2025 13:17:21 +0100 Subject: [PATCH 111/189] Adding the possibility that inherited operator can overwrite attribute of base operator --- dagger/utilities/config_validator.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/dagger/utilities/config_validator.py b/dagger/utilities/config_validator.py index 1d68f33..1f70c1b 100644 --- a/dagger/utilities/config_validator.py +++ b/dagger/utilities/config_validator.py @@ -98,10 +98,14 @@ def init_attributes_once(cls, orig_cls: str) -> None: cls.init_attributes(orig_cls) if parent_class.__name__ != "ConfigValidator": - cls.config_attributes[cls.__name__] = ( - cls.config_attributes[parent_class.__name__] - + cls.config_attributes[cls.__name__] - ) + parent_attributes = cls.config_attributes[parent_class.__name__] + current_attributes = cls.config_attributes[cls.__name__] + + merged_attributes = {attr.name: attr for attr in parent_attributes} + for attr in current_attributes: + merged_attributes[attr.name] = attr + + cls.config_attributes[cls.__name__] = list(merged_attributes.values()) attributes_lookup = {} for index, attribute in enumerate(cls.config_attributes[cls.__name__]): From 67a8142cbdd8a143870224f4cfa7616b68114d87 Mon Sep 17 00:00:00 2001 From: David Siklosi Date: Thu, 2 Jan 2025 14:38:33 +0100 Subject: [PATCH 112/189] Smaller fixes; syntax fix; Fixing command creation by extending the existing solution in base class --- .../operator_creators/reverse_etl_creator.py | 22 +++++++++++++++++-- .../airflow/operators/reverse_etl_batch.py | 3 --- dagger/pipeline/tasks/reverse_etl_task.py | 8 +++---- 3 files changed, 23 insertions(+), 10 deletions(-) diff --git a/dagger/dag_creator/airflow/operator_creators/reverse_etl_creator.py b/dagger/dag_creator/airflow/operator_creators/reverse_etl_creator.py index f6f9095..be94f74 100644 --- a/dagger/dag_creator/airflow/operator_creators/reverse_etl_creator.py +++ b/dagger/dag_creator/airflow/operator_creators/reverse_etl_creator.py @@ -1,6 +1,7 @@ import base64 from dagger.dag_creator.airflow.operator_creators.batch_creator import BatchCreator +from dagger.dag_creator.airflow.operators.reverse_etl_batch import ReverseEtlBatchOperator import json @@ -26,8 +27,7 @@ def __init__(self, task, dag): self._days_to_live = task.days_to_live def _generate_command(self): - command = [self._task.executable_prefix, self._task.executable] - + command = BatchCreator._generate_command(self) command.append(f"--num_threads={self._num_threads}") command.append(f"--batch_size={self._batch_size}") @@ -53,3 +53,21 @@ def _generate_command(self): command.append(f"--days_to_live={self._days_to_live}") return command + + def _create_operator(self, **kwargs): + overrides = self._task.overrides + overrides.update({"command": self._generate_command()}) + + job_name = self._validate_job_name(self._task.job_name, self._task.absolute_job_name) + batch_op = ReverseEtlBatchOperator( + dag=self._dag, + task_id=self._task.name, + job_name=self._task.name, + job_definition=job_name, + region_name=self._task.region_name, + job_queue=self._task.job_queue, + container_overrides=overrides, + awslogs_enabled=True, + **kwargs, + ) + return batch_op diff --git a/dagger/dag_creator/airflow/operators/reverse_etl_batch.py b/dagger/dag_creator/airflow/operators/reverse_etl_batch.py index abb775b..78c1619 100644 --- a/dagger/dag_creator/airflow/operators/reverse_etl_batch.py +++ b/dagger/dag_creator/airflow/operators/reverse_etl_batch.py @@ -3,6 +3,3 @@ class ReverseEtlBatchOperator(AWSBatchOperator): custom_operator_name = 'ReverseETL' ui_color = "#f0ede4" - - def __init__(self, *args, **kwargs): - super().__init__(args, kwargs) diff --git a/dagger/pipeline/tasks/reverse_etl_task.py b/dagger/pipeline/tasks/reverse_etl_task.py index 1e21d49..47cf555 100644 --- a/dagger/pipeline/tasks/reverse_etl_task.py +++ b/dagger/pipeline/tasks/reverse_etl_task.py @@ -31,14 +31,12 @@ def init_attributes(cls, orig_cls): attribute_name="num_threads", parent_fields=["task_parameters"], required=False, - validator=int, comment="The number of threads to use for the job", ), Attribute( attribute_name="batch_size", parent_fields=["task_parameters"], required=False, - validator=int, comment="The number of rows to fetch in each batch", ), Attribute( @@ -125,8 +123,8 @@ def init_attributes(cls, orig_cls): def __init__(self, name, pipeline_name, pipeline, job_config): super().__init__(name, pipeline_name, pipeline, job_config) - self.executable = self.executable or "reverse_etl.py" - self.executable_prefix = self.executable_prefix or "python" + self._executable = self.executable or "reverse_etl.py" + self._executable_prefix = self.executable_prefix or "python" self._assume_role_arn = self.parse_attribute("assume_role_arn") self._num_threads = self.parse_attribute("num_threads") or 4 @@ -158,7 +156,7 @@ def assume_role_arn(self): def num_threads(self): return self._num_threads - @@property + @property def batch_size(self): return self._batch_size From 0557bab09a16a35acd62df24f20c398b367af318 Mon Sep 17 00:00:00 2001 From: David Siklosi Date: Fri, 3 Jan 2025 12:25:18 +0100 Subject: [PATCH 113/189] Adding dynamo and sns io types --- dagger/pipeline/ios/dynamo_io.py | 60 ++++++++++++++++++++++++++++++++ dagger/pipeline/ios/sns_io.py | 60 ++++++++++++++++++++++++++++++++ 2 files changed, 120 insertions(+) create mode 100644 dagger/pipeline/ios/dynamo_io.py create mode 100644 dagger/pipeline/ios/sns_io.py diff --git a/dagger/pipeline/ios/dynamo_io.py b/dagger/pipeline/ios/dynamo_io.py new file mode 100644 index 0000000..c10459c --- /dev/null +++ b/dagger/pipeline/ios/dynamo_io.py @@ -0,0 +1,60 @@ +from dagger.pipeline.io import IO +from dagger.utilities.config_validator import Attribute + + +class DynamoIO(IO): + ref_name = "dynamo" + + @classmethod + def init_attributes(cls, orig_cls): + cls.add_config_attributes( + [ + Attribute( + attribute_name="account_id", + required=False, + comment="Only needed for cross account dynamo tables" + ), + Attribute( + attribute_name="region", + required=False, + comment="Only needed for cross region dynamo tables" + ), + Attribute( + attribute_name="table", + comment="The name of the dynamo table" + ), + ] + ) + + def __init__(self, io_config, config_location): + super().__init__(io_config, config_location) + + self._account_id = self.parse_attribute("account_id") + self._region = self.parse_attribute("region") + self._table = self.parse_attribute("table") + + def alias(self): + return f"dynamo://{self._account_id or ''}/{self._region or ''}/{self._table}" + + @property + def rendered_name(self): + if not self._account_id and not self._region: + return self._table + else: + return ":".join([self._account_id or '', self._region or '', self._table]) + + @property + def airflow_name(self): + return f"dynamo-{'-'.join([name_part for name_part in [self._account_id, self._region, self._table] if name_part])}" + + @property + def account_id(self): + return self._account_id + + @property + def region(self): + return self._region + + @property + def table(self): + return self._table diff --git a/dagger/pipeline/ios/sns_io.py b/dagger/pipeline/ios/sns_io.py new file mode 100644 index 0000000..14b4112 --- /dev/null +++ b/dagger/pipeline/ios/sns_io.py @@ -0,0 +1,60 @@ +from dagger.pipeline.io import IO +from dagger.utilities.config_validator import Attribute + + +class SnsIO(IO): + ref_name = "sns" + + @classmethod + def init_attributes(cls, orig_cls): + cls.add_config_attributes( + [ + Attribute( + attribute_name="account_id", + required=False, + comment="Only needed for cross account dynamo tables" + ), + Attribute( + attribute_name="region", + required=False, + comment="Only needed for cross region dynamo tables" + ), + Attribute( + attribute_name="sns_topic", + comment="The name of the sns topic" + ), + ] + ) + + def __init__(self, io_config, config_location): + super().__init__(io_config, config_location) + + self._account_id = self.parse_attribute("account_id") + self._region = self.parse_attribute("region") + self._sns_topic = self.parse_attribute("sns_topic") + + def alias(self): + return f"dynamo://{self._account_id or ''}/{self._region or ''}/{self._sns_topic}" + + @property + def rendered_name(self): + if not self._account_id and not self._region: + return self._sns_topic + else: + return ":".join([self._account_id or '', self._region or '', self._sns_topic]) + + @property + def airflow_name(self): + return f"dynamo-{'-'.join([name_part for name_part in [self._account_id, self._region, self._sns_topic] if name_part])}" + + @property + def account_id(self): + return self._account_id + + @property + def region(self): + return self._region + + @property + def sns_topic(self): + return self._sns_topic \ No newline at end of file From 4ddfc338e346e11e052c99fe6676f9fa8633b53c Mon Sep 17 00:00:00 2001 From: David Siklosi Date: Fri, 3 Jan 2025 12:25:38 +0100 Subject: [PATCH 114/189] Adding dynamo and sns io types --- dagger/pipeline/io_factory.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/dagger/pipeline/io_factory.py b/dagger/pipeline/io_factory.py index 5454f31..61d9fd2 100644 --- a/dagger/pipeline/io_factory.py +++ b/dagger/pipeline/io_factory.py @@ -7,7 +7,9 @@ gdrive_io, redshift_io, s3_io, - databricks_io + databricks_io, + dynamo_io, + sns_io, ) from dagger.utilities.classes import get_deep_obj_subclasses From 8038d3cd494f2b2f5ced1e741e7c920de3aeface Mon Sep 17 00:00:00 2001 From: David Siklosi Date: Fri, 3 Jan 2025 12:27:17 +0100 Subject: [PATCH 115/189] Fixing input/output name for reverse etl so it matches the batch job expected format; Inferring output type (dynamo/sns) based on output type and passing as an argument to the batch job --- .../operator_creators/reverse_etl_creator.py | 2 ++ dagger/pipeline/io.py | 4 +++ dagger/pipeline/tasks/reverse_etl_task.py | 27 +++++++++++++++++++ 3 files changed, 33 insertions(+) diff --git a/dagger/dag_creator/airflow/operator_creators/reverse_etl_creator.py b/dagger/dag_creator/airflow/operator_creators/reverse_etl_creator.py index be94f74..d81d706 100644 --- a/dagger/dag_creator/airflow/operator_creators/reverse_etl_creator.py +++ b/dagger/dag_creator/airflow/operator_creators/reverse_etl_creator.py @@ -25,6 +25,7 @@ def __init__(self, task, dag): self._updated_at_column = task.updated_at_column self._from_time = task.from_time self._days_to_live = task.days_to_live + self._output_type = task.output_type def _generate_command(self): command = BatchCreator._generate_command(self) @@ -34,6 +35,7 @@ def _generate_command(self): command.append(f"--primary_id_column={self._primary_id_column}") command.append(f"--model_name={self._model_name}") command.append(f"--project_name={self._project_name}") + command.append(f"--output_type={self._output_type}") if self._assume_role_arn: command.append(f"--assume_role_arn={self._assume_role_arn}") diff --git a/dagger/pipeline/io.py b/dagger/pipeline/io.py index 32ae303..cdacdd0 100644 --- a/dagger/pipeline/io.py +++ b/dagger/pipeline/io.py @@ -63,6 +63,10 @@ def alias(self): def name(self): return self._name + @name.setter + def name(self, value): + self._name = value + @property def has_dependency(self): return self._has_dependency diff --git a/dagger/pipeline/tasks/reverse_etl_task.py b/dagger/pipeline/tasks/reverse_etl_task.py index 47cf555..29e70dd 100644 --- a/dagger/pipeline/tasks/reverse_etl_task.py +++ b/dagger/pipeline/tasks/reverse_etl_task.py @@ -148,6 +148,29 @@ def __init__(self, name, pipeline_name, pipeline, job_config): if not self._from_time: raise ValueError("from_time is required when hash_column or updated_at_column is provided") + # Making sure the input table name is set as it is expected in the reverse etl job + input_index = self._get_io_index(self._inputs) + self._inputs[input_index].name = "input_table_name" + + # Making sure the output name is set as it is expected in the reverse etl job + output_index = self._get_io_index(self._outputs) + self._outputs[output_index].name = "output_name" + + # Extracting the output type from the output definition + self._output_type = self._outputs[output_index].ref_name + if not self._output_type: + raise ValueError("ReverseEtlTask must have an output") + + @staticmethod + def _get_io_index(ios): + if len([io for io in ios if io.ref_name != "dummy"]) > 1: + raise ValueError("ReverseEtlTask can only have one input or output") + + for i, io in enumerate(ios): + if io.ref_name != "dummy": + return i + + @property def assume_role_arn(self): return self._assume_role_arn @@ -199,3 +222,7 @@ def from_time(self): @property def days_to_live(self): return self._days_to_live + + @property + def output_type(self): + return self._output_type From 60c4736efc6a4d17567a09685c9e3481f8f01dbf Mon Sep 17 00:00:00 2001 From: David Siklosi Date: Fri, 3 Jan 2025 12:35:25 +0100 Subject: [PATCH 116/189] Handling hard wired constants as local parameters of the task --- dagger/pipeline/tasks/reverse_etl_task.py | 19 +++++++++++++------ 1 file changed, 13 insertions(+), 6 deletions(-) diff --git a/dagger/pipeline/tasks/reverse_etl_task.py b/dagger/pipeline/tasks/reverse_etl_task.py index 29e70dd..3c87a0d 100644 --- a/dagger/pipeline/tasks/reverse_etl_task.py +++ b/dagger/pipeline/tasks/reverse_etl_task.py @@ -4,6 +4,13 @@ class ReverseEtlTask(BatchTask): ref_name = "reverse_etl" + DEFAULT_EXECUTABLE_PREFIX = "python" + DEFAULT_EXECUTABLE = "reverse_etl.py" + DEFAULT_NUM_THREADS = 4 + DEFAULT_BATCH_SIZE = 10000 + DEFAULT_JOB_NAME = "common_batch_jobs/reverse_etl" + DEFAULT_PROJECT_NAME = "feature_store" + @classmethod def init_attributes(cls, orig_cls): cls.add_config_attributes( @@ -123,18 +130,18 @@ def init_attributes(cls, orig_cls): def __init__(self, name, pipeline_name, pipeline, job_config): super().__init__(name, pipeline_name, pipeline, job_config) - self._executable = self.executable or "reverse_etl.py" - self._executable_prefix = self.executable_prefix or "python" + self._executable = self.executable or self.DEFAULT_EXECUTABLE + self._executable_prefix = self.executable_prefix or self.DEFAULT_EXECUTABLE_PREFIX self._assume_role_arn = self.parse_attribute("assume_role_arn") - self._num_threads = self.parse_attribute("num_threads") or 4 - self._batch_size = self.parse_attribute("batch_size") or 10000 - self._absolute_job_name = self._absolute_job_name or "common_batch_jobs/reverse_etl" + self._num_threads = self.parse_attribute("num_threads") or self.DEFAULT_NUM_THREADS + self._batch_size = self.parse_attribute("batch_size") or self.DEFAULT_BATCH_SIZE + self._absolute_job_name = self._absolute_job_name or self.DEFAULT_JOB_NAME self._primary_id_column = self.parse_attribute("primary_id_column") self._secondary_id_column = self.parse_attribute("secondary_id_column") self._custom_id_column = self.parse_attribute("custom_id_column") self._model_name = self.parse_attribute("model_name") - self._project_name = self.parse_attribute("project_name") or "feature_store" + self._project_name = self.parse_attribute("project_name") or self.DEFAULT_PROJECT_NAME self._is_deleted_column = self.parse_attribute("is_deleted_column") self._hash_column = self.parse_attribute("hash_column") self._updated_at_column = self.parse_attribute("updated_at_column") From f64910f749f74773083218aeb47d8b8b15def9a6 Mon Sep 17 00:00:00 2001 From: David Siklosi Date: Fri, 3 Jan 2025 16:38:17 +0100 Subject: [PATCH 117/189] Removing account_id from io; naming convention; fixing small issues --- .../operator_creators/reverse_etl_creator.py | 3 +++ dagger/pipeline/ios/dynamo_io.py | 27 +++++-------------- dagger/pipeline/ios/sns_io.py | 22 +++++---------- dagger/pipeline/tasks/reverse_etl_task.py | 25 ++++++++++++----- 4 files changed, 35 insertions(+), 42 deletions(-) diff --git a/dagger/dag_creator/airflow/operator_creators/reverse_etl_creator.py b/dagger/dag_creator/airflow/operator_creators/reverse_etl_creator.py index d81d706..e133e40 100644 --- a/dagger/dag_creator/airflow/operator_creators/reverse_etl_creator.py +++ b/dagger/dag_creator/airflow/operator_creators/reverse_etl_creator.py @@ -26,6 +26,7 @@ def __init__(self, task, dag): self._from_time = task.from_time self._days_to_live = task.days_to_live self._output_type = task.output_type + self._region_name = task.region_name def _generate_command(self): command = BatchCreator._generate_command(self) @@ -53,6 +54,8 @@ def _generate_command(self): command.append(f"--from_time={self._from_time}") if self._days_to_live: command.append(f"--days_to_live={self._days_to_live}") + if self._region_name: + command.append(f"--region_name={self._region_name}") return command diff --git a/dagger/pipeline/ios/dynamo_io.py b/dagger/pipeline/ios/dynamo_io.py index c10459c..88d822e 100644 --- a/dagger/pipeline/ios/dynamo_io.py +++ b/dagger/pipeline/ios/dynamo_io.py @@ -10,12 +10,7 @@ def init_attributes(cls, orig_cls): cls.add_config_attributes( [ Attribute( - attribute_name="account_id", - required=False, - comment="Only needed for cross account dynamo tables" - ), - Attribute( - attribute_name="region", + attribute_name="region_name", required=False, comment="Only needed for cross region dynamo tables" ), @@ -29,31 +24,23 @@ def init_attributes(cls, orig_cls): def __init__(self, io_config, config_location): super().__init__(io_config, config_location) - self._account_id = self.parse_attribute("account_id") - self._region = self.parse_attribute("region") + self._region_name = self.parse_attribute("region_name") self._table = self.parse_attribute("table") def alias(self): - return f"dynamo://{self._account_id or ''}/{self._region or ''}/{self._table}" + return f"dynamo://{self._region_name or ''}/{self._table}" @property def rendered_name(self): - if not self._account_id and not self._region: - return self._table - else: - return ":".join([self._account_id or '', self._region or '', self._table]) + return self._table @property def airflow_name(self): - return f"dynamo-{'-'.join([name_part for name_part in [self._account_id, self._region, self._table] if name_part])}" - - @property - def account_id(self): - return self._account_id + return f"dynamo-{'-'.join([name_part for name_part in [self._region_name, self._table] if name_part])}" @property - def region(self): - return self._region + def region_name(self): + return self._region_name @property def table(self): diff --git a/dagger/pipeline/ios/sns_io.py b/dagger/pipeline/ios/sns_io.py index 14b4112..3be660d 100644 --- a/dagger/pipeline/ios/sns_io.py +++ b/dagger/pipeline/ios/sns_io.py @@ -15,7 +15,7 @@ def init_attributes(cls, orig_cls): comment="Only needed for cross account dynamo tables" ), Attribute( - attribute_name="region", + attribute_name="region_name", required=False, comment="Only needed for cross region dynamo tables" ), @@ -29,31 +29,23 @@ def init_attributes(cls, orig_cls): def __init__(self, io_config, config_location): super().__init__(io_config, config_location) - self._account_id = self.parse_attribute("account_id") - self._region = self.parse_attribute("region") + self._region_name = self.parse_attribute("region_name") self._sns_topic = self.parse_attribute("sns_topic") def alias(self): - return f"dynamo://{self._account_id or ''}/{self._region or ''}/{self._sns_topic}" + return f"dynamo://{self._region_name or ''}/{self._sns_topic}" @property def rendered_name(self): - if not self._account_id and not self._region: - return self._sns_topic - else: - return ":".join([self._account_id or '', self._region or '', self._sns_topic]) + return self._sns_topic @property def airflow_name(self): - return f"dynamo-{'-'.join([name_part for name_part in [self._account_id, self._region, self._sns_topic] if name_part])}" + return f"dynamo-{'-'.join([name_part for name_part in [self._region_name, self._sns_topic] if name_part])}" @property - def account_id(self): - return self._account_id - - @property - def region(self): - return self._region + def region_name(self): + return self._region_name @property def sns_topic(self): diff --git a/dagger/pipeline/tasks/reverse_etl_task.py b/dagger/pipeline/tasks/reverse_etl_task.py index 3c87a0d..b62bba5 100644 --- a/dagger/pipeline/tasks/reverse_etl_task.py +++ b/dagger/pipeline/tasks/reverse_etl_task.py @@ -149,33 +149,40 @@ def __init__(self, name, pipeline_name, pipeline, job_config): self._days_to_live = self.parse_attribute("days_to_live") if self._hash_column and self._updated_at_column: - raise ValueError("hash_column and updated_at_column are mutually exclusive") + raise ValueError(f"ReverseETLTask: {self._name} hash_column and updated_at_column are mutually exclusive") if self._hash_column or self._updated_at_column: if not self._from_time: - raise ValueError("from_time is required when hash_column or updated_at_column is provided") + raise ValueError(f"ReverseETLTask: {self._name} from_time is required when hash_column or updated_at_column is provided") # Making sure the input table name is set as it is expected in the reverse etl job input_index = self._get_io_index(self._inputs) + print('XXX', self._inputs, input_index) + if input_index is None: + raise ValueError(f"ReverseEtlTask: {self._name} must have an input") self._inputs[input_index].name = "input_table_name" # Making sure the output name is set as it is expected in the reverse etl job output_index = self._get_io_index(self._outputs) + if output_index is None: + raise ValueError(f"ReverseEtlTask: {self._name} must have an output") self._outputs[output_index].name = "output_name" # Extracting the output type from the output definition self._output_type = self._outputs[output_index].ref_name - if not self._output_type: - raise ValueError("ReverseEtlTask must have an output") - @staticmethod - def _get_io_index(ios): + # Extracting the outputs region name from the output definition + self._region_name = self._outputs[output_index].region_name + + + def _get_io_index(self, ios): if len([io for io in ios if io.ref_name != "dummy"]) > 1: - raise ValueError("ReverseEtlTask can only have one input or output") + raise ValueError(f"ReverseEtlTask: {self._name} can only have one input or output") for i, io in enumerate(ios): if io.ref_name != "dummy": return i + return None @property @@ -233,3 +240,7 @@ def days_to_live(self): @property def output_type(self): return self._output_type + + @property + def region_name(self): + return self._region_name From 2ea6f7e6ca0d8ce36861fc40545166c5f9ad58cb Mon Sep 17 00:00:00 2001 From: David Siklosi Date: Fri, 3 Jan 2025 17:41:34 +0100 Subject: [PATCH 118/189] Fixing batch job name --- dagger/pipeline/tasks/reverse_etl_task.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dagger/pipeline/tasks/reverse_etl_task.py b/dagger/pipeline/tasks/reverse_etl_task.py index b62bba5..98094e6 100644 --- a/dagger/pipeline/tasks/reverse_etl_task.py +++ b/dagger/pipeline/tasks/reverse_etl_task.py @@ -8,7 +8,7 @@ class ReverseEtlTask(BatchTask): DEFAULT_EXECUTABLE = "reverse_etl.py" DEFAULT_NUM_THREADS = 4 DEFAULT_BATCH_SIZE = 10000 - DEFAULT_JOB_NAME = "common_batch_jobs/reverse_etl" + DEFAULT_JOB_NAME = "common_batch_jobs-reverse_etl" DEFAULT_PROJECT_NAME = "feature_store" @classmethod From 85c9c2aa2b896ccdf8617ea8d2cd59dd58a63480 Mon Sep 17 00:00:00 2001 From: David Siklosi Date: Tue, 7 Jan 2025 13:46:30 +0100 Subject: [PATCH 119/189] Removing choco specific parameters and moving them to conf file; Making some arguments default value handled by the job itslef; Making some arguments required instead hardwiring default value --- dagger/conf.py | 6 ++++++ dagger/dagger_config.yaml | 5 +++++ dagger/pipeline/tasks/reverse_etl_task.py | 25 ++++++++--------------- 3 files changed, 20 insertions(+), 16 deletions(-) diff --git a/dagger/conf.py b/dagger/conf.py index cbb075d..df2ab8e 100644 --- a/dagger/conf.py +++ b/dagger/conf.py @@ -104,3 +104,9 @@ plugin_config = config.get('plugin', None) or {} PLUGIN_DIRS = [os.path.join(AIRFLOW_HOME, path) for path in plugin_config.get('paths', [])] logging.info(f"All Python classes will be loaded as plugins from the following directories: {PLUGIN_DIRS}") + +# ReverseETL parameters +reverse_etl_config = config.get('reverse_etl', None) or {} +REVERSE_ETL_DEFAULT_JOB_NAME = reverse_etl_config.get('default_job_name', None) +REVERSE_ETL_DEFAULT_EXECUTABLE_PREFIX = reverse_etl_config.get('default_executable_prefix', None) +REVERSE_ETL_DEFAULT_EXECUTABLE = reverse_etl_config.get('default_executable', None) diff --git a/dagger/dagger_config.yaml b/dagger/dagger_config.yaml index 69c3d54..38abccd 100644 --- a/dagger/dagger_config.yaml +++ b/dagger/dagger_config.yaml @@ -62,3 +62,8 @@ alert: plugin: # paths: # - plugins + +reverse_etl: +# default_job_name: +# default_executable_prefix: +# default_executable: diff --git a/dagger/pipeline/tasks/reverse_etl_task.py b/dagger/pipeline/tasks/reverse_etl_task.py index 98094e6..6c9a5d2 100644 --- a/dagger/pipeline/tasks/reverse_etl_task.py +++ b/dagger/pipeline/tasks/reverse_etl_task.py @@ -1,16 +1,10 @@ from dagger.pipeline.tasks.batch_task import BatchTask from dagger.utilities.config_validator import Attribute +from dagger import conf class ReverseEtlTask(BatchTask): ref_name = "reverse_etl" - DEFAULT_EXECUTABLE_PREFIX = "python" - DEFAULT_EXECUTABLE = "reverse_etl.py" - DEFAULT_NUM_THREADS = 4 - DEFAULT_BATCH_SIZE = 10000 - DEFAULT_JOB_NAME = "common_batch_jobs-reverse_etl" - DEFAULT_PROJECT_NAME = "feature_store" - @classmethod def init_attributes(cls, orig_cls): cls.add_config_attributes( @@ -78,9 +72,8 @@ def init_attributes(cls, orig_cls): attribute_name="project_name", parent_fields=["task_parameters"], validator=str, - required=False, - comment="The name of the project. This is going to be a column on the target table. By default it is" - " set to feature_store", + required=True, + comment="The name of the project. This is going to be a column on the target table.", ), Attribute( attribute_name="is_deleted_column", @@ -130,18 +123,18 @@ def init_attributes(cls, orig_cls): def __init__(self, name, pipeline_name, pipeline, job_config): super().__init__(name, pipeline_name, pipeline, job_config) - self._executable = self.executable or self.DEFAULT_EXECUTABLE - self._executable_prefix = self.executable_prefix or self.DEFAULT_EXECUTABLE_PREFIX + self._executable = self.executable or conf.REVERSE_ETL_DEFAULT_EXECUTABLE + self._executable_prefix = self.executable_prefix or conf.REVERSE_ETL_DEFAULT_EXECUTABLE_PREFIX self._assume_role_arn = self.parse_attribute("assume_role_arn") - self._num_threads = self.parse_attribute("num_threads") or self.DEFAULT_NUM_THREADS - self._batch_size = self.parse_attribute("batch_size") or self.DEFAULT_BATCH_SIZE - self._absolute_job_name = self._absolute_job_name or self.DEFAULT_JOB_NAME + self._num_threads = self.parse_attribute("num_threads") + self._batch_size = self.parse_attribute("batch_size") + self._absolute_job_name = self._absolute_job_name or conf.REVERSE_ETL_DEFAULT_JOB_NAME self._primary_id_column = self.parse_attribute("primary_id_column") self._secondary_id_column = self.parse_attribute("secondary_id_column") self._custom_id_column = self.parse_attribute("custom_id_column") self._model_name = self.parse_attribute("model_name") - self._project_name = self.parse_attribute("project_name") or self.DEFAULT_PROJECT_NAME + self._project_name = self.parse_attribute("project_name") self._is_deleted_column = self.parse_attribute("is_deleted_column") self._hash_column = self.parse_attribute("hash_column") self._updated_at_column = self.parse_attribute("updated_at_column") From 3c5a61072cc17388491286aa278f75841936eb52 Mon Sep 17 00:00:00 2001 From: David Siklosi Date: Tue, 7 Jan 2025 15:08:08 +0100 Subject: [PATCH 120/189] Adding unit tests and a small fix --- dagger/pipeline/ios/sns_io.py | 4 ++-- tests/fixtures/pipeline/ios/dynamo_io.yaml | 11 +++++++++++ tests/fixtures/pipeline/ios/sns_io.yaml | 11 +++++++++++ tests/pipeline/ios/test_dynamo_io.py | 19 +++++++++++++++++++ tests/pipeline/ios/test_sns_io.py | 19 +++++++++++++++++++ 5 files changed, 62 insertions(+), 2 deletions(-) create mode 100644 tests/fixtures/pipeline/ios/dynamo_io.yaml create mode 100644 tests/fixtures/pipeline/ios/sns_io.yaml create mode 100644 tests/pipeline/ios/test_dynamo_io.py create mode 100644 tests/pipeline/ios/test_sns_io.py diff --git a/dagger/pipeline/ios/sns_io.py b/dagger/pipeline/ios/sns_io.py index 3be660d..67d17d4 100644 --- a/dagger/pipeline/ios/sns_io.py +++ b/dagger/pipeline/ios/sns_io.py @@ -33,7 +33,7 @@ def __init__(self, io_config, config_location): self._sns_topic = self.parse_attribute("sns_topic") def alias(self): - return f"dynamo://{self._region_name or ''}/{self._sns_topic}" + return f"sns://{self._region_name or ''}/{self._sns_topic}" @property def rendered_name(self): @@ -41,7 +41,7 @@ def rendered_name(self): @property def airflow_name(self): - return f"dynamo-{'-'.join([name_part for name_part in [self._region_name, self._sns_topic] if name_part])}" + return f"sns-{'-'.join([name_part for name_part in [self._region_name, self._sns_topic] if name_part])}" @property def region_name(self): diff --git a/tests/fixtures/pipeline/ios/dynamo_io.yaml b/tests/fixtures/pipeline/ios/dynamo_io.yaml new file mode 100644 index 0000000..d083171 --- /dev/null +++ b/tests/fixtures/pipeline/ios/dynamo_io.yaml @@ -0,0 +1,11 @@ +type: dynamo +name: dynamo_table +table: schema.table_name # The name of the dynamo table +region_name: eu_west_1 + + +# Other attributes: + +# has_dependency: # Weather this i/o should be added to the dependency graph or not. Default is True +# follow_external_dependency: # format: dictionary or boolean | External Task Sensor parameters in key value format: https://airflow.apache.org/docs/apache-airflow/stable/_api/airflow/sensors/base/index.html +# region_name: # Only needed for cross region dynamo tables \ No newline at end of file diff --git a/tests/fixtures/pipeline/ios/sns_io.yaml b/tests/fixtures/pipeline/ios/sns_io.yaml new file mode 100644 index 0000000..542fd8e --- /dev/null +++ b/tests/fixtures/pipeline/ios/sns_io.yaml @@ -0,0 +1,11 @@ +type: sns +name: topic_name +sns_topic: topic_name # The name of the dynamo table +region_name: eu_west_1 + + +# Other attributes: + +# has_dependency: # Weather this i/o should be added to the dependency graph or not. Default is True +# follow_external_dependency: # format: dictionary or boolean | External Task Sensor parameters in key value format: https://airflow.apache.org/docs/apache-airflow/stable/_api/airflow/sensors/base/index.html +# region_name: # Only needed for cross region dynamo tables \ No newline at end of file diff --git a/tests/pipeline/ios/test_dynamo_io.py b/tests/pipeline/ios/test_dynamo_io.py new file mode 100644 index 0000000..0633c8b --- /dev/null +++ b/tests/pipeline/ios/test_dynamo_io.py @@ -0,0 +1,19 @@ +import unittest +from dagger.pipeline.ios.dynamo_io import DynamoIO +import yaml + + +class DynamoIOTest(unittest.TestCase): + def setUp(self) -> None: + with open("tests/fixtures/pipeline/ios/dynamo_io.yaml", "r") as stream: + config = yaml.safe_load(stream) + + self.dynamo_io = DynamoIO(config, "/") + + def test_properties(self): + self.assertEqual(self.dynamo_io.alias(), "dynamo://eu_west_1/schema.table_name") + self.assertEqual(self.dynamo_io.rendered_name, "schema.table_name") + self.assertEqual(self.dynamo_io.airflow_name,"dynamo-eu_west_1-schema.table_name") + +if __name__ == "__main__": + unittest.main() diff --git a/tests/pipeline/ios/test_sns_io.py b/tests/pipeline/ios/test_sns_io.py new file mode 100644 index 0000000..2a5de25 --- /dev/null +++ b/tests/pipeline/ios/test_sns_io.py @@ -0,0 +1,19 @@ +import unittest +from dagger.pipeline.ios.sns_io import SnsIO +import yaml + + +class SnsIOTest(unittest.TestCase): + def setUp(self) -> None: + with open("tests/fixtures/pipeline/ios/sns_io.yaml", "r") as stream: + config = yaml.safe_load(stream) + + self.sns_io = SnsIO(config, "/") + + def test_properties(self): + self.assertEqual(self.sns_io.alias(), f"sns://eu_west_1/topic_name") + self.assertEqual(self.sns_io.rendered_name, "topic_name") + self.assertEqual(self.sns_io.airflow_name, "sns-eu_west_1-topic_name") + +if __name__ == "__main__": + unittest.main() From f45c84a34fb3608e50f8cecf47f3f87173c958e3 Mon Sep 17 00:00:00 2001 From: David Siklosi Date: Tue, 7 Jan 2025 15:10:56 +0100 Subject: [PATCH 121/189] Adding comments --- dagger/utilities/config_validator.py | 1 + 1 file changed, 1 insertion(+) diff --git a/dagger/utilities/config_validator.py b/dagger/utilities/config_validator.py index 1f70c1b..e90af68 100644 --- a/dagger/utilities/config_validator.py +++ b/dagger/utilities/config_validator.py @@ -101,6 +101,7 @@ def init_attributes_once(cls, orig_cls: str) -> None: parent_attributes = cls.config_attributes[parent_class.__name__] current_attributes = cls.config_attributes[cls.__name__] + # Overwriting attributes in parent operator if they are also existing in the child operator merged_attributes = {attr.name: attr for attr in parent_attributes} for attr in current_attributes: merged_attributes[attr.name] = attr From e62890b5d729f881afe25ef0087b73910d139525 Mon Sep 17 00:00:00 2001 From: claudiazi Date: Thu, 16 Jan 2025 09:26:36 +0100 Subject: [PATCH 122/189] feat: add application name to the spark job & add kill spark job when timeout --- .../operators/spark_submit_operator.py | 49 ++++++++++++++++++- 1 file changed, 47 insertions(+), 2 deletions(-) diff --git a/dagger/dag_creator/airflow/operators/spark_submit_operator.py b/dagger/dag_creator/airflow/operators/spark_submit_operator.py index d9df768..af8c451 100644 --- a/dagger/dag_creator/airflow/operators/spark_submit_operator.py +++ b/dagger/dag_creator/airflow/operators/spark_submit_operator.py @@ -85,10 +85,50 @@ def get_cluster_id_by_name(self, emr_cluster_name, cluster_states): else: return None + + def get_application_id_by_name(self, emr_master_instance_id, application_name): + command = f"yarn application -list -appStates RUNNING | grep {application_name}" + + response = self.ssm_client.send_command( + InstanceIds=[emr_master_instance_id], + DocumentName="AWS-RunShellScript", + Parameters={"commands": [command]} + ) + + command_id = response['Command']['CommandId'] + time.sleep(10) # Wait for the command to execute + + output = self.ssm_client.get_command_invocation( + CommandId=command_id, + InstanceId=emr_master_instance_id + ) + + stdout = output['StandardOutputContent'] + for line in stdout.split('\n'): + if application_name in line: + application_id = line.split()[0] + return application_id + return None + + def kill_spark_job(self, emr_master_instance_id, application_id): + """ + Kill the Spark job using YARN + """ + kill_command = f"yarn application -kill {application_id}" + self.ssm_client.send_command( + InstanceIds=[emr_master_instance_id], + DocumentName="AWS-RunShellScript", + Parameters={"commands": [kill_command]} + ) + raise AirflowException( + f"Spark job exceeded the execution timeout of {self._execution_timeout} seconds and was terminated.") + + def execute(self, context): """ See `execute` method from airflow.operators.bash_operator """ + start_time = time.time() cluster_id = self.get_cluster_id_by_name(self.cluster_name, ["WAITING", "RUNNING"]) emr_master_instance_id = self.emr_client.list_instances(ClusterId=cluster_id, InstanceGroupTypes=["MASTER"], InstanceStates=["RUNNING"])["Instances"][0][ @@ -101,20 +141,25 @@ def execute(self, context): response = self.ssm_client.send_command( InstanceIds=[emr_master_instance_id], DocumentName="AWS-RunShellScript", - Parameters= command_parameters + Parameters=command_parameters ) command_id = response['Command']['CommandId'] status = 'Pending' status_details = None while status in ['Pending', 'InProgress', 'Delayed']: time.sleep(30) + elapsed_time = time.time() - start_time + if self._execution_timeout and elapsed_time > self._execution_timeout: + application_id = self.get_application_id_by_name(emr_master_instance_id, + self.spark_conf_args["application_name"]) + if application_id: + self.kill_spark_job(emr_master_instance_id, application_id) response = self.ssm_client.get_command_invocation(CommandId=command_id, InstanceId=emr_master_instance_id) status = response['Status'] status_details = response['StatusDetails'] self.log.info( self.ssm_client.get_command_invocation(CommandId=command_id, InstanceId=emr_master_instance_id)[ 'StandardErrorContent']) - if status != 'Success': raise AirflowException(f"Spark command failed, check Spark job status in YARN resource manager. " f"Response status details: {status_details}") From 31085acd5aea5a2f18f3094fdc3b72180a6a0907 Mon Sep 17 00:00:00 2001 From: claudiazi Date: Thu, 16 Jan 2025 12:09:24 +0100 Subject: [PATCH 123/189] fix: type of _execution_timeout --- .../dag_creator/airflow/operators/spark_submit_operator.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/dagger/dag_creator/airflow/operators/spark_submit_operator.py b/dagger/dag_creator/airflow/operators/spark_submit_operator.py index af8c451..f3ef0a4 100644 --- a/dagger/dag_creator/airflow/operators/spark_submit_operator.py +++ b/dagger/dag_creator/airflow/operators/spark_submit_operator.py @@ -71,7 +71,6 @@ def get_execution_timeout(self): return None def get_cluster_id_by_name(self, emr_cluster_name, cluster_states): - response = self.emr_client.list_clusters(ClusterStates=cluster_states) matching_clusters = list( filter(lambda cluster: cluster['Name'] == emr_cluster_name, response['Clusters'])) @@ -87,6 +86,9 @@ def get_cluster_id_by_name(self, emr_cluster_name, cluster_states): def get_application_id_by_name(self, emr_master_instance_id, application_name): + """ + Get the application ID of the Spark job + """ command = f"yarn application -list -appStates RUNNING | grep {application_name}" response = self.ssm_client.send_command( @@ -149,7 +151,7 @@ def execute(self, context): while status in ['Pending', 'InProgress', 'Delayed']: time.sleep(30) elapsed_time = time.time() - start_time - if self._execution_timeout and elapsed_time > self._execution_timeout: + if self._execution_timeout and elapsed_time > self._execution_timeout.total_seconds(): application_id = self.get_application_id_by_name(emr_master_instance_id, self.spark_conf_args["application_name"]) if application_id: From 02283d824a110117ecf246c8f6dab72cbdf1301e Mon Sep 17 00:00:00 2001 From: claudiazi Date: Thu, 16 Jan 2025 12:33:59 +0100 Subject: [PATCH 124/189] fix: remove the wrong and uncessary function --- dagger/dag_creator/airflow/operators/spark_submit_operator.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/dagger/dag_creator/airflow/operators/spark_submit_operator.py b/dagger/dag_creator/airflow/operators/spark_submit_operator.py index f3ef0a4..61ab3f4 100644 --- a/dagger/dag_creator/airflow/operators/spark_submit_operator.py +++ b/dagger/dag_creator/airflow/operators/spark_submit_operator.py @@ -165,7 +165,3 @@ def execute(self, context): if status != 'Success': raise AirflowException(f"Spark command failed, check Spark job status in YARN resource manager. " f"Response status details: {status_details}") - - def on_kill(self): - self.log.info("Sending SIGTERM signal to bash process group") - os.killpg(os.getpgid(self.sp.pid), signal.SIGTERM) From c9a1c6f223c54193add2a13777e217adac4f30c4 Mon Sep 17 00:00:00 2001 From: claudiazi Date: Fri, 17 Jan 2025 02:38:43 +0100 Subject: [PATCH 125/189] fix: timeout logic --- .../operators/spark_submit_operator.py | 87 +++++++++++-------- 1 file changed, 52 insertions(+), 35 deletions(-) diff --git a/dagger/dag_creator/airflow/operators/spark_submit_operator.py b/dagger/dag_creator/airflow/operators/spark_submit_operator.py index 61ab3f4..d3156ed 100644 --- a/dagger/dag_creator/airflow/operators/spark_submit_operator.py +++ b/dagger/dag_creator/airflow/operators/spark_submit_operator.py @@ -125,43 +125,60 @@ def kill_spark_job(self, emr_master_instance_id, application_id): raise AirflowException( f"Spark job exceeded the execution timeout of {self._execution_timeout} seconds and was terminated.") - def execute(self, context): """ See `execute` method from airflow.operators.bash_operator """ start_time = time.time() - cluster_id = self.get_cluster_id_by_name(self.cluster_name, ["WAITING", "RUNNING"]) - emr_master_instance_id = self.emr_client.list_instances(ClusterId=cluster_id, InstanceGroupTypes=["MASTER"], - InstanceStates=["RUNNING"])["Instances"][0][ - "Ec2InstanceId"] - - command_parameters = {"commands": [self.spark_submit_cmd]} - if self._execution_timeout: - command_parameters["executionTimeout"] = [self.get_execution_timeout()] - - response = self.ssm_client.send_command( - InstanceIds=[emr_master_instance_id], - DocumentName="AWS-RunShellScript", - Parameters=command_parameters - ) - command_id = response['Command']['CommandId'] - status = 'Pending' - status_details = None - while status in ['Pending', 'InProgress', 'Delayed']: - time.sleep(30) - elapsed_time = time.time() - start_time - if self._execution_timeout and elapsed_time > self._execution_timeout.total_seconds(): - application_id = self.get_application_id_by_name(emr_master_instance_id, - self.spark_conf_args["application_name"]) - if application_id: - self.kill_spark_job(emr_master_instance_id, application_id) - response = self.ssm_client.get_command_invocation(CommandId=command_id, InstanceId=emr_master_instance_id) - status = response['Status'] - status_details = response['StatusDetails'] - self.log.info( - self.ssm_client.get_command_invocation(CommandId=command_id, InstanceId=emr_master_instance_id)[ - 'StandardErrorContent']) - if status != 'Success': - raise AirflowException(f"Spark command failed, check Spark job status in YARN resource manager. " - f"Response status details: {status_details}") + try: + # Get cluster and master node information + cluster_id = self.get_cluster_id_by_name(self.cluster_name, ["WAITING", "RUNNING"]) + emr_master_instance_id = self.emr_client.list_instances( + ClusterId=cluster_id, InstanceGroupTypes=["MASTER"], InstanceStates=["RUNNING"] + )["Instances"][0]["Ec2InstanceId"] + + # Build the command parameters + command_parameters = {"commands": [self.spark_submit_cmd]} + if self._execution_timeout: + command_parameters["executionTimeout"] = [self.get_execution_timeout()] + + # Send the command via SSM + response = self.ssm_client.send_command( + InstanceIds=[emr_master_instance_id], + DocumentName="AWS-RunShellScript", + Parameters=command_parameters + ) + command_id = response['Command']['CommandId'] + status = 'Pending' + status_details = None + + # Monitor the command's execution + while status in ['Pending', 'InProgress', 'Delayed']: + time.sleep(30) + # Check the status of the SSM command + response = self.ssm_client.get_command_invocation( + CommandId=command_id, InstanceId=emr_master_instance_id + ) + status = response['Status'] + status_details = response['StatusDetails'] + + self.log.info( + self.ssm_client.get_command_invocation( + CommandId=command_id, InstanceId=emr_master_instance_id + )['StandardErrorContent'] + ) + + # Raise an exception if the command did not succeed + if status != 'Success': + raise AirflowException(f"Spark command failed, check Spark job status in YARN resource manager. " + f"Response status details: {status_details}") + + except AirflowTaskTimeout: + # Handle task timeout + self.log.error("Task timed out. Attempting to terminate the Spark job.") + application_id = self.get_application_id_by_name( + emr_master_instance_id, self.spark_conf_args["application_name"] + ) + if application_id: + self.kill_spark_job(emr_master_instance_id, application_id) + raise AirflowException("Task timed out and the Spark job was terminated.") From bfb935c61485555618b3e30d5173c502f39a2655 Mon Sep 17 00:00:00 2001 From: claudiazi Date: Fri, 17 Jan 2025 03:07:15 +0100 Subject: [PATCH 126/189] fix: add missing import --- dagger/dag_creator/airflow/operators/spark_submit_operator.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dagger/dag_creator/airflow/operators/spark_submit_operator.py b/dagger/dag_creator/airflow/operators/spark_submit_operator.py index d3156ed..a87e225 100644 --- a/dagger/dag_creator/airflow/operators/spark_submit_operator.py +++ b/dagger/dag_creator/airflow/operators/spark_submit_operator.py @@ -4,7 +4,7 @@ import time import boto3 -from airflow.exceptions import AirflowException +from airflow.exceptions import AirflowException, AirflowTaskTimeout from airflow.utils.decorators import apply_defaults from dagger.dag_creator.airflow.operators.dagger_base_operator import DaggerBaseOperator From 7fba8d564bed48543e54c4f5ef4cdc5c61e63940 Mon Sep 17 00:00:00 2001 From: claudiazi Date: Fri, 17 Jan 2025 04:45:44 +0100 Subject: [PATCH 127/189] fix: spark_app_name --- dagger/dag_creator/airflow/operator_creators/spark_creator.py | 1 + dagger/dag_creator/airflow/operators/spark_submit_operator.py | 4 +++- 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/dagger/dag_creator/airflow/operator_creators/spark_creator.py b/dagger/dag_creator/airflow/operator_creators/spark_creator.py index c48ebda..2d1aae8 100644 --- a/dagger/dag_creator/airflow/operator_creators/spark_creator.py +++ b/dagger/dag_creator/airflow/operator_creators/spark_creator.py @@ -91,6 +91,7 @@ def _create_operator(self, **kwargs): job_args=_parse_args(self._template_parameters), spark_args=_parse_spark_args(self._task.spark_args), spark_conf_args=_parse_spark_args(self._task.spark_conf_args, '=', 'conf '), + spark_app_name=self._task.spark_conf_args.get("spark.app.name", ""), extra_py_files=self._task.extra_py_files, **kwargs, ) diff --git a/dagger/dag_creator/airflow/operators/spark_submit_operator.py b/dagger/dag_creator/airflow/operators/spark_submit_operator.py index a87e225..6501fcb 100644 --- a/dagger/dag_creator/airflow/operators/spark_submit_operator.py +++ b/dagger/dag_creator/airflow/operators/spark_submit_operator.py @@ -25,6 +25,7 @@ def __init__( job_args=None, spark_args=None, spark_conf_args=None, + spark_app_name=None, extra_py_files=None, *args, **kwargs, @@ -34,6 +35,7 @@ def __init__( self.job_args = job_args self.spark_args = spark_args self.spark_conf_args = spark_conf_args + self.spark_app_name = spark_app_name self.extra_py_files = extra_py_files self.cluster_name = cluster_name self._execution_timeout = kwargs.get('execution_timeout') @@ -177,7 +179,7 @@ def execute(self, context): # Handle task timeout self.log.error("Task timed out. Attempting to terminate the Spark job.") application_id = self.get_application_id_by_name( - emr_master_instance_id, self.spark_conf_args["application_name"] + emr_master_instance_id, self.spark_app_name ) if application_id: self.kill_spark_job(emr_master_instance_id, application_id) From 37a31f86d205d907031e37968c80d708454462fc Mon Sep 17 00:00:00 2001 From: claudiazi Date: Fri, 17 Jan 2025 10:25:18 +0100 Subject: [PATCH 128/189] fix: default spark_app_name --- .../operator_creators/spark_creator.py | 2 +- .../operators/spark_submit_operator.py | 35 ++++++++++--------- 2 files changed, 19 insertions(+), 18 deletions(-) diff --git a/dagger/dag_creator/airflow/operator_creators/spark_creator.py b/dagger/dag_creator/airflow/operator_creators/spark_creator.py index 2d1aae8..8212a08 100644 --- a/dagger/dag_creator/airflow/operator_creators/spark_creator.py +++ b/dagger/dag_creator/airflow/operator_creators/spark_creator.py @@ -91,7 +91,7 @@ def _create_operator(self, **kwargs): job_args=_parse_args(self._template_parameters), spark_args=_parse_spark_args(self._task.spark_args), spark_conf_args=_parse_spark_args(self._task.spark_conf_args, '=', 'conf '), - spark_app_name=self._task.spark_conf_args.get("spark.app.name", ""), + spark_app_name=self._task.spark_conf_args.get("spark.app.name", None) if self._task.spark_conf_args else None, extra_py_files=self._task.extra_py_files, **kwargs, ) diff --git a/dagger/dag_creator/airflow/operators/spark_submit_operator.py b/dagger/dag_creator/airflow/operators/spark_submit_operator.py index 6501fcb..1bb4b19 100644 --- a/dagger/dag_creator/airflow/operators/spark_submit_operator.py +++ b/dagger/dag_creator/airflow/operators/spark_submit_operator.py @@ -91,27 +91,28 @@ def get_application_id_by_name(self, emr_master_instance_id, application_name): """ Get the application ID of the Spark job """ - command = f"yarn application -list -appStates RUNNING | grep {application_name}" + if application_name: + command = f"yarn application -list -appStates RUNNING | grep {application_name}" - response = self.ssm_client.send_command( - InstanceIds=[emr_master_instance_id], - DocumentName="AWS-RunShellScript", - Parameters={"commands": [command]} - ) + response = self.ssm_client.send_command( + InstanceIds=[emr_master_instance_id], + DocumentName="AWS-RunShellScript", + Parameters={"commands": [command]} + ) - command_id = response['Command']['CommandId'] - time.sleep(10) # Wait for the command to execute + command_id = response['Command']['CommandId'] + time.sleep(10) # Wait for the command to execute - output = self.ssm_client.get_command_invocation( - CommandId=command_id, - InstanceId=emr_master_instance_id - ) + output = self.ssm_client.get_command_invocation( + CommandId=command_id, + InstanceId=emr_master_instance_id + ) - stdout = output['StandardOutputContent'] - for line in stdout.split('\n'): - if application_name in line: - application_id = line.split()[0] - return application_id + stdout = output['StandardOutputContent'] + for line in stdout.split('\n'): + if application_name in line: + application_id = line.split()[0] + return application_id return None def kill_spark_job(self, emr_master_instance_id, application_id): From a3422db71b3b30df61ebf7d5d1fda36d61651ddd Mon Sep 17 00:00:00 2001 From: claudiazi Date: Mon, 20 Jan 2025 08:13:30 +0100 Subject: [PATCH 129/189] feat: improve the logic to kill the spark job --- .../operators/spark_submit_operator.py | 61 ++++++++++--------- 1 file changed, 33 insertions(+), 28 deletions(-) diff --git a/dagger/dag_creator/airflow/operators/spark_submit_operator.py b/dagger/dag_creator/airflow/operators/spark_submit_operator.py index 1bb4b19..a5df9cc 100644 --- a/dagger/dag_creator/airflow/operators/spark_submit_operator.py +++ b/dagger/dag_creator/airflow/operators/spark_submit_operator.py @@ -1,10 +1,9 @@ import logging import os -import signal import time import boto3 -from airflow.exceptions import AirflowException, AirflowTaskTimeout +from airflow.exceptions import AirflowException from airflow.utils.decorators import apply_defaults from dagger.dag_creator.airflow.operators.dagger_base_operator import DaggerBaseOperator @@ -39,6 +38,8 @@ def __init__( self.extra_py_files = extra_py_files self.cluster_name = cluster_name self._execution_timeout = kwargs.get('execution_timeout') + self._application_id = None + self._emr_master_instance_id = None @property def emr_client(self): @@ -115,18 +116,26 @@ def get_application_id_by_name(self, emr_master_instance_id, application_name): return application_id return None - def kill_spark_job(self, emr_master_instance_id, application_id): - """ - Kill the Spark job using YARN - """ - kill_command = f"yarn application -kill {application_id}" - self.ssm_client.send_command( - InstanceIds=[emr_master_instance_id], - DocumentName="AWS-RunShellScript", - Parameters={"commands": [kill_command]} - ) - raise AirflowException( - f"Spark job exceeded the execution timeout of {self._execution_timeout} seconds and was terminated.") + + def kill_spark_job(self): + if self._application_id and self._emr_master_instance_id: + kill_command = f"yarn application -kill {self._application_id}" + self.ssm_client.send_command( + InstanceIds=[self._emr_master_instance_id], + DocumentName="AWS-RunShellScript", + Parameters={"commands": [kill_command]}, + ) + logging.info( + f"Spark job {self._application_id} terminated successfully." + ) + else: + logging.warning("No application ID or master instance ID found to terminate.") + + + def on_kill(self): + logging.info("Task killed. Attempting to terminate the Spark job.") + self.kill_spark_job() + def execute(self, context): """ @@ -136,7 +145,7 @@ def execute(self, context): try: # Get cluster and master node information cluster_id = self.get_cluster_id_by_name(self.cluster_name, ["WAITING", "RUNNING"]) - emr_master_instance_id = self.emr_client.list_instances( + self._emr_master_instance_id = self.emr_client.list_instances( ClusterId=cluster_id, InstanceGroupTypes=["MASTER"], InstanceStates=["RUNNING"] )["Instances"][0]["Ec2InstanceId"] @@ -147,7 +156,7 @@ def execute(self, context): # Send the command via SSM response = self.ssm_client.send_command( - InstanceIds=[emr_master_instance_id], + InstanceIds=[self._emr_master_instance_id], DocumentName="AWS-RunShellScript", Parameters=command_parameters ) @@ -160,28 +169,24 @@ def execute(self, context): time.sleep(30) # Check the status of the SSM command response = self.ssm_client.get_command_invocation( - CommandId=command_id, InstanceId=emr_master_instance_id + CommandId=command_id, InstanceId=self._emr_master_instance_id ) status = response['Status'] status_details = response['StatusDetails'] self.log.info( self.ssm_client.get_command_invocation( - CommandId=command_id, InstanceId=emr_master_instance_id + CommandId=command_id, InstanceId=self._emr_master_instance_id )['StandardErrorContent'] ) - # Raise an exception if the command did not succeed + # Kill the command and raise an exception if the command did not succeed if status != 'Success': + self.kill_spark_job() raise AirflowException(f"Spark command failed, check Spark job status in YARN resource manager. " f"Response status details: {status_details}") - except AirflowTaskTimeout: - # Handle task timeout - self.log.error("Task timed out. Attempting to terminate the Spark job.") - application_id = self.get_application_id_by_name( - emr_master_instance_id, self.spark_app_name - ) - if application_id: - self.kill_spark_job(emr_master_instance_id, application_id) - raise AirflowException("Task timed out and the Spark job was terminated.") + except Exception as e: + logging.error(f"Error encountered: {str(e)}") + self.kill_spark_job() + raise AirflowException(f"Task failed with error: {str(e)}") From 6046a01243d82010e50cbf9d657131d7dec12082 Mon Sep 17 00:00:00 2001 From: claudiazi Date: Mon, 20 Jan 2025 08:56:54 +0100 Subject: [PATCH 130/189] chore: black + add log + missing function --- .../operators/spark_submit_operator.py | 102 ++++++++++-------- 1 file changed, 59 insertions(+), 43 deletions(-) diff --git a/dagger/dag_creator/airflow/operators/spark_submit_operator.py b/dagger/dag_creator/airflow/operators/spark_submit_operator.py index a5df9cc..d94b9b5 100644 --- a/dagger/dag_creator/airflow/operators/spark_submit_operator.py +++ b/dagger/dag_creator/airflow/operators/spark_submit_operator.py @@ -18,16 +18,16 @@ class SparkSubmitOperator(DaggerBaseOperator): @apply_defaults def __init__( - self, - job_file, - cluster_name, - job_args=None, - spark_args=None, - spark_conf_args=None, - spark_app_name=None, - extra_py_files=None, - *args, - **kwargs, + self, + job_file, + cluster_name, + job_args=None, + spark_args=None, + spark_conf_args=None, + spark_app_name=None, + extra_py_files=None, + *args, + **kwargs, ): super().__init__(*args, **kwargs) self.job_file = job_file @@ -37,7 +37,7 @@ def __init__( self.spark_app_name = spark_app_name self.extra_py_files = extra_py_files self.cluster_name = cluster_name - self._execution_timeout = kwargs.get('execution_timeout') + self._execution_timeout = kwargs.get("execution_timeout") self._application_id = None self._emr_master_instance_id = None @@ -76,47 +76,54 @@ def get_execution_timeout(self): def get_cluster_id_by_name(self, emr_cluster_name, cluster_states): response = self.emr_client.list_clusters(ClusterStates=cluster_states) matching_clusters = list( - filter(lambda cluster: cluster['Name'] == emr_cluster_name, response['Clusters'])) + filter( + lambda cluster: cluster["Name"] == emr_cluster_name, + response["Clusters"], + ) + ) if len(matching_clusters) == 1: - cluster_id = matching_clusters[0]['Id'] - logging.info('Found cluster name = %s id = %s' % (emr_cluster_name, cluster_id)) + cluster_id = matching_clusters[0]["Id"] + logging.info( + "Found cluster name = %s id = %s" % (emr_cluster_name, cluster_id) + ) return cluster_id elif len(matching_clusters) > 1: - raise AirflowException('More than one cluster found for name = %s' % emr_cluster_name) + raise AirflowException( + "More than one cluster found for name = %s" % emr_cluster_name + ) else: return None - def get_application_id_by_name(self, emr_master_instance_id, application_name): """ Get the application ID of the Spark job """ if application_name: - command = f"yarn application -list -appStates RUNNING | grep {application_name}" + command = ( + f"yarn application -list -appStates RUNNING | grep {application_name}" + ) response = self.ssm_client.send_command( InstanceIds=[emr_master_instance_id], DocumentName="AWS-RunShellScript", - Parameters={"commands": [command]} + Parameters={"commands": [command]}, ) - command_id = response['Command']['CommandId'] + command_id = response["Command"]["CommandId"] time.sleep(10) # Wait for the command to execute output = self.ssm_client.get_command_invocation( - CommandId=command_id, - InstanceId=emr_master_instance_id + CommandId=command_id, InstanceId=emr_master_instance_id ) - stdout = output['StandardOutputContent'] - for line in stdout.split('\n'): + stdout = output["StandardOutputContent"] + for line in stdout.split("\n"): if application_name in line: application_id = line.split()[0] return application_id return None - def kill_spark_job(self): if self._application_id and self._emr_master_instance_id: kill_command = f"yarn application -kill {self._application_id}" @@ -125,28 +132,29 @@ def kill_spark_job(self): DocumentName="AWS-RunShellScript", Parameters={"commands": [kill_command]}, ) - logging.info( - f"Spark job {self._application_id} terminated successfully." - ) + logging.info(f"Spark job {self._application_id} terminated successfully.") else: - logging.warning("No application ID or master instance ID found to terminate.") - + logging.warning( + "No application ID or master instance ID found to terminate." + ) def on_kill(self): logging.info("Task killed. Attempting to terminate the Spark job.") self.kill_spark_job() - def execute(self, context): """ See `execute` method from airflow.operators.bash_operator """ - start_time = time.time() try: # Get cluster and master node information - cluster_id = self.get_cluster_id_by_name(self.cluster_name, ["WAITING", "RUNNING"]) + cluster_id = self.get_cluster_id_by_name( + self.cluster_name, ["WAITING", "RUNNING"] + ) self._emr_master_instance_id = self.emr_client.list_instances( - ClusterId=cluster_id, InstanceGroupTypes=["MASTER"], InstanceStates=["RUNNING"] + ClusterId=cluster_id, + InstanceGroupTypes=["MASTER"], + InstanceStates=["RUNNING"], )["Instances"][0]["Ec2InstanceId"] # Build the command parameters @@ -158,33 +166,41 @@ def execute(self, context): response = self.ssm_client.send_command( InstanceIds=[self._emr_master_instance_id], DocumentName="AWS-RunShellScript", - Parameters=command_parameters + Parameters=command_parameters, ) - command_id = response['Command']['CommandId'] - status = 'Pending' + command_id = response["Command"]["CommandId"] + status = "Pending" status_details = None + self._application_id = self.get_application_id_by_name( + self._emr_master_instance_id, self.spark_app_name + ) + self.log.info( + f"emr:{self._emr_master_instance_id}, application_name:{self.spark_app_name}, application_id: {self._application_id}" + ) # Monitor the command's execution - while status in ['Pending', 'InProgress', 'Delayed']: + while status in ["Pending", "InProgress", "Delayed"]: time.sleep(30) # Check the status of the SSM command response = self.ssm_client.get_command_invocation( CommandId=command_id, InstanceId=self._emr_master_instance_id ) - status = response['Status'] - status_details = response['StatusDetails'] + status = response["Status"] + status_details = response["StatusDetails"] self.log.info( self.ssm_client.get_command_invocation( CommandId=command_id, InstanceId=self._emr_master_instance_id - )['StandardErrorContent'] + )["StandardErrorContent"] ) # Kill the command and raise an exception if the command did not succeed - if status != 'Success': + if status != "Success": self.kill_spark_job() - raise AirflowException(f"Spark command failed, check Spark job status in YARN resource manager. " - f"Response status details: {status_details}") + raise AirflowException( + f"Spark command failed, check Spark job status in YARN resource manager. " + f"Response status details: {status_details}" + ) except Exception as e: logging.error(f"Error encountered: {str(e)}") From 57c394b6db3a47d3cec71a84fb7f6b18f7db1c9f Mon Sep 17 00:00:00 2001 From: claudiazi Date: Mon, 20 Jan 2025 09:46:31 +0100 Subject: [PATCH 131/189] chore: add info to debug --- .../operators/spark_submit_operator.py | 39 +++++++++++++++++++ 1 file changed, 39 insertions(+) diff --git a/dagger/dag_creator/airflow/operators/spark_submit_operator.py b/dagger/dag_creator/airflow/operators/spark_submit_operator.py index d94b9b5..fd26f12 100644 --- a/dagger/dag_creator/airflow/operators/spark_submit_operator.py +++ b/dagger/dag_creator/airflow/operators/spark_submit_operator.py @@ -117,13 +117,46 @@ def get_application_id_by_name(self, emr_master_instance_id, application_name): CommandId=command_id, InstanceId=emr_master_instance_id ) + self.log.info(f"ouotput: {output}") + stdout = output["StandardOutputContent"] + self.log.info(f"stdout: {stdout}") for line in stdout.split("\n"): if application_name in line: application_id = line.split()[0] return application_id return None + def get_application_id_by_name(self, emr_master_instance_id, application_name): + """ + Get the application ID of the Spark job + """ + if application_name: + command = f"yarn application -list -appStates RUNNING | grep {application_name}" + + response = self.ssm_client.send_command( + InstanceIds=[emr_master_instance_id], + DocumentName="AWS-RunShellScript", + Parameters={"commands": [command]} + ) + + command_id = response['Command']['CommandId'] + time.sleep(10) # Wait for the command to execute + + output = self.ssm_client.get_command_invocation( + CommandId=command_id, + InstanceId=emr_master_instance_id + ) + + stdout = output['StandardOutputContent'] + for line in stdout.split('\n'): + if application_name in line: + application_id = line.split()[0] + return application_id + return None + + + def kill_spark_job(self): if self._application_id and self._emr_master_instance_id: kill_command = f"yarn application -kill {self._application_id}" @@ -203,6 +236,12 @@ def execute(self, context): ) except Exception as e: + self._application_id = self.get_application_id_by_name( + self._emr_master_instance_id, self.spark_app_name + ) + self.log.info( + f"emr:{self._emr_master_instance_id}, application_name:{self.spark_app_name}, application_id: {self._application_id}" + ) logging.error(f"Error encountered: {str(e)}") self.kill_spark_job() raise AirflowException(f"Task failed with error: {str(e)}") From 6fdb6740b5c338c4d48dbf10abc2bfcdad925432 Mon Sep 17 00:00:00 2001 From: claudiazi Date: Mon, 20 Jan 2025 10:05:32 +0100 Subject: [PATCH 132/189] chore: add info to debug --- .../dag_creator/airflow/operators/spark_submit_operator.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/dagger/dag_creator/airflow/operators/spark_submit_operator.py b/dagger/dag_creator/airflow/operators/spark_submit_operator.py index fd26f12..aff9bfb 100644 --- a/dagger/dag_creator/airflow/operators/spark_submit_operator.py +++ b/dagger/dag_creator/airflow/operators/spark_submit_operator.py @@ -156,8 +156,13 @@ def get_application_id_by_name(self, emr_master_instance_id, application_name): return None - def kill_spark_job(self): + self._application_id = self.get_application_id_by_name( + self._emr_master_instance_id, self.spark_app_name + ) + self.log.info( + f"emr:{self._emr_master_instance_id}, application_name:{self.spark_app_name}, application_id: {self._application_id}" + ) if self._application_id and self._emr_master_instance_id: kill_command = f"yarn application -kill {self._application_id}" self.ssm_client.send_command( From a5d627cc4e4ab8e0cab2356ebc81ff93859646fd Mon Sep 17 00:00:00 2001 From: claudiazi Date: Mon, 20 Jan 2025 10:53:11 +0100 Subject: [PATCH 133/189] chore: reformat --- .../operators/spark_submit_operator.py | 47 ------------------- 1 file changed, 47 deletions(-) diff --git a/dagger/dag_creator/airflow/operators/spark_submit_operator.py b/dagger/dag_creator/airflow/operators/spark_submit_operator.py index aff9bfb..31f6a70 100644 --- a/dagger/dag_creator/airflow/operators/spark_submit_operator.py +++ b/dagger/dag_creator/airflow/operators/spark_submit_operator.py @@ -117,52 +117,17 @@ def get_application_id_by_name(self, emr_master_instance_id, application_name): CommandId=command_id, InstanceId=emr_master_instance_id ) - self.log.info(f"ouotput: {output}") - stdout = output["StandardOutputContent"] - self.log.info(f"stdout: {stdout}") for line in stdout.split("\n"): if application_name in line: application_id = line.split()[0] return application_id return None - def get_application_id_by_name(self, emr_master_instance_id, application_name): - """ - Get the application ID of the Spark job - """ - if application_name: - command = f"yarn application -list -appStates RUNNING | grep {application_name}" - - response = self.ssm_client.send_command( - InstanceIds=[emr_master_instance_id], - DocumentName="AWS-RunShellScript", - Parameters={"commands": [command]} - ) - - command_id = response['Command']['CommandId'] - time.sleep(10) # Wait for the command to execute - - output = self.ssm_client.get_command_invocation( - CommandId=command_id, - InstanceId=emr_master_instance_id - ) - - stdout = output['StandardOutputContent'] - for line in stdout.split('\n'): - if application_name in line: - application_id = line.split()[0] - return application_id - return None - - def kill_spark_job(self): self._application_id = self.get_application_id_by_name( self._emr_master_instance_id, self.spark_app_name ) - self.log.info( - f"emr:{self._emr_master_instance_id}, application_name:{self.spark_app_name}, application_id: {self._application_id}" - ) if self._application_id and self._emr_master_instance_id: kill_command = f"yarn application -kill {self._application_id}" self.ssm_client.send_command( @@ -209,12 +174,6 @@ def execute(self, context): command_id = response["Command"]["CommandId"] status = "Pending" status_details = None - self._application_id = self.get_application_id_by_name( - self._emr_master_instance_id, self.spark_app_name - ) - self.log.info( - f"emr:{self._emr_master_instance_id}, application_name:{self.spark_app_name}, application_id: {self._application_id}" - ) # Monitor the command's execution while status in ["Pending", "InProgress", "Delayed"]: @@ -241,12 +200,6 @@ def execute(self, context): ) except Exception as e: - self._application_id = self.get_application_id_by_name( - self._emr_master_instance_id, self.spark_app_name - ) - self.log.info( - f"emr:{self._emr_master_instance_id}, application_name:{self.spark_app_name}, application_id: {self._application_id}" - ) logging.error(f"Error encountered: {str(e)}") self.kill_spark_job() raise AirflowException(f"Task failed with error: {str(e)}") From 3ef83b3fba580bba026d66a75ea4e57319c7a2ba Mon Sep 17 00:00:00 2001 From: David Siklosi Date: Thu, 13 Feb 2025 13:43:51 +0100 Subject: [PATCH 134/189] Bumping tenacity version --- reqs/base.txt | 2 +- setup.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/reqs/base.txt b/reqs/base.txt index d9cc38a..279ed4b 100644 --- a/reqs/base.txt +++ b/reqs/base.txt @@ -4,4 +4,4 @@ envyaml==1.10.211231 mergedeep==1.3.4 slack==0.0.2 slackclient==2.9.4 -tenacity==8.2.3 +tenacity~=8.3.0 diff --git a/setup.py b/setup.py index 080a5bb..f8b4b28 100644 --- a/setup.py +++ b/setup.py @@ -59,6 +59,6 @@ def reqs(*f): packages=find_packages(), tests_require=test_requires, url="https://gitlab.com/goflash1/data/dagger", - version="0.9.0", + version="0.9.1", zip_safe=False, ) From 10f77cf86c12e3b5d8f591659be698c07d1e172b Mon Sep 17 00:00:00 2001 From: claudiazi Date: Wed, 19 Feb 2025 13:59:51 +0100 Subject: [PATCH 135/189] feat: add param _full_refresh in the reverse_etl task --- .../airflow/operator_creators/reverse_etl_creator.py | 3 +++ dagger/pipeline/tasks/reverse_etl_task.py | 12 ++++++++++++ 2 files changed, 15 insertions(+) diff --git a/dagger/dag_creator/airflow/operator_creators/reverse_etl_creator.py b/dagger/dag_creator/airflow/operator_creators/reverse_etl_creator.py index e133e40..7ac32d7 100644 --- a/dagger/dag_creator/airflow/operator_creators/reverse_etl_creator.py +++ b/dagger/dag_creator/airflow/operator_creators/reverse_etl_creator.py @@ -27,6 +27,7 @@ def __init__(self, task, dag): self._days_to_live = task.days_to_live self._output_type = task.output_type self._region_name = task.region_name + self._full_refresh = task.full_refresh def _generate_command(self): command = BatchCreator._generate_command(self) @@ -56,6 +57,8 @@ def _generate_command(self): command.append(f"--days_to_live={self._days_to_live}") if self._region_name: command.append(f"--region_name={self._region_name}") + if self._full_refresh: + command.append(f"--full_refresh={self._full_refresh}") return command diff --git a/dagger/pipeline/tasks/reverse_etl_task.py b/dagger/pipeline/tasks/reverse_etl_task.py index 6c9a5d2..30ce352 100644 --- a/dagger/pipeline/tasks/reverse_etl_task.py +++ b/dagger/pipeline/tasks/reverse_etl_task.py @@ -116,6 +116,13 @@ def init_attributes(cls, orig_cls): comment="The number of days to keep the data in the table. If provided, the time_to_live attribute " "will be set in dynamodb", ), + Attribute( + attribute_name="full_refresh", + parent_fields=["task_parameters"], + validator=bool, + required=False, + comment="If set to True, the job will perform a full refresh instead of an incremental one", + ) ] ) @@ -140,6 +147,7 @@ def __init__(self, name, pipeline_name, pipeline, job_config): self._updated_at_column = self.parse_attribute("updated_at_column") self._from_time = self.parse_attribute("from_time") self._days_to_live = self.parse_attribute("days_to_live") + self._full_refresh = self.parse_attribute("full_refresh") if self._hash_column and self._updated_at_column: raise ValueError(f"ReverseETLTask: {self._name} hash_column and updated_at_column are mutually exclusive") @@ -237,3 +245,7 @@ def output_type(self): @property def region_name(self): return self._region_name + + @property + def full_refresh(self): + return self._full_refresh From 13ccf9e3ed4c786ec98dff93265109c0579cd4a1 Mon Sep 17 00:00:00 2001 From: raimundovidaljunior Date: Tue, 18 Mar 2025 17:19:54 +0100 Subject: [PATCH 136/189] Create soda runner --- dagger/conf.py | 10 ++ .../airflow/operator_creators/soda_creator.py | 59 ++++++++ .../dag_creator/airflow/operator_factory.py | 1 + .../airflow/operators/soda_batch.py | 5 + dagger/pipeline/task_factory.py | 1 + dagger/pipeline/tasks/soda_task.py | 143 ++++++++++++++++++ 6 files changed, 219 insertions(+) create mode 100644 dagger/dag_creator/airflow/operator_creators/soda_creator.py create mode 100644 dagger/dag_creator/airflow/operators/soda_batch.py create mode 100644 dagger/pipeline/tasks/soda_task.py diff --git a/dagger/conf.py b/dagger/conf.py index df2ab8e..c024e6b 100644 --- a/dagger/conf.py +++ b/dagger/conf.py @@ -110,3 +110,13 @@ REVERSE_ETL_DEFAULT_JOB_NAME = reverse_etl_config.get('default_job_name', None) REVERSE_ETL_DEFAULT_EXECUTABLE_PREFIX = reverse_etl_config.get('default_executable_prefix', None) REVERSE_ETL_DEFAULT_EXECUTABLE = reverse_etl_config.get('default_executable', None) + +# Soda parameters +SODA_DEFAULT_JOB_NAME = reverse_etl_config.get('default_job_name', None) +SODA_DEFAULT_EXECUTABLE_PREFIX = reverse_etl_config.get('default_executable_prefix', None) +SODA_DEFAULT_EXECUTABLE = reverse_etl_config.get('default_executable', None) +SODA_DEFAULT_PROJECT_DIR = reverse_etl_config.get('default_project_dir', None) +SODA_DEFAULT_PROFILES_DIR = reverse_etl_config.get('default_profiles_dir', None) +SODA_DEFAULT_PROFILE_NAME = reverse_etl_config.get('default_profile_name', None) +SODA_DEFAULT_OUTPUT_TABLE = reverse_etl_config.get('default_output_table', None) +SODA_DEFAULT_OUTPUT_S3_PATH = reverse_etl_config.get('default_output_s3_path', None) diff --git a/dagger/dag_creator/airflow/operator_creators/soda_creator.py b/dagger/dag_creator/airflow/operator_creators/soda_creator.py new file mode 100644 index 0000000..a0c593d --- /dev/null +++ b/dagger/dag_creator/airflow/operator_creators/soda_creator.py @@ -0,0 +1,59 @@ +import base64 + +from dagger.dag_creator.airflow.operator_creators.batch_creator import BatchCreator +from dagger.dag_creator.airflow.operators.soda_batch import SodaBatchOperator +import json + + +class SodaCreator(BatchCreator): + ref_name = "soda" + + def __init__(self, task, dag): + super().__init__(task, dag) + + self._absolute_job_name = task.absolute_job_name + self._project_dir = task.project_dir + self._profiles_dir = task.profiles_dir + self._profile_name = task.profile_name + self._target_name = task.target_name + self._table_name = task.table_name + self._model_name = task.model_name + self._output_s3_path = task.output_s3_path + self._output_table = task.output_table + self._vars = task.vars + + def _generate_command(self): + command = BatchCreator._generate_command(self) + + command.append(f"--project_dir={self._project_dir}") + command.append(f"--profiles_dir={self._profiles_dir}") + command.append(f"--profile_name={self._profile_name}") + command.append(f"--target_name={self._target_name}") + command.append(f"--output_s3_path={self._output_s3_path}") + command.append(f"--output_table={self._output_table}") + + if self._table_name: + command.append(f"--table_name={self._table_name}") + if self._model_name: + command.append(f"--model_name={self._model_name}") + if self._vars: + command.append(f"--vars={self._vars}") + return command + + def _create_operator(self, **kwargs): + overrides = self._task.overrides + overrides.update({"command": self._generate_command()}) + + job_name = self._validate_job_name(self._task.job_name, self._task.absolute_job_name) + batch_op = SodaBatchOperator( + dag=self._dag, + task_id=self._task.name, + job_name=self._task.name, + job_definition=job_name, + region_name=self._task.region_name, + job_queue=self._task.job_queue, + container_overrides=overrides, + awslogs_enabled=True, + **kwargs, + ) + return batch_op diff --git a/dagger/dag_creator/airflow/operator_factory.py b/dagger/dag_creator/airflow/operator_factory.py index f610f1e..2a1654a 100644 --- a/dagger/dag_creator/airflow/operator_factory.py +++ b/dagger/dag_creator/airflow/operator_factory.py @@ -13,6 +13,7 @@ reverse_etl_creator, spark_creator, sqoop_creator, + soda_creator, ) from dagger.dag_creator.airflow.utils.operator_factories import make_control_flow from dagger.utilities.classes import get_deep_obj_subclasses diff --git a/dagger/dag_creator/airflow/operators/soda_batch.py b/dagger/dag_creator/airflow/operators/soda_batch.py new file mode 100644 index 0000000..d67a26c --- /dev/null +++ b/dagger/dag_creator/airflow/operators/soda_batch.py @@ -0,0 +1,5 @@ +from dagger.dag_creator.airflow.operators.awsbatch_operator import AWSBatchOperator + +class SodaBatchOperator(AWSBatchOperator): + custom_operator_name = 'Soda' + ui_color = "#e4f0e7" diff --git a/dagger/pipeline/task_factory.py b/dagger/pipeline/task_factory.py index d8a1e53..9ed79e7 100644 --- a/dagger/pipeline/task_factory.py +++ b/dagger/pipeline/task_factory.py @@ -12,6 +12,7 @@ reverse_etl_task, spark_task, sqoop_task, + soda_task ) from dagger.utilities.classes import get_deep_obj_subclasses diff --git a/dagger/pipeline/tasks/soda_task.py b/dagger/pipeline/tasks/soda_task.py new file mode 100644 index 0000000..25a90f7 --- /dev/null +++ b/dagger/pipeline/tasks/soda_task.py @@ -0,0 +1,143 @@ +from dagger.pipeline.tasks.batch_task import BatchTask +from dagger.utilities.config_validator import Attribute +from dagger import conf + +class SodaTask(BatchTask): + ref_name = "soda" + + @classmethod + def init_attributes(cls, orig_cls): + cls.add_config_attributes( + [ + Attribute( + attribute_name="executable_prefix", + required=False, + parent_fields=["task_parameters"], + comment="E.g.: python", + ), + Attribute( + attribute_name="executable", + required=False, + parent_fields=["task_parameters"], + comment="E.g.: my_code.py", + ), + Attribute( + attribute_name="project_dir", + parent_fields=["task_parameters"], + required = True, + validator=str, + comment="Directory containing the dbt_project.yml file", + ), + Attribute( + attribute_name="profiles_dir", + parent_fields=["task_parameters"], + required=True, + comment="Directory containing the profiles.yml file", + ), + Attribute( + attribute_name="profile_name", + parent_fields=["task_parameters"], + required=True, + comment="Profile name to load from the profiles.yml file.", + ), + Attribute( + attribute_name="target_name", + parent_fields=["task_parameters"], + validator=str, + required=True, + comment="Target to load for the given profile. By default use 'ENV' environment variable.", + ), + Attribute( + attribute_name="table_name", + parent_fields=["task_parameters"], + validator=str, + required=False, + comment="Full table name in the format 'database.schema.table'", + ), + Attribute( + attribute_name="model_name", + parent_fields=["task_parameters"], + validator=str, + required=False, + comment="Name of dbt model to be scanned by soda", + ), + Attribute( + attribute_name="output_s3_path", + parent_fields=["task_parameters"], + validator=str, + required=True, + comment="S3 location to upload the scan results", + + ), + Attribute( + attribute_name="output_table", + parent_fields=["task_parameters"], + validator=str, + required=True, + comment="Athena table that will contain the scan results.", + ), + Attribute( + attribute_name="vars", + parent_fields=["task_parameters"], + validator=str, + required=False, + comment="Variables needed to run soda scan", + ) + + ] + ) + + def __init__(self, name, pipeline_name, pipeline, job_config): + super().__init__(name, pipeline_name, pipeline, job_config) + + self._executable = self.executable or conf.SODA_DEFAULT_EXECUTABLE + self._executable_prefix = self.executable_prefix or conf.SODA_DEFAULT_EXECUTABLE_PREFIX + self._absolute_job_name = self._absolute_job_name or conf.SODA_DEFAULT_JOB_NAME + self._project_dir = self.parse_attribute("project_dir") or conf.SODA_DEFAULT_PROJECT_DIR + self._profiles_dir = self.parse_attribute("profiles_dir") or conf.SODA_DEFAULT_PROFILES_DIR + self._profile_name = self.parse_attribute("profile_name") or conf.SODA_DEFAULT_PROFILE_NAME + self._output_table = self.parse_attribute("output_path") or conf.SODA_DEFAULT_OUTPUT_TABLE + self._output_s3_path = self.parse_attribute("output_s3_path") or conf.SODA_DEFAULT_OUTPUT_S3_PATH + + self._table_name = self.parse_attribute("table_name") + self._model_name = self.parse_attribute("model_name") + self._vars = self.parse_attribute("vars") + + + if self._table_name and self._model_name: + raise ValueError(f"SodaTask: {self._name} table_name and model_name are mutually exclusive") + + + + @property + def project_dir(self): + return self._project_dir + + @property + def profiles_dir(self): + return self._profiles_dir + + @property + def profile_name(self): + return self._profile_name + + @property + def output_table(self): + return self._output_table + + @property + def output_s3_path(self): + return self._output_s3_path + + @property + def table_name(self): + return self._table_name + + @property + def model_name(self): + return self._model_name + + @property + def vars(self): + return self._vars + From 2f7cc55c4731b7b0d0c1c66c9a812d789ba43cc9 Mon Sep 17 00:00:00 2001 From: raimundovidaljunior Date: Wed, 19 Mar 2025 14:40:41 +0100 Subject: [PATCH 137/189] Fixing configs --- dagger/conf.py | 17 +++++++++-------- dagger/dagger_config.yaml | 10 ++++++++++ dagger/pipeline/tasks/soda_task.py | 17 +++++++++++------ 3 files changed, 30 insertions(+), 14 deletions(-) diff --git a/dagger/conf.py b/dagger/conf.py index c024e6b..ae3d9ec 100644 --- a/dagger/conf.py +++ b/dagger/conf.py @@ -112,11 +112,12 @@ REVERSE_ETL_DEFAULT_EXECUTABLE = reverse_etl_config.get('default_executable', None) # Soda parameters -SODA_DEFAULT_JOB_NAME = reverse_etl_config.get('default_job_name', None) -SODA_DEFAULT_EXECUTABLE_PREFIX = reverse_etl_config.get('default_executable_prefix', None) -SODA_DEFAULT_EXECUTABLE = reverse_etl_config.get('default_executable', None) -SODA_DEFAULT_PROJECT_DIR = reverse_etl_config.get('default_project_dir', None) -SODA_DEFAULT_PROFILES_DIR = reverse_etl_config.get('default_profiles_dir', None) -SODA_DEFAULT_PROFILE_NAME = reverse_etl_config.get('default_profile_name', None) -SODA_DEFAULT_OUTPUT_TABLE = reverse_etl_config.get('default_output_table', None) -SODA_DEFAULT_OUTPUT_S3_PATH = reverse_etl_config.get('default_output_s3_path', None) +soda_config = config.get('soda', None) or {} +SODA_DEFAULT_JOB_NAME = soda_config.get('default_job_name', None) +SODA_DEFAULT_EXECUTABLE_PREFIX = soda_config.get('default_executable_prefix', None) +SODA_DEFAULT_EXECUTABLE = soda_config.get('default_executable', None) +SODA_DEFAULT_PROJECT_DIR = soda_config.get('default_project_dir', None) +SODA_DEFAULT_PROFILES_DIR = soda_config.get('default_profiles_dir', None) +SODA_DEFAULT_PROFILE_NAME = soda_config.get('default_profile_name', None) +SODA_DEFAULT_OUTPUT_TABLE = soda_config.get('default_output_table', None) +SODA_DEFAULT_OUTPUT_S3_PATH = soda_config.get('default_output_s3_path', None) diff --git a/dagger/dagger_config.yaml b/dagger/dagger_config.yaml index 38abccd..209f110 100644 --- a/dagger/dagger_config.yaml +++ b/dagger/dagger_config.yaml @@ -67,3 +67,13 @@ reverse_etl: # default_job_name: # default_executable_prefix: # default_executable: + +soda: +# default_job_name: +# default_executable_prefix: +# default_executable: +# default_project_dir: +# default_profiles_dir: +# default_profile_name: +# default_output_table: +# default_output_s3_path: \ No newline at end of file diff --git a/dagger/pipeline/tasks/soda_task.py b/dagger/pipeline/tasks/soda_task.py index 25a90f7..7aeebf9 100644 --- a/dagger/pipeline/tasks/soda_task.py +++ b/dagger/pipeline/tasks/soda_task.py @@ -24,20 +24,20 @@ def init_attributes(cls, orig_cls): Attribute( attribute_name="project_dir", parent_fields=["task_parameters"], - required = True, + required = False, validator=str, comment="Directory containing the dbt_project.yml file", ), Attribute( attribute_name="profiles_dir", parent_fields=["task_parameters"], - required=True, + required=False, comment="Directory containing the profiles.yml file", ), Attribute( attribute_name="profile_name", parent_fields=["task_parameters"], - required=True, + required=False, comment="Profile name to load from the profiles.yml file.", ), Attribute( @@ -65,7 +65,7 @@ def init_attributes(cls, orig_cls): attribute_name="output_s3_path", parent_fields=["task_parameters"], validator=str, - required=True, + required=False, comment="S3 location to upload the scan results", ), @@ -73,7 +73,7 @@ def init_attributes(cls, orig_cls): attribute_name="output_table", parent_fields=["task_parameters"], validator=str, - required=True, + required=False, comment="Athena table that will contain the scan results.", ), Attribute( @@ -96,8 +96,9 @@ def __init__(self, name, pipeline_name, pipeline, job_config): self._project_dir = self.parse_attribute("project_dir") or conf.SODA_DEFAULT_PROJECT_DIR self._profiles_dir = self.parse_attribute("profiles_dir") or conf.SODA_DEFAULT_PROFILES_DIR self._profile_name = self.parse_attribute("profile_name") or conf.SODA_DEFAULT_PROFILE_NAME - self._output_table = self.parse_attribute("output_path") or conf.SODA_DEFAULT_OUTPUT_TABLE + self._output_table = self.parse_attribute("output_table") or conf.SODA_DEFAULT_OUTPUT_TABLE self._output_s3_path = self.parse_attribute("output_s3_path") or conf.SODA_DEFAULT_OUTPUT_S3_PATH + self._target_name = self.parse_attribute("target_name") self._table_name = self.parse_attribute("table_name") self._model_name = self.parse_attribute("model_name") @@ -141,3 +142,7 @@ def model_name(self): def vars(self): return self._vars + @property + def target_name(self): + return self._target_name + From 47f4d83dc6533c7d61a72c87cc2be34f489d798d Mon Sep 17 00:00:00 2001 From: raimundovidaljunior Date: Mon, 24 Mar 2025 17:27:06 +0100 Subject: [PATCH 138/189] Remove unnecessary vars --- dagger/conf.py | 3 - .../airflow/operator_creators/soda_creator.py | 12 +--- dagger/dagger_config.yaml | 3 - dagger/pipeline/tasks/soda_task.py | 64 +------------------ 4 files changed, 3 insertions(+), 79 deletions(-) diff --git a/dagger/conf.py b/dagger/conf.py index ae3d9ec..b750d84 100644 --- a/dagger/conf.py +++ b/dagger/conf.py @@ -116,8 +116,5 @@ SODA_DEFAULT_JOB_NAME = soda_config.get('default_job_name', None) SODA_DEFAULT_EXECUTABLE_PREFIX = soda_config.get('default_executable_prefix', None) SODA_DEFAULT_EXECUTABLE = soda_config.get('default_executable', None) -SODA_DEFAULT_PROJECT_DIR = soda_config.get('default_project_dir', None) -SODA_DEFAULT_PROFILES_DIR = soda_config.get('default_profiles_dir', None) -SODA_DEFAULT_PROFILE_NAME = soda_config.get('default_profile_name', None) SODA_DEFAULT_OUTPUT_TABLE = soda_config.get('default_output_table', None) SODA_DEFAULT_OUTPUT_S3_PATH = soda_config.get('default_output_s3_path', None) diff --git a/dagger/dag_creator/airflow/operator_creators/soda_creator.py b/dagger/dag_creator/airflow/operator_creators/soda_creator.py index a0c593d..5f145c8 100644 --- a/dagger/dag_creator/airflow/operator_creators/soda_creator.py +++ b/dagger/dag_creator/airflow/operator_creators/soda_creator.py @@ -12,12 +12,7 @@ def __init__(self, task, dag): super().__init__(task, dag) self._absolute_job_name = task.absolute_job_name - self._project_dir = task.project_dir - self._profiles_dir = task.profiles_dir - self._profile_name = task.profile_name - self._target_name = task.target_name self._table_name = task.table_name - self._model_name = task.model_name self._output_s3_path = task.output_s3_path self._output_table = task.output_table self._vars = task.vars @@ -25,17 +20,12 @@ def __init__(self, task, dag): def _generate_command(self): command = BatchCreator._generate_command(self) - command.append(f"--project_dir={self._project_dir}") - command.append(f"--profiles_dir={self._profiles_dir}") - command.append(f"--profile_name={self._profile_name}") - command.append(f"--target_name={self._target_name}") + command.append(f"--output_s3_path={self._output_s3_path}") command.append(f"--output_table={self._output_table}") if self._table_name: command.append(f"--table_name={self._table_name}") - if self._model_name: - command.append(f"--model_name={self._model_name}") if self._vars: command.append(f"--vars={self._vars}") return command diff --git a/dagger/dagger_config.yaml b/dagger/dagger_config.yaml index 209f110..c7b291a 100644 --- a/dagger/dagger_config.yaml +++ b/dagger/dagger_config.yaml @@ -72,8 +72,5 @@ soda: # default_job_name: # default_executable_prefix: # default_executable: -# default_project_dir: -# default_profiles_dir: -# default_profile_name: # default_output_table: # default_output_s3_path: \ No newline at end of file diff --git a/dagger/pipeline/tasks/soda_task.py b/dagger/pipeline/tasks/soda_task.py index 7aeebf9..443f98a 100644 --- a/dagger/pipeline/tasks/soda_task.py +++ b/dagger/pipeline/tasks/soda_task.py @@ -21,45 +21,13 @@ def init_attributes(cls, orig_cls): parent_fields=["task_parameters"], comment="E.g.: my_code.py", ), - Attribute( - attribute_name="project_dir", - parent_fields=["task_parameters"], - required = False, - validator=str, - comment="Directory containing the dbt_project.yml file", - ), - Attribute( - attribute_name="profiles_dir", - parent_fields=["task_parameters"], - required=False, - comment="Directory containing the profiles.yml file", - ), - Attribute( - attribute_name="profile_name", - parent_fields=["task_parameters"], - required=False, - comment="Profile name to load from the profiles.yml file.", - ), - Attribute( - attribute_name="target_name", - parent_fields=["task_parameters"], - validator=str, - required=True, - comment="Target to load for the given profile. By default use 'ENV' environment variable.", - ), Attribute( attribute_name="table_name", parent_fields=["task_parameters"], validator=str, required=False, - comment="Full table name in the format 'database.schema.table'", - ), - Attribute( - attribute_name="model_name", - parent_fields=["task_parameters"], - validator=str, - required=False, - comment="Name of dbt model to be scanned by soda", + comment="Full table name in the format 'database.schema.table' By default it is" + " set to the name of the input .
", ), Attribute( attribute_name="output_s3_path", @@ -93,34 +61,13 @@ def __init__(self, name, pipeline_name, pipeline, job_config): self._executable = self.executable or conf.SODA_DEFAULT_EXECUTABLE self._executable_prefix = self.executable_prefix or conf.SODA_DEFAULT_EXECUTABLE_PREFIX self._absolute_job_name = self._absolute_job_name or conf.SODA_DEFAULT_JOB_NAME - self._project_dir = self.parse_attribute("project_dir") or conf.SODA_DEFAULT_PROJECT_DIR - self._profiles_dir = self.parse_attribute("profiles_dir") or conf.SODA_DEFAULT_PROFILES_DIR - self._profile_name = self.parse_attribute("profile_name") or conf.SODA_DEFAULT_PROFILE_NAME self._output_table = self.parse_attribute("output_table") or conf.SODA_DEFAULT_OUTPUT_TABLE self._output_s3_path = self.parse_attribute("output_s3_path") or conf.SODA_DEFAULT_OUTPUT_S3_PATH - self._target_name = self.parse_attribute("target_name") - self._table_name = self.parse_attribute("table_name") - self._model_name = self.parse_attribute("model_name") self._vars = self.parse_attribute("vars") - if self._table_name and self._model_name: - raise ValueError(f"SodaTask: {self._name} table_name and model_name are mutually exclusive") - - - @property - def project_dir(self): - return self._project_dir - - @property - def profiles_dir(self): - return self._profiles_dir - - @property - def profile_name(self): - return self._profile_name @property def output_table(self): @@ -134,15 +81,8 @@ def output_s3_path(self): def table_name(self): return self._table_name - @property - def model_name(self): - return self._model_name - @property def vars(self): return self._vars - @property - def target_name(self): - return self._target_name From 75bb61f7b773a0fac2af8a22e2f7cc0b06ddc54b Mon Sep 17 00:00:00 2001 From: raimundovidaljunior Date: Fri, 28 Mar 2025 15:48:19 +0100 Subject: [PATCH 139/189] Adding column mapping and case conversion --- .../operator_creators/reverse_etl_creator.py | 9 +++++ dagger/pipeline/tasks/reverse_etl_task.py | 37 +++++++++++++++++++ 2 files changed, 46 insertions(+) diff --git a/dagger/dag_creator/airflow/operator_creators/reverse_etl_creator.py b/dagger/dag_creator/airflow/operator_creators/reverse_etl_creator.py index 7ac32d7..6ee4637 100644 --- a/dagger/dag_creator/airflow/operator_creators/reverse_etl_creator.py +++ b/dagger/dag_creator/airflow/operator_creators/reverse_etl_creator.py @@ -28,6 +28,9 @@ def __init__(self, task, dag): self._output_type = task.output_type self._region_name = task.region_name self._full_refresh = task.full_refresh + self._target_case = task.target_case + self._source_case = task.source_case + self._column_mapping = task.column_mapping def _generate_command(self): command = BatchCreator._generate_command(self) @@ -59,6 +62,12 @@ def _generate_command(self): command.append(f"--region_name={self._region_name}") if self._full_refresh: command.append(f"--full_refresh={self._full_refresh}") + if self._target_case: + command.append(f"--target_case={self._target_case}") + if self._source_case: + command.append(f"--source_case={self._source_case}") + if self._column_mapping: + command.append(f"--column_mapping={self._column_mapping}") return command diff --git a/dagger/pipeline/tasks/reverse_etl_task.py b/dagger/pipeline/tasks/reverse_etl_task.py index 30ce352..d4ead6a 100644 --- a/dagger/pipeline/tasks/reverse_etl_task.py +++ b/dagger/pipeline/tasks/reverse_etl_task.py @@ -122,8 +122,30 @@ def init_attributes(cls, orig_cls): validator=bool, required=False, comment="If set to True, the job will perform a full refresh instead of an incremental one", + ), + Attribute( + attribute_name="target_case", + parent_fields=["task_parameters"], + validator=str, + required=False, + comment="Target column case for DynamoDB. 'snake' leaves columns in snake_case; 'camel' converts to camelCase.", + ), + Attribute( + attribute_name="source_case", + parent_fields=["task_parameters"], + validator=str, + required=False, + comment="Source dataset column case. Specify the case of the incoming dataset." + ), + Attribute( + attribute_name="column_mapping", + parent_fields=["task_parameters"], + validator=str, + required=False, + comment='Optional JSON string for column mappings. Example: \'{"id": "chat_id"}\'', ) + ] ) @@ -148,6 +170,9 @@ def __init__(self, name, pipeline_name, pipeline, job_config): self._from_time = self.parse_attribute("from_time") self._days_to_live = self.parse_attribute("days_to_live") self._full_refresh = self.parse_attribute("full_refresh") + self._target_case = self.parse_attribute("target_case") + self._source_case = self.parse_attribute("source_case") + self._column_mapping = self.parse_attribute("column_mapping") if self._hash_column and self._updated_at_column: raise ValueError(f"ReverseETLTask: {self._name} hash_column and updated_at_column are mutually exclusive") @@ -249,3 +274,15 @@ def region_name(self): @property def full_refresh(self): return self._full_refresh + + @property + def target_case(self): + return self._target_case + + @property + def source_case(self): + return self._source_case + + @property + def column_mapping(self): + return self._column_mapping From 57f3c35e90bb13b9636b0a43fa0b482707d59132 Mon Sep 17 00:00:00 2001 From: raimundovidaljunior Date: Tue, 1 Apr 2025 11:31:13 +0200 Subject: [PATCH 140/189] Adjust reverse_etl --- .../operator_creators/reverse_etl_creator.py | 16 +++- dagger/pipeline/tasks/reverse_etl_task.py | 73 ++++++++++++------- 2 files changed, 60 insertions(+), 29 deletions(-) diff --git a/dagger/dag_creator/airflow/operator_creators/reverse_etl_creator.py b/dagger/dag_creator/airflow/operator_creators/reverse_etl_creator.py index 6ee4637..0bd9355 100644 --- a/dagger/dag_creator/airflow/operator_creators/reverse_etl_creator.py +++ b/dagger/dag_creator/airflow/operator_creators/reverse_etl_creator.py @@ -18,8 +18,6 @@ def __init__(self, task, dag): self._primary_id_column = task.primary_id_column self._secondary_id_column = task.secondary_id_column self._custom_id_column = task.custom_id_column - self._model_name = task.model_name - self._project_name = task.project_name self._is_deleted_column = task.is_deleted_column self._hash_column = task.hash_column self._updated_at_column = task.updated_at_column @@ -31,6 +29,10 @@ def __init__(self, task, dag): self._target_case = task.target_case self._source_case = task.source_case self._column_mapping = task.column_mapping + self._glue_registry_name = self.parse_attribute("glue_registry_name") + self._glue_schema_name = self.parse_attribute("glue_schema_name") + self._sort_key = self.parse_attribute("sort_key") + self._custom_columns = self.parse_attribute("custom_columns") def _generate_command(self): command = BatchCreator._generate_command(self) @@ -38,9 +40,8 @@ def _generate_command(self): command.append(f"--num_threads={self._num_threads}") command.append(f"--batch_size={self._batch_size}") command.append(f"--primary_id_column={self._primary_id_column}") - command.append(f"--model_name={self._model_name}") - command.append(f"--project_name={self._project_name}") command.append(f"--output_type={self._output_type}") + command.append(f"--glue_registry_name={self._glue_registry_name}") if self._assume_role_arn: command.append(f"--assume_role_arn={self._assume_role_arn}") @@ -68,6 +69,13 @@ def _generate_command(self): command.append(f"--source_case={self._source_case}") if self._column_mapping: command.append(f"--column_mapping={self._column_mapping}") + if self._glue_schema_name: + command.append(f"--glue_schema_name={self._glue_schema_name}") + if self._sort_key: + command.append(f"--sort_key={self._sort_key}") + if self._custom_columns: + command.append(f"--custom_columns={json.dumps(self._custom_columns)}") + return command diff --git a/dagger/pipeline/tasks/reverse_etl_task.py b/dagger/pipeline/tasks/reverse_etl_task.py index d4ead6a..c0c6290 100644 --- a/dagger/pipeline/tasks/reverse_etl_task.py +++ b/dagger/pipeline/tasks/reverse_etl_task.py @@ -60,21 +60,6 @@ def init_attributes(cls, orig_cls): required=False, comment="The custom key column to use for the job", ), - Attribute( - attribute_name="model_name", - parent_fields=["task_parameters"], - validator=str, - required=False, - comment="The name of the model. This is going to be a column on the target table. By default it is" - " set to the name of the input .
", - ), - Attribute( - attribute_name="project_name", - parent_fields=["task_parameters"], - validator=str, - required=True, - comment="The name of the project. This is going to be a column on the target table.", - ), Attribute( attribute_name="is_deleted_column", parent_fields=["task_parameters"], @@ -143,6 +128,34 @@ def init_attributes(cls, orig_cls): validator=str, required=False, comment='Optional JSON string for column mappings. Example: \'{"id": "chat_id"}\'', + ), + Attribute( + attribute_name="glue_registry_name", + parent_fields=["task_parameters"], + validator=str, + required=True, + comment='AWS Glue Registry name', + ), + Attribute( + attribute_name="glue_schema_name", + parent_fields=["task_parameters"], + validator=str, + required=False, + comment='AWS Glue Schema name. output_name will be used if not provided', + ), + Attribute( + attribute_name="sort_key", + parent_fields=["task_parameters"], + validator=str, + required=False, + comment='Optional JSON string for sort key composition using #.join(). Example: \'{"sort_key": ["project", "model_name", "secondary_id", "custom_id"]}\'', + ), + Attribute( + attribute_name="custom_columns", + parent_fields=["task_parameters"], + validator=str, + required=False, + comment='Optional JSON string for additional custom columns from static values. Example: \'{"custom_project": "ProjectXYZ", "model_name": "ModelABC"}\'' ) @@ -162,8 +175,6 @@ def __init__(self, name, pipeline_name, pipeline, job_config): self._primary_id_column = self.parse_attribute("primary_id_column") self._secondary_id_column = self.parse_attribute("secondary_id_column") self._custom_id_column = self.parse_attribute("custom_id_column") - self._model_name = self.parse_attribute("model_name") - self._project_name = self.parse_attribute("project_name") self._is_deleted_column = self.parse_attribute("is_deleted_column") self._hash_column = self.parse_attribute("hash_column") self._updated_at_column = self.parse_attribute("updated_at_column") @@ -173,6 +184,10 @@ def __init__(self, name, pipeline_name, pipeline, job_config): self._target_case = self.parse_attribute("target_case") self._source_case = self.parse_attribute("source_case") self._column_mapping = self.parse_attribute("column_mapping") + self._glue_registry_name = self.parse_attribute("glue_registry_name") + self._glue_schema_name = self.parse_attribute("glue_schema_name") + self._sort_key = self.parse_attribute("sort_key") + self._custom_columns = self.parse_attribute("custom_columns") if self._hash_column and self._updated_at_column: raise ValueError(f"ReverseETLTask: {self._name} hash_column and updated_at_column are mutually exclusive") @@ -235,14 +250,6 @@ def secondary_id_column(self): def custom_id_column(self): return self._custom_id_column - @property - def model_name(self): - return self._model_name - - @property - def project_name(self): - return self._project_name - @property def is_deleted_column(self): return self._is_deleted_column @@ -286,3 +293,19 @@ def source_case(self): @property def column_mapping(self): return self._column_mapping + + @property + def glue_registry_name(self): + return self._glue_registry_name + + @property + def glue_schema_name(self): + return self._glue_schema_name + + @property + def sort_key(self): + return self._sort_key + + @property + def custom_columns(self): + return self._custom_columns From 7e10650b96ece93b4eb53c6265da73de2b8573c7 Mon Sep 17 00:00:00 2001 From: raimundovidaljunior Date: Tue, 1 Apr 2025 12:07:56 +0200 Subject: [PATCH 141/189] fixing error --- .../airflow/operator_creators/reverse_etl_creator.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/dagger/dag_creator/airflow/operator_creators/reverse_etl_creator.py b/dagger/dag_creator/airflow/operator_creators/reverse_etl_creator.py index 0bd9355..8c3385d 100644 --- a/dagger/dag_creator/airflow/operator_creators/reverse_etl_creator.py +++ b/dagger/dag_creator/airflow/operator_creators/reverse_etl_creator.py @@ -29,10 +29,10 @@ def __init__(self, task, dag): self._target_case = task.target_case self._source_case = task.source_case self._column_mapping = task.column_mapping - self._glue_registry_name = self.parse_attribute("glue_registry_name") - self._glue_schema_name = self.parse_attribute("glue_schema_name") - self._sort_key = self.parse_attribute("sort_key") - self._custom_columns = self.parse_attribute("custom_columns") + self._glue_registry_name = task.glue_registry_name + self._glue_schema_name = task.glue_schema_name + self._sort_key = task.sort_key + self._custom_columns = task.custom_columns def _generate_command(self): command = BatchCreator._generate_command(self) From 78ce5b2f8f2a7798e8eb021f1f82a4146974c8cc Mon Sep 17 00:00:00 2001 From: raimundovidaljunior Date: Tue, 1 Apr 2025 14:28:23 +0200 Subject: [PATCH 142/189] fixing custom columns --- .../airflow/operator_creators/reverse_etl_creator.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dagger/dag_creator/airflow/operator_creators/reverse_etl_creator.py b/dagger/dag_creator/airflow/operator_creators/reverse_etl_creator.py index 8c3385d..5ab3910 100644 --- a/dagger/dag_creator/airflow/operator_creators/reverse_etl_creator.py +++ b/dagger/dag_creator/airflow/operator_creators/reverse_etl_creator.py @@ -74,7 +74,7 @@ def _generate_command(self): if self._sort_key: command.append(f"--sort_key={self._sort_key}") if self._custom_columns: - command.append(f"--custom_columns={json.dumps(self._custom_columns)}") + command.append(f"--custom_columns={self._custom_columns}") return command From 64bb82210823df26c78f456da6a30da7be0fafcc Mon Sep 17 00:00:00 2001 From: raimundovidaljunior Date: Tue, 15 Apr 2025 12:19:36 +0200 Subject: [PATCH 143/189] Removing unnecessary inputs/outputs from command --- .../dag_creator/airflow/operator_creators/dbt_creator.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/dagger/dag_creator/airflow/operator_creators/dbt_creator.py b/dagger/dag_creator/airflow/operator_creators/dbt_creator.py index 9be9ee8..b461306 100644 --- a/dagger/dag_creator/airflow/operator_creators/dbt_creator.py +++ b/dagger/dag_creator/airflow/operator_creators/dbt_creator.py @@ -34,7 +34,8 @@ def _generate_command(self): if self._create_external_athena_table: command.append(f"--create_external_athena_table={self._create_external_athena_table}") for param_name, param_value in self._template_parameters.items(): - command.append( - f"--{param_name}={param_value}" - ) + if param_name == 'output_s3_path': + command.append( + f"--{param_name}={param_value}" + ) return command From 0e008dc0b44f7d645151731ce050bec44bdafb91 Mon Sep 17 00:00:00 2001 From: claudiazi Date: Tue, 22 Apr 2025 15:32:12 +0200 Subject: [PATCH 144/189] feat: add param _columns_to_include and _columns_to_exclude in the reverse_etl task --- .../operator_creators/reverse_etl_creator.py | 7 ++++- dagger/pipeline/tasks/reverse_etl_task.py | 26 +++++++++++++++++-- 2 files changed, 30 insertions(+), 3 deletions(-) diff --git a/dagger/dag_creator/airflow/operator_creators/reverse_etl_creator.py b/dagger/dag_creator/airflow/operator_creators/reverse_etl_creator.py index 5ab3910..b8b90d5 100644 --- a/dagger/dag_creator/airflow/operator_creators/reverse_etl_creator.py +++ b/dagger/dag_creator/airflow/operator_creators/reverse_etl_creator.py @@ -33,6 +33,8 @@ def __init__(self, task, dag): self._glue_schema_name = task.glue_schema_name self._sort_key = task.sort_key self._custom_columns = task.custom_columns + self._columns_to_include = task.columns_to_include + self._columns_to_exclude = task.columns_to_exclude def _generate_command(self): command = BatchCreator._generate_command(self) @@ -75,7 +77,10 @@ def _generate_command(self): command.append(f"--sort_key={self._sort_key}") if self._custom_columns: command.append(f"--custom_columns={self._custom_columns}") - + if self._columns_to_include: + command.append(f"--columns_to_include={self._columns_to_include}") + if self._columns_to_exclude: + command.append(f"--columns_to_exclude={self._columns_to_exclude}") return command diff --git a/dagger/pipeline/tasks/reverse_etl_task.py b/dagger/pipeline/tasks/reverse_etl_task.py index c0c6290..59463ab 100644 --- a/dagger/pipeline/tasks/reverse_etl_task.py +++ b/dagger/pipeline/tasks/reverse_etl_task.py @@ -156,9 +156,21 @@ def init_attributes(cls, orig_cls): validator=str, required=False, comment='Optional JSON string for additional custom columns from static values. Example: \'{"custom_project": "ProjectXYZ", "model_name": "ModelABC"}\'' + ), + Attribute( + attribute_name="columns_to_include", + parent_fields=["task_parameters"], + validator=str, + required=False, + comment='Optional comma-separated list of columns to include in the job. Example: \'column1,column2,column3\', if not provided, all columns will be included', + ), + Attribute( + attribute_name="columns_to_exclude", + parent_fields=["task_parameters"], + validator=str, + required=False, + comment='Optional comma-separated list of columns to exclude from the job. Example: \'column1,column2,column3\', if not provided, all columns will be included', ) - - ] ) @@ -188,6 +200,8 @@ def __init__(self, name, pipeline_name, pipeline, job_config): self._glue_schema_name = self.parse_attribute("glue_schema_name") self._sort_key = self.parse_attribute("sort_key") self._custom_columns = self.parse_attribute("custom_columns") + self._columns_to_include = self.parse_attribute("columns_to_include") + self._columns_to_exclude = self.parse_attribute("columns_to_exclude") if self._hash_column and self._updated_at_column: raise ValueError(f"ReverseETLTask: {self._name} hash_column and updated_at_column are mutually exclusive") @@ -309,3 +323,11 @@ def sort_key(self): @property def custom_columns(self): return self._custom_columns + + @property + def columns_to_include(self): + return self._columns_to_include + + @property + def columns_to_exclude(self): + return self._columns_to_exclude From 001f19c805789ca6daf35135ba870ff9a2b10acd Mon Sep 17 00:00:00 2001 From: claudiazi Date: Wed, 23 Apr 2025 10:19:52 +0200 Subject: [PATCH 145/189] chore: improve the naming of new params --- dagger/pipeline/tasks/reverse_etl_task.py | 23 +++++++++++++---------- 1 file changed, 13 insertions(+), 10 deletions(-) diff --git a/dagger/pipeline/tasks/reverse_etl_task.py b/dagger/pipeline/tasks/reverse_etl_task.py index 59463ab..7442d42 100644 --- a/dagger/pipeline/tasks/reverse_etl_task.py +++ b/dagger/pipeline/tasks/reverse_etl_task.py @@ -158,18 +158,18 @@ def init_attributes(cls, orig_cls): comment='Optional JSON string for additional custom columns from static values. Example: \'{"custom_project": "ProjectXYZ", "model_name": "ModelABC"}\'' ), Attribute( - attribute_name="columns_to_include", + attribute_name="input_table_columns_to_include", parent_fields=["task_parameters"], validator=str, required=False, - comment='Optional comma-separated list of columns to include in the job. Example: \'column1,column2,column3\', if not provided, all columns will be included', + comment='Optional comma-separated list of columns to include in the job. Example: \'column1,column2,column3\', if not provided, all columns of input table will be included', ), Attribute( - attribute_name="columns_to_exclude", + attribute_name="input_table_columns_to_exclude", parent_fields=["task_parameters"], validator=str, required=False, - comment='Optional comma-separated list of columns to exclude from the job. Example: \'column1,column2,column3\', if not provided, all columns will be included', + comment='Optional comma-separated list of columns to exclude from the job. Example: \'column1,column2,column3\', if not provided, all columns of input table will be included', ) ] ) @@ -200,12 +200,15 @@ def __init__(self, name, pipeline_name, pipeline, job_config): self._glue_schema_name = self.parse_attribute("glue_schema_name") self._sort_key = self.parse_attribute("sort_key") self._custom_columns = self.parse_attribute("custom_columns") - self._columns_to_include = self.parse_attribute("columns_to_include") - self._columns_to_exclude = self.parse_attribute("columns_to_exclude") + self._input_table_columns_to_include = self.parse_attribute("input_table_columns_to_include") + self._input_table_columns_to_exclude = self.parse_attribute("input_table_columns_to_exclude") if self._hash_column and self._updated_at_column: raise ValueError(f"ReverseETLTask: {self._name} hash_column and updated_at_column are mutually exclusive") + if self._input_table_columns_to_include and not self._input_table_columns_to_exclude: + raise ValueError(f"ReverseETLTask: {self._name} _input_table_columns_to_include and _input_table_columns_to_exclude are mutually exclusive") + if self._hash_column or self._updated_at_column: if not self._from_time: raise ValueError(f"ReverseETLTask: {self._name} from_time is required when hash_column or updated_at_column is provided") @@ -325,9 +328,9 @@ def custom_columns(self): return self._custom_columns @property - def columns_to_include(self): - return self._columns_to_include + def input_table_columns_to_include(self): + return self._input_table_columns_to_include @property - def columns_to_exclude(self): - return self._columns_to_exclude + def input_table_columns_to_exclude(self): + return self._input_table_columns_to_exclude From b3a96dc277ad55deb4b6536f75360cfb99147d42 Mon Sep 17 00:00:00 2001 From: claudiazi Date: Wed, 23 Apr 2025 10:35:43 +0200 Subject: [PATCH 146/189] fix: error message for _input_table_columns_to_include and _input_table_columns_to_exclude --- dagger/pipeline/tasks/reverse_etl_task.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dagger/pipeline/tasks/reverse_etl_task.py b/dagger/pipeline/tasks/reverse_etl_task.py index 7442d42..6160bfb 100644 --- a/dagger/pipeline/tasks/reverse_etl_task.py +++ b/dagger/pipeline/tasks/reverse_etl_task.py @@ -206,7 +206,7 @@ def __init__(self, name, pipeline_name, pipeline, job_config): if self._hash_column and self._updated_at_column: raise ValueError(f"ReverseETLTask: {self._name} hash_column and updated_at_column are mutually exclusive") - if self._input_table_columns_to_include and not self._input_table_columns_to_exclude: + if self._input_table_columns_to_include and self._input_table_columns_to_exclude: raise ValueError(f"ReverseETLTask: {self._name} _input_table_columns_to_include and _input_table_columns_to_exclude are mutually exclusive") if self._hash_column or self._updated_at_column: From 18d0ca2efeee4084e3fbc2ef60c2b314a676351a Mon Sep 17 00:00:00 2001 From: claudiazi Date: Wed, 23 Apr 2025 10:46:09 +0200 Subject: [PATCH 147/189] chore: improve the naming of params --- .../airflow/operator_creators/reverse_etl_creator.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/dagger/dag_creator/airflow/operator_creators/reverse_etl_creator.py b/dagger/dag_creator/airflow/operator_creators/reverse_etl_creator.py index b8b90d5..8446bcb 100644 --- a/dagger/dag_creator/airflow/operator_creators/reverse_etl_creator.py +++ b/dagger/dag_creator/airflow/operator_creators/reverse_etl_creator.py @@ -33,8 +33,8 @@ def __init__(self, task, dag): self._glue_schema_name = task.glue_schema_name self._sort_key = task.sort_key self._custom_columns = task.custom_columns - self._columns_to_include = task.columns_to_include - self._columns_to_exclude = task.columns_to_exclude + self._input_table_columns_to_include = task.input_table_columns_to_include + self._input_table_columns_to_exclude = task.input_table_columns_to_exclude def _generate_command(self): command = BatchCreator._generate_command(self) @@ -77,10 +77,10 @@ def _generate_command(self): command.append(f"--sort_key={self._sort_key}") if self._custom_columns: command.append(f"--custom_columns={self._custom_columns}") - if self._columns_to_include: - command.append(f"--columns_to_include={self._columns_to_include}") - if self._columns_to_exclude: - command.append(f"--columns_to_exclude={self._columns_to_exclude}") + if self._input_table_columns_to_include: + command.append(f"--input_table_columns_to_include={self._input_table_columns_to_include}") + if self._input_table_columns_to_exclude: + command.append(f"--input_table_columns_to_exclude={self._input_table_columns_to_exclude}") return command From c775839f0f83560979fb28a720bf7d999cc527cd Mon Sep 17 00:00:00 2001 From: claudiazi Date: Fri, 25 Apr 2025 09:47:52 +0200 Subject: [PATCH 148/189] fix: wrong validator for full_refresh --- dagger/pipeline/tasks/reverse_etl_task.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dagger/pipeline/tasks/reverse_etl_task.py b/dagger/pipeline/tasks/reverse_etl_task.py index c0c6290..10c013d 100644 --- a/dagger/pipeline/tasks/reverse_etl_task.py +++ b/dagger/pipeline/tasks/reverse_etl_task.py @@ -104,7 +104,7 @@ def init_attributes(cls, orig_cls): Attribute( attribute_name="full_refresh", parent_fields=["task_parameters"], - validator=bool, + validator=str, required=False, comment="If set to True, the job will perform a full refresh instead of an incremental one", ), From 164f7bb38ca1996349dd7b047ae5b357e26e4119 Mon Sep 17 00:00:00 2001 From: raimundovidaljunior Date: Mon, 28 Apr 2025 11:53:10 +0200 Subject: [PATCH 149/189] Adding critical test flag --- .../dag_creator/airflow/operator_creators/soda_creator.py | 4 +++- dagger/pipeline/tasks/soda_task.py | 8 ++++++++ 2 files changed, 11 insertions(+), 1 deletion(-) diff --git a/dagger/dag_creator/airflow/operator_creators/soda_creator.py b/dagger/dag_creator/airflow/operator_creators/soda_creator.py index 5f145c8..8b1f68a 100644 --- a/dagger/dag_creator/airflow/operator_creators/soda_creator.py +++ b/dagger/dag_creator/airflow/operator_creators/soda_creator.py @@ -15,6 +15,7 @@ def __init__(self, task, dag): self._table_name = task.table_name self._output_s3_path = task.output_s3_path self._output_table = task.output_table + self._is_critical_test = task.is_critical_test self._vars = task.vars def _generate_command(self): @@ -23,7 +24,8 @@ def _generate_command(self): command.append(f"--output_s3_path={self._output_s3_path}") command.append(f"--output_table={self._output_table}") - + if self._is_critical_test: + command.append(f"--is_critical_test={self._is_critical_test}") if self._table_name: command.append(f"--table_name={self._table_name}") if self._vars: diff --git a/dagger/pipeline/tasks/soda_task.py b/dagger/pipeline/tasks/soda_task.py index 443f98a..b07a2b9 100644 --- a/dagger/pipeline/tasks/soda_task.py +++ b/dagger/pipeline/tasks/soda_task.py @@ -44,6 +44,14 @@ def init_attributes(cls, orig_cls): required=False, comment="Athena table that will contain the scan results.", ), + Attribute( + attribute_name="is_critical_test", + parent_fields=["task_parameters"], + validator=str, + required=False, + comment="True if test run is critical test. Defaults to False", + + ), Attribute( attribute_name="vars", parent_fields=["task_parameters"], From 4adb4e60e5b5c5528d72b356117d07ef6b6d8fc5 Mon Sep 17 00:00:00 2001 From: raimundovidaljunior Date: Mon, 28 Apr 2025 16:39:52 +0200 Subject: [PATCH 150/189] ADding missing params --- dagger/pipeline/tasks/soda_task.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/dagger/pipeline/tasks/soda_task.py b/dagger/pipeline/tasks/soda_task.py index b07a2b9..e4c0343 100644 --- a/dagger/pipeline/tasks/soda_task.py +++ b/dagger/pipeline/tasks/soda_task.py @@ -72,6 +72,7 @@ def __init__(self, name, pipeline_name, pipeline, job_config): self._output_table = self.parse_attribute("output_table") or conf.SODA_DEFAULT_OUTPUT_TABLE self._output_s3_path = self.parse_attribute("output_s3_path") or conf.SODA_DEFAULT_OUTPUT_S3_PATH self._table_name = self.parse_attribute("table_name") + self._is_critical_test = self.parse_attribute("is_critical_test") self._vars = self.parse_attribute("vars") @@ -88,6 +89,9 @@ def output_s3_path(self): @property def table_name(self): return self._table_name + @property + def is_critical_test(self): + return self._is_critical_test @property def vars(self): From abd2b4c3a55ca118d9f9556a2711f490cdfe5e3f Mon Sep 17 00:00:00 2001 From: claudiazi Date: Wed, 4 Jun 2025 16:13:35 +0200 Subject: [PATCH 151/189] feat: new param for s3 destination --- .../operator_creators/reverse_etl_creator.py | 6 +++++ dagger/pipeline/tasks/reverse_etl_task.py | 24 +++++++++++++++++++ 2 files changed, 30 insertions(+) diff --git a/dagger/dag_creator/airflow/operator_creators/reverse_etl_creator.py b/dagger/dag_creator/airflow/operator_creators/reverse_etl_creator.py index 8446bcb..c276201 100644 --- a/dagger/dag_creator/airflow/operator_creators/reverse_etl_creator.py +++ b/dagger/dag_creator/airflow/operator_creators/reverse_etl_creator.py @@ -35,6 +35,8 @@ def __init__(self, task, dag): self._custom_columns = task.custom_columns self._input_table_columns_to_include = task.input_table_columns_to_include self._input_table_columns_to_exclude = task.input_table_columns_to_exclude + self._file_format = task.file_format + self._file_prefix = task.file_prefix def _generate_command(self): command = BatchCreator._generate_command(self) @@ -81,6 +83,10 @@ def _generate_command(self): command.append(f"--input_table_columns_to_include={self._input_table_columns_to_include}") if self._input_table_columns_to_exclude: command.append(f"--input_table_columns_to_exclude={self._input_table_columns_to_exclude}") + if self._file_format: + command.append(f"--file_format={self._file_format}") + if self._file_prefix: + command.append(f"--file_prefix={self._file_prefix}") return command diff --git a/dagger/pipeline/tasks/reverse_etl_task.py b/dagger/pipeline/tasks/reverse_etl_task.py index 3211195..e242cfc 100644 --- a/dagger/pipeline/tasks/reverse_etl_task.py +++ b/dagger/pipeline/tasks/reverse_etl_task.py @@ -170,6 +170,20 @@ def init_attributes(cls, orig_cls): validator=str, required=False, comment='Optional comma-separated list of columns to exclude from the job. Example: \'column1,column2,column3\', if not provided, all columns of input table will be included', + ), + Attribute( + attribute_name="file_format", + parent_fields=["task_parameters"], + validator=str, + required=False, + comment="File format for S3 output: 'json' or 'parquet' (required when output_type is 's3')", + ), + Attribute( + attribute_name="file_prefix", + parent_fields=["task_parameters"], + validator=str, + required=False, + comment="File prefix for S3 output files", ) ] ) @@ -202,6 +216,8 @@ def __init__(self, name, pipeline_name, pipeline, job_config): self._custom_columns = self.parse_attribute("custom_columns") self._input_table_columns_to_include = self.parse_attribute("input_table_columns_to_include") self._input_table_columns_to_exclude = self.parse_attribute("input_table_columns_to_exclude") + self._file_format = self.parse_attribute("file_format") + self._file_prefix = self.parse_attribute("file_prefix") if self._hash_column and self._updated_at_column: raise ValueError(f"ReverseETLTask: {self._name} hash_column and updated_at_column are mutually exclusive") @@ -334,3 +350,11 @@ def input_table_columns_to_include(self): @property def input_table_columns_to_exclude(self): return self._input_table_columns_to_exclude + + @property + def file_format(self): + return self._file_format + + @property + def file_prefix(self): + return self._file_prefix From 9c90caba9e6754c5a970e3449403b6b1057e70c7 Mon Sep 17 00:00:00 2001 From: claudiazi Date: Wed, 4 Jun 2025 17:28:39 +0200 Subject: [PATCH 152/189] feat: add region_name for s3 io --- dagger/pipeline/ios/s3_io.py | 17 ++++++++++++----- tests/fixtures/pipeline/ios/s3_io.yaml | 3 ++- tests/pipeline/ios/test_s3_io.py | 17 +++++++++++++---- 3 files changed, 27 insertions(+), 10 deletions(-) diff --git a/dagger/pipeline/ios/s3_io.py b/dagger/pipeline/ios/s3_io.py index decb0fa..f834a22 100644 --- a/dagger/pipeline/ios/s3_io.py +++ b/dagger/pipeline/ios/s3_io.py @@ -11,6 +11,11 @@ class S3IO(IO): def init_attributes(cls, orig_cls): cls.add_config_attributes( [ + Attribute( + attribute_name="region_name", + required=False, + comment="Only needed for cross region S3 buckets" + ), Attribute( attribute_name="s3_protocol", required=False, @@ -24,22 +29,21 @@ def init_attributes(cls, orig_cls): def __init__(self, io_config, config_location): super().__init__(io_config, config_location) + self._region_name = self.parse_attribute("region_name") self._s3_protocol = self.parse_attribute("s3_protocol") or "s3" self._bucket = normpath(self.parse_attribute("bucket")) self._path = normpath(self.parse_attribute("path")) def alias(self): - return "s3://{path}".format(path=join(self._bucket, self._path)) + return f"s3://{self._region_name or ''}/{join(self._bucket, self._path)}" @property def rendered_name(self): - return "{protocol}://{path}".format( - protocol=self._s3_protocol, path=join(self._bucket, self._path) - ) + return f"{self._s3_protocol}://{join(self._bucket, self._path)}" @property def airflow_name(self): - return "s3-{}".format(join(self._bucket, self._path).replace("/", "-")) + return f"s3-{'-'.join([name_part for name_part in [self._region_name, join(self._bucket, self._path).replace('/', '-')] if name_part])}" @property def bucket(self): @@ -49,3 +53,6 @@ def bucket(self): def path(self): return self._path + @property + def region_name(self): + return self._region_name diff --git a/tests/fixtures/pipeline/ios/s3_io.yaml b/tests/fixtures/pipeline/ios/s3_io.yaml index e5e213f..082fadb 100644 --- a/tests/fixtures/pipeline/ios/s3_io.yaml +++ b/tests/fixtures/pipeline/ios/s3_io.yaml @@ -1,4 +1,5 @@ type: s3 name: test_s3 bucket: test_bucket -path: test_path \ No newline at end of file +path: test_path +region_name: eu_west_1 diff --git a/tests/pipeline/ios/test_s3_io.py b/tests/pipeline/ios/test_s3_io.py index 61e8cd4..0391880 100644 --- a/tests/pipeline/ios/test_s3_io.py +++ b/tests/pipeline/ios/test_s3_io.py @@ -12,14 +12,23 @@ def setUp(self) -> None: def test_properties(self): db_io = s3_io.S3IO(self.config, "/") - self.assertEqual(db_io.alias(), "s3://test_bucket/test_path") + self.assertEqual(db_io.alias(), "s3://eu_west_1/test_bucket/test_path") self.assertEqual(db_io.rendered_name, "s3://test_bucket/test_path") - self.assertEqual(db_io.airflow_name, "s3-test_bucket-test_path") + self.assertEqual(db_io.airflow_name, "s3-eu_west_1-test_bucket-test_path") def test_with_protocol(self): self.config['s3_protocol'] = 's3a' db_io = s3_io.S3IO(self.config, "/") - self.assertEqual(db_io.alias(), "s3://test_bucket/test_path") + self.assertEqual(db_io.alias(), "s3://eu_west_1/test_bucket/test_path") self.assertEqual(db_io.rendered_name, "s3a://test_bucket/test_path") - self.assertEqual(db_io.airflow_name, "s3-test_bucket-test_path") + self.assertEqual(db_io.airflow_name, "s3-eu_west_1-test_bucket-test_path") + + def test_with_region_name(self): + self.config['region_name'] = 'us-west-2' + db_io = s3_io.S3IO(self.config, "/") + + self.assertEqual(db_io.alias(), "s3://us-west-2/test_bucket/test_path") + self.assertEqual(db_io.rendered_name, "s3://test_bucket/test_path") + self.assertEqual(db_io.airflow_name, "s3-us-west-2-test_bucket-test_path") + self.assertEqual(db_io.region_name, "us-west-2") From 3fc8ffd2c890d5eb9b20c048884a0cfa9ce92334 Mon Sep 17 00:00:00 2001 From: claudiazi Date: Thu, 5 Jun 2025 10:58:44 +0200 Subject: [PATCH 153/189] feat: make batch_size and num_threads optional --- .../airflow/operator_creators/reverse_etl_creator.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/dagger/dag_creator/airflow/operator_creators/reverse_etl_creator.py b/dagger/dag_creator/airflow/operator_creators/reverse_etl_creator.py index c276201..58b4b65 100644 --- a/dagger/dag_creator/airflow/operator_creators/reverse_etl_creator.py +++ b/dagger/dag_creator/airflow/operator_creators/reverse_etl_creator.py @@ -40,13 +40,14 @@ def __init__(self, task, dag): def _generate_command(self): command = BatchCreator._generate_command(self) - - command.append(f"--num_threads={self._num_threads}") - command.append(f"--batch_size={self._batch_size}") command.append(f"--primary_id_column={self._primary_id_column}") command.append(f"--output_type={self._output_type}") command.append(f"--glue_registry_name={self._glue_registry_name}") + if self._num_threads: + command.append(f"--num_threads={self._num_threads}") + if self._batch_size: + command.append(f"--batch_size={self._batch_size}") if self._assume_role_arn: command.append(f"--assume_role_arn={self._assume_role_arn}") if self._secondary_id_column: From cd80d5edfc4a2638ebf4bbaf3e723b32fec00412 Mon Sep 17 00:00:00 2001 From: claudiazi Date: Wed, 18 Jun 2025 17:44:54 +0200 Subject: [PATCH 154/189] feat: convert soda to deferrable --- dagger/dag_creator/airflow/operator_creators/soda_creator.py | 1 + 1 file changed, 1 insertion(+) diff --git a/dagger/dag_creator/airflow/operator_creators/soda_creator.py b/dagger/dag_creator/airflow/operator_creators/soda_creator.py index 8b1f68a..fdf9931 100644 --- a/dagger/dag_creator/airflow/operator_creators/soda_creator.py +++ b/dagger/dag_creator/airflow/operator_creators/soda_creator.py @@ -11,6 +11,7 @@ class SodaCreator(BatchCreator): def __init__(self, task, dag): super().__init__(task, dag) + self.deferrable = True self._absolute_job_name = task.absolute_job_name self._table_name = task.table_name self._output_s3_path = task.output_s3_path From f961d1530456ab49b9937101d4254e24d0c15a17 Mon Sep 17 00:00:00 2001 From: claudiazi Date: Thu, 19 Jun 2025 14:05:04 +0200 Subject: [PATCH 155/189] feat: convert soda to deferrable --- dagger/dag_creator/airflow/operator_creators/soda_creator.py | 1 - dagger/dag_creator/airflow/operators/soda_batch.py | 1 + 2 files changed, 1 insertion(+), 1 deletion(-) diff --git a/dagger/dag_creator/airflow/operator_creators/soda_creator.py b/dagger/dag_creator/airflow/operator_creators/soda_creator.py index fdf9931..8b1f68a 100644 --- a/dagger/dag_creator/airflow/operator_creators/soda_creator.py +++ b/dagger/dag_creator/airflow/operator_creators/soda_creator.py @@ -11,7 +11,6 @@ class SodaCreator(BatchCreator): def __init__(self, task, dag): super().__init__(task, dag) - self.deferrable = True self._absolute_job_name = task.absolute_job_name self._table_name = task.table_name self._output_s3_path = task.output_s3_path diff --git a/dagger/dag_creator/airflow/operators/soda_batch.py b/dagger/dag_creator/airflow/operators/soda_batch.py index d67a26c..f5d492f 100644 --- a/dagger/dag_creator/airflow/operators/soda_batch.py +++ b/dagger/dag_creator/airflow/operators/soda_batch.py @@ -3,3 +3,4 @@ class SodaBatchOperator(AWSBatchOperator): custom_operator_name = 'Soda' ui_color = "#e4f0e7" + deferrable = True From 16493761dd4e8174510c6f719443c27624cd7f9b Mon Sep 17 00:00:00 2001 From: claudiazi Date: Thu, 19 Jun 2025 14:06:50 +0200 Subject: [PATCH 156/189] feat: convert soda to deferrable --- dagger/dag_creator/airflow/operator_creators/soda_creator.py | 1 + dagger/dag_creator/airflow/operators/soda_batch.py | 1 - 2 files changed, 1 insertion(+), 1 deletion(-) diff --git a/dagger/dag_creator/airflow/operator_creators/soda_creator.py b/dagger/dag_creator/airflow/operator_creators/soda_creator.py index 8b1f68a..9aa2321 100644 --- a/dagger/dag_creator/airflow/operator_creators/soda_creator.py +++ b/dagger/dag_creator/airflow/operator_creators/soda_creator.py @@ -46,6 +46,7 @@ def _create_operator(self, **kwargs): job_queue=self._task.job_queue, container_overrides=overrides, awslogs_enabled=True, + deferrable=True, **kwargs, ) return batch_op diff --git a/dagger/dag_creator/airflow/operators/soda_batch.py b/dagger/dag_creator/airflow/operators/soda_batch.py index f5d492f..d67a26c 100644 --- a/dagger/dag_creator/airflow/operators/soda_batch.py +++ b/dagger/dag_creator/airflow/operators/soda_batch.py @@ -3,4 +3,3 @@ class SodaBatchOperator(AWSBatchOperator): custom_operator_name = 'Soda' ui_color = "#e4f0e7" - deferrable = True From fca31a6ebaa6be2a358a23a07bdac88d677c9eb2 Mon Sep 17 00:00:00 2001 From: claudiazi Date: Thu, 19 Jun 2025 14:24:25 +0200 Subject: [PATCH 157/189] fix: add extra package --- reqs/base.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/reqs/base.txt b/reqs/base.txt index 279ed4b..18ac07e 100644 --- a/reqs/base.txt +++ b/reqs/base.txt @@ -1,3 +1,4 @@ +aiobotocore>=2.5.0 click==8.1.3 croniter==2.0.2 envyaml==1.10.211231 From 4c009d31aef3975a96c205dd2fa028357383b785 Mon Sep 17 00:00:00 2001 From: claudiazi Date: Tue, 24 Jun 2025 12:20:30 +0200 Subject: [PATCH 158/189] feat: add real logs for deferrable batch job --- .../airflow/operators/awsbatch_operator.py | 51 +++++++++++++++++-- 1 file changed, 48 insertions(+), 3 deletions(-) diff --git a/dagger/dag_creator/airflow/operators/awsbatch_operator.py b/dagger/dag_creator/airflow/operators/awsbatch_operator.py index 23b3596..952aae6 100644 --- a/dagger/dag_creator/airflow/operators/awsbatch_operator.py +++ b/dagger/dag_creator/airflow/operators/awsbatch_operator.py @@ -1,11 +1,13 @@ -from airflow.providers.amazon.aws.operators.batch import BatchOperator -from airflow.utils.context import Context +from typing import Any, Optional + from airflow.exceptions import AirflowException from airflow.providers.amazon.aws.links.batch import ( BatchJobDefinitionLink, BatchJobQueueLink, ) from airflow.providers.amazon.aws.links.logs import CloudWatchEventsLink +from airflow.providers.amazon.aws.operators.batch import BatchOperator +from airflow.utils.context import Context class AWSBatchOperator(BatchOperator): @@ -69,7 +71,6 @@ def monitor_job(self, context: Context): if awslogs: self.log.info("AWS Batch job (%s) CloudWatch Events details found. Links to logs:", self.job_id) - link_builder = CloudWatchEventsLink() for log in awslogs: self.log.info(self._format_cloudwatch_link(**log)) if len(awslogs) > 1: @@ -88,3 +89,47 @@ def monitor_job(self, context: Context): self.hook.check_job_success(self.job_id) self.log.info("AWS Batch job (%s) succeeded", self.job_id) + + def execute_complete(self, context: Context, event: Optional[dict[str, Any]] = None) -> str: + """Execute when the trigger fires - fetch logs and complete the task.""" + # Call parent's execute_complete first + job_id = super().execute_complete(context, event) + + # Only fetch logs if we're in deferrable mode and awslogs are enabled + # In non-deferrable mode, logs are already fetched by monitor_job() + if self.deferrable and self.awslogs_enabled and job_id: + # Set job_id for our log fetching methods + self.job_id = job_id + + # Get job logs and display them + try: + # Use the log fetcher to display container logs + log_fetcher = self._get_batch_log_fetcher() + if log_fetcher: + log_fetcher.get_all_logs() + except Exception as e: + self.log.warning("Could not fetch batch job logs: %s", e) + + # Get CloudWatch log links + awslogs = [] + try: + awslogs = self.hook.get_job_all_awslogs_info(self.job_id) + except AirflowException as ae: + self.log.warning("Cannot determine where to find the AWS logs for this Batch job: %s", ae) + + if awslogs: + self.log.info("AWS Batch job (%s) CloudWatch Events details found. Links to logs:", self.job_id) + for log in awslogs: + self.log.info(self._format_cloudwatch_link(**log)) + + CloudWatchEventsLink.persist( + context=context, + operator=self, + region_name=self.hook.conn_region_name, + aws_partition=self.hook.conn_partition, + **awslogs[0], + ) + + self.log.info("AWS Batch job (%s) succeeded", self.job_id) + + return job_id From 7a5af26a32deae53919ae2e3100328a616c2cb03 Mon Sep 17 00:00:00 2001 From: claudiazi Date: Tue, 24 Jun 2025 13:35:48 +0200 Subject: [PATCH 159/189] feat: improve the log message --- dagger/dag_creator/airflow/operators/awsbatch_operator.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/dagger/dag_creator/airflow/operators/awsbatch_operator.py b/dagger/dag_creator/airflow/operators/awsbatch_operator.py index 952aae6..95e4a1b 100644 --- a/dagger/dag_creator/airflow/operators/awsbatch_operator.py +++ b/dagger/dag_creator/airflow/operators/awsbatch_operator.py @@ -104,9 +104,13 @@ def execute_complete(self, context: Context, event: Optional[dict[str, Any]] = N # Get job logs and display them try: # Use the log fetcher to display container logs - log_fetcher = self._get_batch_log_fetcher() + log_fetcher = self._get_batch_log_fetcher(job_id) if log_fetcher: - log_fetcher.get_all_logs() + # Get the last 10,000 log messages (CloudWatch limit) + self.log.info("Fetch the latest 100 messages from cloudwatch:") + log_messages = log_fetcher.get_last_log_messages(100) + for message in log_messages: + self.log.info(message) except Exception as e: self.log.warning("Could not fetch batch job logs: %s", e) From b38ede94955ebdf7eb4dbdd335dad61bb31c0057 Mon Sep 17 00:00:00 2001 From: claudiazi Date: Tue, 24 Jun 2025 13:48:24 +0200 Subject: [PATCH 160/189] chore: decrease the size of logs --- dagger/dag_creator/airflow/operators/awsbatch_operator.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/dagger/dag_creator/airflow/operators/awsbatch_operator.py b/dagger/dag_creator/airflow/operators/awsbatch_operator.py index 95e4a1b..1b56ca7 100644 --- a/dagger/dag_creator/airflow/operators/awsbatch_operator.py +++ b/dagger/dag_creator/airflow/operators/awsbatch_operator.py @@ -106,9 +106,9 @@ def execute_complete(self, context: Context, event: Optional[dict[str, Any]] = N # Use the log fetcher to display container logs log_fetcher = self._get_batch_log_fetcher(job_id) if log_fetcher: - # Get the last 10,000 log messages (CloudWatch limit) - self.log.info("Fetch the latest 100 messages from cloudwatch:") - log_messages = log_fetcher.get_last_log_messages(100) + # Get the last 50 log messages + self.log.info("Fetch the latest 50 messages from cloudwatch:") + log_messages = log_fetcher.get_last_log_messages(50) for message in log_messages: self.log.info(message) except Exception as e: From 9603fb51a7bb6efdf2e0b9facccc0e85904f67eb Mon Sep 17 00:00:00 2001 From: "raimundo.vidal" Date: Wed, 16 Jul 2025 15:11:58 +0200 Subject: [PATCH 161/189] upgrading flask and requests --- dockers/dagger_ui/requirements.txt | 4 ++-- reqs/ui.txt | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/dockers/dagger_ui/requirements.txt b/dockers/dagger_ui/requirements.txt index 52aae65..10f95e3 100644 --- a/dockers/dagger_ui/requirements.txt +++ b/dockers/dagger_ui/requirements.txt @@ -1,6 +1,6 @@ elasticsearch==7.17.7 Flask-WTF==0.15.1 -flask==2.2.2 +flask==2.2.5 python-dotenv==0.21.0 -requests==2.28.1 +requests==2.32.4 WTForms==2.3.3 diff --git a/reqs/ui.txt b/reqs/ui.txt index 52aae65..10f95e3 100644 --- a/reqs/ui.txt +++ b/reqs/ui.txt @@ -1,6 +1,6 @@ elasticsearch==7.17.7 Flask-WTF==0.15.1 -flask==2.2.2 +flask==2.2.5 python-dotenv==0.21.0 -requests==2.28.1 +requests==2.32.4 WTForms==2.3.3 From 7e4b663bbc17c53c1bd571722b79290547615251 Mon Sep 17 00:00:00 2001 From: Kiran Vasudev Date: Wed, 6 Aug 2025 18:12:35 +0200 Subject: [PATCH 162/189] parse dagrun_timeout as a timedelta instead of int --- dagger/pipeline/pipeline.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/dagger/pipeline/pipeline.py b/dagger/pipeline/pipeline.py index 32c45a2..d0d28ee 100644 --- a/dagger/pipeline/pipeline.py +++ b/dagger/pipeline/pipeline.py @@ -1,4 +1,4 @@ -from datetime import datetime +from datetime import datetime, timedelta from os.path import join, relpath from dagger import conf @@ -71,6 +71,7 @@ def __init__(self, directory: str, config: dict): self._alerts = [] self._alert_factory = AlertFactory() self.process_alerts(config["alerts"] or []) + self.process_dag_parameters(self._parameters) @property def directory(self): @@ -129,3 +130,8 @@ def process_alerts(self, alert_configs): alert_type, join(self.directory, "pipeline.yaml"), alert_config ) ) + def process_dag_parameters(self, dag_parameters):#TODO: create long term fix for this + if dag_parameters is not None: + for key, value in dag_parameters.items(): + if key == 'dagrun_timeout': + self._parameters[key] = eval(value) From cd1c90e3840074863b8c5648353b38d980b67632 Mon Sep 17 00:00:00 2001 From: claudiazi Date: Thu, 7 Aug 2025 14:38:29 +0200 Subject: [PATCH 163/189] feat: fetch the logs from cloudwatch before BatchOperator.execute_complete --- .../airflow/operators/awsbatch_operator.py | 18 +++++++----------- 1 file changed, 7 insertions(+), 11 deletions(-) diff --git a/dagger/dag_creator/airflow/operators/awsbatch_operator.py b/dagger/dag_creator/airflow/operators/awsbatch_operator.py index 1b56ca7..2b6950f 100644 --- a/dagger/dag_creator/airflow/operators/awsbatch_operator.py +++ b/dagger/dag_creator/airflow/operators/awsbatch_operator.py @@ -92,16 +92,13 @@ def monitor_job(self, context: Context): def execute_complete(self, context: Context, event: Optional[dict[str, Any]] = None) -> str: """Execute when the trigger fires - fetch logs and complete the task.""" - # Call parent's execute_complete first - job_id = super().execute_complete(context, event) - - # Only fetch logs if we're in deferrable mode and awslogs are enabled - # In non-deferrable mode, logs are already fetched by monitor_job() - if self.deferrable and self.awslogs_enabled and job_id: + # Fetch logs before calling parent's execute_complete for both success and failure cases + if self.deferrable and self.awslogs_enabled and event and event.get("job_id"): + job_id = event["job_id"] # Set job_id for our log fetching methods self.job_id = job_id - # Get job logs and display them + # Get job logs and display them for both successful and failed jobs try: # Use the log fetcher to display container logs log_fetcher = self._get_batch_log_fetcher(job_id) @@ -133,7 +130,6 @@ def execute_complete(self, context: Context, event: Optional[dict[str, Any]] = N aws_partition=self.hook.conn_partition, **awslogs[0], ) - - self.log.info("AWS Batch job (%s) succeeded", self.job_id) - - return job_id + + # Call parent's execute_complete which will handle success/failure logic + return super().execute_complete(context, event) From 6733f4a703da64a05d7edb86aac8947b76893bfb Mon Sep 17 00:00:00 2001 From: claudiazi Date: Thu, 7 Aug 2025 16:37:42 +0200 Subject: [PATCH 164/189] feat: add fetch logs if the execute fails --- .../airflow/operators/awsbatch_operator.py | 48 ++++++++++++++++++- 1 file changed, 47 insertions(+), 1 deletion(-) diff --git a/dagger/dag_creator/airflow/operators/awsbatch_operator.py b/dagger/dag_creator/airflow/operators/awsbatch_operator.py index 2b6950f..9eca269 100644 --- a/dagger/dag_creator/airflow/operators/awsbatch_operator.py +++ b/dagger/dag_creator/airflow/operators/awsbatch_operator.py @@ -1,6 +1,6 @@ from typing import Any, Optional -from airflow.exceptions import AirflowException +from airflow.exceptions import AirflowException, TaskDeferralError from airflow.providers.amazon.aws.links.batch import ( BatchJobDefinitionLink, BatchJobQueueLink, @@ -133,3 +133,49 @@ def execute_complete(self, context: Context, event: Optional[dict[str, Any]] = N # Call parent's execute_complete which will handle success/failure logic return super().execute_complete(context, event) + + def _fetch_batch_logs(self): + """Fetch and display batch job logs for debugging failed jobs.""" + if not self.job_id or not self.awslogs_enabled: + return + + try: + # Use the log fetcher to display container logs + log_fetcher = self._get_batch_log_fetcher(self.job_id) + if log_fetcher: + # Get the last 50 log messages + self.log.info("Fetch the latest 50 messages from cloudwatch:") + log_messages = log_fetcher.get_last_log_messages(50) + for message in log_messages: + self.log.info(message) + except Exception as e: + self.log.warning("Could not fetch batch job logs: %s", e) + + # Get CloudWatch log links + awslogs = [] + try: + awslogs = self.hook.get_job_all_awslogs_info(self.job_id) + except AirflowException as ae: + self.log.warning("Cannot determine where to find the AWS logs for this Batch job: %s", ae) + + if awslogs: + self.log.info("AWS Batch job (%s) CloudWatch Events details found. Links to logs:", self.job_id) + for log in awslogs: + self.log.info(self._format_cloudwatch_link(**log)) + + def execute(self, context: Context): + """Override execute to handle failures and fetch logs.""" + try: + return super().execute(context) + except (TaskDeferralError, AirflowException) as e: + # When deferred task fails or other batch-related errors occur, fetch logs if we have a job_id + if self.deferrable and self.job_id and self.awslogs_enabled: + self.log.info("Task failed (deferrable mode), attempting to fetch batch job logs...") + self._fetch_batch_logs() + raise + except Exception as e: + # For any other unexpected exception, still try to fetch logs if we have job info + if self.deferrable and self.job_id and self.awslogs_enabled: + self.log.info("Unexpected error in deferrable batch task, attempting to fetch logs...") + self._fetch_batch_logs() + raise From abfeac2a01d557c622ea0345cc0802ac4c09c1a2 Mon Sep 17 00:00:00 2001 From: claudiazi Date: Thu, 7 Aug 2025 14:38:29 +0200 Subject: [PATCH 165/189] feat: fetch the logs from cloudwatch before BatchOperator.execute_complete --- .../airflow/operators/awsbatch_operator.py | 18 +++++++----------- 1 file changed, 7 insertions(+), 11 deletions(-) diff --git a/dagger/dag_creator/airflow/operators/awsbatch_operator.py b/dagger/dag_creator/airflow/operators/awsbatch_operator.py index 1b56ca7..2b6950f 100644 --- a/dagger/dag_creator/airflow/operators/awsbatch_operator.py +++ b/dagger/dag_creator/airflow/operators/awsbatch_operator.py @@ -92,16 +92,13 @@ def monitor_job(self, context: Context): def execute_complete(self, context: Context, event: Optional[dict[str, Any]] = None) -> str: """Execute when the trigger fires - fetch logs and complete the task.""" - # Call parent's execute_complete first - job_id = super().execute_complete(context, event) - - # Only fetch logs if we're in deferrable mode and awslogs are enabled - # In non-deferrable mode, logs are already fetched by monitor_job() - if self.deferrable and self.awslogs_enabled and job_id: + # Fetch logs before calling parent's execute_complete for both success and failure cases + if self.deferrable and self.awslogs_enabled and event and event.get("job_id"): + job_id = event["job_id"] # Set job_id for our log fetching methods self.job_id = job_id - # Get job logs and display them + # Get job logs and display them for both successful and failed jobs try: # Use the log fetcher to display container logs log_fetcher = self._get_batch_log_fetcher(job_id) @@ -133,7 +130,6 @@ def execute_complete(self, context: Context, event: Optional[dict[str, Any]] = N aws_partition=self.hook.conn_partition, **awslogs[0], ) - - self.log.info("AWS Batch job (%s) succeeded", self.job_id) - - return job_id + + # Call parent's execute_complete which will handle success/failure logic + return super().execute_complete(context, event) From a65e25e0bb1f2968da65ced8b13a1a5ca64fa0c7 Mon Sep 17 00:00:00 2001 From: claudiazi Date: Thu, 7 Aug 2025 16:37:42 +0200 Subject: [PATCH 166/189] feat: add fetch logs if the execute fails --- .../airflow/operators/awsbatch_operator.py | 48 ++++++++++++++++++- 1 file changed, 47 insertions(+), 1 deletion(-) diff --git a/dagger/dag_creator/airflow/operators/awsbatch_operator.py b/dagger/dag_creator/airflow/operators/awsbatch_operator.py index 2b6950f..9eca269 100644 --- a/dagger/dag_creator/airflow/operators/awsbatch_operator.py +++ b/dagger/dag_creator/airflow/operators/awsbatch_operator.py @@ -1,6 +1,6 @@ from typing import Any, Optional -from airflow.exceptions import AirflowException +from airflow.exceptions import AirflowException, TaskDeferralError from airflow.providers.amazon.aws.links.batch import ( BatchJobDefinitionLink, BatchJobQueueLink, @@ -133,3 +133,49 @@ def execute_complete(self, context: Context, event: Optional[dict[str, Any]] = N # Call parent's execute_complete which will handle success/failure logic return super().execute_complete(context, event) + + def _fetch_batch_logs(self): + """Fetch and display batch job logs for debugging failed jobs.""" + if not self.job_id or not self.awslogs_enabled: + return + + try: + # Use the log fetcher to display container logs + log_fetcher = self._get_batch_log_fetcher(self.job_id) + if log_fetcher: + # Get the last 50 log messages + self.log.info("Fetch the latest 50 messages from cloudwatch:") + log_messages = log_fetcher.get_last_log_messages(50) + for message in log_messages: + self.log.info(message) + except Exception as e: + self.log.warning("Could not fetch batch job logs: %s", e) + + # Get CloudWatch log links + awslogs = [] + try: + awslogs = self.hook.get_job_all_awslogs_info(self.job_id) + except AirflowException as ae: + self.log.warning("Cannot determine where to find the AWS logs for this Batch job: %s", ae) + + if awslogs: + self.log.info("AWS Batch job (%s) CloudWatch Events details found. Links to logs:", self.job_id) + for log in awslogs: + self.log.info(self._format_cloudwatch_link(**log)) + + def execute(self, context: Context): + """Override execute to handle failures and fetch logs.""" + try: + return super().execute(context) + except (TaskDeferralError, AirflowException) as e: + # When deferred task fails or other batch-related errors occur, fetch logs if we have a job_id + if self.deferrable and self.job_id and self.awslogs_enabled: + self.log.info("Task failed (deferrable mode), attempting to fetch batch job logs...") + self._fetch_batch_logs() + raise + except Exception as e: + # For any other unexpected exception, still try to fetch logs if we have job info + if self.deferrable and self.job_id and self.awslogs_enabled: + self.log.info("Unexpected error in deferrable batch task, attempting to fetch logs...") + self._fetch_batch_logs() + raise From 007dcbac0655054c061805765988b4bc1164c51d Mon Sep 17 00:00:00 2001 From: claudiazi Date: Fri, 8 Aug 2025 12:51:15 +0200 Subject: [PATCH 167/189] fix: handle all failure scenarios --- .../airflow/operators/awsbatch_operator.py | 155 +++++++++--------- 1 file changed, 79 insertions(+), 76 deletions(-) diff --git a/dagger/dag_creator/airflow/operators/awsbatch_operator.py b/dagger/dag_creator/airflow/operators/awsbatch_operator.py index 9eca269..b4173e2 100644 --- a/dagger/dag_creator/airflow/operators/awsbatch_operator.py +++ b/dagger/dag_creator/airflow/operators/awsbatch_operator.py @@ -1,4 +1,4 @@ -from typing import Any, Optional +from typing import Any, Optional, Union from airflow.exceptions import AirflowException, TaskDeferralError from airflow.providers.amazon.aws.links.batch import ( @@ -10,6 +10,18 @@ from airflow.utils.context import Context +def _format_extra_info(error_msg: str, last_logs: list[str], cloudwatch_link: Optional[str]) -> str: + """Format the enhanced error message with logs and link.""" + extra_info = [] + if cloudwatch_link: + extra_info.append(f"CloudWatch Logs: {cloudwatch_link}") + if last_logs: + extra_info.append("Last log lines:\n" + "\n".join(last_logs[-5:])) + if extra_info: + return f"{error_msg}\n\n" + "\n".join(extra_info) + return error_msg + + class AWSBatchOperator(BatchOperator): @staticmethod def _format_cloudwatch_link(awslogs_region: str, awslogs_group: str, awslogs_stream_name: str): @@ -90,92 +102,83 @@ def monitor_job(self, context: Context): self.hook.check_job_success(self.job_id) self.log.info("AWS Batch job (%s) succeeded", self.job_id) - def execute_complete(self, context: Context, event: Optional[dict[str, Any]] = None) -> str: - """Execute when the trigger fires - fetch logs and complete the task.""" - # Fetch logs before calling parent's execute_complete for both success and failure cases - if self.deferrable and self.awslogs_enabled and event and event.get("job_id"): - job_id = event["job_id"] - # Set job_id for our log fetching methods - self.job_id = job_id - - # Get job logs and display them for both successful and failed jobs + def _fetch_and_log_cloudwatch(self, context: Context, job_id: str) -> tuple[list[str], Optional[str]]: + """ + Fetch CloudWatch logs for the given job_id, log them to Airflow, + and return (last_logs, cloudwatch_link). + """ + last_logs: list[str] = [] + cloudwatch_link: Optional[str] = None + + if self.awslogs_enabled: + # Fetch last 50 log messages try: - # Use the log fetcher to display container logs log_fetcher = self._get_batch_log_fetcher(job_id) if log_fetcher: - # Get the last 50 log messages - self.log.info("Fetch the latest 50 messages from cloudwatch:") - log_messages = log_fetcher.get_last_log_messages(50) - for message in log_messages: + self.log.info("Fetching the latest 50 messages from CloudWatch:") + last_logs = log_fetcher.get_last_log_messages(50) + for message in last_logs: self.log.info(message) except Exception as e: self.log.warning("Could not fetch batch job logs: %s", e) - - # Get CloudWatch log links - awslogs = [] + + # Fetch CloudWatch log link try: - awslogs = self.hook.get_job_all_awslogs_info(self.job_id) + awslogs = self.hook.get_job_all_awslogs_info(job_id) except AirflowException as ae: - self.log.warning("Cannot determine where to find the AWS logs for this Batch job: %s", ae) - - if awslogs: - self.log.info("AWS Batch job (%s) CloudWatch Events details found. Links to logs:", self.job_id) - for log in awslogs: - self.log.info(self._format_cloudwatch_link(**log)) - - CloudWatchEventsLink.persist( - context=context, - operator=self, - region_name=self.hook.conn_region_name, - aws_partition=self.hook.conn_partition, - **awslogs[0], + self.log.warning("Cannot determine where to find the AWS logs: %s", ae) + awslogs = [] + else: + if awslogs: + cloudwatch_link = self._format_cloudwatch_link(**awslogs[0]) + self.log.info("AWS Batch job (%s) CloudWatch Events details found:", job_id) + for log in awslogs: + self.log.info(self._format_cloudwatch_link(**log)) + CloudWatchEventsLink.persist( + context=context, + operator=self, + region_name=self.hook.conn_region_name, + aws_partition=self.hook.conn_partition, + **awslogs[0], + ) + + return last_logs, cloudwatch_link + + def execute(self, context: Context) -> Union[str, None]: + """Submit and monitor an AWS Batch job, including early failures.""" + try: + result = super().execute(context) + return result + except TaskDeferralError as e: + # Trigger itself failed — try to fetch logs if job_id is available + if self.deferrable and self.job_id: + last_logs, cloudwatch_link = self._fetch_and_log_cloudwatch(context, self.job_id) + raise AirflowException( + _format_extra_info(f"Trigger failed for job {self.job_id}: {e}", last_logs, cloudwatch_link) ) + raise + except AirflowException as e: + # Covers immediate failures before deferral (job already FAILED) + if self.job_id: + last_logs, cloudwatch_link = self._fetch_and_log_cloudwatch(context, self.job_id) + raise AirflowException(_format_extra_info(str(e), last_logs, cloudwatch_link)) + raise - # Call parent's execute_complete which will handle success/failure logic - return super().execute_complete(context, event) + def execute_complete(self, context: Context, event: Optional[dict[str, Any]] = None) -> str: + """Execute when the trigger fires - fetch logs first, then check job status.""" + job_id = event.get("job_id") if event else None + if not job_id: + raise AirflowException("No job_id found in event data from trigger.") - def _fetch_batch_logs(self): - """Fetch and display batch job logs for debugging failed jobs.""" - if not self.job_id or not self.awslogs_enabled: - return + self.job_id = job_id - try: - # Use the log fetcher to display container logs - log_fetcher = self._get_batch_log_fetcher(self.job_id) - if log_fetcher: - # Get the last 50 log messages - self.log.info("Fetch the latest 50 messages from cloudwatch:") - log_messages = log_fetcher.get_last_log_messages(50) - for message in log_messages: - self.log.info(message) - except Exception as e: - self.log.warning("Could not fetch batch job logs: %s", e) - - # Get CloudWatch log links - awslogs = [] - try: - awslogs = self.hook.get_job_all_awslogs_info(self.job_id) - except AirflowException as ae: - self.log.warning("Cannot determine where to find the AWS logs for this Batch job: %s", ae) + # Always fetch logs before checking status + last_logs, cloudwatch_link = self._fetch_and_log_cloudwatch(context, job_id) - if awslogs: - self.log.info("AWS Batch job (%s) CloudWatch Events details found. Links to logs:", self.job_id) - for log in awslogs: - self.log.info(self._format_cloudwatch_link(**log)) - - def execute(self, context: Context): - """Override execute to handle failures and fetch logs.""" try: - return super().execute(context) - except (TaskDeferralError, AirflowException) as e: - # When deferred task fails or other batch-related errors occur, fetch logs if we have a job_id - if self.deferrable and self.job_id and self.awslogs_enabled: - self.log.info("Task failed (deferrable mode), attempting to fetch batch job logs...") - self._fetch_batch_logs() - raise - except Exception as e: - # For any other unexpected exception, still try to fetch logs if we have job info - if self.deferrable and self.job_id and self.awslogs_enabled: - self.log.info("Unexpected error in deferrable batch task, attempting to fetch logs...") - self._fetch_batch_logs() - raise + self.hook.check_job_success(job_id) + except AirflowException as e: + raise AirflowException(_format_extra_info(str(e), last_logs, cloudwatch_link)) + + self.log.info("AWS Batch job (%s) succeeded", job_id) + return job_id From 78b47726ff1bc54171f42fe37a6f47ad9245a99a Mon Sep 17 00:00:00 2001 From: claudiazi Date: Fri, 8 Aug 2025 14:16:43 +0200 Subject: [PATCH 168/189] fix: overwrite resume_execution --- .../airflow/operators/awsbatch_operator.py | 21 +++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/dagger/dag_creator/airflow/operators/awsbatch_operator.py b/dagger/dag_creator/airflow/operators/awsbatch_operator.py index b4173e2..9f80dec 100644 --- a/dagger/dag_creator/airflow/operators/awsbatch_operator.py +++ b/dagger/dag_creator/airflow/operators/awsbatch_operator.py @@ -182,3 +182,24 @@ def execute_complete(self, context: Context, event: Optional[dict[str, Any]] = N self.log.info("AWS Batch job (%s) succeeded", job_id) return job_id + + def resume_execution(self, next_method: str, next_kwargs: Optional[dict[str, Any]], context: Context): + """Override resume_execution to handle trigger failures and fetch logs.""" + self.log.info(f"AWSBatchOperator.resume_execution called with next_method='{next_method}'") + self.log.info(f"job_id available: {hasattr(self, 'job_id') and bool(self.job_id)}") + self.log.info(f"awslogs_enabled: {getattr(self, 'awslogs_enabled', False)}") + + try: + return super().resume_execution(next_method, next_kwargs, context) + except TaskDeferralError as e: + # When trigger fails, try to fetch logs if job_id is available + if hasattr(self, 'job_id') and self.job_id and self.awslogs_enabled: + self.log.info("Trigger failed - attempting to fetch batch job logs...") + last_logs, cloudwatch_link = self._fetch_and_log_cloudwatch(context, self.job_id) + # Re-raise with enhanced error message + raise AirflowException( + _format_extra_info(f"Trigger failed for job {self.job_id}: {e}", last_logs, cloudwatch_link) + ) + else: + self.log.warning(f"Cannot fetch logs: job_id={getattr(self, 'job_id', None)}, awslogs_enabled={getattr(self, 'awslogs_enabled', False)}") + raise From 892e6189cba22a36216af8032f5cc2b0d3fee464 Mon Sep 17 00:00:00 2001 From: claudiazi Date: Fri, 8 Aug 2025 14:51:05 +0200 Subject: [PATCH 169/189] fix: save job_id in xcom --- .../airflow/operators/awsbatch_operator.py | 60 +++++++++++-------- 1 file changed, 36 insertions(+), 24 deletions(-) diff --git a/dagger/dag_creator/airflow/operators/awsbatch_operator.py b/dagger/dag_creator/airflow/operators/awsbatch_operator.py index 9f80dec..9374813 100644 --- a/dagger/dag_creator/airflow/operators/awsbatch_operator.py +++ b/dagger/dag_creator/airflow/operators/awsbatch_operator.py @@ -146,23 +146,27 @@ def _fetch_and_log_cloudwatch(self, context: Context, job_id: str) -> tuple[list def execute(self, context: Context) -> Union[str, None]: """Submit and monitor an AWS Batch job, including early failures.""" - try: - result = super().execute(context) - return result - except TaskDeferralError as e: - # Trigger itself failed — try to fetch logs if job_id is available - if self.deferrable and self.job_id: - last_logs, cloudwatch_link = self._fetch_and_log_cloudwatch(context, self.job_id) - raise AirflowException( - _format_extra_info(f"Trigger failed for job {self.job_id}: {e}", last_logs, cloudwatch_link) - ) - raise - except AirflowException as e: - # Covers immediate failures before deferral (job already FAILED) - if self.job_id: - last_logs, cloudwatch_link = self._fetch_and_log_cloudwatch(context, self.job_id) - raise AirflowException(_format_extra_info(str(e), last_logs, cloudwatch_link)) - raise + # First call parent execute, which will submit the job and possibly defer + result = super().execute(context) + + # If we reach here without exception, the task completed (didn't defer) + return result + + def defer(self, *, trigger, method_name: str = "execute_complete", kwargs=None, timeout=None): + """Override defer to store job_id in XCom before deferring.""" + # Store job_id in XCom so it's available when the task resumes + if hasattr(self, 'job_id') and self.job_id: + # Get task instance from current context + from airflow.operators.python import get_current_context + try: + context = get_current_context() + context['task_instance'].xcom_push(key='batch_job_id', value=self.job_id) + self.log.info(f"Stored job_id in XCom before deferring: {self.job_id}") + except Exception as e: + self.log.warning(f"Could not store job_id in XCom: {e}") + + # Call parent defer method + super().defer(trigger=trigger, method_name=method_name, kwargs=kwargs, timeout=timeout) def execute_complete(self, context: Context, event: Optional[dict[str, Any]] = None) -> str: """Execute when the trigger fires - fetch logs first, then check job status.""" @@ -185,21 +189,29 @@ def execute_complete(self, context: Context, event: Optional[dict[str, Any]] = N def resume_execution(self, next_method: str, next_kwargs: Optional[dict[str, Any]], context: Context): """Override resume_execution to handle trigger failures and fetch logs.""" - self.log.info(f"AWSBatchOperator.resume_execution called with next_method='{next_method}'") - self.log.info(f"job_id available: {hasattr(self, 'job_id') and bool(self.job_id)}") - self.log.info(f"awslogs_enabled: {getattr(self, 'awslogs_enabled', False)}") + # Retrieve job_id from XCom if not available on the instance + if not hasattr(self, 'job_id') or not self.job_id: + task_instance = context.get('task_instance') + if task_instance: + try: + stored_job_id = task_instance.xcom_pull(task_ids=task_instance.task_id, key='batch_job_id') + if stored_job_id: + self.job_id = stored_job_id + self.log.info(f"Retrieved job_id from XCom: {stored_job_id}") + except Exception as e: + self.log.debug(f"Could not retrieve job_id from XCom: {e}") try: return super().resume_execution(next_method, next_kwargs, context) except TaskDeferralError as e: # When trigger fails, try to fetch logs if job_id is available if hasattr(self, 'job_id') and self.job_id and self.awslogs_enabled: - self.log.info("Trigger failed - attempting to fetch batch job logs...") + self.log.info("Batch job trigger failed - fetching CloudWatch logs...") last_logs, cloudwatch_link = self._fetch_and_log_cloudwatch(context, self.job_id) - # Re-raise with enhanced error message + # Re-raise with enhanced error message including logs raise AirflowException( - _format_extra_info(f"Trigger failed for job {self.job_id}: {e}", last_logs, cloudwatch_link) + _format_extra_info(f"Batch job {self.job_id} failed: {e}", last_logs, cloudwatch_link) ) else: - self.log.warning(f"Cannot fetch logs: job_id={getattr(self, 'job_id', None)}, awslogs_enabled={getattr(self, 'awslogs_enabled', False)}") + self.log.warning("Cannot fetch logs for failed batch job - job_id or awslogs_enabled not available") raise From 3e6fe6d61cc5105ed6da9f9fe0b40649c771e3ed Mon Sep 17 00:00:00 2001 From: claudiazi Date: Fri, 8 Aug 2025 15:12:53 +0200 Subject: [PATCH 170/189] chore: simplify the logic --- .../airflow/operators/awsbatch_operator.py | 28 ++++--------------- 1 file changed, 6 insertions(+), 22 deletions(-) diff --git a/dagger/dag_creator/airflow/operators/awsbatch_operator.py b/dagger/dag_creator/airflow/operators/awsbatch_operator.py index 9374813..e052724 100644 --- a/dagger/dag_creator/airflow/operators/awsbatch_operator.py +++ b/dagger/dag_creator/airflow/operators/awsbatch_operator.py @@ -152,22 +152,6 @@ def execute(self, context: Context) -> Union[str, None]: # If we reach here without exception, the task completed (didn't defer) return result - def defer(self, *, trigger, method_name: str = "execute_complete", kwargs=None, timeout=None): - """Override defer to store job_id in XCom before deferring.""" - # Store job_id in XCom so it's available when the task resumes - if hasattr(self, 'job_id') and self.job_id: - # Get task instance from current context - from airflow.operators.python import get_current_context - try: - context = get_current_context() - context['task_instance'].xcom_push(key='batch_job_id', value=self.job_id) - self.log.info(f"Stored job_id in XCom before deferring: {self.job_id}") - except Exception as e: - self.log.warning(f"Could not store job_id in XCom: {e}") - - # Call parent defer method - super().defer(trigger=trigger, method_name=method_name, kwargs=kwargs, timeout=timeout) - def execute_complete(self, context: Context, event: Optional[dict[str, Any]] = None) -> str: """Execute when the trigger fires - fetch logs first, then check job status.""" job_id = event.get("job_id") if event else None @@ -189,17 +173,17 @@ def execute_complete(self, context: Context, event: Optional[dict[str, Any]] = N def resume_execution(self, next_method: str, next_kwargs: Optional[dict[str, Any]], context: Context): """Override resume_execution to handle trigger failures and fetch logs.""" - # Retrieve job_id from XCom if not available on the instance + # Retrieve job_id from batch_job_details XCom if not available on the instance if not hasattr(self, 'job_id') or not self.job_id: task_instance = context.get('task_instance') if task_instance: try: - stored_job_id = task_instance.xcom_pull(task_ids=task_instance.task_id, key='batch_job_id') - if stored_job_id: - self.job_id = stored_job_id - self.log.info(f"Retrieved job_id from XCom: {stored_job_id}") + batch_job_details = task_instance.xcom_pull(task_ids=task_instance.task_id, key='batch_job_details') + if batch_job_details and 'job_id' in batch_job_details: + self.job_id = batch_job_details['job_id'] + self.log.info(f"Retrieved job_id from batch_job_details XCom: {self.job_id}") except Exception as e: - self.log.debug(f"Could not retrieve job_id from XCom: {e}") + self.log.debug(f"Could not retrieve job_id from batch_job_details XCom: {e}") try: return super().resume_execution(next_method, next_kwargs, context) From 8db638a24d2fc1f83ecefb2e215402d6f0aa6a1b Mon Sep 17 00:00:00 2001 From: claudiazi Date: Fri, 8 Aug 2025 15:42:11 +0200 Subject: [PATCH 171/189] chore: clean the logic --- .../airflow/operators/awsbatch_operator.py | 71 +++++++------------ 1 file changed, 26 insertions(+), 45 deletions(-) diff --git a/dagger/dag_creator/airflow/operators/awsbatch_operator.py b/dagger/dag_creator/airflow/operators/awsbatch_operator.py index e052724..3297e85 100644 --- a/dagger/dag_creator/airflow/operators/awsbatch_operator.py +++ b/dagger/dag_creator/airflow/operators/awsbatch_operator.py @@ -1,4 +1,4 @@ -from typing import Any, Optional, Union +from typing import Any, Optional from airflow.exceptions import AirflowException, TaskDeferralError from airflow.providers.amazon.aws.links.batch import ( @@ -102,55 +102,36 @@ def monitor_job(self, context: Context): self.hook.check_job_success(self.job_id) self.log.info("AWS Batch job (%s) succeeded", self.job_id) - def _fetch_and_log_cloudwatch(self, context: Context, job_id: str) -> tuple[list[str], Optional[str]]: - """ - Fetch CloudWatch logs for the given job_id, log them to Airflow, - and return (last_logs, cloudwatch_link). - """ + def _fetch_and_log_cloudwatch(self, job_id: str) -> tuple[list[str], Optional[str]]: + """Fetch CloudWatch logs for the given job_id and return (last_logs, cloudwatch_link).""" last_logs: list[str] = [] cloudwatch_link: Optional[str] = None - if self.awslogs_enabled: - # Fetch last 50 log messages - try: - log_fetcher = self._get_batch_log_fetcher(job_id) - if log_fetcher: - self.log.info("Fetching the latest 50 messages from CloudWatch:") - last_logs = log_fetcher.get_last_log_messages(50) + if not self.awslogs_enabled: + return last_logs, cloudwatch_link + + # Fetch last log messages + try: + log_fetcher = self._get_batch_log_fetcher(job_id) + if log_fetcher: + last_logs = log_fetcher.get_last_log_messages(50) + if last_logs: + self.log.info("CloudWatch logs (last 50 messages):") for message in last_logs: self.log.info(message) - except Exception as e: - self.log.warning("Could not fetch batch job logs: %s", e) - - # Fetch CloudWatch log link - try: - awslogs = self.hook.get_job_all_awslogs_info(job_id) - except AirflowException as ae: - self.log.warning("Cannot determine where to find the AWS logs: %s", ae) - awslogs = [] - else: - if awslogs: - cloudwatch_link = self._format_cloudwatch_link(**awslogs[0]) - self.log.info("AWS Batch job (%s) CloudWatch Events details found:", job_id) - for log in awslogs: - self.log.info(self._format_cloudwatch_link(**log)) - CloudWatchEventsLink.persist( - context=context, - operator=self, - region_name=self.hook.conn_region_name, - aws_partition=self.hook.conn_partition, - **awslogs[0], - ) + except Exception as e: + self.log.warning("Could not fetch batch job logs: %s", e) - return last_logs, cloudwatch_link + # Get CloudWatch log link + try: + awslogs = self.hook.get_job_all_awslogs_info(job_id) + if awslogs: + cloudwatch_link = self._format_cloudwatch_link(**awslogs[0]) + self.log.info("CloudWatch link: %s", cloudwatch_link) + except AirflowException as e: + self.log.warning("Cannot determine CloudWatch log link: %s", e) - def execute(self, context: Context) -> Union[str, None]: - """Submit and monitor an AWS Batch job, including early failures.""" - # First call parent execute, which will submit the job and possibly defer - result = super().execute(context) - - # If we reach here without exception, the task completed (didn't defer) - return result + return last_logs, cloudwatch_link def execute_complete(self, context: Context, event: Optional[dict[str, Any]] = None) -> str: """Execute when the trigger fires - fetch logs first, then check job status.""" @@ -161,7 +142,7 @@ def execute_complete(self, context: Context, event: Optional[dict[str, Any]] = N self.job_id = job_id # Always fetch logs before checking status - last_logs, cloudwatch_link = self._fetch_and_log_cloudwatch(context, job_id) + last_logs, cloudwatch_link = self._fetch_and_log_cloudwatch(job_id) try: self.hook.check_job_success(job_id) @@ -191,7 +172,7 @@ def resume_execution(self, next_method: str, next_kwargs: Optional[dict[str, Any # When trigger fails, try to fetch logs if job_id is available if hasattr(self, 'job_id') and self.job_id and self.awslogs_enabled: self.log.info("Batch job trigger failed - fetching CloudWatch logs...") - last_logs, cloudwatch_link = self._fetch_and_log_cloudwatch(context, self.job_id) + last_logs, cloudwatch_link = self._fetch_and_log_cloudwatch(self.job_id) # Re-raise with enhanced error message including logs raise AirflowException( _format_extra_info(f"Batch job {self.job_id} failed: {e}", last_logs, cloudwatch_link) From 925029f30a84af8bd72a4b1fd0ab4fb9bb5cdf7f Mon Sep 17 00:00:00 2001 From: David Siklosi Date: Sun, 19 Oct 2025 11:25:43 +0200 Subject: [PATCH 172/189] Relaxing click version to avoid conflict with dbt --- reqs/base.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/reqs/base.txt b/reqs/base.txt index 18ac07e..f88ab98 100644 --- a/reqs/base.txt +++ b/reqs/base.txt @@ -1,5 +1,5 @@ aiobotocore>=2.5.0 -click==8.1.3 +click>=8.1.3 croniter==2.0.2 envyaml==1.10.211231 mergedeep==1.3.4 From 29203d309006bc9c89bb536d1e9cbbc6fa5854f0 Mon Sep 17 00:00:00 2001 From: David Siklosi Date: Mon, 3 Nov 2025 13:31:00 +0100 Subject: [PATCH 173/189] Upgrading local env --- Makefile | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/Makefile b/Makefile index 8872dba..32f1f93 100644 --- a/Makefile +++ b/Makefile @@ -96,17 +96,16 @@ install: clean ## install the package to the active Python's site-packages install-dev: clean ## install the package to the active Python's site-packages - virtualenv -p python3.9 venv; \ + virtualenv -p python3.12 venv; \ source venv/bin/activate; \ - python -m pip install --upgrade pip; \ - python setup.py install; \ + python -m pip install --upgrade pip setuptools wheel; \ pip install -e . ; \ SYSTEM_VERSION_COMPAT=0 CFLAGS='-std=c++20' pip install -r reqs/dev.txt -r reqs/test.txt install-test: clean ## install the package to the active Python's site-packages - virtualenv -p python3.9 venv; \ + virtualenv -p python3.12 venv; \ source venv/bin/activate; \ - python -m pip install --upgrade pip; \ + python -m pip install --upgrade pip setuptools wheel; \ pip install -r reqs/test.txt -r reqs/base.txt install-ui: clean ## install the package to the active Python's site-packages From 2f665c5d3bbc7fcef099be09ba5a1a00b2c8b6dd Mon Sep 17 00:00:00 2001 From: David Siklosi Date: Mon, 3 Nov 2025 13:31:39 +0100 Subject: [PATCH 174/189] Fixing dependency issues --- reqs/base.txt | 1 - reqs/dev.txt | 3 +-- setup.py | 2 +- 3 files changed, 2 insertions(+), 4 deletions(-) diff --git a/reqs/base.txt b/reqs/base.txt index f88ab98..d5b400f 100644 --- a/reqs/base.txt +++ b/reqs/base.txt @@ -1,4 +1,3 @@ -aiobotocore>=2.5.0 click>=8.1.3 croniter==2.0.2 envyaml==1.10.211231 diff --git a/reqs/dev.txt b/reqs/dev.txt index 238b2e2..bfb4a44 100644 --- a/reqs/dev.txt +++ b/reqs/dev.txt @@ -1,7 +1,6 @@ pip==24.0 -apache-airflow[amazon,postgres,s3,statsd]==2.9.0 --constraint "https://raw.githubusercontent.com/apache/airflow/constraints-2.9.0/constraints-3.9.txt" +apache-airflow[amazon,postgres,s3,statsd]==2.9.0 --constraint "https://raw.githubusercontent.com/apache/airflow/constraints-2.9.0/constraints-3.12.txt" black==22.10.0 -boto3==1.34.82 bumpversion==0.6.0 coverage==7.4.4 elasticsearch==7.17.7 diff --git a/setup.py b/setup.py index f8b4b28..a9580a1 100644 --- a/setup.py +++ b/setup.py @@ -45,7 +45,7 @@ def reqs(*f): classifiers=[ "Development Status :: 2 - Pre-Alpha", "Intended Audience :: Developers", - "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.12", ], description="Config Driven ETL", entry_points={"console_scripts": ["dagger=dagger.main:cli"]}, From 57d300ab039ab944ee9f1186565cf56740424a52 Mon Sep 17 00:00:00 2001 From: David Siklosi Date: Mon, 3 Nov 2025 13:32:02 +0100 Subject: [PATCH 175/189] Upgrading git workflow python --- .github/workflows/ci-data.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/ci-data.yml b/.github/workflows/ci-data.yml index bef5bed..2d3044c 100644 --- a/.github/workflows/ci-data.yml +++ b/.github/workflows/ci-data.yml @@ -17,10 +17,10 @@ jobs: with: persist-credentials: false - - name: Set up Python 3.9 + - name: Set up Python 3.12 uses: actions/setup-python@v2 with: - python-version: 3.9 + python-version: 3.12 - name: Install dependencies run: | From 735279f40a861a14a430a4f75e5bad6fd9348fae Mon Sep 17 00:00:00 2001 From: David Siklosi Date: Mon, 3 Nov 2025 14:00:43 +0100 Subject: [PATCH 176/189] Bumping airflow version --- reqs/dev.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/reqs/dev.txt b/reqs/dev.txt index bfb4a44..b39d00a 100644 --- a/reqs/dev.txt +++ b/reqs/dev.txt @@ -1,5 +1,5 @@ pip==24.0 -apache-airflow[amazon,postgres,s3,statsd]==2.9.0 --constraint "https://raw.githubusercontent.com/apache/airflow/constraints-2.9.0/constraints-3.12.txt" +apache-airflow[amazon,postgres,s3,statsd]==2.11.0 --constraint "https://raw.githubusercontent.com/apache/airflow/constraints-2.11.0/constraints-3.12.txt" black==22.10.0 bumpversion==0.6.0 coverage==7.4.4 From d72b81230decd2ade25109b2560c79a8664cacf2 Mon Sep 17 00:00:00 2001 From: David Siklosi Date: Mon, 3 Nov 2025 14:15:07 +0100 Subject: [PATCH 177/189] Fixing version mismatch between test and dev --- reqs/test.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/reqs/test.txt b/reqs/test.txt index 7bdc89f..6bb6c2e 100644 --- a/reqs/test.txt +++ b/reqs/test.txt @@ -1,4 +1,4 @@ -apache-airflow[amazon,postgres,s3,statsd]==2.9.0 +apache-airflow[amazon,postgres,s3,statsd]==2.11.0 --constraint "https://raw.githubusercontent.com/apache/airflow/constraints-2.11.0/constraints-3.12.txt" pytest-cov==4.0.0 pytest==7.2.0 graphviz From 08fb0ab8e88c85452d9e361645a32a2e644f5f2a Mon Sep 17 00:00:00 2001 From: David Siklosi Date: Mon, 3 Nov 2025 14:49:40 +0100 Subject: [PATCH 178/189] Fixing unit tests caused by airflow version bump --- reqs/ui.txt | 1 + tests/fixtures/dag_creator/airflow/dag_test_external_sensor.dot | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/reqs/ui.txt b/reqs/ui.txt index 10f95e3..46d5546 100644 --- a/reqs/ui.txt +++ b/reqs/ui.txt @@ -4,3 +4,4 @@ flask==2.2.5 python-dotenv==0.21.0 requests==2.32.4 WTForms==2.3.3 + diff --git a/tests/fixtures/dag_creator/airflow/dag_test_external_sensor.dot b/tests/fixtures/dag_creator/airflow/dag_test_external_sensor.dot index 9824739..1add9af 100644 --- a/tests/fixtures/dag_creator/airflow/dag_test_external_sensor.dot +++ b/tests/fixtures/dag_creator/airflow/dag_test_external_sensor.dot @@ -3,7 +3,7 @@ digraph test_external_sensor { "dummy-control-flow" [color="#000000" fillcolor="#ffefeb" label="dummy-control-flow" shape=rectangle style="filled,rounded"] dummy_first [color="#000000" fillcolor="#e8f7e4" label=dummy_first shape=rectangle style="filled,rounded"] dummy_second [color="#000000" fillcolor="#e8f7e4" label=dummy_second shape=rectangle style="filled,rounded"] - "test_batch-batch-sensor" [color="#000000" fillcolor="#19647e" label="test_batch-batch-sensor" shape=rectangle style="filled,rounded"] + "test_batch-batch-sensor" [color="#000000" fillcolor="#4db7db" label="test_batch-batch-sensor" shape=rectangle style="filled,rounded"] "dummy-control-flow" -> "test_batch-batch-sensor" dummy_first -> dummy_second "test_batch-batch-sensor" -> dummy_first From f0640fdfd38481c19cadb74a8362e7be946e1bb7 Mon Sep 17 00:00:00 2001 From: Mohamad Hallak <16711801+mrhallak@users.noreply.github.com> Date: Fri, 9 Jan 2026 14:19:18 +0100 Subject: [PATCH 179/189] Implement DLT creator --- CLAUDE.md | 82 ++++ .../databricks_dlt_creator.py | 118 ++++++ .../dag_creator/airflow/operator_factory.py | 1 + dagger/pipeline/task_factory.py | 3 +- dagger/pipeline/tasks/databricks_dlt_task.py | 164 ++++++++ dagger/plugins/__init__.py | 1 + dagger/plugins/dlt_task_generator/__init__.py | 6 + .../dlt_task_generator/bundle_parser.py | 309 +++++++++++++++ .../dlt_task_generator/dlt_task_generator.py | 374 ++++++++++++++++++ dagger/utilities/dbt_config_parser.py | 180 ++++++++- dagger/utilities/module.py | 23 +- 11 files changed, 1234 insertions(+), 27 deletions(-) create mode 100644 CLAUDE.md create mode 100644 dagger/dag_creator/airflow/operator_creators/databricks_dlt_creator.py create mode 100644 dagger/pipeline/tasks/databricks_dlt_task.py create mode 100644 dagger/plugins/__init__.py create mode 100644 dagger/plugins/dlt_task_generator/__init__.py create mode 100644 dagger/plugins/dlt_task_generator/bundle_parser.py create mode 100644 dagger/plugins/dlt_task_generator/dlt_task_generator.py diff --git a/CLAUDE.md b/CLAUDE.md new file mode 100644 index 0000000..ffd1ebd --- /dev/null +++ b/CLAUDE.md @@ -0,0 +1,82 @@ +# CLAUDE.md + +This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository. + +## Project Overview + +Dagger is a configuration-driven framework that transforms YAML definitions into Apache Airflow DAGs. It uses dataset lineage (matching inputs/outputs) to automatically build dependency graphs across workflows. + +## Common Commands + +### Development Setup +```bash +make install-dev # Create venv, install package in editable mode with dev/test deps +source venv/bin/activate +``` + +### Testing +```bash +make test # Run all tests with coverage (sets AIRFLOW_HOME automatically) + +# Run a single test file +AIRFLOW_HOME=$(pwd)/tests/fixtures/config_finder/root/ ENV=local pytest -s tests/path/to/test_file.py + +# Run a specific test +AIRFLOW_HOME=$(pwd)/tests/fixtures/config_finder/root/ ENV=local pytest -s tests/path/to/test_file.py::test_function_name +``` + +### Linting +```bash +make lint # Run flake8 on dagger and tests directories +black dagger tests # Format code +``` + +### Local Airflow Testing +```bash +make test-airflow # Build and start Airflow in Docker (localhost:8080, user: dev_user, pass: dev_user) +make stop-airflow # Stop Airflow containers +``` + +### CLI +```bash +dagger --help +dagger list-tasks # Show available task types +dagger list-ios # Show available IO types +dagger init-pipeline # Create a new pipeline.yaml +dagger init-task --type= # Add a task configuration +dagger init-io --type= # Add an IO definition +dagger print-graph # Visualize dependency graph +``` + +## Architecture + +### Core Flow +1. **ConfigFinder** discovers pipeline directories (each with `pipeline.yaml` + task YAML files) +2. **ConfigProcessor** loads YAML configs with environment variable support +3. **TaskFactory/IOFactory** use reflection to instantiate task/IO objects from YAML +4. **TaskGraph** builds a 3-layer graph: Pipeline → Task → Dataset nodes +5. **DagCreator** traverses the graph and generates Airflow DAGs using **OperatorFactory** + +### Key Directories +- `dagger/pipeline/tasks/` - Task type definitions (DbtTask, SparkTask, AthenaTransformTask, etc.) +- `dagger/pipeline/ios/` - IO type definitions (S3, Redshift, Athena, Databricks, etc.) +- `dagger/dag_creator/airflow/operator_creators/` - One creator per task type, translates tasks to Airflow operators +- `dagger/graph/` - Graph construction from task inputs/outputs +- `dagger/config_finder/` - YAML discovery and loading +- `tests/fixtures/config_finder/root/dags/` - Example DAG configurations for testing + +### Adding a New Task Type +1. Create task definition in `dagger/pipeline/tasks/` (subclass of Task) +2. Create any needed IOs in `dagger/pipeline/ios/` (if new data sources) +3. Create operator creator in `dagger/dag_creator/airflow/operator_creators/` +4. Register in `dagger/dag_creator/airflow/operator_factory.py` + +### Configuration Files +- `pipeline.yaml` - Pipeline metadata (owner, schedule, alerts, airflow_parameters) +- `[taskname].yaml` - Task configs (type, inputs, outputs, task-specific params) +- `dagger_config.yaml` - System config (Neo4j, Elasticsearch, Spark settings) + +### Key Patterns +- **Factory Pattern**: TaskFactory/IOFactory auto-discover types via reflection +- **Strategy Pattern**: OperatorCreator subclasses handle task-specific operator creation +- **Dataset Aliasing**: IO `alias()` method enables automatic dependency detection across pipelines diff --git a/dagger/dag_creator/airflow/operator_creators/databricks_dlt_creator.py b/dagger/dag_creator/airflow/operator_creators/databricks_dlt_creator.py new file mode 100644 index 0000000..66034a6 --- /dev/null +++ b/dagger/dag_creator/airflow/operator_creators/databricks_dlt_creator.py @@ -0,0 +1,118 @@ +"""Operator creator for Databricks DLT (Delta Live Tables) pipelines.""" + +import logging +from typing import Any + +from airflow.models import BaseOperator, DAG + +from dagger.dag_creator.airflow.operator_creator import OperatorCreator +from dagger.pipeline.tasks.databricks_dlt_task import DatabricksDLTTask + +_logger = logging.getLogger(__name__) + + +def _cancel_databricks_run(context: dict[str, Any]) -> None: + """Cancel a Databricks job run when task fails or is cleared. + + This callback retrieves the run_id from XCom and cancels the corresponding + Databricks job run. Used as on_failure_callback to ensure jobs are cancelled + when tasks are marked as failed. + + Args: + context: Airflow context dictionary containing task instance and other metadata. + """ + from airflow.providers.databricks.hooks.databricks import DatabricksHook + + ti = context.get("task_instance") + if not ti: + _logger.warning("No task instance in context, cannot cancel Databricks run") + return + + # Get run_id from XCom (pushed by DatabricksRunNowOperator) + run_id = ti.xcom_pull(task_ids=ti.task_id, key="run_id") + if not run_id: + _logger.warning(f"No run_id found in XCom for task {ti.task_id}") + return + + # Get the databricks_conn_id from the operator + databricks_conn_id = getattr(ti.task, "databricks_conn_id", "databricks_default") + + try: + hook = DatabricksHook(databricks_conn_id=databricks_conn_id) + hook.cancel_run(run_id) + _logger.info(f"Cancelled Databricks run {run_id} for task {ti.task_id}") + except Exception as e: + _logger.error(f"Failed to cancel Databricks run {run_id}: {e}") + + +class DatabricksDLTCreator(OperatorCreator): + """Creates operators for triggering Databricks DLT pipelines via Jobs. + + This creator uses DatabricksRunNowOperator to trigger a Databricks Job + that wraps the DLT pipeline. The job is identified by name and must be + defined in the Databricks Asset Bundle. + + Attributes: + ref_name: Reference name used by OperatorFactory to match this creator + with DatabricksDLTTask instances. + """ + + ref_name: str = "databricks_dlt" + + def __init__(self, task: DatabricksDLTTask, dag: DAG) -> None: + """Initialize the DatabricksDLTCreator. + + Args: + task: The DatabricksDLTTask containing pipeline configuration. + dag: The Airflow DAG this operator will belong to. + """ + super().__init__(task, dag) + + def _create_operator(self, **kwargs: Any) -> BaseOperator: + """Create a DatabricksRunNowOperator for the DLT pipeline. + + Creates an Airflow operator that triggers an existing Databricks Job + by name. The job must have a pipeline_task that references the DLT + pipeline. + + Args: + **kwargs: Additional keyword arguments passed to the operator. + + Returns: + A configured DatabricksRunNowOperator instance. + """ + # Import here to avoid import errors if databricks provider not installed + from datetime import timedelta + + from airflow.providers.databricks.operators.databricks import ( + DatabricksRunNowOperator, + ) + + # Get task parameters + job_name: str = self._task.job_name + databricks_conn_id: str = getattr( + self._task, "databricks_conn_id", "databricks_default" + ) + wait_for_completion: bool = getattr(self._task, "wait_for_completion", True) + poll_interval_seconds: int = getattr(self._task, "poll_interval_seconds", 30) + timeout_seconds: int = getattr(self._task, "timeout_seconds", 3600) + + # DatabricksRunNowOperator triggers an existing Databricks Job by name + # The job must have a pipeline_task that references the DLT pipeline + # Note: timeout is handled via Airflow's execution_timeout, not a direct parameter + # Note: on_kill() is already implemented in DatabricksRunNowOperator to cancel runs + # We add on_failure_callback to also cancel when task is marked as failed + operator: BaseOperator = DatabricksRunNowOperator( + dag=self._dag, + task_id=self._task.name, + databricks_conn_id=databricks_conn_id, + job_name=job_name, + wait_for_termination=wait_for_completion, + polling_period_seconds=poll_interval_seconds, + execution_timeout=timedelta(seconds=timeout_seconds), + do_xcom_push=True, # Required to store run_id for cancellation callback + on_failure_callback=_cancel_databricks_run, + **kwargs, + ) + + return operator diff --git a/dagger/dag_creator/airflow/operator_factory.py b/dagger/dag_creator/airflow/operator_factory.py index 2a1654a..dd7344e 100644 --- a/dagger/dag_creator/airflow/operator_factory.py +++ b/dagger/dag_creator/airflow/operator_factory.py @@ -4,6 +4,7 @@ airflow_op_creator, athena_transform_creator, batch_creator, + databricks_dlt_creator, dbt_creator, dummy_creator, python_creator, diff --git a/dagger/pipeline/task_factory.py b/dagger/pipeline/task_factory.py index 9ed79e7..f5f80bb 100644 --- a/dagger/pipeline/task_factory.py +++ b/dagger/pipeline/task_factory.py @@ -3,6 +3,7 @@ airflow_op_task, athena_transform_task, batch_task, + databricks_dlt_task, dbt_task, dummy_task, python_task, @@ -12,7 +13,7 @@ reverse_etl_task, spark_task, sqoop_task, - soda_task + soda_task, ) from dagger.utilities.classes import get_deep_obj_subclasses diff --git a/dagger/pipeline/tasks/databricks_dlt_task.py b/dagger/pipeline/tasks/databricks_dlt_task.py new file mode 100644 index 0000000..4f0b113 --- /dev/null +++ b/dagger/pipeline/tasks/databricks_dlt_task.py @@ -0,0 +1,164 @@ +"""Task configuration for Databricks DLT (Delta Live Tables) pipelines.""" + +from typing import Any, Optional + +from dagger.pipeline.task import Task +from dagger.utilities.config_validator import Attribute + + +class DatabricksDLTTask(Task): + """Task configuration for triggering Databricks DLT pipelines via Jobs. + + This task type uses DatabricksRunNowOperator to trigger a Databricks Job + that wraps the DLT pipeline. The job is identified by name and must be + defined in the Databricks Asset Bundle. + + Attributes: + ref_name: Reference name used by TaskFactory to instantiate this task type. + job_name: Databricks Job name that triggers the DLT pipeline. + databricks_conn_id: Airflow connection ID for Databricks. + wait_for_completion: Whether to wait for job completion. + poll_interval_seconds: Polling interval in seconds. + timeout_seconds: Timeout in seconds. + cancel_on_kill: Whether to cancel Databricks job if Airflow task is killed. + + Example YAML configuration: + type: databricks_dlt + description: Run DLT pipeline users + inputs: + - type: athena + schema: ddb_changelogs + table: order_preference + follow_external_dependency: true + outputs: + - type: databricks + catalog: ${ENV_MARTS} + schema: dlt_users + table: silver_order_preference + task_parameters: + job_name: dlt-users + databricks_conn_id: databricks_default + wait_for_completion: true + poll_interval_seconds: 30 + timeout_seconds: 3600 + """ + + ref_name: str = "databricks_dlt" + + @classmethod + def init_attributes(cls, orig_cls: type) -> None: + """Initialize configuration attributes for YAML parsing. + + Registers all task_parameters attributes that can be specified in the + YAML configuration file. Called by the Task metaclass during class creation. + + Args: + orig_cls: The original class being initialized (used for attribute registration). + """ + cls.add_config_attributes( + [ + Attribute( + attribute_name="job_name", + parent_fields=["task_parameters"], + comment="Databricks Job name that triggers the DLT pipeline", + ), + Attribute( + attribute_name="databricks_conn_id", + parent_fields=["task_parameters"], + required=False, + comment="Airflow connection ID for Databricks (default: databricks_default)", + ), + Attribute( + attribute_name="wait_for_completion", + parent_fields=["task_parameters"], + required=False, + validator=bool, + comment="Wait for job to complete (default: true)", + ), + Attribute( + attribute_name="poll_interval_seconds", + parent_fields=["task_parameters"], + required=False, + validator=int, + comment="Polling interval in seconds (default: 30)", + ), + Attribute( + attribute_name="timeout_seconds", + parent_fields=["task_parameters"], + required=False, + validator=int, + comment="Timeout in seconds (default: 3600)", + ), + Attribute( + attribute_name="cancel_on_kill", + parent_fields=["task_parameters"], + required=False, + validator=bool, + comment="Cancel Databricks job if Airflow task is killed (default: true)", + ), + ] + ) + + def __init__( + self, + name: str, + pipeline_name: str, + pipeline: Any, + job_config: dict[str, Any], + ) -> None: + """Initialize a DatabricksDLTTask instance. + + Args: + name: The task name (used as task_id in Airflow). + pipeline_name: Name of the Dagger pipeline this task belongs to. + pipeline: The parent Pipeline object. + job_config: Dictionary containing the task configuration from YAML. + """ + super().__init__(name, pipeline_name, pipeline, job_config) + + self._job_name: str = self.parse_attribute("job_name") + self._databricks_conn_id: str = ( + self.parse_attribute("databricks_conn_id") or "databricks_default" + ) + wait_for_completion: Optional[bool] = self.parse_attribute("wait_for_completion") + self._wait_for_completion: bool = ( + wait_for_completion if wait_for_completion is not None else True + ) + self._poll_interval_seconds: int = ( + self.parse_attribute("poll_interval_seconds") or 30 + ) + self._timeout_seconds: int = self.parse_attribute("timeout_seconds") or 3600 + cancel_on_kill: Optional[bool] = self.parse_attribute("cancel_on_kill") + self._cancel_on_kill: bool = ( + cancel_on_kill if cancel_on_kill is not None else True + ) + + @property + def job_name(self) -> str: + """Databricks Job name that triggers the DLT pipeline.""" + return self._job_name + + @property + def databricks_conn_id(self) -> str: + """Airflow connection ID for Databricks.""" + return self._databricks_conn_id + + @property + def wait_for_completion(self) -> bool: + """Whether to wait for job completion.""" + return self._wait_for_completion + + @property + def poll_interval_seconds(self) -> int: + """Polling interval in seconds.""" + return self._poll_interval_seconds + + @property + def timeout_seconds(self) -> int: + """Timeout in seconds.""" + return self._timeout_seconds + + @property + def cancel_on_kill(self) -> bool: + """Whether to cancel Databricks job if Airflow task is killed.""" + return self._cancel_on_kill diff --git a/dagger/plugins/__init__.py b/dagger/plugins/__init__.py new file mode 100644 index 0000000..26acb8c --- /dev/null +++ b/dagger/plugins/__init__.py @@ -0,0 +1 @@ +"""Dagger plugins for task generation.""" diff --git a/dagger/plugins/dlt_task_generator/__init__.py b/dagger/plugins/dlt_task_generator/__init__.py new file mode 100644 index 0000000..49e4a17 --- /dev/null +++ b/dagger/plugins/dlt_task_generator/__init__.py @@ -0,0 +1,6 @@ +"""DLT Task Generator plugin for generating Dagger configs from Databricks Asset Bundles.""" + +from dagger.plugins.dlt_task_generator.bundle_parser import DatabricksBundleParser +from dagger.plugins.dlt_task_generator.dlt_task_generator import DLTTaskGenerator + +__all__ = ["DatabricksBundleParser", "DLTTaskGenerator"] diff --git a/dagger/plugins/dlt_task_generator/bundle_parser.py b/dagger/plugins/dlt_task_generator/bundle_parser.py new file mode 100644 index 0000000..5a15ed7 --- /dev/null +++ b/dagger/plugins/dlt_task_generator/bundle_parser.py @@ -0,0 +1,309 @@ +"""Parse Databricks Asset Bundle YAML files for DLT pipeline configuration.""" + +import logging +import re +from dataclasses import dataclass, field +from pathlib import Path +from typing import Any, Optional + +import yaml + +_logger = logging.getLogger(__name__) + + +@dataclass +class TableConfig: + """Configuration for a single table in a DLT pipeline. + + Attributes: + database: Source database name. + table: Source table name. + changelog_type: Type of changelog source ('dynamodb' or 'postgres'). + unique_keys: List of columns that uniquely identify a row. + scd_type: Slowly changing dimension type (1 or 2). + """ + + database: str + table: str + changelog_type: str # 'dynamodb' or 'postgres' + unique_keys: list[str] = field(default_factory=list) + scd_type: int = 1 + + @property + def source_schema(self) -> str: + """Get the source schema name for Athena/Glue catalog. + + For DynamoDB: ddb_changelogs + For PostgreSQL: pg_changelogs_kafka_{database_normalized} + """ + if self.changelog_type == "dynamodb": + return self.database + elif self.changelog_type == "postgres": + # Normalize database name (replace hyphens with underscores) + db_normalized = self.database.replace("-", "_") + return f"pg_changelogs_kafka_{db_normalized}" + else: + return self.database + + @property + def silver_table_name(self) -> str: + """Get the silver table name produced by DLT.""" + return f"silver_{self.table}" + + @property + def bronze_table_name(self) -> str: + """Get the bronze table name produced by DLT.""" + return f"bronze_{self.table}" + + +@dataclass +class PipelineConfig: + """Configuration for a DLT pipeline parsed from Databricks Asset Bundle. + + Attributes: + name: Pipeline/bundle name. + catalog: Target Unity Catalog name. + schema: Target schema name. + tables: List of table configurations for the pipeline. + targets: Target environment configurations (dev/prod). + variables: Variable definitions from databricks.yml. + tags: Pipeline tags. + """ + + name: str + catalog: str + schema: str + tables: list[TableConfig] = field(default_factory=list) + targets: dict[str, Any] = field(default_factory=dict) + variables: dict[str, Any] = field(default_factory=dict) + tags: dict[str, str] = field(default_factory=dict) + + +class DatabricksBundleParser: + """Parse Databricks Asset Bundle YAML files (databricks.yml and tables.yml). + + This parser extracts pipeline configuration from Databricks Asset Bundles, + including the target catalog, schema, table definitions, and job configuration. + It resolves variable references to Dagger environment variable format. + + Attributes: + _databricks_yml_path: Path to the databricks.yml file. + _tables_yml_path: Path to the tables.yml file. + _databricks_config: Parsed databricks.yml content. + _tables_config: Parsed tables.yml content. + _pipeline_config: Cached PipelineConfig instance. + """ + + def __init__( + self, + databricks_yml_path: Path, + tables_yml_path: Optional[Path] = None, + ) -> None: + """Initialize the parser with paths to bundle YAML files. + + Args: + databricks_yml_path: Path to the databricks.yml file. + tables_yml_path: Optional path to the tables.yml file. If not provided, + will look for tables.yml in the same directory. + """ + self._databricks_yml_path = Path(databricks_yml_path) + self._tables_yml_path = ( + Path(tables_yml_path) + if tables_yml_path + else self._databricks_yml_path.parent / "tables.yml" + ) + + self._databricks_config = self._load_yaml(self._databricks_yml_path) + self._tables_config = ( + self._load_yaml(self._tables_yml_path) + if self._tables_yml_path.exists() + else {} + ) + + self._pipeline_config: Optional[PipelineConfig] = None + + @staticmethod + def _load_yaml(path: Path) -> dict[str, Any]: + """Load and parse a YAML file. + + Args: + path: Path to the YAML file. + + Returns: + Parsed YAML content as a dictionary. + + Raises: + yaml.YAMLError: If the YAML file is malformed. + """ + try: + with open(path, "r") as f: + return yaml.safe_load(f) or {} + except FileNotFoundError: + _logger.warning(f"YAML file not found: {path}") + return {} + except yaml.YAMLError as e: + _logger.error(f"Error parsing YAML file {path}: {e}") + raise + + def _resolve_variable(self, value: str) -> str: + """Resolve Databricks bundle variable references like ${var.catalog}. + + Args: + value: String that may contain variable references + + Returns: + Resolved string with environment variable format for Dagger + """ + if not isinstance(value, str): + return value + + # Match ${var.variable_name} pattern + var_pattern = re.compile(r"\$\{var\.(\w+)\}") + + def replace_var(match): + var_name = match.group(1) + # Get the default value from variables section + var_config = self._databricks_config.get("variables", {}).get(var_name, {}) + default_value = var_config.get("default", "") + + # Map to Dagger environment variables + if var_name == "catalog": + # Map catalog to Dagger's ${ENV_MARTS} pattern + return "${ENV_MARTS}" + return default_value + + return var_pattern.sub(replace_var, value) + + def _parse_tables(self) -> list[TableConfig]: + """Parse table configurations from tables.yml. + + Returns: + List of TableConfig instances for each table defined in the bundle. + """ + tables = [] + defaults = self._tables_config.get("defaults", {}) + default_scd_type = defaults.get("scd_type", 1) + + for table_config in self._tables_config.get("tables", []): + tables.append( + TableConfig( + database=table_config.get("database", ""), + table=table_config.get("table", ""), + changelog_type=table_config.get("changelog_type", "dynamodb"), + unique_keys=table_config.get("unique_keys", []), + scd_type=table_config.get("scd_type", default_scd_type), + ) + ) + + return tables + + def _parse_pipeline(self) -> PipelineConfig: + """Parse pipeline configuration from databricks.yml. + + Extracts bundle name, variables, targets, and pipeline-specific settings + from the Databricks Asset Bundle configuration. + + Returns: + PipelineConfig instance with all parsed configuration. + """ + bundle_name = self._databricks_config.get("bundle", {}).get("name", "") + variables = self._databricks_config.get("variables", {}) + targets = self._databricks_config.get("targets", {}) + + # Get pipeline configuration from resources + resources = self._databricks_config.get("resources", {}) + pipelines = resources.get("pipelines", {}) + + # Get the first pipeline (usually matches bundle name) + pipeline_key = bundle_name or next(iter(pipelines.keys()), "") + pipeline_config = pipelines.get(pipeline_key, {}) + + catalog = self._resolve_variable(pipeline_config.get("catalog", "")) + schema = pipeline_config.get("schema", "") + tags = pipeline_config.get("tags", {}) + + return PipelineConfig( + name=bundle_name, + catalog=catalog, + schema=schema, + tables=self._parse_tables(), + targets=targets, + variables=variables, + tags=tags, + ) + + def parse(self) -> PipelineConfig: + """Parse the Databricks Asset Bundle and return pipeline configuration. + + Returns: + PipelineConfig with all parsed configuration + """ + if self._pipeline_config is None: + self._pipeline_config = self._parse_pipeline() + return self._pipeline_config + + def get_bundle_name(self) -> str: + """Return the bundle/pipeline name. + + Returns: + The bundle name from databricks.yml. + """ + return self.parse().name + + def get_catalog(self) -> str: + """Return the target catalog with Dagger environment variable format. + + Returns: + Catalog name with variables resolved to Dagger format (e.g., ${ENV_MARTS}). + """ + return self.parse().catalog + + def get_schema(self) -> str: + """Return the target schema. + + Returns: + Target schema name for the DLT pipeline. + """ + return self.parse().schema + + def get_tables(self) -> list[TableConfig]: + """Return the list of table configurations. + + Returns: + List of TableConfig instances for all tables in the pipeline. + """ + return self.parse().tables + + def get_targets(self) -> dict[str, Any]: + """Return target configurations (dev/prod). + + Returns: + Dictionary of target environment configurations. + """ + return self.parse().targets + + def get_variables(self) -> dict[str, Any]: + """Return variable definitions. + + Returns: + Dictionary of variable definitions from databricks.yml. + """ + return self.parse().variables + + def get_job_name(self) -> str: + """Get the Databricks Job name that triggers this pipeline. + + Looks for a job defined in resources.jobs that wraps the pipeline. + Falls back to a default naming convention if no job is defined. + + Returns: + The job name to use with DatabricksRunNowOperator + """ + resources = self._databricks_config.get("resources", {}) + jobs = resources.get("jobs", {}) + + # Return the first job's name, or use default naming convention + for job_config in jobs.values(): + return job_config.get("name", f"dlt-{self.get_bundle_name()}") + + return f"dlt-{self.get_bundle_name()}" diff --git a/dagger/plugins/dlt_task_generator/dlt_task_generator.py b/dagger/plugins/dlt_task_generator/dlt_task_generator.py new file mode 100644 index 0000000..6136227 --- /dev/null +++ b/dagger/plugins/dlt_task_generator/dlt_task_generator.py @@ -0,0 +1,374 @@ +"""Generate Dagger task configurations from Databricks DLT bundle definitions.""" + +import logging +import os +from pathlib import Path +from typing import Any, Optional + +import yaml + +from dagger.plugins.dlt_task_generator.bundle_parser import ( + DatabricksBundleParser, + PipelineConfig, + TableConfig, +) + +_logger = logging.getLogger(__name__) + +# Default path to the DLT pipelines repository (can be overridden via env var) +DEFAULT_DLT_PIPELINES_REPO = os.getenv( + "DLT_PIPELINES_REPO", + str(Path(__file__).parent.parent.parent.parent.parent / "dataeng-databricks-dlt-pipelines"), +) + + +class DLTTaskGenerator: + """Generate Dagger task configurations from Databricks DLT bundle definitions. + + This generator reads Databricks Asset Bundle configurations and produces + Dagger-compatible YAML task configurations for DLT pipelines. + + Attributes: + ATHENA_TASK_BASE: Base configuration for Athena input tasks. + DATABRICKS_TASK_BASE: Base configuration for Databricks output tasks. + DUMMY_TASK_BASE: Base configuration for dummy tasks. + """ + + ATHENA_TASK_BASE: dict[str, str] = {"type": "athena"} + DATABRICKS_TASK_BASE: dict[str, str] = {"type": "databricks"} + DUMMY_TASK_BASE: dict[str, str] = {"type": "dummy"} + + def __init__(self, dlt_repo_path: Optional[str] = None) -> None: + """Initialize the generator with path to DLT pipelines repository. + + Args: + dlt_repo_path: Path to the dataeng-databricks-dlt-pipelines repository. + Defaults to DLT_PIPELINES_REPO env var or sibling directory. + """ + self._dlt_repo_path = Path(dlt_repo_path or DEFAULT_DLT_PIPELINES_REPO) + self._pipelines: dict[str, DatabricksBundleParser] = {} + self._load_all_pipelines() + + def _load_all_pipelines(self) -> None: + """Load all pipeline bundles from the DLT repository. + + Scans the pipelines directory and loads each valid Databricks Asset Bundle + found. Bundles are identified by the presence of a databricks.yml file. + """ + pipelines_dir = self._dlt_repo_path / "pipelines" + + if not pipelines_dir.exists(): + _logger.warning(f"DLT pipelines directory not found: {pipelines_dir}") + return + + for pipeline_dir in pipelines_dir.iterdir(): + if not pipeline_dir.is_dir(): + continue + + databricks_yml = pipeline_dir / "databricks.yml" + if not databricks_yml.exists(): + continue + + tables_yml = pipeline_dir / "tables.yml" + try: + parser = DatabricksBundleParser(databricks_yml, tables_yml) + pipeline_name = parser.get_bundle_name() or pipeline_dir.name + self._pipelines[pipeline_name] = parser + _logger.info(f"Loaded DLT pipeline: {pipeline_name}") + except Exception as e: + _logger.error(f"Error loading pipeline from {pipeline_dir}: {e}") + + def get_pipeline_names(self) -> list[str]: + """Return list of available DLT pipeline names.""" + return list(self._pipelines.keys()) + + def get_pipeline_config(self, pipeline_name: str) -> PipelineConfig: + """Get the parsed pipeline configuration. + + Args: + pipeline_name: Name of the pipeline + + Returns: + PipelineConfig object + + Raises: + ValueError: If pipeline not found + """ + if pipeline_name not in self._pipelines: + raise ValueError( + f"Unknown pipeline: {pipeline_name}. " + f"Available pipelines: {self.get_pipeline_names()}" + ) + return self._pipelines[pipeline_name].parse() + + def _get_athena_input( + self, table: TableConfig, follow_external_dependency: bool = True + ) -> dict[str, Any]: + """Generate an Athena input task for a source changelog table. + + Args: + table: Table configuration from the DLT bundle. + follow_external_dependency: Whether to create an ExternalTaskSensor + for cross-pipeline dependency tracking. + + Returns: + Dagger Athena task configuration dict. + """ + task = self.ATHENA_TASK_BASE.copy() + task.update( + { + "schema": table.source_schema, + "table": table.table, + "name": f"{table.source_schema}__{table.table}_athena", + } + ) + if follow_external_dependency: + task["follow_external_dependency"] = True + return task + + def _get_databricks_output( + self, table: TableConfig, catalog: str, schema: str + ) -> dict[str, Any]: + """Generate a Databricks output task for a silver table. + + Args: + table: Table configuration from the DLT bundle. + catalog: Target Unity Catalog name (e.g., ${ENV_MARTS}). + schema: Target schema name. + + Returns: + Dagger Databricks task configuration dict. + """ + task = self.DATABRICKS_TASK_BASE.copy() + # Normalize catalog name for task naming + catalog_name = catalog.replace("${", "").replace("}", "").lower() + task.update( + { + "catalog": catalog, + "schema": schema, + "table": table.silver_table_name, + "name": f"{catalog_name}__{schema}__{table.silver_table_name}_databricks", + } + ) + return task + + def get_inputs( + self, pipeline_name: str, follow_external_dependency: bool = True + ) -> list[dict[str, Any]]: + """Generate input dependencies for a DLT pipeline task. + + These are the source changelog tables that the DLT pipeline reads from. + + Args: + pipeline_name: Name of the DLT pipeline. + follow_external_dependency: Whether to create ExternalTaskSensors + for cross-pipeline dependency tracking. + + Returns: + List of Dagger input task configurations. + """ + config = self.get_pipeline_config(pipeline_name) + inputs = [] + + for table in config.tables: + input_task = self._get_athena_input(table, follow_external_dependency) + inputs.append(input_task) + + return inputs + + def get_outputs(self, pipeline_name: str) -> list[dict[str, Any]]: + """Generate output declarations for a DLT pipeline task. + + These are the silver tables produced by the DLT pipeline. + + Args: + pipeline_name: Name of the DLT pipeline. + + Returns: + List of Dagger output task configurations. + """ + config = self.get_pipeline_config(pipeline_name) + outputs = [] + + for table in config.tables: + output_task = self._get_databricks_output( + table, config.catalog, config.schema + ) + outputs.append(output_task) + + return outputs + + def get_task_parameters(self, pipeline_name: str) -> dict[str, Any]: + """Generate task parameters for triggering a DLT pipeline via Databricks Job. + + Args: + pipeline_name: Name of the DLT pipeline. + + Returns: + Dict of task parameters for the DatabricksRunNowOperator. + """ + parser = self._pipelines[pipeline_name] + return { + "job_name": parser.get_job_name(), + "databricks_conn_id": "${DATABRICKS_CONN_ID}", + "wait_for_completion": True, + "poll_interval_seconds": 30, + "timeout_seconds": 3600, + } + + def generate_task_config( + self, + pipeline_name: str, + description: Optional[str] = None, + follow_external_dependency: bool = True, + ) -> dict[str, Any]: + """Generate a complete Dagger task configuration for a DLT pipeline. + + Args: + pipeline_name: Name of the DLT pipeline. + description: Optional task description. Defaults to auto-generated. + follow_external_dependency: Whether to create ExternalTaskSensors for inputs. + + Returns: + Complete Dagger task configuration dict ready for YAML serialization. + """ + config = self.get_pipeline_config(pipeline_name) + + task_config = { + "type": "databricks_dlt", + "description": description or f"Run DLT pipeline {pipeline_name}", + "inputs": self.get_inputs(pipeline_name, follow_external_dependency), + "outputs": self.get_outputs(pipeline_name), + "airflow_task_parameters": { + "retries": 2, + "retry_delay": 300, + }, + "template_parameters": {}, + "task_parameters": self.get_task_parameters(pipeline_name), + } + + return task_config + + def generate_pipeline_config( + self, + pipeline_name: str, + schedule: str = "0 * * * *", + owner: str = "dataeng@choco.com", + ) -> dict[str, Any]: + """Generate a Dagger pipeline.yaml configuration for a DLT pipeline DAG. + + Args: + pipeline_name: Name of the DLT pipeline. + schedule: Cron schedule expression. Defaults to hourly. + owner: Pipeline owner email address. + + Returns: + Dagger pipeline.yaml configuration dict ready for YAML serialization. + """ + config = self.get_pipeline_config(pipeline_name) + + return { + "owner": owner, + "description": f"DLT Pipeline - {pipeline_name}", + "schedule": schedule, + "start_date": "2024-01-01T00:00", + "airflow_parameters": { + "default_args": { + "retries": 2, + "retry_delay": 180, + "depends_on_past": False, + }, + "dag_parameters": { + "catchup": False, + "max_active_runs": 1, + "tags": ["dlt", "databricks", pipeline_name], + }, + }, + "alerts": [ + { + "type": "slack", + "channel": "#${ENV}-airflow-alerts", + "mentions": ["@dataeng-oncall"], + } + ], + } + + def write_task_config( + self, pipeline_name: str, output_path: Path, **kwargs: Any + ) -> Path: + """Write a task configuration to a YAML file. + + Args: + pipeline_name: Name of the DLT pipeline. + output_path: Directory to write the file to. + **kwargs: Additional arguments passed to generate_task_config. + + Returns: + Path to the written YAML file. + """ + output_path = Path(output_path) + output_path.mkdir(parents=True, exist_ok=True) + + task_config = self.generate_task_config(pipeline_name, **kwargs) + file_path = output_path / f"{pipeline_name}_dlt.yaml" + + with open(file_path, "w") as f: + # Add autogenerated marker + task_config["autogenerated_by_dagger"] = f"dlt_task_generator:{pipeline_name}" + yaml.dump(task_config, f, default_flow_style=False, sort_keys=False) + + _logger.info(f"Generated task config: {file_path}") + return file_path + + def write_pipeline_config( + self, pipeline_name: str, output_path: Path, **kwargs: Any + ) -> Path: + """Write a pipeline configuration to a YAML file. + + Args: + pipeline_name: Name of the DLT pipeline. + output_path: Directory to write the file to. + **kwargs: Additional arguments passed to generate_pipeline_config. + + Returns: + Path to the written YAML file. + """ + output_path = Path(output_path) + output_path.mkdir(parents=True, exist_ok=True) + + pipeline_config = self.generate_pipeline_config(pipeline_name, **kwargs) + file_path = output_path / "pipeline.yaml" + + with open(file_path, "w") as f: + yaml.dump(pipeline_config, f, default_flow_style=False, sort_keys=False) + + _logger.info(f"Generated pipeline config: {file_path}") + return file_path + + def generate_all(self, output_base_path: Path) -> list[Path]: + """Generate all DLT pipeline configurations. + + Creates pipeline.yaml and task configuration files for each loaded + DLT pipeline in the repository. + + Args: + output_base_path: Base directory for output (e.g., dags/dlt/). + + Returns: + List of paths to all generated files. + """ + output_base_path = Path(output_base_path) + generated_files = [] + + for pipeline_name in self.get_pipeline_names(): + pipeline_output_path = output_base_path / pipeline_name + + # Generate pipeline.yaml + pipeline_file = self.write_pipeline_config(pipeline_name, pipeline_output_path) + generated_files.append(pipeline_file) + + # Generate task config + task_file = self.write_task_config(pipeline_name, pipeline_output_path) + generated_files.append(task_file) + + return generated_files diff --git a/dagger/utilities/dbt_config_parser.py b/dagger/utilities/dbt_config_parser.py index 9a341f6..9b86d5d 100644 --- a/dagger/utilities/dbt_config_parser.py +++ b/dagger/utilities/dbt_config_parser.py @@ -296,7 +296,20 @@ def _generate_dagger_output(self, node: dict): class DatabricksDBTConfigParser(DBTConfigParser): - """Implementation for Databricks configurations.""" + """DBT config parser implementation for Databricks Unity Catalog. + + Parses dbt manifest.json files for projects using the databricks-dbt adapter + and generates Dagger task configurations. Handles both Unity Catalog sources + (accessed via Databricks) and legacy Hive metastore sources (accessed via Athena). + + Attributes: + LEGACY_HIVE_DATABASES: Set of database names that indicate legacy Hive + metastore tables accessed via Athena rather than Unity Catalog. + """ + + # Schemas that indicate sources are in legacy Hive metastore (accessed via Athena) + # rather than Unity Catalog (accessed via Databricks) + LEGACY_HIVE_DATABASES: set[str] = {"hive_metastore"} def __init__(self, default_config_parameters: dict): super().__init__(default_config_parameters) @@ -306,17 +319,132 @@ def __init__(self, default_config_parameters: dict): "create_external_athena_table", False ) - def _is_node_preparation_model(self, node: dict): + def _is_databricks_source(self, node: dict) -> bool: + """Check if a source is a Unity Catalog table (accessed via Databricks). + + Sources with database 'hive_metastore' are legacy tables accessed via Athena. + Sources with other databases (e.g., Unity Catalog like ${ENV_MARTS}) are + Databricks tables that should create databricks input tasks. + + Args: + node: The source node from dbt manifest + + Returns: + True if the source is a Unity Catalog table, False otherwise """ - Define whether it is a preparation model. + database = node.get("database", "") + return database not in self.LEGACY_HIVE_DATABASES + + def _is_node_preparation_model(self, node: dict) -> bool: + """Determine whether a node is a preparation model. + + Preparation models are intermediate models in the transformation pipeline + that should not create external dependencies. + + Args: + node: The dbt node from manifest.json. + + Returns: + True if the node's schema contains 'preparation', False otherwise. """ return "preparation" in node.get("schema", "") - def _get_table_task( + def _get_databricks_source_task( self, node: dict, follow_external_dependency: bool = False ) -> dict: + """Generate a databricks input task for a Unity Catalog source. + + This is used for sources that point to Unity Catalog tables (e.g., DLT outputs) + rather than legacy Hive metastore tables. + + Args: + node: The source node from dbt manifest + follow_external_dependency: Whether to create an ExternalTaskSensor + + Returns: + Dagger databricks task configuration dict """ - Generates the dagger databricks task for the DBT model node + task = DATABRICKS_TASK_BASE.copy() + if follow_external_dependency: + task["follow_external_dependency"] = True + + task["catalog"] = node.get("database", self._default_catalog) + task["schema"] = node.get("schema", self._default_schema) + task["table"] = node.get("name", "") + task["name"] = f"{task['catalog']}__{task['schema']}__{task['table']}_databricks" + + return task + + def _generate_dagger_tasks(self, node_name: str) -> List[Dict]: + """Generate dagger tasks, with special handling for Databricks Unity Catalog sources. + + Overrides the base class method to handle sources that are in Unity Catalog + (e.g., DLT output tables) by creating databricks input tasks instead of athena tasks. + + Args: + node_name: The name of the DBT model node + + Returns: + List[Dict]: The respective dagger tasks for the DBT model node + """ + dagger_tasks = [] + + if node_name.startswith("source"): + node = self._sources_in_manifest[node_name] + else: + node = self._nodes_in_manifest[node_name] + + resource_type = node.get("resource_type") + materialized_type = node.get("config", {}).get("materialized") + + follow_external_dependency = True + if resource_type == "seed" or (self._is_node_preparation_model(node) and materialized_type != "table"): + follow_external_dependency = False + + if resource_type == "source": + # Check if this source is a Unity Catalog table (e.g., DLT outputs) + if self._is_databricks_source(node): + table_task = self._get_databricks_source_task( + node, follow_external_dependency=follow_external_dependency + ) + else: + # Legacy Hive metastore sources use Athena + table_task = self._get_athena_table_task( + node, follow_external_dependency=follow_external_dependency + ) + dagger_tasks.append(table_task) + + elif materialized_type == "ephemeral": + task = self._get_dummy_task(node) + dagger_tasks.append(task) + for dependent_node_name in node.get("depends_on", {}).get("nodes", []): + dagger_tasks += self._generate_dagger_tasks(dependent_node_name) + + else: + table_task = self._get_table_task(node, follow_external_dependency=follow_external_dependency) + dagger_tasks.append(table_task) + + if materialized_type in ("table", "incremental"): + dagger_tasks.append(self._get_s3_task(node)) + elif self._is_node_preparation_model(node): + for dependent_node_name in node.get("depends_on", {}).get("nodes", []): + dagger_tasks.extend( + self._generate_dagger_tasks(dependent_node_name) + ) + + return dagger_tasks + + def _get_table_task( + self, node: dict, follow_external_dependency: bool = False + ) -> dict: + """Generate a Databricks table task for a dbt model node. + + Args: + node: The dbt model node from manifest.json. + follow_external_dependency: Whether to create an ExternalTaskSensor. + + Returns: + Dagger databricks task configuration dict. """ task = DATABRICKS_TASK_BASE.copy() if follow_external_dependency: @@ -334,8 +462,15 @@ def _get_table_task( def _get_model_data_location( self, node: dict, schema: str, model_name: str ) -> Tuple[str, str]: - """ - Gets the S3 path of the dbt model relative to the data bucket. + """Get the S3 path of a dbt model relative to the data bucket. + + Args: + node: The dbt model node from manifest.json. + schema: The schema name (unused for Databricks, kept for interface compatibility). + model_name: The model name. + + Returns: + Tuple of (bucket_name, data_path). """ location_root = node.get("config", {}).get("location_root") location = join(location_root, model_name) @@ -345,32 +480,39 @@ def _get_model_data_location( return bucket_name, data_path def _get_s3_task(self, node: dict, is_output: bool = False) -> dict: - """ - Generates the dagger s3 task for the databricks-dbt model node + """Generate an S3 task for a databricks-dbt model node. + + Args: + node: The dbt model node from manifest.json. + is_output: If True, names the task 'output_s3_path' for output declarations. + + Returns: + Dagger S3 task configuration dict. """ task = S3_TASK_BASE.copy() schema = node.get("schema", self._default_schema) table = node.get("name", "") - task["name"] = f"output_s3_path" if is_output else f"s3_{table}" + task["name"] = "output_s3_path" if is_output else f"s3_{table}" task["bucket"], task["path"] = self._get_model_data_location( node, schema, table ) return task - def _generate_dagger_output(self, node: dict): - """ - Generates the dagger output for the DBT model node with the databricks-dbt adapter. - If the model is materialized as a view or ephemeral, then a dummy task is created. - Otherwise, and databricks and s3 task is created for the DBT model node. - And if create_external_athena_table is True te an extra athena task is created. + def _generate_dagger_output(self, node: dict) -> List[Dict]: + """Generate dagger output tasks for a databricks-dbt model node. + + Creates output task configurations based on the model's materialization type: + - Ephemeral models produce a dummy task + - Table/incremental models produce databricks + S3 tasks + - Optionally adds an Athena task if create_external_athena_table is True + Args: - node: The extracted node from the manifest.json file + node: The dbt model node from manifest.json. Returns: - dict: The dagger output, which is a combination of an athena and s3 task for the DBT model node - + List of dagger output task configuration dicts. """ materialized_type = node.get("config", {}).get("materialized") if materialized_type == "ephemeral": diff --git a/dagger/utilities/module.py b/dagger/utilities/module.py index 7f33690..a12c25f 100644 --- a/dagger/utilities/module.py +++ b/dagger/utilities/module.py @@ -51,7 +51,9 @@ def read_task_config(self, task): return content @staticmethod - def load_plugins_to_jinja_environment(environment: jinja2.Environment) -> jinja2.Environment: + def load_plugins_to_jinja_environment( + environment: jinja2.Environment, + ) -> jinja2.Environment: """ Dynamically load all classes(plugins) from the folders defined in the conf.PLUGIN_DIRS variable. The folder contains all plugins that are part of the project. @@ -60,12 +62,20 @@ def load_plugins_to_jinja_environment(environment: jinja2.Environment) -> jinja2 """ for plugin_path in conf.PLUGIN_DIRS: for root, dirs, files in os.walk(plugin_path): - dirs[:] = [directory for directory in dirs if not directory.lower().startswith("test")] + dirs[:] = [ + directory + for directory in dirs + if not directory.lower().startswith("test") + ] for plugin_file in files: - if plugin_file.endswith(".py") and not (plugin_file.startswith("__") or plugin_file.startswith("test")): + if plugin_file.endswith(".py") and not ( + plugin_file.startswith("__") or plugin_file.startswith("test") + ): module_name = plugin_file.replace(".py", "") module_path = os.path.join(root, plugin_file) - spec = importlib.util.spec_from_file_location(module_name, module_path) + spec = importlib.util.spec_from_file_location( + module_name, module_path + ) module = importlib.util.module_from_spec(spec) spec.loader.exec_module(module) @@ -84,8 +94,7 @@ def replace_template_parameters(_task_str, _template_parameters): return ( rendered_task # TODO Remove this hack and use Jinja escaping instead of special expression in template files - .replace("__CBS__", "{") - .replace("__CBE__", "}") + .replace("__CBS__", "{").replace("__CBE__", "}") ) @staticmethod @@ -102,7 +111,7 @@ def generate_task_configs(self): template_parameters = {} template_parameters.update(self._default_parameters or {}) template_parameters.update(attrs) - template_parameters['branch_name'] = branch_name + template_parameters["branch_name"] = branch_name template_parameters.update(self._jinja_parameters) for task, task_yaml in self._tasks.items(): From 1a7ff6ecdf0f79f2a6dd434e91ab60f956d77e85 Mon Sep 17 00:00:00 2001 From: Mohamad Hallak <16711801+mrhallak@users.noreply.github.com> Date: Fri, 9 Jan 2026 15:07:40 +0100 Subject: [PATCH 180/189] Add databricks provider to Airflow dependencies Required for the new DLT creator which uses DatabricksRunNowOperator and DatabricksHook from apache-airflow-providers-databricks. --- reqs/dev.txt | 2 +- reqs/test.txt | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/reqs/dev.txt b/reqs/dev.txt index b39d00a..cb96f9e 100644 --- a/reqs/dev.txt +++ b/reqs/dev.txt @@ -1,5 +1,5 @@ pip==24.0 -apache-airflow[amazon,postgres,s3,statsd]==2.11.0 --constraint "https://raw.githubusercontent.com/apache/airflow/constraints-2.11.0/constraints-3.12.txt" +apache-airflow[amazon,databricks,postgres,s3,statsd]==2.11.0 --constraint "https://raw.githubusercontent.com/apache/airflow/constraints-2.11.0/constraints-3.12.txt" black==22.10.0 bumpversion==0.6.0 coverage==7.4.4 diff --git a/reqs/test.txt b/reqs/test.txt index 6bb6c2e..3b97347 100644 --- a/reqs/test.txt +++ b/reqs/test.txt @@ -1,4 +1,4 @@ -apache-airflow[amazon,postgres,s3,statsd]==2.11.0 --constraint "https://raw.githubusercontent.com/apache/airflow/constraints-2.11.0/constraints-3.12.txt" +apache-airflow[amazon,databricks,postgres,s3,statsd]==2.11.0 --constraint "https://raw.githubusercontent.com/apache/airflow/constraints-2.11.0/constraints-3.12.txt" pytest-cov==4.0.0 pytest==7.2.0 graphviz From 435beedec013ba2e77072d31f09bd24648b2ccb1 Mon Sep 17 00:00:00 2001 From: Mohamad Hallak <16711801+mrhallak@users.noreply.github.com> Date: Fri, 9 Jan 2026 15:10:30 +0100 Subject: [PATCH 181/189] Add databricks provider to production Airflow image Required for DLT creator to work in production. --- dockers/airflow/Dockerfile | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dockers/airflow/Dockerfile b/dockers/airflow/Dockerfile index 2bd40d5..71e73d7 100644 --- a/dockers/airflow/Dockerfile +++ b/dockers/airflow/Dockerfile @@ -52,7 +52,7 @@ RUN curl -Ls "https://awscli.amazonaws.com/awscli-exe-linux-x86_64.zip" -o "awsc RUN pip install -U --progress-bar off --no-cache-dir pip setuptools wheel COPY requirements.txt requirements.txt -RUN pip install --progress-bar off --no-cache-dir apache-airflow[amazon,postgres,s3,statsd]==$AIRFLOW_VERSION --constraint $AIRFLOW_CONSTRAINTS && \ +RUN pip install --progress-bar off --no-cache-dir apache-airflow[amazon,databricks,postgres,s3,statsd]==$AIRFLOW_VERSION --constraint $AIRFLOW_CONSTRAINTS && \ pip install --progress-bar off --no-cache-dir -r requirements.txt && \ apt-get purge --auto-remove -yq $BUILD_DEPS && \ apt-get autoremove --purge -yq && \ From 1a82c9d58cbd6c95904cf56ad18faa8fa9fbef4f Mon Sep 17 00:00:00 2001 From: Mohamad Hallak <16711801+mrhallak@users.noreply.github.com> Date: Mon, 12 Jan 2026 16:39:30 +0100 Subject: [PATCH 182/189] Remove plugin --- .../databricks_dlt_creator.py | 16 +- dagger/plugins/dlt_task_generator/__init__.py | 6 - .../dlt_task_generator/bundle_parser.py | 309 --------------- .../dlt_task_generator/dlt_task_generator.py | 374 ------------------ 4 files changed, 7 insertions(+), 698 deletions(-) delete mode 100644 dagger/plugins/dlt_task_generator/__init__.py delete mode 100644 dagger/plugins/dlt_task_generator/bundle_parser.py delete mode 100644 dagger/plugins/dlt_task_generator/dlt_task_generator.py diff --git a/dagger/dag_creator/airflow/operator_creators/databricks_dlt_creator.py b/dagger/dag_creator/airflow/operator_creators/databricks_dlt_creator.py index 66034a6..d3f538c 100644 --- a/dagger/dag_creator/airflow/operator_creators/databricks_dlt_creator.py +++ b/dagger/dag_creator/airflow/operator_creators/databricks_dlt_creator.py @@ -34,8 +34,8 @@ def _cancel_databricks_run(context: dict[str, Any]) -> None: _logger.warning(f"No run_id found in XCom for task {ti.task_id}") return - # Get the databricks_conn_id from the operator - databricks_conn_id = getattr(ti.task, "databricks_conn_id", "databricks_default") + # Get the databricks_conn_id from the operator (set during operator creation) + databricks_conn_id = ti.task.databricks_conn_id try: hook = DatabricksHook(databricks_conn_id=databricks_conn_id) @@ -88,14 +88,12 @@ def _create_operator(self, **kwargs: Any) -> BaseOperator: DatabricksRunNowOperator, ) - # Get task parameters + # Get task parameters - defaults are handled in DatabricksDLTTask job_name: str = self._task.job_name - databricks_conn_id: str = getattr( - self._task, "databricks_conn_id", "databricks_default" - ) - wait_for_completion: bool = getattr(self._task, "wait_for_completion", True) - poll_interval_seconds: int = getattr(self._task, "poll_interval_seconds", 30) - timeout_seconds: int = getattr(self._task, "timeout_seconds", 3600) + databricks_conn_id: str = self._task.databricks_conn_id + wait_for_completion: bool = self._task.wait_for_completion + poll_interval_seconds: int = self._task.poll_interval_seconds + timeout_seconds: int = self._task.timeout_seconds # DatabricksRunNowOperator triggers an existing Databricks Job by name # The job must have a pipeline_task that references the DLT pipeline diff --git a/dagger/plugins/dlt_task_generator/__init__.py b/dagger/plugins/dlt_task_generator/__init__.py deleted file mode 100644 index 49e4a17..0000000 --- a/dagger/plugins/dlt_task_generator/__init__.py +++ /dev/null @@ -1,6 +0,0 @@ -"""DLT Task Generator plugin for generating Dagger configs from Databricks Asset Bundles.""" - -from dagger.plugins.dlt_task_generator.bundle_parser import DatabricksBundleParser -from dagger.plugins.dlt_task_generator.dlt_task_generator import DLTTaskGenerator - -__all__ = ["DatabricksBundleParser", "DLTTaskGenerator"] diff --git a/dagger/plugins/dlt_task_generator/bundle_parser.py b/dagger/plugins/dlt_task_generator/bundle_parser.py deleted file mode 100644 index 5a15ed7..0000000 --- a/dagger/plugins/dlt_task_generator/bundle_parser.py +++ /dev/null @@ -1,309 +0,0 @@ -"""Parse Databricks Asset Bundle YAML files for DLT pipeline configuration.""" - -import logging -import re -from dataclasses import dataclass, field -from pathlib import Path -from typing import Any, Optional - -import yaml - -_logger = logging.getLogger(__name__) - - -@dataclass -class TableConfig: - """Configuration for a single table in a DLT pipeline. - - Attributes: - database: Source database name. - table: Source table name. - changelog_type: Type of changelog source ('dynamodb' or 'postgres'). - unique_keys: List of columns that uniquely identify a row. - scd_type: Slowly changing dimension type (1 or 2). - """ - - database: str - table: str - changelog_type: str # 'dynamodb' or 'postgres' - unique_keys: list[str] = field(default_factory=list) - scd_type: int = 1 - - @property - def source_schema(self) -> str: - """Get the source schema name for Athena/Glue catalog. - - For DynamoDB: ddb_changelogs - For PostgreSQL: pg_changelogs_kafka_{database_normalized} - """ - if self.changelog_type == "dynamodb": - return self.database - elif self.changelog_type == "postgres": - # Normalize database name (replace hyphens with underscores) - db_normalized = self.database.replace("-", "_") - return f"pg_changelogs_kafka_{db_normalized}" - else: - return self.database - - @property - def silver_table_name(self) -> str: - """Get the silver table name produced by DLT.""" - return f"silver_{self.table}" - - @property - def bronze_table_name(self) -> str: - """Get the bronze table name produced by DLT.""" - return f"bronze_{self.table}" - - -@dataclass -class PipelineConfig: - """Configuration for a DLT pipeline parsed from Databricks Asset Bundle. - - Attributes: - name: Pipeline/bundle name. - catalog: Target Unity Catalog name. - schema: Target schema name. - tables: List of table configurations for the pipeline. - targets: Target environment configurations (dev/prod). - variables: Variable definitions from databricks.yml. - tags: Pipeline tags. - """ - - name: str - catalog: str - schema: str - tables: list[TableConfig] = field(default_factory=list) - targets: dict[str, Any] = field(default_factory=dict) - variables: dict[str, Any] = field(default_factory=dict) - tags: dict[str, str] = field(default_factory=dict) - - -class DatabricksBundleParser: - """Parse Databricks Asset Bundle YAML files (databricks.yml and tables.yml). - - This parser extracts pipeline configuration from Databricks Asset Bundles, - including the target catalog, schema, table definitions, and job configuration. - It resolves variable references to Dagger environment variable format. - - Attributes: - _databricks_yml_path: Path to the databricks.yml file. - _tables_yml_path: Path to the tables.yml file. - _databricks_config: Parsed databricks.yml content. - _tables_config: Parsed tables.yml content. - _pipeline_config: Cached PipelineConfig instance. - """ - - def __init__( - self, - databricks_yml_path: Path, - tables_yml_path: Optional[Path] = None, - ) -> None: - """Initialize the parser with paths to bundle YAML files. - - Args: - databricks_yml_path: Path to the databricks.yml file. - tables_yml_path: Optional path to the tables.yml file. If not provided, - will look for tables.yml in the same directory. - """ - self._databricks_yml_path = Path(databricks_yml_path) - self._tables_yml_path = ( - Path(tables_yml_path) - if tables_yml_path - else self._databricks_yml_path.parent / "tables.yml" - ) - - self._databricks_config = self._load_yaml(self._databricks_yml_path) - self._tables_config = ( - self._load_yaml(self._tables_yml_path) - if self._tables_yml_path.exists() - else {} - ) - - self._pipeline_config: Optional[PipelineConfig] = None - - @staticmethod - def _load_yaml(path: Path) -> dict[str, Any]: - """Load and parse a YAML file. - - Args: - path: Path to the YAML file. - - Returns: - Parsed YAML content as a dictionary. - - Raises: - yaml.YAMLError: If the YAML file is malformed. - """ - try: - with open(path, "r") as f: - return yaml.safe_load(f) or {} - except FileNotFoundError: - _logger.warning(f"YAML file not found: {path}") - return {} - except yaml.YAMLError as e: - _logger.error(f"Error parsing YAML file {path}: {e}") - raise - - def _resolve_variable(self, value: str) -> str: - """Resolve Databricks bundle variable references like ${var.catalog}. - - Args: - value: String that may contain variable references - - Returns: - Resolved string with environment variable format for Dagger - """ - if not isinstance(value, str): - return value - - # Match ${var.variable_name} pattern - var_pattern = re.compile(r"\$\{var\.(\w+)\}") - - def replace_var(match): - var_name = match.group(1) - # Get the default value from variables section - var_config = self._databricks_config.get("variables", {}).get(var_name, {}) - default_value = var_config.get("default", "") - - # Map to Dagger environment variables - if var_name == "catalog": - # Map catalog to Dagger's ${ENV_MARTS} pattern - return "${ENV_MARTS}" - return default_value - - return var_pattern.sub(replace_var, value) - - def _parse_tables(self) -> list[TableConfig]: - """Parse table configurations from tables.yml. - - Returns: - List of TableConfig instances for each table defined in the bundle. - """ - tables = [] - defaults = self._tables_config.get("defaults", {}) - default_scd_type = defaults.get("scd_type", 1) - - for table_config in self._tables_config.get("tables", []): - tables.append( - TableConfig( - database=table_config.get("database", ""), - table=table_config.get("table", ""), - changelog_type=table_config.get("changelog_type", "dynamodb"), - unique_keys=table_config.get("unique_keys", []), - scd_type=table_config.get("scd_type", default_scd_type), - ) - ) - - return tables - - def _parse_pipeline(self) -> PipelineConfig: - """Parse pipeline configuration from databricks.yml. - - Extracts bundle name, variables, targets, and pipeline-specific settings - from the Databricks Asset Bundle configuration. - - Returns: - PipelineConfig instance with all parsed configuration. - """ - bundle_name = self._databricks_config.get("bundle", {}).get("name", "") - variables = self._databricks_config.get("variables", {}) - targets = self._databricks_config.get("targets", {}) - - # Get pipeline configuration from resources - resources = self._databricks_config.get("resources", {}) - pipelines = resources.get("pipelines", {}) - - # Get the first pipeline (usually matches bundle name) - pipeline_key = bundle_name or next(iter(pipelines.keys()), "") - pipeline_config = pipelines.get(pipeline_key, {}) - - catalog = self._resolve_variable(pipeline_config.get("catalog", "")) - schema = pipeline_config.get("schema", "") - tags = pipeline_config.get("tags", {}) - - return PipelineConfig( - name=bundle_name, - catalog=catalog, - schema=schema, - tables=self._parse_tables(), - targets=targets, - variables=variables, - tags=tags, - ) - - def parse(self) -> PipelineConfig: - """Parse the Databricks Asset Bundle and return pipeline configuration. - - Returns: - PipelineConfig with all parsed configuration - """ - if self._pipeline_config is None: - self._pipeline_config = self._parse_pipeline() - return self._pipeline_config - - def get_bundle_name(self) -> str: - """Return the bundle/pipeline name. - - Returns: - The bundle name from databricks.yml. - """ - return self.parse().name - - def get_catalog(self) -> str: - """Return the target catalog with Dagger environment variable format. - - Returns: - Catalog name with variables resolved to Dagger format (e.g., ${ENV_MARTS}). - """ - return self.parse().catalog - - def get_schema(self) -> str: - """Return the target schema. - - Returns: - Target schema name for the DLT pipeline. - """ - return self.parse().schema - - def get_tables(self) -> list[TableConfig]: - """Return the list of table configurations. - - Returns: - List of TableConfig instances for all tables in the pipeline. - """ - return self.parse().tables - - def get_targets(self) -> dict[str, Any]: - """Return target configurations (dev/prod). - - Returns: - Dictionary of target environment configurations. - """ - return self.parse().targets - - def get_variables(self) -> dict[str, Any]: - """Return variable definitions. - - Returns: - Dictionary of variable definitions from databricks.yml. - """ - return self.parse().variables - - def get_job_name(self) -> str: - """Get the Databricks Job name that triggers this pipeline. - - Looks for a job defined in resources.jobs that wraps the pipeline. - Falls back to a default naming convention if no job is defined. - - Returns: - The job name to use with DatabricksRunNowOperator - """ - resources = self._databricks_config.get("resources", {}) - jobs = resources.get("jobs", {}) - - # Return the first job's name, or use default naming convention - for job_config in jobs.values(): - return job_config.get("name", f"dlt-{self.get_bundle_name()}") - - return f"dlt-{self.get_bundle_name()}" diff --git a/dagger/plugins/dlt_task_generator/dlt_task_generator.py b/dagger/plugins/dlt_task_generator/dlt_task_generator.py deleted file mode 100644 index 6136227..0000000 --- a/dagger/plugins/dlt_task_generator/dlt_task_generator.py +++ /dev/null @@ -1,374 +0,0 @@ -"""Generate Dagger task configurations from Databricks DLT bundle definitions.""" - -import logging -import os -from pathlib import Path -from typing import Any, Optional - -import yaml - -from dagger.plugins.dlt_task_generator.bundle_parser import ( - DatabricksBundleParser, - PipelineConfig, - TableConfig, -) - -_logger = logging.getLogger(__name__) - -# Default path to the DLT pipelines repository (can be overridden via env var) -DEFAULT_DLT_PIPELINES_REPO = os.getenv( - "DLT_PIPELINES_REPO", - str(Path(__file__).parent.parent.parent.parent.parent / "dataeng-databricks-dlt-pipelines"), -) - - -class DLTTaskGenerator: - """Generate Dagger task configurations from Databricks DLT bundle definitions. - - This generator reads Databricks Asset Bundle configurations and produces - Dagger-compatible YAML task configurations for DLT pipelines. - - Attributes: - ATHENA_TASK_BASE: Base configuration for Athena input tasks. - DATABRICKS_TASK_BASE: Base configuration for Databricks output tasks. - DUMMY_TASK_BASE: Base configuration for dummy tasks. - """ - - ATHENA_TASK_BASE: dict[str, str] = {"type": "athena"} - DATABRICKS_TASK_BASE: dict[str, str] = {"type": "databricks"} - DUMMY_TASK_BASE: dict[str, str] = {"type": "dummy"} - - def __init__(self, dlt_repo_path: Optional[str] = None) -> None: - """Initialize the generator with path to DLT pipelines repository. - - Args: - dlt_repo_path: Path to the dataeng-databricks-dlt-pipelines repository. - Defaults to DLT_PIPELINES_REPO env var or sibling directory. - """ - self._dlt_repo_path = Path(dlt_repo_path or DEFAULT_DLT_PIPELINES_REPO) - self._pipelines: dict[str, DatabricksBundleParser] = {} - self._load_all_pipelines() - - def _load_all_pipelines(self) -> None: - """Load all pipeline bundles from the DLT repository. - - Scans the pipelines directory and loads each valid Databricks Asset Bundle - found. Bundles are identified by the presence of a databricks.yml file. - """ - pipelines_dir = self._dlt_repo_path / "pipelines" - - if not pipelines_dir.exists(): - _logger.warning(f"DLT pipelines directory not found: {pipelines_dir}") - return - - for pipeline_dir in pipelines_dir.iterdir(): - if not pipeline_dir.is_dir(): - continue - - databricks_yml = pipeline_dir / "databricks.yml" - if not databricks_yml.exists(): - continue - - tables_yml = pipeline_dir / "tables.yml" - try: - parser = DatabricksBundleParser(databricks_yml, tables_yml) - pipeline_name = parser.get_bundle_name() or pipeline_dir.name - self._pipelines[pipeline_name] = parser - _logger.info(f"Loaded DLT pipeline: {pipeline_name}") - except Exception as e: - _logger.error(f"Error loading pipeline from {pipeline_dir}: {e}") - - def get_pipeline_names(self) -> list[str]: - """Return list of available DLT pipeline names.""" - return list(self._pipelines.keys()) - - def get_pipeline_config(self, pipeline_name: str) -> PipelineConfig: - """Get the parsed pipeline configuration. - - Args: - pipeline_name: Name of the pipeline - - Returns: - PipelineConfig object - - Raises: - ValueError: If pipeline not found - """ - if pipeline_name not in self._pipelines: - raise ValueError( - f"Unknown pipeline: {pipeline_name}. " - f"Available pipelines: {self.get_pipeline_names()}" - ) - return self._pipelines[pipeline_name].parse() - - def _get_athena_input( - self, table: TableConfig, follow_external_dependency: bool = True - ) -> dict[str, Any]: - """Generate an Athena input task for a source changelog table. - - Args: - table: Table configuration from the DLT bundle. - follow_external_dependency: Whether to create an ExternalTaskSensor - for cross-pipeline dependency tracking. - - Returns: - Dagger Athena task configuration dict. - """ - task = self.ATHENA_TASK_BASE.copy() - task.update( - { - "schema": table.source_schema, - "table": table.table, - "name": f"{table.source_schema}__{table.table}_athena", - } - ) - if follow_external_dependency: - task["follow_external_dependency"] = True - return task - - def _get_databricks_output( - self, table: TableConfig, catalog: str, schema: str - ) -> dict[str, Any]: - """Generate a Databricks output task for a silver table. - - Args: - table: Table configuration from the DLT bundle. - catalog: Target Unity Catalog name (e.g., ${ENV_MARTS}). - schema: Target schema name. - - Returns: - Dagger Databricks task configuration dict. - """ - task = self.DATABRICKS_TASK_BASE.copy() - # Normalize catalog name for task naming - catalog_name = catalog.replace("${", "").replace("}", "").lower() - task.update( - { - "catalog": catalog, - "schema": schema, - "table": table.silver_table_name, - "name": f"{catalog_name}__{schema}__{table.silver_table_name}_databricks", - } - ) - return task - - def get_inputs( - self, pipeline_name: str, follow_external_dependency: bool = True - ) -> list[dict[str, Any]]: - """Generate input dependencies for a DLT pipeline task. - - These are the source changelog tables that the DLT pipeline reads from. - - Args: - pipeline_name: Name of the DLT pipeline. - follow_external_dependency: Whether to create ExternalTaskSensors - for cross-pipeline dependency tracking. - - Returns: - List of Dagger input task configurations. - """ - config = self.get_pipeline_config(pipeline_name) - inputs = [] - - for table in config.tables: - input_task = self._get_athena_input(table, follow_external_dependency) - inputs.append(input_task) - - return inputs - - def get_outputs(self, pipeline_name: str) -> list[dict[str, Any]]: - """Generate output declarations for a DLT pipeline task. - - These are the silver tables produced by the DLT pipeline. - - Args: - pipeline_name: Name of the DLT pipeline. - - Returns: - List of Dagger output task configurations. - """ - config = self.get_pipeline_config(pipeline_name) - outputs = [] - - for table in config.tables: - output_task = self._get_databricks_output( - table, config.catalog, config.schema - ) - outputs.append(output_task) - - return outputs - - def get_task_parameters(self, pipeline_name: str) -> dict[str, Any]: - """Generate task parameters for triggering a DLT pipeline via Databricks Job. - - Args: - pipeline_name: Name of the DLT pipeline. - - Returns: - Dict of task parameters for the DatabricksRunNowOperator. - """ - parser = self._pipelines[pipeline_name] - return { - "job_name": parser.get_job_name(), - "databricks_conn_id": "${DATABRICKS_CONN_ID}", - "wait_for_completion": True, - "poll_interval_seconds": 30, - "timeout_seconds": 3600, - } - - def generate_task_config( - self, - pipeline_name: str, - description: Optional[str] = None, - follow_external_dependency: bool = True, - ) -> dict[str, Any]: - """Generate a complete Dagger task configuration for a DLT pipeline. - - Args: - pipeline_name: Name of the DLT pipeline. - description: Optional task description. Defaults to auto-generated. - follow_external_dependency: Whether to create ExternalTaskSensors for inputs. - - Returns: - Complete Dagger task configuration dict ready for YAML serialization. - """ - config = self.get_pipeline_config(pipeline_name) - - task_config = { - "type": "databricks_dlt", - "description": description or f"Run DLT pipeline {pipeline_name}", - "inputs": self.get_inputs(pipeline_name, follow_external_dependency), - "outputs": self.get_outputs(pipeline_name), - "airflow_task_parameters": { - "retries": 2, - "retry_delay": 300, - }, - "template_parameters": {}, - "task_parameters": self.get_task_parameters(pipeline_name), - } - - return task_config - - def generate_pipeline_config( - self, - pipeline_name: str, - schedule: str = "0 * * * *", - owner: str = "dataeng@choco.com", - ) -> dict[str, Any]: - """Generate a Dagger pipeline.yaml configuration for a DLT pipeline DAG. - - Args: - pipeline_name: Name of the DLT pipeline. - schedule: Cron schedule expression. Defaults to hourly. - owner: Pipeline owner email address. - - Returns: - Dagger pipeline.yaml configuration dict ready for YAML serialization. - """ - config = self.get_pipeline_config(pipeline_name) - - return { - "owner": owner, - "description": f"DLT Pipeline - {pipeline_name}", - "schedule": schedule, - "start_date": "2024-01-01T00:00", - "airflow_parameters": { - "default_args": { - "retries": 2, - "retry_delay": 180, - "depends_on_past": False, - }, - "dag_parameters": { - "catchup": False, - "max_active_runs": 1, - "tags": ["dlt", "databricks", pipeline_name], - }, - }, - "alerts": [ - { - "type": "slack", - "channel": "#${ENV}-airflow-alerts", - "mentions": ["@dataeng-oncall"], - } - ], - } - - def write_task_config( - self, pipeline_name: str, output_path: Path, **kwargs: Any - ) -> Path: - """Write a task configuration to a YAML file. - - Args: - pipeline_name: Name of the DLT pipeline. - output_path: Directory to write the file to. - **kwargs: Additional arguments passed to generate_task_config. - - Returns: - Path to the written YAML file. - """ - output_path = Path(output_path) - output_path.mkdir(parents=True, exist_ok=True) - - task_config = self.generate_task_config(pipeline_name, **kwargs) - file_path = output_path / f"{pipeline_name}_dlt.yaml" - - with open(file_path, "w") as f: - # Add autogenerated marker - task_config["autogenerated_by_dagger"] = f"dlt_task_generator:{pipeline_name}" - yaml.dump(task_config, f, default_flow_style=False, sort_keys=False) - - _logger.info(f"Generated task config: {file_path}") - return file_path - - def write_pipeline_config( - self, pipeline_name: str, output_path: Path, **kwargs: Any - ) -> Path: - """Write a pipeline configuration to a YAML file. - - Args: - pipeline_name: Name of the DLT pipeline. - output_path: Directory to write the file to. - **kwargs: Additional arguments passed to generate_pipeline_config. - - Returns: - Path to the written YAML file. - """ - output_path = Path(output_path) - output_path.mkdir(parents=True, exist_ok=True) - - pipeline_config = self.generate_pipeline_config(pipeline_name, **kwargs) - file_path = output_path / "pipeline.yaml" - - with open(file_path, "w") as f: - yaml.dump(pipeline_config, f, default_flow_style=False, sort_keys=False) - - _logger.info(f"Generated pipeline config: {file_path}") - return file_path - - def generate_all(self, output_base_path: Path) -> list[Path]: - """Generate all DLT pipeline configurations. - - Creates pipeline.yaml and task configuration files for each loaded - DLT pipeline in the repository. - - Args: - output_base_path: Base directory for output (e.g., dags/dlt/). - - Returns: - List of paths to all generated files. - """ - output_base_path = Path(output_base_path) - generated_files = [] - - for pipeline_name in self.get_pipeline_names(): - pipeline_output_path = output_base_path / pipeline_name - - # Generate pipeline.yaml - pipeline_file = self.write_pipeline_config(pipeline_name, pipeline_output_path) - generated_files.append(pipeline_file) - - # Generate task config - task_file = self.write_task_config(pipeline_name, pipeline_output_path) - generated_files.append(task_file) - - return generated_files From 6cc7438e862c9b6404e8efa453a9202f46eea582 Mon Sep 17 00:00:00 2001 From: Mohamad Hallak <16711801+mrhallak@users.noreply.github.com> Date: Mon, 12 Jan 2026 16:41:44 +0100 Subject: [PATCH 183/189] Add Claude to repository --- .github/workflows/claude.yml | 49 ++++++++++++++++++++++++++++++++++++ 1 file changed, 49 insertions(+) create mode 100644 .github/workflows/claude.yml diff --git a/.github/workflows/claude.yml b/.github/workflows/claude.yml new file mode 100644 index 0000000..d199848 --- /dev/null +++ b/.github/workflows/claude.yml @@ -0,0 +1,49 @@ +name: Claude Code + +on: + issue_comment: + types: [created] + pull_request_review_comment: + types: [created] + issues: + types: [opened, assigned] + pull_request_review: + types: [submitted] + +jobs: + claude: + if: | + (github.event_name == 'issue_comment' && contains(github.event.comment.body, '@claude')) || + (github.event_name == 'pull_request_review_comment' && contains(github.event.comment.body, '@claude')) || + (github.event_name == 'pull_request_review' && contains(github.event.review.body, '@claude')) || + (github.event_name == 'issues' && (contains(github.event.issue.body, '@claude') || contains(github.event.issue.title, '@claude'))) + runs-on: ubuntu-latest + permissions: + contents: read + pull-requests: read + issues: read + id-token: write + actions: read # Required for Claude to read CI results on PRs + steps: + - name: Checkout repository + uses: actions/checkout@v4 + with: + fetch-depth: 1 + + - name: Run Claude Code + id: claude + uses: anthropics/claude-code-action@v1 + with: + anthropic_api_key: ${{ secrets.ANTHROPIC_API_KEY }} + + # This is an optional setting that allows Claude to read CI results on PRs + additional_permissions: | + actions: read + + # Optional: Give a custom prompt to Claude. If this is not specified, Claude will perform the instructions specified in the comment that tagged it. + # prompt: 'Update the pull request description to include a summary of changes.' + + # Optional: Add claude_args to customize behavior and configuration + # See https://github.com/anthropics/claude-code-action/blob/main/docs/usage.md + # or https://code.claude.com/docs/en/cli-reference for available options + # claude_args: '--allowed-tools Bash(gh pr:*)' From bfeb772df1acbe648205231bac4bc745b0485f75 Mon Sep 17 00:00:00 2001 From: Mohamad Hallak <16711801+mrhallak@users.noreply.github.com> Date: Mon, 12 Jan 2026 17:34:49 +0100 Subject: [PATCH 184/189] Update the module to support yaml --- dagger/cli/module.py | 25 ++++++++++++++++++++----- 1 file changed, 20 insertions(+), 5 deletions(-) diff --git a/dagger/cli/module.py b/dagger/cli/module.py index 931e809..d77897d 100644 --- a/dagger/cli/module.py +++ b/dagger/cli/module.py @@ -1,19 +1,34 @@ +import json + import click +import yaml + from dagger.utilities.module import Module from dagger.utils import Printer -import json def parse_key_value(ctx, param, value): - #print('YYY', value) + """Parse key=value pairs where value is a path to JSON or YAML file. + + Args: + ctx: Click context. + param: Click parameter. + value: List of key=value pairs. + + Returns: + Dictionary mapping variable names to parsed file contents. + """ if not value: return {} key_value_dict = {} for pair in value: try: key, val_file_path = pair.split('=', 1) - #print('YYY', key, val_file_path, pair) - val = json.load(open(val_file_path)) + with open(val_file_path, 'r') as f: + if val_file_path.endswith(('.yaml', '.yml')): + val = yaml.safe_load(f) + else: + val = json.load(f) key_value_dict[key] = val except ValueError: raise click.BadParameter(f"Key-value pair '{pair}' is not in the format key=value") @@ -22,7 +37,7 @@ def parse_key_value(ctx, param, value): @click.command() @click.option("--config_file", "-c", help="Path to module config file") @click.option("--target_dir", "-t", help="Path to directory to generate the task configs to") -@click.option("--jinja_parameters", "-j", callback=parse_key_value, multiple=True, default=None, help="Path to jinja parameters json file in the format: =") +@click.option("--jinja_parameters", "-j", callback=parse_key_value, multiple=True, default=None, help="Jinja parameters file in the format: =") def generate_tasks(config_file: str, target_dir: str, jinja_parameters: dict) -> None: """ Generating tasks for a module based on config From 63d03452c5f78e299712d512740ede9479bb18cf Mon Sep 17 00:00:00 2001 From: Mohamad Hallak <16711801+mrhallak@users.noreply.github.com> Date: Tue, 13 Jan 2026 17:21:39 +0100 Subject: [PATCH 185/189] Add comprehensive tests and improve Databricks DLT components - Add type hints and docstrings to DatabricksIO - Improve error handling in DatabricksDLTCreator with ImportError support - Add validation for empty job_name in DatabricksDLTCreator - Add comprehensive test coverage for DatabricksIO, DatabricksDLTTask, and DatabricksDLTCreator - All Databricks components now have 100% test coverage --- .../databricks_dlt_creator.py | 17 +- dagger/pipeline/ios/databricks_io.py | 106 ++++++- .../airflow/operator_creators/__init__.py | 0 .../test_databricks_dlt_creator.py | 276 ++++++++++++++++++ .../pipeline/tasks/databricks_dlt_task.yaml | 22 ++ tests/pipeline/ios/test_databricks_io.py | 222 +++++++++++++- .../tasks/test_databricks_dlt_task.py | 176 +++++++++++ 7 files changed, 798 insertions(+), 21 deletions(-) create mode 100644 tests/dag_creator/airflow/operator_creators/__init__.py create mode 100644 tests/dag_creator/airflow/operator_creators/test_databricks_dlt_creator.py create mode 100644 tests/fixtures/pipeline/tasks/databricks_dlt_task.yaml create mode 100644 tests/pipeline/tasks/test_databricks_dlt_task.py diff --git a/dagger/dag_creator/airflow/operator_creators/databricks_dlt_creator.py b/dagger/dag_creator/airflow/operator_creators/databricks_dlt_creator.py index d3f538c..87a11ac 100644 --- a/dagger/dag_creator/airflow/operator_creators/databricks_dlt_creator.py +++ b/dagger/dag_creator/airflow/operator_creators/databricks_dlt_creator.py @@ -21,8 +21,6 @@ def _cancel_databricks_run(context: dict[str, Any]) -> None: Args: context: Airflow context dictionary containing task instance and other metadata. """ - from airflow.providers.databricks.hooks.databricks import DatabricksHook - ti = context.get("task_instance") if not ti: _logger.warning("No task instance in context, cannot cancel Databricks run") @@ -37,10 +35,18 @@ def _cancel_databricks_run(context: dict[str, Any]) -> None: # Get the databricks_conn_id from the operator (set during operator creation) databricks_conn_id = ti.task.databricks_conn_id + # Import here to avoid import errors if databricks provider not installed + # and to only import when actually needed (after early returns) try: + from airflow.providers.databricks.hooks.databricks import DatabricksHook + hook = DatabricksHook(databricks_conn_id=databricks_conn_id) hook.cancel_run(run_id) _logger.info(f"Cancelled Databricks run {run_id} for task {ti.task_id}") + except ImportError: + _logger.error( + "airflow-providers-databricks is not installed, cannot cancel run" + ) except Exception as e: _logger.error(f"Failed to cancel Databricks run {run_id}: {e}") @@ -80,6 +86,9 @@ def _create_operator(self, **kwargs: Any) -> BaseOperator: Returns: A configured DatabricksRunNowOperator instance. + + Raises: + ValueError: If job_name is empty or not provided. """ # Import here to avoid import errors if databricks provider not installed from datetime import timedelta @@ -90,6 +99,10 @@ def _create_operator(self, **kwargs: Any) -> BaseOperator: # Get task parameters - defaults are handled in DatabricksDLTTask job_name: str = self._task.job_name + if not job_name: + raise ValueError( + f"job_name is required for DatabricksDLTTask '{self._task.name}'" + ) databricks_conn_id: str = self._task.databricks_conn_id wait_for_completion: bool = self._task.wait_for_completion poll_interval_seconds: int = self._task.poll_interval_seconds diff --git a/dagger/pipeline/ios/databricks_io.py b/dagger/pipeline/ios/databricks_io.py index 15be2c1..7c7b4d2 100644 --- a/dagger/pipeline/ios/databricks_io.py +++ b/dagger/pipeline/ios/databricks_io.py @@ -1,12 +1,45 @@ +"""IO representation for Databricks Unity Catalog tables.""" + +from typing import Any + from dagger.pipeline.io import IO from dagger.utilities.config_validator import Attribute class DatabricksIO(IO): - ref_name = "databricks" + """IO representation for Databricks Unity Catalog tables. + + Represents a table in Databricks Unity Catalog with catalog.schema.table naming. + Used to define inputs and outputs for tasks that read from or write to + Databricks tables. + + Attributes: + ref_name: Reference name used by IOFactory to instantiate this IO type. + catalog: Databricks Unity Catalog name. + schema: Schema/database name within the catalog. + table: Table name. + + Example YAML configuration: + type: databricks + name: my_output_table + catalog: prod_catalog + schema: analytics + table: user_metrics + """ + + ref_name: str = "databricks" @classmethod - def init_attributes(cls, orig_cls): + def init_attributes(cls, orig_cls: type) -> None: + """Initialize configuration attributes for YAML parsing. + + Registers all attributes that can be specified in the YAML configuration. + Called by the IO metaclass during class creation. + + Args: + orig_cls: The original class being initialized (used for attribute + registration). + """ cls.add_config_attributes( [ Attribute(attribute_name="catalog"), @@ -15,32 +48,81 @@ def init_attributes(cls, orig_cls): ] ) - def __init__(self, io_config, config_location): + def __init__(self, io_config: dict[str, Any], config_location: str) -> None: + """Initialize a DatabricksIO instance. + + Args: + io_config: Dictionary containing the IO configuration from YAML. + config_location: Path to the configuration file for error reporting. + + Raises: + DaggerMissingFieldException: If required fields (catalog, schema, table) + are missing from the configuration. + """ super().__init__(io_config, config_location) - self._catalog = self.parse_attribute("catalog") - self._schema = self.parse_attribute("schema") - self._table = self.parse_attribute("table") + self._catalog: str = self.parse_attribute("catalog") + self._schema: str = self.parse_attribute("schema") + self._table: str = self.parse_attribute("table") - def alias(self): + def alias(self) -> str: + """Return the unique alias for this IO in databricks:// URI format. + + The alias is used for dataset lineage tracking and dependency resolution + across pipelines. + + Returns: + A unique identifier string in the format + 'databricks://{catalog}/{schema}/{table}'. + """ return f"databricks://{self._catalog}/{self._schema}/{self._table}" @property - def rendered_name(self): + def rendered_name(self) -> str: + """Return the fully qualified table name in dot notation. + + This format is used in SQL queries and Databricks API calls. + + Returns: + The table name in '{catalog}.{schema}.{table}' format. + """ return f"{self._catalog}.{self._schema}.{self._table}" @property - def airflow_name(self): + def airflow_name(self) -> str: + """Return an Airflow-safe identifier for this table. + + Airflow task/dataset IDs cannot contain dots, so this returns a + hyphen-separated format suitable for use in Airflow contexts. + + Returns: + The table name in 'databricks-{catalog}-{schema}-{table}' format. + """ return f"databricks-{self._catalog}-{self._schema}-{self._table}" @property - def catalog(self): + def catalog(self) -> str: + """Return the Databricks Unity Catalog name. + + Returns: + The catalog name. + """ return self._catalog @property - def schema(self): + def schema(self) -> str: + """Return the schema/database name within the catalog. + + Returns: + The schema name. + """ return self._schema @property - def table(self): + def table(self) -> str: + """Return the table name. + + Returns: + The table name. + """ return self._table diff --git a/tests/dag_creator/airflow/operator_creators/__init__.py b/tests/dag_creator/airflow/operator_creators/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/dag_creator/airflow/operator_creators/test_databricks_dlt_creator.py b/tests/dag_creator/airflow/operator_creators/test_databricks_dlt_creator.py new file mode 100644 index 0000000..39de91b --- /dev/null +++ b/tests/dag_creator/airflow/operator_creators/test_databricks_dlt_creator.py @@ -0,0 +1,276 @@ +"""Unit tests for DatabricksDLTCreator.""" + +import sys +import unittest +from datetime import timedelta +from unittest.mock import MagicMock, patch + +from dagger.dag_creator.airflow.operator_creators.databricks_dlt_creator import ( + DatabricksDLTCreator, + _cancel_databricks_run, +) + + +class TestDatabricksDLTCreator(unittest.TestCase): + """Test cases for DatabricksDLTCreator.""" + + def setUp(self) -> None: + """Set up test fixtures.""" + self.mock_task = MagicMock() + self.mock_task.name = "test_dlt_task" + self.mock_task.job_name = "test-dlt-job" + self.mock_task.databricks_conn_id = "databricks_default" + self.mock_task.wait_for_completion = True + self.mock_task.poll_interval_seconds = 30 + self.mock_task.timeout_seconds = 3600 + self.mock_task.cancel_on_kill = True + + self.mock_dag = MagicMock() + + # Set up mock for DatabricksRunNowOperator + self.mock_operator = MagicMock() + self.mock_operator_class = MagicMock(return_value=self.mock_operator) + self.mock_databricks_module = MagicMock() + self.mock_databricks_module.DatabricksRunNowOperator = self.mock_operator_class + + def test_ref_name(self) -> None: + """Test that ref_name is correctly set.""" + self.assertEqual(DatabricksDLTCreator.ref_name, "databricks_dlt") + + @patch.dict( + sys.modules, + {"airflow.providers.databricks.operators.databricks": MagicMock()}, + ) + def test_create_operator(self) -> None: + """Test operator creation returns an operator instance.""" + mock_operator = MagicMock() + mock_operator_class = MagicMock(return_value=mock_operator) + sys.modules[ + "airflow.providers.databricks.operators.databricks" + ].DatabricksRunNowOperator = mock_operator_class + + creator = DatabricksDLTCreator(self.mock_task, self.mock_dag) + operator = creator._create_operator() + + mock_operator_class.assert_called_once() + self.assertEqual(operator, mock_operator) + + @patch.dict( + sys.modules, + {"airflow.providers.databricks.operators.databricks": MagicMock()}, + ) + def test_create_operator_maps_task_properties(self) -> None: + """Test that task properties are correctly mapped to operator.""" + mock_operator_class = MagicMock() + sys.modules[ + "airflow.providers.databricks.operators.databricks" + ].DatabricksRunNowOperator = mock_operator_class + + creator = DatabricksDLTCreator(self.mock_task, self.mock_dag) + creator._create_operator() + + call_kwargs = mock_operator_class.call_args[1] + + self.assertEqual(call_kwargs["dag"], self.mock_dag) + self.assertEqual(call_kwargs["task_id"], "test_dlt_task") + self.assertEqual(call_kwargs["databricks_conn_id"], "databricks_default") + self.assertEqual(call_kwargs["job_name"], "test-dlt-job") + self.assertEqual(call_kwargs["wait_for_termination"], True) + self.assertEqual(call_kwargs["polling_period_seconds"], 30) + self.assertEqual(call_kwargs["execution_timeout"], timedelta(seconds=3600)) + self.assertTrue(call_kwargs["do_xcom_push"]) + self.assertEqual(call_kwargs["on_failure_callback"], _cancel_databricks_run) + + @patch.dict( + sys.modules, + {"airflow.providers.databricks.operators.databricks": MagicMock()}, + ) + def test_create_operator_with_custom_values(self) -> None: + """Test operator creation with non-default values.""" + self.mock_task.databricks_conn_id = "custom_conn" + self.mock_task.wait_for_completion = False + self.mock_task.poll_interval_seconds = 60 + self.mock_task.timeout_seconds = 7200 + + mock_operator_class = MagicMock() + sys.modules[ + "airflow.providers.databricks.operators.databricks" + ].DatabricksRunNowOperator = mock_operator_class + + creator = DatabricksDLTCreator(self.mock_task, self.mock_dag) + creator._create_operator() + + call_kwargs = mock_operator_class.call_args[1] + + self.assertEqual(call_kwargs["databricks_conn_id"], "custom_conn") + self.assertEqual(call_kwargs["wait_for_termination"], False) + self.assertEqual(call_kwargs["polling_period_seconds"], 60) + self.assertEqual(call_kwargs["execution_timeout"], timedelta(seconds=7200)) + + @patch.dict( + sys.modules, + {"airflow.providers.databricks.operators.databricks": MagicMock()}, + ) + def test_create_operator_empty_job_name_raises_error(self) -> None: + """Test that empty job_name raises ValueError.""" + self.mock_task.job_name = "" + + creator = DatabricksDLTCreator(self.mock_task, self.mock_dag) + + with self.assertRaises(ValueError) as context: + creator._create_operator() + + self.assertIn("job_name is required", str(context.exception)) + self.assertIn("test_dlt_task", str(context.exception)) + + @patch.dict( + sys.modules, + {"airflow.providers.databricks.operators.databricks": MagicMock()}, + ) + def test_create_operator_none_job_name_raises_error(self) -> None: + """Test that None job_name raises ValueError.""" + self.mock_task.job_name = None + + creator = DatabricksDLTCreator(self.mock_task, self.mock_dag) + + with self.assertRaises(ValueError) as context: + creator._create_operator() + + self.assertIn("job_name is required", str(context.exception)) + + @patch.dict( + sys.modules, + {"airflow.providers.databricks.operators.databricks": MagicMock()}, + ) + def test_create_operator_passes_kwargs(self) -> None: + """Test that additional kwargs are passed to operator.""" + mock_operator_class = MagicMock() + sys.modules[ + "airflow.providers.databricks.operators.databricks" + ].DatabricksRunNowOperator = mock_operator_class + + creator = DatabricksDLTCreator(self.mock_task, self.mock_dag) + creator._create_operator(retries=3, retry_delay=60) + + call_kwargs = mock_operator_class.call_args[1] + + self.assertEqual(call_kwargs["retries"], 3) + self.assertEqual(call_kwargs["retry_delay"], 60) + + +class TestCancelDatabricksRun(unittest.TestCase): + """Test cases for _cancel_databricks_run callback.""" + + def test_cancel_run_no_task_instance(self) -> None: + """Test callback handles missing task instance gracefully.""" + context: dict = {} + + # Should not raise, just log warning + _cancel_databricks_run(context) + + def test_cancel_run_no_run_id(self) -> None: + """Test callback handles missing run_id gracefully.""" + mock_ti = MagicMock() + mock_ti.task_id = "test_task" + mock_ti.xcom_pull.return_value = None + + context = {"task_instance": mock_ti} + + # Should not raise, just log warning + _cancel_databricks_run(context) + + mock_ti.xcom_pull.assert_called_once_with(task_ids="test_task", key="run_id") + + @patch.dict( + sys.modules, + {"airflow.providers.databricks.hooks.databricks": MagicMock()}, + ) + def test_cancel_run_success(self) -> None: + """Test successful cancellation of Databricks run.""" + mock_hook = MagicMock() + mock_hook_class = MagicMock(return_value=mock_hook) + sys.modules[ + "airflow.providers.databricks.hooks.databricks" + ].DatabricksHook = mock_hook_class + + mock_ti = MagicMock() + mock_ti.task_id = "test_task" + mock_ti.xcom_pull.return_value = "run_12345" + mock_ti.task.databricks_conn_id = "databricks_default" + + context = {"task_instance": mock_ti} + + _cancel_databricks_run(context) + + mock_hook_class.assert_called_once_with(databricks_conn_id="databricks_default") + mock_hook.cancel_run.assert_called_once_with("run_12345") + + @patch.dict( + sys.modules, + {"airflow.providers.databricks.hooks.databricks": MagicMock()}, + ) + def test_cancel_run_handles_exception(self) -> None: + """Test callback handles cancellation errors gracefully.""" + mock_hook = MagicMock() + mock_hook.cancel_run.side_effect = Exception("API Error") + mock_hook_class = MagicMock(return_value=mock_hook) + sys.modules[ + "airflow.providers.databricks.hooks.databricks" + ].DatabricksHook = mock_hook_class + + mock_ti = MagicMock() + mock_ti.task_id = "test_task" + mock_ti.xcom_pull.return_value = "run_12345" + mock_ti.task.databricks_conn_id = "databricks_default" + + context = {"task_instance": mock_ti} + + # Should not raise, just log error + _cancel_databricks_run(context) + + mock_hook.cancel_run.assert_called_once_with("run_12345") + + @patch.dict( + sys.modules, + {"airflow.providers.databricks.hooks.databricks": MagicMock()}, + ) + def test_cancel_run_with_custom_conn_id(self) -> None: + """Test cancellation uses correct connection ID.""" + mock_hook = MagicMock() + mock_hook_class = MagicMock(return_value=mock_hook) + sys.modules[ + "airflow.providers.databricks.hooks.databricks" + ].DatabricksHook = mock_hook_class + + mock_ti = MagicMock() + mock_ti.task_id = "test_task" + mock_ti.xcom_pull.return_value = "run_67890" + mock_ti.task.databricks_conn_id = "custom_databricks_conn" + + context = {"task_instance": mock_ti} + + _cancel_databricks_run(context) + + mock_hook_class.assert_called_once_with( + databricks_conn_id="custom_databricks_conn" + ) + + @patch.dict( + sys.modules, + {"airflow.providers.databricks.hooks.databricks": None}, + ) + def test_cancel_run_handles_import_error(self) -> None: + """Test callback handles missing databricks provider gracefully.""" + mock_ti = MagicMock() + mock_ti.task_id = "test_task" + mock_ti.xcom_pull.return_value = "run_12345" + mock_ti.task.databricks_conn_id = "databricks_default" + + context = {"task_instance": mock_ti} + + # Should not raise, just log error + _cancel_databricks_run(context) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/fixtures/pipeline/tasks/databricks_dlt_task.yaml b/tests/fixtures/pipeline/tasks/databricks_dlt_task.yaml new file mode 100644 index 0000000..1902cf0 --- /dev/null +++ b/tests/fixtures/pipeline/tasks/databricks_dlt_task.yaml @@ -0,0 +1,22 @@ +type: databricks_dlt +description: Test DLT pipeline task +inputs: + - type: athena + name: input_table + schema: test_schema + table: input_table +outputs: + - type: databricks + name: output_table + catalog: test_catalog + schema: test_schema + table: output_table +airflow_task_parameters: +template_parameters: +task_parameters: + job_name: test-dlt-job + databricks_conn_id: databricks_test + wait_for_completion: true + poll_interval_seconds: 60 + timeout_seconds: 7200 + cancel_on_kill: true diff --git a/tests/pipeline/ios/test_databricks_io.py b/tests/pipeline/ios/test_databricks_io.py index b1d0c45..e4a9456 100644 --- a/tests/pipeline/ios/test_databricks_io.py +++ b/tests/pipeline/ios/test_databricks_io.py @@ -1,17 +1,225 @@ +"""Unit tests for DatabricksIO.""" + import unittest -from dagger.pipeline.io_factory import databricks_io import yaml +from dagger.pipeline.ios import databricks_io +from dagger.utilities.exceptions import DaggerMissingFieldException + + +class TestDatabricksIO(unittest.TestCase): + """Test cases for DatabricksIO.""" -class DbIOTest(unittest.TestCase): def setUp(self) -> None: - with open('tests/fixtures/pipeline/ios/databricks_io.yaml', "r") as stream: + """Set up test fixtures.""" + with open("tests/fixtures/pipeline/ios/databricks_io.yaml", "r") as stream: config = yaml.safe_load(stream) self.db_io = databricks_io.DatabricksIO(config, "/") - def test_properties(self): - self.assertEqual(self.db_io.alias(), "databricks://test_catalog/test_schema/test_table") - self.assertEqual(self.db_io.rendered_name, "test_catalog.test_schema.test_table") - self.assertEqual(self.db_io.airflow_name, "databricks-test_catalog-test_schema-test_table") + def test_ref_name(self) -> None: + """Test that ref_name is correctly set.""" + self.assertEqual(databricks_io.DatabricksIO.ref_name, "databricks") + + def test_catalog(self) -> None: + """Test catalog property.""" + self.assertEqual(self.db_io.catalog, "test_catalog") + + def test_schema(self) -> None: + """Test schema property.""" + self.assertEqual(self.db_io.schema, "test_schema") + + def test_table(self) -> None: + """Test table property.""" + self.assertEqual(self.db_io.table, "test_table") + + def test_alias(self) -> None: + """Test alias method returns databricks:// URI format.""" + self.assertEqual( + self.db_io.alias(), "databricks://test_catalog/test_schema/test_table" + ) + + def test_rendered_name(self) -> None: + """Test rendered_name returns dot-separated format.""" + self.assertEqual( + self.db_io.rendered_name, "test_catalog.test_schema.test_table" + ) + + def test_airflow_name(self) -> None: + """Test airflow_name returns hyphen-separated format.""" + self.assertEqual( + self.db_io.airflow_name, "databricks-test_catalog-test_schema-test_table" + ) + + def test_name(self) -> None: + """Test name property from base IO class.""" + self.assertEqual(self.db_io.name, "test") + + def test_has_dependency_default(self) -> None: + """Test that has_dependency defaults to True.""" + self.assertTrue(self.db_io.has_dependency) + + +class TestDatabricksIOInlineConfig(unittest.TestCase): + """Test cases for DatabricksIO with inline configuration.""" + + def test_with_minimal_config(self) -> None: + """Test DatabricksIO with minimal required configuration.""" + config = { + "type": "databricks", + "name": "minimal_table", + "catalog": "my_catalog", + "schema": "my_schema", + "table": "my_table", + } + + db_io = databricks_io.DatabricksIO(config, "/test/path") + + self.assertEqual(db_io.catalog, "my_catalog") + self.assertEqual(db_io.schema, "my_schema") + self.assertEqual(db_io.table, "my_table") + self.assertEqual(db_io.name, "minimal_table") + + def test_alias_format_with_special_characters(self) -> None: + """Test alias format with underscores and numbers.""" + config = { + "type": "databricks", + "name": "output_123", + "catalog": "prod_catalog_v2", + "schema": "analytics_schema", + "table": "user_events_2024", + } + + db_io = databricks_io.DatabricksIO(config, "/") + + self.assertEqual( + db_io.alias(), + "databricks://prod_catalog_v2/analytics_schema/user_events_2024", + ) + self.assertEqual( + db_io.rendered_name, "prod_catalog_v2.analytics_schema.user_events_2024" + ) + self.assertEqual( + db_io.airflow_name, + "databricks-prod_catalog_v2-analytics_schema-user_events_2024", + ) + + def test_has_dependency_false(self) -> None: + """Test that has_dependency can be set to False.""" + config = { + "type": "databricks", + "name": "no_dep_table", + "catalog": "cat", + "schema": "sch", + "table": "tbl", + "has_dependency": False, + } + + db_io = databricks_io.DatabricksIO(config, "/") + + self.assertFalse(db_io.has_dependency) + + +class TestDatabricksIOMissingFields(unittest.TestCase): + """Test cases for DatabricksIO error handling.""" + + def test_missing_catalog_raises_exception(self) -> None: + """Test that missing catalog raises DaggerMissingFieldException.""" + config = { + "type": "databricks", + "name": "test_table", + "schema": "test_schema", + "table": "test_table", + } + + with self.assertRaises(DaggerMissingFieldException): + databricks_io.DatabricksIO(config, "/test/config.yaml") + + def test_missing_schema_raises_exception(self) -> None: + """Test that missing schema raises DaggerMissingFieldException.""" + config = { + "type": "databricks", + "name": "test_table", + "catalog": "test_catalog", + "table": "test_table", + } + + with self.assertRaises(DaggerMissingFieldException): + databricks_io.DatabricksIO(config, "/test/config.yaml") + + def test_missing_table_raises_exception(self) -> None: + """Test that missing table raises DaggerMissingFieldException.""" + config = { + "type": "databricks", + "name": "test_table", + "catalog": "test_catalog", + "schema": "test_schema", + } + + with self.assertRaises(DaggerMissingFieldException): + databricks_io.DatabricksIO(config, "/test/config.yaml") + + def test_missing_name_raises_exception(self) -> None: + """Test that missing name raises DaggerMissingFieldException.""" + config = { + "type": "databricks", + "catalog": "test_catalog", + "schema": "test_schema", + "table": "test_table", + } + + with self.assertRaises(DaggerMissingFieldException): + databricks_io.DatabricksIO(config, "/test/config.yaml") + + +class TestDatabricksIOEquality(unittest.TestCase): + """Test cases for DatabricksIO equality comparison.""" + + def test_equal_ios_are_equal(self) -> None: + """Test that two IOs with same alias are equal.""" + config1 = { + "type": "databricks", + "name": "table1", + "catalog": "cat", + "schema": "sch", + "table": "tbl", + } + config2 = { + "type": "databricks", + "name": "table2", # Different name, same catalog.schema.table + "catalog": "cat", + "schema": "sch", + "table": "tbl", + } + + io1 = databricks_io.DatabricksIO(config1, "/") + io2 = databricks_io.DatabricksIO(config2, "/") + + self.assertEqual(io1, io2) + + def test_different_ios_are_not_equal(self) -> None: + """Test that two IOs with different aliases are not equal.""" + config1 = { + "type": "databricks", + "name": "table1", + "catalog": "cat1", + "schema": "sch", + "table": "tbl", + } + config2 = { + "type": "databricks", + "name": "table2", + "catalog": "cat2", + "schema": "sch", + "table": "tbl", + } + + io1 = databricks_io.DatabricksIO(config1, "/") + io2 = databricks_io.DatabricksIO(config2, "/") + + self.assertNotEqual(io1, io2) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/pipeline/tasks/test_databricks_dlt_task.py b/tests/pipeline/tasks/test_databricks_dlt_task.py new file mode 100644 index 0000000..a222148 --- /dev/null +++ b/tests/pipeline/tasks/test_databricks_dlt_task.py @@ -0,0 +1,176 @@ +"""Unit tests for DatabricksDLTTask.""" + +import unittest +from unittest.mock import MagicMock + +import yaml + +from dagger.pipeline.tasks.databricks_dlt_task import DatabricksDLTTask + + +class TestDatabricksDLTTask(unittest.TestCase): + """Test cases for DatabricksDLTTask.""" + + def setUp(self) -> None: + """Set up test fixtures.""" + with open( + "tests/fixtures/pipeline/tasks/databricks_dlt_task.yaml", "r" + ) as stream: + self.config = yaml.safe_load(stream) + + # Create a mock pipeline object + self.mock_pipeline = MagicMock() + self.mock_pipeline.directory = "tests/fixtures/pipeline/tasks" + + self.task = DatabricksDLTTask( + name="test_dlt_task", + pipeline_name="test_pipeline", + pipeline=self.mock_pipeline, + job_config=self.config, + ) + + def test_ref_name(self) -> None: + """Test that ref_name is correctly set.""" + self.assertEqual(DatabricksDLTTask.ref_name, "databricks_dlt") + + def test_job_name(self) -> None: + """Test job_name property.""" + self.assertEqual(self.task.job_name, "test-dlt-job") + + def test_databricks_conn_id(self) -> None: + """Test databricks_conn_id property.""" + self.assertEqual(self.task.databricks_conn_id, "databricks_test") + + def test_wait_for_completion(self) -> None: + """Test wait_for_completion property.""" + self.assertTrue(self.task.wait_for_completion) + + def test_poll_interval_seconds(self) -> None: + """Test poll_interval_seconds property.""" + self.assertEqual(self.task.poll_interval_seconds, 60) + + def test_timeout_seconds(self) -> None: + """Test timeout_seconds property.""" + self.assertEqual(self.task.timeout_seconds, 7200) + + def test_cancel_on_kill(self) -> None: + """Test cancel_on_kill property.""" + self.assertTrue(self.task.cancel_on_kill) + + def test_task_name(self) -> None: + """Test that task name is correctly set.""" + self.assertEqual(self.task.name, "test_dlt_task") + + def test_pipeline_name(self) -> None: + """Test that pipeline_name is correctly set.""" + self.assertEqual(self.task.pipeline_name, "test_pipeline") + + +class TestDatabricksDLTTaskDefaults(unittest.TestCase): + """Test cases for DatabricksDLTTask default values.""" + + def setUp(self) -> None: + """Set up test fixtures with minimal config.""" + self.config = { + "type": "databricks_dlt", + "description": "Test DLT task with defaults", + "inputs": [], + "outputs": [], + "airflow_task_parameters": None, + "template_parameters": None, + "task_parameters": { + "job_name": "minimal-dlt-job", + }, + } + + self.mock_pipeline = MagicMock() + self.mock_pipeline.directory = "tests/fixtures/pipeline/tasks" + + self.task = DatabricksDLTTask( + name="minimal_dlt_task", + pipeline_name="test_pipeline", + pipeline=self.mock_pipeline, + job_config=self.config, + ) + + def test_default_databricks_conn_id(self) -> None: + """Test default databricks_conn_id value.""" + self.assertEqual(self.task.databricks_conn_id, "databricks_default") + + def test_default_wait_for_completion(self) -> None: + """Test default wait_for_completion value.""" + self.assertTrue(self.task.wait_for_completion) + + def test_default_poll_interval_seconds(self) -> None: + """Test default poll_interval_seconds value.""" + self.assertEqual(self.task.poll_interval_seconds, 30) + + def test_default_timeout_seconds(self) -> None: + """Test default timeout_seconds value.""" + self.assertEqual(self.task.timeout_seconds, 3600) + + def test_default_cancel_on_kill(self) -> None: + """Test default cancel_on_kill value.""" + self.assertTrue(self.task.cancel_on_kill) + + +class TestDatabricksDLTTaskBooleanHandling(unittest.TestCase): + """Test cases for boolean parameter handling edge cases.""" + + def test_wait_for_completion_false(self) -> None: + """Test that wait_for_completion=false is correctly handled.""" + config = { + "type": "databricks_dlt", + "description": "Test", + "inputs": [], + "outputs": [], + "airflow_task_parameters": None, + "template_parameters": None, + "task_parameters": { + "job_name": "test-job", + "wait_for_completion": False, + }, + } + + mock_pipeline = MagicMock() + mock_pipeline.directory = "tests/fixtures/pipeline/tasks" + + task = DatabricksDLTTask( + name="test_task", + pipeline_name="test_pipeline", + pipeline=mock_pipeline, + job_config=config, + ) + + self.assertFalse(task.wait_for_completion) + + def test_cancel_on_kill_false(self) -> None: + """Test that cancel_on_kill=false is correctly handled.""" + config = { + "type": "databricks_dlt", + "description": "Test", + "inputs": [], + "outputs": [], + "airflow_task_parameters": None, + "template_parameters": None, + "task_parameters": { + "job_name": "test-job", + "cancel_on_kill": False, + }, + } + + mock_pipeline = MagicMock() + mock_pipeline.directory = "tests/fixtures/pipeline/tasks" + + task = DatabricksDLTTask( + name="test_task", + pipeline_name="test_pipeline", + pipeline=mock_pipeline, + job_config=config, + ) + + self.assertFalse(task.cancel_on_kill) + + +if __name__ == "__main__": + unittest.main() From 40bdea4b037182e0107184bb9461a3d8cbb8805e Mon Sep 17 00:00:00 2001 From: Mohamad Hallak <16711801+mrhallak@users.noreply.github.com> Date: Tue, 13 Jan 2026 17:22:40 +0100 Subject: [PATCH 186/189] Add coding standard to avoid getattr in CLAUDE.md Prefer explicit properties over getattr for type safety and better IDE support. --- CLAUDE.md | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/CLAUDE.md b/CLAUDE.md index ffd1ebd..9cf1b75 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -80,3 +80,19 @@ dagger print-graph # Visualize dependency graph - **Factory Pattern**: TaskFactory/IOFactory auto-discover types via reflection - **Strategy Pattern**: OperatorCreator subclasses handle task-specific operator creation - **Dataset Aliasing**: IO `alias()` method enables automatic dependency detection across pipelines + +## Coding Standards + +### Avoid getattr +Do not use `getattr` for accessing task or IO properties. Instead, define explicit properties on the class. This ensures: +- Type safety and IDE autocompletion +- Clear interface contracts +- Easier debugging and testing + +```python +# Bad - avoid this pattern +value = getattr(self._task, 'some_property', default) + +# Good - use explicit properties +value = self._task.some_property # Property defined on task class +``` From 99898b47c1a87f7cde3c25bb659b8ea0a3563aa3 Mon Sep 17 00:00:00 2001 From: Mohamad Hallak <16711801+mrhallak@users.noreply.github.com> Date: Tue, 13 Jan 2026 17:26:54 +0100 Subject: [PATCH 187/189] Revert dbt_config_parser.py changes --- dagger/utilities/dbt_config_parser.py | 180 +++----------------------- 1 file changed, 19 insertions(+), 161 deletions(-) diff --git a/dagger/utilities/dbt_config_parser.py b/dagger/utilities/dbt_config_parser.py index 9b86d5d..9a341f6 100644 --- a/dagger/utilities/dbt_config_parser.py +++ b/dagger/utilities/dbt_config_parser.py @@ -296,20 +296,7 @@ def _generate_dagger_output(self, node: dict): class DatabricksDBTConfigParser(DBTConfigParser): - """DBT config parser implementation for Databricks Unity Catalog. - - Parses dbt manifest.json files for projects using the databricks-dbt adapter - and generates Dagger task configurations. Handles both Unity Catalog sources - (accessed via Databricks) and legacy Hive metastore sources (accessed via Athena). - - Attributes: - LEGACY_HIVE_DATABASES: Set of database names that indicate legacy Hive - metastore tables accessed via Athena rather than Unity Catalog. - """ - - # Schemas that indicate sources are in legacy Hive metastore (accessed via Athena) - # rather than Unity Catalog (accessed via Databricks) - LEGACY_HIVE_DATABASES: set[str] = {"hive_metastore"} + """Implementation for Databricks configurations.""" def __init__(self, default_config_parameters: dict): super().__init__(default_config_parameters) @@ -319,132 +306,17 @@ def __init__(self, default_config_parameters: dict): "create_external_athena_table", False ) - def _is_databricks_source(self, node: dict) -> bool: - """Check if a source is a Unity Catalog table (accessed via Databricks). - - Sources with database 'hive_metastore' are legacy tables accessed via Athena. - Sources with other databases (e.g., Unity Catalog like ${ENV_MARTS}) are - Databricks tables that should create databricks input tasks. - - Args: - node: The source node from dbt manifest - - Returns: - True if the source is a Unity Catalog table, False otherwise + def _is_node_preparation_model(self, node: dict): """ - database = node.get("database", "") - return database not in self.LEGACY_HIVE_DATABASES - - def _is_node_preparation_model(self, node: dict) -> bool: - """Determine whether a node is a preparation model. - - Preparation models are intermediate models in the transformation pipeline - that should not create external dependencies. - - Args: - node: The dbt node from manifest.json. - - Returns: - True if the node's schema contains 'preparation', False otherwise. + Define whether it is a preparation model. """ return "preparation" in node.get("schema", "") - def _get_databricks_source_task( - self, node: dict, follow_external_dependency: bool = False - ) -> dict: - """Generate a databricks input task for a Unity Catalog source. - - This is used for sources that point to Unity Catalog tables (e.g., DLT outputs) - rather than legacy Hive metastore tables. - - Args: - node: The source node from dbt manifest - follow_external_dependency: Whether to create an ExternalTaskSensor - - Returns: - Dagger databricks task configuration dict - """ - task = DATABRICKS_TASK_BASE.copy() - if follow_external_dependency: - task["follow_external_dependency"] = True - - task["catalog"] = node.get("database", self._default_catalog) - task["schema"] = node.get("schema", self._default_schema) - task["table"] = node.get("name", "") - task["name"] = f"{task['catalog']}__{task['schema']}__{task['table']}_databricks" - - return task - - def _generate_dagger_tasks(self, node_name: str) -> List[Dict]: - """Generate dagger tasks, with special handling for Databricks Unity Catalog sources. - - Overrides the base class method to handle sources that are in Unity Catalog - (e.g., DLT output tables) by creating databricks input tasks instead of athena tasks. - - Args: - node_name: The name of the DBT model node - - Returns: - List[Dict]: The respective dagger tasks for the DBT model node - """ - dagger_tasks = [] - - if node_name.startswith("source"): - node = self._sources_in_manifest[node_name] - else: - node = self._nodes_in_manifest[node_name] - - resource_type = node.get("resource_type") - materialized_type = node.get("config", {}).get("materialized") - - follow_external_dependency = True - if resource_type == "seed" or (self._is_node_preparation_model(node) and materialized_type != "table"): - follow_external_dependency = False - - if resource_type == "source": - # Check if this source is a Unity Catalog table (e.g., DLT outputs) - if self._is_databricks_source(node): - table_task = self._get_databricks_source_task( - node, follow_external_dependency=follow_external_dependency - ) - else: - # Legacy Hive metastore sources use Athena - table_task = self._get_athena_table_task( - node, follow_external_dependency=follow_external_dependency - ) - dagger_tasks.append(table_task) - - elif materialized_type == "ephemeral": - task = self._get_dummy_task(node) - dagger_tasks.append(task) - for dependent_node_name in node.get("depends_on", {}).get("nodes", []): - dagger_tasks += self._generate_dagger_tasks(dependent_node_name) - - else: - table_task = self._get_table_task(node, follow_external_dependency=follow_external_dependency) - dagger_tasks.append(table_task) - - if materialized_type in ("table", "incremental"): - dagger_tasks.append(self._get_s3_task(node)) - elif self._is_node_preparation_model(node): - for dependent_node_name in node.get("depends_on", {}).get("nodes", []): - dagger_tasks.extend( - self._generate_dagger_tasks(dependent_node_name) - ) - - return dagger_tasks - def _get_table_task( self, node: dict, follow_external_dependency: bool = False ) -> dict: - """Generate a Databricks table task for a dbt model node. - - Args: - node: The dbt model node from manifest.json. - follow_external_dependency: Whether to create an ExternalTaskSensor. - - Returns: - Dagger databricks task configuration dict. + """ + Generates the dagger databricks task for the DBT model node """ task = DATABRICKS_TASK_BASE.copy() if follow_external_dependency: @@ -462,15 +334,8 @@ def _get_table_task( def _get_model_data_location( self, node: dict, schema: str, model_name: str ) -> Tuple[str, str]: - """Get the S3 path of a dbt model relative to the data bucket. - - Args: - node: The dbt model node from manifest.json. - schema: The schema name (unused for Databricks, kept for interface compatibility). - model_name: The model name. - - Returns: - Tuple of (bucket_name, data_path). + """ + Gets the S3 path of the dbt model relative to the data bucket. """ location_root = node.get("config", {}).get("location_root") location = join(location_root, model_name) @@ -480,39 +345,32 @@ def _get_model_data_location( return bucket_name, data_path def _get_s3_task(self, node: dict, is_output: bool = False) -> dict: - """Generate an S3 task for a databricks-dbt model node. - - Args: - node: The dbt model node from manifest.json. - is_output: If True, names the task 'output_s3_path' for output declarations. - - Returns: - Dagger S3 task configuration dict. + """ + Generates the dagger s3 task for the databricks-dbt model node """ task = S3_TASK_BASE.copy() schema = node.get("schema", self._default_schema) table = node.get("name", "") - task["name"] = "output_s3_path" if is_output else f"s3_{table}" + task["name"] = f"output_s3_path" if is_output else f"s3_{table}" task["bucket"], task["path"] = self._get_model_data_location( node, schema, table ) return task - def _generate_dagger_output(self, node: dict) -> List[Dict]: - """Generate dagger output tasks for a databricks-dbt model node. - - Creates output task configurations based on the model's materialization type: - - Ephemeral models produce a dummy task - - Table/incremental models produce databricks + S3 tasks - - Optionally adds an Athena task if create_external_athena_table is True - + def _generate_dagger_output(self, node: dict): + """ + Generates the dagger output for the DBT model node with the databricks-dbt adapter. + If the model is materialized as a view or ephemeral, then a dummy task is created. + Otherwise, and databricks and s3 task is created for the DBT model node. + And if create_external_athena_table is True te an extra athena task is created. Args: - node: The dbt model node from manifest.json. + node: The extracted node from the manifest.json file Returns: - List of dagger output task configuration dicts. + dict: The dagger output, which is a combination of an athena and s3 task for the DBT model node + """ materialized_type = node.get("config", {}).get("materialized") if materialized_type == "ephemeral": From a0286d342503ede4f606b79d99bfe5237eb9cb3e Mon Sep 17 00:00:00 2001 From: Mohamad Hallak <16711801+mrhallak@users.noreply.github.com> Date: Thu, 15 Jan 2026 16:43:50 +0100 Subject: [PATCH 188/189] Rename databricks_dlt to declarative_pipeline task type Rename task type and operator creator to match the generated task YAML files from the DLT task generator plugin. Files renamed: - databricks_dlt_task.py -> declarative_pipeline_task.py - databricks_dlt_creator.py -> declarative_pipeline_creator.py - test_databricks_dlt_task.py -> test_declarative_pipeline_task.py - test_databricks_dlt_creator.py -> test_declarative_pipeline_creator.py - databricks_dlt_task.yaml -> declarative_pipeline_task.yaml Changes: - Class names: DatabricksDLTTask -> DeclarativePipelineTask - Class names: DatabricksDLTCreator -> DeclarativePipelineCreator - ref_name: "databricks_dlt" -> "declarative_pipeline" - Updated imports in factory files --- ...tor.py => declarative_pipeline_creator.py} | 24 +++++++------- .../dag_creator/airflow/operator_factory.py | 2 +- dagger/pipeline/task_factory.py | 2 +- ...t_task.py => declarative_pipeline_task.py} | 32 +++++++++++-------- ...y => test_declarative_pipeline_creator.py} | 24 +++++++------- ...sk.yaml => declarative_pipeline_task.yaml} | 2 +- ...k.py => test_declarative_pipeline_task.py} | 32 +++++++++---------- 7 files changed, 61 insertions(+), 57 deletions(-) rename dagger/dag_creator/airflow/operator_creators/{databricks_dlt_creator.py => declarative_pipeline_creator.py} (83%) rename dagger/pipeline/tasks/{databricks_dlt_task.py => declarative_pipeline_task.py} (87%) rename tests/dag_creator/airflow/operator_creators/{test_databricks_dlt_creator.py => test_declarative_pipeline_creator.py} (92%) rename tests/fixtures/pipeline/tasks/{databricks_dlt_task.yaml => declarative_pipeline_task.yaml} (94%) rename tests/pipeline/tasks/{test_databricks_dlt_task.py => test_declarative_pipeline_task.py} (84%) diff --git a/dagger/dag_creator/airflow/operator_creators/databricks_dlt_creator.py b/dagger/dag_creator/airflow/operator_creators/declarative_pipeline_creator.py similarity index 83% rename from dagger/dag_creator/airflow/operator_creators/databricks_dlt_creator.py rename to dagger/dag_creator/airflow/operator_creators/declarative_pipeline_creator.py index 87a11ac..94ba721 100644 --- a/dagger/dag_creator/airflow/operator_creators/databricks_dlt_creator.py +++ b/dagger/dag_creator/airflow/operator_creators/declarative_pipeline_creator.py @@ -1,4 +1,4 @@ -"""Operator creator for Databricks DLT (Delta Live Tables) pipelines.""" +"""Operator creator for declarative pipelines (DLT/Delta Live Tables).""" import logging from typing import Any @@ -6,7 +6,7 @@ from airflow.models import BaseOperator, DAG from dagger.dag_creator.airflow.operator_creator import OperatorCreator -from dagger.pipeline.tasks.databricks_dlt_task import DatabricksDLTTask +from dagger.pipeline.tasks.declarative_pipeline_task import DeclarativePipelineTask _logger = logging.getLogger(__name__) @@ -51,8 +51,8 @@ def _cancel_databricks_run(context: dict[str, Any]) -> None: _logger.error(f"Failed to cancel Databricks run {run_id}: {e}") -class DatabricksDLTCreator(OperatorCreator): - """Creates operators for triggering Databricks DLT pipelines via Jobs. +class DeclarativePipelineCreator(OperatorCreator): + """Creates operators for triggering declarative pipelines via Databricks Jobs. This creator uses DatabricksRunNowOperator to trigger a Databricks Job that wraps the DLT pipeline. The job is identified by name and must be @@ -60,22 +60,22 @@ class DatabricksDLTCreator(OperatorCreator): Attributes: ref_name: Reference name used by OperatorFactory to match this creator - with DatabricksDLTTask instances. + with DeclarativePipelineTask instances. """ - ref_name: str = "databricks_dlt" + ref_name: str = "declarative_pipeline" - def __init__(self, task: DatabricksDLTTask, dag: DAG) -> None: - """Initialize the DatabricksDLTCreator. + def __init__(self, task: DeclarativePipelineTask, dag: DAG) -> None: + """Initialize the DeclarativePipelineCreator. Args: - task: The DatabricksDLTTask containing pipeline configuration. + task: The DeclarativePipelineTask containing pipeline configuration. dag: The Airflow DAG this operator will belong to. """ super().__init__(task, dag) def _create_operator(self, **kwargs: Any) -> BaseOperator: - """Create a DatabricksRunNowOperator for the DLT pipeline. + """Create a DatabricksRunNowOperator for the declarative pipeline. Creates an Airflow operator that triggers an existing Databricks Job by name. The job must have a pipeline_task that references the DLT @@ -97,11 +97,11 @@ def _create_operator(self, **kwargs: Any) -> BaseOperator: DatabricksRunNowOperator, ) - # Get task parameters - defaults are handled in DatabricksDLTTask + # Get task parameters - defaults are handled in DeclarativePipelineTask job_name: str = self._task.job_name if not job_name: raise ValueError( - f"job_name is required for DatabricksDLTTask '{self._task.name}'" + f"job_name is required for DeclarativePipelineTask '{self._task.name}'" ) databricks_conn_id: str = self._task.databricks_conn_id wait_for_completion: bool = self._task.wait_for_completion diff --git a/dagger/dag_creator/airflow/operator_factory.py b/dagger/dag_creator/airflow/operator_factory.py index dd7344e..079e3a0 100644 --- a/dagger/dag_creator/airflow/operator_factory.py +++ b/dagger/dag_creator/airflow/operator_factory.py @@ -4,7 +4,7 @@ airflow_op_creator, athena_transform_creator, batch_creator, - databricks_dlt_creator, + declarative_pipeline_creator, dbt_creator, dummy_creator, python_creator, diff --git a/dagger/pipeline/task_factory.py b/dagger/pipeline/task_factory.py index f5f80bb..f7a89f2 100644 --- a/dagger/pipeline/task_factory.py +++ b/dagger/pipeline/task_factory.py @@ -3,7 +3,7 @@ airflow_op_task, athena_transform_task, batch_task, - databricks_dlt_task, + declarative_pipeline_task, dbt_task, dummy_task, python_task, diff --git a/dagger/pipeline/tasks/databricks_dlt_task.py b/dagger/pipeline/tasks/declarative_pipeline_task.py similarity index 87% rename from dagger/pipeline/tasks/databricks_dlt_task.py rename to dagger/pipeline/tasks/declarative_pipeline_task.py index 4f0b113..5cb4e81 100644 --- a/dagger/pipeline/tasks/databricks_dlt_task.py +++ b/dagger/pipeline/tasks/declarative_pipeline_task.py @@ -1,4 +1,4 @@ -"""Task configuration for Databricks DLT (Delta Live Tables) pipelines.""" +"""Task configuration for declarative pipelines (DLT/Delta Live Tables).""" from typing import Any, Optional @@ -6,8 +6,8 @@ from dagger.utilities.config_validator import Attribute -class DatabricksDLTTask(Task): - """Task configuration for triggering Databricks DLT pipelines via Jobs. +class DeclarativePipelineTask(Task): + """Task configuration for triggering declarative pipelines via Databricks Jobs. This task type uses DatabricksRunNowOperator to trigger a Databricks Job that wraps the DLT pipeline. The job is identified by name and must be @@ -23,27 +23,31 @@ class DatabricksDLTTask(Task): cancel_on_kill: Whether to cancel Databricks job if Airflow task is killed. Example YAML configuration: - type: databricks_dlt + type: declarative_pipeline description: Run DLT pipeline users inputs: - - type: athena - schema: ddb_changelogs - table: order_preference - follow_external_dependency: true + - type: s3 + name: input_order_service_public_users + bucket: cho${ENV}-data-lake + path: pg_changelogs/kafka/order-service/order_service.public.users outputs: - type: databricks - catalog: ${ENV_MARTS} - schema: dlt_users - table: silver_order_preference + catalog: changelogs + schema: order_service_public + table: pg_users + - type: databricks + catalog: core + schema: order_service + table: pg_users task_parameters: - job_name: dlt-users + job_name: "[JOB] order-service-pipeline" databricks_conn_id: databricks_default wait_for_completion: true poll_interval_seconds: 30 timeout_seconds: 3600 """ - ref_name: str = "databricks_dlt" + ref_name: str = "declarative_pipeline" @classmethod def init_attributes(cls, orig_cls: type) -> None: @@ -106,7 +110,7 @@ def __init__( pipeline: Any, job_config: dict[str, Any], ) -> None: - """Initialize a DatabricksDLTTask instance. + """Initialize a DeclarativePipelineTask instance. Args: name: The task name (used as task_id in Airflow). diff --git a/tests/dag_creator/airflow/operator_creators/test_databricks_dlt_creator.py b/tests/dag_creator/airflow/operator_creators/test_declarative_pipeline_creator.py similarity index 92% rename from tests/dag_creator/airflow/operator_creators/test_databricks_dlt_creator.py rename to tests/dag_creator/airflow/operator_creators/test_declarative_pipeline_creator.py index 39de91b..5addd03 100644 --- a/tests/dag_creator/airflow/operator_creators/test_databricks_dlt_creator.py +++ b/tests/dag_creator/airflow/operator_creators/test_declarative_pipeline_creator.py @@ -1,18 +1,18 @@ -"""Unit tests for DatabricksDLTCreator.""" +"""Unit tests for DeclarativePipelineCreator.""" import sys import unittest from datetime import timedelta from unittest.mock import MagicMock, patch -from dagger.dag_creator.airflow.operator_creators.databricks_dlt_creator import ( - DatabricksDLTCreator, +from dagger.dag_creator.airflow.operator_creators.declarative_pipeline_creator import ( + DeclarativePipelineCreator, _cancel_databricks_run, ) -class TestDatabricksDLTCreator(unittest.TestCase): - """Test cases for DatabricksDLTCreator.""" +class TestDeclarativePipelineCreator(unittest.TestCase): + """Test cases for DeclarativePipelineCreator.""" def setUp(self) -> None: """Set up test fixtures.""" @@ -35,7 +35,7 @@ def setUp(self) -> None: def test_ref_name(self) -> None: """Test that ref_name is correctly set.""" - self.assertEqual(DatabricksDLTCreator.ref_name, "databricks_dlt") + self.assertEqual(DeclarativePipelineCreator.ref_name, "declarative_pipeline") @patch.dict( sys.modules, @@ -49,7 +49,7 @@ def test_create_operator(self) -> None: "airflow.providers.databricks.operators.databricks" ].DatabricksRunNowOperator = mock_operator_class - creator = DatabricksDLTCreator(self.mock_task, self.mock_dag) + creator = DeclarativePipelineCreator(self.mock_task, self.mock_dag) operator = creator._create_operator() mock_operator_class.assert_called_once() @@ -66,7 +66,7 @@ def test_create_operator_maps_task_properties(self) -> None: "airflow.providers.databricks.operators.databricks" ].DatabricksRunNowOperator = mock_operator_class - creator = DatabricksDLTCreator(self.mock_task, self.mock_dag) + creator = DeclarativePipelineCreator(self.mock_task, self.mock_dag) creator._create_operator() call_kwargs = mock_operator_class.call_args[1] @@ -97,7 +97,7 @@ def test_create_operator_with_custom_values(self) -> None: "airflow.providers.databricks.operators.databricks" ].DatabricksRunNowOperator = mock_operator_class - creator = DatabricksDLTCreator(self.mock_task, self.mock_dag) + creator = DeclarativePipelineCreator(self.mock_task, self.mock_dag) creator._create_operator() call_kwargs = mock_operator_class.call_args[1] @@ -115,7 +115,7 @@ def test_create_operator_empty_job_name_raises_error(self) -> None: """Test that empty job_name raises ValueError.""" self.mock_task.job_name = "" - creator = DatabricksDLTCreator(self.mock_task, self.mock_dag) + creator = DeclarativePipelineCreator(self.mock_task, self.mock_dag) with self.assertRaises(ValueError) as context: creator._create_operator() @@ -131,7 +131,7 @@ def test_create_operator_none_job_name_raises_error(self) -> None: """Test that None job_name raises ValueError.""" self.mock_task.job_name = None - creator = DatabricksDLTCreator(self.mock_task, self.mock_dag) + creator = DeclarativePipelineCreator(self.mock_task, self.mock_dag) with self.assertRaises(ValueError) as context: creator._create_operator() @@ -149,7 +149,7 @@ def test_create_operator_passes_kwargs(self) -> None: "airflow.providers.databricks.operators.databricks" ].DatabricksRunNowOperator = mock_operator_class - creator = DatabricksDLTCreator(self.mock_task, self.mock_dag) + creator = DeclarativePipelineCreator(self.mock_task, self.mock_dag) creator._create_operator(retries=3, retry_delay=60) call_kwargs = mock_operator_class.call_args[1] diff --git a/tests/fixtures/pipeline/tasks/databricks_dlt_task.yaml b/tests/fixtures/pipeline/tasks/declarative_pipeline_task.yaml similarity index 94% rename from tests/fixtures/pipeline/tasks/databricks_dlt_task.yaml rename to tests/fixtures/pipeline/tasks/declarative_pipeline_task.yaml index 1902cf0..be20955 100644 --- a/tests/fixtures/pipeline/tasks/databricks_dlt_task.yaml +++ b/tests/fixtures/pipeline/tasks/declarative_pipeline_task.yaml @@ -1,4 +1,4 @@ -type: databricks_dlt +type: declarative_pipeline description: Test DLT pipeline task inputs: - type: athena diff --git a/tests/pipeline/tasks/test_databricks_dlt_task.py b/tests/pipeline/tasks/test_declarative_pipeline_task.py similarity index 84% rename from tests/pipeline/tasks/test_databricks_dlt_task.py rename to tests/pipeline/tasks/test_declarative_pipeline_task.py index a222148..5904351 100644 --- a/tests/pipeline/tasks/test_databricks_dlt_task.py +++ b/tests/pipeline/tasks/test_declarative_pipeline_task.py @@ -1,20 +1,20 @@ -"""Unit tests for DatabricksDLTTask.""" +"""Unit tests for DeclarativePipelineTask.""" import unittest from unittest.mock import MagicMock import yaml -from dagger.pipeline.tasks.databricks_dlt_task import DatabricksDLTTask +from dagger.pipeline.tasks.declarative_pipeline_task import DeclarativePipelineTask -class TestDatabricksDLTTask(unittest.TestCase): - """Test cases for DatabricksDLTTask.""" +class TestDeclarativePipelineTask(unittest.TestCase): + """Test cases for DeclarativePipelineTask.""" def setUp(self) -> None: """Set up test fixtures.""" with open( - "tests/fixtures/pipeline/tasks/databricks_dlt_task.yaml", "r" + "tests/fixtures/pipeline/tasks/declarative_pipeline_task.yaml", "r" ) as stream: self.config = yaml.safe_load(stream) @@ -22,7 +22,7 @@ def setUp(self) -> None: self.mock_pipeline = MagicMock() self.mock_pipeline.directory = "tests/fixtures/pipeline/tasks" - self.task = DatabricksDLTTask( + self.task = DeclarativePipelineTask( name="test_dlt_task", pipeline_name="test_pipeline", pipeline=self.mock_pipeline, @@ -31,7 +31,7 @@ def setUp(self) -> None: def test_ref_name(self) -> None: """Test that ref_name is correctly set.""" - self.assertEqual(DatabricksDLTTask.ref_name, "databricks_dlt") + self.assertEqual(DeclarativePipelineTask.ref_name, "declarative_pipeline") def test_job_name(self) -> None: """Test job_name property.""" @@ -66,13 +66,13 @@ def test_pipeline_name(self) -> None: self.assertEqual(self.task.pipeline_name, "test_pipeline") -class TestDatabricksDLTTaskDefaults(unittest.TestCase): - """Test cases for DatabricksDLTTask default values.""" +class TestDeclarativePipelineTaskDefaults(unittest.TestCase): + """Test cases for DeclarativePipelineTask default values.""" def setUp(self) -> None: """Set up test fixtures with minimal config.""" self.config = { - "type": "databricks_dlt", + "type": "declarative_pipeline", "description": "Test DLT task with defaults", "inputs": [], "outputs": [], @@ -86,7 +86,7 @@ def setUp(self) -> None: self.mock_pipeline = MagicMock() self.mock_pipeline.directory = "tests/fixtures/pipeline/tasks" - self.task = DatabricksDLTTask( + self.task = DeclarativePipelineTask( name="minimal_dlt_task", pipeline_name="test_pipeline", pipeline=self.mock_pipeline, @@ -114,13 +114,13 @@ def test_default_cancel_on_kill(self) -> None: self.assertTrue(self.task.cancel_on_kill) -class TestDatabricksDLTTaskBooleanHandling(unittest.TestCase): +class TestDeclarativePipelineTaskBooleanHandling(unittest.TestCase): """Test cases for boolean parameter handling edge cases.""" def test_wait_for_completion_false(self) -> None: """Test that wait_for_completion=false is correctly handled.""" config = { - "type": "databricks_dlt", + "type": "declarative_pipeline", "description": "Test", "inputs": [], "outputs": [], @@ -135,7 +135,7 @@ def test_wait_for_completion_false(self) -> None: mock_pipeline = MagicMock() mock_pipeline.directory = "tests/fixtures/pipeline/tasks" - task = DatabricksDLTTask( + task = DeclarativePipelineTask( name="test_task", pipeline_name="test_pipeline", pipeline=mock_pipeline, @@ -147,7 +147,7 @@ def test_wait_for_completion_false(self) -> None: def test_cancel_on_kill_false(self) -> None: """Test that cancel_on_kill=false is correctly handled.""" config = { - "type": "databricks_dlt", + "type": "declarative_pipeline", "description": "Test", "inputs": [], "outputs": [], @@ -162,7 +162,7 @@ def test_cancel_on_kill_false(self) -> None: mock_pipeline = MagicMock() mock_pipeline.directory = "tests/fixtures/pipeline/tasks" - task = DatabricksDLTTask( + task = DeclarativePipelineTask( name="test_task", pipeline_name="test_pipeline", pipeline=mock_pipeline, From 6120683d31fa1adc937781bf178bb61f3dd0300e Mon Sep 17 00:00:00 2001 From: Mohamad Hallak <16711801+mrhallak@users.noreply.github.com> Date: Thu, 12 Feb 2026 11:11:36 +0100 Subject: [PATCH 189/189] Fix Slack alerts not firing for declarative pipeline tasks The on_failure_callback on DatabricksRunNowOperator was overriding the DAG-level callback (Slack alerts) set in default_args. Build a composite callback that cancels the Databricks run AND invokes the DAG-level alert callback. --- .../declarative_pipeline_creator.py | 37 +++++++- .../test_declarative_pipeline_creator.py | 87 +++++++++++++++++++ 2 files changed, 122 insertions(+), 2 deletions(-) diff --git a/dagger/dag_creator/airflow/operator_creators/declarative_pipeline_creator.py b/dagger/dag_creator/airflow/operator_creators/declarative_pipeline_creator.py index 94ba721..0f58808 100644 --- a/dagger/dag_creator/airflow/operator_creators/declarative_pipeline_creator.py +++ b/dagger/dag_creator/airflow/operator_creators/declarative_pipeline_creator.py @@ -1,7 +1,7 @@ """Operator creator for declarative pipelines (DLT/Delta Live Tables).""" import logging -from typing import Any +from typing import Any, Callable, Optional from airflow.models import BaseOperator, DAG @@ -51,6 +51,33 @@ def _cancel_databricks_run(context: dict[str, Any]) -> None: _logger.error(f"Failed to cancel Databricks run {run_id}: {e}") +def _build_failure_callback( + dag_callback: Optional[Callable[[dict[str, Any]], None]], +) -> Callable[[dict[str, Any]], None]: + """Build a failure callback that cancels the Databricks run and invokes the DAG-level callback. + + When a DAG defines an ``on_failure_callback`` in its ``default_args`` (e.g. for + Slack alerts), that callback is normally overridden by operator-level callbacks. + This helper produces a single callback that always cancels the Databricks run + **and** forwards to the DAG-level callback so that alerts still fire. + + Args: + dag_callback: The DAG-level ``on_failure_callback`` (from ``default_args``), + or ``None`` if the DAG has no failure callback configured. + + Returns: + A callback suitable for ``on_failure_callback`` on the operator. + """ + if dag_callback is None: + return _cancel_databricks_run + + def _composite_failure_callback(context: dict[str, Any]) -> None: + _cancel_databricks_run(context) + dag_callback(context) + + return _composite_failure_callback + + class DeclarativePipelineCreator(OperatorCreator): """Creates operators for triggering declarative pipelines via Databricks Jobs. @@ -108,6 +135,12 @@ def _create_operator(self, **kwargs: Any) -> BaseOperator: poll_interval_seconds: int = self._task.poll_interval_seconds timeout_seconds: int = self._task.timeout_seconds + # Build the on_failure_callback: always cancel the Databricks run, + # and also invoke the DAG-level callback (e.g. Slack alerts) if one exists. + on_failure: Callable[[dict[str, Any]], None] = _build_failure_callback( + self._dag.default_args.get("on_failure_callback"), + ) + # DatabricksRunNowOperator triggers an existing Databricks Job by name # The job must have a pipeline_task that references the DLT pipeline # Note: timeout is handled via Airflow's execution_timeout, not a direct parameter @@ -122,7 +155,7 @@ def _create_operator(self, **kwargs: Any) -> BaseOperator: polling_period_seconds=poll_interval_seconds, execution_timeout=timedelta(seconds=timeout_seconds), do_xcom_push=True, # Required to store run_id for cancellation callback - on_failure_callback=_cancel_databricks_run, + on_failure_callback=on_failure, **kwargs, ) diff --git a/tests/dag_creator/airflow/operator_creators/test_declarative_pipeline_creator.py b/tests/dag_creator/airflow/operator_creators/test_declarative_pipeline_creator.py index 5addd03..cf013f4 100644 --- a/tests/dag_creator/airflow/operator_creators/test_declarative_pipeline_creator.py +++ b/tests/dag_creator/airflow/operator_creators/test_declarative_pipeline_creator.py @@ -7,6 +7,7 @@ from dagger.dag_creator.airflow.operator_creators.declarative_pipeline_creator import ( DeclarativePipelineCreator, + _build_failure_callback, _cancel_databricks_run, ) @@ -26,6 +27,7 @@ def setUp(self) -> None: self.mock_task.cancel_on_kill = True self.mock_dag = MagicMock() + self.mock_dag.default_args = {} # Set up mock for DatabricksRunNowOperator self.mock_operator = MagicMock() @@ -79,8 +81,33 @@ def test_create_operator_maps_task_properties(self) -> None: self.assertEqual(call_kwargs["polling_period_seconds"], 30) self.assertEqual(call_kwargs["execution_timeout"], timedelta(seconds=3600)) self.assertTrue(call_kwargs["do_xcom_push"]) + # No DAG-level on_failure_callback, so only _cancel_databricks_run is used self.assertEqual(call_kwargs["on_failure_callback"], _cancel_databricks_run) + @patch.dict( + sys.modules, + {"airflow.providers.databricks.operators.databricks": MagicMock()}, + ) + def test_create_operator_with_dag_failure_callback_uses_composite(self) -> None: + """Test that DAG-level on_failure_callback is chained with cancel callback.""" + mock_dag_callback = MagicMock() + self.mock_dag.default_args = {"on_failure_callback": mock_dag_callback} + + mock_operator_class = MagicMock() + sys.modules[ + "airflow.providers.databricks.operators.databricks" + ].DatabricksRunNowOperator = mock_operator_class + + creator = DeclarativePipelineCreator(self.mock_task, self.mock_dag) + creator._create_operator() + + call_kwargs = mock_operator_class.call_args[1] + + # Should be a composite callback, not _cancel_databricks_run directly + callback = call_kwargs["on_failure_callback"] + self.assertNotEqual(callback, _cancel_databricks_run) + self.assertNotEqual(callback, mock_dag_callback) + @patch.dict( sys.modules, {"airflow.providers.databricks.operators.databricks": MagicMock()}, @@ -272,5 +299,65 @@ def test_cancel_run_handles_import_error(self) -> None: _cancel_databricks_run(context) +class TestBuildFailureCallback(unittest.TestCase): + """Test cases for _build_failure_callback.""" + + def test_returns_cancel_callback_when_no_dag_callback(self) -> None: + """Test that _cancel_databricks_run is returned when there is no DAG callback.""" + result = _build_failure_callback(None) + self.assertIs(result, _cancel_databricks_run) + + def test_returns_composite_when_dag_callback_exists(self) -> None: + """Test that a composite callback is returned when DAG callback exists.""" + dag_callback = MagicMock() + result = _build_failure_callback(dag_callback) + + self.assertIsNot(result, _cancel_databricks_run) + self.assertIsNot(result, dag_callback) + + @patch.dict( + sys.modules, + {"airflow.providers.databricks.hooks.databricks": MagicMock()}, + ) + def test_composite_calls_both_callbacks(self) -> None: + """Test that composite callback invokes both cancel and DAG callbacks.""" + mock_hook = MagicMock() + mock_hook_class = MagicMock(return_value=mock_hook) + sys.modules[ + "airflow.providers.databricks.hooks.databricks" + ].DatabricksHook = mock_hook_class + + mock_ti = MagicMock() + mock_ti.task_id = "test_task" + mock_ti.xcom_pull.return_value = "run_123" + mock_ti.task.databricks_conn_id = "databricks_default" + + context: dict = {"task_instance": mock_ti} + dag_callback = MagicMock() + + composite = _build_failure_callback(dag_callback) + composite(context) + + # _cancel_databricks_run was invoked (hook was called) + mock_hook.cancel_run.assert_called_once_with("run_123") + # DAG-level callback (e.g. Slack alert) was also invoked + dag_callback.assert_called_once_with(context) + + def test_composite_calls_dag_callback_even_if_cancel_has_no_run_id(self) -> None: + """Test that DAG callback fires even when there is no Databricks run to cancel.""" + mock_ti = MagicMock() + mock_ti.task_id = "test_task" + mock_ti.xcom_pull.return_value = None # no run_id + + context: dict = {"task_instance": mock_ti} + dag_callback = MagicMock() + + composite = _build_failure_callback(dag_callback) + composite(context) + + # DAG callback should still be called + dag_callback.assert_called_once_with(context) + + if __name__ == "__main__": unittest.main()