Skip to content

Commit 7c6980a

Browse files
committed
Make invalid schema a user error (400)
1 parent ddf9f94 commit 7c6980a

File tree

4 files changed

+27
-18
lines changed

4 files changed

+27
-18
lines changed

common/json-schema-to-grammar.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -974,7 +974,7 @@ class SchemaConverter {
974974

975975
void check_errors() {
976976
if (!_errors.empty()) {
977-
throw std::runtime_error("JSON schema conversion failed:\n" + string_join(_errors, "\n"));
977+
throw std::invalid_argument("JSON schema conversion failed:\n" + string_join(_errors, "\n"));
978978
}
979979
if (!_warnings.empty()) {
980980
fprintf(stderr, "WARNING: JSON schema conversion was incomplete: %s\n", string_join(_warnings, "; ").c_str());

tests/test-json-schema-to-grammar.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1349,7 +1349,7 @@ int main() {
13491349
try {
13501350
tc.verify(json_schema_to_grammar(nlohmann::ordered_json::parse(tc.schema), true));
13511351
tc.verify_status(SUCCESS);
1352-
} catch (const std::runtime_error & ex) {
1352+
} catch (const std::invalid_argument & ex) {
13531353
fprintf(stderr, "Error: %s\n", ex.what());
13541354
tc.verify_status(FAILURE);
13551355
}

tools/server/server.cpp

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2979,10 +2979,20 @@ struct server_routes {
29792979
server_http_context::handler_t post_chat_completions = [this](const server_http_req & req) {
29802980
std::vector<raw_buffer> files;
29812981
json body = json::parse(req.body);
2982-
json body_parsed = oaicompat_chat_params_parse(
2983-
body,
2984-
ctx_server.oai_parser_opt,
2985-
files);
2982+
2983+
json body_parsed;
2984+
try {
2985+
body_parsed = oaicompat_chat_params_parse(
2986+
body,
2987+
ctx_server.oai_parser_opt,
2988+
files
2989+
);
2990+
} catch (const std::invalid_argument& e) {
2991+
auto res = std::make_unique<server_res_generator>(ctx_server);
2992+
res->error(format_error_response(e.what(), ERROR_TYPE_INVALID_REQUEST));
2993+
return res;
2994+
}
2995+
29862996
return handle_completions_impl(
29872997
SERVER_TASK_TYPE_COMPLETION,
29882998
body_parsed,

tools/server/tests/unit/test_chat_completion.py

Lines changed: 11 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -171,18 +171,18 @@ def test_apply_chat_template():
171171
assert res.body["prompt"] == "<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>You are a test.<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|USER_TOKEN|>Hi there<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>"
172172

173173

174-
@pytest.mark.parametrize("response_format,n_predicted,re_content", [
175-
({"type": "json_object", "schema": {"const": "42"}}, 6, "\"42\""),
176-
({"type": "json_object", "schema": {"items": [{"type": "integer"}]}}, 10, "[ -3000 ]"),
177-
({"type": "json_schema", "json_schema": {"schema": {"const": "foooooo"}}}, 10, "\"foooooo\""),
178-
({"type": "json_object"}, 10, "(\\{|John)+"),
179-
({"type": "sound"}, 0, None),
174+
@pytest.mark.parametrize("response_format,n_predicted,re_content,expected_status_code", [
175+
({"type": "json_object", "schema": {"const": "42"}}, 6, "\"42\"", 200),
176+
({"type": "json_object", "schema": {"items": [{"type": "integer"}]}}, 10, "[ -3000 ]", 200),
177+
({"type": "json_schema", "json_schema": {"schema": {"const": "foooooo"}}}, 10, "\"foooooo\"", 200),
178+
({"type": "json_object"}, 10, "(\\{|John)+", 200),
179+
({"type": "sound"}, 0, None, 500),
180180
# invalid response format (expected to fail)
181-
({"type": "json_object", "schema": 123}, 0, None),
182-
({"type": "json_object", "schema": {"type": 123}}, 0, None),
183-
({"type": "json_object", "schema": {"type": "hiccup"}}, 0, None),
181+
({"type": "json_object", "schema": 123}, 0, None, 400),
182+
({"type": "json_object", "schema": {"type": 123}}, 0, None, 400),
183+
({"type": "json_object", "schema": {"type": "hiccup"}}, 0, None, 400),
184184
])
185-
def test_completion_with_response_format(response_format: dict, n_predicted: int, re_content: str | None):
185+
def test_completion_with_response_format(response_format: dict, n_predicted: int, re_content: str | None, expected_status_code: int):
186186
global server
187187
server.start()
188188
res = server.make_request("POST", "/chat/completions", data={
@@ -193,12 +193,11 @@ def test_completion_with_response_format(response_format: dict, n_predicted: int
193193
],
194194
"response_format": response_format,
195195
})
196+
assert res.status_code == expected_status_code
196197
if re_content is not None:
197-
assert res.status_code == 200
198198
choice = res.body["choices"][0]
199199
assert match_regex(re_content, choice["message"]["content"])
200200
else:
201-
assert res.status_code != 200
202201
assert "error" in res.body
203202

204203

0 commit comments

Comments
 (0)