From 359a2030e305c8325107681a3f6ebcec08fc8126 Mon Sep 17 00:00:00 2001 From: Keyton Weissinger <125780187+keyton-weissinger@users.noreply.github.com> Date: Tue, 8 Jul 2025 14:27:03 -0400 Subject: [PATCH] feat: support custom base_url for OpenAI proxy --- docs/providers.md | 14 +++--- llmcosts/tracker/openai_handler.py | 4 ++ llmcosts/tracker/proxy.py | 38 +++++++++++--- tests/test_deepseek_limit.py | 7 ++- tests/test_deepseek_nonstreaming.py | 7 ++- tests/test_deepseek_streaming.py | 7 ++- tests/test_grok_limit.py | 7 ++- tests/test_grok_nonstreaming.py | 7 ++- tests/test_grok_streaming.py | 7 ++- tests/test_proxy_new_features.py | 78 +++++++++++++++++++++++++++++ 10 files changed, 156 insertions(+), 20 deletions(-) diff --git a/docs/providers.md b/docs/providers.md index 34d5bba..4cc692c 100644 --- a/docs/providers.md +++ b/docs/providers.md @@ -168,9 +168,9 @@ deepseek_client = openai.OpenAI( base_url="https://api.deepseek.com/v1" ) tracked_deepseek = LLMTrackingProxy( - deepseek_client, - provider=Provider.DEEPSEEK, # REQUIRED: Specifies this is DeepSeek - # framework=None by default for direct DeepSeek API usage + deepseek_client, + provider=Provider.OPENAI, # OpenAI-compatible API + base_url="https://api.deepseek.com/v1", api_key=os.environ.get("LLMCOSTS_API_KEY"), ) @@ -188,13 +188,13 @@ from llmcosts.tracker import LLMTrackingProxy, Provider import openai grok_client = openai.OpenAI( - api_key=os.environ.get("XAI_API_KEY"), + api_key=os.environ.get("XAI_API_KEY"), base_url="https://api.x.ai/v1" ) tracked_grok = LLMTrackingProxy( - grok_client, - provider=Provider.XAI, # REQUIRED: Specifies this is Grok/xAI - # framework=None by default for direct Grok API usage + grok_client, + provider=Provider.OPENAI, # OpenAI-compatible API + base_url="https://api.x.ai/v1", api_key=os.environ.get("LLMCOSTS_API_KEY"), ) diff --git a/llmcosts/tracker/openai_handler.py b/llmcosts/tracker/openai_handler.py index 86e5920..7924f23 100644 --- a/llmcosts/tracker/openai_handler.py +++ b/llmcosts/tracker/openai_handler.py @@ -45,6 +45,7 @@ def extract_usage_payload(self, obj: Any, attr: Any, **kwargs) -> Optional[Dict] # attribute but no `choices` attribute. method_owner = str(getattr(attr, "__self__", "")) is_responses_api = "responses" in method_owner + base_url = kwargs.get("base_url") if is_responses_api: # For Responses API streaming, usage is nested in obj.response.usage @@ -110,6 +111,9 @@ def extract_usage_payload(self, obj: Any, attr: Any, **kwargs) -> Optional[Dict] if hasattr(obj, "service_tier") and obj.service_tier is not None: payload["service_tier"] = obj.service_tier + if base_url: + payload["base_url"] = base_url + return self._add_common_fields(payload, obj) def validate_streaming_options(self, target: Any, kw: Dict, attr: Any) -> None: diff --git a/llmcosts/tracker/proxy.py b/llmcosts/tracker/proxy.py index 0f22d6a..c44e89a 100644 --- a/llmcosts/tracker/proxy.py +++ b/llmcosts/tracker/proxy.py @@ -69,6 +69,7 @@ def __init__( response_callback: Optional[Callable[[Any], None]] = None, api_key: Optional[str] = None, client_customer_key: Optional[str] = None, + base_url: Optional[str] = None, ): """Initialize the tracking proxy. @@ -88,6 +89,9 @@ def __init__( api_key: Optional LLMCOSTS_API_KEY. If None, will check environment variables. If not found in environment either, will raise an error. client_customer_key: Optional customer key for multi-tenant applications. + base_url: Optional base URL for OpenAI-compatible endpoints. If set, + this value will be included in usage payloads for completions + tracking. """ self._target = target self._provider = provider @@ -98,6 +102,7 @@ def __init__( self._context = context.copy() if context else None self._response_callback = response_callback self._client_customer_key = client_customer_key + self._base_url = base_url self._usage_handler = get_usage_handler(target) self._langchain_mode = False # Initialize LangChain compatibility mode @@ -331,6 +336,16 @@ def client_customer_key(self, value: Optional[str]) -> None: """Set the client_customer_key setting.""" self._client_customer_key = value + @property + def base_url(self) -> Optional[str]: + """Get the base_url used for OpenAI-compatible calls.""" + return self._base_url + + @base_url.setter + def base_url(self, value: Optional[str]) -> None: + """Set the base_url for subsequent calls.""" + self._base_url = value + def enable_langchain_mode(self) -> None: """Enable LangChain compatibility mode. @@ -397,6 +412,7 @@ def __getattr__(self, item: str): response_callback=self._response_callback, api_key=None, # Child proxies use the already-set global tracker client_customer_key=self._client_customer_key, + base_url=self._base_url, ) # Pass the client reference to the child proxy for threshold checking child_proxy._llm_costs_client = self._llm_costs_client @@ -441,7 +457,7 @@ async def agen(): async for chunk in iterator: payload = self._usage_handler.extract_usage_payload( - chunk, **kw, attr=attr + chunk, **kw, attr=attr, base_url=self._base_url ) if payload: usage_found = True @@ -455,14 +471,14 @@ async def agen(): if last_chunk and not usage_found: payload = self._usage_handler.extract_usage_payload( - last_chunk, **kw, attr=attr + last_chunk, **kw, attr=attr, base_url=self._base_url ) self._track_usage(payload) return agen() payload = self._usage_handler.extract_usage_payload( - res, **kw, attr=attr + res, **kw, attr=attr, base_url=self._base_url ) self._track_usage(payload) return res @@ -548,7 +564,10 @@ def _create_tracking_iterator(self, stream): for chunk in stream: # Track usage for this chunk payload = self._usage_handler.extract_usage_payload( - chunk, **self._kw, attr=self._attr + chunk, + **self._kw, + attr=self._attr, + base_url=self._base_url, ) self._tracker_func(payload) # Call response callback with each chunk @@ -578,7 +597,7 @@ def gen(): for chunk in iterator: payload = self._usage_handler.extract_usage_payload( - chunk, **kw, attr=attr + chunk, **kw, attr=attr, base_url=self._base_url ) if payload: usage_found = True @@ -592,13 +611,18 @@ def gen(): if last_chunk and not usage_found: payload = self._usage_handler.extract_usage_payload( - last_chunk, **kw, attr=attr + last_chunk, + **kw, + attr=attr, + base_url=self._base_url, ) self._track_usage(payload) return gen() - payload = self._usage_handler.extract_usage_payload(res, **kw, attr=attr) + payload = self._usage_handler.extract_usage_payload( + res, **kw, attr=attr, base_url=self._base_url + ) self._track_usage(payload) return res diff --git a/tests/test_deepseek_limit.py b/tests/test_deepseek_limit.py index a811ffd..699e3c9 100644 --- a/tests/test_deepseek_limit.py +++ b/tests/test_deepseek_limit.py @@ -26,7 +26,12 @@ def client(): @pytest.fixture def tracked_client(client): - return LLMTrackingProxy(client, provider=Provider.DEEPSEEK, debug=True) + return LLMTrackingProxy( + client, + provider=Provider.OPENAI, + base_url="https://api.deepseek.com/v1", + debug=True, + ) def _allow(): diff --git a/tests/test_deepseek_nonstreaming.py b/tests/test_deepseek_nonstreaming.py index 20903f3..b52378b 100644 --- a/tests/test_deepseek_nonstreaming.py +++ b/tests/test_deepseek_nonstreaming.py @@ -40,7 +40,12 @@ def deepseek_client(self): @pytest.fixture def tracked_deepseek_client(self, deepseek_client): """Create a tracked DeepSeek client.""" - return LLMTrackingProxy(deepseek_client, provider=Provider.DEEPSEEK, debug=True) + return LLMTrackingProxy( + deepseek_client, + provider=Provider.OPENAI, + base_url="https://api.deepseek.com/v1", + debug=True, + ) def test_deepseek_chat_completions_non_streaming( self, tracked_deepseek_client, caplog diff --git a/tests/test_deepseek_streaming.py b/tests/test_deepseek_streaming.py index 55dc855..749ac6a 100644 --- a/tests/test_deepseek_streaming.py +++ b/tests/test_deepseek_streaming.py @@ -40,7 +40,12 @@ def deepseek_client(self): @pytest.fixture def tracked_deepseek_client(self, deepseek_client): """Create a tracked DeepSeek client.""" - return LLMTrackingProxy(deepseek_client, provider=Provider.DEEPSEEK, debug=True) + return LLMTrackingProxy( + deepseek_client, + provider=Provider.OPENAI, + base_url="https://api.deepseek.com/v1", + debug=True, + ) def test_deepseek_chat_completions_streaming(self, tracked_deepseek_client, caplog): """Test streaming chat completion with DeepSeek captures usage.""" diff --git a/tests/test_grok_limit.py b/tests/test_grok_limit.py index b9aeda5..56ed8c6 100644 --- a/tests/test_grok_limit.py +++ b/tests/test_grok_limit.py @@ -34,7 +34,12 @@ def client(): @pytest.fixture def tracked_client(client): - return LLMTrackingProxy(client, provider=Provider.XAI, debug=True) + return LLMTrackingProxy( + client, + provider=Provider.OPENAI, + base_url="https://api.x.ai/v1", + debug=True, + ) def _allow(): diff --git a/tests/test_grok_nonstreaming.py b/tests/test_grok_nonstreaming.py index ace1954..e723a8a 100644 --- a/tests/test_grok_nonstreaming.py +++ b/tests/test_grok_nonstreaming.py @@ -40,7 +40,12 @@ def grok_client(self): @pytest.fixture def tracked_grok_client(self, grok_client): """Create a tracked Grok client.""" - return LLMTrackingProxy(grok_client, provider=Provider.XAI, debug=True) + return LLMTrackingProxy( + grok_client, + provider=Provider.OPENAI, + base_url="https://api.x.ai/v1", + debug=True, + ) def test_grok_3_mini_model(self, tracked_grok_client, caplog): """Test Grok 3 mini model captures usage correctly.""" diff --git a/tests/test_grok_streaming.py b/tests/test_grok_streaming.py index b0c5504..b0c112f 100644 --- a/tests/test_grok_streaming.py +++ b/tests/test_grok_streaming.py @@ -40,7 +40,12 @@ def grok_client(self): @pytest.fixture def tracked_grok_client(self, grok_client): """Create a tracked Grok client.""" - return LLMTrackingProxy(grok_client, provider=Provider.XAI, debug=True) + return LLMTrackingProxy( + grok_client, + provider=Provider.OPENAI, + base_url="https://api.x.ai/v1", + debug=True, + ) def test_grok_streaming_basic(self, tracked_grok_client, caplog): """Test basic Grok streaming functionality.""" diff --git a/tests/test_proxy_new_features.py b/tests/test_proxy_new_features.py index 71abcec..6711eb9 100644 --- a/tests/test_proxy_new_features.py +++ b/tests/test_proxy_new_features.py @@ -21,6 +21,7 @@ def __init__(self): "completion_tokens": 20, "total_tokens": 30, } + self._is_openai_mock = True self.model = "test-model" self.id = "test-response-id" @@ -29,6 +30,20 @@ def chat_completion(self, **kwargs): return self +class MockResponsesClient(MockClient): + """Mock client mimicking OpenAI responses API.""" + + def __init__(self): + super().__init__() + self.responses = self + + def __str__(self): + return "responses" + + def create(self, **kwargs): # noqa: D401 - mimic openai responses.create + return self + + class TestProxyNewFeatures: """Test the new LLMTrackingProxy features.""" @@ -280,3 +295,66 @@ def test_context_changes_between_calls(self): proxy.chat_completion() second_payload = mock_tracker.track.call_args_list[1][0][0] assert second_payload["context"] == {"second": True} + + def test_base_url_not_in_payload_by_default(self): + mock_client = MockClient() + + with patch("llmcosts.tracker.proxy.get_usage_tracker") as mock_get_tracker: + mock_tracker = MagicMock() + mock_get_tracker.return_value = mock_tracker + + proxy = LLMTrackingProxy(mock_client, provider=Provider.OPENAI) + + proxy.chat_completion() + payload = mock_tracker.track.call_args[0][0] + assert "base_url" not in payload + + def test_base_url_added_when_set(self): + mock_client = MockClient() + + with patch("llmcosts.tracker.proxy.get_usage_tracker") as mock_get_tracker: + mock_tracker = MagicMock() + mock_get_tracker.return_value = mock_tracker + + proxy = LLMTrackingProxy( + mock_client, + provider=Provider.OPENAI, + base_url="https://api.example.com/v1", + ) + + proxy.chat_completion() + payload = mock_tracker.track.call_args[0][0] + assert payload["base_url"] == "https://api.example.com/v1" + + def test_base_url_setter(self): + mock_client = MockClient() + with patch("llmcosts.tracker.proxy.get_usage_tracker") as mock_get_tracker: + mock_tracker = MagicMock() + mock_get_tracker.return_value = mock_tracker + + proxy = LLMTrackingProxy(mock_client, provider=Provider.OPENAI) + + assert proxy.base_url is None + proxy.base_url = "https://api.deepseek.com/v1" + assert proxy.base_url == "https://api.deepseek.com/v1" + + proxy.chat_completion() + payload = mock_tracker.track.call_args[0][0] + assert payload["base_url"] == "https://api.deepseek.com/v1" + + def test_base_url_ignored_for_responses_api(self): + mock_client = MockResponsesClient() + + with patch("llmcosts.tracker.proxy.get_usage_tracker") as mock_get_tracker: + mock_tracker = MagicMock() + mock_get_tracker.return_value = mock_tracker + + proxy = LLMTrackingProxy( + mock_client, + provider=Provider.OPENAI, + base_url="https://api.example.com/v1", + ) + + proxy.responses.create(model="gpt-4") + payload = mock_tracker.track.call_args[0][0] + assert "base_url" not in payload