Skip to content

Commit 0a1a868

Browse files
authored
Add Losses (#129)
* Initial checkin to rebase to Initialziers to pick up changes to ndarry Shape * Initial Checkin for losses * Fix reshape in sparseCategoricalCrossentropy() * Apply various fixes to JavaDoc * Change Tuple to LossTuple * Repair JavaDOx * Fixed AllAxis to hanlde dynamic shape when static shape rank is unknown. * change method name allAxis to allAxes * change private method binaryCrossentropy to binaryCrossentropyHelper * Fixed squeezeOrExpandDimensions to make sure the updated labels, predictions and weights are returned in LossTuple * Fix JavaDoc, Add in rangeCheck and valueCheck Misc fixes based on review * Fix unused imports and add @SuppressWarnings("unchecked") for casts. * Add copyright * Add CastHelper and used that for all casts * Fix JavaDoc, change snake case to camel case. * Change class LossesImpl to LossesHelper * Remove commented out JavaDoc * Changed method name from smoothLabelsBinaryX to smoothBinaryLabels, smoothLabelsCatX to smoothCategoricalLabels. Added clarification oin JavaDoc for cosineSimilarity to describe the difference between the mathematical definition for cosine similarity and the loss definition. * Fixed JavaDoc for labelSmoothing * Fixed JavaDoc to change label_smoothing to labelSmoothing. * Fix formatting * replace label_smoothing with labelSmoothing. fix typo error in JavaDoc comment * Add copyright to test cases * Fix copyright to attribute TensorFlow Authors. * Fix typo on broadcast in JavaDoc * Fix typo on broadcast in JavaDoc
1 parent 91400f4 commit 0a1a868

36 files changed

+6126
-0
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,231 @@
1+
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License.
14+
=======================================================================*/
15+
package org.tensorflow.framework.losses;
16+
17+
import org.tensorflow.Operand;
18+
import org.tensorflow.framework.losses.impl.LossesHelper;
19+
import org.tensorflow.op.Ops;
20+
import org.tensorflow.types.family.TNumber;
21+
22+
import static org.tensorflow.framework.utils.CastHelper.cast;
23+
24+
/**
25+
* Computes the cross-entropy loss between true labels and predicted labels.
26+
*
27+
* <p>Use this cross-entropy loss when there are only two label classes (assumed to be 0 and 1). For
28+
* each example, there should be a single floating-point value per prediction.
29+
*
30+
* <p>Standalone usage:
31+
*
32+
* <pre>
33+
* Operand&lt;TFloat32&gt; labels =
34+
* tf.constant(new float[][] {{0.f, 1.f}, {0.f, 0.f}});
35+
* Operand&lt;TFloat32&gt; predictions =
36+
* tf.constant(new float[][] {{0.6f, 0.4f}, {0.4f, 0.6f}});
37+
* BinaryCrossentropy bce = new BinaryCrossentropy(tf);
38+
* Operand&lt;TFloat32&gt; result = bce.call(labels, predictions);
39+
* // produces 0.815
40+
* </pre>
41+
*
42+
* <p>Calling with sample weight:
43+
*
44+
* <pre>
45+
* Operand&lt;TFloat32&gt; sampleWeight = tf.constant(new float[] {1.f, 0.f});
46+
* Operand&lt;TFloat32&gt; result = bce.call(labels, predictions, sampleWeight);
47+
* // produces 0.458f
48+
* </pre>
49+
*
50+
* <p>Using <code>SUM</code> reduction type:
51+
*
52+
* <pre>
53+
* BinaryCrossentropy bce = new BinaryCrossentropy(tf, Reduction.SUM);
54+
* Operand&lt;TFloat32&gt; result = bce.call(labels, predictions);
55+
* // produces 1.630f
56+
* </pre>
57+
*
58+
* <p>Using <code>NONE</code> reduction type:
59+
*
60+
* <pre>
61+
* BinaryCrossentropy bce = new BinaryCrossentropy(tf, Reduction.NONE);
62+
* Operand&lt;TFloat32&gt; result = bce.call(labels, predictions);
63+
* // produces [0.916f, 0.714f]
64+
* </pre>
65+
*/
66+
public class BinaryCrossentropy extends Loss {
67+
public static final boolean FROM_LOGITS_DEFAULT = false;
68+
public static final float LABEL_SMOOTHING_DEFAULT = 0.0f;
69+
70+
private final boolean fromLogits;
71+
private final float labelSmoothing;
72+
73+
/**
74+
* Creates a Binary Crossentropy Loss using {@link Class#getSimpleName()} as the loss name, {@link
75+
* #FROM_LOGITS_DEFAULT} for fromLogits, {@link #LABEL_SMOOTHING_DEFAULT} for labelSmoothing and a
76+
* Loss Reduction of {@link Loss#REDUCTION_DEFAULT}
77+
*
78+
* @param tf the TensorFlow Ops
79+
*/
80+
public BinaryCrossentropy(Ops tf) {
81+
this(tf, null, FROM_LOGITS_DEFAULT, LABEL_SMOOTHING_DEFAULT, REDUCTION_DEFAULT);
82+
}
83+
84+
/**
85+
* Creates a Binary Crossentropy loss using {@link Class#getSimpleName()} as the loss name, {@link
86+
* #FROM_LOGITS_DEFAULT} for fromLogits, and {@link #LABEL_SMOOTHING_DEFAULT} for labelSmoothing
87+
*
88+
* @param tf the TensorFlow Ops
89+
* @param reduction Type of Reduction to apply to the loss.
90+
*/
91+
public BinaryCrossentropy(Ops tf, Reduction reduction) {
92+
this(tf, null, FROM_LOGITS_DEFAULT, LABEL_SMOOTHING_DEFAULT, reduction);
93+
}
94+
95+
/**
96+
* Creates a Binary Crossentropy loss using using {@link Class#getSimpleName()} as the loss name,
97+
* labelSmoothing of {@link #LABEL_SMOOTHING_DEFAULT}, a reduction of {@link
98+
* Loss#REDUCTION_DEFAULT},
99+
*
100+
* @param tf the TensorFlow Ops
101+
* @param fromLogits Whether to interpret predictions as a tensor of logit values
102+
*/
103+
public BinaryCrossentropy(Ops tf, boolean fromLogits) {
104+
this(tf, null, fromLogits, LABEL_SMOOTHING_DEFAULT, REDUCTION_DEFAULT);
105+
}
106+
107+
/**
108+
* Creates a Binary Crossentropy loss using labelSmoothing of {@link #LABEL_SMOOTHING_DEFAULT} a
109+
* reduction of {@link Loss#REDUCTION_DEFAULT}.
110+
*
111+
* @param tf the TensorFlow Ops
112+
* @param name the name of the loss
113+
* @param fromLogits Whether to interpret predictions as a tensor of logit values
114+
*/
115+
public BinaryCrossentropy(Ops tf, String name, boolean fromLogits) {
116+
this(tf, name, fromLogits, LABEL_SMOOTHING_DEFAULT, REDUCTION_DEFAULT);
117+
}
118+
119+
/**
120+
* Creates a Binary Crossentropy loss using using {@link Class#getSimpleName()} as the loss name,
121+
* and a reduction of {@link Loss#REDUCTION_DEFAULT}.
122+
*
123+
* @param tf the TensorFlow Ops
124+
* @param fromLogits Whether to interpret predictions as a tensor of logit values
125+
* @param labelSmoothing A number in the range, [0, 1]. When 0, no smoothing occurs. When &gt; 0,
126+
* compute the loss between the predicted labels and a smoothed version of the true labels,
127+
* where the smoothing squeezes the labels towards 0.5. Larger values of labelSmoothing
128+
* correspond to heavier smoothing.
129+
*/
130+
public BinaryCrossentropy(Ops tf, boolean fromLogits, float labelSmoothing) {
131+
this(tf, null, fromLogits, labelSmoothing, REDUCTION_DEFAULT);
132+
}
133+
134+
/**
135+
* Creates a Binary Crossentropy loss using a reduction of {@link Loss#REDUCTION_DEFAULT}.
136+
*
137+
* @param tf the TensorFlow Ops
138+
* @param name the name of the loss
139+
* @param fromLogits Whether to interpret predictions as a tensor of logit values
140+
* @param labelSmoothing A number in the range, [0, 1]. When 0, no smoothing occurs. When &gt; 0,
141+
* compute the loss between the predicted labels and a smoothed version of the true labels,
142+
* where the smoothing squeezes the labels towards 0.5. Larger values of labelSmoothing
143+
* correspond to heavier smoothing.
144+
*/
145+
public BinaryCrossentropy(Ops tf, String name, boolean fromLogits, float labelSmoothing) {
146+
this(tf, name, fromLogits, labelSmoothing, REDUCTION_DEFAULT);
147+
}
148+
149+
/**
150+
* Creates a Binary Crossentropy loss
151+
*
152+
* @param tf the TensorFlow Ops
153+
* @param fromLogits Whether to interpret predictions as a tensor of logit values
154+
* @param labelSmoothing A number in the range, [0, 1]. When 0, no smoothing occurs. When &gt; 0,
155+
* compute the loss between the predicted labels and a smoothed version of the true labels,
156+
* where the smoothing squeezes the labels towards 0.5. Larger values of labelSmoothing
157+
* correspond to heavier smoothing.
158+
* @param reduction Type of Reduction to apply to the loss.
159+
*/
160+
public BinaryCrossentropy(Ops tf, boolean fromLogits, float labelSmoothing, Reduction reduction) {
161+
this(tf, null, fromLogits, labelSmoothing, reduction);
162+
}
163+
164+
/**
165+
* Creates a Binary Crossentropy loss
166+
*
167+
* @param tf the TensorFlow Ops
168+
* @param name the name of the loss
169+
* @param fromLogits Whether to interpret predictions as a tensor of logit values
170+
* @param labelSmoothing A number in the range, [0, 1]. When 0, no smoothing occurs. When &gt; 0,
171+
* compute the loss between the predicted labels and a smoothed version of the true labels,
172+
* where the smoothing squeezes the labels towards 0.5. Larger values of labelSmoothing
173+
* correspond to heavier smoothing.
174+
* @param reduction Type of Reduction to apply to the loss.
175+
* @throws IllegalArgumentException if labelSmoothing is not in the inclusive range of 0. - 1.
176+
*/
177+
public BinaryCrossentropy(
178+
Ops tf, String name, boolean fromLogits, float labelSmoothing, Reduction reduction) {
179+
super(tf, name, reduction);
180+
if (labelSmoothing < 0 || labelSmoothing > 1)
181+
throw new IllegalArgumentException(
182+
"labelSmoothing must be >= 0. and <= 1, found " + labelSmoothing);
183+
this.fromLogits = fromLogits;
184+
this.labelSmoothing = labelSmoothing;
185+
}
186+
187+
/**
188+
* Generates an Operand that calculates the loss.
189+
*
190+
* <p>If run in Graph mode, the computation will throw {@link
191+
* org.tensorflow.exceptions.TFInvalidArgumentException} if the predictions values are outside the
192+
* range o [0. to 1.]. In Eager Mode, this call will throw {@link IllegalArgumentException}, if
193+
* the predictions values are outside the range o [0. to 1.]
194+
*
195+
* @param labels the truth values or labels
196+
* @param predictions the predictions, values must be in the range [0. to 1.] inclusive.
197+
* @param sampleWeights Optional SampleWeights acts as a coefficient for the loss. If a scalar is
198+
* provided, then the loss is simply scaled by the given value. If SampleWeights is a tensor
199+
* of size [batch_size], then the total loss for each sample of the batch is rescaled by the
200+
* corresponding element in the SampleWeights vector. If the shape of SampleWeights is
201+
* [batch_size, d0, .. dN-1] (or can be broadcast to this shape), then each loss element of
202+
* predictions is scaled by the corresponding value of SampleWeights. (Note on dN-1: all loss
203+
* functions reduce by 1 dimension, usually axis=-1.)
204+
* @param <T> The data type of the predictions, sampleWeights and loss.
205+
* @param <U> The data type of the labels.
206+
* @return the loss
207+
* @throws IllegalArgumentException if the predictions are outside the range [0.-1.].
208+
*/
209+
@Override
210+
public <T extends TNumber, U extends TNumber> Operand<T> call(
211+
Operand<U> labels, Operand<T> predictions, Operand<T> sampleWeights) {
212+
Operand<T> lPredictions;
213+
if (!fromLogits) {
214+
// add predictions range check for 0 - 1
215+
lPredictions =
216+
LossesHelper.rangeCheck(
217+
getTF(),
218+
"predictions range check [0-1]",
219+
predictions,
220+
cast(getTF(), getTF().constant(0), predictions.asOutput().dataType()),
221+
cast(getTF(), getTF().constant(1), predictions.asOutput().dataType()));
222+
223+
} else {
224+
lPredictions = predictions;
225+
}
226+
227+
Operand<T> losses =
228+
Losses.binaryCrossentropy(getTF(), labels, lPredictions, fromLogits, labelSmoothing);
229+
return LossesHelper.computeWeightedLoss(getTF(), losses, getReduction(), sampleWeights);
230+
}
231+
}

0 commit comments

Comments
 (0)