Skip to content

Commit 23d6f0b

Browse files
authored
activations, constraints, initializers, losses, regularizers: move Ops param from CTOR to call method (#329)
* Move Ops from CTOR to call method * Move Ops from CTOR to call method * Move Ops from CTOR to call method * JavaDoc fixes including Dataset * Results of Run mvn spotless:apply
1 parent 19e1c8d commit 23d6f0b

File tree

138 files changed

+2865
-3066
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

138 files changed

+2865
-3066
lines changed
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License.
14+
=======================================================================*/
15+
package org.tensorflow.framework.activations;
16+
17+
import org.tensorflow.op.Ops;
18+
import org.tensorflow.types.family.TNumber;
19+
20+
/** Abstract base class for Activations */
21+
public abstract class AbstractActivation<T extends TNumber> implements Activation<T> {
22+
23+
/** The TensorFlow Ops */
24+
protected Ops tf;
25+
26+
/** Creates the abstract class for an AbstractActivation */
27+
protected AbstractActivation() {}
28+
29+
/**
30+
* Gets the TensorFlow Ops
31+
*
32+
* @return the TensorFlow Ops
33+
*/
34+
protected Ops getTF() {
35+
return this.tf;
36+
}
37+
38+
/**
39+
* Sets the TensorFlow Ops
40+
*
41+
* @param tf the TensorFlow Ops
42+
*/
43+
protected void setTF(Ops tf) {
44+
this.tf = tf;
45+
}
46+
}
Lines changed: 7 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
1+
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
22
33
Licensed under the Apache License, Version 2.0 (the "License");
44
you may not use this file except in compliance with the License.
@@ -19,50 +19,19 @@
1919
import org.tensorflow.types.family.TNumber;
2020

2121
/**
22-
* Abstract base class for Activations
22+
* Interface for Activations
2323
*
24-
* <p><b>Note:</b> The {@link #tf} attribute must be set prior to invoking the call method. See
25-
* {@link #setTF(Ops)} and the constructor {@link #Activation(Ops)}.
26-
*
27-
* @param <T> the data type of the activation
24+
* @param <T> the data type of the input and the result
2825
*/
29-
public abstract class Activation<T extends TNumber> {
30-
31-
/** The TensorFlow Ops */
32-
protected Ops tf;
33-
34-
/**
35-
* Creates the abstract class for an Activation
36-
*
37-
* @param tf the TensorFlow Ops
38-
*/
39-
protected Activation(Ops tf) {
40-
this.tf = tf;
41-
}
42-
43-
/**
44-
* Sets the TensorFlow Ops
45-
*
46-
* @param tf the TensorFlow Ops
47-
*/
48-
protected void setTF(Ops tf) {
49-
this.tf = tf;
50-
}
51-
52-
/**
53-
* Gets the TensorFlow Ops
54-
*
55-
* @return the TensorFlow Ops
56-
*/
57-
protected Ops getTF() {
58-
return this.tf;
59-
}
26+
@FunctionalInterface
27+
public interface Activation<T extends TNumber> {
6028

6129
/**
6230
* Gets the calculation operation for the activation.
6331
*
32+
* @param tf the TensorFlow Ops
6433
* @param input the input tensor
6534
* @return The operand for the activation
6635
*/
67-
public abstract Operand<T> call(Operand<T> input);
36+
Operand<T> call(Ops tf, Operand<T> input);
6837
}

tensorflow-framework/src/main/java/org/tensorflow/framework/activations/ELU.java

Lines changed: 12 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414
=======================================================================*/
1515
package org.tensorflow.framework.activations;
1616

17+
import static org.tensorflow.framework.utils.CastHelper.cast;
18+
1719
import org.tensorflow.Operand;
1820
import org.tensorflow.op.Ops;
1921
import org.tensorflow.types.TBool;
@@ -44,53 +46,41 @@
4446
* Operand&lt;TFloat32&gt; result = elu.call(input);
4547
* </pre>
4648
*
47-
* @param <T> the data type of the activation
4849
* @see <a href="https://arxiv.org/abs/1511.07289">Clevert et al, 2016, Fast and Accurate Deep
4950
* Network Learning by Exponential Linear Units (ELUs)</a>
5051
*/
51-
public class ELU<T extends TFloating> extends Activation<T> {
52+
public class ELU<T extends TFloating> extends AbstractActivation<T> {
5253

5354
private static final double ALPHA_DEFAULT = 1.0;
5455

5556
/** A scalar, slope of negative section. */
5657
private final double alpha;
5758

58-
/**
59-
* Creates a new ELU with alpha={@link #ALPHA_DEFAULT}.
60-
*
61-
* @param tf the TensorFlow Ops
62-
*/
63-
public ELU(Ops tf) {
64-
this(tf, ALPHA_DEFAULT);
59+
/** Creates a new ELU with alpha={@link #ALPHA_DEFAULT}. */
60+
public ELU() {
61+
this(ALPHA_DEFAULT);
6562
}
6663

6764
/**
6865
* Creates a new ELU
6966
*
70-
* @param tf the TensorFlow Ops
7167
* @param alpha A scalar, slope of negative section. It controls the value to which an ELU
7268
* saturates for negative net inputs.
7369
*/
74-
public ELU(Ops tf, double alpha) {
75-
super(tf);
70+
public ELU(double alpha) {
71+
super();
7672
this.alpha = alpha;
7773
}
7874

79-
/**
80-
* Gets the calculation operation for the activation.
81-
*
82-
* @param input the input tensor
83-
* @return The operand for the activation
84-
*/
75+
/** {@inheritDoc} */
8576
@Override
86-
public Operand<T> call(Operand<T> input) {
87-
77+
public Operand<T> call(Ops tf, Operand<T> input) {
8878
Operand<T> result = tf.nn.elu(input);
8979
if (alpha == 1.0) return result;
9080
else {
9181
Class<T> inputType = input.type();
92-
Operand<T> y = tf.math.mul(result, tf.dtypes.cast(tf.constant(alpha), inputType));
93-
Operand<TBool> cond = tf.math.greater(result, tf.dtypes.cast(tf.constant(0), inputType));
82+
Operand<T> y = tf.math.mul(result, cast(tf, tf.constant(alpha), inputType));
83+
Operand<TBool> cond = tf.math.greater(result, cast(tf, tf.constant(0), inputType));
9484
return tf.select(cond, result, y);
9585
}
9686
}

tensorflow-framework/src/main/java/org/tensorflow/framework/activations/Exponential.java

Lines changed: 6 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -30,28 +30,17 @@
3030
* Operand&lt;TFloat32&gt; result = exp.call(input);
3131
* // result is [0.04978707f, 0.36787945f, 1.f, 2.7182817f, 20.085537f]
3232
* </pre>
33-
*
34-
* @param <T> the data type of the activation
3533
*/
36-
public class Exponential<T extends TFloating> extends Activation<T> {
34+
public class Exponential<T extends TFloating> extends AbstractActivation<T> {
3735

38-
/**
39-
* Creates an Exponential activation.
40-
*
41-
* @param tf the TensorFlow Ops
42-
*/
43-
public Exponential(Ops tf) {
44-
super(tf);
36+
/** Creates an Exponential activation. */
37+
public Exponential() {
38+
super();
4539
}
4640

47-
/**
48-
* Calculates the Exponential activation.
49-
*
50-
* @param input the input tensor
51-
* @return an Operand for the exponential activation: <code>exp(x)</code>.
52-
*/
41+
/** {@inheritDoc} */
5342
@Override
54-
public Operand<T> call(Operand<T> input) {
43+
public Operand<T> call(Ops tf, Operand<T> input) {
5544
return tf.math.exp(input);
5645
}
5746
}

tensorflow-framework/src/main/java/org/tensorflow/framework/activations/HardSigmoid.java

Lines changed: 11 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414
=======================================================================*/
1515
package org.tensorflow.framework.activations;
1616

17+
import static org.tensorflow.framework.utils.CastHelper.cast;
18+
1719
import org.tensorflow.Operand;
1820
import org.tensorflow.op.Ops;
1921
import org.tensorflow.types.family.TFloating;
@@ -40,34 +42,23 @@
4042
* Operand&lt;TFloat32&gt; result = hardSigmoid.call(input);
4143
* // result is [0.f , 0.3f, 0.5f, 0.7f, 1.f]
4244
* </pre>
43-
*
44-
* @param <T> the data type of the result
4545
*/
46-
public class HardSigmoid<T extends TFloating> extends Activation<T> {
46+
public class HardSigmoid<T extends TFloating> extends AbstractActivation<T> {
4747

48-
/**
49-
* Creates Hard sigmoid activation.
50-
*
51-
* @param tf the TensorFlow Ops
52-
*/
53-
public HardSigmoid(Ops tf) {
54-
super(tf);
48+
/** Creates Hard sigmoid activation. */
49+
public HardSigmoid() {
50+
super();
5551
}
5652

57-
/**
58-
* Gets the calculation operation for the activation.
59-
*
60-
* @param input the input tensor
61-
* @return The operand for the activation
62-
*/
53+
/** {@inheritDoc} */
6354
@Override
64-
public Operand<T> call(Operand<T> input) {
55+
public Operand<T> call(Ops tf, Operand<T> input) {
6556
Class<T> inputType = input.type();
66-
Operand<T> point2 = tf.dtypes.cast(tf.constant(0.2), inputType);
67-
Operand<T> point5 = tf.dtypes.cast(tf.constant(0.5), inputType);
57+
Operand<T> point2 = cast(tf, tf.constant(0.2), inputType);
58+
Operand<T> point5 = cast(tf, tf.constant(0.5), inputType);
6859

6960
Operand<T> x = tf.math.add(tf.math.mul(input, point2), point5);
7061
return tf.clipByValue(
71-
x, tf.dtypes.cast(tf.constant(0), inputType), tf.dtypes.cast(tf.constant(1), inputType));
62+
x, cast(tf, tf.constant(0), inputType), cast(tf, tf.constant(1), inputType));
7263
}
7364
}

tensorflow-framework/src/main/java/org/tensorflow/framework/activations/Linear.java

Lines changed: 7 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,9 @@
1919
import org.tensorflow.types.family.TNumber;
2020

2121
/**
22-
* Linear activation function (pass-through).
22+
* Linear activation function (pass-through).
2323
*
24-
* <p>The linear activation returns its input. It is also known as the Identity activation function.</p>
24+
* <p>The linear activation returns its input. It is also known as the Identity activation function.
2525
*
2626
* <p>For example:
2727
*
@@ -33,20 +33,16 @@
3333
* // result is [-3.0f,-1.0f, 0.0f,1.0f,3.0f]
3434
* </pre>
3535
*/
36-
public class Linear<U extends TNumber> extends Activation<U> {
36+
public class Linear<U extends TNumber> extends AbstractActivation<U> {
3737

38-
/**
39-
* Creates a linear activation.
40-
*
41-
* @param tf the TensorFlow Ops
42-
*/
43-
public Linear(Ops tf) {
44-
super(tf);
38+
/** Creates a linear activation. */
39+
public Linear() {
40+
super();
4541
}
4642

4743
/** {@inheritDoc} */
4844
@Override
49-
public Operand<U> call(Operand<U> input) {
45+
public Operand<U> call(Ops tf, Operand<U> input) {
5046
return input;
5147
}
5248
}

0 commit comments

Comments
 (0)