@@ -37,13 +37,16 @@ extern "C" {
3737 // ====== Dataset ======
3838
3939 GGML_API ggml_opt_dataset_t ggml_opt_dataset_init (
40- int64_t ne_datapoint , // number of elements per datapoint
41- int64_t ne_label , // number of elements per label
42- int64_t ndata , // total number of datapoints/labels
43- int64_t ndata_shard ); // number of datapoints/labels per shard (unit at which the dataset is shuffled/copied)
40+ enum ggml_type type_data , // the type for the internal data tensor
41+ enum ggml_type type_label , // the type for the internal labels tensor
42+ int64_t ne_datapoint , // number of elements per datapoint
43+ int64_t ne_label , // number of elements per label
44+ int64_t ndata , // total number of datapoints/labels
45+ int64_t ndata_shard ); // number of datapoints/labels per shard (unit at which the dataset is shuffled/copied)
4446 GGML_API void ggml_opt_dataset_free (ggml_opt_dataset_t dataset );
4547
4648 // get underlying tensors that store the data
49+ GGML_API int64_t ggml_opt_dataset_ndata (ggml_opt_dataset_t dataset );
4750 GGML_API struct ggml_tensor * ggml_opt_dataset_data (ggml_opt_dataset_t dataset ); // shape = [ne_datapoint, ndata]
4851 GGML_API struct ggml_tensor * ggml_opt_dataset_labels (ggml_opt_dataset_t dataset ); // shape = [nd_label, ndata]
4952
@@ -56,13 +59,19 @@ extern "C" {
5659 struct ggml_tensor * data_batch , // shape = [ne_datapoint, ndata_batch]
5760 struct ggml_tensor * labels_batch , // shape = [ne_label, ndata_batch]
5861 int64_t ibatch );
62+ GGML_API void ggml_opt_dataset_get_batch_host (
63+ ggml_opt_dataset_t dataset ,
64+ void * data_batch ,
65+ size_t nb_data_batch ,
66+ void * labels_batch ,
67+ int64_t ibatch );
5968
6069 // ====== Model / Context ======
6170
6271 enum ggml_opt_build_type {
63- GGML_OPT_BUILD_TYPE_FORWARD ,
64- GGML_OPT_BUILD_TYPE_GRAD ,
65- GGML_OPT_BUILD_TYPE_OPT ,
72+ GGML_OPT_BUILD_TYPE_FORWARD = 10 ,
73+ GGML_OPT_BUILD_TYPE_GRAD = 20 ,
74+ GGML_OPT_BUILD_TYPE_OPT = 30 ,
6675 };
6776
6877 // parameters that control which optimizer is used and how said optimizer tries to find the minimal loss
@@ -81,20 +90,22 @@ extern "C" {
8190 // userdata can be used to pass arbitrary data
8291 typedef struct ggml_opt_optimizer_params (* ggml_opt_get_optimizer_params )(void * userdata );
8392
84- // returns the default optimizer params (constant)
93+ // returns the default optimizer params (constant, hard-coded values )
8594 // userdata is not used
8695 GGML_API struct ggml_opt_optimizer_params ggml_opt_get_default_optimizer_params (void * userdata );
8796
97+ // casts userdata to ggml_opt_optimizer_params and returns it
98+ GGML_API struct ggml_opt_optimizer_params ggml_opt_get_constant_optimizer_params (void * userdata );
99+
88100 // parameters for initializing a new optimization context
89101 struct ggml_opt_params {
90102 ggml_backend_sched_t backend_sched ; // defines which backends are used to construct the compute graphs
91103
92- struct ggml_context * ctx_compute ; // created in user code, holds non-static tensors
93-
94- // the forward graph is defined by inputs and outputs
95- // those tensors and all tensors inbetween are not intended to be reusable between multiple optimization contexts
96- struct ggml_tensor * inputs ;
97- struct ggml_tensor * outputs ;
104+ // by default the forward graph needs to be reconstructed for each eval
105+ // if ctx_compute, inputs, and outputs are set the graphs are instead allocated statically
106+ struct ggml_context * ctx_compute ;
107+ struct ggml_tensor * inputs ;
108+ struct ggml_tensor * outputs ;
98109
99110 enum ggml_opt_loss_type loss_type ;
100111 enum ggml_opt_build_type build_type ;
@@ -107,12 +118,9 @@ extern "C" {
107118
108119 // get parameters for an optimization context with defaults set where possible
109120 // parameters for which no sensible defaults exist are supplied as arguments to this function
110- GGML_API ggml_opt_params ggml_opt_default_params (
111- ggml_backend_sched_t backend_sched ,
112- struct ggml_context * ctx_compute ,
113- struct ggml_tensor * inputs ,
114- struct ggml_tensor * outputs ,
115- enum ggml_opt_loss_type loss_type );
121+ GGML_API struct ggml_opt_params ggml_opt_default_params (
122+ ggml_backend_sched_t backend_sched ,
123+ enum ggml_opt_loss_type loss_type );
116124
117125 GGML_API ggml_opt_context_t ggml_opt_init (struct ggml_opt_params params );
118126 GGML_API void ggml_opt_free (ggml_opt_context_t opt_ctx );
@@ -121,18 +129,20 @@ extern "C" {
121129 GGML_API void ggml_opt_reset (ggml_opt_context_t opt_ctx , bool optimizer );
122130
123131 // get underlying tensors that store data
132+ // if not using static graphs these pointers become invalid with the next call to ggml_opt_alloc
124133 GGML_API struct ggml_tensor * ggml_opt_inputs ( ggml_opt_context_t opt_ctx ); // forward graph input tensor
125134 GGML_API struct ggml_tensor * ggml_opt_outputs ( ggml_opt_context_t opt_ctx ); // forward graph output tensor
126135 GGML_API struct ggml_tensor * ggml_opt_labels ( ggml_opt_context_t opt_ctx ); // labels to compare outputs against
127136 GGML_API struct ggml_tensor * ggml_opt_loss ( ggml_opt_context_t opt_ctx ); // scalar tensor that contains the loss
128137 GGML_API struct ggml_tensor * ggml_opt_pred ( ggml_opt_context_t opt_ctx ); // predictions made by outputs
129138 GGML_API struct ggml_tensor * ggml_opt_ncorrect (ggml_opt_context_t opt_ctx ); // number of matching predictions between outputs and labels
130139
140+ // get the gradient accumulator for a node from the forward graph
131141 GGML_API struct ggml_tensor * ggml_opt_grad_acc (ggml_opt_context_t opt_ctx , struct ggml_tensor * node );
132142
133143 // ====== Optimization Result ======
134144
135- GGML_API ggml_opt_result_t ggml_opt_result_init ();
145+ GGML_API ggml_opt_result_t ggml_opt_result_init (void );
136146 GGML_API void ggml_opt_result_free (ggml_opt_result_t result );
137147 GGML_API void ggml_opt_result_reset (ggml_opt_result_t result );
138148
@@ -144,11 +154,20 @@ extern "C" {
144154
145155 // ====== Computation ======
146156
147- // do forward pass, increment result if not NULL
148- GGML_API void ggml_opt_forward (ggml_opt_context_t opt_ctx , ggml_opt_result_t result );
157+ // if not using static graphs, this function must be called prior to ggml_opt_alloc
158+ GGML_API void ggml_opt_prepare_alloc (
159+ ggml_opt_context_t opt_ctx ,
160+ struct ggml_context * ctx_compute ,
161+ struct ggml_cgraph * gf ,
162+ struct ggml_tensor * inputs ,
163+ struct ggml_tensor * outputs );
164+
165+ // allocate the next graph for evaluation, either forward or forward + backward
166+ // must be called exactly once prior to calling ggml_opt_eval
167+ GGML_API void ggml_opt_alloc (ggml_opt_context_t opt_ctx , bool backward );
149168
150- // do forward pass, increment result if not NULL, do backward pass
151- GGML_API void ggml_opt_forward_backward (ggml_opt_context_t opt_ctx , ggml_opt_result_t result );
169+ // do forward pass, increment result if not NULL, do backward pass if allocated
170+ GGML_API void ggml_opt_eval (ggml_opt_context_t opt_ctx , ggml_opt_result_t result );
152171
153172 // ############################################################################
154173 // ## The high-level functions start here. They do not depend on any private ##
@@ -200,9 +219,9 @@ extern "C" {
200219 // fit model defined by inputs and outputs to dataset
201220 GGML_API void ggml_opt_fit (
202221 ggml_backend_sched_t backend_sched , // backend scheduler for constructing the compute graphs
203- ggml_context * ctx_compute , // context with temporarily allocated tensors to calculate the outputs
204- ggml_tensor * inputs , // input tensor with shape [ne_datapoint, ndata_batch]
205- ggml_tensor * outputs , // output tensor, must have shape [ne_label, ndata_batch] if labels are used
222+ struct ggml_context * ctx_compute , // context with temporarily allocated tensors to calculate the outputs
223+ struct ggml_tensor * inputs , // input tensor with shape [ne_datapoint, ndata_batch]
224+ struct ggml_tensor * outputs , // output tensor, must have shape [ne_label, ndata_batch] if labels are used
206225 ggml_opt_dataset_t dataset , // dataset with data and optionally also labels
207226 enum ggml_opt_loss_type loss_type , // loss to minimize
208227 ggml_opt_get_optimizer_params get_opt_pars , // callback to get optimizer params, userdata is pointer to epoch (of type int64_t)
0 commit comments