Skip to content

Commit 6845d9f

Browse files
committed
PartitionedCall to do calls
Signed-off-by: Ryan Nett <JNett96@gmail.com>
1 parent 272af6a commit 6845d9f

File tree

1 file changed

+14
-36
lines changed

1 file changed

+14
-36
lines changed

tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/ConcreteFunction.java

Lines changed: 14 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -38,12 +38,12 @@
3838
import org.tensorflow.internal.c_api.TF_Operation;
3939
import org.tensorflow.internal.c_api.TF_Output;
4040
import org.tensorflow.internal.c_api.TF_Status;
41+
import org.tensorflow.internal.types.registry.TensorTypeRegistry;
4142
import org.tensorflow.op.Ops;
4243
import org.tensorflow.op.Scope;
44+
import org.tensorflow.op.core.PartitionedCall;
4345
import org.tensorflow.op.core.Placeholder;
4446
import org.tensorflow.op.core.PlaceholderWithDefault;
45-
import org.tensorflow.op.core.StatefulPartitionedCall;
46-
import org.tensorflow.op.core.StatelessPartitionedCall;
4747
import org.tensorflow.proto.framework.AttrValue;
4848
import org.tensorflow.proto.framework.DataType;
4949
import org.tensorflow.proto.framework.FunctionDef;
@@ -218,11 +218,8 @@ public String toString() {
218218
* @return the outputs of the function
219219
*/
220220
public Map<String, Operand<?>> call(Scope scope, Map<String, Operand<?>> arguments) {
221-
List<Operand<?>> inputList = new ArrayList<>();
221+
List<Operand<?>> inputList = new ArrayList<>(signature.inputNames().size());
222222

223-
Output<?>[] inputs = new Output<?>[signature().inputNames().size()];
224-
225-
int i = 0;
226223
for (String inputName : signature().inputNames()) {
227224
if (!arguments.containsKey(inputName)) {
228225
throw new IllegalArgumentException(
@@ -240,42 +237,23 @@ public Map<String, Operand<?>> call(Scope scope, Map<String, Operand<?>> argumen
240237
+ inputName
241238
+ "\" was null.");
242239
}
243-
inputs[i] = input.asOutput();
244-
i++;
240+
inputList.add(input);
245241
}
246242

247-
scope.env().attachFunction(this);
248-
String name = getDefinedName();
249-
250-
String displayName = Scope.isValidOpName(name) ? name : "FunctionCall";
251-
252-
OperationBuilder opBuilder =
253-
scope
254-
.env()
255-
.opBuilder(
256-
isStateful() ? StatefulPartitionedCall.OP_NAME : StatelessPartitionedCall.OP_NAME,
257-
scope.makeOpName(displayName));
258-
259-
opBuilder.addInputList(inputs);
260-
261-
opBuilder.setAttr("f", this);
262-
opBuilder.setAttr("Tin", inputDtypes);
263-
opBuilder.setAttr("Tout", outputDtypes);
264-
265-
opBuilder = scope.apply(opBuilder);
266-
Operation op = opBuilder.build();
267-
268-
int numOutputs1 = op.numOutputs();
269-
List<Operand<?>> outputList = new ArrayList<>(signature().outputNames().size());
270-
271-
for (i = 0; i < numOutputs1; i++) {
272-
outputList.add(op.output(i));
273-
}
243+
List<Output<?>> outputList =
244+
PartitionedCall.create(
245+
scope,
246+
inputList,
247+
Arrays.stream(inputDtypes)
248+
.map(x -> TensorTypeRegistry.find(x).type())
249+
.collect(Collectors.toList()),
250+
this)
251+
.output();
274252

275253
Map<String, Operand<?>> namedOutputs = new LinkedHashMap<>(signature().outputNames().size());
276254

277255
List<String> outputNames = new ArrayList<>(signature().outputNames());
278-
for (i = 0; i < outputNames.size(); i++) {
256+
for (int i = 0; i < outputNames.size(); i++) {
279257
String outputName = outputNames.get(i);
280258

281259
if (i > outputList.size()) {

0 commit comments

Comments
 (0)