diff --git a/common/arg.cpp b/common/arg.cpp index 0657553e4e9cf..3cf7b1a5d0612 100644 --- a/common/arg.cpp +++ b/common/arg.cpp @@ -40,7 +40,7 @@ using json = nlohmann::ordered_json; std::initializer_list mmproj_examples = { LLAMA_EXAMPLE_LLAVA, - // TODO: add LLAMA_EXAMPLE_SERVER when it's ready + LLAMA_EXAMPLE_SERVER, }; common_arg & common_arg::set_examples(std::initializer_list examples) { diff --git a/examples/llava/mtmd.cpp b/examples/llava/mtmd.cpp index a994ef0166e6a..601d2200e0bdc 100644 --- a/examples/llava/mtmd.cpp +++ b/examples/llava/mtmd.cpp @@ -29,6 +29,7 @@ struct mtmd_context { bool print_timings; int n_threads; std::string image_marker; + bool calc_image_hash; // for minicpmv, we need special tokens in-between slices mtmd_slice_tmpl slice_tmpl = MTMD_SLICE_TMPL_NONE; diff --git a/examples/llava/mtmd.h b/examples/llava/mtmd.h index 78be192dd6eb6..4a3436245fdae 100644 --- a/examples/llava/mtmd.h +++ b/examples/llava/mtmd.h @@ -87,6 +87,7 @@ MTMD_API void mtmd_free(mtmd_context * ctx); // 2. (image tokens) // 3. "\ndescribe it in detail." // number of bitmaps must be equal to the number of image markers in the prompt +// the returned value must be freed using mtmd_input_chunks_free() // this function is thread-safe (shared ctx) // return values: // 0 on success diff --git a/examples/server/CMakeLists.txt b/examples/server/CMakeLists.txt index aee90388e4fb3..17109fddbd307 100644 --- a/examples/server/CMakeLists.txt +++ b/examples/server/CMakeLists.txt @@ -34,8 +34,9 @@ endforeach() add_executable(${TARGET} ${TARGET_SRCS}) install(TARGETS ${TARGET} RUNTIME) +target_include_directories(${TARGET} PRIVATE ../llava) target_include_directories(${TARGET} PRIVATE ${CMAKE_SOURCE_DIR}) -target_link_libraries(${TARGET} PRIVATE common ${CMAKE_THREAD_LIBS_INIT}) +target_link_libraries(${TARGET} PRIVATE common mtmd ${CMAKE_THREAD_LIBS_INIT}) if (LLAMA_SERVER_SSL) find_package(OpenSSL REQUIRED) diff --git a/examples/server/server.cpp b/examples/server/server.cpp index c580ec123299c..9428fceb0396c 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -7,6 +7,7 @@ #include "log.h" #include "sampling.h" #include "speculative.h" +#include "mtmd.h" // Change JSON_ASSERT from assert() to GGML_ASSERT: #define JSON_ASSERT GGML_ASSERT @@ -196,8 +197,8 @@ struct server_task { int id_target = -1; // used by SERVER_TASK_TYPE_INFERENCE - slot_params params; - llama_tokens prompt_tokens; + slot_params params; + server_tokens prompt_tokens; int id_selected_slot = -1; // used by SERVER_TASK_TYPE_SLOT_SAVE, SERVER_TASK_TYPE_SLOT_RESTORE, SERVER_TASK_TYPE_SLOT_ERASE @@ -1246,6 +1247,9 @@ struct server_slot { llama_context * ctx = nullptr; llama_context * ctx_dft = nullptr; + // multimodal + mtmd_context * mctx = nullptr; + common_speculative * spec = nullptr; std::vector lora; @@ -1273,14 +1277,14 @@ struct server_slot { int32_t n_prompt_tokens_processed = 0; // input prompt tokens - llama_tokens prompt_tokens; + server_tokens prompt_tokens; size_t last_nl_pos = 0; std::string generated_text; llama_tokens generated_tokens; - llama_tokens cache_tokens; + server_tokens cache_tokens; std::vector generated_token_probs; @@ -1474,7 +1478,7 @@ struct server_slot { {"is_processing", is_processing()}, {"non_causal", is_non_causal()}, {"params", params.to_json()}, - {"prompt", common_detokenize(ctx, prompt_tokens)}, + {"prompt", ""}, // TODO @ngxson, hacky patch, to fix before merge {"next_token", { {"has_next_token", has_next_token}, @@ -1847,13 +1851,17 @@ struct server_context { llama_model * model = nullptr; llama_context * ctx = nullptr; + // multimodal + mtmd_context * mctx = nullptr; + const llama_vocab * vocab = nullptr; llama_model * model_dft = nullptr; llama_context_params cparams_dft; - llama_batch batch = {}; + llama_batch batch; + server_batch_embd batch_embd; bool clean_kv_cache = true; bool add_bos_token = true; @@ -1876,6 +1884,8 @@ struct server_context { common_chat_templates_ptr chat_templates; ~server_context() { + mtmd_free(mctx); + // Clear any sampling context for (server_slot & slot : slots) { common_sampler_free(slot.smpl); @@ -1963,6 +1973,37 @@ struct server_context { chat_templates = common_chat_templates_init(model, "chatml"); } + std::string & mmproj_path = params_base.mmproj.path; + if (!mmproj_path.empty()) { + mtmd_context_params mparams{ + /* use_gpu */ params_base.mmproj_use_gpu, + /* timings */ false, + /* n_threads */ params_base.cpuparams.n_threads, + /* verbosity */ params_base.verbosity > 0 ? GGML_LOG_LEVEL_DEBUG : GGML_LOG_LEVEL_INFO, + }; + mctx = mtmd_init_from_file(mmproj_path.c_str(), model, mparams); + if (mctx == nullptr) { + SRV_ERR("failed to load multimodal model, '%s'\n", mmproj_path.c_str()); + return false; + } + SRV_INF("loaded multimodal model, '%s'\n", mmproj_path.c_str()); + + if (params_base.ctx_shift) { + params_base.ctx_shift = false; + SRV_INF("%s\n", "ctx_shift is not supported by multimodal, it will be disabled"); + } + + if (params_base.n_cache_reuse) { + params_base.n_cache_reuse = 0; + SRV_INF("%s\n", "cache_reuse is not supported by multimodal, it will be disabled"); + } + + if (!params_base.speculative.model.path.empty()) { + SRV_ERR("%s\n", "err: speculative decode is not supported by multimodal"); + return false; + } + } + return true; } @@ -1978,6 +2019,8 @@ struct server_context { slot.ctx = ctx; slot.n_ctx = n_ctx_slot; slot.n_predict = params_base.n_predict; + slot.mctx = mctx; + slot.cache_tokens.has_mtmd = mctx != nullptr; if (model_dft) { slot.batch_spec = llama_batch_init(params_base.speculative.n_max + 1, 0, 1); @@ -2014,9 +2057,8 @@ struct server_context { // note that n_batch can be > n_ctx (e.g. for non-causal attention models such as BERT where the KV cache is not used) { const int32_t n_batch = llama_n_batch(ctx); - - // only a single seq_id per token is needed - batch = llama_batch_init(std::max(n_batch, params_base.n_parallel), 0, 1); + batch = llama_batch_init(std::max(n_batch, params_base.n_parallel), 0, 1); + batch_embd = server_batch_embd(std::max(n_batch, params_base.n_parallel)); } metrics.init(); @@ -2052,7 +2094,7 @@ struct server_context { } // length of the Longest Common Subsequence between the current slot's prompt and the input prompt - int cur_lcs_len = common_lcs(slot.cache_tokens, task.prompt_tokens); + int cur_lcs_len = slot.cache_tokens.get_common_prefix(task.prompt_tokens); // fraction of the common subsequence length compared to the current slot's prompt length float cur_similarity = static_cast(cur_lcs_len) / static_cast(slot.cache_tokens.size()); @@ -2094,16 +2136,11 @@ struct server_context { return ret; } - bool can_be_detokenized(const struct llama_context * ctx, const std::vector & tokens) { + bool can_be_detokenized(const struct llama_context * ctx, const server_tokens & inp) { const llama_model * model = llama_get_model(ctx); const llama_vocab * vocab = llama_model_get_vocab(model); const int32_t n_vocab = llama_vocab_n_tokens(vocab); - for (const auto & token : tokens) { - if (token < 0 || token >= n_vocab) { - return false; - } - } - return true; + return inp.validate(n_vocab); } bool launch_slot_with_task(server_slot & slot, server_task && task) { @@ -2415,6 +2452,7 @@ struct server_context { void send_final_response(server_slot & slot) { auto res = std::make_unique(); + llama_tokens text_tokens = slot.prompt_tokens.get_text_tokens(); res->id = slot.id_task; res->id_slot = slot.id; @@ -2422,7 +2460,7 @@ struct server_context { res->content = std::move(slot.generated_text); res->tokens = std::move(slot.generated_tokens); res->timings = slot.get_timings(); - res->prompt = common_detokenize(ctx, slot.prompt_tokens, true); + res->prompt = common_detokenize(ctx, text_tokens, true); res->response_fields = std::move(slot.params.response_fields); res->truncated = slot.truncated; @@ -2734,6 +2772,10 @@ struct server_context { { int id_slot = task.slot_action.slot_id; server_slot * slot = get_slot_by_id(id_slot); + if (mctx) { + send_error(task, "This feature is not supported by multimodal", ERROR_TYPE_NOT_SUPPORTED); + break; + } if (slot == nullptr) { send_error(task, "Invalid slot ID", ERROR_TYPE_INVALID_REQUEST); break; @@ -2751,7 +2793,8 @@ struct server_context { std::string filename = task.slot_action.filename; std::string filepath = task.slot_action.filepath; - const size_t nwrite = llama_state_seq_save_file(ctx, filepath.c_str(), slot->id, slot->cache_tokens.data(), token_count); + const llama_tokens tokens = slot->cache_tokens.get_text_tokens(); + const size_t nwrite = llama_state_seq_save_file(ctx, filepath.c_str(), slot->id, tokens.data(), token_count); const int64_t t_end = ggml_time_us(); const double t_save_ms = (t_end - t_start) / 1000.0; @@ -2768,6 +2811,10 @@ struct server_context { } break; case SERVER_TASK_TYPE_SLOT_RESTORE: { + if (mctx) { + send_error(task, "This feature is not supported by multimodal", ERROR_TYPE_NOT_SUPPORTED); + break; + } int id_slot = task.slot_action.slot_id; server_slot * slot = get_slot_by_id(id_slot); if (slot == nullptr) { @@ -2786,15 +2833,17 @@ struct server_context { std::string filename = task.slot_action.filename; std::string filepath = task.slot_action.filepath; - slot->cache_tokens.resize(slot->n_ctx); + llama_tokens tokens; + tokens.resize(slot->n_ctx); size_t token_count = 0; - size_t nread = llama_state_seq_load_file(ctx, filepath.c_str(), slot->id, slot->cache_tokens.data(), slot->cache_tokens.size(), &token_count); + size_t nread = llama_state_seq_load_file(ctx, filepath.c_str(), slot->id, tokens.data(), tokens.size(), &token_count); if (nread == 0) { - slot->cache_tokens.resize(0); + slot->cache_tokens.clear(); // KV may already been invalidated? send_error(task, "Unable to restore slot, no available space in KV cache or invalid slot save file", ERROR_TYPE_INVALID_REQUEST); break; } - slot->cache_tokens.resize(token_count); + tokens.resize(token_count); + slot->cache_tokens.set_text_tokens(tokens); const int64_t t_end = ggml_time_us(); const double t_restore_ms = (t_end - t_start) / 1000.0; @@ -2811,6 +2860,10 @@ struct server_context { } break; case SERVER_TASK_TYPE_SLOT_ERASE: { + if (mctx) { + send_error(task, "This feature is not supported by multimodal", ERROR_TYPE_NOT_SUPPORTED); + break; + } int id_slot = task.slot_action.slot_id; server_slot * slot = get_slot_by_id(id_slot); if (slot == nullptr) { @@ -2842,6 +2895,7 @@ struct server_context { res->id = task.id; queue_results.send(std::move(res)); } break; + } } @@ -2887,6 +2941,12 @@ struct server_context { continue; } + if (mctx) { + // we should never reach this because params_base.ctx_shift is automatically disabled if mmproj is loaded + // we don't support ctx_shift because an image chunk may contains multiple tokens + GGML_ABORT("not supported by multimodal"); + } + // Shift context const int n_keep = slot.params.n_keep + add_bos_token; const int n_left = slot.n_past - n_keep; @@ -2941,7 +3001,7 @@ struct server_context { slot.n_past += 1; if (slot.params.cache_prompt) { - slot.cache_tokens.push_back(slot.sampled); + slot.cache_tokens.add_text_token(slot.sampled); } SLT_DBG(slot, "slot decode token, n_ctx = %d, n_past = %d, n_cache_tokens = %d, truncated = %d\n", @@ -2980,7 +3040,7 @@ struct server_context { SLT_INF(slot, "new prompt, n_ctx_slot = %d, n_keep = %d, n_prompt_tokens = %d\n", slot.n_ctx, slot.params.n_keep, slot.n_prompt_tokens); // print prompt tokens (for debugging) - if (1) { + /*if (1) { // first 16 tokens (avoid flooding logs) for (int i = 0; i < std::min(16, prompt_tokens.size()); i++) { SLT_DBG(slot, "prompt token %3d: %6d '%s'\n", i, prompt_tokens[i], common_token_to_piece(ctx, prompt_tokens[i]).c_str()); @@ -2990,7 +3050,7 @@ struct server_context { for (int i = 0; i < (int) prompt_tokens.size(); i++) { SLT_DBG(slot, "prompt token %3d: %6d '%s'\n", i, prompt_tokens[i], common_token_to_piece(ctx, prompt_tokens[i]).c_str()); } - } + }*/ // empty prompt passed -> release the slot and send empty response if (prompt_tokens.empty()) { @@ -3032,21 +3092,26 @@ struct server_context { // if input prompt is too big, truncate it if (slot.n_prompt_tokens >= slot.n_ctx) { + if (mctx) { + // we should never reach this + GGML_ABORT("not supported by multimodal"); + } + llama_tokens curr_tokens = slot.prompt_tokens.get_text_tokens(); const int n_left = slot.n_ctx - slot.params.n_keep; const int n_block_size = n_left / 2; const int erased_blocks = (slot.n_prompt_tokens - slot.params.n_keep - n_block_size) / n_block_size; llama_tokens new_tokens( - prompt_tokens.begin(), - prompt_tokens.begin() + slot.params.n_keep); + curr_tokens.begin(), + curr_tokens.begin() + slot.params.n_keep); new_tokens.insert( new_tokens.end(), - prompt_tokens.begin() + slot.params.n_keep + erased_blocks * n_block_size, - prompt_tokens.end()); + curr_tokens.begin() + slot.params.n_keep + erased_blocks * n_block_size, + curr_tokens.end()); - prompt_tokens = std::move(new_tokens); + prompt_tokens.set_text_tokens(new_tokens); slot.truncated = true; slot.n_prompt_tokens = prompt_tokens.size(); @@ -3058,13 +3123,18 @@ struct server_context { if (slot.params.cache_prompt) { // reuse any previously computed tokens that are common with the new prompt - slot.n_past = common_lcp(slot.cache_tokens, prompt_tokens); + slot.n_past = slot.cache_tokens.get_common_prefix(prompt_tokens); // reuse chunks from the cached prompt by shifting their KV cache in the new position if (params_base.n_cache_reuse > 0) { size_t head_c = slot.n_past; // cache size_t head_p = slot.n_past; // current prompt + if (mctx) { + // we should never reach this + GGML_ABORT("not supported by multimodal"); + } + SLT_DBG(slot, "trying to reuse chunks with size > %d, slot.n_past = %d\n", params_base.n_cache_reuse, slot.n_past); while (head_c < slot.cache_tokens.size() && @@ -3073,7 +3143,7 @@ struct server_context { size_t n_match = 0; while (head_c + n_match < slot.cache_tokens.size() && head_p + n_match < prompt_tokens.size() && - slot.cache_tokens[head_c + n_match] == prompt_tokens[head_p + n_match]) { + slot.cache_tokens[head_c + n_match].txt == prompt_tokens[head_p + n_match].txt) { n_match++; } @@ -3090,7 +3160,7 @@ struct server_context { llama_kv_self_seq_add(ctx, slot.id, head_c, head_c + n_match, kv_shift); for (size_t i = 0; i < n_match; i++) { - slot.cache_tokens[head_p + i] = slot.cache_tokens[head_c + i]; + slot.cache_tokens[head_p + i].txt = slot.cache_tokens[head_c + i].txt; slot.n_past++; } @@ -3136,23 +3206,55 @@ struct server_context { SLT_INF(slot, "kv cache rm [%d, end)\n", slot.n_past); // remove the non-common part from the cache - slot.cache_tokens.resize(slot.n_past); + slot.cache_tokens.keep_until(slot.n_past); + + auto & cur_tok = slot.prompt_tokens[slot.n_past]; + + // check if we should process the image + if (cur_tok.img) { + // process the image + int32_t res = server_img_process(ctx, mctx, cur_tok, batch_embd, slot.n_past, slot.id); + int32_t n_tokens = mtmd_image_tokens_get_n_tokens(cur_tok.img.get()); + if (res != 0) { + SLT_ERR(slot, "failed to process image, res = %d\n", res); + slot.release(); + send_error(slot, "failed to process image", ERROR_TYPE_SERVER); + continue; + } + + if (slot.params.cache_prompt) { + // all ALL image tokens at once + for (int32_t i = 0; i < n_tokens; i++) { + slot.cache_tokens.add_token(std::move(slot.prompt_tokens[slot.n_past + i])); + } + } + + slot.n_past += n_tokens; + slot.n_prompt_tokens_processed += n_tokens; + } // add prompt tokens for processing in the current batch while (slot.n_past < slot.n_prompt_tokens && batch.n_tokens < n_batch) { + // get next token to process + auto & curr_chunk = slot.prompt_tokens[slot.n_past]; + if (curr_chunk.txt == LLAMA_TOKEN_NULL) { + break; // end of text chunk + } + // without pooling, we want to output the embeddings for all the tokens in the batch const bool need_embd = slot.task_type == SERVER_TASK_TYPE_EMBEDDING && llama_pooling_type(slot.ctx) == LLAMA_POOLING_TYPE_NONE; - common_batch_add(batch, prompt_tokens[slot.n_past], slot.n_past, { slot.id }, need_embd); - + common_batch_add(batch, curr_chunk.txt, slot.n_past, { slot.id }, need_embd); if (slot.params.cache_prompt) { - slot.cache_tokens.push_back(prompt_tokens[slot.n_past]); + slot.cache_tokens.add_text_token(curr_chunk.txt); } slot.n_prompt_tokens_processed++; slot.n_past++; } + // SLT_INF(slot, "new cache_tokens: %s\n", slot.cache_tokens.str().c_str()); + SLT_INF(slot, "prompt processing progress, n_past = %d, n_tokens = %d, progress = %f\n", slot.n_past, batch.n_tokens, (float) slot.n_prompt_tokens_processed / slot.n_prompt_tokens); // entire prompt has been processed @@ -3164,8 +3266,11 @@ struct server_context { common_sampler_reset(slot.smpl); // Process all prompt tokens through sampler system - for (int i = 0; i < slot.n_prompt_tokens; ++i) { - common_sampler_accept(slot.smpl, prompt_tokens[i], false); + for (size_t i = 0; i < slot.cache_tokens.size(); ++i) { + auto & cur_tok = slot.prompt_tokens[i]; + if (cur_tok.txt != LLAMA_TOKEN_NULL) { + common_sampler_accept(slot.smpl, cur_tok.txt, false); + } } // extract the logits only for the last token @@ -3311,6 +3416,11 @@ struct server_context { continue; } + if (mctx) { + // we should never reach this + GGML_ABORT("not supported by multimodal"); + } + // determine the max draft that fits the current slot state int n_draft_max = slot.params.speculative.n_max; @@ -3337,7 +3447,8 @@ struct server_context { params_spec.n_reuse = llama_n_ctx(slot.ctx_dft) - slot.params.speculative.n_max; params_spec.p_min = slot.params.speculative.p_min; - llama_tokens draft = common_speculative_gen_draft(slot.spec, params_spec, slot.cache_tokens, id); + llama_tokens cached_text_tokens = slot.cache_tokens.get_text_tokens(); + llama_tokens draft = common_speculative_gen_draft(slot.spec, params_spec, cached_text_tokens, id); // keep track of total number of tokens generated in the draft slot.n_draft_total += draft.size(); @@ -3370,8 +3481,10 @@ struct server_context { // update how many tokens out of draft was accepted slot.n_draft_accepted += ids.size() - 1; - slot.cache_tokens.push_back(id); - slot.cache_tokens.insert(slot.cache_tokens.end(), ids.begin(), ids.end() - 1); + slot.cache_tokens.add_text_token(id); + for (auto & t : ids) { + slot.cache_tokens.add_text_token(t); + } llama_kv_self_seq_rm(ctx, slot.id, slot.n_past, -1); @@ -3941,6 +4054,7 @@ int main(int argc, char ** argv) { const auto handle_completions_impl = [&ctx_server, &res_error, &res_ok]( server_task_type type, json & data, + std::vector & files, std::function is_connection_closed, httplib::Response & res, oaicompat_type oaicompat) { @@ -3960,15 +4074,62 @@ int main(int argc, char ** argv) { // TODO: this log can become very long, put it behind a flag or think about a more compact format //SRV_DBG("Prompt: %s\n", prompt.is_string() ? prompt.get().c_str() : prompt.dump(2).c_str()); - std::vector tokenized_prompts = tokenize_input_prompts(ctx_server.vocab, prompt, true, true); - tasks.reserve(tokenized_prompts.size()); - for (size_t i = 0; i < tokenized_prompts.size(); i++) { + // process files + std::vector bitmaps; + const bool has_mtmd = ctx_server.mctx != nullptr; + { + if (!has_mtmd && !files.empty()) { + throw std::runtime_error("This server does not support multimodal"); + } + for (auto & file : files) { + mtmd_bitmap bmp; + int32_t res = mtmd_helper_bitmap_init_from_buf(file.data(), file.size(), bmp); + if (res != 0) { + throw std::runtime_error("Failed to load image"); + } + // calculate bitmap hash (for KV caching) + bmp.id = server_tokens::fnv_hash(bmp.data.data(), bmp.data.size()); + bitmaps.push_back(std::move(bmp)); + } + } + + // process prompt + std::vector inputs; + if (oaicompat && !prompt.is_string()) { + throw std::runtime_error("prompt must be a string"); + + } else if (oaicompat && has_mtmd) { + // multimodal + mtmd_input_text inp_txt = { + prompt.get(), + /* add_special */ true, + /* parse_special */ true, + }; + mtmd_input_chunks chunks; + int32_t tokenized = mtmd_tokenize(ctx_server.mctx, chunks, inp_txt, bitmaps); + if (tokenized != 0) { + throw std::runtime_error("Failed to tokenize prompt"); + } + server_tokens tmp(chunks, true); + inputs.push_back(std::move(tmp)); + + } else { + // non-multimodal version + auto tokenized_prompts = tokenize_input_prompts(ctx_server.vocab, prompt, true, true); + for (auto & p : tokenized_prompts) { + auto tmp = convert_legacy_to_mtmd(p); + inputs.push_back(std::move(tmp)); + } + } + + tasks.reserve(inputs.size()); + for (size_t i = 0; i < inputs.size(); i++) { server_task task = server_task(type); task.id = ctx_server.queue_tasks.get_new_id(); task.index = i; - task.prompt_tokens = std::move(tokenized_prompts[i]); + task.prompt_tokens = std::move(inputs[i]); task.params = server_task::params_from_json_cmpl( ctx_server.ctx, ctx_server.params_base, @@ -4050,9 +4211,11 @@ int main(int argc, char ** argv) { const auto handle_completions = [&handle_completions_impl](const httplib::Request & req, httplib::Response & res) { json data = json::parse(req.body); + std::vector files; // dummy return handle_completions_impl( SERVER_TASK_TYPE_COMPLETION, data, + files, req.is_connection_closed, res, OAICOMPAT_TYPE_NONE); @@ -4060,9 +4223,11 @@ int main(int argc, char ** argv) { const auto handle_completions_oai = [&handle_completions_impl](const httplib::Request & req, httplib::Response & res) { json data = oaicompat_completion_params_parse(json::parse(req.body)); + std::vector files; // dummy return handle_completions_impl( SERVER_TASK_TYPE_COMPLETION, data, + files, req.is_connection_closed, res, OAICOMPAT_TYPE_COMPLETION); @@ -4137,9 +4302,11 @@ int main(int argc, char ** argv) { tokenized_prompts[0] ); + std::vector files; // dummy return handle_completions_impl( SERVER_TASK_TYPE_INFILL, data, + files, req.is_connection_closed, res, OAICOMPAT_TYPE_NONE); // infill is not OAI compatible @@ -4153,11 +4320,13 @@ int main(int argc, char ** argv) { } auto body = json::parse(req.body); - json data = oaicompat_completion_params_parse(body, params.use_jinja, params.reasoning_format, ctx_server.chat_templates.get()); + std::vector files; + json data = oaicompat_completion_params_parse(body, params.use_jinja, params.reasoning_format, ctx_server.chat_templates.get(), files); return handle_completions_impl( SERVER_TASK_TYPE_COMPLETION, data, + files, req.is_connection_closed, res, OAICOMPAT_TYPE_CHAT); @@ -4166,7 +4335,8 @@ int main(int argc, char ** argv) { // same with handle_chat_completions, but without inference part const auto handle_apply_template = [&ctx_server, ¶ms, &res_ok](const httplib::Request & req, httplib::Response & res) { auto body = json::parse(req.body); - json data = oaicompat_completion_params_parse(body, params.use_jinja, params.reasoning_format, ctx_server.chat_templates.get()); + std::vector files; // dummy, unused + json data = oaicompat_completion_params_parse(body, params.use_jinja, params.reasoning_format, ctx_server.chat_templates.get(), files); res_ok(res, {{ "prompt", std::move(data.at("prompt")) }}); }; @@ -4271,7 +4441,7 @@ int main(int argc, char ** argv) { } } - std::vector tokenized_prompts = tokenize_input_prompts(ctx_server.vocab, prompt, true, true); + auto tokenized_prompts = tokenize_input_prompts(ctx_server.vocab, prompt, true, true); for (const auto & tokens : tokenized_prompts) { // this check is necessary for models that do not add BOS token to the input if (tokens.empty()) { @@ -4291,7 +4461,7 @@ int main(int argc, char ** argv) { task.id = ctx_server.queue_tasks.get_new_id(); task.index = i; - task.prompt_tokens = std::move(tokenized_prompts[i]); + task.prompt_tokens = convert_legacy_to_mtmd(tokenized_prompts[i]); // OAI-compat task.params.oaicompat = oaicompat; @@ -4385,13 +4555,14 @@ int main(int argc, char ** argv) { std::unordered_set task_ids; { std::vector tasks; - std::vector tokenized_docs = tokenize_input_prompts(ctx_server.vocab, documents, /* add_special */ false, true); + auto tokenized_docs = tokenize_input_prompts(ctx_server.vocab, documents, /* add_special */ false, true); tasks.reserve(tokenized_docs.size()); for (size_t i = 0; i < tokenized_docs.size(); i++) { + auto tmp = format_rerank(ctx_server.vocab, tokenized_query, tokenized_docs[i]); server_task task = server_task(SERVER_TASK_TYPE_RERANK); task.id = ctx_server.queue_tasks.get_new_id(); task.index = i; - task.prompt_tokens = format_rerank(ctx_server.vocab, tokenized_query, tokenized_docs[i]); + task.prompt_tokens = convert_legacy_to_mtmd(tmp); tasks.push_back(std::move(task)); } diff --git a/examples/server/utils.hpp b/examples/server/utils.hpp index aba2f27f9b564..f048be0265e7e 100644 --- a/examples/server/utils.hpp +++ b/examples/server/utils.hpp @@ -4,6 +4,7 @@ #include "log.h" #include "llama.h" #include "base64.hpp" +#include "mtmd.h" // increase max payload length to allow use of larger context size #define CPPHTTPLIB_FORM_URL_ENCODED_PAYLOAD_MAX_LENGTH 1048576 @@ -21,6 +22,7 @@ #include #include #include +#include #define DEFAULT_OAICOMPAT_MODEL "gpt-3.5-turbo" @@ -41,6 +43,8 @@ using json = nlohmann::ordered_json; #define QUE_ERR(fmt, ...) LOG_ERR("que %12.*s: " fmt, 12, __func__, __VA_ARGS__) #define QUE_DBG(fmt, ...) LOG_DBG("que %12.*s: " fmt, 12, __func__, __VA_ARGS__) +using raw_buffer = std::vector; + template static T json_value(const json & body, const std::string & key, const T & default_value) { // Fallback null to default value @@ -386,7 +390,7 @@ static inline bool is_base64(uint8_t c) { return (isalnum(c) || (c == '+') || (c == '/')); } -static inline std::vector base64_decode(const std::string & encoded_string) { +static inline raw_buffer base64_decode(const std::string & encoded_string) { int i = 0; int j = 0; int in_ = 0; @@ -396,7 +400,7 @@ static inline std::vector base64_decode(const std::string & encoded_str uint8_t char_array_4[4]; uint8_t char_array_3[3]; - std::vector ret; + raw_buffer ret; while (in_len-- && (encoded_string[in_] != '=') && is_base64(encoded_string[in_])) { char_array_4[i++] = encoded_string[in_]; in_++; @@ -579,7 +583,8 @@ static json oaicompat_completion_params_parse( const json & body, /* openai api json semantics */ bool use_jinja, common_reasoning_format reasoning_format, - const struct common_chat_templates * tmpls) + const struct common_chat_templates * tmpls, + std::vector & out_files) { json llama_params; @@ -627,8 +632,50 @@ static json oaicompat_completion_params_parse( } } + // get input files + if (!body.contains("messages")) { + throw std::runtime_error("'messages' is required"); + } + json messages = body.at("messages"); + if (!messages.is_array()) { + throw std::runtime_error("Expected 'messages' to be an array"); + } + for (auto & msg : messages) { + json & content = msg.at("content"); + if (content.is_string()) { + continue; + } + + if (!content.is_array()) { + throw std::runtime_error("Expected 'content' to be a string or an array"); + } + + for (auto & p : content) { + std::string type = json_value(p, "type", std::string()); + json image_url = json_value(p, "image_url", json::object()); + if (type == "image_url") { + std::string url = json_value(image_url, "url", std::string()); + std::vector parts = string_split(url, /*separator*/ ','); + if (parts.size() != 2) { + throw std::runtime_error("Invalid image_url.url value"); + } else if (!string_starts_with(parts[0], "data:image/")) { + throw std::runtime_error("Invalid image_url.url format: " + parts[0]); + } else if (!string_ends_with(parts[0], "base64")) { + throw std::runtime_error("image_url.url must be base64 encoded"); + } else { + auto base64_data = parts[1]; + auto decoded_data = base64_decode(base64_data); + out_files.push_back(decoded_data); + } + p["type"] = "text"; + p["text"] = "<__image__>"; + p.erase("image_url"); + } + } + } + common_chat_templates_inputs inputs; - inputs.messages = common_chat_msgs_parse_oaicompat(body.at("messages")); + inputs.messages = common_chat_msgs_parse_oaicompat(messages); inputs.tools = common_chat_tools_parse_oaicompat(tools); inputs.tool_choice = common_chat_tool_choice_parse_oaicompat(json_value(body, "tool_choice", std::string("auto"))); inputs.json_schema = json_schema.is_null() ? "" : json_schema.dump(); @@ -913,3 +960,335 @@ static std::vector parse_lora_request( return lora; } + +// +// utils for interacting with libmtmd +// (may need to refactor in near future) +// + +// each chunk can contain either one SINGLE text token or pointer to image +// this is to simplify the logic of KV cache management +struct server_token { + llama_token txt; + std::shared_ptr img; + std::string str() const { + // for debugging + GGML_ASSERT(img || txt != LLAMA_TOKEN_NULL); + if (img) { + return " "; + } else { + return std::to_string(txt) + " "; + } + } +}; + +/** + * server_tokens is a helper to manage the input tokens and image for the server. + * it is made this way to simplify the logic of KV cache management. + * + * each token can be either a text token or a pointer to an image. + * if image usually contains multiple tokens, each token contains a shared_ptr to the same image. + */ +struct server_tokens { + bool has_mtmd = false; + std::vector values; + + server_tokens() = default; + ~server_tokens() = default; + + // Prevent copying + server_tokens(const server_tokens&) = delete; + server_tokens& operator=(const server_tokens&) = delete; + + // Allow moving (usually implicitly generated if members are movable) + server_tokens(server_tokens&&) = default; + server_tokens& operator=(server_tokens&&) = default; + + // Allow accessing elements using [] operator + server_token& operator[](size_t index) { return values[index]; } + const server_token& operator[](size_t index) const { return values[index]; } + + server_tokens(mtmd_input_chunks & mtmd_chunks, bool has_mtmd) : has_mtmd(has_mtmd) { + for (auto & c : mtmd_chunks) { + if (c.type == MTMD_INPUT_CHUNK_TYPE_IMAGE) { + add_image_tokens(std::move(c.tokens_image)); + } else if (c.type == MTMD_INPUT_CHUNK_TYPE_TEXT) { + for (auto & tok : c.tokens_text) { + add_text_token(tok); + } + } else { + GGML_ABORT("Invalid chunk type"); + } + } + } + + std::string str() const { + // for debugging + std::string ret; + for (const auto & t : values) { + ret += t.str(); + } + return ret; + } + + size_t size() const { + return values.size(); + } + + bool empty() const { + return values.empty(); + } + + void clear() { + values.clear(); + } + + void resize(size_t n) { + values.resize(n); + } + + void add_token(server_token && t) { + if (t.img) GGML_ASSERT(has_mtmd); + values.push_back(std::move(t)); + } + + void add_text_token(llama_token tok) { + GGML_ASSERT(tok != LLAMA_TOKEN_NULL); + values.push_back({tok, nullptr}); + } + + void add_image_tokens(mtmd_image_tokens_ptr && image) { + GGML_ASSERT(has_mtmd); + GGML_ASSERT(image != nullptr); + std::shared_ptr tok_image(std::move(image)); + size_t n_tokens = mtmd_image_tokens_get_n_tokens(tok_image.get()); + GGML_ASSERT(n_tokens > 0 && "Invalid image token"); // should never happen + for (size_t i = 0; i < n_tokens; ++i) { + values.push_back({LLAMA_TOKEN_NULL, tok_image}); + } + } + + size_t get_common_prefix(const server_tokens & b) const { + size_t max_idx = std::min(values.size(), b.values.size()); + for (size_t i = 0; i < max_idx; ++i) { + auto & ai = values[i]; + auto & bi = b.values[i]; + + if (ai.txt == bi.txt && !ai.img && !bi.img) { + continue; + } else if (ai.img && bi.img) { + GGML_ASSERT(has_mtmd); + std::string ai_id = mtmd_image_tokens_get_id(ai.img.get()); + std::string bi_id = mtmd_image_tokens_get_id(bi.img.get()); + if (ai_id == bi_id) { + size_t n_tokens = mtmd_image_tokens_get_n_tokens(ai.img.get()); + GGML_ASSERT(n_tokens > 0 && "Invalid image token"); // should never happen + i += mtmd_image_tokens_get_n_tokens(ai.img.get()) - 1; + continue; + } else { + return i; + } + } else { + return i; + } + } + return max_idx; // all tokens are equal + } + + // make sure all text tokens are within the vocab range + bool validate(llama_token max_vocab_id) const { + for (const auto & t : values) { + if (!t.img) { + if (t.txt < 0 || t.txt >= max_vocab_id) { + return false; + } + } + } + return true; + } + + // same idea with std::vector::resize() + void keep_until(size_t pos) { + // TODO : maybe throw error we remove part of the image (only allow removing the whole image) + // this cannot happen currently because get_common_prefix() only never returns such pos + values.resize(pos); + } + + // Computes FNV-1a hash of the data + static std::string fnv_hash(const uint8_t * data, size_t len) { + const uint64_t fnv_prime = 0x100000001b3ULL; + uint64_t hash = 0xcbf29ce484222325ULL; + + for (size_t i = 0; i < len; ++i) { + hash ^= data[i]; + hash *= fnv_prime; + } + return std::to_string(hash); + } + + // TODO: maybe implement a (de)seralizer for this struct, so we can get rid of functions below + + // return all text tokens (for legacy code), to be used by save/load slot + llama_tokens get_text_tokens() { + llama_tokens output; + for (auto & t : values) { + if (t.txt != LLAMA_TOKEN_NULL) { + output.push_back(t.txt); + } + } + return output; + } + + // clear and set text tokens (for legacy code), to be used by save/load slot + void set_text_tokens(llama_tokens tokens) { + values.clear(); + for (auto & tok : tokens) { + add_text_token(tok); + } + } +}; + +// helper struct to make working with embd batch easier +// note: this will be removed after llama_batch_ext refactoring +struct server_batch_embd { + std::vector pos; + std::vector token; + std::vector n_seq_id; + std::vector seq_id; + std::vector seq_ids; + std::vector logits; + + llama_batch batch; + + server_batch_embd() : server_batch_embd(1) {} + server_batch_embd(int32_t n_tokens) { + token .resize(n_tokens); + pos .resize(n_tokens); + n_seq_id.resize(n_tokens); + logits .resize(n_tokens); + seq_id .resize(n_tokens); + seq_ids .resize(n_tokens + 1); + seq_ids[n_tokens] = nullptr; + + batch = { + /*n_tokens =*/ 0, + /*tokens =*/ token.data(), + /*embd =*/ nullptr, + /*pos =*/ pos.data(), + /*n_seq_id =*/ n_seq_id.data(), + /*seq_id =*/ seq_ids.data(), + /*logits =*/ logits.data(), + }; + + for (int i = 0; i < n_tokens; i++) { + batch.n_seq_id[i] = 1; // only a single seq_id per token is needed + batch.seq_id [i] = seq_id.data() + i; + } + } + + void reserve_embd_batch(float * embd, int32_t n_tokens, llama_pos pos_0, llama_seq_id seq_id) { + GGML_ASSERT(n_tokens <= (int32_t)pos.size()); + batch.n_tokens = n_tokens; + batch.embd = embd; + batch.token = nullptr; + for (int i = 0; i < n_tokens; i++) { + batch.pos [i] = pos_0 + i; + batch.n_seq_id[i] = 1; + batch.seq_id [i][0] = seq_id; + batch.logits [i] = false; + } + } + + void clear() { + batch.n_tokens = 0; + batch.embd = nullptr; + batch.token = token.data(); + } + + int32_t n_tokens() const { + return batch.n_tokens; + } + + bool has_embd() const { + return batch.embd != nullptr && batch.n_tokens > 0; + } + + bool has_text() const { + return batch.token != nullptr && batch.n_tokens > 0; + } +}; + +static int32_t server_img_process( + llama_context * ctx, + mtmd_context * mctx, + server_token & chunk, + server_batch_embd & batch, + llama_pos n_past, + int slot_id) { + GGML_ASSERT(chunk.img); + int32_t n_tokens = mtmd_image_tokens_get_n_tokens(chunk.img.get()); + int32_t ret; + + // encode the image + { + int64_t t0 = ggml_time_ms(); + SRV_INF("encoding image (%d tokens)...\n", (int)n_tokens); + ret = mtmd_encode(mctx, chunk.img.get()); + if (ret != 0) { + SRV_ERR("failed to encode image, status = %d\n", ret); + return ret; + } + SRV_INF("image encoded in %" PRId64 " ms\n", ggml_time_ms() - t0); + } + + float * embd = mtmd_get_output_embd(mctx); + // decode the embeddings + int64_t t1 = ggml_time_ms(); + int32_t n_embd = llama_model_n_embd(llama_get_model(ctx)); + int32_t n_batch = batch.pos.size(); + int32_t i_batch = 0; + int32_t n_img_batches = GGML_PAD(n_tokens, n_batch) / n_batch; + // split into batches + while (i_batch < n_img_batches) { + int32_t pos_offset = i_batch*n_batch; + int32_t n_tokens_batch = std::min(n_batch, n_tokens - pos_offset); + float * embd_batch = embd + pos_offset*n_embd; + batch.clear(); + batch.reserve_embd_batch(embd_batch, n_tokens_batch, n_past, slot_id); + + SRV_INF("decoding embd batch %d/%d, n_tokens_batch = %d\n", i_batch+1, n_img_batches, n_tokens_batch); + + // TODO @ngxson : maybe move this to llama_batch_ext + if (mtmd_decode_use_non_causal(mctx)) { + llama_set_causal_attn(ctx, false); + } + + ret = llama_decode(ctx, batch.batch); + if (ret != 0) { + LOG_ERR("failed to decode image\n"); + llama_set_causal_attn(ctx, true); // restore causal attn + return ret; + } + + if (mtmd_decode_use_non_causal(mctx)) { + llama_set_causal_attn(ctx, true); + } + + i_batch++; + n_past += n_tokens_batch; + } + SRV_INF("image decoded in %" PRId64 " ms\n", ggml_time_ms() - t1); + + batch.clear(); + return 0; +} + +// hacky, support text-only for now +static server_tokens convert_legacy_to_mtmd(llama_tokens & tokenized) { + server_tokens res; + res.has_mtmd = false; + for (auto & tok : tokenized) { + res.add_text_token(tok); + } + return res; +}