Skip to content

Commit 398d0fe

Browse files
65a65a
authored andcommitted
server : Support multimodal completion and embeddings prompts in JSON format
- Use server_tokens in more places in server and util.cpp - Convert most functions that used llama_tokens to server_tokens - Modify input tokenizer to handle JSON objects as subprompts - Break out MTMD prompt parsing into utility function - Support JSON objects with multimodal_data arrays for MTMD prompts along with other existing types - Add capability to model endpoint to indicate if client can send multimodal data - Add tests
1 parent 9515c61 commit 398d0fe

File tree

5 files changed

+324
-138
lines changed

5 files changed

+324
-138
lines changed

tools/server/README.md

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -226,6 +226,10 @@ services:
226226
### Multimodal support
227227
228228
Multimodal support was added in [#12898](https://github.com/ggml-org/llama.cpp/pull/12898) and is currently an experimental feature.
229+
It is currently available in the following endpoints:
230+
- The OAI-compatible chat endpoint.
231+
- The non-OAI-compatible completions endpoint.
232+
- The non-OAI-compatible embeddings endpoint.
229233
230234
For more details, please refer to [multimodal documentation](../../docs/multimodal.md)
231235
@@ -400,12 +404,15 @@ These input shapes and data type are allowed for `prompt`:
400404
- Single string: `"string"`
401405
- Single sequence of tokens: `[12, 34, 56]`
402406
- Mixed tokens and strings: `[12, 34, "string", 56, 78]`
407+
- A JSON object which optionally contains multimodal data: `{ "prompt_string": "string", "multimodal_data": ["base64"] }`
403408

404409
Multiple prompts are also supported. In this case, the completion result will be an array.
405410

406411
- Only strings: `["string1", "string2"]`
407-
- Strings and sequences of tokens: `["string1", [12, 34, 56]]`
408-
- Mixed types: `[[12, 34, "string", 56, 78], [12, 34, 56], "string"]`
412+
- Strings, JSON objects, and sequences of tokens: `["string1", [12, 34, 56], { "prompt_string": "string", "multimodal_data": ["base64"]}]`
413+
- Mixed types: `[[12, 34, "string", 56, 78], [12, 34, 56], "string", { "prompt_string": "string" }]`
414+
415+
Note for `multimodal_data` in JSON object prompts. This should be an array of strings, containing base64 encoded multimodal data such as images and audio. There must be an identical number of MTMD media markers in the string prompt element which act as placeholders for the data provided to this parameter. The multimodal data files will be substituted in order. The marker string (e.g. `<__media__>`) can be found by calling `mtmd_default_marker()` defined in [the MTMD C API](https://github.com/ggml-org/llama.cpp/blob/5fd160bbd9d70b94b5b11b0001fd7f477005e4a0/tools/mtmd/mtmd.h#L87). A client *must not* specify this field unless the server has the multimodal capability. Clients should check `/models` or `/v1/models` for the `multimodal` capability before a multimodal request.
409416

410417
`temperature`: Adjust the randomness of the generated text. Default: `0.8`
411418

@@ -477,8 +484,6 @@ These words will not be included in the completion, so make sure to add them to
477484

478485
`t_max_predict_ms`: Set a time limit in milliseconds for the prediction (a.k.a. text-generation) phase. The timeout will trigger if the generation takes more than the specified time (measured since the first token was generated) and if a new-line character has already been generated. Useful for FIM applications. Default: `0`, which is disabled.
479486

480-
`image_data`: An array of objects to hold base64-encoded image `data` and its `id`s to be reference in `prompt`. You can determine the place of the image in the prompt as in the following: `USER:[img-12]Describe the image in detail.\nASSISTANT:`. In this case, `[img-12]` will be replaced by the embeddings of the image with id `12` in the following `image_data` array: `{..., "image_data": [{"data": "<BASE64_STRING>", "id": 12}]}`. Use `image_data` only with multimodal models, e.g., LLaVA.
481-
482487
`id_slot`: Assign the completion task to an specific slot. If is -1 the task will be assigned to a Idle slot. Default: `-1`
483488

484489
`cache_prompt`: Re-use KV cache from a previous request if possible. This way the common prefix does not have to be re-processed, only the suffix that differs between the requests. Because (depending on the backend) the logits are **not** guaranteed to be bit-for-bit identical for different batch sizes (prompt processing vs. token generation) enabling this option can cause nondeterministic results. Default: `true`
@@ -638,12 +643,12 @@ Returns a JSON object with a field `prompt` containing a string of the input mes
638643

639644
The same as [the embedding example](../embedding) does.
640645

646+
This endpoint also supports multimodal embeddings. See the documentation for [completions prompts](../completions) for details on how to send a multimodal prompt.
647+
641648
*Options:*
642649

643650
`content`: Set the text to process.
644651

645-
`image_data`: An array of objects to hold base64-encoded image `data` and its `id`s to be reference in `content`. You can determine the place of the image in the content as in the following: `Image: [img-21].\nCaption: This is a picture of a house`. In this case, `[img-21]` will be replaced by the embeddings of the image with id `21` in the following `image_data` array: `{..., "image_data": [{"data": "<BASE64_STRING>", "id": 21}]}`. Use `image_data` only with multimodal models, e.g., LLaVA.
646-
647652
`embd_normalize`: Normalization for pooled embeddings. Can be one of the following values:
648653
```
649654
-1: No normalization

tools/server/server.cpp

Lines changed: 20 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -4181,6 +4181,7 @@ int main(int argc, char ** argv) {
41814181
};
41824182

41834183
const auto handle_api_show = [&ctx_server, &res_ok](const httplib::Request &, httplib::Response & res) {
4184+
bool has_mtmd = ctx_server.mctx != nullptr;
41844185
json data = {
41854186
{
41864187
"template", common_chat_templates_source(ctx_server.chat_templates.get()),
@@ -4202,7 +4203,7 @@ int main(int argc, char ** argv) {
42024203
{"quantization_level", ""}
42034204
}},
42044205
{"model_info", ""},
4205-
{"capabilities", {"completion"}}
4206+
{"capabilities", has_mtmd ? json({"completion","multimodal"}) : json({"completion"})}
42064207
};
42074208

42084209
res_ok(res, data);
@@ -4228,56 +4229,15 @@ int main(int argc, char ** argv) {
42284229
// TODO: this log can become very long, put it behind a flag or think about a more compact format
42294230
//SRV_DBG("Prompt: %s\n", prompt.is_string() ? prompt.get<std::string>().c_str() : prompt.dump(2).c_str());
42304231

4231-
// process files
4232-
mtmd::bitmaps bitmaps;
4233-
const bool has_mtmd = ctx_server.mctx != nullptr;
4234-
{
4235-
if (!has_mtmd && !files.empty()) {
4236-
throw std::runtime_error("This server does not support multimodal");
4237-
}
4238-
for (auto & file : files) {
4239-
mtmd::bitmap bmp(mtmd_helper_bitmap_init_from_buf(ctx_server.mctx, file.data(), file.size()));
4240-
if (!bmp.ptr) {
4241-
throw std::runtime_error("Failed to load image or audio file");
4242-
}
4243-
// calculate bitmap hash (for KV caching)
4244-
std::string hash = fnv_hash(bmp.data(), bmp.n_bytes());
4245-
bmp.set_id(hash.c_str());
4246-
bitmaps.entries.push_back(std::move(bmp));
4247-
}
4248-
}
4249-
42504232
// process prompt
42514233
std::vector<server_tokens> inputs;
42524234

4253-
if (oaicompat && has_mtmd) {
4254-
// multimodal
4255-
std::string prompt_str = prompt.get<std::string>();
4256-
mtmd_input_text inp_txt = {
4257-
prompt_str.c_str(),
4258-
/* add_special */ true,
4259-
/* parse_special */ true,
4260-
};
4261-
mtmd::input_chunks chunks(mtmd_input_chunks_init());
4262-
auto bitmaps_c_ptr = bitmaps.c_ptr();
4263-
int32_t tokenized = mtmd_tokenize(ctx_server.mctx,
4264-
chunks.ptr.get(),
4265-
&inp_txt,
4266-
bitmaps_c_ptr.data(),
4267-
bitmaps_c_ptr.size());
4268-
if (tokenized != 0) {
4269-
throw std::runtime_error("Failed to tokenize prompt");
4270-
}
4271-
4272-
server_tokens tmp(chunks, true);
4273-
inputs.push_back(std::move(tmp));
4235+
if (oaicompat && ctx_server.mctx != nullptr) {
4236+
// This is the case used by OAI compatible chat path with MTMD. TODO It can be moved to the path below.
4237+
inputs.push_back(process_mtmd_prompt(ctx_server.mctx, prompt.get<std::string>(), files));
42744238
} else {
4275-
// non-multimodal version
4276-
auto tokenized_prompts = tokenize_input_prompts(ctx_server.vocab, prompt, true, true);
4277-
for (auto & p : tokenized_prompts) {
4278-
auto tmp = server_tokens(p, ctx_server.mctx != nullptr);
4279-
inputs.push_back(std::move(tmp));
4280-
}
4239+
// Everything else, including multimodal completions.
4240+
inputs = tokenize_input_prompts(ctx_server.vocab, ctx_server.mctx, prompt, true, true);
42814241
}
42824242

42834243
tasks.reserve(inputs.size());
@@ -4446,7 +4406,7 @@ int main(int argc, char ** argv) {
44464406
data["input_extra"] = input_extra; // default to empty array if it's not exist
44474407

44484408
std::string prompt = json_value(data, "prompt", std::string());
4449-
std::vector<llama_tokens> tokenized_prompts = tokenize_input_prompts(ctx_server.vocab, prompt, false, true);
4409+
std::vector<server_tokens> tokenized_prompts = tokenize_input_prompts(ctx_server.vocab, ctx_server.mctx, prompt, false, true);
44504410
SRV_DBG("creating infill tasks, n_prompts = %d\n", (int) tokenized_prompts.size());
44514411
data["prompt"] = format_infill(
44524412
ctx_server.vocab,
@@ -4457,7 +4417,7 @@ int main(int argc, char ** argv) {
44574417
ctx_server.params_base.n_predict,
44584418
ctx_server.slots[0].n_ctx, // TODO: there should be a better way
44594419
ctx_server.params_base.spm_infill,
4460-
tokenized_prompts[0]
4420+
tokenized_prompts[0].get_text_tokens() // TODO: this could maybe be multimodal.
44614421
);
44624422

44634423
std::vector<raw_buffer> files; // dummy
@@ -4506,7 +4466,7 @@ int main(int argc, char ** argv) {
45064466
if (current_state == SERVER_STATE_READY) {
45074467
model_meta = ctx_server.model_meta();
45084468
}
4509-
4469+
bool has_mtmd = ctx_server.mctx != nullptr;
45104470
json models = {
45114471
{"models", {
45124472
{
@@ -4518,7 +4478,7 @@ int main(int argc, char ** argv) {
45184478
{"type", "model"},
45194479
{"description", ""},
45204480
{"tags", {""}},
4521-
{"capabilities", {"completion"}},
4481+
{"capabilities", has_mtmd ? json({"completion","multimodal"}) : json({"completion"})},
45224482
{"parameters", ""},
45234483
{"details", {
45244484
{"parent_model", ""},
@@ -4635,7 +4595,7 @@ int main(int argc, char ** argv) {
46354595
}
46364596
}
46374597

4638-
auto tokenized_prompts = tokenize_input_prompts(ctx_server.vocab, prompt, true, true);
4598+
auto tokenized_prompts = tokenize_input_prompts(ctx_server.vocab, ctx_server.mctx, prompt, true, true);
46394599
for (const auto & tokens : tokenized_prompts) {
46404600
// this check is necessary for models that do not add BOS token to the input
46414601
if (tokens.empty()) {
@@ -4663,7 +4623,7 @@ int main(int argc, char ** argv) {
46634623

46644624
task.id = ctx_server.queue_tasks.get_new_id();
46654625
task.index = i;
4666-
task.prompt_tokens = server_tokens(tokenized_prompts[i], ctx_server.mctx != nullptr);
4626+
task.prompt_tokens = std::move(tokenized_prompts[i]);
46674627

46684628
// OAI-compat
46694629
task.params.oaicompat = oaicompat;
@@ -4750,22 +4710,25 @@ int main(int argc, char ** argv) {
47504710
return;
47514711
}
47524712

4753-
llama_tokens tokenized_query = tokenize_input_prompts(ctx_server.vocab, query, /* add_special */ false, true)[0];
4713+
std::vector<server_tokens> tokenized_queries = tokenize_input_prompts(ctx_server.vocab, ctx_server.mctx, query, /* add_special */ false, true);
4714+
if (tokenized_queries.size() != 1) {
4715+
res_error(res, format_error_response("\"query\" must contain only a single prompt", ERROR_TYPE_INVALID_REQUEST));
4716+
}
47544717

47554718
// create and queue the task
47564719
json responses = json::array();
47574720
bool error = false;
47584721
std::unordered_set<int> task_ids;
47594722
{
47604723
std::vector<server_task> tasks;
4761-
auto tokenized_docs = tokenize_input_prompts(ctx_server.vocab, documents, /* add_special */ false, true);
4724+
auto tokenized_docs = tokenize_input_prompts(ctx_server.vocab, ctx_server.mctx, documents, /* add_special */ false, true);
47624725
tasks.reserve(tokenized_docs.size());
47634726
for (size_t i = 0; i < tokenized_docs.size(); i++) {
4764-
auto tmp = format_rerank(ctx_server.vocab, tokenized_query, tokenized_docs[i]);
4727+
auto tmp = format_rerank(ctx_server.vocab, tokenized_queries[0], tokenized_docs[i]);
47654728
server_task task = server_task(SERVER_TASK_TYPE_RERANK);
47664729
task.id = ctx_server.queue_tasks.get_new_id();
47674730
task.index = i;
4768-
task.prompt_tokens = server_tokens(tmp, ctx_server.mctx != nullptr);
4731+
task.prompt_tokens = std::move(tmp);
47694732
tasks.push_back(std::move(task));
47704733
}
47714734

tools/server/tests/unit/test_completion.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66

77
server = ServerPreset.tinyllama2()
88

9+
JSON_MULTIMODAL_KEY = "multimodal_data"
10+
JSON_PROMPT_STRING_KEY = "prompt_string"
911

1012
@pytest.fixture(scope="module", autouse=True)
1113
def create_server():
@@ -231,6 +233,28 @@ def test_nocache_long_input_prompt():
231233
})
232234
assert res.status_code == 200
233235

236+
def test_json_prompt_no_mtmd():
237+
global server
238+
server.start()
239+
res = server.make_request("POST", "/completion", data={
240+
"prompt": { JSON_PROMPT_STRING_KEY: "I believe the meaning of life is" },
241+
"seed": 42,
242+
"temperature": 1.0,
243+
"cache_prompt": False,
244+
})
245+
assert res.status_code == 200
246+
247+
def test_json_prompt_mtm_error_when_not_supported():
248+
global server
249+
server.start()
250+
res = server.make_request("POST", "/completion", data={
251+
"prompt": { JSON_PROMPT_STRING_KEY: "I believe the meaning of life is <__media__>", JSON_MULTIMODAL_KEY: "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR42mNk+A8AAQUBAScY42YAAAAASUVORK5CYII=" },
252+
"seed": 42,
253+
"temperature": 1.0,
254+
"cache_prompt": False,
255+
})
256+
# MTMD is disabled on this model, so this should fail.
257+
assert res.status_code != 200
234258

235259
def test_completion_with_tokens_input():
236260
global server
@@ -269,6 +293,20 @@ def test_completion_with_tokens_input():
269293
assert len(res.body) == 2
270294
assert res.body[0]["content"] == res.body[1]["content"]
271295

296+
# mixed JSON and tokens
297+
res = server.make_request("POST", "/completion", data={
298+
"prompt": [
299+
tokens,
300+
{
301+
JSON_PROMPT_STRING_KEY: "I believe the meaning of life is",
302+
},
303+
],
304+
})
305+
assert res.status_code == 200
306+
assert type(res.body) == list
307+
assert len(res.body) == 2
308+
assert res.body[0]["content"] == res.body[1]["content"]
309+
272310
# mixed string and tokens in one sequence
273311
res = server.make_request("POST", "/completion", data={
274312
"prompt": [1, 2, 3, 4, 5, 6, prompt_str, 7, 8, 9, 10, prompt_str],

0 commit comments

Comments
 (0)