Skip to content

Commit 272af6a

Browse files
committed
PartitionedCall test
Signed-off-by: Ryan Nett <JNett96@gmail.com>
1 parent add3c80 commit 272af6a

File tree

1 file changed

+68
-0
lines changed

1 file changed

+68
-0
lines changed
Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License.
14+
=======================================================================
15+
*/
16+
package org.tensorflow.op.core;
17+
18+
import static org.junit.jupiter.api.Assertions.assertEquals;
19+
20+
import java.util.Arrays;
21+
import org.junit.jupiter.api.Test;
22+
import org.tensorflow.ConcreteFunction;
23+
import org.tensorflow.EagerSession;
24+
import org.tensorflow.Graph;
25+
import org.tensorflow.Operand;
26+
import org.tensorflow.Session;
27+
import org.tensorflow.Signature;
28+
import org.tensorflow.op.Ops;
29+
import org.tensorflow.types.TInt32;
30+
31+
public class PartitionedCallTest {
32+
33+
public static Signature plusTwo(Ops tf) {
34+
Operand<TInt32> x = tf.placeholder(TInt32.class);
35+
Operand<TInt32> y = tf.math.add(x, tf.constant(2));
36+
return Signature.builder().input("x", x).output("y", y).build();
37+
}
38+
39+
@Test
40+
public void testEager() {
41+
try (EagerSession e = EagerSession.create();
42+
ConcreteFunction f = ConcreteFunction.create(PartitionedCallTest::plusTwo)) {
43+
Ops tf = Ops.create(e);
44+
Operand<TInt32> x = tf.constant(3);
45+
Operand<TInt32> y =
46+
(Operand<TInt32>)
47+
tf.partitionedCall(Arrays.asList(x), Arrays.asList(TInt32.class), f).output().get(0);
48+
assertEquals(5, y.asTensor().getInt());
49+
}
50+
}
51+
52+
@Test
53+
public void testGraph() {
54+
try (Graph g = new Graph();
55+
ConcreteFunction f = ConcreteFunction.create(PartitionedCallTest::plusTwo)) {
56+
Ops tf = Ops.create(g);
57+
Operand<TInt32> x = tf.placeholder(TInt32.class);
58+
Operand<TInt32> y =
59+
(Operand<TInt32>)
60+
tf.partitionedCall(Arrays.asList(x), Arrays.asList(TInt32.class), f).output().get(0);
61+
62+
try (Session s = new Session(g);
63+
TInt32 out = (TInt32) s.runner().feed(x, TInt32.scalarOf(3)).fetch(y).run().get(0)) {
64+
assertEquals(5, out.getInt());
65+
}
66+
}
67+
}
68+
}

0 commit comments

Comments
 (0)