From 20c0f5044e8f86786b403942f7bc37dd42f206dc Mon Sep 17 00:00:00 2001 From: Jim Clarke Date: Thu, 14 Jan 2021 12:08:02 -0500 Subject: [PATCH 1/2] Initial Checkin --- .../framework/constraints/Constraint.java | 95 ++++++++++ .../framework/constraints/MaxNorm.java | 115 +++++++++++++ .../framework/constraints/MinMaxNorm.java | 162 ++++++++++++++++++ .../framework/constraints/NonNeg.java | 46 +++++ .../framework/constraints/UnitNorm.java | 90 ++++++++++ .../framework/constraints/MaxNormTest.java | 65 +++++++ .../framework/constraints/MinMaxNormTest.java | 65 +++++++ .../framework/constraints/NonNegTest.java | 110 ++++++++++++ .../framework/constraints/UnitNormTest.java | 66 +++++++ .../org/tensorflow/framework/utils/ND.java | 101 ++++++++++- .../framework/utils/TestSession.java | 12 ++ 11 files changed, 923 insertions(+), 4 deletions(-) create mode 100644 tensorflow-framework/src/main/java/org/tensorflow/framework/constraints/Constraint.java create mode 100644 tensorflow-framework/src/main/java/org/tensorflow/framework/constraints/MaxNorm.java create mode 100644 tensorflow-framework/src/main/java/org/tensorflow/framework/constraints/MinMaxNorm.java create mode 100644 tensorflow-framework/src/main/java/org/tensorflow/framework/constraints/NonNeg.java create mode 100644 tensorflow-framework/src/main/java/org/tensorflow/framework/constraints/UnitNorm.java create mode 100644 tensorflow-framework/src/test/java/org/tensorflow/framework/constraints/MaxNormTest.java create mode 100644 tensorflow-framework/src/test/java/org/tensorflow/framework/constraints/MinMaxNormTest.java create mode 100644 tensorflow-framework/src/test/java/org/tensorflow/framework/constraints/NonNegTest.java create mode 100644 tensorflow-framework/src/test/java/org/tensorflow/framework/constraints/UnitNormTest.java diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/constraints/Constraint.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/constraints/Constraint.java new file mode 100644 index 00000000000..bf6f97b463a --- /dev/null +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/constraints/Constraint.java @@ -0,0 +1,95 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.constraints; + +import org.tensorflow.Operand; +import org.tensorflow.op.Ops; +import org.tensorflow.types.family.TNumber; + +import static org.tensorflow.framework.utils.CastHelper.cast; + +/** + * Base class for Constraints. Constraint subclasses impose constraints on weight values + * + * @param the date type for the weights + */ +public abstract class Constraint { + + public static final float EPSILON = 1e-7f; + + private final Ops tf; + + /** + * Creates a Constraint + * + * @param tf the TensorFlow Ops + */ + public Constraint(Ops tf) { + this.tf = tf; + } + + /** + * Applies the constraint against the provided weights + * + * @param weights the weights + * @return the constrained weights + */ + public abstract Operand call(Operand weights); + + /** + * Gets the TensorFlow Ops + * + * @return the TensorFlow Ops + */ + public Ops getTF() { + return tf; + } + + /** + * Get the element-wise square root. + * + * @param x the input Operand. + * @return the element-wise square root. + */ + protected Operand sqrt(Operand x) { + Class type = x.type(); + Operand zero = cast(tf, tf.constant(0), type); + Operand inf = cast(tf, tf.constant(Float.POSITIVE_INFINITY), type); + x = tf.clipByValue(x, zero, inf); + return tf.math.sqrt(x); + } + + /** + * Element-wise value clipping. + * + * @param x the Operand to clip + * @param minValue the minimum value + * @param maxValue the maximum value + * @return the operand with clipped values + */ + protected Operand clip(Operand x, double minValue, double maxValue) { + if (x == null) throw new IllegalArgumentException("Operand x must not be null"); + Ops tf = getTF(); + Class type = x.type(); + if (maxValue < minValue) { + double tmp = maxValue; + maxValue = minValue; + minValue = tmp; + } + Operand minValueConstant = cast(tf, tf.constant(minValue), type); + Operand maxValueConstant = cast(tf, tf.constant(maxValue), type); + return tf.clipByValue(x, minValueConstant, maxValueConstant); + } +} diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/constraints/MaxNorm.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/constraints/MaxNorm.java new file mode 100644 index 00000000000..f55a9998ff0 --- /dev/null +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/constraints/MaxNorm.java @@ -0,0 +1,115 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.constraints; + +import org.tensorflow.Operand; +import org.tensorflow.op.Ops; +import org.tensorflow.op.core.ReduceSum; +import org.tensorflow.types.family.TNumber; + +import static org.tensorflow.framework.utils.CastHelper.cast; + +/** + * Constrains the weights incident to each hidden unit to have a norm less than or equal to a + * desired value. + * + * @param the data type for the weights + */ +public class MaxNorm extends Constraint { + public static final float MAX_VALUE_DEFAULT = 2.0f; + public static final int AXIS_DEFAULT = 0; + + /** the maximum norm for the incoming weights. */ + private final float maxValue; + /** integer, axis along which to calculate weight norms. */ + private final int[] axes; + + /** + * Create a MaxNorm constraint using {@link #MAX_VALUE_DEFAULT} for the max value and {@link + * #AXIS_DEFAULT} for the axis. + * + * @param tf the TensorFlow Ops + */ + public MaxNorm(Ops tf) { + this(tf, MAX_VALUE_DEFAULT, AXIS_DEFAULT); + } + + /** + * Create a MaxNorm constraint using {@link #AXIS_DEFAULT} for the axis. + * + * @param tf the TensorFlow Ops + * @param maxValue the maximum norm for the incoming weights. + */ + public MaxNorm(Ops tf, float maxValue) { + this(tf, maxValue, AXIS_DEFAULT); + } + + /** + * Create a MaxNorm constraint + * + * @param tf the TensorFlow Ops + * @param maxValue the maximum norm for the incoming weights. + * @param axis axis along which to calculate weight norms. + */ + public MaxNorm(Ops tf, float maxValue, int axis) { + this(tf, maxValue, new int[] {axis}); + } + + /** + * Create a MaxNorm constraint + * + * @param tf the TensorFlow Ops + * @param maxValue the maximum norm for the incoming weights. + * @param axes axes along which to calculate weight norms. + */ + public MaxNorm(Ops tf, float maxValue, int[] axes) { + super(tf); + this.maxValue = maxValue; + this.axes = axes; + } + + /** {@inheritDoc} */ + @Override + public Operand call(Operand weights) { + Ops tf = getTF(); + Class type = weights.type(); + Operand norms = + sqrt( + tf.reduceSum( + tf.math.square(weights), tf.constant(getAxes()), ReduceSum.keepDims(Boolean.TRUE))); + Operand desired = clip(norms, 0f, this.getMaxValue()); + + return tf.math.mul( + weights, tf.math.div(desired, tf.math.add(cast(tf, tf.constant(EPSILON), type), norms))); + } + + /** + * Gets the max value + * + * @return the maxValue + */ + public float getMaxValue() { + return maxValue; + } + + /** + * Gets the axes + * + * @return the axes + */ + public int[] getAxes() { + return axes; + } +} diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/constraints/MinMaxNorm.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/constraints/MinMaxNorm.java new file mode 100644 index 00000000000..8388d651225 --- /dev/null +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/constraints/MinMaxNorm.java @@ -0,0 +1,162 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.constraints; + +import org.tensorflow.Operand; +import org.tensorflow.op.Ops; +import org.tensorflow.op.core.ReduceSum; +import org.tensorflow.types.family.TNumber; + +import static org.tensorflow.framework.utils.CastHelper.cast; + +/** + * Constrains the weights to have the norm between a lower bound and an upper bound. + * + * @param the data type for the weights + */ +public class MinMaxNorm extends Constraint { + public static final float MIN_VALUE_DEFAULT = 0.0F; + public static final float MAX_VALUE_DEFAULT = 1.0F; + public static final float RATE_DEFAULT = 1.0F; + public static final int AXIS_DEFAULT = 0; + + /** the minimum norm for the incoming weights. */ + private final float minValue; + /** the maximum norm for the incoming weights. */ + private final float maxValue; + + /** + * rate for enforcing the constraint: weights will be rescaled to yield (1 - rate) * norm + rate * + * norm.clip(min_value, max_value). Effectively, this means that rate=1.0 stands for strict + * enforcement of the constraint, while rate<1.0 means that weights will be rescaled at each step + * to slowly move towards a value inside the desired interval. + */ + private final float rate; + + /** axis along which to calculate weight norms. */ + private final int[] axes; + + /** + * Create a MaxNorm constraint using {@link #MIN_VALUE_DEFAULT} for the min value, {@link + * #MAX_VALUE_DEFAULT} for the max value, {@link #RATE_DEFAULT} for the rate and {@link + * #AXIS_DEFAULT} for the axis + * + * @param tf the TensorFlow Ops + */ + public MinMaxNorm(Ops tf) { + this(tf, MIN_VALUE_DEFAULT, MAX_VALUE_DEFAULT, RATE_DEFAULT, AXIS_DEFAULT); + } + + /** + * Create a MaxNorm constraint using {@link #RATE_DEFAULT} for the rate and {@link #AXIS_DEFAULT} + * for the axis + * + * @param tf the TensorFlow Ops + * @param minValue the minimum norm for the incoming weights. + * @param maxValue the maximum norm for the incoming weights. + */ + public MinMaxNorm(Ops tf, float minValue, float maxValue) { + this(tf, minValue, maxValue, RATE_DEFAULT, AXIS_DEFAULT); + } + + /** + * Create a MaxNorm constraint + * + * @param tf the TensorFlow Ops + * @param minValue the minimum norm for the incoming weights. + * @param maxValue the maximum norm for the incoming weights. + * @param rate the rate for enforcing the constraint. + * @param axis integer, axis along which to calculate weight norms. + */ + public MinMaxNorm(Ops tf, float minValue, float maxValue, float rate, int axis) { + this(tf, minValue, maxValue, rate, new int[] {axis}); + } + /** + * Create a MaxNorm constraint + * + * @param tf the TensorFlow Ops + * @param minValue the minimum norm for the incoming weights. + * @param maxValue the maximum norm for the incoming weights. + * @param rate the rate for enforcing the constraint. + * @param axes integer, axis along which to calculate weight norms. + */ + public MinMaxNorm(Ops tf, float minValue, float maxValue, float rate, int[] axes) { + super(tf); + this.minValue = minValue; + this.maxValue = maxValue; + this.rate = rate; + this.axes = axes; + } + + /** {@inheritDoc} */ + @Override + public Operand call(Operand weights) { + Class type = weights.type(); + Ops tf = getTF(); + Operand norms = + sqrt( + tf.reduceSum( + tf.math.square(weights), tf.constant(getAxes()), ReduceSum.keepDims(Boolean.TRUE))); + Operand desired = + tf.math.add( + tf.math.mul( + tf.dtypes.cast(tf.constant(this.getRate()), type), + clip(norms, this.getMinValue(), this.getMaxValue())), + tf.math.mul( + tf.math.sub( + tf.dtypes.cast(tf.constant(1), type), + tf.dtypes.cast(tf.constant(this.getRate()), type)), + norms)); + + return tf.math.mul( + weights, tf.math.div(desired, tf.math.add(cast(tf, tf.constant(EPSILON), type), norms))); + } + + /** + * Gets the minValue + * + * @return the minValue + */ + public float getMinValue() { + return minValue; + } + + /** + * Gets the maxValue + * + * @return the maxValue + */ + public float getMaxValue() { + return maxValue; + } + + /** + * Gets the rate + * + * @return the rate + */ + public float getRate() { + return rate; + } + + /** + * Gets the axes + * + * @return the axes + */ + public int[] getAxes() { + return axes; + } +} diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/constraints/NonNeg.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/constraints/NonNeg.java new file mode 100644 index 00000000000..3edfa1c036b --- /dev/null +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/constraints/NonNeg.java @@ -0,0 +1,46 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.constraints; + +import org.tensorflow.Operand; +import org.tensorflow.op.Ops; +import org.tensorflow.types.family.TNumber; + +/** + * Constrains the weights to be non-negative. + * + * @param the data type for the weights + */ +public class NonNeg extends Constraint { + + /** + * Create a NonNeg constraint + * + * @param tf the TensorFlow Ops + */ + public NonNeg(Ops tf) { + super(tf); + } + + /** {@inheritDoc} */ + @Override + public Operand call(Operand weights) { + Ops tf = getTF(); + Class type = weights.type(); + return tf.math.mul( + weights, + tf.dtypes.cast(tf.math.greaterEqual(weights, tf.dtypes.cast(tf.constant(0), type)), type)); + } +} diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/constraints/UnitNorm.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/constraints/UnitNorm.java new file mode 100644 index 00000000000..4eba2fd98c0 --- /dev/null +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/constraints/UnitNorm.java @@ -0,0 +1,90 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.constraints; + +import org.tensorflow.Operand; +import org.tensorflow.op.Ops; +import org.tensorflow.op.core.ReduceSum; +import org.tensorflow.types.family.TNumber; + +import static org.tensorflow.framework.utils.CastHelper.cast; + +/** + * Constrains the weights to have unit norm. + * + * @param the data type for the weights + */ +public class UnitNorm extends Constraint { + public static final int AXIS_DEFAULT = 0; + + /** integer, axis along which to calculate weight norms. */ + private final int[] axes; + + /** + * Create a UnitNorm Constraint with the axis set to {@link #AXIS_DEFAULT} + * + * @param tf the TensorFlow Ops + */ + public UnitNorm(Ops tf) { + this(tf, AXIS_DEFAULT); + } + + /** + * Create a UnitNorm Constraint + * + * @param tf the TensorFlow Ops + * @param axis axis along which to calculate weight norms. + */ + public UnitNorm(Ops tf, int axis) { + this(tf, new int[] {axis}); + } + + /** + * Create a UnitNorm Constraint + * + * @param tf the TensorFlow Ops + * @param axes axes along which to calculate weight norms. + */ + public UnitNorm(Ops tf, int[] axes) { + super(tf); + this.axes = axes; + } + + /** {@inheritDoc} */ + @Override + public Operand call(Operand weights) { + Class type = weights.type(); + + Ops tf = getTF(); + return tf.math.div( + weights, + tf.math.add( + cast(tf, tf.constant(EPSILON), type), + sqrt( + tf.reduceSum( + tf.math.square(weights), + tf.constant(getAxes()), + ReduceSum.keepDims(Boolean.TRUE))))); + } + + /** + * Gets the axes + * + * @return the axes + */ + public int[] getAxes() { + return axes; + } +} diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/constraints/MaxNormTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/constraints/MaxNormTest.java new file mode 100644 index 00000000000..fa61e097b42 --- /dev/null +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/constraints/MaxNormTest.java @@ -0,0 +1,65 @@ +package org.tensorflow.framework.constraints; + +import org.junit.jupiter.api.Test; +import org.tensorflow.Operand; +import org.tensorflow.framework.utils.TestSession; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.op.Ops; +import org.tensorflow.types.TFloat32; + +import java.util.Random; +import java.util.concurrent.atomic.AtomicInteger; + +class MaxNormTest { + private final TestSession.Mode[] tfModes = {TestSession.Mode.EAGER, TestSession.Mode.GRAPH}; + + private float[] getSampleArray() { + Random rand = new Random(3537L); + float[] result = new float[100 * 100]; + for (int i = 0; i < result.length; i++) { + result[i] = rand.nextFloat() * 100 - 50; + } + result[0] = 0; + return result; + } + + /** Test of call method, of class MaxNorm. */ + @Test + public void testCall() { + float[] testValues = {0.1f, 0.5f, 3f, 8f, 1e-7f}; + for (TestSession.Mode tfMode : tfModes) + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + final float[] array = getSampleArray(); + Operand weights = tf.reshape(tf.constant(array), tf.constant(Shape.of(100, 100))); + for (AtomicInteger i = new AtomicInteger(); + i.get() < testValues.length; + i.getAndIncrement()) { + MaxNorm instance = new MaxNorm<>(tf, testValues[i.get()]); + Operand result = instance.call(weights); + session.evaluate(result, (Number v) -> v.floatValue() <= testValues[i.get()]); + } + } + } + /** Test of call method, of class MaxNorm. */ + @Test + public void testCall1() { + for (TestSession.Mode tfMode : tfModes) + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + MaxNorm instance = new MaxNorm<>(tf, 2.0f); + Operand weights = + tf.constant( + new float[][] { + {0, 1, 3, 3}, {0, 0, 0, 3}, {0, 0, 0, 3}, + }); + Operand result = instance.call(weights); + float[] expected = { + 0, 1, 2, 1.1547005f, + 0, 0, 0, 1.1547005f, + 0, 0, 0, 1.1547005f + }; + session.evaluate(expected, result); + } + } +} diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/constraints/MinMaxNormTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/constraints/MinMaxNormTest.java new file mode 100644 index 00000000000..70bae6b9c83 --- /dev/null +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/constraints/MinMaxNormTest.java @@ -0,0 +1,65 @@ +package org.tensorflow.framework.constraints; + +import org.junit.jupiter.api.Test; +import org.tensorflow.Operand; +import org.tensorflow.framework.utils.ND; +import org.tensorflow.framework.utils.TestSession; +import org.tensorflow.ndarray.FloatNdArray; +import org.tensorflow.ndarray.NdArrays; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.op.Ops; +import org.tensorflow.types.TFloat32; + +import java.util.Random; +import java.util.concurrent.atomic.AtomicInteger; + +class MinMaxNormTest { + + private final TestSession.Mode[] tfModes = {TestSession.Mode.EAGER, TestSession.Mode.GRAPH}; + + private float[] getSampleArray() { + Random rand = new Random(3537L); + float[] result = new float[100 * 100]; + for (int i = 0; i < result.length; i++) { + result[i] = rand.nextFloat() * 100 - 50; + } + result[0] = 0; + return result; + } + + /** Test of call method, of class MinMaxNorm. */ + @Test + public void testCall() { + float[] testValues = {0.1f, 0.5f, 3f, 8f, 1e-7f}; + for (TestSession.Mode tfMode : tfModes) + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + final float[] array = getSampleArray(); + Operand weights = tf.reshape(tf.constant(array), tf.constant(Shape.of(100, 100))); + for (AtomicInteger i = new AtomicInteger(); + i.get() < testValues.length; + i.getAndIncrement()) { + MinMaxNorm instance = + new MinMaxNorm<>(tf, testValues[i.get()], testValues[i.get()] * 2); + Operand result = instance.call(weights); + if (tfMode == TestSession.Mode.EAGER) + evaluate(session, result.asTensor(), testValues[i.get()]); + else + try (TFloat32 tensor = + (TFloat32) session.getGraphSession().runner().fetch(result).run().get(0)) { + evaluate(session, tensor, testValues[i.get()]); + } + } + } + } + + private void evaluate(TestSession session, TFloat32 tensor, float m) { + FloatNdArray tensorArray = NdArrays.ofFloats(tensor.shape()); + tensor.copyTo(tensorArray); + tensorArray = ND.square(tensorArray); + FloatNdArray normArray = ND.sum(tensorArray, 0); + FloatNdArray normOfNormalized = ND.sqrt(normArray); + session.evaluate( + normOfNormalized, (f) -> f.floatValue() >= m && f.floatValue() <= m * 2f + 1e-5f); + } +} diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/constraints/NonNegTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/constraints/NonNegTest.java new file mode 100644 index 00000000000..3942629d6ed --- /dev/null +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/constraints/NonNegTest.java @@ -0,0 +1,110 @@ +package org.tensorflow.framework.constraints; + +import org.junit.jupiter.api.Test; +import org.tensorflow.Operand; +import org.tensorflow.framework.utils.ND; +import org.tensorflow.framework.utils.TestSession; +import org.tensorflow.ndarray.DoubleNdArray; +import org.tensorflow.ndarray.FloatNdArray; +import org.tensorflow.ndarray.NdArrays; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.op.Ops; +import org.tensorflow.types.TFloat32; +import org.tensorflow.types.TFloat64; + +import java.util.Random; +import java.util.concurrent.atomic.AtomicInteger; + +class NonNegTest { + private final TestSession.Mode[] tfModes = {TestSession.Mode.EAGER, TestSession.Mode.GRAPH}; + + private float[] getSampleArray() { + Random rand = new Random(3537L); + float[] result = new float[100 * 100]; + for (int i = 0; i < result.length; i++) { + result[i] = rand.nextFloat() * 100 - 50; + } + result[0] = 0; + return result; + } + + private double[] getSampleDArray() { + Random rand = new Random(3537L); + double[] result = new double[100 * 100]; + for (int i = 0; i < result.length; i++) { + result[i] = rand.nextDouble() * 100 - 50; + } + result[0] = 0; + return result; + } + + @Test + public void testTFloat32() { + float[] testValues = {0.1f, 0.5f, 3f, 8f, 1e-7f}; + for (TestSession.Mode tfMode : tfModes) + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + final float[] array = getSampleArray(); + Operand weights = tf.reshape(tf.constant(array), tf.constant(Shape.of(100, 100))); + for (AtomicInteger i = new AtomicInteger(); + i.get() < testValues.length; + i.getAndIncrement()) { + MinMaxNorm instance = + new MinMaxNorm<>(tf, testValues[i.get()], testValues[i.get()] * 2); + Operand result = instance.call(weights); + if (tfMode == TestSession.Mode.EAGER) + evaluate(session, result.asTensor(), testValues[i.get()]); + else + try (TFloat32 tensor = + (TFloat32) session.getGraphSession().runner().fetch(result).run().get(0)) { + evaluate(session, tensor, testValues[i.get()]); + } + } + } + } + + @Test + public void testTFloat64() { + float[] testValues = {0.1f, 0.5f, 3f, 8f, 1e-7f}; + for (TestSession.Mode tfMode : tfModes) + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + final double[] array = getSampleDArray(); + Operand weights = tf.reshape(tf.constant(array), tf.constant(Shape.of(100, 100))); + for (AtomicInteger i = new AtomicInteger(); + i.get() < testValues.length; + i.getAndIncrement()) { + MinMaxNorm instance = + new MinMaxNorm<>(tf, testValues[i.get()], testValues[i.get()] * 2); + Operand result = instance.call(weights); + if (tfMode == TestSession.Mode.EAGER) + evaluate(session, result.asTensor(), testValues[i.get()]); + else + try (TFloat64 tensor = + (TFloat64) session.getGraphSession().runner().fetch(result).run().get(0)) { + evaluate(session, tensor, testValues[i.get()]); + } + } + } + } + + private void evaluate(TestSession session, TFloat32 tensor, float m) { + FloatNdArray tensorArray = NdArrays.ofFloats(tensor.shape()); + tensor.copyTo(tensorArray); + tensorArray = ND.square(tensorArray); + FloatNdArray normArray = ND.sum(tensorArray, 0); + FloatNdArray normOfNormalized = ND.sqrt(normArray); + session.evaluate( + normOfNormalized, (f) -> f.floatValue() >= m && f.floatValue() <= m * 2f + 1e-5f); + } + + private void evaluate(TestSession session, TFloat64 tensor, float m) { + DoubleNdArray tensorArray = NdArrays.ofDoubles(tensor.shape()); + tensor.copyTo(tensorArray); + tensorArray = ND.square(tensorArray); + DoubleNdArray normArray = ND.sum(tensorArray, 0); + DoubleNdArray normOfNormalized = ND.sqrt(normArray); + session.evaluate( + normOfNormalized, (f) -> f.doubleValue() >= m && f.doubleValue() <= m * 2 + 1e-5); + } +} diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/constraints/UnitNormTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/constraints/UnitNormTest.java new file mode 100644 index 00000000000..7b6359bcf6c --- /dev/null +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/constraints/UnitNormTest.java @@ -0,0 +1,66 @@ +package org.tensorflow.framework.constraints; + +import org.junit.jupiter.api.Test; +import org.tensorflow.Operand; +import org.tensorflow.framework.utils.TestSession; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.op.Ops; +import org.tensorflow.types.TFloat32; +import org.tensorflow.types.TFloat64; + +import java.util.Random; +import java.util.concurrent.atomic.AtomicInteger; + +class UnitNormTest { + private final TestSession.Mode[] tfModes = {TestSession.Mode.EAGER, TestSession.Mode.GRAPH}; + + private float[] getSampleArray() { + Random rand = new Random(3537L); + float[] result = new float[100 * 100]; + for (int i = 0; i < result.length; i++) { + result[i] = rand.nextFloat() * 100 - 50; + } + result[0] = 0; + return result; + } + + /** Test of call method, of class MaxNorm. */ + @Test + public void testTFloat32() { + float[] testValues = {0.1f, 0.5f, 3f, 8f, 1e-7f}; + for (TestSession.Mode tfMode : tfModes) + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + final float[] array = getSampleArray(); + Operand weights = tf.reshape(tf.constant(array), tf.constant(Shape.of(100, 100))); + for (AtomicInteger i = new AtomicInteger(); + i.get() < testValues.length; + i.getAndIncrement()) { + MaxNorm instance = new MaxNorm<>(tf, testValues[i.get()]); + Operand result = instance.call(weights); + session.evaluate(result, (Number v) -> v.floatValue() <= testValues[i.get()]); + } + } + } + /** Test of call method, of class MaxNorm. */ + @Test + public void testCallTFloat64() { + for (TestSession.Mode tfMode : tfModes) + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + MaxNorm instance = new MaxNorm<>(tf, 2.0f); + Operand weights = + tf.constant( + new double[][] { + {0, 1, 3, 3}, {0, 0, 0, 3}, {0, 0, 0, 3}, + }); + Operand result = instance.call(weights); + double[] expected = { + 0, 1, 2, 1.1547005, + 0, 0, 0, 1.1547005, + 0, 0, 0, 1.1547005 + }; + session.evaluate(expected, result); + } + } +} diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/utils/ND.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/utils/ND.java index 0503a41dfc2..ef8bb71d724 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/utils/ND.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/utils/ND.java @@ -14,10 +14,7 @@ =======================================================================*/ package org.tensorflow.framework.utils; -import org.tensorflow.ndarray.FloatNdArray; -import org.tensorflow.ndarray.NdArray; -import org.tensorflow.ndarray.NdArrays; -import org.tensorflow.ndarray.Shape; +import org.tensorflow.ndarray.*; import java.util.Arrays; import java.util.concurrent.atomic.AtomicBoolean; @@ -103,6 +100,23 @@ public static FloatNdArray sqrt(FloatNdArray a) { return result; } + /** + * Gets the square root of an array. + * + * @param a the array + * @return the square root of the array. + */ + public static DoubleNdArray sqrt(DoubleNdArray a) { + DoubleNdArray result = NdArrays.ofDoubles(a.shape()); + int nDims = a.shape().numDimensions(); + a.elements(nDims - 1) + .forEachIndexed( + (idx, v) -> { + result.setDouble(Math.sqrt(v.getDouble()), idx); + }); + return result; + } + /** * Gets the square of an array. * @@ -120,6 +134,23 @@ public static FloatNdArray square(FloatNdArray a) { return result; } + /** + * Gets the square of an array. + * + * @param a the array + * @return the square of the array. + */ + public static DoubleNdArray square(DoubleNdArray a) { + DoubleNdArray result = NdArrays.ofDoubles(a.shape()); + int nDims = a.shape().numDimensions(); + a.elements(nDims - 1) + .forEachIndexed( + (idx, v) -> { + result.setDouble(v.getDouble() * v.getDouble(), idx); + }); + return result; + } + /** * Adds two arrays * @@ -568,6 +599,18 @@ public static FloatNdArray sum(FloatNdArray a) { return NdArrays.scalarOf(sum.get()); } + /** + * Sum all elements of an array + * + * @param a the array + * @return an a array with one element containing the sum. + */ + public static DoubleNdArray sum(DoubleNdArray a) { + AtomicReference sum = new AtomicReference(0D); + a.scalars().forEach(f -> sum.set(sum.get() + f.getDouble())); + return NdArrays.scalarOf(sum.get()); + } + /** * Sum all elements of an array based on the specified axis * @@ -579,6 +622,17 @@ public static FloatNdArray sum(FloatNdArray a, int axis) { return sum(a, axis, false); } + /** + * Sum all elements of an array based on the specified axis + * + * @param a the array + * @param axis the axis to sum + * @return an a array the sum over the axis less the diemsnion + */ + public static DoubleNdArray sum(DoubleNdArray a, int axis) { + return sum(a, axis, false); + } + /** * Sum all elements of an array based on the specified axis * @@ -618,6 +672,45 @@ public static FloatNdArray sum(FloatNdArray a, int axis, boolean keepDims) { } } + /** + * Sum all elements of an array based on the specified axis + * + * @param a the array + * @param axis the axis to sum + * @param keepDims indicates whether the dimensions over the sum should be kept or not. + * @return an a array the sum over the axis + */ + public static DoubleNdArray sum(DoubleNdArray a, int axis, boolean keepDims) { + Shape shape = a.shape(); + int nDims = shape.numDimensions(); + int xis = nDims - 1 - axis; + long totalSize = shape.size(); + long axisSize = shape.size(xis); + final double[] sums = new double[(int) axisSize]; + + a.scalars() + .forEachIndexed( + (idx, f) -> { + sums[(int) idx[xis]] += f.getDouble(); + }); + + if (keepDims) { + long[] newDims = shape.asArray(); + newDims[axis] = 1; + final AtomicInteger counter = new AtomicInteger(); + DoubleNdArray arrayK = NdArrays.ofDoubles(Shape.of(newDims)); + arrayK + .elements(newDims.length - 1) + .forEachIndexed( + (idx, v) -> { + v.setDouble(sums[counter.getAndAdd(1)]); + }); + return arrayK; + } else { + return NdArrays.vectorOf(sums); + } + } + /** * Sum all elements of an array based on the specified axis * diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/utils/TestSession.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/utils/TestSession.java index 3fccd0f0506..88316203310 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/utils/TestSession.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/utils/TestSession.java @@ -15,6 +15,7 @@ package org.tensorflow.framework.utils; import org.tensorflow.*; +import org.tensorflow.ndarray.DoubleNdArray; import org.tensorflow.ndarray.FloatNdArray; import org.tensorflow.op.Op; import org.tensorflow.op.Ops; @@ -492,6 +493,17 @@ public void evaluate(FloatNdArray input, Predicate predicate) { input.scalars().forEach(f -> assertTrue(predicate.test(f.getFloat()))); } + /** + * Evaluates the input against the expected value + * + * @param input the operand to evaluate + * @param predicate The Predicate that evaluates the each value from input + * @throws org.opentest4j.AssertionFailedError if the evaluation fails + */ + public void evaluate(DoubleNdArray input, Predicate predicate) { + input.scalars().forEach(f -> assertTrue(predicate.test(f.getDouble()))); + } + /** * Print the input * From 5837d6d02d6361537dc9a5c426a2665d132c32ec Mon Sep 17 00:00:00 2001 From: Jim Clarke Date: Tue, 26 Jan 2021 11:21:36 -0500 Subject: [PATCH 2/2] Clean up JavaDoc Change float attributes to double --- .../framework/constraints/Constraint.java | 22 ++++++++--------- .../framework/constraints/MaxNorm.java | 12 +++++----- .../framework/constraints/MinMaxNorm.java | 24 +++++++++---------- .../framework/constraints/MaxNormTest.java | 6 ++--- 4 files changed, 31 insertions(+), 33 deletions(-) diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/constraints/Constraint.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/constraints/Constraint.java index bf6f97b463a..1bcd3bd04ad 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/constraints/Constraint.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/constraints/Constraint.java @@ -58,7 +58,7 @@ public Ops getTF() { } /** - * Get the element-wise square root. + * Gets the element-wise square root. * * @param x the input Operand. * @return the element-wise square root. @@ -66,13 +66,12 @@ public Ops getTF() { protected Operand sqrt(Operand x) { Class type = x.type(); Operand zero = cast(tf, tf.constant(0), type); - Operand inf = cast(tf, tf.constant(Float.POSITIVE_INFINITY), type); - x = tf.clipByValue(x, zero, inf); - return tf.math.sqrt(x); + Operand inf = cast(tf, tf.constant(Double.POSITIVE_INFINITY), type); + return tf.math.sqrt(tf.clipByValue(x, zero, inf)); } /** - * Element-wise value clipping. + * Gets the element-wise value clipping. * * @param x the Operand to clip * @param minValue the minimum value @@ -83,13 +82,12 @@ protected Operand clip(Operand x, double minValue, double maxValue) { if (x == null) throw new IllegalArgumentException("Operand x must not be null"); Ops tf = getTF(); Class type = x.type(); - if (maxValue < minValue) { - double tmp = maxValue; - maxValue = minValue; - minValue = tmp; - } - Operand minValueConstant = cast(tf, tf.constant(minValue), type); - Operand maxValueConstant = cast(tf, tf.constant(maxValue), type); + + double min = Math.min(minValue, maxValue); + double max = Math.max(minValue, maxValue); + + Operand minValueConstant = cast(tf, tf.constant(min), type); + Operand maxValueConstant = cast(tf, tf.constant(max), type); return tf.clipByValue(x, minValueConstant, maxValueConstant); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/constraints/MaxNorm.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/constraints/MaxNorm.java index f55a9998ff0..13a7ee9eb16 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/constraints/MaxNorm.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/constraints/MaxNorm.java @@ -28,11 +28,11 @@ * @param the data type for the weights */ public class MaxNorm extends Constraint { - public static final float MAX_VALUE_DEFAULT = 2.0f; + public static final double MAX_VALUE_DEFAULT = 2.0; public static final int AXIS_DEFAULT = 0; /** the maximum norm for the incoming weights. */ - private final float maxValue; + private final double maxValue; /** integer, axis along which to calculate weight norms. */ private final int[] axes; @@ -52,7 +52,7 @@ public MaxNorm(Ops tf) { * @param tf the TensorFlow Ops * @param maxValue the maximum norm for the incoming weights. */ - public MaxNorm(Ops tf, float maxValue) { + public MaxNorm(Ops tf, double maxValue) { this(tf, maxValue, AXIS_DEFAULT); } @@ -63,7 +63,7 @@ public MaxNorm(Ops tf, float maxValue) { * @param maxValue the maximum norm for the incoming weights. * @param axis axis along which to calculate weight norms. */ - public MaxNorm(Ops tf, float maxValue, int axis) { + public MaxNorm(Ops tf, double maxValue, int axis) { this(tf, maxValue, new int[] {axis}); } @@ -74,7 +74,7 @@ public MaxNorm(Ops tf, float maxValue, int axis) { * @param maxValue the maximum norm for the incoming weights. * @param axes axes along which to calculate weight norms. */ - public MaxNorm(Ops tf, float maxValue, int[] axes) { + public MaxNorm(Ops tf, double maxValue, int[] axes) { super(tf); this.maxValue = maxValue; this.axes = axes; @@ -100,7 +100,7 @@ public Operand call(Operand weights) { * * @return the maxValue */ - public float getMaxValue() { + public double getMaxValue() { return maxValue; } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/constraints/MinMaxNorm.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/constraints/MinMaxNorm.java index 8388d651225..9cc39bfcf99 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/constraints/MinMaxNorm.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/constraints/MinMaxNorm.java @@ -27,15 +27,15 @@ * @param the data type for the weights */ public class MinMaxNorm extends Constraint { - public static final float MIN_VALUE_DEFAULT = 0.0F; - public static final float MAX_VALUE_DEFAULT = 1.0F; - public static final float RATE_DEFAULT = 1.0F; + public static final double MIN_VALUE_DEFAULT = 0.0; + public static final double MAX_VALUE_DEFAULT = 1.0; + public static final double RATE_DEFAULT = 1.0; public static final int AXIS_DEFAULT = 0; /** the minimum norm for the incoming weights. */ - private final float minValue; + private final double minValue; /** the maximum norm for the incoming weights. */ - private final float maxValue; + private final double maxValue; /** * rate for enforcing the constraint: weights will be rescaled to yield (1 - rate) * norm + rate * @@ -43,7 +43,7 @@ public class MinMaxNorm extends Constraint { * enforcement of the constraint, while rate<1.0 means that weights will be rescaled at each step * to slowly move towards a value inside the desired interval. */ - private final float rate; + private final double rate; /** axis along which to calculate weight norms. */ private final int[] axes; @@ -67,7 +67,7 @@ public MinMaxNorm(Ops tf) { * @param minValue the minimum norm for the incoming weights. * @param maxValue the maximum norm for the incoming weights. */ - public MinMaxNorm(Ops tf, float minValue, float maxValue) { + public MinMaxNorm(Ops tf, double minValue, double maxValue) { this(tf, minValue, maxValue, RATE_DEFAULT, AXIS_DEFAULT); } @@ -80,7 +80,7 @@ public MinMaxNorm(Ops tf, float minValue, float maxValue) { * @param rate the rate for enforcing the constraint. * @param axis integer, axis along which to calculate weight norms. */ - public MinMaxNorm(Ops tf, float minValue, float maxValue, float rate, int axis) { + public MinMaxNorm(Ops tf, double minValue, double maxValue, double rate, int axis) { this(tf, minValue, maxValue, rate, new int[] {axis}); } /** @@ -92,7 +92,7 @@ public MinMaxNorm(Ops tf, float minValue, float maxValue, float rate, int axis) * @param rate the rate for enforcing the constraint. * @param axes integer, axis along which to calculate weight norms. */ - public MinMaxNorm(Ops tf, float minValue, float maxValue, float rate, int[] axes) { + public MinMaxNorm(Ops tf, double minValue, double maxValue, double rate, int[] axes) { super(tf); this.minValue = minValue; this.maxValue = maxValue; @@ -129,7 +129,7 @@ public Operand call(Operand weights) { * * @return the minValue */ - public float getMinValue() { + public double getMinValue() { return minValue; } @@ -138,7 +138,7 @@ public float getMinValue() { * * @return the maxValue */ - public float getMaxValue() { + public double getMaxValue() { return maxValue; } @@ -147,7 +147,7 @@ public float getMaxValue() { * * @return the rate */ - public float getRate() { + public double getRate() { return rate; } diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/constraints/MaxNormTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/constraints/MaxNormTest.java index fa61e097b42..08d693c9432 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/constraints/MaxNormTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/constraints/MaxNormTest.java @@ -26,7 +26,7 @@ private float[] getSampleArray() { /** Test of call method, of class MaxNorm. */ @Test public void testCall() { - float[] testValues = {0.1f, 0.5f, 3f, 8f, 1e-7f}; + double[] testValues = {0.1, 0.5, 3, 8, 1e-7}; for (TestSession.Mode tfMode : tfModes) try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); @@ -37,7 +37,7 @@ public void testCall() { i.getAndIncrement()) { MaxNorm instance = new MaxNorm<>(tf, testValues[i.get()]); Operand result = instance.call(weights); - session.evaluate(result, (Number v) -> v.floatValue() <= testValues[i.get()]); + session.evaluate(result, v -> v.floatValue() <= testValues[i.get()]); } } } @@ -47,7 +47,7 @@ public void testCall1() { for (TestSession.Mode tfMode : tfModes) try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - MaxNorm instance = new MaxNorm<>(tf, 2.0f); + MaxNorm instance = new MaxNorm<>(tf, 2.0); Operand weights = tf.constant( new float[][] {