Skip to content

Commit 94f5b15

Browse files
authored
Load TF library before computing TString size (#322)
1 parent a24b8ca commit 94f5b15

File tree

12 files changed

+72
-36
lines changed

12 files changed

+72
-36
lines changed

tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/EagerSession.java

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -404,6 +404,11 @@ private static void delete(TFE_Context handle) {
404404
}
405405

406406
static {
407-
TensorFlow.init();
407+
try {
408+
// Ensure that TensorFlow native library and classes are ready to be used
409+
Class.forName("org.tensorflow.TensorFlow");
410+
} catch (ClassNotFoundException e) {
411+
throw new RuntimeException(e);
412+
}
408413
}
409414
}

tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Graph.java

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1070,6 +1070,11 @@ private static SaverDef addVariableSaver(Graph graph) {
10701070
}
10711071

10721072
static {
1073-
TensorFlow.init();
1073+
try {
1074+
// Ensure that TensorFlow native library and classes are ready to be used
1075+
Class.forName("org.tensorflow.TensorFlow");
1076+
} catch (ClassNotFoundException e) {
1077+
throw new RuntimeException(e);
1078+
}
10741079
}
10751080
}

tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/RawTensor.java

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -228,6 +228,11 @@ private static long[] shape(TF_Tensor handle) {
228228
private ByteDataBuffer buffer = null;
229229

230230
static {
231-
TensorFlow.init();
231+
try {
232+
// Ensure that TensorFlow native library and classes are ready to be used
233+
Class.forName("org.tensorflow.TensorFlow");
234+
} catch (ClassNotFoundException e) {
235+
throw new RuntimeException(e);
236+
}
232237
}
233238
}

tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/SavedModelBundle.java

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -435,6 +435,11 @@ private static void validateTags(String[] tags) {
435435
}
436436

437437
static {
438-
TensorFlow.init();
438+
try {
439+
// Ensure that TensorFlow native library and classes are ready to be used
440+
Class.forName("org.tensorflow.TensorFlow");
441+
} catch (ClassNotFoundException e) {
442+
throw new RuntimeException(e);
443+
}
439444
}
440445
}

tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Server.java

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -178,6 +178,11 @@ private static void delete(TF_Server nativeHandle) {
178178
private int numJoining;
179179

180180
static {
181-
TensorFlow.init();
181+
try {
182+
// Ensure that TensorFlow native library and classes are ready to be used
183+
Class.forName("org.tensorflow.TensorFlow");
184+
} catch (ClassNotFoundException e) {
185+
throw new RuntimeException(e);
186+
}
182187
}
183188
}

tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/TensorFlow.java

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232

3333
/** Static utility methods describing the TensorFlow runtime. */
3434
public final class TensorFlow {
35+
3536
/** Returns the version of the underlying TensorFlow runtime. */
3637
public static String version() {
3738
return TF_Version().getString();
@@ -106,7 +107,7 @@ private static OpList libraryOpList(TF_Library handle) {
106107
private TensorFlow() {}
107108

108109
/** Load the TensorFlow runtime C library. */
109-
static void init() {
110+
static {
110111
try {
111112
NativeLibrary.load();
112113
} catch (Exception e) {
@@ -121,8 +122,4 @@ static void init() {
121122
throw e;
122123
}
123124
}
124-
125-
static {
126-
init();
127-
}
128125
}

tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/internal/buffer/ByteSequenceTensorBuffer.java

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
import org.bytedeco.javacpp.Loader;
2929
import org.bytedeco.javacpp.Pointer;
3030
import org.bytedeco.javacpp.PointerScope;
31+
import org.tensorflow.TensorFlow;
3132
import org.tensorflow.ndarray.buffer.DataBuffer;
3233
import org.tensorflow.internal.c_api.TF_TString;
3334
import org.tensorflow.ndarray.impl.buffer.AbstractDataBuffer;
@@ -132,4 +133,13 @@ void writeNext(byte[] bytes) {
132133
}
133134

134135
private final TF_TString data;
136+
137+
static {
138+
try {
139+
// Ensure that TensorFlow native library and classes are ready to be used
140+
Class.forName("org.tensorflow.TensorFlow");
141+
} catch (ClassNotFoundException e) {
142+
throw new RuntimeException(e);
143+
}
144+
}
135145
}

tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/WrongEnvTest.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
import static org.junit.jupiter.api.Assertions.assertTrue;
2121
import static org.junit.jupiter.api.Assertions.fail;
2222

23-
import org.junit.Test;
23+
import org.junit.jupiter.api.Test;
2424
import org.tensorflow.op.Ops;
2525
import org.tensorflow.types.TInt32;
2626

tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/BooleanMaskTest.java

Lines changed: 15 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818

1919
import static org.junit.jupiter.api.Assertions.assertEquals;
2020

21-
import org.junit.Test;
21+
import org.junit.jupiter.api.Test;
2222
import org.tensorflow.Graph;
2323
import org.tensorflow.Operand;
2424
import org.tensorflow.Session;
@@ -29,6 +29,7 @@
2929
import org.tensorflow.types.TInt32;
3030

3131
public class BooleanMaskTest {
32+
3233
@Test
3334
public void testBooleanMask(){
3435
try (Graph g = new Graph();
@@ -43,24 +44,24 @@ public void testBooleanMask(){
4344
Operand<TInt32> output1 = BooleanMask.create(scope, input, mask);
4445
Operand<TInt32> output2 = BooleanMask.create(scope, input2, mask, BooleanMask.axis(1));
4546

46-
try (TFloat32 result = (TFloat32) sess.runner().fetch(output1).run().get(0)) {
47+
try (TInt32 result = (TInt32) sess.runner().fetch(output1).run().get(0)) {
4748
// expected shape from Python tensorflow
4849
assertEquals(Shape.of(5), result.shape());
49-
assertEquals(0, result.getFloat(0));
50-
assertEquals(1, result.getFloat(1));
51-
assertEquals(4, result.getFloat(2));
52-
assertEquals(5, result.getFloat(3));
53-
assertEquals(6, result.getFloat(4));
50+
assertEquals(0, result.getInt(0));
51+
assertEquals(1, result.getInt(1));
52+
assertEquals(4, result.getInt(2));
53+
assertEquals(5, result.getInt(3));
54+
assertEquals(6, result.getInt(4));
5455
}
5556

56-
try (TFloat32 result = (TFloat32) sess.runner().fetch(output2).run().get(0)) {
57+
try (TInt32 result = (TInt32) sess.runner().fetch(output2).run().get(0)) {
5758
// expected shape from Python tensorflow
58-
assertEquals(Shape.of(5), result.shape());
59-
assertEquals(0, result.getFloat(0));
60-
assertEquals(1, result.getFloat(1));
61-
assertEquals(4, result.getFloat(2));
62-
assertEquals(5, result.getFloat(3));
63-
assertEquals(6, result.getFloat(4));
59+
assertEquals(Shape.of(1, 5), result.shape());
60+
assertEquals(0, result.getInt(0, 0));
61+
assertEquals(1, result.getInt(0, 1));
62+
assertEquals(4, result.getInt(0, 2));
63+
assertEquals(5, result.getInt(0, 3));
64+
assertEquals(6, result.getInt(0, 4));
6465
}
6566
}
6667
}

tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/BooleanMaskUpdateTest.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,8 @@
1919
import static org.junit.jupiter.api.Assertions.assertEquals;
2020

2121
import java.util.List;
22-
import org.junit.Test;
22+
23+
import org.junit.jupiter.api.Test;
2324
import org.tensorflow.Graph;
2425
import org.tensorflow.Operand;
2526
import org.tensorflow.Session;

0 commit comments

Comments
 (0)