Skip to content

Commit 63d7e65

Browse files
reset graph at beginning of opt period
1 parent e27f93e commit 63d7e65

File tree

3 files changed

+6
-5
lines changed

3 files changed

+6
-5
lines changed

ggml/src/ggml-opt.cpp

+3-4
Original file line numberDiff line numberDiff line change
@@ -694,6 +694,9 @@ void ggml_opt_prepare_alloc(
694694

695695
void ggml_opt_alloc(ggml_opt_context_t opt_ctx, bool backward) {
696696
GGML_ASSERT(!opt_ctx->eval_ready);
697+
if (opt_ctx->build_type == GGML_OPT_BUILD_TYPE_OPT && opt_ctx->opt_period > 1 && opt_ctx->opt_i == 0) {
698+
ggml_graph_reset(opt_ctx->gb_grad);
699+
}
697700
if (backward) {
698701
const int32_t opt_i_next = (opt_ctx->opt_i + 1) % opt_ctx->opt_period;
699702
opt_ctx->build_type = opt_i_next == 0 ? GGML_OPT_BUILD_TYPE_OPT : GGML_OPT_BUILD_TYPE_GRAD;
@@ -743,10 +746,6 @@ void ggml_opt_alloc(ggml_opt_context_t opt_ctx, bool backward) {
743746
ggml_backend_sched_alloc_graph(opt_ctx->backend_sched, opt_ctx->allocated_graph_copy);
744747
opt_ctx->allocated_graph = graph;
745748

746-
if (opt_ctx->build_type == GGML_OPT_BUILD_TYPE_OPT && opt_ctx->opt_period > 1 && opt_ctx->opt_i == 0) {
747-
ggml_graph_reset(opt_ctx->gb_grad);
748-
}
749-
750749
opt_ctx->eval_ready = true;
751750
}
752751

ggml/src/ggml.c

+3
Original file line numberDiff line numberDiff line change
@@ -6038,6 +6038,9 @@ struct ggml_tensor * ggml_set_zero(struct ggml_tensor * tensor) {
60386038
}
60396039

60406040
void ggml_graph_reset(struct ggml_cgraph * cgraph) {
6041+
if (!cgraph) {
6042+
return;
6043+
}
60416044
GGML_ASSERT(cgraph->grads != NULL);
60426045

60436046
for (int i = 0; i < cgraph->n_nodes; i++) {

tests/test-opt.cpp

-1
Original file line numberDiff line numberDiff line change
@@ -584,7 +584,6 @@ static std::pair<int, int> test_gradient_accumulation(
584584

585585
struct helper_ctx_data cd = helper_get_ctx_data(
586586
backend_sched, backend, /*init_opt_ctx =*/ true, /*optimizer_defaults =*/ false, /*nbatch_logical =*/ 6, nbatch_physical, loss_type);
587-
struct ggml_tensor * loss = ggml_opt_loss(cd.opt_ctx);
588587

589588
std::vector<float> grad_history(ndata);
590589
for (int64_t idata = 0; idata < ndata; ++idata) {

0 commit comments

Comments
 (0)