Skip to content

Commit 8b18374

Browse files
[BUG] Fix json_schema method in the with_structured_output function. (#54)
* Fix with_structured_output json_schema method Fix json_schema method in the with_structured_output function. * Remove import and define an alias for the class in the init Remove import and define an alias for the class in the init * Set default structured output method to function_calling Changed the default method for structured output in ChatOCIGenAI from 'json_schema' to 'function_calling'. Updated documentation to clarify the default and suggest alternatives if it fails. * Fix assertions in unit tests to match the new json_schema method in with_structured_output * Clarify with_structured_output methods in README Clarify with_structured_output methods in README
1 parent 8151628 commit 8b18374

File tree

3 files changed

+47
-17
lines changed

3 files changed

+47
-17
lines changed

libs/oci/README.md

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,24 @@ embeddings = OCIGenAIEmbeddings()
6161
embeddings.embed_query("What is the meaning of life?")
6262
```
6363

64+
### 4. Use Structured Output
65+
`ChatOCIGenAI` supports structured output.
66+
67+
<sub>**Note:** The default method is `function_calling`. If default method returns `None` (e.g. for Gemini models), try `json_schema` or `json_mode`.</sub>
68+
69+
```python
70+
from langchain_oci import ChatOCIGenAI
71+
from pydantic import BaseModel
72+
73+
class Joke(BaseModel):
74+
setup: str
75+
punchline: str
76+
77+
llm = ChatOCIGenAI()
78+
structured_llm = llm.with_structured_output(Joke)
79+
structured_llm.invoke("Tell me a joke about programming")
80+
```
81+
6482

6583
## OCI Data Science Model Deployment Examples
6684

libs/oci/langchain_oci/chat_models/oci_generative_ai.py

Lines changed: 25 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -213,6 +213,9 @@ def __init__(self) -> None:
213213
"SYSTEM": models.CohereSystemMessage,
214214
"TOOL": models.CohereToolMessage,
215215
}
216+
217+
self.oci_response_json_schema = models.ResponseJsonSchema
218+
self.oci_json_schema_response_format = models.JsonSchemaResponseFormat
216219
self.chat_api_format = models.BaseChatRequest.API_FORMAT_COHERE
217220

218221
def chat_response_to_text(self, response: Any) -> str:
@@ -588,6 +591,10 @@ def __init__(self) -> None:
588591
self.oci_tool_call = models.FunctionCall
589592
self.oci_tool_message = models.ToolMessage
590593

594+
# Response format models
595+
self.oci_response_json_schema = models.ResponseJsonSchema
596+
self.oci_json_schema_response_format = models.JsonSchemaResponseFormat
597+
591598
self.chat_api_format = models.BaseChatRequest.API_FORMAT_GENERIC
592599

593600
def chat_response_to_text(self, response: Any) -> str:
@@ -1230,14 +1237,14 @@ def with_structured_output(
12301237
`method` is "function_calling" and `schema` is a dict, then the dict
12311238
must match the OCI Generative AI function-calling spec.
12321239
method:
1233-
The method for steering model generation, either "function_calling"
1234-
or "json_mode" or "json_schema. If "function_calling" then the schema
1240+
The method for steering model generation, either "function_calling" (default method)
1241+
or "json_mode" or "json_schema". If "function_calling" then the schema
12351242
will be converted to an OCI function and the returned model will make
12361243
use of the function-calling API. If "json_mode" then Cohere's JSON mode will be
12371244
used. Note that if using "json_mode" then you must include instructions
12381245
for formatting the output into the desired schema into the model call.
12391246
If "json_schema" then it allows the user to pass a json schema (or pydantic)
1240-
to the model for structured output. This is the default method.
1247+
to the model for structured output.
12411248
include_raw:
12421249
If False then only the parsed structured output is returned. If
12431250
an error occurs during model output parsing it will be raised. If True
@@ -1288,19 +1295,24 @@ def with_structured_output(
12881295
else JsonOutputParser()
12891296
)
12901297
elif method == "json_schema":
1291-
response_format = (
1292-
dict(
1293-
schema.model_json_schema().items() # type: ignore[union-attr]
1294-
)
1298+
json_schema_dict = (
1299+
schema.model_json_schema() # type: ignore[union-attr]
12951300
if is_pydantic_schema
12961301
else schema
12971302
)
1298-
llm_response_format: Dict[Any, Any] = {"type": "JSON_OBJECT"}
1299-
llm_response_format["schema"] = {
1300-
k: v
1301-
for k, v in response_format.items() # type: ignore[union-attr]
1302-
}
1303-
llm = self.bind(response_format=llm_response_format)
1303+
1304+
response_json_schema = self._provider.oci_response_json_schema(
1305+
name=json_schema_dict.get("title", "response"),
1306+
description=json_schema_dict.get("description", ""),
1307+
schema=json_schema_dict,
1308+
is_strict=True
1309+
)
1310+
1311+
response_format_obj = self._provider.oci_json_schema_response_format(
1312+
json_schema=response_json_schema
1313+
)
1314+
1315+
llm = self.bind(response_format=response_format_obj)
13041316
if is_pydantic_schema:
13051317
output_parser = PydanticOutputParser(pydantic_object=schema)
13061318
else:

libs/oci/tests/unit_tests/chat_models/test_oci_generative_ai.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515
class MockResponseDict(dict):
1616
def __getattr__(self, val): # type: ignore[no-untyped-def]
17-
return self[val]
17+
return self.get(val)
1818

1919

2020
class MockToolCall(dict):
@@ -473,10 +473,10 @@ class WeatherResponse(BaseModel):
473473
llm = ChatOCIGenAI(model_id="cohere.command-latest", client=oci_gen_ai_client)
474474

475475
def mocked_response(*args, **kwargs): # type: ignore[no-untyped-def]
476-
# Verify that response_format contains the schema
476+
# Verify that response_format is a JsonSchemaResponseFormat object
477477
request = args[0]
478-
assert request.chat_request.response_format["type"] == "JSON_OBJECT"
479-
assert "schema" in request.chat_request.response_format
478+
assert hasattr(request.chat_request, 'response_format')
479+
assert request.chat_request.response_format is not None
480480

481481
return MockResponseDict(
482482
{

0 commit comments

Comments
 (0)