@@ -1955,6 +1955,17 @@ void llama_context::opt_epoch_iter(
1955
1955
// }
1956
1956
llama_ubatch ubatch = kv_self->ubatch_next (sbatch, cparams.n_ubatch , embd_pooled);
1957
1957
1958
+ n_outputs = ubatch.n_tokens ;
1959
+
1960
+ printf (" ubatch.n_tokens = %d\n " , ubatch.n_tokens );
1961
+
1962
+ // TODO: not sure if this is needed
1963
+ if (!kv_self->find_slot (ubatch)) {
1964
+ LLAMA_LOG_WARN (" %s: failed to find KV cache slot for ubatch of size %d\n " , __func__, ubatch.n_tokens );
1965
+
1966
+ GGML_ABORT (" TODO: handle this error" );
1967
+ }
1968
+
1958
1969
auto * gf = graph_init ();
1959
1970
auto res = graph_build (ctx_compute.get (), gf, ubatch, LLM_GRAPH_TYPE_DEFAULT);
1960
1971
@@ -1970,7 +1981,7 @@ void llama_context::opt_epoch_iter(
1970
1981
};
1971
1982
ctx_compute_opt = ggml_init (params);
1972
1983
}
1973
- ggml_opt_prepare_alloc (opt_ctx, ctx_compute_opt, gf, res->get_tokens (), ggml_graph_node (gf, - 1 ));
1984
+ ggml_opt_prepare_alloc (opt_ctx, ctx_compute_opt, gf, res->get_tokens (), res-> get_logits ( ));
1974
1985
ggml_opt_alloc (opt_ctx, train);
1975
1986
// llama_set_inputs(*lctx, ubatch);
1976
1987
res->set_inputs (&ubatch);
0 commit comments