Skip to content

Commit 3044d4b

Browse files
authored
Number constant ops (#210)
Signed-off-by: Ryan Nett <rnett@calpoly.edu>
1 parent 52bbff3 commit 3044d4b

File tree

3 files changed

+143
-22
lines changed

3 files changed

+143
-22
lines changed

tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/Ops.java

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1831,6 +1831,19 @@ public Constant<TInt32> constant(Shape shape, IntDataBuffer data) {
18311831
return Constant.tensorOf(scope, shape, data);
18321832
}
18331833

1834+
/**
1835+
* Creates a scalar of {@code type}, with the value of {@code number}.
1836+
* {@code number} may be truncated if it does not fit in the target type.
1837+
*
1838+
* @param type the type of tensor to create. Must be concrete (i.e. not {@link org.tensorflow.types.family.TFloating})
1839+
* @param number the value of the tensor
1840+
* @return a constant of the passed type
1841+
* @throws IllegalArgumentException if the type is abstract (i.e. {@link org.tensorflow.types.family.TFloating}) or unknown.
1842+
*/
1843+
public <T extends TNumber> Constant<T> constant(Class<T> type, Number number) {
1844+
return Constant.tensorOf(scope, type, number);
1845+
}
1846+
18341847
/**
18351848
* Create a {@link TString} constant with data from the given buffer, using the given encoding.
18361849
*
@@ -1876,6 +1889,20 @@ public <T extends TType> Constant<T> constantOf(T tensor) {
18761889
return Constant.create(scope, tensor);
18771890
}
18781891

1892+
/**
1893+
* Creates a scalar of the same type as {@code toMatch}, with the value of {@code number}.
1894+
* {@code number} may be truncated if it does not fit in the target type.
1895+
*
1896+
* @param toMatch the operand providing the target type
1897+
* @param number the value of the tensor
1898+
* @return a constant with the same type as {@code toMatch}
1899+
* @see Ops#constant(Class, Number)
1900+
* @throws IllegalArgumentException if the type is unknown (which should be impossible).
1901+
*/
1902+
public <T extends TNumber> Constant<T> constantOfSameType(Operand<T> toMatch, Number number) {
1903+
return Constant.tensorOfSameType(scope, toMatch, number);
1904+
}
1905+
18791906
/**
18801907
* This op consumes a lock created by `MutexLock`.
18811908
* <p>

tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/core/Constant.java

Lines changed: 77 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -20,18 +20,6 @@
2020
import org.tensorflow.Operation;
2121
import org.tensorflow.Output;
2222
import org.tensorflow.Tensor;
23-
import org.tensorflow.op.RawOp;
24-
import org.tensorflow.op.Scope;
25-
import org.tensorflow.op.annotation.Endpoint;
26-
import org.tensorflow.op.annotation.Operator;
27-
import org.tensorflow.ndarray.Shape;
28-
import org.tensorflow.ndarray.buffer.BooleanDataBuffer;
29-
import org.tensorflow.ndarray.buffer.ByteDataBuffer;
30-
import org.tensorflow.ndarray.buffer.DataBuffer;
31-
import org.tensorflow.ndarray.buffer.DoubleDataBuffer;
32-
import org.tensorflow.ndarray.buffer.FloatDataBuffer;
33-
import org.tensorflow.ndarray.buffer.IntDataBuffer;
34-
import org.tensorflow.ndarray.buffer.LongDataBuffer;
3523
import org.tensorflow.ndarray.BooleanNdArray;
3624
import org.tensorflow.ndarray.ByteNdArray;
3725
import org.tensorflow.ndarray.DoubleNdArray;
@@ -40,14 +28,30 @@
4028
import org.tensorflow.ndarray.LongNdArray;
4129
import org.tensorflow.ndarray.NdArray;
4230
import org.tensorflow.ndarray.NdArrays;
31+
import org.tensorflow.ndarray.Shape;
4332
import org.tensorflow.ndarray.StdArrays;
33+
import org.tensorflow.ndarray.buffer.BooleanDataBuffer;
34+
import org.tensorflow.ndarray.buffer.ByteDataBuffer;
35+
import org.tensorflow.ndarray.buffer.DataBuffer;
36+
import org.tensorflow.ndarray.buffer.DoubleDataBuffer;
37+
import org.tensorflow.ndarray.buffer.FloatDataBuffer;
38+
import org.tensorflow.ndarray.buffer.IntDataBuffer;
39+
import org.tensorflow.ndarray.buffer.LongDataBuffer;
40+
import org.tensorflow.op.Ops;
41+
import org.tensorflow.op.RawOp;
42+
import org.tensorflow.op.Scope;
43+
import org.tensorflow.op.annotation.Endpoint;
44+
import org.tensorflow.op.annotation.Operator;
45+
import org.tensorflow.types.TBfloat16;
4446
import org.tensorflow.types.TBool;
47+
import org.tensorflow.types.TFloat16;
4548
import org.tensorflow.types.TFloat32;
4649
import org.tensorflow.types.TFloat64;
4750
import org.tensorflow.types.TInt32;
4851
import org.tensorflow.types.TInt64;
4952
import org.tensorflow.types.TString;
5053
import org.tensorflow.types.TUint8;
54+
import org.tensorflow.types.family.TNumber;
5155
import org.tensorflow.types.family.TType;
5256

5357
/**
@@ -1277,6 +1281,67 @@ public static Constant<TInt64> tensorOf(Scope scope, Shape shape) {
12771281
return vectorOf(scope, shape.asArray());
12781282
}
12791283

1284+
/**
1285+
* Creates a scalar of {@code type}, with the value of {@code number}. {@code number} may be truncated if it does not
1286+
* fit in the target type.
1287+
*
1288+
* @param type the type of tensor to create. Must be concrete (i.e. not {@link org.tensorflow.types.family.TFloating})
1289+
* @param number the value of the tensor
1290+
* @return a constant of the passed type
1291+
* @throws IllegalArgumentException if the type is abstract (i.e. {@link org.tensorflow.types.family.TFloating}) or
1292+
* unknown.
1293+
*/
1294+
@SuppressWarnings("unchecked")
1295+
@Endpoint
1296+
public static <T extends TNumber> Constant<T> tensorOf(Scope scope, Class<T> type, Number number) {
1297+
if (type.equals(TBfloat16.class)) {
1298+
try (TBfloat16 tensor = TBfloat16.scalarOf(number.floatValue())) {
1299+
return (Constant<T>) create(scope, tensor);
1300+
}
1301+
} else if (type.equals(TFloat64.class)) {
1302+
try (TFloat64 tensor = TFloat64.scalarOf(number.doubleValue())) {
1303+
return (Constant<T>) create(scope, tensor);
1304+
}
1305+
} else if (type.equals(TFloat32.class)) {
1306+
try (TFloat32 tensor = TFloat32.scalarOf(number.floatValue())) {
1307+
return (Constant<T>) create(scope, tensor);
1308+
}
1309+
} else if (type.equals(TFloat16.class)) {
1310+
try (TFloat16 tensor = TFloat16.scalarOf(number.floatValue())) {
1311+
return (Constant<T>) create(scope, tensor);
1312+
}
1313+
} else if (type.equals(TInt64.class)) {
1314+
try (TInt64 tensor = TInt64.scalarOf(number.longValue())) {
1315+
return (Constant<T>) create(scope, tensor);
1316+
}
1317+
} else if (type.equals(TInt32.class)) {
1318+
try (TInt32 tensor = TInt32.scalarOf(number.intValue())) {
1319+
return (Constant<T>) create(scope, tensor);
1320+
}
1321+
} else if (type.equals(TUint8.class)) {
1322+
try (TUint8 tensor = TUint8.scalarOf(number.byteValue())) {
1323+
return (Constant<T>) create(scope, tensor);
1324+
}
1325+
} else {
1326+
throw new IllegalArgumentException("Tensor type " + type + " is an abstract or unknown numeric type.");
1327+
}
1328+
}
1329+
1330+
/**
1331+
* Creates a scalar of the same type as {@code toMatch}, with the value of {@code number}. {@code number} may be
1332+
* truncated if it does not fit in the target type.
1333+
*
1334+
* @param toMatch the operand providing the target type
1335+
* @param number the value of the tensor
1336+
* @return a constant with the same type as {@code toMatch}
1337+
* @throws IllegalArgumentException if the type is unknown (which should be impossible).
1338+
* @see Ops#constant(Class, Number)
1339+
*/
1340+
@Endpoint(name = "constantOfSameType")
1341+
public static <T extends TNumber> Constant<T> tensorOfSameType(Scope scope, Operand<T> toMatch, Number number) {
1342+
return tensorOf(scope, toMatch.type(), number);
1343+
}
1344+
12801345
/**
12811346
* Create a constant by making an immutable copy of {@code tensor}. {@code tensor} may be closed afterwards without
12821347
* issue.

tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/ConstantTest.java

Lines changed: 39 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -18,35 +18,40 @@
1818
import static org.junit.jupiter.api.Assertions.assertEquals;
1919

2020
import java.io.IOException;
21-
2221
import org.junit.jupiter.api.Test;
2322
import org.tensorflow.AutoCloseableList;
2423
import org.tensorflow.EagerSession;
2524
import org.tensorflow.Graph;
25+
import org.tensorflow.Operand;
2626
import org.tensorflow.Session;
2727
import org.tensorflow.Tensor;
28-
import org.tensorflow.op.Ops;
29-
import org.tensorflow.op.Scope;
28+
import org.tensorflow.ndarray.DoubleNdArray;
29+
import org.tensorflow.ndarray.FloatNdArray;
30+
import org.tensorflow.ndarray.IntNdArray;
31+
import org.tensorflow.ndarray.LongNdArray;
32+
import org.tensorflow.ndarray.NdArray;
33+
import org.tensorflow.ndarray.NdArrays;
3034
import org.tensorflow.ndarray.Shape;
3135
import org.tensorflow.ndarray.buffer.DataBuffer;
3236
import org.tensorflow.ndarray.buffer.DataBuffers;
3337
import org.tensorflow.ndarray.buffer.DoubleDataBuffer;
3438
import org.tensorflow.ndarray.buffer.FloatDataBuffer;
3539
import org.tensorflow.ndarray.buffer.IntDataBuffer;
3640
import org.tensorflow.ndarray.buffer.LongDataBuffer;
37-
import org.tensorflow.ndarray.DoubleNdArray;
38-
import org.tensorflow.ndarray.FloatNdArray;
39-
import org.tensorflow.ndarray.IntNdArray;
40-
import org.tensorflow.ndarray.LongNdArray;
41-
import org.tensorflow.ndarray.NdArray;
42-
import org.tensorflow.ndarray.NdArrays;
41+
import org.tensorflow.op.Ops;
42+
import org.tensorflow.op.Scope;
43+
import org.tensorflow.types.TBfloat16;
44+
import org.tensorflow.types.TFloat16;
4345
import org.tensorflow.types.TFloat32;
4446
import org.tensorflow.types.TFloat64;
4547
import org.tensorflow.types.TInt32;
4648
import org.tensorflow.types.TInt64;
4749
import org.tensorflow.types.TString;
50+
import org.tensorflow.types.TUint8;
51+
import org.tensorflow.types.family.TNumber;
4852

4953
public class ConstantTest {
54+
5055
private static final float EPSILON = 1e-7f;
5156

5257
@Test
@@ -56,7 +61,7 @@ public void createInts() {
5661
IntNdArray array = NdArrays.wrap(shape, buffer);
5762

5863
try (Graph g = new Graph();
59-
Session sess = new Session(g)) {
64+
Session sess = new Session(g)) {
6065
Scope scope = new Scope(g);
6166
Constant<TInt32> op1 = Constant.tensorOf(scope, shape, buffer);
6267
Constant<TInt32> op2 = Constant.tensorOf(scope, array);
@@ -164,4 +169,28 @@ public void createFromTensorsInEagerMode() throws IOException {
164169
assertEquals(NdArrays.vectorOf(1, 2, 3, 4), c1.asTensor());
165170
}
166171
}
172+
173+
private static void testCreateFromNumber(Ops tf, Class<? extends TNumber> type) {
174+
Operand<? extends TNumber> constant = tf.constant(type, 10);
175+
assertEquals(type, constant.type());
176+
177+
try (TFloat64 t = tf.dtypes.cast(constant, TFloat64.class).asTensor()) {
178+
assertEquals(10.0, t.getDouble());
179+
}
180+
}
181+
182+
@Test
183+
public void createFromNumber() {
184+
try (EagerSession s = EagerSession.create()) {
185+
Ops tf = Ops.create(s);
186+
187+
testCreateFromNumber(tf, TBfloat16.class);
188+
testCreateFromNumber(tf, TFloat64.class);
189+
testCreateFromNumber(tf, TFloat32.class);
190+
testCreateFromNumber(tf, TFloat16.class);
191+
testCreateFromNumber(tf, TInt64.class);
192+
testCreateFromNumber(tf, TInt32.class);
193+
testCreateFromNumber(tf, TUint8.class);
194+
}
195+
}
167196
}

0 commit comments

Comments
 (0)