diff --git a/src/opengradient/client/llm.py b/src/opengradient/client/llm.py index ed54fd9..f8d6ae7 100644 --- a/src/opengradient/client/llm.py +++ b/src/opengradient/client/llm.py @@ -251,8 +251,12 @@ async def completion( Raises: RuntimeError: If the inference fails. + ValueError: If the model string is not in "provider/model" format. """ - model_id = model.split("/")[1] + parts = model.split("/") + if len(parts) < 2 or not parts[1]: + raise ValueError(f"Invalid model format {model!r}: expected 'provider/model' (e.g. 'openai/gpt-5')") + model_id = parts[1] payload: Dict = { "model": model_id, "prompt": prompt, @@ -325,9 +329,13 @@ async def chat( Raises: RuntimeError: If the inference fails. + ValueError: If the model string is not in "provider/model" format. """ + parts = model.split("/") + if len(parts) < 2 or not parts[1]: + raise ValueError(f"Invalid model format {model!r}: expected 'provider/model' (e.g. 'openai/gpt-5')") params = _ChatParams( - model=model.split("/")[1], + model=parts[1], max_tokens=max_tokens, temperature=temperature, stop_sequence=stop_sequence, diff --git a/tests/client_test.py b/tests/client_test.py index 6829fc9..7466ab0 100644 --- a/tests/client_test.py +++ b/tests/client_test.py @@ -190,3 +190,34 @@ def test_settlement_modes_values(self): assert x402SettlementMode.PRIVATE == "private" assert x402SettlementMode.BATCH_HASHED == "batch" assert x402SettlementMode.INDIVIDUAL_FULL == "individual" + + +# --- Model Format Validation Tests --- + + +class TestModelFormatValidation: + """Tests for model string validation in LLM.completion() and LLM.chat().""" + + def test_completion_rejects_model_without_slash(self, mock_tee_registry): + llm = LLM(private_key=FAKE_PRIVATE_KEY) + with pytest.raises(ValueError, match="Invalid model format"): + import asyncio + asyncio.get_event_loop().run_until_complete( + llm.completion(model="no-slash-model", prompt="hi") + ) + + def test_completion_rejects_model_with_trailing_slash(self, mock_tee_registry): + llm = LLM(private_key=FAKE_PRIVATE_KEY) + with pytest.raises(ValueError, match="Invalid model format"): + import asyncio + asyncio.get_event_loop().run_until_complete( + llm.completion(model="openai/", prompt="hi") + ) + + def test_chat_rejects_model_without_slash(self, mock_tee_registry): + llm = LLM(private_key=FAKE_PRIVATE_KEY) + with pytest.raises(ValueError, match="Invalid model format"): + import asyncio + asyncio.get_event_loop().run_until_complete( + llm.chat(model="no-slash-model", messages=[{"role": "user", "content": "hi"}]) + )