Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 26 additions & 0 deletions tee_gateway/controllers/chat_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,18 @@ def _create_non_streaming_response(chat_request: CreateChatCompletionRequest):
tools_list.append(tool)
model = model.bind_tools(tools_list)

# Bind response_format if provided (json_object or json_schema)
if chat_request.response_format:
rf = chat_request.response_format
rf_type = (
rf.get("type", "text")
if isinstance(rf, dict)
else getattr(rf, "type", "text")
)
if rf_type != "text":
Comment on lines +85 to +93
Copy link

Copilot AI Apr 1, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The comment says this only binds for json_object or json_schema, but the condition actually binds for any non-text type. Either tighten the check to the supported types or update the comment to reflect the actual behavior (non-text pass-through).

Copilot uses AI. Check for mistakes.
rf_dict = rf if isinstance(rf, dict) else {"type": rf_type}
Copy link

Copilot AI Apr 1, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This block attempts to support non-dict response_format values (via getattr(rf, 'type', ...)), but then constructs rf_dict as only {type: ...} and would drop any attached json_schema payload. Either enforce that response_format must be a dict (and reject/normalize earlier) or fully serialize object forms so json_schema details are preserved.

Suggested change
rf_dict = rf if isinstance(rf, dict) else {"type": rf_type}
if isinstance(rf, dict):
rf_dict = rf
else:
# Attempt to fully serialize object forms so json_schema and other
# fields are preserved instead of being reduced to {"type": ...}.
rf_dict = None
# Prefer Pydantic-style serialization if available.
model_dump = getattr(rf, "model_dump", None)
if callable(model_dump):
try:
dumped = model_dump()
if isinstance(dumped, dict):
rf_dict = dumped
except Exception:
rf_dict = None
# Fallback: use public attributes from __dict__ if present.
if rf_dict is None and hasattr(rf, "__dict__"):
try:
rf_dict = {
k: v
for k, v in vars(rf).items()
if not k.startswith("_")
}
except Exception:
rf_dict = None
# Final fallback: at least preserve the type.
if rf_dict is None:
rf_dict = {"type": rf_type}

Copilot uses AI. Check for mistakes.
model = model.bind(response_format=rf_dict)

Comment on lines +85 to +96
Copy link

Copilot AI Apr 1, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The response_format binding logic is duplicated in both the non-streaming and streaming paths. Consider extracting it into a small helper (e.g., that takes model and response_format) so behavior stays consistent and future changes don't accidentally diverge between the two codepaths.

Copilot uses AI. Check for mistakes.
langchain_messages = convert_messages(chat_request.messages)
response = model.invoke(langchain_messages)

Expand Down Expand Up @@ -196,6 +208,18 @@ def _create_streaming_response(chat_request: CreateChatCompletionRequest):
tools_list.append(tool)
model = model.bind_tools(tools_list)

# Bind response_format if provided (json_object or json_schema)
if chat_request.response_format:
rf = chat_request.response_format
rf_type = (
rf.get("type", "text")
if isinstance(rf, dict)
else getattr(rf, "type", "text")
)
if rf_type != "text":
rf_dict = rf if isinstance(rf, dict) else {"type": rf_type}
model = model.bind(response_format=rf_dict)

Comment on lines +211 to +222
Copy link

Copilot AI Apr 1, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

New response_format support is added to the streaming pipeline here, but the added tests only exercise the non-streaming path. Please add a unit test that verifies model.bind(response_format=...) is applied before model.stream(...) (and that it composes correctly with bind_tools(...)).

Copilot uses AI. Check for mistakes.
langchain_messages = convert_messages(chat_request.messages)
tee_keys = get_tee_keys()

Expand Down Expand Up @@ -481,6 +505,8 @@ def _chat_request_to_dict(chat_request: CreateChatCompletionRequest) -> dict:
if isinstance(chat_request.tools, list)
else list(chat_request.tools)
)
if chat_request.response_format:
d["response_format"] = chat_request.response_format
Comment on lines +508 to +509
Copy link

Copilot AI Apr 1, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

response_format is inserted into the canonical hash dict without normalization. If response_format is ever provided as a non-dict object (which the binding logic above appears to support), json.dumps(..., sort_keys=True) will fail here and the request will 500. Consider normalizing response_format into a JSON-serializable dict in one place (and using the same normalization for both hashing and model binding).

Copilot uses AI. Check for mistakes.
return d


Expand Down
340 changes: 340 additions & 0 deletions tests/test_structured_outputs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,340 @@
import json
import unittest
from unittest.mock import patch, MagicMock

from tee_gateway.controllers.chat_controller import (
_parse_chat_request as parse_chat_request,
_chat_request_to_dict as chat_request_to_dict,
)
from tee_gateway.models.create_chat_completion_request import (
CreateChatCompletionRequest,
)


class TestResponseFormatParsing(unittest.TestCase):
"""Tests for response_format parsing from request dicts."""

def _base_request(self, **overrides):
d = {
"model": "gpt-4o",
"messages": [{"role": "user", "content": "Hello"}],
}
d.update(overrides)
return d

def test_no_response_format(self):
req = parse_chat_request(self._base_request())
self.assertIsNone(req.response_format)

def test_text_response_format(self):
req = parse_chat_request(self._base_request(response_format={"type": "text"}))
self.assertEqual(req.response_format, {"type": "text"})

def test_json_object_response_format(self):
rf = {"type": "json_object"}
req = parse_chat_request(self._base_request(response_format=rf))
self.assertEqual(req.response_format, {"type": "json_object"})

def test_json_schema_response_format(self):
rf = {
"type": "json_schema",
"json_schema": {
"name": "user_info",
"strict": True,
"schema": {
"type": "object",
"properties": {
"name": {"type": "string"},
"age": {"type": "integer"},
},
"required": ["name", "age"],
"additionalProperties": False,
},
},
}
req = parse_chat_request(self._base_request(response_format=rf))
self.assertEqual(req.response_format["type"], "json_schema")
self.assertEqual(req.response_format["json_schema"]["name"], "user_info")
self.assertTrue(req.response_format["json_schema"]["strict"])


class TestResponseFormatInHashDict(unittest.TestCase):
"""Tests that response_format is included in the TEE hash dict."""

def _make_request(self, response_format=None):
return CreateChatCompletionRequest(
model="gpt-4o",
messages=[],
temperature=1.0,
response_format=response_format,
)

def test_no_response_format_omitted(self):
req = self._make_request()
d = chat_request_to_dict(req)
self.assertNotIn("response_format", d)

def test_json_object_included(self):
req = self._make_request(response_format={"type": "json_object"})
d = chat_request_to_dict(req)
self.assertIn("response_format", d)
self.assertEqual(d["response_format"]["type"], "json_object")

def test_json_schema_included(self):
rf = {
"type": "json_schema",
"json_schema": {
"name": "math_answer",
"schema": {
"type": "object",
"properties": {"answer": {"type": "number"}},
},
},
}
req = self._make_request(response_format=rf)
d = chat_request_to_dict(req)
self.assertEqual(d["response_format"], rf)

def test_hash_deterministic_with_response_format(self):
rf = {"type": "json_object"}
req = self._make_request(response_format=rf)
d1 = json.dumps(chat_request_to_dict(req), sort_keys=True)
d2 = json.dumps(chat_request_to_dict(req), sort_keys=True)
self.assertEqual(d1, d2)

def test_hash_differs_with_and_without_response_format(self):
req_plain = self._make_request()
req_json = self._make_request(response_format={"type": "json_object"})
h1 = json.dumps(chat_request_to_dict(req_plain), sort_keys=True)
h2 = json.dumps(chat_request_to_dict(req_json), sort_keys=True)
self.assertNotEqual(h1, h2)


class TestResponseFormatModelBinding(unittest.TestCase):
"""Tests that response_format is bound to the model before invocation."""

@patch("tee_gateway.controllers.chat_controller.compute_tee_msg_hash")
@patch("tee_gateway.controllers.chat_controller.get_tee_keys")
@patch("tee_gateway.controllers.chat_controller.convert_messages")
@patch("tee_gateway.controllers.chat_controller.get_chat_model_cached")
def test_json_object_binds_to_model(
self, mock_get_model, mock_convert, mock_tee_keys, mock_hash
):
from tee_gateway.controllers.chat_controller import (
_create_non_streaming_response,
)

mock_model = MagicMock()
mock_bound = MagicMock()
mock_model.bind.return_value = mock_bound
mock_get_model.return_value = mock_model

mock_response = MagicMock()
mock_response.content = "test"
mock_response.tool_calls = None
mock_bound.invoke.return_value = mock_response

mock_convert.return_value = []
mock_hash.return_value = (b"hash", "input_hex", "output_hex")
mock_keys = MagicMock()
mock_keys.sign_data.return_value = "sig"
mock_keys.get_tee_id.return_value = "abc"
mock_tee_keys.return_value = mock_keys

req = CreateChatCompletionRequest(
model="gpt-4o",
messages=[],
temperature=1.0,
response_format={"type": "json_object"},
)

_create_non_streaming_response(req)

mock_model.bind.assert_called_once_with(response_format={"type": "json_object"})
mock_bound.invoke.assert_called_once()

@patch("tee_gateway.controllers.chat_controller.compute_tee_msg_hash")
@patch("tee_gateway.controllers.chat_controller.get_tee_keys")
@patch("tee_gateway.controllers.chat_controller.convert_messages")
@patch("tee_gateway.controllers.chat_controller.get_chat_model_cached")
def test_text_format_does_not_bind(
self, mock_get_model, mock_convert, mock_tee_keys, mock_hash
):
from tee_gateway.controllers.chat_controller import (
_create_non_streaming_response,
)

mock_model = MagicMock()
mock_get_model.return_value = mock_model

mock_response = MagicMock()
mock_response.content = "test"
mock_response.tool_calls = None
mock_model.invoke.return_value = mock_response

mock_convert.return_value = []
mock_hash.return_value = (b"hash", "input_hex", "output_hex")
mock_keys = MagicMock()
mock_keys.sign_data.return_value = "sig"
mock_keys.get_tee_id.return_value = "abc"
mock_tee_keys.return_value = mock_keys

req = CreateChatCompletionRequest(
model="gpt-4o",
messages=[],
temperature=1.0,
response_format={"type": "text"},
)

_create_non_streaming_response(req)

mock_model.bind.assert_not_called()
mock_model.invoke.assert_called_once()

@patch("tee_gateway.controllers.chat_controller.compute_tee_msg_hash")
@patch("tee_gateway.controllers.chat_controller.get_tee_keys")
@patch("tee_gateway.controllers.chat_controller.convert_messages")
@patch("tee_gateway.controllers.chat_controller.get_chat_model_cached")
def test_no_format_does_not_bind(
self, mock_get_model, mock_convert, mock_tee_keys, mock_hash
):
from tee_gateway.controllers.chat_controller import (
_create_non_streaming_response,
)

mock_model = MagicMock()
mock_get_model.return_value = mock_model

mock_response = MagicMock()
mock_response.content = "result"
mock_response.tool_calls = None
mock_model.invoke.return_value = mock_response

mock_convert.return_value = []
mock_hash.return_value = (b"hash", "input_hex", "output_hex")
mock_keys = MagicMock()
mock_keys.sign_data.return_value = "sig"
mock_keys.get_tee_id.return_value = "abc"
mock_tee_keys.return_value = mock_keys

req = CreateChatCompletionRequest(
model="gpt-4o",
messages=[],
temperature=1.0,
)

_create_non_streaming_response(req)

mock_model.bind.assert_not_called()

@patch("tee_gateway.controllers.chat_controller.compute_tee_msg_hash")
@patch("tee_gateway.controllers.chat_controller.get_tee_keys")
@patch("tee_gateway.controllers.chat_controller.convert_messages")
@patch("tee_gateway.controllers.chat_controller.get_chat_model_cached")
def test_json_schema_binds_full_schema(
self, mock_get_model, mock_convert, mock_tee_keys, mock_hash
):
from tee_gateway.controllers.chat_controller import (
_create_non_streaming_response,
)

mock_model = MagicMock()
mock_bound = MagicMock()
mock_model.bind.return_value = mock_bound
mock_get_model.return_value = mock_model

mock_response = MagicMock()
mock_response.content = '{"name": "Alice", "age": 30}'
mock_response.tool_calls = None
mock_bound.invoke.return_value = mock_response

mock_convert.return_value = []
mock_hash.return_value = (b"hash", "input_hex", "output_hex")
mock_keys = MagicMock()
mock_keys.sign_data.return_value = "sig"
mock_keys.get_tee_id.return_value = "abc"
mock_tee_keys.return_value = mock_keys

rf = {
"type": "json_schema",
"json_schema": {
"name": "user_info",
"strict": True,
"schema": {
"type": "object",
"properties": {
"name": {"type": "string"},
"age": {"type": "integer"},
},
"required": ["name", "age"],
"additionalProperties": False,
},
},
}

req = CreateChatCompletionRequest(
model="gpt-4o",
messages=[],
temperature=1.0,
response_format=rf,
)

_create_non_streaming_response(req)

mock_model.bind.assert_called_once_with(response_format=rf)


class TestResponseFormatWithTools(unittest.TestCase):
"""Tests that response_format works alongside tool binding."""

@patch("tee_gateway.controllers.chat_controller.compute_tee_msg_hash")
@patch("tee_gateway.controllers.chat_controller.get_tee_keys")
@patch("tee_gateway.controllers.chat_controller.convert_messages")
@patch("tee_gateway.controllers.chat_controller.get_chat_model_cached")
def test_tools_and_response_format_both_bind(
self, mock_get_model, mock_convert, mock_tee_keys, mock_hash
):
from tee_gateway.controllers.chat_controller import (
_create_non_streaming_response,
)

mock_model = MagicMock()
mock_after_tools = MagicMock()
mock_after_format = MagicMock()
mock_model.bind_tools.return_value = mock_after_tools
mock_after_tools.bind.return_value = mock_after_format
mock_get_model.return_value = mock_model

mock_response = MagicMock()
mock_response.content = '{"result": 42}'
mock_response.tool_calls = None
mock_after_format.invoke.return_value = mock_response

mock_convert.return_value = []
mock_hash.return_value = (b"hash", "input_hex", "output_hex")
mock_keys = MagicMock()
mock_keys.sign_data.return_value = "sig"
mock_keys.get_tee_id.return_value = "abc"
mock_tee_keys.return_value = mock_keys

req = CreateChatCompletionRequest(
model="gpt-4o",
messages=[],
temperature=1.0,
tools=[
{"type": "function", "function": {"name": "calc", "parameters": {}}}
],
response_format={"type": "json_object"},
)

_create_non_streaming_response(req)

mock_model.bind_tools.assert_called_once()
mock_after_tools.bind.assert_called_once_with(
response_format={"type": "json_object"}
)
mock_after_format.invoke.assert_called_once()


if __name__ == "__main__":
unittest.main()
Loading