Skip to content

Commit cee751c

Browse files
committed
opt : fix n_outputs
ggml-ci
1 parent 4e73b81 commit cee751c

File tree

1 file changed

+12
-1
lines changed

1 file changed

+12
-1
lines changed

src/llama-context.cpp

+12-1
Original file line numberDiff line numberDiff line change
@@ -1955,6 +1955,17 @@ void llama_context::opt_epoch_iter(
19551955
//}
19561956
llama_ubatch ubatch = kv_self->ubatch_next(sbatch, cparams.n_ubatch, embd_pooled);
19571957

1958+
n_outputs = ubatch.n_tokens;
1959+
1960+
printf("ubatch.n_tokens = %d\n", ubatch.n_tokens);
1961+
1962+
// TODO: not sure if this is needed
1963+
if (!kv_self->find_slot(ubatch)) {
1964+
LLAMA_LOG_WARN("%s: failed to find KV cache slot for ubatch of size %d\n", __func__, ubatch.n_tokens);
1965+
1966+
GGML_ABORT("TODO: handle this error");
1967+
}
1968+
19581969
auto * gf = graph_init();
19591970
auto res = graph_build(ctx_compute.get(), gf, ubatch, LLM_GRAPH_TYPE_DEFAULT);
19601971

@@ -1970,7 +1981,7 @@ void llama_context::opt_epoch_iter(
19701981
};
19711982
ctx_compute_opt = ggml_init(params);
19721983
}
1973-
ggml_opt_prepare_alloc(opt_ctx, ctx_compute_opt, gf, res->get_tokens(), ggml_graph_node(gf, -1));
1984+
ggml_opt_prepare_alloc(opt_ctx, ctx_compute_opt, gf, res->get_tokens(), res->get_logits());
19741985
ggml_opt_alloc(opt_ctx, train);
19751986
//llama_set_inputs(*lctx, ubatch);
19761987
res->set_inputs(&ubatch);

0 commit comments

Comments
 (0)