|
20 | 20 | import org.tensorflow.Operation; |
21 | 21 | import org.tensorflow.Output; |
22 | 22 | import org.tensorflow.Tensor; |
23 | | -import org.tensorflow.op.RawOp; |
24 | | -import org.tensorflow.op.Scope; |
25 | | -import org.tensorflow.op.annotation.Endpoint; |
26 | | -import org.tensorflow.op.annotation.Operator; |
27 | | -import org.tensorflow.ndarray.Shape; |
28 | | -import org.tensorflow.ndarray.buffer.BooleanDataBuffer; |
29 | | -import org.tensorflow.ndarray.buffer.ByteDataBuffer; |
30 | | -import org.tensorflow.ndarray.buffer.DataBuffer; |
31 | | -import org.tensorflow.ndarray.buffer.DoubleDataBuffer; |
32 | | -import org.tensorflow.ndarray.buffer.FloatDataBuffer; |
33 | | -import org.tensorflow.ndarray.buffer.IntDataBuffer; |
34 | | -import org.tensorflow.ndarray.buffer.LongDataBuffer; |
35 | 23 | import org.tensorflow.ndarray.BooleanNdArray; |
36 | 24 | import org.tensorflow.ndarray.ByteNdArray; |
37 | 25 | import org.tensorflow.ndarray.DoubleNdArray; |
|
40 | 28 | import org.tensorflow.ndarray.LongNdArray; |
41 | 29 | import org.tensorflow.ndarray.NdArray; |
42 | 30 | import org.tensorflow.ndarray.NdArrays; |
| 31 | +import org.tensorflow.ndarray.Shape; |
43 | 32 | import org.tensorflow.ndarray.StdArrays; |
| 33 | +import org.tensorflow.ndarray.buffer.BooleanDataBuffer; |
| 34 | +import org.tensorflow.ndarray.buffer.ByteDataBuffer; |
| 35 | +import org.tensorflow.ndarray.buffer.DataBuffer; |
| 36 | +import org.tensorflow.ndarray.buffer.DoubleDataBuffer; |
| 37 | +import org.tensorflow.ndarray.buffer.FloatDataBuffer; |
| 38 | +import org.tensorflow.ndarray.buffer.IntDataBuffer; |
| 39 | +import org.tensorflow.ndarray.buffer.LongDataBuffer; |
| 40 | +import org.tensorflow.op.Ops; |
| 41 | +import org.tensorflow.op.RawOp; |
| 42 | +import org.tensorflow.op.Scope; |
| 43 | +import org.tensorflow.op.annotation.Endpoint; |
| 44 | +import org.tensorflow.op.annotation.Operator; |
| 45 | +import org.tensorflow.types.TBfloat16; |
44 | 46 | import org.tensorflow.types.TBool; |
| 47 | +import org.tensorflow.types.TFloat16; |
45 | 48 | import org.tensorflow.types.TFloat32; |
46 | 49 | import org.tensorflow.types.TFloat64; |
47 | 50 | import org.tensorflow.types.TInt32; |
48 | 51 | import org.tensorflow.types.TInt64; |
49 | 52 | import org.tensorflow.types.TString; |
50 | 53 | import org.tensorflow.types.TUint8; |
| 54 | +import org.tensorflow.types.family.TNumber; |
51 | 55 | import org.tensorflow.types.family.TType; |
52 | 56 |
|
53 | 57 | /** |
@@ -1277,6 +1281,67 @@ public static Constant<TInt64> tensorOf(Scope scope, Shape shape) { |
1277 | 1281 | return vectorOf(scope, shape.asArray()); |
1278 | 1282 | } |
1279 | 1283 |
|
| 1284 | + /** |
| 1285 | + * Creates a scalar of {@code type}, with the value of {@code number}. {@code number} may be truncated if it does not |
| 1286 | + * fit in the target type. |
| 1287 | + * |
| 1288 | + * @param type the type of tensor to create. Must be concrete (i.e. not {@link org.tensorflow.types.family.TFloating}) |
| 1289 | + * @param number the value of the tensor |
| 1290 | + * @return a constant of the passed type |
| 1291 | + * @throws IllegalArgumentException if the type is abstract (i.e. {@link org.tensorflow.types.family.TFloating}) or |
| 1292 | + * unknown. |
| 1293 | + */ |
| 1294 | + @SuppressWarnings("unchecked") |
| 1295 | + @Endpoint |
| 1296 | + public static <T extends TNumber> Constant<T> tensorOf(Scope scope, Class<T> type, Number number) { |
| 1297 | + if (type.equals(TBfloat16.class)) { |
| 1298 | + try (TBfloat16 tensor = TBfloat16.scalarOf(number.floatValue())) { |
| 1299 | + return (Constant<T>) create(scope, tensor); |
| 1300 | + } |
| 1301 | + } else if (type.equals(TFloat64.class)) { |
| 1302 | + try (TFloat64 tensor = TFloat64.scalarOf(number.doubleValue())) { |
| 1303 | + return (Constant<T>) create(scope, tensor); |
| 1304 | + } |
| 1305 | + } else if (type.equals(TFloat32.class)) { |
| 1306 | + try (TFloat32 tensor = TFloat32.scalarOf(number.floatValue())) { |
| 1307 | + return (Constant<T>) create(scope, tensor); |
| 1308 | + } |
| 1309 | + } else if (type.equals(TFloat16.class)) { |
| 1310 | + try (TFloat16 tensor = TFloat16.scalarOf(number.floatValue())) { |
| 1311 | + return (Constant<T>) create(scope, tensor); |
| 1312 | + } |
| 1313 | + } else if (type.equals(TInt64.class)) { |
| 1314 | + try (TInt64 tensor = TInt64.scalarOf(number.longValue())) { |
| 1315 | + return (Constant<T>) create(scope, tensor); |
| 1316 | + } |
| 1317 | + } else if (type.equals(TInt32.class)) { |
| 1318 | + try (TInt32 tensor = TInt32.scalarOf(number.intValue())) { |
| 1319 | + return (Constant<T>) create(scope, tensor); |
| 1320 | + } |
| 1321 | + } else if (type.equals(TUint8.class)) { |
| 1322 | + try (TUint8 tensor = TUint8.scalarOf(number.byteValue())) { |
| 1323 | + return (Constant<T>) create(scope, tensor); |
| 1324 | + } |
| 1325 | + } else { |
| 1326 | + throw new IllegalArgumentException("Tensor type " + type + " is an abstract or unknown numeric type."); |
| 1327 | + } |
| 1328 | + } |
| 1329 | + |
| 1330 | + /** |
| 1331 | + * Creates a scalar of the same type as {@code toMatch}, with the value of {@code number}. {@code number} may be |
| 1332 | + * truncated if it does not fit in the target type. |
| 1333 | + * |
| 1334 | + * @param toMatch the operand providing the target type |
| 1335 | + * @param number the value of the tensor |
| 1336 | + * @return a constant with the same type as {@code toMatch} |
| 1337 | + * @throws IllegalArgumentException if the type is unknown (which should be impossible). |
| 1338 | + * @see Ops#constant(Class, Number) |
| 1339 | + */ |
| 1340 | + @Endpoint(name = "constantOfSameType") |
| 1341 | + public static <T extends TNumber> Constant<T> tensorOfSameType(Scope scope, Operand<T> toMatch, Number number) { |
| 1342 | + return tensorOf(scope, toMatch.type(), number); |
| 1343 | + } |
| 1344 | + |
1280 | 1345 | /** |
1281 | 1346 | * Create a constant by making an immutable copy of {@code tensor}. {@code tensor} may be closed afterwards without |
1282 | 1347 | * issue. |
|
0 commit comments