Skip to content

Add Constraints #197

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 7 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
=======================================================================*/
package org.tensorflow.framework.constraints;

import org.tensorflow.Operand;
import org.tensorflow.op.Ops;
import org.tensorflow.types.family.TNumber;

import static org.tensorflow.framework.utils.CastHelper.cast;

/**
* Base class for Constraints. Constraint subclasses impose constraints on weight values
*
* @param <T> the date type for the weights
*/
public abstract class Constraint<T extends TNumber> {

public static final float EPSILON = 1e-7f;

private final Ops tf;

/**
* Creates a Constraint
*
* @param tf the TensorFlow Ops
*/
public Constraint(Ops tf) {
this.tf = tf;
}

/**
* Applies the constraint against the provided weights
*
* @param weights the weights
* @return the constrained weights
*/
public abstract Operand<T> call(Operand<T> weights);

/**
* Gets the TensorFlow Ops
*
* @return the TensorFlow Ops
*/
public Ops getTF() {
return tf;
}

/**
* Gets the element-wise square root.
*
* @param x the input Operand.
* @return the element-wise square root.
*/
protected Operand<T> sqrt(Operand<T> x) {
Class<T> type = x.type();
Operand<T> zero = cast(tf, tf.constant(0), type);
Operand<T> inf = cast(tf, tf.constant(Double.POSITIVE_INFINITY), type);
return tf.math.sqrt(tf.clipByValue(x, zero, inf));
}

/**
* Gets the element-wise value clipping.
*
* @param x the Operand to clip
* @param minValue the minimum value
* @param maxValue the maximum value
* @return the operand with clipped values
*/
protected Operand<T> clip(Operand<T> x, double minValue, double maxValue) {
if (x == null) throw new IllegalArgumentException("Operand x must not be null");
Ops tf = getTF();
Class<T> type = x.type();

double min = Math.min(minValue, maxValue);
double max = Math.max(minValue, maxValue);

Operand<T> minValueConstant = cast(tf, tf.constant(min), type);
Operand<T> maxValueConstant = cast(tf, tf.constant(max), type);
return tf.clipByValue(x, minValueConstant, maxValueConstant);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
=======================================================================*/
package org.tensorflow.framework.constraints;

import org.tensorflow.Operand;
import org.tensorflow.op.Ops;
import org.tensorflow.op.core.ReduceSum;
import org.tensorflow.types.family.TNumber;

import static org.tensorflow.framework.utils.CastHelper.cast;

/**
* Constrains the weights incident to each hidden unit to have a norm less than or equal to a
* desired value.
*
* @param <T> the data type for the weights
*/
public class MaxNorm<T extends TNumber> extends Constraint<T> {
public static final double MAX_VALUE_DEFAULT = 2.0;
public static final int AXIS_DEFAULT = 0;

/** the maximum norm for the incoming weights. */
private final double maxValue;
/** integer, axis along which to calculate weight norms. */
private final int[] axes;

/**
* Create a MaxNorm constraint using {@link #MAX_VALUE_DEFAULT} for the max value and {@link
* #AXIS_DEFAULT} for the axis.
*
* @param tf the TensorFlow Ops
*/
public MaxNorm(Ops tf) {
this(tf, MAX_VALUE_DEFAULT, AXIS_DEFAULT);
}

/**
* Create a MaxNorm constraint using {@link #AXIS_DEFAULT} for the axis.
*
* @param tf the TensorFlow Ops
* @param maxValue the maximum norm for the incoming weights.
*/
public MaxNorm(Ops tf, double maxValue) {
this(tf, maxValue, AXIS_DEFAULT);
}

/**
* Create a MaxNorm constraint
*
* @param tf the TensorFlow Ops
* @param maxValue the maximum norm for the incoming weights.
* @param axis axis along which to calculate weight norms.
*/
public MaxNorm(Ops tf, double maxValue, int axis) {
this(tf, maxValue, new int[] {axis});
}

/**
* Create a MaxNorm constraint
*
* @param tf the TensorFlow Ops
* @param maxValue the maximum norm for the incoming weights.
* @param axes axes along which to calculate weight norms.
*/
public MaxNorm(Ops tf, double maxValue, int[] axes) {
super(tf);
this.maxValue = maxValue;
this.axes = axes;
}

/** {@inheritDoc} */
@Override
public Operand<T> call(Operand<T> weights) {
Ops tf = getTF();
Class<T> type = weights.type();
Operand<T> norms =
sqrt(
tf.reduceSum(
tf.math.square(weights), tf.constant(getAxes()), ReduceSum.keepDims(Boolean.TRUE)));
Operand<T> desired = clip(norms, 0f, this.getMaxValue());

return tf.math.mul(
weights, tf.math.div(desired, tf.math.add(cast(tf, tf.constant(EPSILON), type), norms)));
}

/**
* Gets the max value
*
* @return the maxValue
*/
public double getMaxValue() {
return maxValue;
}

/**
* Gets the axes
*
* @return the axes
*/
public int[] getAxes() {
return axes;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,162 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
=======================================================================*/
package org.tensorflow.framework.constraints;

import org.tensorflow.Operand;
import org.tensorflow.op.Ops;
import org.tensorflow.op.core.ReduceSum;
import org.tensorflow.types.family.TNumber;

import static org.tensorflow.framework.utils.CastHelper.cast;

/**
* Constrains the weights to have the norm between a lower bound and an upper bound.
*
* @param <T> the data type for the weights
*/
public class MinMaxNorm<T extends TNumber> extends Constraint<T> {
public static final double MIN_VALUE_DEFAULT = 0.0;
public static final double MAX_VALUE_DEFAULT = 1.0;
public static final double RATE_DEFAULT = 1.0;
public static final int AXIS_DEFAULT = 0;

/** the minimum norm for the incoming weights. */
private final double minValue;
/** the maximum norm for the incoming weights. */
private final double maxValue;

/**
* rate for enforcing the constraint: weights will be rescaled to yield (1 - rate) * norm + rate *
* norm.clip(min_value, max_value). Effectively, this means that rate=1.0 stands for strict
* enforcement of the constraint, while rate<1.0 means that weights will be rescaled at each step
* to slowly move towards a value inside the desired interval.
*/
private final double rate;

/** axis along which to calculate weight norms. */
private final int[] axes;

/**
* Create a MaxNorm constraint using {@link #MIN_VALUE_DEFAULT} for the min value, {@link
* #MAX_VALUE_DEFAULT} for the max value, {@link #RATE_DEFAULT} for the rate and {@link
* #AXIS_DEFAULT} for the axis
*
* @param tf the TensorFlow Ops
*/
public MinMaxNorm(Ops tf) {
this(tf, MIN_VALUE_DEFAULT, MAX_VALUE_DEFAULT, RATE_DEFAULT, AXIS_DEFAULT);
}

/**
* Create a MaxNorm constraint using {@link #RATE_DEFAULT} for the rate and {@link #AXIS_DEFAULT}
* for the axis
*
* @param tf the TensorFlow Ops
* @param minValue the minimum norm for the incoming weights.
* @param maxValue the maximum norm for the incoming weights.
*/
public MinMaxNorm(Ops tf, double minValue, double maxValue) {
this(tf, minValue, maxValue, RATE_DEFAULT, AXIS_DEFAULT);
}

/**
* Create a MaxNorm constraint
*
* @param tf the TensorFlow Ops
* @param minValue the minimum norm for the incoming weights.
* @param maxValue the maximum norm for the incoming weights.
* @param rate the rate for enforcing the constraint.
* @param axis integer, axis along which to calculate weight norms.
*/
public MinMaxNorm(Ops tf, double minValue, double maxValue, double rate, int axis) {
this(tf, minValue, maxValue, rate, new int[] {axis});
}
/**
* Create a MaxNorm constraint
*
* @param tf the TensorFlow Ops
* @param minValue the minimum norm for the incoming weights.
* @param maxValue the maximum norm for the incoming weights.
* @param rate the rate for enforcing the constraint.
* @param axes integer, axis along which to calculate weight norms.
*/
public MinMaxNorm(Ops tf, double minValue, double maxValue, double rate, int[] axes) {
super(tf);
this.minValue = minValue;
this.maxValue = maxValue;
this.rate = rate;
this.axes = axes;
}

/** {@inheritDoc} */
@Override
public Operand<T> call(Operand<T> weights) {
Class<T> type = weights.type();
Ops tf = getTF();
Operand<T> norms =
sqrt(
tf.reduceSum(
tf.math.square(weights), tf.constant(getAxes()), ReduceSum.keepDims(Boolean.TRUE)));
Operand<T> desired =
tf.math.add(
tf.math.mul(
tf.dtypes.cast(tf.constant(this.getRate()), type),
clip(norms, this.getMinValue(), this.getMaxValue())),
tf.math.mul(
tf.math.sub(
tf.dtypes.cast(tf.constant(1), type),
tf.dtypes.cast(tf.constant(this.getRate()), type)),
norms));

return tf.math.mul(
weights, tf.math.div(desired, tf.math.add(cast(tf, tf.constant(EPSILON), type), norms)));
}

/**
* Gets the minValue
*
* @return the minValue
*/
public double getMinValue() {
return minValue;
}

/**
* Gets the maxValue
*
* @return the maxValue
*/
public double getMaxValue() {
return maxValue;
}

/**
* Gets the rate
*
* @return the rate
*/
public double getRate() {
return rate;
}

/**
* Gets the axes
*
* @return the axes
*/
public int[] getAxes() {
return axes;
}
}
Loading