Skip to content
Merged
51 changes: 38 additions & 13 deletions metaflow/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,8 @@ class Decorator(object):

name = "NONAME"
defaults = {}
# `allow_multiple` allows setting many decorators of the same type to a step/flow.
allow_multiple = False

def __init__(self, attributes=None, statically_defined=False):
self.attributes = self.defaults.copy()
Expand Down Expand Up @@ -255,9 +257,6 @@ class MyDecorator(StepDecorator):
pass them around with every lifecycle call.
"""

# `allow_multiple` allows setting many decorators of the same type to a step.
allow_multiple = False

def step_init(
self, flow, graph, step_name, decorators, environment, flow_datastore, logger
):
Expand Down Expand Up @@ -402,13 +401,12 @@ def _base_flow_decorator(decofunc, *args, **kwargs):
cls = args[0]
if isinstance(cls, type) and issubclass(cls, FlowSpec):
# flow decorators add attributes in the class dictionary,
# _flow_decorators.
if decofunc.name in cls._flow_decorators:
# _flow_decorators. _flow_decorators is of type `{key:[decos]}`
if decofunc.name in cls._flow_decorators and not decofunc.allow_multiple:
raise DuplicateFlowDecoratorException(decofunc.name)
else:
cls._flow_decorators[decofunc.name] = decofunc(
attributes=kwargs, statically_defined=True
)
deco_instance = decofunc(attributes=kwargs, statically_defined=True)
cls._flow_decorators.setdefault(decofunc.name, []).append(deco_instance)
else:
raise BadFlowDecoratorException(decofunc.name)
return cls
Expand Down Expand Up @@ -503,11 +501,38 @@ def _attach_decorators_to_step(step, decospecs):
def _init_flow_decorators(
flow, graph, environment, flow_datastore, metadata, logger, echo, deco_options
):
for deco in flow._flow_decorators.values():
opts = {option: deco_options[option] for option in deco.options}
deco.flow_init(
flow, graph, environment, flow_datastore, metadata, logger, echo, opts
)
# Since all flow decorators are stored as `{key:[deco]}` we iterate through each of them.
for decorators in flow._flow_decorators.values():
# First resolve the `options` for the flow decorator.
# Options are passed from cli.
# For example `@project` can take a `--name` / `--branch` from the cli as options.
deco_flow_init_options = {}
deco = decorators[0]
# If a flow decorator allow multiple of same type then we don't allow multiple options for it.
if deco.allow_multiple:
if len(deco.options) > 0:
raise MetaflowException(
"Flow decorator `@%s` has multiple options, which is not allowed. "
"Please ensure the FlowDecorator `%s` has no options since flow decorators with "
"`allow_mutiple=True` are not allowed to have options"
% (deco.name, deco.__class__.__name__)
)
else:
# Each "non-multiple" flow decorator is only allowed to have one set of options
deco_flow_init_options = {
option: deco_options[option] for option in deco.options
}
for deco in decorators:
deco.flow_init(
flow,
graph,
environment,
flow_datastore,
metadata,
logger,
echo,
deco_flow_init_options,
)


def _init_step_decorators(flow, graph, environment, flow_datastore, logger):
Expand Down
4 changes: 4 additions & 0 deletions metaflow/metaflow_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,6 +264,10 @@
)
# This configuration sets `kubernetes_conn_id` in airflow's KubernetesPodOperator.
AIRFLOW_KUBERNETES_CONN_ID = from_conf("AIRFLOW_KUBERNETES_CONN_ID")
AIRFLOW_KUBERNETES_KUBECONFIG_FILE = from_conf("AIRFLOW_KUBERNETES_KUBECONFIG_FILE")
AIRFLOW_KUBERNETES_KUBECONFIG_CONTEXT = from_conf(
"AIRFLOW_KUBERNETES_KUBECONFIG_CONTEXT"
)


###
Expand Down
8 changes: 8 additions & 0 deletions metaflow/plugins/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,14 @@
# Add AWS client providers here
AWS_CLIENT_PROVIDERS_DESC = [("boto3", ".aws.aws_client.Boto3ClientProvider")]

# Add Airflow sensor related flow decorators
SENSOR_FLOW_DECORATORS = [
("airflow_external_task_sensor", ".airflow.sensors.ExternalTaskSensorDecorator"),
("airflow_s3_key_sensor", ".airflow.sensors.S3KeySensorDecorator"),
]

FLOW_DECORATORS_DESC += SENSOR_FLOW_DECORATORS

process_plugins(globals())


Expand Down
38 changes: 35 additions & 3 deletions metaflow/plugins/airflow/airflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,10 @@
DATASTORE_SYSROOT_AZURE,
CARD_AZUREROOT,
AIRFLOW_KUBERNETES_CONN_ID,
AIRFLOW_KUBERNETES_KUBECONFIG_CONTEXT,
AIRFLOW_KUBERNETES_KUBECONFIG_FILE,
DATASTORE_SYSROOT_GS,
CARD_GSROOT,
)
from metaflow.parameters import DelayedEvaluationParameter, deploy_time_eval
from metaflow.plugins.kubernetes.kubernetes import Kubernetes
Expand All @@ -35,6 +39,7 @@

from . import airflow_utils
from .exception import AirflowException
from .sensors import SUPPORTED_SENSORS
from .airflow_utils import (
TASK_ID_XCOM_KEY,
AirflowTask,
Expand Down Expand Up @@ -88,6 +93,7 @@ def __init__(
self.username = username
self.max_workers = max_workers
self.description = description
self._depends_on_upstream_sensors = False
self._file_path = file_path
_, self.graph_structure = self.graph.output_steps()
self.worker_pool = worker_pool
Expand Down Expand Up @@ -140,6 +146,7 @@ def _get_schedule(self):
schedule = self.flow._flow_decorators.get("schedule")
if not schedule:
return None
schedule = schedule[0]
if schedule.attributes["cron"]:
return schedule.attributes["cron"]
elif schedule.attributes["weekly"]:
Expand Down Expand Up @@ -358,7 +365,7 @@ def _to_job(self, node):
"METAFLOW_SERVICE_HEADERS": json.dumps(SERVICE_HEADERS),
"METAFLOW_DATASTORE_SYSROOT_S3": DATASTORE_SYSROOT_S3,
"METAFLOW_DATATOOLS_S3ROOT": DATATOOLS_S3ROOT,
"METAFLOW_DEFAULT_DATASTORE": "s3",
"METAFLOW_DEFAULT_DATASTORE": self.flow_datastore.TYPE,
"METAFLOW_DEFAULT_METADATA": "service",
"METAFLOW_KUBERNETES_WORKLOAD": str(
1
Expand All @@ -373,6 +380,9 @@ def _to_job(self, node):
"METAFLOW_AIRFLOW_JOB_ID": AIRFLOW_MACROS.AIRFLOW_JOB_ID,
"METAFLOW_PRODUCTION_TOKEN": self.production_token,
"METAFLOW_ATTEMPT_NUMBER": AIRFLOW_MACROS.ATTEMPT,
# GCP stuff
"METAFLOW_DATASTORE_SYSROOT_GS": DATASTORE_SYSROOT_GS,
"METAFLOW_CARD_GSROOT": CARD_GSROOT,
}
env[
"METAFLOW_AZURE_STORAGE_BLOB_SERVICE_ENDPOINT"
Expand Down Expand Up @@ -460,10 +470,16 @@ def _to_job(self, node):
reattach_on_restart=False,
secrets=[],
)
k8s_operator_args["in_cluster"] = True
if AIRFLOW_KUBERNETES_CONN_ID is not None:
k8s_operator_args["kubernetes_conn_id"] = AIRFLOW_KUBERNETES_CONN_ID
else:
k8s_operator_args["in_cluster"] = True
k8s_operator_args["in_cluster"] = False
if AIRFLOW_KUBERNETES_KUBECONFIG_CONTEXT is not None:
k8s_operator_args["cluster_context"] = AIRFLOW_KUBERNETES_KUBECONFIG_CONTEXT
k8s_operator_args["in_cluster"] = False
if AIRFLOW_KUBERNETES_KUBECONFIG_FILE is not None:
k8s_operator_args["config_file"] = AIRFLOW_KUBERNETES_KUBECONFIG_FILE
k8s_operator_args["in_cluster"] = False

if k8s_deco.attributes["secrets"]:
if isinstance(k8s_deco.attributes["secrets"], str):
Expand Down Expand Up @@ -584,6 +600,17 @@ def _step_cli(self, node, paths, code_package_url, user_code_retries):
cmds.append(" ".join(entrypoint + top_level + step))
return cmds

def _collect_flow_sensors(self):
decos_lists = [
self.flow._flow_decorators.get(s.name)
for s in SUPPORTED_SENSORS
if self.flow._flow_decorators.get(s.name) is not None
]
af_tasks = [deco.create_task() for decos in decos_lists for deco in decos]
if len(af_tasks) > 0:
self._depends_on_upstream_sensors = True
return af_tasks

def _contains_foreach(self):
for node in self.graph:
if node.type == "foreach":
Expand Down Expand Up @@ -638,6 +665,7 @@ def _visit(node, workflow, exit_node=None):
if self.workflow_timeout is not None and self.schedule is not None:
airflow_dag_args["dagrun_timeout"] = dict(seconds=self.workflow_timeout)

appending_sensors = self._collect_flow_sensors()
workflow = Workflow(
dag_id=self.name,
default_args=self._create_defaults(),
Expand All @@ -658,6 +686,10 @@ def _visit(node, workflow, exit_node=None):
workflow = _visit(self.graph["start"], workflow)

workflow.set_parameters(self.parameters)
if len(appending_sensors) > 0:
for s in appending_sensors:
workflow.add_state(s)
workflow.graph_structure.insert(0, [[s.name] for s in appending_sensors])
return self._to_airflow_dag_file(workflow.to_dict())

def _to_airflow_dag_file(self, json_dag):
Expand Down
29 changes: 19 additions & 10 deletions metaflow/plugins/airflow/airflow_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,7 +322,6 @@ def make_flow(


def _validate_foreach_constraints(graph):
# Todo :Invoke this function when we integrate `foreach`s
def traverse_graph(node, state):
if node.type == "foreach" and node.is_inside_foreach:
raise NotSupportedException(
Expand All @@ -338,7 +337,7 @@ def traverse_graph(node, state):
if node.type == "linear" and node.is_inside_foreach:
state["foreach_stack"].append(node.name)

if len(state["foreach_stack"]) > 2:
if "foreach_stack" in state and len(state["foreach_stack"]) > 2:
raise NotSupportedException(
"The foreach step *%s* created by step *%s* needs to have an immediate join step. "
"Step *%s* is invalid since it is a linear step with a foreach. "
Expand Down Expand Up @@ -378,27 +377,37 @@ def _validate_workflow(flow, graph, flow_datastore, metadata, workflow_timeout):
"A default value is required for parameters when deploying flows on Airflow."
)
# check for other compute related decorators.
_validate_foreach_constraints(graph)
for node in graph:
if node.parallel_foreach:
raise AirflowException(
"Deploying flows with @parallel decorator(s) "
"to Airflow is not supported currently."
)

if node.type == "foreach":
raise NotSupportedException(
"Step *%s* is a foreach step and Foreach steps are not currently supported with Airflow."
% node.name
)
if any([d.name == "batch" for d in node.decorators]):
raise NotSupportedException(
"Step *%s* is marked for execution on AWS Batch with Airflow which isn't currently supported."
% node.name
)
SUPPORTED_DATASTORES = ("azure", "s3", "gs")
if flow_datastore.TYPE not in SUPPORTED_DATASTORES:
raise AirflowException(
"Datastore type `%s` is not supported with `airflow create`. "
"Please choose from datastore of type %s when calling `airflow create`"
% (
str(flow_datastore.TYPE),
"or ".join(["`%s`" % x for x in SUPPORTED_DATASTORES]),
)
)

schedule = flow._flow_decorators.get("schedule")
if not schedule:
return

if flow_datastore.TYPE not in ("azure", "s3"):
schedule = schedule[0]
if schedule.timezone is not None:
raise AirflowException(
'Datastore of type "s3" or "azure" required with `airflow create`'
"`airflow create` does not support scheduling with `timezone`."
)


Expand Down
58 changes: 58 additions & 0 deletions metaflow/plugins/airflow/airflow_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,10 @@ class IncompatibleKubernetesProviderVersionException(Exception):
) % (sys.executable, KUBERNETES_PROVIDER_FOREACH_VERSION)


class AirflowSensorNotFound(Exception):
headline = "Sensor package not found"


def create_absolute_version_number(version):
abs_version = None
# For all digits
Expand Down Expand Up @@ -189,6 +193,15 @@ def pathspec(cls, flowname, is_foreach=False):
)


class SensorNames:
EXTERNAL_TASK_SENSOR = "ExternalTaskSensor"
S3_SENSOR = "S3KeySensor"

@classmethod
def get_supported_sensors(cls):
return list(cls.__dict__.values())


def run_id_creator(val):
# join `[dag-id,run-id]` of airflow dag.
return hashlib.md5("-".join([str(x) for x in val]).encode("utf-8")).hexdigest()[
Expand Down Expand Up @@ -375,6 +388,42 @@ def _kubernetes_pod_operator_args(operator_args):
return args


def _parse_sensor_args(name, kwargs):
if name == SensorNames.EXTERNAL_TASK_SENSOR:
if "execution_delta" in kwargs:
if type(kwargs["execution_delta"]) == dict:
kwargs["execution_delta"] = timedelta(**kwargs["execution_delta"])
else:
del kwargs["execution_delta"]
return kwargs


def _get_sensor(name):
# from airflow import XComArg
# XComArg()
if name == SensorNames.EXTERNAL_TASK_SENSOR:
# ExternalTaskSensors uses an execution_date of a dag to
# determine the appropriate DAG.
# This is set to the exact date the current dag gets executed on.
# For example if "DagA" (Upstream DAG) got scheduled at
# 12 Jan 4:00 PM PDT then "DagB"(current DAG)'s task sensor will try to
# look for a "DagA" that got executed at 12 Jan 4:00 PM PDT **exactly**.
# They also support a `execution_timeout` argument to
from airflow.sensors.external_task_sensor import ExternalTaskSensor

return ExternalTaskSensor
elif name == SensorNames.S3_SENSOR:
try:
from airflow.providers.amazon.aws.sensors.s3 import S3KeySensor
except ImportError:
raise AirflowSensorNotFound(
"This DAG requires a `S3KeySensor`. "
"Install the Airflow AWS provider using : "
"`pip install apache-airflow-providers-amazon`"
)
return S3KeySensor


def get_metaflow_kubernetes_operator():
try:
from airflow.contrib.operators.kubernetes_pod_operator import (
Expand Down Expand Up @@ -493,6 +542,13 @@ def set_operator_args(self, **kwargs):
self._operator_args = kwargs
return self

def _make_sensor(self):
TaskSensor = _get_sensor(self._operator_type)
return TaskSensor(
task_id=self.name,
**_parse_sensor_args(self._operator_type, self._operator_args)
)

def to_dict(self):
return {
"name": self.name,
Expand Down Expand Up @@ -541,6 +597,8 @@ def to_task(self):
return self._kubernetes_task()
else:
return self._kubernetes_mapper_task()
elif self._operator_type in SensorNames.get_supported_sensors():
return self._make_sensor()


class Workflow(object):
Expand Down
7 changes: 7 additions & 0 deletions metaflow/plugins/airflow/sensors/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
from .external_task_sensor import ExternalTaskSensorDecorator
from .s3_sensor import S3KeySensorDecorator

SUPPORTED_SENSORS = [
ExternalTaskSensorDecorator,
S3KeySensorDecorator,
]
Loading