Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions metaflow/plugins/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,15 @@
# Add AWS client providers here
AWS_CLIENT_PROVIDERS_DESC = [("boto3", ".aws.aws_client.Boto3ClientProvider")]

# Add Airflow sensor related flow decorators
SENSOR_FLOW_DECORATORS = [
("airflow_external_task_sensor", ".airflow.sensors.ExternalTaskSensorDecorator"),
("airflow_sql_sensor", ".airflow.sensors.SQLSensorDecorator"),
("airflow_s3_key_sensor", ".airflow.sensors.S3KeySensorDecorator"),
]

FLOW_DECORATORS_DESC += SENSOR_FLOW_DECORATORS

process_plugins(globals())


Expand Down
18 changes: 18 additions & 0 deletions metaflow/plugins/airflow/airflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@

from . import airflow_utils
from .exception import AirflowException
from .sensors import SUPPORTED_SENSORS
from .airflow_utils import (
TASK_ID_XCOM_KEY,
AirflowTask,
Expand Down Expand Up @@ -88,6 +89,7 @@ def __init__(
self.username = username
self.max_workers = max_workers
self.description = description
self._depends_on_upstream_sensors = False
self._file_path = file_path
_, self.graph_structure = self.graph.output_steps()
self.worker_pool = worker_pool
Expand Down Expand Up @@ -585,6 +587,17 @@ def _step_cli(self, node, paths, code_package_url, user_code_retries):
cmds.append(" ".join(entrypoint + top_level + step))
return cmds

def _collect_flow_sensors(self):
decos_lists = [
self.flow._flow_decorators.get(s.name)
for s in SUPPORTED_SENSORS
if self.flow._flow_decorators.get(s.name) is not None
]
af_tasks = [deco.create_task() for decos in decos_lists for deco in decos]
if len(af_tasks) > 0:
self._depends_on_upstream_sensors = True
return af_tasks

def _contains_foreach(self):
for node in self.graph:
if node.type == "foreach":
Expand Down Expand Up @@ -639,6 +652,7 @@ def _visit(node, workflow, exit_node=None):
if self.workflow_timeout is not None and self.schedule is not None:
airflow_dag_args["dagrun_timeout"] = dict(seconds=self.workflow_timeout)

appending_sensors = self._collect_flow_sensors()
workflow = Workflow(
dag_id=self.name,
default_args=self._create_defaults(),
Expand All @@ -659,6 +673,10 @@ def _visit(node, workflow, exit_node=None):
workflow = _visit(self.graph["start"], workflow)

workflow.set_parameters(self.parameters)
if len(appending_sensors) > 0:
for s in appending_sensors:
workflow.add_state(s)
workflow.graph_structure.insert(0, [[s.name] for s in appending_sensors])
return self._to_airflow_dag_file(workflow.to_dict())

def _to_airflow_dag_file(self, json_dag):
Expand Down
10 changes: 2 additions & 8 deletions metaflow/plugins/airflow/airflow_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,7 +322,6 @@ def make_flow(


def _validate_foreach_constraints(graph):
# Todo :Invoke this function when we integrate `foreach`s
def traverse_graph(node, state):
if node.type == "foreach" and node.is_inside_foreach:
raise NotSupportedException(
Expand All @@ -338,7 +337,7 @@ def traverse_graph(node, state):
if node.type == "linear" and node.is_inside_foreach:
state["foreach_stack"].append(node.name)

if len(state["foreach_stack"]) > 2:
if "foreach_stack" in state and len(state["foreach_stack"]) > 2:
raise NotSupportedException(
"The foreach step *%s* created by step *%s* needs to have an immediate join step. "
"Step *%s* is invalid since it is a linear step with a foreach. "
Expand Down Expand Up @@ -378,18 +377,13 @@ def _validate_workflow(flow, graph, flow_datastore, metadata, workflow_timeout):
"A default value is required for parameters when deploying flows on Airflow."
)
# check for other compute related decorators.
_validate_foreach_constraints(graph)
for node in graph:
if node.parallel_foreach:
raise AirflowException(
"Deploying flows with @parallel decorator(s) "
"to Airflow is not supported currently."
)

if node.type == "foreach":
raise NotSupportedException(
"Step *%s* is a foreach step and Foreach steps are not currently supported with Airflow."
% node.name
)
if any([d.name == "batch" for d in node.decorators]):
raise NotSupportedException(
"Step *%s* is marked for execution on AWS Batch with Airflow which isn't currently supported."
Expand Down
63 changes: 63 additions & 0 deletions metaflow/plugins/airflow/airflow_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,10 @@ class IncompatibleKubernetesProviderVersionException(Exception):
) % (sys.executable, KUBERNETES_PROVIDER_FOREACH_VERSION)


class AirflowSensorNotFound(Exception):
headline = "Sensor package not found"


def create_absolute_version_number(version):
abs_version = None
# For all digits
Expand Down Expand Up @@ -189,6 +193,16 @@ def pathspec(cls, flowname, is_foreach=False):
)


class SensorNames:
EXTERNAL_TASK_SENSOR = "ExternalTaskSensor"
S3_SENSOR = "S3KeySensor"
SQL_SENSOR = "SQLSensor"

@classmethod
def get_supported_sensors(cls):
return list(cls.__dict__.values())


def run_id_creator(val):
# join `[dag-id,run-id]` of airflow dag.
return hashlib.md5("-".join([str(x) for x in val]).encode("utf-8")).hexdigest()[
Expand Down Expand Up @@ -375,6 +389,46 @@ def _kubernetes_pod_operator_args(operator_args):
return args


def _parse_sensor_args(name, kwargs):
if name == SensorNames.EXTERNAL_TASK_SENSOR:
if "execution_delta" in kwargs:
if type(kwargs["execution_delta"]) == dict:
kwargs["execution_delta"] = timedelta(**kwargs["execution_delta"])
else:
del kwargs["execution_delta"]
return kwargs


def _get_sensor(name):
# from airflow import XComArg
# XComArg()
if name == SensorNames.EXTERNAL_TASK_SENSOR:
# ExternalTaskSensors uses an execution_date of a dag to
# determine the appropriate DAG.
# This is set to the exact date the current dag gets executed on.
# For example if "DagA" (Upstream DAG) got scheduled at
# 12 Jan 4:00 PM PDT then "DagB"(current DAG)'s task sensor will try to
# look for a "DagA" that got executed at 12 Jan 4:00 PM PDT **exactly**.
# They also support a `execution_timeout` argument to
from airflow.sensors.external_task_sensor import ExternalTaskSensor

return ExternalTaskSensor
elif name == SensorNames.S3_SENSOR:
try:
from airflow.providers.amazon.aws.sensors.s3 import S3KeySensor
except ImportError:
raise AirflowSensorNotFound(
"This DAG requires a `S3KeySensor`. "
"Install the Airflow AWS provider using : "
"`pip install apache-airflow-providers-amazon`"
)
return S3KeySensor
elif name == SensorNames.SQL_SENSOR:
from airflow.sensors.sql import SqlSensor

return SqlSensor


def get_metaflow_kubernetes_operator():
try:
from airflow.contrib.operators.kubernetes_pod_operator import (
Expand Down Expand Up @@ -493,6 +547,13 @@ def set_operator_args(self, **kwargs):
self._operator_args = kwargs
return self

def _make_sensor(self):
TaskSensor = _get_sensor(self._operator_type)
return TaskSensor(
task_id=self.name,
**_parse_sensor_args(self._operator_type, self._operator_args)
)

def to_dict(self):
return {
"name": self.name,
Expand Down Expand Up @@ -541,6 +602,8 @@ def to_task(self):
return self._kubernetes_task()
else:
return self._kubernetes_mapper_task()
elif self._operator_type in SensorNames.get_supported_sensors():
return self._make_sensor()


class Workflow(object):
Expand Down
9 changes: 9 additions & 0 deletions metaflow/plugins/airflow/sensors/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from .external_task_sensor import ExternalTaskSensorDecorator
from .s3_sensor import S3KeySensorDecorator
from .sql_sensor import SQLSensorDecorator

SUPPORTED_SENSORS = [
ExternalTaskSensorDecorator,
S3KeySensorDecorator,
SQLSensorDecorator,
]
74 changes: 74 additions & 0 deletions metaflow/plugins/airflow/sensors/base_sensor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
import uuid
from metaflow.decorators import FlowDecorator
from ..exception import AirflowException
from ..airflow_utils import AirflowTask, id_creator, TASK_ID_HASH_LEN


class AirflowSensorDecorator(FlowDecorator):
"""
Base class for all Airflow sensor decorators.
"""

allow_multiple = True

defaults = dict(
timeout=3600,
poke_interval=60,
mode="reschedule",
exponential_backoff=True,
pool=None,
soft_fail=False,
name=None,
description=None,
)

operator_type = None

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._airflow_task_name = None
self._id = str(uuid.uuid4())

def serialize_operator_args(self):
"""
Subclasses will parse the decorator arguments to
Airflow task serializable arguments.
"""
task_args = dict(**self.attributes)
del task_args["name"]
if task_args["description"] is not None:
task_args["doc"] = task_args["description"]
del task_args["description"]
task_args["do_xcom_push"] = True
return task_args

def create_task(self):
task_args = self.serialize_operator_args()
return AirflowTask(
self._airflow_task_name,
operator_type=self.operator_type,
).set_operator_args(**{k: v for k, v in task_args.items() if v is not None})

def validate(self):
"""
Validate if the arguments for the sensor are correct.
"""
# If there is no name set then auto-generate the name. This is done because there can be more than
# one `AirflowSensorDecorator` of the same type.
if self.attributes["name"] is None:
deco_index = [
d._id
for d in self._flow_decorators
if issubclass(d.__class__, AirflowSensorDecorator)
].index(self._id)
self._airflow_task_name = "%s-%s" % (
self.operator_type,
id_creator([self.operator_type, str(deco_index)], TASK_ID_HASH_LEN),
)
else:
self._airflow_task_name = self.attributes["name"]

def flow_init(
self, flow, graph, environment, flow_datastore, metadata, logger, echo, options
):
self.validate()
94 changes: 94 additions & 0 deletions metaflow/plugins/airflow/sensors/external_task_sensor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
from .base_sensor import AirflowSensorDecorator
from ..airflow_utils import SensorNames
from ..exception import AirflowException
from datetime import timedelta


AIRFLOW_STATES = dict(
QUEUED="queued",
RUNNING="running",
SUCCESS="success",
SHUTDOWN="shutdown", # External request to shut down,
FAILED="failed",
UP_FOR_RETRY="up_for_retry",
UP_FOR_RESCHEDULE="up_for_reschedule",
UPSTREAM_FAILED="upstream_failed",
SKIPPED="skipped",
)


class ExternalTaskSensorDecorator(AirflowSensorDecorator):
operator_type = SensorNames.EXTERNAL_TASK_SENSOR
# Docs:
# https://airflow.apache.org/docs/apache-airflow/stable/_api/airflow/sensors/external_task/index.html#airflow.sensors.external_task.ExternalTaskSensor
name = "airflow_external_task_sensor"
defaults = dict(
**AirflowSensorDecorator.defaults,
external_dag_id=None,
external_task_ids=None,
allowed_states=[AIRFLOW_STATES["SUCCESS"]],
failed_states=None,
execution_delta=None,
check_existence=True,
# We cannot add `execution_date_fn` as it requires a python callable.
# Passing around a python callable is non-trivial since we are passing a
# callable from metaflow-code to airflow python script. Since we cannot
# transfer dependencies of the callable, we cannot gaurentee that the callable
# behave exactly as the user expects
)

def serialize_operator_args(self):
task_args = super().serialize_operator_args()
if task_args["execution_delta"] is not None:
task_args["execution_delta"] = dict(
seconds=task_args["execution_delta"].total_seconds()
)
return task_args

def validate(self):
if self.attributes["external_dag_id"] is None:
raise AirflowException(
"`%s` argument of `@%s`cannot be `None`."
% ("external_dag_id", self.name)
)

if type(self.attributes["allowed_states"]) == str:
if self.attributes["allowed_states"] not in list(AIRFLOW_STATES.values()):
raise AirflowException(
"`%s` is an invalid input of `%s` for `@%s`. Accepted values are %s"
% (
str(self.attributes["allowed_states"]),
"allowed_states",
self.name,
", ".join(list(AIRFLOW_STATES.values())),
)
)
elif type(self.attributes["allowed_states"]) == list:
enum_not_matched = [
x
for x in self.attributes["allowed_states"]
if x not in list(AIRFLOW_STATES.values())
]
if len(enum_not_matched) > 0:
raise AirflowException(
"`%s` is an invalid input of `%s` for `@%s`. Accepted values are %s"
% (
str(" OR ".join(["'%s'" % i for i in enum_not_matched])),
"allowed_states",
self.name,
", ".join(list(AIRFLOW_STATES.values())),
)
)
else:
self.attributes["allowed_states"] = [AIRFLOW_STATES["SUCCESS"]]

if self.attributes["execution_delta"] is not None:
if not isinstance(self.attributes["execution_delta"], timedelta):
raise AirflowException(
"`%s` is an invalid input type of `execution_delta` for `@%s`. Accepted type is `datetime.timedelta`"
% (
str(type(self.attributes["execution_delta"])),
self.name,
)
)
super().validate()
Loading