Skip to content

Commit e98c1c1

Browse files
JohannesGaesslerggerganov
authored andcommitted
test: fix OPT_STEP_ADAMW for test-backend-ops (ggml/974)
1 parent cb00020 commit e98c1c1

File tree

3 files changed

+11
-5
lines changed

3 files changed

+11
-5
lines changed

ggml/include/ggml.h

+1
Original file line numberDiff line numberDiff line change
@@ -2052,6 +2052,7 @@ extern "C" {
20522052
GGML_API struct ggml_tensor * ggml_opt_step_adamw(
20532053
struct ggml_context * ctx,
20542054
struct ggml_tensor * a,
2055+
struct ggml_tensor * grad,
20552056
float alpha,
20562057
float beta1,
20572058
float beta2,

ggml/src/ggml.c

+6-4
Original file line numberDiff line numberDiff line change
@@ -7818,12 +7818,14 @@ struct ggml_tensor * ggml_cross_entropy_loss_back(
78187818
struct ggml_tensor * ggml_opt_step_adamw(
78197819
struct ggml_context * ctx,
78207820
struct ggml_tensor * a,
7821+
struct ggml_tensor * grad,
78217822
float alpha,
78227823
float beta1,
78237824
float beta2,
78247825
float eps,
78257826
float wd) {
78267827
GGML_ASSERT(a->flags & GGML_TENSOR_FLAG_PARAM);
7828+
GGML_ASSERT(ggml_are_same_shape(a, grad));
78277829
GGML_ASSERT(alpha > 0.0f);
78287830
GGML_ASSERT(beta1 >= 0.0f && beta1 <= 1.0f);
78297831
GGML_ASSERT(beta2 >= 0.0f && beta2 <= 1.0f);
@@ -7842,9 +7844,9 @@ struct ggml_tensor * ggml_opt_step_adamw(
78427844

78437845
result->op = GGML_OP_OPT_STEP_ADAMW;
78447846
result->src[0] = a;
7845-
result->src[1] = a->grad;
7846-
result->src[2] = ggml_dup_tensor(ctx, a);
7847-
result->src[3] = ggml_dup_tensor(ctx, a);
7847+
result->src[1] = grad;
7848+
result->src[2] = ggml_dup_tensor(ctx, grad);
7849+
result->src[3] = ggml_dup_tensor(ctx, grad);
78487850

78497851
return result;
78507852
}
@@ -18769,7 +18771,7 @@ void ggml_build_opt_adamw(
1876918771

1877018772
if (node->flags & GGML_TENSOR_FLAG_PARAM) {
1877118773
GGML_PRINT_DEBUG("%s: found root node %p\n", __func__, (void *) node);
18772-
struct ggml_tensor * opt_step = ggml_opt_step_adamw(ctx, node, alpha, beta1, beta2, eps, wd);
18774+
struct ggml_tensor * opt_step = ggml_opt_step_adamw(ctx, node, node->grad, alpha, beta1, beta2, eps, wd);
1877318775
ggml_build_forward_expand(gb, opt_step);
1877418776
}
1877518777
}

tests/test-backend-ops.cpp

+4-1
Original file line numberDiff line numberDiff line change
@@ -2751,7 +2751,10 @@ struct test_opt_step_adamw : public test_case {
27512751
ggml_set_param(ctx, a); // Despite tensor a having gradients the output tensor will not.
27522752
ggml_set_name(a, "a");
27532753

2754-
ggml_tensor * out = ggml_opt_step_adamw(ctx, a, alpha, beta1, beta2, eps, wd);
2754+
ggml_tensor * grad = ggml_new_tensor_4d(ctx, type, ne[0], ne[1], ne[2], ne[3]);
2755+
ggml_set_name(grad, "grad");
2756+
2757+
ggml_tensor * out = ggml_opt_step_adamw(ctx, a, grad, alpha, beta1, beta2, eps, wd);
27552758
ggml_set_name(out, "out");
27562759

27572760
return out;

0 commit comments

Comments
 (0)