Skip to content

Commit 2676b19

Browse files
Add baseten integration
1 parent a392509 commit 2676b19

File tree

4 files changed

+667
-0
lines changed

4 files changed

+667
-0
lines changed

pyproject.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,9 @@ packages = ["src/strands"]
4949
anthropic = [
5050
"anthropic>=0.21.0,<1.0.0",
5151
]
52+
baseten = [
53+
"openai>=1.68.0,<2.0.0",
54+
]
5255
dev = [
5356
"commitizen>=4.4.0,<5.0.0",
5457
"hatch>=1.0.0,<2.0.0",

src/strands/models/baseten.py

Lines changed: 185 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,185 @@
1+
"""Baseten model provider.
2+
3+
- Docs: https://docs.baseten.co/
4+
"""
5+
6+
import logging
7+
from typing import Any, Generator, Iterable, Optional, Protocol, Type, TypedDict, TypeVar, Union, cast
8+
9+
import openai
10+
from openai.types.chat.parsed_chat_completion import ParsedChatCompletion
11+
from pydantic import BaseModel
12+
from typing_extensions import Unpack, override
13+
14+
from ..types.content import Messages
15+
from ..types.models import OpenAIModel
16+
17+
logger = logging.getLogger(__name__)
18+
19+
T = TypeVar("T", bound=BaseModel)
20+
21+
22+
class Client(Protocol):
23+
"""Protocol defining the OpenAI-compatible interface for the underlying provider client."""
24+
25+
@property
26+
# pragma: no cover
27+
def chat(self) -> Any:
28+
"""Chat completions interface."""
29+
...
30+
31+
32+
class BasetenModel(OpenAIModel):
33+
"""Baseten model provider implementation."""
34+
35+
client: Client
36+
37+
class BasetenConfig(TypedDict, total=False):
38+
"""Configuration options for Baseten models.
39+
40+
Attributes:
41+
model_id: Model ID for the Baseten model.
42+
For Model APIs, use model slugs like "deepseek-ai/DeepSeek-R1-0528" or "meta-llama/Llama-4-Maverick-17B-128E-Instruct".
43+
For dedicated deployments, use the deployment ID.
44+
base_url: Base URL for the Baseten API.
45+
For Model APIs: https://inference.baseten.co/v1
46+
For dedicated deployments: https://model-xxxxxxx.api.baseten.co/environments/production/sync/v1
47+
params: Model parameters (e.g., max_tokens).
48+
For a complete list of supported parameters, see
49+
https://platform.openai.com/docs/api-reference/chat/create.
50+
"""
51+
52+
model_id: str
53+
base_url: Optional[str]
54+
params: Optional[dict[str, Any]]
55+
56+
def __init__(self, client_args: Optional[dict[str, Any]] = None, **model_config: Unpack[BasetenConfig]) -> None:
57+
"""Initialize provider instance.
58+
59+
Args:
60+
client_args: Arguments for the Baseten client.
61+
For a complete list of supported arguments, see https://pypi.org/project/openai/.
62+
**model_config: Configuration options for the Baseten model.
63+
"""
64+
self.config = dict(model_config)
65+
66+
logger.debug("config=<%s> | initializing", self.config)
67+
68+
client_args = client_args or {}
69+
70+
# Set default base URL for Model APIs if not provided
71+
if "base_url" not in client_args and "base_url" not in self.config:
72+
client_args["base_url"] = "https://inference.baseten.co/v1"
73+
elif "base_url" in self.config:
74+
client_args["base_url"] = self.config["base_url"]
75+
76+
self.client = openai.OpenAI(**client_args)
77+
78+
@override
79+
def update_config(self, **model_config: Unpack[BasetenConfig]) -> None: # type: ignore[override]
80+
"""Update the Baseten model configuration with the provided arguments.
81+
82+
Args:
83+
**model_config: Configuration overrides.
84+
"""
85+
self.config.update(model_config)
86+
87+
@override
88+
def get_config(self) -> BasetenConfig:
89+
"""Get the Baseten model configuration.
90+
91+
Returns:
92+
The Baseten model configuration.
93+
"""
94+
return cast(BasetenModel.BasetenConfig, self.config)
95+
96+
@override
97+
def stream(self, request: dict[str, Any]) -> Iterable[dict[str, Any]]:
98+
"""Send the request to the Baseten model and get the streaming response.
99+
100+
Args:
101+
request: The formatted request to send to the Baseten model.
102+
103+
Returns:
104+
An iterable of response events from the Baseten model.
105+
"""
106+
response = self.client.chat.completions.create(**request)
107+
108+
yield {"chunk_type": "message_start"}
109+
yield {"chunk_type": "content_start", "data_type": "text"}
110+
111+
tool_calls: dict[int, list[Any]] = {}
112+
113+
for event in response:
114+
# Defensive: skip events with empty or missing choices
115+
if not getattr(event, "choices", None):
116+
continue
117+
choice = event.choices[0]
118+
119+
if choice.delta.content:
120+
yield {"chunk_type": "content_delta", "data_type": "text", "data": choice.delta.content}
121+
122+
if hasattr(choice.delta, "reasoning_content") and choice.delta.reasoning_content:
123+
yield {
124+
"chunk_type": "content_delta",
125+
"data_type": "reasoning_content",
126+
"data": choice.delta.reasoning_content,
127+
}
128+
129+
for tool_call in choice.delta.tool_calls or []:
130+
tool_calls.setdefault(tool_call.index, []).append(tool_call)
131+
132+
if choice.finish_reason:
133+
break
134+
135+
yield {"chunk_type": "content_stop", "data_type": "text"}
136+
137+
for tool_deltas in tool_calls.values():
138+
yield {"chunk_type": "content_start", "data_type": "tool", "data": tool_deltas[0]}
139+
140+
for tool_delta in tool_deltas:
141+
yield {"chunk_type": "content_delta", "data_type": "tool", "data": tool_delta}
142+
143+
yield {"chunk_type": "content_stop", "data_type": "tool"}
144+
145+
yield {"chunk_type": "message_stop", "data": choice.finish_reason}
146+
147+
# Skip remaining events as we don't have use for anything except the final usage payload
148+
for event in response:
149+
_ = event
150+
151+
yield {"chunk_type": "metadata", "data": event.usage}
152+
153+
@override
154+
def structured_output(
155+
self, output_model: Type[T], prompt: Messages
156+
) -> Generator[dict[str, Union[T, Any]], None, None]:
157+
"""Get structured output from the model.
158+
159+
Args:
160+
output_model: The output model to use for the agent.
161+
prompt: The prompt messages to use for the agent.
162+
163+
Yields:
164+
Model events with the last being the structured output.
165+
"""
166+
response: ParsedChatCompletion = self.client.beta.chat.completions.parse( # type: ignore
167+
model=self.get_config()["model_id"],
168+
messages=super().format_request(prompt)["messages"],
169+
response_format=output_model,
170+
)
171+
172+
parsed: T | None = None
173+
# Find the first choice with tool_calls
174+
if len(response.choices) > 1:
175+
raise ValueError("Multiple choices found in the Baseten response.")
176+
177+
for choice in response.choices:
178+
if isinstance(choice.message.parsed, output_model):
179+
parsed = choice.message.parsed
180+
break
181+
182+
if parsed:
183+
yield {"output": parsed}
184+
else:
185+
raise ValueError("No valid tool use or tool use input was found in the Baseten response.")

tests-integ/test_model_baseten.py

Lines changed: 175 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,175 @@
1+
import os
2+
3+
import pytest
4+
from pydantic import BaseModel
5+
6+
import strands
7+
from strands import Agent
8+
from strands.models.baseten import BasetenModel
9+
10+
11+
@pytest.fixture
12+
def model_model_apis():
13+
"""Test with Model APIs using DeepSeek R1 model."""
14+
return BasetenModel(
15+
model_id="deepseek-ai/DeepSeek-V3-0324",
16+
client_args={
17+
"api_key": os.getenv("BASETEN_API_KEY"),
18+
},
19+
)
20+
21+
22+
@pytest.fixture
23+
def model_dedicated_deployment():
24+
"""Test with dedicated deployment -- change this to your deployment ID when testing."""
25+
base_url = "https://model-232k7g23.api.baseten.co/environments/production/sync/v1"
26+
27+
return BasetenModel(
28+
base_url=base_url,
29+
client_args={
30+
"api_key": os.getenv("BASETEN_API_KEY"),
31+
},
32+
)
33+
34+
35+
@pytest.fixture
36+
def tools():
37+
@strands.tool
38+
def tool_time() -> str:
39+
return "12:00"
40+
41+
@strands.tool
42+
def tool_weather() -> str:
43+
return "sunny"
44+
45+
return [tool_time, tool_weather]
46+
47+
48+
@pytest.fixture
49+
def agent_model_apis(model_model_apis, tools):
50+
return Agent(model=model_model_apis, tools=tools)
51+
52+
53+
@pytest.fixture
54+
def agent_dedicated(model_dedicated_deployment, tools):
55+
return Agent(model=model_dedicated_deployment, tools=tools)
56+
57+
58+
@pytest.mark.skipif(
59+
"BASETEN_API_KEY" not in os.environ,
60+
reason="BASETEN_API_KEY environment variable missing",
61+
)
62+
def test_agent_model_apis(agent_model_apis):
63+
result = agent_model_apis("What is the time and weather in New York?")
64+
text = result.message["content"][0]["text"].lower()
65+
66+
assert all(string in text for string in ["12:00", "sunny"])
67+
68+
69+
@pytest.mark.skipif(
70+
"BASETEN_API_KEY" not in os.environ or "BASETEN_DEPLOYMENT_ID" not in os.environ,
71+
reason="BASETEN_API_KEY or BASETEN_DEPLOYMENT_ID environment variable missing",
72+
)
73+
def test_agent_dedicated_deployment(agent_dedicated):
74+
result = agent_dedicated("What is the time and weather in New York?")
75+
text = result.message["content"][0]["text"].lower()
76+
77+
assert all(string in text for string in ["12:00", "sunny"])
78+
79+
80+
@pytest.mark.skipif(
81+
"BASETEN_API_KEY" not in os.environ,
82+
reason="BASETEN_API_KEY environment variable missing",
83+
)
84+
def test_structured_output_model_apis(model_model_apis):
85+
class Weather(BaseModel):
86+
"""Extracts the time and weather from the user's message with the exact strings."""
87+
88+
time: str
89+
weather: str
90+
91+
agent = Agent(model=model_model_apis)
92+
93+
result = agent.structured_output(Weather, "The time is 12:00 and the weather is sunny")
94+
assert isinstance(result, Weather)
95+
assert result.time == "12:00"
96+
assert result.weather == "sunny"
97+
98+
99+
@pytest.mark.skipif(
100+
"BASETEN_API_KEY" not in os.environ or "BASETEN_DEPLOYMENT_ID" not in os.environ,
101+
reason="BASETEN_API_KEY or BASETEN_DEPLOYMENT_ID environment variable missing",
102+
)
103+
def test_structured_output_dedicated_deployment(model_dedicated_deployment):
104+
class Weather(BaseModel):
105+
"""Extracts the time and weather from the user's message with the exact strings."""
106+
107+
time: str
108+
weather: str
109+
110+
agent = Agent(model=model_dedicated_deployment)
111+
112+
result = agent.structured_output(Weather, "The time is 12:00 and the weather is sunny")
113+
assert isinstance(result, Weather)
114+
assert result.time == "12:00"
115+
assert result.weather == "sunny"
116+
117+
118+
@pytest.mark.skipif(
119+
"BASETEN_API_KEY" not in os.environ,
120+
reason="BASETEN_API_KEY environment variable missing",
121+
)
122+
def test_llama_model_model_apis():
123+
"""Test with Llama 4 Maverick model on Model APIs."""
124+
model = BasetenModel(
125+
model_id="meta-llama/Llama-4-Maverick-17B-128E-Instruct",
126+
client_args={
127+
"api_key": os.getenv("BASETEN_API_KEY"),
128+
},
129+
)
130+
131+
agent = Agent(model=model)
132+
result = agent("Hello, how are you?")
133+
134+
assert result.message["content"][0]["text"] is not None
135+
assert len(result.message["content"][0]["text"]) > 0
136+
137+
138+
@pytest.mark.skipif(
139+
"BASETEN_API_KEY" not in os.environ,
140+
reason="BASETEN_API_KEY environment variable missing",
141+
)
142+
def test_deepseek_r1_model_apis():
143+
"""Test with DeepSeek R1 model on Model APIs."""
144+
model = BasetenModel(
145+
model_id="deepseek-ai/DeepSeek-R1-0528",
146+
client_args={
147+
"api_key": os.getenv("BASETEN_API_KEY"),
148+
},
149+
)
150+
151+
agent = Agent(model=model)
152+
result = agent("What is 2 + 2?")
153+
154+
assert result.message["content"][0]["text"] is not None
155+
assert len(result.message["content"][0]["text"]) > 0
156+
157+
158+
@pytest.mark.skipif(
159+
"BASETEN_API_KEY" not in os.environ,
160+
reason="BASETEN_API_KEY environment variable missing",
161+
)
162+
def test_llama_scout_model_apis():
163+
"""Test with Llama 4 Scout model on Model APIs."""
164+
model = BasetenModel(
165+
model_id="meta-llama/Llama-4-Scout-17B-16E-Instruct",
166+
client_args={
167+
"api_key": os.getenv("BASETEN_API_KEY"),
168+
},
169+
)
170+
171+
agent = Agent(model=model)
172+
result = agent("Explain quantum computing in simple terms.")
173+
174+
assert result.message["content"][0]["text"] is not None
175+
assert len(result.message["content"][0]["text"]) > 0

0 commit comments

Comments
 (0)