From 466c6cddba76c95780c2d098ae9c691203b6c416 Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Fri, 11 Apr 2025 17:46:04 +0200 Subject: [PATCH 01/18] server : (experimental) vision support via libmtmd --- common/arg.cpp | 8 +- examples/server/CMakeLists.txt | 3 +- examples/server/server.cpp | 248 ++++++++++++++++++++--------- examples/server/utils.hpp | 280 ++++++++++++++++++++++++++++++++- 4 files changed, 458 insertions(+), 81 deletions(-) diff --git a/common/arg.cpp b/common/arg.cpp index 642fefb57548f..17955872e61ef 100644 --- a/common/arg.cpp +++ b/common/arg.cpp @@ -834,9 +834,11 @@ static bool common_params_parse_ex(int argc, char ** argv, common_params_context // allow --mmproj to be set from -hf // assuming that mmproj is always in the same repo as text model - if (!params.model.hf_repo.empty() && ctx_arg.ex == LLAMA_EXAMPLE_LLAVA) { + if (!params.model.hf_repo.empty() && ( + ctx_arg.ex == LLAMA_EXAMPLE_LLAVA || ctx_arg.ex == LLAMA_EXAMPLE_SERVER)) { params.mmproj.hf_repo = params.model.hf_repo; } + // TODO @ngxson : this will break non-vision model with -hf, need to fix before merging common_params_handle_model(params.mmproj, params.hf_token, "", true); if (params.escape) { @@ -2101,14 +2103,14 @@ common_params_context common_params_parser_init(common_params & params, llama_ex [](common_params & params, const std::string & value) { params.mmproj.path = value; } - ).set_examples({LLAMA_EXAMPLE_LLAVA})); + ).set_examples({LLAMA_EXAMPLE_LLAVA, LLAMA_EXAMPLE_SERVER})); add_opt(common_arg( {"--mmproj-url"}, "URL", "URL to a multimodal projector file for LLaVA. see examples/llava/README.md", [](common_params & params, const std::string & value) { params.mmproj.url = value; } - ).set_examples({LLAMA_EXAMPLE_LLAVA})); + ).set_examples({LLAMA_EXAMPLE_LLAVA, LLAMA_EXAMPLE_SERVER})); add_opt(common_arg( {"--image"}, "FILE", "path to an image file. use with multimodal models. Specify multiple times for batching", diff --git a/examples/server/CMakeLists.txt b/examples/server/CMakeLists.txt index aee90388e4fb3..17109fddbd307 100644 --- a/examples/server/CMakeLists.txt +++ b/examples/server/CMakeLists.txt @@ -34,8 +34,9 @@ endforeach() add_executable(${TARGET} ${TARGET_SRCS}) install(TARGETS ${TARGET} RUNTIME) +target_include_directories(${TARGET} PRIVATE ../llava) target_include_directories(${TARGET} PRIVATE ${CMAKE_SOURCE_DIR}) -target_link_libraries(${TARGET} PRIVATE common ${CMAKE_THREAD_LIBS_INIT}) +target_link_libraries(${TARGET} PRIVATE common mtmd ${CMAKE_THREAD_LIBS_INIT}) if (LLAMA_SERVER_SSL) find_package(OpenSSL REQUIRED) diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 1bf1ee876b40f..17b0ccfa108e1 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -7,6 +7,7 @@ #include "log.h" #include "sampling.h" #include "speculative.h" +#include "mtmd.h" // Change JSON_ASSERT from assert() to GGML_ASSERT: #define JSON_ASSERT GGML_ASSERT @@ -196,8 +197,8 @@ struct server_task { int id_target = -1; // used by SERVER_TASK_TYPE_INFERENCE - slot_params params; - llama_tokens prompt_tokens; + slot_params params; + server_inputs prompt_tokens; int id_selected_slot = -1; // used by SERVER_TASK_TYPE_SLOT_SAVE, SERVER_TASK_TYPE_SLOT_RESTORE, SERVER_TASK_TYPE_SLOT_ERASE @@ -1246,6 +1247,9 @@ struct server_slot { llama_context * ctx = nullptr; llama_context * ctx_dft = nullptr; + // multimodal + mtmd_context * mctx = nullptr; + common_speculative * spec = nullptr; std::vector lora; @@ -1273,14 +1277,14 @@ struct server_slot { int32_t n_prompt_tokens_processed = 0; // input prompt tokens - llama_tokens prompt_tokens; + server_inputs prompt_tokens; size_t last_nl_pos = 0; std::string generated_text; llama_tokens generated_tokens; - llama_tokens cache_tokens; + server_inputs cache_tokens; std::vector generated_token_probs; @@ -1474,7 +1478,7 @@ struct server_slot { {"is_processing", is_processing()}, {"non_causal", is_non_causal()}, {"params", params.to_json()}, - {"prompt", common_detokenize(ctx, prompt_tokens)}, + {"prompt", ""}, // TODO @ngxson, hacky patch, to fix before merge {"next_token", { {"has_next_token", has_next_token}, @@ -1552,11 +1556,11 @@ struct server_queue { std::condition_variable condition_tasks; // callback functions - std::function callback_new_task; - std::function callback_update_slots; + std::function callback_new_task; + std::function callback_update_slots; // Add a new task to the end of the queue - int post(server_task task, bool front = false) { + int post(server_task & task, bool front = false) { std::unique_lock lock(mutex_tasks); GGML_ASSERT(task.id != -1); // if this is cancel task make sure to clean up pending tasks @@ -1596,7 +1600,7 @@ struct server_queue { } // Add a new task, but defer until one slot is available - void defer(server_task task) { + void defer(server_task & task) { std::unique_lock lock(mutex_tasks); QUE_DBG("defer task, id = %d\n", task.id); queue_tasks_deferred.push_back(std::move(task)); @@ -1611,7 +1615,7 @@ struct server_queue { } // Register function to process a new task - void on_new_task(std::function callback) { + void on_new_task(std::function callback) { callback_new_task = std::move(callback); } @@ -1660,12 +1664,12 @@ struct server_queue { lock.unlock(); break; } - server_task task = queue_tasks.front(); + server_task task = std::move(queue_tasks.front()); queue_tasks.pop_front(); lock.unlock(); QUE_DBG("processing task, id = %d\n", task.id); - callback_new_task(std::move(task)); + callback_new_task(task); } // all tasks in the current loop is processed, slots data is now ready @@ -1846,6 +1850,9 @@ struct server_context { llama_model * model = nullptr; llama_context * ctx = nullptr; + // multimodal + mtmd_context * mctx = nullptr; + const llama_vocab * vocab = nullptr; llama_model * model_dft = nullptr; @@ -1875,6 +1882,8 @@ struct server_context { common_chat_templates_ptr chat_templates; ~server_context() { + mtmd_free(mctx); + // Clear any sampling context for (server_slot & slot : slots) { common_sampler_free(slot.smpl); @@ -1962,6 +1971,18 @@ struct server_context { chat_templates = common_chat_templates_init(model, "chatml"); } + std::string & mmproj_path = params_base.mmproj.path; + if (!mmproj_path.empty()) { + mtmd_context_params mparams; + mparams.n_threads = params_base.cpuparams.n_threads; + mctx = mtmd_init_from_file(mmproj_path.c_str(), model, mparams); + if (mctx == nullptr) { + SRV_ERR("failed to load multimodal model, '%s'\n", mmproj_path.c_str()); + return false; + } + SRV_INF("loaded multimodal model, '%s'\n", mmproj_path.c_str()); + } + return true; } @@ -1977,6 +1998,7 @@ struct server_context { slot.ctx = ctx; slot.n_ctx = n_ctx_slot; slot.n_predict = params_base.n_predict; + slot.mctx = mctx; if (model_dft) { slot.batch_spec = llama_batch_init(params_base.speculative.n_max + 1, 0, 1); @@ -2004,7 +2026,7 @@ struct server_context { slot.reset(); - slots.push_back(slot); + slots.push_back(std::move(slot)); } default_generation_settings_for_props = slots[0].to_json(); @@ -2051,10 +2073,10 @@ struct server_context { } // length of the Longest Common Subsequence between the current slot's prompt and the input prompt - int cur_lcs_len = common_lcs(slot.cache_tokens, task.prompt_tokens); + int cur_lcs_len = slot.cache_tokens.get_common_prefix(task.prompt_tokens); // fraction of the common subsequence length compared to the current slot's prompt length - float cur_similarity = static_cast(cur_lcs_len) / static_cast(slot.cache_tokens.size()); + float cur_similarity = static_cast(cur_lcs_len) / static_cast(slot.cache_tokens.n_tokens()); // select the current slot if the criteria match if (cur_lcs_len > lcs_len && cur_similarity > slot_prompt_similarity) { @@ -2093,19 +2115,14 @@ struct server_context { return ret; } - bool can_be_detokenized(const struct llama_context * ctx, const std::vector & tokens) { + bool can_be_detokenized(const struct llama_context * ctx, const server_inputs & inp) { const llama_model * model = llama_get_model(ctx); const llama_vocab * vocab = llama_model_get_vocab(model); const int32_t n_vocab = llama_vocab_n_tokens(vocab); - for (const auto & token : tokens) { - if (token < 0 || token >= n_vocab) { - return false; - } - } - return true; + return inp.validate(n_vocab); } - bool launch_slot_with_task(server_slot & slot, const server_task & task) { + bool launch_slot_with_task(server_slot & slot, server_task & task) { slot.reset(); slot.id_task = task.id; slot.index = task.index; @@ -2421,7 +2438,7 @@ struct server_context { res->content = std::move(slot.generated_text); res->tokens = std::move(slot.generated_tokens); res->timings = slot.get_timings(); - res->prompt = common_detokenize(ctx, slot.prompt_tokens, true); + //res->prompt = common_detokenize(ctx, slot.prompt_tokens, true); // TODO @ngxson : hacky, need to fix res->response_fields = std::move(slot.params.response_fields); res->truncated = slot.truncated; @@ -2547,7 +2564,7 @@ struct server_context { server_task task(SERVER_TASK_TYPE_CANCEL); task.id_target = id_task; queue_results.remove_waiting_task_id(id_task); - cancel_tasks.push_back(task); + cancel_tasks.push_back(std::move(task)); } // push to beginning of the queue, so it has highest priority queue_tasks.post(cancel_tasks, true); @@ -2637,7 +2654,7 @@ struct server_context { // Functions to process the task // - void process_single_task(server_task task) { + void process_single_task(server_task & task) { switch (task.type) { case SERVER_TASK_TYPE_COMPLETION: case SERVER_TASK_TYPE_INFILL: @@ -2729,7 +2746,7 @@ struct server_context { } queue_results.send(std::move(res)); } break; - case SERVER_TASK_TYPE_SLOT_SAVE: + /*case SERVER_TASK_TYPE_SLOT_SAVE: { int id_slot = task.slot_action.slot_id; server_slot * slot = get_slot_by_id(id_slot); @@ -2833,7 +2850,11 @@ struct server_context { res->id_slot = id_slot; res->n_erased = n_erased; queue_results.send(std::move(res)); - } break; + } break;*/ + case SERVER_TASK_TYPE_SLOT_SAVE: + case SERVER_TASK_TYPE_SLOT_RESTORE: + case SERVER_TASK_TYPE_SLOT_ERASE: + GGML_ASSERT(false && "TODO @ngxson : removed due to not compat with multimodal"); case SERVER_TASK_TYPE_SET_LORA: { params_base.lora_adapters = std::move(task.set_lora); @@ -2841,6 +2862,7 @@ struct server_context { res->id = task.id; queue_results.send(std::move(res)); } break; + } } @@ -2876,7 +2898,8 @@ struct server_context { // apply context-shift if needed // TODO: simplify and improve - for (server_slot & slot : slots) { + // TODO @ngxson : hacky, need to disable context shift for multimodal + /*for (server_slot & slot : slots) { if (slot.is_processing() && slot.n_past + 1 >= slot.n_ctx) { if (!params_base.ctx_shift) { // this check is redundant (for good) @@ -2908,7 +2931,7 @@ struct server_context { slot.truncated = true; } - } + }*/ // start populating the batch for this iteration common_batch_clear(batch); @@ -2940,17 +2963,21 @@ struct server_context { slot.n_past += 1; if (slot.params.cache_prompt) { - slot.cache_tokens.push_back(slot.sampled); + slot.cache_tokens.add_text_token(slot.sampled); } SLT_DBG(slot, "slot decode token, n_ctx = %d, n_past = %d, n_cache_tokens = %d, truncated = %d\n", - slot.n_ctx, slot.n_past, (int) slot.cache_tokens.size(), slot.truncated); + slot.n_ctx, slot.n_past, (int) slot.cache_tokens.n_tokens(), slot.truncated); } // process in chunks of params.n_batch int32_t n_batch = llama_n_batch(ctx); int32_t n_ubatch = llama_n_ubatch(ctx); + // for multimodal + bool is_decoding_embd = false; + server_embd_batch batch_embd; + // next, batch any pending prompts without exceeding n_batch if (params_base.cont_batching || batch.n_tokens == 0) { for (auto & slot : slots) { @@ -2973,23 +3000,23 @@ struct server_context { slot.t_start_generation = 0; slot.n_past = 0; - slot.n_prompt_tokens = prompt_tokens.size(); + slot.n_prompt_tokens = prompt_tokens.n_tokens(); slot.state = SLOT_STATE_PROCESSING_PROMPT; SLT_INF(slot, "new prompt, n_ctx_slot = %d, n_keep = %d, n_prompt_tokens = %d\n", slot.n_ctx, slot.params.n_keep, slot.n_prompt_tokens); // print prompt tokens (for debugging) - if (1) { - // first 16 tokens (avoid flooding logs) - for (int i = 0; i < std::min(16, prompt_tokens.size()); i++) { - SLT_DBG(slot, "prompt token %3d: %6d '%s'\n", i, prompt_tokens[i], common_token_to_piece(ctx, prompt_tokens[i]).c_str()); - } - } else { - // all - for (int i = 0; i < (int) prompt_tokens.size(); i++) { - SLT_DBG(slot, "prompt token %3d: %6d '%s'\n", i, prompt_tokens[i], common_token_to_piece(ctx, prompt_tokens[i]).c_str()); - } - } + // if (1) { + // // first 16 tokens (avoid flooding logs) + // for (int i = 0; i < std::min(16, prompt_tokens.size()); i++) { + // SLT_DBG(slot, "prompt token %3d: %6d '%s'\n", i, prompt_tokens[i], common_token_to_piece(ctx, prompt_tokens[i]).c_str()); + // } + // } else { + // // all + // for (int i = 0; i < (int) prompt_tokens.size(); i++) { + // SLT_DBG(slot, "prompt token %3d: %6d '%s'\n", i, prompt_tokens[i], common_token_to_piece(ctx, prompt_tokens[i]).c_str()); + // } + // } // empty prompt passed -> release the slot and send empty response if (prompt_tokens.empty()) { @@ -3030,7 +3057,8 @@ struct server_context { slot.params.n_keep = std::min(slot.n_ctx - 4, slot.params.n_keep); // if input prompt is too big, truncate it - if (slot.n_prompt_tokens >= slot.n_ctx) { + // TODO @ngxson : this won't work with multimodal + /*if (slot.n_prompt_tokens >= slot.n_ctx) { const int n_left = slot.n_ctx - slot.params.n_keep; const int n_block_size = n_left / 2; @@ -3053,14 +3081,15 @@ struct server_context { SLT_WRN(slot, "input truncated, n_ctx = %d, n_keep = %d, n_left = %d, n_prompt_tokens = %d\n", slot.n_ctx, slot.params.n_keep, n_left, slot.n_prompt_tokens); GGML_ASSERT(slot.n_prompt_tokens < slot.n_ctx); - } + }*/ if (slot.params.cache_prompt) { // reuse any previously computed tokens that are common with the new prompt - slot.n_past = common_lcp(slot.cache_tokens, prompt_tokens); + slot.n_past = slot.cache_tokens.get_common_prefix(prompt_tokens); // reuse chunks from the cached prompt by shifting their KV cache in the new position - if (params_base.n_cache_reuse > 0) { + // TODO @ngxson : this won't work with multimodal + /*if (params_base.n_cache_reuse > 0) { size_t head_c = slot.n_past; // cache size_t head_p = slot.n_past; // current prompt @@ -3101,7 +3130,7 @@ struct server_context { } SLT_DBG(slot, "after context reuse, new slot.n_past = %d\n", slot.n_past); - } + }*/ } } @@ -3135,17 +3164,26 @@ struct server_context { SLT_INF(slot, "kv cache rm [%d, end)\n", slot.n_past); // remove the non-common part from the cache - slot.cache_tokens.resize(slot.n_past); + slot.cache_tokens.keep_until(slot.n_past); // add prompt tokens for processing in the current batch while (slot.n_past < slot.n_prompt_tokens && batch.n_tokens < n_batch) { // without pooling, we want to output the embeddings for all the tokens in the batch const bool need_embd = slot.task_type == SERVER_TASK_TYPE_EMBEDDING && llama_pooling_type(slot.ctx) == LLAMA_POOLING_TYPE_NONE; - common_batch_add(batch, prompt_tokens[slot.n_past], slot.n_past, { slot.id }, need_embd); - - if (slot.params.cache_prompt) { - slot.cache_tokens.push_back(prompt_tokens[slot.n_past]); + auto & curr_chunk = slot.prompt_tokens.get_chunk(slot.n_past); + if (curr_chunk.tok_image) { + // decode image + server_encode_image(slot.mctx, batch_embd, curr_chunk, slot.n_past, slot.id); + is_decoding_embd = true; + SLT_INF(slot, "decoding image, n_past = %d, n_tokens = %d\n", slot.n_past, batch_embd.batch.n_tokens); + slot.n_past += batch_embd.batch.n_tokens; + break; // do not process any other slots + } else { + common_batch_add(batch, curr_chunk.tok_text, slot.n_past, { slot.id }, need_embd); + if (slot.params.cache_prompt) { + slot.cache_tokens.add_text_token(curr_chunk.tok_text); + } } slot.n_prompt_tokens_processed++; @@ -3163,8 +3201,11 @@ struct server_context { common_sampler_reset(slot.smpl); // Process all prompt tokens through sampler system - for (int i = 0; i < slot.n_prompt_tokens; ++i) { - common_sampler_accept(slot.smpl, prompt_tokens[i], false); + for (size_t i = 0; i < slot.cache_tokens.n_tokens(); ++i) { + auto & curr_chunk = slot.cache_tokens.get_chunk(i); + if (curr_chunk.tok_text != LLAMA_TOKEN_NULL) { + common_sampler_accept(slot.smpl, curr_chunk.tok_text, false); + } } // extract the logits only for the last token @@ -3201,7 +3242,7 @@ struct server_context { for (int32_t i = 0; i < batch.n_tokens; i += n_batch) { const int32_t n_tokens = std::min(n_batch, batch.n_tokens - i); - llama_batch batch_view = { + llama_batch batch_view = is_decoding_embd ? batch_embd.batch : llama_batch{ n_tokens, batch.token + i, nullptr, @@ -3211,9 +3252,18 @@ struct server_context { batch.logits + i, }; + // TODO @ngxson : maybe move this to llama_batch_ext + if (is_decoding_embd && mtmd_decode_use_non_causal(mctx)) { + llama_set_causal_attn(ctx, false); + } + const int ret = llama_decode(ctx, batch_view); metrics.on_decoded(slots); + if (is_decoding_embd && mtmd_decode_use_non_causal(mctx)) { + llama_set_causal_attn(ctx, true); + } + if (ret != 0) { if (n_batch == 1 || ret < 0) { // if you get here, it means the KV cache is full - try increasing it via the context size @@ -3301,7 +3351,8 @@ struct server_context { } // do speculative decoding - for (auto & slot : slots) { + // TODO @ngxson : remove speculative decoding for multimodal + /*for (auto & slot : slots) { if (!slot.is_processing() || !slot.can_speculate()) { continue; } @@ -3394,7 +3445,7 @@ struct server_context { } SLT_DBG(slot, "accepted %d/%d draft tokens, new n_past = %d\n", (int) ids.size() - 1, (int) draft.size(), slot.n_past); - } + }*/ } SRV_DBG("%s", "run slots completed\n"); @@ -3912,6 +3963,7 @@ int main(int argc, char ** argv) { const auto handle_completions_impl = [&ctx_server, &res_error, &res_ok]( server_task_type type, json & data, + std::vector & files, std::function is_connection_closed, httplib::Response & res, oaicompat_type oaicompat) { @@ -3930,15 +3982,55 @@ int main(int argc, char ** argv) { // TODO: this log can become very long, put it behind a flag or think about a more compact format //SRV_DBG("Prompt: %s\n", prompt.is_string() ? prompt.get().c_str() : prompt.dump(2).c_str()); - std::vector tokenized_prompts = tokenize_input_prompts(ctx_server.vocab, prompt, true, true); - tasks.reserve(tokenized_prompts.size()); - for (size_t i = 0; i < tokenized_prompts.size(); i++) { + // process files + std::vector bitmaps; + { + for (auto & file : files) { + mtmd_bitmap bmp; + int32_t res = mtmd_helper_bitmap_init_from_buf(file.data(), file.size(), bmp); + if (res != 0) { + throw std::runtime_error("Failed to load image"); + } + bitmaps.push_back(std::move(bmp)); + } + } + + std::vector inputs; + if (oaicompat) { + if (!prompt.is_string()) { + throw std::runtime_error("prompt must be a string"); + } else { + printf("prompt: %s\n", prompt.get().c_str()); + mtmd_input_text inp_txt = { + prompt.get(), + /* add_special */ true, + /* parse_special */ true, + }; + mtmd_input_chunks * tokenized = mtmd_tokenize(ctx_server.mctx, inp_txt, bitmaps); + if (!tokenized) { + throw std::runtime_error("Failed to tokenize prompt"); + } + server_inputs tmp(tokenized); + inputs.push_back(std::move(tmp)); + mtmd_input_chunks_free(tokenized, false); // only free the container, not the images + } + } else { + // non-multimodal version + auto tokenized_prompts = tokenize_input_prompts(ctx_server.vocab, prompt, true, true); + for (auto & p : tokenized_prompts) { + auto tmp = convert_legacy_to_mtmd(p); + inputs.push_back(std::move(tmp)); + } + } + + tasks.reserve(inputs.size()); + for (size_t i = 0; i < inputs.size(); i++) { server_task task = server_task(type); task.id = ctx_server.queue_tasks.get_new_id(); task.index = i; - task.prompt_tokens = std::move(tokenized_prompts[i]); + task.prompt_tokens = std::move(inputs[i]); task.params = server_task::params_from_json_cmpl( ctx_server.ctx, ctx_server.params_base, @@ -3950,7 +4042,7 @@ int main(int argc, char ** argv) { task.params.oaicompat_cmpl_id = completion_id; // oaicompat_model is already populated by params_from_json_cmpl - tasks.push_back(task); + tasks.push_back(std::move(task)); } } catch (const std::exception & e) { res_error(res, format_error_response(e.what(), ERROR_TYPE_INVALID_REQUEST)); @@ -4020,9 +4112,11 @@ int main(int argc, char ** argv) { const auto handle_completions = [&handle_completions_impl](const httplib::Request & req, httplib::Response & res) { json data = json::parse(req.body); + std::vector files; // dummy return handle_completions_impl( SERVER_TASK_TYPE_COMPLETION, data, + files, req.is_connection_closed, res, OAICOMPAT_TYPE_NONE); @@ -4030,9 +4124,11 @@ int main(int argc, char ** argv) { const auto handle_completions_oai = [&handle_completions_impl](const httplib::Request & req, httplib::Response & res) { json data = oaicompat_completion_params_parse(json::parse(req.body)); + std::vector files; // dummy return handle_completions_impl( SERVER_TASK_TYPE_COMPLETION, data, + files, req.is_connection_closed, res, OAICOMPAT_TYPE_COMPLETION); @@ -4107,9 +4203,11 @@ int main(int argc, char ** argv) { tokenized_prompts[0] ); + std::vector files; // dummy return handle_completions_impl( SERVER_TASK_TYPE_INFILL, data, + files, req.is_connection_closed, res, OAICOMPAT_TYPE_NONE); // infill is not OAI compatible @@ -4123,11 +4221,13 @@ int main(int argc, char ** argv) { } auto body = json::parse(req.body); - json data = oaicompat_completion_params_parse(body, params.use_jinja, params.reasoning_format, ctx_server.chat_templates.get()); + std::vector files; + json data = oaicompat_completion_params_parse(body, params.use_jinja, params.reasoning_format, ctx_server.chat_templates.get(), files); return handle_completions_impl( SERVER_TASK_TYPE_COMPLETION, data, + files, req.is_connection_closed, res, OAICOMPAT_TYPE_CHAT); @@ -4136,7 +4236,8 @@ int main(int argc, char ** argv) { // same with handle_chat_completions, but without inference part const auto handle_apply_template = [&ctx_server, ¶ms, &res_ok](const httplib::Request & req, httplib::Response & res) { auto body = json::parse(req.body); - json data = oaicompat_completion_params_parse(body, params.use_jinja, params.reasoning_format, ctx_server.chat_templates.get()); + std::vector files; // dummy, unused + json data = oaicompat_completion_params_parse(body, params.use_jinja, params.reasoning_format, ctx_server.chat_templates.get(), files); res_ok(res, {{ "prompt", std::move(data.at("prompt")) }}); }; @@ -4241,7 +4342,7 @@ int main(int argc, char ** argv) { } } - std::vector tokenized_prompts = tokenize_input_prompts(ctx_server.vocab, prompt, true, true); + auto tokenized_prompts = tokenize_input_prompts(ctx_server.vocab, prompt, true, true); for (const auto & tokens : tokenized_prompts) { // this check is necessary for models that do not add BOS token to the input if (tokens.empty()) { @@ -4260,12 +4361,12 @@ int main(int argc, char ** argv) { task.id = ctx_server.queue_tasks.get_new_id(); task.index = i; - task.prompt_tokens = std::move(tokenized_prompts[i]); + task.prompt_tokens = convert_legacy_to_mtmd(tokenized_prompts[i]); // OAI-compat task.params.oaicompat = oaicompat; - tasks.push_back(task); + tasks.push_back(std::move(task)); } ctx_server.queue_results.add_waiting_tasks(tasks); @@ -4354,14 +4455,15 @@ int main(int argc, char ** argv) { bool error = false; { std::vector tasks; - std::vector tokenized_docs = tokenize_input_prompts(ctx_server.vocab, documents, /* add_special */ false, true); + auto tokenized_docs = tokenize_input_prompts(ctx_server.vocab, documents, /* add_special */ false, true); tasks.reserve(tokenized_docs.size()); for (size_t i = 0; i < tokenized_docs.size(); i++) { + auto tmp = format_rerank(ctx_server.vocab, tokenized_query, tokenized_docs[i]); server_task task = server_task(SERVER_TASK_TYPE_RERANK); task.id = ctx_server.queue_tasks.get_new_id(); task.index = i; - task.prompt_tokens = format_rerank(ctx_server.vocab, tokenized_query, tokenized_docs[i]); - tasks.push_back(task); + task.prompt_tokens = convert_legacy_to_mtmd(tmp); + tasks.push_back(std::move(task)); } ctx_server.queue_results.add_waiting_tasks(tasks); @@ -4566,7 +4668,7 @@ int main(int argc, char ** argv) { common_chat_templates_source(ctx_server.chat_templates.get()), common_chat_format_example(ctx_server.chat_templates.get(), ctx_server.params_base.use_jinja).c_str()); - ctx_server.queue_tasks.on_new_task([&ctx_server](const server_task & task) { + ctx_server.queue_tasks.on_new_task([&ctx_server](server_task & task) { ctx_server.process_single_task(task); }); diff --git a/examples/server/utils.hpp b/examples/server/utils.hpp index aba2f27f9b564..5103e22e163dd 100644 --- a/examples/server/utils.hpp +++ b/examples/server/utils.hpp @@ -4,6 +4,7 @@ #include "log.h" #include "llama.h" #include "base64.hpp" +#include "mtmd.h" // increase max payload length to allow use of larger context size #define CPPHTTPLIB_FORM_URL_ENCODED_PAYLOAD_MAX_LENGTH 1048576 @@ -21,6 +22,7 @@ #include #include #include +#include #define DEFAULT_OAICOMPAT_MODEL "gpt-3.5-turbo" @@ -41,6 +43,8 @@ using json = nlohmann::ordered_json; #define QUE_ERR(fmt, ...) LOG_ERR("que %12.*s: " fmt, 12, __func__, __VA_ARGS__) #define QUE_DBG(fmt, ...) LOG_DBG("que %12.*s: " fmt, 12, __func__, __VA_ARGS__) +using raw_buffer = std::vector; + template static T json_value(const json & body, const std::string & key, const T & default_value) { // Fallback null to default value @@ -386,7 +390,7 @@ static inline bool is_base64(uint8_t c) { return (isalnum(c) || (c == '+') || (c == '/')); } -static inline std::vector base64_decode(const std::string & encoded_string) { +static inline raw_buffer base64_decode(const std::string & encoded_string) { int i = 0; int j = 0; int in_ = 0; @@ -396,7 +400,7 @@ static inline std::vector base64_decode(const std::string & encoded_str uint8_t char_array_4[4]; uint8_t char_array_3[3]; - std::vector ret; + raw_buffer ret; while (in_len-- && (encoded_string[in_] != '=') && is_base64(encoded_string[in_])) { char_array_4[i++] = encoded_string[in_]; in_++; @@ -579,7 +583,8 @@ static json oaicompat_completion_params_parse( const json & body, /* openai api json semantics */ bool use_jinja, common_reasoning_format reasoning_format, - const struct common_chat_templates * tmpls) + const struct common_chat_templates * tmpls, + std::vector & out_files) { json llama_params; @@ -627,8 +632,47 @@ static json oaicompat_completion_params_parse( } } + // get input files + json messages = json_value(body, "messages", json::array()); + if (!messages.is_array()) { + throw std::runtime_error("Expected 'messages' to be an array, got " + messages.dump()); + } + for (auto & msg : messages) { + json & content = msg.at("content"); + if (content.is_string()) { + continue; + } + + if (!content.is_array()) { + throw std::runtime_error("Expected 'content' to be a string or an array"); + } + + for (auto & p : content) { + std::string type = json_value(p, "type", std::string()); + json image_url = json_value(p, "image_url", json::object()); + if (type == "image_url") { + std::string url = json_value(image_url, "url", std::string()); + std::vector parts = string_split(url, /*separator*/ ','); + if (parts.size() != 2) { + throw std::runtime_error("Invalid image_url.url value"); + } else if (!string_starts_with(parts[0], "data:image/")) { + throw std::runtime_error("Invalid image_url.url format: " + parts[0]); + } else if (!string_ends_with(parts[0], "base64")) { + throw std::runtime_error("image_url.url must be base64 encoded"); + } else { + auto base64_data = parts[1]; + auto decoded_data = base64_decode(base64_data); + out_files.push_back(decoded_data); + } + p["type"] = "text"; + p["text"] = "<__image__>"; + p.erase("image_url"); + } + } + } + common_chat_templates_inputs inputs; - inputs.messages = common_chat_msgs_parse_oaicompat(body.at("messages")); + inputs.messages = common_chat_msgs_parse_oaicompat(messages); inputs.tools = common_chat_tools_parse_oaicompat(tools); inputs.tool_choice = common_chat_tool_choice_parse_oaicompat(json_value(body, "tool_choice", std::string("auto"))); inputs.json_schema = json_schema.is_null() ? "" : json_schema.dump(); @@ -913,3 +957,231 @@ static std::vector parse_lora_request( return lora; } + +// +// utils for interacting with libmtmd +// (may need to refactor in near future) +// + +struct server_inp_chunk { + llama_token tok_text; + mtmd_image_tokens_ptr tok_image; + std::string str() { + // for debugging + if (tok_image) { + return " "; + } else { + return std::to_string(tok_text) + " "; + } + } +}; + +struct server_inputs { + std::vector chunks; + + server_inputs() = default; + ~server_inputs() = default; // Important if unique_ptr is used + + // Prevent copying + server_inputs(const server_inputs&) = delete; + server_inputs& operator=(const server_inputs&) = delete; + + // Allow moving (usually implicitly generated if members are movable) + server_inputs(server_inputs&&) = default; + server_inputs& operator=(server_inputs&&) = default; + + server_inputs(mtmd_input_chunks * mtmd_chunks) { + for (auto & c : *mtmd_chunks) { + if (c.type == MTMD_INPUT_CHUNK_TYPE_IMAGE) { + chunks.push_back({LLAMA_TOKEN_NULL, mtmd_image_tokens_ptr(c.tokens_image)}); + } else if (c.type == MTMD_INPUT_CHUNK_TYPE_TEXT) { + for (auto & tok : c.tokens_text) { + chunks.push_back({tok, nullptr}); + } + } else { + GGML_ASSERT(false && "Invalid chunk type"); + } + } + } + + size_t n_tokens() const { + size_t res = 0; + for (const auto & chunk : chunks) { + if (chunk.tok_image) { + res += mtmd_image_tokens_get_n_tokens(chunk.tok_image.get()); + } else { + res++; + } + } + return res; + } + + bool empty() const { + return n_tokens() == 0; + } + + void clear() { + chunks.clear(); + } + + void add_text_token(llama_token tok) { + GGML_ASSERT(tok != LLAMA_TOKEN_NULL); + chunks.push_back({tok, nullptr}); + } + + size_t get_common_prefix(const server_inputs & b) const { + size_t ret = 0; + size_t max_idx = std::min(chunks.size(), b.chunks.size()); + for (size_t i = 0; i < max_idx; ++i) { + auto & ai = chunks[i]; + auto & bi = b.chunks[i]; + + if (ai.tok_text == bi.tok_text && !ai.tok_image && !bi.tok_image) { + ret++; + continue; + } else if (ai.tok_image && bi.tok_image) { + // TODO check image hash + break; + } else { + break; + } + } + return ret; + } + + bool validate(llama_token max_vocab_id) const { + for (const auto & chunk : chunks) { + if (!chunk.tok_image) { + if (chunk.tok_text < 0 || chunk.tok_text >= max_vocab_id) { + return false; + } + } + } + return true; + } + + server_inp_chunk & get_chunk(size_t pos) { + return chunks[get_chunk_idx(pos)]; + } + + size_t get_chunk_idx(size_t pos) const { + size_t current_pos = 0; + for (size_t i = 0; i < chunks.size(); ++i) { + const auto & chunk = chunks[i]; + size_t chunk_size = chunk.tok_image ? mtmd_image_tokens_get_n_tokens(chunk.tok_image.get()) : 1; + size_t chunk_end_pos = current_pos + chunk_size; + if (pos < chunk_end_pos) { + // The target position 'pos' falls within this chunk + return i; + } + + current_pos = chunk_end_pos; + } + // If the loop finishes, 'pos' is >= the total number of logical positions + return chunks.size(); + } + + // same idea with std::vector resize() + void keep_until(size_t pos) { + if (pos == 0) { + chunks.clear(); + return; + } + + size_t current_pos = 0; + for (size_t i = 0; i < chunks.size(); ++i) { + const auto & chunk = chunks[i]; + size_t chunk_size = chunk.tok_image ? mtmd_image_tokens_get_n_tokens(chunk.tok_image.get()) : 1; + size_t chunk_end_pos = current_pos + chunk_size; + if (pos <= current_pos) { + // Truncation point is exactly at or before the start of this chunk. + // Keep only chunks before index 'i'. + chunks.resize(i); + return; + } + if (pos < chunk_end_pos) { + // Truncation point 'pos' falls within this chunk. + if (chunk.tok_image) { + // It's an image chunk, keep the whole chunk. + // Keep chunks up to and including index 'i'. + chunks.resize(i + 1); + } else { + // It's a text chunk. Since pos < chunk_end_pos and chunk_size is 1, + // this means pos == current_pos. + // Keep only chunks before index 'i'. + chunks.resize(i); + } + return; + } + // pos >= chunk_end_pos, so keep this chunk entirely and continue. + current_pos = chunk_end_pos; + } + // If the loop completes, it means 'pos' is >= the total logical size. + // No truncation needed, the vector remains unchanged. + } +}; + +// helper struct to make working with embd batch easier +// note: this will be removed after llama_batch_ext refactoring +struct server_embd_batch { + std::vector pos; + std::vector n_seq_id; + std::vector seq_id_0; + std::vector seq_ids; + std::vector logits; + llama_batch batch; + server_embd_batch() = default; + server_embd_batch(float * embd, int32_t n_tokens, llama_pos pos_0, llama_seq_id seq_id) { + pos .resize(n_tokens); + n_seq_id.resize(n_tokens); + seq_ids .resize(n_tokens + 1); + logits .resize(n_tokens); + seq_id_0.resize(1); + seq_id_0[0] = seq_id; + seq_ids [n_tokens] = nullptr; + batch = { + /*n_tokens =*/ n_tokens, + /*tokens =*/ nullptr, + /*embd =*/ embd, + /*pos =*/ pos.data(), + /*n_seq_id =*/ n_seq_id.data(), + /*seq_id =*/ seq_ids.data(), + /*logits =*/ logits.data(), + }; + for (int i = 0; i < n_tokens; i++) { + batch.pos [i] = pos_0 + i; + batch.n_seq_id[i] = 1; + batch.seq_id [i] = seq_id_0.data(); + batch.logits [i] = false; + } + } +}; + +// TODO @ngxson : quite hacky for now, but just to see if it works +static int32_t server_encode_image(mtmd_context * mctx, server_embd_batch & batch_out, server_inp_chunk & chunk, llama_pos n_past, llama_seq_id seq_id) { + GGML_ASSERT(chunk.tok_image); + + int64_t t0 = ggml_time_ms(); + LOG_INF("encoding image...\n"); + int32_t ret = mtmd_encode(mctx, chunk.tok_image.get()); + if (ret != 0) { + LOG_ERR("failed to encode image\n"); + batch_out = server_embd_batch{}; + return ret; + } + LOG_INF("image encoded in %" PRId64 " ms\n", ggml_time_ms() - t0); + + int32_t n_tokens = mtmd_image_tokens_get_n_tokens(chunk.tok_image.get()); + float * embd = mtmd_get_output_embd(mctx); + batch_out = server_embd_batch(embd, n_tokens, n_past, seq_id); + return ret; +} + +// hacky, support text-only for now +static server_inputs convert_legacy_to_mtmd(llama_tokens & tokenized) { + server_inputs res; + for (auto & tok : tokenized) { + res.add_text_token(tok); + } + return res; +} From 2317e618b5df417eaf3980ddf6114caae7392dfb Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Fri, 11 Apr 2025 17:46:25 +0200 Subject: [PATCH 02/18] mtmd : add more api around mtmd_image_tokens --- examples/llava/mtmd.cpp | 39 ++++++++++++++++++++++++++++++++++----- examples/llava/mtmd.h | 23 ++++++++++++++++++++--- 2 files changed, 54 insertions(+), 8 deletions(-) diff --git a/examples/llava/mtmd.cpp b/examples/llava/mtmd.cpp index 58503d0b22c33..be856c0fa9ed6 100644 --- a/examples/llava/mtmd.cpp +++ b/examples/llava/mtmd.cpp @@ -166,15 +166,36 @@ mtmd_input_chunks * mtmd_tokenize(mtmd_context * ctx, return output; } -void mtmd_input_chunks_free(mtmd_input_chunks * chunks) { - for (auto & chunk : *chunks) { - if (chunk.type == MTMD_INPUT_CHUNK_TYPE_IMAGE && chunk.tokens_image) { - delete chunk.tokens_image; +void mtmd_image_tokens_free(mtmd_image_tokens * image_tokens) { + if (image_tokens) { + delete image_tokens; + } +} + +void mtmd_input_chunks_free(mtmd_input_chunks * chunks, bool free_images) { + if (free_images) { + for (auto & chunk : *chunks) { + if (chunk.type == MTMD_INPUT_CHUNK_TYPE_IMAGE && chunk.tokens_image) { + mtmd_image_tokens_free(chunk.tokens_image); + chunk.tokens_image = nullptr; + } } } delete chunks; } +size_t mtmd_image_tokens_get_n_tokens(const mtmd_image_tokens * image_tokens) { + return image_tokens->n_tokens(); +} + +size_t mtmd_image_tokens_get_nx(const mtmd_image_tokens * image_tokens) { + return image_tokens->nx; +} + +size_t mtmd_image_tokens_get_ny(const mtmd_image_tokens * image_tokens) { + return image_tokens->ny; +} + int32_t mtmd_encode(mtmd_context * ctx, const mtmd_image_tokens * image_tokens) { int n_mmproj_embd = clip_n_mmproj_embd(ctx->ctx_clip); ctx->image_embd_v.resize(image_tokens->n_tokens() * n_mmproj_embd); @@ -289,7 +310,7 @@ int32_t mtmd_helper_eval(mtmd_context * ctx, LOG_INF("image encoded in %" PRId64 " ms\n", ggml_time_ms() - t0); } - int32_t n_tokens = chunk.tokens_image->n_tokens(); + int32_t n_tokens = mtmd_image_tokens_get_n_tokens(chunk.tokens_image); float * embd = mtmd_get_output_embd(ctx); decode_embd_batch batch_img(embd, n_tokens, n_past, 0); int64_t t1 = ggml_time_ms(); @@ -339,3 +360,11 @@ int32_t mtmd_helper_bitmap_init_from_file(const char * fname, mtmd_bitmap & outp std::memcpy(output.data.data(), data, output.nx * output.ny * 3); return 0; } + +bool mtmd_decode_use_non_causal(mtmd_context * ctx) { + projector_type proj_type = clip_get_projector_type(ctx->ctx_clip); + if (proj_type == PROJECTOR_TYPE_GEMMA3) { + return true; + } + return false; +} diff --git a/examples/llava/mtmd.h b/examples/llava/mtmd.h index 598f6947bb092..ca3fb6fdc7960 100644 --- a/examples/llava/mtmd.h +++ b/examples/llava/mtmd.h @@ -81,13 +81,20 @@ MTMD_API void mtmd_free(mtmd_context * ctx); // 2. (image tokens) // 3. "\ndescribe it in detail." // number of bitmaps must be equal to the number of image markers in the prompt +// the returned value must be freed using mtmd_input_chunks_free() // this function is thread-safe (shared ctx) MTMD_API mtmd_input_chunks * mtmd_tokenize(mtmd_context * ctx, const mtmd_input_text & text, const std::vector & bitmaps); -// free image chunk data -MTMD_API void mtmd_input_chunks_free(mtmd_input_chunks * chunks); +// if free_images = true, free the image tokens ; otherwise, you must free them using mtmd_image_free() +MTMD_API void mtmd_input_chunks_free(mtmd_input_chunks * chunks, bool free_images); + +// access mtmd_image_tokens +MTMD_API size_t mtmd_image_tokens_get_n_tokens(const mtmd_image_tokens * image_tokens); +MTMD_API size_t mtmd_image_tokens_get_nx(const mtmd_image_tokens * image_tokens); +MTMD_API size_t mtmd_image_tokens_get_ny(const mtmd_image_tokens * image_tokens); +MTMD_API void mtmd_image_tokens_free(mtmd_image_tokens * image_tokens); // returns 0 on success MTMD_API int32_t mtmd_encode(mtmd_context * ctx, @@ -96,6 +103,11 @@ MTMD_API int32_t mtmd_encode(mtmd_context * ctx, // get output embeddings from the last encode pass MTMD_API float * mtmd_get_output_embd(mtmd_context * ctx); +// whether we need to set non-causal mask before llama_decode +MTMD_API bool mtmd_decode_use_non_causal(mtmd_context * ctx); + + + // // helper functions (can be implemented based on other functions) // @@ -133,10 +145,15 @@ struct mtmd_context_deleter { using mtmd_context_ptr = std::unique_ptr; struct mtmd_input_chunks_deleter { - void operator()(mtmd_input_chunks * val) { mtmd_input_chunks_free(val); } + void operator()(mtmd_input_chunks * val) { mtmd_input_chunks_free(val, true); } }; using mtmd_input_chunks_ptr = std::unique_ptr; +struct mtmd_image_tokens_deleter { + void operator()(mtmd_image_tokens * val) { mtmd_image_tokens_free(val); } +}; +using mtmd_image_tokens_ptr = std::unique_ptr; + #else static_assert(false && "C header is not yet supported by this library"); From a46b6db6844c2d213965d7450a7eb0d2588d88e3 Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Fri, 11 Apr 2025 17:46:25 +0200 Subject: [PATCH 03/18] mtmd : add more api around mtmd_image_tokens --- examples/llava/mtmd.cpp | 39 ++++++++++++++++++++++++++++++++++----- examples/llava/mtmd.h | 23 ++++++++++++++++++++--- 2 files changed, 54 insertions(+), 8 deletions(-) diff --git a/examples/llava/mtmd.cpp b/examples/llava/mtmd.cpp index 114c274bc1250..98d660a643809 100644 --- a/examples/llava/mtmd.cpp +++ b/examples/llava/mtmd.cpp @@ -166,15 +166,36 @@ mtmd_input_chunks * mtmd_tokenize(mtmd_context * ctx, return output; } -void mtmd_input_chunks_free(mtmd_input_chunks * chunks) { - for (auto & chunk : *chunks) { - if (chunk.type == MTMD_INPUT_CHUNK_TYPE_IMAGE && chunk.tokens_image) { - delete chunk.tokens_image; +void mtmd_image_tokens_free(mtmd_image_tokens * image_tokens) { + if (image_tokens) { + delete image_tokens; + } +} + +void mtmd_input_chunks_free(mtmd_input_chunks * chunks, bool free_images) { + if (free_images) { + for (auto & chunk : *chunks) { + if (chunk.type == MTMD_INPUT_CHUNK_TYPE_IMAGE && chunk.tokens_image) { + mtmd_image_tokens_free(chunk.tokens_image); + chunk.tokens_image = nullptr; + } } } delete chunks; } +size_t mtmd_image_tokens_get_n_tokens(const mtmd_image_tokens * image_tokens) { + return image_tokens->n_tokens(); +} + +size_t mtmd_image_tokens_get_nx(const mtmd_image_tokens * image_tokens) { + return image_tokens->nx; +} + +size_t mtmd_image_tokens_get_ny(const mtmd_image_tokens * image_tokens) { + return image_tokens->ny; +} + int32_t mtmd_encode(mtmd_context * ctx, const mtmd_image_tokens * image_tokens) { int n_mmproj_embd = clip_n_mmproj_embd(ctx->ctx_clip); ctx->image_embd_v.resize(image_tokens->n_tokens() * n_mmproj_embd); @@ -289,7 +310,7 @@ int32_t mtmd_helper_eval(mtmd_context * ctx, LOG_INF("image encoded in %" PRId64 " ms\n", ggml_time_ms() - t0); } - int32_t n_tokens = chunk.tokens_image->n_tokens(); + int32_t n_tokens = mtmd_image_tokens_get_n_tokens(chunk.tokens_image); float * embd = mtmd_get_output_embd(ctx); decode_embd_batch batch_img(embd, n_tokens, n_past, 0); int64_t t1 = ggml_time_ms(); @@ -339,3 +360,11 @@ int32_t mtmd_helper_bitmap_init_from_file(const char * fname, mtmd_bitmap & outp std::memcpy(output.data.data(), data, output.nx * output.ny * 3); return 0; } + +bool mtmd_decode_use_non_causal(mtmd_context * ctx) { + projector_type proj_type = clip_get_projector_type(ctx->ctx_clip); + if (proj_type == PROJECTOR_TYPE_GEMMA3) { + return true; + } + return false; +} diff --git a/examples/llava/mtmd.h b/examples/llava/mtmd.h index 598f6947bb092..ca3fb6fdc7960 100644 --- a/examples/llava/mtmd.h +++ b/examples/llava/mtmd.h @@ -81,13 +81,20 @@ MTMD_API void mtmd_free(mtmd_context * ctx); // 2. (image tokens) // 3. "\ndescribe it in detail." // number of bitmaps must be equal to the number of image markers in the prompt +// the returned value must be freed using mtmd_input_chunks_free() // this function is thread-safe (shared ctx) MTMD_API mtmd_input_chunks * mtmd_tokenize(mtmd_context * ctx, const mtmd_input_text & text, const std::vector & bitmaps); -// free image chunk data -MTMD_API void mtmd_input_chunks_free(mtmd_input_chunks * chunks); +// if free_images = true, free the image tokens ; otherwise, you must free them using mtmd_image_free() +MTMD_API void mtmd_input_chunks_free(mtmd_input_chunks * chunks, bool free_images); + +// access mtmd_image_tokens +MTMD_API size_t mtmd_image_tokens_get_n_tokens(const mtmd_image_tokens * image_tokens); +MTMD_API size_t mtmd_image_tokens_get_nx(const mtmd_image_tokens * image_tokens); +MTMD_API size_t mtmd_image_tokens_get_ny(const mtmd_image_tokens * image_tokens); +MTMD_API void mtmd_image_tokens_free(mtmd_image_tokens * image_tokens); // returns 0 on success MTMD_API int32_t mtmd_encode(mtmd_context * ctx, @@ -96,6 +103,11 @@ MTMD_API int32_t mtmd_encode(mtmd_context * ctx, // get output embeddings from the last encode pass MTMD_API float * mtmd_get_output_embd(mtmd_context * ctx); +// whether we need to set non-causal mask before llama_decode +MTMD_API bool mtmd_decode_use_non_causal(mtmd_context * ctx); + + + // // helper functions (can be implemented based on other functions) // @@ -133,10 +145,15 @@ struct mtmd_context_deleter { using mtmd_context_ptr = std::unique_ptr; struct mtmd_input_chunks_deleter { - void operator()(mtmd_input_chunks * val) { mtmd_input_chunks_free(val); } + void operator()(mtmd_input_chunks * val) { mtmd_input_chunks_free(val, true); } }; using mtmd_input_chunks_ptr = std::unique_ptr; +struct mtmd_image_tokens_deleter { + void operator()(mtmd_image_tokens * val) { mtmd_image_tokens_free(val); } +}; +using mtmd_image_tokens_ptr = std::unique_ptr; + #else static_assert(false && "C header is not yet supported by this library"); From 7ac0b7b7b0433eacd8c9cabf3734f092637e6212 Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Fri, 11 Apr 2025 22:17:47 +0200 Subject: [PATCH 04/18] mtmd : ability to calc image hash --- examples/llava/gemma3-cli.cpp | 1 + examples/llava/mtmd.cpp | 29 ++++++++++++++++++++++++++++- examples/llava/mtmd.h | 12 ++++++++---- 3 files changed, 37 insertions(+), 5 deletions(-) diff --git a/examples/llava/gemma3-cli.cpp b/examples/llava/gemma3-cli.cpp index 91a07e2a8f40d..b200d8f111918 100644 --- a/examples/llava/gemma3-cli.cpp +++ b/examples/llava/gemma3-cli.cpp @@ -89,6 +89,7 @@ struct gemma3_context { ctx_vision.reset(mtmd_init_from_file(clip_path, model, mtmd_context_params{ /* use_gpu */ true, /* timings */ true, + /* hash */ false, /* n_threads */ params.cpuparams.n_threads, /* verbosity */ GGML_LOG_LEVEL_INFO, })); diff --git a/examples/llava/mtmd.cpp b/examples/llava/mtmd.cpp index 98d660a643809..1691a71bf27fc 100644 --- a/examples/llava/mtmd.cpp +++ b/examples/llava/mtmd.cpp @@ -16,15 +16,22 @@ struct mtmd_context { struct clip_ctx * ctx_clip; const struct llama_model * text_model; std::vector image_embd_v; // image embedding vector + bool print_timings; int n_threads; std::string image_marker; + bool calc_image_hash; // TODO @ngxson : add timings mtmd_context(const char * mmproj_fname, const llama_model * text_model, - const mtmd_context_params & ctx_params) : print_timings(ctx_params.print_timings), n_threads(ctx_params.n_threads), image_marker(ctx_params.image_marker) { + const mtmd_context_params & ctx_params) : + print_timings (ctx_params.print_timings), + n_threads (ctx_params.n_threads), + image_marker (ctx_params.image_marker), + calc_image_hash(ctx_params.calc_image_hash) + { clip_context_params ctx_clip_params; ctx_clip_params.use_gpu = ctx_params.use_gpu; ctx_clip_params.verbosity = ctx_params.verbosity; @@ -49,6 +56,7 @@ struct mtmd_image_tokens { uint32_t ny; // number of tokens in y direction uint32_t n_tokens() const { return nx * ny; } clip_image_f32_batch batch_f32; // preprocessed image patches + size_t image_hash = 0; // hash of the image, useful for KV cache tracking }; mtmd_context * mtmd_init_from_file(const char * mmproj_fname, @@ -88,6 +96,16 @@ static std::vector mtmd_tokenize_text_internal( return result; } +static uint64_t hash_vector_float(const std::vector & vec) { + uint64_t seed = vec.size(); + std::hash hasher; + for (float val : vec) { + // inspired by boost::hash_combine + seed ^= hasher(val) + 0x9e3779b9 + (seed << 6) + (seed >> 2); + } + return seed; +} + mtmd_input_chunks * mtmd_tokenize(mtmd_context * ctx, const mtmd_input_text & text, const std::vector & bitmaps) { @@ -153,6 +171,11 @@ mtmd_input_chunks * mtmd_tokenize(mtmd_context * ctx, image_tokens->ny = 1; // TODO image_tokens->batch_f32 = std::move(batch_f32); + // optionally calculate the hash + if (ctx->calc_image_hash) { + image_tokens->image_hash = hash_vector_float(image_tokens->batch_f32.entries[0]->buf); + } + mtmd_input_chunk chunk{ MTMD_INPUT_CHUNK_TYPE_IMAGE, {}, @@ -196,6 +219,10 @@ size_t mtmd_image_tokens_get_ny(const mtmd_image_tokens * image_tokens) { return image_tokens->ny; } +uint64_t mtmd_image_tokens_get_hash(const mtmd_image_tokens * image_tokens) { + return image_tokens->image_hash; +} + int32_t mtmd_encode(mtmd_context * ctx, const mtmd_image_tokens * image_tokens) { int n_mmproj_embd = clip_n_mmproj_embd(ctx->ctx_clip); ctx->image_embd_v.resize(image_tokens->n_tokens() * n_mmproj_embd); diff --git a/examples/llava/mtmd.h b/examples/llava/mtmd.h index ca3fb6fdc7960..cadcfa16fdceb 100644 --- a/examples/llava/mtmd.h +++ b/examples/llava/mtmd.h @@ -52,6 +52,9 @@ using mtmd_input_chunks = std::vector; struct mtmd_context_params { bool use_gpu = true; bool print_timings = true; + // calc_image_hash is useful for tracking KV cache + // if not set, mtmd_image_tokens_get_hash will return 0 + bool calc_image_hash = false; int n_threads = 4; enum ggml_log_level verbosity = GGML_LOG_LEVEL_INFO; const char * image_marker = "<__image__>"; @@ -91,10 +94,11 @@ MTMD_API mtmd_input_chunks * mtmd_tokenize(mtmd_context * ctx, MTMD_API void mtmd_input_chunks_free(mtmd_input_chunks * chunks, bool free_images); // access mtmd_image_tokens -MTMD_API size_t mtmd_image_tokens_get_n_tokens(const mtmd_image_tokens * image_tokens); -MTMD_API size_t mtmd_image_tokens_get_nx(const mtmd_image_tokens * image_tokens); -MTMD_API size_t mtmd_image_tokens_get_ny(const mtmd_image_tokens * image_tokens); -MTMD_API void mtmd_image_tokens_free(mtmd_image_tokens * image_tokens); +MTMD_API size_t mtmd_image_tokens_get_n_tokens(const mtmd_image_tokens * image_tokens); +MTMD_API size_t mtmd_image_tokens_get_nx(const mtmd_image_tokens * image_tokens); +MTMD_API size_t mtmd_image_tokens_get_ny(const mtmd_image_tokens * image_tokens); +MTMD_API uint64_t mtmd_image_tokens_get_hash(const mtmd_image_tokens * image_tokens); +MTMD_API void mtmd_image_tokens_free(mtmd_image_tokens * image_tokens); // returns 0 on success MTMD_API int32_t mtmd_encode(mtmd_context * ctx, From 58c47674aac9704cfbc2f8e44ebbbe318edc432e Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Sat, 12 Apr 2025 10:34:12 +0200 Subject: [PATCH 05/18] shared_ptr for mtmd_image_tokens --- examples/llava/gemma3-cli.cpp | 11 +++---- examples/llava/mtmd.cpp | 56 +++++++++++++++-------------------- examples/llava/mtmd.h | 32 +++++++++----------- 3 files changed, 44 insertions(+), 55 deletions(-) diff --git a/examples/llava/gemma3-cli.cpp b/examples/llava/gemma3-cli.cpp index b200d8f111918..34296c87132b0 100644 --- a/examples/llava/gemma3-cli.cpp +++ b/examples/llava/gemma3-cli.cpp @@ -185,18 +185,19 @@ static int eval_message(gemma3_context & ctx, common_chat_msg & msg, std::vector text.text = formatted_chat.prompt; text.add_special = add_bos; text.parse_special = true; - mtmd_input_chunks_ptr chunks(mtmd_tokenize(ctx.ctx_vision.get(), text, bitmaps)); - if (chunks == nullptr) { - LOG_ERR("Unable to tokenize prompt\n"); + mtmd_input_chunks chunks; + int32_t res = mtmd_tokenize(ctx.ctx_vision.get(), chunks, text, bitmaps); + if (res != 0) { + LOG_ERR("Unable to tokenize prompt, res = %d\n", res); return 1; } - if (mtmd_helper_eval(ctx.ctx_vision.get(), ctx.lctx, chunks.get(), ctx.n_past, 0, ctx.n_batch)) { + if (mtmd_helper_eval(ctx.ctx_vision.get(), ctx.lctx, chunks, ctx.n_past, 0, ctx.n_batch)) { LOG_ERR("Unable to eval prompt\n"); return 1; } - ctx.n_past += mtmd_helper_get_n_tokens(chunks.get()); + ctx.n_past += mtmd_helper_get_n_tokens(chunks); return 0; } diff --git a/examples/llava/mtmd.cpp b/examples/llava/mtmd.cpp index 1691a71bf27fc..44e48c7270368 100644 --- a/examples/llava/mtmd.cpp +++ b/examples/llava/mtmd.cpp @@ -106,10 +106,10 @@ static uint64_t hash_vector_float(const std::vector & vec) { return seed; } -mtmd_input_chunks * mtmd_tokenize(mtmd_context * ctx, - const mtmd_input_text & text, - const std::vector & bitmaps) { - mtmd_input_chunks * output = new mtmd_input_chunks; +int32_t mtmd_tokenize(mtmd_context * ctx, + std::vector & output, + const mtmd_input_text & text, + const std::vector & bitmaps) { auto vocab = llama_model_get_vocab(ctx->text_model); std::string prompt_modified(text.text); @@ -124,8 +124,8 @@ mtmd_input_chunks * mtmd_tokenize(mtmd_context * ctx, } std::vector parts = string_split_str(text.text, ctx->image_marker); - output->clear(); - output->reserve(parts.size()); + output.clear(); + output.reserve(parts.size()); size_t i_img = 0; @@ -141,14 +141,14 @@ mtmd_input_chunks * mtmd_tokenize(mtmd_context * ctx, std::move(tokens), {}, }; - output->emplace_back(std::move(chunk)); + output.emplace_back(std::move(chunk)); if (&parts.back() != &part) { // add image token to middle of 2 parts if (i_img >= bitmaps.size()) { LOG_ERR("%s: error: not enough images for %d parts\n", __func__, (int)parts.size()); - return nullptr; + return 1; } // shim layer @@ -163,10 +163,10 @@ mtmd_input_chunks * mtmd_tokenize(mtmd_context * ctx, bool ok = clip_image_preprocess(ctx->ctx_clip, img_u8.get(), &batch_f32); if (!ok) { LOG_ERR("Unable to preprocess image\n"); - return nullptr; + return 2; } - mtmd_image_tokens * image_tokens = new mtmd_image_tokens; + mtmd_image_tokens_ptr image_tokens(new mtmd_image_tokens); image_tokens->nx = clip_n_patches(ctx->ctx_clip); // TODO @ngxson : use clip_n_patches_by_image image_tokens->ny = 1; // TODO image_tokens->batch_f32 = std::move(batch_f32); @@ -179,14 +179,14 @@ mtmd_input_chunks * mtmd_tokenize(mtmd_context * ctx, mtmd_input_chunk chunk{ MTMD_INPUT_CHUNK_TYPE_IMAGE, {}, - image_tokens, + std::move(image_tokens), }; - output->emplace_back(std::move(chunk)); + output.emplace_back(std::move(chunk)); i_img++; } } - return output; + return 0; } void mtmd_image_tokens_free(mtmd_image_tokens * image_tokens) { @@ -195,18 +195,6 @@ void mtmd_image_tokens_free(mtmd_image_tokens * image_tokens) { } } -void mtmd_input_chunks_free(mtmd_input_chunks * chunks, bool free_images) { - if (free_images) { - for (auto & chunk : *chunks) { - if (chunk.type == MTMD_INPUT_CHUNK_TYPE_IMAGE && chunk.tokens_image) { - mtmd_image_tokens_free(chunk.tokens_image); - chunk.tokens_image = nullptr; - } - } - } - delete chunks; -} - size_t mtmd_image_tokens_get_n_tokens(const mtmd_image_tokens * image_tokens) { return image_tokens->n_tokens(); } @@ -238,9 +226,9 @@ float * mtmd_get_output_embd(mtmd_context * ctx) { return ctx->image_embd_v.data(); } -size_t mtmd_helper_get_n_tokens(mtmd_input_chunks * chunks) { +size_t mtmd_helper_get_n_tokens(mtmd_input_chunks & chunks) { size_t n_tokens = 0; - for (auto & chunk : *chunks) { + for (auto & chunk : chunks) { if (chunk.type == MTMD_INPUT_CHUNK_TYPE_TEXT) { n_tokens += chunk.tokens_text.size(); } else if (chunk.type == MTMD_INPUT_CHUNK_TYPE_IMAGE) { @@ -289,7 +277,7 @@ struct decode_embd_batch { int32_t mtmd_helper_eval(mtmd_context * ctx, llama_context * lctx, - mtmd_input_chunks * chunks, + mtmd_input_chunks & chunks, llama_pos pos0, llama_seq_id seq_id, int32_t n_batch) { @@ -297,8 +285,8 @@ int32_t mtmd_helper_eval(mtmd_context * ctx, llama_pos n_past = pos0; llama_batch text_batch = llama_batch_init(n_batch, 0, 1); - for (auto & chunk : *chunks) { - bool is_last = &chunk == &chunks->back(); + for (auto & chunk : chunks) { + bool is_last = &chunk == &chunks.back(); if (chunk.type == MTMD_INPUT_CHUNK_TYPE_TEXT) { // TODO @ngxson : may need to split into smaller batches text_batch.n_tokens = chunk.tokens_text.size(); @@ -327,7 +315,7 @@ int32_t mtmd_helper_eval(mtmd_context * ctx, if (ctx->print_timings) { LOG_INF("encoding image...\n"); } - ret = mtmd_encode(ctx, chunk.tokens_image); + ret = mtmd_encode(ctx, chunk.tokens_image.get()); if (ret != 0) { LOG_ERR("failed to encode image\n"); llama_batch_free(text_batch); @@ -337,7 +325,7 @@ int32_t mtmd_helper_eval(mtmd_context * ctx, LOG_INF("image encoded in %" PRId64 " ms\n", ggml_time_ms() - t0); } - int32_t n_tokens = mtmd_image_tokens_get_n_tokens(chunk.tokens_image); + int32_t n_tokens = mtmd_image_tokens_get_n_tokens(chunk.tokens_image.get()); float * embd = mtmd_get_output_embd(ctx); decode_embd_batch batch_img(embd, n_tokens, n_past, 0); int64_t t1 = ggml_time_ms(); @@ -395,3 +383,7 @@ bool mtmd_decode_use_non_causal(mtmd_context * ctx) { } return false; } + +void mtmd_image_tokens_deleter::operator()(mtmd_image_tokens * val) { + mtmd_image_tokens_free(val); +} diff --git a/examples/llava/mtmd.h b/examples/llava/mtmd.h index cadcfa16fdceb..f07814a56208c 100644 --- a/examples/llava/mtmd.h +++ b/examples/llava/mtmd.h @@ -41,10 +41,15 @@ struct mtmd_bitmap { std::vector data; }; +struct mtmd_image_tokens_deleter { + void operator()(mtmd_image_tokens * val); // forward declaration +}; +using mtmd_image_tokens_ptr = std::unique_ptr; + struct mtmd_input_chunk { mtmd_input_chunk_type type; std::vector tokens_text; - mtmd_image_tokens * tokens_image = nullptr; + mtmd_image_tokens_ptr tokens_image; }; using mtmd_input_chunks = std::vector; @@ -84,15 +89,16 @@ MTMD_API void mtmd_free(mtmd_context * ctx); // 2. (image tokens) // 3. "\ndescribe it in detail." // number of bitmaps must be equal to the number of image markers in the prompt -// the returned value must be freed using mtmd_input_chunks_free() // this function is thread-safe (shared ctx) -MTMD_API mtmd_input_chunks * mtmd_tokenize(mtmd_context * ctx, +// return values: +// 0 on success +// 1 on number of images not matching the number of markers +// 2 on image preprocessing error +MTMD_API int32_t mtmd_tokenize(mtmd_context * ctx, + std::vector & output, const mtmd_input_text & text, const std::vector & bitmaps); -// if free_images = true, free the image tokens ; otherwise, you must free them using mtmd_image_free() -MTMD_API void mtmd_input_chunks_free(mtmd_input_chunks * chunks, bool free_images); - // access mtmd_image_tokens MTMD_API size_t mtmd_image_tokens_get_n_tokens(const mtmd_image_tokens * image_tokens); MTMD_API size_t mtmd_image_tokens_get_nx(const mtmd_image_tokens * image_tokens); @@ -117,7 +123,7 @@ MTMD_API bool mtmd_decode_use_non_causal(mtmd_context * ctx); // // helper to count the total number of tokens from a list of chunks, useful to keep track of n_past -MTMD_API size_t mtmd_helper_get_n_tokens(mtmd_input_chunks * chunks); +MTMD_API size_t mtmd_helper_get_n_tokens(mtmd_input_chunks & chunks); // helper function that automatically: // 1. run llama_decode() on text chunks @@ -126,7 +132,7 @@ MTMD_API size_t mtmd_helper_get_n_tokens(mtmd_input_chunks * chunks); // otherwise, returns 0 on success MTMD_API int32_t mtmd_helper_eval(mtmd_context * ctx, llama_context * lctx, - mtmd_input_chunks * chunks, + mtmd_input_chunks & chunks, llama_pos pos0, llama_seq_id seq_id, int32_t n_batch); @@ -148,16 +154,6 @@ struct mtmd_context_deleter { }; using mtmd_context_ptr = std::unique_ptr; -struct mtmd_input_chunks_deleter { - void operator()(mtmd_input_chunks * val) { mtmd_input_chunks_free(val, true); } -}; -using mtmd_input_chunks_ptr = std::unique_ptr; - -struct mtmd_image_tokens_deleter { - void operator()(mtmd_image_tokens * val) { mtmd_image_tokens_free(val); } -}; -using mtmd_image_tokens_ptr = std::unique_ptr; - #else static_assert(false && "C header is not yet supported by this library"); From d3c3e20c424b02fedbef8d2fdddd0061c6255348 Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Sat, 12 Apr 2025 11:03:38 +0200 Subject: [PATCH 06/18] move hash to user-define ID (fixed) --- examples/llava/gemma3-cli.cpp | 1 - examples/llava/mtmd.cpp | 25 +++++-------------------- examples/llava/mtmd.h | 14 ++++++-------- 3 files changed, 11 insertions(+), 29 deletions(-) diff --git a/examples/llava/gemma3-cli.cpp b/examples/llava/gemma3-cli.cpp index 34296c87132b0..de206c85ae80c 100644 --- a/examples/llava/gemma3-cli.cpp +++ b/examples/llava/gemma3-cli.cpp @@ -89,7 +89,6 @@ struct gemma3_context { ctx_vision.reset(mtmd_init_from_file(clip_path, model, mtmd_context_params{ /* use_gpu */ true, /* timings */ true, - /* hash */ false, /* n_threads */ params.cpuparams.n_threads, /* verbosity */ GGML_LOG_LEVEL_INFO, })); diff --git a/examples/llava/mtmd.cpp b/examples/llava/mtmd.cpp index 44e48c7270368..0898439d11d48 100644 --- a/examples/llava/mtmd.cpp +++ b/examples/llava/mtmd.cpp @@ -29,8 +29,7 @@ struct mtmd_context { const mtmd_context_params & ctx_params) : print_timings (ctx_params.print_timings), n_threads (ctx_params.n_threads), - image_marker (ctx_params.image_marker), - calc_image_hash(ctx_params.calc_image_hash) + image_marker (ctx_params.image_marker) { clip_context_params ctx_clip_params; ctx_clip_params.use_gpu = ctx_params.use_gpu; @@ -56,7 +55,7 @@ struct mtmd_image_tokens { uint32_t ny; // number of tokens in y direction uint32_t n_tokens() const { return nx * ny; } clip_image_f32_batch batch_f32; // preprocessed image patches - size_t image_hash = 0; // hash of the image, useful for KV cache tracking + std::string id; // optional user-defined ID, useful for KV cache tracking }; mtmd_context * mtmd_init_from_file(const char * mmproj_fname, @@ -96,16 +95,6 @@ static std::vector mtmd_tokenize_text_internal( return result; } -static uint64_t hash_vector_float(const std::vector & vec) { - uint64_t seed = vec.size(); - std::hash hasher; - for (float val : vec) { - // inspired by boost::hash_combine - seed ^= hasher(val) + 0x9e3779b9 + (seed << 6) + (seed >> 2); - } - return seed; -} - int32_t mtmd_tokenize(mtmd_context * ctx, std::vector & output, const mtmd_input_text & text, @@ -170,11 +159,7 @@ int32_t mtmd_tokenize(mtmd_context * ctx, image_tokens->nx = clip_n_patches(ctx->ctx_clip); // TODO @ngxson : use clip_n_patches_by_image image_tokens->ny = 1; // TODO image_tokens->batch_f32 = std::move(batch_f32); - - // optionally calculate the hash - if (ctx->calc_image_hash) { - image_tokens->image_hash = hash_vector_float(image_tokens->batch_f32.entries[0]->buf); - } + image_tokens->id = bitmaps[i_img].id; // optional mtmd_input_chunk chunk{ MTMD_INPUT_CHUNK_TYPE_IMAGE, @@ -207,8 +192,8 @@ size_t mtmd_image_tokens_get_ny(const mtmd_image_tokens * image_tokens) { return image_tokens->ny; } -uint64_t mtmd_image_tokens_get_hash(const mtmd_image_tokens * image_tokens) { - return image_tokens->image_hash; +std::string mtmd_image_tokens_get_id(const mtmd_image_tokens * image_tokens) { + return image_tokens->id; } int32_t mtmd_encode(mtmd_context * ctx, const mtmd_image_tokens * image_tokens) { diff --git a/examples/llava/mtmd.h b/examples/llava/mtmd.h index f07814a56208c..78be192dd6eb6 100644 --- a/examples/llava/mtmd.h +++ b/examples/llava/mtmd.h @@ -39,6 +39,7 @@ struct mtmd_bitmap { uint32_t nx; uint32_t ny; std::vector data; + std::string id; // optional user-defined id, for ex: can be set to image hash, useful for KV cache tracking }; struct mtmd_image_tokens_deleter { @@ -57,9 +58,6 @@ using mtmd_input_chunks = std::vector; struct mtmd_context_params { bool use_gpu = true; bool print_timings = true; - // calc_image_hash is useful for tracking KV cache - // if not set, mtmd_image_tokens_get_hash will return 0 - bool calc_image_hash = false; int n_threads = 4; enum ggml_log_level verbosity = GGML_LOG_LEVEL_INFO; const char * image_marker = "<__image__>"; @@ -100,11 +98,11 @@ MTMD_API int32_t mtmd_tokenize(mtmd_context * ctx, const std::vector & bitmaps); // access mtmd_image_tokens -MTMD_API size_t mtmd_image_tokens_get_n_tokens(const mtmd_image_tokens * image_tokens); -MTMD_API size_t mtmd_image_tokens_get_nx(const mtmd_image_tokens * image_tokens); -MTMD_API size_t mtmd_image_tokens_get_ny(const mtmd_image_tokens * image_tokens); -MTMD_API uint64_t mtmd_image_tokens_get_hash(const mtmd_image_tokens * image_tokens); -MTMD_API void mtmd_image_tokens_free(mtmd_image_tokens * image_tokens); +MTMD_API size_t mtmd_image_tokens_get_n_tokens(const mtmd_image_tokens * image_tokens); +MTMD_API size_t mtmd_image_tokens_get_nx(const mtmd_image_tokens * image_tokens); +MTMD_API size_t mtmd_image_tokens_get_ny(const mtmd_image_tokens * image_tokens); +MTMD_API std::string mtmd_image_tokens_get_id(const mtmd_image_tokens * image_tokens); +MTMD_API void mtmd_image_tokens_free(mtmd_image_tokens * image_tokens); // returns 0 on success MTMD_API int32_t mtmd_encode(mtmd_context * ctx, From 5e6c7ba4a8f765639dc947afac94bac629fec6cd Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Sun, 13 Apr 2025 23:38:32 +0200 Subject: [PATCH 07/18] abstract out the batch management --- examples/llava/mtmd.cpp | 14 +---- examples/server/server.cpp | 111 +++++++++++++++++++++---------------- examples/server/utils.hpp | 110 ++++++++++++++++++++++++++---------- 3 files changed, 147 insertions(+), 88 deletions(-) diff --git a/examples/llava/mtmd.cpp b/examples/llava/mtmd.cpp index 687abfbc472ee..fe6d769095011 100644 --- a/examples/llava/mtmd.cpp +++ b/examples/llava/mtmd.cpp @@ -112,7 +112,7 @@ int32_t mtmd_tokenize(mtmd_context * ctx, string_replace_all(prompt_modified, ctx->image_marker, marker_modified); } - std::vector parts = string_split_str(text.text, ctx->image_marker); + std::vector parts = string_split_str(prompt_modified, ctx->image_marker); output.clear(); output.reserve(parts.size()); @@ -196,18 +196,6 @@ std::string mtmd_image_tokens_get_id(const mtmd_image_tokens * image_tokens) { return image_tokens->id; } -size_t mtmd_image_tokens_get_n_tokens(const mtmd_image_tokens * image_tokens) { - return image_tokens->n_tokens(); -} - -size_t mtmd_image_tokens_get_nx(const mtmd_image_tokens * image_tokens) { - return image_tokens->nx; -} - -size_t mtmd_image_tokens_get_ny(const mtmd_image_tokens * image_tokens) { - return image_tokens->ny; -} - int32_t mtmd_encode(mtmd_context * ctx, const mtmd_image_tokens * image_tokens) { int n_mmproj_embd = clip_n_mmproj_embd(ctx->ctx_clip); ctx->image_embd_v.resize(image_tokens->n_tokens() * n_mmproj_embd); diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 17b0ccfa108e1..2c4b0b876d576 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -1859,7 +1859,7 @@ struct server_context { llama_context_params cparams_dft; - llama_batch batch = {}; + server_batch batch; bool clean_kv_cache = true; bool add_bos_token = true; @@ -1897,8 +1897,6 @@ struct server_context { llama_batch_free(slot.batch_spec); } - - llama_batch_free(batch); } bool load_model(const common_params & params) { @@ -2035,9 +2033,7 @@ struct server_context { // note that n_batch can be > n_ctx (e.g. for non-causal attention models such as BERT where the KV cache is not used) { const int32_t n_batch = llama_n_batch(ctx); - - // only a single seq_id per token is needed - batch = llama_batch_init(std::max(n_batch, params_base.n_parallel), 0, 1); + batch = server_batch(std::max(n_batch, params_base.n_parallel)); } metrics.init(); @@ -2934,7 +2930,7 @@ struct server_context { }*/ // start populating the batch for this iteration - common_batch_clear(batch); + batch.clear(); // track if given slot can be batched with slots already in the batch server_slot * slot_batched = nullptr; @@ -2956,9 +2952,9 @@ struct server_context { continue; } - slot.i_batch = batch.n_tokens; + slot.i_batch = batch.n_tokens(); - common_batch_add(batch, slot.sampled, slot.n_past, { slot.id }, true); + common_batch_add(batch.batch, slot.sampled, slot.n_past, { slot.id }, true); slot.n_past += 1; @@ -2974,12 +2970,8 @@ struct server_context { int32_t n_batch = llama_n_batch(ctx); int32_t n_ubatch = llama_n_ubatch(ctx); - // for multimodal - bool is_decoding_embd = false; - server_embd_batch batch_embd; - // next, batch any pending prompts without exceeding n_batch - if (params_base.cont_batching || batch.n_tokens == 0) { + if (params_base.cont_batching || batch.n_tokens() == 0) { for (auto & slot : slots) { // check if we can batch this slot with the previous one if (slot.is_processing()) { @@ -3147,7 +3139,7 @@ struct server_context { // non-causal tasks require to fit the entire prompt in the physical batch if (slot.is_non_causal()) { // cannot fit the prompt in the current batch - will try next iter - if (batch.n_tokens + slot.n_prompt_tokens > n_batch) { + if (batch.n_tokens() + slot.n_prompt_tokens > n_batch) { continue; } } @@ -3167,36 +3159,55 @@ struct server_context { slot.cache_tokens.keep_until(slot.n_past); // add prompt tokens for processing in the current batch - while (slot.n_past < slot.n_prompt_tokens && batch.n_tokens < n_batch) { + while (slot.n_past < slot.n_prompt_tokens && batch.n_tokens() < n_batch) { // without pooling, we want to output the embeddings for all the tokens in the batch const bool need_embd = slot.task_type == SERVER_TASK_TYPE_EMBEDDING && llama_pooling_type(slot.ctx) == LLAMA_POOLING_TYPE_NONE; auto & curr_chunk = slot.prompt_tokens.get_chunk(slot.n_past); if (curr_chunk.tok_image) { - // decode image - server_encode_image(slot.mctx, batch_embd, curr_chunk, slot.n_past, slot.id); - is_decoding_embd = true; - SLT_INF(slot, "decoding image, n_past = %d, n_tokens = %d\n", slot.n_past, batch_embd.batch.n_tokens); - slot.n_past += batch_embd.batch.n_tokens; - break; // do not process any other slots + // if there are already TEXT tokens in the batch, we need to process them first + if (batch.batch.n_tokens > 0) { + break; + } + // encode the image + server_encode_image(slot.mctx, batch, curr_chunk, slot.n_past, slot.id); + GGML_ASSERT(batch.has_embd()); + SLT_INF(slot, "image encoded, n_past = %d, n_embd_tokens = %d\n", slot.n_past, batch.n_tokens()); + + if (slot.params.cache_prompt) { + slot.cache_tokens.add_image_tokens(curr_chunk.tok_image); + } + + slot.n_past += batch.n_tokens(); + slot.n_prompt_tokens_processed += batch.n_tokens(); + break; // we cannot have both text batch and image batch + } else { - common_batch_add(batch, curr_chunk.tok_text, slot.n_past, { slot.id }, need_embd); + GGML_ASSERT(!batch.has_embd()); + common_batch_add(batch.batch, curr_chunk.tok_text, slot.n_past, { slot.id }, need_embd); if (slot.params.cache_prompt) { slot.cache_tokens.add_text_token(curr_chunk.tok_text); } + + slot.n_prompt_tokens_processed++; + slot.n_past++; } + } + + SLT_INF(slot, "new cache_tokens: %s\n", slot.cache_tokens.str().c_str()); - slot.n_prompt_tokens_processed++; - slot.n_past++; + if (batch.has_embd()) { + // currently, we can only process one image at a time, so we skip other slots + break; } - SLT_INF(slot, "prompt processing progress, n_past = %d, n_tokens = %d, progress = %f\n", slot.n_past, batch.n_tokens, (float) slot.n_prompt_tokens_processed / slot.n_prompt_tokens); + SLT_INF(slot, "prompt processing progress, n_past = %d, n_tokens = %d, progress = %f\n", slot.n_past, batch.n_tokens(), (float) slot.n_prompt_tokens_processed / slot.n_prompt_tokens); // entire prompt has been processed if (slot.n_past == slot.n_prompt_tokens) { slot.state = SLOT_STATE_DONE_PROMPT; - GGML_ASSERT(batch.n_tokens > 0); + GGML_ASSERT(batch.n_tokens() > 0); common_sampler_reset(slot.smpl); @@ -3209,27 +3220,32 @@ struct server_context { } // extract the logits only for the last token - batch.logits[batch.n_tokens - 1] = true; + batch.logits[batch.n_tokens() - 1] = true; slot.n_decoded = 0; - slot.i_batch = batch.n_tokens - 1; + slot.i_batch = batch.n_tokens() - 1; - SLT_INF(slot, "prompt done, n_past = %d, n_tokens = %d\n", slot.n_past, batch.n_tokens); + SLT_INF(slot, "prompt done, n_past = %d, n_tokens = %d\n", slot.n_past, batch.n_tokens()); } } - if (batch.n_tokens >= n_batch) { + if (batch.n_tokens() >= n_batch) { break; } } } - if (batch.n_tokens == 0) { + if (batch.n_tokens() == 0) { SRV_WRN("%s", "no tokens to decode\n"); return; } - SRV_DBG("decoding batch, n_tokens = %d\n", batch.n_tokens); + // debug + if (batch.has_embd()) { + SRV_INF("decoding embd batch, n_tokens = %d\n", batch.n_tokens()); + } else { + SRV_INF("decoding batch, n_tokens = %d\n", batch.n_tokens()); + } if (slot_batched) { // make sure we're in the right embedding mode @@ -3239,28 +3255,29 @@ struct server_context { } // process the created batch of tokens - for (int32_t i = 0; i < batch.n_tokens; i += n_batch) { - const int32_t n_tokens = std::min(n_batch, batch.n_tokens - i); + for (int32_t i = 0; i < batch.n_tokens(); i += n_batch) { + const int32_t n_tokens = std::min(n_batch, batch.n_tokens() - i); - llama_batch batch_view = is_decoding_embd ? batch_embd.batch : llama_batch{ + // TODO @ngxson : hacky here, we don't want to split the embd batch + llama_batch batch_view = batch.has_embd() ? batch.batch : llama_batch{ n_tokens, - batch.token + i, + batch.batch.token + i, nullptr, - batch.pos + i, - batch.n_seq_id + i, - batch.seq_id + i, - batch.logits + i, + batch.batch.pos + i, + batch.batch.n_seq_id + i, + batch.batch.seq_id + i, + batch.batch.logits + i, }; // TODO @ngxson : maybe move this to llama_batch_ext - if (is_decoding_embd && mtmd_decode_use_non_causal(mctx)) { + if (batch.has_embd() && mtmd_decode_use_non_causal(mctx)) { llama_set_causal_attn(ctx, false); } const int ret = llama_decode(ctx, batch_view); metrics.on_decoded(slots); - if (is_decoding_embd && mtmd_decode_use_non_causal(mctx)) { + if (batch.has_embd() && mtmd_decode_use_non_causal(mctx)) { llama_set_causal_attn(ctx, true); } @@ -4006,13 +4023,13 @@ int main(int argc, char ** argv) { /* add_special */ true, /* parse_special */ true, }; - mtmd_input_chunks * tokenized = mtmd_tokenize(ctx_server.mctx, inp_txt, bitmaps); - if (!tokenized) { + mtmd_input_chunks chunks; + int32_t tokenized = mtmd_tokenize(ctx_server.mctx, chunks, inp_txt, bitmaps); + if (tokenized != 0) { throw std::runtime_error("Failed to tokenize prompt"); } - server_inputs tmp(tokenized); + server_inputs tmp(chunks); inputs.push_back(std::move(tmp)); - mtmd_input_chunks_free(tokenized, false); // only free the container, not the images } } else { // non-multimodal version diff --git a/examples/server/utils.hpp b/examples/server/utils.hpp index 5103e22e163dd..3bc0d0da17ec3 100644 --- a/examples/server/utils.hpp +++ b/examples/server/utils.hpp @@ -964,18 +964,26 @@ static std::vector parse_lora_request( // struct server_inp_chunk { + size_t n_tokens = 1; // always 1 in case of text llama_token tok_text; mtmd_image_tokens_ptr tok_image; - std::string str() { + std::string str() const { // for debugging if (tok_image) { - return " "; + return string_format("( at %p) ", (void *)tok_image.get()); } else { return std::to_string(tok_text) + " "; } } }; +/** + * server_inputs is a helper to manage the input tokens and image for the server. + * + * the difference between server_inputs and mtmd_input_chunks is that each chunk of server_inputs only contains a single text token, but text chunk of mtmd_input_chunks can contain multiple tokens. + * + * it is made this way to simplify the logic of KV cache management. + */ struct server_inputs { std::vector chunks; @@ -990,13 +998,14 @@ struct server_inputs { server_inputs(server_inputs&&) = default; server_inputs& operator=(server_inputs&&) = default; - server_inputs(mtmd_input_chunks * mtmd_chunks) { - for (auto & c : *mtmd_chunks) { + server_inputs(mtmd_input_chunks & mtmd_chunks) { + for (auto & c : mtmd_chunks) { if (c.type == MTMD_INPUT_CHUNK_TYPE_IMAGE) { - chunks.push_back({LLAMA_TOKEN_NULL, mtmd_image_tokens_ptr(c.tokens_image)}); + size_t n_tokens = mtmd_image_tokens_get_n_tokens(c.tokens_image.get()); + chunks.push_back({n_tokens, LLAMA_TOKEN_NULL, std::move(c.tokens_image)}); } else if (c.type == MTMD_INPUT_CHUNK_TYPE_TEXT) { for (auto & tok : c.tokens_text) { - chunks.push_back({tok, nullptr}); + chunks.push_back({1, tok, nullptr}); } } else { GGML_ASSERT(false && "Invalid chunk type"); @@ -1004,11 +1013,20 @@ struct server_inputs { } } + std::string str() { + // for debugging + std::string ret; + for (const auto & chunk : chunks) { + ret += chunk.str(); + } + return ret; + } + size_t n_tokens() const { size_t res = 0; for (const auto & chunk : chunks) { if (chunk.tok_image) { - res += mtmd_image_tokens_get_n_tokens(chunk.tok_image.get()); + res += chunk.n_tokens; } else { res++; } @@ -1026,7 +1044,13 @@ struct server_inputs { void add_text_token(llama_token tok) { GGML_ASSERT(tok != LLAMA_TOKEN_NULL); - chunks.push_back({tok, nullptr}); + chunks.push_back({1, tok, nullptr}); + } + + void add_image_tokens(mtmd_image_tokens_ptr & image) { + GGML_ASSERT(image != nullptr); + size_t n_tokens = mtmd_image_tokens_get_n_tokens(image.get()); + chunks.push_back({n_tokens, LLAMA_TOKEN_NULL, std::move(image)}); } size_t get_common_prefix(const server_inputs & b) const { @@ -1068,8 +1092,7 @@ struct server_inputs { size_t current_pos = 0; for (size_t i = 0; i < chunks.size(); ++i) { const auto & chunk = chunks[i]; - size_t chunk_size = chunk.tok_image ? mtmd_image_tokens_get_n_tokens(chunk.tok_image.get()) : 1; - size_t chunk_end_pos = current_pos + chunk_size; + size_t chunk_end_pos = current_pos + chunk.n_tokens; if (pos < chunk_end_pos) { // The target position 'pos' falls within this chunk return i; @@ -1123,57 +1146,88 @@ struct server_inputs { // helper struct to make working with embd batch easier // note: this will be removed after llama_batch_ext refactoring -struct server_embd_batch { +struct server_batch { std::vector pos; + std::vector token; std::vector n_seq_id; - std::vector seq_id_0; + std::vector seq_id; std::vector seq_ids; std::vector logits; + llama_batch batch; - server_embd_batch() = default; - server_embd_batch(float * embd, int32_t n_tokens, llama_pos pos_0, llama_seq_id seq_id) { + + server_batch() : server_batch(1) {} + server_batch(int32_t n_tokens) { + token .resize(n_tokens); pos .resize(n_tokens); n_seq_id.resize(n_tokens); - seq_ids .resize(n_tokens + 1); logits .resize(n_tokens); - seq_id_0.resize(1); - seq_id_0[0] = seq_id; - seq_ids [n_tokens] = nullptr; + seq_id .resize(n_tokens); + seq_ids .resize(n_tokens + 1); + seq_ids[n_tokens] = nullptr; + batch = { - /*n_tokens =*/ n_tokens, - /*tokens =*/ nullptr, - /*embd =*/ embd, + /*n_tokens =*/ 0, + /*tokens =*/ token.data(), + /*embd =*/ nullptr, /*pos =*/ pos.data(), /*n_seq_id =*/ n_seq_id.data(), /*seq_id =*/ seq_ids.data(), /*logits =*/ logits.data(), }; + + for (int i = 0; i < n_tokens; i++) { + batch.n_seq_id[i] = 1; // only a single seq_id per token is needed + batch.seq_id [i] = seq_id.data() + i; + } + } + + void reserve_embd_batch(float * embd, int32_t n_tokens, llama_pos pos_0, llama_seq_id seq_id) { + GGML_ASSERT(n_tokens <= (int32_t)pos.size()); + seq_ids[n_tokens] = nullptr; + batch.n_tokens = n_tokens; + batch.embd = embd; + batch.token = nullptr; for (int i = 0; i < n_tokens; i++) { - batch.pos [i] = pos_0 + i; - batch.n_seq_id[i] = 1; - batch.seq_id [i] = seq_id_0.data(); - batch.logits [i] = false; + batch.pos [i] = pos_0 + i; + batch.n_seq_id[i] = 1; + batch.seq_id [i][0] = seq_id; + batch.logits [i] = false; } } + + void clear() { + batch.n_tokens = 0; + batch.embd = nullptr; + batch.token = token.data(); + } + + int32_t n_tokens() const { + return batch.n_tokens; + } + + bool has_embd() const { + return batch.embd != nullptr; + } }; // TODO @ngxson : quite hacky for now, but just to see if it works -static int32_t server_encode_image(mtmd_context * mctx, server_embd_batch & batch_out, server_inp_chunk & chunk, llama_pos n_past, llama_seq_id seq_id) { +static int32_t server_encode_image(mtmd_context * mctx, server_batch & batch_out, server_inp_chunk & chunk, llama_pos n_past, llama_seq_id seq_id) { GGML_ASSERT(chunk.tok_image); + batch_out.clear(); int64_t t0 = ggml_time_ms(); LOG_INF("encoding image...\n"); int32_t ret = mtmd_encode(mctx, chunk.tok_image.get()); if (ret != 0) { LOG_ERR("failed to encode image\n"); - batch_out = server_embd_batch{}; return ret; } LOG_INF("image encoded in %" PRId64 " ms\n", ggml_time_ms() - t0); int32_t n_tokens = mtmd_image_tokens_get_n_tokens(chunk.tok_image.get()); float * embd = mtmd_get_output_embd(mctx); - batch_out = server_embd_batch(embd, n_tokens, n_past, seq_id); + batch_out.reserve_embd_batch(embd, n_tokens, n_past, seq_id); return ret; } From a6a36537d2018417b51a83f767719ddd7abe8672 Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Mon, 21 Apr 2025 22:41:04 +0200 Subject: [PATCH 08/18] small fix --- examples/server/server.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/server/server.cpp b/examples/server/server.cpp index a356e8180fcbb..fbabd2872b0c9 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -1670,7 +1670,7 @@ struct server_queue { lock.unlock(); QUE_DBG("processing task, id = %d\n", task.id); - callback_new_task(task); + callback_new_task(std::move(task)); } // all tasks in the current loop is processed, slots data is now ready From f8bc46629fa4e669697106f251dd9e7d957ff218 Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Mon, 21 Apr 2025 23:18:44 +0200 Subject: [PATCH 09/18] refactor logic adding tokens to batch --- examples/server/server.cpp | 70 ++++++++++++++++++++------------------ examples/server/utils.hpp | 13 ++++--- 2 files changed, 44 insertions(+), 39 deletions(-) diff --git a/examples/server/server.cpp b/examples/server/server.cpp index fbabd2872b0c9..e9b1de10cd1a5 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -2859,7 +2859,7 @@ struct server_context { res->id = task.id; queue_results.send(std::move(res)); } break; - + } } @@ -3159,49 +3159,51 @@ struct server_context { // remove the non-common part from the cache slot.cache_tokens.keep_until(slot.n_past); - // add prompt tokens for processing in the current batch - while (slot.n_past < slot.n_prompt_tokens && batch.n_tokens() < n_batch) { - // without pooling, we want to output the embeddings for all the tokens in the batch - const bool need_embd = slot.task_type == SERVER_TASK_TYPE_EMBEDDING && llama_pooling_type(slot.ctx) == LLAMA_POOLING_TYPE_NONE; + auto & curr_chunk = slot.prompt_tokens.get_chunk(slot.n_past); - auto & curr_chunk = slot.prompt_tokens.get_chunk(slot.n_past); - if (curr_chunk.tok_image) { - // if there are already TEXT tokens in the batch, we need to process them first - if (batch.batch.n_tokens > 0) { - break; - } - // encode the image - server_encode_image(slot.mctx, batch, curr_chunk, slot.n_past, slot.id); - GGML_ASSERT(batch.has_embd()); - SLT_INF(slot, "image encoded, n_past = %d, n_embd_tokens = %d\n", slot.n_past, batch.n_tokens()); + // check if we should process the image + if (curr_chunk.tok_image) { + if (batch.has_text()) { + continue; // we cannot have both text batch and image batch + } - if (slot.params.cache_prompt) { - slot.cache_tokens.add_image_tokens(curr_chunk.tok_image); - } + // encode the image + server_encode_image(slot.mctx, batch, curr_chunk, slot.n_past, slot.id); + GGML_ASSERT(batch.has_embd()); + SLT_INF(slot, "image encoded, n_past = %d, n_embd_tokens = %d\n", slot.n_past, batch.n_tokens()); - slot.n_past += batch.n_tokens(); - slot.n_prompt_tokens_processed += batch.n_tokens(); - break; // we cannot have both text batch and image batch + if (slot.params.cache_prompt) { + slot.cache_tokens.add_image_tokens(curr_chunk.tok_image); + } - } else { - GGML_ASSERT(!batch.has_embd()); - common_batch_add(batch.batch, curr_chunk.tok_text, slot.n_past, { slot.id }, need_embd); - if (slot.params.cache_prompt) { - slot.cache_tokens.add_text_token(curr_chunk.tok_text); - } + slot.n_past += batch.n_tokens(); + slot.n_prompt_tokens_processed += batch.n_tokens(); - slot.n_prompt_tokens_processed++; - slot.n_past++; - } + break; // currently, we can only process one image at a time, so we skip ALL other slots } - SLT_INF(slot, "new cache_tokens: %s\n", slot.cache_tokens.str().c_str()); + // add prompt tokens for processing in the current batch + while (slot.n_past < slot.n_prompt_tokens && batch.n_tokens() < n_batch) { + GGML_ASSERT(!batch.has_embd()); + auto & curr_chunk = slot.prompt_tokens.get_chunk(slot.n_past); + if (curr_chunk.tok_text == LLAMA_TOKEN_NULL) { + break; // end of text chunk + } - if (batch.has_embd()) { - // currently, we can only process one image at a time, so we skip other slots - break; + // without pooling, we want to output the embeddings for all the tokens in the batch + const bool need_embd = slot.task_type == SERVER_TASK_TYPE_EMBEDDING && llama_pooling_type(slot.ctx) == LLAMA_POOLING_TYPE_NONE; + + common_batch_add(batch.batch, curr_chunk.tok_text, slot.n_past, { slot.id }, need_embd); + if (slot.params.cache_prompt) { + slot.cache_tokens.add_text_token(curr_chunk.tok_text); + } + + slot.n_prompt_tokens_processed++; + slot.n_past++; } + SLT_INF(slot, "new cache_tokens: %s\n", slot.cache_tokens.str().c_str()); + SLT_INF(slot, "prompt processing progress, n_past = %d, n_tokens = %d, progress = %f\n", slot.n_past, batch.n_tokens(), (float) slot.n_prompt_tokens_processed / slot.n_prompt_tokens); // entire prompt has been processed diff --git a/examples/server/utils.hpp b/examples/server/utils.hpp index 3bc0d0da17ec3..ce7e2780e3c16 100644 --- a/examples/server/utils.hpp +++ b/examples/server/utils.hpp @@ -668,7 +668,7 @@ static json oaicompat_completion_params_parse( p["text"] = "<__image__>"; p.erase("image_url"); } - } + } } common_chat_templates_inputs inputs; @@ -979,9 +979,9 @@ struct server_inp_chunk { /** * server_inputs is a helper to manage the input tokens and image for the server. - * + * * the difference between server_inputs and mtmd_input_chunks is that each chunk of server_inputs only contains a single text token, but text chunk of mtmd_input_chunks can contain multiple tokens. - * + * * it is made this way to simplify the logic of KV cache management. */ struct server_inputs { @@ -1184,7 +1184,6 @@ struct server_batch { void reserve_embd_batch(float * embd, int32_t n_tokens, llama_pos pos_0, llama_seq_id seq_id) { GGML_ASSERT(n_tokens <= (int32_t)pos.size()); - seq_ids[n_tokens] = nullptr; batch.n_tokens = n_tokens; batch.embd = embd; batch.token = nullptr; @@ -1207,7 +1206,11 @@ struct server_batch { } bool has_embd() const { - return batch.embd != nullptr; + return batch.embd != nullptr && batch.n_tokens > 0; + } + + bool has_text() const { + return batch.token != nullptr && batch.n_tokens > 0; } }; From f5420e1d90bf7228c12bb5f8cd85808c4cb00ba8 Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Mon, 21 Apr 2025 23:35:20 +0200 Subject: [PATCH 10/18] implement hashing image --- examples/server/CMakeLists.txt | 3 ++- examples/server/server.cpp | 19 ++++++++++++------- examples/server/utils.hpp | 10 ++++++++-- 3 files changed, 22 insertions(+), 10 deletions(-) diff --git a/examples/server/CMakeLists.txt b/examples/server/CMakeLists.txt index 17109fddbd307..0ff77b0944881 100644 --- a/examples/server/CMakeLists.txt +++ b/examples/server/CMakeLists.txt @@ -35,8 +35,9 @@ add_executable(${TARGET} ${TARGET_SRCS}) install(TARGETS ${TARGET} RUNTIME) target_include_directories(${TARGET} PRIVATE ../llava) +target_include_directories(${TARGET} PRIVATE ../gguf-hash/deps/sha1) # TODO @ngxson : this is a hacky way to get this working, to be fixed before merging target_include_directories(${TARGET} PRIVATE ${CMAKE_SOURCE_DIR}) -target_link_libraries(${TARGET} PRIVATE common mtmd ${CMAKE_THREAD_LIBS_INIT}) +target_link_libraries(${TARGET} PRIVATE common mtmd sha1 ${CMAKE_THREAD_LIBS_INIT}) if (LLAMA_SERVER_SSL) find_package(OpenSSL REQUIRED) diff --git a/examples/server/server.cpp b/examples/server/server.cpp index e9b1de10cd1a5..af9e1270d40b0 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -8,6 +8,7 @@ #include "sampling.h" #include "speculative.h" #include "mtmd.h" +#include "sha1.h" // Change JSON_ASSERT from assert() to GGML_ASSERT: #define JSON_ASSERT GGML_ASSERT @@ -3202,7 +3203,7 @@ struct server_context { slot.n_past++; } - SLT_INF(slot, "new cache_tokens: %s\n", slot.cache_tokens.str().c_str()); + // SLT_INF(slot, "new cache_tokens: %s\n", slot.cache_tokens.str().c_str()); SLT_INF(slot, "prompt processing progress, n_past = %d, n_tokens = %d, progress = %f\n", slot.n_past, batch.n_tokens(), (float) slot.n_prompt_tokens_processed / slot.n_prompt_tokens); @@ -3244,11 +3245,7 @@ struct server_context { } // debug - if (batch.has_embd()) { - SRV_INF("decoding embd batch, n_tokens = %d\n", batch.n_tokens()); - } else { - SRV_INF("decoding batch, n_tokens = %d\n", batch.n_tokens()); - } + SRV_DBG("decoding %s batch, n_tokens = %d\n", batch.has_embd() ? "embd" : "text", batch.n_tokens()); if (slot_batched) { // make sure we're in the right embedding mode @@ -4036,6 +4033,14 @@ int main(int argc, char ** argv) { { for (auto & file : files) { mtmd_bitmap bmp; + // calculate hash (for KV caching) + { + SHA1_CTX sha1_ctx; + SHA1Update(&sha1_ctx, (unsigned char const *)file.data(), file.size()); + unsigned char result[21]; + SHA1Final(result, &sha1_ctx); + bmp.id = std::string((char *)result, 20); + } int32_t res = mtmd_helper_bitmap_init_from_buf(file.data(), file.size(), bmp); if (res != 0) { throw std::runtime_error("Failed to load image"); @@ -4049,7 +4054,7 @@ int main(int argc, char ** argv) { if (!prompt.is_string()) { throw std::runtime_error("prompt must be a string"); } else { - printf("prompt: %s\n", prompt.get().c_str()); + // SRV_INF("prompt: %s\n", prompt.get().c_str()); mtmd_input_text inp_txt = { prompt.get(), /* add_special */ true, diff --git a/examples/server/utils.hpp b/examples/server/utils.hpp index ce7e2780e3c16..d642d7831893c 100644 --- a/examples/server/utils.hpp +++ b/examples/server/utils.hpp @@ -1064,8 +1064,14 @@ struct server_inputs { ret++; continue; } else if (ai.tok_image && bi.tok_image) { - // TODO check image hash - break; + std::string ai_id = mtmd_image_tokens_get_id(ai.tok_image.get()); + std::string bi_id = mtmd_image_tokens_get_id(bi.tok_image.get()); + if (ai_id == bi_id) { + ret += mtmd_image_tokens_get_n_tokens(ai.tok_image.get()); + continue; + } else { + break; + } } else { break; } From cd115854786e5d83424cc97809d0b96746da5af6 Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Wed, 23 Apr 2025 20:57:22 +0200 Subject: [PATCH 11/18] use FNV hash, now hash bitmap instead of file data --- examples/server/CMakeLists.txt | 3 +-- examples/server/server.cpp | 11 ++--------- examples/server/utils.hpp | 12 ++++++++++++ 3 files changed, 15 insertions(+), 11 deletions(-) diff --git a/examples/server/CMakeLists.txt b/examples/server/CMakeLists.txt index 0ff77b0944881..17109fddbd307 100644 --- a/examples/server/CMakeLists.txt +++ b/examples/server/CMakeLists.txt @@ -35,9 +35,8 @@ add_executable(${TARGET} ${TARGET_SRCS}) install(TARGETS ${TARGET} RUNTIME) target_include_directories(${TARGET} PRIVATE ../llava) -target_include_directories(${TARGET} PRIVATE ../gguf-hash/deps/sha1) # TODO @ngxson : this is a hacky way to get this working, to be fixed before merging target_include_directories(${TARGET} PRIVATE ${CMAKE_SOURCE_DIR}) -target_link_libraries(${TARGET} PRIVATE common mtmd sha1 ${CMAKE_THREAD_LIBS_INIT}) +target_link_libraries(${TARGET} PRIVATE common mtmd ${CMAKE_THREAD_LIBS_INIT}) if (LLAMA_SERVER_SSL) find_package(OpenSSL REQUIRED) diff --git a/examples/server/server.cpp b/examples/server/server.cpp index af9e1270d40b0..c9c33a0778c87 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -8,7 +8,6 @@ #include "sampling.h" #include "speculative.h" #include "mtmd.h" -#include "sha1.h" // Change JSON_ASSERT from assert() to GGML_ASSERT: #define JSON_ASSERT GGML_ASSERT @@ -4033,18 +4032,12 @@ int main(int argc, char ** argv) { { for (auto & file : files) { mtmd_bitmap bmp; - // calculate hash (for KV caching) - { - SHA1_CTX sha1_ctx; - SHA1Update(&sha1_ctx, (unsigned char const *)file.data(), file.size()); - unsigned char result[21]; - SHA1Final(result, &sha1_ctx); - bmp.id = std::string((char *)result, 20); - } int32_t res = mtmd_helper_bitmap_init_from_buf(file.data(), file.size(), bmp); if (res != 0) { throw std::runtime_error("Failed to load image"); } + // calculate bitmap hash (for KV caching) + bmp.id = server_inputs::fnv_hash(bmp.data.data(), bmp.data.size()); bitmaps.push_back(std::move(bmp)); } } diff --git a/examples/server/utils.hpp b/examples/server/utils.hpp index d642d7831893c..6a711376f0a95 100644 --- a/examples/server/utils.hpp +++ b/examples/server/utils.hpp @@ -1148,6 +1148,18 @@ struct server_inputs { // If the loop completes, it means 'pos' is >= the total logical size. // No truncation needed, the vector remains unchanged. } + + // Computes FNV-1a hash of the data + static std::string fnv_hash(const uint8_t * data, size_t len) { + const uint64_t fnv_prime = 0x100000001b3ULL; + uint64_t hash = 0xcbf29ce484222325ULL; + + for (size_t i = 0; i < len; ++i) { + hash ^= data[i]; + hash *= fnv_prime; + } + return std::to_string(hash); + } }; // helper struct to make working with embd batch easier From 8afa9528371fe7b0e7e5edd9fd9bdee06bb4f327 Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Wed, 23 Apr 2025 22:20:52 +0200 Subject: [PATCH 12/18] allow decoding image embedding to be split into batches --- examples/server/server.cpp | 88 +++++++++++++++----------------- examples/server/utils.hpp | 102 +++++++++++++++++++++++++++++-------- 2 files changed, 121 insertions(+), 69 deletions(-) diff --git a/examples/server/server.cpp b/examples/server/server.cpp index c9c33a0778c87..ec571de136c3e 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -1860,7 +1860,8 @@ struct server_context { llama_context_params cparams_dft; - server_batch batch; + llama_batch batch; + server_batch_embd batch_embd; bool clean_kv_cache = true; bool add_bos_token = true; @@ -1898,6 +1899,8 @@ struct server_context { llama_batch_free(slot.batch_spec); } + + llama_batch_free(batch); } bool load_model(const common_params & params) { @@ -2034,7 +2037,8 @@ struct server_context { // note that n_batch can be > n_ctx (e.g. for non-causal attention models such as BERT where the KV cache is not used) { const int32_t n_batch = llama_n_batch(ctx); - batch = server_batch(std::max(n_batch, params_base.n_parallel)); + batch = llama_batch_init(std::max(n_batch, params_base.n_parallel), 0, 1); + batch_embd = server_batch_embd(std::max(n_batch, params_base.n_parallel)); } metrics.init(); @@ -2931,7 +2935,7 @@ struct server_context { }*/ // start populating the batch for this iteration - batch.clear(); + common_batch_clear(batch); // track if given slot can be batched with slots already in the batch server_slot * slot_batched = nullptr; @@ -2953,9 +2957,9 @@ struct server_context { continue; } - slot.i_batch = batch.n_tokens(); + slot.i_batch = batch.n_tokens; - common_batch_add(batch.batch, slot.sampled, slot.n_past, { slot.id }, true); + common_batch_add(batch, slot.sampled, slot.n_past, { slot.id }, true); slot.n_past += 1; @@ -2972,7 +2976,7 @@ struct server_context { int32_t n_ubatch = llama_n_ubatch(ctx); // next, batch any pending prompts without exceeding n_batch - if (params_base.cont_batching || batch.n_tokens() == 0) { + if (params_base.cont_batching || batch.n_tokens == 0) { for (auto & slot : slots) { // check if we can batch this slot with the previous one if (slot.is_processing()) { @@ -3140,7 +3144,7 @@ struct server_context { // non-causal tasks require to fit the entire prompt in the physical batch if (slot.is_non_causal()) { // cannot fit the prompt in the current batch - will try next iter - if (batch.n_tokens() + slot.n_prompt_tokens > n_batch) { + if (batch.n_tokens + slot.n_prompt_tokens > n_batch) { continue; } } @@ -3163,28 +3167,26 @@ struct server_context { // check if we should process the image if (curr_chunk.tok_image) { - if (batch.has_text()) { - continue; // we cannot have both text batch and image batch + // process the image + int32_t res = server_img_process(ctx, mctx, curr_chunk, batch_embd, slot.n_past, slot.id); + if (res != 0) { + SLT_ERR(slot, "failed to process image, res = %d\n", res); + slot.release(); + send_error(slot, "failed to process image", ERROR_TYPE_SERVER); + continue; } - // encode the image - server_encode_image(slot.mctx, batch, curr_chunk, slot.n_past, slot.id); - GGML_ASSERT(batch.has_embd()); - SLT_INF(slot, "image encoded, n_past = %d, n_embd_tokens = %d\n", slot.n_past, batch.n_tokens()); - if (slot.params.cache_prompt) { slot.cache_tokens.add_image_tokens(curr_chunk.tok_image); } - slot.n_past += batch.n_tokens(); - slot.n_prompt_tokens_processed += batch.n_tokens(); - - break; // currently, we can only process one image at a time, so we skip ALL other slots + slot.n_past += curr_chunk.n_tokens; + slot.n_prompt_tokens_processed += curr_chunk.n_tokens; } // add prompt tokens for processing in the current batch - while (slot.n_past < slot.n_prompt_tokens && batch.n_tokens() < n_batch) { - GGML_ASSERT(!batch.has_embd()); + while (slot.n_past < slot.n_prompt_tokens && batch.n_tokens < n_batch) { + // get next token to process auto & curr_chunk = slot.prompt_tokens.get_chunk(slot.n_past); if (curr_chunk.tok_text == LLAMA_TOKEN_NULL) { break; // end of text chunk @@ -3193,7 +3195,7 @@ struct server_context { // without pooling, we want to output the embeddings for all the tokens in the batch const bool need_embd = slot.task_type == SERVER_TASK_TYPE_EMBEDDING && llama_pooling_type(slot.ctx) == LLAMA_POOLING_TYPE_NONE; - common_batch_add(batch.batch, curr_chunk.tok_text, slot.n_past, { slot.id }, need_embd); + common_batch_add(batch, curr_chunk.tok_text, slot.n_past, { slot.id }, need_embd); if (slot.params.cache_prompt) { slot.cache_tokens.add_text_token(curr_chunk.tok_text); } @@ -3204,47 +3206,47 @@ struct server_context { // SLT_INF(slot, "new cache_tokens: %s\n", slot.cache_tokens.str().c_str()); - SLT_INF(slot, "prompt processing progress, n_past = %d, n_tokens = %d, progress = %f\n", slot.n_past, batch.n_tokens(), (float) slot.n_prompt_tokens_processed / slot.n_prompt_tokens); + SLT_INF(slot, "prompt processing progress, n_past = %d, n_tokens = %d, progress = %f\n", slot.n_past, batch.n_tokens, (float) slot.n_prompt_tokens_processed / slot.n_prompt_tokens); // entire prompt has been processed if (slot.n_past == slot.n_prompt_tokens) { slot.state = SLOT_STATE_DONE_PROMPT; - GGML_ASSERT(batch.n_tokens() > 0); + GGML_ASSERT(batch.n_tokens > 0); common_sampler_reset(slot.smpl); // Process all prompt tokens through sampler system for (size_t i = 0; i < slot.cache_tokens.n_tokens(); ++i) { - auto & curr_chunk = slot.cache_tokens.get_chunk(i); + auto & curr_chunk = slot.prompt_tokens.get_chunk(i); if (curr_chunk.tok_text != LLAMA_TOKEN_NULL) { common_sampler_accept(slot.smpl, curr_chunk.tok_text, false); } } // extract the logits only for the last token - batch.logits[batch.n_tokens() - 1] = true; + batch.logits[batch.n_tokens - 1] = true; slot.n_decoded = 0; - slot.i_batch = batch.n_tokens() - 1; + slot.i_batch = batch.n_tokens - 1; - SLT_INF(slot, "prompt done, n_past = %d, n_tokens = %d\n", slot.n_past, batch.n_tokens()); + SLT_INF(slot, "prompt done, n_past = %d, n_tokens = %d\n", slot.n_past, batch.n_tokens); } } - if (batch.n_tokens() >= n_batch) { + if (batch.n_tokens >= n_batch) { break; } } } - if (batch.n_tokens() == 0) { + if (batch.n_tokens == 0) { SRV_WRN("%s", "no tokens to decode\n"); return; } // debug - SRV_DBG("decoding %s batch, n_tokens = %d\n", batch.has_embd() ? "embd" : "text", batch.n_tokens()); + SRV_DBG("decoding batch, n_tokens = %d\n", batch.n_tokens); if (slot_batched) { // make sure we're in the right embedding mode @@ -3254,32 +3256,22 @@ struct server_context { } // process the created batch of tokens - for (int32_t i = 0; i < batch.n_tokens(); i += n_batch) { - const int32_t n_tokens = std::min(n_batch, batch.n_tokens() - i); + for (int32_t i = 0; i < batch.n_tokens; i += n_batch) { + const int32_t n_tokens = std::min(n_batch, batch.n_tokens - i); - // TODO @ngxson : hacky here, we don't want to split the embd batch - llama_batch batch_view = batch.has_embd() ? batch.batch : llama_batch{ + llama_batch batch_view = llama_batch{ n_tokens, - batch.batch.token + i, + batch.token + i, nullptr, - batch.batch.pos + i, - batch.batch.n_seq_id + i, - batch.batch.seq_id + i, - batch.batch.logits + i, + batch.pos + i, + batch.n_seq_id + i, + batch.seq_id + i, + batch.logits + i, }; - // TODO @ngxson : maybe move this to llama_batch_ext - if (batch.has_embd() && mtmd_decode_use_non_causal(mctx)) { - llama_set_causal_attn(ctx, false); - } - const int ret = llama_decode(ctx, batch_view); metrics.on_decoded(slots); - if (batch.has_embd() && mtmd_decode_use_non_causal(mctx)) { - llama_set_causal_attn(ctx, true); - } - if (ret != 0) { if (n_batch == 1 || ret < 0) { // if you get here, it means the KV cache is full - try increasing it via the context size diff --git a/examples/server/utils.hpp b/examples/server/utils.hpp index 6a711376f0a95..e6a67e2febd2b 100644 --- a/examples/server/utils.hpp +++ b/examples/server/utils.hpp @@ -963,6 +963,8 @@ static std::vector parse_lora_request( // (may need to refactor in near future) // +// each chunk can contain either one SINGLE text token or an image (multiple token embeddings) +// this is to simplify the logic of KV cache management struct server_inp_chunk { size_t n_tokens = 1; // always 1 in case of text llama_token tok_text; @@ -981,6 +983,15 @@ struct server_inp_chunk { * server_inputs is a helper to manage the input tokens and image for the server. * * the difference between server_inputs and mtmd_input_chunks is that each chunk of server_inputs only contains a single text token, but text chunk of mtmd_input_chunks can contain multiple tokens. + * + * for example, server_inputs may contain 5 text tokens followed by 1 image chunk: + * 1 41 2635 325 463 + * + * in this example: + * - n_tokens() returns 5+15 = 20 total tokens + * - get_chunk(1) returns chunk containing token ID 41 + * - get_chunk(5) returns image chunk (15 tokens) + * - get_chunk(7) returns same image chunk * * it is made this way to simplify the logic of KV cache management. */ @@ -1079,6 +1090,7 @@ struct server_inputs { return ret; } + // make sure all text tokens are within the vocab range bool validate(llama_token max_vocab_id) const { for (const auto & chunk : chunks) { if (!chunk.tok_image) { @@ -1090,24 +1102,26 @@ struct server_inputs { return true; } + // pos is also referred as logical index server_inp_chunk & get_chunk(size_t pos) { - return chunks[get_chunk_idx(pos)]; + size_t physical_idx = get_chunk_physical_idx(pos); + return chunks[physical_idx]; } - size_t get_chunk_idx(size_t pos) const { + // returns physical_index + size_t get_chunk_physical_idx(size_t logical_idx) const { size_t current_pos = 0; for (size_t i = 0; i < chunks.size(); ++i) { const auto & chunk = chunks[i]; size_t chunk_end_pos = current_pos + chunk.n_tokens; - if (pos < chunk_end_pos) { + if (logical_idx < chunk_end_pos) { // The target position 'pos' falls within this chunk return i; } - current_pos = chunk_end_pos; } // If the loop finishes, 'pos' is >= the total number of logical positions - return chunks.size(); + throw std::out_of_range("Position out of range"); } // same idea with std::vector resize() @@ -1164,7 +1178,7 @@ struct server_inputs { // helper struct to make working with embd batch easier // note: this will be removed after llama_batch_ext refactoring -struct server_batch { +struct server_batch_embd { std::vector pos; std::vector token; std::vector n_seq_id; @@ -1174,8 +1188,8 @@ struct server_batch { llama_batch batch; - server_batch() : server_batch(1) {} - server_batch(int32_t n_tokens) { + server_batch_embd() : server_batch_embd(1) {} + server_batch_embd(int32_t n_tokens) { token .resize(n_tokens); pos .resize(n_tokens); n_seq_id.resize(n_tokens); @@ -1233,23 +1247,69 @@ struct server_batch { }; // TODO @ngxson : quite hacky for now, but just to see if it works -static int32_t server_encode_image(mtmd_context * mctx, server_batch & batch_out, server_inp_chunk & chunk, llama_pos n_past, llama_seq_id seq_id) { +static int32_t server_img_process( + llama_context * ctx, + mtmd_context * mctx, + server_inp_chunk & chunk, + server_batch_embd & batch, + llama_pos n_past, + int slot_id) { GGML_ASSERT(chunk.tok_image); - batch_out.clear(); - - int64_t t0 = ggml_time_ms(); - LOG_INF("encoding image...\n"); - int32_t ret = mtmd_encode(mctx, chunk.tok_image.get()); - if (ret != 0) { - LOG_ERR("failed to encode image\n"); - return ret; + int32_t ret; + + // encode the image + { + int64_t t0 = ggml_time_ms(); + SRV_INF("encoding image (%d tokens)...\n", (int)chunk.n_tokens); + ret = mtmd_encode(mctx, chunk.tok_image.get()); + if (ret != 0) { + SRV_ERR("failed to encode image, status = %d\n", ret); + return ret; + } + SRV_INF("image encoded in %" PRId64 " ms\n", ggml_time_ms() - t0); } - LOG_INF("image encoded in %" PRId64 " ms\n", ggml_time_ms() - t0); - int32_t n_tokens = mtmd_image_tokens_get_n_tokens(chunk.tok_image.get()); float * embd = mtmd_get_output_embd(mctx); - batch_out.reserve_embd_batch(embd, n_tokens, n_past, seq_id); - return ret; + // decode the embeddings + int64_t t1 = ggml_time_ms(); + int32_t n_embd = llama_model_n_embd(llama_get_model(ctx)); + int32_t n_tokens = chunk.n_tokens; + int32_t n_batch = batch.pos.size(); + int32_t i_batch = 0; + int32_t n_img_batches = GGML_PAD(n_tokens, n_batch) / n_batch; + // split into batches + while (i_batch < n_img_batches) { + int32_t pos_offset = i_batch*n_batch; + int32_t n_tokens_batch = std::min(n_batch, n_tokens - pos_offset); + float * embd_batch = embd + pos_offset*n_embd; + batch.clear(); + batch.reserve_embd_batch(embd_batch, n_tokens_batch, n_past, slot_id); + + SRV_INF("decoding embd batch %d/%d, n_tokens_batch = %d\n", i_batch+1, n_img_batches, n_tokens_batch); + + // TODO @ngxson : maybe move this to llama_batch_ext + if (mtmd_decode_use_non_causal(mctx)) { + llama_set_causal_attn(ctx, false); + } + + ret = llama_decode(ctx, batch.batch); + if (ret != 0) { + LOG_ERR("failed to decode image\n"); + llama_set_causal_attn(ctx, true); // restore causal attn + return ret; + } + + if (mtmd_decode_use_non_causal(mctx)) { + llama_set_causal_attn(ctx, true); + } + + i_batch++; + n_past += n_tokens_batch; + } + SRV_INF("image decoded in %" PRId64 " ms\n", ggml_time_ms() - t1); + + batch.clear(); + return 0; } // hacky, support text-only for now From 989730c6e1b87b495f2fc36ab986094a8295d7e1 Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Wed, 23 Apr 2025 22:21:40 +0200 Subject: [PATCH 13/18] rm whitespace --- examples/server/utils.hpp | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/examples/server/utils.hpp b/examples/server/utils.hpp index e6a67e2febd2b..0914a9c425cf8 100644 --- a/examples/server/utils.hpp +++ b/examples/server/utils.hpp @@ -983,10 +983,10 @@ struct server_inp_chunk { * server_inputs is a helper to manage the input tokens and image for the server. * * the difference between server_inputs and mtmd_input_chunks is that each chunk of server_inputs only contains a single text token, but text chunk of mtmd_input_chunks can contain multiple tokens. - * + * * for example, server_inputs may contain 5 text tokens followed by 1 image chunk: * 1 41 2635 325 463 - * + * * in this example: * - n_tokens() returns 5+15 = 20 total tokens * - get_chunk(1) returns chunk containing token ID 41 @@ -1163,16 +1163,16 @@ struct server_inputs { // No truncation needed, the vector remains unchanged. } - // Computes FNV-1a hash of the data - static std::string fnv_hash(const uint8_t * data, size_t len) { - const uint64_t fnv_prime = 0x100000001b3ULL; - uint64_t hash = 0xcbf29ce484222325ULL; + // Computes FNV-1a hash of the data + static std::string fnv_hash(const uint8_t * data, size_t len) { + const uint64_t fnv_prime = 0x100000001b3ULL; + uint64_t hash = 0xcbf29ce484222325ULL; - for (size_t i = 0; i < len; ++i) { - hash ^= data[i]; - hash *= fnv_prime; - } - return std::to_string(hash); + for (size_t i = 0; i < len; ++i) { + hash ^= data[i]; + hash *= fnv_prime; + } + return std::to_string(hash); } }; From 2df8c1a4b422fbf30c01ae1de75506826dd3499f Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Thu, 24 Apr 2025 23:14:13 +0200 Subject: [PATCH 14/18] disable some features when mtmd is on --- examples/server/server.cpp | 135 +++++++++++++++++++++++++------------ examples/server/utils.hpp | 21 ++++++ 2 files changed, 112 insertions(+), 44 deletions(-) diff --git a/examples/server/server.cpp b/examples/server/server.cpp index ec571de136c3e..dde300decef44 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -1983,6 +1983,21 @@ struct server_context { return false; } SRV_INF("loaded multimodal model, '%s'\n", mmproj_path.c_str()); + + if (params_base.ctx_shift) { + params_base.ctx_shift = false; + SRV_INF("%s\n", "ctx_shift is not supported by multimodal, it will be disabled"); + } + + if (params_base.n_cache_reuse) { + params_base.n_cache_reuse = 0; + SRV_INF("%s\n", "cache_reuse is not supported by multimodal, it will be disabled"); + } + + if (!params_base.speculative.model.path.empty()) { + SRV_ERR("%s\n", "err: speculative decode is not supported by multimodal"); + return false; + } } return true; @@ -2432,6 +2447,7 @@ struct server_context { void send_final_response(server_slot & slot) { auto res = std::make_unique(); + llama_tokens text_tokens = slot.prompt_tokens.get_text_tokens(); res->id = slot.id_task; res->id_slot = slot.id; @@ -2439,7 +2455,7 @@ struct server_context { res->content = std::move(slot.generated_text); res->tokens = std::move(slot.generated_tokens); res->timings = slot.get_timings(); - //res->prompt = common_detokenize(ctx, slot.prompt_tokens, true); // TODO @ngxson : hacky, need to fix + res->prompt = common_detokenize(ctx, text_tokens, true); res->response_fields = std::move(slot.params.response_fields); res->truncated = slot.truncated; @@ -2747,10 +2763,14 @@ struct server_context { } queue_results.send(std::move(res)); } break; - /*case SERVER_TASK_TYPE_SLOT_SAVE: + case SERVER_TASK_TYPE_SLOT_SAVE: { int id_slot = task.slot_action.slot_id; server_slot * slot = get_slot_by_id(id_slot); + if (mctx) { + send_error(task, "This feature is not supported by multimodal", ERROR_TYPE_NOT_SUPPORTED); + break; + } if (slot == nullptr) { send_error(task, "Invalid slot ID", ERROR_TYPE_INVALID_REQUEST); break; @@ -2762,13 +2782,14 @@ struct server_context { break; } - const size_t token_count = slot->cache_tokens.size(); + const size_t token_count = slot->cache_tokens.n_tokens(); const int64_t t_start = ggml_time_us(); std::string filename = task.slot_action.filename; std::string filepath = task.slot_action.filepath; - const size_t nwrite = llama_state_seq_save_file(ctx, filepath.c_str(), slot->id, slot->cache_tokens.data(), token_count); + const llama_tokens tokens = slot->cache_tokens.get_text_tokens(); + const size_t nwrite = llama_state_seq_save_file(ctx, filepath.c_str(), slot->id, tokens.data(), token_count); const int64_t t_end = ggml_time_us(); const double t_save_ms = (t_end - t_start) / 1000.0; @@ -2785,6 +2806,10 @@ struct server_context { } break; case SERVER_TASK_TYPE_SLOT_RESTORE: { + if (mctx) { + send_error(task, "This feature is not supported by multimodal", ERROR_TYPE_NOT_SUPPORTED); + break; + } int id_slot = task.slot_action.slot_id; server_slot * slot = get_slot_by_id(id_slot); if (slot == nullptr) { @@ -2803,15 +2828,17 @@ struct server_context { std::string filename = task.slot_action.filename; std::string filepath = task.slot_action.filepath; - slot->cache_tokens.resize(slot->n_ctx); + llama_tokens tokens; + tokens.resize(slot->n_ctx); size_t token_count = 0; - size_t nread = llama_state_seq_load_file(ctx, filepath.c_str(), slot->id, slot->cache_tokens.data(), slot->cache_tokens.size(), &token_count); + size_t nread = llama_state_seq_load_file(ctx, filepath.c_str(), slot->id, tokens.data(), tokens.size(), &token_count); if (nread == 0) { - slot->cache_tokens.resize(0); + slot->cache_tokens.clear(); // KV may already been invalidated? send_error(task, "Unable to restore slot, no available space in KV cache or invalid slot save file", ERROR_TYPE_INVALID_REQUEST); break; } - slot->cache_tokens.resize(token_count); + tokens.resize(token_count); + slot->cache_tokens.set_text_tokens(tokens); const int64_t t_end = ggml_time_us(); const double t_restore_ms = (t_end - t_start) / 1000.0; @@ -2828,6 +2855,10 @@ struct server_context { } break; case SERVER_TASK_TYPE_SLOT_ERASE: { + if (mctx) { + send_error(task, "This feature is not supported by multimodal", ERROR_TYPE_NOT_SUPPORTED); + break; + } int id_slot = task.slot_action.slot_id; server_slot * slot = get_slot_by_id(id_slot); if (slot == nullptr) { @@ -2842,7 +2873,7 @@ struct server_context { } // Erase token cache - const size_t n_erased = slot->cache_tokens.size(); + const size_t n_erased = slot->cache_tokens.n_tokens(); llama_kv_self_seq_rm(ctx, slot->id, -1, -1); slot->cache_tokens.clear(); @@ -2851,11 +2882,7 @@ struct server_context { res->id_slot = id_slot; res->n_erased = n_erased; queue_results.send(std::move(res)); - } break;*/ - case SERVER_TASK_TYPE_SLOT_SAVE: - case SERVER_TASK_TYPE_SLOT_RESTORE: - case SERVER_TASK_TYPE_SLOT_ERASE: - GGML_ASSERT(false && "TODO @ngxson : removed due to not compat with multimodal"); + } break; case SERVER_TASK_TYPE_SET_LORA: { params_base.lora_adapters = std::move(task.set_lora); @@ -2899,8 +2926,7 @@ struct server_context { // apply context-shift if needed // TODO: simplify and improve - // TODO @ngxson : hacky, need to disable context shift for multimodal - /*for (server_slot & slot : slots) { + for (server_slot & slot : slots) { if (slot.is_processing() && slot.n_past + 1 >= slot.n_ctx) { if (!params_base.ctx_shift) { // this check is redundant (for good) @@ -2910,6 +2936,12 @@ struct server_context { continue; } + if (mctx) { + // we should never reach this because params_base.ctx_shift is automatically disabled if mmproj is loaded + // we don't support ctx_shift because an image chunk may contains multiple tokens + GGML_ABORT("not supported by multimodal"); + } + // Shift context const int n_keep = slot.params.n_keep + add_bos_token; const int n_left = slot.n_past - n_keep; @@ -2921,18 +2953,18 @@ struct server_context { llama_kv_self_seq_add(ctx, slot.id, n_keep + n_discard, slot.n_past, -n_discard); if (slot.params.cache_prompt) { - for (size_t i = n_keep + n_discard; i < slot.cache_tokens.size(); i++) { - slot.cache_tokens[i - n_discard] = slot.cache_tokens[i]; + for (size_t i = n_keep + n_discard; i < slot.cache_tokens.chunks.size(); i++) { + slot.cache_tokens.chunks[i - n_discard] = std::move(slot.cache_tokens.chunks[i]); } - slot.cache_tokens.resize(slot.cache_tokens.size() - n_discard); + slot.cache_tokens.chunks.resize(slot.cache_tokens.chunks.size() - n_discard); } slot.n_past -= n_discard; slot.truncated = true; } - }*/ + } // start populating the batch for this iteration common_batch_clear(batch); @@ -3054,51 +3086,59 @@ struct server_context { slot.params.n_keep = std::min(slot.n_ctx - 4, slot.params.n_keep); // if input prompt is too big, truncate it - // TODO @ngxson : this won't work with multimodal - /*if (slot.n_prompt_tokens >= slot.n_ctx) { + if (slot.n_prompt_tokens >= slot.n_ctx) { + if (mctx) { + // we should never reach this + GGML_ABORT("not supported by multimodal"); + } + llama_tokens curr_tokens = slot.prompt_tokens.get_text_tokens(); const int n_left = slot.n_ctx - slot.params.n_keep; const int n_block_size = n_left / 2; const int erased_blocks = (slot.n_prompt_tokens - slot.params.n_keep - n_block_size) / n_block_size; llama_tokens new_tokens( - prompt_tokens.begin(), - prompt_tokens.begin() + slot.params.n_keep); + curr_tokens.begin(), + curr_tokens.begin() + slot.params.n_keep); new_tokens.insert( new_tokens.end(), - prompt_tokens.begin() + slot.params.n_keep + erased_blocks * n_block_size, - prompt_tokens.end()); + curr_tokens.begin() + slot.params.n_keep + erased_blocks * n_block_size, + curr_tokens.end()); - prompt_tokens = std::move(new_tokens); + prompt_tokens.set_text_tokens(new_tokens); slot.truncated = true; - slot.n_prompt_tokens = prompt_tokens.size(); + slot.n_prompt_tokens = prompt_tokens.n_tokens(); SLT_WRN(slot, "input truncated, n_ctx = %d, n_keep = %d, n_left = %d, n_prompt_tokens = %d\n", slot.n_ctx, slot.params.n_keep, n_left, slot.n_prompt_tokens); GGML_ASSERT(slot.n_prompt_tokens < slot.n_ctx); - }*/ + } if (slot.params.cache_prompt) { // reuse any previously computed tokens that are common with the new prompt slot.n_past = slot.cache_tokens.get_common_prefix(prompt_tokens); // reuse chunks from the cached prompt by shifting their KV cache in the new position - // TODO @ngxson : this won't work with multimodal - /*if (params_base.n_cache_reuse > 0) { + if (params_base.n_cache_reuse > 0) { size_t head_c = slot.n_past; // cache size_t head_p = slot.n_past; // current prompt + if (mctx) { + // we should never reach this + GGML_ABORT("not supported by multimodal"); + } + SLT_DBG(slot, "trying to reuse chunks with size > %d, slot.n_past = %d\n", params_base.n_cache_reuse, slot.n_past); - while (head_c < slot.cache_tokens.size() && - head_p < prompt_tokens.size()) { + while (head_c < slot.cache_tokens.chunks.size() && + head_p < prompt_tokens.chunks.size()) { size_t n_match = 0; - while (head_c + n_match < slot.cache_tokens.size() && - head_p + n_match < prompt_tokens.size() && - slot.cache_tokens[head_c + n_match] == prompt_tokens[head_p + n_match]) { + while (head_c + n_match < slot.cache_tokens.chunks.size() && + head_p + n_match < prompt_tokens.chunks.size() && + slot.cache_tokens.chunks[head_c + n_match].tok_text == prompt_tokens.chunks[head_p + n_match].tok_text) { n_match++; } @@ -3115,7 +3155,7 @@ struct server_context { llama_kv_self_seq_add(ctx, slot.id, head_c, head_c + n_match, kv_shift); for (size_t i = 0; i < n_match; i++) { - slot.cache_tokens[head_p + i] = slot.cache_tokens[head_c + i]; + slot.cache_tokens.chunks[head_p + i].tok_text = slot.cache_tokens.chunks[head_c + i].tok_text; slot.n_past++; } @@ -3127,7 +3167,7 @@ struct server_context { } SLT_DBG(slot, "after context reuse, new slot.n_past = %d\n", slot.n_past); - }*/ + } } } @@ -3359,8 +3399,7 @@ struct server_context { } // do speculative decoding - // TODO @ngxson : remove speculative decoding for multimodal - /*for (auto & slot : slots) { + for (auto & slot : slots) { if (!slot.is_processing() || !slot.can_speculate()) { continue; } @@ -3369,6 +3408,11 @@ struct server_context { continue; } + if (mctx) { + // we should never reach this + GGML_ABORT("not supported by multimodal"); + } + // determine the max draft that fits the current slot state int n_draft_max = slot.params.speculative.n_max; @@ -3395,7 +3439,8 @@ struct server_context { params_spec.n_reuse = llama_n_ctx(slot.ctx_dft) - slot.params.speculative.n_max; params_spec.p_min = slot.params.speculative.p_min; - llama_tokens draft = common_speculative_gen_draft(slot.spec, params_spec, slot.cache_tokens, id); + llama_tokens cached_text_tokens = slot.cache_tokens.get_text_tokens(); + llama_tokens draft = common_speculative_gen_draft(slot.spec, params_spec, cached_text_tokens, id); // keep track of total number of tokens generated in the draft slot.n_draft_total += draft.size(); @@ -3428,8 +3473,10 @@ struct server_context { // update how many tokens out of draft was accepted slot.n_draft_accepted += ids.size() - 1; - slot.cache_tokens.push_back(id); - slot.cache_tokens.insert(slot.cache_tokens.end(), ids.begin(), ids.end() - 1); + slot.cache_tokens.add_text_token(id); + for (auto & t : ids) { + slot.cache_tokens.add_text_token(t); + } llama_kv_self_seq_rm(ctx, slot.id, slot.n_past, -1); @@ -3453,7 +3500,7 @@ struct server_context { } SLT_DBG(slot, "accepted %d/%d draft tokens, new n_past = %d\n", (int) ids.size() - 1, (int) draft.size(), slot.n_past); - }*/ + } } SRV_DBG("%s", "run slots completed\n"); diff --git a/examples/server/utils.hpp b/examples/server/utils.hpp index 0914a9c425cf8..425d48ba5ab9d 100644 --- a/examples/server/utils.hpp +++ b/examples/server/utils.hpp @@ -1174,6 +1174,27 @@ struct server_inputs { } return std::to_string(hash); } + + // TODO: maybe implement a (de)seralizer for this struct, so we can get rid of functions below + + // return all text tokens (for legacy code), to be used by save/load slot + llama_tokens get_text_tokens() { + llama_tokens output; + for (auto & chunk : chunks) { + if (chunk.tok_text != LLAMA_TOKEN_NULL) { + output.push_back(chunk.tok_text); + } + } + return output; + } + + // clear and set text tokens (for legacy code), to be used by save/load slot + void set_text_tokens(llama_tokens tokens) { + chunks.clear(); + for (auto & tok : tokens) { + add_text_token(tok); + } + } }; // helper struct to make working with embd batch easier From b9ef895fd779cc9c29d76910fa12f9372f215386 Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Fri, 25 Apr 2025 10:41:25 +0200 Subject: [PATCH 15/18] fix --no-mmproj-offload --- examples/server/server.cpp | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/examples/server/server.cpp b/examples/server/server.cpp index dde300decef44..185af7bcabd1f 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -1975,8 +1975,12 @@ struct server_context { std::string & mmproj_path = params_base.mmproj.path; if (!mmproj_path.empty()) { - mtmd_context_params mparams; - mparams.n_threads = params_base.cpuparams.n_threads; + mtmd_context_params mparams{ + /* use_gpu */ params_base.mmproj_use_gpu, + /* timings */ true, + /* n_threads */ params_base.cpuparams.n_threads, + /* verbosity */ params_base.verbosity > 0 ? GGML_LOG_LEVEL_DEBUG : GGML_LOG_LEVEL_INFO, + }; mctx = mtmd_init_from_file(mmproj_path.c_str(), model, mparams); if (mctx == nullptr) { SRV_ERR("failed to load multimodal model, '%s'\n", mmproj_path.c_str()); From add9e215026b6d5465757fb369ba469da20db70e Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Fri, 25 Apr 2025 11:55:03 +0200 Subject: [PATCH 16/18] mtmd_context_params no timings --- examples/server/server.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 185af7bcabd1f..3eaf01b1409de 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -1977,7 +1977,7 @@ struct server_context { if (!mmproj_path.empty()) { mtmd_context_params mparams{ /* use_gpu */ params_base.mmproj_use_gpu, - /* timings */ true, + /* timings */ false, /* n_threads */ params_base.cpuparams.n_threads, /* verbosity */ params_base.verbosity > 0 ? GGML_LOG_LEVEL_DEBUG : GGML_LOG_LEVEL_INFO, }; From 58100b393d8c288c6f06fb6385d7a1127f1fc753 Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Fri, 25 Apr 2025 17:57:11 +0200 Subject: [PATCH 17/18] refactor server_inp to server_tokens --- examples/server/server.cpp | 139 +++++++++++----------- examples/server/utils.hpp | 228 ++++++++++++++----------------------- 2 files changed, 162 insertions(+), 205 deletions(-) diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 3eaf01b1409de..9428fceb0396c 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -198,7 +198,7 @@ struct server_task { // used by SERVER_TASK_TYPE_INFERENCE slot_params params; - server_inputs prompt_tokens; + server_tokens prompt_tokens; int id_selected_slot = -1; // used by SERVER_TASK_TYPE_SLOT_SAVE, SERVER_TASK_TYPE_SLOT_RESTORE, SERVER_TASK_TYPE_SLOT_ERASE @@ -1277,14 +1277,14 @@ struct server_slot { int32_t n_prompt_tokens_processed = 0; // input prompt tokens - server_inputs prompt_tokens; + server_tokens prompt_tokens; size_t last_nl_pos = 0; std::string generated_text; llama_tokens generated_tokens; - server_inputs cache_tokens; + server_tokens cache_tokens; std::vector generated_token_probs; @@ -2020,6 +2020,7 @@ struct server_context { slot.n_ctx = n_ctx_slot; slot.n_predict = params_base.n_predict; slot.mctx = mctx; + slot.cache_tokens.has_mtmd = mctx != nullptr; if (model_dft) { slot.batch_spec = llama_batch_init(params_base.speculative.n_max + 1, 0, 1); @@ -2096,7 +2097,7 @@ struct server_context { int cur_lcs_len = slot.cache_tokens.get_common_prefix(task.prompt_tokens); // fraction of the common subsequence length compared to the current slot's prompt length - float cur_similarity = static_cast(cur_lcs_len) / static_cast(slot.cache_tokens.n_tokens()); + float cur_similarity = static_cast(cur_lcs_len) / static_cast(slot.cache_tokens.size()); // select the current slot if the criteria match if (cur_lcs_len > lcs_len && cur_similarity > slot_prompt_similarity) { @@ -2135,7 +2136,7 @@ struct server_context { return ret; } - bool can_be_detokenized(const struct llama_context * ctx, const server_inputs & inp) { + bool can_be_detokenized(const struct llama_context * ctx, const server_tokens & inp) { const llama_model * model = llama_get_model(ctx); const llama_vocab * vocab = llama_model_get_vocab(model); const int32_t n_vocab = llama_vocab_n_tokens(vocab); @@ -2786,7 +2787,7 @@ struct server_context { break; } - const size_t token_count = slot->cache_tokens.n_tokens(); + const size_t token_count = slot->cache_tokens.size(); const int64_t t_start = ggml_time_us(); std::string filename = task.slot_action.filename; @@ -2877,7 +2878,7 @@ struct server_context { } // Erase token cache - const size_t n_erased = slot->cache_tokens.n_tokens(); + const size_t n_erased = slot->cache_tokens.size(); llama_kv_self_seq_rm(ctx, slot->id, -1, -1); slot->cache_tokens.clear(); @@ -2957,11 +2958,11 @@ struct server_context { llama_kv_self_seq_add(ctx, slot.id, n_keep + n_discard, slot.n_past, -n_discard); if (slot.params.cache_prompt) { - for (size_t i = n_keep + n_discard; i < slot.cache_tokens.chunks.size(); i++) { - slot.cache_tokens.chunks[i - n_discard] = std::move(slot.cache_tokens.chunks[i]); + for (size_t i = n_keep + n_discard; i < slot.cache_tokens.size(); i++) { + slot.cache_tokens[i - n_discard] = slot.cache_tokens[i]; } - slot.cache_tokens.chunks.resize(slot.cache_tokens.chunks.size() - n_discard); + slot.cache_tokens.resize(slot.cache_tokens.size() - n_discard); } slot.n_past -= n_discard; @@ -3004,7 +3005,7 @@ struct server_context { } SLT_DBG(slot, "slot decode token, n_ctx = %d, n_past = %d, n_cache_tokens = %d, truncated = %d\n", - slot.n_ctx, slot.n_past, (int) slot.cache_tokens.n_tokens(), slot.truncated); + slot.n_ctx, slot.n_past, (int) slot.cache_tokens.size(), slot.truncated); } // process in chunks of params.n_batch @@ -3033,23 +3034,23 @@ struct server_context { slot.t_start_generation = 0; slot.n_past = 0; - slot.n_prompt_tokens = prompt_tokens.n_tokens(); + slot.n_prompt_tokens = prompt_tokens.size(); slot.state = SLOT_STATE_PROCESSING_PROMPT; SLT_INF(slot, "new prompt, n_ctx_slot = %d, n_keep = %d, n_prompt_tokens = %d\n", slot.n_ctx, slot.params.n_keep, slot.n_prompt_tokens); // print prompt tokens (for debugging) - // if (1) { - // // first 16 tokens (avoid flooding logs) - // for (int i = 0; i < std::min(16, prompt_tokens.size()); i++) { - // SLT_DBG(slot, "prompt token %3d: %6d '%s'\n", i, prompt_tokens[i], common_token_to_piece(ctx, prompt_tokens[i]).c_str()); - // } - // } else { - // // all - // for (int i = 0; i < (int) prompt_tokens.size(); i++) { - // SLT_DBG(slot, "prompt token %3d: %6d '%s'\n", i, prompt_tokens[i], common_token_to_piece(ctx, prompt_tokens[i]).c_str()); - // } - // } + /*if (1) { + // first 16 tokens (avoid flooding logs) + for (int i = 0; i < std::min(16, prompt_tokens.size()); i++) { + SLT_DBG(slot, "prompt token %3d: %6d '%s'\n", i, prompt_tokens[i], common_token_to_piece(ctx, prompt_tokens[i]).c_str()); + } + } else { + // all + for (int i = 0; i < (int) prompt_tokens.size(); i++) { + SLT_DBG(slot, "prompt token %3d: %6d '%s'\n", i, prompt_tokens[i], common_token_to_piece(ctx, prompt_tokens[i]).c_str()); + } + }*/ // empty prompt passed -> release the slot and send empty response if (prompt_tokens.empty()) { @@ -3113,7 +3114,7 @@ struct server_context { prompt_tokens.set_text_tokens(new_tokens); slot.truncated = true; - slot.n_prompt_tokens = prompt_tokens.n_tokens(); + slot.n_prompt_tokens = prompt_tokens.size(); SLT_WRN(slot, "input truncated, n_ctx = %d, n_keep = %d, n_left = %d, n_prompt_tokens = %d\n", slot.n_ctx, slot.params.n_keep, n_left, slot.n_prompt_tokens); @@ -3136,13 +3137,13 @@ struct server_context { SLT_DBG(slot, "trying to reuse chunks with size > %d, slot.n_past = %d\n", params_base.n_cache_reuse, slot.n_past); - while (head_c < slot.cache_tokens.chunks.size() && - head_p < prompt_tokens.chunks.size()) { + while (head_c < slot.cache_tokens.size() && + head_p < prompt_tokens.size()) { size_t n_match = 0; - while (head_c + n_match < slot.cache_tokens.chunks.size() && - head_p + n_match < prompt_tokens.chunks.size() && - slot.cache_tokens.chunks[head_c + n_match].tok_text == prompt_tokens.chunks[head_p + n_match].tok_text) { + while (head_c + n_match < slot.cache_tokens.size() && + head_p + n_match < prompt_tokens.size() && + slot.cache_tokens[head_c + n_match].txt == prompt_tokens[head_p + n_match].txt) { n_match++; } @@ -3159,7 +3160,7 @@ struct server_context { llama_kv_self_seq_add(ctx, slot.id, head_c, head_c + n_match, kv_shift); for (size_t i = 0; i < n_match; i++) { - slot.cache_tokens.chunks[head_p + i].tok_text = slot.cache_tokens.chunks[head_c + i].tok_text; + slot.cache_tokens[head_p + i].txt = slot.cache_tokens[head_c + i].txt; slot.n_past++; } @@ -3207,12 +3208,13 @@ struct server_context { // remove the non-common part from the cache slot.cache_tokens.keep_until(slot.n_past); - auto & curr_chunk = slot.prompt_tokens.get_chunk(slot.n_past); + auto & cur_tok = slot.prompt_tokens[slot.n_past]; // check if we should process the image - if (curr_chunk.tok_image) { + if (cur_tok.img) { // process the image - int32_t res = server_img_process(ctx, mctx, curr_chunk, batch_embd, slot.n_past, slot.id); + int32_t res = server_img_process(ctx, mctx, cur_tok, batch_embd, slot.n_past, slot.id); + int32_t n_tokens = mtmd_image_tokens_get_n_tokens(cur_tok.img.get()); if (res != 0) { SLT_ERR(slot, "failed to process image, res = %d\n", res); slot.release(); @@ -3221,27 +3223,30 @@ struct server_context { } if (slot.params.cache_prompt) { - slot.cache_tokens.add_image_tokens(curr_chunk.tok_image); + // all ALL image tokens at once + for (int32_t i = 0; i < n_tokens; i++) { + slot.cache_tokens.add_token(std::move(slot.prompt_tokens[slot.n_past + i])); + } } - slot.n_past += curr_chunk.n_tokens; - slot.n_prompt_tokens_processed += curr_chunk.n_tokens; + slot.n_past += n_tokens; + slot.n_prompt_tokens_processed += n_tokens; } // add prompt tokens for processing in the current batch while (slot.n_past < slot.n_prompt_tokens && batch.n_tokens < n_batch) { // get next token to process - auto & curr_chunk = slot.prompt_tokens.get_chunk(slot.n_past); - if (curr_chunk.tok_text == LLAMA_TOKEN_NULL) { + auto & curr_chunk = slot.prompt_tokens[slot.n_past]; + if (curr_chunk.txt == LLAMA_TOKEN_NULL) { break; // end of text chunk } // without pooling, we want to output the embeddings for all the tokens in the batch const bool need_embd = slot.task_type == SERVER_TASK_TYPE_EMBEDDING && llama_pooling_type(slot.ctx) == LLAMA_POOLING_TYPE_NONE; - common_batch_add(batch, curr_chunk.tok_text, slot.n_past, { slot.id }, need_embd); + common_batch_add(batch, curr_chunk.txt, slot.n_past, { slot.id }, need_embd); if (slot.params.cache_prompt) { - slot.cache_tokens.add_text_token(curr_chunk.tok_text); + slot.cache_tokens.add_text_token(curr_chunk.txt); } slot.n_prompt_tokens_processed++; @@ -3261,10 +3266,10 @@ struct server_context { common_sampler_reset(slot.smpl); // Process all prompt tokens through sampler system - for (size_t i = 0; i < slot.cache_tokens.n_tokens(); ++i) { - auto & curr_chunk = slot.prompt_tokens.get_chunk(i); - if (curr_chunk.tok_text != LLAMA_TOKEN_NULL) { - common_sampler_accept(slot.smpl, curr_chunk.tok_text, false); + for (size_t i = 0; i < slot.cache_tokens.size(); ++i) { + auto & cur_tok = slot.prompt_tokens[i]; + if (cur_tok.txt != LLAMA_TOKEN_NULL) { + common_sampler_accept(slot.smpl, cur_tok.txt, false); } } @@ -3289,7 +3294,6 @@ struct server_context { return; } - // debug SRV_DBG("decoding batch, n_tokens = %d\n", batch.n_tokens); if (slot_batched) { @@ -3303,7 +3307,7 @@ struct server_context { for (int32_t i = 0; i < batch.n_tokens; i += n_batch) { const int32_t n_tokens = std::min(n_batch, batch.n_tokens - i); - llama_batch batch_view = llama_batch{ + llama_batch batch_view = { n_tokens, batch.token + i, nullptr, @@ -4072,7 +4076,11 @@ int main(int argc, char ** argv) { // process files std::vector bitmaps; + const bool has_mtmd = ctx_server.mctx != nullptr; { + if (!has_mtmd && !files.empty()) { + throw std::runtime_error("This server does not support multimodal"); + } for (auto & file : files) { mtmd_bitmap bmp; int32_t res = mtmd_helper_bitmap_init_from_buf(file.data(), file.size(), bmp); @@ -4080,30 +4088,31 @@ int main(int argc, char ** argv) { throw std::runtime_error("Failed to load image"); } // calculate bitmap hash (for KV caching) - bmp.id = server_inputs::fnv_hash(bmp.data.data(), bmp.data.size()); + bmp.id = server_tokens::fnv_hash(bmp.data.data(), bmp.data.size()); bitmaps.push_back(std::move(bmp)); } } - std::vector inputs; - if (oaicompat) { - if (!prompt.is_string()) { - throw std::runtime_error("prompt must be a string"); - } else { - // SRV_INF("prompt: %s\n", prompt.get().c_str()); - mtmd_input_text inp_txt = { - prompt.get(), - /* add_special */ true, - /* parse_special */ true, - }; - mtmd_input_chunks chunks; - int32_t tokenized = mtmd_tokenize(ctx_server.mctx, chunks, inp_txt, bitmaps); - if (tokenized != 0) { - throw std::runtime_error("Failed to tokenize prompt"); - } - server_inputs tmp(chunks); - inputs.push_back(std::move(tmp)); + // process prompt + std::vector inputs; + if (oaicompat && !prompt.is_string()) { + throw std::runtime_error("prompt must be a string"); + + } else if (oaicompat && has_mtmd) { + // multimodal + mtmd_input_text inp_txt = { + prompt.get(), + /* add_special */ true, + /* parse_special */ true, + }; + mtmd_input_chunks chunks; + int32_t tokenized = mtmd_tokenize(ctx_server.mctx, chunks, inp_txt, bitmaps); + if (tokenized != 0) { + throw std::runtime_error("Failed to tokenize prompt"); } + server_tokens tmp(chunks, true); + inputs.push_back(std::move(tmp)); + } else { // non-multimodal version auto tokenized_prompts = tokenize_input_prompts(ctx_server.vocab, prompt, true, true); diff --git a/examples/server/utils.hpp b/examples/server/utils.hpp index 425d48ba5ab9d..fb4ce9c0fb2b2 100644 --- a/examples/server/utils.hpp +++ b/examples/server/utils.hpp @@ -963,138 +963,140 @@ static std::vector parse_lora_request( // (may need to refactor in near future) // -// each chunk can contain either one SINGLE text token or an image (multiple token embeddings) +// each chunk can contain either one SINGLE text token or pointer to image // this is to simplify the logic of KV cache management -struct server_inp_chunk { - size_t n_tokens = 1; // always 1 in case of text - llama_token tok_text; - mtmd_image_tokens_ptr tok_image; +struct server_token { + llama_token txt; + std::shared_ptr img; std::string str() const { // for debugging - if (tok_image) { - return string_format("( at %p) ", (void *)tok_image.get()); + GGML_ASSERT(img || txt != LLAMA_TOKEN_NULL); + if (img) { + return " "; } else { - return std::to_string(tok_text) + " "; + return std::to_string(txt) + " "; } } }; /** - * server_inputs is a helper to manage the input tokens and image for the server. - * - * the difference between server_inputs and mtmd_input_chunks is that each chunk of server_inputs only contains a single text token, but text chunk of mtmd_input_chunks can contain multiple tokens. - * - * for example, server_inputs may contain 5 text tokens followed by 1 image chunk: - * 1 41 2635 325 463 - * - * in this example: - * - n_tokens() returns 5+15 = 20 total tokens - * - get_chunk(1) returns chunk containing token ID 41 - * - get_chunk(5) returns image chunk (15 tokens) - * - get_chunk(7) returns same image chunk - * + * server_tokens is a helper to manage the input tokens and image for the server. * it is made this way to simplify the logic of KV cache management. + * + * each token can be either a text token or a pointer to an image. + * if image usually contains multiple tokens, each token contains a shared_ptr to the same image. */ -struct server_inputs { - std::vector chunks; +struct server_tokens { + bool has_mtmd = false; + std::vector values; - server_inputs() = default; - ~server_inputs() = default; // Important if unique_ptr is used + server_tokens() = default; + ~server_tokens() = default; // Prevent copying - server_inputs(const server_inputs&) = delete; - server_inputs& operator=(const server_inputs&) = delete; + server_tokens(const server_tokens&) = delete; + server_tokens& operator=(const server_tokens&) = delete; // Allow moving (usually implicitly generated if members are movable) - server_inputs(server_inputs&&) = default; - server_inputs& operator=(server_inputs&&) = default; + server_tokens(server_tokens&&) = default; + server_tokens& operator=(server_tokens&&) = default; + + // Allow accessing elements using [] operator + server_token& operator[](size_t index) { return values[index]; } + const server_token& operator[](size_t index) const { return values[index]; } - server_inputs(mtmd_input_chunks & mtmd_chunks) { + server_tokens(mtmd_input_chunks & mtmd_chunks, bool has_mtmd) : has_mtmd(has_mtmd) { for (auto & c : mtmd_chunks) { if (c.type == MTMD_INPUT_CHUNK_TYPE_IMAGE) { - size_t n_tokens = mtmd_image_tokens_get_n_tokens(c.tokens_image.get()); - chunks.push_back({n_tokens, LLAMA_TOKEN_NULL, std::move(c.tokens_image)}); + add_image_tokens(std::move(c.tokens_image)); } else if (c.type == MTMD_INPUT_CHUNK_TYPE_TEXT) { for (auto & tok : c.tokens_text) { - chunks.push_back({1, tok, nullptr}); + add_text_token(tok); } } else { - GGML_ASSERT(false && "Invalid chunk type"); + GGML_ABORT("Invalid chunk type"); } } } - std::string str() { + std::string str() const { // for debugging std::string ret; - for (const auto & chunk : chunks) { - ret += chunk.str(); + for (const auto & t : values) { + ret += t.str(); } return ret; } - size_t n_tokens() const { - size_t res = 0; - for (const auto & chunk : chunks) { - if (chunk.tok_image) { - res += chunk.n_tokens; - } else { - res++; - } - } - return res; + size_t size() const { + return values.size(); } bool empty() const { - return n_tokens() == 0; + return values.empty(); } void clear() { - chunks.clear(); + values.clear(); + } + + void resize(size_t n) { + values.resize(n); + } + + void add_token(server_token && t) { + if (t.img) GGML_ASSERT(has_mtmd); + values.push_back(std::move(t)); } void add_text_token(llama_token tok) { GGML_ASSERT(tok != LLAMA_TOKEN_NULL); - chunks.push_back({1, tok, nullptr}); + values.push_back({tok, nullptr}); } - void add_image_tokens(mtmd_image_tokens_ptr & image) { + void add_image_tokens(mtmd_image_tokens_ptr && image) { + GGML_ASSERT(has_mtmd); GGML_ASSERT(image != nullptr); - size_t n_tokens = mtmd_image_tokens_get_n_tokens(image.get()); - chunks.push_back({n_tokens, LLAMA_TOKEN_NULL, std::move(image)}); + std::shared_ptr tok_image(std::move(image)); + size_t n_tokens = mtmd_image_tokens_get_n_tokens(tok_image.get()); + GGML_ASSERT(n_tokens > 0 && "Invalid image token"); // should never happen + for (size_t i = 0; i < n_tokens; ++i) { + values.push_back({LLAMA_TOKEN_NULL, tok_image}); + } } - size_t get_common_prefix(const server_inputs & b) const { - size_t ret = 0; - size_t max_idx = std::min(chunks.size(), b.chunks.size()); + size_t get_common_prefix(const server_tokens & b) const { + size_t max_idx = std::min(values.size(), b.values.size()); for (size_t i = 0; i < max_idx; ++i) { - auto & ai = chunks[i]; - auto & bi = b.chunks[i]; + auto & ai = values[i]; + auto & bi = b.values[i]; - if (ai.tok_text == bi.tok_text && !ai.tok_image && !bi.tok_image) { - ret++; + if (ai.txt == bi.txt && !ai.img && !bi.img) { continue; - } else if (ai.tok_image && bi.tok_image) { - std::string ai_id = mtmd_image_tokens_get_id(ai.tok_image.get()); - std::string bi_id = mtmd_image_tokens_get_id(bi.tok_image.get()); + } else if (ai.img && bi.img) { + GGML_ASSERT(has_mtmd); + std::string ai_id = mtmd_image_tokens_get_id(ai.img.get()); + std::string bi_id = mtmd_image_tokens_get_id(bi.img.get()); if (ai_id == bi_id) { - ret += mtmd_image_tokens_get_n_tokens(ai.tok_image.get()); + size_t n_tokens = mtmd_image_tokens_get_n_tokens(ai.img.get()); + GGML_ASSERT(n_tokens > 0 && "Invalid image token"); // should never happen + i += mtmd_image_tokens_get_n_tokens(ai.img.get()) - 1; continue; } else { - break; + return i; } } else { - break; + return i; } } - return ret; + return max_idx; // all tokens are equal } // make sure all text tokens are within the vocab range bool validate(llama_token max_vocab_id) const { - for (const auto & chunk : chunks) { - if (!chunk.tok_image) { - if (chunk.tok_text < 0 || chunk.tok_text >= max_vocab_id) { + for (const auto & t : values) { + if (!t.img) { + if (t.txt < 0 || t.txt >= max_vocab_id) { return false; } } @@ -1102,65 +1104,11 @@ struct server_inputs { return true; } - // pos is also referred as logical index - server_inp_chunk & get_chunk(size_t pos) { - size_t physical_idx = get_chunk_physical_idx(pos); - return chunks[physical_idx]; - } - - // returns physical_index - size_t get_chunk_physical_idx(size_t logical_idx) const { - size_t current_pos = 0; - for (size_t i = 0; i < chunks.size(); ++i) { - const auto & chunk = chunks[i]; - size_t chunk_end_pos = current_pos + chunk.n_tokens; - if (logical_idx < chunk_end_pos) { - // The target position 'pos' falls within this chunk - return i; - } - current_pos = chunk_end_pos; - } - // If the loop finishes, 'pos' is >= the total number of logical positions - throw std::out_of_range("Position out of range"); - } - - // same idea with std::vector resize() + // same idea with std::vector::resize() void keep_until(size_t pos) { - if (pos == 0) { - chunks.clear(); - return; - } - - size_t current_pos = 0; - for (size_t i = 0; i < chunks.size(); ++i) { - const auto & chunk = chunks[i]; - size_t chunk_size = chunk.tok_image ? mtmd_image_tokens_get_n_tokens(chunk.tok_image.get()) : 1; - size_t chunk_end_pos = current_pos + chunk_size; - if (pos <= current_pos) { - // Truncation point is exactly at or before the start of this chunk. - // Keep only chunks before index 'i'. - chunks.resize(i); - return; - } - if (pos < chunk_end_pos) { - // Truncation point 'pos' falls within this chunk. - if (chunk.tok_image) { - // It's an image chunk, keep the whole chunk. - // Keep chunks up to and including index 'i'. - chunks.resize(i + 1); - } else { - // It's a text chunk. Since pos < chunk_end_pos and chunk_size is 1, - // this means pos == current_pos. - // Keep only chunks before index 'i'. - chunks.resize(i); - } - return; - } - // pos >= chunk_end_pos, so keep this chunk entirely and continue. - current_pos = chunk_end_pos; - } - // If the loop completes, it means 'pos' is >= the total logical size. - // No truncation needed, the vector remains unchanged. + // TODO : maybe throw error we remove part of the image (only allow removing the whole image) + // this cannot happen currently because get_common_prefix() only never returns such pos + values.resize(pos); } // Computes FNV-1a hash of the data @@ -1180,9 +1128,9 @@ struct server_inputs { // return all text tokens (for legacy code), to be used by save/load slot llama_tokens get_text_tokens() { llama_tokens output; - for (auto & chunk : chunks) { - if (chunk.tok_text != LLAMA_TOKEN_NULL) { - output.push_back(chunk.tok_text); + for (auto & t : values) { + if (t.txt != LLAMA_TOKEN_NULL) { + output.push_back(t.txt); } } return output; @@ -1190,7 +1138,7 @@ struct server_inputs { // clear and set text tokens (for legacy code), to be used by save/load slot void set_text_tokens(llama_tokens tokens) { - chunks.clear(); + values.clear(); for (auto & tok : tokens) { add_text_token(tok); } @@ -1267,22 +1215,22 @@ struct server_batch_embd { } }; -// TODO @ngxson : quite hacky for now, but just to see if it works static int32_t server_img_process( llama_context * ctx, mtmd_context * mctx, - server_inp_chunk & chunk, + server_token & chunk, server_batch_embd & batch, llama_pos n_past, int slot_id) { - GGML_ASSERT(chunk.tok_image); + GGML_ASSERT(chunk.img); + int32_t n_tokens = mtmd_image_tokens_get_n_tokens(chunk.img.get()); int32_t ret; // encode the image { int64_t t0 = ggml_time_ms(); - SRV_INF("encoding image (%d tokens)...\n", (int)chunk.n_tokens); - ret = mtmd_encode(mctx, chunk.tok_image.get()); + SRV_INF("encoding image (%d tokens)...\n", (int)n_tokens); + ret = mtmd_encode(mctx, chunk.img.get()); if (ret != 0) { SRV_ERR("failed to encode image, status = %d\n", ret); return ret; @@ -1294,7 +1242,6 @@ static int32_t server_img_process( // decode the embeddings int64_t t1 = ggml_time_ms(); int32_t n_embd = llama_model_n_embd(llama_get_model(ctx)); - int32_t n_tokens = chunk.n_tokens; int32_t n_batch = batch.pos.size(); int32_t i_batch = 0; int32_t n_img_batches = GGML_PAD(n_tokens, n_batch) / n_batch; @@ -1334,8 +1281,9 @@ static int32_t server_img_process( } // hacky, support text-only for now -static server_inputs convert_legacy_to_mtmd(llama_tokens & tokenized) { - server_inputs res; +static server_tokens convert_legacy_to_mtmd(llama_tokens & tokenized) { + server_tokens res; + res.has_mtmd = false; for (auto & tok : tokenized) { res.add_text_token(tok); } From e82fea8f0e95f38146a5d7be1935602162a2f5f4 Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Fri, 25 Apr 2025 22:41:51 +0200 Subject: [PATCH 18/18] fix the failing test case --- examples/server/utils.hpp | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/examples/server/utils.hpp b/examples/server/utils.hpp index fb4ce9c0fb2b2..f048be0265e7e 100644 --- a/examples/server/utils.hpp +++ b/examples/server/utils.hpp @@ -633,9 +633,12 @@ static json oaicompat_completion_params_parse( } // get input files - json messages = json_value(body, "messages", json::array()); + if (!body.contains("messages")) { + throw std::runtime_error("'messages' is required"); + } + json messages = body.at("messages"); if (!messages.is_array()) { - throw std::runtime_error("Expected 'messages' to be an array, got " + messages.dump()); + throw std::runtime_error("Expected 'messages' to be an array"); } for (auto & msg : messages) { json & content = msg.at("content");