diff --git a/.github/workflows/ci-data.yml b/.github/workflows/ci-data.yml index 599d325..bef5bed 100644 --- a/.github/workflows/ci-data.yml +++ b/.github/workflows/ci-data.yml @@ -17,10 +17,10 @@ jobs: with: persist-credentials: false - - name: Set up Python 3.7 + - name: Set up Python 3.9 uses: actions/setup-python@v2 with: - python-version: 3.7 + python-version: 3.9 - name: Install dependencies run: | diff --git a/Makefile b/Makefile index 02daa9e..8872dba 100644 --- a/Makefile +++ b/Makefile @@ -96,15 +96,15 @@ install: clean ## install the package to the active Python's site-packages install-dev: clean ## install the package to the active Python's site-packages - virtualenv -p python3 venv; \ + virtualenv -p python3.9 venv; \ source venv/bin/activate; \ python -m pip install --upgrade pip; \ python setup.py install; \ pip install -e . ; \ - pip install -r reqs/dev.txt -r reqs/test.txt + SYSTEM_VERSION_COMPAT=0 CFLAGS='-std=c++20' pip install -r reqs/dev.txt -r reqs/test.txt install-test: clean ## install the package to the active Python's site-packages - virtualenv -p python3 venv; \ + virtualenv -p python3.9 venv; \ source venv/bin/activate; \ python -m pip install --upgrade pip; \ pip install -r reqs/test.txt -r reqs/base.txt diff --git a/README.md b/README.md index 2a161be..8e59257 100644 --- a/README.md +++ b/README.md @@ -87,6 +87,41 @@ flowchart TD; ``` +Plugins for dagger +------- + +### Overview +Dagger now supports a plugin system that allows users to extend its functionality by adding custom Python classes. These plugins are integrated into the Jinja2 templating engine, enabling dynamic rendering of task configuration templates. +### Purpose +The plugin system allows users to define Python classes that can be loaded into the Jinja2 environment. When functions from these classes are invoked within a task configuration template, they are rendered dynamically using Jinja2. This feature enhances the flexibility of task configurations by allowing custom logic to be embedded directly in the templates. + +### Usage +1. **Creating a Plugin:** To create a new plugin, define a Python class in a folder(for example `plugins/sample_plugin/sample_plugin.py`) with the desired methods. For example: +```python +class MyCustomPlugin: + def generate_input(self, branch_name): + return [{"name": f"{branch_name}", "type": "dummy"}] +``` +This class defines a `generate_input` method that takes the branch_name from the module config and returns a dummy dagger task. + +2. **Loading the Plugin into Dagger:** To load this plugin into Dagger's Jinja2 environment, you need to register it in your `dagger_config.yaml`: +```yaml +# pipeline.yaml +plugin: + paths: + - plugins # all Python classes within this path will be loaded into the Jinja environment +``` + +3. **Using Plugin Methods in Templates:** Once the plugin is loaded, you can call its methods from within any Jinja2 template in your task configurations: +```yaml +# task_configuration.yaml +type: batch +description: sample task +inputs: # format: list | Use dagger init-io cli + {{ MyCustomPlugin.generate_input("dummy_input") }} +``` + + Credits ------- diff --git a/dagger/cli/module.py b/dagger/cli/module.py index 67fca87..931e809 100644 --- a/dagger/cli/module.py +++ b/dagger/cli/module.py @@ -1,17 +1,34 @@ import click from dagger.utilities.module import Module from dagger.utils import Printer +import json +def parse_key_value(ctx, param, value): + #print('YYY', value) + if not value: + return {} + key_value_dict = {} + for pair in value: + try: + key, val_file_path = pair.split('=', 1) + #print('YYY', key, val_file_path, pair) + val = json.load(open(val_file_path)) + key_value_dict[key] = val + except ValueError: + raise click.BadParameter(f"Key-value pair '{pair}' is not in the format key=value") + return key_value_dict + @click.command() @click.option("--config_file", "-c", help="Path to module config file") @click.option("--target_dir", "-t", help="Path to directory to generate the task configs to") -def generate_tasks(config_file: str, target_dir: str) -> None: +@click.option("--jinja_parameters", "-j", callback=parse_key_value, multiple=True, default=None, help="Path to jinja parameters json file in the format: =") +def generate_tasks(config_file: str, target_dir: str, jinja_parameters: dict) -> None: """ Generating tasks for a module based on config """ - module = Module(config_file, target_dir) + module = Module(config_file, target_dir, jinja_parameters) module.generate_task_configs() Printer.print_success("Tasks are successfully generated") diff --git a/dagger/conf.py b/dagger/conf.py index 667c207..df2ab8e 100644 --- a/dagger/conf.py +++ b/dagger/conf.py @@ -21,9 +21,7 @@ # Airflow parameters airflow_config = config.get('airflow', None) or {} WITH_DATA_NODES = airflow_config.get('with_data_nodes', False) -EXTERNAL_SENSOR_POKE_INTERVAL = airflow_config.get('external_sensor_poke_interval', 600) -EXTERNAL_SENSOR_TIMEOUT = airflow_config.get('external_sensor_timeout', 28800) -EXTERNAL_SENSOR_MODE = airflow_config.get('external_sensor_mode', 'reschedule') +EXTERNAL_SENSOR_DEFAULT_ARGS = airflow_config.get('external_sensor_default_args', {}) IS_DUMMY_OPERATOR_SHORT_CIRCUIT = airflow_config.get('is_dummy_operator_short_circuit', False) # Neo4j parameters @@ -100,4 +98,15 @@ # Alert parameters alert_config = config.get('alert', None) or {} SLACK_TOKEN = alert_config.get('slack_token', None) -DEFAULT_ALERT = alert_config.get('default_alert', {"type": "slack", "channel": "#airflow-jobs", "mentions": None}) \ No newline at end of file +DEFAULT_ALERT = alert_config.get('default_alert', {"type": "slack", "channel": "#airflow-jobs", "mentions": None}) + +# Plugin parameters +plugin_config = config.get('plugin', None) or {} +PLUGIN_DIRS = [os.path.join(AIRFLOW_HOME, path) for path in plugin_config.get('paths', [])] +logging.info(f"All Python classes will be loaded as plugins from the following directories: {PLUGIN_DIRS}") + +# ReverseETL parameters +reverse_etl_config = config.get('reverse_etl', None) or {} +REVERSE_ETL_DEFAULT_JOB_NAME = reverse_etl_config.get('default_job_name', None) +REVERSE_ETL_DEFAULT_EXECUTABLE_PREFIX = reverse_etl_config.get('default_executable_prefix', None) +REVERSE_ETL_DEFAULT_EXECUTABLE = reverse_etl_config.get('default_executable', None) diff --git a/dagger/dag_creator/airflow/dag_creator.py b/dagger/dag_creator/airflow/dag_creator.py index 70358e8..031a3a4 100644 --- a/dagger/dag_creator/airflow/dag_creator.py +++ b/dagger/dag_creator/airflow/dag_creator.py @@ -58,7 +58,7 @@ def _get_external_task_sensor_name_dict(self, from_task_id: str) -> dict: "external_sensor_name": f"{from_pipeline_name}-{from_task_name}-sensor", } - def _get_external_task_sensor(self, from_task_id: str, to_task_id: str) -> ExternalTaskSensor: + def _get_external_task_sensor(self, from_task_id: str, to_task_id: str, follow_external_dependency: dict) -> ExternalTaskSensor: """ create an object of external task sensor for a specific from_task_id and to_task_id """ @@ -72,6 +72,9 @@ def _get_external_task_sensor(self, from_task_id: str, to_task_id: str) -> Exter to_pipe_id = self._task_graph.get_node(to_task_id).obj.pipeline.name + extra_args = conf.EXTERNAL_SENSOR_DEFAULT_ARGS.copy() + extra_args.update(follow_external_dependency) + return ExternalTaskSensor( dag=self._dags[to_pipe_id], task_id=external_sensor_name, @@ -80,9 +83,7 @@ def _get_external_task_sensor(self, from_task_id: str, to_task_id: str) -> Exter execution_date_fn=self._get_execution_date_fn( from_pipeline_schedule, to_pipeline_schedule ), - mode=conf.EXTERNAL_SENSOR_MODE, - poke_interval=conf.EXTERNAL_SENSOR_POKE_INTERVAL, - timeout=conf.EXTERNAL_SENSOR_TIMEOUT, + **extra_args ) def _create_control_flow_task(self, pipe_id, dag): @@ -135,6 +136,7 @@ def _create_edge_without_data(self, from_task_id: str, to_task_ids: list, node: to_task_ids: The IDs of the tasks to which the edge connects. node: The current node in a task graph. """ + from_pipe = ( self._task_graph.get_node(from_task_id).obj.pipeline_name if from_task_id else None ) @@ -143,7 +145,7 @@ def _create_edge_without_data(self, from_task_id: str, to_task_ids: list, node: to_pipe = self._task_graph.get_node(to_task_id).obj.pipeline_name if from_pipe and from_pipe == to_pipe: self._tasks[from_task_id] >> self._tasks[to_task_id] - elif from_pipe and from_pipe != to_pipe and edge_properties.follow_external_dependency: + elif from_pipe and from_pipe != to_pipe and edge_properties.follow_external_dependency is not None: from_schedule = self._task_graph.get_node(from_task_id).obj.pipeline.schedule to_schedule = self._task_graph.get_node(to_task_id).obj.pipeline.schedule if not from_schedule.startswith("@") and not to_schedule.startswith("@"): @@ -155,15 +157,17 @@ def _create_edge_without_data(self, from_task_id: str, to_task_ids: list, node: not in self._sensor_dict.get(to_pipe, dict()).keys() ): external_task_sensor = self._get_external_task_sensor( - from_task_id, to_task_id + from_task_id, to_task_id, edge_properties.follow_external_dependency ) - self._sensor_dict[to_pipe] = { + + if self._sensor_dict.get(to_pipe) is None: + self._sensor_dict[to_pipe] = {} + + self._sensor_dict[to_pipe].update({ external_task_sensor_name: external_task_sensor - } - ( - self._tasks[self._get_control_flow_task_id(to_pipe)] - >> external_task_sensor - ) + }) + + self._tasks[self._get_control_flow_task_id(to_pipe)] >> external_task_sensor self._sensor_dict[to_pipe][external_task_sensor_name] >> self._tasks[to_task_id] else: self._tasks[self._get_control_flow_task_id(to_pipe)] >> self._tasks[to_task_id] diff --git a/dagger/dag_creator/airflow/operator_creator.py b/dagger/dag_creator/airflow/operator_creator.py index fc46234..b6aa036 100644 --- a/dagger/dag_creator/airflow/operator_creator.py +++ b/dagger/dag_creator/airflow/operator_creator.py @@ -1,5 +1,6 @@ from abc import ABC, abstractmethod from datetime import timedelta +from airflow.utils.task_group import TaskGroup TIMEDELTA_PARAMETERS = ['execution_timeout'] @@ -11,6 +12,15 @@ def __init__(self, task, dag): self._template_parameters = {} self._airflow_parameters = {} + def _get_existing_task_group_or_create_new(self): + group_id = self._task.task_group + if self._dag.task_group: + for group in self._dag.task_group.children.values(): + if isinstance(group, TaskGroup) and group.group_id == group_id: + return group + + return TaskGroup(group_id=group_id, dag=self._dag) + @abstractmethod def _create_operator(self, kwargs): raise NotImplementedError @@ -34,6 +44,9 @@ def _update_airflow_parameters(self): if self._task.timeout_in_seconds: self._airflow_parameters["execution_timeout"] = self._task.timeout_in_seconds + if self._task.task_group: + self._airflow_parameters["task_group"] = self._get_existing_task_group_or_create_new() + self._fix_timedelta_parameters() def create_operator(self): diff --git a/dagger/dag_creator/airflow/operator_creators/batch_creator.py b/dagger/dag_creator/airflow/operator_creators/batch_creator.py index a3d2534..0cfe9fb 100644 --- a/dagger/dag_creator/airflow/operator_creators/batch_creator.py +++ b/dagger/dag_creator/airflow/operator_creators/batch_creator.py @@ -1,5 +1,9 @@ +from pathlib import Path +from datetime import timedelta + from dagger.dag_creator.airflow.operator_creator import OperatorCreator from dagger.dag_creator.airflow.operators.awsbatch_operator import AWSBatchOperator +from dagger import conf class BatchCreator(OperatorCreator): @@ -8,6 +12,20 @@ class BatchCreator(OperatorCreator): def __init__(self, task, dag): super().__init__(task, dag) + @staticmethod + def _validate_job_name(job_name, absolute_job_name): + if not absolute_job_name and not job_name: + raise Exception("Both job_name and absolute_job_name cannot be null") + + if absolute_job_name is not None: + return absolute_job_name + + job_path = Path(conf.DAGS_DIR) / job_name.replace("-", "/") + assert ( + job_path.is_dir() + ), f"Job name `{job_name}`, points to a non-existing folder `{job_path}`" + return job_name + def _generate_command(self): command = [self._task.executable_prefix, self._task.executable] for param_name, param_value in self._template_parameters.items(): @@ -21,16 +39,16 @@ def _create_operator(self, **kwargs): overrides = self._task.overrides overrides.update({"command": self._generate_command()}) + job_name = self._validate_job_name(self._task.job_name, self._task.absolute_job_name) batch_op = AWSBatchOperator( dag=self._dag, task_id=self._task.name, - job_name=self._task.job_name, - absolute_job_name=self._task.absolute_job_name, + job_name=self._task.name, + job_definition=job_name, region_name=self._task.region_name, - cluster_name=self._task.cluster_name, job_queue=self._task.job_queue, - overrides=overrides, + container_overrides=overrides, + awslogs_enabled=True, **kwargs, ) - return batch_op diff --git a/dagger/dag_creator/airflow/operator_creators/dbt_creator.py b/dagger/dag_creator/airflow/operator_creators/dbt_creator.py index 1c16835..9be9ee8 100644 --- a/dagger/dag_creator/airflow/operator_creators/dbt_creator.py +++ b/dagger/dag_creator/airflow/operator_creators/dbt_creator.py @@ -13,28 +13,28 @@ def __init__(self, task, dag): self._project_dir = task.project_dir self._profile_dir = task.profile_dir self._profile_name = task.profile_name + self._target_name = task.target_name self._select = task.select self._dbt_command = task.dbt_command + self._vars = task.vars + self._create_external_athena_table = task.create_external_athena_table def _generate_command(self): command = [self._task.executable_prefix, self._task.executable] command.append(f"--project_dir={self._project_dir}") command.append(f"--profiles_dir={self._profile_dir}") command.append(f"--profile_name={self._profile_name}") + command.append(f"--target_name={self._target_name}") command.append(f"--dbt_command={self._dbt_command}") if self._select: command.append(f"--select={self._select}") - - if len(self._template_parameters) > 0: - dbt_vars = json.dumps(self._template_parameters) + if self._vars: + dbt_vars = json.dumps(self._vars) command.append(f"--vars='{dbt_vars}'") - + if self._create_external_athena_table: + command.append(f"--create_external_athena_table={self._create_external_athena_table}") + for param_name, param_value in self._template_parameters.items(): + command.append( + f"--{param_name}={param_value}" + ) return command - - # Overwriting function because for dbt we don't want to add inputs/outputs to the - # template parameters. - def create_operator(self): - self._template_parameters.update(self._task.template_parameters) - self._update_airflow_parameters() - - return self._create_operator(**self._airflow_parameters) diff --git a/dagger/dag_creator/airflow/operator_creators/redshift_load_creator.py b/dagger/dag_creator/airflow/operator_creators/redshift_load_creator.py index 8d14182..f1576f4 100644 --- a/dagger/dag_creator/airflow/operator_creators/redshift_load_creator.py +++ b/dagger/dag_creator/airflow/operator_creators/redshift_load_creator.py @@ -2,9 +2,8 @@ from typing import Optional from dagger.dag_creator.airflow.operator_creator import OperatorCreator -from dagger.dag_creator.airflow.operators.redshift_sql_operator import ( - RedshiftSQLOperator, -) +from dagger.dag_creator.airflow.operators.redshift_sql_operator import RedshiftSQLOperator + class RedshiftLoadCreator(OperatorCreator): diff --git a/dagger/dag_creator/airflow/operator_creators/redshift_transform_creator.py b/dagger/dag_creator/airflow/operator_creators/redshift_transform_creator.py index 0218a6f..c8eb8dd 100644 --- a/dagger/dag_creator/airflow/operator_creators/redshift_transform_creator.py +++ b/dagger/dag_creator/airflow/operator_creators/redshift_transform_creator.py @@ -1,7 +1,7 @@ from os.path import join from dagger.dag_creator.airflow.operator_creator import OperatorCreator -from dagger.dag_creator.airflow.operators.postgres_operator import PostgresOperator +from dagger.dag_creator.airflow.operators.redshift_sql_operator import RedshiftSQLOperator class RedshiftTransformCreator(OperatorCreator): @@ -22,11 +22,11 @@ def _read_sql(directory, file_path): def _create_operator(self, **kwargs): sql_string = self._read_sql(self._task.pipeline.directory, self._task.sql_file) - redshift_op = PostgresOperator( + redshift_op = RedshiftSQLOperator( dag=self._dag, task_id=self._task.name, sql=sql_string, - postgres_conn_id=self._task.postgres_conn_id, + redshift_conn_id=self._task.postgres_conn_id, params=self._template_parameters, **kwargs, ) diff --git a/dagger/dag_creator/airflow/operator_creators/redshift_unload_creator.py b/dagger/dag_creator/airflow/operator_creators/redshift_unload_creator.py index 7fd74d7..cb7be04 100644 --- a/dagger/dag_creator/airflow/operator_creators/redshift_unload_creator.py +++ b/dagger/dag_creator/airflow/operator_creators/redshift_unload_creator.py @@ -1,7 +1,7 @@ from os.path import join from dagger.dag_creator.airflow.operator_creator import OperatorCreator -from dagger.dag_creator.airflow.operators.postgres_operator import PostgresOperator +from dagger.dag_creator.airflow.operators.redshift_sql_operator import RedshiftSQLOperator REDSHIFT_UNLOAD_CMD = """ unload ('{sql_string}') @@ -58,12 +58,13 @@ def _create_operator(self, **kwargs): unload_cmd = self._get_unload_command(sql_string) - redshift_op = PostgresOperator( + redshift_op = RedshiftSQLOperator( dag=self._dag, task_id=self._task.name, sql=unload_cmd, - postgres_conn_id=self._task.postgres_conn_id, + redshift_conn_id=self._task.postgres_conn_id, params=self._template_parameters, + autocommit=True, **kwargs, ) diff --git a/dagger/dag_creator/airflow/operator_creators/reverse_etl_creator.py b/dagger/dag_creator/airflow/operator_creators/reverse_etl_creator.py new file mode 100644 index 0000000..e133e40 --- /dev/null +++ b/dagger/dag_creator/airflow/operator_creators/reverse_etl_creator.py @@ -0,0 +1,78 @@ +import base64 + +from dagger.dag_creator.airflow.operator_creators.batch_creator import BatchCreator +from dagger.dag_creator.airflow.operators.reverse_etl_batch import ReverseEtlBatchOperator +import json + + +class ReverseEtlCreator(BatchCreator): + ref_name = "reverse_etl" + + def __init__(self, task, dag): + super().__init__(task, dag) + + self._assume_role_arn = task.assume_role_arn + self._num_threads = task.num_threads + self._batch_size = task.batch_size + self._absolute_job_name = task.absolute_job_name + self._primary_id_column = task.primary_id_column + self._secondary_id_column = task.secondary_id_column + self._custom_id_column = task.custom_id_column + self._model_name = task.model_name + self._project_name = task.project_name + self._is_deleted_column = task.is_deleted_column + self._hash_column = task.hash_column + self._updated_at_column = task.updated_at_column + self._from_time = task.from_time + self._days_to_live = task.days_to_live + self._output_type = task.output_type + self._region_name = task.region_name + + def _generate_command(self): + command = BatchCreator._generate_command(self) + + command.append(f"--num_threads={self._num_threads}") + command.append(f"--batch_size={self._batch_size}") + command.append(f"--primary_id_column={self._primary_id_column}") + command.append(f"--model_name={self._model_name}") + command.append(f"--project_name={self._project_name}") + command.append(f"--output_type={self._output_type}") + + if self._assume_role_arn: + command.append(f"--assume_role_arn={self._assume_role_arn}") + if self._secondary_id_column: + command.append(f"--secondary_id_column={self._secondary_id_column}") + if self._custom_id_column: + command.append(f"--custom_id_column={self._custom_id_column}") + if self._is_deleted_column: + command.append(f"--is_deleted_column={self._is_deleted_column}") + if self._hash_column: + command.append(f"--hash_column={self._hash_column}") + if self._updated_at_column: + command.append(f"--updated_at_column={self._updated_at_column}") + if self._from_time: + command.append(f"--from_time={self._from_time}") + if self._days_to_live: + command.append(f"--days_to_live={self._days_to_live}") + if self._region_name: + command.append(f"--region_name={self._region_name}") + + return command + + def _create_operator(self, **kwargs): + overrides = self._task.overrides + overrides.update({"command": self._generate_command()}) + + job_name = self._validate_job_name(self._task.job_name, self._task.absolute_job_name) + batch_op = ReverseEtlBatchOperator( + dag=self._dag, + task_id=self._task.name, + job_name=self._task.name, + job_definition=job_name, + region_name=self._task.region_name, + job_queue=self._task.job_queue, + container_overrides=overrides, + awslogs_enabled=True, + **kwargs, + ) + return batch_op diff --git a/dagger/dag_creator/airflow/operator_creators/spark_creator.py b/dagger/dag_creator/airflow/operator_creators/spark_creator.py index 2bb41e9..8212a08 100644 --- a/dagger/dag_creator/airflow/operator_creators/spark_creator.py +++ b/dagger/dag_creator/airflow/operator_creators/spark_creator.py @@ -91,6 +91,7 @@ def _create_operator(self, **kwargs): job_args=_parse_args(self._template_parameters), spark_args=_parse_spark_args(self._task.spark_args), spark_conf_args=_parse_spark_args(self._task.spark_conf_args, '=', 'conf '), + spark_app_name=self._task.spark_conf_args.get("spark.app.name", None) if self._task.spark_conf_args else None, extra_py_files=self._task.extra_py_files, **kwargs, ) @@ -113,7 +114,7 @@ def _create_operator(self, **kwargs): job_name=job_name, region_name=self._task.region_name, job_queue=self._task.job_queue, - overrides=overrides, + container_overrides=overrides, **kwargs, ) elif self._task.spark_engine == "glue": diff --git a/dagger/dag_creator/airflow/operator_factory.py b/dagger/dag_creator/airflow/operator_factory.py index 706a737..f610f1e 100644 --- a/dagger/dag_creator/airflow/operator_factory.py +++ b/dagger/dag_creator/airflow/operator_factory.py @@ -10,6 +10,7 @@ redshift_load_creator, redshift_transform_creator, redshift_unload_creator, + reverse_etl_creator, spark_creator, sqoop_creator, ) diff --git a/dagger/dag_creator/airflow/operators/awsbatch_operator.py b/dagger/dag_creator/airflow/operators/awsbatch_operator.py index a267ba7..23b3596 100644 --- a/dagger/dag_creator/airflow/operators/awsbatch_operator.py +++ b/dagger/dag_creator/airflow/operators/awsbatch_operator.py @@ -1,203 +1,90 @@ -from pathlib import Path -from time import sleep - -from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook +from airflow.providers.amazon.aws.operators.batch import BatchOperator +from airflow.utils.context import Context from airflow.exceptions import AirflowException -from airflow.utils.decorators import apply_defaults - -from dagger.dag_creator.airflow.operators.dagger_base_operator import DaggerBaseOperator -from dagger.dag_creator.airflow.utils.decorators import lazy_property -from dagger import conf - - -class AWSBatchOperator(DaggerBaseOperator): - """ - Execute a job on AWS Batch Service - - .. warning: the queue parameter was renamed to job_queue to segregate the - internal CeleryExecutor queue from the AWS Batch internal queue. - - :param job_name: the name for the job that will run on AWS Batch - :type job_name: str - :param job_definition: the job definition name on AWS Batch - :type job_definition: str - :param job_queue: the queue name on AWS Batch - :type job_queue: str - :param overrides: the same parameter that boto3 will receive on - containerOverrides (templated): - http://boto3.readthedocs.io/en/latest/reference/services/batch.html#submit_job - :type overrides: dict - :param max_retries: exponential backoff retries while waiter is not - merged, 4200 = 48 hours - :type max_retries: int - :param aws_conn_id: connection id of AWS credentials / region name. If None, - credential boto3 strategy will be used - (http://boto3.readthedocs.io/en/latest/guide/configuration.html). - :type aws_conn_id: str - :param region_name: region name to use in AWS Hook. - Override the region_name in connection (if provided) - :type region_name: str - :param cluster_name: Batch cluster short name or arn - :type region_name: str - - """ - - ui_color = "#c3dae0" - client = None - arn = None - template_fields = ("overrides",) +from airflow.providers.amazon.aws.links.batch import ( + BatchJobDefinitionLink, + BatchJobQueueLink, +) +from airflow.providers.amazon.aws.links.logs import CloudWatchEventsLink - @apply_defaults - def __init__( - self, - job_queue, - job_name=None, - absolute_job_name=None, - overrides=None, - job_definition=None, - aws_conn_id=None, - region_name=None, - cluster_name=None, - *args, - **kwargs, - ): - super().__init__(*args, **kwargs) - self.job_name = self._validate_job_name(job_name, absolute_job_name) - self.aws_conn_id = aws_conn_id - self.region_name = region_name - self.cluster_name = cluster_name - self.job_definition = job_definition or self.job_name - self.job_queue = job_queue - self.overrides = overrides or {} - self.job_id = None - - @lazy_property - def batch_client(self): - return AwsBaseHook(aws_conn_id=self.aws_conn_id, client_type="batch").get_client_type( - region_name=self.region_name) - - @lazy_property - def logs_client(self): - return AwsBaseHook(aws_conn_id=self.aws_conn_id, client_type="logs").get_client_type( - region_name=self.region_name) - - @lazy_property - def ecs_client(self): - return AwsBaseHook(aws_conn_id=self.aws_conn_id, client_type="ecs").get_client_type( - region_name=self.region_name) +class AWSBatchOperator(BatchOperator): @staticmethod - def _validate_job_name(job_name, absolute_job_name): - if absolute_job_name is None and job_name is None: - raise Exception("Both job_name and absolute_job_name cannot be null") - - if absolute_job_name is not None: - return absolute_job_name - - job_path = Path(conf.DAGS_DIR) / job_name.replace("-", "/") - assert ( - job_path.is_dir() - ), f"Job name `{job_name}`, points to a non-existing folder `{job_path}`" - return job_name - - def execute(self, context): - self.task_instance = context["ti"] - self.log.info( - "\n" - f"\n\tJob name: {self.job_name}" - f"\n\tJob queue: {self.job_queue}" - f"\n\tJob definition: {self.job_definition}" - "\n" - ) - - res = self.batch_client.submit_job( - jobName=self.job_name, - jobQueue=self.job_queue, - jobDefinition=self.job_definition, - containerOverrides=self.overrides, - ) - self.job_id = res["jobId"] - self.log.info( - "\n" - f"\n\tJob ID: {self.job_id}" - "\n" - ) - self.poll_task() - - def poll_task(self): - log_offset = 0 - print_logs_url = True - - while True: - res = self.batch_client.describe_jobs(jobs=[self.job_id]) - - if len(res["jobs"]) == 0: - sleep(3) - continue - - job = res["jobs"][0] - job_status = job["status"] - log_stream_name = job["container"].get("logStreamName") - - if print_logs_url and log_stream_name: - print_logs_url = False - self.log.info( - "\n" - f"\n\tLogs at: https://{self.region_name}.console.aws.amazon.com/cloudwatch/home?" - f"region={self.region_name}#logEventViewer:group=/aws/batch/job;stream={log_stream_name}" - "\n" - ) + def _format_cloudwatch_link(awslogs_region: str, awslogs_group: str, awslogs_stream_name: str): + return f"https://{awslogs_region}.console.aws.amazon.com/cloudwatch/home?region={awslogs_region}#logEventViewer:group={awslogs_group};stream={awslogs_stream_name}" + + def monitor_job(self, context: Context): + """Monitor an AWS Batch job. + + This can raise an exception or an AirflowTaskTimeout if the task was + created with ``execution_timeout``. + """ + if not self.job_id: + raise AirflowException("AWS Batch job - job_id was not found") + + try: + job_desc = self.hook.get_job_description(self.job_id) + job_definition_arn = job_desc["jobDefinition"] + job_queue_arn = job_desc["jobQueue"] + self.log.info( + "AWS Batch job (%s) Job Definition ARN: %r, Job Queue ARN: %r", + self.job_id, + job_definition_arn, + job_queue_arn, + ) + except KeyError: + self.log.warning("AWS Batch job (%s) can't get Job Definition ARN and Job Queue ARN", self.job_id) + else: + BatchJobDefinitionLink.persist( + context=context, + operator=self, + region_name=self.hook.conn_region_name, + aws_partition=self.hook.conn_partition, + job_definition_arn=job_definition_arn, + ) + BatchJobQueueLink.persist( + context=context, + operator=self, + region_name=self.hook.conn_region_name, + aws_partition=self.hook.conn_partition, + job_queue_arn=job_queue_arn, + ) - if job_status in ("RUNNING", "FAILED", "SUCCEEDED") and log_stream_name: - try: - log_offset = self.print_logs(log_stream_name, log_offset) - except self.logs_client.exceptions.ResourceNotFoundException: - pass + if self.awslogs_enabled: + if self.waiters: + self.waiters.wait_for_job(self.job_id, get_batch_log_fetcher=self._get_batch_log_fetcher) else: - self.log.info(f"Job status: {job_status}") - - if job_status == "FAILED": - status_reason = res["jobs"][0]["statusReason"] - exit_code = res["jobs"][0]["container"].get("exitCode") - reason = res["jobs"][0]["container"].get("reason", "") - failure_msg = f"Status: {status_reason} | Exit code: {exit_code} | Reason: {reason}" - container_instance_arn = job["container"]["containerInstanceArn"] - self.retry_check(container_instance_arn) - raise AirflowException(failure_msg) - - if job_status == "SUCCEEDED": - self.log.info("AWS Batch Job has been successfully executed") - return - - sleep(7.5) + self.hook.wait_for_job(self.job_id, get_batch_log_fetcher=self._get_batch_log_fetcher) + else: + if self.waiters: + self.waiters.wait_for_job(self.job_id) + else: + self.hook.wait_for_job(self.job_id) + + awslogs = [] + try: + awslogs = self.hook.get_job_all_awslogs_info(self.job_id) + except AirflowException as ae: + self.log.warning("Cannot determine where to find the AWS logs for this Batch job: %s", ae) + + if awslogs: + self.log.info("AWS Batch job (%s) CloudWatch Events details found. Links to logs:", self.job_id) + link_builder = CloudWatchEventsLink() + for log in awslogs: + self.log.info(self._format_cloudwatch_link(**log)) + if len(awslogs) > 1: + # there can be several log streams on multi-node jobs + self.log.warning( + "out of all those logs, we can only link to one in the UI. Using the first one." + ) - def retry_check(self, container_instance_arn): - res = self.ecs_client.describe_container_instances( - cluster=self.cluster_name, containerInstances=[container_instance_arn] - ) - instance_status = res["containerInstances"][0]["status"] - if instance_status != "ACTIVE": - self.log.warning( - f"Instance in {instance_status} state: setting the task up for retry..." + CloudWatchEventsLink.persist( + context=context, + operator=self, + region_name=self.hook.conn_region_name, + aws_partition=self.hook.conn_partition, + **awslogs[0], ) - self.retries += self.task_instance.try_number + 1 - self.task_instance.max_tries = self.retries - - def print_logs(self, log_stream_name, log_offset): - logs = self.logs_client.get_log_events( - logGroupName="/aws/batch/job", - logStreamName=log_stream_name, - startFromHead=True, - ) - - for event in logs["events"][log_offset:]: - self.log.info(event["message"]) - - log_offset = len(logs["events"]) - return log_offset - def on_kill(self): - res = self.batch_client.terminate_job( - jobId=self.job_id, reason="Task killed by the user" - ) - self.log.info(res) + self.hook.check_job_success(self.job_id) + self.log.info("AWS Batch job (%s) succeeded", self.job_id) diff --git a/dagger/dag_creator/airflow/operators/postgres_operator.py b/dagger/dag_creator/airflow/operators/postgres_operator.py index b833516..ce90250 100644 --- a/dagger/dag_creator/airflow/operators/postgres_operator.py +++ b/dagger/dag_creator/airflow/operators/postgres_operator.py @@ -1,6 +1,6 @@ from typing import Iterable, Mapping, Optional, Union -from airflow.hooks.postgres_hook import PostgresHook +from airflow.providers.postgres.hooks.postgres import PostgresHook from airflow.utils.decorators import apply_defaults from dagger.dag_creator.airflow.operators.dagger_base_operator import DaggerBaseOperator @@ -51,6 +51,6 @@ def execute(self, context): self.hook = PostgresHook( postgres_conn_id=self.postgres_conn_id, schema=self.database ) - self.hook.run(self.sql, self.autocommit, parameters=self.parameters) + self.hook.run(self.sql, self.autocommit, parameters=self.parameters, split_statements=True) for output in self.hook.conn.notices: self.log.info(output) diff --git a/dagger/dag_creator/airflow/operators/reverse_etl_batch.py b/dagger/dag_creator/airflow/operators/reverse_etl_batch.py new file mode 100644 index 0000000..78c1619 --- /dev/null +++ b/dagger/dag_creator/airflow/operators/reverse_etl_batch.py @@ -0,0 +1,5 @@ +from dagger.dag_creator.airflow.operators.awsbatch_operator import AWSBatchOperator + +class ReverseEtlBatchOperator(AWSBatchOperator): + custom_operator_name = 'ReverseETL' + ui_color = "#f0ede4" diff --git a/dagger/dag_creator/airflow/operators/spark_submit_operator.py b/dagger/dag_creator/airflow/operators/spark_submit_operator.py index d9df768..31f6a70 100644 --- a/dagger/dag_creator/airflow/operators/spark_submit_operator.py +++ b/dagger/dag_creator/airflow/operators/spark_submit_operator.py @@ -1,6 +1,5 @@ import logging import os -import signal import time import boto3 @@ -19,24 +18,28 @@ class SparkSubmitOperator(DaggerBaseOperator): @apply_defaults def __init__( - self, - job_file, - cluster_name, - job_args=None, - spark_args=None, - spark_conf_args=None, - extra_py_files=None, - *args, - **kwargs, + self, + job_file, + cluster_name, + job_args=None, + spark_args=None, + spark_conf_args=None, + spark_app_name=None, + extra_py_files=None, + *args, + **kwargs, ): super().__init__(*args, **kwargs) self.job_file = job_file self.job_args = job_args self.spark_args = spark_args self.spark_conf_args = spark_conf_args + self.spark_app_name = spark_app_name self.extra_py_files = extra_py_files self.cluster_name = cluster_name - self._execution_timeout = kwargs.get('execution_timeout') + self._execution_timeout = kwargs.get("execution_timeout") + self._application_id = None + self._emr_master_instance_id = None @property def emr_client(self): @@ -71,54 +74,132 @@ def get_execution_timeout(self): return None def get_cluster_id_by_name(self, emr_cluster_name, cluster_states): - response = self.emr_client.list_clusters(ClusterStates=cluster_states) matching_clusters = list( - filter(lambda cluster: cluster['Name'] == emr_cluster_name, response['Clusters'])) + filter( + lambda cluster: cluster["Name"] == emr_cluster_name, + response["Clusters"], + ) + ) if len(matching_clusters) == 1: - cluster_id = matching_clusters[0]['Id'] - logging.info('Found cluster name = %s id = %s' % (emr_cluster_name, cluster_id)) + cluster_id = matching_clusters[0]["Id"] + logging.info( + "Found cluster name = %s id = %s" % (emr_cluster_name, cluster_id) + ) return cluster_id elif len(matching_clusters) > 1: - raise AirflowException('More than one cluster found for name = %s' % emr_cluster_name) + raise AirflowException( + "More than one cluster found for name = %s" % emr_cluster_name + ) else: return None - def execute(self, context): + def get_application_id_by_name(self, emr_master_instance_id, application_name): """ - See `execute` method from airflow.operators.bash_operator + Get the application ID of the Spark job """ - cluster_id = self.get_cluster_id_by_name(self.cluster_name, ["WAITING", "RUNNING"]) - emr_master_instance_id = self.emr_client.list_instances(ClusterId=cluster_id, InstanceGroupTypes=["MASTER"], - InstanceStates=["RUNNING"])["Instances"][0][ - "Ec2InstanceId"] - - command_parameters = {"commands": [self.spark_submit_cmd]} - if self._execution_timeout: - command_parameters["executionTimeout"] = [self.get_execution_timeout()] + if application_name: + command = ( + f"yarn application -list -appStates RUNNING | grep {application_name}" + ) + + response = self.ssm_client.send_command( + InstanceIds=[emr_master_instance_id], + DocumentName="AWS-RunShellScript", + Parameters={"commands": [command]}, + ) + + command_id = response["Command"]["CommandId"] + time.sleep(10) # Wait for the command to execute + + output = self.ssm_client.get_command_invocation( + CommandId=command_id, InstanceId=emr_master_instance_id + ) + + stdout = output["StandardOutputContent"] + for line in stdout.split("\n"): + if application_name in line: + application_id = line.split()[0] + return application_id + return None - response = self.ssm_client.send_command( - InstanceIds=[emr_master_instance_id], - DocumentName="AWS-RunShellScript", - Parameters= command_parameters + def kill_spark_job(self): + self._application_id = self.get_application_id_by_name( + self._emr_master_instance_id, self.spark_app_name ) - command_id = response['Command']['CommandId'] - status = 'Pending' - status_details = None - while status in ['Pending', 'InProgress', 'Delayed']: - time.sleep(30) - response = self.ssm_client.get_command_invocation(CommandId=command_id, InstanceId=emr_master_instance_id) - status = response['Status'] - status_details = response['StatusDetails'] - self.log.info( - self.ssm_client.get_command_invocation(CommandId=command_id, InstanceId=emr_master_instance_id)[ - 'StandardErrorContent']) - - if status != 'Success': - raise AirflowException(f"Spark command failed, check Spark job status in YARN resource manager. " - f"Response status details: {status_details}") + if self._application_id and self._emr_master_instance_id: + kill_command = f"yarn application -kill {self._application_id}" + self.ssm_client.send_command( + InstanceIds=[self._emr_master_instance_id], + DocumentName="AWS-RunShellScript", + Parameters={"commands": [kill_command]}, + ) + logging.info(f"Spark job {self._application_id} terminated successfully.") + else: + logging.warning( + "No application ID or master instance ID found to terminate." + ) def on_kill(self): - self.log.info("Sending SIGTERM signal to bash process group") - os.killpg(os.getpgid(self.sp.pid), signal.SIGTERM) + logging.info("Task killed. Attempting to terminate the Spark job.") + self.kill_spark_job() + + def execute(self, context): + """ + See `execute` method from airflow.operators.bash_operator + """ + try: + # Get cluster and master node information + cluster_id = self.get_cluster_id_by_name( + self.cluster_name, ["WAITING", "RUNNING"] + ) + self._emr_master_instance_id = self.emr_client.list_instances( + ClusterId=cluster_id, + InstanceGroupTypes=["MASTER"], + InstanceStates=["RUNNING"], + )["Instances"][0]["Ec2InstanceId"] + + # Build the command parameters + command_parameters = {"commands": [self.spark_submit_cmd]} + if self._execution_timeout: + command_parameters["executionTimeout"] = [self.get_execution_timeout()] + + # Send the command via SSM + response = self.ssm_client.send_command( + InstanceIds=[self._emr_master_instance_id], + DocumentName="AWS-RunShellScript", + Parameters=command_parameters, + ) + command_id = response["Command"]["CommandId"] + status = "Pending" + status_details = None + + # Monitor the command's execution + while status in ["Pending", "InProgress", "Delayed"]: + time.sleep(30) + # Check the status of the SSM command + response = self.ssm_client.get_command_invocation( + CommandId=command_id, InstanceId=self._emr_master_instance_id + ) + status = response["Status"] + status_details = response["StatusDetails"] + + self.log.info( + self.ssm_client.get_command_invocation( + CommandId=command_id, InstanceId=self._emr_master_instance_id + )["StandardErrorContent"] + ) + + # Kill the command and raise an exception if the command did not succeed + if status != "Success": + self.kill_spark_job() + raise AirflowException( + f"Spark command failed, check Spark job status in YARN resource manager. " + f"Response status details: {status_details}" + ) + + except Exception as e: + logging.error(f"Error encountered: {str(e)}") + self.kill_spark_job() + raise AirflowException(f"Task failed with error: {str(e)}") diff --git a/dagger/dagger_config.yaml b/dagger/dagger_config.yaml index 9eac6ff..38abccd 100644 --- a/dagger/dagger_config.yaml +++ b/dagger/dagger_config.yaml @@ -1,8 +1,10 @@ airflow: + external_sensor_default_args: + poll_interval: 30 + timeout: 28800 + mode: reschedule + deferrable: true with_data_node: false - external_sensor_poke_interval: 600 - external_sensor_timeout: 28800 - external_sensor_mode: reschedule is_dummy_operator_short_circuit: false @@ -56,3 +58,12 @@ alert: # type: slack # channel: "#airflow-jobs" # mentions: + +plugin: +# paths: +# - plugins + +reverse_etl: +# default_job_name: +# default_executable_prefix: +# default_executable: diff --git a/dagger/graph/task_graph.py b/dagger/graph/task_graph.py index c7898f6..0f14a56 100644 --- a/dagger/graph/task_graph.py +++ b/dagger/graph/task_graph.py @@ -55,7 +55,7 @@ def add_child(self, child_id): class Edge: - def __init__(self, follow_external_dependency=False): + def __init__(self, follow_external_dependency=None): self._follow_external_dependency = follow_external_dependency @property diff --git a/dagger/pipeline/io.py b/dagger/pipeline/io.py index 452798f..cdacdd0 100644 --- a/dagger/pipeline/io.py +++ b/dagger/pipeline/io.py @@ -19,9 +19,15 @@ def init_attributes(cls, orig_cls): Attribute( attribute_name="follow_external_dependency", required=False, - comment="Weather an external task sensor should be created if this dataset" - "is created in another pipeline. Default is False", + format_help="dictionary or boolean", + comment="External Task Sensor parameters in key value format: https://airflow.apache.org/docs/apache-airflow/stable/_api/airflow/sensors/base/index.html" ), + # Attribute( + # attribute_name="follow_external_dependency", + # required=False, + # comment="Weather an external task sensor should be created if this dataset" + # "is created in another pipeline. Default is False", + # ), ] ) @@ -34,7 +40,17 @@ def __init__(self, io_config, config_location): self._has_dependency = self.parse_attribute("has_dependency") if self._has_dependency is None: self._has_dependency = True - self._follow_external_dependency = self.parse_attribute("follow_external_dependency") or False + + follow_external_dependency = self.parse_attribute("follow_external_dependency") + if follow_external_dependency is not None: + if isinstance(follow_external_dependency, bool): + if follow_external_dependency: + follow_external_dependency = dict() + else: + follow_external_dependency = None + else: + follow_external_dependency = dict(follow_external_dependency) + self._follow_external_dependency = follow_external_dependency def __eq__(self, other): return self.alias() == other.alias() @@ -47,6 +63,10 @@ def alias(self): def name(self): return self._name + @name.setter + def name(self, value): + self._name = value + @property def has_dependency(self): return self._has_dependency diff --git a/dagger/pipeline/io_factory.py b/dagger/pipeline/io_factory.py index 782fd14..61d9fd2 100644 --- a/dagger/pipeline/io_factory.py +++ b/dagger/pipeline/io_factory.py @@ -6,7 +6,10 @@ dummy_io, gdrive_io, redshift_io, - s3_io + s3_io, + databricks_io, + dynamo_io, + sns_io, ) from dagger.utilities.classes import get_deep_obj_subclasses diff --git a/dagger/pipeline/ios/databricks_io.py b/dagger/pipeline/ios/databricks_io.py new file mode 100644 index 0000000..15be2c1 --- /dev/null +++ b/dagger/pipeline/ios/databricks_io.py @@ -0,0 +1,46 @@ +from dagger.pipeline.io import IO +from dagger.utilities.config_validator import Attribute + + +class DatabricksIO(IO): + ref_name = "databricks" + + @classmethod + def init_attributes(cls, orig_cls): + cls.add_config_attributes( + [ + Attribute(attribute_name="catalog"), + Attribute(attribute_name="schema"), + Attribute(attribute_name="table"), + ] + ) + + def __init__(self, io_config, config_location): + super().__init__(io_config, config_location) + + self._catalog = self.parse_attribute("catalog") + self._schema = self.parse_attribute("schema") + self._table = self.parse_attribute("table") + + def alias(self): + return f"databricks://{self._catalog}/{self._schema}/{self._table}" + + @property + def rendered_name(self): + return f"{self._catalog}.{self._schema}.{self._table}" + + @property + def airflow_name(self): + return f"databricks-{self._catalog}-{self._schema}-{self._table}" + + @property + def catalog(self): + return self._catalog + + @property + def schema(self): + return self._schema + + @property + def table(self): + return self._table diff --git a/dagger/pipeline/ios/dynamo_io.py b/dagger/pipeline/ios/dynamo_io.py new file mode 100644 index 0000000..88d822e --- /dev/null +++ b/dagger/pipeline/ios/dynamo_io.py @@ -0,0 +1,47 @@ +from dagger.pipeline.io import IO +from dagger.utilities.config_validator import Attribute + + +class DynamoIO(IO): + ref_name = "dynamo" + + @classmethod + def init_attributes(cls, orig_cls): + cls.add_config_attributes( + [ + Attribute( + attribute_name="region_name", + required=False, + comment="Only needed for cross region dynamo tables" + ), + Attribute( + attribute_name="table", + comment="The name of the dynamo table" + ), + ] + ) + + def __init__(self, io_config, config_location): + super().__init__(io_config, config_location) + + self._region_name = self.parse_attribute("region_name") + self._table = self.parse_attribute("table") + + def alias(self): + return f"dynamo://{self._region_name or ''}/{self._table}" + + @property + def rendered_name(self): + return self._table + + @property + def airflow_name(self): + return f"dynamo-{'-'.join([name_part for name_part in [self._region_name, self._table] if name_part])}" + + @property + def region_name(self): + return self._region_name + + @property + def table(self): + return self._table diff --git a/dagger/pipeline/ios/sns_io.py b/dagger/pipeline/ios/sns_io.py new file mode 100644 index 0000000..67d17d4 --- /dev/null +++ b/dagger/pipeline/ios/sns_io.py @@ -0,0 +1,52 @@ +from dagger.pipeline.io import IO +from dagger.utilities.config_validator import Attribute + + +class SnsIO(IO): + ref_name = "sns" + + @classmethod + def init_attributes(cls, orig_cls): + cls.add_config_attributes( + [ + Attribute( + attribute_name="account_id", + required=False, + comment="Only needed for cross account dynamo tables" + ), + Attribute( + attribute_name="region_name", + required=False, + comment="Only needed for cross region dynamo tables" + ), + Attribute( + attribute_name="sns_topic", + comment="The name of the sns topic" + ), + ] + ) + + def __init__(self, io_config, config_location): + super().__init__(io_config, config_location) + + self._region_name = self.parse_attribute("region_name") + self._sns_topic = self.parse_attribute("sns_topic") + + def alias(self): + return f"sns://{self._region_name or ''}/{self._sns_topic}" + + @property + def rendered_name(self): + return self._sns_topic + + @property + def airflow_name(self): + return f"sns-{'-'.join([name_part for name_part in [self._region_name, self._sns_topic] if name_part])}" + + @property + def region_name(self): + return self._region_name + + @property + def sns_topic(self): + return self._sns_topic \ No newline at end of file diff --git a/dagger/pipeline/task.py b/dagger/pipeline/task.py index 26235bd..d484e49 100644 --- a/dagger/pipeline/task.py +++ b/dagger/pipeline/task.py @@ -36,6 +36,12 @@ def init_attributes(cls, orig_cls): comment="Use dagger init-io cli", ), Attribute(attribute_name="pool", required=False), + Attribute( + attribute_name="task_group", + required=False, + format_help="str", + comment="Task group name", + ), Attribute( attribute_name="timeout_in_seconds", required=False, @@ -73,6 +79,7 @@ def __init__(self, name: str, pipeline_name, pipeline, config: dict): self._outputs = [] self._pool = self.parse_attribute("pool") or self.default_pool self._timeout_in_seconds = self.parse_attribute("timeout_in_seconds") + self._task_group = self.parse_attribute("task_group") self.process_inputs(config["inputs"]) self.process_outputs(config["outputs"]) @@ -137,6 +144,10 @@ def pool(self): def timeout_in_seconds(self): return self._timeout_in_seconds + @property + def task_group(self): + return self._task_group + def add_input(self, task_input: IO): _logger.info("Adding input: %s to task: %s", task_input.name, self._name) self._inputs.append(task_input) diff --git a/dagger/pipeline/task_factory.py b/dagger/pipeline/task_factory.py index a9c5eef..d8a1e53 100644 --- a/dagger/pipeline/task_factory.py +++ b/dagger/pipeline/task_factory.py @@ -9,6 +9,7 @@ redshift_load_task, redshift_transform_task, redshift_unload_task, + reverse_etl_task, spark_task, sqoop_task, ) diff --git a/dagger/pipeline/tasks/dbt_task.py b/dagger/pipeline/tasks/dbt_task.py index 33b9c1a..e59ea5a 100644 --- a/dagger/pipeline/tasks/dbt_task.py +++ b/dagger/pipeline/tasks/dbt_task.py @@ -21,10 +21,14 @@ def init_attributes(cls, orig_cls): ), Attribute( attribute_name="profile_name", - required=False, + parent_fields=["task_parameters"], + comment="Which profile to load from the profiles.yml file", + ), + Attribute( + attribute_name="target_name", parent_fields=["task_parameters"], comment="Which target to load for the given profile " - "(--target dbt option). Default is 'default'", + "(--target dbt option). Default is 'default'", ), Attribute( attribute_name="select", @@ -37,6 +41,18 @@ def init_attributes(cls, orig_cls): parent_fields=["task_parameters"], comment="Specify the name of the DBT command to run", ), + Attribute( + attribute_name="vars", + required=False, + parent_fields=["task_parameters"], + comment="Specify the variables to pass to dbt", + ), + Attribute( + attribute_name="create_external_athena_table", + required=False, + parent_fields=["task_parameters"], + comment="Specify whether to create an external Athena table for the model", + ), ] ) @@ -45,9 +61,14 @@ def __init__(self, name, pipeline_name, pipeline, job_config): self._project_dir = self.parse_attribute("project_dir") self._profile_dir = self.parse_attribute("profile_dir") - self._profile_name = self.parse_attribute("profile_name") or "default" + self._profile_name = self.parse_attribute("profile_name") + self._target_name = self.parse_attribute("target_name") self._select = self.parse_attribute("select") self._dbt_command = self.parse_attribute("dbt_command") + self._vars = self.parse_attribute("vars") + self._create_external_athena_table = self.parse_attribute( + "create_external_athena_table" + ) @property def project_dir(self): @@ -61,6 +82,10 @@ def profile_dir(self): def profile_name(self): return self._profile_name + @property + def target_name(self): + return self._target_name + @property def select(self): return self._select @@ -68,3 +93,11 @@ def select(self): @property def dbt_command(self): return self._dbt_command + + @property + def vars(self): + return self._vars + + @property + def create_external_athena_table(self): + return self._create_external_athena_table diff --git a/dagger/pipeline/tasks/reverse_etl_task.py b/dagger/pipeline/tasks/reverse_etl_task.py new file mode 100644 index 0000000..6c9a5d2 --- /dev/null +++ b/dagger/pipeline/tasks/reverse_etl_task.py @@ -0,0 +1,239 @@ +from dagger.pipeline.tasks.batch_task import BatchTask +from dagger.utilities.config_validator import Attribute +from dagger import conf + +class ReverseEtlTask(BatchTask): + ref_name = "reverse_etl" + + @classmethod + def init_attributes(cls, orig_cls): + cls.add_config_attributes( + [ + Attribute( + attribute_name="executable_prefix", + required=False, + parent_fields=["task_parameters"], + comment="E.g.: python", + ), + Attribute( + attribute_name="executable", + required=False, + parent_fields=["task_parameters"], + comment="E.g.: my_code.py", + ), + Attribute( + attribute_name="assume_role_arn", + parent_fields=["task_parameters"], + required = False, + validator=str, + comment="The ARN of the role to assume before running the job", + ), + Attribute( + attribute_name="num_threads", + parent_fields=["task_parameters"], + required=False, + comment="The number of threads to use for the job", + ), + Attribute( + attribute_name="batch_size", + parent_fields=["task_parameters"], + required=False, + comment="The number of rows to fetch in each batch", + ), + Attribute( + attribute_name="primary_id_column", + parent_fields=["task_parameters"], + validator=str, + comment="The primary key column to use for the job", + ), + Attribute( + attribute_name="secondary_id_column", + parent_fields=["task_parameters"], + validator=str, + required=False, + comment="The secondary key column to use for the job", + ), + Attribute( + attribute_name="custom_id_column", + parent_fields=["task_parameters"], + validator=str, + required=False, + comment="The custom key column to use for the job", + ), + Attribute( + attribute_name="model_name", + parent_fields=["task_parameters"], + validator=str, + required=False, + comment="The name of the model. This is going to be a column on the target table. By default it is" + " set to the name of the input .", + ), + Attribute( + attribute_name="project_name", + parent_fields=["task_parameters"], + validator=str, + required=True, + comment="The name of the project. This is going to be a column on the target table.", + ), + Attribute( + attribute_name="is_deleted_column", + parent_fields=["task_parameters"], + validator=str, + required=False, + comment="The column that has the boolean flag to indicate if the row is deleted", + ), + Attribute( + attribute_name="hash_column", + parent_fields=["task_parameters"], + validator=str, + required=False, + comment="The column that has the the hash value of the row to be used to get the diff since " + "the last export. If provided, the from_time is required. It's mutually exclusive with " + "updated_at_column", + ), + Attribute( + attribute_name="updated_at_column", + parent_fields=["task_parameters"], + validator=str, + required=False, + comment="The column that has the last updated timestamp of the row to be used to get the diff " + "since the last export. If provided, the from_time is required. It's mutually exclusive " + "with hash_column", + ), + Attribute( + attribute_name="from_time", + parent_fields=["task_parameters"], + validator=str, + required=False, + comment="Timestamp in YYYY-mm-ddTHH:MM format. It is used for incremental loads." + "It's required when hash_column or updated_at_column is provided", + ), + Attribute( + attribute_name="days_to_live", + parent_fields=["task_parameters"], + validator=str, + required=False, + comment="The number of days to keep the data in the table. If provided, the time_to_live attribute " + "will be set in dynamodb", + ), + + ] + ) + + def __init__(self, name, pipeline_name, pipeline, job_config): + super().__init__(name, pipeline_name, pipeline, job_config) + + self._executable = self.executable or conf.REVERSE_ETL_DEFAULT_EXECUTABLE + self._executable_prefix = self.executable_prefix or conf.REVERSE_ETL_DEFAULT_EXECUTABLE_PREFIX + + self._assume_role_arn = self.parse_attribute("assume_role_arn") + self._num_threads = self.parse_attribute("num_threads") + self._batch_size = self.parse_attribute("batch_size") + self._absolute_job_name = self._absolute_job_name or conf.REVERSE_ETL_DEFAULT_JOB_NAME + self._primary_id_column = self.parse_attribute("primary_id_column") + self._secondary_id_column = self.parse_attribute("secondary_id_column") + self._custom_id_column = self.parse_attribute("custom_id_column") + self._model_name = self.parse_attribute("model_name") + self._project_name = self.parse_attribute("project_name") + self._is_deleted_column = self.parse_attribute("is_deleted_column") + self._hash_column = self.parse_attribute("hash_column") + self._updated_at_column = self.parse_attribute("updated_at_column") + self._from_time = self.parse_attribute("from_time") + self._days_to_live = self.parse_attribute("days_to_live") + + if self._hash_column and self._updated_at_column: + raise ValueError(f"ReverseETLTask: {self._name} hash_column and updated_at_column are mutually exclusive") + + if self._hash_column or self._updated_at_column: + if not self._from_time: + raise ValueError(f"ReverseETLTask: {self._name} from_time is required when hash_column or updated_at_column is provided") + + # Making sure the input table name is set as it is expected in the reverse etl job + input_index = self._get_io_index(self._inputs) + print('XXX', self._inputs, input_index) + if input_index is None: + raise ValueError(f"ReverseEtlTask: {self._name} must have an input") + self._inputs[input_index].name = "input_table_name" + + # Making sure the output name is set as it is expected in the reverse etl job + output_index = self._get_io_index(self._outputs) + if output_index is None: + raise ValueError(f"ReverseEtlTask: {self._name} must have an output") + self._outputs[output_index].name = "output_name" + + # Extracting the output type from the output definition + self._output_type = self._outputs[output_index].ref_name + + # Extracting the outputs region name from the output definition + self._region_name = self._outputs[output_index].region_name + + + def _get_io_index(self, ios): + if len([io for io in ios if io.ref_name != "dummy"]) > 1: + raise ValueError(f"ReverseEtlTask: {self._name} can only have one input or output") + + for i, io in enumerate(ios): + if io.ref_name != "dummy": + return i + return None + + + @property + def assume_role_arn(self): + return self._assume_role_arn + + @property + def num_threads(self): + return self._num_threads + + @property + def batch_size(self): + return self._batch_size + + @property + def primary_id_column(self): + return self._primary_id_column + + @property + def secondary_id_column(self): + return self._secondary_id_column + + @property + def custom_id_column(self): + return self._custom_id_column + + @property + def model_name(self): + return self._model_name + + @property + def project_name(self): + return self._project_name + + @property + def is_deleted_column(self): + return self._is_deleted_column + + @property + def hash_column(self): + return self._hash_column + + @property + def updated_at_column(self): + return self._updated_at_column + + @property + def from_time(self): + return self._from_time + + @property + def days_to_live(self): + return self._days_to_live + + @property + def output_type(self): + return self._output_type + + @property + def region_name(self): + return self._region_name diff --git a/dagger/utilities/config_validator.py b/dagger/utilities/config_validator.py index 1d68f33..e90af68 100644 --- a/dagger/utilities/config_validator.py +++ b/dagger/utilities/config_validator.py @@ -98,10 +98,15 @@ def init_attributes_once(cls, orig_cls: str) -> None: cls.init_attributes(orig_cls) if parent_class.__name__ != "ConfigValidator": - cls.config_attributes[cls.__name__] = ( - cls.config_attributes[parent_class.__name__] - + cls.config_attributes[cls.__name__] - ) + parent_attributes = cls.config_attributes[parent_class.__name__] + current_attributes = cls.config_attributes[cls.__name__] + + # Overwriting attributes in parent operator if they are also existing in the child operator + merged_attributes = {attr.name: attr for attr in parent_attributes} + for attr in current_attributes: + merged_attributes[attr.name] = attr + + cls.config_attributes[cls.__name__] = list(merged_attributes.values()) attributes_lookup = {} for index, attribute in enumerate(cls.config_attributes[cls.__name__]): diff --git a/dagger/utilities/dbt_config_parser.py b/dagger/utilities/dbt_config_parser.py new file mode 100644 index 0000000..9a341f6 --- /dev/null +++ b/dagger/utilities/dbt_config_parser.py @@ -0,0 +1,384 @@ +import json +import yaml +from abc import ABC, abstractmethod +from collections import OrderedDict +from os import path +from os.path import join +from typing import Tuple, List, Dict +import logging + +# Task base configurations +ATHENA_TASK_BASE = {"type": "athena"} +DATABRICKS_TASK_BASE = {"type": "databricks"} +S3_TASK_BASE = {"type": "s3"} +_logger = logging.getLogger("root") + + +class DBTConfigParser(ABC): + """Abstract base class for parsing dbt manifest.json files and generating task configurations.""" + + def __init__(self, config_parameters: dict): + self._dbt_project_dir = config_parameters.get("project_dir") + self._profile_name = config_parameters.get("profile_name", "") + self._target_name = config_parameters.get("target_name", "") + self._dbt_profile_dir = config_parameters.get("profile_dir", None) + self._manifest_data = self._load_file( + self._get_manifest_path(), file_type="json" + ) + profile_data = self._load_file(self._get_profile_path(), file_type="yaml") + self._target_config = ( + profile_data[self._profile_name]["outputs"].get(self._target_name) + if self._profile_name == "athena" + else profile_data[self._profile_name]["outputs"]["data"] + ) # if databricks, get the default catalog and schema from the data output + self._default_schema = self._target_config.get("schema", "") + self._nodes_in_manifest = self._manifest_data.get("nodes", {}) + self._sources_in_manifest = self._manifest_data.get("sources", {}) + + @property + def nodes_in_manifest(self): + return self._nodes_in_manifest + + @property + def sources_in_manifest(self): + return self._sources_in_manifest + + @property + def dbt_default_schema(self): + return self._default_schema + + def _get_manifest_path(self) -> str: + """ + Construct path for manifest.json file based on configuration parameters. + """ + target_path = f"{self._profile_name}_target" + return path.join(self._dbt_project_dir, target_path, "manifest.json") + + def _get_profile_path(self) -> str: + """ + Construct path for profiles.yml file based on configuration parameters. + """ + return path.join(self._dbt_profile_dir, "profiles.yml") + + @staticmethod + def _load_file(file_path: str, file_type: str) -> dict: + """Load a file (JSON or YAML) based on the specified type and return its contents.""" + try: + with open(file_path, "r") as file: + if file_type == "json": + return json.load(file) + elif file_type == "yaml": + return yaml.safe_load(file) + except FileNotFoundError: + _logger.error(f"File not found: {file_path}") + exit(1) + + def _get_athena_table_task( + self, node: dict, follow_external_dependency: bool = False + ) -> dict: + """Generate an athena table task for a DBT node.""" + task = ATHENA_TASK_BASE.copy() + if follow_external_dependency: + task["follow_external_dependency"] = True + + task["schema"] = node.get("schema", self._default_schema) + task["table"] = node.get("name", "") + task["name"] = f"{task['schema']}__{task['table']}_athena" + + return task + + @abstractmethod + def _get_table_task( + self, node: dict, follow_external_dependency: bool = False + ) -> dict: + """Generate a table task for a DBT node for the specific dbt-adapter. Must be implemented by subclasses.""" + pass + + @abstractmethod + def _get_model_data_location( + self, node: dict, schema: str, model_name: str + ) -> Tuple[str, str]: + """Get the S3 path of the DBT model relative to the data bucket. Must be implemented by subclasses.""" + pass + + @abstractmethod + def _get_s3_task(self, node: dict, is_output: bool = False) -> dict: + """ + Generate an S3 task for a DBT node for the specific dbt-adapter. Must be implemented by subclasses. + """ + pass + + @staticmethod + def _get_dummy_task(node: dict, follow_external_dependency: bool = False) -> dict: + """ + Generates a dummy dagger task + Args: + node: The extracted node from the manifest.json file + + Returns: + dict: The dummy dagger task for the DBT node + + """ + task = {} + task["name"] = node.get("name", "") + task["type"] = "dummy" + + if follow_external_dependency: + task["follow_external_dependency"] = True + + return task + + @abstractmethod + def _generate_dagger_output(self, node: dict): + """Generate the dagger output for a DBT node. Must be implemented by subclasses.""" + pass + + @abstractmethod + def _is_node_preparation_model(self, node: dict): + """Define whether it is a preparation model. Must be implemented by subclasses.""" + pass + + def _generate_dagger_tasks(self, node_name: str) -> List[Dict]: + """ + Generates dagger tasks based on the type and materialization of the DBT model node. + + - If the node is a DBT source, an Athena table task is generated. + - If the node is an ephemeral model, a dummy task is generated, and tasks for its dependent nodes are recursively generated. + - If the node is a staging model (preparation model) and not materialized as a table, a table task is generated along with tasks for its dependent nodes. + - For other nodes, a table task is generated. If the node is materialized as a table, an additional S3 task is also generated. + + Args: + node_name: The name of the DBT model node + + Returns: + List[Dict]: The respective dagger tasks for the DBT model node + """ + dagger_tasks = [] + + if node_name.startswith("source"): + node = self._sources_in_manifest[node_name] + else: + node = self._nodes_in_manifest[node_name] + + resource_type = node.get("resource_type") + materialized_type = node.get("config", {}).get("materialized") + + follow_external_dependency = True + if resource_type == "seed" or (self._is_node_preparation_model(node) and materialized_type != "table"): + follow_external_dependency = False + + if resource_type == "source": + table_task = self._get_athena_table_task( + node, follow_external_dependency=follow_external_dependency + ) + dagger_tasks.append(table_task) + + elif materialized_type == "ephemeral": + task = self._get_dummy_task(node) + dagger_tasks.append(task) + for node_name in node.get("depends_on", {}).get("nodes", []): + dagger_tasks += self._generate_dagger_tasks(node_name) + + else: + table_task = self._get_table_task(node, follow_external_dependency=follow_external_dependency) + dagger_tasks.append(table_task) + + if materialized_type in ("table", "incremental"): + dagger_tasks.append(self._get_s3_task(node)) + elif self._is_node_preparation_model(node): + for dependent_node_name in node.get("depends_on", {}).get("nodes", []): + dagger_tasks.extend( + self._generate_dagger_tasks(dependent_node_name) + ) + + return dagger_tasks + + def generate_dagger_io(self, model_name: str) -> Tuple[List[dict], List[dict]]: + """ + Parse through all the parents of the DBT model and return the dagger inputs and outputs for the DBT model + Args: + model_name: The name of the DBT model + + Returns: + Tuple[list, list]: The dagger inputs and outputs for the DBT model + + """ + inputs_list = [] + model_node = self._nodes_in_manifest[f"model.main.{model_name}"] + parent_node_names = model_node.get("depends_on", {}).get("nodes", []) + + for parent_node_name in parent_node_names: + dagger_input = self._generate_dagger_tasks(parent_node_name) + inputs_list += dagger_input + + output_list = self._generate_dagger_output(model_node) + + unique_inputs = list( + OrderedDict( + (frozenset(item.items()), item) for item in inputs_list + ).values() + ) + + return unique_inputs, output_list + + +class AthenaDBTConfigParser(DBTConfigParser): + """Implementation for Athena configurations.""" + + def __init__(self, default_config_parameters: dict): + super().__init__(default_config_parameters) + self._profile_name = "athena" + self._default_data_bucket = default_config_parameters.get("data_bucket") + self._default_data_dir = self._target_config.get( + "s3_data_dir" + ) or self._target_config.get("s3_staging_dir") + + def _is_node_preparation_model(self, node: dict): + """Define whether it is a preparation model.""" + return node.get("name").startswith("stg_") + + def _get_table_task( + self, node: dict, follow_external_dependency: bool = False + ) -> dict: + """ + Generates the dagger athena task for the DBT model node + """ + return self._get_athena_table_task(node, follow_external_dependency) + + def _get_model_data_location( + self, node: dict, schema: str, model_name: str + ) -> Tuple[str, str]: + """ + Gets the S3 path of the dbt model relative to the data bucket. + """ + location = node.get("config", {}).get("external_location") + if not location: + location = join(self._default_data_dir, schema, model_name) + + split = location.split("//")[1].split("/") + bucket_name, data_path = split[0], "/".join(split[1:]) + + return bucket_name, data_path + + def _get_s3_task(self, node: dict, is_output: bool = False) -> dict: + """ + Generates the dagger s3 task for the athena-dbt model node + """ + task = S3_TASK_BASE.copy() + + schema = node.get("schema", self._default_schema) + table = node.get("name", "") + task["name"] = f"output_s3_path" if is_output else f"s3_{table}" + task["bucket"] = self._default_data_bucket + _, task["path"] = self._get_model_data_location(node, schema, table) + + return task + + def _generate_dagger_output(self, node: dict): + """ + Generates the dagger output for the DBT model node with athena-dbt adapter. If the model is materialized as a view or ephemeral, then a dummy task is created. + Otherwise, an athena and s3 task is created for the DBT model node. + Args: + node: The extracted node from the manifest.json file + + Returns: + dict: The dagger output, which is a combination of an athena and s3 task for the DBT model node + + """ + materialized_type = node.get("config", {}).get("materialized") + if materialized_type == "ephemeral": + return [self._get_dummy_task(node)] + else: + output_tasks = [self._get_table_task(node)] + if materialized_type in ("table", "incremental"): + output_tasks.append(self._get_s3_task(node, is_output=True)) + return output_tasks + + +class DatabricksDBTConfigParser(DBTConfigParser): + """Implementation for Databricks configurations.""" + + def __init__(self, default_config_parameters: dict): + super().__init__(default_config_parameters) + self._profile_name = "databricks" + self._default_catalog = self._target_config.get("catalog") + self._create_external_athena_table = default_config_parameters.get( + "create_external_athena_table", False + ) + + def _is_node_preparation_model(self, node: dict): + """ + Define whether it is a preparation model. + """ + return "preparation" in node.get("schema", "") + + def _get_table_task( + self, node: dict, follow_external_dependency: bool = False + ) -> dict: + """ + Generates the dagger databricks task for the DBT model node + """ + task = DATABRICKS_TASK_BASE.copy() + if follow_external_dependency: + task["follow_external_dependency"] = True + + task["catalog"] = node.get("database", self._default_catalog) + task["schema"] = node.get("schema", self._default_schema) + task["table"] = node.get("name", "") + task[ + "name" + ] = f"{task['catalog']}__{task['schema']}__{task['table']}_databricks" + + return task + + def _get_model_data_location( + self, node: dict, schema: str, model_name: str + ) -> Tuple[str, str]: + """ + Gets the S3 path of the dbt model relative to the data bucket. + """ + location_root = node.get("config", {}).get("location_root") + location = join(location_root, model_name) + split = location.split("//")[1].split("/") + bucket_name, data_path = split[0], "/".join(split[1:]) + + return bucket_name, data_path + + def _get_s3_task(self, node: dict, is_output: bool = False) -> dict: + """ + Generates the dagger s3 task for the databricks-dbt model node + """ + task = S3_TASK_BASE.copy() + + schema = node.get("schema", self._default_schema) + table = node.get("name", "") + task["name"] = f"output_s3_path" if is_output else f"s3_{table}" + task["bucket"], task["path"] = self._get_model_data_location( + node, schema, table + ) + + return task + + def _generate_dagger_output(self, node: dict): + """ + Generates the dagger output for the DBT model node with the databricks-dbt adapter. + If the model is materialized as a view or ephemeral, then a dummy task is created. + Otherwise, and databricks and s3 task is created for the DBT model node. + And if create_external_athena_table is True te an extra athena task is created. + Args: + node: The extracted node from the manifest.json file + + Returns: + dict: The dagger output, which is a combination of an athena and s3 task for the DBT model node + + """ + materialized_type = node.get("config", {}).get("materialized") + if materialized_type == "ephemeral": + return [self._get_dummy_task(node)] + else: + output_tasks = [self._get_table_task(node)] + if materialized_type in ("table", "incremental"): + output_tasks.append(self._get_s3_task(node, is_output=True)) + if self._create_external_athena_table: + output_tasks.append(self._get_athena_table_task(node)) + return output_tasks diff --git a/dagger/utilities/module.py b/dagger/utilities/module.py index 6f3b395..7f33690 100644 --- a/dagger/utilities/module.py +++ b/dagger/utilities/module.py @@ -1,14 +1,21 @@ +import importlib +import inspect import logging -from os import path -from mergedeep import merge +import os +import pkgutil +from os import path, environ +import jinja2 import yaml +from dagger import conf +from mergedeep import merge + _logger = logging.getLogger("root") class Module: - def __init__(self, path_to_config, target_dir): + def __init__(self, path_to_config, target_dir, jinja_parameters=None): self._directory = path.dirname(path_to_config) self._target_dir = target_dir or "./" self._path_to_config = path_to_config @@ -16,11 +23,14 @@ def __init__(self, path_to_config, target_dir): self._tasks = {} for task in config["tasks"]: - self._tasks[task] = self.read_task_config(f"{path.join(self._directory, task)}.yaml") + self._tasks[task] = self.read_task_config( + f"{path.join(self._directory, task)}.yaml" + ) self._branches_to_generate = config["branches_to_generate"] self._override_parameters = config.get("override_parameters", {}) self._default_parameters = config.get("default_parameters", {}) + self._jinja_parameters = jinja_parameters or {} @staticmethod def read_yaml(yaml_str): @@ -40,21 +50,40 @@ def read_task_config(self, task): exit(1) return content + @staticmethod + def load_plugins_to_jinja_environment(environment: jinja2.Environment) -> jinja2.Environment: + """ + Dynamically load all classes(plugins) from the folders defined in the conf.PLUGIN_DIRS variable. + The folder contains all plugins that are part of the project. + Returns: + dict: A dictionary with the class name as key and the class object as value + """ + for plugin_path in conf.PLUGIN_DIRS: + for root, dirs, files in os.walk(plugin_path): + dirs[:] = [directory for directory in dirs if not directory.lower().startswith("test")] + for plugin_file in files: + if plugin_file.endswith(".py") and not (plugin_file.startswith("__") or plugin_file.startswith("test")): + module_name = plugin_file.replace(".py", "") + module_path = os.path.join(root, plugin_file) + spec = importlib.util.spec_from_file_location(module_name, module_path) + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + + for name, obj in inspect.getmembers(module, inspect.isclass): + environment.globals[f"{name}"] = obj + + return environment + @staticmethod def replace_template_parameters(_task_str, _template_parameters): - for _key, _value in _template_parameters.items(): - if type(_value) == str: - try: - int_value = int(_value) - _value = f"\"{_value}\"" - except: - pass - locals()[_key] = _value + environment = jinja2.Environment() + environment = Module.load_plugins_to_jinja_environment(environment) + template = environment.from_string(_task_str) + rendered_task = template.render(_template_parameters) return ( - _task_str.format(**locals()) - .replace("{", "{{") - .replace("}", "}}") + rendered_task + # TODO Remove this hack and use Jinja escaping instead of special expression in template files .replace("__CBS__", "{") .replace("__CBE__", "}") ) @@ -73,6 +102,8 @@ def generate_task_configs(self): template_parameters = {} template_parameters.update(self._default_parameters or {}) template_parameters.update(attrs) + template_parameters['branch_name'] = branch_name + template_parameters.update(self._jinja_parameters) for task, task_yaml in self._tasks.items(): task_name = f"{branch_name}_{task}" @@ -86,7 +117,9 @@ def generate_task_configs(self): override_parameters = self._override_parameters or {} merge(task_dict, override_parameters.get(branch_name, {}).get(task, {})) - self.dump_yaml(task_dict, f"{path.join(self._target_dir, task_name)}.yaml") + self.dump_yaml( + task_dict, f"{path.join(self._target_dir, task_name)}.yaml" + ) @staticmethod def module_config_template(): diff --git a/dockers/airflow/airflow.cfg b/dockers/airflow/airflow.cfg index a5ace87..0b19fbd 100644 --- a/dockers/airflow/airflow.cfg +++ b/dockers/airflow/airflow.cfg @@ -434,7 +434,7 @@ backend = # The backend_kwargs param is loaded into a dictionary and passed to __init__ of secrets backend class. # See documentation for the secrets backend you are using. JSON is expected. # Example for AWS Systems Manager ParameterStore: -# ``{{"connections_prefix": "/airflow/connections", "profile_name": "default"}}`` +# ``{{"connections_prefix": "/airflow/connections", "target_name": "default"}}`` backend_kwargs = [cli] diff --git a/reqs/base.txt b/reqs/base.txt index b6c8400..279ed4b 100644 --- a/reqs/base.txt +++ b/reqs/base.txt @@ -1,7 +1,7 @@ click==8.1.3 -croniter==1.3.8 +croniter==2.0.2 envyaml==1.10.211231 mergedeep==1.3.4 slack==0.0.2 slackclient==2.9.4 -tenacity==8.1.0 +tenacity~=8.3.0 diff --git a/reqs/dev.txt b/reqs/dev.txt index 806d6c5..238b2e2 100644 --- a/reqs/dev.txt +++ b/reqs/dev.txt @@ -1,19 +1,18 @@ -apache-airflow[amazon,postgres,s3,statsd]==2.3.4 +pip==24.0 +apache-airflow[amazon,postgres,s3,statsd]==2.9.0 --constraint "https://raw.githubusercontent.com/apache/airflow/constraints-2.9.0/constraints-3.9.txt" black==22.10.0 -boto3==1.26.16 +boto3==1.34.82 bumpversion==0.6.0 -coverage==6.5.0 +coverage==7.4.4 elasticsearch==7.17.7 -flake8==5.0.4 -neo4j==5.2.1 -numpydoc==1.5.0 -pip==22.3.1 +flake8==7.0.0 +neo4j==5.19.0 +numpydoc==1.7.0 pre-commit==2.20.0 -sphinx-rtd-theme==1.1.1 -Sphinx==4.3.2 -SQLAlchemy==1.4.44 -tox==3.27.1 -twine==4.0.1 -watchdog==2.1.9 -Werkzeug==2.2.2 -wheel==0.38.4 +sphinx-rtd-theme==2.0.0 +Sphinx==7.2.6 +SQLAlchemy +tox==4.14.2 +twine==5.0.0 +watchdog==4.0.0 +Werkzeug diff --git a/reqs/test.txt b/reqs/test.txt index c568f77..7bdc89f 100644 --- a/reqs/test.txt +++ b/reqs/test.txt @@ -1,3 +1,4 @@ -apache-airflow[amazon,postgres,s3,statsd]==2.3.4 +apache-airflow[amazon,postgres,s3,statsd]==2.9.0 pytest-cov==4.0.0 pytest==7.2.0 +graphviz diff --git a/setup.py b/setup.py index 3f80fe3..f8b4b28 100644 --- a/setup.py +++ b/setup.py @@ -45,8 +45,7 @@ def reqs(*f): classifiers=[ "Development Status :: 2 - Pre-Alpha", "Intended Audience :: Developers", - "Programming Language :: Python :: 3.6", - "Programming Language :: Python :: 3.7", + "Programming Language :: Python :: 3.9", ], description="Config Driven ETL", entry_points={"console_scripts": ["dagger=dagger.main:cli"]}, @@ -60,6 +59,6 @@ def reqs(*f): packages=find_packages(), tests_require=test_requires, url="https://gitlab.com/goflash1/data/dagger", - version="0.9.0", + version="0.9.1", zip_safe=False, ) diff --git a/tests/fixtures/config_finder/root/dags/test_external_sensor/dummy_first.yaml b/tests/fixtures/config_finder/root/dags/test_external_sensor/dummy_first.yaml index 279dd70..e97563f 100644 --- a/tests/fixtures/config_finder/root/dags/test_external_sensor/dummy_first.yaml +++ b/tests/fixtures/config_finder/root/dags/test_external_sensor/dummy_first.yaml @@ -5,7 +5,8 @@ inputs: # format: list | Use dagger init-io cli name: redshift_input schema: dwh table: batch_table - follow_external_dependency: True + follow_external_dependency: + poke_interval: 60 outputs: # format: list | Use dagger init-io cli - type: dummy name: first_dummy_output diff --git a/tests/fixtures/modules/dbt_config_parser_fixtures_athena.py b/tests/fixtures/modules/dbt_config_parser_fixtures_athena.py new file mode 100644 index 0000000..64fffce --- /dev/null +++ b/tests/fixtures/modules/dbt_config_parser_fixtures_athena.py @@ -0,0 +1,456 @@ +DBT_PROFILE_FIXTURE = { + "athena": { + "outputs": { + "data": { + "aws_profile_name": "data", + "database": "awsdatacatalog", + "num_retries": 10, + "region_name": "eu-west-1", + "s3_data_dir": "s3://bucket1-data-lake/path1/tmp", + "s3_data_naming": "schema_table", + "s3_staging_dir": "s3://bucket1-data-lake/path1/", + "schema": "analytics_engineering", + "threads": 4, + "type": "athena", + "work_group": "primary", + }, + } + } +} + +DBT_MANIFEST_FILE_FIXTURE = { + "nodes": { + "model.main.model1": { + "database": "awsdatacatalog", + "schema": "analytics_engineering", + "unique_id": "model.main.model1", + "name": "model1", + "config": { + "external_location": "s3://bucket1-data-lake/path1/model1", + "materialized": "incremental", + "incremental_strategy": "insert_overwrite", + }, + "description": "Details of revenue calculation at supplier level for each observation day", + "tags": ["daily"], + "unrendered_config": { + "materialized": "incremental", + "external_location": "s3://bucket1-data-lake/path1/model1", + "incremental_strategy": "insert_overwrite", + "partitioned_by": ["year", "month", "day", "dt"], + "tags": ["daily"], + "on_schema_change": "fail", + }, + "depends_on": { + "macros": [ + "macro.main.macro1", + "macro.main.macro2", + ], + "nodes": [ + "model.main.stg_core_schema2__table2", + "model.main.model2", + "model.main.int_model3", + "seed.main.seed_buyer_country_overwrite", + ], + }, + }, + "model.main.stg_core_schema1__table1": { + "schema": "analytics_engineering", + "unique_id": "model.main.stg_core_schema1__table1", + "name": "stg_core_schema1__table1", + "config": { + "materialized": "table", + "external_location": "s3://bucket1-data-lake/path2/stg_core_schema1__table1", + }, + "depends_on": { + "macros": [], + "nodes": ["source.main.core_schema1.table1"], + }, + }, + "model.main.stg_core_schema2__table2": { + "schema": "analytics_engineering", + "name": "stg_core_schema2__table2", + "unique_id": "model.main.stg_core_schema2__table2", + "config": { + "materialized": "view", + }, + "depends_on": { + "macros": [], + "nodes": [ + "source.main.core_schema2.table2", + "source.main.core_schema2.table3", + "seed.main.seed_buyer_country_overwrite", + ], + }, + }, + "model.main.model2": { + "name": "model2", + "schema": "analytics_engineering", + "unique_id": "model.main.model2", + "config": { + "external_location": "s3://bucket1-data-lake/path2/model2", + "materialized": "table", + }, + "depends_on": {"macros": [], "nodes": []}, + }, + "model.main.int_model3": { + "name": "int_model3", + "unique_id": "model.main.int_model3", + "schema": "analytics_engineering", + "config": { + "materialized": "ephemeral", + }, + "depends_on": { + "macros": [], + "nodes": ["model.main.int_model2"], + }, + }, + "seed.main.seed_buyer_country_overwrite": { + "database": "awsdatacatalog", + "schema": "analytics_engineering", + "unique_id": "seed.main.seed_buyer_country_overwrite", + "name": "seed_buyer_country_overwrite", + "resource_type": "seed", + "alias": "seed_buyer_country_overwrite", + "tags": ["analytics"], + "description": "", + "created_at": 1700216177.105391, + "depends_on": {"macros": []}, + }, + "model.main.model3": { + "name": "model3", + "schema": "analytics_engineering", + "unique_id": "model.main.model3", + "config": { + "external_location": "s3://bucket1-data-lake/path2/model3", + }, + "depends_on": { + "macros": [], + "nodes": [ + "model.main.int_model3", + "model.main.model2", + "seed.main.seed_buyer_country_overwrite", + "model.main.stg_core_schema2__table2", + ], + }, + }, + "model.main.int_model2": { + "name": "int_model2", + "unique_id": "model.main.int_model2", + "schema": "analytics_engineering", + "config": { + "materialized": "ephemeral", + }, + "depends_on": { + "macros": [], + "nodes": [ + "seed.main.seed_buyer_country_overwrite", + "model.main.stg_core_schema1__table1", + ], + }, + }, + }, + "sources": { + "source.main.core_schema1.table1": { + "source_name": "table1", + "database": "awsdatacatalog", + "schema": "core_schema1", + "resource_type": "source", + "unique_id": "source.main.core_schema1.table1", + "name": "table1", + "tags": ["analytics"], + "description": "", + }, + "source.main.core_schema2.table2": { + "source_name": "table2", + "database": "awsdatacatalog", + "schema": "core_schema2", + "resource_type": "source", + "unique_id": "source.main.core_schema2.table2", + "name": "table2", + "tags": ["analytics"], + "description": "", + }, + "source.main.core_schema2.table3": { + "source_name": "table3", + "database": "awsdatacatalog", + "schema": "core_schema2", + "resource_type": "source", + "unique_id": "source.main.core_schema2.table3", + "name": "table3", + "tags": ["analytics"], + "description": "", + }, + }, +} + +EXPECTED_STAGING_NODE = [ + { + "name": "analytics_engineering__stg_core_schema1__table1_athena", + "type": "athena", + "table": "stg_core_schema1__table1", + "schema": "analytics_engineering", + "follow_external_dependency": True, + }, + { + "type": "s3", + "name": "s3_stg_core_schema1__table1", + "bucket": "bucket1-data-lake", + "path": "path2/stg_core_schema1__table1", + }, +] + +EXPECTED_SEED_NODE = [ + { + "type": "athena", + "schema": "analytics_engineering", + "table": "seed_buyer_country_overwrite", + "name": "analytics_engineering__seed_buyer_country_overwrite_athena", + } +] + +EXPECTED_MODEL_MULTIPLE_DEPENDENCIES = [ + { + "type": "dummy", + "name": "int_model3", + }, + { + "type": "dummy", + "name": "int_model2", + }, + { + "type": "athena", + "schema": "analytics_engineering", + "table": "seed_buyer_country_overwrite", + "name": "analytics_engineering__seed_buyer_country_overwrite_athena", + }, + { + "name": "analytics_engineering__stg_core_schema1__table1_athena", + "type": "athena", + "table": "stg_core_schema1__table1", + "schema": "analytics_engineering", + "follow_external_dependency": True, + }, + { + "type": "s3", + "name": "s3_stg_core_schema1__table1", + "bucket": "bucket1-data-lake", + "path": "path2/stg_core_schema1__table1", + }, + { + "type": "athena", + "name": "analytics_engineering__model2_athena", + "schema": "analytics_engineering", + "table": "model2", + "follow_external_dependency": True, + }, + { + "bucket": "bucket1-data-lake", + "name": "s3_model2", + "path": "path2/model2", + "type": "s3", + }, + { + "type": "athena", + "schema": "analytics_engineering", + "table": "stg_core_schema2__table2", + "name": "analytics_engineering__stg_core_schema2__table2_athena", + }, + { + "type": "athena", + "name": "core_schema2__table2_athena", + "schema": "core_schema2", + "table": "table2", + "follow_external_dependency": True, + }, + { + "type": "athena", + "name": "core_schema2__table3_athena", + "schema": "core_schema2", + "table": "table3", + "follow_external_dependency": True, + }, +] + +EXPECTED_EPHEMERAL_NODE = [ + { + "type": "dummy", + "name": "int_model3", + }, + { + "type": "dummy", + "name": "int_model2", + }, + { + "type": "athena", + "schema": "analytics_engineering", + "table": "seed_buyer_country_overwrite", + "name": "analytics_engineering__seed_buyer_country_overwrite_athena", + }, + { + "name": "analytics_engineering__stg_core_schema1__table1_athena", + "type": "athena", + "table": "stg_core_schema1__table1", + "schema": "analytics_engineering", + "follow_external_dependency": True, + }, + { + "type": "s3", + "name": "s3_stg_core_schema1__table1", + "bucket": "bucket1-data-lake", + "path": "path2/stg_core_schema1__table1", + }, +] + +EXPECTED_MODEL_NODE = [ + { + "type": "athena", + "name": "analytics_engineering__model1_athena", + "schema": "analytics_engineering", + "table": "model1", + "follow_external_dependency": True, + }, + { + "bucket": "bucket1-data-lake", + "name": "s3_model1", + "path": "path1/model1", + "type": "s3", + }, +] + +EXPECTED_DAGGER_INPUTS = [ + { + "type": "athena", + "schema": "analytics_engineering", + "table": "stg_core_schema2__table2", + "name": "analytics_engineering__stg_core_schema2__table2_athena", + }, + { + "type": "athena", + "name": "core_schema2__table2_athena", + "schema": "core_schema2", + "table": "table2", + "follow_external_dependency": True, + }, + { + "type": "athena", + "name": "core_schema2__table3_athena", + "schema": "core_schema2", + "table": "table3", + "follow_external_dependency": True, + }, + { + "type": "athena", + "schema": "analytics_engineering", + "table": "seed_buyer_country_overwrite", + "name": "analytics_engineering__seed_buyer_country_overwrite_athena", + }, + { + "name": "analytics_engineering__model2_athena", + "schema": "analytics_engineering", + "table": "model2", + "type": "athena", + "follow_external_dependency": True, + }, + { + "bucket": "bucket1-data-lake", + "name": "s3_model2", + "path": "path2/model2", + "type": "s3", + }, + { + "type": "dummy", + "name": "int_model3", + }, + { + "type": "dummy", + "name": "int_model2", + }, + { + "name": "analytics_engineering__stg_core_schema1__table1_athena", + "type": "athena", + "table": "stg_core_schema1__table1", + "schema": "analytics_engineering", + "follow_external_dependency": True, + }, + { + "type": "s3", + "name": "s3_stg_core_schema1__table1", + "bucket": "bucket1-data-lake", + "path": "path2/stg_core_schema1__table1", + }, +] + +EXPECTED_DBT_STAGING_MODEL_DAGGER_INPUTS = [ + { + "follow_external_dependency": True, + "name": "core_schema2__table2_athena", + "schema": "core_schema2", + "table": "table2", + "type": "athena", + }, + { + "follow_external_dependency": True, + "name": "core_schema2__table3_athena", + "schema": "core_schema2", + "table": "table3", + "type": "athena", + }, + { + "type": "athena", + "schema": "analytics_engineering", + "table": "seed_buyer_country_overwrite", + "name": "analytics_engineering__seed_buyer_country_overwrite_athena", + } +] + +EXPECTED_DAGGER_OUTPUTS = [ + { + "name": "analytics_engineering__model1_athena", + "schema": "analytics_engineering", + "table": "model1", + "type": "athena", + }, + { + "bucket": "bucket1-data-lake", + "name": "output_s3_path", + "path": "path1/model1", + "type": "s3", + }, +] + +EXPECTED_DBT_STAGING_MODEL_DAGGER_OUTPUTS = [ + { + "name": "analytics_engineering__stg_core_schema1__table1_athena", + "type": "athena", + "table": "stg_core_schema1__table1", + "schema": "analytics_engineering", + }, + { + "type": "s3", + "name": "output_s3_path", + "bucket": "bucket1-data-lake", + "path": "path2/stg_core_schema1__table1", + }, +] + +EXPECTED_DBT_INT_MODEL_DAGGER_INPUTS = [ + { + "type": "athena", + "schema": "analytics_engineering", + "table": "seed_buyer_country_overwrite", + "name": "analytics_engineering__seed_buyer_country_overwrite_athena", + }, + { + "name": "analytics_engineering__stg_core_schema1__table1_athena", + "type": "athena", + "table": "stg_core_schema1__table1", + "schema": "analytics_engineering", + "follow_external_dependency": True, + }, + { + "type": "s3", + "name": "s3_stg_core_schema1__table1", + "bucket": "bucket1-data-lake", + "path": "path2/stg_core_schema1__table1", + }, +] diff --git a/tests/fixtures/modules/dbt_config_parser_fixtures_databricks.py b/tests/fixtures/modules/dbt_config_parser_fixtures_databricks.py new file mode 100644 index 0000000..ad6b912 --- /dev/null +++ b/tests/fixtures/modules/dbt_config_parser_fixtures_databricks.py @@ -0,0 +1,495 @@ +DATABRICKS_DBT_PROFILE_FIXTURE = { + "databricks": { + "outputs": { + "data": { + "type": "databricks", + "catalog": "hive_metastore", + "schema": "analytics_engineering", + "host": "xxx.databricks.com", + "http_path": "/sql/1.0/warehouses/xxx", + "token": "{{ env_var('SECRETDATABRICKS') }}", + }, + } + } +} + +DATABRICKS_DBT_MANIFEST_FILE_FIXTURE = { + "nodes": { + "model.main.model1": { + "database": "marts", + "schema": "analytics_engineering", + "name": "model1", + "unique_id": "model.main.model1", + "resource_type": "model", + "config": { + "location_root": "s3://chodata-data-lake/analytics_warehouse/data/marts/analytics_engineering", + "materialized": "incremental", + "incremental_strategy": "insert_overwrite", + }, + "description": "Details of revenue calculation at supplier level for each observation day", + "tags": ["daily"], + "unrendered_config": { + "materialized": "incremental", + "location_root": "s3://chodata-data-lake/analytics_warehouse/data/marts/analytics_engineering", + "incremental_strategy": "insert_overwrite", + "partitioned_by": ["year", "month", "day", "dt"], + "tags": ["daily"], + "on_schema_change": "fail", + }, + "depends_on": { + "macros": [ + "macro.main.macro1", + "macro.main.macro2", + ], + "nodes": [ + "model.main.stg_core_schema2__table2", + "model.main.model2", + "model.main.int_model3", + "seed.main.seed_buyer_country_overwrite", + ], + }, + }, + "model.main.stg_core_schema1__table1": { + "database": "hive_metastore", + "schema": "data_preparation", + "name": "stg_core_schema1__table1", + "unique_id": "model.main.stg_core_schema1__table1", + "resource_type": "model", + "config": { + "location_root": "s3://chodata-data-lake/analytics_warehouse/data/preparation", + "materialized": "table", + }, + "depends_on": { + "macros": [], + "nodes": ["source.main.core_schema1.table1"], + }, + }, + "model.main.stg_core_schema2__table2": { + "database": "hive_metastore", + "schema": "data_preparation", + "name": "stg_core_schema2__table2", + "unique_id": "model.main.stg_core_schema2__table2", + "resource_type": "model", + "config": { + "location_root": "s3://chodata-data-lake/analytics_warehouse/data/preparation", + "materialized": "view", + }, + "depends_on": { + "macros": [], + "nodes": [ + "source.main.core_schema2.table2", + "source.main.core_schema2.table3", + "seed.main.seed_buyer_country_overwrite", + ], + }, + }, + "model.main.model2": { + "database": "marts", + "schema": "analytics_engineering", + "name": "model2", + "unique_id": "model.main.model2", + "resource_type": "model", + "config": { + "location_root": "s3://chodata-data-lake/analytics_warehouse/data/marts/analytics_engineering", + "materialized": "table", + }, + "depends_on": {"macros": [], "nodes": []}, + }, + "model.main.int_model3": { + "name": "int_model3", + "unique_id": "model.main.int_model3", + "database": "intermediate", + "schema": "analytics_engineering", + "resource_type": "model", + "config": { + "materialized": "ephemeral", + "location_root": "s3://chodata-data-lake/analytics_warehouse/data/intermediate/analytics_engineering", + }, + "depends_on": { + "macros": [], + "nodes": ["model.main.int_model2"], + }, + }, + "seed.main.seed_buyer_country_overwrite": { + "database": "hive_metastore", + "schema": "data_preparation", + "name": "seed_buyer_country_overwrite", + "unique_id": "seed.main.seed_buyer_country_overwrite", + "resource_type": "seed", + "alias": "seed_buyer_country_overwrite", + "tags": ["analytics"], + "description": "", + "created_at": 1700216177.105391, + "depends_on": {"macros": []}, + }, + "model.main.model3": { + "name": "model3", + "database": "marts", + "schema": "analytics_engineering", + "unique_id": "model.main.model3", + "config": { + "location_root": "s3://chodata-data-lake/analytics_warehouse/data/marts/analytics_engineering", + }, + "depends_on": { + "macros": [], + "nodes": [ + "model.main.int_model3", + "model.main.model2", + "seed.main.seed_buyer_country_overwrite", + "model.main.stg_core_schema2__table2", + ], + }, + }, + "model.main.int_model2": { + "name": "int_model2", + "unique_id": "model.main.int_model2", + "database": "intermediate", + "schema": "analytics_engineering", + "config": { + "materialized": "ephemeral", + "location_root": "s3://chodata-data-lake/analytics_warehouse/data/intermediate/analytics_engineering", + }, + "depends_on": { + "macros": [], + "nodes": [ + "seed.main.seed_buyer_country_overwrite", + "model.main.stg_core_schema1__table1", + ], + }, + }, + }, + "sources": { + "source.main.core_schema1.table1": { + "source_name": "table1", + "database": "hive_metastore", + "schema": "core_schema1", + "resource_type": "source", + "unique_id": "source.main.core_schema1.table1", + "name": "table1", + "tags": ["analytics"], + "description": "", + }, + "source.main.core_schema2.table2": { + "source_name": "table2", + "database": "hive_metastore", + "schema": "core_schema2", + "resource_type": "source", + "unique_id": "source.main.core_schema2.table2", + "name": "table2", + "tags": ["analytics"], + "description": "", + }, + "source.main.core_schema2.table3": { + "source_name": "table3", + "database": "hive_metastore", + "schema": "core_schema2", + "resource_type": "source", + "unique_id": "source.main.core_schema2.table3", + "name": "table3", + "tags": ["analytics"], + "description": "", + }, + }, +} + +DATABRICKS_EXPECTED_STAGING_NODE = [ + { + "type": "databricks", + "follow_external_dependency": True, + "catalog": "hive_metastore", + "schema": "data_preparation", + "table": "stg_core_schema1__table1", + "name": "hive_metastore__data_preparation__stg_core_schema1__table1_databricks", + }, + { + "type": "s3", + "name": "s3_stg_core_schema1__table1", + "bucket": "chodata-data-lake", + "path": "analytics_warehouse/data/preparation/stg_core_schema1__table1", + }, +] + +DATABRICKS_EXPECTED_SEED_NODE = [ + { + "type": "databricks", + "catalog": "hive_metastore", + "schema": "data_preparation", + "table": "seed_buyer_country_overwrite", + "name": "hive_metastore__data_preparation__seed_buyer_country_overwrite_databricks", + } +] + +DATABRICKS_EXPECTED_MODEL_MULTIPLE_DEPENDENCIES = [ + { + "type": "dummy", + "name": "int_model3", + }, + { + "type": "dummy", + "name": "int_model2", + }, + { + "type": "databricks", + "catalog": "hive_metastore", + "schema": "data_preparation", + "table": "seed_buyer_country_overwrite", + "name": "hive_metastore__data_preparation__seed_buyer_country_overwrite_databricks", + }, + { + "type": "databricks", + "catalog": "hive_metastore", + "schema": "data_preparation", + "table": "stg_core_schema1__table1", + "name": "hive_metastore__data_preparation__stg_core_schema1__table1_databricks", + "follow_external_dependency": True, + }, + { + "type": "s3", + "name": "s3_stg_core_schema1__table1", + "bucket": "chodata-data-lake", + "path": "analytics_warehouse/data/preparation/stg_core_schema1__table1", + }, + { + "type": "databricks", + "name": "marts__analytics_engineering__model2_databricks", + "catalog": "marts", + "schema": "analytics_engineering", + "table": "model2", + "follow_external_dependency": True, + }, + { + "bucket": "chodata-data-lake", + "name": "s3_model2", + "path": "analytics_warehouse/data/marts/analytics_engineering/model2", + "type": "s3", + }, + { + "type": "databricks", + "catalog": "hive_metastore", + "schema": "data_preparation", + "table": "stg_core_schema2__table2", + "name": "hive_metastore__data_preparation__stg_core_schema2__table2_databricks", + }, + { + "type": "athena", + "name": "core_schema2__table2_athena", + "schema": "core_schema2", + "table": "table2", + "follow_external_dependency": True, + }, + { + "type": "athena", + "name": "core_schema2__table3_athena", + "schema": "core_schema2", + "table": "table3", + "follow_external_dependency": True, + }, +] + +DATABRICKS_EXPECTED_EPHEMERAL_NODE = [ + { + "type": "dummy", + "name": "int_model3", + }, + { + "type": "dummy", + "name": "int_model2", + }, + { + "type": "databricks", + "catalog": "hive_metastore", + "schema": "data_preparation", + "table": "seed_buyer_country_overwrite", + "name": "hive_metastore__data_preparation__seed_buyer_country_overwrite_databricks", + }, + { + "type": "databricks", + "follow_external_dependency": True, + "catalog": "hive_metastore", + "schema": "data_preparation", + "table": "stg_core_schema1__table1", + "name": "hive_metastore__data_preparation__stg_core_schema1__table1_databricks", + }, + { + "type": "s3", + "name": "s3_stg_core_schema1__table1", + "bucket": "chodata-data-lake", + "path": "analytics_warehouse/data/preparation/stg_core_schema1__table1", + }, +] + +DATABRICKS_EXPECTED_MODEL_NODE = [ + { + "type": "databricks", + "name": "marts__analytics_engineering__model1_databricks", + "catalog": "marts", + "schema": "analytics_engineering", + "table": "model1", + "follow_external_dependency": True, + }, + { + "bucket": "chodata-data-lake", + "name": "s3_model1", + "path": "analytics_warehouse/data/marts/analytics_engineering/model1", + "type": "s3", + }, +] + +DATABRICKS_EXPECTED_DAGGER_INPUTS = [ + { + "type": "databricks", + "catalog": "hive_metastore", + "schema": "data_preparation", + "table": "stg_core_schema2__table2", + "name": "hive_metastore__data_preparation__stg_core_schema2__table2_databricks", + }, + { + "type": "athena", + "name": "core_schema2__table2_athena", + "schema": "core_schema2", + "table": "table2", + "follow_external_dependency": True, + }, + { + "type": "athena", + "name": "core_schema2__table3_athena", + "schema": "core_schema2", + "table": "table3", + "follow_external_dependency": True, + }, + { + "type": "databricks", + "catalog": "hive_metastore", + "schema": "data_preparation", + "table": "seed_buyer_country_overwrite", + "name": "hive_metastore__data_preparation__seed_buyer_country_overwrite_databricks", + }, + { + "name": "marts__analytics_engineering__model2_databricks", + "catalog": "marts", + "schema": "analytics_engineering", + "table": "model2", + "type": "databricks", + "follow_external_dependency": True, + }, + { + "bucket": "chodata-data-lake", + "name": "s3_model2", + "path": "analytics_warehouse/data/marts/analytics_engineering/model2", + "type": "s3", + }, + { + "type": "dummy", + "name": "int_model3", + }, + { + "type": "dummy", + "name": "int_model2", + }, + { + "type": "databricks", + "follow_external_dependency": True, + "catalog": "hive_metastore", + "schema": "data_preparation", + "table": "stg_core_schema1__table1", + "name": "hive_metastore__data_preparation__stg_core_schema1__table1_databricks", + }, + { + "type": "s3", + "name": "s3_stg_core_schema1__table1", + "bucket": "chodata-data-lake", + "path": "analytics_warehouse/data/preparation/stg_core_schema1__table1", + }, +] + +DATABRICKS_EXPECTED_DBT_STAGING_MODEL_DAGGER_INPUTS = [ + { + "follow_external_dependency": True, + "name": "core_schema2__table2_athena", + "schema": "core_schema2", + "table": "table2", + "type": "athena", + }, + { + "follow_external_dependency": True, + "name": "core_schema2__table3_athena", + "schema": "core_schema2", + "table": "table3", + "type": "athena", + }, + { + "type": "databricks", + "catalog": "hive_metastore", + "schema": "data_preparation", + "table": "seed_buyer_country_overwrite", + "name": "hive_metastore__data_preparation__seed_buyer_country_overwrite_databricks", + }, +] + +DATABRICKS_EXPECTED_DBT_INT_MODEL_DAGGER_INPUTS = [ + { + "type": "databricks", + "catalog": "hive_metastore", + "schema": "data_preparation", + "table": "seed_buyer_country_overwrite", + "name": "hive_metastore__data_preparation__seed_buyer_country_overwrite_databricks", + }, + { + "type": "databricks", + "follow_external_dependency": True, + "catalog": "hive_metastore", + "schema": "data_preparation", + "table": "stg_core_schema1__table1", + "name": "hive_metastore__data_preparation__stg_core_schema1__table1_databricks", + }, + { + "type": "s3", + "name": "s3_stg_core_schema1__table1", + "bucket": "chodata-data-lake", + "path": "analytics_warehouse/data/preparation/stg_core_schema1__table1", + }, +] + +DATABRICKS_EXPECTED_DAGGER_OUTPUTS = [ + { + "name": "marts__analytics_engineering__model1_databricks", + "catalog": "marts", + "schema": "analytics_engineering", + "table": "model1", + "type": "databricks", + }, + { + "bucket": "chodata-data-lake", + "name": "output_s3_path", + "path": "analytics_warehouse/data/marts/analytics_engineering/model1", + "type": "s3", + }, + { + "name": "analytics_engineering__model1_athena", + "schema": "analytics_engineering", + "table": "model1", + "type": "athena", + }, +] + +DATABRICKS_EXPECTED_DBT_STAGING_MODEL_DAGGER_OUTPUTS = [ + { + "type": "databricks", + "catalog": "hive_metastore", + "schema": "data_preparation", + "table": "stg_core_schema1__table1", + "name": "hive_metastore__data_preparation__stg_core_schema1__table1_databricks", + }, + { + "type": "s3", + "name": "output_s3_path", + "bucket": "chodata-data-lake", + "path": "analytics_warehouse/data/preparation/stg_core_schema1__table1", + }, + { + 'type': 'athena', + 'schema': 'data_preparation', + 'table': 'stg_core_schema1__table1', + 'name': 'data_preparation__stg_core_schema1__table1_athena' + } +] diff --git a/tests/fixtures/pipeline/ios/databricks_io.yaml b/tests/fixtures/pipeline/ios/databricks_io.yaml new file mode 100644 index 0000000..a8d5914 --- /dev/null +++ b/tests/fixtures/pipeline/ios/databricks_io.yaml @@ -0,0 +1,11 @@ +type: databricks +name: test +catalog: test_catalog +schema: test_schema +table: test_table + + + +# Other attributes: + +# has_dependency: # Weather this i/o should be added to the dependency graph or not. Default is True \ No newline at end of file diff --git a/tests/fixtures/pipeline/ios/dynamo_io.yaml b/tests/fixtures/pipeline/ios/dynamo_io.yaml new file mode 100644 index 0000000..d083171 --- /dev/null +++ b/tests/fixtures/pipeline/ios/dynamo_io.yaml @@ -0,0 +1,11 @@ +type: dynamo +name: dynamo_table +table: schema.table_name # The name of the dynamo table +region_name: eu_west_1 + + +# Other attributes: + +# has_dependency: # Weather this i/o should be added to the dependency graph or not. Default is True +# follow_external_dependency: # format: dictionary or boolean | External Task Sensor parameters in key value format: https://airflow.apache.org/docs/apache-airflow/stable/_api/airflow/sensors/base/index.html +# region_name: # Only needed for cross region dynamo tables \ No newline at end of file diff --git a/tests/fixtures/pipeline/ios/sns_io.yaml b/tests/fixtures/pipeline/ios/sns_io.yaml new file mode 100644 index 0000000..542fd8e --- /dev/null +++ b/tests/fixtures/pipeline/ios/sns_io.yaml @@ -0,0 +1,11 @@ +type: sns +name: topic_name +sns_topic: topic_name # The name of the dynamo table +region_name: eu_west_1 + + +# Other attributes: + +# has_dependency: # Weather this i/o should be added to the dependency graph or not. Default is True +# follow_external_dependency: # format: dictionary or boolean | External Task Sensor parameters in key value format: https://airflow.apache.org/docs/apache-airflow/stable/_api/airflow/sensors/base/index.html +# region_name: # Only needed for cross region dynamo tables \ No newline at end of file diff --git a/tests/fixtures/plugins/sample_folder/sample_folder_plugin.py b/tests/fixtures/plugins/sample_folder/sample_folder_plugin.py new file mode 100644 index 0000000..c8b931e --- /dev/null +++ b/tests/fixtures/plugins/sample_folder/sample_folder_plugin.py @@ -0,0 +1,5 @@ +class SampleFolderPlugin: + @staticmethod + def get_inputs(): + return [{"name": "sample_folder_plugin_task", "type": "dummy"}] + diff --git a/tests/fixtures/plugins/sample_plugin.py b/tests/fixtures/plugins/sample_plugin.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/pipeline/ios/test_databricks_io.py b/tests/pipeline/ios/test_databricks_io.py new file mode 100644 index 0000000..b1d0c45 --- /dev/null +++ b/tests/pipeline/ios/test_databricks_io.py @@ -0,0 +1,17 @@ +import unittest +from dagger.pipeline.io_factory import databricks_io + +import yaml + + +class DbIOTest(unittest.TestCase): + def setUp(self) -> None: + with open('tests/fixtures/pipeline/ios/databricks_io.yaml', "r") as stream: + config = yaml.safe_load(stream) + + self.db_io = databricks_io.DatabricksIO(config, "/") + + def test_properties(self): + self.assertEqual(self.db_io.alias(), "databricks://test_catalog/test_schema/test_table") + self.assertEqual(self.db_io.rendered_name, "test_catalog.test_schema.test_table") + self.assertEqual(self.db_io.airflow_name, "databricks-test_catalog-test_schema-test_table") diff --git a/tests/pipeline/ios/test_dynamo_io.py b/tests/pipeline/ios/test_dynamo_io.py new file mode 100644 index 0000000..0633c8b --- /dev/null +++ b/tests/pipeline/ios/test_dynamo_io.py @@ -0,0 +1,19 @@ +import unittest +from dagger.pipeline.ios.dynamo_io import DynamoIO +import yaml + + +class DynamoIOTest(unittest.TestCase): + def setUp(self) -> None: + with open("tests/fixtures/pipeline/ios/dynamo_io.yaml", "r") as stream: + config = yaml.safe_load(stream) + + self.dynamo_io = DynamoIO(config, "/") + + def test_properties(self): + self.assertEqual(self.dynamo_io.alias(), "dynamo://eu_west_1/schema.table_name") + self.assertEqual(self.dynamo_io.rendered_name, "schema.table_name") + self.assertEqual(self.dynamo_io.airflow_name,"dynamo-eu_west_1-schema.table_name") + +if __name__ == "__main__": + unittest.main() diff --git a/tests/pipeline/ios/test_sns_io.py b/tests/pipeline/ios/test_sns_io.py new file mode 100644 index 0000000..2a5de25 --- /dev/null +++ b/tests/pipeline/ios/test_sns_io.py @@ -0,0 +1,19 @@ +import unittest +from dagger.pipeline.ios.sns_io import SnsIO +import yaml + + +class SnsIOTest(unittest.TestCase): + def setUp(self) -> None: + with open("tests/fixtures/pipeline/ios/sns_io.yaml", "r") as stream: + config = yaml.safe_load(stream) + + self.sns_io = SnsIO(config, "/") + + def test_properties(self): + self.assertEqual(self.sns_io.alias(), f"sns://eu_west_1/topic_name") + self.assertEqual(self.sns_io.rendered_name, "topic_name") + self.assertEqual(self.sns_io.airflow_name, "sns-eu_west_1-topic_name") + +if __name__ == "__main__": + unittest.main() diff --git a/tests/utilities/test_dbt_config_parser.py b/tests/utilities/test_dbt_config_parser.py new file mode 100644 index 0000000..d4d9028 --- /dev/null +++ b/tests/utilities/test_dbt_config_parser.py @@ -0,0 +1,171 @@ +import logging +import unittest +from unittest import skip +from unittest.mock import patch, MagicMock + +from dagger.utilities.dbt_config_parser import ( + AthenaDBTConfigParser, + DatabricksDBTConfigParser, +) +from dagger.utilities.module import Module +from tests.fixtures.modules.dbt_config_parser_fixtures_athena import * +from tests.fixtures.modules.dbt_config_parser_fixtures_databricks import * + +_logger = logging.getLogger("root") + +DEFAULT_CONFIG_PARAMS = { + "data_bucket": "bucket1-data-lake", + "project_dir": "main", + "profile_dir": ".dbt", + "profile_name": "athena", + "target_name": "data", +} +DATABRICKS_DEFAULT_CONFIG_PARAMS = { + "project_dir": "main", + "profile_dir": ".dbt", + "profile_name": "databricks", + "target_name": "data", + "create_external_athena_table": True, +} +MODEL_NAME = "model1" + + +class TestAthenaDBTConfigParser(unittest.TestCase): + @patch("builtins.open", new_callable=MagicMock, read_data=DBT_MANIFEST_FILE_FIXTURE) + @patch("json.loads", return_value=DBT_MANIFEST_FILE_FIXTURE) + @patch("yaml.safe_load", return_value=DBT_PROFILE_FIXTURE) + def setUp(self, mock_open, mock_json_load, mock_safe_load): + self._dbt_config_parser = AthenaDBTConfigParser(DEFAULT_CONFIG_PARAMS) + self._sample_dbt_node = DBT_MANIFEST_FILE_FIXTURE["nodes"]["model.main.model1"] + + @skip("Run only locally") + def test_generate_task_configs(self): + module = Module( + path_to_config="./tests/fixtures/modules/dbt_test_config.yaml", + target_dir="./tests/fixtures/modules/", + ) + + module.generate_task_configs() + + def test_generate_dagger_tasks(self): + test_inputs = [ + ( + "seed.main.seed_buyer_country_overwrite", + EXPECTED_SEED_NODE, + ), + ( + "model.main.int_model3", + EXPECTED_EPHEMERAL_NODE, + ), + ( + "model.main.model1", + EXPECTED_MODEL_NODE, + ), + ] + for mock_input, expected_output in test_inputs: + result = self._dbt_config_parser._generate_dagger_tasks(mock_input) + self.assertListEqual(result, expected_output) + + def test_generate_io_inputs(self): + fixtures = [ + ("model1", EXPECTED_DAGGER_INPUTS), + ( + "model3", + EXPECTED_MODEL_MULTIPLE_DEPENDENCIES, + ), + ("stg_core_schema2__table2", EXPECTED_DBT_STAGING_MODEL_DAGGER_INPUTS), + ("int_model2", EXPECTED_DBT_INT_MODEL_DAGGER_INPUTS), + ] + for mock_input, expected_output in fixtures: + result, _ = self._dbt_config_parser.generate_dagger_io(mock_input) + print(f"result: {result}") + print(f"expected_output: {expected_output}") + self.assertListEqual(result, expected_output) + + def test_generate_io_outputs(self): + fixtures = [ + ("model1", EXPECTED_DAGGER_OUTPUTS), + ("stg_core_schema1__table1", EXPECTED_DBT_STAGING_MODEL_DAGGER_OUTPUTS), + ] + for mock_input, expected_output in fixtures: + _, result = self._dbt_config_parser.generate_dagger_io(mock_input) + + self.assertListEqual(result, expected_output) + + +class TestDatabricksDBTConfigParser(unittest.TestCase): + @patch( + "builtins.open", + new_callable=MagicMock, + read_data=DATABRICKS_DBT_MANIFEST_FILE_FIXTURE, + ) + @patch("json.loads", return_value=DATABRICKS_DBT_MANIFEST_FILE_FIXTURE) + @patch("yaml.safe_load", return_value=DATABRICKS_DBT_PROFILE_FIXTURE) + def setUp(self, mock_open, mock_json_load, mock_safe_load): + self._dbt_config_parser = DatabricksDBTConfigParser( + DATABRICKS_DEFAULT_CONFIG_PARAMS + ) + self._sample_dbt_node = DATABRICKS_DBT_MANIFEST_FILE_FIXTURE["nodes"][ + "model.main.model1" + ] + + @skip("Run only locally") + def test_generate_task_configs(self): + module = Module( + path_to_config="./tests/fixtures/modules/dbt_test_config.yaml", + target_dir="./tests/fixtures/modules/", + ) + + module.generate_task_configs() + + def test_generate_dagger_tasks(self): + test_inputs = [ + ( + "model.main.stg_core_schema1__table1", + DATABRICKS_EXPECTED_STAGING_NODE, + ), + ( + "seed.main.seed_buyer_country_overwrite", + DATABRICKS_EXPECTED_SEED_NODE, + ), + ( + "model.main.int_model3", + DATABRICKS_EXPECTED_EPHEMERAL_NODE, + ), + ( + "model.main.model1", + DATABRICKS_EXPECTED_MODEL_NODE, + ), + ] + for mock_input, expected_output in test_inputs: + result = self._dbt_config_parser._generate_dagger_tasks(mock_input) + self.assertListEqual(result, expected_output) + + def test_generate_io_inputs(self): + fixtures = [ + ("model1", DATABRICKS_EXPECTED_DAGGER_INPUTS), + ( + "model3", + DATABRICKS_EXPECTED_MODEL_MULTIPLE_DEPENDENCIES, + ), + ( + "stg_core_schema2__table2", + DATABRICKS_EXPECTED_DBT_STAGING_MODEL_DAGGER_INPUTS, + ), + ("int_model2", DATABRICKS_EXPECTED_DBT_INT_MODEL_DAGGER_INPUTS), + ] + for mock_input, expected_output in fixtures: + result, _ = self._dbt_config_parser.generate_dagger_io(mock_input) + self.assertListEqual(result, expected_output) + + def test_generate_io_outputs(self): + fixtures = [ + ("model1", DATABRICKS_EXPECTED_DAGGER_OUTPUTS), + ( + "stg_core_schema1__table1", + DATABRICKS_EXPECTED_DBT_STAGING_MODEL_DAGGER_OUTPUTS, + ), + ] + for mock_input, expected_output in fixtures: + _, result = self._dbt_config_parser.generate_dagger_io(mock_input) + self.assertListEqual(result, expected_output) diff --git a/tests/utilities/test_plugins.py b/tests/utilities/test_plugins.py new file mode 100644 index 0000000..353a407 --- /dev/null +++ b/tests/utilities/test_plugins.py @@ -0,0 +1,43 @@ +import unittest +from pathlib import Path +from unittest.mock import patch + +import jinja2 + +from dagger.utilities.module import Module + +TESTS_ROOT = Path(__file__).parent.parent + +class TestLoadPlugins(unittest.TestCase): + + def setUp(self): + self._jinja_environment = jinja2.Environment() + self._template = self._jinja_environment.from_string("inputs: {{ SampleFolderPlugin.get_inputs() }}") + + @patch("dagger.conf.PLUGIN_DIRS", new=[]) + @patch("os.walk") + def test_load_plugins_no_plugin_dir(self, mock_os_walk): + # Simulate os.walk returning no Python files + mock_os_walk.return_value = [("/fake/plugin/dir", [], [])] + + result_environment = Module.load_plugins_to_jinja_environment(self._jinja_environment) + + self.assertNotIn("SampleFolderPlugin", result_environment.globals) + + @patch("dagger.conf.PLUGIN_DIRS", new=[str(TESTS_ROOT.joinpath("fixtures/plugins"))]) + def test_load_plugins(self): + result_environment = Module.load_plugins_to_jinja_environment(self._jinja_environment) + + self.assertIn("SampleFolderPlugin", result_environment.globals.keys()) + + @patch("dagger.conf.PLUGIN_DIRS", new=[str(TESTS_ROOT.joinpath("fixtures/plugins"))]) + def test_load_plugins_render_jinja(self): + result_environment = Module.load_plugins_to_jinja_environment(self._jinja_environment) + + rendered_task = self._template.render() + expected_task = "inputs: [{'name': 'sample_folder_plugin_task', 'type': 'dummy'}]" + + self.assertEqual(rendered_task, expected_task) + +if __name__ == '__main__': + unittest.main() \ No newline at end of file