diff --git a/application_sdk/constants.py b/application_sdk/constants.py index 2b52f9e2d..94cab0fe6 100644 --- a/application_sdk/constants.py +++ b/application_sdk/constants.py @@ -159,6 +159,11 @@ #: Maximum number of activities that can run concurrently MAX_CONCURRENT_ACTIVITIES = int(os.getenv("ATLAN_MAX_CONCURRENT_ACTIVITIES", "5")) +#: Maximum concurrent object-store transfers (uploads / downloads) +MAX_CONCURRENT_STORAGE_TRANSFERS = int( + os.getenv("ATLAN_MAX_CONCURRENT_STORAGE_TRANSFERS", "4") +) + #: Build ID for worker versioning (injected by TWD controller via Kubernetes Downward API). #: When set, workers identify themselves with this build ID so the Temporal server can #: route tasks to the correct version during versioned deployments. diff --git a/application_sdk/storage/transfer.py b/application_sdk/storage/transfer.py index d8767bca5..612802533 100644 --- a/application_sdk/storage/transfer.py +++ b/application_sdk/storage/transfer.py @@ -17,12 +17,14 @@ from __future__ import annotations +import asyncio import hashlib import os import tempfile from pathlib import Path, PurePosixPath from typing import TYPE_CHECKING +from application_sdk.constants import MAX_CONCURRENT_STORAGE_TRANSFERS from application_sdk.contracts.types import FileReference, StorageTier if TYPE_CHECKING: @@ -168,6 +170,7 @@ async def upload( store: "ObjectStore | None" = None, _app_prefix: str = "", _tier: StorageTier = StorageTier.RETAINED, + max_concurrency: int = MAX_CONCURRENT_STORAGE_TRANSFERS, ) -> "UploadOutput": """Upload a local file or directory to the object store. @@ -188,6 +191,8 @@ async def upload( skip_if_exists: Skip files whose SHA-256 matches the stored sidecar. store: Object store to use, or ``None`` to resolve from infrastructure. _app_prefix: Internal prefix injected by the ``App.upload`` task. + max_concurrency: Maximum parallel uploads for directory mode + (default :data:`~application_sdk.constants.MAX_CONCURRENT_STORAGE_TRANSFERS`). Returns: :class:`~application_sdk.contracts.storage.UploadOutput` @@ -249,16 +254,30 @@ async def upload( else: prefix = src.name + sem = asyncio.Semaphore(max_concurrency) + files = [p for p in src.rglob("*") if p.is_file()] - transferred_count = 0 + + async def _bounded_upload(file_path: Path, key: str) -> bool: + async with sem: + ok, _ = await _upload_one( + resolved, file_path, key, skip_if_exists=skip_if_exists + ) + return ok + + keys = [] for file_path in files: relative = str(file_path.relative_to(src)).replace(os.sep, "/") - key = f"{prefix}/{relative}" if prefix else relative - ok, _ = await _upload_one( - resolved, file_path, key, skip_if_exists=skip_if_exists - ) - if ok: - transferred_count += 1 + keys.append(f"{prefix}/{relative}" if prefix else relative) + + results = await asyncio.gather( + *[_bounded_upload(fp, k) for fp, k in zip(files, keys)], + return_exceptions=True, + ) + errors = [r for r in results if isinstance(r, BaseException)] + if errors: + raise errors[0] + transferred_count = sum(1 for ok in results if ok) store_prefix = (prefix.rstrip("/") + "/") if prefix else "" reason = "uploaded" if transferred_count > 0 else "skipped:hash_match" diff --git a/tests/integration/test_storage_io.py b/tests/integration/test_storage_io.py index fb346bc7a..fd5c7aae7 100644 --- a/tests/integration/test_storage_io.py +++ b/tests/integration/test_storage_io.py @@ -250,3 +250,93 @@ async def test_delete_prefix_empty(store): """delete_prefix on nonexistent prefix returns 0.""" deleted = await delete_prefix("nonexistent-prefix", store) assert deleted == 0 + + +# ------------------------------------------------------------------ +# transfer.upload / download — concurrent directory path +# ------------------------------------------------------------------ + + +@pytest.mark.integration +async def test_transfer_upload_directory_concurrent(store, tmp_path): + """transfer.upload handles multi-file directories via asyncio.gather.""" + from application_sdk.storage.transfer import upload + + src = tmp_path / "src" + src.mkdir() + for i in range(15): + (src / f"part_{i}.csv").write_text(f"row-{i}") + + out = await upload(str(src), "transfer-conc/", store=store) + + assert out.ref.file_count == 15 + assert out.synced is True + assert out.reason == "uploaded" + + # Verify all data keys exist in the store (excludes .sha256 sidecars) + keys = await list_keys("transfer-conc", store, suffix=".csv") + assert len(keys) == 15 + + +@pytest.mark.integration +async def test_transfer_upload_download_directory_roundtrip(store, tmp_path): + """Full roundtrip: upload a directory concurrently, download and verify.""" + from application_sdk.storage.transfer import download, upload + + src = tmp_path / "src" + src.mkdir() + sub = src / "nested" + sub.mkdir() + (src / "root.txt").write_bytes(b"root-content") + (sub / "child.txt").write_bytes(b"child-content") + + await upload(str(src), "rt-dir/", store=store) + + dest = tmp_path / "dest" + dl = await download("rt-dir/", str(dest), store=store) + + assert dl.ref.file_count == 2 + assert dl.synced is True + assert (dest / "root.txt").read_bytes() == b"root-content" + assert (dest / "nested" / "child.txt").read_bytes() == b"child-content" + + +@pytest.mark.integration +async def test_transfer_upload_directory_skip_partial(store, tmp_path): + """skip_if_exists skips unchanged files and re-uploads changed ones.""" + from application_sdk.storage.transfer import upload + + src = tmp_path / "src" + src.mkdir() + (src / "stable.txt").write_bytes(b"same") + (src / "changing.txt").write_bytes(b"v1") + + out1 = await upload(str(src), "partial/", store=store, skip_if_exists=True) + assert out1.synced is True + + # Second upload with same content → all skipped + out2 = await upload(str(src), "partial/", store=store, skip_if_exists=True) + assert out2.synced is False + assert out2.reason == "skipped:hash_match" + + # Change one file → partial transfer + (src / "changing.txt").write_bytes(b"v2") + out3 = await upload(str(src), "partial/", store=store, skip_if_exists=True) + assert out3.synced is True + assert out3.reason == "uploaded" + + +@pytest.mark.integration +async def test_transfer_upload_directory_max_concurrency(store, tmp_path): + """max_concurrency parameter is respected (runs with low concurrency).""" + from application_sdk.storage.transfer import upload + + src = tmp_path / "src" + src.mkdir() + for i in range(8): + (src / f"f{i}.bin").write_bytes(f"data-{i}".encode()) + + out = await upload(str(src), "low-conc/", store=store, max_concurrency=2) + + assert out.ref.file_count == 8 + assert out.synced is True diff --git a/tests/unit/storage/test_transfer.py b/tests/unit/storage/test_transfer.py index 2e28089ce..5c4cc6ee5 100644 --- a/tests/unit/storage/test_transfer.py +++ b/tests/unit/storage/test_transfer.py @@ -80,6 +80,60 @@ async def test_upload_directory_skip_unchanged(self, store, tmp_path) -> None: assert out2.synced is False assert out2.reason == "skipped:hash_match" + async def test_upload_directory_concurrent_completes(self, store, tmp_path) -> None: + """Multi-file directory upload completes correctly via concurrent path.""" + for i in range(10): + (tmp_path / f"file_{i}.txt").write_bytes(f"content_{i}".encode()) + out = await upload(str(tmp_path), "conc", store=store) + assert out.ref.file_count == 10 + assert out.synced is True + assert out.reason == "uploaded" + + # Verify all files are downloadable + dest = tmp_path / "dest" + dl = await download("conc/", str(dest), store=store) + assert dl.ref.file_count == 10 + + async def test_upload_directory_partial_skip_count(self, store, tmp_path) -> None: + """transferred_count is accurate when some files are skipped.""" + (tmp_path / "a.txt").write_bytes(b"aaa") + (tmp_path / "b.txt").write_bytes(b"bbb") + (tmp_path / "c.txt").write_bytes(b"ccc") + + # Upload once so all files get sidecars + await upload(str(tmp_path), "partial", store=store, skip_if_exists=True) + + # Change only one file + (tmp_path / "b.txt").write_bytes(b"bbb_v2") + out = await upload(str(tmp_path), "partial", store=store, skip_if_exists=True) + + # Only the changed file should have been transferred + assert out.synced is True + assert out.reason == "uploaded" + + async def test_upload_directory_error_propagation( + self, store, tmp_path, monkeypatch + ) -> None: + """Error in one upload propagates correctly from asyncio.gather.""" + (tmp_path / "ok.txt").write_bytes(b"fine") + (tmp_path / "fail.txt").write_bytes(b"boom") + + from application_sdk.storage import transfer as transfer_mod + + _original = transfer_mod._upload_one + + async def _failing_upload_one(st, local_file, store_key, *, skip_if_exists): + if "fail.txt" in str(local_file): + raise RuntimeError("simulated upload failure") + return await _original( + st, local_file, store_key, skip_if_exists=skip_if_exists + ) + + monkeypatch.setattr(transfer_mod, "_upload_one", _failing_upload_one) + + with pytest.raises(RuntimeError, match="simulated upload failure"): + await upload(str(tmp_path), "errtest", store=store) + class TestUploadStorageSubdir: """Tests for the storage_subdir parameter on upload."""