Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 7 additions & 7 deletions docs/providers.md
Original file line number Diff line number Diff line change
Expand Up @@ -168,9 +168,9 @@ deepseek_client = openai.OpenAI(
base_url="https://api.deepseek.com/v1"
)
tracked_deepseek = LLMTrackingProxy(
deepseek_client,
provider=Provider.DEEPSEEK, # REQUIRED: Specifies this is DeepSeek
# framework=None by default for direct DeepSeek API usage
deepseek_client,
provider=Provider.OPENAI, # OpenAI-compatible API
base_url="https://api.deepseek.com/v1",
api_key=os.environ.get("LLMCOSTS_API_KEY"),
)

Expand All @@ -188,13 +188,13 @@ from llmcosts.tracker import LLMTrackingProxy, Provider
import openai

grok_client = openai.OpenAI(
api_key=os.environ.get("XAI_API_KEY"),
api_key=os.environ.get("XAI_API_KEY"),
base_url="https://api.x.ai/v1"
)
tracked_grok = LLMTrackingProxy(
grok_client,
provider=Provider.XAI, # REQUIRED: Specifies this is Grok/xAI
# framework=None by default for direct Grok API usage
grok_client,
provider=Provider.OPENAI, # OpenAI-compatible API
base_url="https://api.x.ai/v1",
api_key=os.environ.get("LLMCOSTS_API_KEY"),
)

Expand Down
4 changes: 4 additions & 0 deletions llmcosts/tracker/openai_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ def extract_usage_payload(self, obj: Any, attr: Any, **kwargs) -> Optional[Dict]
# attribute but no `choices` attribute.
method_owner = str(getattr(attr, "__self__", ""))
is_responses_api = "responses" in method_owner
base_url = kwargs.get("base_url")

if is_responses_api:
# For Responses API streaming, usage is nested in obj.response.usage
Expand Down Expand Up @@ -110,6 +111,9 @@ def extract_usage_payload(self, obj: Any, attr: Any, **kwargs) -> Optional[Dict]
if hasattr(obj, "service_tier") and obj.service_tier is not None:
payload["service_tier"] = obj.service_tier

if base_url:
payload["base_url"] = base_url

return self._add_common_fields(payload, obj)

def validate_streaming_options(self, target: Any, kw: Dict, attr: Any) -> None:
Expand Down
38 changes: 31 additions & 7 deletions llmcosts/tracker/proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ def __init__(
response_callback: Optional[Callable[[Any], None]] = None,
api_key: Optional[str] = None,
client_customer_key: Optional[str] = None,
base_url: Optional[str] = None,
):
"""Initialize the tracking proxy.

Expand All @@ -88,6 +89,9 @@ def __init__(
api_key: Optional LLMCOSTS_API_KEY. If None, will check environment variables.
If not found in environment either, will raise an error.
client_customer_key: Optional customer key for multi-tenant applications.
base_url: Optional base URL for OpenAI-compatible endpoints. If set,
this value will be included in usage payloads for completions
tracking.
"""
self._target = target
self._provider = provider
Expand All @@ -98,6 +102,7 @@ def __init__(
self._context = context.copy() if context else None
self._response_callback = response_callback
self._client_customer_key = client_customer_key
self._base_url = base_url
self._usage_handler = get_usage_handler(target)
self._langchain_mode = False # Initialize LangChain compatibility mode

Expand Down Expand Up @@ -331,6 +336,16 @@ def client_customer_key(self, value: Optional[str]) -> None:
"""Set the client_customer_key setting."""
self._client_customer_key = value

@property
def base_url(self) -> Optional[str]:
"""Get the base_url used for OpenAI-compatible calls."""
return self._base_url

@base_url.setter
def base_url(self, value: Optional[str]) -> None:
"""Set the base_url for subsequent calls."""
self._base_url = value

def enable_langchain_mode(self) -> None:
"""Enable LangChain compatibility mode.

Expand Down Expand Up @@ -397,6 +412,7 @@ def __getattr__(self, item: str):
response_callback=self._response_callback,
api_key=None, # Child proxies use the already-set global tracker
client_customer_key=self._client_customer_key,
base_url=self._base_url,
)
# Pass the client reference to the child proxy for threshold checking
child_proxy._llm_costs_client = self._llm_costs_client
Expand Down Expand Up @@ -441,7 +457,7 @@ async def agen():

async for chunk in iterator:
payload = self._usage_handler.extract_usage_payload(
chunk, **kw, attr=attr
chunk, **kw, attr=attr, base_url=self._base_url
)
if payload:
usage_found = True
Expand All @@ -455,14 +471,14 @@ async def agen():

if last_chunk and not usage_found:
payload = self._usage_handler.extract_usage_payload(
last_chunk, **kw, attr=attr
last_chunk, **kw, attr=attr, base_url=self._base_url
)
self._track_usage(payload)

return agen()

payload = self._usage_handler.extract_usage_payload(
res, **kw, attr=attr
res, **kw, attr=attr, base_url=self._base_url
)
self._track_usage(payload)
return res
Expand Down Expand Up @@ -548,7 +564,10 @@ def _create_tracking_iterator(self, stream):
for chunk in stream:
# Track usage for this chunk
payload = self._usage_handler.extract_usage_payload(
chunk, **self._kw, attr=self._attr
chunk,
**self._kw,
attr=self._attr,
base_url=self._base_url,
)
self._tracker_func(payload)
# Call response callback with each chunk
Expand Down Expand Up @@ -578,7 +597,7 @@ def gen():

for chunk in iterator:
payload = self._usage_handler.extract_usage_payload(
chunk, **kw, attr=attr
chunk, **kw, attr=attr, base_url=self._base_url
)
if payload:
usage_found = True
Expand All @@ -592,13 +611,18 @@ def gen():

if last_chunk and not usage_found:
payload = self._usage_handler.extract_usage_payload(
last_chunk, **kw, attr=attr
last_chunk,
**kw,
attr=attr,
base_url=self._base_url,
)
self._track_usage(payload)

return gen()

payload = self._usage_handler.extract_usage_payload(res, **kw, attr=attr)
payload = self._usage_handler.extract_usage_payload(
res, **kw, attr=attr, base_url=self._base_url
)
self._track_usage(payload)
return res

Expand Down
7 changes: 6 additions & 1 deletion tests/test_deepseek_limit.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,12 @@ def client():

@pytest.fixture
def tracked_client(client):
return LLMTrackingProxy(client, provider=Provider.DEEPSEEK, debug=True)
return LLMTrackingProxy(
client,
provider=Provider.OPENAI,
base_url="https://api.deepseek.com/v1",
debug=True,
)


def _allow():
Expand Down
7 changes: 6 additions & 1 deletion tests/test_deepseek_nonstreaming.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,12 @@ def deepseek_client(self):
@pytest.fixture
def tracked_deepseek_client(self, deepseek_client):
"""Create a tracked DeepSeek client."""
return LLMTrackingProxy(deepseek_client, provider=Provider.DEEPSEEK, debug=True)
return LLMTrackingProxy(
deepseek_client,
provider=Provider.OPENAI,
base_url="https://api.deepseek.com/v1",
debug=True,
)

def test_deepseek_chat_completions_non_streaming(
self, tracked_deepseek_client, caplog
Expand Down
7 changes: 6 additions & 1 deletion tests/test_deepseek_streaming.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,12 @@ def deepseek_client(self):
@pytest.fixture
def tracked_deepseek_client(self, deepseek_client):
"""Create a tracked DeepSeek client."""
return LLMTrackingProxy(deepseek_client, provider=Provider.DEEPSEEK, debug=True)
return LLMTrackingProxy(
deepseek_client,
provider=Provider.OPENAI,
base_url="https://api.deepseek.com/v1",
debug=True,
)

def test_deepseek_chat_completions_streaming(self, tracked_deepseek_client, caplog):
"""Test streaming chat completion with DeepSeek captures usage."""
Expand Down
7 changes: 6 additions & 1 deletion tests/test_grok_limit.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,12 @@ def client():

@pytest.fixture
def tracked_client(client):
return LLMTrackingProxy(client, provider=Provider.XAI, debug=True)
return LLMTrackingProxy(
client,
provider=Provider.OPENAI,
base_url="https://api.x.ai/v1",
debug=True,
)


def _allow():
Expand Down
7 changes: 6 additions & 1 deletion tests/test_grok_nonstreaming.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,12 @@ def grok_client(self):
@pytest.fixture
def tracked_grok_client(self, grok_client):
"""Create a tracked Grok client."""
return LLMTrackingProxy(grok_client, provider=Provider.XAI, debug=True)
return LLMTrackingProxy(
grok_client,
provider=Provider.OPENAI,
base_url="https://api.x.ai/v1",
debug=True,
)

def test_grok_3_mini_model(self, tracked_grok_client, caplog):
"""Test Grok 3 mini model captures usage correctly."""
Expand Down
7 changes: 6 additions & 1 deletion tests/test_grok_streaming.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,12 @@ def grok_client(self):
@pytest.fixture
def tracked_grok_client(self, grok_client):
"""Create a tracked Grok client."""
return LLMTrackingProxy(grok_client, provider=Provider.XAI, debug=True)
return LLMTrackingProxy(
grok_client,
provider=Provider.OPENAI,
base_url="https://api.x.ai/v1",
debug=True,
)

def test_grok_streaming_basic(self, tracked_grok_client, caplog):
"""Test basic Grok streaming functionality."""
Expand Down
78 changes: 78 additions & 0 deletions tests/test_proxy_new_features.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ def __init__(self):
"completion_tokens": 20,
"total_tokens": 30,
}
self._is_openai_mock = True
self.model = "test-model"
self.id = "test-response-id"

Expand All @@ -29,6 +30,20 @@ def chat_completion(self, **kwargs):
return self


class MockResponsesClient(MockClient):
"""Mock client mimicking OpenAI responses API."""

def __init__(self):
super().__init__()
self.responses = self

def __str__(self):
return "responses"

def create(self, **kwargs): # noqa: D401 - mimic openai responses.create
return self


class TestProxyNewFeatures:
"""Test the new LLMTrackingProxy features."""

Expand Down Expand Up @@ -280,3 +295,66 @@ def test_context_changes_between_calls(self):
proxy.chat_completion()
second_payload = mock_tracker.track.call_args_list[1][0][0]
assert second_payload["context"] == {"second": True}

def test_base_url_not_in_payload_by_default(self):
mock_client = MockClient()

with patch("llmcosts.tracker.proxy.get_usage_tracker") as mock_get_tracker:
mock_tracker = MagicMock()
mock_get_tracker.return_value = mock_tracker

proxy = LLMTrackingProxy(mock_client, provider=Provider.OPENAI)

proxy.chat_completion()
payload = mock_tracker.track.call_args[0][0]
assert "base_url" not in payload

def test_base_url_added_when_set(self):
mock_client = MockClient()

with patch("llmcosts.tracker.proxy.get_usage_tracker") as mock_get_tracker:
mock_tracker = MagicMock()
mock_get_tracker.return_value = mock_tracker

proxy = LLMTrackingProxy(
mock_client,
provider=Provider.OPENAI,
base_url="https://api.example.com/v1",
)

proxy.chat_completion()
payload = mock_tracker.track.call_args[0][0]
assert payload["base_url"] == "https://api.example.com/v1"

def test_base_url_setter(self):
mock_client = MockClient()
with patch("llmcosts.tracker.proxy.get_usage_tracker") as mock_get_tracker:
mock_tracker = MagicMock()
mock_get_tracker.return_value = mock_tracker

proxy = LLMTrackingProxy(mock_client, provider=Provider.OPENAI)

assert proxy.base_url is None
proxy.base_url = "https://api.deepseek.com/v1"
assert proxy.base_url == "https://api.deepseek.com/v1"

proxy.chat_completion()
payload = mock_tracker.track.call_args[0][0]
assert payload["base_url"] == "https://api.deepseek.com/v1"

def test_base_url_ignored_for_responses_api(self):
mock_client = MockResponsesClient()

with patch("llmcosts.tracker.proxy.get_usage_tracker") as mock_get_tracker:
mock_tracker = MagicMock()
mock_get_tracker.return_value = mock_tracker

proxy = LLMTrackingProxy(
mock_client,
provider=Provider.OPENAI,
base_url="https://api.example.com/v1",
)

proxy.responses.create(model="gpt-4")
payload = mock_tracker.track.call_args[0][0]
assert "base_url" not in payload