diff --git a/cecli/helpers/model_providers.py b/cecli/helpers/model_providers.py index 4092cac188a..99770fc5c7f 100644 --- a/cecli/helpers/model_providers.py +++ b/cecli/helpers/model_providers.py @@ -439,6 +439,40 @@ def _get_cache_file(self, provider: str) -> Path: fname = f"{provider}_models.json" return self.cache_dir / fname + def _normalize_models_payload(self, provider: str, payload: Dict) -> Dict: + """Normalize provider payloads into an OpenAI-style `{data: [{id: ...}]}`.""" + if not isinstance(payload, dict): + return {} + if "data" in payload and isinstance(payload.get("data"), list): + return payload + # Fireworks returns `{models: [...], nextPageToken: ..., totalSize: ...}` + models = payload.get("models") + if isinstance(models, list): + normalized = [] + for item in models: + if not isinstance(item, dict): + continue + model_id = item.get("name") or item.get("id") + if not model_id: + continue + record = {"id": model_id} + for key in ( + "max_input_tokens", + "max_output_tokens", + "max_tokens", + "context_length", + "context_window", + "mode", + "pricing", + "input_cost_per_token", + "output_cost_per_token", + ): + if key in item and item[key] is not None: + record[key] = item[key] + normalized.append(record) + return {"data": normalized} + return {} + def _load_cache(self, provider: str) -> None: if self._cache_loaded.get(provider): return @@ -460,9 +494,10 @@ def _update_cache(self, provider: str) -> None: payload = self._fetch_provider_models(provider) cache_file = self._get_cache_file(provider) if payload: - self._provider_cache[provider] = payload + normalized = self._normalize_models_payload(provider, payload) + self._provider_cache[provider] = normalized try: - cache_file.write_text(json.dumps(payload, indent=2)) + cache_file.write_text(json.dumps(normalized, indent=2)) except OSError: pass return @@ -479,6 +514,13 @@ def _fetch_provider_models(self, provider: str) -> Optional[Dict]: models_url = api_base.rstrip("/") + "/models" if not models_url: return None + # Substitute {account_id} placeholder if present + if "{account_id}" in models_url: + account_id = self._get_account_id(provider) + if not account_id: + print(f"Failed to fetch {provider} model list: account_id_env not set") + return None + models_url = models_url.replace("{account_id}", account_id) headers = {} default_headers = config.get("default_headers") or {} headers.update(default_headers) @@ -509,6 +551,13 @@ def _get_api_key(self, provider: str) -> Optional[str]: return value return None + def _get_account_id(self, provider: str) -> Optional[str]: + config = self.provider_configs[provider] + account_id_env = config.get("account_id_env") + if account_id_env: + return os.environ.get(account_id_env) + return None + def ensure_litellm_providers_registered() -> None: """One-time registration guard for LiteLLM provider metadata.""" diff --git a/cecli/resources/providers.json b/cecli/resources/providers.json index 66171d1a23e..45b717c8cb7 100644 --- a/cecli/resources/providers.json +++ b/cecli/resources/providers.json @@ -57,6 +57,15 @@ ], "display_name": "veniceai" }, + "fireworks_ai": { + "api_base": "https://api.fireworks.ai/inference/v1", + "api_key_env": [ + "FIREWORKS_AI_API_KEY" + ], + "display_name": "fireworks_ai", + "models_url": "https://api.fireworks.ai/v1/accounts/{account_id}/models", + "account_id_env": "FIREWORKS_AI_ACCOUNT_ID" + }, "xiaomi_mimo": { "api_base": "https://api.xiaomimimo.com/v1", "api_key_env": [ diff --git a/tests/basic/test_model_provider_manager.py b/tests/basic/test_model_provider_manager.py index 1465d4019c0..520713ce849 100644 --- a/tests/basic/test_model_provider_manager.py +++ b/tests/basic/test_model_provider_manager.py @@ -430,3 +430,121 @@ def _fake_get(self, model): model = Model(model_name) assert model.info["max_tokens"] == 2048 + + +def test_get_account_id_from_env(monkeypatch, tmp_path): + config = { + "fireworks_ai": { + "api_base": "https://api.fireworks.ai/inference/v1", + "api_key_env": ["FIREWORKS_AI_API_KEY"], + "account_id_env": "FIREWORKS_AI_ACCOUNT_ID", + } + } + + manager = _make_manager(tmp_path, config) + monkeypatch.setenv("FIREWORKS_AI_ACCOUNT_ID", "test-account-123") + + assert manager._get_account_id("fireworks_ai") == "test-account-123" + + +def test_get_account_id_missing_env(monkeypatch, tmp_path): + config = { + "fireworks_ai": { + "api_base": "https://api.fireworks.ai/inference/v1", + "api_key_env": ["FIREWORKS_AI_API_KEY"], + "account_id_env": "FIREWORKS_AI_ACCOUNT_ID", + } + } + + manager = _make_manager(tmp_path, config) + monkeypatch.delenv("FIREWORKS_AI_ACCOUNT_ID", raising=False) + + assert manager._get_account_id("fireworks_ai") is None + + +def test_get_account_id_no_config(tmp_path): + config = { + "demo": { + "api_base": "https://example.com/v1", + "api_key_env": ["DEMO_KEY"], + } + } + + manager = _make_manager(tmp_path, config) + + assert manager._get_account_id("demo") is None + + +def test_models_url_account_id_substitution(monkeypatch, tmp_path): + config = { + "fireworks_ai": { + "api_base": "https://api.fireworks.ai/inference/v1", + "api_key_env": ["FIREWORKS_AI_API_KEY"], + "models_url": "https://api.fireworks.ai/v1/accounts/{account_id}/models", + "account_id_env": "FIREWORKS_AI_ACCOUNT_ID", + "requires_api_key": False, + } + } + + manager = _make_manager(tmp_path, config) + monkeypatch.setenv("FIREWORKS_AI_ACCOUNT_ID", "my-account-id") + + captured = {} + + def _fake_get(url, *, headers=None, timeout=None, verify=None): + captured["url"] = url + captured["headers"] = headers + captured["timeout"] = timeout + captured["verify"] = verify + return DummyResponse({"data": []}) + + monkeypatch.setattr("requests.get", _fake_get) + + manager._fetch_provider_models("fireworks_ai") + + assert captured["url"] == "https://api.fireworks.ai/v1/accounts/my-account-id/models" + + +def test_models_url_account_id_missing_skips_fetch(monkeypatch, tmp_path, capsys): + config = { + "fireworks_ai": { + "api_base": "https://api.fireworks.ai/inference/v1", + "api_key_env": ["FIREWORKS_AI_API_KEY"], + "models_url": "https://api.fireworks.ai/v1/accounts/{account_id}/models", + "account_id_env": "FIREWORKS_AI_ACCOUNT_ID", + "requires_api_key": False, + } + } + + manager = _make_manager(tmp_path, config) + monkeypatch.delenv("FIREWORKS_AI_ACCOUNT_ID", raising=False) + + assert manager._fetch_provider_models("fireworks_ai") is None + + captured = capsys.readouterr() + assert "account_id_env not set" in captured.out + + +def test_models_url_without_placeholder_unchanged(monkeypatch, tmp_path): + config = { + "demo": { + "api_base": "https://example.com/v1", + "api_key_env": ["DEMO_KEY"], + "models_url": "https://example.com/v1/models", + "requires_api_key": False, + } + } + + manager = _make_manager(tmp_path, config) + + captured = {} + + def _fake_get(url, *, headers=None, timeout=None, verify=None): + captured["url"] = url + return DummyResponse({"data": []}) + + monkeypatch.setattr("requests.get", _fake_get) + + manager._fetch_provider_models("demo") + + assert captured["url"] == "https://example.com/v1/models"