@@ -91,25 +91,73 @@ def remove_signature_from_tool_description(name: str, description: str) -> str:
9191 description = re .sub (r"(?s)(?:\n?\n\s*?)?Args:.*$" , "" , description )
9292 return description
9393
94+ @staticmethod
95+ def oci_to_dict (obj ):
96+ if obj is None or isinstance (obj , (str , int , float , bool )):
97+ return obj
98+ if isinstance (obj , dict ):
99+ return {k : OCIUtils .oci_to_dict (v ) for k , v in obj .items ()}
100+ if isinstance (obj , list ):
101+ return [OCIUtils .oci_to_dict (x ) for x in obj ]
102+ if hasattr (obj , "__dict__" ):
103+ out = {k : OCIUtils .oci_to_dict (getattr (obj , k )) for k in obj .__dict__ if not k .startswith ("_" )}
104+ if hasattr (obj , "attribute_map" ):
105+ attribute_map = getattr (obj , "attribute_map" , {})
106+ out = {attribute_map .get (k , k ): v for k , v in out .items ()}
107+ return out
108+ return str (obj )
109+
110+
94111 @staticmethod
95112 def convert_oci_tool_call_to_langchain (tool_call : Any ) -> ToolCall :
96113 """Convert an OCI tool call to a LangChain ToolCall."""
97- parsed = json .loads (tool_call .arguments )
114+ # Handle various possible provider API shapes ("arguments" or "parameters")
115+ # If object has an 'arguments' attr use it, else use 'parameters'
116+
117+ # Cohere models return a "parameters" dict, Meta models .arguments (JSON string)
118+ if hasattr (tool_call , "arguments" ):
119+ parsed = json .loads (tool_call .arguments )
120+ # If the parsed result is a string, it means the JSON was escaped, so parse again
121+ if isinstance (parsed , str ):
122+ try :
123+ parsed = json .loads (parsed )
124+ except json .JSONDecodeError :
125+ pass
126+ args = parsed
127+ elif hasattr (tool_call , "parameters" ):
128+ args = tool_call .parameters # Already a dict
129+ elif isinstance (tool_call , dict ):
130+ # Defensive: handle raw dict
131+ if "arguments" in tool_call :
132+ value = tool_call ["arguments" ]
133+ args = json .loads (value ) if isinstance (value , str ) else value
134+ elif "parameters" in tool_call :
135+ args = tool_call ["parameters" ]
136+ else :
137+ args = {}
138+ else :
139+ args = {}
98140
99- # If the parsed result is a string, it means the JSON was escaped, so parse again
100- if isinstance (parsed , str ):
101- try :
102- parsed = json .loads (parsed )
103- except json .JSONDecodeError :
104- # If it's not valid JSON, keep it as a string
105- pass
141+ # Try to get id, fallback to random if missing
142+ if hasattr (tool_call , "id" ):
143+ tool_call_id = tool_call .id
144+ elif isinstance (tool_call , dict ) and "id" in tool_call :
145+ tool_call_id = tool_call ["id" ]
146+ else :
147+ tool_call_id = uuid .uuid4 ().hex [:]
148+
149+ # Try to get name
150+ if hasattr (tool_call , "name" ):
151+ name = tool_call .name
152+ elif isinstance (tool_call , dict ) and "name" in tool_call :
153+ name = tool_call ["name" ]
154+ else :
155+ name = "unknown_tool"
106156
107157 return ToolCall (
108- name = tool_call .name ,
109- args = parsed
110- if "arguments" in tool_call .attribute_map
111- else tool_call .parameters ,
112- id = tool_call .id if "id" in tool_call .attribute_map else uuid .uuid4 ().hex [:],
158+ name = name ,
159+ args = args ,
160+ id = tool_call_id ,
113161 )
114162
115163
@@ -223,9 +271,6 @@ def __init__(self) -> None:
223271 "SYSTEM" : models .CohereSystemMessage ,
224272 "TOOL" : models .CohereToolMessage ,
225273 }
226-
227- self .oci_response_json_schema = models .ResponseJsonSchema
228- self .oci_json_schema_response_format = models .JsonSchemaResponseFormat
229274 self .chat_api_format = models .BaseChatRequest .API_FORMAT_COHERE
230275
231276 def chat_response_to_text (self , response : Any ) -> str :
@@ -255,11 +300,6 @@ def chat_generation_info(self, response: Any) -> Dict[str, Any]:
255300 "is_search_required" : response .data .chat_response .is_search_required ,
256301 "finish_reason" : response .data .chat_response .finish_reason ,
257302 }
258-
259- # Include token usage if available
260- if hasattr (response .data .chat_response , "usage" ) and response .data .chat_response .usage :
261- generation_info ["total_tokens" ] = response .data .chat_response .usage .total_tokens
262-
263303 # Include tool calls if available
264304 if self .chat_tool_calls (response ):
265305 generation_info ["tool_calls" ] = self .format_response_tool_calls (
@@ -302,8 +342,12 @@ def format_response_tool_calls(
302342 {
303343 "id" : uuid .uuid4 ().hex [:],
304344 "function" : {
305- "name" : tool_call .name ,
306- "arguments" : json .dumps (tool_call .parameters ),
345+ "name" : getattr (tool_call , "name" , getattr (tool_call , "name" , "unknown_tool" )),
346+ "arguments" : (
347+ json .dumps (json .loads (tool_call .arguments ))
348+ if hasattr (tool_call , "arguments" )
349+ else json .dumps (getattr (tool_call , "parameters" , {}))
350+ ),
307351 },
308352 "type" : "function" ,
309353 }
@@ -385,7 +429,9 @@ def messages_to_oci_params(
385429 self .oci_chat_message [self .get_role (msg )](
386430 tool_results = [
387431 self .oci_tool_result (
388- call = self .oci_tool_call (name = msg .name , parameters = {}),
432+ call = self .oci_tool_call (
433+ name = msg .name , parameters = {}
434+ ),
389435 outputs = [{"output" : msg .content }],
390436 )
391437 ],
@@ -397,17 +443,9 @@ def messages_to_oci_params(
397443 for i , message in enumerate (messages [::- 1 ]):
398444 current_turn .append (message )
399445 if isinstance (message , HumanMessage ):
400- if len (messages ) > i and isinstance (
401- messages [len (messages ) - i - 2 ], ToolMessage
402- ):
403- # add dummy message REPEATING the tool_result to avoid
404- # the error about ToolMessage needing to be followed
405- # by an AI message
406- oci_chat_history .append (
407- self .oci_chat_message ["CHATBOT" ](
408- message = messages [len (messages ) - i - 2 ].content
409- )
410- )
446+ if len (messages ) > i and isinstance (messages [len (messages ) - i - 2 ], ToolMessage ):
447+ # add dummy message REPEATING the tool_result to avoid the error about ToolMessage needing to be followed by an AI message
448+ oci_chat_history .append (self .oci_chat_message ['CHATBOT' ](message = messages [len (messages ) - i - 2 ].content ))
411449 break
412450 current_turn = list (reversed (current_turn ))
413451
@@ -601,10 +639,6 @@ def __init__(self) -> None:
601639 self .oci_tool_call = models .FunctionCall
602640 self .oci_tool_message = models .ToolMessage
603641
604- # Response format models
605- self .oci_response_json_schema = models .ResponseJsonSchema
606- self .oci_json_schema_response_format = models .JsonSchemaResponseFormat
607-
608642 self .chat_api_format = models .BaseChatRequest .API_FORMAT_GENERIC
609643
610644 def chat_response_to_text (self , response : Any ) -> str :
@@ -630,11 +664,6 @@ def chat_generation_info(self, response: Any) -> Dict[str, Any]:
630664 "finish_reason" : response .data .chat_response .choices [0 ].finish_reason ,
631665 "time_created" : str (response .data .chat_response .time_created ),
632666 }
633-
634- # Include token usage if available
635- if hasattr (response .data .chat_response , "usage" ) and response .data .chat_response .usage :
636- generation_info ["total_tokens" ] = response .data .chat_response .usage .total_tokens
637-
638667 if self .chat_tool_calls (response ):
639668 generation_info ["tool_calls" ] = self .format_response_tool_calls (
640669 self .chat_tool_calls (response )
@@ -668,8 +697,12 @@ def format_response_tool_calls(self, tool_calls: List[Any]) -> List[Dict]:
668697 {
669698 "id" : tool_call .id ,
670699 "function" : {
671- "name" : tool_call .name ,
672- "arguments" : json .loads (tool_call .arguments ),
700+ "name" : getattr (tool_call , "name" , getattr (tool_call , "name" , "unknown_tool" )),
701+ "arguments" : (
702+ json .dumps (json .loads (tool_call .arguments ))
703+ if hasattr (tool_call , "arguments" )
704+ else json .dumps (getattr (tool_call , "parameters" , {}))
705+ ),
673706 },
674707 "type" : "function" ,
675708 }
@@ -695,7 +728,11 @@ def format_stream_tool_calls(
695728 "id" : tool_call .get ("id" , "" ),
696729 "function" : {
697730 "name" : tool_call .get ("name" , "" ),
698- "arguments" : tool_call .get ("arguments" , "" ),
731+ "arguments" : (
732+ json .dumps (json .loads (tool_call ["arguments" ]))
733+ if "arguments" in tool_call
734+ else json .dumps (tool_call .get ("parameters" , {}))
735+ ),
699736 },
700737 "type" : "function" ,
701738 }
@@ -745,8 +782,8 @@ def messages_to_oci_params(
745782 )
746783 else :
747784 oci_message = self .oci_chat_message [role ](content = tool_content )
748- elif isinstance (message , AIMessage ) and (
749- message . tool_calls or message . additional_kwargs . get ( "tool_calls" )
785+ elif isinstance (message , AIMessage ) and message . additional_kwargs . get (
786+ "tool_calls"
750787 ):
751788 # Process content and tool calls for assistant messages
752789 content = self ._process_message_content (message .content )
@@ -904,22 +941,7 @@ def convert_to_oci_tool(
904941 Raises:
905942 ValueError: If the tool type is not supported.
906943 """
907- if (isinstance (tool , type ) and issubclass (tool , BaseModel )) or callable (tool ):
908- as_json_schema_function = convert_to_openai_function (tool )
909- parameters = as_json_schema_function .get ("parameters" , {})
910- return self .oci_function_definition (
911- name = as_json_schema_function .get ("name" ),
912- description = as_json_schema_function .get (
913- "description" ,
914- as_json_schema_function .get ("name" ),
915- ),
916- parameters = {
917- "type" : "object" ,
918- "properties" : parameters .get ("properties" , {}),
919- "required" : parameters .get ("required" , []),
920- },
921- )
922- elif isinstance (tool , BaseTool ):
944+ if isinstance (tool , BaseTool ):
923945 return self .oci_function_definition (
924946 name = tool .name ,
925947 description = OCIUtils .remove_signature_from_tool_description (
@@ -941,6 +963,21 @@ def convert_to_oci_tool(
941963 ],
942964 },
943965 )
966+ elif (isinstance (tool , type ) and issubclass (tool , BaseModel )) or callable (tool ):
967+ as_json_schema_function = convert_to_openai_function (tool )
968+ parameters = as_json_schema_function .get ("parameters" , {})
969+ return self .oci_function_definition (
970+ name = as_json_schema_function .get ("name" ),
971+ description = as_json_schema_function .get (
972+ "description" ,
973+ as_json_schema_function .get ("name" ),
974+ ),
975+ parameters = {
976+ "type" : "object" ,
977+ "properties" : parameters .get ("properties" , {}),
978+ "required" : parameters .get ("required" , []),
979+ },
980+ )
944981 raise ValueError (
945982 f"Unsupported tool type { type (tool )} . "
946983 "Tool must be passed in as a BaseTool "
@@ -1034,7 +1071,6 @@ def process_stream_tool_calls(
10341071
10351072class MetaProvider (GenericProvider ):
10361073 """Provider for Meta models. This provider is for backward compatibility."""
1037-
10381074 pass
10391075
10401076
@@ -1151,27 +1187,14 @@ def _prepare_request(
11511187 "Please make sure you have the oci package installed."
11521188 ) from ex
11531189
1154- oci_params = self ._provider .messages_to_oci_params (
1155- messages ,
1156- max_sequential_tool_calls = self .max_sequential_tool_calls ,
1157- ** kwargs
1158- )
1190+ oci_params = self ._provider .messages_to_oci_params (messages , ** kwargs )
11591191
11601192 oci_params ["is_stream" ] = stream
11611193 _model_kwargs = self .model_kwargs or {}
11621194
11631195 if stop is not None :
11641196 _model_kwargs [self ._provider .stop_sequence_key ] = stop
11651197
1166- # Warn if using max_tokens with OpenAI models
1167- if self .model_id and self .model_id .startswith ("openai." ) and "max_tokens" in _model_kwargs :
1168- import warnings
1169- warnings .warn (
1170- f"OpenAI models require 'max_completion_tokens' instead of 'max_tokens'." ,
1171- UserWarning ,
1172- stacklevel = 2
1173- )
1174-
11751198 chat_params = {** _model_kwargs , ** kwargs , ** oci_params }
11761199
11771200 if not self .model_id :
@@ -1247,14 +1270,14 @@ def with_structured_output(
12471270 `method` is "function_calling" and `schema` is a dict, then the dict
12481271 must match the OCI Generative AI function-calling spec.
12491272 method:
1250- The method for steering model generation, either "function_calling" (default method)
1251- or "json_mode" or "json_schema" . If "function_calling" then the schema
1273+ The method for steering model generation, either "function_calling"
1274+ or "json_mode" or "json_schema. If "function_calling" then the schema
12521275 will be converted to an OCI function and the returned model will make
12531276 use of the function-calling API. If "json_mode" then Cohere's JSON mode will be
12541277 used. Note that if using "json_mode" then you must include instructions
12551278 for formatting the output into the desired schema into the model call.
12561279 If "json_schema" then it allows the user to pass a json schema (or pydantic)
1257- to the model for structured output.
1280+ to the model for structured output. This is the default method.
12581281 include_raw:
12591282 If False then only the parsed structured output is returned. If
12601283 an error occurs during model output parsing it will be raised. If True
@@ -1305,24 +1328,19 @@ def with_structured_output(
13051328 else JsonOutputParser ()
13061329 )
13071330 elif method == "json_schema" :
1308- json_schema_dict = (
1309- schema .model_json_schema () # type: ignore[union-attr]
1331+ response_format = (
1332+ dict (
1333+ schema .model_json_schema ().items () # type: ignore[union-attr]
1334+ )
13101335 if is_pydantic_schema
13111336 else schema
13121337 )
1313-
1314- response_json_schema = self ._provider .oci_response_json_schema (
1315- name = json_schema_dict .get ("title" , "response" ),
1316- description = json_schema_dict .get ("description" , "" ),
1317- schema = json_schema_dict ,
1318- is_strict = True
1319- )
1320-
1321- response_format_obj = self ._provider .oci_json_schema_response_format (
1322- json_schema = response_json_schema
1323- )
1324-
1325- llm = self .bind (response_format = response_format_obj )
1338+ llm_response_format : Dict [Any , Any ] = {"type" : "JSON_OBJECT" }
1339+ llm_response_format ["schema" ] = {
1340+ k : v
1341+ for k , v in response_format .items () # type: ignore[union-attr]
1342+ }
1343+ llm = self .bind (response_format = llm_response_format )
13261344 if is_pydantic_schema :
13271345 output_parser = PydanticOutputParser (pydantic_object = schema )
13281346 else :
@@ -1400,7 +1418,7 @@ def _generate(
14001418 for tool_call in self ._provider .chat_tool_calls (response )
14011419 ]
14021420 message = AIMessage (
1403- content = content or "" ,
1421+ content = content ,
14041422 additional_kwargs = generation_info ,
14051423 tool_calls = tool_calls ,
14061424 )
0 commit comments