Skip to content

Commit 5baabb6

Browse files
committed
Graph native pointers
Signed-off-by: Ryan Nett <rnett@calpoly.edu>
1 parent e8ab5ef commit 5baabb6

File tree

7 files changed

+94
-40
lines changed

7 files changed

+94
-40
lines changed
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
// Targeted by JavaCPP version 1.5.4: DO NOT EDIT THIS FILE
2+
3+
package org.tensorflow.internal.c_api;
4+
5+
import java.nio.*;
6+
import org.bytedeco.javacpp.*;
7+
import org.bytedeco.javacpp.annotation.*;
8+
9+
import static org.tensorflow.internal.c_api.global.tensorflow.*;
10+
11+
12+
@Name("tensorflow::Graph") @Opaque @Properties(inherit = org.tensorflow.internal.c_api.presets.tensorflow.class)
13+
public class NativeGraphPointer extends Pointer {
14+
/** Empty constructor. Calls {@code super((Pointer)null)}. */
15+
public NativeGraphPointer() { super((Pointer)null); }
16+
/** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */
17+
public NativeGraphPointer(Pointer p) { super(p); }
18+
}

tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/internal/c_api/TF_Graph.java

Lines changed: 34 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -8,19 +8,42 @@
88

99
import static org.tensorflow.internal.c_api.global.tensorflow.*;
1010

11+
// Parsed from tensorflow/c/c_api_internal.h
1112

12-
// TODO(jeff,sanjay):
13-
// - export functions to set Config fields
14-
15-
// --------------------------------------------------------------------------
16-
// The new graph construction API, still under development.
17-
18-
// Represents a computation graph. Graphs may be shared between sessions.
19-
// Graphs are thread-safe when used as directed below.
20-
@Opaque @Properties(inherit = org.tensorflow.internal.c_api.presets.tensorflow.class)
13+
@NoOffset @Properties(inherit = org.tensorflow.internal.c_api.presets.tensorflow.class)
2114
public class TF_Graph extends org.tensorflow.internal.c_api.AbstractTF_Graph {
22-
/** Empty constructor. Calls {@code super((Pointer)null)}. */
23-
public TF_Graph() { super((Pointer)null); }
15+
static { Loader.load(); }
2416
/** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */
2517
public TF_Graph(Pointer p) { super(p); }
18+
19+
20+
21+
public native @MemberGetter @ByRef NativeGraphPointer graph();
22+
23+
// Runs shape inference.
24+
25+
26+
// Maps from name of an operation to the Node* in 'graph'.
27+
28+
29+
// The keys of this map are all the active sessions using this graph. Each
30+
// value records whether the graph has been mutated since the corresponding
31+
// session has been run (this is detected in RecordMutation function). If the
32+
// string is empty, no mutation has occurred. Otherwise the string is a
33+
// description of the mutation suitable for returning to the user.
34+
//
35+
// Sessions are added to this map in TF_NewSession, and removed in
36+
// TF_DeleteSession.
37+
// TF_Graph may only / must be deleted when
38+
// sessions.size() == 0 && delete_requested == true
39+
//
40+
// TODO(b/74949947): mutations currently trigger a warning instead of a bad
41+
// status, this should be reverted when possible.
42+
43+
// set true by TF_DeleteGraph
44+
45+
// Used to link graphs contained in TF_WhileParams to the parent graph that
46+
// will eventually contain the full while loop.
47+
public native TF_Graph parent(); public native TF_Graph parent(TF_Graph setter);
48+
public native TF_Output parent_inputs(); public native TF_Graph parent_inputs(TF_Output setter);
2649
}

tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/internal/c_api/TF_OperationDescription.java

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,6 @@
88

99
import static org.tensorflow.internal.c_api.global.tensorflow.*;
1010

11-
// Parsed from tensorflow/c/c_api_internal.h
12-
1311
@NoOffset @Properties(inherit = org.tensorflow.internal.c_api.presets.tensorflow.class)
1412
public class TF_OperationDescription extends Pointer {
1513
static { Loader.load(); }

tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/internal/c_api/TF_Scope.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -167,7 +167,7 @@ public class TF_Scope extends Pointer {
167167
public native @Cast("bool") boolean ok();
168168

169169
// TODO(skyewm): Graph is not part of public API
170-
170+
public native NativeGraphPointer graph();
171171

172172
// TODO(skyewm): Graph is not part of public API
173173

tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/internal/c_api/global/tensorflow.java

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -933,9 +933,15 @@ public static native void TF_SetConfig(TF_SessionOptions options,
933933

934934
// Destroy an options object.
935935
public static native void TF_DeleteSessionOptions(TF_SessionOptions arg0);
936-
// Targeting ../TF_Graph.java
937936

937+
// TODO(jeff,sanjay):
938+
// - export functions to set Config fields
939+
940+
// --------------------------------------------------------------------------
941+
// The new graph construction API, still under development.
938942

943+
// Represents a computation graph. Graphs may be shared between sessions.
944+
// Graphs are thread-safe when used as directed below.
939945

940946
// Return a new graph object.
941947
public static native TF_Graph TF_NewGraph();
@@ -4429,6 +4435,9 @@ public static native void TFE_ContextExportRunMetadata(TFE_Context ctx,
44294435
// #include "tensorflow/core/common_runtime/graph_constructor.h"
44304436
// #include "tensorflow/core/lib/core/status.h"
44314437
// #include "tensorflow/core/lib/gtl/array_slice.h"
4438+
// Targeting ../NativeGraphPointer.java
4439+
4440+
44324441
// Targeting ../NodeBuilder.java
44334442

44344443

@@ -4450,6 +4459,9 @@ public static native void TFE_ContextExportRunMetadata(TFE_Context ctx,
44504459
// #endif // TENSORFLOW_CC_FRAMEWORK_SCOPE_H_
44514460

44524461

4462+
// Targeting ../TF_Graph.java
4463+
4464+
44534465
// Targeting ../TF_OperationDescription.java
44544466

44554467

tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Graph.java

Lines changed: 20 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -27,28 +27,15 @@
2727
import static org.tensorflow.internal.c_api.global.tensorflow.TF_NewWhile;
2828

2929
import com.google.protobuf.InvalidProtocolBufferException;
30-
import java.util.ArrayDeque;
31-
import java.util.ArrayList;
32-
import java.util.Arrays;
33-
import java.util.Collections;
34-
import java.util.Iterator;
35-
import java.util.LinkedHashSet;
36-
import java.util.List;
37-
import java.util.Queue;
38-
import java.util.Set;
30+
31+
import java.util.*;
3932
import java.util.stream.Collectors;
4033
import org.bytedeco.javacpp.BytePointer;
4134
import org.bytedeco.javacpp.Pointer;
4235
import org.bytedeco.javacpp.PointerScope;
4336
import org.bytedeco.javacpp.SizeTPointer;
4437
import org.tensorflow.exceptions.TensorFlowException;
45-
import org.tensorflow.internal.c_api.TF_Buffer;
46-
import org.tensorflow.internal.c_api.TF_Graph;
47-
import org.tensorflow.internal.c_api.TF_ImportGraphDefOptions;
48-
import org.tensorflow.internal.c_api.TF_Operation;
49-
import org.tensorflow.internal.c_api.TF_Output;
50-
import org.tensorflow.internal.c_api.TF_Status;
51-
import org.tensorflow.internal.c_api.TF_WhileParams;
38+
import org.tensorflow.internal.c_api.*;
5239
import org.tensorflow.ndarray.StdArrays;
5340
import org.tensorflow.op.Op;
5441
import org.tensorflow.op.Ops;
@@ -78,14 +65,15 @@ public final class Graph implements ExecutionEnvironment, AutoCloseable {
7865
* Create an empty Graph.
7966
*/
8067
public Graph() {
81-
nativeHandle = allocate();
68+
this(allocate());
8269
}
8370

8471
/**
8572
* Create a Graph from an existing handle (takes ownership).
8673
*/
8774
Graph(TF_Graph nativeHandle) {
8875
this.nativeHandle = nativeHandle;
76+
allGraphs.add(this);
8977
}
9078

9179
Graph(TF_Graph nativeHandle, SaverDef saverDef) {
@@ -1069,6 +1057,21 @@ private static SaverDef addVariableSaver(Graph graph) {
10691057
.build();
10701058
}
10711059

1060+
private static Set<Graph> allGraphs = Collections.newSetFromMap(new WeakHashMap<>());
1061+
1062+
/**
1063+
* Find the graph with the matching underlying native pointer.
1064+
* @return the graph if there is one, else null.
1065+
*/
1066+
Graph findGraphForPointer(NativeGraphPointer pointer){
1067+
for(Graph g : allGraphs){
1068+
if(g.nativeHandle.graph().equals(pointer)){
1069+
return g;
1070+
}
1071+
}
1072+
return null;
1073+
}
1074+
10721075
static {
10731076
TensorFlow.init();
10741077
}

tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/internal/c_api/presets/tensorflow.java

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -201,19 +201,20 @@ public void init(ClassProperties properties) {
201201

202202
public void map(InfoMap infoMap) {
203203
infoMap
204-
.put(new Info("TF_OperationDescription").pointerTypes("TF_OperationDescription").purify())
205204
.put(new Info("c_api_internal.h")
206-
.linePatterns("struct TF_OperationDescription \\{", "\\};"))
207-
.put(new Info("TF_CAPI_EXPORT", "TF_Bool").cppTypes().annotations())
205+
.linePatterns("struct TF_OperationDescription \\{", "\\};",
206+
"struct TF_Graph \\{", "\\};"))
207+
.put(new Info("TF_CAPI_EXPORT", "TF_Bool", "TF_GUARDED_BY").cppTypes().annotations())
208208
.put(new Info("TF_Buffer::data").javaText("public native @Const Pointer data(); public native TF_Buffer data(Pointer data);"))
209209
.put(new Info("TF_Status").pointerTypes("TF_Status").base("org.tensorflow.internal.c_api.AbstractTF_Status"))
210210
.put(new Info("TF_Buffer").pointerTypes("TF_Buffer").base("org.tensorflow.internal.c_api.AbstractTF_Buffer"))
211211
.put(new Info("TF_Tensor").pointerTypes("TF_Tensor").base("org.tensorflow.internal.c_api.AbstractTF_Tensor"))
212212
.put(new Info("TF_Session").pointerTypes("TF_Session").base("org.tensorflow.internal.c_api.AbstractTF_Session"))
213213
.put(new Info("TF_SessionOptions").pointerTypes("TF_SessionOptions").base("org.tensorflow.internal.c_api.AbstractTF_SessionOptions"))
214-
.put(new Info("TF_Graph").pointerTypes("TF_Graph").base("org.tensorflow.internal.c_api.AbstractTF_Graph"))
215-
.put(new Info("TF_Graph::graph").javaText("public native @MemberGetter @ByRef Graph graph();"))
216-
.put(new Info("TF_Graph::refiner").javaText("public native @MemberGetter @ByRef ShapeRefiner refiner();"))
214+
.put(new Info("TF_Graph").pointerTypes("TF_Graph").base("org.tensorflow.internal.c_api.AbstractTF_Graph").purify())
215+
.put(new Info("tensorflow::Graph").javaNames("NativeGraphPointer"))
216+
.put(new Info("TF_Graph::graph").javaText("public native @MemberGetter @ByRef NativeGraphPointer graph();"))
217+
.put(new Info("TF_Graph::refiner", "TF_Graph::mu", "TF_Graph::name_map", "TF_Graph::sessions", "TF_Graph::delete_requested").skip())
217218
.put(new Info("TF_ImportGraphDefOptions").pointerTypes("TF_ImportGraphDefOptions").base("org.tensorflow.internal.c_api.AbstractTF_ImportGraphDefOptions"))
218219
.put(new Info("TF_Operation", "TF_WhileParams", "TFE_MonitoringCounterCell", "TFE_MonitoringSamplerCell",
219220
"TFE_MonitoringCounter0", "TFE_MonitoringCounter1", "TFE_MonitoringCounter2",
@@ -235,16 +236,15 @@ public void map(InfoMap infoMap) {
235236
.put(new Info("TFE_Op::operation").javaText("@MemberGetter public native @ByRef EagerOperation operation();"))
236237
.put(new Info("TFE_TensorHandle").pointerTypes("TFE_TensorHandle").base("org.tensorflow.internal.c_api.AbstractTFE_TensorHandle"))
237238
.put(new Info("TF_ShapeInferenceContextDimValueKnown", "TFE_NewTensorHandle(const tensorflow::Tensor&, TF_Status*)").skip())
239+
.put(new Info("TF_OperationDescription").pointerTypes("TF_OperationDescription").purify())
238240
.put(new Info("tensorflow::Scope").javaNames("TF_Scope").pointerTypes("TF_Scope"))
239241
.put(new Info("tensorflow::NodeBuilder").pointerTypes("NodeBuilder"))
240242
.put(new Info("tensorflow::string", "absl::string_view", "tensorflow::StringPiece").annotations("@StdString").valueTypes("BytePointer", "String").pointerTypes("BytePointer"))
241243
.put(new Info("absl::Span", "tensorflow::gtl::ArraySlice").annotations("@Span"))
242244
.put(new Info("tensorflow::Output").javaNames("TF_Output").cast())
243245
.put(new Info("tensorflow::Operation").javaNames("TF_Operation").cast())
244246
.put(new Info("tensorflow::CompositeOpScopes",
245-
"tensorflow::Graph",
246247
"tensorflow::GraphDef",
247-
"tensorflow::Scope::graph",
248248
"tensorflow::Scope::graph_as_shared_ptr",
249249
"tensorflow::Scope::ToGraphDef",
250250
"tensorflow::Scope::ToGraph",

0 commit comments

Comments
 (0)