Skip to content

Commit 4802fd2

Browse files
Generic cleanup Metrics and Losses (#204)
* Simplify generic parameters across losses and metrics. * Reformat code * Change order of TrainOps and QuantiQuantizationOps. For some reason, when I build it reverses these 2 from master's version. * Fix LossMetric to change abstract "call" method to use gneric parameter for predictions instead of <T>. * Reformat code, fix javadoc * Remove trailing character Co-authored-by: Karl Lessard <karl.lessard@gmail.com>
1 parent 8b36d7e commit 4802fd2

File tree

68 files changed

+595
-555
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

68 files changed

+595
-555
lines changed

tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/Ops.java

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -345,10 +345,10 @@ public final class Ops {
345345

346346
public final SignalOps signal;
347347

348-
public final TrainOps train;
349-
350348
public final QuantizationOps quantization;
351349

350+
public final TrainOps train;
351+
352352
private final Scope scope;
353353

354354
private Ops(Scope scope) {
@@ -370,8 +370,8 @@ private Ops(Scope scope) {
370370
math = new MathOps(this);
371371
audio = new AudioOps(this);
372372
signal = new SignalOps(this);
373-
train = new TrainOps(this);
374373
quantization = new QuantizationOps(this);
374+
train = new TrainOps(this);
375375
}
376376

377377
/**

tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/Glorot.java

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,6 @@
6262
* VarianceScaling.Distribution#TRUNCATED_NORMAL} for the distribution parameter.
6363
* <p>For a GlorotUniform equivalent initializer, use {@link VarianceScaling.Distribution#UNIFORM}
6464
* for the distribution parameter.
65-
* <p></p>
6665
*
6766
* @param <T> The TType for the call operation
6867
* @see VarianceScaling.Distribution

tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/He.java

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,6 @@
5757
* VarianceScaling.Distribution#TRUNCATED_NORMAL} for the distribution parameter.
5858
* <p>For an HeUniform equivalent initializer, use {@link VarianceScaling.Distribution#UNIFORM}
5959
* for the distribution parameter.
60-
* <p></p>
6160
*
6261
* @param <T> The TType for the call operation
6362
* @see <a

tensorflow-framework/src/main/java/org/tensorflow/framework/losses/BinaryCrossentropy.java

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -202,13 +202,12 @@ public BinaryCrossentropy(
202202
* predictions is scaled by the corresponding value of SampleWeights. (Note on dN-1: all loss
203203
* functions reduce by 1 dimension, usually axis=-1.)
204204
* @param <T> The data type of the predictions, sampleWeights and loss.
205-
* @param <U> The data type of the labels.
206205
* @return the loss
207206
* @throws IllegalArgumentException if the predictions are outside the range [0.-1.].
208207
*/
209208
@Override
210-
public <T extends TNumber, U extends TNumber> Operand<T> call(
211-
Operand<U> labels, Operand<T> predictions, Operand<T> sampleWeights) {
209+
public <T extends TNumber> Operand<T> call(
210+
Operand<? extends TNumber> labels, Operand<T> predictions, Operand<T> sampleWeights) {
212211
Operand<T> lPredictions;
213212
if (!fromLogits) {
214213
// add predictions range check for 0 - 1

tensorflow-framework/src/main/java/org/tensorflow/framework/losses/CategoricalCrossentropy.java

Lines changed: 22 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -154,24 +154,26 @@ public CategoricalCrossentropy(Ops tf, String name, boolean fromLogits) {
154154
*
155155
* @param tf the TensorFlow Ops
156156
* @param fromLogits Whether to interpret predictions as a tensor of logit values
157-
* @param labelSmoothing Float in <code>[0, 1]</code>. When <code>&gt; 0</code>, label values are smoothed, meaning the
158-
* confidence on label values are relaxed. e.g. <code>labelSmoothing=0.2</code> means that we will use a
159-
* value of <code>0.1</code> for label <code>0</code> and <code>0.9</code> for label <code>1</code>
157+
* @param labelSmoothing Float in <code>[0, 1]</code>. When <code>&gt; 0</code>, label values are
158+
* smoothed, meaning the confidence on label values are relaxed. e.g. <code>labelSmoothing=0.2
159+
* </code> means that we will use a value of <code>0.1</code> for label <code>0</code> and
160+
* <code>0.9</code> for label <code>1</code>
160161
*/
161162
public CategoricalCrossentropy(Ops tf, boolean fromLogits, float labelSmoothing) {
162163
this(tf, null, fromLogits, labelSmoothing, REDUCTION_DEFAULT, DEFAULT_AXIS);
163164
}
164165

165166
/**
166-
* Creates a categorical cross entropy Loss using a Loss Reduction of {@link Loss#REDUCTION_DEFAULT},
167-
* and a channel axis of {@link #DEFAULT_AXIS}
167+
* Creates a categorical cross entropy Loss using a Loss Reduction of {@link
168+
* Loss#REDUCTION_DEFAULT}, and a channel axis of {@link #DEFAULT_AXIS}
168169
*
169170
* @param tf the TensorFlow Ops
170171
* @param name the name of this loss
171172
* @param fromLogits Whether to interpret predictions as a tensor of logit values
172-
* @param labelSmoothing Float in <code>[0, 1]</code>. When <code>&gt; 0</code>, label values are smoothed, meaning the
173-
* confidence on label values are relaxed. e.g. <code>labelSmoothing=0.2</code> means that we will use a
174-
* value of <code>0.1</code> for label <code>0</code> and <code>0.9</code> for label <code>1</code>
173+
* @param labelSmoothing Float in <code>[0, 1]</code>. When <code>&gt; 0</code>, label values are
174+
* smoothed, meaning the confidence on label values are relaxed. e.g. <code>labelSmoothing=0.2
175+
* </code> means that we will use a value of <code>0.1</code> for label <code>0</code> and
176+
* <code>0.9</code> for label <code>1</code>
175177
*/
176178
public CategoricalCrossentropy(Ops tf, String name, boolean fromLogits, float labelSmoothing) {
177179
this(tf, name, fromLogits, labelSmoothing, REDUCTION_DEFAULT, DEFAULT_AXIS);
@@ -183,9 +185,10 @@ public CategoricalCrossentropy(Ops tf, String name, boolean fromLogits, float la
183185
*
184186
* @param tf the TensorFlow Ops
185187
* @param fromLogits Whether to interpret predictions as a tensor of logit values
186-
* @param labelSmoothing Float in <code>[0, 1]</code>. When <code>&gt; 0</code>, label values are smoothed, meaning the
187-
* confidence on label values are relaxed. e.g. <code>x=0.2</code> means that we will use a
188-
* value of <code>0.1</code> for label <code>0</code> and <code>0.9</code> for label <code>1</code>
188+
* @param labelSmoothing Float in <code>[0, 1]</code>. When <code>&gt; 0</code>, label values are
189+
* smoothed, meaning the confidence on label values are relaxed. e.g. <code>x=0.2</code> means
190+
* that we will use a value of <code>0.1</code> for label <code>0</code> and <code>0.9</code>
191+
* for label <code>1</code>
189192
* @param reduction Type of Reduction to apply to loss.
190193
*/
191194
public CategoricalCrossentropy(
@@ -199,13 +202,14 @@ public CategoricalCrossentropy(
199202
* @param tf the TensorFlow Ops
200203
* @param name the name of this loss
201204
* @param fromLogits Whether to interpret predictions as a tensor of logit values
202-
* @param labelSmoothing Float in <code>[0, 1]</code>. When <code>&gt; 0</code>, label values are smoothed, meaning the
203-
* confidence on label values are relaxed. e.g. <code>labelSmoothing=0.2</code> means that we will use a
204-
* value of <code>0.1</code> for label <code>0</code> and <code>0.9</code> for label <code>1</code>
205+
* @param labelSmoothing Float in <code>[0, 1]</code>. When <code>&gt; 0</code>, label values are
206+
* smoothed, meaning the confidence on label values are relaxed. e.g. <code>labelSmoothing=0.2
207+
* </code> means that we will use a value of <code>0.1</code> for label <code>0</code> and
208+
* <code>0.9</code> for label <code>1</code>
205209
* @param reduction Type of Reduction to apply to loss.
206210
* @param axis The channels axis. <code>axis=-1</code> corresponds to data format "Channels Last"
207-
* and <code>axis=1</code> corresponds to data format "Channels First".
208-
* {@link Losses#CHANNELS_LAST} and {@link Losses#CHANNELS_FIRST}
211+
* and <code>axis=1</code> corresponds to data format "Channels First". {@link
212+
* Losses#CHANNELS_LAST} and {@link Losses#CHANNELS_FIRST}
209213
* @throws IllegalArgumentException if labelSmoothing is not in the inclusive range of 0. - 1.
210214
*/
211215
public CategoricalCrossentropy(
@@ -242,13 +246,12 @@ public CategoricalCrossentropy(
242246
* predictions is scaled by the corresponding value of SampleWeights. (Note on dN-1: all loss
243247
* functions reduce by 1 dimension, usually axis=-1.)
244248
* @param <T> The data type of the predictions, sampleWeights and loss.
245-
* @param <U> The data type of the labels.
246249
* @return the loss
247250
* @throws IllegalArgumentException if the predictions are outside the range [0.-1.].
248251
*/
249252
@Override
250-
public <T extends TNumber, U extends TNumber> Operand<T> call(
251-
Operand<U> labels, Operand<T> predictions, Operand<T> sampleWeights) {
253+
public <T extends TNumber> Operand<T> call(
254+
Operand<? extends TNumber> labels, Operand<T> predictions, Operand<T> sampleWeights) {
252255
Operand<T> lPredictions;
253256
if (!fromLogits) {
254257
// add predictions range check for 0 - 1

tensorflow-framework/src/main/java/org/tensorflow/framework/losses/CategoricalHinge.java

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
* <p><code>loss = maximum(neg - pos + 1, 0)</code> where <code>neg=maximum((1-labels)*predictions)
2626
* </code> and <code>pos=sum(labels*predictions)</code>
2727
*
28-
* <p><code>labels</code> values are expected to be 0 or 1.</p>
28+
* <p><code>labels</code> values are expected to be 0 or 1.
2929
*
3030
* <p>Standalone usage:
3131
*
@@ -99,8 +99,8 @@ public CategoricalHinge(Ops tf, String name, Reduction reduction) {
9999

100100
/** {@inheritDoc} */
101101
@Override
102-
public <T extends TNumber, U extends TNumber> Operand<T> call(
103-
Operand<U> labels, Operand<T> predictions, Operand<T> sampleWeights) {
102+
public <T extends TNumber> Operand<T> call(
103+
Operand<? extends TNumber> labels, Operand<T> predictions, Operand<T> sampleWeights) {
104104
Operand<T> losses = Losses.categoricalHinge(getTF(), labels, predictions);
105105
return LossesHelper.computeWeightedLoss(getTF(), losses, getReduction(), sampleWeights);
106106
}

tensorflow-framework/src/main/java/org/tensorflow/framework/losses/CosineSimilarity.java

Lines changed: 58 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -22,12 +22,13 @@
2222
/**
2323
* Computes the cosine similarity between labels and predictions.
2424
*
25-
* <p>Note that it is a number between <code>-1</code> and <code>1</code>. When it is a negative number between <code>-1</code> and <code>0</code>, <code>0</code>
26-
* indicates orthogonality and values closer to <code>-1</code>indicate greater similarity. The values closer to
27-
* <code>1</code> indicate greater dissimilarity. This makes it usable as a loss function in a setting where you
28-
* try to maximize the proximity between predictions and targets. If either <code>labels</code> or <code>predictions</code> is
29-
* a zero vector, cosine similarity will be <code>0</code> regardless of the proximity between predictions and
30-
* targets.
25+
* <p>Note that it is a number between <code>-1</code> and <code>1</code>. When it is a negative
26+
* number between <code>-1</code> and <code>0</code>, <code>0</code> indicates orthogonality and
27+
* values closer to <code>-1</code>indicate greater similarity. The values closer to <code>1</code>
28+
* indicate greater dissimilarity. This makes it usable as a loss function in a setting where you
29+
* try to maximize the proximity between predictions and targets. If either <code>labels</code> or
30+
* <code>predictions</code> is a zero vector, cosine similarity will be <code>0</code> regardless of
31+
* the proximity between predictions and targets.
3132
*
3233
* <p><code>loss = -sum(l2Norm(labels) * l2Norm(predictions))</code>
3334
*
@@ -71,7 +72,7 @@ public class CosineSimilarity extends Loss {
7172
public static final int DEFAULT_AXIS = -1;
7273
public static final Reduction DEFAULT_REDUCTION = Reduction.AUTO;
7374

74-
private final int axis;
75+
private final int[] axis;
7576

7677
/**
7778
* Creates a Cosine Similarity Loss using {@link Class#getSimpleName()} as the loss name, an axis
@@ -107,6 +108,17 @@ public CosineSimilarity(Ops tf, int axis) {
107108

108109
this(tf, null, axis, DEFAULT_REDUCTION);
109110
}
111+
/**
112+
* Creates a Cosine Similarity Loss using {@link Class#getSimpleName()} as the loss name, and a
113+
* Loss Reduction of {@link #DEFAULT_REDUCTION}
114+
*
115+
* @param tf the TensorFlow Ops
116+
* @param axis The dimension along which the cosine similarity is computed.
117+
*/
118+
public CosineSimilarity(Ops tf, int[] axis) {
119+
120+
this(tf, null, axis, DEFAULT_REDUCTION);
121+
}
110122

111123
/**
112124
* Creates a Cosine Similarity Loss using a Loss Reduction of {@link #DEFAULT_REDUCTION}
@@ -120,6 +132,18 @@ public CosineSimilarity(Ops tf, String name, int axis) {
120132
this(tf, name, axis, DEFAULT_REDUCTION);
121133
}
122134

135+
/**
136+
* Creates a Cosine Similarity Loss using a Loss Reduction of {@link #DEFAULT_REDUCTION}
137+
*
138+
* @param tf the TensorFlow Ops
139+
* @param name the name of the loss
140+
* @param axis The dimension along which the cosine similarity is computed.
141+
*/
142+
public CosineSimilarity(Ops tf, String name, int[] axis) {
143+
144+
this(tf, name, axis, DEFAULT_REDUCTION);
145+
}
146+
123147
/**
124148
* Creates a Cosine Similarity Loss using {@link Class#getSimpleName()} as the loss name and an
125149
* axis of {@link #DEFAULT_AXIS}
@@ -153,6 +177,18 @@ public CosineSimilarity(Ops tf, String name, Reduction reduction) {
153177
*/
154178
public CosineSimilarity(Ops tf, int axis, Reduction reduction) {
155179

180+
this(tf, null, new int[] {axis}, reduction);
181+
}
182+
183+
/**
184+
* Creates a Cosine Similarity Loss using {@link Class#getSimpleName()} as the loss name
185+
*
186+
* @param tf the TensorFlow Ops
187+
* @param axis The dimension along which the cosine similarity is computed.
188+
* @param reduction Type of Reduction to apply to the loss.
189+
*/
190+
public CosineSimilarity(Ops tf, int[] axis, Reduction reduction) {
191+
156192
this(tf, null, axis, reduction);
157193
}
158194

@@ -165,15 +201,28 @@ public CosineSimilarity(Ops tf, int axis, Reduction reduction) {
165201
* @param reduction Type of Reduction to apply to the loss.
166202
*/
167203
public CosineSimilarity(Ops tf, String name, int axis, Reduction reduction) {
204+
this(tf, name, new int[] {axis}, reduction);
205+
}
206+
207+
/**
208+
* Creates a Cosine Similarity Loss
209+
*
210+
* @param tf the TensorFlow Ops
211+
* @param name the name of the loss
212+
* @param axis The dimension along which the cosine similarity is computed.
213+
* @param reduction Type of Reduction to apply to the loss.
214+
*/
215+
public CosineSimilarity(Ops tf, String name, int[] axis, Reduction reduction) {
168216
super(tf, name, reduction);
169217
this.axis = axis;
170218
}
171219

172220
/** {@inheritDoc} */
173221
@Override
174-
public <T extends TNumber, U extends TNumber> Operand<T> call(
175-
Operand<U> labels, Operand<T> predictions, Operand<T> sampleWeights) {
222+
public <T extends TNumber> Operand<T> call(
223+
Operand<? extends TNumber> labels, Operand<T> predictions, Operand<T> sampleWeights) {
176224
Operand<T> losses = Losses.cosineSimilarity(getTF(), labels, predictions, axis);
225+
losses = tf.math.neg(losses);
177226
return LossesHelper.computeWeightedLoss(getTF(), losses, getReduction(), sampleWeights);
178227
}
179228
}

tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Hinge.java

Lines changed: 11 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -18,15 +18,16 @@
1818
import org.tensorflow.framework.losses.impl.LossesHelper;
1919
import org.tensorflow.op.Ops;
2020
import org.tensorflow.types.family.TNumber;
21+
2122
import static org.tensorflow.framework.utils.CastHelper.cast;
2223

2324
/**
2425
* Computes the hinge loss between labels and predictions.
2526
*
26-
* <p><code>loss = maximum(1 - labels * predictions, 0)</code></p>.
27+
* <p><code>loss = maximum(1 - labels * predictions, 0)</code>.
2728
*
28-
* <p><code>labels</code> values are expected to be -1 or 1.
29-
* If binary (0 or 1) labels are provided, they will be converted to -1 or 1.</p>
29+
* <p><code>labels</code> values are expected to be -1 or 1. If binary (0 or 1) labels are provided,
30+
* they will be converted to -1 or 1.
3031
*
3132
* <p>Standalone usage:
3233
*
@@ -106,7 +107,7 @@ public Hinge(Ops tf, String name, Reduction reduction) {
106107
* label values are not in the set [-1., 0., 1.].
107108
*
108109
* @param labels the truth values or labels, must be either -1, 0, or 1. Values are expected to be
109-
* -1 or 1. If binary (0 or 1) labels are provided they will be converted to -1 or 1.
110+
* -1 or 1. If binary (0 or 1) labels are provided they will be converted to -1 or 1.
110111
* @param predictions the predictions, values must be in the range [0. to 1.] inclusive.
111112
* @param sampleWeights Optional sampleWeights acts as a coefficient for the loss. If a scalar is
112113
* provided, then the loss is simply scaled by the given value. If sampleWeights is a tensor
@@ -116,21 +117,19 @@ public Hinge(Ops tf, String name, Reduction reduction) {
116117
* predictions is scaled by the corresponding value of SampleWeights. (Note on dN-1: all loss
117118
* functions reduce by 1 dimension, usually axis=-1.)
118119
* @param <T> The data type of the predictions, sampleWeights and loss.
119-
* @param <U> The data type of the labels.
120120
* @return the loss
121121
* @throws IllegalArgumentException if the predictions are outside the range [0.-1.].
122122
*/
123123
@Override
124-
public <T extends TNumber, U extends TNumber> Operand<T> call(
125-
Operand<U> labels, Operand<T> predictions, Operand<T> sampleWeights) {
126-
@SuppressWarnings("unchecked")
127-
Operand<T> tLabels = predictions.type() == labels.type() ?
128-
(Operand<T>)labels : cast(tf, labels, predictions.type());
129-
tLabels = LossesHelper.valueCheck(
124+
public <T extends TNumber> Operand<T> call(
125+
Operand<? extends TNumber> labels, Operand<T> predictions, Operand<T> sampleWeights) {
126+
Operand<T> tLabels = cast(tf, labels, predictions.type());
127+
tLabels =
128+
LossesHelper.valueCheck(
130129
getTF(),
131130
"labels value check [-1, 0, 1]",
132131
tLabels,
133-
cast(getTF(), getTF().constant(new int[] { -1, 0, 1}), predictions.type()));
132+
cast(getTF(), getTF().constant(new int[] {-1, 0, 1}), predictions.type()));
134133
Operand<T> losses = Losses.hinge(getTF(), tLabels, predictions);
135134
return LossesHelper.computeWeightedLoss(getTF(), losses, getReduction(), sampleWeights);
136135
}

tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Huber.java

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,7 @@ public Huber(Ops tf) {
8989
* Loss#REDUCTION_DEFAULT}
9090
*
9191
* @param tf the TensorFlow Ops
92+
* @param name the name of the loss, if null then {@link Class#getSimpleName()} is used.
9293
*/
9394
public Huber(Ops tf, String name) {
9495
this(tf, name, DELTA_DEFAULT, Reduction.AUTO);
@@ -109,6 +110,7 @@ public Huber(Ops tf, Reduction reduction) {
109110
* Creates a Huber Loss using {@link #DELTA_DEFAULT} as the delta
110111
*
111112
* @param tf the TensorFlow Ops
113+
* @param name the name of the loss, if null then {@link Class#getSimpleName()} is used.
112114
* @param reduction Type of Reduction to apply to the loss.
113115
*/
114116
public Huber(Ops tf, String name, Reduction reduction) {
@@ -119,7 +121,7 @@ public Huber(Ops tf, String name, Reduction reduction) {
119121
* Creates a Huber Loss
120122
*
121123
* @param tf the TensorFlow Ops
122-
* @param name the name of the loss
124+
* @param name the name of the loss, if null then {@link Class#getSimpleName()} is used.
123125
* @param delta the point where the Huber loss function changes from quadratic to linear.
124126
* @param reduction Type of Reduction to apply to the loss.
125127
*/
@@ -130,8 +132,8 @@ public Huber(Ops tf, String name, float delta, Reduction reduction) {
130132

131133
/** {@inheritDoc} */
132134
@Override
133-
public <T extends TNumber, U extends TNumber> Operand<T> call(
134-
Operand<U> labels, Operand<T> predictions, Operand<T> sampleWeights) {
135+
public <T extends TNumber> Operand<T> call(
136+
Operand<? extends TNumber> labels, Operand<T> predictions, Operand<T> sampleWeights) {
135137
Operand<T> losses = Losses.huber(getTF(), labels, predictions, delta);
136138
return LossesHelper.computeWeightedLoss(getTF(), losses, getReduction(), sampleWeights);
137139
}

tensorflow-framework/src/main/java/org/tensorflow/framework/losses/KLDivergence.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -99,8 +99,8 @@ public KLDivergence(Ops tf, String name, Reduction reduction) {
9999

100100
/** {@inheritDoc} */
101101
@Override
102-
public <T extends TNumber, U extends TNumber> Operand<T> call(
103-
Operand<U> labels, Operand<T> predictions, Operand<T> sampleWeights) {
102+
public <T extends TNumber> Operand<T> call(
103+
Operand<? extends TNumber> labels, Operand<T> predictions, Operand<T> sampleWeights) {
104104
Operand<T> losses = Losses.kullbackLeiblerDivergence(getTF(), labels, predictions);
105105
return LossesHelper.computeWeightedLoss(getTF(), losses, getReduction(), sampleWeights);
106106
}

0 commit comments

Comments
 (0)