Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
137 changes: 106 additions & 31 deletions src/musher/_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import asyncio
import hashlib
import threading
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, cast

from musher._bundle import (
Asset,
Expand All @@ -16,7 +16,7 @@
)
from musher._cache import BundleCache
from musher._config import MusherConfig, get_config
from musher._errors import IntegrityError
from musher._errors import APIError, IntegrityError
from musher._http import HTTPTransport
from musher._types import AssetType, BundleRef

Expand Down Expand Up @@ -121,17 +121,27 @@ async def resolve(self, ref: str) -> ResolveResult:

return result

async def fetch_asset(self, asset_id: str, *, version: str | None = None) -> Asset:
"""Fetch a single asset by ID.

Hits ``GET /v1/runner/assets/{id}``.
async def fetch_asset(
self,
logical_path: str,
*,
namespace: str,
slug: str,
version: str | None = None,
) -> Asset:
"""Fetch a single asset by logical path.

Hits ``GET /v1/namespaces/{ns}/bundles/{slug}/assets/{path}``.
"""
import urllib.parse # noqa: PLC0415

encoded_path = urllib.parse.quote(logical_path, safe="")
params: dict[str, str] = {}
if version:
params["version"] = version

response = await self._http.get(
f"/v1/runner/assets/{asset_id}",
f"/v1/namespaces/{namespace}/bundles/{slug}/assets/{encoded_path}",
params=params or None,
)
data = _AssetResponse.model_validate(response.json())
Expand Down Expand Up @@ -165,15 +175,18 @@ async def pull(self, ref: str) -> Bundle:
resolve_result=result,
)

# Determine which assets need fetching vs cache hits
semaphore = asyncio.Semaphore(10)
assets: dict[str, Asset] = {}
# Build a lookup of manifest layers by logical_path for checksum verification
layer_map: dict[str, ManifestAsset] = {
layer.logical_path: layer for layer in result.manifest.layers
}

async def _fetch_layer(layer: ManifestAsset) -> None:
# Check blob cache first
# Check if all blobs are cached
all_cached = True
assets: dict[str, Asset] = {}
for layer in result.manifest.layers:
cached_blob = self._cache.get_blob(layer.content_sha256)
if cached_blob is not None:
assets[layer.asset_id] = Asset(
assets[layer.logical_path] = Asset(
asset_id=layer.asset_id,
logical_path=layer.logical_path,
asset_type=AssetType(layer.asset_type),
Expand All @@ -182,23 +195,21 @@ async def _fetch_layer(layer: ManifestAsset) -> None:
size_bytes=layer.size_bytes,
media_type=layer.media_type,
)
return

async with semaphore:
asset = await self.fetch_asset(layer.asset_id, version=result.version)

# Verify checksum
if self._config.verify_checksums:
actual_sha = hashlib.sha256(asset.content).hexdigest()
if actual_sha != layer.content_sha256:
raise IntegrityError(expected=layer.content_sha256, actual=actual_sha)

# Cache the blob
self._cache.put_blob(layer.content_sha256, asset.content)
else:
all_cached = False

assets[layer.asset_id] = asset
if all_cached:
return Bundle(
ref=result.ref,
version=result.version,
resolve_result=result,
_assets=assets,
)

_ = await asyncio.gather(*[_fetch_layer(layer) for layer in result.manifest.layers])
# Fetch all assets via the :pull endpoint
pull_data = await self._pull_version(result.namespace, result.slug, result.version)
pull_manifest = cast("list[dict[str, object]]", pull_data.get("manifest", []))
assets = self._build_assets_from_pull(pull_manifest, layer_map)

return Bundle(
ref=result.ref,
Expand All @@ -207,6 +218,59 @@ async def _fetch_layer(layer: ManifestAsset) -> None:
_assets=assets,
)

def _build_assets_from_pull(
self,
pull_manifest: list[dict[str, object]],
layer_map: dict[str, ManifestAsset],
) -> dict[str, Asset]:
"""Build Asset dict from a :pull response, verifying checksums against resolve manifest."""
assets: dict[str, Asset] = {}
for item in pull_manifest:
logical_path = str(item["logicalPath"])
content = str(item.get("contentText", "")).encode()
layer = layer_map.get(logical_path)

# Verify checksum against the resolve manifest
if layer and self._config.verify_checksums:
actual_sha = hashlib.sha256(content).hexdigest()
if actual_sha != layer.content_sha256:
raise IntegrityError(expected=layer.content_sha256, actual=actual_sha)

content_sha256 = layer.content_sha256 if layer else hashlib.sha256(content).hexdigest()
self._cache.put_blob(content_sha256, content)

media_type = str(item.get("mediaType") or "") or (layer.media_type if layer else None)
assets[logical_path] = Asset(
asset_id=layer.asset_id if layer else logical_path,
logical_path=logical_path,
asset_type=AssetType(str(item["assetType"])),
content=content,
content_sha256=content_sha256,
size_bytes=layer.size_bytes if layer else len(content),
media_type=media_type or None,
)
return assets

async def _pull_version(self, namespace: str, slug: str, version: str) -> dict[str, object]:
"""Fetch all assets for a version via the :pull endpoint.

Tries the namespaced endpoint first, falls back to the hub endpoint
for public bundles in other namespaces.
"""
try:
response = await self._http.get(
f"/v1/namespaces/{namespace}/bundles/{slug}/versions/{version}:pull",
)
return response.json() # pyright: ignore[reportAny]
except APIError as exc:
if exc.status != 403: # noqa: PLR2004
raise
# Fall back to the hub endpoint for public bundles
response = await self._http.get(
f"/v1/hub/bundles/{namespace}/{slug}/versions/{version}:pull",
)
return response.json() # pyright: ignore[reportAny]


class Client:
"""Synchronous client wrapping :class:`AsyncClient`.
Expand Down Expand Up @@ -253,6 +317,17 @@ def resolve(self, ref: str) -> ResolveResult:
"""Resolve a bundle reference (sync)."""
return self._run(self._async_client.resolve(ref))

def fetch_asset(self, asset_id: str, *, version: str | None = None) -> Asset:
"""Fetch a single asset by ID (sync)."""
return self._run(self._async_client.fetch_asset(asset_id, version=version))
def fetch_asset(
self,
logical_path: str,
*,
namespace: str,
slug: str,
version: str | None = None,
) -> Asset:
"""Fetch a single asset by logical path (sync)."""
return self._run(
self._async_client.fetch_asset(
logical_path, namespace=namespace, slug=slug, version=version
)
)
70 changes: 59 additions & 11 deletions tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,21 @@
"mediaType": "text/markdown",
}

_PULL_RESPONSE = {
"namespace": "myorg",
"slug": "my-bundle",
"version": "1.0.0",
"name": "My Bundle",
"manifest": [
{
"logicalPath": "skills/greet/SKILL.md",
"assetType": "skill",
"contentText": "Hello skill",
"mediaType": "text/markdown",
}
],
}


@pytest.fixture
def config(tmp_path: Path) -> MusherConfig:
Expand Down Expand Up @@ -92,11 +107,13 @@ async def test_resolve_passes_version_param(self, config: MusherConfig):

@respx.mock
async def test_fetch_asset(self, config: MusherConfig):
respx.get(f"{_BASE}/v1/runner/assets/asset-1").mock(
return_value=httpx.Response(200, json=_ASSET_RESPONSE)
)
respx.get(
f"{_BASE}/v1/namespaces/myorg/bundles/my-bundle/assets/skills%2Fgreet%2FSKILL.md"
).mock(return_value=httpx.Response(200, json=_ASSET_RESPONSE))
async with AsyncClient(config=config) as client:
asset = await client.fetch_asset("asset-1")
asset = await client.fetch_asset(
"skills/greet/SKILL.md", namespace="myorg", slug="my-bundle"
)
assert isinstance(asset, Asset)
assert asset.asset_id == "asset-1"
assert asset.content == b"Hello skill"
Expand All @@ -106,8 +123,8 @@ async def test_pull(self, config: MusherConfig):
respx.get(f"{_BASE}/v1/namespaces/myorg/bundles/my-bundle:resolve").mock(
return_value=httpx.Response(200, json=_RESOLVE_RESPONSE)
)
respx.get(f"{_BASE}/v1/runner/assets/asset-1").mock(
return_value=httpx.Response(200, json=_ASSET_RESPONSE)
respx.get(f"{_BASE}/v1/namespaces/myorg/bundles/my-bundle/versions/1.0.0:pull").mock(
return_value=httpx.Response(200, json=_PULL_RESPONSE)
)
async with AsyncClient(config=config) as client:
bundle = await client.pull("myorg/my-bundle:1.0.0")
Expand All @@ -129,12 +146,15 @@ async def test_pull_empty_manifest(self, config: MusherConfig):

@respx.mock
async def test_pull_checksum_mismatch(self, config: MusherConfig):
bad_asset = {**_ASSET_RESPONSE, "contentText": "WRONG CONTENT"}
bad_pull = {
**_PULL_RESPONSE,
"manifest": [{**_PULL_RESPONSE["manifest"][0], "contentText": "WRONG CONTENT"}],
}
respx.get(f"{_BASE}/v1/namespaces/myorg/bundles/my-bundle:resolve").mock(
return_value=httpx.Response(200, json=_RESOLVE_RESPONSE)
)
respx.get(f"{_BASE}/v1/runner/assets/asset-1").mock(
return_value=httpx.Response(200, json=bad_asset)
respx.get(f"{_BASE}/v1/namespaces/myorg/bundles/my-bundle/versions/1.0.0:pull").mock(
return_value=httpx.Response(200, json=bad_pull)
)
async with AsyncClient(config=config) as client:
with pytest.raises(IntegrityError):
Expand Down Expand Up @@ -205,6 +225,34 @@ async def test_oci_digest_flows_to_meta(self, config: MusherConfig):
meta = json.loads(meta_path.read_text())
assert meta["ociDigest"] == "sha256:abc123"

@respx.mock
async def test_pull_hub_fallback(self, config: MusherConfig):
"""When namespaced :pull returns 403, falls back to hub :pull."""
respx.get(f"{_BASE}/v1/namespaces/myorg/bundles/my-bundle:resolve").mock(
return_value=httpx.Response(200, json=_RESOLVE_RESPONSE)
)
# Namespaced :pull returns 403 (not authorized for this namespace)
respx.get(f"{_BASE}/v1/namespaces/myorg/bundles/my-bundle/versions/1.0.0:pull").mock(
return_value=httpx.Response(
403,
json={
"type": "https://api.platform.musher.dev/errors/forbidden",
"title": "Forbidden",
"status": 403,
"detail": "Not authorized",
},
)
)
# Hub :pull succeeds
respx.get(f"{_BASE}/v1/hub/bundles/myorg/my-bundle/versions/1.0.0:pull").mock(
return_value=httpx.Response(200, json=_PULL_RESPONSE)
)
async with AsyncClient(config=config) as client:
bundle = await client.pull("myorg/my-bundle:1.0.0")
assert isinstance(bundle, Bundle)
assert bundle.version == "1.0.0"
assert len(bundle.files()) == 1


class TestClient:
def test_instantiation(self, config: MusherConfig):
Expand All @@ -221,8 +269,8 @@ def test_sync_pull(self, config: MusherConfig):
respx.get(f"{_BASE}/v1/namespaces/myorg/bundles/my-bundle:resolve").mock(
return_value=httpx.Response(200, json=_RESOLVE_RESPONSE)
)
respx.get(f"{_BASE}/v1/runner/assets/asset-1").mock(
return_value=httpx.Response(200, json=_ASSET_RESPONSE)
respx.get(f"{_BASE}/v1/namespaces/myorg/bundles/my-bundle/versions/1.0.0:pull").mock(
return_value=httpx.Response(200, json=_PULL_RESPONSE)
)
with Client(config=config) as client:
bundle = client.pull("myorg/my-bundle:1.0.0")
Expand Down
2 changes: 1 addition & 1 deletion uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading