diff --git a/src/observational_memory/llm.py b/src/observational_memory/llm.py index 162afb9..309592f 100644 --- a/src/observational_memory/llm.py +++ b/src/observational_memory/llm.py @@ -102,9 +102,10 @@ def _call_openai_direct( import openai client = openai.OpenAI() + token_limit_arg = _openai_token_limit_arg(model, max_tokens) response = client.chat.completions.create( model=model, - max_tokens=max_tokens, + **token_limit_arg, messages=[ {"role": "system", "content": system_prompt}, {"role": "user", "content": user_content}, @@ -119,6 +120,13 @@ def _call_openai_direct( return str(content) +def _openai_token_limit_arg(model: str, max_tokens: int) -> dict[str, int]: + normalized = model.lower() + if normalized.startswith(("gpt-5", "o1", "o3", "o4")): + return {"max_completion_tokens": max_tokens} + return {"max_tokens": max_tokens} + + def _extract_anthropic_text(message: object) -> str: content = getattr(message, "content", None) if not content: diff --git a/tests/test_llm.py b/tests/test_llm.py index 0e055de..70b56ba 100644 --- a/tests/test_llm.py +++ b/tests/test_llm.py @@ -1,9 +1,12 @@ """Tests for LLM provider dispatch.""" +import sys +from types import SimpleNamespace + import pytest from observational_memory.config import Config -from observational_memory.llm import compress +from observational_memory.llm import _call_openai_direct, compress @pytest.fixture(autouse=True) @@ -112,3 +115,34 @@ def fake_openai(system_prompt, user_content, model, max_tokens, cfg): assert False, "Should have raised" except RuntimeError as e: assert "provider 'openai'" in str(e).lower() + + +@pytest.mark.parametrize( + ("model", "expected_token_arg"), + [ + ("gpt-5.4", "max_completion_tokens"), + ("gpt-5.2-chat-latest", "max_completion_tokens"), + ("o4-mini", "max_completion_tokens"), + ("gpt-4o-mini", "max_tokens"), + ], +) +def test_openai_token_limit_parameter_matches_model_family(monkeypatch, model, expected_token_arg): + request = {} + + class FakeCompletions: + def create(self, **kwargs): + request.update(kwargs) + return SimpleNamespace(choices=[SimpleNamespace(message=SimpleNamespace(content="ok"))]) + + class FakeOpenAI: + def __init__(self): + self.chat = SimpleNamespace(completions=FakeCompletions()) + + monkeypatch.setitem(sys.modules, "openai", SimpleNamespace(OpenAI=FakeOpenAI)) + + result = _call_openai_direct("sys", "user", model, 8, Config()) + + assert result == "ok" + assert request[expected_token_arg] == 8 + unexpected_token_arg = "max_tokens" if expected_token_arg == "max_completion_tokens" else "max_completion_tokens" + assert unexpected_token_arg not in request