Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion src/observational_memory/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,9 +102,10 @@ def _call_openai_direct(
import openai

client = openai.OpenAI()
token_limit_arg = _openai_token_limit_arg(model, max_tokens)
response = client.chat.completions.create(
model=model,
max_tokens=max_tokens,
**token_limit_arg,
messages=[
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_content},
Expand All @@ -119,6 +120,13 @@ def _call_openai_direct(
return str(content)


def _openai_token_limit_arg(model: str, max_tokens: int) -> dict[str, int]:
normalized = model.lower()
if normalized.startswith(("gpt-5", "o1", "o3", "o4")):
return {"max_completion_tokens": max_tokens}
return {"max_tokens": max_tokens}


def _extract_anthropic_text(message: object) -> str:
content = getattr(message, "content", None)
if not content:
Expand Down
36 changes: 35 additions & 1 deletion tests/test_llm.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
"""Tests for LLM provider dispatch."""

import sys
from types import SimpleNamespace

import pytest

from observational_memory.config import Config
from observational_memory.llm import compress
from observational_memory.llm import _call_openai_direct, compress


@pytest.fixture(autouse=True)
Expand Down Expand Up @@ -112,3 +115,34 @@ def fake_openai(system_prompt, user_content, model, max_tokens, cfg):
assert False, "Should have raised"
except RuntimeError as e:
assert "provider 'openai'" in str(e).lower()


@pytest.mark.parametrize(
("model", "expected_token_arg"),
[
("gpt-5.4", "max_completion_tokens"),
("gpt-5.2-chat-latest", "max_completion_tokens"),
("o4-mini", "max_completion_tokens"),
("gpt-4o-mini", "max_tokens"),
],
)
def test_openai_token_limit_parameter_matches_model_family(monkeypatch, model, expected_token_arg):
request = {}

class FakeCompletions:
def create(self, **kwargs):
request.update(kwargs)
return SimpleNamespace(choices=[SimpleNamespace(message=SimpleNamespace(content="ok"))])

class FakeOpenAI:
def __init__(self):
self.chat = SimpleNamespace(completions=FakeCompletions())

monkeypatch.setitem(sys.modules, "openai", SimpleNamespace(OpenAI=FakeOpenAI))

result = _call_openai_direct("sys", "user", model, 8, Config())

assert result == "ok"
assert request[expected_token_arg] == 8
unexpected_token_arg = "max_tokens" if expected_token_arg == "max_completion_tokens" else "max_completion_tokens"
assert unexpected_token_arg not in request
Loading