1
+ """Baseten model provider.
2
+
3
+ - Docs: https://docs.baseten.co/
4
+ """
5
+
6
+ import logging
7
+ from typing import Any , Generator , Iterable , Optional , Protocol , Type , TypedDict , TypeVar , Union , cast
8
+
9
+ import openai
10
+ from openai .types .chat .parsed_chat_completion import ParsedChatCompletion
11
+ from pydantic import BaseModel
12
+ from typing_extensions import Unpack , override
13
+
14
+ from ..types .content import Messages
15
+ from ..types .models import OpenAIModel
16
+
17
+ logger = logging .getLogger (__name__ )
18
+
19
+ T = TypeVar ("T" , bound = BaseModel )
20
+
21
+
22
+ class Client (Protocol ):
23
+ """Protocol defining the OpenAI-compatible interface for the underlying provider client."""
24
+
25
+ @property
26
+ # pragma: no cover
27
+ def chat (self ) -> Any :
28
+ """Chat completions interface."""
29
+ ...
30
+
31
+
32
+ class BasetenModel (OpenAIModel ):
33
+ """Baseten model provider implementation."""
34
+
35
+ client : Client
36
+
37
+ class BasetenConfig (TypedDict , total = False ):
38
+ """Configuration options for Baseten models.
39
+
40
+ Attributes:
41
+ model_id: Model ID for the Baseten model.
42
+ For Model APIs, use model slugs like "deepseek-ai/DeepSeek-R1-0528" or "meta-llama/Llama-4-Maverick-17B-128E-Instruct".
43
+ For dedicated deployments, use the deployment ID.
44
+ base_url: Base URL for the Baseten API.
45
+ For Model APIs: https://inference.baseten.co/v1
46
+ For dedicated deployments: https://model-xxxxxxx.api.baseten.co/environments/production/sync/v1
47
+ params: Model parameters (e.g., max_tokens).
48
+ For a complete list of supported parameters, see
49
+ https://platform.openai.com/docs/api-reference/chat/create.
50
+ """
51
+
52
+ model_id : str
53
+ base_url : Optional [str ]
54
+ params : Optional [dict [str , Any ]]
55
+
56
+ def __init__ (self , client_args : Optional [dict [str , Any ]] = None , ** model_config : Unpack [BasetenConfig ]) -> None :
57
+ """Initialize provider instance.
58
+
59
+ Args:
60
+ client_args: Arguments for the Baseten client.
61
+ For a complete list of supported arguments, see https://pypi.org/project/openai/.
62
+ **model_config: Configuration options for the Baseten model.
63
+ """
64
+ self .config = dict (model_config )
65
+
66
+ logger .debug ("config=<%s> | initializing" , self .config )
67
+
68
+ client_args = client_args or {}
69
+
70
+ # Set default base URL for Model APIs if not provided
71
+ if "base_url" not in client_args and "base_url" not in self .config :
72
+ client_args ["base_url" ] = "https://inference.baseten.co/v1"
73
+ elif "base_url" in self .config :
74
+ client_args ["base_url" ] = self .config ["base_url" ]
75
+
76
+ self .client = openai .OpenAI (** client_args )
77
+
78
+ @override
79
+ def update_config (self , ** model_config : Unpack [BasetenConfig ]) -> None : # type: ignore[override]
80
+ """Update the Baseten model configuration with the provided arguments.
81
+
82
+ Args:
83
+ **model_config: Configuration overrides.
84
+ """
85
+ self .config .update (model_config )
86
+
87
+ @override
88
+ def get_config (self ) -> BasetenConfig :
89
+ """Get the Baseten model configuration.
90
+
91
+ Returns:
92
+ The Baseten model configuration.
93
+ """
94
+ return cast (BasetenModel .BasetenConfig , self .config )
95
+
96
+ @override
97
+ def stream (self , request : dict [str , Any ]) -> Iterable [dict [str , Any ]]:
98
+ """Send the request to the Baseten model and get the streaming response.
99
+
100
+ Args:
101
+ request: The formatted request to send to the Baseten model.
102
+
103
+ Returns:
104
+ An iterable of response events from the Baseten model.
105
+ """
106
+ response = self .client .chat .completions .create (** request )
107
+
108
+ yield {"chunk_type" : "message_start" }
109
+ yield {"chunk_type" : "content_start" , "data_type" : "text" }
110
+
111
+ tool_calls : dict [int , list [Any ]] = {}
112
+
113
+ for event in response :
114
+ # Defensive: skip events with empty or missing choices
115
+ if not getattr (event , "choices" , None ):
116
+ continue
117
+ choice = event .choices [0 ]
118
+
119
+ if choice .delta .content :
120
+ yield {"chunk_type" : "content_delta" , "data_type" : "text" , "data" : choice .delta .content }
121
+
122
+ if hasattr (choice .delta , "reasoning_content" ) and choice .delta .reasoning_content :
123
+ yield {
124
+ "chunk_type" : "content_delta" ,
125
+ "data_type" : "reasoning_content" ,
126
+ "data" : choice .delta .reasoning_content ,
127
+ }
128
+
129
+ for tool_call in choice .delta .tool_calls or []:
130
+ tool_calls .setdefault (tool_call .index , []).append (tool_call )
131
+
132
+ if choice .finish_reason :
133
+ break
134
+
135
+ yield {"chunk_type" : "content_stop" , "data_type" : "text" }
136
+
137
+ for tool_deltas in tool_calls .values ():
138
+ yield {"chunk_type" : "content_start" , "data_type" : "tool" , "data" : tool_deltas [0 ]}
139
+
140
+ for tool_delta in tool_deltas :
141
+ yield {"chunk_type" : "content_delta" , "data_type" : "tool" , "data" : tool_delta }
142
+
143
+ yield {"chunk_type" : "content_stop" , "data_type" : "tool" }
144
+
145
+ yield {"chunk_type" : "message_stop" , "data" : choice .finish_reason }
146
+
147
+ # Skip remaining events as we don't have use for anything except the final usage payload
148
+ for event in response :
149
+ _ = event
150
+
151
+ yield {"chunk_type" : "metadata" , "data" : event .usage }
152
+
153
+ @override
154
+ def structured_output (
155
+ self , output_model : Type [T ], prompt : Messages
156
+ ) -> Generator [dict [str , Union [T , Any ]], None , None ]:
157
+ """Get structured output from the model.
158
+
159
+ Args:
160
+ output_model: The output model to use for the agent.
161
+ prompt: The prompt messages to use for the agent.
162
+
163
+ Yields:
164
+ Model events with the last being the structured output.
165
+ """
166
+ response : ParsedChatCompletion = self .client .beta .chat .completions .parse ( # type: ignore
167
+ model = self .get_config ()["model_id" ],
168
+ messages = super ().format_request (prompt )["messages" ],
169
+ response_format = output_model ,
170
+ )
171
+
172
+ parsed : T | None = None
173
+ # Find the first choice with tool_calls
174
+ if len (response .choices ) > 1 :
175
+ raise ValueError ("Multiple choices found in the Baseten response." )
176
+
177
+ for choice in response .choices :
178
+ if isinstance (choice .message .parsed , output_model ):
179
+ parsed = choice .message .parsed
180
+ break
181
+
182
+ if parsed :
183
+ yield {"output" : parsed }
184
+ else :
185
+ raise ValueError ("No valid tool use or tool use input was found in the Baseten response." )
0 commit comments