Skip to content

Commit 2ddbb6c

Browse files
committed
Store and allow getting native scope device when it has been set from Java
Signed-off-by: Ryan Nett <JNett96@gmail.com>
1 parent 36a6e30 commit 2ddbb6c

File tree

3 files changed

+17
-11
lines changed

3 files changed

+17
-11
lines changed

tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/GradientScope.java

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -38,12 +38,12 @@ public ExecutionEnvironment env() {
3838

3939
@Override
4040
public GradientScope withSubScope(String childScopeName) {
41-
return new GradientScope(nativeScope.NewSubScope(childScopeName), graph, childScopeName);
41+
return new GradientScope(nativeScope.NewSubScope(childScopeName), graph, childScopeName, device);
4242
}
4343

4444
@Override
4545
public GradientScope withName(String opName) {
46-
return new GradientScope(nativeScope, graph, opName);
46+
return new GradientScope(nativeScope, graph, opName, device);
4747
}
4848

4949
@Override
@@ -53,7 +53,7 @@ public GradientScope withNameAsSubScope(String defaultName) {
5353

5454
@Override
5555
public GradientScope withDevice(DeviceSpec deviceSpec) {
56-
return new GradientScope(nativeScope.WithDevice(deviceSpec.toString()), graph);
56+
return new GradientScope(nativeScope.WithDevice(deviceSpec.toString()), graph, deviceSpec.toString());
5757
}
5858

5959
@Override
@@ -90,7 +90,7 @@ public GradientScope withControlDependencies(Iterable<Op> controls) {
9090
.put(new NativeOperation(((GraphOperation) op).getUnsafeNativeHandle().node()));
9191
}
9292

93-
return new GradientScope(nativeScope.WithControlDependencies(new NativeOperation(ops)), graph);
93+
return new GradientScope(nativeScope.WithControlDependencies(new NativeOperation(ops)), graph, device);
9494
}
9595

9696
@Override
@@ -108,7 +108,7 @@ public Scope withControlDependencyOps(Iterable<Operation> controls) {
108108
.put(new NativeOperation(((GraphOperation) op).getUnsafeNativeHandle().node()));
109109
}
110110

111-
return new GradientScope(nativeScope.WithControlDependencies(new NativeOperation(ops)), graph);
111+
return new GradientScope(nativeScope.WithControlDependencies(new NativeOperation(ops)), graph, device);
112112
}
113113

114114
@Override
@@ -121,25 +121,31 @@ public void onOpCreated(Operation op) {}
121121

122122
@Override
123123
public String getDeviceString() {
124-
throw new IllegalStateException("Can't get device string for gradient scope");
124+
if (device == null) {
125+
throw new UnsupportedOperationException("Can't get device string for gradient scope unless it has been explicitly set");
126+
} else {
127+
return device;
128+
}
125129
}
126130

127131
@Override
128132
public boolean isInit() {
129133
return false;
130134
}
131135

132-
GradientScope(TF_Scope nativeScope, Graph graph) {
133-
this(nativeScope, graph, null);
136+
GradientScope(TF_Scope nativeScope, Graph graph, String device) {
137+
this(nativeScope, graph, null, device);
134138
}
135139

136-
private GradientScope(TF_Scope nativeScope, Graph graph, String opName) {
140+
private GradientScope(TF_Scope nativeScope, Graph graph, String opName, String device) {
137141
this.graph = graph;
138142
this.nativeScope = nativeScope;
139143
this.opName = opName;
144+
this.device = device;
140145
}
141146

142147
private final Graph graph;
143148
private final TF_Scope nativeScope;
144149
private final String opName;
150+
private final String device;
145151
}

tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/RawGradientAdapter.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ public NativeStatus call(
5454
throw new IllegalStateException("No graph found for native gradient scope.");
5555
}
5656

57-
Scope nativeScope = new GradientScope(scope, g);
57+
Scope nativeScope = new GradientScope(scope, g, null);
5858
Ops tf = new Ops(nativeScope);
5959

6060
List<Output<?>> gradInputs = BaseGradientAdapter.fromNativeOutputs(g, grad_inputs);

tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/TypedGradientAdapter.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ public NativeStatus call(
6060
throw new IllegalStateException("No graph found for native gradient scope.");
6161
}
6262

63-
Scope nativeScope = new GradientScope(scope, g);
63+
Scope nativeScope = new GradientScope(scope, g, null);
6464

6565
Ops tf = new Ops(nativeScope);
6666

0 commit comments

Comments
 (0)