Skip to content

Commit 0640a04

Browse files
committed
Update stdlib, more helpers
Signed-off-by: Ryan Nett <rnett@calpoly.edu>
1 parent 3aea42e commit 0640a04

File tree

4 files changed

+119
-3
lines changed

4 files changed

+119
-3
lines changed

tensorflow-core-kotlin/pom.xml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@
3838
<dependencies>
3939
<dependency>
4040
<groupId>org.jetbrains.kotlin</groupId>
41-
<artifactId>kotlin-stdlib</artifactId>
41+
<artifactId>kotlin-stdlib-jdk8</artifactId>
4242
<version>${kotlin.version}</version>
4343
</dependency>
4444
</dependencies>
Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,100 @@
11
package org.tensorflow
22

3+
import org.tensorflow.EagerSession.DevicePlacementPolicy
4+
import org.tensorflow.proto.framework.ConfigProto
5+
import kotlin.contracts.InvocationKind
6+
import kotlin.contracts.contract
7+
8+
/**
9+
* Construct a TensorFlow [Graph] and run [block] on it.
10+
*/
11+
public inline fun <R> Graph(block: Graph.() -> R): R {
12+
contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) }
13+
return Graph().use{
14+
it.run(block)
15+
}
16+
}
17+
18+
19+
/**
20+
* Construct a new session with the associated {@link Graph} and configuration options, and run [block] on it.
21+
*
22+
* @param g The {@link Graph} the created Session will operate on.
23+
* @param config Configuration parameters for the session specified as a [ConfigProto](https://www.tensorflow.org/code/tensorflow/core/protobuf/config.proto)
24+
* protocol buffer.
25+
* @throws IllegalArgumentException if the config is not a valid serialization of the ConfigProto
26+
* protocol buffer.
27+
*/
28+
public inline fun <R> Graph.withSession(config: ConfigProto? = null, block: (Session) -> R): R {
29+
contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) }
30+
return Session(this, config).use(block)
31+
}
32+
33+
34+
/**
35+
* An environment for executing TensorFlow operations eagerly.
36+
*
37+
* Eager execution is an imperative programming environment that evaluates operations
38+
* immediately, without building graphs. Operations return concrete values instead of constructing a
39+
* computational graph to run later, as with {@link Graph}s and {@link Session}s.
40+
*
41+
* This makes it easy to develop with TensorFlow and debug models, as it behaves more like a
42+
* standard programming library.
43+
*
44+
* Instances of a {@code EagerSession} are thread-safe.
45+
*
46+
* @param options The options for this session.
47+
* @see EagerSession.Options
48+
*/
49+
public inline fun <R> EagerSession(
50+
options: EagerSession.Options? = null,
51+
block: EagerSession.() -> R
52+
): R {
53+
contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) }
54+
55+
val ses = options?.build() ?: EagerSession.create()
56+
return ses.use(block)
57+
}
58+
59+
/**
60+
* An environment for executing TensorFlow operations eagerly.
61+
*
62+
* Eager execution is an imperative programming environment that evaluates operations
63+
* immediately, without building graphs. Operations return concrete values instead of constructing a
64+
* computational graph to run later, as with {@link Graph}s and {@link Session}s.
65+
*
66+
* This makes it easy to develop with TensorFlow and debug models, as it behaves more like a
67+
* standard programming library.
68+
*
69+
* Instances of a {@code EagerSession} are thread-safe.
70+
*
71+
* @param config The session configuration to use. See [EagerSession.Options.config] and [ConfigProto].
72+
* @param async Whether to return from op methods before the outputs have been calculated. See [EagerSession.Options.async].
73+
* @param devicePlacementPolicy How to handle tensors on different devices. See [EagerSession.Options.devicePlacementPolicy].
74+
* @see EagerSession.Options
75+
*/
76+
public inline fun <R> EagerSession(
77+
config: ConfigProto? = null,
78+
async: Boolean = false,
79+
devicePlacementPolicy: DevicePlacementPolicy = DevicePlacementPolicy.SILENT,
80+
block: EagerSession.() -> R
81+
): R {
82+
contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) }
83+
84+
val options = EagerSession.options()
85+
.config(config)
86+
.async(async)
87+
.devicePlacementPolicy(devicePlacementPolicy)
88+
89+
return EagerSession(options, block)
90+
}
91+
92+
/**
93+
* Executed [block] in the default eager session, creating it if necessary.
94+
*
95+
* To configure the default session, use [EagerSession.initDefault].
96+
*/
97+
public fun <R> withDefaultEagerSession(block: EagerSession.() -> R): R {
98+
contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) }
99+
return EagerSession.getDefault().use(block)
100+
}
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
package org.tensorflow.ndarray
2+
3+
/**
4+
* Convert the [Shape] to a List.
5+
*/
6+
public fun Shape.toList(): List<Long> = asArray().toList()
7+
8+
/**
9+
* Get the size at [index].
10+
*/
11+
public operator fun Shape.get(index: Int): Long = this.size(index)

tensorflow-core-kotlin/tensorflow-core-kotlin-api/src/main/kotlin/org/tensorflow/op/kotlin/OpsHelpers.kt

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,13 @@
11
package org.tensorflow.op.kotlin
22

3+
import org.tensorflow.DataType
34
import org.tensorflow.ExecutionEnvironment
5+
import org.tensorflow.ndarray.Shape
46
import org.tensorflow.op.JavaOps
57
import org.tensorflow.op.Op
8+
import org.tensorflow.op.Ops
9+
import org.tensorflow.op.core.Placeholder
10+
import org.tensorflow.types.family.TType
611
import kotlin.contracts.ExperimentalContracts
712
import kotlin.contracts.InvocationKind
813
import kotlin.contracts.contract
@@ -19,7 +24,7 @@ public fun KotlinOps.withSubScope(childScopeName: String): KotlinOps = KotlinOps
1924
*
2025
* @see {@link Scope#withSubScope(String)}
2126
*/
22-
public fun <R> KotlinOps.withSubScope(childScopeName: String, block: KotlinOps.() -> R): R {
27+
public inline fun <R> KotlinOps.withSubScope(childScopeName: String, block: KotlinOps.() -> R): R {
2328
contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) }
2429
return withSubScope(childScopeName).run(block)
2530
}
@@ -43,4 +48,6 @@ public fun KotlinOps.withControlDependencies(controls: Iterable<Op>): KotlinOps
4348
*/
4449
public val ExecutionEnvironment.tf: KotlinOps get() = JavaOps.create(this).kotlin
4550

46-
//TODO we could have tf that gets itself from ExecutionEnvironment.default(). I think this will be too error prone to be worth doing
51+
//TODO we could have tf that gets itself from ExecutionEnvironment.default(). I think this will be too error prone to be worth doing
52+
53+
//public fun <T: TType> Ops.placeholder(dtype: DataType<T>, vararg shape: Long): Placeholder<T> = placeholder(dtype, Shape.of(*shape))

0 commit comments

Comments
 (0)