@@ -38,12 +38,12 @@ public ExecutionEnvironment env() {
3838
3939 @ Override
4040 public GradientScope withSubScope (String childScopeName ) {
41- return new GradientScope (nativeScope .NewSubScope (childScopeName ), graph , childScopeName );
41+ return new GradientScope (nativeScope .NewSubScope (childScopeName ), graph , childScopeName , device );
4242 }
4343
4444 @ Override
4545 public GradientScope withName (String opName ) {
46- return new GradientScope (nativeScope , graph , opName );
46+ return new GradientScope (nativeScope , graph , opName , device );
4747 }
4848
4949 @ Override
@@ -53,7 +53,7 @@ public GradientScope withNameAsSubScope(String defaultName) {
5353
5454 @ Override
5555 public GradientScope withDevice (DeviceSpec deviceSpec ) {
56- return new GradientScope (nativeScope .WithDevice (deviceSpec .toString ()), graph );
56+ return new GradientScope (nativeScope .WithDevice (deviceSpec .toString ()), graph , deviceSpec . toString () );
5757 }
5858
5959 @ Override
@@ -90,7 +90,7 @@ public GradientScope withControlDependencies(Iterable<Op> controls) {
9090 .put (new NativeOperation (((GraphOperation ) op ).getUnsafeNativeHandle ().node ()));
9191 }
9292
93- return new GradientScope (nativeScope .WithControlDependencies (new NativeOperation (ops )), graph );
93+ return new GradientScope (nativeScope .WithControlDependencies (new NativeOperation (ops )), graph , device );
9494 }
9595
9696 @ Override
@@ -108,7 +108,7 @@ public Scope withControlDependencyOps(Iterable<Operation> controls) {
108108 .put (new NativeOperation (((GraphOperation ) op ).getUnsafeNativeHandle ().node ()));
109109 }
110110
111- return new GradientScope (nativeScope .WithControlDependencies (new NativeOperation (ops )), graph );
111+ return new GradientScope (nativeScope .WithControlDependencies (new NativeOperation (ops )), graph , device );
112112 }
113113
114114 @ Override
@@ -121,25 +121,31 @@ public void onOpCreated(Operation op) {}
121121
122122 @ Override
123123 public String getDeviceString () {
124- throw new IllegalStateException ("Can't get device string for gradient scope" );
124+ if (device == null ) {
125+ throw new UnsupportedOperationException ("Can't get device string for gradient scope unless it has been explicitly set" );
126+ } else {
127+ return device ;
128+ }
125129 }
126130
127131 @ Override
128132 public boolean isInit () {
129133 return false ;
130134 }
131135
132- GradientScope (TF_Scope nativeScope , Graph graph ) {
133- this (nativeScope , graph , null );
136+ GradientScope (TF_Scope nativeScope , Graph graph , String device ) {
137+ this (nativeScope , graph , null , device );
134138 }
135139
136- private GradientScope (TF_Scope nativeScope , Graph graph , String opName ) {
140+ private GradientScope (TF_Scope nativeScope , Graph graph , String opName , String device ) {
137141 this .graph = graph ;
138142 this .nativeScope = nativeScope ;
139143 this .opName = opName ;
144+ this .device = device ;
140145 }
141146
142147 private final Graph graph ;
143148 private final TF_Scope nativeScope ;
144149 private final String opName ;
150+ private final String device ;
145151}
0 commit comments