diff --git a/passl/optimizer/adamw.py b/passl/optimizer/adamw.py index f6355d6..92a3210 100644 --- a/passl/optimizer/adamw.py +++ b/passl/optimizer/adamw.py @@ -110,7 +110,7 @@ def step(self): sub_exp_avg_sq = paddle.gather( exp_avg_sq, index, axis=axis) - _, _, _, _, _, _ = _C_ops.adamw( + _, _, _, _, _, *_ = _C_ops.adamw( sub_p, grad, paddle.to_tensor(lr), sub_exp_avg, sub_exp_avg_sq, beta1_pow, beta2_pow, master_param, sub_p, sub_exp_avg, @@ -126,7 +126,7 @@ def step(self): exp_avg_sq.scatter_(index, sub_exp_avg_sq) else: - _, _, _, _, _, _ = _C_ops.adamw( + _, _, _, _, _, *_ = _C_ops.adamw( p, grad, paddle.to_tensor(lr), exp_avg, exp_avg_sq, beta1_pow, beta2_pow, master_param, p, exp_avg, exp_avg_sq,