Skip to content

Commit 9919a2e

Browse files
committed
Add support for langchain 1.0.4, also fix other bugs.
Signed-off-by: Joe Klein <joseph.klein@oracle.com>
1 parent 0eaca99 commit 9919a2e

File tree

4 files changed

+2563
-1618
lines changed

4 files changed

+2563
-1618
lines changed

libs/oci/langchain_oci/chat_models/oci_generative_ai.py

Lines changed: 117 additions & 99 deletions
Original file line numberDiff line numberDiff line change
@@ -91,25 +91,73 @@ def remove_signature_from_tool_description(name: str, description: str) -> str:
9191
description = re.sub(r"(?s)(?:\n?\n\s*?)?Args:.*$", "", description)
9292
return description
9393

94+
@staticmethod
95+
def oci_to_dict(obj):
96+
if obj is None or isinstance(obj, (str, int, float, bool)):
97+
return obj
98+
if isinstance(obj, dict):
99+
return {k: OCIUtils.oci_to_dict(v) for k, v in obj.items()}
100+
if isinstance(obj, list):
101+
return [OCIUtils.oci_to_dict(x) for x in obj]
102+
if hasattr(obj, "__dict__"):
103+
out = {k: OCIUtils.oci_to_dict(getattr(obj, k)) for k in obj.__dict__ if not k.startswith("_")}
104+
if hasattr(obj, "attribute_map"):
105+
attribute_map = getattr(obj, "attribute_map", {})
106+
out = {attribute_map.get(k, k): v for k, v in out.items()}
107+
return out
108+
return str(obj)
109+
110+
94111
@staticmethod
95112
def convert_oci_tool_call_to_langchain(tool_call: Any) -> ToolCall:
96113
"""Convert an OCI tool call to a LangChain ToolCall."""
97-
parsed = json.loads(tool_call.arguments)
114+
# Handle various possible provider API shapes ("arguments" or "parameters")
115+
# If object has an 'arguments' attr use it, else use 'parameters'
116+
117+
# Cohere models return a "parameters" dict, Meta models .arguments (JSON string)
118+
if hasattr(tool_call, "arguments"):
119+
parsed = json.loads(tool_call.arguments)
120+
# If the parsed result is a string, it means the JSON was escaped, so parse again
121+
if isinstance(parsed, str):
122+
try:
123+
parsed = json.loads(parsed)
124+
except json.JSONDecodeError:
125+
pass
126+
args = parsed
127+
elif hasattr(tool_call, "parameters"):
128+
args = tool_call.parameters # Already a dict
129+
elif isinstance(tool_call, dict):
130+
# Defensive: handle raw dict
131+
if "arguments" in tool_call:
132+
value = tool_call["arguments"]
133+
args = json.loads(value) if isinstance(value, str) else value
134+
elif "parameters" in tool_call:
135+
args = tool_call["parameters"]
136+
else:
137+
args = {}
138+
else:
139+
args = {}
98140

99-
# If the parsed result is a string, it means the JSON was escaped, so parse again
100-
if isinstance(parsed, str):
101-
try:
102-
parsed = json.loads(parsed)
103-
except json.JSONDecodeError:
104-
# If it's not valid JSON, keep it as a string
105-
pass
141+
# Try to get id, fallback to random if missing
142+
if hasattr(tool_call, "id"):
143+
tool_call_id = tool_call.id
144+
elif isinstance(tool_call, dict) and "id" in tool_call:
145+
tool_call_id = tool_call["id"]
146+
else:
147+
tool_call_id = uuid.uuid4().hex[:]
148+
149+
# Try to get name
150+
if hasattr(tool_call, "name"):
151+
name = tool_call.name
152+
elif isinstance(tool_call, dict) and "name" in tool_call:
153+
name = tool_call["name"]
154+
else:
155+
name = "unknown_tool"
106156

107157
return ToolCall(
108-
name=tool_call.name,
109-
args=parsed
110-
if "arguments" in tool_call.attribute_map
111-
else tool_call.parameters,
112-
id=tool_call.id if "id" in tool_call.attribute_map else uuid.uuid4().hex[:],
158+
name=name,
159+
args=args,
160+
id=tool_call_id,
113161
)
114162

115163

@@ -223,9 +271,6 @@ def __init__(self) -> None:
223271
"SYSTEM": models.CohereSystemMessage,
224272
"TOOL": models.CohereToolMessage,
225273
}
226-
227-
self.oci_response_json_schema = models.ResponseJsonSchema
228-
self.oci_json_schema_response_format = models.JsonSchemaResponseFormat
229274
self.chat_api_format = models.BaseChatRequest.API_FORMAT_COHERE
230275

231276
def chat_response_to_text(self, response: Any) -> str:
@@ -255,11 +300,6 @@ def chat_generation_info(self, response: Any) -> Dict[str, Any]:
255300
"is_search_required": response.data.chat_response.is_search_required,
256301
"finish_reason": response.data.chat_response.finish_reason,
257302
}
258-
259-
# Include token usage if available
260-
if hasattr(response.data.chat_response, "usage") and response.data.chat_response.usage:
261-
generation_info["total_tokens"] = response.data.chat_response.usage.total_tokens
262-
263303
# Include tool calls if available
264304
if self.chat_tool_calls(response):
265305
generation_info["tool_calls"] = self.format_response_tool_calls(
@@ -302,8 +342,12 @@ def format_response_tool_calls(
302342
{
303343
"id": uuid.uuid4().hex[:],
304344
"function": {
305-
"name": tool_call.name,
306-
"arguments": json.dumps(tool_call.parameters),
345+
"name": getattr(tool_call, "name", getattr(tool_call, "name", "unknown_tool")),
346+
"arguments": (
347+
json.dumps(json.loads(tool_call.arguments))
348+
if hasattr(tool_call, "arguments")
349+
else json.dumps(getattr(tool_call, "parameters", {}))
350+
),
307351
},
308352
"type": "function",
309353
}
@@ -385,7 +429,9 @@ def messages_to_oci_params(
385429
self.oci_chat_message[self.get_role(msg)](
386430
tool_results=[
387431
self.oci_tool_result(
388-
call=self.oci_tool_call(name=msg.name, parameters={}),
432+
call=self.oci_tool_call(
433+
name=msg.name, parameters={}
434+
),
389435
outputs=[{"output": msg.content}],
390436
)
391437
],
@@ -397,17 +443,9 @@ def messages_to_oci_params(
397443
for i, message in enumerate(messages[::-1]):
398444
current_turn.append(message)
399445
if isinstance(message, HumanMessage):
400-
if len(messages) > i and isinstance(
401-
messages[len(messages) - i - 2], ToolMessage
402-
):
403-
# add dummy message REPEATING the tool_result to avoid
404-
# the error about ToolMessage needing to be followed
405-
# by an AI message
406-
oci_chat_history.append(
407-
self.oci_chat_message["CHATBOT"](
408-
message=messages[len(messages) - i - 2].content
409-
)
410-
)
446+
if len(messages) > i and isinstance(messages[len(messages) - i - 2], ToolMessage):
447+
# add dummy message REPEATING the tool_result to avoid the error about ToolMessage needing to be followed by an AI message
448+
oci_chat_history.append(self.oci_chat_message['CHATBOT'](message=messages[len(messages) - i - 2].content))
411449
break
412450
current_turn = list(reversed(current_turn))
413451

@@ -601,10 +639,6 @@ def __init__(self) -> None:
601639
self.oci_tool_call = models.FunctionCall
602640
self.oci_tool_message = models.ToolMessage
603641

604-
# Response format models
605-
self.oci_response_json_schema = models.ResponseJsonSchema
606-
self.oci_json_schema_response_format = models.JsonSchemaResponseFormat
607-
608642
self.chat_api_format = models.BaseChatRequest.API_FORMAT_GENERIC
609643

610644
def chat_response_to_text(self, response: Any) -> str:
@@ -630,11 +664,6 @@ def chat_generation_info(self, response: Any) -> Dict[str, Any]:
630664
"finish_reason": response.data.chat_response.choices[0].finish_reason,
631665
"time_created": str(response.data.chat_response.time_created),
632666
}
633-
634-
# Include token usage if available
635-
if hasattr(response.data.chat_response, "usage") and response.data.chat_response.usage:
636-
generation_info["total_tokens"] = response.data.chat_response.usage.total_tokens
637-
638667
if self.chat_tool_calls(response):
639668
generation_info["tool_calls"] = self.format_response_tool_calls(
640669
self.chat_tool_calls(response)
@@ -668,8 +697,12 @@ def format_response_tool_calls(self, tool_calls: List[Any]) -> List[Dict]:
668697
{
669698
"id": tool_call.id,
670699
"function": {
671-
"name": tool_call.name,
672-
"arguments": json.loads(tool_call.arguments),
700+
"name": getattr(tool_call, "name", getattr(tool_call, "name", "unknown_tool")),
701+
"arguments": (
702+
json.dumps(json.loads(tool_call.arguments))
703+
if hasattr(tool_call, "arguments")
704+
else json.dumps(getattr(tool_call, "parameters", {}))
705+
),
673706
},
674707
"type": "function",
675708
}
@@ -695,7 +728,11 @@ def format_stream_tool_calls(
695728
"id": tool_call.get("id", ""),
696729
"function": {
697730
"name": tool_call.get("name", ""),
698-
"arguments": tool_call.get("arguments", ""),
731+
"arguments": (
732+
json.dumps(json.loads(tool_call["arguments"]))
733+
if "arguments" in tool_call
734+
else json.dumps(tool_call.get("parameters", {}))
735+
),
699736
},
700737
"type": "function",
701738
}
@@ -745,8 +782,8 @@ def messages_to_oci_params(
745782
)
746783
else:
747784
oci_message = self.oci_chat_message[role](content=tool_content)
748-
elif isinstance(message, AIMessage) and (
749-
message.tool_calls or message.additional_kwargs.get("tool_calls")
785+
elif isinstance(message, AIMessage) and message.additional_kwargs.get(
786+
"tool_calls"
750787
):
751788
# Process content and tool calls for assistant messages
752789
content = self._process_message_content(message.content)
@@ -904,22 +941,7 @@ def convert_to_oci_tool(
904941
Raises:
905942
ValueError: If the tool type is not supported.
906943
"""
907-
if (isinstance(tool, type) and issubclass(tool, BaseModel)) or callable(tool):
908-
as_json_schema_function = convert_to_openai_function(tool)
909-
parameters = as_json_schema_function.get("parameters", {})
910-
return self.oci_function_definition(
911-
name=as_json_schema_function.get("name"),
912-
description=as_json_schema_function.get(
913-
"description",
914-
as_json_schema_function.get("name"),
915-
),
916-
parameters={
917-
"type": "object",
918-
"properties": parameters.get("properties", {}),
919-
"required": parameters.get("required", []),
920-
},
921-
)
922-
elif isinstance(tool, BaseTool):
944+
if isinstance(tool, BaseTool):
923945
return self.oci_function_definition(
924946
name=tool.name,
925947
description=OCIUtils.remove_signature_from_tool_description(
@@ -941,6 +963,21 @@ def convert_to_oci_tool(
941963
],
942964
},
943965
)
966+
elif (isinstance(tool, type) and issubclass(tool, BaseModel)) or callable(tool):
967+
as_json_schema_function = convert_to_openai_function(tool)
968+
parameters = as_json_schema_function.get("parameters", {})
969+
return self.oci_function_definition(
970+
name=as_json_schema_function.get("name"),
971+
description=as_json_schema_function.get(
972+
"description",
973+
as_json_schema_function.get("name"),
974+
),
975+
parameters={
976+
"type": "object",
977+
"properties": parameters.get("properties", {}),
978+
"required": parameters.get("required", []),
979+
},
980+
)
944981
raise ValueError(
945982
f"Unsupported tool type {type(tool)}. "
946983
"Tool must be passed in as a BaseTool "
@@ -1034,7 +1071,6 @@ def process_stream_tool_calls(
10341071

10351072
class MetaProvider(GenericProvider):
10361073
"""Provider for Meta models. This provider is for backward compatibility."""
1037-
10381074
pass
10391075

10401076

@@ -1151,27 +1187,14 @@ def _prepare_request(
11511187
"Please make sure you have the oci package installed."
11521188
) from ex
11531189

1154-
oci_params = self._provider.messages_to_oci_params(
1155-
messages,
1156-
max_sequential_tool_calls=self.max_sequential_tool_calls,
1157-
**kwargs
1158-
)
1190+
oci_params = self._provider.messages_to_oci_params(messages, **kwargs)
11591191

11601192
oci_params["is_stream"] = stream
11611193
_model_kwargs = self.model_kwargs or {}
11621194

11631195
if stop is not None:
11641196
_model_kwargs[self._provider.stop_sequence_key] = stop
11651197

1166-
# Warn if using max_tokens with OpenAI models
1167-
if self.model_id and self.model_id.startswith("openai.") and "max_tokens" in _model_kwargs:
1168-
import warnings
1169-
warnings.warn(
1170-
f"OpenAI models require 'max_completion_tokens' instead of 'max_tokens'.",
1171-
UserWarning,
1172-
stacklevel=2
1173-
)
1174-
11751198
chat_params = {**_model_kwargs, **kwargs, **oci_params}
11761199

11771200
if not self.model_id:
@@ -1247,14 +1270,14 @@ def with_structured_output(
12471270
`method` is "function_calling" and `schema` is a dict, then the dict
12481271
must match the OCI Generative AI function-calling spec.
12491272
method:
1250-
The method for steering model generation, either "function_calling" (default method)
1251-
or "json_mode" or "json_schema". If "function_calling" then the schema
1273+
The method for steering model generation, either "function_calling"
1274+
or "json_mode" or "json_schema. If "function_calling" then the schema
12521275
will be converted to an OCI function and the returned model will make
12531276
use of the function-calling API. If "json_mode" then Cohere's JSON mode will be
12541277
used. Note that if using "json_mode" then you must include instructions
12551278
for formatting the output into the desired schema into the model call.
12561279
If "json_schema" then it allows the user to pass a json schema (or pydantic)
1257-
to the model for structured output.
1280+
to the model for structured output. This is the default method.
12581281
include_raw:
12591282
If False then only the parsed structured output is returned. If
12601283
an error occurs during model output parsing it will be raised. If True
@@ -1305,24 +1328,19 @@ def with_structured_output(
13051328
else JsonOutputParser()
13061329
)
13071330
elif method == "json_schema":
1308-
json_schema_dict = (
1309-
schema.model_json_schema() # type: ignore[union-attr]
1331+
response_format = (
1332+
dict(
1333+
schema.model_json_schema().items() # type: ignore[union-attr]
1334+
)
13101335
if is_pydantic_schema
13111336
else schema
13121337
)
1313-
1314-
response_json_schema = self._provider.oci_response_json_schema(
1315-
name=json_schema_dict.get("title", "response"),
1316-
description=json_schema_dict.get("description", ""),
1317-
schema=json_schema_dict,
1318-
is_strict=True
1319-
)
1320-
1321-
response_format_obj = self._provider.oci_json_schema_response_format(
1322-
json_schema=response_json_schema
1323-
)
1324-
1325-
llm = self.bind(response_format=response_format_obj)
1338+
llm_response_format: Dict[Any, Any] = {"type": "JSON_OBJECT"}
1339+
llm_response_format["schema"] = {
1340+
k: v
1341+
for k, v in response_format.items() # type: ignore[union-attr]
1342+
}
1343+
llm = self.bind(response_format=llm_response_format)
13261344
if is_pydantic_schema:
13271345
output_parser = PydanticOutputParser(pydantic_object=schema)
13281346
else:
@@ -1400,7 +1418,7 @@ def _generate(
14001418
for tool_call in self._provider.chat_tool_calls(response)
14011419
]
14021420
message = AIMessage(
1403-
content=content or "",
1421+
content=content,
14041422
additional_kwargs=generation_info,
14051423
tool_calls=tool_calls,
14061424
)

0 commit comments

Comments
 (0)