-
Notifications
You must be signed in to change notification settings - Fork 12.7k
mtmd: server: Support multimodal data prompt in /completions and /embeddings endpoint of server #15108
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Conversation
The proposal looks ok but there will be some edge cases:
I think proper test cases are required for this PR, similar to |
Let me look at those. The token case is interesting, interested in your thoughts there. The current server uses null tokens, but already knows how many to insert, which seems hard on the client (have to know the multimodal embedding size before sending raw tokens to completion endpoint). A magic token could work similarly to The multiple text prompt part is also interesting from a usability perspective. I'll think about these and come back. The multi-prompt case should be straightforward to add tests for, not sure how it ought to work yet. |
I have an idea that might be usable, namely that prompt can now contain an array of JSON objects, like |
Rough draft for the idea here (it compiles and passes existing tests): https://pastebin.com/8zek7ium Not complete or properly indented, but the idea is to use server_tokens in more places, so that the input tokenizer can branch and use MTMD tokenization where it makes sense to do so. As a side effect, probably got multimodal support in embeddings. Infill needs more work, and rerank would work if I can get the push_back(server_tokens) for server_tokens to work properly, I think. There are probably better ways to do some of this than I did, feedback welcome. |
Improved version of the rough draft that actually works, ignore indentation: https://pastebin.com/R6NdKQPP This works locally for my use case, and I've started adding tests. There are a few TODOs to make doc ranking and embeddings support multimodal usecases, and I think the oai case can also be streamlined. The general approach is as described previously: use server_tokens in more places, break out mtmd prompt parsing into a function, and change various input tokenization calls to handle server_tokens instead of llama_tokens. The request format for multiple prompts would be like this:
The JSON entry only supports what |
744d758
to
62f3bae
Compare
Added tests including vision test. Should be good for a review pass. There is some potential future work, including supporting multimodal prompts in document rerank and infill. Embeddings may already work, existing tests pass, but I didn't try it and not sure it's expected to provide a stable embedding or not. Further refactoring is possible to streamline the OAI chat path into the rest, but probably a follow up. @ngxson let me know what you think. |
Cleaned up the code quite a bit, and fixed the TODO around server_tokens.push_back(server_tokens). Now the tokenize_inputs handling reads a lot cleaner, which is nice. |
5359dda
to
234531f
Compare
I have tested this PR and it worked perfectly ✅ Here is a simple test with ![]() The prompt was:
The details of my UI integration are here oobabooga/text-generation-webui#7027 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looking good, can be merge after my comments are all resolved
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we merge this file to test_vision_api.py
? We don't have many tests atm, so we should reduce number of files
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done. Added some tests for other functionality discussed below, and updated test_completions.py (it doesn't actually have MTMD, so dropped the data).
tools/server/utils.hpp
Outdated
if (json_prompt.is_array() && !json_is_array_with_tokens(json_prompt)) { | ||
result.reserve(json_prompt.size()); | ||
for (const auto & p : json_prompt) { | ||
result.push_back(tokenize_input_subprompt(vocab,mctx, p,add_special, parse_special)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
result.push_back(tokenize_input_subprompt(vocab,mctx, p,add_special, parse_special)); | |
result.push_back(tokenize_input_subprompt(vocab,mctx, p, add_special, parse_special)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
tools/server/utils.hpp
Outdated
// array of tokens | ||
llama_tokens tmp = json_prompt.get<llama_tokens>(); | ||
return server_tokens(tmp, false); | ||
} else if (json_prompt.find("prompt") != json_prompt.end()) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
} else if (json_prompt.find("prompt") != json_prompt.end()) { | |
} else if (json_prompt.contains("prompt")) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
tools/server/utils.hpp
Outdated
return server_tokens(tmp, false); | ||
} else if (json_prompt.find("prompt") != json_prompt.end()) { | ||
// JSON object with prompt key. | ||
if (has_mtmd && json_prompt.find("multimodal_data") != json_prompt.end()) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if (has_mtmd && json_prompt.find("multimodal_data") != json_prompt.end()) { | |
if (has_mtmd && json_prompt.contains("multimodal_data")) { |
Or even better, like this:
if (json_prompt.contains("multimodal_data")) {
if (has_mtmd) { ... do the thing .... }
else throw std::runtime_error("multimodal is not supported by this server");
}
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done, however this leaves a trap for the clients who call us, because they don't know if we will give them an error or not. Added a capability to /models and /v1/models so the client can tell if we support multimodal or not, this should be sufficient to support an error rather than silently dropping it.
tools/server/server.cpp
Outdated
@@ -4750,22 +4709,22 @@ int main(int argc, char ** argv) { | |||
return; | |||
} | |||
|
|||
llama_tokens tokenized_query = tokenize_input_prompts(ctx_server.vocab, query, /* add_special */ false, true)[0]; | |||
server_tokens tokenized_query = std::move(tokenize_input_prompts(ctx_server.vocab, ctx_server.mctx, query, /* add_special */ false, true)[0]); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think std::move
is unnecessary here, the compiler should be good enough to optimize this
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done. It was necessary, at least on my recent gcc on linux (this was triggering an ateempted copy), but I refactored this so we only std::move once, below this point.
tools/server/utils.hpp
Outdated
#define JSON_STRING_PROMPT_KEY "prompt_string" | ||
#define JSON_MTMD_DATA_KEY "multimodal_data" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it's best to define these in local scope, not globally, using const char *
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
tools/server/utils.hpp
Outdated
|
||
// JSON object with prompt and multimodal key. | ||
std::vector<raw_buffer> files; | ||
for (const auto& entry : json_prompt.at(JSON_MTMD_DATA_KEY)) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
for (const auto& entry : json_prompt.at(JSON_MTMD_DATA_KEY)) { | |
for (const auto & entry : json_prompt.at(JSON_MTMD_DATA_KEY)) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
tools/server/utils.hpp
Outdated
if (tokenized != 0) { | ||
throw std::runtime_error("Failed to tokenize prompt"); | ||
} | ||
auto result = server_tokens(chunks,true); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
auto result = server_tokens(chunks,true); | |
auto result = server_tokens(chunks, true); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
tools/server/tests/utils.py
Outdated
server.n_ctx = 1024 | ||
server.n_batch = 32 | ||
server.n_batch = 512 | ||
server.n_slots = 2 | ||
server.n_predict = 4 | ||
server.seed = 42 | ||
server.server_embeddings = True | ||
return server |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
define this in local test, before server.start(...)
see other test files for example
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
A few notable updates:
@ngxson ready for re-review, I won't resolve your comments in case you want to discuss any of the changes. |
398d0fe
to
58b9c3e
Compare
@ngxson I believe second round of comments are now addressed! @oobabooga I saw you already were testing this, thanks. Please note the API has changed slightly. The client should check if multimodal is supported via the |
Thanks for the heads up @65a, I have updated the request! oobabooga/text-generation-webui@e6447cd |
Can't reproduce sanitizer test timeout failure locally, hopefully pushing an updated commit message can retrigger CI. |
… format - Use server_tokens in more places in server and util.cpp - Convert most functions that used llama_tokens to server_tokens - Modify input tokenizer to handle JSON objects as subprompts - Break out MTMD prompt parsing into utility function - Support JSON objects with multimodal_data arrays for MTMD prompts along with other existing types - Add capability to model endpoint to indicate if client can send multimodal data - Add tests.
Editing first comment to match current state:
This pull adds support for a multimodal data in the
/completions
(and in a similar fashion,/embeddings
) API endpoint. Instead of a string, list of tokens, or a mixed string token list as currently supported by that endpoint, this pull add support for a JSON object containing both aprompt_string
and and amultimodal_data
field. The client should check the result of/models
or/v1/models
for themultimodal
capability, sending multimodal data to a non-multimodal model will result in a request error.A singular request example is like this:
A multiple (and mixed-type) request would look like:
All existing tests pass, new tests are added to cover both the prompt splitting and visual inference. If a prompt is added and the model does not support MTMD, only the text part will be used. The multimodal part should be a base64-encoded media data supported by libmtmd.
With this approach, other server endpoints can be multimodal in the future relatively easily (Rerank and infill are close but would need additional work and testing, etc). Feedback welcome!
Implement a basic way to include multimodal data in the completions endpoint. For now, this just supports directly included data, in base64-encoded format provided as an array of strings under the json keymultimodal_data
. Documentation updated to match. Local testing shows no regression for without-media case, and successful image processing with media provided from a custom client.Similar to #14016 but avoids importing the URL fetching logic at all. It could be added later when factored out of the OpenAI-emulation code, but this is simpler for now and avoids need for URL parsing and remote fetch capabilities.Original referenced issue was #13872
@ngxson ptal, hopefully this works for you.