From a1edca6d7e625afd0d362fb604cb656080a5c4ef Mon Sep 17 00:00:00 2001 From: Valay Dave Date: Tue, 24 Jan 2023 00:26:27 +0000 Subject: [PATCH] gcp support with AF --- metaflow/plugins/airflow/airflow.py | 7 ++++++- metaflow/plugins/airflow/airflow_cli.py | 11 ++++++++--- 2 files changed, 14 insertions(+), 4 deletions(-) diff --git a/metaflow/plugins/airflow/airflow.py b/metaflow/plugins/airflow/airflow.py index 9e8d0a0d398..0e8cbad4da4 100644 --- a/metaflow/plugins/airflow/airflow.py +++ b/metaflow/plugins/airflow/airflow.py @@ -23,6 +23,8 @@ DATASTORE_SYSROOT_AZURE, CARD_AZUREROOT, AIRFLOW_KUBERNETES_CONN_ID, + DATASTORE_SYSROOT_GS, + CARD_GSROOT, ) from metaflow.parameters import DelayedEvaluationParameter, deploy_time_eval from metaflow.plugins.kubernetes.kubernetes import Kubernetes @@ -361,7 +363,7 @@ def _to_job(self, node): "METAFLOW_SERVICE_HEADERS": json.dumps(SERVICE_HEADERS), "METAFLOW_DATASTORE_SYSROOT_S3": DATASTORE_SYSROOT_S3, "METAFLOW_DATATOOLS_S3ROOT": DATATOOLS_S3ROOT, - "METAFLOW_DEFAULT_DATASTORE": "s3", + "METAFLOW_DEFAULT_DATASTORE": self.flow_datastore.TYPE, "METAFLOW_DEFAULT_METADATA": "service", "METAFLOW_KUBERNETES_WORKLOAD": str( 1 @@ -376,6 +378,9 @@ def _to_job(self, node): "METAFLOW_AIRFLOW_JOB_ID": AIRFLOW_MACROS.AIRFLOW_JOB_ID, "METAFLOW_PRODUCTION_TOKEN": self.production_token, "METAFLOW_ATTEMPT_NUMBER": AIRFLOW_MACROS.ATTEMPT, + # GCP stuff + "METAFLOW_DATASTORE_SYSROOT_GS": DATASTORE_SYSROOT_GS, + "METAFLOW_CARD_GSROOT": CARD_GSROOT, } env[ "METAFLOW_AZURE_STORAGE_BLOB_SERVICE_ENDPOINT" diff --git a/metaflow/plugins/airflow/airflow_cli.py b/metaflow/plugins/airflow/airflow_cli.py index 1f48a1fa481..e9ea4521e28 100644 --- a/metaflow/plugins/airflow/airflow_cli.py +++ b/metaflow/plugins/airflow/airflow_cli.py @@ -389,10 +389,15 @@ def _validate_workflow(flow, graph, flow_datastore, metadata, workflow_timeout): "Step *%s* is marked for execution on AWS Batch with Airflow which isn't currently supported." % node.name ) - - if flow_datastore.TYPE not in ("azure", "s3"): + SUPPORTED_DATASTORES = ("azure", "s3", "gs") + if flow_datastore.TYPE not in SUPPORTED_DATASTORES: raise AirflowException( - 'Datastore of type "s3" or "azure" required with `airflow create`' + "Datastore type `%s` is not supported with `airflow create`. " + "Please choose from datastore of type %s when calling `airflow create`" + % ( + str(flow_datastore.TYPE), + "or ".join(["`%s`" % x for x in SUPPORTED_DATASTORES]), + ) )