Skip to content

Commit 1450980

Browse files
committed
Fix framework tests.
1 parent 155fdd6 commit 1450980

File tree

6 files changed

+70
-68
lines changed

6 files changed

+70
-68
lines changed

tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/SessionFunction.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,7 @@ public Map<String, Tensor> call(Map<String, Tensor> arguments) {
113113

114114
signature.getOutputs().values().forEach(x -> runner.fetch(x.name));
115115

116-
List<Tensor> results = runner.run();
116+
Session.Result results = runner.run();
117117

118118
Map<String, Tensor> outputs = new LinkedHashMap<>(results.size());
119119
int i = 0;

tensorflow-framework/src/test/java/org/tensorflow/framework/data/DatasetIteratorTest.java

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -51,15 +51,10 @@ public void testGraphIteration() {
5151

5252
int batches = 0;
5353
while (true) {
54-
try {
55-
List<?> outputs = session.runner().fetch(x).fetch(y).run();
56-
57-
try (TInt32 xBatch = (TInt32) outputs.get(0);
58-
TInt32 yBatch = (TInt32) outputs.get(1)) {
59-
assertEquals(testMatrix1.get(batches), xBatch);
60-
assertEquals(testMatrix2.get(batches), yBatch);
61-
batches++;
62-
}
54+
try (Session.Result outputs = session.runner().fetch(x).fetch(y).run()) {
55+
assertEquals(testMatrix1.get(batches), outputs.get(0));
56+
assertEquals(testMatrix2.get(batches), outputs.get(1));
57+
batches++;
6358
} catch (TFOutOfRangeException e) {
6459
break;
6560
}

tensorflow-framework/src/test/java/org/tensorflow/framework/data/MapDatasetTest.java

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -76,17 +76,11 @@ public void testGraphIteration() {
7676

7777
int batches = 0;
7878
while (true) {
79-
try {
80-
List<?> outputs = session.runner().fetch(X).fetch(y).run();
81-
82-
try (TInt32 XBatch = (TInt32) outputs.get(0);
83-
TInt32 yBatch = (TInt32) outputs.get(1)) {
84-
85-
assertEquals(mapped1.get(batches), XBatch);
86-
assertEquals(mapped2.get(batches), yBatch);
79+
try (Session.Result outputs = session.runner().fetch(X).fetch(y).run()) {
80+
assertEquals(mapped1.get(batches), outputs.get(0));
81+
assertEquals(mapped2.get(batches), outputs.get(1));
8782

8883
batches++;
89-
}
9084
} catch (TFOutOfRangeException e) {
9185
break;
9286
}

tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/impl/AssertBroadcastableTest.java

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
import org.junit.jupiter.api.Test;
1818
import org.tensorflow.Operand;
19+
import org.tensorflow.Session;
1920
import org.tensorflow.Tensor;
2021
import org.tensorflow.framework.utils.TestSession;
2122
import org.tensorflow.op.Op;
@@ -69,10 +70,10 @@ private <T extends TNumber> void testValid(
6970
Operand<T> weightsPlaceholder = tf.placeholder(type);
7071
Operand<T> valuesPlaceholder = tf.placeholder(type);
7172

72-
List<Tensor> tensors =
73-
testSession.getGraphSession().runner().fetch(weights).fetch(values).run();
74-
try (Tensor weightsTensor = tensors.get(0);
75-
Tensor valuesTensor = tensors.get(1)) {
73+
try (Session.Result tensors =
74+
testSession.getGraphSession().runner().fetch(weights).fetch(values).run()) {
75+
Tensor weightsTensor = tensors.get(0);
76+
Tensor valuesTensor = tensors.get(1);
7677
Op dynamicOp = MetricsHelper.assertBroadcastable(tf, weightsPlaceholder, valuesPlaceholder);
7778

7879
testSession

tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/impl/BroadcastWeightsTest.java

Lines changed: 41 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
import org.junit.jupiter.api.Test;
1818
import org.tensorflow.Operand;
19+
import org.tensorflow.Session;
1920
import org.tensorflow.Tensor;
2021
import org.tensorflow.framework.utils.TestSession;
2122
import org.tensorflow.op.Ops;
@@ -78,55 +79,56 @@ private <T extends TNumber> void testValid(
7879
Operand<T> weightsPlaceholder = tf.placeholder(type);
7980
Operand<T> valuesPlaceholder = tf.placeholder(type);
8081

81-
List<Tensor> tensors =
82-
testSession.getGraphSession().runner().fetch(weights).fetch(values).run();
83-
try (Tensor weightsTensor = tensors.get(0);
84-
Tensor valuesTensor = tensors.get(1)) {
82+
try (Session.Result tensors =
83+
testSession.getGraphSession().runner().fetch(weights).fetch(values).run()) {
84+
Tensor weightsTensor = tensors.get(0);
85+
Tensor valuesTensor = tensors.get(1);
8586

8687
Operand<T> dynamicOp =
8788
MetricsHelper.broadcastWeights(tf, weightsPlaceholder, valuesPlaceholder);
8889

89-
List<Tensor> result =
90+
try (Session.Result result =
9091
testSession
9192
.getGraphSession()
9293
.runner()
9394
.feed(weightsPlaceholder, weightsTensor)
9495
.feed(valuesPlaceholder, valuesTensor)
9596
.fetch(dynamicOp)
96-
.run();
97-
98-
if (expected != null) {
99-
if (type.equals(TInt32.class)) {
100-
TInt32 intT = (TInt32) result.get(0);
101-
AtomicInteger i = new AtomicInteger();
102-
intT.scalars()
103-
.forEachIndexed(
104-
(idx, f) -> assertEquals(expected[i.getAndIncrement()].intValue(), f.getInt()));
105-
} else if (type.equals(TInt64.class)) {
106-
TInt64 floatT = (TInt64) result.get(0);
107-
AtomicInteger i = new AtomicInteger();
108-
floatT
109-
.scalars()
110-
.forEachIndexed(
111-
(idx, f) -> assertEquals(expected[i.getAndIncrement()].longValue(), f.getLong()));
112-
} else if (type.equals(TFloat32.class)) {
113-
TFloat32 floatT = (TFloat32) result.get(0);
114-
AtomicInteger i = new AtomicInteger();
115-
floatT
116-
.scalars()
117-
.forEachIndexed(
118-
(idx, f) ->
119-
assertEquals(
120-
expected[i.getAndIncrement()].floatValue(), f.getFloat(), 1e-5F));
121-
} else if (type.equals(TFloat64.class)) {
122-
TFloat64 doubleT = (TFloat64) result.get(0);
123-
AtomicInteger i = new AtomicInteger();
124-
doubleT
125-
.scalars()
126-
.forEachIndexed(
127-
(idx, f) ->
128-
assertEquals(
129-
expected[i.getAndIncrement()].doubleValue(), f.getDouble(), 1e-5F));
97+
.run()) {
98+
99+
if (expected != null) {
100+
if (type.equals(TInt32.class)) {
101+
TInt32 intT = (TInt32) result.get(0);
102+
AtomicInteger i = new AtomicInteger();
103+
intT.scalars()
104+
.forEachIndexed(
105+
(idx, f) -> assertEquals(expected[i.getAndIncrement()].intValue(), f.getInt()));
106+
} else if (type.equals(TInt64.class)) {
107+
TInt64 floatT = (TInt64) result.get(0);
108+
AtomicInteger i = new AtomicInteger();
109+
floatT
110+
.scalars()
111+
.forEachIndexed(
112+
(idx, f) -> assertEquals(expected[i.getAndIncrement()].longValue(), f.getLong()));
113+
} else if (type.equals(TFloat32.class)) {
114+
TFloat32 floatT = (TFloat32) result.get(0);
115+
AtomicInteger i = new AtomicInteger();
116+
floatT
117+
.scalars()
118+
.forEachIndexed(
119+
(idx, f) ->
120+
assertEquals(
121+
expected[i.getAndIncrement()].floatValue(), f.getFloat(), 1e-5F));
122+
} else if (type.equals(TFloat64.class)) {
123+
TFloat64 doubleT = (TFloat64) result.get(0);
124+
AtomicInteger i = new AtomicInteger();
125+
doubleT
126+
.scalars()
127+
.forEachIndexed(
128+
(idx, f) ->
129+
assertEquals(
130+
expected[i.getAndIncrement()].doubleValue(), f.getDouble(), 1e-5F));
131+
}
130132
}
131133
}
132134
}

tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/GradientDescentTest.java

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44

55
import java.util.ArrayList;
66
import java.util.List;
7+
import java.util.Map;
8+
79
import org.junit.jupiter.api.AfterAll;
810
import org.junit.jupiter.api.AfterEach;
911
import org.junit.jupiter.api.BeforeAll;
@@ -189,13 +191,17 @@ public void testDeterminism() {
189191
g.importGraphDef(def);
190192
s.initialize();
191193

192-
initialized.add(
193-
s.runner()
194+
Session.Result initializationRes = s.runner()
194195
.fetch(fcWeightName)
195196
.fetch(fcBiasName)
196197
.fetch(outputWeightName)
197198
.fetch(outputBiasName)
198-
.run());
199+
.run();
200+
List<Tensor> initializedRun = new ArrayList<>();
201+
for (Map.Entry<String, Tensor> e : initializationRes) {
202+
initializedRun.add(e.getValue());
203+
}
204+
initialized.add(initializedRun);
199205

200206
TFloat32 lossVal =
201207
(TFloat32)
@@ -209,13 +215,17 @@ public void testDeterminism() {
209215
initialLoss[i] = lossVal.getFloat();
210216
lossVal.close();
211217

212-
trained.add(
213-
s.runner()
218+
Session.Result trainedRes = s.runner()
214219
.fetch(fcWeightName)
215220
.fetch(fcBiasName)
216221
.fetch(outputWeightName)
217222
.fetch(outputBiasName)
218-
.run());
223+
.run();
224+
List<Tensor> trainedRun = new ArrayList<>();
225+
for (Map.Entry<String, Tensor> e : trainedRes) {
226+
trainedRun.add(e.getValue());
227+
}
228+
trained.add(trainedRun);
219229

220230
lossVal =
221231
(TFloat32)

0 commit comments

Comments
 (0)