2121import static org .tensorflow .internal .c_api .global .tensorflow .TF_SetConfig ;
2222
2323import com .google .protobuf .InvalidProtocolBufferException ;
24+
25+ import java .sql .Array ;
2426import java .util .ArrayList ;
2527import java .util .Collections ;
28+ import java .util .Iterator ;
29+ import java .util .LinkedHashMap ;
2630import java .util .LinkedHashSet ;
2731import java .util .List ;
2832import java .util .Map ;
33+ import java .util .Optional ;
2934import java .util .Set ;
35+ import java .util .logging .Logger ;
36+
3037import org .bytedeco .javacpp .BytePointer ;
3138import org .bytedeco .javacpp .Pointer ;
3239import org .bytedeco .javacpp .PointerPointer ;
@@ -490,13 +497,13 @@ private void doInit() {
490497 *
491498 * @return list of resulting tensors fetched by this session runner
492499 */
493- public List < Tensor > run () {
500+ public Result run () {
494501 doInit ();
495502 return runNoInit ();
496503 }
497504
498- List < Tensor > runNoInit () {
499- return runHelper (false ). outputs ;
505+ Result runNoInit () {
506+ return runHelper (false );
500507 }
501508
502509 /**
@@ -509,18 +516,19 @@ List<Tensor> runNoInit() {
509516 *
510517 * @return list of resulting tensors fetched by this session runner, with execution metadata
511518 */
512- public Run runAndFetchMetadata () {
519+ public Result runAndFetchMetadata () {
513520 doInit ();
514521 return runHelper (true );
515522 }
516523
517- private Run runHelper (boolean wantMetadata ) {
524+ private Result runHelper (boolean wantMetadata ) {
518525 TF_Tensor [] inputTensorHandles = new TF_Tensor [inputTensors .size ()];
519526 TF_Operation [] inputOpHandles = new TF_Operation [inputs .size ()];
520527 int [] inputOpIndices = new int [inputs .size ()];
521528 TF_Operation [] outputOpHandles = new TF_Operation [outputs .size ()];
522529 int [] outputOpIndices = new int [outputs .size ()];
523530 TF_Operation [] targetOpHandles = new TF_Operation [targets .size ()];
531+ List <String > outputNames = new ArrayList <>();
524532
525533 // It's okay to use Operation.getUnsafeNativeHandle() here since the safety depends on the
526534 // validity of the Graph and graphRef ensures that.
@@ -538,6 +546,7 @@ private Run runHelper(boolean wantMetadata) {
538546 for (Output <?> o : outputs ) {
539547 outputOpHandles [idx ] = (TF_Operation ) o .getUnsafeNativeHandle ();
540548 outputOpIndices [idx ] = o .index ();
549+ outputNames .add (o .name ());
541550 idx ++;
542551 }
543552 idx = 0 ;
@@ -569,10 +578,7 @@ private Run runHelper(boolean wantMetadata) {
569578 } finally {
570579 runRef .close ();
571580 }
572- Run ret = new Run ();
573- ret .outputs = outputs ;
574- ret .metadata = metadata ;
575- return ret ;
581+ return new Result (outputNames ,outputs ,metadata );
576582 }
577583
578584 private class Reference implements AutoCloseable {
@@ -699,14 +705,117 @@ public void restore(String prefix) {
699705 }
700706
701707 /**
702- * Output tensors and metadata obtained when executing a session .
708+ * An {@link AutoCloseable} wrapper around a {@link Map} containing {@link Tensor}s .
703709 *
704- * <p>See {@link Runner#runAndFetchMetadata()}
710+ * <p>When this is closed it closes all the {@link Tensor}s inside it. If you maintain a
711+ * reference to a value after this object has been closed it will throw an {@link
712+ * IllegalStateException} upon access.
705713 */
706- public static final class Run {
714+ public static class Result implements AutoCloseable , Iterable <Map .Entry <String , Tensor >> {
715+
716+ private static final Logger logger = Logger .getLogger (Result .class .getName ());
717+
718+ private final Map <String , Tensor > map ;
719+
720+ private final List <Tensor > list ;
721+
722+ private final RunMetadata metadata ;
723+
724+ private boolean closed ;
725+
726+ /**
727+ * Creates a Result from the names and values produced by {@link Session.Runner#run()}.
728+ *
729+ * @param names The output names.
730+ * @param values The output values.
731+ * @param metadata The run metadata, may be null.
732+ */
733+ Result (List <String > names , List <Tensor > values , RunMetadata metadata ) {
734+ this .map = new LinkedHashMap <>();
735+ this .list = new ArrayList <>(values );
736+
737+ if (names .size () != values .size ()) {
738+ throw new IllegalArgumentException (
739+ "Expected same number of names and values, found names.length = "
740+ + names .size ()
741+ + ", values.length = "
742+ + values .size ());
743+ }
707744
708- /** Tensors from requested fetches. */
709- public List <Tensor > outputs ;
745+ for (int i = 0 ; i < names .size (); i ++) {
746+ this .map .put (names .get (i ), values .get (i ));
747+ }
748+ this .metadata = metadata ;
749+ this .closed = false ;
750+ }
751+
752+ @ Override
753+ public void close () {
754+ if (!closed ) {
755+ closed = true ;
756+ for (Tensor t : map .values ()) {
757+ t .close ();
758+ }
759+ } else {
760+ logger .warning ("Closing an already closed Result" );
761+ }
762+ }
763+
764+ @ Override
765+ public Iterator <Map .Entry <String , Tensor >> iterator () {
766+ if (!closed ) {
767+ return map .entrySet ().iterator ();
768+ } else {
769+ throw new IllegalStateException ("Result is closed" );
770+ }
771+ }
772+
773+ /**
774+ * Gets the value from the container at the specified index.
775+ *
776+ * <p>Throws {@link IllegalStateException} if the container has been closed, and {@link
777+ * IndexOutOfBoundsException} if the index is invalid.
778+ *
779+ * @param index The index to lookup.
780+ * @return The value at the index.
781+ */
782+ public Tensor get (int index ) {
783+ if (!closed ) {
784+ return list .get (index );
785+ } else {
786+ throw new IllegalStateException ("Result is closed" );
787+ }
788+ }
789+
790+ /**
791+ * Returns the number of outputs in this Result.
792+ *
793+ * @return The number of outputs.
794+ */
795+ public int size () {
796+ return map .size ();
797+ }
798+
799+ /**
800+ * Gets the value from the container assuming it's not been closed.
801+ *
802+ * <p>Throws {@link IllegalStateException} if the container has been closed.
803+ *
804+ * @param key The key to lookup.
805+ * @return Optional.of the value if it exists.
806+ */
807+ public Optional <Tensor > get (String key ) {
808+ if (!closed ) {
809+ Tensor value = map .get (key );
810+ if (value != null ) {
811+ return Optional .of (value );
812+ } else {
813+ return Optional .empty ();
814+ }
815+ } else {
816+ throw new IllegalStateException ("Result is closed" );
817+ }
818+ }
710819
711820 /**
712821 * Metadata about the run.
@@ -715,7 +824,9 @@ public static final class Run {
715824 * href="https://www.tensorflow.org/code/tensorflow/core/protobuf/config.proto">RunMetadata
716825 * protocol buffer</a>.
717826 */
718- public RunMetadata metadata ;
827+ public Optional <RunMetadata > getMetadata () {
828+ return Optional .ofNullable (metadata );
829+ }
719830 }
720831
721832 Graph graph () {
0 commit comments