Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 51 additions & 2 deletions cecli/helpers/model_providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -439,6 +439,40 @@ def _get_cache_file(self, provider: str) -> Path:
fname = f"{provider}_models.json"
return self.cache_dir / fname

def _normalize_models_payload(self, provider: str, payload: Dict) -> Dict:
"""Normalize provider payloads into an OpenAI-style `{data: [{id: ...}]}`."""
if not isinstance(payload, dict):
return {}
if "data" in payload and isinstance(payload.get("data"), list):
return payload
# Fireworks returns `{models: [...], nextPageToken: ..., totalSize: ...}`
models = payload.get("models")
if isinstance(models, list):
normalized = []
for item in models:
if not isinstance(item, dict):
continue
model_id = item.get("name") or item.get("id")
if not model_id:
continue
record = {"id": model_id}
for key in (
"max_input_tokens",
"max_output_tokens",
"max_tokens",
"context_length",
"context_window",
"mode",
"pricing",
"input_cost_per_token",
"output_cost_per_token",
):
if key in item and item[key] is not None:
record[key] = item[key]
normalized.append(record)
return {"data": normalized}
return {}

def _load_cache(self, provider: str) -> None:
if self._cache_loaded.get(provider):
return
Expand All @@ -460,9 +494,10 @@ def _update_cache(self, provider: str) -> None:
payload = self._fetch_provider_models(provider)
cache_file = self._get_cache_file(provider)
if payload:
self._provider_cache[provider] = payload
normalized = self._normalize_models_payload(provider, payload)
self._provider_cache[provider] = normalized
try:
cache_file.write_text(json.dumps(payload, indent=2))
cache_file.write_text(json.dumps(normalized, indent=2))
except OSError:
pass
return
Expand All @@ -479,6 +514,13 @@ def _fetch_provider_models(self, provider: str) -> Optional[Dict]:
models_url = api_base.rstrip("/") + "/models"
if not models_url:
return None
# Substitute {account_id} placeholder if present
if "{account_id}" in models_url:
account_id = self._get_account_id(provider)
if not account_id:
print(f"Failed to fetch {provider} model list: account_id_env not set")
return None
models_url = models_url.replace("{account_id}", account_id)
headers = {}
default_headers = config.get("default_headers") or {}
headers.update(default_headers)
Expand Down Expand Up @@ -509,6 +551,13 @@ def _get_api_key(self, provider: str) -> Optional[str]:
return value
return None

def _get_account_id(self, provider: str) -> Optional[str]:
config = self.provider_configs[provider]
account_id_env = config.get("account_id_env")
if account_id_env:
return os.environ.get(account_id_env)
return None


def ensure_litellm_providers_registered() -> None:
"""One-time registration guard for LiteLLM provider metadata."""
Expand Down
9 changes: 9 additions & 0 deletions cecli/resources/providers.json
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,15 @@
],
"display_name": "veniceai"
},
"fireworks_ai": {
"api_base": "https://api.fireworks.ai/inference/v1",
"api_key_env": [
"FIREWORKS_AI_API_KEY"
],
"display_name": "fireworks_ai",
"models_url": "https://api.fireworks.ai/v1/accounts/{account_id}/models",
"account_id_env": "FIREWORKS_AI_ACCOUNT_ID"
},
"xiaomi_mimo": {
"api_base": "https://api.xiaomimimo.com/v1",
"api_key_env": [
Expand Down
118 changes: 118 additions & 0 deletions tests/basic/test_model_provider_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -430,3 +430,121 @@ def _fake_get(self, model):

model = Model(model_name)
assert model.info["max_tokens"] == 2048


def test_get_account_id_from_env(monkeypatch, tmp_path):
config = {
"fireworks_ai": {
"api_base": "https://api.fireworks.ai/inference/v1",
"api_key_env": ["FIREWORKS_AI_API_KEY"],
"account_id_env": "FIREWORKS_AI_ACCOUNT_ID",
}
}

manager = _make_manager(tmp_path, config)
monkeypatch.setenv("FIREWORKS_AI_ACCOUNT_ID", "test-account-123")

assert manager._get_account_id("fireworks_ai") == "test-account-123"


def test_get_account_id_missing_env(monkeypatch, tmp_path):
config = {
"fireworks_ai": {
"api_base": "https://api.fireworks.ai/inference/v1",
"api_key_env": ["FIREWORKS_AI_API_KEY"],
"account_id_env": "FIREWORKS_AI_ACCOUNT_ID",
}
}

manager = _make_manager(tmp_path, config)
monkeypatch.delenv("FIREWORKS_AI_ACCOUNT_ID", raising=False)

assert manager._get_account_id("fireworks_ai") is None


def test_get_account_id_no_config(tmp_path):
config = {
"demo": {
"api_base": "https://example.com/v1",
"api_key_env": ["DEMO_KEY"],
}
}

manager = _make_manager(tmp_path, config)

assert manager._get_account_id("demo") is None


def test_models_url_account_id_substitution(monkeypatch, tmp_path):
config = {
"fireworks_ai": {
"api_base": "https://api.fireworks.ai/inference/v1",
"api_key_env": ["FIREWORKS_AI_API_KEY"],
"models_url": "https://api.fireworks.ai/v1/accounts/{account_id}/models",
"account_id_env": "FIREWORKS_AI_ACCOUNT_ID",
"requires_api_key": False,
}
}

manager = _make_manager(tmp_path, config)
monkeypatch.setenv("FIREWORKS_AI_ACCOUNT_ID", "my-account-id")

captured = {}

def _fake_get(url, *, headers=None, timeout=None, verify=None):
captured["url"] = url
captured["headers"] = headers
captured["timeout"] = timeout
captured["verify"] = verify
return DummyResponse({"data": []})

monkeypatch.setattr("requests.get", _fake_get)

manager._fetch_provider_models("fireworks_ai")

assert captured["url"] == "https://api.fireworks.ai/v1/accounts/my-account-id/models"


def test_models_url_account_id_missing_skips_fetch(monkeypatch, tmp_path, capsys):
config = {
"fireworks_ai": {
"api_base": "https://api.fireworks.ai/inference/v1",
"api_key_env": ["FIREWORKS_AI_API_KEY"],
"models_url": "https://api.fireworks.ai/v1/accounts/{account_id}/models",
"account_id_env": "FIREWORKS_AI_ACCOUNT_ID",
"requires_api_key": False,
}
}

manager = _make_manager(tmp_path, config)
monkeypatch.delenv("FIREWORKS_AI_ACCOUNT_ID", raising=False)

assert manager._fetch_provider_models("fireworks_ai") is None

captured = capsys.readouterr()
assert "account_id_env not set" in captured.out


def test_models_url_without_placeholder_unchanged(monkeypatch, tmp_path):
config = {
"demo": {
"api_base": "https://example.com/v1",
"api_key_env": ["DEMO_KEY"],
"models_url": "https://example.com/v1/models",
"requires_api_key": False,
}
}

manager = _make_manager(tmp_path, config)

captured = {}

def _fake_get(url, *, headers=None, timeout=None, verify=None):
captured["url"] = url
return DummyResponse({"data": []})

monkeypatch.setattr("requests.get", _fake_get)

manager._fetch_provider_models("demo")

assert captured["url"] == "https://example.com/v1/models"
Loading