Skip to content

Commit 9601138

Browse files
committed
Adjust GraphOperation#input to not require a graph lock
Signed-off-by: Ryan Nett <JNett96@gmail.com>
1 parent be4840c commit 9601138

File tree

2 files changed

+8
-13
lines changed

2 files changed

+8
-13
lines changed

tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/GraphOperation.java

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -213,8 +213,7 @@ public Output<?> input(int idx) {
213213
try (PointerScope scope = new PointerScope()) {
214214
TF_Input input = new TF_Input().oper(getUnsafeNativeHandle()).index(idx);
215215
TF_Output output = TF_OperationInput(input);
216-
String opName = TF_OperationName(output.oper()).getString();
217-
return graph.operation(opName).output(output.index());
216+
return new GraphOperation(graph, output.oper()).output(output.index());
218217
}
219218
}
220219

tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/CustomGradientTest.java

Lines changed: 7 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,6 @@
2222
import static org.junit.jupiter.api.Assertions.assertTrue;
2323

2424
import java.util.Arrays;
25-
import java.util.Spliterator;
26-
import java.util.Spliterators;
27-
import java.util.stream.Collectors;
28-
import java.util.stream.StreamSupport;
2925
import org.junit.jupiter.api.Test;
3026
import org.tensorflow.ndarray.index.Indices;
3127
import org.tensorflow.op.Ops;
@@ -68,13 +64,13 @@ public void testCustomGradient() {
6864
assertEquals(1, grads0.length);
6965
assertEquals(DataType.DT_FLOAT, grads0[0].dataType());
7066

71-
System.out.println(
72-
StreamSupport.stream(
73-
Spliterators.spliteratorUnknownSize(
74-
g.operations(), Spliterator.ORDERED | Spliterator.NONNULL),
75-
false)
76-
.map(GraphOperation::name)
77-
.collect(Collectors.toList()));
67+
// System.out.println(
68+
// StreamSupport.stream(
69+
// Spliterators.spliteratorUnknownSize(
70+
// g.operations(), Spliterator.ORDERED | Spliterator.NONNULL),
71+
// false)
72+
// .map(GraphOperation::name)
73+
// .collect(Collectors.toList()));
7874

7975
try (TFloat32 c1 = TFloat32.vectorOf(3.0f, 2.0f, 1.0f, 0.0f);
8076
AutoCloseableList<Tensor> outputs =

0 commit comments

Comments
 (0)