-
Notifications
You must be signed in to change notification settings - Fork 1
feat: support structured outputs (response_format) in chat completions #43
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -82,6 +82,18 @@ def _create_non_streaming_response(chat_request: CreateChatCompletionRequest): | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| tools_list.append(tool) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| model = model.bind_tools(tools_list) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # Bind response_format if provided (json_object or json_schema) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if chat_request.response_format: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| rf = chat_request.response_format | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| rf_type = ( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| rf.get("type", "text") | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if isinstance(rf, dict) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| else getattr(rf, "type", "text") | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if rf_type != "text": | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| rf_dict = rf if isinstance(rf, dict) else {"type": rf_type} | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| rf_dict = rf if isinstance(rf, dict) else {"type": rf_type} | |
| if isinstance(rf, dict): | |
| rf_dict = rf | |
| else: | |
| # Attempt to fully serialize object forms so json_schema and other | |
| # fields are preserved instead of being reduced to {"type": ...}. | |
| rf_dict = None | |
| # Prefer Pydantic-style serialization if available. | |
| model_dump = getattr(rf, "model_dump", None) | |
| if callable(model_dump): | |
| try: | |
| dumped = model_dump() | |
| if isinstance(dumped, dict): | |
| rf_dict = dumped | |
| except Exception: | |
| rf_dict = None | |
| # Fallback: use public attributes from __dict__ if present. | |
| if rf_dict is None and hasattr(rf, "__dict__"): | |
| try: | |
| rf_dict = { | |
| k: v | |
| for k, v in vars(rf).items() | |
| if not k.startswith("_") | |
| } | |
| except Exception: | |
| rf_dict = None | |
| # Final fallback: at least preserve the type. | |
| if rf_dict is None: | |
| rf_dict = {"type": rf_type} |
Copilot
AI
Apr 1, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The response_format binding logic is duplicated in both the non-streaming and streaming paths. Consider extracting it into a small helper (e.g., that takes model and response_format) so behavior stays consistent and future changes don't accidentally diverge between the two codepaths.
Copilot
AI
Apr 1, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
New response_format support is added to the streaming pipeline here, but the added tests only exercise the non-streaming path. Please add a unit test that verifies model.bind(response_format=...) is applied before model.stream(...) (and that it composes correctly with bind_tools(...)).
Copilot
AI
Apr 1, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
response_format is inserted into the canonical hash dict without normalization. If response_format is ever provided as a non-dict object (which the binding logic above appears to support), json.dumps(..., sort_keys=True) will fail here and the request will 500. Consider normalizing response_format into a JSON-serializable dict in one place (and using the same normalization for both hashing and model binding).
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,340 @@ | ||
| import json | ||
| import unittest | ||
| from unittest.mock import patch, MagicMock | ||
|
|
||
| from tee_gateway.controllers.chat_controller import ( | ||
| _parse_chat_request as parse_chat_request, | ||
| _chat_request_to_dict as chat_request_to_dict, | ||
| ) | ||
| from tee_gateway.models.create_chat_completion_request import ( | ||
| CreateChatCompletionRequest, | ||
| ) | ||
|
|
||
|
|
||
| class TestResponseFormatParsing(unittest.TestCase): | ||
| """Tests for response_format parsing from request dicts.""" | ||
|
|
||
| def _base_request(self, **overrides): | ||
| d = { | ||
| "model": "gpt-4o", | ||
| "messages": [{"role": "user", "content": "Hello"}], | ||
| } | ||
| d.update(overrides) | ||
| return d | ||
|
|
||
| def test_no_response_format(self): | ||
| req = parse_chat_request(self._base_request()) | ||
| self.assertIsNone(req.response_format) | ||
|
|
||
| def test_text_response_format(self): | ||
| req = parse_chat_request(self._base_request(response_format={"type": "text"})) | ||
| self.assertEqual(req.response_format, {"type": "text"}) | ||
|
|
||
| def test_json_object_response_format(self): | ||
| rf = {"type": "json_object"} | ||
| req = parse_chat_request(self._base_request(response_format=rf)) | ||
| self.assertEqual(req.response_format, {"type": "json_object"}) | ||
|
|
||
| def test_json_schema_response_format(self): | ||
| rf = { | ||
| "type": "json_schema", | ||
| "json_schema": { | ||
| "name": "user_info", | ||
| "strict": True, | ||
| "schema": { | ||
| "type": "object", | ||
| "properties": { | ||
| "name": {"type": "string"}, | ||
| "age": {"type": "integer"}, | ||
| }, | ||
| "required": ["name", "age"], | ||
| "additionalProperties": False, | ||
| }, | ||
| }, | ||
| } | ||
| req = parse_chat_request(self._base_request(response_format=rf)) | ||
| self.assertEqual(req.response_format["type"], "json_schema") | ||
| self.assertEqual(req.response_format["json_schema"]["name"], "user_info") | ||
| self.assertTrue(req.response_format["json_schema"]["strict"]) | ||
|
|
||
|
|
||
| class TestResponseFormatInHashDict(unittest.TestCase): | ||
| """Tests that response_format is included in the TEE hash dict.""" | ||
|
|
||
| def _make_request(self, response_format=None): | ||
| return CreateChatCompletionRequest( | ||
| model="gpt-4o", | ||
| messages=[], | ||
| temperature=1.0, | ||
| response_format=response_format, | ||
| ) | ||
|
|
||
| def test_no_response_format_omitted(self): | ||
| req = self._make_request() | ||
| d = chat_request_to_dict(req) | ||
| self.assertNotIn("response_format", d) | ||
|
|
||
| def test_json_object_included(self): | ||
| req = self._make_request(response_format={"type": "json_object"}) | ||
| d = chat_request_to_dict(req) | ||
| self.assertIn("response_format", d) | ||
| self.assertEqual(d["response_format"]["type"], "json_object") | ||
|
|
||
| def test_json_schema_included(self): | ||
| rf = { | ||
| "type": "json_schema", | ||
| "json_schema": { | ||
| "name": "math_answer", | ||
| "schema": { | ||
| "type": "object", | ||
| "properties": {"answer": {"type": "number"}}, | ||
| }, | ||
| }, | ||
| } | ||
| req = self._make_request(response_format=rf) | ||
| d = chat_request_to_dict(req) | ||
| self.assertEqual(d["response_format"], rf) | ||
|
|
||
| def test_hash_deterministic_with_response_format(self): | ||
| rf = {"type": "json_object"} | ||
| req = self._make_request(response_format=rf) | ||
| d1 = json.dumps(chat_request_to_dict(req), sort_keys=True) | ||
| d2 = json.dumps(chat_request_to_dict(req), sort_keys=True) | ||
| self.assertEqual(d1, d2) | ||
|
|
||
| def test_hash_differs_with_and_without_response_format(self): | ||
| req_plain = self._make_request() | ||
| req_json = self._make_request(response_format={"type": "json_object"}) | ||
| h1 = json.dumps(chat_request_to_dict(req_plain), sort_keys=True) | ||
| h2 = json.dumps(chat_request_to_dict(req_json), sort_keys=True) | ||
| self.assertNotEqual(h1, h2) | ||
|
|
||
|
|
||
| class TestResponseFormatModelBinding(unittest.TestCase): | ||
| """Tests that response_format is bound to the model before invocation.""" | ||
|
|
||
| @patch("tee_gateway.controllers.chat_controller.compute_tee_msg_hash") | ||
| @patch("tee_gateway.controllers.chat_controller.get_tee_keys") | ||
| @patch("tee_gateway.controllers.chat_controller.convert_messages") | ||
| @patch("tee_gateway.controllers.chat_controller.get_chat_model_cached") | ||
| def test_json_object_binds_to_model( | ||
| self, mock_get_model, mock_convert, mock_tee_keys, mock_hash | ||
| ): | ||
| from tee_gateway.controllers.chat_controller import ( | ||
| _create_non_streaming_response, | ||
| ) | ||
|
|
||
| mock_model = MagicMock() | ||
| mock_bound = MagicMock() | ||
| mock_model.bind.return_value = mock_bound | ||
| mock_get_model.return_value = mock_model | ||
|
|
||
| mock_response = MagicMock() | ||
| mock_response.content = "test" | ||
| mock_response.tool_calls = None | ||
| mock_bound.invoke.return_value = mock_response | ||
|
|
||
| mock_convert.return_value = [] | ||
| mock_hash.return_value = (b"hash", "input_hex", "output_hex") | ||
| mock_keys = MagicMock() | ||
| mock_keys.sign_data.return_value = "sig" | ||
| mock_keys.get_tee_id.return_value = "abc" | ||
| mock_tee_keys.return_value = mock_keys | ||
|
|
||
| req = CreateChatCompletionRequest( | ||
| model="gpt-4o", | ||
| messages=[], | ||
| temperature=1.0, | ||
| response_format={"type": "json_object"}, | ||
| ) | ||
|
|
||
| _create_non_streaming_response(req) | ||
|
|
||
| mock_model.bind.assert_called_once_with(response_format={"type": "json_object"}) | ||
| mock_bound.invoke.assert_called_once() | ||
|
|
||
| @patch("tee_gateway.controllers.chat_controller.compute_tee_msg_hash") | ||
| @patch("tee_gateway.controllers.chat_controller.get_tee_keys") | ||
| @patch("tee_gateway.controllers.chat_controller.convert_messages") | ||
| @patch("tee_gateway.controllers.chat_controller.get_chat_model_cached") | ||
| def test_text_format_does_not_bind( | ||
| self, mock_get_model, mock_convert, mock_tee_keys, mock_hash | ||
| ): | ||
| from tee_gateway.controllers.chat_controller import ( | ||
| _create_non_streaming_response, | ||
| ) | ||
|
|
||
| mock_model = MagicMock() | ||
| mock_get_model.return_value = mock_model | ||
|
|
||
| mock_response = MagicMock() | ||
| mock_response.content = "test" | ||
| mock_response.tool_calls = None | ||
| mock_model.invoke.return_value = mock_response | ||
|
|
||
| mock_convert.return_value = [] | ||
| mock_hash.return_value = (b"hash", "input_hex", "output_hex") | ||
| mock_keys = MagicMock() | ||
| mock_keys.sign_data.return_value = "sig" | ||
| mock_keys.get_tee_id.return_value = "abc" | ||
| mock_tee_keys.return_value = mock_keys | ||
|
|
||
| req = CreateChatCompletionRequest( | ||
| model="gpt-4o", | ||
| messages=[], | ||
| temperature=1.0, | ||
| response_format={"type": "text"}, | ||
| ) | ||
|
|
||
| _create_non_streaming_response(req) | ||
|
|
||
| mock_model.bind.assert_not_called() | ||
| mock_model.invoke.assert_called_once() | ||
|
|
||
| @patch("tee_gateway.controllers.chat_controller.compute_tee_msg_hash") | ||
| @patch("tee_gateway.controllers.chat_controller.get_tee_keys") | ||
| @patch("tee_gateway.controllers.chat_controller.convert_messages") | ||
| @patch("tee_gateway.controllers.chat_controller.get_chat_model_cached") | ||
| def test_no_format_does_not_bind( | ||
| self, mock_get_model, mock_convert, mock_tee_keys, mock_hash | ||
| ): | ||
| from tee_gateway.controllers.chat_controller import ( | ||
| _create_non_streaming_response, | ||
| ) | ||
|
|
||
| mock_model = MagicMock() | ||
| mock_get_model.return_value = mock_model | ||
|
|
||
| mock_response = MagicMock() | ||
| mock_response.content = "result" | ||
| mock_response.tool_calls = None | ||
| mock_model.invoke.return_value = mock_response | ||
|
|
||
| mock_convert.return_value = [] | ||
| mock_hash.return_value = (b"hash", "input_hex", "output_hex") | ||
| mock_keys = MagicMock() | ||
| mock_keys.sign_data.return_value = "sig" | ||
| mock_keys.get_tee_id.return_value = "abc" | ||
| mock_tee_keys.return_value = mock_keys | ||
|
|
||
| req = CreateChatCompletionRequest( | ||
| model="gpt-4o", | ||
| messages=[], | ||
| temperature=1.0, | ||
| ) | ||
|
|
||
| _create_non_streaming_response(req) | ||
|
|
||
| mock_model.bind.assert_not_called() | ||
|
|
||
| @patch("tee_gateway.controllers.chat_controller.compute_tee_msg_hash") | ||
| @patch("tee_gateway.controllers.chat_controller.get_tee_keys") | ||
| @patch("tee_gateway.controllers.chat_controller.convert_messages") | ||
| @patch("tee_gateway.controllers.chat_controller.get_chat_model_cached") | ||
| def test_json_schema_binds_full_schema( | ||
| self, mock_get_model, mock_convert, mock_tee_keys, mock_hash | ||
| ): | ||
| from tee_gateway.controllers.chat_controller import ( | ||
| _create_non_streaming_response, | ||
| ) | ||
|
|
||
| mock_model = MagicMock() | ||
| mock_bound = MagicMock() | ||
| mock_model.bind.return_value = mock_bound | ||
| mock_get_model.return_value = mock_model | ||
|
|
||
| mock_response = MagicMock() | ||
| mock_response.content = '{"name": "Alice", "age": 30}' | ||
| mock_response.tool_calls = None | ||
| mock_bound.invoke.return_value = mock_response | ||
|
|
||
| mock_convert.return_value = [] | ||
| mock_hash.return_value = (b"hash", "input_hex", "output_hex") | ||
| mock_keys = MagicMock() | ||
| mock_keys.sign_data.return_value = "sig" | ||
| mock_keys.get_tee_id.return_value = "abc" | ||
| mock_tee_keys.return_value = mock_keys | ||
|
|
||
| rf = { | ||
| "type": "json_schema", | ||
| "json_schema": { | ||
| "name": "user_info", | ||
| "strict": True, | ||
| "schema": { | ||
| "type": "object", | ||
| "properties": { | ||
| "name": {"type": "string"}, | ||
| "age": {"type": "integer"}, | ||
| }, | ||
| "required": ["name", "age"], | ||
| "additionalProperties": False, | ||
| }, | ||
| }, | ||
| } | ||
|
|
||
| req = CreateChatCompletionRequest( | ||
| model="gpt-4o", | ||
| messages=[], | ||
| temperature=1.0, | ||
| response_format=rf, | ||
| ) | ||
|
|
||
| _create_non_streaming_response(req) | ||
|
|
||
| mock_model.bind.assert_called_once_with(response_format=rf) | ||
|
|
||
|
|
||
| class TestResponseFormatWithTools(unittest.TestCase): | ||
| """Tests that response_format works alongside tool binding.""" | ||
|
|
||
| @patch("tee_gateway.controllers.chat_controller.compute_tee_msg_hash") | ||
| @patch("tee_gateway.controllers.chat_controller.get_tee_keys") | ||
| @patch("tee_gateway.controllers.chat_controller.convert_messages") | ||
| @patch("tee_gateway.controllers.chat_controller.get_chat_model_cached") | ||
| def test_tools_and_response_format_both_bind( | ||
| self, mock_get_model, mock_convert, mock_tee_keys, mock_hash | ||
| ): | ||
| from tee_gateway.controllers.chat_controller import ( | ||
| _create_non_streaming_response, | ||
| ) | ||
|
|
||
| mock_model = MagicMock() | ||
| mock_after_tools = MagicMock() | ||
| mock_after_format = MagicMock() | ||
| mock_model.bind_tools.return_value = mock_after_tools | ||
| mock_after_tools.bind.return_value = mock_after_format | ||
| mock_get_model.return_value = mock_model | ||
|
|
||
| mock_response = MagicMock() | ||
| mock_response.content = '{"result": 42}' | ||
| mock_response.tool_calls = None | ||
| mock_after_format.invoke.return_value = mock_response | ||
|
|
||
| mock_convert.return_value = [] | ||
| mock_hash.return_value = (b"hash", "input_hex", "output_hex") | ||
| mock_keys = MagicMock() | ||
| mock_keys.sign_data.return_value = "sig" | ||
| mock_keys.get_tee_id.return_value = "abc" | ||
| mock_tee_keys.return_value = mock_keys | ||
|
|
||
| req = CreateChatCompletionRequest( | ||
| model="gpt-4o", | ||
| messages=[], | ||
| temperature=1.0, | ||
| tools=[ | ||
| {"type": "function", "function": {"name": "calc", "parameters": {}}} | ||
| ], | ||
| response_format={"type": "json_object"}, | ||
| ) | ||
|
|
||
| _create_non_streaming_response(req) | ||
|
|
||
| mock_model.bind_tools.assert_called_once() | ||
| mock_after_tools.bind.assert_called_once_with( | ||
| response_format={"type": "json_object"} | ||
| ) | ||
| mock_after_format.invoke.assert_called_once() | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| unittest.main() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The comment says this only binds for
json_objectorjson_schema, but the condition actually binds for any non-texttype. Either tighten the check to the supported types or update the comment to reflect the actual behavior (non-text pass-through).