-
Notifications
You must be signed in to change notification settings - Fork 214
Learning rate #106
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Learning rate #106
Conversation
…lder into each Optimizer. Also, added to each Optimizer a corresponding Tensor that holds the value of the learning rate, and added a feed dictionary that maps the placeholder to the Tensor, so that it can be fed into the runner when running or evaluating. When setLearning rate is called the learning rate tensor and the feed dictionary are updated.
…lder into each Optimizer. Also, added to each Optimizer a corresponding Tensor that holds the value of the learning rate, and added a feed dictionary that maps the placeholder to the Tensor, so that it can be fed into the runner when running or evaluating. When setLearning rate is called the learning rate tensor and the feed dictionary are updated.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A couple of quick observations as I start reading this code.
} | ||
|
||
/** Returns true if this data type represents a floating point type */ | ||
public boolean isFloating() { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This pattern is very uncomfortable to me: DataType
being omniscient about TType
and dispatching on a string NAME
. What's our motivation? If we think it's the best pattern for this situation, perhaps we could document why?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Never mind, in this context -- I see in my local diff that this delta is unrelated to this PR. I'll raise this as an issue.
* @param graph the TensorFlow Graph | ||
* @param name the name for this Optimizer (defaults to 'Adadelta') | ||
* @param learningRate the learning rate | ||
*/ | ||
public AdaDelta(Graph graph, String name, float learningRate) { | ||
this(graph, name, learningRate, 0.95f, 1e-8f); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
-> RHO_DEFAULT, EPSILON_DEFAULT
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Never mind, in this context.
*/ | ||
protected Optimizer(Graph graph, String name) { | ||
protected Optimizer(Graph graph, float learningRate) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we have both of these constructors call into the Optimizer(Graph,float,String)
one?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Being that Optimizer is abstract, we really only need one constructor, protected Optimizer(Graph graph, String name, float learningRate)
. Of course, we would have to handle a null
name, with something like:
this.tf = Ops.create(graph).withName(name == null? getOptimizerName() : name);
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
CTORS have been changed
public static String createName(Output<? extends TType> variable, String slotName) { | ||
return variable.op().name() + "-" + slotName; | ||
} | ||
|
||
/** |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why'd the Javadoc go away?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am not sure what happened. I had a local copy that I saved and it was there, so will add it back in.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Update pushed
@@ -305,41 +350,20 @@ private Options() {} | |||
} | |||
} | |||
|
|||
/** |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Where'd the javadoc go?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have added it back in. Update pushed
* @param learningRate the learning rate | ||
*/ | ||
public final void setLearningRate(float learningRate) { | ||
if (this.learningRatePlaceholder == null) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Everything seems to have grown a this
reference. I don't think that's particularly necessary in these methods, as the argument could be newLearningRate
rather than learningRate
and then there is no aliasing.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I do this out of habit. I can easily change it as you suggest.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
changed setLearningRate
to setLearningRate(float newLearningRate)
, removed spurious this.
.
Update pushed
…ewLearningRate), eliminated spurious "this."
@@ -280,20 +321,20 @@ protected Op finish(List<Op> updateOperations, String name) { | |||
/** | |||
* Sets the learning rate | |||
* | |||
* @param learningRate the learning rate | |||
* @param newLearningRate the new earning rate |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
typo - "earning"
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK
tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Optimizer.java
Show resolved
Hide resolved
…raph graph, String name, float learningRate)"", change all the subclass ctors to use this one.
Add Operand<TFloat32> learningRateOperand as an option for learning rate.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just a few snake_case variable names in the tests that need converting to camelCase, and then I'll merge this in.
new RMSProp(session.getGraph(), learningRate, decay, momentum, epsilon, centered)) { | ||
Ops tf = session.getTF(); | ||
session.setEpsilon(1e-2f); | ||
float[] var0_init = {1.0F, 2.0F}; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please switch the python style variable names to camelCase.
FloatNdArray mul1 = ND.mul(v, beta); | ||
FloatNdArray squareG = ND.square(gT); | ||
FloatNdArray mul2 = ND.mul((1 - beta), squareG); | ||
return ND.add(mul1, mul2); | ||
} | ||
|
||
private FloatNdArray calculateParam( | ||
FloatNdArray param, float lrT, FloatNdArray m, FloatNdArray v, float epsilon) { | ||
// param - lrT * mT / (np.sqrt(vT) + epsilon) | ||
FloatNdArray param, float lr_t, FloatNdArray m, FloatNdArray v, float epsilon) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Switch python style name to camelCase.
This PR requires some rework due to #174, |
Whatever you think is easiest is fine by me. |
I am closing for now, and will reopen after we get further along on Model. |
This PR requires PR "Initial checkin of Keras Optimzers and helper classes" to be merged first.
Added changeable learning rate to Optimizers. This was done by adding a Placeholder for the learning rate, a Tensor to track the actual learning rate, and adding a Map to map the Placeholder to the Tensor that can be used to "feed" the runner.
Test Sessions were modified to accept a "FeedDict" Map to popullate the feed() of the runner.