Skip to content

Commit 10d2af0

Browse files
llama/ggml: add LLM training support (#10544)
* llama/ggml: add LLM training support more compact progress bar llama_save_model_to_file llama_opt_param_filter ggml_graph_dup force_grads refactor ggml_opt, fix test-opt * remove logits_all * refactor CUDA implementation for ACC * reset graph at beginning of opt period
1 parent 064cc59 commit 10d2af0

31 files changed

+1409
-353
lines changed

build-xcframework.sh

+1
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,7 @@ setup_framework_structure() {
117117
# Copy all required headers (common for all platforms)
118118
cp include/llama.h ${header_path}
119119
cp ggml/include/ggml.h ${header_path}
120+
cp ggml/include/ggml-opt.h ${header_path}
120121
cp ggml/include/ggml-alloc.h ${header_path}
121122
cp ggml/include/ggml-backend.h ${header_path}
122123
cp ggml/include/ggml-metal.h ${header_path}

common/common.cpp

+17
Original file line numberDiff line numberDiff line change
@@ -1565,3 +1565,20 @@ common_control_vector_data common_control_vector_load(const std::vector<common_c
15651565

15661566
return result;
15671567
}
1568+
1569+
ggml_opt_dataset_t common_opt_dataset_init(struct llama_context * ctx, const std::vector<llama_token> & tokens, int64_t stride) {
1570+
const int64_t ne_datapoint = llama_n_ctx(ctx);
1571+
const int64_t ndata = (tokens.size() - ne_datapoint - 1) / stride;
1572+
ggml_opt_dataset_t result = ggml_opt_dataset_init(
1573+
GGML_TYPE_I32, GGML_TYPE_I32, ne_datapoint, ne_datapoint, ndata, /*ndata_shard =*/ 1);
1574+
1575+
llama_token * data = (llama_token *) ggml_opt_dataset_data(result)->data;
1576+
llama_token * labels = (llama_token *) ggml_opt_dataset_labels(result)->data;
1577+
1578+
for (int64_t idata = 0; idata < ndata; ++idata) {
1579+
memcpy(data + idata*ne_datapoint, tokens.data() + idata*stride + 0, ne_datapoint*sizeof(llama_token));
1580+
memcpy(labels + idata*ne_datapoint, tokens.data() + idata*stride + 1, ne_datapoint*sizeof(llama_token));
1581+
}
1582+
1583+
return result;
1584+
}

common/common.h

+6
Original file line numberDiff line numberDiff line change
@@ -666,3 +666,9 @@ const char * const LLM_KV_SPLIT_COUNT = "split.count";
666666
const char * const LLM_KV_SPLIT_TENSORS_COUNT = "split.tensors.count";
667667

668668
}
669+
670+
//
671+
// training utils
672+
//
673+
674+
ggml_opt_dataset_t common_opt_dataset_init(struct llama_context * ctx, const std::vector<llama_token> & tokens, int64_t stride);

examples/CMakeLists.txt

+1
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ else()
3232
add_subdirectory(speculative)
3333
add_subdirectory(speculative-simple)
3434
add_subdirectory(gen-docs)
35+
add_subdirectory(training)
3536
if (NOT GGML_BACKEND_DL)
3637
add_subdirectory(convert-llama2c-to-ggml)
3738
# these examples use the backends directly and cannot be built with dynamic loading

examples/training/CMakeLists.txt

+5
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
set(TARGET llama-finetune)
2+
add_executable(${TARGET} finetune.cpp)
3+
install(TARGETS ${TARGET} RUNTIME)
4+
target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT})
5+
target_compile_features(${TARGET} PRIVATE cxx_std_11)

examples/training/README.md

+17
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
# llama.cpp/examples/training
2+
3+
This directory contains examples related to language model training using llama.cpp/GGML.
4+
So far finetuning is technically functional (for FP32 models and limited hardware setups) but the code is very much WIP.
5+
Finetuning of Stories 260K and LLaMA 3.2 1b seems to work with 24 GB of memory.
6+
**For CPU training, compile llama.cpp without any additional backends such as CUDA.**
7+
**For CUDA training, use the maximum number of GPU layers.**
8+
9+
Proof of concept:
10+
11+
``` sh
12+
export model_name=llama_3.2-1b && export quantization=f32
13+
./build/bin/finetune --file wikitext-2-raw/wiki.test.raw -ngl 999 --model models/${model_name}-${quantization}.gguf -c 512 -b 512 -ub 512
14+
./build/bin/perplexity --file wikitext-2-raw/wiki.test.raw -ngl 999 --model finetuned-model.gguf
15+
```
16+
17+
The perplexity value of the finetuned model should be lower after training on the test set for 2 epochs.

examples/training/finetune.cpp

+96
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
#include "arg.h"
2+
#include "common.h"
3+
#include "log.h"
4+
#include "llama.h"
5+
6+
#include <cmath>
7+
#include <cstdio>
8+
#include <cstring>
9+
#include <ctime>
10+
#include <vector>
11+
12+
#if defined(_MSC_VER)
13+
#pragma warning(disable: 4244 4267) // possible loss of data
14+
#endif
15+
16+
int main(int argc, char ** argv) {
17+
common_params params;
18+
19+
params.escape = false;
20+
21+
if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_PERPLEXITY)) {
22+
return 1;
23+
}
24+
25+
if (params.use_mmap) {
26+
LOG_INF("%s: force disabling memory mapping because it would result in-read-only pointers to the weights\n", __func__);
27+
params.use_mmap = false;
28+
}
29+
if (params.cache_type_k != GGML_TYPE_F32) {
30+
LOG_INF("%s: force changing k cache type to f32 due to a lack of f16 support for OUT_PROD\n", __func__);
31+
params.cache_type_k = GGML_TYPE_F32;
32+
}
33+
if (params.cache_type_v != GGML_TYPE_F32) {
34+
LOG_INF("%s: force changing v cache type to f32 due to a lack of f16 support for OUT_PROD\n", __func__);
35+
params.cache_type_v = GGML_TYPE_F32;
36+
}
37+
38+
common_init();
39+
llama_backend_init();
40+
llama_numa_init(params.numa);
41+
42+
// load the model and apply lora adapter, if any
43+
common_init_result llama_init = common_init_from_params(params);
44+
llama_model_ptr & model = llama_init.model;
45+
llama_context_ptr & ctx = llama_init.context;
46+
47+
if (model == NULL) {
48+
LOG_ERR("%s: unable to load model\n", __func__);
49+
return 1;
50+
}
51+
52+
// print system information
53+
{
54+
LOG_INF("\n");
55+
LOG_INF("%s\n", common_params_get_system_info(params).c_str());
56+
}
57+
58+
constexpr float val_split = 0.05f;
59+
60+
std::vector<llama_token> tokens = common_tokenize(ctx.get(), params.prompt, true);
61+
ggml_opt_dataset_t dataset = common_opt_dataset_init(ctx.get(), tokens, llama_n_ctx(ctx.get())/2);
62+
63+
struct ggml_opt_optimizer_params optimizer_params = ggml_opt_get_default_optimizer_params(nullptr);
64+
optimizer_params.adamw.alpha = 1e-7f; // learning rate
65+
66+
struct llama_opt_params lopt_params {
67+
/*n_ctx_train =*/ 0,
68+
/*param_filter =*/ llama_opt_param_filter_all,
69+
/*param_filter_ud =*/ nullptr,
70+
/*get_opt_pars =*/ ggml_opt_get_constant_optimizer_params,
71+
/*get_opt_pars_ud =*/ &optimizer_params,
72+
};
73+
llama_opt_init(ctx.get(), model.get(), lopt_params);
74+
75+
const int64_t idata_split = ggml_opt_dataset_ndata(dataset) * (1.0f - val_split);
76+
77+
ggml_opt_result_t result_train = ggml_opt_result_init();
78+
ggml_opt_result_t result_eval = ggml_opt_result_init();
79+
80+
for (int epoch = 0; epoch < 2; ++epoch) {
81+
llama_opt_epoch(ctx.get(), dataset, result_train, result_eval, idata_split,
82+
ggml_opt_epoch_callback_progress_bar, ggml_opt_epoch_callback_progress_bar);
83+
fprintf(stderr, "\n");
84+
85+
ggml_opt_result_reset(result_train);
86+
ggml_opt_result_reset(result_eval);
87+
}
88+
ggml_opt_result_free(result_train);
89+
ggml_opt_result_free(result_eval);
90+
91+
llama_model_save_to_file(model.get(), "finetuned-model.gguf");
92+
93+
llama_backend_free();
94+
95+
return 0;
96+
}

ggml/include/ggml-opt.h

+47-28
Original file line numberDiff line numberDiff line change
@@ -37,13 +37,16 @@ extern "C" {
3737
// ====== Dataset ======
3838

3939
GGML_API ggml_opt_dataset_t ggml_opt_dataset_init(
40-
int64_t ne_datapoint, // number of elements per datapoint
41-
int64_t ne_label, // number of elements per label
42-
int64_t ndata, // total number of datapoints/labels
43-
int64_t ndata_shard); // number of datapoints/labels per shard (unit at which the dataset is shuffled/copied)
40+
enum ggml_type type_data, // the type for the internal data tensor
41+
enum ggml_type type_label, // the type for the internal labels tensor
42+
int64_t ne_datapoint, // number of elements per datapoint
43+
int64_t ne_label, // number of elements per label
44+
int64_t ndata, // total number of datapoints/labels
45+
int64_t ndata_shard); // number of datapoints/labels per shard (unit at which the dataset is shuffled/copied)
4446
GGML_API void ggml_opt_dataset_free(ggml_opt_dataset_t dataset);
4547

4648
// get underlying tensors that store the data
49+
GGML_API int64_t ggml_opt_dataset_ndata (ggml_opt_dataset_t dataset);
4750
GGML_API struct ggml_tensor * ggml_opt_dataset_data (ggml_opt_dataset_t dataset); // shape = [ne_datapoint, ndata]
4851
GGML_API struct ggml_tensor * ggml_opt_dataset_labels(ggml_opt_dataset_t dataset); // shape = [nd_label, ndata]
4952

@@ -56,13 +59,19 @@ extern "C" {
5659
struct ggml_tensor * data_batch, // shape = [ne_datapoint, ndata_batch]
5760
struct ggml_tensor * labels_batch, // shape = [ne_label, ndata_batch]
5861
int64_t ibatch);
62+
GGML_API void ggml_opt_dataset_get_batch_host(
63+
ggml_opt_dataset_t dataset,
64+
void * data_batch,
65+
size_t nb_data_batch,
66+
void * labels_batch,
67+
int64_t ibatch);
5968

6069
// ====== Model / Context ======
6170

6271
enum ggml_opt_build_type {
63-
GGML_OPT_BUILD_TYPE_FORWARD,
64-
GGML_OPT_BUILD_TYPE_GRAD,
65-
GGML_OPT_BUILD_TYPE_OPT,
72+
GGML_OPT_BUILD_TYPE_FORWARD = 10,
73+
GGML_OPT_BUILD_TYPE_GRAD = 20,
74+
GGML_OPT_BUILD_TYPE_OPT = 30,
6675
};
6776

6877
// parameters that control which optimizer is used and how said optimizer tries to find the minimal loss
@@ -81,20 +90,22 @@ extern "C" {
8190
// userdata can be used to pass arbitrary data
8291
typedef struct ggml_opt_optimizer_params (*ggml_opt_get_optimizer_params)(void * userdata);
8392

84-
// returns the default optimizer params (constant)
93+
// returns the default optimizer params (constant, hard-coded values)
8594
// userdata is not used
8695
GGML_API struct ggml_opt_optimizer_params ggml_opt_get_default_optimizer_params(void * userdata);
8796

97+
// casts userdata to ggml_opt_optimizer_params and returns it
98+
GGML_API struct ggml_opt_optimizer_params ggml_opt_get_constant_optimizer_params(void * userdata);
99+
88100
// parameters for initializing a new optimization context
89101
struct ggml_opt_params {
90102
ggml_backend_sched_t backend_sched; // defines which backends are used to construct the compute graphs
91103

92-
struct ggml_context * ctx_compute; // created in user code, holds non-static tensors
93-
94-
// the forward graph is defined by inputs and outputs
95-
// those tensors and all tensors inbetween are not intended to be reusable between multiple optimization contexts
96-
struct ggml_tensor * inputs;
97-
struct ggml_tensor * outputs;
104+
// by default the forward graph needs to be reconstructed for each eval
105+
// if ctx_compute, inputs, and outputs are set the graphs are instead allocated statically
106+
struct ggml_context * ctx_compute;
107+
struct ggml_tensor * inputs;
108+
struct ggml_tensor * outputs;
98109

99110
enum ggml_opt_loss_type loss_type;
100111
enum ggml_opt_build_type build_type;
@@ -107,12 +118,9 @@ extern "C" {
107118

108119
// get parameters for an optimization context with defaults set where possible
109120
// parameters for which no sensible defaults exist are supplied as arguments to this function
110-
GGML_API ggml_opt_params ggml_opt_default_params(
111-
ggml_backend_sched_t backend_sched,
112-
struct ggml_context * ctx_compute,
113-
struct ggml_tensor * inputs,
114-
struct ggml_tensor * outputs,
115-
enum ggml_opt_loss_type loss_type);
121+
GGML_API struct ggml_opt_params ggml_opt_default_params(
122+
ggml_backend_sched_t backend_sched,
123+
enum ggml_opt_loss_type loss_type);
116124

117125
GGML_API ggml_opt_context_t ggml_opt_init(struct ggml_opt_params params);
118126
GGML_API void ggml_opt_free(ggml_opt_context_t opt_ctx);
@@ -121,18 +129,20 @@ extern "C" {
121129
GGML_API void ggml_opt_reset(ggml_opt_context_t opt_ctx, bool optimizer);
122130

123131
// get underlying tensors that store data
132+
// if not using static graphs these pointers become invalid with the next call to ggml_opt_alloc
124133
GGML_API struct ggml_tensor * ggml_opt_inputs( ggml_opt_context_t opt_ctx); // forward graph input tensor
125134
GGML_API struct ggml_tensor * ggml_opt_outputs( ggml_opt_context_t opt_ctx); // forward graph output tensor
126135
GGML_API struct ggml_tensor * ggml_opt_labels( ggml_opt_context_t opt_ctx); // labels to compare outputs against
127136
GGML_API struct ggml_tensor * ggml_opt_loss( ggml_opt_context_t opt_ctx); // scalar tensor that contains the loss
128137
GGML_API struct ggml_tensor * ggml_opt_pred( ggml_opt_context_t opt_ctx); // predictions made by outputs
129138
GGML_API struct ggml_tensor * ggml_opt_ncorrect(ggml_opt_context_t opt_ctx); // number of matching predictions between outputs and labels
130139

140+
// get the gradient accumulator for a node from the forward graph
131141
GGML_API struct ggml_tensor * ggml_opt_grad_acc(ggml_opt_context_t opt_ctx, struct ggml_tensor * node);
132142

133143
// ====== Optimization Result ======
134144

135-
GGML_API ggml_opt_result_t ggml_opt_result_init();
145+
GGML_API ggml_opt_result_t ggml_opt_result_init(void);
136146
GGML_API void ggml_opt_result_free(ggml_opt_result_t result);
137147
GGML_API void ggml_opt_result_reset(ggml_opt_result_t result);
138148

@@ -144,11 +154,20 @@ extern "C" {
144154

145155
// ====== Computation ======
146156

147-
// do forward pass, increment result if not NULL
148-
GGML_API void ggml_opt_forward(ggml_opt_context_t opt_ctx, ggml_opt_result_t result);
157+
// if not using static graphs, this function must be called prior to ggml_opt_alloc
158+
GGML_API void ggml_opt_prepare_alloc(
159+
ggml_opt_context_t opt_ctx,
160+
struct ggml_context * ctx_compute,
161+
struct ggml_cgraph * gf,
162+
struct ggml_tensor * inputs,
163+
struct ggml_tensor * outputs);
164+
165+
// allocate the next graph for evaluation, either forward or forward + backward
166+
// must be called exactly once prior to calling ggml_opt_eval
167+
GGML_API void ggml_opt_alloc(ggml_opt_context_t opt_ctx, bool backward);
149168

150-
// do forward pass, increment result if not NULL, do backward pass
151-
GGML_API void ggml_opt_forward_backward(ggml_opt_context_t opt_ctx, ggml_opt_result_t result);
169+
// do forward pass, increment result if not NULL, do backward pass if allocated
170+
GGML_API void ggml_opt_eval(ggml_opt_context_t opt_ctx, ggml_opt_result_t result);
152171

153172
// ############################################################################
154173
// ## The high-level functions start here. They do not depend on any private ##
@@ -200,9 +219,9 @@ extern "C" {
200219
// fit model defined by inputs and outputs to dataset
201220
GGML_API void ggml_opt_fit(
202221
ggml_backend_sched_t backend_sched, // backend scheduler for constructing the compute graphs
203-
ggml_context * ctx_compute, // context with temporarily allocated tensors to calculate the outputs
204-
ggml_tensor * inputs, // input tensor with shape [ne_datapoint, ndata_batch]
205-
ggml_tensor * outputs, // output tensor, must have shape [ne_label, ndata_batch] if labels are used
222+
struct ggml_context * ctx_compute, // context with temporarily allocated tensors to calculate the outputs
223+
struct ggml_tensor * inputs, // input tensor with shape [ne_datapoint, ndata_batch]
224+
struct ggml_tensor * outputs, // output tensor, must have shape [ne_label, ndata_batch] if labels are used
206225
ggml_opt_dataset_t dataset, // dataset with data and optionally also labels
207226
enum ggml_opt_loss_type loss_type, // loss to minimize
208227
ggml_opt_get_optimizer_params get_opt_pars, // callback to get optimizer params, userdata is pointer to epoch (of type int64_t)

ggml/include/ggml.h

+6-7
Original file line numberDiff line numberDiff line change
@@ -768,7 +768,7 @@ extern "C" {
768768
// Tensor flags
769769
GGML_API void ggml_set_input(struct ggml_tensor * tensor);
770770
GGML_API void ggml_set_output(struct ggml_tensor * tensor);
771-
GGML_API void ggml_set_param(struct ggml_context * ctx, struct ggml_tensor * tensor);
771+
GGML_API void ggml_set_param(struct ggml_tensor * tensor);
772772
GGML_API void ggml_set_loss(struct ggml_tensor * tensor);
773773

774774
//
@@ -938,7 +938,7 @@ extern "C" {
938938
GGML_API struct ggml_tensor * ggml_repeat_back(
939939
struct ggml_context * ctx,
940940
struct ggml_tensor * a,
941-
struct ggml_tensor * b);
941+
struct ggml_tensor * b); // sum up values that are adjacent in dims > 0 instead of repeated with same stride
942942

943943
// concat a and b along dim
944944
// used in stable-diffusion
@@ -2049,15 +2049,14 @@ extern "C" {
20492049

20502050
GGML_API void ggml_build_forward_expand(struct ggml_cgraph * cgraph, struct ggml_tensor * tensor);
20512051
GGML_API void ggml_build_backward_expand(
2052-
struct ggml_context * ctx_static, // context for static gradients (loss + gradient accumulation)
2053-
struct ggml_context * ctx_compute, // context for gradient computation
2054-
struct ggml_cgraph * cgraph,
2055-
bool accumulate); // whether or not gradients should be accumulated, requires static allocation of tensors in ctx_static
2052+
struct ggml_context * ctx, // context for gradient computation
2053+
struct ggml_cgraph * cgraph,
2054+
struct ggml_tensor ** grad_accs);
20562055

20572056
// graph allocation in a context
20582057
GGML_API struct ggml_cgraph * ggml_new_graph (struct ggml_context * ctx); // size = GGML_DEFAULT_GRAPH_SIZE, grads = false
20592058
GGML_API struct ggml_cgraph * ggml_new_graph_custom(struct ggml_context * ctx, size_t size, bool grads);
2060-
GGML_API struct ggml_cgraph * ggml_graph_dup (struct ggml_context * ctx, struct ggml_cgraph * cgraph);
2059+
GGML_API struct ggml_cgraph * ggml_graph_dup (struct ggml_context * ctx, struct ggml_cgraph * cgraph, bool force_grads);
20612060
GGML_API void ggml_graph_cpy (struct ggml_cgraph * src, struct ggml_cgraph * dst);
20622061
GGML_API void ggml_graph_reset (struct ggml_cgraph * cgraph); // set regular grads + optimizer momenta to 0, set loss grad to 1
20632062
GGML_API void ggml_graph_clear (struct ggml_cgraph * cgraph);

ggml/src/ggml-backend.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -1111,7 +1111,7 @@ static void ggml_backend_sched_split_graph(ggml_backend_sched_t sched, struct gg
11111111

11121112
const int node_backend_id = tensor_backend_id(node);
11131113

1114-
assert(node_backend_id != -1); // all nodes should be assigned by now
1114+
assert(node_backend_id != -1); // all nodes should be assigned by now, this can happen if there is no CPU fallback
11151115

11161116
// check if we should start a new split based on the sources of the current node
11171117
bool need_new_split = false;

0 commit comments

Comments
 (0)