Skip to content

Commit c4357dc

Browse files
authored
Server: Change Invalid Schema from Server Error (500) to User Error (400) (#17572)
* Make invalid schema a user error (400) * Move invalid_argument exception handler to ex_wrapper * Fix test * Simplify test back to original pattern
1 parent e148380 commit c4357dc

File tree

6 files changed

+44
-38
lines changed

6 files changed

+44
-38
lines changed

common/chat.cpp

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -163,7 +163,7 @@ common_chat_tool_choice common_chat_tool_choice_parse_oaicompat(const std::strin
163163
if (tool_choice == "required") {
164164
return COMMON_CHAT_TOOL_CHOICE_REQUIRED;
165165
}
166-
throw std::runtime_error("Invalid tool_choice: " + tool_choice);
166+
throw std::invalid_argument("Invalid tool_choice: " + tool_choice);
167167
}
168168

169169
bool common_chat_templates_support_enable_thinking(const common_chat_templates * chat_templates) {
@@ -186,17 +186,17 @@ std::vector<common_chat_msg> common_chat_msgs_parse_oaicompat(const json & messa
186186
try {
187187

188188
if (!messages.is_array()) {
189-
throw std::runtime_error("Expected 'messages' to be an array, got " + messages.dump());
189+
throw std::invalid_argument("Expected 'messages' to be an array, got " + messages.dump());
190190
}
191191

192192
for (const auto & message : messages) {
193193
if (!message.is_object()) {
194-
throw std::runtime_error("Expected 'message' to be an object, got " + message.dump());
194+
throw std::invalid_argument("Expected 'message' to be an object, got " + message.dump());
195195
}
196196

197197
common_chat_msg msg;
198198
if (!message.contains("role")) {
199-
throw std::runtime_error("Missing 'role' in message: " + message.dump());
199+
throw std::invalid_argument("Missing 'role' in message: " + message.dump());
200200
}
201201
msg.role = message.at("role");
202202

@@ -209,37 +209,37 @@ std::vector<common_chat_msg> common_chat_msgs_parse_oaicompat(const json & messa
209209
} else if (content.is_array()) {
210210
for (const auto & part : content) {
211211
if (!part.contains("type")) {
212-
throw std::runtime_error("Missing content part type: " + part.dump());
212+
throw std::invalid_argument("Missing content part type: " + part.dump());
213213
}
214214
const auto & type = part.at("type");
215215
if (type != "text") {
216-
throw std::runtime_error("Unsupported content part type: " + type.dump());
216+
throw std::invalid_argument("Unsupported content part type: " + type.dump());
217217
}
218218
common_chat_msg_content_part msg_part;
219219
msg_part.type = type;
220220
msg_part.text = part.at("text");
221221
msg.content_parts.push_back(msg_part);
222222
}
223223
} else if (!content.is_null()) {
224-
throw std::runtime_error("Invalid 'content' type: expected string or array, got " + content.dump() + " (ref: https://github.com/ggml-org/llama.cpp/issues/8367)");
224+
throw std::invalid_argument("Invalid 'content' type: expected string or array, got " + content.dump() + " (ref: https://github.com/ggml-org/llama.cpp/issues/8367)");
225225
}
226226
}
227227
if (has_tool_calls) {
228228
for (const auto & tool_call : message.at("tool_calls")) {
229229
common_chat_tool_call tc;
230230
if (!tool_call.contains("type")) {
231-
throw std::runtime_error("Missing tool call type: " + tool_call.dump());
231+
throw std::invalid_argument("Missing tool call type: " + tool_call.dump());
232232
}
233233
const auto & type = tool_call.at("type");
234234
if (type != "function") {
235-
throw std::runtime_error("Unsupported tool call type: " + tool_call.dump());
235+
throw std::invalid_argument("Unsupported tool call type: " + tool_call.dump());
236236
}
237237
if (!tool_call.contains("function")) {
238-
throw std::runtime_error("Missing tool call function: " + tool_call.dump());
238+
throw std::invalid_argument("Missing tool call function: " + tool_call.dump());
239239
}
240240
const auto & fc = tool_call.at("function");
241241
if (!fc.contains("name")) {
242-
throw std::runtime_error("Missing tool call name: " + tool_call.dump());
242+
throw std::invalid_argument("Missing tool call name: " + tool_call.dump());
243243
}
244244
tc.name = fc.at("name");
245245
tc.arguments = fc.at("arguments");
@@ -250,7 +250,7 @@ std::vector<common_chat_msg> common_chat_msgs_parse_oaicompat(const json & messa
250250
}
251251
}
252252
if (!has_content && !has_tool_calls) {
253-
throw std::runtime_error("Expected 'content' or 'tool_calls' (ref: https://github.com/ggml-org/llama.cpp/issues/8367 & https://github.com/ggml-org/llama.cpp/issues/12279)");
253+
throw std::invalid_argument("Expected 'content' or 'tool_calls' (ref: https://github.com/ggml-org/llama.cpp/issues/8367 & https://github.com/ggml-org/llama.cpp/issues/12279)");
254254
}
255255
if (message.contains("reasoning_content")) {
256256
msg.reasoning_content = message.at("reasoning_content");
@@ -353,18 +353,18 @@ std::vector<common_chat_tool> common_chat_tools_parse_oaicompat(const json & too
353353
try {
354354
if (!tools.is_null()) {
355355
if (!tools.is_array()) {
356-
throw std::runtime_error("Expected 'tools' to be an array, got " + tools.dump());
356+
throw std::invalid_argument("Expected 'tools' to be an array, got " + tools.dump());
357357
}
358358
for (const auto & tool : tools) {
359359
if (!tool.contains("type")) {
360-
throw std::runtime_error("Missing tool type: " + tool.dump());
360+
throw std::invalid_argument("Missing tool type: " + tool.dump());
361361
}
362362
const auto & type = tool.at("type");
363363
if (!type.is_string() || type != "function") {
364-
throw std::runtime_error("Unsupported tool type: " + tool.dump());
364+
throw std::invalid_argument("Unsupported tool type: " + tool.dump());
365365
}
366366
if (!tool.contains("function")) {
367-
throw std::runtime_error("Missing tool function: " + tool.dump());
367+
throw std::invalid_argument("Missing tool function: " + tool.dump());
368368
}
369369

370370
const auto & function = tool.at("function");

common/json-schema-to-grammar.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -974,7 +974,7 @@ class SchemaConverter {
974974

975975
void check_errors() {
976976
if (!_errors.empty()) {
977-
throw std::runtime_error("JSON schema conversion failed:\n" + string_join(_errors, "\n"));
977+
throw std::invalid_argument("JSON schema conversion failed:\n" + string_join(_errors, "\n"));
978978
}
979979
if (!_warnings.empty()) {
980980
fprintf(stderr, "WARNING: JSON schema conversion was incomplete: %s\n", string_join(_warnings, "; ").c_str());

tests/test-json-schema-to-grammar.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1375,7 +1375,7 @@ int main() {
13751375
try {
13761376
tc.verify(json_schema_to_grammar(nlohmann::ordered_json::parse(tc.schema), true));
13771377
tc.verify_status(SUCCESS);
1378-
} catch (const std::runtime_error & ex) {
1378+
} catch (const std::invalid_argument & ex) {
13791379
fprintf(stderr, "Error: %s\n", ex.what());
13801380
tc.verify_status(FAILURE);
13811381
}

tools/server/server-common.cpp

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -819,26 +819,26 @@ json oaicompat_chat_params_parse(
819819
auto schema_wrapper = json_value(response_format, "json_schema", json::object());
820820
json_schema = json_value(schema_wrapper, "schema", json::object());
821821
} else if (!response_type.empty() && response_type != "text") {
822-
throw std::runtime_error("response_format type must be one of \"text\" or \"json_object\", but got: " + response_type);
822+
throw std::invalid_argument("response_format type must be one of \"text\" or \"json_object\", but got: " + response_type);
823823
}
824824
}
825825

826826
// get input files
827827
if (!body.contains("messages")) {
828-
throw std::runtime_error("'messages' is required");
828+
throw std::invalid_argument("'messages' is required");
829829
}
830830
json & messages = body.at("messages");
831831
if (!messages.is_array()) {
832-
throw std::runtime_error("Expected 'messages' to be an array");
832+
throw std::invalid_argument("Expected 'messages' to be an array");
833833
}
834834
for (auto & msg : messages) {
835835
std::string role = json_value(msg, "role", std::string());
836836
if (role != "assistant" && !msg.contains("content")) {
837-
throw std::runtime_error("All non-assistant messages must contain 'content'");
837+
throw std::invalid_argument("All non-assistant messages must contain 'content'");
838838
}
839839
if (role == "assistant") {
840840
if (!msg.contains("content") && !msg.contains("tool_calls")) {
841-
throw std::runtime_error("Assistant message must contain either 'content' or 'tool_calls'!");
841+
throw std::invalid_argument("Assistant message must contain either 'content' or 'tool_calls'!");
842842
}
843843
if (!msg.contains("content")) {
844844
continue; // avoid errors with no content
@@ -850,7 +850,7 @@ json oaicompat_chat_params_parse(
850850
}
851851

852852
if (!content.is_array()) {
853-
throw std::runtime_error("Expected 'content' to be a string or an array");
853+
throw std::invalid_argument("Expected 'content' to be a string or an array");
854854
}
855855

856856
for (auto & p : content) {
@@ -884,11 +884,11 @@ json oaicompat_chat_params_parse(
884884
// try to decode base64 image
885885
std::vector<std::string> parts = string_split<std::string>(url, /*separator*/ ',');
886886
if (parts.size() != 2) {
887-
throw std::runtime_error("Invalid image_url.url value");
887+
throw std::invalid_argument("Invalid image_url.url value");
888888
} else if (!string_starts_with(parts[0], "data:image/")) {
889-
throw std::runtime_error("Invalid image_url.url format: " + parts[0]);
889+
throw std::invalid_argument("Invalid image_url.url format: " + parts[0]);
890890
} else if (!string_ends_with(parts[0], "base64")) {
891-
throw std::runtime_error("image_url.url must be base64 encoded");
891+
throw std::invalid_argument("image_url.url must be base64 encoded");
892892
} else {
893893
auto base64_data = parts[1];
894894
auto decoded_data = base64_decode(base64_data);
@@ -911,7 +911,7 @@ json oaicompat_chat_params_parse(
911911
std::string format = json_value(input_audio, "format", std::string());
912912
// while we also support flac, we don't allow it here so we matches the OAI spec
913913
if (format != "wav" && format != "mp3") {
914-
throw std::runtime_error("input_audio.format must be either 'wav' or 'mp3'");
914+
throw std::invalid_argument("input_audio.format must be either 'wav' or 'mp3'");
915915
}
916916
auto decoded_data = base64_decode(data); // expected to be base64 encoded
917917
out_files.push_back(decoded_data);
@@ -922,7 +922,7 @@ json oaicompat_chat_params_parse(
922922
p.erase("input_audio");
923923

924924
} else if (type != "text") {
925-
throw std::runtime_error("unsupported content[].type");
925+
throw std::invalid_argument("unsupported content[].type");
926926
}
927927
}
928928
}
@@ -940,7 +940,7 @@ json oaicompat_chat_params_parse(
940940
inputs.enable_thinking = opt.enable_thinking;
941941
if (!inputs.tools.empty() && inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_NONE) {
942942
if (body.contains("grammar")) {
943-
throw std::runtime_error("Cannot use custom grammar constraints with tools.");
943+
throw std::invalid_argument("Cannot use custom grammar constraints with tools.");
944944
}
945945
llama_params["parse_tool_calls"] = true;
946946
}
@@ -959,7 +959,7 @@ json oaicompat_chat_params_parse(
959959
} else if (enable_thinking_kwarg == "false") {
960960
inputs.enable_thinking = false;
961961
} else if (!enable_thinking_kwarg.empty() && enable_thinking_kwarg[0] == '"') {
962-
throw std::runtime_error("invalid type for \"enable_thinking\" (expected boolean, got string)");
962+
throw std::invalid_argument("invalid type for \"enable_thinking\" (expected boolean, got string)");
963963
}
964964

965965
// if the assistant message appears at the end of list, we do not add end-of-turn token
@@ -972,14 +972,14 @@ json oaicompat_chat_params_parse(
972972

973973
/* sanity check, max one assistant message at the end of the list */
974974
if (!inputs.messages.empty() && inputs.messages.back().role == "assistant"){
975-
throw std::runtime_error("Cannot have 2 or more assistant messages at the end of the list.");
975+
throw std::invalid_argument("Cannot have 2 or more assistant messages at the end of the list.");
976976
}
977977

978978
/* TODO: test this properly */
979979
inputs.reasoning_format = COMMON_REASONING_FORMAT_NONE;
980980

981981
if ( inputs.enable_thinking ) {
982-
throw std::runtime_error("Assistant response prefill is incompatible with enable_thinking.");
982+
throw std::invalid_argument("Assistant response prefill is incompatible with enable_thinking.");
983983
}
984984

985985
inputs.add_generation_prompt = true;
@@ -1020,18 +1020,18 @@ json oaicompat_chat_params_parse(
10201020
// Handle "n" field
10211021
int n_choices = json_value(body, "n", 1);
10221022
if (n_choices != 1) {
1023-
throw std::runtime_error("Only one completion choice is allowed");
1023+
throw std::invalid_argument("Only one completion choice is allowed");
10241024
}
10251025

10261026
// Handle "logprobs" field
10271027
// TODO: The response format of this option is not yet OAI-compatible, but seems like no one really using it; We may need to fix it in the future
10281028
if (json_value(body, "logprobs", false)) {
10291029
if (has_tools && stream) {
1030-
throw std::runtime_error("logprobs is not supported with tools + stream");
1030+
throw std::invalid_argument("logprobs is not supported with tools + stream");
10311031
}
10321032
llama_params["n_probs"] = json_value(body, "top_logprobs", 20);
10331033
} else if (body.contains("top_logprobs") && !body.at("top_logprobs").is_null()) {
1034-
throw std::runtime_error("top_logprobs requires logprobs to be set to true");
1034+
throw std::invalid_argument("top_logprobs requires logprobs to be set to true");
10351035
}
10361036

10371037
// Copy remaining properties to llama_params

tools/server/server.cpp

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,18 +34,24 @@ static inline void signal_handler(int signal) {
3434
static server_http_context::handler_t ex_wrapper(server_http_context::handler_t func) {
3535
return [func = std::move(func)](const server_http_req & req) -> server_http_res_ptr {
3636
std::string message;
37+
error_type error;
3738
try {
3839
return func(req);
40+
} catch (const std::invalid_argument & e) {
41+
error = ERROR_TYPE_INVALID_REQUEST;
42+
message = e.what();
3943
} catch (const std::exception & e) {
44+
error = ERROR_TYPE_SERVER;
4045
message = e.what();
4146
} catch (...) {
47+
error = ERROR_TYPE_SERVER;
4248
message = "unknown error";
4349
}
4450

4551
auto res = std::make_unique<server_http_res>();
4652
res->status = 500;
4753
try {
48-
json error_data = format_error_response(message, ERROR_TYPE_SERVER);
54+
json error_data = format_error_response(message, error);
4955
res->status = json_value(error_data, "code", 500);
5056
res->data = safe_json_to_str({{ "error", error_data }});
5157
SRV_WRN("got exception: %s\n", res->data.c_str());

tools/server/tests/unit/test_chat_completion.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -199,7 +199,7 @@ def test_completion_with_response_format(response_format: dict, n_predicted: int
199199
choice = res.body["choices"][0]
200200
assert match_regex(re_content, choice["message"]["content"])
201201
else:
202-
assert res.status_code != 200
202+
assert res.status_code == 400
203203
assert "error" in res.body
204204

205205

0 commit comments

Comments
 (0)