Skip to content

Learning rate #106

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 16 commits into from
Closed
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -111,20 +111,47 @@ protected Optimizer(Graph graph, String name, float learningRate) {
setLearningRate(learningRate);
}

/**
* Creates a name by combining a variable name and a slot name
*
* @param variable the variable
* @param slotName the name of the slot
* @return the combined name
*/
public static String createName(Output<? extends TType> variable, String slotName) {
return variable.op().name() + "-" + slotName;
}

/**
* Minimizes the loss by updating the variables
*
* @param loss the loss operation that returns the value to minimize
* @return returns op that minimizes the loss by updating the listed variables
*/
public Op minimize(Operand<?> loss) {
return minimize(loss, getOptimizerName() + "-minimize");
}

/**
* Minimizes the loss by updating the variables
*
* @param loss the loss operation that returns the value to minimize
* @param name the name for the minimize operation
* @return op that minimizes the loss by updating the listed variables
*/
public Op minimize(Operand<?> loss, String name) {
List<GradAndVar<?>> gradsAndVars = computeGradients(loss);

return applyGradients(gradsAndVars, name);
}

/**
* Computes the gradients based on a loss operand.
*
* @param loss the loss operation
* @param <T> the data type of the loss, gradients and variables.
* @return the computed gradients
*/
public <T extends TType> List<GradAndVar<?>> computeGradients(Operand<?> loss) {
List<Operation> variables = new ArrayList<>();
graph
Expand Down Expand Up @@ -156,6 +183,13 @@ public <T extends TType> List<GradAndVar<?>> computeGradients(Operand<?> loss) {
return gradVarPairs;
}

/**
* Applies gradients to variables
*
* @param gradsAndVars the list of (gradient, variable) pairs.
* @param name the name of the apply gradients operation
* @return an Op that applies the gradients to the variables.
*/
public Op applyGradients(List<GradAndVar<? extends TType>> gradsAndVars, String name) {
List<Output<? extends TType>> variables =
gradsAndVars.stream().map(GradAndVar::getVariable).collect(Collectors.toList());
Expand Down Expand Up @@ -242,6 +276,13 @@ protected Optional<Op> prepare(String scopeName) {
*/
protected void createSlots(List<Output<? extends TType>> variables) {}

/**
* Generates the gradient update operations for the specific variable and gradient.
*
* @param gradVarPair the list of (gradient, variable) pairs.
* @param <T> the datatype of the gradients and variables.
* @return An operand which applies the desired optimizer update to the variable.
*/
private <T extends TType> Op applyDense(GradAndVar<T> gradVarPair) {
return applyDense(gradVarPair.getGradient(), gradVarPair.getVariable());
}
Expand Down Expand Up @@ -280,20 +321,20 @@ protected Op finish(List<Op> updateOperations, String name) {
/**
* Sets the learning rate
*
* @param learningRate the learning rate
* @param newLearningRate the new earning rate
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

typo - "earning"

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK

*/
public final void setLearningRate(float learningRate) {
if (this.learningRatePlaceholder == null) {
this.learningRatePlaceholder =
public final void setLearningRate(float newLearningRate) {
if (learningRatePlaceholder == null) {
learningRatePlaceholder =
tf.withSubScope(LEARNING_RATE)
.placeholder(TFloat32.DTYPE, Placeholder.shape(Shape.scalar()));
}

if (this.learningRate != learningRate) {
if (this.learningRateTensor != null) this.learningRateTensor.close();
this.learningRate = learningRate;
this.learningRateTensor = TFloat32.scalarOf(this.learningRate);
this.feedMap = Collections.singletonMap(this.learningRatePlaceholder, learningRateTensor);
if (learningRate != newLearningRate) {
if (learningRateTensor != null) learningRateTensor.close();
learningRate = newLearningRate;
learningRateTensor = TFloat32.scalarOf(learningRate);
feedMap = Collections.singletonMap(learningRatePlaceholder, learningRateTensor);
}
}

Expand All @@ -303,7 +344,7 @@ public final void setLearningRate(float learningRate) {
* @return the learning rate
*/
public float getLearningRate() {
return this.learningRate;
return learningRate;
}

/**
Expand All @@ -312,7 +353,7 @@ public float getLearningRate() {
* @return the learning rate Operand
*/
protected Operand<TFloat32> getLearningRateOperand() {
return this.learningRatePlaceholder;
return learningRatePlaceholder;
}

/**
Expand All @@ -323,13 +364,15 @@ protected Operand<TFloat32> getLearningRateOperand() {
* Operand has been set.
*/
public Map<Operand<? extends TType>, Tensor<? extends TType>> getFeedMap() {
return this.feedMap;
return feedMap;
}

/** {@inheritDoc} */
public void close() {
// close the learningRate Tensor if it exists.
if (this.feedMap != null) {
this.feedMap.get(this.learningRatePlaceholder).close();
if (learningRateTensor != null) {
learningRateTensor.close();
learningRateTensor = null;
}
}

Expand Down