Skip to content

Commit e8ab5ef

Browse files
committed
NativeScope class
Signed-off-by: Ryan Nett <rnett@calpoly.edu>
1 parent 99ea4f8 commit e8ab5ef

File tree

4 files changed

+144
-2
lines changed

4 files changed

+144
-2
lines changed

tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/GraphOperation.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -331,7 +331,7 @@ public Set<GraphOperation> controlConsumers() {
331331
}
332332

333333

334-
TF_Operation getUnsafeNativeHandle() {
334+
public TF_Operation getUnsafeNativeHandle() {
335335
return unsafeNativeHandle;
336336
}
337337

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
/*
2+
Copyright 2021 The TensorFlow Authors. All Rights Reserved.
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
http://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
==============================================================================
16+
*/
17+
package org.tensorflow.op;
18+
19+
import org.tensorflow.DeviceSpec;
20+
import org.tensorflow.ExecutionEnvironment;
21+
import org.tensorflow.OperationBuilder;
22+
23+
public interface IScope {
24+
25+
ExecutionEnvironment env();
26+
27+
IScope withSubScope(String childScopeName);
28+
29+
IScope withName(String opName);
30+
31+
IScope withNameAsSubScope(String defaultName);
32+
33+
IScope withDevice(DeviceSpec deviceSpec);
34+
35+
String makeOpName(String defaultName);
36+
37+
IScope withControlDependencies(Iterable<Op> controls);
38+
39+
OperationBuilder apply(OperationBuilder builder);
40+
}
Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
/*
2+
Copyright 2021 The TensorFlow Authors. All Rights Reserved.
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
http://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
==============================================================================
16+
*/
17+
package org.tensorflow.op;
18+
19+
import org.bytedeco.javacpp.PointerPointer;
20+
import org.tensorflow.*;
21+
import org.tensorflow.internal.c_api.TF_Operation;
22+
import org.tensorflow.internal.c_api.TF_Scope;
23+
24+
import java.util.List;
25+
import java.util.stream.Collectors;
26+
import java.util.stream.StreamSupport;
27+
28+
public final class NativeScope implements IScope {
29+
30+
@Override
31+
public ExecutionEnvironment env() {
32+
return graph;
33+
}
34+
35+
@Override
36+
public NativeScope withSubScope(String childScopeName) {
37+
return new NativeScope(nativeScope.NewSubScope(childScopeName), graph);
38+
}
39+
40+
@Override
41+
public NativeScope withName(String opName) {
42+
return new NativeScope(nativeScope, graph, opName);
43+
}
44+
45+
@Override
46+
public NativeScope withNameAsSubScope(String defaultName) {
47+
return withSubScope(opName);
48+
}
49+
50+
@Override
51+
public NativeScope withDevice(DeviceSpec deviceSpec) {
52+
return new NativeScope(nativeScope.WithDevice(deviceSpec.toString()), graph);
53+
}
54+
55+
@Override
56+
public String makeOpName(String defaultName) {
57+
String name = opName != null ? opName : defaultName;
58+
return nativeScope.GetUniqueNameForOp(name);
59+
}
60+
61+
@Override
62+
public NativeScope withControlDependencies(Iterable<Op> controls) {
63+
List<Op> controlDeps = StreamSupport.stream(controls.spliterator(), false).collect(Collectors.toList());
64+
PointerPointer<TF_Operation> ops = new PointerPointer<TF_Operation>(controlDeps.size());
65+
66+
for(int i = 0 ; i < controlDeps.size() ; i++){
67+
Operation op = controlDeps.get(i).op();
68+
if(!(op instanceof GraphOperation))
69+
throw new IllegalArgumentException("Can only add graph ops as control dependencies");
70+
ops.put(i, (((GraphOperation) op).getUnsafeNativeHandle()));
71+
}
72+
73+
return new NativeScope(nativeScope.WithControlDependencies(new TF_Operation(ops)), graph);
74+
}
75+
76+
@Override
77+
public OperationBuilder apply(OperationBuilder builder) {
78+
return builder;
79+
}
80+
81+
NativeScope(TF_Scope nativeScope, Graph graph){
82+
this(nativeScope, graph, null);
83+
}
84+
85+
private NativeScope(TF_Scope nativeScope, Graph graph, String opName){
86+
this.graph = graph;
87+
this.nativeScope = nativeScope;
88+
this.opName = opName;
89+
}
90+
91+
private final Graph graph;
92+
private final TF_Scope nativeScope;
93+
private final String opName;
94+
}

tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/Scope.java

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@
7575
*
7676
* <p>Scope objects are <b>not</b> thread-safe.
7777
*/
78-
public final class Scope {
78+
public final class Scope implements IScope {
7979

8080
/**
8181
* Create a new top-level scope.
@@ -89,6 +89,7 @@ public Scope(ExecutionEnvironment env) {
8989
/**
9090
* Returns the execution environment used by this scope.
9191
*/
92+
@Override
9293
public ExecutionEnvironment env() {
9394
return env;
9495
}
@@ -105,6 +106,7 @@ public ExecutionEnvironment env() {
105106
* @return a new subscope
106107
* @throws IllegalArgumentException if the name is invalid
107108
*/
109+
@Override
108110
public Scope withSubScope(String childScopeName) {
109111
return new Scope(env, nameScope.withSubScope(childScopeName), controlDependencies, deviceSpec);
110112
}
@@ -121,6 +123,7 @@ public Scope withSubScope(String childScopeName) {
121123
* @return a new Scope that uses opName for operations.
122124
* @throws IllegalArgumentException if the name is invalid
123125
*/
126+
@Override
124127
public Scope withName(String opName) {
125128
return new Scope(env, nameScope.withName(opName), controlDependencies, deviceSpec);
126129
}
@@ -140,6 +143,7 @@ public Scope withName(String opName) {
140143
* @return a new subscope
141144
* @throws IllegalArgumentException if the name is invalid
142145
*/
146+
@Override
143147
public Scope withNameAsSubScope(String defaultName){
144148
return new Scope(env, nameScope.withSubScope(nameScope.makeOpName(defaultName)), controlDependencies, deviceSpec);
145149
}
@@ -153,6 +157,7 @@ public Scope withNameAsSubScope(String defaultName){
153157
* @param deviceSpec device specification for an operator in the returned scope
154158
* @return a new Scope that uses opName for operations.
155159
*/
160+
@Override
156161
public Scope withDevice(DeviceSpec deviceSpec) {
157162
return new Scope(env, nameScope, controlDependencies, deviceSpec);
158163
}
@@ -177,6 +182,7 @@ public Scope withDevice(DeviceSpec deviceSpec) {
177182
* @return unique name for the operator.
178183
* @throws IllegalArgumentException if the default name is invalid.
179184
*/
185+
@Override
180186
public String makeOpName(String defaultName) {
181187
return nameScope.makeOpName(defaultName);
182188
}
@@ -198,6 +204,7 @@ private Scope(
198204
* @param controls control dependencies for ops created with the returned scope
199205
* @return a new scope with the provided control dependencies
200206
*/
207+
@Override
201208
public Scope withControlDependencies(Iterable<Op> controls) {
202209
for (Op control : controls) {
203210
env.checkInput(control);
@@ -211,6 +218,7 @@ public Scope withControlDependencies(Iterable<Op> controls) {
211218
*
212219
* @param builder OperationBuilder to add control inputs and device specification to
213220
*/
221+
@Override
214222
public OperationBuilder apply(OperationBuilder builder) {
215223
builder.setDevice(deviceSpec.toString());
216224
return applyControlDependencies(builder);

0 commit comments

Comments
 (0)