Skip to content

Commit 155fdd6

Browse files
committed
Adding an autocloseable result class for the output of Session.Runner.run.
1 parent d518678 commit 155fdd6

File tree

2 files changed

+141
-32
lines changed

2 files changed

+141
-32
lines changed

tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Session.java

Lines changed: 126 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -21,12 +21,19 @@
2121
import static org.tensorflow.internal.c_api.global.tensorflow.TF_SetConfig;
2222

2323
import com.google.protobuf.InvalidProtocolBufferException;
24+
25+
import java.sql.Array;
2426
import java.util.ArrayList;
2527
import java.util.Collections;
28+
import java.util.Iterator;
29+
import java.util.LinkedHashMap;
2630
import java.util.LinkedHashSet;
2731
import java.util.List;
2832
import java.util.Map;
33+
import java.util.Optional;
2934
import java.util.Set;
35+
import java.util.logging.Logger;
36+
3037
import org.bytedeco.javacpp.BytePointer;
3138
import org.bytedeco.javacpp.Pointer;
3239
import org.bytedeco.javacpp.PointerPointer;
@@ -490,13 +497,13 @@ private void doInit() {
490497
*
491498
* @return list of resulting tensors fetched by this session runner
492499
*/
493-
public List<Tensor> run() {
500+
public Result run() {
494501
doInit();
495502
return runNoInit();
496503
}
497504

498-
List<Tensor> runNoInit() {
499-
return runHelper(false).outputs;
505+
Result runNoInit() {
506+
return runHelper(false);
500507
}
501508

502509
/**
@@ -509,18 +516,19 @@ List<Tensor> runNoInit() {
509516
*
510517
* @return list of resulting tensors fetched by this session runner, with execution metadata
511518
*/
512-
public Run runAndFetchMetadata() {
519+
public Result runAndFetchMetadata() {
513520
doInit();
514521
return runHelper(true);
515522
}
516523

517-
private Run runHelper(boolean wantMetadata) {
524+
private Result runHelper(boolean wantMetadata) {
518525
TF_Tensor[] inputTensorHandles = new TF_Tensor[inputTensors.size()];
519526
TF_Operation[] inputOpHandles = new TF_Operation[inputs.size()];
520527
int[] inputOpIndices = new int[inputs.size()];
521528
TF_Operation[] outputOpHandles = new TF_Operation[outputs.size()];
522529
int[] outputOpIndices = new int[outputs.size()];
523530
TF_Operation[] targetOpHandles = new TF_Operation[targets.size()];
531+
List<String> outputNames = new ArrayList<>();
524532

525533
// It's okay to use Operation.getUnsafeNativeHandle() here since the safety depends on the
526534
// validity of the Graph and graphRef ensures that.
@@ -538,6 +546,7 @@ private Run runHelper(boolean wantMetadata) {
538546
for (Output<?> o : outputs) {
539547
outputOpHandles[idx] = (TF_Operation) o.getUnsafeNativeHandle();
540548
outputOpIndices[idx] = o.index();
549+
outputNames.add(o.name());
541550
idx++;
542551
}
543552
idx = 0;
@@ -569,10 +578,7 @@ private Run runHelper(boolean wantMetadata) {
569578
} finally {
570579
runRef.close();
571580
}
572-
Run ret = new Run();
573-
ret.outputs = outputs;
574-
ret.metadata = metadata;
575-
return ret;
581+
return new Result(outputNames,outputs,metadata);
576582
}
577583

578584
private class Reference implements AutoCloseable {
@@ -699,14 +705,117 @@ public void restore(String prefix) {
699705
}
700706

701707
/**
702-
* Output tensors and metadata obtained when executing a session.
708+
* An {@link AutoCloseable} wrapper around a {@link Map} containing {@link Tensor}s.
703709
*
704-
* <p>See {@link Runner#runAndFetchMetadata()}
710+
* <p>When this is closed it closes all the {@link Tensor}s inside it. If you maintain a
711+
* reference to a value after this object has been closed it will throw an {@link
712+
* IllegalStateException} upon access.
705713
*/
706-
public static final class Run {
714+
public static class Result implements AutoCloseable, Iterable<Map.Entry<String, Tensor>> {
715+
716+
private static final Logger logger = Logger.getLogger(Result.class.getName());
717+
718+
private final Map<String, Tensor> map;
719+
720+
private final List<Tensor> list;
721+
722+
private final RunMetadata metadata;
723+
724+
private boolean closed;
725+
726+
/**
727+
* Creates a Result from the names and values produced by {@link Session.Runner#run()}.
728+
*
729+
* @param names The output names.
730+
* @param values The output values.
731+
* @param metadata The run metadata, may be null.
732+
*/
733+
Result(List<String> names, List<Tensor> values, RunMetadata metadata) {
734+
this.map = new LinkedHashMap<>();
735+
this.list = new ArrayList<>(values);
736+
737+
if (names.size() != values.size()) {
738+
throw new IllegalArgumentException(
739+
"Expected same number of names and values, found names.length = "
740+
+ names.size()
741+
+ ", values.length = "
742+
+ values.size());
743+
}
707744

708-
/** Tensors from requested fetches. */
709-
public List<Tensor> outputs;
745+
for (int i = 0; i < names.size(); i++) {
746+
this.map.put(names.get(i), values.get(i));
747+
}
748+
this.metadata = metadata;
749+
this.closed = false;
750+
}
751+
752+
@Override
753+
public void close() {
754+
if (!closed) {
755+
closed = true;
756+
for (Tensor t : map.values()) {
757+
t.close();
758+
}
759+
} else {
760+
logger.warning("Closing an already closed Result");
761+
}
762+
}
763+
764+
@Override
765+
public Iterator<Map.Entry<String, Tensor>> iterator() {
766+
if (!closed) {
767+
return map.entrySet().iterator();
768+
} else {
769+
throw new IllegalStateException("Result is closed");
770+
}
771+
}
772+
773+
/**
774+
* Gets the value from the container at the specified index.
775+
*
776+
* <p>Throws {@link IllegalStateException} if the container has been closed, and {@link
777+
* IndexOutOfBoundsException} if the index is invalid.
778+
*
779+
* @param index The index to lookup.
780+
* @return The value at the index.
781+
*/
782+
public Tensor get(int index) {
783+
if (!closed) {
784+
return list.get(index);
785+
} else {
786+
throw new IllegalStateException("Result is closed");
787+
}
788+
}
789+
790+
/**
791+
* Returns the number of outputs in this Result.
792+
*
793+
* @return The number of outputs.
794+
*/
795+
public int size() {
796+
return map.size();
797+
}
798+
799+
/**
800+
* Gets the value from the container assuming it's not been closed.
801+
*
802+
* <p>Throws {@link IllegalStateException} if the container has been closed.
803+
*
804+
* @param key The key to lookup.
805+
* @return Optional.of the value if it exists.
806+
*/
807+
public Optional<Tensor> get(String key) {
808+
if (!closed) {
809+
Tensor value = map.get(key);
810+
if (value != null) {
811+
return Optional.of(value);
812+
} else {
813+
return Optional.empty();
814+
}
815+
} else {
816+
throw new IllegalStateException("Result is closed");
817+
}
818+
}
710819

711820
/**
712821
* Metadata about the run.
@@ -715,7 +824,9 @@ public static final class Run {
715824
* href="https://www.tensorflow.org/code/tensorflow/core/protobuf/config.proto">RunMetadata
716825
* protocol buffer</a>.
717826
*/
718-
public RunMetadata metadata;
827+
public Optional<RunMetadata> getMetadata() {
828+
return Optional.ofNullable(metadata);
829+
}
719830
}
720831

721832
Graph graph() {

tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/SessionTest.java

Lines changed: 15 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,8 @@
2626
import java.nio.file.Path;
2727
import java.util.Comparator;
2828
import java.util.Iterator;
29+
import java.util.Optional;
30+
2931
import org.junit.jupiter.api.Test;
3032
import org.tensorflow.ndarray.NdArrays;
3133
import org.tensorflow.ndarray.Shape;
@@ -38,6 +40,7 @@
3840
import org.tensorflow.op.math.Add;
3941
import org.tensorflow.proto.framework.ConfigProto;
4042
import org.tensorflow.proto.framework.GraphDef;
43+
import org.tensorflow.proto.framework.RunMetadata;
4144
import org.tensorflow.proto.framework.RunOptions;
4245
import org.tensorflow.types.TFloat32;
4346
import org.tensorflow.types.TInt32;
@@ -69,8 +72,7 @@ public void runUsingOperationNames() {
6972
Ops tf = Ops.create(g);
7073
transpose_A_times_X(tf, new int[][] {{2}, {3}});
7174
try (TInt32 x = TInt32.tensorOf(StdArrays.ndCopyOf(new int[][] {{5}, {7}}));
72-
AutoCloseableList<Tensor> outputs =
73-
new AutoCloseableList<>(s.runner().feed("X", x).fetch("Y").run())) {
75+
Session.Result outputs = s.runner().feed("X", x).fetch("Y").run()) {
7476
assertEquals(1, outputs.size());
7577
assertEquals(31, ((TInt32) outputs.get(0)).getInt(0, 0));
7678
}
@@ -86,8 +88,7 @@ public void runUsingOperationHandles() {
8688
Output<TInt32> feed = g.operation("X").output(0);
8789
Output<TInt32> fetch = g.operation("Y").output(0);
8890
try (TInt32 x = TInt32.tensorOf(StdArrays.ndCopyOf(new int[][] {{5}, {7}}));
89-
AutoCloseableList<Tensor> outputs =
90-
new AutoCloseableList<>(s.runner().feed(feed, x).fetch(fetch).run())) {
91+
Session.Result outputs = s.runner().feed(feed, x).fetch(fetch).run()) {
9192
assertEquals(1, outputs.size());
9293
assertEquals(31, ((TInt32) outputs.get(0)).getInt(0, 0));
9394
}
@@ -124,20 +125,20 @@ public void runWithMetadata() {
124125
Ops tf = Ops.create(g);
125126
transpose_A_times_X(tf, new int[][] {{2}, {3}});
126127
try (TInt32 x = TInt32.tensorOf(StdArrays.ndCopyOf(new int[][] {{5}, {7}}))) {
127-
Session.Run result =
128+
Session.Result result =
128129
s.runner()
129130
.feed("X", x)
130131
.fetch("Y")
131132
.setOptions(fullTraceRunOptions())
132133
.runAndFetchMetadata();
133134
// Sanity check on outputs.
134-
AutoCloseableList<Tensor> outputs = new AutoCloseableList<>(result.outputs);
135-
assertEquals(1, outputs.size());
136-
assertEquals(31, ((TInt32) outputs.get(0)).getInt(0, 0));
135+
assertEquals(1, result.size());
136+
assertEquals(31, ((TInt32) result.get(0)).getInt(0, 0));
137137
// Sanity check on metadata
138-
assertNotNull(result.metadata);
139-
assertTrue(result.metadata.hasStepStats(), result.metadata.toString());
140-
outputs.close();
138+
Optional<RunMetadata> metadata = result.getMetadata();
139+
assertTrue(metadata.isPresent());
140+
assertTrue(metadata.get().hasStepStats(), metadata.get().toString());
141+
result.close();
141142
}
142143
}
143144
}
@@ -149,8 +150,7 @@ public void runMultipleOutputs() {
149150
Ops tf = Ops.create(g);
150151
tf.withName("c1").constant(2718);
151152
tf.withName("c2").constant(31415);
152-
AutoCloseableList<Tensor> outputs =
153-
new AutoCloseableList<>(s.runner().fetch("c2").fetch("c1").run());
153+
Session.Result outputs = s.runner().fetch("c2").fetch("c1").run();
154154
assertEquals(2, outputs.size());
155155
assertEquals(31415, ((TInt32) outputs.get(0)).getInt());
156156
assertEquals(2718, ((TInt32) outputs.get(1)).getInt());
@@ -227,10 +227,8 @@ public void saveAndRestore() throws IOException {
227227
restoredGraph.importGraphDef(graphDef);
228228
try (Session restoredSession = new Session(restoredGraph)) {
229229
restoredSession.restore(testFolder.resolve("checkpoint").toString());
230-
try (AutoCloseableList<Tensor> oldList =
231-
new AutoCloseableList<>(s.runner().fetch("x").fetch("y").run());
232-
AutoCloseableList<Tensor> newList =
233-
new AutoCloseableList<>(restoredSession.runner().fetch("x").fetch("y").run())) {
230+
try (Session.Result oldList = s.runner().fetch("x").fetch("y").run();
231+
Session.Result newList = restoredSession.runner().fetch("x").fetch("y").run()) {
234232
assertEquals(oldList.get(0), newList.get(0));
235233
assertEquals(oldList.get(1), newList.get(1));
236234
}

0 commit comments

Comments
 (0)