Skip to content

Commit 6862d3a

Browse files
author
Beacontownfc
committed
Add AdamW optimizer
1 parent fac48f5 commit 6862d3a

File tree

3 files changed

+104
-0
lines changed

3 files changed

+104
-0
lines changed

src/TensorFlowNET.Core/Keras/IOptimizerApi.cs

+21
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,27 @@ IOptimizer Adam(float learning_rate = 0.001f,
2525
bool amsgrad = false,
2626
string name = "Adam");
2727

28+
/// <summary>
29+
/// Adam enables L2 weight decay on gradients.
30+
/// </summary>
31+
/// <param name="learning_rate"></param>
32+
/// <param name="weight_decay"></param>
33+
/// <param name="beta_1"></param>
34+
/// <param name="beta_2"></param>
35+
/// <param name="epsilon"></param>
36+
/// <param name="amsgrad"></param>
37+
/// <param name="decay_params"></param>
38+
/// <param name="name"></param>
39+
/// <returns></returns>
40+
IOptimizer AdamW(float learning_rate = 0.001f,
41+
float weight_decay = 0.004f,
42+
float beta_1 = 0.9f,
43+
float beta_2 = 0.999f,
44+
float epsilon = 1e-7f,
45+
bool amsgrad = false,
46+
List<string> no_decay_params = null,
47+
string name = "AdamW");
48+
2849
/// <summary>
2950
/// Construct a new RMSprop optimizer.
3051
/// </summary>
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
namespace Tensorflow.Keras.Optimizers
2+
{
3+
public class AdamW : Adam
4+
{
5+
string name;
6+
float weight_decay;
7+
DeviceDType deType;
8+
List<string> no_decay_params = null;
9+
public AdamW(float learning_rate= 0.001f,
10+
float weight_decay= 0.004f,
11+
float beta_1= 0.9f,
12+
float beta_2= 0.999f,
13+
float epsilon= 1e-7f,
14+
bool amsgrad = false,
15+
List<string> no_decay_params = null,
16+
string name= "AdamW") : base(learning_rate, beta_1, beta_2, epsilon, amsgrad)
17+
{
18+
this.name = name;
19+
this.weight_decay = weight_decay;
20+
this.no_decay_params = no_decay_params;
21+
}
22+
23+
protected Operation _decay_weights_op(IVariableV1 var, float learning_rate, Dictionary<DeviceDType, Dictionary<string, Tensor>> apply_state)
24+
{
25+
var device_dtype = new DeviceDType();
26+
device_dtype.DType = var.dtype;
27+
device_dtype.Device = var.Device;
28+
bool do_decay = _do_use_weight_decay(var.Name);
29+
if (do_decay) return var.assign_add(
30+
-learning_rate * var.AsTensor() * apply_state[deType]["weight_decay"]);
31+
return tf.no_op();
32+
}
33+
34+
35+
protected bool _do_use_weight_decay(string param_name)
36+
{
37+
// Whether to use L2 weight decay for `param_name`.
38+
if (this.weight_decay == 0)
39+
return false;
40+
41+
if (this.no_decay_params != null)
42+
{
43+
foreach (var name in no_decay_params)
44+
{
45+
if (param_name.Contains(name)) return false;
46+
}
47+
48+
}
49+
return true;
50+
}
51+
52+
protected override Operation _resource_apply_dense(IVariableV1 var, Tensor grad, Dictionary<DeviceDType, Dictionary<string, Tensor>> apply_state)
53+
{
54+
var decay = _decay_weights_op(var, _hyper["learning_rate"], apply_state);
55+
tf.control_dependencies(new[] { decay });
56+
return base._resource_apply_dense(var, grad, apply_state);
57+
}
58+
59+
protected override void _prepare_local(DeviceDType device_dtype, Dictionary<DeviceDType, Dictionary<string, Tensor>> apply_state)
60+
{
61+
this.deType = device_dtype;
62+
base._prepare_local(device_dtype, apply_state);
63+
apply_state[device_dtype]["weight_decay"] = tf.constant(
64+
weight_decay, name: "adam_weight_decay_rate");
65+
}
66+
}
67+
}

src/TensorFlowNET.Keras/Optimizers/OptimizerApi.cs

+16
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,22 @@ public IOptimizer Adam(float learning_rate = 0.001f,
2929
amsgrad: amsgrad,
3030
name: name);
3131

32+
public IOptimizer AdamW(float learning_rate = 0.001f,
33+
float weight_decay = 0.004f,
34+
float beta_1 = 0.9f,
35+
float beta_2 = 0.999f,
36+
float epsilon = 1e-7f,
37+
bool amsgrad = false,
38+
List<string> no_decay_params = null,
39+
string name = "AdamW") => new AdamW(learning_rate: learning_rate,
40+
beta_1: beta_1,
41+
beta_2: beta_2,
42+
epsilon: epsilon,
43+
amsgrad: amsgrad,
44+
name: name,
45+
weight_decay: weight_decay,
46+
no_decay_params: no_decay_params);
47+
3248
/// <summary>
3349
/// Construct a new RMSprop optimizer.
3450
/// </summary>

0 commit comments

Comments
 (0)