Skip to content

Commit c1850bc

Browse files
committed
Helper methods for withDevice, a combined with method, and tf(DeviceSpec) since device will often be used at or near the top level.
Signed-off-by: Ryan Nett <rnett@calpoly.edu>
1 parent 2e5f465 commit c1850bc

File tree

1 file changed

+100
-9
lines changed
  • tensorflow-core-kotlin/tensorflow-core-kotlin-api/src/main/kotlin/org/tensorflow/op/kotlin

1 file changed

+100
-9
lines changed
Lines changed: 100 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
package org.tensorflow.op.kotlin
22

3+
import org.tensorflow.DeviceSpec
34
import org.tensorflow.ExecutionEnvironment
45
import org.tensorflow.op.JavaOps
56
import org.tensorflow.op.Op
@@ -11,38 +12,128 @@ import kotlin.contracts.contract
1112
*/
1213
public val JavaOps.kotlin: KotlinOps get() = KotlinOps(this)
1314

15+
/**
16+
* Returns a child [KotlinOps] builder that builds operations with the provided name prefix.
17+
*
18+
* @see org.tensorflow.op.Scope.withSubScope
19+
*/
1420
public fun KotlinOps.withSubScope(childScopeName: String): KotlinOps = KotlinOps(java.withSubScope(childScopeName))
1521

1622
/**
17-
* Returns an API that builds operations with the provided name prefix.
23+
* Runs [block] on a child [KotlinOps] builder that builds operations with the provided name prefix.
1824
*
19-
* @see {@link Scope#withSubScope(String)}
25+
* @see org.tensorflow.op.Scope.withSubScope
2026
*/
2127
public inline fun <R> KotlinOps.withSubScope(childScopeName: String, block: KotlinOps.() -> R): R {
2228
contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) }
2329
return withSubScope(childScopeName).run(block)
2430
}
2531

2632
/**
27-
* Returns an API that uses the provided name for an op.
33+
* Returns a child [KotlinOps] builder that uses the provided name for an op.
2834
*
29-
* @see {@link Scope#withName(String)}
35+
* @see org.tensorflow.op.Scope.withName
3036
*/
3137
public fun KotlinOps.withName(opName: String): KotlinOps = java.withName(opName).kotlin
3238

3339
/**
34-
* Returns an API that adds operations to the graph with the provided control dependencies.
40+
* Returns a child [KotlinOps] builder that adds operations to the graph with the provided control dependencies.
3541
*
36-
* @see {@link Scope#withControlDependencies(Iterable<Op<?>>)}
42+
* @see org.tensorflow.op.Scope.withControlDependencies
3743
*/
3844
public fun KotlinOps.withControlDependencies(controls: Iterable<Op>): KotlinOps =
3945
java.withControlDependencies(controls).kotlin
4046

4147
/**
42-
* Creates an API for building operations in the provided execution environment
48+
* Returns a child [KotlinOps] builder that adds operations to the graph with the provided control dependencies.
49+
*
50+
* @see org.tensorflow.op.Scope.withControlDependencies
51+
*/
52+
public fun KotlinOps.withControlDependencies(vararg controls: Op): KotlinOps =
53+
withControlDependencies(controls.toList())
54+
55+
/**
56+
* Runs [block] on a child [KotlinOps] builder that adds operations to the graph with the provided control dependencies.
57+
*
58+
* @see org.tensorflow.op.Scope.withControlDependencies
59+
*/
60+
public inline fun <R> KotlinOps.withControlDependencies(controls: Iterable<Op>, block: KotlinOps.() -> R): R {
61+
contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) }
62+
return withControlDependencies(controls).run(block)
63+
}
64+
65+
/**
66+
* Runs [block] on a child [KotlinOps] builder that adds operations to the graph with the provided control dependencies.
67+
*
68+
* @see org.tensorflow.op.Scope.withControlDependencies
69+
*/
70+
public inline fun <R> KotlinOps.withControlDependencies(vararg controls: Op, block: KotlinOps.() -> R): R {
71+
contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) }
72+
return withControlDependencies(*controls).run(block)
73+
}
74+
75+
/**
76+
* Returns a child [KotlinOps] builder that uses the provided device for created ops.
77+
*
78+
* @see org.tensorflow.op.Scope.withDevice
79+
*/
80+
public fun KotlinOps.withDevice(device: DeviceSpec): KotlinOps = java.withDevice(device).kotlin
81+
82+
/**
83+
* Runs [block] on a child [KotlinOps] builder that uses the provided device for created ops.
84+
*
85+
* @see org.tensorflow.op.Scope.withDevice
86+
*/
87+
public inline fun <R> KotlinOps.withDevice(device: DeviceSpec, block: KotlinOps.() -> R): R {
88+
contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) }
89+
return withDevice(device).run(block)
90+
}
91+
92+
/**
93+
* Returns a child [KotlinOps] builder, combining [withSubScope], [withControlDependencies], and [withDevice].
94+
* Null arguments are ignored.
95+
*
96+
* @see org.tensorflow.op.Scope.withSubScope
97+
* @see org.tensorflow.op.Scope.withControlDependencies
98+
* @see org.tensorflow.op.Scope.withDevice
99+
*/
100+
public fun KotlinOps.with(
101+
childScopeName: String? = null,
102+
controlDependencies: Iterable<Op>? = null,
103+
device: DeviceSpec? = null
104+
): KotlinOps {
105+
var ops = this
106+
childScopeName?.let { ops = ops.withSubScope(it) }
107+
controlDependencies?.let { ops = ops.withControlDependencies(it) }
108+
device?.let { ops = ops.withDevice(it) }
109+
return ops
110+
}
111+
112+
/**
113+
* Runs [block] on a child [KotlinOps] builder, combining [withSubScope], [withControlDependencies], and [withDevice].
114+
* Null arguments are ignored.
115+
*
116+
* @see org.tensorflow.op.Scope.withSubScope
117+
* @see org.tensorflow.op.Scope.withControlDependencies
118+
* @see org.tensorflow.op.Scope.withDevice
119+
*/
120+
public inline fun <R> KotlinOps.with(
121+
childScopeName: String? = null,
122+
controlDependencies: Iterable<Op>? = null,
123+
device: DeviceSpec? = null,
124+
block: KotlinOps.() -> R
125+
): R {
126+
return with(childScopeName, controlDependencies, device).run(block)
127+
}
128+
129+
/**
130+
* Creates a [KotlinOps] builder for building operations in the provided execution environment.
43131
*/
44132
public val ExecutionEnvironment.tf: KotlinOps get() = JavaOps.create(this).kotlin
45133

46-
// TODO we could have tf that gets itself from ExecutionEnvironment.default(). I think this will be too error prone to be worth doing
134+
/**
135+
* Creates a [KotlinOps] builder for building operations in the provided execution environment with the provided device.
136+
*/
137+
public fun ExecutionEnvironment.tf(device: DeviceSpec): KotlinOps = tf.withDevice(device)
47138

48-
// public fun <T: TType> Ops.placeholder(dtype: DataType<T>, vararg shape: Long): Placeholder<T> = placeholder(dtype, Shape.of(*shape))
139+
// TODO we could have tf that gets itself from ExecutionEnvironment.default(). I think this will be too error prone to be worth doing

0 commit comments

Comments
 (0)