diff --git a/earth2studio/serve/server/config.py b/earth2studio/serve/server/config.py index 418616475..3501519e0 100644 --- a/earth2studio/serve/server/config.py +++ b/earth2studio/serve/server/config.py @@ -61,6 +61,7 @@ class QueueConfig: name: str = "inference" result_zip_queue_name: str = "result_zip" object_storage_queue_name: str = "object_storage" + geocatalog_ingestion_queue_name: str = "geocatalog_ingestion" finalize_metadata_queue_name: str = "finalize_metadata" max_size: int = 10 default_timeout: str = "1h" @@ -115,9 +116,10 @@ class CORSConfig: @dataclass class ObjectStorageConfig: - """Object storage configuration for S3/CloudFront""" + """Object storage configuration for S3/CloudFront and Azure Blob Storage""" enabled: bool = False + storage_type: Literal["s3", "azure"] = "s3" # Storage provider type # S3 configuration bucket: str | None = None region: str = "us-east-1" @@ -137,6 +139,27 @@ class ObjectStorageConfig: cloudfront_private_key: str | None = None # PEM private key content # Signed URL settings signed_url_expires_in: int = 86400 # Default 24 hours + # Azure Blob Storage configuration + azure_account_name: str | None = None # Azure storage account name + azure_container_name: str | None = ( + None # Azure container name (falls back to bucket if not set) + ) + # Azure Planetary Computer / GeoCatalog ingestion (optional) + azure_geocatalog_url: str | None = ( + None # When set, triggers PC ingestion after upload + ) + + +@dataclass +class WorkflowExposureConfig: + """Configuration for controlling which workflows are exposed via API endpoints""" + + exposed_workflows: list[str] = field( + default_factory=lambda: [] + ) # Empty list means all workflows are exposed + warmup_workflows: list[str] = field( + default_factory=lambda: ["example_user_workflow"] + ) # Workflows accessible for warmup even if not exposed @dataclass @@ -150,6 +173,9 @@ class AppConfig: server: ServerConfig = field(default_factory=ServerConfig) cors: CORSConfig = field(default_factory=CORSConfig) object_storage: ObjectStorageConfig = field(default_factory=ObjectStorageConfig) + workflow_exposure: WorkflowExposureConfig = field( + default_factory=WorkflowExposureConfig + ) class ConfigManager: @@ -232,6 +258,9 @@ def _dict_to_config(self, cfg_dict: dict) -> AppConfig: server=ServerConfig(**cfg_dict.get("server", {})), cors=CORSConfig(**cfg_dict.get("cors", {})), object_storage=ObjectStorageConfig(**cfg_dict.get("object_storage", {})), + workflow_exposure=WorkflowExposureConfig( + **cfg_dict.get("workflow_exposure", {}) + ), ) def _create_default_config_object(self) -> AppConfig: @@ -244,6 +273,7 @@ def _create_default_config_object(self) -> AppConfig: server=ServerConfig(), cors=CORSConfig(), object_storage=ObjectStorageConfig(), + workflow_exposure=WorkflowExposureConfig(), ) def _apply_env_overrides(self) -> None: @@ -288,6 +318,12 @@ def _apply_env_overrides(self) -> None: self._config.paths.results_zip_dir = os.getenv( "RESULTS_ZIP_DIR", default=self._config.paths.results_zip_dir ) + if os.getenv("OUTPUT_FORMAT"): + output_format = os.getenv("OUTPUT_FORMAT", "").lower() + if output_format in ["zarr", "netcdf4"]: + self._config.paths.output_format = cast( + Literal["zarr", "netcdf4"], output_format + ) # Logging overrides if os.getenv("LOG_LEVEL"): @@ -319,6 +355,13 @@ def _apply_env_overrides(self) -> None: self._config.object_storage.enabled = ( os.getenv("OBJECT_STORAGE_ENABLED", "").lower() == "true" ) + if os.getenv("OBJECT_STORAGE_TYPE"): + storage_type = os.getenv("OBJECT_STORAGE_TYPE", "").lower() + if storage_type in ["s3", "azure"]: + self._config.object_storage.storage_type = cast( + Literal["s3", "azure"], storage_type + ) + if os.getenv("OBJECT_STORAGE_BUCKET"): self._config.object_storage.bucket = os.getenv("OBJECT_STORAGE_BUCKET") if os.getenv("OBJECT_STORAGE_REGION"): @@ -387,6 +430,31 @@ def _apply_env_overrides(self) -> None: ) ) + # Azure Blob Storage overrides + if os.getenv("AZURE_STORAGE_ACCOUNT_NAME"): + self._config.object_storage.azure_account_name = os.getenv( + "AZURE_STORAGE_ACCOUNT_NAME" + ) + if os.getenv("AZURE_CONTAINER_NAME"): + self._config.object_storage.azure_container_name = os.getenv( + "AZURE_CONTAINER_NAME" + ) + # Support AZURE_ENDPOINT_URL for managed identity scenarios + if os.getenv("AZURE_ENDPOINT_URL"): + self._config.object_storage.endpoint_url = os.getenv("AZURE_ENDPOINT_URL") + if os.getenv("AZURE_GEOCATALOG_URL"): + self._config.object_storage.azure_geocatalog_url = os.getenv( + "AZURE_GEOCATALOG_URL" + ) + + # Workflow exposure overrides + if os.getenv("EXPOSED_WORKFLOWS"): + # Parse comma-separated list of workflow names + exposed_workflows_str = os.getenv("EXPOSED_WORKFLOWS", "") + self._config.workflow_exposure.exposed_workflows = [ + w.strip() for w in exposed_workflows_str.split(",") if w.strip() + ] + logger.debug("Environment variable overrides applied") def _ensure_paths_exist(self) -> None: diff --git a/earth2studio/serve/server/cpu_worker.py b/earth2studio/serve/server/cpu_worker.py index aadd4e6a9..b7c58b9b1 100644 --- a/earth2studio/serve/server/cpu_worker.py +++ b/earth2studio/serve/server/cpu_worker.py @@ -515,6 +515,36 @@ def process_result_zip( raise +def _primary_azure_asset_relpath(output_path: Path) -> str | None: + """ + Relative path under ``output_path`` for the primary dataset asset (GeoCatalog STAC href). + + Prefers the first ``*.nc`` file found; otherwise the first ``*.zarr`` directory (Zarr store). + If ``output_path`` is itself a ``*.zarr`` directory, returns ``""`` (store contents sync to + ``remote_prefix`` without an extra path segment). + + Returns + ------- + str | None + Relative POSIX path, empty string for a root-level Zarr directory, or ``None`` if + no NetCDF or Zarr asset was found. + """ + if not output_path.is_dir(): + return None + nc_files = sorted(output_path.rglob("*.nc")) + if nc_files: + return nc_files[0].relative_to(output_path).as_posix() + zarr_dirs = sorted( + (p for p in output_path.rglob("*.zarr") if p.is_dir()), + key=lambda p: p.as_posix(), + ) + if zarr_dirs: + return zarr_dirs[0].relative_to(output_path).as_posix() + if output_path.name.endswith(".zarr"): + return "" + return None + + @check_optional_dependencies() def process_object_storage_upload( workflow_name: str, @@ -550,7 +580,11 @@ def process_object_storage_upload( # Upload to object storage if enabled - if config.object_storage.enabled and config.object_storage.bucket: + # Check if object storage is enabled and properly configured + if config.object_storage.enabled and ( + config.object_storage.bucket + or config.object_storage.storage_type == "azure" + ): from earth2studio.serve.server.object_storage import ( MSCObjectStorage, ObjectStorageError, @@ -564,44 +598,82 @@ def process_object_storage_upload( f"Output path does not exist: {output_path}", ) - # Create S3 storage instance + # Validate Azure container name is configured + if config.object_storage.storage_type == "azure": + if ( + not config.object_storage.azure_container_name + and not config.object_storage.bucket + ): + return fail_workflow( + workflow_name, + execution_id, + "Azure storage is enabled but neither 'azure_container_name' nor 'bucket' is configured", + ) + + # Create storage instance storage_kwargs: dict[str, Any] = { - "bucket": config.object_storage.bucket, - "region": config.object_storage.region, - "use_transfer_acceleration": config.object_storage.use_transfer_acceleration, + "bucket": config.object_storage.bucket + or config.object_storage.azure_container_name + or "", + "storage_type": config.object_storage.storage_type, "max_concurrency": config.object_storage.max_concurrency, "multipart_chunksize": config.object_storage.multipart_chunksize, "use_rust_client": config.object_storage.use_rust_client, } - # Add optional credentials - if ( - config.object_storage.access_key_id - and config.object_storage.secret_access_key - ): - storage_kwargs["access_key_id"] = config.object_storage.access_key_id - storage_kwargs["secret_access_key"] = ( - config.object_storage.secret_access_key - ) - if config.object_storage.session_token: - storage_kwargs["session_token"] = config.object_storage.session_token - if config.object_storage.endpoint_url: - storage_kwargs["endpoint_url"] = config.object_storage.endpoint_url - - # Add CloudFront configuration for signed URLs - if config.object_storage.cloudfront_domain: - storage_kwargs["cloudfront_domain"] = ( - config.object_storage.cloudfront_domain - ) - if config.object_storage.cloudfront_key_pair_id: - storage_kwargs["cloudfront_key_pair_id"] = ( - config.object_storage.cloudfront_key_pair_id - ) - if config.object_storage.cloudfront_private_key: - storage_kwargs["cloudfront_private_key"] = ( - config.object_storage.cloudfront_private_key + # Add storage-type-specific configuration + if config.object_storage.storage_type == "s3": + # S3-specific parameters + storage_kwargs["region"] = config.object_storage.region + storage_kwargs["use_transfer_acceleration"] = ( + config.object_storage.use_transfer_acceleration ) + # Add optional S3 credentials + if ( + config.object_storage.access_key_id + and config.object_storage.secret_access_key + ): + storage_kwargs["access_key_id"] = ( + config.object_storage.access_key_id + ) + storage_kwargs["secret_access_key"] = ( + config.object_storage.secret_access_key + ) + if config.object_storage.session_token: + storage_kwargs["session_token"] = ( + config.object_storage.session_token + ) + if config.object_storage.endpoint_url: + storage_kwargs["endpoint_url"] = config.object_storage.endpoint_url + + # Add CloudFront configuration for signed URLs + if config.object_storage.cloudfront_domain: + storage_kwargs["cloudfront_domain"] = ( + config.object_storage.cloudfront_domain + ) + if config.object_storage.cloudfront_key_pair_id: + storage_kwargs["cloudfront_key_pair_id"] = ( + config.object_storage.cloudfront_key_pair_id + ) + if config.object_storage.cloudfront_private_key: + storage_kwargs["cloudfront_private_key"] = ( + config.object_storage.cloudfront_private_key + ) + elif config.object_storage.storage_type == "azure": + # Azure-specific parameters (managed identity via DefaultAzureCredentials) + if config.object_storage.azure_account_name: + storage_kwargs["azure_account_name"] = ( + config.object_storage.azure_account_name + ) + if config.object_storage.azure_container_name: + storage_kwargs["azure_container_name"] = ( + config.object_storage.azure_container_name + ) + # Support endpoint_url for Azure (useful for managed identity) + if config.object_storage.endpoint_url: + storage_kwargs["endpoint_url"] = config.object_storage.endpoint_url + try: storage = MSCObjectStorage(**storage_kwargs) except Exception as e: @@ -617,8 +689,13 @@ def process_object_storage_upload( ) # Upload the output directory + storage_location = ( + f"s3://{config.object_storage.bucket}" + if config.object_storage.storage_type == "s3" + else f"azure://{config.object_storage.azure_container_name or config.object_storage.bucket}" + ) logger.info( - f"Uploading {output_path} to s3://{config.object_storage.bucket}/{remote_prefix}" + f"Uploading {output_path} to {storage_location}/{remote_prefix}" ) try: @@ -642,14 +719,14 @@ def process_object_storage_upload( f"Failed to upload to object storage: {upload_result.errors}", ) - storage_type = "s3" + storage_type = config.object_storage.storage_type logger.info( f"Successfully uploaded {upload_result.files_uploaded} files " f"({upload_result.total_bytes} bytes) to {upload_result.destination}" ) - # Generate signed URL if CloudFront is configured - cloudfront_configured = all( + # Generate signed URL for S3 only (CloudFront). Azure: clients obtain tokens to read blobs. + can_generate_signed_url = storage_type == "s3" and all( [ config.object_storage.cloudfront_domain, config.object_storage.cloudfront_key_pair_id, @@ -657,8 +734,10 @@ def process_object_storage_upload( ] ) - if not cloudfront_configured: - logger.info("CloudFront not configured, skipping signed URL generation") + if not can_generate_signed_url: + logger.info( + f"Signed URL generation not configured for {storage_type}, skipping" + ) else: try: signed_url_path = f"{remote_prefix}/*" @@ -688,10 +767,38 @@ def process_object_storage_upload( storage_info = { "storage_type": storage_type, } - if storage_type == "s3" and remote_prefix: - storage_info["remote_path"] = ( - f"s3://{config.object_storage.bucket}/{remote_prefix}" - ) + if remote_prefix: + if storage_type == "s3": + storage_info["remote_path"] = ( + f"s3://{config.object_storage.bucket}/{remote_prefix}" + ) + elif storage_type == "azure": + container_name = ( + config.object_storage.azure_container_name + or config.object_storage.bucket + ) + storage_info["remote_path"] = ( + f"azure://{container_name}/{remote_prefix}" + ) + azure_account = config.object_storage.azure_account_name + if azure_account: + storage_info["azure_account_name"] = azure_account + # Build HTTPS blob URL for primary netcdf file (for GeoCatalog ingestion) + if ( + config.object_storage.azure_account_name + and config.object_storage.azure_geocatalog_url + ): + primary_rel = _primary_azure_asset_relpath(output_path) + if primary_rel is not None: + base_url = ( + f"https://{config.object_storage.azure_account_name}" + f".blob.core.windows.net/{container_name}/{remote_prefix}" + ) + storage_info["blob_url"] = ( + f"{base_url}/{primary_rel}" + if primary_rel + else f"{base_url}/" + ) if signed_url: storage_info["signed_url"] = signed_url @@ -729,7 +836,7 @@ def process_object_storage_upload( "signed_url": signed_url, } - if upload_result and storage_type == "s3": + if upload_result: result["files_uploaded"] = upload_result.files_uploaded result["total_bytes"] = upload_result.total_bytes result["destination"] = upload_result.destination @@ -780,6 +887,7 @@ def process_finalize_metadata( storage_info_json = redis_client.get(storage_info_key) if not pending_metadata_json or not results_zip_dir_str: + logger.error(f"Pending metadata not found in Redis for {request_id}") return fail_workflow( workflow_name, execution_id, @@ -799,6 +907,10 @@ def process_finalize_metadata( metadata_dict["remote_path"] = storage_info["remote_path"] if storage_info.get("signed_url"): metadata_dict["signed_url"] = storage_info["signed_url"] + if storage_info.get("azure_account_name"): + metadata_dict["azure_account_name"] = storage_info["azure_account_name"] + if storage_info.get("blob_url"): + metadata_dict["blob_url"] = storage_info["blob_url"] else: # No storage info means object storage was skipped or failed metadata_dict["storage_type"] = "server" diff --git a/earth2studio/serve/server/e2workflow.py b/earth2studio/serve/server/e2workflow.py index 0ed432010..7134c6b17 100644 --- a/earth2studio/serve/server/e2workflow.py +++ b/earth2studio/serve/server/e2workflow.py @@ -117,6 +117,7 @@ class Earth2Workflow(Workflow, metaclass=AutoParameters): def __init__(self) -> None: super().__init__() + self.execution_id: str | None = None @abstractmethod def __call__(self, io: IOBackend) -> None: @@ -141,6 +142,9 @@ def run( ) -> dict[str, Any]: """Run custom workflow""" + # Store execution_id for use in update_progress + self.execution_id = execution_id + # Validate and convert parameters parameters = self.validate_parameters(parameters) @@ -178,6 +182,10 @@ def run( # Consolidate zarr metadata for faster remote access if output_format == "zarr": zarr.consolidate_metadata(output_path) + elif output_format == "netcdf4": + inner = getattr(results_io, "io", results_io) + inner.root.sync() + inner.close() # Update final metadata and progress progress = WorkflowProgress(progress="Finished workflow successfully") @@ -203,6 +211,36 @@ def run( self.update_execution_data(execution_id, progress) raise + def update_progress(self, progress: WorkflowProgress) -> None: + """ + Update workflow execution progress. + + This method is intended for child workflows to update progress + information during execution. It uses the execution_id stored + during the run() method. + + If execution_id is not set (e.g., when running outside the API server), + this method is a no-op and silently returns without updating progress. + + Parameters + ---------- + progress : WorkflowProgress + WorkflowProgress object containing progress information to update. + + Examples + -------- + >>> progress = WorkflowProgress( + ... progress="Processing data...", + ... current_step=5, + ... total_steps=10 + ... ) + >>> self.update_progress(progress) + """ + if self.execution_id is None: + # No-op when running outside API server context + return + self.update_execution_data(self.execution_id, progress) + logger = logging.getLogger(__name__) @@ -213,7 +251,7 @@ class BackendProgress: def __init__( self, io: IOBackend, - workflow: Workflow, + workflow: Earth2Workflow, execution_id: str, progress_dim: str = "lead_time", ) -> None: @@ -244,7 +282,7 @@ def add_array( progress = WorkflowProgress( current_step=0, total_steps=len(self.progress_coords) ) - self.workflow.update_execution_data(self.execution_id, progress) + self.workflow.update_progress(progress) def write( self, @@ -261,8 +299,12 @@ def write( step_index = self.progress_coords.index(current_coord) # Update progress using WorkflowProgress progress = WorkflowProgress(current_step=step_index + 1) - self.workflow.update_execution_data(self.execution_id, progress) + self.workflow.update_progress(progress) def __getattr__(self, name: str) -> Any: """Allow passthrough of unwrapped attributes.""" return getattr(self.io, name) + + def __getitem__(self, key: str) -> Any: + """Allow subscripting to access underlying io object.""" + return self.io[key] diff --git a/earth2studio/serve/server/main.py b/earth2studio/serve/server/main.py index 78688098b..bf7608451 100644 --- a/earth2studio/serve/server/main.py +++ b/earth2studio/serve/server/main.py @@ -94,6 +94,7 @@ def check_admission_control() -> None: config.queue.name, config.queue.result_zip_queue_name, config.queue.object_storage_queue_name, + config.queue.geocatalog_ingestion_queue_name, config.queue.finalize_metadata_queue_name, ] for queue_name in queue_names: @@ -384,7 +385,7 @@ async def list_workflows() -> dict[str, dict[str, str]]: dict Single key ``workflows`` mapping workflow name to description. """ - workflows = workflow_registry.list_workflows() + workflows = workflow_registry.list_workflows(exposed_only=True) return {"workflows": workflows} @@ -413,12 +414,16 @@ async def get_workflow_schema(workflow_name: str) -> dict[str, Any]: HTTPException 404 if workflow not found; 500 if schema generation fails. """ - # Check if workflow exists + # Check if workflow exists and is exposed workflow_class = workflow_registry.get_workflow_class(workflow_name) if not workflow_class: raise HTTPException( status_code=404, detail=f"Workflow '{workflow_name}' not found" ) + if not workflow_registry.is_workflow_exposed(workflow_name): + raise HTTPException( + status_code=404, detail=f"Workflow '{workflow_name}' is not exposed" + ) try: # Get the Parameters class from the workflow @@ -521,12 +526,16 @@ async def execute_workflow( 404 if workflow not found; 422 if parameters invalid; 429 if queues full; 503 if Redis/queue not initialized; 500 on enqueue failure. """ - # Check if workflow exists and get the workflow class for validation + # Check if workflow exists and is exposed custom_workflow_class = workflow_registry.get_workflow_class(workflow_name) if not custom_workflow_class: raise HTTPException( status_code=404, detail=f"Workflow '{workflow_name}' not found" ) + if not workflow_registry.is_workflow_exposed(workflow_name): + raise HTTPException( + status_code=404, detail=f"Workflow '{workflow_name}' is not exposed" + ) # Validate parameters early to provide immediate feedback using classmethod try: @@ -649,12 +658,16 @@ async def get_workflow_status(workflow_name: str, execution_id: str) -> Workflow # Create logger adapter with execution_id log = logging.LoggerAdapter(logger, {"execution_id": execution_id}) - # Check if workflow exists + # Check if workflow exists and is exposed custom_workflow_class = workflow_registry.get_workflow_class(workflow_name) if not custom_workflow_class: raise HTTPException( status_code=404, detail=f"Custom workflow '{workflow_name}' not found" ) + if not workflow_registry.is_workflow_exposed(workflow_name): + raise HTTPException( + status_code=404, detail=f"Workflow '{workflow_name}' is not exposed" + ) try: result = custom_workflow_class._get_execution_data( @@ -717,12 +730,16 @@ async def get_workflow_results( # Create logger adapter with execution_id log = logging.LoggerAdapter(logger, {"execution_id": execution_id}) - # Check if workflow exists + # Check if workflow exists and is exposed custom_workflow_class = workflow_registry.get_workflow_class(workflow_name) if not custom_workflow_class: raise HTTPException( status_code=404, detail=f"Custom workflow '{workflow_name}' not found" ) + if not workflow_registry.is_workflow_exposed(workflow_name): + raise HTTPException( + status_code=404, detail=f"Workflow '{workflow_name}' is not exposed" + ) # Check workflow status first try: @@ -843,12 +860,16 @@ async def get_workflow_result_file( 403 on path traversal attempt; 404 if workflow, execution, file, or zip not found or results not completed; 503 if Redis not initialized; 500 on error. """ - # Check if workflow exists + # Check if workflow exists and is exposed custom_workflow_class = workflow_registry.get_workflow_class(workflow_name) if not custom_workflow_class: raise HTTPException( status_code=404, detail=f"Custom workflow '{workflow_name}' not found" ) + if not workflow_registry.is_workflow_exposed(workflow_name): + raise HTTPException( + status_code=404, detail=f"Workflow '{workflow_name}' is not exposed" + ) # Check workflow status first try: diff --git a/earth2studio/serve/server/object_storage.py b/earth2studio/serve/server/object_storage.py index 8b318e372..fbd065c1d 100644 --- a/earth2studio/serve/server/object_storage.py +++ b/earth2studio/serve/server/object_storage.py @@ -22,7 +22,7 @@ from abc import ABC, abstractmethod from dataclasses import dataclass from pathlib import Path -from typing import Any +from typing import Any, Literal logger = logging.getLogger(__name__) @@ -198,25 +198,31 @@ class ObjectStorageError(Exception): class MSCObjectStorage(ObjectStorage): """ - Object storage using NVIDIA Multi-Storage Client (MSC) with Rust backend. + Object storage using NVIDIA Multi-Storage Client (MSC) with Rust backend for AWS S3 and Azure Blob Storage. MSC provides optimized parallel transfers; the Rust client bypasses Python's GIL for improved I/O performance (up to 12x faster). Uses sync_from for efficient directory uploads with parallel transfers. - Credentials are read from environment variables: AWS_ACCESS_KEY_ID, + Supports both AWS S3 and Azure Blob Storage. + + For S3, credentials are read from environment variables: AWS_ACCESS_KEY_ID, AWS_SECRET_ACCESS_KEY, AWS_SESSION_TOKEN (optional), AWS_DEFAULT_REGION. + For Azure Blob Storage, authentication uses DefaultAzureCredentials (managed identity, + Azure CLI, etc.). Provide ``endpoint_url`` and/or ``azure_account_name`` to locate the + account; do not use connection strings. + References ---------- https://nvidia.github.io/multi-storage-client/user_guide/rust.html - Initialize MSCObjectStorage with AWS credentials and configuration. - Parameters ---------- bucket : str - S3 bucket name. + S3 bucket name or Azure container name. + storage_type : str, optional + Storage provider type, either "s3" or "azure". Default is "s3". region : str, optional AWS region (e.g. 'us-east-1'). access_key_id : str, optional @@ -226,7 +232,7 @@ class MSCObjectStorage(ObjectStorage): session_token : str, optional AWS session token for temporary credentials. endpoint_url : str, optional - Custom endpoint URL for S3-compatible services. + Custom endpoint URL for S3-compatible services or Azure Blob Storage. use_transfer_acceleration : bool, optional Enable S3 Transfer Acceleration (bucket must support it). Default is False. max_concurrency : int, optional @@ -236,18 +242,23 @@ class MSCObjectStorage(ObjectStorage): use_rust_client : bool, optional Use the high-performance Rust client. Default is False. profile_name : str, optional - Name for the MSC profile. Default is 'e2studio-s3'. + Name for the MSC profile. Default is 'e2studio-s3' for S3, 'e2studio-azure' for Azure. cloudfront_domain : str, optional CloudFront distribution domain for signed URLs. cloudfront_key_pair_id : str, optional CloudFront key pair ID for signed URLs. cloudfront_private_key : str, optional PEM private key content as string for signed URLs. + azure_account_name : str, optional + Azure storage account name (used with managed identity when ``endpoint_url`` is not set). + azure_container_name : str, optional + Azure container name. """ def __init__( self, bucket: str, + storage_type: Literal["s3", "azure"] = "s3", region: str | None = None, access_key_id: str | None = None, secret_access_key: str | None = None, @@ -257,40 +268,68 @@ def __init__( max_concurrency: int = 16, multipart_chunksize: int = 8 * 1024 * 1024, # 8 MB use_rust_client: bool = False, - profile_name: str = "e2studio-s3", + profile_name: str | None = None, cloudfront_domain: str | None = None, cloudfront_key_pair_id: str | None = None, cloudfront_private_key: str | None = None, + # Azure-specific parameters (DefaultAzureCredentials / managed identity) + azure_account_name: str | None = None, + azure_container_name: str | None = None, ): + self.storage_type = storage_type self.bucket = bucket - self.region = region or os.environ.get("AWS_DEFAULT_REGION", "us-east-1") self.max_concurrency = max_concurrency self.multipart_chunksize = multipart_chunksize self.use_rust_client = use_rust_client - self.profile_name = profile_name - self.use_transfer_acceleration = use_transfer_acceleration - - # Use S3 Transfer Acceleration endpoint if enabled (and no custom endpoint provided) - if use_transfer_acceleration and not endpoint_url: - self.endpoint_url = f"https://{bucket}.s3-accelerate.amazonaws.com" - logger.info(f"S3 Transfer Acceleration enabled: {self.endpoint_url}") + self.profile_name = profile_name or ( + "e2studio-s3" if storage_type == "s3" else "e2studio-azure" + ) + # Initialize endpoint_url as None to allow str | None type + self.endpoint_url: str | None = None + + # S3-specific configuration + if storage_type == "s3": + self.region = region or os.environ.get("AWS_DEFAULT_REGION", "us-east-1") + self.use_transfer_acceleration = use_transfer_acceleration + + # Use S3 Transfer Acceleration endpoint if enabled (and no custom endpoint provided) + if use_transfer_acceleration and not endpoint_url: + self.endpoint_url = f"https://{bucket}.s3-accelerate.amazonaws.com" + logger.info(f"S3 Transfer Acceleration enabled: {self.endpoint_url}") + else: + self.endpoint_url = endpoint_url + + # CloudFront configuration for signed URLs + self.cloudfront_domain = cloudfront_domain + self.cloudfront_key_pair_id = cloudfront_key_pair_id + self.cloudfront_private_key = cloudfront_private_key + + # Set credentials as environment variables - MSC picks these up automatically + if access_key_id: + os.environ["AWS_ACCESS_KEY_ID"] = access_key_id + if secret_access_key: + os.environ["AWS_SECRET_ACCESS_KEY"] = secret_access_key + if session_token: + os.environ["AWS_SESSION_TOKEN"] = session_token + if region: + os.environ["AWS_DEFAULT_REGION"] = region + + # Azure-specific configuration + elif storage_type == "azure": + self.azure_container_name = azure_container_name or bucket + + # Azure Blob: DefaultAzureCredentials (managed identity, Azure CLI, etc.) + self.use_managed_identity = True + logger.info( + "Using Azure DefaultAzureCredentials (managed identity / Azure CLI). " + f"Account: {azure_account_name or 'will be determined from endpoint'}, " + f"Container: {self.azure_container_name}" + ) + self.azure_account_name = azure_account_name else: - self.endpoint_url = endpoint_url or "" - - # CloudFront configuration for signed URLs - self.cloudfront_domain = cloudfront_domain - self.cloudfront_key_pair_id = cloudfront_key_pair_id - self.cloudfront_private_key = cloudfront_private_key - - # Set credentials as environment variables - MSC picks these up automatically - if access_key_id: - os.environ["AWS_ACCESS_KEY_ID"] = access_key_id - if secret_access_key: - os.environ["AWS_SECRET_ACCESS_KEY"] = secret_access_key - if session_token: - os.environ["AWS_SESSION_TOKEN"] = session_token - if region: - os.environ["AWS_DEFAULT_REGION"] = region + raise ValueError( + f"Unsupported storage_type: {storage_type}. Must be 's3' or 'azure'." + ) # Import multi-storage-client try: @@ -303,44 +342,88 @@ def __init__( self._msc = msc - # Build the S3 storage provider options - s3_storage_provider_options: dict[str, Any] = { - "base_path": bucket, - "region_name": self.region, - "multipart_threshold": multipart_chunksize, - "multipart_chunksize": multipart_chunksize, - "max_concurrency": max_concurrency, - } - - # Add endpoint URL if provided (for S3-compatible services) - if endpoint_url: - s3_storage_provider_options["endpoint_url"] = endpoint_url - - # Enable Rust client for high-performance I/O - if use_rust_client: - s3_storage_provider_options["rust_client"] = { + # Build storage provider profile config based on storage_type + if storage_type == "s3": + # Build the S3 storage provider options + s3_storage_provider_options: dict[str, Any] = { + "base_path": bucket, + "region_name": self.region, + "multipart_threshold": multipart_chunksize, "multipart_chunksize": multipart_chunksize, "max_concurrency": max_concurrency, } - # Build the S3 profile config - s3_profile_config = { - "profiles": { - profile_name: { - "storage_provider": { - "type": "s3", - "options": s3_storage_provider_options, + # Add endpoint URL if provided (for S3-compatible services) + if self.endpoint_url: + s3_storage_provider_options["endpoint_url"] = self.endpoint_url + + # Enable Rust client for high-performance I/O + if use_rust_client: + s3_storage_provider_options["rust_client"] = { + "multipart_chunksize": multipart_chunksize, + "max_concurrency": max_concurrency, + } + + # Build the S3 profile config + profile_config = { + "profiles": { + self.profile_name: { + "storage_provider": { + "type": "s3", + "options": s3_storage_provider_options, + } + } + } + } + elif storage_type == "azure": + # Build the Azure storage provider options + # Derive endpoint URL from endpoint_url parameter or azure_account_name + azure_endpoint_url: str | None = None + + if endpoint_url: + azure_endpoint_url = endpoint_url.rstrip("/") + else: + account_name = self.azure_account_name + endpoint_suffix = "core.windows.net" + if not account_name: + raise ObjectStorageError( + "Azure endpoint_url cannot be determined. " + "Please provide endpoint_url or azure_account_name (managed identity)." + ) + azure_endpoint_url = f"https://{account_name}.blob.{endpoint_suffix}" + logger.info( + f"Constructed Azure endpoint URL from account name: {azure_endpoint_url}" + ) + + azure_storage_provider_options = { + "base_path": self.azure_container_name, + "endpoint_url": azure_endpoint_url, + } + + # Build the Azure profile config with credentials provider + profile_config = { + "profiles": { + self.profile_name: { + "storage_provider": { + "type": "azure", + "options": azure_storage_provider_options, + } } } } - } - # Initialize the S3 StorageClient (target for uploads) - s3_client_config = msc.StorageClientConfig.from_dict( - config_dict=s3_profile_config, - profile=profile_name, + # DefaultAzureCredentials (managed identity, Azure CLI, etc.) + profile_config["profiles"][self.profile_name]["credentials_provider"] = { + "type": "DefaultAzureCredentials", + "options": {}, + } + + # Initialize the StorageClient (target for uploads) + storage_client_config = msc.StorageClientConfig.from_dict( + config_dict=profile_config, + profile=self.profile_name, ) - self._s3_client = msc.StorageClient(config=s3_client_config) + self._storage_client = msc.StorageClient(config=storage_client_config) # Initialize the local filesystem StorageClient (source for uploads) local_profile_config = { @@ -362,12 +445,18 @@ def __init__( self._local_client = msc.StorageClient(config=local_client_config) rust_status = "enabled" if use_rust_client else "disabled" - accel_status = "enabled" if use_transfer_acceleration else "disabled" - logger.info( - f"MSCObjectStorage initialized: bucket={bucket}, region={self.region}, " - f"max_concurrency={max_concurrency}, rust_client={rust_status}, " - f"transfer_acceleration={accel_status}" - ) + if storage_type == "s3": + accel_status = "enabled" if use_transfer_acceleration else "disabled" + logger.info( + f"MSCObjectStorage initialized (S3): bucket={bucket}, region={self.region}, " + f"max_concurrency={max_concurrency}, rust_client={rust_status}, " + f"transfer_acceleration={accel_status}" + ) + else: + logger.info( + f"MSCObjectStorage initialized (Azure): container={self.azure_container_name}, " + f"max_concurrency={max_concurrency}, rust_client={rust_status}" + ) def upload_directory( self, @@ -416,9 +505,14 @@ def upload_directory( total_bytes = sum(f.stat().st_size for f in files) + storage_prefix = ( + f"s3://{self.bucket}" + if self.storage_type == "s3" + else f"azure://{self.azure_container_name if self.storage_type == 'azure' else self.bucket}" + ) logger.info( f"[MSC] Syncing {len(files)} files ({total_bytes / (1024 * 1024):.2f} MB) " - f"from {local_directory} to s3://{self.bucket}/{remote_prefix}" + f"from {local_directory} to {storage_prefix}/{remote_prefix}" ) errors: list[str] = [] @@ -426,7 +520,7 @@ def upload_directory( try: # Use sync_from for efficient parallel directory upload - result = self._s3_client.sync_from( + result = self._storage_client.sync_from( source_client=self._local_client, source_path=source_path, target_path=f"/{remote_prefix}" if remote_prefix else "/", @@ -443,7 +537,15 @@ def upload_directory( elapsed_time = time.time() - start_time success = len(errors) == 0 - destination = f"s3://{self.bucket}/{remote_prefix}" + if self.storage_type == "s3": + destination = f"s3://{self.bucket}/{remote_prefix}" + else: + container = ( + self.azure_container_name + if self.storage_type == "azure" + else self.bucket + ) + destination = f"azure://{container}/{remote_prefix}" result = UploadResult( success=success, @@ -500,7 +602,7 @@ def upload_file( try: remote_key = f"/{remote_key.lstrip('/')}" - self._s3_client.upload_file(remote_key, str(local_path)) + self._storage_client.upload_file(remote_key, str(local_path)) return True except Exception as e: @@ -523,7 +625,7 @@ def file_exists(self, remote_key: str) -> bool: """ try: remote_path = f"/{remote_key.lstrip('/')}" - self._s3_client.info(remote_path) + self._storage_client.info(remote_path) return True except FileNotFoundError: return False @@ -544,7 +646,7 @@ def delete_file(self, remote_key: str) -> bool: """ try: remote_path = f"/{remote_key.lstrip('/')}" - self._s3_client.delete(remote_path) + self._storage_client.delete(remote_path) return True except FileNotFoundError: logger.warning(f"File not found for deletion: {remote_key}") @@ -610,25 +712,40 @@ def _url_safe_b64(data: bytes) -> str: def generate_signed_url(self, remote_key: str, expires_in: int = 86400) -> str: """ - Generate a CloudFront signed URL for accessing a file. + Generate a signed URL for accessing a file. + + For S3, generates a CloudFront signed URL. + Azure blob access is not supported here; clients should obtain tokens to read blobs. Parameters ---------- remote_key : str - S3 key/path to the file. Can include wildcards. + Storage key/path to the file. Can include wildcards for S3. expires_in : int, optional Number of seconds until the URL expires. Default is 86400. Returns ------- str - Signed CloudFront URL string. + Signed URL string. Raises ------ ObjectStorageError - If CloudFront configuration is missing. + If required configuration is missing. """ + if self.storage_type == "s3": + return self._generate_cloudfront_signed_url(remote_key, expires_in) + if self.storage_type == "azure": + raise ObjectStorageError( + "Azure blob signed URLs are not generated by the server. " + "Use remote_path / blob_url in metadata and obtain Azure AD or other " + "tokens on the client to read objects." + ) + raise ObjectStorageError(f"Unsupported storage_type: {self.storage_type}") + + def _generate_cloudfront_signed_url(self, remote_key: str, expires_in: int) -> str: + """Generate a CloudFront signed URL for S3.""" if not all( [ self.cloudfront_domain, @@ -678,5 +795,7 @@ def generate_signed_url(self, remote_key: str, expires_in: int = 86400) -> str: f"&Key-Pair-Id={self.cloudfront_key_pair_id}" ) - logger.debug(f"Generated signed URL for {remote_key}, expires in {expires_in}s") + logger.debug( + f"Generated CloudFront signed URL for {remote_key}, expires in {expires_in}s" + ) return signed_url diff --git a/earth2studio/serve/server/utils.py b/earth2studio/serve/server/utils.py index 1d0ea6b70..09460d855 100644 --- a/earth2studio/serve/server/utils.py +++ b/earth2studio/serve/server/utils.py @@ -14,8 +14,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from __future__ import annotations - import logging from typing import Any, Literal @@ -72,10 +70,13 @@ def get_signed_url_key(request_id: str) -> str: # ============================================================================= +Stage = Literal["inference", "result_zip", "object_storage", "geocatalog_ingestion"] + + @check_optional_dependencies() def queue_next_stage( redis_client: redis.Redis, - current_stage: Literal["inference", "result_zip", "object_storage"], + current_stage: Stage, workflow_name: str, execution_id: str, output_path_str: str, @@ -85,12 +86,12 @@ def queue_next_stage( Queue the next pipeline stage based on configuration. Pipeline flow: - - If result_zip_enabled: inference -> result_zip -> object_storage (if enabled) -> finalize - - If not result_zip_enabled: inference -> object_storage (if enabled) -> finalize + - If result_zip_enabled: inference -> result_zip -> object_storage (if enabled) -> [geocatalog_ingestion (if AZURE_GEOCATALOG_URL)] -> finalize + - If not result_zip_enabled: inference -> object_storage (if enabled) -> [geocatalog_ingestion (if AZURE_GEOCATALOG_URL)] -> finalize Args: redis_client: Redis client for queue connection - current_stage: The stage that just completed ("inference", "result_zip", "object_storage") + current_stage: The stage that just completed ("inference", "result_zip", "object_storage", "geocatalog_ingestion") workflow_name: Name of the workflow execution_id: Execution ID of the workflow output_path_str: Path to the output files @@ -133,6 +134,16 @@ def queue_next_stage( args = (workflow_name, execution_id) elif current_stage == "object_storage": + if config.object_storage.azure_geocatalog_url: + next_queue = "geocatalog_ingestion" + next_func = "azure_planetary_computer.geocatalog_ingestion.process_geocatalog_ingestion" + args = (workflow_name, execution_id) + else: + next_queue = "finalize_metadata" + next_func = "earth2studio.serve.server.cpu_worker.process_finalize_metadata" + args = (workflow_name, execution_id) + + elif current_stage == "geocatalog_ingestion": next_queue = "finalize_metadata" next_func = "earth2studio.serve.server.cpu_worker.process_finalize_metadata" args = (workflow_name, execution_id) diff --git a/earth2studio/serve/server/workflow.py b/earth2studio/serve/server/workflow.py index 7534fdf70..1958edebf 100644 --- a/earth2studio/serve/server/workflow.py +++ b/earth2studio/serve/server/workflow.py @@ -752,11 +752,74 @@ def get( return instance - def list_workflows(self) -> dict[str, str]: - """List all registered workflows.""" + def is_workflow_exposed(self, workflow_name: str) -> bool: + """ + Check if a workflow is exposed via API. + + A workflow is exposed if: + - The exposed_workflows list is empty (all workflows exposed by default), OR + - The workflow name is in the exposed_workflows list, OR + - The workflow name is in the warmup_workflows list (accessible for warmup) + + Parameters + ---------- + workflow_name : str + Name of the workflow to check + + Returns + ------- + bool + True if workflow should be exposed, False otherwise + """ + config = get_config() + exposed_workflows = config.workflow_exposure.exposed_workflows + warmup_workflows = config.workflow_exposure.warmup_workflows + + # Empty list means all workflows are exposed + if not exposed_workflows: + return True + + # Check if in exposed list or warmup list + return workflow_name in exposed_workflows or workflow_name in warmup_workflows + + def list_workflows(self, exposed_only: bool = True) -> dict[str, str]: + """ + List registered workflows. + + Parameters + ---------- + exposed_only : bool, optional + If True, only return workflows that are in exposed_workflows + (warmup-only workflows are excluded from public listing). + If False, return all registered workflows. + + Returns + ------- + dict + Dictionary mapping workflow names to descriptions + """ + if not exposed_only: + return { + name: workflow_class.description + for name, workflow_class in self._workflows.items() + } + + config = get_config() + exposed_workflows = config.workflow_exposure.exposed_workflows + + # Empty list means all workflows are exposed (including warmup) + if not exposed_workflows: + return { + name: workflow_class.description + for name, workflow_class in self._workflows.items() + } + + # Only return workflows in the exposed_workflows list + # (warmup-only workflows are excluded from public listing) return { name: workflow_class.description for name, workflow_class in self._workflows.items() + if name in exposed_workflows } def discover_and_register_from_directories( diff --git a/pyproject.toml b/pyproject.toml index 0a40fbadf..014d21190 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -106,6 +106,11 @@ serve = [ "python-multipart>=0.0.6", "requests>=2.25.0", "urllib3>=1.26.0", + # Object storage related + "cryptography>=41.0.0", + "multi-storage-client>=0.44.0", + "azure-storage-blob>=12.19.0", + "azure-identity>=1.15.0", ] # PX Models ace2 = [ diff --git a/serve/server/README_object_storage.md b/serve/server/README_object_storage.md index b455137d1..4947dd7e6 100644 --- a/serve/server/README_object_storage.md +++ b/serve/server/README_object_storage.md @@ -1,18 +1,28 @@ # Object Storage Support -This document describes how to configure and use object storage (AWS S3 with CloudFront) for storing -workflow results in the Earth2Studio Inference Server. +This document describes how to configure and use object storage (AWS S3 with CloudFront or +Azure Blob Storage) for storing workflow results in the Earth2Studio Inference Server. ## Overview By default, workflow results are stored locally on the inference server. When object storage is -enabled, results are automatically uploaded to S3 and served via CloudFront signed URLs. This -provides: +enabled, results are automatically uploaded to your chosen cloud storage provider (AWS S3 or +Azure Blob Storage). For S3, CloudFront signed URLs can be generated; for Azure, the server +uploads with managed identity and does **not** issue read URLs—clients obtain tokens to read blobs. - **Scalability**: Offload storage from the inference server -- **Performance**: CloudFront CDN for fast global access -- **Security**: Time-limited signed URLs for secure access -- **Seamless Client Experience**: The Python client SDK automatically handles both storage types +- **Performance**: Fast global access via CDN (CloudFront for S3) or direct Azure Blob Storage access +- **Security**: Time-limited CloudFront signed URLs (S3); Azure reads use your own token model +- **Seamless Client Experience**: The Python client SDK automatically handles S3; Azure may require + client-side auth for reads + +## Storage Provider Options + +The inference server supports two storage providers: + +- **AWS S3**: With optional CloudFront CDN for enhanced performance +- **Azure Blob Storage**: Uploads via managed identity; clients obtain Azure AD (or other) tokens to + read data ## AWS Prerequisites @@ -20,70 +30,56 @@ Before enabling object storage, you need to set up the following AWS resources: ### 1. S3 Bucket -Create an S3 bucket to store workflow results: +Create an S3 bucket to store workflow results. +**Must for performance**: Enable S3 Transfer Acceleration for faster uploads: -```bash -aws s3 mb s3://your-bucket-name --region us-east-1 -``` +### 2. CloudFront Distribution -**Must for performance**: Enable S3 Transfer Acceleration for faster uploads: +Create a CloudFront distribution to serve content from your S3 bucket. -```bash -aws s3api put-bucket-accelerate-configuration \ - --bucket your-bucket-name \ - --accelerate-configuration Status=Enabled -``` +### 3. CloudFront Key Pair for Signed URLs -### 2. CloudFront Distribution +To generate signed URLs, you need a CloudFront key pair. -Create a CloudFront distribution to serve content from your S3 bucket: +### 4. IAM Credentials -1. Go to AWS CloudFront Console → Create Distribution -2. Set Origin Domain to your S3 bucket (`your-bucket-name.s3.amazonaws.com`) -3. Set Origin Access to "Origin access control settings (recommended)" -4. Create a new Origin Access Control (OAC) -5. Update the S3 bucket policy to allow CloudFront access (AWS will provide the policy) +Create IAM credentials with permissions to upload to S3. -### 3. CloudFront Key Pair for Signed URLs +## Azure Prerequisites -To generate signed URLs, you need a CloudFront key pair: +Before enabling Azure Blob Storage, you need to set up the following Azure resources: -1. Go to AWS CloudFront Console → Key Management → Public Keys -2. Create a new public key by uploading a public key you generated: +### 1. Azure Storage Account -```bash -# Generate a private key -openssl genrsa -out cloudfront-private-key.pem 2048 +Create an Azure Storage Account. -# Extract the public key -openssl rsa -in cloudfront-private-key.pem -pubout -out cloudfront-public-key.pem -``` +### 2. Storage Container -Then: +Create a blob container in your storage account. -1. Upload `cloudfront-public-key.pem` to CloudFront -2. Create a Key Group containing your public key -3. Associate the Key Group with your CloudFront distribution's behavior settings (Restrict Viewer -Access → Yes, Trusted Key Groups) -4. Note the **Key Pair ID** (e.g., `KUCQGLNFR6UH1`) -5. Keep `cloudfront-private-key.pem` secure - this is used by the server to sign URLs +### 3. Managed identity (or equivalent) for uploads -### 4. IAM Credentials +The inference server writes to the container using **DefaultAzureCredential** (e.g. user-assigned +or system-assigned managed identity). Grant that identity **Storage Blob Data Contributor** (or +equivalent) on the storage account or container. -Create IAM credentials with permissions to upload to S3. See [Creating IAM -users](https://docs.aws.amazon.com/IAM/latest/UserGuide/id_users_create.html) for -detailed instructions. The credentials need `s3:PutObject`, `s3:GetObject`, -`s3:DeleteObject`, and `s3:ListBucket` permissions on your bucket. +### 4. Client read access + +Downstream clients that read blobs should use **Azure AD** (or your chosen mechanism) to obtain +tokens; the server does not generate SAS or other signed read URLs for Azure. ## Server Configuration ### Environment Variables -Configure object storage using environment variables: +Configure object storage using environment variables. Choose either AWS S3 or Azure Blob Storage: + +#### AWS S3 Configuration ```bash # Enable object storage export OBJECT_STORAGE_ENABLED=true +export OBJECT_STORAGE_TYPE=s3 # S3 Configuration export OBJECT_STORAGE_BUCKET=your-bucket-name @@ -107,17 +103,57 @@ export OBJECT_STORAGE_USE_RUST_CLIENT=true # High-performance Rust client # CloudFront Signed URL Configuration export CLOUDFRONT_DOMAIN=https://d30anq61ot046p.cloudfront.net export CLOUDFRONT_KEY_PAIR_ID=KUCQGLNFR6UH1 -export CLOUDFRONT_PRIVATE_KEY_PATH=/path/to/cloudfront-private-key.pem -export OBJECT_STORAGE_SIGNED_URL_EXPIRES_IN=3600 # URL expiration in seconds +# PEM private key *content* (not a file path); use quoting / multiline env as +# supported by your shell +export CLOUDFRONT_PRIVATE_KEY="-----BEGIN RSA PRIVATE KEY-----..." +export SIGNED_URL_EXPIRES_IN=86400 # URL expiration in seconds (S3/CloudFront only) +``` + +#### Azure Blob Storage Configuration + +```bash +# Enable object storage +export OBJECT_STORAGE_ENABLED=true +export OBJECT_STORAGE_TYPE=azure + +# Azure Configuration +# Container name (used as bucket equivalent) +export OBJECT_STORAGE_BUCKET=your-container-name +export OBJECT_STORAGE_PREFIX=outputs # Optional: prefix for uploaded files + +# Azure: storage account (managed identity / DefaultAzureCredential for uploads) +export AZURE_STORAGE_ACCOUNT_NAME=mystorageaccount +# Blob service endpoint (optional). Either works; both set `endpoint_url` in config +# (see code order). +# export OBJECT_STORAGE_ENDPOINT_URL=https://mystorageaccount.blob.core.windows.net +# export AZURE_ENDPOINT_URL=https://mystorageaccount.blob.core.windows.net + +# Optional: Container name (defaults to OBJECT_STORAGE_BUCKET if not set) +export AZURE_CONTAINER_NAME=workflow-results + +# Optional: Planetary Computer Pro / GeoCatalog STAC ingestion after upload +# (queues `geocatalog_ingestion`) +# export AZURE_GEOCATALOG_URL=https://geocatalog.spatio.azure.com/ + +# Transfer Configuration +export OBJECT_STORAGE_MAX_CONCURRENCY=16 # Concurrent upload threads +export OBJECT_STORAGE_MULTIPART_CHUNKSIZE=8388608 # 8MB chunk size +export OBJECT_STORAGE_USE_RUST_CLIENT=true # High-performance Rust client + +# Blob reads: clients obtain Azure AD (or other) tokens independently; +# the server does not issue SAS/signed read URLs. ``` ### YAML Configuration Alternatively, configure via `config.yaml`: +#### AWS S3 YAML Configuration + ```yaml object_storage: enabled: true + storage_type: s3 bucket: your-bucket-name region: us-east-1 prefix: outputs @@ -127,8 +163,25 @@ object_storage: use_rust_client: true cloudfront_domain: https://d30anq61ot046p.cloudfront.net cloudfront_key_pair_id: KUCQGLNFR6UH1 - cloudfront_private_key_path: /path/to/cloudfront-private-key.pem - signed_url_expires_in: 3600 + cloudfront_private_key: null # PEM private key content + signed_url_expires_in: 86400 +``` + +#### Azure Blob Storage YAML Configuration + +```yaml +object_storage: + enabled: true + storage_type: azure + bucket: your-container-name # Container name (used as bucket equivalent) + prefix: outputs + max_concurrency: 16 + multipart_chunksize: 8388608 + use_rust_client: true + azure_account_name: mystorageaccount # Required unless endpoint_url is set + endpoint_url: null # Optional: e.g. https://mystorageaccount.blob.core.windows.net + azure_container_name: workflow-results # Optional, defaults to bucket + azure_geocatalog_url: null # When set, run GeoCatalog ingestion after upload ``` ### Configuration Parameters Reference @@ -137,27 +190,33 @@ object_storage: | Parameter | Environment Variable | Default | Description | |-----------|---------------------|---------|-------------| | `enabled` | `OBJECT_STORAGE_ENABLED` | `false` | Enable object storage | -| `bucket` | `OBJECT_STORAGE_BUCKET` | `null` | S3 bucket name | -| `region` | `OBJECT_STORAGE_REGION` | `us-east-1` | AWS region | +| `storage_type` | `OBJECT_STORAGE_TYPE` | `s3` | Storage provider: `s3` or `azure` | +| `bucket` | `OBJECT_STORAGE_BUCKET` | `null` | S3 bucket name or Azure container name | +| `region` | `OBJECT_STORAGE_REGION` | `us-east-1` | AWS region (S3 only) | | `prefix` | `OBJECT_STORAGE_PREFIX` | `outputs` | Remote prefix for files | -| `access_key_id` | `OBJECT_STORAGE_ACCESS_KEY_ID` | `null` | AWS access key ID | -| `secret_access_key` | `OBJECT_STORAGE_SECRET_ACCESS_KEY` | `null` | AWS secret access key | -| `session_token` | `OBJECT_STORAGE_SESSION_TOKEN` | `null` | AWS session token | -| `endpoint_url` | `OBJECT_STORAGE_ENDPOINT_URL` | `null` | Custom endpoint (S3-compatible) | -| `use_transfer_acceleration` | `OBJECT_STORAGE_TRANSFER_ACCELERATION` | `true` | Enable S3 Transfer Acceleration | +| `access_key_id` | `OBJECT_STORAGE_ACCESS_KEY_ID` | `null` | AWS access key ID (S3 only) | +| `secret_access_key` | `OBJECT_STORAGE_SECRET_ACCESS_KEY` | `null` | AWS secret access key (S3 only) | +| `session_token` | `OBJECT_STORAGE_SESSION_TOKEN` | `null` | AWS session token (S3 only) | +| `endpoint_url` | `OBJECT_STORAGE_ENDPOINT_URL` or `AZURE_ENDPOINT_URL` | `null` | Custom endpoint (S3-compatible; Azure blob URL). Both env vars map to `endpoint_url`; if both are set, `AZURE_ENDPOINT_URL` wins. | +| `use_transfer_acceleration` | `OBJECT_STORAGE_TRANSFER_ACCELERATION` | `true` | Enable S3 Transfer Acceleration (S3 only) | | `max_concurrency` | `OBJECT_STORAGE_MAX_CONCURRENCY` | `16` | Max concurrent transfers | | `multipart_chunksize` | `OBJECT_STORAGE_MULTIPART_CHUNKSIZE` | `8388608` | Multipart chunk size (bytes) | | `use_rust_client` | `OBJECT_STORAGE_USE_RUST_CLIENT` | `true` | Use high-performance Rust client | -| `cloudfront_domain` | `CLOUDFRONT_DOMAIN` | `null` | CloudFront distribution domain | -| `cloudfront_key_pair_id` | `CLOUDFRONT_KEY_PAIR_ID` | `null` | CloudFront key pair ID | -| `cloudfront_private_key_path` | `CLOUDFRONT_PRIVATE_KEY_PATH` | `null` | Path to private key PEM file | -| `signed_url_expires_in` | `OBJECT_STORAGE_SIGNED_URL_EXPIRES_IN` | `3600` | Signed URL expiration (seconds) | +| `cloudfront_domain` | `CLOUDFRONT_DOMAIN` | `null` | CloudFront distribution domain (S3 only) | +| `cloudfront_key_pair_id` | `CLOUDFRONT_KEY_PAIR_ID` | `null` | CloudFront key pair ID (S3 only) | +| `cloudfront_private_key` | `CLOUDFRONT_PRIVATE_KEY` | `null` | PEM private key *content* for CloudFront signing (S3 only) | +| `azure_account_name` | `AZURE_STORAGE_ACCOUNT_NAME` | `null` | Azure storage account name (Azure only; uploads use DefaultAzureCredential) | +| `azure_container_name` | `AZURE_CONTAINER_NAME` | `null` | Azure container name (Azure only, defaults to bucket) | +| `azure_geocatalog_url` | `AZURE_GEOCATALOG_URL` | `null` | When set, enqueue GeoCatalog / Planetary Computer ingestion after Azure upload | +| `signed_url_expires_in` | `SIGNED_URL_EXPIRES_IN` | `86400` | CloudFront signed URL TTL (seconds; S3 only) | ## Result Metadata When object storage is enabled, the workflow result metadata includes additional fields: +### AWS S3 Example + ```json { "request_id": "exec_1769560728_10ed9d3c", @@ -174,12 +233,31 @@ When object storage is enabled, the workflow result metadata includes additional } ``` +### Azure Blob Storage Example + +```json +{ + "request_id": "exec_1769560728_10ed9d3c", + "status": "completed", + "storage_type": "azure", + "remote_path": "azure://workflow-results/outputs/exec_1769560728_10ed9d3c", + "output_files": [ + {"path": "exec_1769560728_10ed9d3c/results.zarr/.zarray", "size": 123}, + {"path": "exec_1769560728_10ed9d3c/results.zarr/t2m/0.0.0", "size": 4567890} + ] +} +``` + +There is no `signed_url` for Azure: use `remote_path` (and optional `blob_url` for GeoCatalog-related +flows) and authorize reads with your own Azure token flow. + ### Storage Type Values | Value | Description | |-------|-------------| | `server` | Results stored locally on the inference server | | `s3` | Results stored in S3, accessible via CloudFront signed URL | +| `azure` | Results stored in Azure Blob Storage; reads use client-issued tokens (no server SAS) | ## Client Usage @@ -209,7 +287,7 @@ request_result = client.run_inference_sync( InferenceRequest(parameters={"start_time": [datetime(2025, 8, 21, 6)]}) ) -# Automatically downloads from S3 if storage_type is "s3" +# Automatically downloads from S3 or Azure if storage_type is "s3" or "azure" for file in request_result.output_files[:5]: content = client.download_result(request_result, file.path) print(f"Downloaded {file.path}: {len(content.getvalue())} bytes") @@ -217,7 +295,9 @@ for file in request_result.output_files[:5]: ### Using Signed URLs Directly -For advanced use cases, you can use the signed URL directly: +For advanced use cases, you can use the signed URL directly (S3/CloudFront only). + +#### Using CloudFront Signed URLs ```python import requests @@ -237,6 +317,13 @@ file_url = f"{base_url}/{file_path}?{query_params}" response = requests.get(file_url) ``` +#### Azure blob reads + +Use the blob URL from result metadata (or construct it from account, container, and path) and +request an OAuth token from Azure AD with scope for Storage (e.g. `https://storage.azure.com/.default`), +then `GET` the blob with `Authorization: Bearer `. The inference server does not return a +pre-signed Azure URL. + ### Using with Xarray and Zarr The client provides an fsspec mapper for opening Zarr stores directly: @@ -252,63 +339,3 @@ mapper = create_cloudfront_mapper(request_result.signed_url, zarr_path="results. ds = xr.open_zarr(mapper, consolidated=True) print(ds) ``` - -## Signed URL Format - -CloudFront signed URLs contain three query parameters: - -```text -https://d30anq61ot046p.cloudfront.net/outputs/exec_123/*?Policy=eyJTdGF0ZW1lbnQiOl...\ -&Signature=ABC123...&Key-Pair-Id=KUCQGLNFR6UH1 -``` - -| Parameter | Description | -|-----------|-------------| -| `Policy` | Base64-encoded JSON policy specifying resource and expiration | -| `Signature` | RSA signature of the policy using the private key | -| `Key-Pair-Id` | CloudFront key pair ID used to verify the signature | - -The wildcard (`*`) in the URL path allows access to all files under that prefix. - -## Testing - -Run object storage integration tests: - -```bash -# Set required environment variables -export TEST_S3_BUCKET=my-test-bucket -export AWS_ACCESS_KEY_ID=AKIA... -export AWS_SECRET_ACCESS_KEY=... - -# Run S3 upload tests -pytest test/integration/test_object_storage.py -v - -# Run CloudFront signed URL tests (requires additional config) -export TEST_CLOUDFRONT_DOMAIN=https://d30anq61ot046p.cloudfront.net -export TEST_CLOUDFRONT_KEY_PAIR_ID=KUCQGLNFR6UH1 -export TEST_CLOUDFRONT_PRIVATE_KEY_PATH=/path/to/private.pem -pytest test/integration/test_object_storage.py::TestCloudFrontSignedUrl -v -``` - -## Troubleshooting - -### Common Issues - -1. **403 Forbidden from CloudFront** - - Verify the S3 bucket policy allows CloudFront OAC access - - Check that the CloudFront distribution is configured with the correct origin - - Ensure the key pair is in a Key Group associated with the distribution - -2. **Signed URL expired** - - Increase `signed_url_expires_in` configuration - - Request fresh results from the API (URLs are regenerated) - -3. **Upload failures** - - Verify IAM credentials have `s3:PutObject` permission - - Check bucket name and region are correct - - If using Transfer Acceleration, ensure it's enabled on the bucket - -4. **Slow uploads** - - Enable `use_rust_client` for better performance - - Increase `max_concurrency` for more parallel uploads - - Enable `use_transfer_acceleration` if uploading from distant regions diff --git a/serve/server/azure_planetary_computer/__init__.py b/serve/server/azure_planetary_computer/__init__.py new file mode 100644 index 000000000..0df5e8698 --- /dev/null +++ b/serve/server/azure_planetary_computer/__init__.py @@ -0,0 +1,15 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/serve/server/azure_planetary_computer/geocatalog_ingestion.py b/serve/server/azure_planetary_computer/geocatalog_ingestion.py new file mode 100644 index 000000000..517d3fe1c --- /dev/null +++ b/serve/server/azure_planetary_computer/geocatalog_ingestion.py @@ -0,0 +1,205 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import logging +from typing import Any + +from azure_planetary_computer.pc_client import ( + PLANETARY_COMPUTER_CLIENT_WORKFLOWS, + PlanetaryComputerClient, +) +from earth2studio.serve.server.cpu_worker import ( + config, + fail_workflow, + redis_client, +) +from earth2studio.serve.server.utils import ( + get_inference_request_metadata_key, + queue_next_stage, +) + +logger = logging.getLogger(__name__) + + +def process_geocatalog_ingestion( + workflow_name: str, + execution_id: str, +) -> dict[str, Any] | None: + """Trigger ingestion of uploaded inference results into Azure Planetary Computer / GeoCatalog when AZURE_GEOCATALOG_URL is configured. + + Intended to run from the geocatalog_ingestion_queue after process_object_storage_upload. + Reads storage info and parameters from Redis and calls the Planetary Computer client to + create a STAC feature for the uploaded netcdf blob. + + Parameters + ---------- + workflow_name : str + Name of the workflow + execution_id : str + Execution ID of the workflow + + Returns + ------- + dict[str, Any] | None + Dict containing result info, None on critical failure + """ + request_id = f"{workflow_name}:{execution_id}" + logger.info(f"Processing geocatalog ingestion for {request_id}") + + try: + geocatalog_url = config.object_storage.azure_geocatalog_url + if not geocatalog_url: + logger.warning( + f"AZURE_GEOCATALOG_URL not set, skipping geocatalog ingestion for {request_id}" + ) + job_id = queue_next_stage( + redis_client=redis_client, + current_stage="geocatalog_ingestion", + workflow_name=workflow_name, + execution_id=execution_id, + output_path_str="", + ) + if not job_id: + return fail_workflow( + workflow_name, + execution_id, + f"Failed to queue finalize_metadata for {request_id}", + ) + return { + "success": True, + "skipped": True, + "reason": "AZURE_GEOCATALOG_URL not set", + } + + storage_info_key = f"inference_request:{request_id}:storage_info" + metadata_key = get_inference_request_metadata_key(request_id) + storage_info_json = redis_client.get(storage_info_key) + pending_metadata_json = redis_client.get(metadata_key) + + if not storage_info_json or not pending_metadata_json: + logger.warning( + f"Storage info or pending metadata missing for {request_id}, skipping geocatalog ingestion" + ) + job_id = queue_next_stage( + redis_client=redis_client, + current_stage="geocatalog_ingestion", + workflow_name=workflow_name, + execution_id=execution_id, + output_path_str="", + ) + if not job_id: + return fail_workflow( + workflow_name, + execution_id, + f"Failed to queue finalize_metadata for {request_id}", + ) + return { + "success": True, + "skipped": True, + "reason": "missing storage/metadata", + } + + storage_info = json.loads(storage_info_json) + metadata_dict = json.loads(pending_metadata_json) + blob_url = storage_info.get("blob_url") + parameters = metadata_dict.get("parameters") or {} + + if not blob_url: + logger.warning( + f"No blob_url in storage info for {request_id} (e.g. not Azure or no .nc/.zarr " + f"dataset), skipping geocatalog ingestion" + ) + job_id = queue_next_stage( + redis_client=redis_client, + current_stage="geocatalog_ingestion", + workflow_name=workflow_name, + execution_id=execution_id, + output_path_str="", + ) + if not job_id: + return fail_workflow( + workflow_name, + execution_id, + f"Failed to queue finalize_metadata for {request_id}", + ) + return {"success": True, "skipped": True, "reason": "no blob_url"} + + logger.info(f"Blob URL: {blob_url}") + if workflow_name not in PLANETARY_COMPUTER_CLIENT_WORKFLOWS: + logger.info( + f"Workflow {workflow_name} not supported by Planetary Computer client, skipping ingestion for {request_id}" + ) + job_id = queue_next_stage( + redis_client=redis_client, + current_stage="geocatalog_ingestion", + workflow_name=workflow_name, + execution_id=execution_id, + output_path_str="", + ) + if not job_id: + return fail_workflow( + workflow_name, + execution_id, + f"Failed to queue finalize_metadata for {request_id}", + ) + return { + "success": True, + "skipped": True, + "reason": "workflow not supported", + } + + try: + pc_client = PlanetaryComputerClient(workflow_name=workflow_name) + collection_id = parameters.get("collection_id") + pc_client.create_feature( + geocatalog_url=geocatalog_url, + collection_id=collection_id, + parameters=parameters, + blob_url=blob_url, + ) + logger.info(f"GeoCatalog ingestion completed for {request_id}") + except Exception as e: + # Log but do not fail the pipeline; finalize_metadata should still run + logger.exception( + f"GeoCatalog ingestion failed for {request_id}: {e}. Queuing finalize_metadata anyway." + ) + + job_id = queue_next_stage( + redis_client=redis_client, + current_stage="geocatalog_ingestion", + workflow_name=workflow_name, + execution_id=execution_id, + output_path_str="", + ) + if not job_id: + return fail_workflow( + workflow_name, + execution_id, + f"Failed to queue next pipeline stage for {request_id}", + ) + logger.info( + f"Queued finalize_metadata for {workflow_name}:{execution_id} with RQ job ID: {job_id}" + ) + return {"success": True} + + except Exception as e: + logger.exception(f"Failed in geocatalog ingestion for {request_id}") + return fail_workflow( + workflow_name, + execution_id, + f"Geocatalog ingestion failed for {request_id}: {str(e)}", + ) diff --git a/serve/server/azure_planetary_computer/pc_client.py b/serve/server/azure_planetary_computer/pc_client.py new file mode 100644 index 000000000..2a3b6e4d3 --- /dev/null +++ b/serve/server/azure_planetary_computer/pc_client.py @@ -0,0 +1,387 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""HTTP client for Azure Planetary Computer Pro GeoCatalog STAC ingestion.""" + +from __future__ import annotations + +import json +import logging +import os +import time +from collections.abc import Mapping +from datetime import datetime, timedelta +from time import perf_counter +from typing import Any, Final +from uuid import uuid4 + +import requests + +logger = logging.getLogger("azure_planetary_computer") +logger.setLevel(logging.INFO) + +# Workflows with templates and GeoCatalog behavior in this package (single source of truth). +PLANETARY_COMPUTER_CLIENT_WORKFLOWS: Final[frozenset[str]] = frozenset( + { + "foundry_fcn3_workflow", + "foundry_fcn3_stormscope_goes_workflow", + } +) + + +class PlanetaryComputerClient: + """Client for STAC collection/feature creation against Planetary Computer Pro. + + Templates and render options are loaded from JSON files alongside this module. + + Parameters + ---------- + workflow_name : {'foundry_fcn3_workflow', 'foundry_fcn3_stormscope_goes_workflow'} + Earth2Studio workflow identifier selecting templates and step sizing. + """ + + APPLICATION_URL = "https://geocatalog.spatio.azure.com/" + REQUESTS_TIMEOUT = 30 + CREATION_TIMEOUT = 300 + + def __init__(self, workflow_name: str) -> None: + if workflow_name not in PLANETARY_COMPUTER_CLIENT_WORKFLOWS: + raise ValueError( + f"Unsupported workflow_name for PlanetaryComputerClient: {workflow_name!r}. " + f"Supported: {sorted(PLANETARY_COMPUTER_CLIENT_WORKFLOWS)}" + ) + self.workflow_name = workflow_name + self.headers: dict[str, str] | None = None + + def update_headers(self) -> None: + """Refresh the Authorization header using a new Azure credential token.""" + try: + from azure.identity import DefaultAzureCredential + except ImportError as e: + raise ImportError( + "PlanetaryComputerClient requires 'azure-identity'. " + "Install with the serve extra or pip install azure-identity." + ) from e + credential = DefaultAzureCredential() + token = credential.get_token(self.APPLICATION_URL) + self.headers = {"Authorization": f"Bearer {token.token}"} + + def _require_headers(self) -> dict[str, str]: + if self.headers is None: + raise RuntimeError( + "Authorization headers are not set; call update_headers() first." + ) + return self.headers + + def _get(self, url: str) -> requests.Response: + headers = self._require_headers() + return requests.get( + url, + headers=headers, + params={"api-version": "2025-04-30-preview"}, + timeout=self.REQUESTS_TIMEOUT, + ) + + def _post(self, url: str, body: dict | None = None) -> requests.Response: + headers = self._require_headers() + return requests.post( + url, + json=body, + headers=headers, + params={"api-version": "2025-04-30-preview"}, + timeout=self.REQUESTS_TIMEOUT, + ) + + def _put(self, url: str, body: dict | None = None) -> requests.Response: + headers = self._require_headers() + return requests.put( + url, + json=body, + headers=headers, + params={"api-version": "2025-04-30-preview"}, + timeout=self.REQUESTS_TIMEOUT, + ) + + def _create_element(self, url: str, stac_config: dict) -> None: + """Create a STAC collection or feature and wait for it to finish.""" + response = self._post( + url, + body=stac_config, + ) + location = response.headers["location"] + + logger.info("Creating '%s'...", stac_config["id"]) + start = perf_counter() + while True: + if (perf_counter() - start) > self.CREATION_TIMEOUT: + logger.error("Creation of '%s' timed out", stac_config["id"]) + return + + response = self._get(location) + status = response.json()["status"] + logger.info(status) + if status not in {"Pending", "Running"}: + break + time.sleep(5) + + if status == "Succeeded": + logger.info("Successfully created '%s'", stac_config["id"]) + else: + logger.error("Failed to create '%s': %s", stac_config["id"], response.text) + + def _get_collection_json(self, collection_id: str | None) -> dict: + """Load the STAC collection template and set the collection ID.""" + template_fns = { + "foundry_fcn3_workflow": "template-collection-fcn3.json", + "foundry_fcn3_stormscope_goes_workflow": "template-collection-fcn3-stormscope-goes.json", + } + template_fn = os.path.join( + os.path.dirname(__file__), template_fns[self.workflow_name] + ) + with open(template_fn) as f: + stac_config = json.load(f) + + if collection_id is None: + stac_config["id"] = stac_config["id"].format(uuid=uuid4()) + else: + stac_config["id"] = collection_id + return stac_config + + def _get_feature_json( + self, + start_time: datetime, + end_time: datetime, + blob_url: str, + ) -> dict: + """Load the STAC feature template and set the workflow parameters.""" + template_fns = { + "foundry_fcn3_workflow": "template-feature-fcn3.json", + "foundry_fcn3_stormscope_goes_workflow": "template-feature-fcn3-stormscope-goes.json", + } + template_fn = os.path.join( + os.path.dirname(__file__), template_fns[self.workflow_name] + ) + with open(template_fn) as f: + stac_config = json.load(f) + + iso_start = start_time.isoformat() + iso_end = end_time.isoformat() + + stac_config["id"] = stac_config["id"].format( + start_time=iso_start[:13], uuid=uuid4() + ) + stac_config["properties"]["datetime"] = iso_start + stac_config["properties"]["start_datetime"] = iso_start + stac_config["properties"]["end_datetime"] = iso_end + stac_config["assets"]["data"]["href"] = blob_url + stac_config["assets"]["data"]["description"] = stac_config["assets"]["data"][ + "description" + ].format(start_time=iso_start, end_time=iso_end) + return stac_config + + def _update_tile_settings(self, geocatalog_url: str, collection_id: str) -> None: + """Update 'minZoom' of the tile settings so user can zoom out.""" + tile_settings = { + "minZoom": 0, + "maxItemsPerTile": 35, + } + response = self._put( + f"{geocatalog_url}/stac/collections/{collection_id}/configurations/tile-settings", + body=tile_settings, + ) + status = response.status_code + if status not in {200, 201}: + logger.error( + "Could not update tile settings: Error %s - %s", + status, + response.text, + ) + + def _update_render_options( + self, + geocatalog_url: str, + collection_id: str, + ) -> None: + """Add example render options for a new collection.""" + if self.workflow_name == "foundry_fcn3_workflow": + render_params = [ + { + "id": "t2m", + "scale": [263, 313], + "cmap": "balance", + }, + { + "id": "t850", + "scale": [263, 313], + "cmap": "balance", + }, + { + "id": "u10m", + "scale": [-20, 20], + "cmap": "prgn", + }, + { + "id": "v10m", + "scale": [-20, 20], + "cmap": "prgn", + }, + { + "id": "z500", + "scale": [45000, 60000], + "cmap": "viridis", + }, + ] + elif self.workflow_name == "foundry_fcn3_stormscope_goes_workflow": + render_params = [ + { + "id": f"abi{aid:02}c", + "scale": [0, 1], + "cmap": "plasma", + } + for aid in [1, 2, 3, 7, 8, 9, 10, 13] + ] + else: + render_params = [] + + for params in render_params: + render_option = { + "id": f"auto-{params['id']}", + "name": params["id"], + "type": "raster-tile", + "options": ( + f"assets=data&subdataset_name={params['id']}" + "&sel=time=2100-01-01&sel=ensemble=0&sel_method=nearest" + f"&rescale={params['scale'][0]},{params['scale'][1]}" + f"&colormap_name={params['cmap']}" + ), + "minZoom": 0, + } + response = self._post( + f"{geocatalog_url}/stac/collections/{collection_id}/configurations/render-options", + body=render_option, + ) + status = response.status_code + if status not in {200, 201}: + logger.error( + "Could not update render options: Error %s - %s", + status, + response.text, + ) + + def _create_collection( + self, + geocatalog_url: str, + collection_id: str | None, + ) -> str: + """Create a new STAC collection.""" + stac_config = self._get_collection_json(collection_id) + self._create_element( + url=f"{geocatalog_url}/stac/collections", + stac_config=stac_config, + ) + self._update_tile_settings(geocatalog_url, stac_config["id"]) + self._update_render_options(geocatalog_url, stac_config["id"]) + return stac_config["id"] + + def _ensure_collection_exists( + self, + geocatalog_url: str, + collection_id: str, + ) -> str: + """Return collection ID, creating the collection if it does not exist.""" + response = self._get(f"{geocatalog_url}/stac/collections/{collection_id}") + status = response.status_code + if status == 200: + return collection_id + if status != 404: + raise RuntimeError( + f"Failed to retrieve collection: Error {status} - {response.text}" + ) + + return self._create_collection(geocatalog_url, collection_id) + + def _resolve_start_time( + self, parameters: Mapping[str, Any] | dict[str, Any] + ) -> datetime: + """Resolve start_time from workflow parameters (workflow-specific keys).""" + raw: datetime | str | None = None + if self.workflow_name == "foundry_fcn3_workflow": + raw = parameters.get("start_time") + elif self.workflow_name == "foundry_fcn3_stormscope_goes_workflow": + raw = parameters.get("start_time_stormscope") + else: + raise ValueError(f"Unsupported workflow name: {self.workflow_name}") + if raw is None: + raise ValueError( + f"Missing start time in parameters for workflow {self.workflow_name}. " + "Expected 'start_time' or (for stormscope) 'start_time_stormscope'." + ) + if isinstance(raw, str): + normalized = raw.replace("Z", "+00:00") + return datetime.fromisoformat(normalized) + if isinstance(raw, datetime): + return raw + if hasattr(raw, "isoformat"): + return datetime.fromisoformat(raw.isoformat()) + raise TypeError(f"start_time must be str or datetime, got {type(raw)}") + + def create_feature( + self, + geocatalog_url: str, + collection_id: str | None, + parameters: Mapping[str, Any] | dict[str, Any], + blob_url: str, + ) -> tuple[str, str]: + """Ingest a new STAC feature into the collection. + + Parameters + ---------- + geocatalog_url : str + URL to the Planetary Computer Pro catalog + collection_id : str | None + Existing collection ID, or None to create a new collection + parameters : Mapping[str, Any] | dict[str, Any] + Workflow parameters including start time (and optional collection_id override + is passed separately) + blob_url : str + Blob location on Azure Blob Storage + + Returns + ------- + tuple[str, str] + Collection ID and feature ID + """ + self.update_headers() + + if collection_id is None: + collection_id = self._create_collection(geocatalog_url, None) + else: + self._ensure_collection_exists(geocatalog_url, collection_id) + + start_time = self._resolve_start_time(parameters) + step_sizes = { + "foundry_fcn3_workflow": 6, + "foundry_fcn3_stormscope_goes_workflow": 1, + } + end_time = start_time + timedelta(hours=step_sizes[self.workflow_name]) + stac_config = self._get_feature_json(start_time, end_time, blob_url) + + self._create_element( + url=f"{geocatalog_url}/stac/collections/{collection_id}/items", + stac_config=stac_config, + ) + + return collection_id, stac_config["id"] diff --git a/serve/server/azure_planetary_computer/template-collection-fcn3-stormscope-goes.json b/serve/server/azure_planetary_computer/template-collection-fcn3-stormscope-goes.json new file mode 100644 index 000000000..62d5205b3 --- /dev/null +++ b/serve/server/azure_planetary_computer/template-collection-fcn3-stormscope-goes.json @@ -0,0 +1,42 @@ +{ + "type": "Collection", + "stac_version": "1.0.0", + "id": "earth-2-fcn3-stormscope-goes-{uuid}", + "title": "Earth-2 StormScope-GOES conditioned on FourCastNet3", + "description": "Forecasts generated with the FourCastNet3-conditioned StormScope-GOES workflow on Microsoft Foundry.", + "links": [], + "stac_extensions": [], + "item_assets": { + "data": { + "type": "application/x-netcdf", + "roles": [] + } + }, + "extent": { + "spatial": { + "bbox": [ + [ + -135, + 20, + -60, + 53 + ] + ] + }, + "temporal": { + "interval": [ + [ + null, + null + ] + ] + } + }, + "license": "other", + "keywords": [ + "CONUS", + "Forecast", + "Earth-2" + ], + "providers": [] +} diff --git a/serve/server/azure_planetary_computer/template-collection-fcn3.json b/serve/server/azure_planetary_computer/template-collection-fcn3.json new file mode 100644 index 000000000..21fc63b56 --- /dev/null +++ b/serve/server/azure_planetary_computer/template-collection-fcn3.json @@ -0,0 +1,42 @@ +{ + "type": "Collection", + "stac_version": "1.0.0", + "id": "earth-2-fcn3-{uuid}", + "title": "Earth-2 FourCastNet3", + "description": "Forecasts generated with the FourCastNet3 workflow on Microsoft Foundry.", + "links": [], + "stac_extensions": [], + "item_assets": { + "data": { + "type": "application/x-netcdf", + "roles": [] + } + }, + "extent": { + "spatial": { + "bbox": [ + [ + -180, + -90, + 180, + 90 + ] + ] + }, + "temporal": { + "interval": [ + [ + null, + null + ] + ] + } + }, + "license": "other", + "keywords": [ + "Global", + "Forecast", + "Earth-2" + ], + "providers": [] +} diff --git a/serve/server/azure_planetary_computer/template-feature-fcn3-stormscope-goes.json b/serve/server/azure_planetary_computer/template-feature-fcn3-stormscope-goes.json new file mode 100644 index 000000000..3a8d389f9 --- /dev/null +++ b/serve/server/azure_planetary_computer/template-feature-fcn3-stormscope-goes.json @@ -0,0 +1,54 @@ +{ + "type": "Feature", + "stac_version": "1.1.0", + "id": "fcn3-stormscope-goes-{start_time}-{uuid}", + "stac_extensions": [], + "geometry": { + "type": "Polygon", + "coordinates": [ + [ + [ + -60, + 20 + ], + [ + -60, + 53 + ], + [ + -135, + 53 + ], + [ + -135, + 20 + ], + [ + -60, + 20 + ] + ] + ] + }, + "bbox": [ + -135, + 20, + -60, + 53 + ], + "properties": { + "datetime": "{start_time}", + "start_datetime": "{start_time}", + "end_datetime": "{end_time}" + }, + "links": [], + "assets": { + "data": { + "href": "{blob_url}", + "type": "application/x-netcdf", + "title": "StormScope-GOES forecast conditioned on FourCastNet3", + "description": "FourCastNet3-conditioned StormScope-GOES forecast from {start_time} to {end_time}.", + "roles": [] + } + } +} diff --git a/serve/server/azure_planetary_computer/template-feature-fcn3.json b/serve/server/azure_planetary_computer/template-feature-fcn3.json new file mode 100644 index 000000000..38fb9a801 --- /dev/null +++ b/serve/server/azure_planetary_computer/template-feature-fcn3.json @@ -0,0 +1,54 @@ +{ + "type": "Feature", + "stac_version": "1.1.0", + "id": "fcn3-{start_time}-{uuid}", + "stac_extensions": [], + "geometry": { + "type": "Polygon", + "coordinates": [ + [ + [ + 180, + -90 + ], + [ + 180, + 90 + ], + [ + -180, + 90 + ], + [ + -180, + -90 + ], + [ + 180, + -90 + ] + ] + ] + }, + "bbox": [ + -180, + -90, + 180, + 90 + ], + "properties": { + "datetime": "{start_time}", + "start_datetime": "{start_time}", + "end_datetime": "{end_time}" + }, + "links": [], + "assets": { + "data": { + "href": "{blob_url}", + "type": "application/x-netcdf", + "title": "FourCastNet3 forecast", + "description": "FourCastNet3 forecast from {start_time} to {end_time}.", + "roles": [] + } + } +} diff --git a/serve/server/conf/config.yaml b/serve/server/conf/config.yaml index f654305da..114068d51 100644 --- a/serve/server/conf/config.yaml +++ b/serve/server/conf/config.yaml @@ -36,11 +36,13 @@ worker: num_workers: 1 # The number of RQ inference workers to create by default zip_num_workers: 1 # The number of RQ workers for result_zip queue objstore_num_workers: 1 # The number of RQ workers for object_storage queue + geocatalog_num_workers: 1 # The number of RQ workers for geocatalog_ingestion queue (used when AZURE_GEOCATALOG_URL is set) finalize_num_workers: 1 # The number of RQ workers for finalize_metadata queue paths: default_output_dir: /outputs results_zip_dir: /workspace/earth2studio-project/examples/outputs + output_format: zarr # Output format: "zarr" or "netcdf4" result_zip_enabled: false logging: @@ -86,5 +88,18 @@ object_storage: cloudfront_domain: null cloudfront_key_pair_id: null cloudfront_private_key: null # PEM private key content - # Signed URL settings + # Signed URL settings (S3 + CloudFront only; not used for Azure blob) signed_url_expires_in: 86400 # 24 hours + # Azure Blob Storage (uploads use managed identity; no server-side SAS/signed URLs) + azure_account_name: null + azure_container_name: null + azure_geocatalog_url: null # When set, triggers Planetary Computer ingestion after Azure upload + +workflow_exposure: + # List of workflow names to expose via API endpoints + # Empty list means all workflows are exposed + exposed_workflows: [] + # Workflows accessible for warmup even if not in exposed_workflows + # These workflows can be called via API for warmup purposes + warmup_workflows: + - example_user_workflow diff --git a/serve/server/example_workflows/foundry_fcn3.py b/serve/server/example_workflows/foundry_fcn3.py new file mode 100644 index 000000000..86350023c --- /dev/null +++ b/serve/server/example_workflows/foundry_fcn3.py @@ -0,0 +1,243 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +from collections.abc import Sequence +from datetime import datetime + +import numpy as np +import torch +import zarr +from cftime import date2num + +from earth2studio.data import PlanetaryComputerECMWFOpenDataIFS, fetch_data +from earth2studio.io import IOBackend, NetCDF4Backend, ZarrBackend +from earth2studio.models.px import FCN3 +from earth2studio.serve.server import ( + Earth2Workflow, + WorkflowProgress, + workflow_registry, +) +from earth2studio.utils.coords import CoordSystem, map_coords, split_coords +from earth2studio.utils.time import timearray_to_datetime, to_time_array + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger("foundry_fcn3_workflow") + + +@workflow_registry.register +class FoundryFCN3Workflow(Earth2Workflow): + """FCN3 ensemble inference workflow for Foundry using ECMWF IFS initial conditions.""" + + name = "foundry_fcn3_workflow" + description = "FCN3 ensemble workflow for Foundry" + + def __init__( + self, + device: str = "cuda", + init_seed: int | None = None, + ): + super().__init__() + + self.device = torch.device(device) + + self.fcn3 = self.load_fcn3() + self.rng = np.random.default_rng(init_seed) + + self.data = PlanetaryComputerECMWFOpenDataIFS(verbose=False, cache=False) + + def load_fcn3(self) -> FCN3: + """Load the default FCN3 package, move it to the workflow device, and set eval mode.""" + logger.info("Loading FCN3") + package = FCN3.load_default_package() + fcn3 = FCN3.load_model(package) + fcn3.to(self.device) + fcn3.eval() + return fcn3 + + def get_seeds(self, n_seeds: int) -> list[int]: + """Sample ``n_seeds`` distinct integer RNG seeds for ensemble members.""" + seeds = self.rng.choice(2**32, size=n_seeds, replace=False) + return [int(s) for s in seeds] + + def validate_start_time(self, start_time: datetime) -> None: + """Require ``start_time`` to fall on a 6-hour boundary (FCN3 / IFS cadence).""" + if (start_time - datetime(1900, 1, 1)).total_seconds() % 21600 != 0: + raise ValueError(f"Start time needs to be 6-hour interval: {start_time}") + + def validate_samples( + self, n_samples: int, seeds: Sequence[int] | None + ) -> list[int]: + """Return ensemble seeds of length ``n_samples``, generating them if omitted.""" + if seeds is None: + seeds = self.get_seeds(n_samples) + elif len(seeds) != n_samples: + logger.warning( + "Ignoring requested number of samples because it does not match number of seeds" + ) + return list(seeds) + + def validate_variables(self, variables: Sequence[str] | None) -> np.ndarray: + """Resolve output variables, defaulting to the model's variables and checking names.""" + if variables is None: + variables = self.fcn3.variables + else: + unknown_variables = set(variables) - set(self.fcn3.variables) + if len(unknown_variables): + raise ValueError(f"Unknown variable(s) {', '.join(unknown_variables)}") + variables = np.array(variables) + return variables + + def setup_io( + self, io: IOBackend, output_coords: CoordSystem, seeds: Sequence[int] + ) -> None: + """Define Zarr/NetCDF arrays, CRS metadata, and time encoding for ensemble outputs.""" + io.add_array( + {k: v for k, v in output_coords.items() if k != "variable"}, + output_coords["variable"], + ) + + # Storing seeds separately makes it easier to filter with Titiler + e_coords = {"ensemble": output_coords["ensemble"]} + io.add_array(e_coords, "seed", data=torch.tensor(seeds)) + + # Add CRS definition + io.add_array({}, "crs") + io.root["crs"].grid_mapping_name = "latitude_longitude" + io.root["crs"].longitude_of_prime_meridian = 0.0 + io.root["crs"].semi_major_axis = 6378137.0 + io.root["crs"].inverse_flattening = 298.257223563 + + for var in output_coords["variable"]: + io.root[var].grid_mapping = "crs" + + # Set attributes for automatic parsing of dimensions + io.root["ensemble"].standard_name = "realization" + io.root["time"].standard_name = "time" + io.root["time"].axis = "T" + io.root["lat"].standard_name = "latitude" + io.root["lat"].units = "degrees_north" + io.root["lat"].axis = "Y" + io.root["lon"].standard_name = "longitude" + io.root["lon"].units = "degrees_east" + io.root["lon"].axis = "X" + + # Unwrap BackendProgress (serve API) + e2io = ( + io + if isinstance(io, (NetCDF4Backend, ZarrBackend)) + else getattr(io, "io", None) + ) + + if isinstance(e2io, ZarrBackend): + zarr.consolidate_metadata(e2io.store) + + if isinstance(e2io, NetCDF4Backend): + # Planetary Computer does not like the original time format (hours since 0001-01-01). + ref_time = np.datetime_as_string(output_coords["time"][0], unit="s") + units = f"hours since {ref_time.replace('T', ' ')}" + tv = e2io.root["time"] + tv.units = units + tv[:] = date2num( + timearray_to_datetime(output_coords["time"]), + units=units, + calendar=tv.calendar, + ) + e2io.root.sync() + + return io + + def get_fcn3_input(self, time: datetime) -> tuple[torch.Tensor, CoordSystem]: + """Fetch FCN3 input tensors and coordinates from Planetary Computer ECMWF IFS.""" + x, coords = fetch_data( + self.data, + time=to_time_array([time]), + variable=self.fcn3.input_coords()["variable"], + device=self.device, + ) + return x, coords + + def __call__( + self, + io: IOBackend, + start_time: datetime = datetime(2025, 1, 1), + n_steps: int = 20, + n_samples: int = 16, + seeds: Sequence[int] | None = None, + variables: Sequence[str] | None = ("t2m", "u10m", "v10m"), + collection_id: str | None = None, + ) -> None: + self.validate_start_time(start_time) + lead_times = np.array([np.timedelta64(i * 6, "h") for i in range(n_steps + 1)]) + seeds = self.validate_samples(n_samples, seeds) + variables = self.validate_variables(variables) + + x_ori, coords_ori = self.get_fcn3_input(start_time) + + output_coords = CoordSystem( + { + "ensemble": np.arange(len(seeds)), + # Combine 'time' and 'lead_time' into single dimension + "time": to_time_array([start_time]) + lead_times, + "variable": variables, + "lat": np.linspace(90.0, -90.0, 721), + "lon": np.linspace(-180, 180, 1440, endpoint=False), + } + ) + self.setup_io(io, output_coords, seeds) + + logger.info("Starting inference") + total_samples = len(seeds) + n_steps += 1 # add 1 for step 0 (initial conditions) + for sample, seed in enumerate(seeds): + + self.fcn3.set_rng(seed=seed) + iterator = self.fcn3.create_iterator(x_ori.clone(), coords_ori.copy()) + for step, (x, coords) in enumerate(iterator): + # Update progress for step within sample + msg = ( + f"Processing sample {sample + 1}/{total_samples} " + f"(seed={seed}), step {step + 1}/{len(lead_times)}" + ) + progress = WorkflowProgress( + progress=msg, + current_step=step + 1, + total_steps=n_steps, + ) + self.update_progress(progress) + logger.info(msg) + + # Select variables + x_out, coords_out = map_coords( + x, coords, CoordSystem({"variable": output_coords["variable"]}) + ) + # Roll longitudes (for raster visualization) + x_out = torch.roll(x_out, 720, dims=-1) + coords_out["lon"] = np.linspace(-180, 180, 1440, endpoint=False) + # Add ensemble dimension + x_out = x_out.unsqueeze(0) + coords_out["ensemble"] = np.array([sample]) + coords_out.move_to_end("ensemble", last=False) + # Combine time and lead_time + lead_time_dim = list(coords_out).index("lead_time") + x_out = x_out.squeeze(lead_time_dim) + coords_out["time"] = coords_out["time"] + coords_out["lead_time"] + del coords_out["lead_time"] + # Write to disk + io.write(*split_coords(x_out, coords_out)) + + if step == (n_steps - 1): + break diff --git a/serve/server/example_workflows/foundry_fcn3_stormscope_goes.py b/serve/server/example_workflows/foundry_fcn3_stormscope_goes.py new file mode 100644 index 000000000..82013dd0c --- /dev/null +++ b/serve/server/example_workflows/foundry_fcn3_stormscope_goes.py @@ -0,0 +1,502 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +from collections import OrderedDict +from collections.abc import Sequence +from datetime import datetime, timedelta + +import numpy as np +import torch +import xarray as xr +import zarr +from cftime import date2num + +from earth2studio.data import ( + GOES, + InferenceOutputSource, + PlanetaryComputerECMWFOpenDataIFS, + PlanetaryComputerGOES, + fetch_data, +) +from earth2studio.io import IOBackend, NetCDF4Backend, XarrayBackend, ZarrBackend +from earth2studio.models.dx import DerivedSurfacePressure +from earth2studio.models.px import FCN3, DiagnosticWrapper, InterpModAFNO +from earth2studio.models.px.stormscope import ( + StormScopeBase, + StormScopeGOES, +) +from earth2studio.serve.server import ( + Earth2Workflow, + WorkflowProgress, + workflow_registry, +) +from earth2studio.utils.coords import CoordSystem, map_coords, split_coords +from earth2studio.utils.time import timearray_to_datetime, to_time_array + +logger = logging.getLogger("foundry_fcn3_stormscope_goes_workflow") +logger.setLevel(logging.INFO) + +GOES_MODEL_NAME = "6km_60min_natten_cos_zenith_input_eoe_v2" + + +@workflow_registry.register +class FoundryFCN3StormScopeGOESWorkflow(Earth2Workflow): + """FCN3 (with interpolation) plus StormScope GOES diagnostic ensemble for Foundry.""" + + name = "foundry_fcn3_stormscope_goes_workflow" + description = "FCN3+StormScopeGOES ensemble workflow for Foundry" + + def __init__( + self, + device: str = "cuda", + init_seed: int = 1234, + ): + super().__init__() + + self.device = torch.device(device) + + self.fcn3_interp = self.load_fcn3_interp() + self.stormscope = self.load_stormscope() + self.rng = np.random.default_rng(init_seed) + + self.data_fcn3 = PlanetaryComputerECMWFOpenDataIFS(verbose=False, cache=False) + + scan_mode = "C" + self.data_stormscope = { + satellite: PlanetaryComputerGOES( + satellite=satellite, scan_mode=scan_mode, verbose=False, cache=False + ) + for satellite in ["goes16", "goes19"] + } + + # GOES-16 and GOES19 have the same grid + goes_lat, goes_lon = GOES.grid(satellite="goes16", scan_mode=scan_mode) + coords_out = self.fcn3_interp.output_coords(self.fcn3_interp.input_coords()) + self.stormscope.build_input_interpolator(goes_lat, goes_lon) + self.stormscope.build_conditioning_interpolator( + coords_out["lat"], coords_out["lon"] + ) + + def load_fcn3_interp(self) -> InterpModAFNO: + """Load FCN3 with surface pressure diagnostics and hourly ``InterpModAFNO`` wrapping.""" + logger.info("Loading FCN3") + package = FCN3.load_default_package() + fcn3 = FCN3.load_model(package) + + # Surface pressure interpolation + orography_fn = package.resolve("orography.nc") + with xr.open_dataset(orography_fn) as ds: + z_surface = torch.as_tensor(ds["Z"][0].values) + z_surf_coords = OrderedDict({d: fcn3.input_coords()[d] for d in ["lat", "lon"]}) + sp_model = DerivedSurfacePressure( + p_levels=[50, 100, 150, 200, 250, 300, 400, 500, 600, 700, 850, 925, 1000], + surface_geopotential=z_surface, + surface_geopotential_coords=z_surf_coords, + ) + + # Bundle surface pressure with FCN3 + fcn3_sp = DiagnosticWrapper(px_model=fcn3, dx_model=sp_model) + + # Add temporal interpolation to 1 hour + fcn3_interp = InterpModAFNO.from_pretrained() + fcn3_interp.px_model = fcn3_sp + fcn3_interp.to(device=self.device) + fcn3_interp.eval() + return fcn3_interp + + def load_stormscope(self) -> StormScopeGOES: + """Load the StormScope GOES model package and move it to the workflow device.""" + logger.info("Loading StormScope") + package = StormScopeBase.load_default_package() + stormscope = StormScopeGOES.load_model( + package=package, + conditioning_data_source=None, # set later + model_name=GOES_MODEL_NAME, + ) + stormscope.to(self.device) + stormscope.eval() + return stormscope + + def get_seeds(self, n_seeds: int) -> list[int]: + """Sample ``n_seeds`` distinct integer RNG seeds for ensemble members.""" + seeds = self.rng.choice(2**32, size=n_seeds, replace=False) + return [int(s) for s in seeds] + + def validate_start_times( + self, time_stormscope: datetime, time_fcn3: datetime + ) -> None: + """Check StormScope (1 h) and FCN3 (6 h) start times and their relative ordering.""" + ref = datetime(1900, 1, 1) + if (time_stormscope - ref).total_seconds() % (1 * 60 * 60) != 0: + raise ValueError( + f"Start time for StormScope must be 1-hour interval: {time_stormscope}" + ) + if (time_fcn3 - ref).total_seconds() % (6 * 60 * 60) != 0: + raise ValueError( + f"Start time for FCN3 must be 6-hour interval: {time_fcn3}" + ) + if time_stormscope < time_fcn3: + raise ValueError( + "Start time for StormScope cannot preceed start time for FCN3" + ) + if time_stormscope - time_fcn3 > timedelta(hours=12): + logger.warning( + "Start times for StormScope and FCN3 should not be more than 12 hours apart but got '%s' and '%s'", + time_stormscope, + time_fcn3, + ) + + def validate_samples( + self, n_samples: int, seeds: Sequence[int] | None + ) -> list[int]: + """Return ensemble seeds of length ``n_samples``, generating them if missing.""" + if not seeds: + return self.get_seeds(n_samples) + if len(seeds) != n_samples: + logger.warning( + "Ignoring requested number of samples because it does not match number of seeds" + ) + return list(seeds) + + def validate_variables(self, variables: Sequence[str] | None) -> np.ndarray: + """Resolve StormScope output variables, defaulting to the model's variables.""" + if variables is None: + variables = self.stormscope.variables + else: + unknown_variables = set(variables) - set(self.stormscope.variables) + if len(unknown_variables): + raise ValueError(f"Unknown variable(s) {', '.join(unknown_variables)}") + variables = np.array(variables) + return variables + + def setup_io( + self, + io: IOBackend, + output_coords: CoordSystem, + seeds_fcn3: Sequence[int], + seeds_stormscope: Sequence[int], + ) -> None: + """Define IO arrays, CRS metadata, and per-model seeds for ensemble outputs.""" + io.add_array( + {k: v for k, v in output_coords.items() if k != "variable"}, + output_coords["variable"], + ) + + # Storing seeds separately makes it easier to filter with Titiler + e_coords = {"ensemble": output_coords["ensemble"]} + n_stormscope_per_fcn3 = len(seeds_stormscope) // len(seeds_fcn3) + tiled_seeds_fcn3 = np.repeat(seeds_fcn3, n_stormscope_per_fcn3) + io.add_array(e_coords, "seed_fcn3", data=torch.tensor(tiled_seeds_fcn3)) + io.add_array(e_coords, "seed_stormscope", data=torch.tensor(seeds_stormscope)) + + # Add CRS definition + io.add_array({}, "crs") + io.root["crs"].grid_mapping_name = "lambert_conformal_conic" + io.root["crs"].standard_parallel = 38.5 + io.root["crs"].longitude_of_central_meridian = 262.5 + io.root["crs"].latitude_of_projection_origin = 38.5 + io.root["crs"].semi_major_axis = 6371229 + io.root["crs"].semi_minor_axis = 6371229 + + for var in output_coords["variable"]: + io.root[var].grid_mapping = "crs" + + # Set attributes for automatic parsing of dimensions + io.root["ensemble"].standard_name = "realization" + io.root["time"].standard_name = "time" + io.root["time"].axis = "T" + io.root["y"].standard_name = "projection_y_coordinate" + io.root["y"].units = "m" + io.root["y"].axis = "Y" + io.root["x"].standard_name = "projection_x_coordinate" + io.root["x"].units = "m" + io.root["x"].axis = "X" + + # Unwrap BackendProgress (serve API) + e2io = ( + io + if isinstance(io, (NetCDF4Backend, ZarrBackend)) + else getattr(io, "io", None) + ) + + if isinstance(e2io, ZarrBackend): + zarr.consolidate_metadata(e2io.store) + + if isinstance(e2io, NetCDF4Backend): + # Planetary Computer does not like the original time format (hours since 0001-01-01). + # Re-encode from the same datetimes as add_dimension so values match units, then + # sync so the coordinate is flushed. + ref_time = np.datetime_as_string(output_coords["time"][0], unit="s") + units = f"hours since {ref_time.replace('T', ' ')}" + tv = e2io.root["time"] + tv.units = units + tv[:] = date2num( + timearray_to_datetime(output_coords["time"]), + units=units, + calendar=tv.calendar, + ) + e2io.root.sync() + + return io + + def get_fcn3_input(self, time: datetime) -> tuple[torch.Tensor, CoordSystem]: + """Fetch FCN3 branch input from Planetary Computer ECMWF IFS.""" + x, coords = fetch_data( + self.data_fcn3, + time=to_time_array([time]), + variable=self.fcn3_interp.input_coords()["variable"], + device=self.device, + ) + return x, coords + + def get_stormscope_input(self, time: datetime) -> tuple[torch.Tensor, CoordSystem]: + """Fetch GOES inputs for StormScope (GOES-16 vs GOES-19 by date) and preprocess.""" + coords_in = self.stormscope.input_coords() + if time < datetime(2025, 4, 7): + data = self.data_stormscope["goes16"] + else: + data = self.data_stormscope["goes19"] + x, coords = fetch_data( + data, + time=to_time_array([time]), + variable=coords_in["variable"], + lead_time=coords_in["lead_time"], + device=self.device, + ) + + batch_size = 1 + if x.dim() == 5: + x = x.unsqueeze(0).repeat(batch_size, 1, 1, 1, 1, 1) + coords["batch"] = np.arange(batch_size) + coords.move_to_end("batch", last=False) + + x, coords = self.stormscope.prep_input(x, coords) + x = torch.where(self.stormscope.valid_mask, x, torch.nan) + + return x, coords + + def run_fcn3( + self, + io: IOBackend, + x: torch.Tensor, + coords_x: CoordSystem, + seed_fcn3: int, + start_time_stormscope: datetime, + lead_times: np.ndarray, + sample: int, + total_samples: int, + ) -> None: + """Run FCN3 to produce conditioning fields for StormScope up to the given horizon.""" + # Create z500 conditioning with FCN3 + coords_in = self.stormscope.input_coords() + start_time_stormscope = to_time_array([start_time_stormscope]) + variables = self.stormscope.conditioning_variables + # Start time and lead times are shifted to StormScope start time + output_coords = { + "time": start_time_stormscope, + "lead_time": lead_times, + "variable": variables, + "y": coords_in["y"], + "x": coords_in["x"], + } + io.add_array( + {k: v for k, v in output_coords.items() if k != "variable"}, variables + ) + + model_gap = int( + (start_time_stormscope - coords_x["time"]) / np.timedelta64(1, "h") + ) + + self.fcn3_interp.px_model.px_model.set_rng(seed=seed_fcn3) + iterator = self.fcn3_interp.create_iterator(x.clone(), coords_x.copy()) + + n_steps = model_gap + len(lead_times) + for step, (x, coords_x) in enumerate(iterator): + # Update progress for FCN3 step + msg = ( + f"Processing FCN3 for sample {sample + 1}/{total_samples} " + f"(seed_fcn3={seed_fcn3}) " + f"step {step + 1}/{n_steps}" + ) + progress = WorkflowProgress( + progress=msg, + current_step=step + 1, + total_steps=n_steps, + ) + self.update_progress(progress) + logger.info(msg) + + if step < model_gap: + # Skip initial steps leading up to StormScope start time + continue + + x, coords_x = map_coords(x, coords_x, OrderedDict({"variable": variables})) + x, coords_x = self.stormscope.prep_input(x, coords_x, conditioning=True) + coords_x["time"] = start_time_stormscope + coords_x["lead_time"] = coords_x["lead_time"] - np.timedelta64( + model_gap, "h" + ) + io.write(*split_coords(x, coords_x)) + + if step == (n_steps - 1): + break + + def run_stormscope( + self, + io: IOBackend, + y: torch.Tensor, + coords_y: CoordSystem, + seed_fcn3: int, + seed_stormscope: int, + lead_times: np.ndarray, + variables: np.ndarray, + sample: int, + total_samples: int, + ) -> None: + """Run StormScope autoregressively and write outputs to ``io``.""" + n_steps = len(lead_times) + + def log_progress(step: int) -> None: + msg = ( + f"Processing sample {sample + 1}/{total_samples} " + f"(seed_fcn3={seed_fcn3}, seed_stormscope={seed_stormscope}), " + f"step {step + 1}/{n_steps}" + ) + progress = WorkflowProgress( + progress=msg, + current_step=step + 1, + total_steps=n_steps, + ) + self.update_progress(progress) + logger.info(msg) + + def prep_output( + y_pred: torch.Tensor, coords_pred: CoordSystem + ) -> tuple[torch.Tensor, CoordSystem]: + y_out, coords_out = map_coords( + y_pred, coords_pred, CoordSystem({"variable": variables}) + ) + del coords_out["batch"] + # Reuse batch dimension as ensemble dimension (squeeze/unsqueeze) + coords_out["ensemble"] = np.array([sample]) + coords_out.move_to_end("ensemble", last=False) + # Combine time and lead_time + lead_time_dim = list(coords_out).index("lead_time") + y_out = y_out.squeeze(lead_time_dim) + coords_out["time"] = coords_out["time"] + coords_out["lead_time"] + del coords_out["lead_time"] + return y_out, coords_out + + # Update progress for step within sample + log_progress(0) + + # Store initial GOES data (identical across seeds) + y_out, coords_out = prep_output(y, coords_y) + io.write(*split_coords(y_out, coords_out)) + + # Cannot use seeded Generator before torch==2.10 + # Use self.stormscope.sampler_args["randn_like"] once updated + torch.manual_seed(seed_stormscope) + + for step in range(1, n_steps): + y_pred, coords_pred = self.stormscope(y, coords_y) + + # Update progress for step within sample + log_progress(step) + + y_out, coords_out = prep_output(y_pred, coords_pred) + io.write(*split_coords(y_out, coords_out)) + + if step == (n_steps - 1): + break + + y, coords_y = self.stormscope.next_input(y_pred, coords_pred, y, coords_y) + + def __call__( + self, + io: IOBackend, + start_time_fcn3: datetime = datetime(2025, 1, 1, 18), + start_time_stormscope: datetime = datetime(2025, 1, 1, 18), + n_steps: int = 12, + n_samples_fcn3: int = 4, + n_samples_stormscope: int = 16, + seeds_fcn3: Sequence[int] | None = None, + seeds_stormscope: Sequence[int] | None = None, + variables: Sequence[str] | None = ("abi01c", "abi02c", "abi03c"), + collection_id: str | None = None, + ) -> None: + self.validate_start_times(start_time_stormscope, start_time_fcn3) + lead_times = np.array([np.timedelta64(i, "h") for i in range(n_steps + 1)]) + # Different StormScope seed for every trajectory + if n_samples_stormscope % n_samples_fcn3 != 0: + raise ValueError( + "'n_samples_stormscope' must be divisible by 'n_samples_fcn3'" + ) + seeds_fcn3 = self.validate_samples(n_samples_fcn3, seeds_fcn3) + seeds_stormscope = self.validate_samples(n_samples_stormscope, seeds_stormscope) + n_stormscope_per_fcn3 = len(seeds_stormscope) // len(seeds_fcn3) + variables = self.validate_variables(variables) + + x_ori, coords_x_ori = self.get_fcn3_input(start_time_fcn3) + y_ori, coords_y_ori = self.get_stormscope_input(start_time_stormscope) + + coords_out = self.stormscope.output_coords(self.stormscope.input_coords()) + output_coords = { + "ensemble": np.arange(len(seeds_stormscope)), + # Planetary Computer does not like separate 'lead_time' + "time": to_time_array([start_time_stormscope]) + lead_times, + "variable": variables, + "y": coords_out["y"], + "x": coords_out["x"], + } + self.setup_io(io, output_coords, seeds_fcn3, seeds_stormscope) + + total_samples = len(seeds_stormscope) + sample = 0 + for seed_fcn3 in seeds_fcn3: + # Generate FCN3 conditioning (z500) + logger.info("Starting FCN3 inference") + io_fcn3 = XarrayBackend() + self.run_fcn3( + io=io_fcn3, + x=x_ori.clone(), + coords_x=coords_x_ori.copy(), + seed_fcn3=seed_fcn3, + start_time_stormscope=start_time_stormscope, + lead_times=lead_times, + sample=sample, + total_samples=total_samples, + ) + self.stormscope.conditioning_data_source = InferenceOutputSource( + io_fcn3.root + ) + + # Run StormScope forecast conditioned on FCN3 + logger.info("Starting StormScope inference") + for _ in range(n_stormscope_per_fcn3): + self.run_stormscope( + io=io, + y=y_ori.clone(), + coords_y=coords_y_ori.copy(), + seed_fcn3=seed_fcn3, + seed_stormscope=seeds_stormscope[sample], + lead_times=lead_times, + variables=variables, + sample=sample, + total_samples=total_samples, + ) + sample += 1 diff --git a/serve/server/requirements.txt b/serve/server/requirements.txt index 9003236f8..69b3a0404 100644 --- a/serve/server/requirements.txt +++ b/serve/server/requirements.txt @@ -32,6 +32,13 @@ hydra-core>=1.3.0 # Cryptography for CloudFront signed URLs cryptography>=41.0.0 +# Multi-Storage Client with Azure support +multi-storage-client>=0.44.0 +# Azure Blob Storage for SAS token generation +azure-storage-blob>=12.19.0 +# Azure Identity for managed identity authentication +azure-identity>=1.15.0 + # Development dependencies pytest>=7.0.0 pytest-asyncio>=0.21.0 diff --git a/serve/server/scripts/start_api_server.sh b/serve/server/scripts/start_api_server.sh index 7bcf20fcf..506f66db7 100755 --- a/serve/server/scripts/start_api_server.sh +++ b/serve/server/scripts/start_api_server.sh @@ -25,6 +25,10 @@ CONFIG_DIR=${CONFIG_DIR:-"${SCRIPT_DIR}/../conf"} export SCRIPT_DIR export CONFIG_DIR export WORKFLOW_DIR +SERVE_SERVER_DIR="$(cd "$SCRIPT_DIR/.." && pwd)" +REPO_ROOT="$(cd "$SERVE_SERVER_DIR/../.." && pwd)" +# Repo root for earth2studio; serve/server for azure_planetary_computer.* (geocatalog RQ jobs). +export PYTHONPATH="${REPO_ROOT}:${REPO_ROOT}/serve/server:${PYTHONPATH:-}" CONFIG_FILE="$CONFIG_DIR/config.yaml" # Function to read config values from YAML using Python @@ -56,6 +60,7 @@ if [ -f "$CONFIG_FILE" ]; then CONFIG_RQ_NUM_WORKERS=$(read_config "worker.num_workers") CONFIG_ZIP_NUM_WORKERS=$(read_config "worker.zip_num_workers") CONFIG_OBJSTORE_NUM_WORKERS=$(read_config "worker.objstore_num_workers") + CONFIG_GEOCATALOG_NUM_WORKERS=$(read_config "worker.geocatalog_num_workers") CONFIG_FINALIZE_NUM_WORKERS=$(read_config "worker.finalize_num_workers") CONFIG_PERSISTENT_WORKER=$(read_config "worker.persistent") fi @@ -66,10 +71,11 @@ REDIS_HOST=${3:-${CONFIG_REDIS_HOST:-localhost}} # Default Redis host NUM_RQ_WORKERS=${4:-${CONFIG_RQ_NUM_WORKERS:-1}} # Default to 1 RQ workers NUM_ZIP_WORKERS=${5:-${CONFIG_ZIP_NUM_WORKERS:-1}} # Default to 1 workers for result_zip queue NUM_OBJSTORE_WORKERS=${CONFIG_OBJSTORE_NUM_WORKERS:-1} # Default to 1 object storage worker +NUM_GEOCATALOG_WORKERS=${CONFIG_GEOCATALOG_NUM_WORKERS:-1} # Default to 1 geocatalog ingestion worker NUM_FINALIZE_WORKERS=${CONFIG_FINALIZE_NUM_WORKERS:-1} # Default to 1 finalize metadata worker PERSISTENT_WORKER=${CONFIG_PERSISTENT_WORKER:-false} -echo "Starting Earth2Studio with $NUM_WORKERS API workers, $NUM_RQ_WORKERS RQ workers, $NUM_ZIP_WORKERS zip workers, $NUM_OBJSTORE_WORKERS object storage workers, and $NUM_FINALIZE_WORKERS finalize workers on port $API_PORT..." +echo "Starting Earth2Studio with $NUM_WORKERS API workers, $NUM_RQ_WORKERS RQ workers, $NUM_ZIP_WORKERS zip workers, $NUM_OBJSTORE_WORKERS object storage workers, $NUM_GEOCATALOG_WORKERS geocatalog workers, and $NUM_FINALIZE_WORKERS finalize workers on port $API_PORT..." echo "Configuration: Redis=$REDIS_HOST, Persistent Worker=$PERSISTENT_WORKER" # Function to cleanup on exit @@ -106,6 +112,12 @@ cleanup() { pkill -f "rq.*worker.*object_storage" fi + # Stop all geocatalog ingestion workers + if pgrep -f "rq.*worker.*geocatalog_ingestion" > /dev/null; then + echo "Stopping geocatalog ingestion workers..." + pkill -f "rq.*worker.*geocatalog_ingestion" + fi + # Stop all finalize metadata workers if pgrep -f "rq.*worker.*finalize_metadata" > /dev/null; then echo "Stopping finalize metadata workers..." @@ -122,7 +134,7 @@ trap cleanup SIGINT SIGTERM export EARTH2STUDIO_API_ACTIVE=1 # Start multiple workers using uvicorn with extended timeouts for large file downloads -uvicorn earth2studio.serve.server.main:app --host 0.0.0.0 --port $API_PORT --workers $NUM_WORKERS --loop asyncio --timeout-keep-alive 300 --timeout-graceful-shutdown 30 & +CUDA_VISIBLE_DEVICES="" uvicorn earth2studio.serve.server.main:app --host 0.0.0.0 --port $API_PORT --workers $NUM_WORKERS --loop asyncio --timeout-keep-alive 300 --timeout-graceful-shutdown 30 & UVICORN_PID=$! # Start RQ workers @@ -146,7 +158,7 @@ done echo "Starting $NUM_ZIP_WORKERS zip workers for result_zip queue..." ZIP_WORKER_PIDS=() for i in $(seq 1 $NUM_ZIP_WORKERS); do - rq worker -w rq.worker.SimpleWorker result_zip & + CUDA_VISIBLE_DEVICES="" rq worker -w rq.worker.SimpleWorker result_zip & ZIP_WORKER_PIDS+=($!) echo "Started zip worker $i for result_zip queue (PID: $!)" done @@ -155,22 +167,31 @@ done echo "Starting $NUM_OBJSTORE_WORKERS object storage workers..." OBJSTORE_WORKER_PIDS=() for i in $(seq 1 $NUM_OBJSTORE_WORKERS); do - rq worker -w rq.worker.SimpleWorker object_storage & + CUDA_VISIBLE_DEVICES="" rq worker -w rq.worker.SimpleWorker object_storage & OBJSTORE_WORKER_PIDS+=($!) echo "Started object storage worker $i (PID: $!)" done +# Start geocatalog ingestion workers (used when AZURE_GEOCATALOG_URL is set) +echo "Starting $NUM_GEOCATALOG_WORKERS geocatalog ingestion workers..." +GEOCATALOG_WORKER_PIDS=() +for i in $(seq 1 $NUM_GEOCATALOG_WORKERS); do + CUDA_VISIBLE_DEVICES="" rq worker -w rq.worker.SimpleWorker geocatalog_ingestion & + GEOCATALOG_WORKER_PIDS+=($!) + echo "Started geocatalog ingestion worker $i (PID: $!)" +done + # Start finalize metadata workers echo "Starting $NUM_FINALIZE_WORKERS finalize metadata workers..." FINALIZE_WORKER_PIDS=() for i in $(seq 1 $NUM_FINALIZE_WORKERS); do - rq worker -w rq.worker.SimpleWorker finalize_metadata & + CUDA_VISIBLE_DEVICES="" rq worker -w rq.worker.SimpleWorker finalize_metadata & FINALIZE_WORKER_PIDS+=($!) echo "Started finalize metadata worker $i (PID: $!)" done # Start cleanup daemon -python -m earth2studio.serve.server.cleanup_daemon & +CUDA_VISIBLE_DEVICES="" python -m earth2studio.serve.server.cleanup_daemon & CLEANUP_DAEMON_PID=$! echo "Started cleanup daemon (PID: $CLEANUP_DAEMON_PID)" @@ -205,6 +226,13 @@ if [ "$OBJSTORE_WORKER_COUNT" -eq 0 ]; then exit 1 fi +# Check if geocatalog ingestion workers are running +GEOCATALOG_WORKER_COUNT=$(pgrep -f "rq.*worker.*geocatalog_ingestion" | wc -l) +if [ "$GEOCATALOG_WORKER_COUNT" -eq 0 ]; then + echo "Failed to start geocatalog ingestion workers..." + exit 1 +fi + # Check if finalize metadata workers are running FINALIZE_WORKER_COUNT=$(pgrep -f "rq.*worker.*finalize_metadata" | wc -l) if [ "$FINALIZE_WORKER_COUNT" -eq 0 ]; then @@ -224,12 +252,14 @@ echo "Uvicorn PID: $UVICORN_PID" echo "RQ Worker PIDs: ${RQ_WORKER_PIDS[*]}" echo "Zip Worker PIDs: ${ZIP_WORKER_PIDS[*]}" echo "Object Storage Worker PIDs: ${OBJSTORE_WORKER_PIDS[*]}" +echo "Geocatalog Ingestion Worker PIDs: ${GEOCATALOG_WORKER_PIDS[*]}" echo "Finalize Metadata Worker PIDs: ${FINALIZE_WORKER_PIDS[*]}" echo "Cleanup Daemon PID: $CLEANUP_DAEMON_PID" echo "Active API workers: $API_WORKER_COUNT" echo "Active RQ inference workers: $RQ_WORKER_COUNT" echo "Active zip workers: $ZIP_WORKER_COUNT" echo "Active object storage workers: $OBJSTORE_WORKER_COUNT" +echo "Active geocatalog ingestion workers: $GEOCATALOG_WORKER_COUNT" echo "Active finalize metadata workers: $FINALIZE_WORKER_COUNT" echo "API available at http://localhost:$API_PORT" echo "API docs at http://localhost:$API_PORT/docs" @@ -238,7 +268,7 @@ echo "API docs at http://localhost:$API_PORT/docs" # Wait for health check to pass before invoking warmup workflow echo "" echo "Waiting for health check to pass..." -MAX_HEALTH_RETRIES=30 +MAX_HEALTH_RETRIES=60 HEALTH_RETRY_INTERVAL=2 for i in $(seq 1 $MAX_HEALTH_RETRIES); do if curl -s "http://localhost:$API_PORT/health" | grep -q '"status":"healthy"'; then diff --git a/serve/server/scripts/startup.sh b/serve/server/scripts/startup.sh index b5739fe07..b9ca82086 100755 --- a/serve/server/scripts/startup.sh +++ b/serve/server/scripts/startup.sh @@ -17,16 +17,24 @@ set -euo pipefail +# Set EARTH2STUDIO_MODEL_CACHE to use AZUREML_MODEL_DIR if available +if [ -n "${AZUREML_MODEL_DIR:-}" ]; then + echo "AZUREML_MODEL_DIR: $AZUREML_MODEL_DIR" + export EARTH2STUDIO_MODEL_CACHE="$AZUREML_MODEL_DIR/${EARTH2STUDIO_MODEL_SUBPATH:-e2s_fcn3_stormscope}" + echo "--------------------------------" + echo "EARTH2STUDIO_MODEL_CACHE: $EARTH2STUDIO_MODEL_CACHE" + ls -la $EARTH2STUDIO_MODEL_CACHE && echo "--------------------------------" +fi + # Use CONFIG_DIR/SCRIPT_DIR from env if set (e.g. in Docker); else resolve from script location SCRIPT_DIR="${SCRIPT_DIR:-$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)}" SERVE_SERVER_DIR="$(cd "$SCRIPT_DIR/.." && pwd)" -REPO_ROOT="$(cd "$SERVE_SERVER_DIR/../.." && pwd)" export SCRIPT_DIR export CONFIG_DIR="${CONFIG_DIR:-$SCRIPT_DIR/../conf}" export WORKFLOW_DIR="${WORKFLOW_DIR:-}" -# Ensure Python can find earth2studio (including earth2studio.serve.server) -export PYTHONPATH="${REPO_ROOT}:${PYTHONPATH:-}" +# PYTHONPATH (repo root + serve/server for azure_planetary_computer.*) is set in +# scripts/start_api_server.sh before uvicorn and RQ workers start. cd "$SERVE_SERVER_DIR" make start-redis diff --git a/serve/server/scripts/status.sh b/serve/server/scripts/status.sh index aadb6b94f..2dcdaadf4 100755 --- a/serve/server/scripts/status.sh +++ b/serve/server/scripts/status.sh @@ -65,7 +65,7 @@ fi echo "" # Check RQ workers status per queue -RQ_QUEUES=("inference" "result_zip" "object_storage" "finalize_metadata") +RQ_QUEUES=("inference" "result_zip" "object_storage" "geocatalog_ingestion" "finalize_metadata") RQ_ALL_OK=1 echo "RQ Workers:" diff --git a/serve/server/scripts/stop_api_server.sh b/serve/server/scripts/stop_api_server.sh index d5769013d..6f00e0830 100755 --- a/serve/server/scripts/stop_api_server.sh +++ b/serve/server/scripts/stop_api_server.sh @@ -33,6 +33,10 @@ pkill -f "rq.*worker.*result_zip" echo "Stopping object storage workers..." pkill -f "rq.*worker.*object_storage" +# Stop geocatalog ingestion workers +echo "Stopping geocatalog ingestion workers..." +pkill -f "rq.*worker.*geocatalog_ingestion" + # Stop finalize metadata workers echo "Stopping finalize metadata workers..." pkill -f "rq.*worker.*finalize_metadata" diff --git a/test/serve/server/conftest.py b/test/serve/server/conftest.py index 197a365e5..abb36805e 100644 --- a/test/serve/server/conftest.py +++ b/test/serve/server/conftest.py @@ -17,11 +17,16 @@ import sys from pathlib import Path -# Repo root is parent of test/serve/server; ensure project is importable -_repo_root = Path(__file__).resolve().parent.parent.parent +# Repo root: conftest lives at /test/serve/server/conftest.py +_repo_root = Path(__file__).resolve().parent.parent.parent.parent if str(_repo_root) not in sys.path: sys.path.insert(0, str(_repo_root)) +# serve/server hosts top-level package azure_planetary_computer (worker GeoCatalog client) +_serve_server_root = _repo_root / "serve" / "server" +if str(_serve_server_root) not in sys.path: + sys.path.insert(0, str(_serve_server_root)) + # Store original earth2studio.serve.server.config module to restore if mocked _original_config_module = None diff --git a/test/serve/server/test_planetary_pc_client.py b/test/serve/server/test_planetary_pc_client.py new file mode 100644 index 000000000..32e16e6b7 --- /dev/null +++ b/test/serve/server/test_planetary_pc_client.py @@ -0,0 +1,109 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import sys +from pathlib import Path + +# Top-level package azure_planetary_computer lives under /serve/server (see start_api_server.sh PYTHONPATH). +_repo_root = Path(__file__).resolve().parent.parent.parent.parent +_serve_server = _repo_root / "serve" / "server" +if str(_serve_server) not in sys.path: + sys.path.insert(0, str(_serve_server)) + +from datetime import datetime, timedelta, timezone # noqa: E402 +from unittest.mock import MagicMock, patch # noqa: E402 + +import pytest # noqa: E402 + +pytest.importorskip("earth2studio.serve.server") + +from azure_planetary_computer.pc_client import PlanetaryComputerClient # noqa: E402 + + +@pytest.fixture +def mock_azure_credential(): + with patch("azure.identity.DefaultAzureCredential") as m: + cred = MagicMock() + token = MagicMock() + token.token = "mock-token" # noqa: S105 + cred.get_token.return_value = token + m.return_value = cred + yield m + + +@pytest.fixture +def mock_requests_pc(): + with patch("azure_planetary_computer.pc_client.requests") as m: + get_resp = MagicMock() + get_resp.status_code = 200 + get_resp.json.return_value = {"status": "Succeeded"} + get_resp.headers = {} + m.get.return_value = get_resp + + post_resp = MagicMock() + post_resp.status_code = 201 + post_resp.headers = {"location": "https://geocatalog.example/status/123"} + m.post.return_value = post_resp + + put_resp = MagicMock() + put_resp.status_code = 200 + m.put.return_value = put_resp + + yield m + + +def test_pc_client_resolve_start_time_iso_string(mock_azure_credential): + client = PlanetaryComputerClient("foundry_fcn3_workflow") + params = {"start_time": "2025-01-15T12:00:00Z"} + result = client._resolve_start_time(params) + assert result == datetime(2025, 1, 15, 12, 0, 0, tzinfo=timezone.utc) + + +def test_pc_client_resolve_start_time_stormscope_key(mock_azure_credential): + client = PlanetaryComputerClient("foundry_fcn3_stormscope_goes_workflow") + params = {"start_time_stormscope": "2025-01-15T12:00:00Z"} + result = client._resolve_start_time(params) + assert result == datetime(2025, 1, 15, 12, 0, 0, tzinfo=timezone.utc) + + +def test_pc_client_get_feature_json_formats_times(mock_azure_credential): + client = PlanetaryComputerClient("foundry_fcn3_workflow") + start = datetime(2025, 1, 1, 0, 0, 0, tzinfo=timezone.utc) + end = start + timedelta(hours=6) + out = client._get_feature_json( + start_time=start, + end_time=end, + blob_url="https://storage.example/container/blob.nc", + ) + assert out["properties"]["datetime"] == start.isoformat() + assert out["assets"]["data"]["href"] == "https://storage.example/container/blob.nc" + assert out["id"].startswith("fcn3-") + + +def test_pc_client_create_feature_new_collection( + mock_azure_credential, mock_requests_pc +): + mock_requests_pc.get.return_value.json.return_value = {"status": "Succeeded"} + client = PlanetaryComputerClient("foundry_fcn3_workflow") + collection_id, feature_id = client.create_feature( + geocatalog_url="https://geocatalog.example/", + collection_id=None, + parameters={"start_time": "2025-01-01T00:00:00Z"}, + blob_url="https://storage.example/blob.nc", + ) + assert collection_id is not None + assert feature_id is not None + assert feature_id.startswith("fcn3-") diff --git a/test/serve/server/test_server_config.py b/test/serve/server/test_server_config.py index 9325ab653..69f810221 100644 --- a/test/serve/server/test_server_config.py +++ b/test/serve/server/test_server_config.py @@ -576,3 +576,277 @@ def test_dict_to_config_handles_missing_keys(self) -> None: assert config.redis.host == "test_host" # Other configs should have defaults assert config.queue.name == "inference" # Default value + + +class TestObjectStorageEnvOverrides: + """Test object storage and Azure environment variable overrides""" + + def setup_method(self) -> None: + self._vars = [ + "OBJECT_STORAGE_TYPE", + "OBJECT_STORAGE_BUCKET", + "OBJECT_STORAGE_REGION", + "OBJECT_STORAGE_PREFIX", + "OBJECT_STORAGE_ACCESS_KEY_ID", + "OBJECT_STORAGE_SECRET_ACCESS_KEY", + "OBJECT_STORAGE_SESSION_TOKEN", + "OBJECT_STORAGE_ENDPOINT_URL", + "OBJECT_STORAGE_TRANSFER_ACCELERATION", + "OBJECT_STORAGE_MAX_CONCURRENCY", + "OBJECT_STORAGE_MULTIPART_CHUNKSIZE", + "OBJECT_STORAGE_USE_RUST_CLIENT", + "CLOUDFRONT_DOMAIN", + "CLOUDFRONT_KEY_PAIR_ID", + "CLOUDFRONT_PRIVATE_KEY", + "SIGNED_URL_EXPIRES_IN", + "AZURE_STORAGE_ACCOUNT_NAME", + "AZURE_CONTAINER_NAME", + "AZURE_ENDPOINT_URL", + "AZURE_GEOCATALOG_URL", + "EXPOSED_WORKFLOWS", + "OUTPUT_FORMAT", + "CONFIG_DIR", + ] + for v in self._vars: + os.environ.pop(v, None) + + def teardown_method(self) -> None: + for v in self._vars: + os.environ.pop(v, None) + reset_config() + + def _get_manager(self) -> "ConfigManager": + reset_config() + return ConfigManager() + + def test_object_storage_type_s3(self) -> None: + manager = self._get_manager() + os.environ["OBJECT_STORAGE_TYPE"] = "s3" + manager._apply_env_overrides() + assert manager.config.object_storage.storage_type == "s3" + + def test_object_storage_type_azure(self) -> None: + manager = self._get_manager() + os.environ["OBJECT_STORAGE_TYPE"] = "azure" + manager._apply_env_overrides() + assert manager.config.object_storage.storage_type == "azure" + + def test_object_storage_type_invalid_ignored(self) -> None: + manager = self._get_manager() + original = manager.config.object_storage.storage_type + os.environ["OBJECT_STORAGE_TYPE"] = "gcs" + manager._apply_env_overrides() + assert manager.config.object_storage.storage_type == original + + def test_object_storage_bucket(self) -> None: + manager = self._get_manager() + os.environ["OBJECT_STORAGE_BUCKET"] = "my-bucket" + manager._apply_env_overrides() + assert manager.config.object_storage.bucket == "my-bucket" + + def test_object_storage_region(self) -> None: + manager = self._get_manager() + os.environ["OBJECT_STORAGE_REGION"] = "eu-west-1" + manager._apply_env_overrides() + assert manager.config.object_storage.region == "eu-west-1" + + def test_object_storage_prefix(self) -> None: + manager = self._get_manager() + os.environ["OBJECT_STORAGE_PREFIX"] = "custom/prefix" + manager._apply_env_overrides() + assert manager.config.object_storage.prefix == "custom/prefix" + + def test_object_storage_access_key_id(self) -> None: + manager = self._get_manager() + os.environ["OBJECT_STORAGE_ACCESS_KEY_ID"] = "AKID" + manager._apply_env_overrides() + assert manager.config.object_storage.access_key_id == "AKID" + + def test_object_storage_secret_access_key(self) -> None: + manager = self._get_manager() + os.environ["OBJECT_STORAGE_SECRET_ACCESS_KEY"] = "SECRET" # noqa: S105 + manager._apply_env_overrides() + assert manager.config.object_storage.secret_access_key == "SECRET" # noqa: S105 + + def test_object_storage_session_token(self) -> None: + manager = self._get_manager() + os.environ["OBJECT_STORAGE_SESSION_TOKEN"] = "TOKEN" # noqa: S105 + manager._apply_env_overrides() + assert manager.config.object_storage.session_token == "TOKEN" # noqa: S105 + + def test_object_storage_endpoint_url(self) -> None: + manager = self._get_manager() + os.environ["OBJECT_STORAGE_ENDPOINT_URL"] = "https://s3.local" + manager._apply_env_overrides() + assert manager.config.object_storage.endpoint_url == "https://s3.local" + + def test_object_storage_transfer_acceleration_true(self) -> None: + manager = self._get_manager() + os.environ["OBJECT_STORAGE_TRANSFER_ACCELERATION"] = "true" + manager._apply_env_overrides() + assert manager.config.object_storage.use_transfer_acceleration is True + + def test_object_storage_transfer_acceleration_false(self) -> None: + manager = self._get_manager() + os.environ["OBJECT_STORAGE_TRANSFER_ACCELERATION"] = "false" + manager._apply_env_overrides() + assert manager.config.object_storage.use_transfer_acceleration is False + + def test_object_storage_max_concurrency(self) -> None: + manager = self._get_manager() + os.environ["OBJECT_STORAGE_MAX_CONCURRENCY"] = "32" + manager._apply_env_overrides() + assert manager.config.object_storage.max_concurrency == 32 + + def test_object_storage_multipart_chunksize(self) -> None: + manager = self._get_manager() + os.environ["OBJECT_STORAGE_MULTIPART_CHUNKSIZE"] = "8388608" + manager._apply_env_overrides() + assert manager.config.object_storage.multipart_chunksize == 8388608 + + def test_object_storage_use_rust_client_true(self) -> None: + manager = self._get_manager() + os.environ["OBJECT_STORAGE_USE_RUST_CLIENT"] = "true" + manager._apply_env_overrides() + assert manager.config.object_storage.use_rust_client is True + + def test_cloudfront_domain(self) -> None: + manager = self._get_manager() + os.environ["CLOUDFRONT_DOMAIN"] = "cdn.example.com" + manager._apply_env_overrides() + assert manager.config.object_storage.cloudfront_domain == "cdn.example.com" + + def test_cloudfront_key_pair_id(self) -> None: + manager = self._get_manager() + os.environ["CLOUDFRONT_KEY_PAIR_ID"] = "KID123" + manager._apply_env_overrides() + assert manager.config.object_storage.cloudfront_key_pair_id == "KID123" + + def test_cloudfront_private_key(self) -> None: + manager = self._get_manager() + os.environ["CLOUDFRONT_PRIVATE_KEY"] = "-----BEGIN RSA PRIVATE KEY-----" + manager._apply_env_overrides() + assert ( + manager.config.object_storage.cloudfront_private_key + == "-----BEGIN RSA PRIVATE KEY-----" + ) + + def test_signed_url_expires_in(self) -> None: + manager = self._get_manager() + os.environ["SIGNED_URL_EXPIRES_IN"] = "3600" + manager._apply_env_overrides() + assert manager.config.object_storage.signed_url_expires_in == 3600 + + def test_azure_storage_account_name(self) -> None: + manager = self._get_manager() + os.environ["AZURE_STORAGE_ACCOUNT_NAME"] = "myaccount" + manager._apply_env_overrides() + assert manager.config.object_storage.azure_account_name == "myaccount" + + def test_azure_container_name(self) -> None: + manager = self._get_manager() + os.environ["AZURE_CONTAINER_NAME"] = "mycontainer" + manager._apply_env_overrides() + assert manager.config.object_storage.azure_container_name == "mycontainer" + + def test_azure_endpoint_url(self) -> None: + manager = self._get_manager() + os.environ["AZURE_ENDPOINT_URL"] = "https://myaccount.blob.core.windows.net" + manager._apply_env_overrides() + assert ( + manager.config.object_storage.endpoint_url + == "https://myaccount.blob.core.windows.net" + ) + + def test_azure_geocatalog_url(self) -> None: + manager = self._get_manager() + os.environ["AZURE_GEOCATALOG_URL"] = "https://geocatalog.example.com" + manager._apply_env_overrides() + assert ( + manager.config.object_storage.azure_geocatalog_url + == "https://geocatalog.example.com" + ) + + def test_exposed_workflows_parses_comma_separated(self) -> None: + manager = self._get_manager() + os.environ["EXPOSED_WORKFLOWS"] = "workflow_a, workflow_b, workflow_c" + manager._apply_env_overrides() + assert manager.config.workflow_exposure.exposed_workflows == [ + "workflow_a", + "workflow_b", + "workflow_c", + ] + + def test_exposed_workflows_single(self) -> None: + manager = self._get_manager() + os.environ["EXPOSED_WORKFLOWS"] = "only_workflow" + manager._apply_env_overrides() + assert manager.config.workflow_exposure.exposed_workflows == ["only_workflow"] + + def test_output_format_zarr(self) -> None: + manager = self._get_manager() + os.environ["OUTPUT_FORMAT"] = "zarr" + manager._apply_env_overrides() + assert manager.config.paths.output_format == "zarr" + + def test_output_format_netcdf4(self) -> None: + manager = self._get_manager() + os.environ["OUTPUT_FORMAT"] = "netcdf4" + manager._apply_env_overrides() + assert manager.config.paths.output_format == "netcdf4" + + def test_output_format_invalid_ignored(self) -> None: + manager = self._get_manager() + original = manager.config.paths.output_format + os.environ["OUTPUT_FORMAT"] = "csv" + manager._apply_env_overrides() + assert manager.config.paths.output_format == original + + def test_apply_env_overrides_no_op_when_config_none(self) -> None: + """_apply_env_overrides returns early when _config is None""" + reset_config() + manager = ConfigManager() + manager._config = None + # Should not raise + manager._apply_env_overrides() + + def test_ensure_paths_exist_no_op_when_config_none(self) -> None: + """_ensure_paths_exist returns early when _config is None""" + reset_config() + manager = ConfigManager() + manager._config = None + # Should not raise + manager._ensure_paths_exist() + + def test_config_property_reinitializes_when_none(self) -> None: + """config property calls _initialize_config when _config is None""" + reset_config() + manager = ConfigManager() + manager._config = None + cfg = manager.config + assert isinstance(cfg, AppConfig) + + def test_workflow_config_property_reinitializes_when_none(self) -> None: + """workflow_config property calls _initialize_config when _workflow_config is None""" + reset_config() + manager = ConfigManager() + manager._workflow_config = None + expected = {"wf": {"param": 1}} + with patch.object( + manager, + "_initialize_config", + side_effect=lambda: setattr(manager, "_workflow_config", expected), + ) as mock_init: + wf_cfg = manager.workflow_config + mock_init.assert_called_once() + assert wf_cfg == expected + + def test_initialize_config_uses_config_dir_env_var(self) -> None: + """_initialize_config uses CONFIG_DIR env var when set (covers lines 203-204)""" + reset_config() + manager = ConfigManager() + manager._config = None + manager._workflow_config = None + os.environ["CONFIG_DIR"] = "/custom/conf" + manager._initialize_config() + assert isinstance(manager._config, AppConfig) diff --git a/test/serve/server/test_server_cpu_worker.py b/test/serve/server/test_server_cpu_worker.py index 45e2c9553..6094b68c1 100644 --- a/test/serve/server/test_server_cpu_worker.py +++ b/test/serve/server/test_server_cpu_worker.py @@ -49,6 +49,7 @@ class MockQueueConfig: name: str = "inference" result_zip_queue_name: str = "result_zip" object_storage_queue_name: str = "object_storage" + geocatalog_ingestion_queue_name: str = "geocatalog_ingestion" finalize_metadata_queue_name: str = "finalize_metadata" max_size: int = 10 default_timeout: str = "1h" @@ -116,6 +117,15 @@ class MockObjectStorageConfig: cloudfront_key_pair_id: str | None = None cloudfront_private_key_path: str | None = None signed_url_expires_in: int = 3600 + azure_geocatalog_url: str | None = None + + +@dataclass +class MockWorkflowExposureConfig: + """Mock workflow exposure configuration""" + + exposed_workflows: list = field(default_factory=list) + warmup_workflows: list = field(default_factory=lambda: ["example_user_workflow"]) @dataclass @@ -131,6 +141,9 @@ class MockAppConfig: object_storage: MockObjectStorageConfig = field( default_factory=MockObjectStorageConfig ) + workflow_exposure: MockWorkflowExposureConfig = field( + default_factory=MockWorkflowExposureConfig + ) # Create a mock config module @@ -1399,5 +1412,640 @@ def test_process_object_storage_upload_output_path_missing_returns_failure(self) assert "does not exist" in mock_fail.call_args[0][2] +class TestProcessObjectStorageUploadEnabled: + """Tests for process_object_storage_upload when storage is enabled.""" + + def _make_mock_config(self, storage_type="s3", **kwargs): + """Return a Mock config with object_storage defaults for enabled storage.""" + mock_config = Mock() + os_cfg = mock_config.object_storage + os_cfg.enabled = True + os_cfg.storage_type = storage_type + os_cfg.bucket = "my-bucket" + os_cfg.prefix = "outputs" + os_cfg.region = "us-east-1" + os_cfg.max_concurrency = 10 + os_cfg.multipart_chunksize = 8388608 + os_cfg.use_rust_client = False + os_cfg.use_transfer_acceleration = False + os_cfg.access_key_id = None + os_cfg.secret_access_key = None + os_cfg.session_token = None + os_cfg.endpoint_url = None + os_cfg.cloudfront_domain = None + os_cfg.cloudfront_key_pair_id = None + os_cfg.cloudfront_private_key = None + os_cfg.azure_container_name = None + os_cfg.azure_account_name = None + os_cfg.azure_geocatalog_url = None + os_cfg.signed_url_expires_in = 3600 + mock_config.redis.retention_ttl = 604800 + for k, v in kwargs.items(): + setattr(os_cfg, k, v) + return mock_config + + def _patch_all( + self, mock_config, mock_redis, mock_queue, mock_storage_cls, tmp_path + ): + """Context manager helper that patches everything for upload tests.""" + return ( + patch("earth2studio.serve.server.cpu_worker.config", mock_config), + patch("earth2studio.serve.server.cpu_worker.redis_client", mock_redis), + patch("earth2studio.serve.server.cpu_worker.queue_next_stage", mock_queue), + patch( + "earth2studio.serve.server.object_storage.MSCObjectStorage", + mock_storage_cls, + ), + ) + + def test_s3_upload_success_returns_result_with_files(self, tmp_path): + """S3 upload success path returns dict with files_uploaded, destination, remote_prefix.""" + output_dir = tmp_path / "out" + output_dir.mkdir() + (output_dir / "result.nc").write_text("data") + + mock_config = self._make_mock_config(storage_type="s3") + mock_redis = Mock() + mock_queue = Mock(return_value="job_1") + mock_storage_cls = Mock() + mock_storage = mock_storage_cls.return_value + mock_upload = Mock() + mock_upload.success = True + mock_upload.files_uploaded = 1 + mock_upload.total_bytes = 4 + mock_upload.destination = "s3://my-bucket/outputs/wf/exec_1" + mock_storage.upload_directory.return_value = mock_upload + + with ( + patch("earth2studio.serve.server.cpu_worker.config", mock_config), + patch("earth2studio.serve.server.cpu_worker.redis_client", mock_redis), + patch("earth2studio.serve.server.cpu_worker.queue_next_stage", mock_queue), + patch( + "earth2studio.serve.server.object_storage.MSCObjectStorage", + mock_storage_cls, + ), + ): + result = process_object_storage_upload( + workflow_name="wf", + execution_id="exec_1", + output_path_str=str(output_dir), + ) + + assert result is not None + assert result["success"] is True + assert result["storage_type"] == "s3" + assert result["files_uploaded"] == 1 + assert result["total_bytes"] == 4 + assert result["remote_prefix"] == "outputs/wf/exec_1" + mock_queue.assert_called_once() + + def test_azure_missing_container_and_bucket_returns_failure(self, tmp_path): + """Azure storage enabled but no container name and no bucket returns fail_workflow.""" + output_dir = tmp_path / "out" + output_dir.mkdir() + + mock_config = self._make_mock_config( + storage_type="azure", bucket=None, azure_container_name=None + ) + + with ( + patch("earth2studio.serve.server.cpu_worker.config", mock_config), + patch("earth2studio.serve.server.cpu_worker.fail_workflow") as mock_fail, + ): + mock_fail.return_value = {"success": False, "error": "no container"} + result = process_object_storage_upload( + workflow_name="wf", + execution_id="exec_1", + output_path_str=str(output_dir), + ) + + assert result is not None and result.get("success") is False + mock_fail.assert_called_once() + assert ( + "azure_container_name" in mock_fail.call_args[0][2].lower() + or "container" in mock_fail.call_args[0][2].lower() + ) + + def test_msc_storage_creation_fails_returns_failure(self, tmp_path): + """When MSCObjectStorage constructor raises, returns fail_workflow dict.""" + output_dir = tmp_path / "out" + output_dir.mkdir() + + mock_config = self._make_mock_config(storage_type="s3") + mock_redis = Mock() + mock_storage_cls = Mock(side_effect=RuntimeError("cannot connect")) + + with ( + patch("earth2studio.serve.server.cpu_worker.config", mock_config), + patch("earth2studio.serve.server.cpu_worker.redis_client", mock_redis), + patch("earth2studio.serve.server.cpu_worker.fail_workflow") as mock_fail, + patch( + "earth2studio.serve.server.object_storage.MSCObjectStorage", + mock_storage_cls, + ), + ): + mock_fail.return_value = {"success": False, "error": "msc failed"} + result = process_object_storage_upload( + workflow_name="wf", + execution_id="exec_1", + output_path_str=str(output_dir), + ) + + assert result is not None and result.get("success") is False + mock_fail.assert_called_once() + assert "msc storage client" in mock_fail.call_args[0][2].lower() + + def test_upload_directory_exception_returns_failure(self, tmp_path): + """When upload_directory raises, returns fail_workflow dict.""" + output_dir = tmp_path / "out" + output_dir.mkdir() + + mock_config = self._make_mock_config(storage_type="s3") + mock_redis = Mock() + mock_storage_cls = Mock() + mock_storage_cls.return_value.upload_directory.side_effect = IOError( + "network error" + ) + + with ( + patch("earth2studio.serve.server.cpu_worker.config", mock_config), + patch("earth2studio.serve.server.cpu_worker.redis_client", mock_redis), + patch("earth2studio.serve.server.cpu_worker.fail_workflow") as mock_fail, + patch( + "earth2studio.serve.server.object_storage.MSCObjectStorage", + mock_storage_cls, + ), + ): + mock_fail.return_value = {"success": False, "error": "upload failed"} + result = process_object_storage_upload( + workflow_name="wf", + execution_id="exec_1", + output_path_str=str(output_dir), + ) + + assert result is not None and result.get("success") is False + mock_fail.assert_called_once() + assert "upload failed" in mock_fail.call_args[0][2].lower() + + def test_upload_result_not_success_returns_failure(self, tmp_path): + """When upload_result.success is False, returns fail_workflow dict.""" + output_dir = tmp_path / "out" + output_dir.mkdir() + + mock_config = self._make_mock_config(storage_type="s3") + mock_redis = Mock() + mock_storage_cls = Mock() + mock_upload = Mock() + mock_upload.success = False + mock_upload.errors = ["checksum mismatch"] + mock_storage_cls.return_value.upload_directory.return_value = mock_upload + + with ( + patch("earth2studio.serve.server.cpu_worker.config", mock_config), + patch("earth2studio.serve.server.cpu_worker.redis_client", mock_redis), + patch("earth2studio.serve.server.cpu_worker.fail_workflow") as mock_fail, + patch( + "earth2studio.serve.server.object_storage.MSCObjectStorage", + mock_storage_cls, + ), + ): + mock_fail.return_value = {"success": False, "error": "upload result false"} + result = process_object_storage_upload( + workflow_name="wf", + execution_id="exec_1", + output_path_str=str(output_dir), + ) + + assert result is not None and result.get("success") is False + mock_fail.assert_called_once() + assert "object storage" in mock_fail.call_args[0][2].lower() + + def test_s3_with_cloudfront_generates_and_stores_signed_url(self, tmp_path): + """When CloudFront is configured, signed URL is generated and stored in storage_info.""" + output_dir = tmp_path / "out" + output_dir.mkdir() + + mock_config = self._make_mock_config( + storage_type="s3", + cloudfront_domain="d123.cloudfront.net", + cloudfront_key_pair_id="KP123", + cloudfront_private_key="-----BEGIN RSA PRIVATE KEY-----\nfake\n-----END RSA PRIVATE KEY-----", + ) + mock_redis = Mock() + mock_queue = Mock(return_value="job_1") + mock_storage_cls = Mock() + mock_storage = mock_storage_cls.return_value + mock_upload = Mock() + mock_upload.success = True + mock_upload.files_uploaded = 1 + mock_upload.total_bytes = 10 + mock_upload.destination = "s3://my-bucket/outputs/wf/exec_1" + mock_storage.upload_directory.return_value = mock_upload + mock_storage.generate_signed_url.return_value = ( + "https://d123.cloudfront.net/signed" + ) + + with ( + patch("earth2studio.serve.server.cpu_worker.config", mock_config), + patch("earth2studio.serve.server.cpu_worker.redis_client", mock_redis), + patch("earth2studio.serve.server.cpu_worker.queue_next_stage", mock_queue), + patch( + "earth2studio.serve.server.object_storage.MSCObjectStorage", + mock_storage_cls, + ), + ): + result = process_object_storage_upload( + workflow_name="wf", + execution_id="exec_1", + output_path_str=str(output_dir), + ) + + assert result is not None + assert result["success"] is True + assert result["signed_url"] == "https://d123.cloudfront.net/signed" + mock_storage.generate_signed_url.assert_called_once() + # Verify signed URL is stored in Redis (setex called with signed_url_key) + setex_keys = [call[0][0] for call in mock_redis.setex.call_args_list] + assert any("signed_url" in k for k in setex_keys) + + def test_s3_signed_url_objectstorageerror_returns_failure(self, tmp_path): + """When generate_signed_url raises ObjectStorageError, returns fail_workflow.""" + from earth2studio.serve.server.object_storage import ObjectStorageError + + output_dir = tmp_path / "out" + output_dir.mkdir() + + mock_config = self._make_mock_config( + storage_type="s3", + cloudfront_domain="d123.cloudfront.net", + cloudfront_key_pair_id="KP123", + cloudfront_private_key="fake", + ) + mock_redis = Mock() + mock_storage_cls = Mock() + mock_storage = mock_storage_cls.return_value + mock_upload = Mock() + mock_upload.success = True + mock_upload.files_uploaded = 1 + mock_upload.total_bytes = 10 + mock_upload.destination = "s3://my-bucket/prefix" + mock_storage.upload_directory.return_value = mock_upload + mock_storage.generate_signed_url.side_effect = ObjectStorageError( + "signing failed" + ) + + with ( + patch("earth2studio.serve.server.cpu_worker.config", mock_config), + patch("earth2studio.serve.server.cpu_worker.redis_client", mock_redis), + patch("earth2studio.serve.server.cpu_worker.fail_workflow") as mock_fail, + patch( + "earth2studio.serve.server.object_storage.MSCObjectStorage", + mock_storage_cls, + ), + ): + mock_fail.return_value = {"success": False, "error": "signed url failed"} + result = process_object_storage_upload( + workflow_name="wf", + execution_id="exec_1", + output_path_str=str(output_dir), + ) + + assert result is not None and result.get("success") is False + mock_fail.assert_called_once() + assert "signed url" in mock_fail.call_args[0][2].lower() + + def test_azure_upload_success_sets_remote_path_and_blob_url(self, tmp_path): + """Azure upload success: storage_info has azure remote_path and blob_url for .nc file.""" + output_dir = tmp_path / "out" + output_dir.mkdir() + (output_dir / "result.nc").write_bytes(b"netcdf") + + mock_config = self._make_mock_config( + storage_type="azure", + bucket=None, + azure_container_name="my-container", + azure_account_name="myaccount", + azure_geocatalog_url="https://geocatalog.example/", + ) + mock_redis = Mock() + mock_queue = Mock(return_value="job_1") + mock_storage_cls = Mock() + mock_storage = mock_storage_cls.return_value + mock_upload = Mock() + mock_upload.success = True + mock_upload.files_uploaded = 1 + mock_upload.total_bytes = 6 + mock_upload.destination = "azure://my-container/outputs/wf/exec_1" + mock_storage.upload_directory.return_value = mock_upload + + with ( + patch("earth2studio.serve.server.cpu_worker.config", mock_config), + patch("earth2studio.serve.server.cpu_worker.redis_client", mock_redis), + patch("earth2studio.serve.server.cpu_worker.queue_next_stage", mock_queue), + patch( + "earth2studio.serve.server.object_storage.MSCObjectStorage", + mock_storage_cls, + ), + ): + result = process_object_storage_upload( + workflow_name="wf", + execution_id="exec_1", + output_path_str=str(output_dir), + ) + + assert result is not None + assert result["success"] is True + assert result["storage_type"] == "azure" + # Verify storage_info was written to Redis with azure remote_path and blob_url + storage_info_calls = [ + c for c in mock_redis.setex.call_args_list if "storage_info" in c[0][0] + ] + assert len(storage_info_calls) == 1 + stored_info = json.loads(storage_info_calls[0][0][2]) + assert stored_info["remote_path"].startswith("azure://my-container/") + assert "blob_url" in stored_info + assert "result.nc" in stored_info["blob_url"] + + def test_azure_upload_blob_url_from_nc_file_in_directory(self, tmp_path): + """Azure upload: blob_url is built from the first .nc file found in a directory.""" + output_dir = tmp_path / "out" + output_dir.mkdir() + (output_dir / "forecast.nc").write_bytes(b"data") + (output_dir / "other.txt").write_text("text") + + mock_config = self._make_mock_config( + storage_type="azure", + bucket=None, + azure_container_name="container", + azure_account_name="myaccount", + azure_geocatalog_url="https://geocatalog.example/", + ) + mock_redis = Mock() + mock_queue = Mock(return_value="job_1") + mock_storage_cls = Mock() + mock_upload = Mock() + mock_upload.success = True + mock_upload.files_uploaded = 2 + mock_upload.total_bytes = 100 + mock_upload.destination = "azure://container/outputs/wf/exec_1" + mock_storage_cls.return_value.upload_directory.return_value = mock_upload + + with ( + patch("earth2studio.serve.server.cpu_worker.config", mock_config), + patch("earth2studio.serve.server.cpu_worker.redis_client", mock_redis), + patch("earth2studio.serve.server.cpu_worker.queue_next_stage", mock_queue), + patch( + "earth2studio.serve.server.object_storage.MSCObjectStorage", + mock_storage_cls, + ), + ): + result = process_object_storage_upload( + workflow_name="wf", + execution_id="exec_1", + output_path_str=str(output_dir), + ) + + assert result["success"] is True + storage_info_calls = [ + c for c in mock_redis.setex.call_args_list if "storage_info" in c[0][0] + ] + stored_info = json.loads(storage_info_calls[0][0][2]) + assert "forecast.nc" in stored_info.get("blob_url", "") + + def test_azure_upload_blob_url_from_zarr_store_in_directory(self, tmp_path): + """Azure upload: blob_url points at first *.zarr store when no .nc files exist.""" + output_dir = tmp_path / "out" + output_dir.mkdir() + zarr_store = output_dir / "results.zarr" + zarr_store.mkdir() + (zarr_store / ".zarray").write_text("{}") + + mock_config = self._make_mock_config( + storage_type="azure", + bucket=None, + azure_container_name="container", + azure_account_name="myaccount", + azure_geocatalog_url="https://geocatalog.example/", + ) + mock_redis = Mock() + mock_queue = Mock(return_value="job_1") + mock_storage_cls = Mock() + mock_upload = Mock() + mock_upload.success = True + mock_upload.files_uploaded = 1 + mock_upload.total_bytes = 10 + mock_upload.destination = "azure://container/outputs/wf/exec_1" + mock_storage_cls.return_value.upload_directory.return_value = mock_upload + + with ( + patch("earth2studio.serve.server.cpu_worker.config", mock_config), + patch("earth2studio.serve.server.cpu_worker.redis_client", mock_redis), + patch("earth2studio.serve.server.cpu_worker.queue_next_stage", mock_queue), + patch( + "earth2studio.serve.server.object_storage.MSCObjectStorage", + mock_storage_cls, + ), + ): + result = process_object_storage_upload( + workflow_name="wf", + execution_id="exec_1", + output_path_str=str(output_dir), + ) + + assert result["success"] is True + storage_info_calls = [ + c for c in mock_redis.setex.call_args_list if "storage_info" in c[0][0] + ] + stored_info = json.loads(storage_info_calls[0][0][2]) + assert "results.zarr" in stored_info.get("blob_url", "") + + def test_queue_next_returns_none_after_upload_returns_failure(self, tmp_path): + """When queue_next_stage returns None after a successful upload, returns fail_workflow.""" + output_dir = tmp_path / "out" + output_dir.mkdir() + + mock_config = self._make_mock_config(storage_type="s3") + mock_redis = Mock() + mock_queue = Mock(return_value=None) + mock_storage_cls = Mock() + mock_upload = Mock() + mock_upload.success = True + mock_upload.files_uploaded = 0 + mock_upload.total_bytes = 0 + mock_upload.destination = "s3://my-bucket/prefix" + mock_storage_cls.return_value.upload_directory.return_value = mock_upload + + with ( + patch("earth2studio.serve.server.cpu_worker.config", mock_config), + patch("earth2studio.serve.server.cpu_worker.redis_client", mock_redis), + patch("earth2studio.serve.server.cpu_worker.queue_next_stage", mock_queue), + patch("earth2studio.serve.server.cpu_worker.fail_workflow") as mock_fail, + patch( + "earth2studio.serve.server.object_storage.MSCObjectStorage", + mock_storage_cls, + ), + ): + mock_fail.return_value = {"success": False, "error": "no job"} + result = process_object_storage_upload( + workflow_name="wf", + execution_id="exec_1", + output_path_str=str(output_dir), + ) + + assert result is not None and result.get("success") is False + mock_fail.assert_called_once() + assert "pipeline stage" in mock_fail.call_args[0][2].lower() + + def test_unexpected_exception_returns_fail_workflow(self, tmp_path): + """An unexpected exception in the try block returns fail_workflow dict (lines 825-827).""" + output_dir = tmp_path / "out" + output_dir.mkdir() + + mock_config = self._make_mock_config(storage_type="s3") + # Make redis_client.setex raise to trigger the outer except after upload + mock_redis = Mock() + mock_queue = Mock(return_value="job_1") + mock_storage_cls = Mock() + mock_upload = Mock() + mock_upload.success = True + mock_upload.files_uploaded = 1 + mock_upload.total_bytes = 1 + mock_upload.destination = "s3://my-bucket/prefix" + mock_storage_cls.return_value.upload_directory.return_value = mock_upload + mock_redis.setex.side_effect = RuntimeError("redis crashed") + + with ( + patch("earth2studio.serve.server.cpu_worker.config", mock_config), + patch("earth2studio.serve.server.cpu_worker.redis_client", mock_redis), + patch("earth2studio.serve.server.cpu_worker.queue_next_stage", mock_queue), + patch( + "earth2studio.serve.server.object_storage.MSCObjectStorage", + mock_storage_cls, + ), + ): + result = process_object_storage_upload( + workflow_name="wf", + execution_id="exec_1", + output_path_str=str(output_dir), + ) + + assert result is not None + assert result.get("success") is False + + def test_azure_credentials_added_to_storage_kwargs(self, tmp_path): + """Azure-specific kwargs (account_name, container, endpoint_url) are passed for managed identity.""" + output_dir = tmp_path / "out" + output_dir.mkdir() + + mock_config = self._make_mock_config( + storage_type="azure", + bucket=None, + azure_container_name="my-container", + azure_account_name="myaccount", + endpoint_url="https://myaccount.blob.core.windows.net", + ) + mock_redis = Mock() + mock_queue = Mock(return_value="job_1") + mock_storage_cls = Mock() + mock_upload = Mock() + mock_upload.success = True + mock_upload.files_uploaded = 0 + mock_upload.total_bytes = 0 + mock_upload.destination = "azure://my-container/prefix" + mock_storage_cls.return_value.upload_directory.return_value = mock_upload + + with ( + patch("earth2studio.serve.server.cpu_worker.config", mock_config), + patch("earth2studio.serve.server.cpu_worker.redis_client", mock_redis), + patch("earth2studio.serve.server.cpu_worker.queue_next_stage", mock_queue), + patch( + "earth2studio.serve.server.object_storage.MSCObjectStorage", + mock_storage_cls, + ), + ): + process_object_storage_upload( + workflow_name="wf", + execution_id="exec_1", + output_path_str=str(output_dir), + ) + + call_kwargs = mock_storage_cls.call_args[1] + assert ( + call_kwargs.get("endpoint_url") == "https://myaccount.blob.core.windows.net" + ) + assert call_kwargs.get("azure_account_name") == "myaccount" + assert call_kwargs.get("azure_container_name") == "my-container" + + def test_s3_optional_credentials_added_when_set(self, tmp_path): + """S3 optional kwargs (access_key_id, session_token, endpoint_url) are passed when set.""" + output_dir = tmp_path / "out" + output_dir.mkdir() + + mock_config = self._make_mock_config( # noqa: S106 + storage_type="s3", + access_key_id="AK123", + secret_access_key="SK456", # noqa: S106 + session_token="ST789", # noqa: S106 + endpoint_url="https://s3.custom.example.com", + ) + mock_redis = Mock() + mock_queue = Mock(return_value="job_1") + mock_storage_cls = Mock() + mock_upload = Mock() + mock_upload.success = True + mock_upload.files_uploaded = 0 + mock_upload.total_bytes = 0 + mock_upload.destination = "s3://my-bucket/prefix" + mock_storage_cls.return_value.upload_directory.return_value = mock_upload + + with ( + patch("earth2studio.serve.server.cpu_worker.config", mock_config), + patch("earth2studio.serve.server.cpu_worker.redis_client", mock_redis), + patch("earth2studio.serve.server.cpu_worker.queue_next_stage", mock_queue), + patch( + "earth2studio.serve.server.object_storage.MSCObjectStorage", + mock_storage_cls, + ), + ): + process_object_storage_upload( + workflow_name="wf", + execution_id="exec_1", + output_path_str=str(output_dir), + ) + + call_kwargs = mock_storage_cls.call_args[1] + assert call_kwargs.get("access_key_id") == "AK123" + assert call_kwargs.get("secret_access_key") == "SK456" + assert call_kwargs.get("session_token") == "ST789" + assert call_kwargs.get("endpoint_url") == "https://s3.custom.example.com" + + +class TestProcessFinalizeMetadataEdgeCases: + """Tests for the exception handler in process_finalize_metadata (lines 1114-1116).""" + + def test_exception_in_try_block_returns_fail_workflow(self, tmp_path): + """When json.loads raises (corrupt metadata), the except block returns fail_workflow.""" + results_zip_dir = tmp_path / "results" + results_zip_dir.mkdir() + request_id = "my_wf:exec_1" + metadata_key = f"inference_request:{request_id}:pending_metadata" + results_zip_dir_key = f"inference_request:{request_id}:results_zip_dir" + + with patch("earth2studio.serve.server.cpu_worker.redis_client") as mock_redis: + # Return non-JSON for metadata to trigger json.loads to raise + mock_redis.get.side_effect = lambda k: { + metadata_key: "NOT_VALID_JSON{{{", + results_zip_dir_key: str(results_zip_dir), + }.get(k) + + result = process_finalize_metadata( + workflow_name="my_wf", execution_id="exec_1" + ) + + assert result is not None + assert result.get("success") is False + assert "metadata finalization" in result.get("error", "").lower() + + if __name__ == "__main__": pytest.main([__file__, "-v"]) diff --git a/test/serve/server/test_server_main.py b/test/serve/server/test_server_main.py index 6c9894ff0..5e2d862bf 100644 --- a/test/serve/server/test_server_main.py +++ b/test/serve/server/test_server_main.py @@ -1422,6 +1422,7 @@ def test_admission_control_queue_full_returns_429(self, client_for_admission): mock_config.queue.name = "inference" mock_config.queue.result_zip_queue_name = "result_zip" mock_config.queue.object_storage_queue_name = "object_storage" + mock_config.queue.geocatalog_ingestion_queue_name = "geocatalog_ingestion" mock_config.queue.finalize_metadata_queue_name = "finalize_metadata" response = client.post( "/v1/infer/admit_wf", @@ -1541,8 +1542,8 @@ def test_execute_workflow_inference_queue_none_503(self, client_exec): def test_execute_workflow_llen_raises_after_enqueue_500(self, client_exec): """When llen raises after enqueue (queue position lookup), returns 500.""" client, mock_redis, mock_queue = client_exec - # First 4 calls: admission (4 queues); 5th: position lookup. Make 5th raise. - mock_redis.llen.side_effect = [0, 0, 0, 0, RuntimeError("redis error")] + # 5 queue checks in admission control + 1 for queue position lookup + mock_redis.llen.side_effect = [0, 0, 0, 0, 0, RuntimeError("redis error")] with patch("earth2studio.serve.server.main.inference_queue", mock_queue): response = client.post( "/v1/infer/exec_wf", @@ -2113,6 +2114,496 @@ def test_get_workflow_result_file_stream_file_200(self, client_file): assert response.text == "hello world" +class TestLifespanBranches: + """Tests for lifespan startup exception branches (lines 236-242).""" + + def test_lifespan_workflow_registration_generic_exception_continues(self): + """When register_all_workflows raises non-ImportError, app still starts.""" + with ( + patch("redis.asyncio.Redis") as mock_async_redis, + patch("redis.Redis") as mock_sync_redis, + patch("rq.Queue"), + patch( + "earth2studio.serve.server.workflow.register_all_workflows", + side_effect=RuntimeError("unexpected registration error"), + ), + ): + mock_async_instance = MagicMock() + mock_async_instance.ping = AsyncMock(return_value=True) + mock_async_instance.close = AsyncMock() + mock_async_redis.return_value = mock_async_instance + mock_sync_instance = MagicMock() + mock_sync_instance.ping = MagicMock(return_value=True) + mock_sync_instance.close = MagicMock() + mock_sync_redis.return_value = mock_sync_instance + + from earth2studio.serve.server.main import app + + with TestClient(app, raise_server_exceptions=False) as c: + response = c.get("/liveness") + assert response.status_code == 200 + + def test_lifespan_redis_ping_failure_raises(self): + """When Redis ping fails, lifespan raises and app fails to start.""" + with ( + patch("redis.asyncio.Redis") as mock_async_redis, + patch("redis.Redis") as mock_sync_redis, + patch("rq.Queue"), + patch("earth2studio.serve.server.workflow.register_all_workflows"), + ): + mock_async_instance = MagicMock() + mock_async_instance.ping = AsyncMock( + side_effect=ConnectionError("Redis unavailable") + ) + mock_async_instance.close = AsyncMock() + mock_async_redis.return_value = mock_async_instance + mock_sync_instance = MagicMock() + mock_sync_instance.ping = MagicMock(return_value=True) + mock_sync_redis.return_value = mock_sync_instance + + from earth2studio.serve.server.main import app + + with pytest.raises(Exception): + with TestClient(app): + pass + + +class TestHealthCheckWithoutScriptDir: + """Tests health endpoint when SCRIPT_DIR env var is absent (lines 304-305).""" + + @pytest.fixture + def client_probes(self): + """Client for probe endpoints.""" + with ( + patch("redis.asyncio.Redis") as mock_async_redis, + patch("redis.Redis") as mock_sync_redis, + patch("rq.Queue"), + patch("earth2studio.serve.server.workflow.register_all_workflows"), + ): + mock_async_instance = MagicMock() + mock_async_instance.ping = AsyncMock(return_value=True) + mock_async_instance.close = AsyncMock() + mock_async_redis.return_value = mock_async_instance + mock_sync_instance = MagicMock() + mock_sync_instance.ping = MagicMock(return_value=True) + mock_sync_instance.close = MagicMock() + mock_sync_redis.return_value = mock_sync_instance + + from earth2studio.serve.server.main import app + + with TestClient(app, raise_server_exceptions=False) as c: + yield c + + def test_health_follows_repo_relative_path_when_script_dir_empty( + self, client_probes + ): + """Health check uses repo-relative script path when SCRIPT_DIR is empty/unset.""" + with patch.dict(os.environ, {"SCRIPT_DIR": ""}): + with patch( + "earth2studio.serve.server.main.asyncio.create_subprocess_exec" + ) as mock_exec: + mock_proc = MagicMock() + mock_proc.returncode = 0 + mock_proc.communicate = AsyncMock(return_value=(b"", b"")) + mock_exec.return_value = mock_proc + + response = client_probes.get("/health") + assert response.status_code == 200 + assert response.json()["status"] == "healthy" + + +class TestNotExposedWorkflowEndpoints: + """Tests for 404 when workflow is registered but not exposed (lines 427, 539, 671, 743, 876).""" + + @pytest.fixture + def client_with_workflow(self): + """Standard client with a registered workflow.""" + with ( + patch("redis.asyncio.Redis") as mock_async_redis, + patch("redis.Redis") as mock_sync_redis, + patch("rq.Queue"), + patch("earth2studio.serve.server.workflow.register_all_workflows"), + ): + mock_async_instance = MagicMock() + mock_async_instance.ping = AsyncMock(return_value=True) + mock_async_instance.close = AsyncMock() + mock_async_redis.return_value = mock_async_instance + mock_sync_instance = MagicMock() + mock_sync_instance.ping = MagicMock(return_value=True) + mock_sync_instance.close = MagicMock() + mock_sync_redis.return_value = mock_sync_instance + + from earth2studio.serve.server.main import app + from earth2studio.serve.server.workflow import ( + Workflow, + WorkflowParameters, + workflow_registry, + ) + + class TestEndpointsParams(WorkflowParameters): + x: int = Field(default=1) + + class TestEndpointsWf(Workflow): + name = "test_endpoints_wf" + description = "Test" + Parameters = TestEndpointsParams + + @classmethod + def validate_parameters(cls, parameters): + return TestEndpointsParams.validate(parameters) + + def run(self, parameters, execution_id): + return {"status": "ok"} + + workflow_registry._workflows["test_endpoints_wf"] = TestEndpointsWf + + with TestClient(app, raise_server_exceptions=False) as c: + yield c + + if "test_endpoints_wf" in workflow_registry._workflows: + del workflow_registry._workflows["test_endpoints_wf"] + + def test_schema_not_exposed_404(self, client_with_workflow): + """get_workflow_schema returns 404 when workflow is not exposed.""" + with patch( + "earth2studio.serve.server.main.workflow_registry.is_workflow_exposed", + return_value=False, + ): + response = client_with_workflow.get( + "/v1/workflows/test_endpoints_wf/schema" + ) + assert response.status_code == 404 + assert "not exposed" in response.json().get("detail", "").lower() + + def test_execute_not_exposed_404(self, client_with_workflow): + """execute_workflow returns 404 when workflow is not exposed.""" + with patch( + "earth2studio.serve.server.main.workflow_registry.is_workflow_exposed", + return_value=False, + ): + response = client_with_workflow.post( + "/v1/infer/test_endpoints_wf", json={"parameters": {}} + ) + assert response.status_code == 404 + + def test_get_status_workflow_not_found_404(self, client_with_workflow): + """get_workflow_status returns 404 when workflow not found.""" + response = client_with_workflow.get("/v1/infer/nonexistent_wf/exec_1/status") + assert response.status_code == 404 + assert "not found" in response.json().get("detail", "").lower() + + def test_get_status_not_exposed_404(self, client_with_workflow): + """get_workflow_status returns 404 when workflow is not exposed.""" + with patch( + "earth2studio.serve.server.main.workflow_registry.is_workflow_exposed", + return_value=False, + ): + response = client_with_workflow.get( + "/v1/infer/test_endpoints_wf/exec_1/status" + ) + assert response.status_code == 404 + + def test_get_results_workflow_not_found_404(self, client_with_workflow): + """get_workflow_results returns 404 when workflow not found.""" + response = client_with_workflow.get("/v1/infer/nonexistent_wf/exec_1/results") + assert response.status_code == 404 + + def test_get_results_not_exposed_404(self, client_with_workflow): + """get_workflow_results returns 404 when workflow is not exposed.""" + with patch( + "earth2studio.serve.server.main.workflow_registry.is_workflow_exposed", + return_value=False, + ): + response = client_with_workflow.get( + "/v1/infer/test_endpoints_wf/exec_1/results" + ) + assert response.status_code == 404 + + def test_get_result_file_not_exposed_404(self, client_with_workflow): + """get_workflow_result_file returns 404 when workflow is not exposed.""" + with patch( + "earth2studio.serve.server.main.workflow_registry.is_workflow_exposed", + return_value=False, + ): + response = client_with_workflow.get( + "/v1/infer/test_endpoints_wf/exec_1/results/file.nc" + ) + assert response.status_code == 404 + + +class TestExecuteWorkflowAdditionalBranches: + """Additional coverage for execute_workflow (lines 598, 605-609).""" + + @pytest.fixture + def client_exec2(self): + """Client with workflow for additional execute branch tests.""" + from earth2studio.serve.server.workflow import Workflow, WorkflowParameters + + with ( + patch("redis.asyncio.Redis") as mock_async_redis, + patch("redis.Redis") as mock_sync_redis, + patch("rq.Queue") as mock_queue_class, + patch("earth2studio.serve.server.workflow.register_all_workflows"), + ): + mock_async_instance = MagicMock() + mock_async_instance.ping = AsyncMock(return_value=True) + mock_async_instance.close = AsyncMock() + mock_async_instance.get = AsyncMock(return_value=None) + mock_async_redis.return_value = mock_async_instance + + mock_sync_instance = MagicMock() + mock_sync_instance.ping = MagicMock(return_value=True) + mock_sync_instance.close = MagicMock() + mock_sync_instance.setex = MagicMock() + mock_sync_instance.llen = MagicMock(return_value=0) + mock_sync_redis.return_value = mock_sync_instance + + mock_queue = MagicMock() + mock_job = MagicMock() + mock_job.id = "exec2_wf_exec_123" + mock_queue.enqueue = MagicMock(return_value=mock_job) + mock_queue_class.return_value = mock_queue + + class Exec2Params(WorkflowParameters): + x: int = Field(default=1) + + class Exec2Workflow(Workflow): + name = "exec2_wf" + description = "For additional execute tests" + Parameters = Exec2Params + + @classmethod + def validate_parameters(cls, parameters): + return Exec2Params.validate(parameters) + + def run(self, parameters, execution_id): + return {"status": "ok"} + + from earth2studio.serve.server.main import app + from earth2studio.serve.server.workflow import workflow_registry + + workflow_registry._workflows["exec2_wf"] = Exec2Workflow + with TestClient(app, raise_server_exceptions=False) as c: + yield c, mock_sync_instance, mock_queue, Exec2Workflow + if "exec2_wf" in workflow_registry._workflows: + del workflow_registry._workflows["exec2_wf"] + + def test_execute_llen_raises_after_enqueue_500(self, client_exec2): + """When llen raises during queue position lookup, returns 500.""" + client, mock_redis, mock_queue, _ = client_exec2 + # 5 queue checks in admission control + 1 for queue position lookup + mock_redis.llen.side_effect = [0, 0, 0, 0, 0, RuntimeError("redis error")] + with patch("earth2studio.serve.server.main.inference_queue", mock_queue): + response = client.post( + "/v1/infer/exec2_wf", + json={"parameters": {}}, + ) + assert response.status_code == 500 + + def test_execute_redis_none_during_queue_position_503(self, client_exec2): + """When redis_sync_client is None at queue position check, returns 503.""" + client, mock_redis, mock_queue, wf_class = client_exec2 + with patch("earth2studio.serve.server.main.check_admission_control"): + with patch.object(wf_class, "_save_execution_data"): + with patch("earth2studio.serve.server.main.redis_sync_client", None): + with patch( + "earth2studio.serve.server.main.inference_queue", mock_queue + ): + response = client.post( + "/v1/infer/exec2_wf", + json={"parameters": {}}, + ) + assert response.status_code == 503 + assert "Redis" in response.json().get("detail", "") + + +class TestGetWorkflowResultFileAdditionalBranches: + """Additional tests for get_workflow_result_file (lines 895, 903, 919-952, 986, 1035, 1052, 1062-1066).""" + + @pytest.fixture + def client_file2(self, tmp_path): + """Client and tmp dir for additional result file tests.""" + with ( + patch("redis.asyncio.Redis") as mock_async_redis, + patch("redis.Redis") as mock_sync_redis, + patch("rq.Queue") as mock_queue_class, + patch("earth2studio.serve.server.workflow.register_all_workflows"), + ): + mock_async_instance = MagicMock() + mock_async_instance.ping = AsyncMock(return_value=True) + mock_async_instance.close = AsyncMock() + mock_async_instance.get = AsyncMock(return_value=None) + mock_async_redis.return_value = mock_async_instance + mock_sync_instance = MagicMock() + mock_sync_instance.ping = MagicMock(return_value=True) + mock_sync_instance.close = MagicMock() + mock_sync_redis.return_value = mock_sync_instance + mock_queue_class.return_value = MagicMock() + + from earth2studio.serve.server.main import app + from earth2studio.serve.server.workflow import ( + Workflow, + WorkflowParameters, + WorkflowStatus, + workflow_registry, + ) + + class File2Params(WorkflowParameters): + x: int = Field(default=1) + + class File2Workflow(Workflow): + name = "file2_wf" + description = "File2 test" + Parameters = File2Params + + @classmethod + def validate_parameters(cls, parameters): + return File2Params.validate(parameters) + + def run(self, parameters, execution_id): + return {"status": "ok"} + + workflow_registry._workflows["file2_wf"] = File2Workflow + with TestClient(app, raise_server_exceptions=False) as c: + yield c, mock_async_instance, tmp_path, WorkflowStatus + if "file2_wf" in workflow_registry._workflows: + del workflow_registry._workflows["file2_wf"] + + def _completed_exec_data(self): + from earth2studio.serve.server.workflow import WorkflowResult, WorkflowStatus + + return WorkflowResult( + workflow_name="file2_wf", + execution_id="exec_1", + status=WorkflowStatus.COMPLETED, + start_time=datetime.now(timezone.utc).isoformat(), + end_time=datetime.now(timezone.utc).isoformat(), + ) + + def test_get_result_file_value_error_in_exec_data_404(self, client_file2): + """When _get_execution_data raises ValueError, returns 404.""" + client, *_ = client_file2 + with patch("earth2studio.serve.server.main.workflow_registry") as mock_reg: + mock_wf = MagicMock() + mock_wf._get_execution_data = MagicMock( + side_effect=ValueError("Execution not found") + ) + mock_reg.get_workflow_class.return_value = mock_wf + response = client.get("/v1/infer/file2_wf/exec_1/results/file.nc") + assert response.status_code == 404 + assert "Execution not found" in response.json().get("detail", "") + + def test_get_result_file_redis_none_for_zip_path_503(self, client_file2): + """When filepath == request_id but redis_client is None, returns 503.""" + client, mock_async, tmp_path, _ = client_file2 + exec_data = self._completed_exec_data() + with patch("earth2studio.serve.server.main.workflow_registry") as mock_reg: + with patch("earth2studio.serve.server.main.redis_client", None): + mock_wf = MagicMock() + mock_wf._get_execution_data = MagicMock(return_value=exec_data) + mock_reg.get_workflow_class.return_value = mock_wf + response = client.get( + "/v1/infer/file2_wf/exec_1/results/file2_wf:exec_1" + ) + assert response.status_code == 503 + + def test_get_result_file_zip_not_on_disk_404(self, client_file2): + """When zip key in Redis but zip file missing from disk, returns 404.""" + client, mock_async, tmp_path, _ = client_file2 + exec_data = self._completed_exec_data() + mock_async.get = AsyncMock(return_value="missing_zip.zip") + with patch("earth2studio.serve.server.main.workflow_registry") as mock_reg: + with patch("earth2studio.serve.server.main.redis_client", mock_async): + with patch("earth2studio.serve.server.main.RESULTS_ZIP_DIR", tmp_path): + mock_wf = MagicMock() + mock_wf._get_execution_data = MagicMock(return_value=exec_data) + mock_reg.get_workflow_class.return_value = mock_wf + response = client.get( + "/v1/infer/file2_wf/exec_1/results/file2_wf:exec_1" + ) + assert response.status_code == 404 + assert "disk" in response.json().get("detail", {}).get("error", "").lower() + + def test_get_result_file_zip_stream_success(self, client_file2): + """When zip file exists on disk, returns 200 with streamed content.""" + client, mock_async, tmp_path, _ = client_file2 + exec_data = self._completed_exec_data() + zip_file = tmp_path / "results.zip" + zip_file.write_bytes(b"PK\x03\x04fake_zip_content") + mock_async.get = AsyncMock(return_value="results.zip") + with patch("earth2studio.serve.server.main.workflow_registry") as mock_reg: + with patch("earth2studio.serve.server.main.redis_client", mock_async): + with patch("earth2studio.serve.server.main.RESULTS_ZIP_DIR", tmp_path): + mock_wf = MagicMock() + mock_wf._get_execution_data = MagicMock(return_value=exec_data) + mock_reg.get_workflow_class.return_value = mock_wf + response = client.get( + "/v1/infer/file2_wf/exec_1/results/file2_wf:exec_1" + ) + assert response.status_code == 200 + assert 'results.zip"' in response.headers.get("content-disposition", "") + + def test_get_result_file_filepath_with_output_dir_prefix(self, client_file2): + """When filepath starts with output dir name, the prefix is stripped.""" + client, mock_async, tmp_path, _ = client_file2 + exec_data = self._completed_exec_data() + output_dir = tmp_path / "exec_1" + output_dir.mkdir() + data_file = output_dir / "data.txt" + data_file.write_text("contents") + mock_async.get = AsyncMock(return_value=str(output_dir)) + with patch("earth2studio.serve.server.main.workflow_registry") as mock_reg: + with patch("earth2studio.serve.server.main.redis_client", mock_async): + mock_wf = MagicMock() + mock_wf._get_execution_data = MagicMock(return_value=exec_data) + mock_reg.get_workflow_class.return_value = mock_wf + # filepath starts with output dir name (exec_1/data.txt) + response = client.get( + "/v1/infer/file2_wf/exec_1/results/exec_1/data.txt" + ) + assert response.status_code == 200 + assert response.text == "contents" + + def test_get_result_file_no_mime_type_uses_octet_stream(self, client_file2): + """Files with no recognized MIME type use application/octet-stream.""" + client, mock_async, tmp_path, _ = client_file2 + exec_data = self._completed_exec_data() + output_dir = tmp_path / "output" + output_dir.mkdir() + unknown_file = output_dir / "data.unknownext99999" + unknown_file.write_bytes(b"\x00\x01\x02\x03") + mock_async.get = AsyncMock(return_value=str(output_dir)) + with patch("earth2studio.serve.server.main.workflow_registry") as mock_reg: + with patch("earth2studio.serve.server.main.redis_client", mock_async): + mock_wf = MagicMock() + mock_wf._get_execution_data = MagicMock(return_value=exec_data) + mock_reg.get_workflow_class.return_value = mock_wf + with patch("mimetypes.guess_type", return_value=(None, None)): + response = client.get( + "/v1/infer/file2_wf/exec_1/results/data.unknownext99999" + ) + assert response.status_code == 200 + assert "octet-stream" in response.headers.get("content-type", "") + + def test_get_result_file_generic_exception_500(self, client_file2): + """When an unexpected exception occurs in the file handler, returns 500.""" + client, mock_async, tmp_path, _ = client_file2 + exec_data = self._completed_exec_data() + mock_async.get = AsyncMock(side_effect=RuntimeError("unexpected redis error")) + with patch("earth2studio.serve.server.main.workflow_registry") as mock_reg: + with patch("earth2studio.serve.server.main.redis_client", mock_async): + mock_wf = MagicMock() + mock_wf._get_execution_data = MagicMock(return_value=exec_data) + mock_reg.get_workflow_class.return_value = mock_wf + response = client.get( + "/v1/infer/file2_wf/exec_1/results/file2_wf:exec_1" + ) + assert response.status_code == 500 + assert "Failed to retrieve file" in response.json().get("detail", {}).get( + "error", "" + ) + + class TestMainEntrypoint: """Test main module entrypoint (covers line 1044).""" diff --git a/test/serve/server/test_server_object_storage.py b/test/serve/server/test_server_object_storage.py index 907441cc2..9ebc6e775 100644 --- a/test/serve/server/test_server_object_storage.py +++ b/test/serve/server/test_server_object_storage.py @@ -140,7 +140,7 @@ def test_upload_directory_success_returns_upload_result(self, mock_msc): with tempfile.TemporaryDirectory() as tmpdir: (Path(tmpdir) / "f1.txt").write_text("hello") storage = MSCObjectStorage(bucket="b", region="us-east-1") - storage._s3_client.sync_from.return_value = None + storage._storage_client.sync_from.return_value = None result = storage.upload_directory( local_directory=tmpdir, @@ -159,7 +159,7 @@ def test_upload_directory_failure_appends_errors(self, mock_msc): with tempfile.TemporaryDirectory() as tmpdir: (Path(tmpdir) / "f1.txt").write_text("x") storage = MSCObjectStorage(bucket="b", region="us-east-1") - storage._s3_client.sync_from.side_effect = Exception("sync failed") + storage._storage_client.sync_from.side_effect = Exception("sync failed") result = storage.upload_directory( local_directory=tmpdir, @@ -196,7 +196,7 @@ def test_upload_file_returns_true_on_success(self, mock_msc): remote_key="key.txt", ) assert result is True - storage._s3_client.upload_file.assert_called_once() + storage._storage_client.upload_file.assert_called_once() finally: Path(path).unlink(missing_ok=True) @@ -206,7 +206,7 @@ def test_upload_file_returns_false_on_exception(self, mock_msc): path = f.name try: storage = MSCObjectStorage(bucket="b", region="us-east-1") - storage._s3_client.upload_file.side_effect = Exception("upload failed") + storage._storage_client.upload_file.side_effect = Exception("upload failed") result = storage.upload_file( local_file=path, @@ -219,14 +219,14 @@ def test_upload_file_returns_false_on_exception(self, mock_msc): def test_file_exists_returns_true_when_info_succeeds(self, mock_msc): """file_exists returns True when info() does not raise.""" storage = MSCObjectStorage(bucket="b", region="us-east-1") - storage._s3_client.info.return_value = None + storage._storage_client.info.return_value = None assert storage.file_exists("my/key") is True def test_file_exists_returns_false_when_file_not_found(self, mock_msc): """file_exists returns False when info() raises FileNotFoundError.""" storage = MSCObjectStorage(bucket="b", region="us-east-1") - storage._s3_client.info.side_effect = FileNotFoundError() + storage._storage_client.info.side_effect = FileNotFoundError() assert storage.file_exists("my/key") is False @@ -235,12 +235,12 @@ def test_delete_file_returns_true_on_success(self, mock_msc): storage = MSCObjectStorage(bucket="b", region="us-east-1") assert storage.delete_file("my/key") is True - storage._s3_client.delete.assert_called_once() + storage._storage_client.delete.assert_called_once() def test_delete_file_returns_false_when_file_not_found(self, mock_msc): """delete_file returns False when delete raises FileNotFoundError.""" storage = MSCObjectStorage(bucket="b", region="us-east-1") - storage._s3_client.delete.side_effect = FileNotFoundError() + storage._storage_client.delete.side_effect = FileNotFoundError() assert storage.delete_file("my/key") is False @@ -292,3 +292,208 @@ def test_generate_signed_url_returns_url_when_configured(self, mock_msc): assert "Policy=" in url assert "Signature=" in url assert "Key-Pair-Id=KP123" in url + + +class TestMSCObjectStorageS3Additional: + """Additional tests to cover S3 init branches and other uncovered S3 paths.""" + + @pytest.fixture + def mock_msc(self): + mock_module = MagicMock() + mock_module.StorageClientConfig.from_dict.side_effect = [ + MagicMock(), + MagicMock(), + ] + mock_module.StorageClient.return_value = MagicMock() + with patch.dict(sys.modules, {"multistorageclient": mock_module}): + yield mock_module + + def test_init_s3_transfer_acceleration(self, mock_msc): + """S3 init with use_transfer_acceleration sets accelerate endpoint URL.""" + storage = MSCObjectStorage(bucket="my-bucket", use_transfer_acceleration=True) + assert storage.endpoint_url == "https://my-bucket.s3-accelerate.amazonaws.com" + + def test_init_s3_with_credentials(self, mock_msc): + """S3 init with credentials sets AWS environment variables.""" + import os + + MSCObjectStorage( + bucket="b", + access_key_id="AKID", + secret_access_key="SECRET", # noqa: S106 + session_token="TOKEN", # noqa: S106 + ) + assert os.environ.get("AWS_ACCESS_KEY_ID") == "AKID" + assert os.environ.get("AWS_SECRET_ACCESS_KEY") == "SECRET" + assert os.environ.get("AWS_SESSION_TOKEN") == "TOKEN" + + def test_init_s3_with_endpoint_url(self, mock_msc): + """S3 init with endpoint_url stores it and includes it in provider options.""" + storage = MSCObjectStorage(bucket="b", endpoint_url="http://minio:9000") + assert storage.endpoint_url == "http://minio:9000" + + def test_init_s3_with_rust_client(self, mock_msc): + """S3 init with use_rust_client=True adds rust_client section to config.""" + storage = MSCObjectStorage(bucket="b", use_rust_client=True) + assert storage.use_rust_client is True + + def test_init_unsupported_storage_type(self, mock_msc): + """__init__ raises ValueError for unsupported storage_type.""" + with pytest.raises(ValueError, match="Unsupported storage_type"): + MSCObjectStorage(bucket="b", storage_type="gcs") + + def test_upload_directory_non_recursive(self, mock_msc): + """upload_directory with recursive=False only counts top-level files.""" + with tempfile.TemporaryDirectory() as tmpdir: + (Path(tmpdir) / "top.txt").write_text("hello") + subdir = Path(tmpdir) / "sub" + subdir.mkdir() + (subdir / "deep.txt").write_text("world") + + storage = MSCObjectStorage(bucket="b", region="us-east-1") + storage._storage_client.sync_from.return_value = None + + result = storage.upload_directory( + local_directory=tmpdir, + remote_prefix="prefix", + recursive=False, + ) + assert result.success is True + assert result.files_uploaded == 1 # only top-level file + + def test_upload_file_path_is_directory(self, mock_msc): + """upload_file returns False when local_file path is a directory.""" + with tempfile.TemporaryDirectory() as tmpdir: + storage = MSCObjectStorage(bucket="b", region="us-east-1") + result = storage.upload_file(local_file=tmpdir, remote_key="key.txt") + assert result is False + + def test_delete_file_generic_exception(self, mock_msc): + """delete_file returns False on an unexpected (non-FileNotFoundError) exception.""" + storage = MSCObjectStorage(bucket="b", region="us-east-1") + storage._storage_client.delete.side_effect = RuntimeError("unexpected") + assert storage.delete_file("my/key") is False + + def test_rsa_signer_no_key_raises(self, mock_msc): + """_rsa_signer raises ObjectStorageError when cloudfront_private_key is None.""" + storage = MSCObjectStorage(bucket="b", region="us-east-1") + storage.cloudfront_private_key = None + with pytest.raises(ObjectStorageError, match="No CloudFront private key"): + storage._rsa_signer(b"message") + + def test_rsa_signer_with_mocked_key(self, mock_msc): + """_rsa_signer signs message using the configured private key.""" + storage = MSCObjectStorage(bucket="b", region="us-east-1") + storage.cloudfront_private_key = ( + "-----BEGIN RSA PRIVATE KEY-----\nfake\n-----END RSA PRIVATE KEY-----" + ) + + mock_private_key = MagicMock() + mock_private_key.sign.return_value = b"fake_signature" + mock_serialization = MagicMock() + mock_serialization.load_pem_private_key.return_value = mock_private_key + + crypto_mocks = { + "cryptography": MagicMock(), + "cryptography.hazmat": MagicMock(), + "cryptography.hazmat.primitives": MagicMock(), + "cryptography.hazmat.primitives.hashes": MagicMock(), + "cryptography.hazmat.primitives.asymmetric": MagicMock(), + "cryptography.hazmat.primitives.asymmetric.padding": MagicMock(), + "cryptography.hazmat.primitives.serialization": mock_serialization, + } + with patch.dict(sys.modules, crypto_mocks): + result = storage._rsa_signer(b"message") + + assert result == b"fake_signature" + + def test_generate_signed_url_unsupported_type(self, mock_msc): + """generate_signed_url raises ObjectStorageError for unsupported storage_type.""" + storage = MSCObjectStorage(bucket="b", region="us-east-1") + storage.storage_type = "unsupported" + with pytest.raises(ObjectStorageError, match="Unsupported storage_type"): + storage.generate_signed_url("key.txt") + + +class TestMSCObjectStorageAzure: + """Tests for MSCObjectStorage with Azure storage type.""" + + @pytest.fixture + def mock_msc(self): + mock_module = MagicMock() + mock_module.StorageClientConfig.from_dict.side_effect = [ + MagicMock(), + MagicMock(), + ] + mock_module.StorageClient.return_value = MagicMock() + with patch.dict(sys.modules, {"multistorageclient": mock_module}): + yield mock_module + + def _make_azure_storage(self, mock_msc, **kwargs): + """Helper to reset mock side_effect for multiple instantiations.""" + mock_msc.StorageClientConfig.from_dict.side_effect = [ + MagicMock(), + MagicMock(), + ] + return MSCObjectStorage(**kwargs) + + def test_init_azure_managed_identity_with_account_name(self, mock_msc): + """Azure init with managed identity uses DefaultAzureCredentials.""" + storage = MSCObjectStorage( + bucket="mycontainer", + storage_type="azure", + azure_account_name="myaccount", + ) + assert storage.use_managed_identity is True + assert storage.azure_account_name == "myaccount" + assert storage.storage_type == "azure" + + def test_init_azure_managed_identity_with_endpoint_url(self, mock_msc): + """Azure init with managed identity and explicit endpoint_url succeeds.""" + storage = MSCObjectStorage( + bucket="mycontainer", + storage_type="azure", + endpoint_url="https://myaccount.blob.core.windows.net", + ) + assert storage.use_managed_identity is True + + def test_init_azure_no_account_name_no_endpoint_raises(self, mock_msc): + """Azure managed identity raises ObjectStorageError when neither account name nor endpoint_url is given.""" + with pytest.raises( + ObjectStorageError, match="Azure endpoint_url cannot be determined" + ): + MSCObjectStorage( + bucket="mycontainer", + storage_type="azure", + ) + + def test_upload_directory_azure_destination(self, mock_msc): + """upload_directory for azure uses azure:// in the destination.""" + with tempfile.TemporaryDirectory() as tmpdir: + (Path(tmpdir) / "f.txt").write_text("hi") + storage = MSCObjectStorage( + bucket="mycontainer", + storage_type="azure", + azure_account_name="acct", + endpoint_url="https://acct.blob.core.windows.net", + ) + storage._storage_client.sync_from.return_value = None + + result = storage.upload_directory( + local_directory=tmpdir, + remote_prefix="prefix", + ) + assert result.success is True + assert "azure://" in result.destination + assert "mycontainer" in result.destination + + def test_generate_signed_url_azure_not_supported(self, mock_msc): + """generate_signed_url for Azure raises; clients use tokens to read blobs.""" + storage = MSCObjectStorage( + bucket="mycontainer", + storage_type="azure", + azure_account_name="acct", + endpoint_url="https://acct.blob.core.windows.net", + ) + with pytest.raises(ObjectStorageError, match="not generated by the server"): + storage.generate_signed_url("key.txt") diff --git a/test/serve/server/test_server_utils.py b/test/serve/server/test_server_utils.py index 3e3b663ff..6b5837795 100644 --- a/test/serve/server/test_server_utils.py +++ b/test/serve/server/test_server_utils.py @@ -206,10 +206,45 @@ def test_result_zip_stage_object_storage_disabled_queues_finalize_metadata(self) assert "process_finalize_metadata" in mock_queue.enqueue.call_args[0][0] assert mock_queue.enqueue.call_args[0][1:3] == ("wf", "exec_1") + def test_object_storage_stage_with_geocatalog_url_queues_geocatalog(self): + """current_stage=object_storage with geocatalog URL enqueues process_geocatalog_ingestion.""" + mock_redis = MagicMock() + mock_config = MagicMock() + mock_config.object_storage.azure_geocatalog_url = ( + "https://geocatalog.example.com" + ) + mock_config.queue.geocatalog_ingestion_queue_name = "geocatalog_ingestion" + mock_config.queue.default_timeout = "1h" + mock_config.queue.job_timeout = "2h" + mock_job = MagicMock() + mock_job.id = "job_geo" + mock_queue = MagicMock() + mock_queue.enqueue.return_value = mock_job + + with ( + patch( + "earth2studio.serve.server.utils.get_config", return_value=mock_config + ), + patch("earth2studio.serve.server.utils.Queue", return_value=mock_queue), + ): + result = queue_next_stage( + redis_client=mock_redis, + current_stage="object_storage", + workflow_name="wf", + execution_id="exec_1", + output_path_str="/out", + ) + + assert result == "job_geo" + mock_queue.enqueue.assert_called_once() + assert "process_geocatalog_ingestion" in mock_queue.enqueue.call_args[0][0] + assert mock_queue.enqueue.call_args[0][1:3] == ("wf", "exec_1") + def test_object_storage_stage_queues_finalize_metadata(self): - """current_stage=object_storage enqueues process_finalize_metadata.""" + """current_stage=object_storage enqueues process_finalize_metadata when geocatalog is not configured.""" mock_redis = MagicMock() mock_config = MagicMock() + mock_config.object_storage.azure_geocatalog_url = None mock_config.queue.finalize_metadata_queue_name = "finalize_metadata" mock_config.queue.default_timeout = "1h" mock_config.queue.job_timeout = "2h" @@ -236,6 +271,37 @@ def test_object_storage_stage_queues_finalize_metadata(self): mock_queue.enqueue.assert_called_once() assert "process_finalize_metadata" in mock_queue.enqueue.call_args[0][0] + def test_geocatalog_ingestion_stage_queues_finalize_metadata(self): + """current_stage=geocatalog_ingestion enqueues process_finalize_metadata.""" + mock_redis = MagicMock() + mock_config = MagicMock() + mock_config.queue.finalize_metadata_queue_name = "finalize_metadata" + mock_config.queue.default_timeout = "1h" + mock_config.queue.job_timeout = "2h" + mock_job = MagicMock() + mock_job.id = "job_finalize" + mock_queue = MagicMock() + mock_queue.enqueue.return_value = mock_job + + with ( + patch( + "earth2studio.serve.server.utils.get_config", return_value=mock_config + ), + patch("earth2studio.serve.server.utils.Queue", return_value=mock_queue), + ): + result = queue_next_stage( + redis_client=mock_redis, + current_stage="geocatalog_ingestion", + workflow_name="wf", + execution_id="exec_1", + output_path_str="/out", + ) + + assert result == "job_finalize" + mock_queue.enqueue.assert_called_once() + assert "process_finalize_metadata" in mock_queue.enqueue.call_args[0][0] + assert mock_queue.enqueue.call_args[0][1:3] == ("wf", "exec_1") + def test_enqueue_exception_returns_none(self): """When Queue.enqueue raises, queue_next_stage returns None.""" mock_redis = MagicMock() diff --git a/test/serve/server/test_server_workflow.py b/test/serve/server/test_server_workflow.py index d73adfffe..b6d573691 100644 --- a/test/serve/server/test_server_workflow.py +++ b/test/serve/server/test_server_workflow.py @@ -1592,6 +1592,114 @@ def test_auto_register_workflows_with_error(self): self.registry.auto_register_workflows(mock_redis) +class TestWorkflowRegistryExposure: + """Tests for WorkflowRegistry.is_workflow_exposed and list_workflows with exposure filtering.""" + + def setup_method(self): + self.registry = WorkflowRegistry() + self.registry.register(Workflow1) + self.registry.register(Workflow2) + self.registry.register(Workflow3) + + def _make_config(self, exposed=None, warmup=None): + mock_config = MagicMock() + mock_config.workflow_exposure.exposed_workflows = ( + exposed if exposed is not None else [] + ) + mock_config.workflow_exposure.warmup_workflows = ( + warmup if warmup is not None else [] + ) + return mock_config + + # --- is_workflow_exposed --- + + def test_is_workflow_exposed_empty_list_exposes_all(self): + """Empty exposed_workflows means all workflows are exposed.""" + mock_config = self._make_config(exposed=[], warmup=[]) + with patch( + "earth2studio.serve.server.config.get_config", return_value=mock_config + ): + assert self.registry.is_workflow_exposed("workflow1") is True + assert self.registry.is_workflow_exposed("workflow2") is True + assert self.registry.is_workflow_exposed("unknown_wf") is True + + def test_is_workflow_exposed_in_exposed_list(self): + """Workflow in exposed_workflows is exposed.""" + mock_config = self._make_config(exposed=["workflow1", "workflow2"], warmup=[]) + with patch( + "earth2studio.serve.server.config.get_config", return_value=mock_config + ): + assert self.registry.is_workflow_exposed("workflow1") is True + assert self.registry.is_workflow_exposed("workflow2") is True + + def test_is_workflow_exposed_in_warmup_list_only(self): + """Workflow in warmup_workflows (but not exposed_workflows) is still exposed.""" + mock_config = self._make_config(exposed=["workflow1"], warmup=["workflow2"]) + with patch( + "earth2studio.serve.server.config.get_config", return_value=mock_config + ): + assert self.registry.is_workflow_exposed("workflow2") is True + + def test_is_workflow_exposed_not_in_any_list(self): + """Workflow not in exposed or warmup lists is not exposed.""" + mock_config = self._make_config(exposed=["workflow1"], warmup=["workflow2"]) + with patch( + "earth2studio.serve.server.config.get_config", return_value=mock_config + ): + assert self.registry.is_workflow_exposed("workflow3") is False + + # --- list_workflows --- + + def test_list_workflows_exposed_only_empty_list_returns_all(self): + """Empty exposed_workflows with exposed_only=True returns all registered workflows.""" + mock_config = self._make_config(exposed=[], warmup=[]) + with patch( + "earth2studio.serve.server.config.get_config", return_value=mock_config + ): + result = self.registry.list_workflows(exposed_only=True) + assert set(result.keys()) == {"workflow1", "workflow2", "workflow3"} + + def test_list_workflows_exposed_only_filters_to_exposed_list(self): + """exposed_only=True excludes warmup-only workflows from the listing.""" + mock_config = self._make_config( + exposed=["workflow1", "workflow2"], warmup=["workflow3"] + ) + with patch( + "earth2studio.serve.server.config.get_config", return_value=mock_config + ): + result = self.registry.list_workflows(exposed_only=True) + assert set(result.keys()) == {"workflow1", "workflow2"} + assert "workflow3" not in result + + def test_list_workflows_exposed_only_false_returns_all(self): + """exposed_only=False returns all registered workflows regardless of config.""" + mock_config = self._make_config(exposed=["workflow1"], warmup=[]) + with patch( + "earth2studio.serve.server.config.get_config", return_value=mock_config + ): + result = self.registry.list_workflows(exposed_only=False) + assert set(result.keys()) == {"workflow1", "workflow2", "workflow3"} + + def test_list_workflows_default_is_exposed_only(self): + """list_workflows() with no args defaults to exposed_only=True.""" + mock_config = self._make_config(exposed=["workflow1"], warmup=[]) + with patch( + "earth2studio.serve.server.config.get_config", return_value=mock_config + ): + result = self.registry.list_workflows() + assert set(result.keys()) == {"workflow1"} + + def test_list_workflows_returns_descriptions(self): + """list_workflows includes the description for each returned workflow.""" + mock_config = self._make_config(exposed=[], warmup=[]) + with patch( + "earth2studio.serve.server.config.get_config", return_value=mock_config + ): + result = self.registry.list_workflows() + assert result["workflow1"] == "First workflow" + assert result["workflow2"] == "Second workflow" + + # Test helper functions class TestHelperFunctions: """Test helper functions"""