|
| 1 | +namespace Tensorflow.Keras.Optimizers |
| 2 | +{ |
| 3 | + public class AdamW : Adam |
| 4 | + { |
| 5 | + string name; |
| 6 | + float weight_decay; |
| 7 | + DeviceDType deType; |
| 8 | + List<string> no_decay_params = null; |
| 9 | + public AdamW(float learning_rate= 0.001f, |
| 10 | + float weight_decay= 0.004f, |
| 11 | + float beta_1= 0.9f, |
| 12 | + float beta_2= 0.999f, |
| 13 | + float epsilon= 1e-7f, |
| 14 | + bool amsgrad = false, |
| 15 | + List<string> no_decay_params = null, |
| 16 | + string name= "AdamW") : base(learning_rate, beta_1, beta_2, epsilon, amsgrad) |
| 17 | + { |
| 18 | + this.name = name; |
| 19 | + this.weight_decay = weight_decay; |
| 20 | + this.no_decay_params = no_decay_params; |
| 21 | + } |
| 22 | + |
| 23 | + protected Operation _decay_weights_op(IVariableV1 var, float learning_rate, Dictionary<DeviceDType, Dictionary<string, Tensor>> apply_state) |
| 24 | + { |
| 25 | + var device_dtype = new DeviceDType(); |
| 26 | + device_dtype.DType = var.dtype; |
| 27 | + device_dtype.Device = var.Device; |
| 28 | + bool do_decay = _do_use_weight_decay(var.Name); |
| 29 | + if (do_decay) return var.assign_add( |
| 30 | + -learning_rate * var.AsTensor() * apply_state[deType]["weight_decay"]); |
| 31 | + return tf.no_op(); |
| 32 | + } |
| 33 | + |
| 34 | + |
| 35 | + protected bool _do_use_weight_decay(string param_name) |
| 36 | + { |
| 37 | + // Whether to use L2 weight decay for `param_name`. |
| 38 | + if (this.weight_decay == 0) |
| 39 | + return false; |
| 40 | + |
| 41 | + if (this.no_decay_params != null) |
| 42 | + { |
| 43 | + foreach (var name in no_decay_params) |
| 44 | + { |
| 45 | + if (param_name.Contains(name)) return false; |
| 46 | + } |
| 47 | + |
| 48 | + } |
| 49 | + return true; |
| 50 | + } |
| 51 | + |
| 52 | + protected override Operation _resource_apply_dense(IVariableV1 var, Tensor grad, Dictionary<DeviceDType, Dictionary<string, Tensor>> apply_state) |
| 53 | + { |
| 54 | + var decay = _decay_weights_op(var, _hyper["learning_rate"], apply_state); |
| 55 | + tf.control_dependencies(new[] { decay }); |
| 56 | + return base._resource_apply_dense(var, grad, apply_state); |
| 57 | + } |
| 58 | + |
| 59 | + protected override void _prepare_local(DeviceDType device_dtype, Dictionary<DeviceDType, Dictionary<string, Tensor>> apply_state) |
| 60 | + { |
| 61 | + this.deType = device_dtype; |
| 62 | + base._prepare_local(device_dtype, apply_state); |
| 63 | + apply_state[device_dtype]["weight_decay"] = tf.constant( |
| 64 | + weight_decay, name: "adam_weight_decay_rate"); |
| 65 | + } |
| 66 | + } |
| 67 | +} |
0 commit comments