diff --git a/src/musher/_client.py b/src/musher/_client.py index 3fecdc8..d560fd1 100644 --- a/src/musher/_client.py +++ b/src/musher/_client.py @@ -5,7 +5,7 @@ import asyncio import hashlib import threading -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, cast from musher._bundle import ( Asset, @@ -16,7 +16,7 @@ ) from musher._cache import BundleCache from musher._config import MusherConfig, get_config -from musher._errors import IntegrityError +from musher._errors import APIError, IntegrityError from musher._http import HTTPTransport from musher._types import AssetType, BundleRef @@ -121,17 +121,27 @@ async def resolve(self, ref: str) -> ResolveResult: return result - async def fetch_asset(self, asset_id: str, *, version: str | None = None) -> Asset: - """Fetch a single asset by ID. - - Hits ``GET /v1/runner/assets/{id}``. + async def fetch_asset( + self, + logical_path: str, + *, + namespace: str, + slug: str, + version: str | None = None, + ) -> Asset: + """Fetch a single asset by logical path. + + Hits ``GET /v1/namespaces/{ns}/bundles/{slug}/assets/{path}``. """ + import urllib.parse # noqa: PLC0415 + + encoded_path = urllib.parse.quote(logical_path, safe="") params: dict[str, str] = {} if version: params["version"] = version response = await self._http.get( - f"/v1/runner/assets/{asset_id}", + f"/v1/namespaces/{namespace}/bundles/{slug}/assets/{encoded_path}", params=params or None, ) data = _AssetResponse.model_validate(response.json()) @@ -165,15 +175,18 @@ async def pull(self, ref: str) -> Bundle: resolve_result=result, ) - # Determine which assets need fetching vs cache hits - semaphore = asyncio.Semaphore(10) - assets: dict[str, Asset] = {} + # Build a lookup of manifest layers by logical_path for checksum verification + layer_map: dict[str, ManifestAsset] = { + layer.logical_path: layer for layer in result.manifest.layers + } - async def _fetch_layer(layer: ManifestAsset) -> None: - # Check blob cache first + # Check if all blobs are cached + all_cached = True + assets: dict[str, Asset] = {} + for layer in result.manifest.layers: cached_blob = self._cache.get_blob(layer.content_sha256) if cached_blob is not None: - assets[layer.asset_id] = Asset( + assets[layer.logical_path] = Asset( asset_id=layer.asset_id, logical_path=layer.logical_path, asset_type=AssetType(layer.asset_type), @@ -182,23 +195,21 @@ async def _fetch_layer(layer: ManifestAsset) -> None: size_bytes=layer.size_bytes, media_type=layer.media_type, ) - return - - async with semaphore: - asset = await self.fetch_asset(layer.asset_id, version=result.version) - - # Verify checksum - if self._config.verify_checksums: - actual_sha = hashlib.sha256(asset.content).hexdigest() - if actual_sha != layer.content_sha256: - raise IntegrityError(expected=layer.content_sha256, actual=actual_sha) - - # Cache the blob - self._cache.put_blob(layer.content_sha256, asset.content) + else: + all_cached = False - assets[layer.asset_id] = asset + if all_cached: + return Bundle( + ref=result.ref, + version=result.version, + resolve_result=result, + _assets=assets, + ) - _ = await asyncio.gather(*[_fetch_layer(layer) for layer in result.manifest.layers]) + # Fetch all assets via the :pull endpoint + pull_data = await self._pull_version(result.namespace, result.slug, result.version) + pull_manifest = cast("list[dict[str, object]]", pull_data.get("manifest", [])) + assets = self._build_assets_from_pull(pull_manifest, layer_map) return Bundle( ref=result.ref, @@ -207,6 +218,59 @@ async def _fetch_layer(layer: ManifestAsset) -> None: _assets=assets, ) + def _build_assets_from_pull( + self, + pull_manifest: list[dict[str, object]], + layer_map: dict[str, ManifestAsset], + ) -> dict[str, Asset]: + """Build Asset dict from a :pull response, verifying checksums against resolve manifest.""" + assets: dict[str, Asset] = {} + for item in pull_manifest: + logical_path = str(item["logicalPath"]) + content = str(item.get("contentText", "")).encode() + layer = layer_map.get(logical_path) + + # Verify checksum against the resolve manifest + if layer and self._config.verify_checksums: + actual_sha = hashlib.sha256(content).hexdigest() + if actual_sha != layer.content_sha256: + raise IntegrityError(expected=layer.content_sha256, actual=actual_sha) + + content_sha256 = layer.content_sha256 if layer else hashlib.sha256(content).hexdigest() + self._cache.put_blob(content_sha256, content) + + media_type = str(item.get("mediaType") or "") or (layer.media_type if layer else None) + assets[logical_path] = Asset( + asset_id=layer.asset_id if layer else logical_path, + logical_path=logical_path, + asset_type=AssetType(str(item["assetType"])), + content=content, + content_sha256=content_sha256, + size_bytes=layer.size_bytes if layer else len(content), + media_type=media_type or None, + ) + return assets + + async def _pull_version(self, namespace: str, slug: str, version: str) -> dict[str, object]: + """Fetch all assets for a version via the :pull endpoint. + + Tries the namespaced endpoint first, falls back to the hub endpoint + for public bundles in other namespaces. + """ + try: + response = await self._http.get( + f"/v1/namespaces/{namespace}/bundles/{slug}/versions/{version}:pull", + ) + return response.json() # pyright: ignore[reportAny] + except APIError as exc: + if exc.status != 403: # noqa: PLR2004 + raise + # Fall back to the hub endpoint for public bundles + response = await self._http.get( + f"/v1/hub/bundles/{namespace}/{slug}/versions/{version}:pull", + ) + return response.json() # pyright: ignore[reportAny] + class Client: """Synchronous client wrapping :class:`AsyncClient`. @@ -253,6 +317,17 @@ def resolve(self, ref: str) -> ResolveResult: """Resolve a bundle reference (sync).""" return self._run(self._async_client.resolve(ref)) - def fetch_asset(self, asset_id: str, *, version: str | None = None) -> Asset: - """Fetch a single asset by ID (sync).""" - return self._run(self._async_client.fetch_asset(asset_id, version=version)) + def fetch_asset( + self, + logical_path: str, + *, + namespace: str, + slug: str, + version: str | None = None, + ) -> Asset: + """Fetch a single asset by logical path (sync).""" + return self._run( + self._async_client.fetch_asset( + logical_path, namespace=namespace, slug=slug, version=version + ) + ) diff --git a/tests/test_client.py b/tests/test_client.py index b4d24ce..d2748ce 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -47,6 +47,21 @@ "mediaType": "text/markdown", } +_PULL_RESPONSE = { + "namespace": "myorg", + "slug": "my-bundle", + "version": "1.0.0", + "name": "My Bundle", + "manifest": [ + { + "logicalPath": "skills/greet/SKILL.md", + "assetType": "skill", + "contentText": "Hello skill", + "mediaType": "text/markdown", + } + ], +} + @pytest.fixture def config(tmp_path: Path) -> MusherConfig: @@ -92,11 +107,13 @@ async def test_resolve_passes_version_param(self, config: MusherConfig): @respx.mock async def test_fetch_asset(self, config: MusherConfig): - respx.get(f"{_BASE}/v1/runner/assets/asset-1").mock( - return_value=httpx.Response(200, json=_ASSET_RESPONSE) - ) + respx.get( + f"{_BASE}/v1/namespaces/myorg/bundles/my-bundle/assets/skills%2Fgreet%2FSKILL.md" + ).mock(return_value=httpx.Response(200, json=_ASSET_RESPONSE)) async with AsyncClient(config=config) as client: - asset = await client.fetch_asset("asset-1") + asset = await client.fetch_asset( + "skills/greet/SKILL.md", namespace="myorg", slug="my-bundle" + ) assert isinstance(asset, Asset) assert asset.asset_id == "asset-1" assert asset.content == b"Hello skill" @@ -106,8 +123,8 @@ async def test_pull(self, config: MusherConfig): respx.get(f"{_BASE}/v1/namespaces/myorg/bundles/my-bundle:resolve").mock( return_value=httpx.Response(200, json=_RESOLVE_RESPONSE) ) - respx.get(f"{_BASE}/v1/runner/assets/asset-1").mock( - return_value=httpx.Response(200, json=_ASSET_RESPONSE) + respx.get(f"{_BASE}/v1/namespaces/myorg/bundles/my-bundle/versions/1.0.0:pull").mock( + return_value=httpx.Response(200, json=_PULL_RESPONSE) ) async with AsyncClient(config=config) as client: bundle = await client.pull("myorg/my-bundle:1.0.0") @@ -129,12 +146,15 @@ async def test_pull_empty_manifest(self, config: MusherConfig): @respx.mock async def test_pull_checksum_mismatch(self, config: MusherConfig): - bad_asset = {**_ASSET_RESPONSE, "contentText": "WRONG CONTENT"} + bad_pull = { + **_PULL_RESPONSE, + "manifest": [{**_PULL_RESPONSE["manifest"][0], "contentText": "WRONG CONTENT"}], + } respx.get(f"{_BASE}/v1/namespaces/myorg/bundles/my-bundle:resolve").mock( return_value=httpx.Response(200, json=_RESOLVE_RESPONSE) ) - respx.get(f"{_BASE}/v1/runner/assets/asset-1").mock( - return_value=httpx.Response(200, json=bad_asset) + respx.get(f"{_BASE}/v1/namespaces/myorg/bundles/my-bundle/versions/1.0.0:pull").mock( + return_value=httpx.Response(200, json=bad_pull) ) async with AsyncClient(config=config) as client: with pytest.raises(IntegrityError): @@ -205,6 +225,34 @@ async def test_oci_digest_flows_to_meta(self, config: MusherConfig): meta = json.loads(meta_path.read_text()) assert meta["ociDigest"] == "sha256:abc123" + @respx.mock + async def test_pull_hub_fallback(self, config: MusherConfig): + """When namespaced :pull returns 403, falls back to hub :pull.""" + respx.get(f"{_BASE}/v1/namespaces/myorg/bundles/my-bundle:resolve").mock( + return_value=httpx.Response(200, json=_RESOLVE_RESPONSE) + ) + # Namespaced :pull returns 403 (not authorized for this namespace) + respx.get(f"{_BASE}/v1/namespaces/myorg/bundles/my-bundle/versions/1.0.0:pull").mock( + return_value=httpx.Response( + 403, + json={ + "type": "https://api.platform.musher.dev/errors/forbidden", + "title": "Forbidden", + "status": 403, + "detail": "Not authorized", + }, + ) + ) + # Hub :pull succeeds + respx.get(f"{_BASE}/v1/hub/bundles/myorg/my-bundle/versions/1.0.0:pull").mock( + return_value=httpx.Response(200, json=_PULL_RESPONSE) + ) + async with AsyncClient(config=config) as client: + bundle = await client.pull("myorg/my-bundle:1.0.0") + assert isinstance(bundle, Bundle) + assert bundle.version == "1.0.0" + assert len(bundle.files()) == 1 + class TestClient: def test_instantiation(self, config: MusherConfig): @@ -221,8 +269,8 @@ def test_sync_pull(self, config: MusherConfig): respx.get(f"{_BASE}/v1/namespaces/myorg/bundles/my-bundle:resolve").mock( return_value=httpx.Response(200, json=_RESOLVE_RESPONSE) ) - respx.get(f"{_BASE}/v1/runner/assets/asset-1").mock( - return_value=httpx.Response(200, json=_ASSET_RESPONSE) + respx.get(f"{_BASE}/v1/namespaces/myorg/bundles/my-bundle/versions/1.0.0:pull").mock( + return_value=httpx.Response(200, json=_PULL_RESPONSE) ) with Client(config=config) as client: bundle = client.pull("myorg/my-bundle:1.0.0") diff --git a/uv.lock b/uv.lock index fdd78e1..a176aa8 100644 --- a/uv.lock +++ b/uv.lock @@ -1430,7 +1430,7 @@ wheels = [ [[package]] name = "musher-sdk" -version = "0.1.1" +version = "0.2.0" source = { editable = "." } dependencies = [ { name = "httpx" },