Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 20 additions & 5 deletions src/musher/_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
)
from musher._cache import BundleCache
from musher._config import MusherConfig, get_config
from musher._errors import APIError, IntegrityError
from musher._errors import APIError, AuthenticationError, IntegrityError
from musher._http import HTTPTransport
from musher._types import AssetType, BundleRef

Expand Down Expand Up @@ -102,10 +102,23 @@ async def resolve(self, ref: str) -> ResolveResult:
if parsed.digest:
params["digest"] = parsed.digest

response = await self._http.get(
f"/v1/namespaces/{parsed.namespace}/bundles/{parsed.slug}:resolve",
params=params or None,
)
try:
response = await self._http.get(
f"/v1/namespaces/{parsed.namespace}/bundles/{parsed.slug}:resolve",
params=params or None,
)
except AuthenticationError:
response = await self._http.get(
f"/v1/hub/bundles/{parsed.namespace}/{parsed.slug}:resolve",
params=params or None,
)
except APIError as exc:
if exc.status != 403: # noqa: PLR2004
raise
response = await self._http.get(
f"/v1/hub/bundles/{parsed.namespace}/{parsed.slug}:resolve",
params=params or None,
)
response_data: dict[str, object] = response.json() # pyright: ignore[reportAny]
result = ResolveResult.model_validate(response_data)

Expand Down Expand Up @@ -277,6 +290,8 @@ async def _pull_version(self, namespace: str, slug: str, version: str) -> dict[s
f"/v1/namespaces/{namespace}/bundles/{slug}/versions/{version}:pull",
)
return response.json() # pyright: ignore[reportAny]
except AuthenticationError:
pass # No token or invalid — try public hub endpoint
except APIError as exc:
if exc.status != 403: # noqa: PLR2004
raise
Expand Down
56 changes: 56 additions & 0 deletions tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,62 @@ async def test_pull_hub_fallback(self, config: MusherConfig):
assert bundle.version == "1.0.0"
assert len(bundle.files()) == 1

@respx.mock
async def test_pull_hub_fallback_401(self, config: MusherConfig):
"""When namespaced :pull returns 401, falls back to hub :pull."""
respx.get(f"{_BASE}/v1/namespaces/myorg/bundles/my-bundle:resolve").mock(
return_value=httpx.Response(200, json=_RESOLVE_RESPONSE)
)
# Namespaced :pull returns 401 (no API key)
respx.get(f"{_BASE}/v1/namespaces/myorg/bundles/my-bundle/versions/1.0.0:pull").mock(
return_value=httpx.Response(401, json={"detail": "Invalid or missing API token"})
)
# Hub :pull succeeds
respx.get(f"{_BASE}/v1/hub/bundles/myorg/my-bundle/versions/1.0.0:pull").mock(
return_value=httpx.Response(200, json=_PULL_RESPONSE)
)
async with AsyncClient(config=config) as client:
bundle = await client.pull("myorg/my-bundle:1.0.0")
assert isinstance(bundle, Bundle)
assert bundle.version == "1.0.0"
assert len(bundle.files()) == 1

@respx.mock
async def test_resolve_hub_fallback_401(self, config: MusherConfig):
"""When namespaced :resolve returns 401, falls back to hub :resolve."""
respx.get(f"{_BASE}/v1/namespaces/myorg/bundles/my-bundle:resolve").mock(
return_value=httpx.Response(401, json={"detail": "Invalid or missing API token"})
)
respx.get(f"{_BASE}/v1/hub/bundles/myorg/my-bundle:resolve").mock(
return_value=httpx.Response(200, json=_RESOLVE_RESPONSE)
)
async with AsyncClient(config=config) as client:
result = await client.resolve("myorg/my-bundle:1.0.0")
assert isinstance(result, ResolveResult)
assert result.version == "1.0.0"

@respx.mock
async def test_resolve_hub_fallback_403(self, config: MusherConfig):
"""When namespaced :resolve returns 403, falls back to hub :resolve."""
respx.get(f"{_BASE}/v1/namespaces/myorg/bundles/my-bundle:resolve").mock(
return_value=httpx.Response(
403,
json={
"type": "https://api.platform.musher.dev/errors/forbidden",
"title": "Forbidden",
"status": 403,
"detail": "Not authorized",
},
)
)
respx.get(f"{_BASE}/v1/hub/bundles/myorg/my-bundle:resolve").mock(
return_value=httpx.Response(200, json=_RESOLVE_RESPONSE)
)
async with AsyncClient(config=config) as client:
result = await client.resolve("myorg/my-bundle:1.0.0")
assert isinstance(result, ResolveResult)
assert result.version == "1.0.0"


class TestClient:
def test_instantiation(self, config: MusherConfig):
Expand Down
Loading