|
| 1 | +#include "common.h" |
| 2 | +#include "llama.h" |
| 3 | +#include "binding.h" |
| 4 | + |
| 5 | +#include <cassert> |
| 6 | +#include <cinttypes> |
| 7 | +#include <cmath> |
| 8 | +#include <cstdio> |
| 9 | +#include <cstring> |
| 10 | +#include <fstream> |
| 11 | +#include <iostream> |
| 12 | +#include <string> |
| 13 | +#include <vector> |
| 14 | + |
| 15 | +#if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__)) |
| 16 | +#include <signal.h> |
| 17 | +#include <unistd.h> |
| 18 | +#elif defined (_WIN32) |
| 19 | +#include <signal.h> |
| 20 | +#endif |
| 21 | + |
| 22 | +#if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__)) || defined (_WIN32) |
| 23 | +void sigint_handler(int signo) { |
| 24 | + if (signo == SIGINT) { |
| 25 | + _exit(130); |
| 26 | + } |
| 27 | +} |
| 28 | +#endif |
| 29 | + |
| 30 | +int llama_predict(void* params_ptr, void* state_pr, char* result) { |
| 31 | + gpt_params* params_p = (gpt_params*) params_ptr; |
| 32 | + llama_context* ctx = (llama_context*) state_pr; |
| 33 | + |
| 34 | + gpt_params params = *params_p; |
| 35 | + |
| 36 | + if (params.seed <= 0) { |
| 37 | + params.seed = time(NULL); |
| 38 | + } |
| 39 | + |
| 40 | + std::mt19937 rng(params.seed); |
| 41 | + |
| 42 | + // Add a space in front of the first character to match OG llama tokenizer behavior |
| 43 | + params.prompt.insert(0, 1, ' '); |
| 44 | + |
| 45 | + // tokenize the prompt |
| 46 | + auto embd_inp = ::llama_tokenize(ctx, params.prompt, true); |
| 47 | + |
| 48 | + const int n_ctx = llama_n_ctx(ctx); |
| 49 | + |
| 50 | + // number of tokens to keep when resetting context |
| 51 | + if (params.n_keep < 0 || params.n_keep > (int)embd_inp.size() || params.instruct) { |
| 52 | + params.n_keep = (int)embd_inp.size(); |
| 53 | + } |
| 54 | + |
| 55 | + // determine newline token |
| 56 | + auto llama_token_newline = ::llama_tokenize(ctx, "\n", false); |
| 57 | + |
| 58 | + // TODO: replace with ring-buffer |
| 59 | + std::vector<llama_token> last_n_tokens(n_ctx); |
| 60 | + std::fill(last_n_tokens.begin(), last_n_tokens.end(), 0); |
| 61 | + |
| 62 | + int n_past = 0; |
| 63 | + int n_remain = params.n_predict; |
| 64 | + int n_consumed = 0; |
| 65 | + |
| 66 | + std::vector<llama_token> embd; |
| 67 | + std::string res = ""; |
| 68 | + |
| 69 | + while (n_remain != 0) { |
| 70 | + // predict |
| 71 | + if (embd.size() > 0) { |
| 72 | + // infinite text generation via context swapping |
| 73 | + // if we run out of context: |
| 74 | + // - take the n_keep first tokens from the original prompt (via n_past) |
| 75 | + // - take half of the last (n_ctx - n_keep) tokens and recompute the logits in a batch |
| 76 | + if (n_past + (int) embd.size() > n_ctx) { |
| 77 | + const int n_left = n_past - params.n_keep; |
| 78 | + |
| 79 | + n_past = params.n_keep; |
| 80 | + |
| 81 | + // insert n_left/2 tokens at the start of embd from last_n_tokens |
| 82 | + embd.insert(embd.begin(), last_n_tokens.begin() + n_ctx - n_left/2 - embd.size(), last_n_tokens.end() - embd.size()); |
| 83 | + } |
| 84 | + |
| 85 | + if (llama_eval(ctx, embd.data(), embd.size(), n_past, params.n_threads)) { |
| 86 | + fprintf(stderr, "%s : failed to eval\n", __func__); |
| 87 | + return 1; |
| 88 | + } |
| 89 | + } |
| 90 | + |
| 91 | + n_past += embd.size(); |
| 92 | + embd.clear(); |
| 93 | + |
| 94 | + if ((int) embd_inp.size() <= n_consumed) { |
| 95 | + // out of user input, sample next token |
| 96 | + const int32_t top_k = params.top_k; |
| 97 | + const float top_p = params.top_p; |
| 98 | + const float temp = params.temp; |
| 99 | + const float repeat_penalty = params.repeat_penalty; |
| 100 | + |
| 101 | + llama_token id = 0; |
| 102 | + |
| 103 | + { |
| 104 | + auto logits = llama_get_logits(ctx); |
| 105 | + |
| 106 | + if (params.ignore_eos) { |
| 107 | + logits[llama_token_eos()] = 0; |
| 108 | + } |
| 109 | + |
| 110 | + id = llama_sample_top_p_top_k(ctx, |
| 111 | + last_n_tokens.data() + n_ctx - params.repeat_last_n, |
| 112 | + params.repeat_last_n, top_k, top_p, temp, repeat_penalty); |
| 113 | + |
| 114 | + last_n_tokens.erase(last_n_tokens.begin()); |
| 115 | + last_n_tokens.push_back(id); |
| 116 | + } |
| 117 | + |
| 118 | + // add it to the context |
| 119 | + embd.push_back(id); |
| 120 | + |
| 121 | + // decrement remaining sampling budget |
| 122 | + --n_remain; |
| 123 | + } else { |
| 124 | + // some user input remains from prompt or interaction, forward it to processing |
| 125 | + while ((int) embd_inp.size() > n_consumed) { |
| 126 | + embd.push_back(embd_inp[n_consumed]); |
| 127 | + last_n_tokens.erase(last_n_tokens.begin()); |
| 128 | + last_n_tokens.push_back(embd_inp[n_consumed]); |
| 129 | + ++n_consumed; |
| 130 | + if ((int) embd.size() >= params.n_batch) { |
| 131 | + break; |
| 132 | + } |
| 133 | + } |
| 134 | + } |
| 135 | + |
| 136 | + for (auto id : embd) { |
| 137 | + res += llama_token_to_str(ctx, id); |
| 138 | + } |
| 139 | + |
| 140 | + // end of text token |
| 141 | + if (embd.back() == llama_token_eos()) { |
| 142 | + break; |
| 143 | + } |
| 144 | + } |
| 145 | + |
| 146 | +#if defined (_WIN32) |
| 147 | + signal(SIGINT, SIG_DFL); |
| 148 | +#endif |
| 149 | + strcpy(result, res.c_str()); |
| 150 | + return 0; |
| 151 | +} |
| 152 | + |
| 153 | +void llama_free_model(void *state_ptr) { |
| 154 | + llama_context* ctx = (llama_context*) state_ptr; |
| 155 | + llama_free(ctx); |
| 156 | +} |
| 157 | + |
| 158 | +void llama_free_params(void* params_ptr) { |
| 159 | + gpt_params* params = (gpt_params*) params_ptr; |
| 160 | + delete params; |
| 161 | +} |
| 162 | + |
| 163 | + |
| 164 | +void* llama_allocate_params(const char *prompt, int seed, int threads, int tokens, int top_k, |
| 165 | + float top_p, float temp, float repeat_penalty, int repeat_last_n, bool ignore_eos, bool memory_f16) { |
| 166 | + gpt_params* params = new gpt_params; |
| 167 | + params->seed = seed; |
| 168 | + params->n_threads = threads; |
| 169 | + params->n_predict = tokens; |
| 170 | + params->repeat_last_n = repeat_last_n; |
| 171 | + |
| 172 | + params->top_k = top_k; |
| 173 | + params->top_p = top_p; |
| 174 | + params->memory_f16 = memory_f16; |
| 175 | + params->temp = temp; |
| 176 | + params->repeat_penalty = repeat_penalty; |
| 177 | + |
| 178 | + params->prompt = prompt; |
| 179 | + params->ignore_eos = ignore_eos; |
| 180 | + |
| 181 | + return params; |
| 182 | +} |
| 183 | + |
| 184 | +void* load_model(const char *fname, int n_ctx, int n_parts, int n_seed, bool memory_f16, bool mlock) { |
| 185 | + // load the model |
| 186 | + auto lparams = llama_context_default_params(); |
| 187 | + |
| 188 | + lparams.n_ctx = n_ctx; |
| 189 | + lparams.n_parts = n_parts; |
| 190 | + lparams.seed = n_seed; |
| 191 | + lparams.f16_kv = memory_f16; |
| 192 | + lparams.use_mlock = mlock; |
| 193 | + |
| 194 | + return llama_init_from_file(fname, lparams); |
| 195 | +} |
0 commit comments