3838import org .tensorflow .internal .c_api .TF_Operation ;
3939import org .tensorflow .internal .c_api .TF_Output ;
4040import org .tensorflow .internal .c_api .TF_Status ;
41+ import org .tensorflow .internal .types .registry .TensorTypeRegistry ;
4142import org .tensorflow .op .Ops ;
4243import org .tensorflow .op .Scope ;
44+ import org .tensorflow .op .core .PartitionedCall ;
4345import org .tensorflow .op .core .Placeholder ;
4446import org .tensorflow .op .core .PlaceholderWithDefault ;
45- import org .tensorflow .op .core .StatefulPartitionedCall ;
46- import org .tensorflow .op .core .StatelessPartitionedCall ;
4747import org .tensorflow .proto .framework .AttrValue ;
4848import org .tensorflow .proto .framework .DataType ;
4949import org .tensorflow .proto .framework .FunctionDef ;
@@ -218,11 +218,8 @@ public String toString() {
218218 * @return the outputs of the function
219219 */
220220 public Map <String , Operand <?>> call (Scope scope , Map <String , Operand <?>> arguments ) {
221- List <Operand <?>> inputList = new ArrayList <>();
221+ List <Operand <?>> inputList = new ArrayList <>(signature . inputNames (). size () );
222222
223- Output <?>[] inputs = new Output <?>[signature ().inputNames ().size ()];
224-
225- int i = 0 ;
226223 for (String inputName : signature ().inputNames ()) {
227224 if (!arguments .containsKey (inputName )) {
228225 throw new IllegalArgumentException (
@@ -240,42 +237,23 @@ public Map<String, Operand<?>> call(Scope scope, Map<String, Operand<?>> argumen
240237 + inputName
241238 + "\" was null." );
242239 }
243- inputs [i ] = input .asOutput ();
244- i ++;
240+ inputList .add (input );
245241 }
246242
247- scope .env ().attachFunction (this );
248- String name = getDefinedName ();
249-
250- String displayName = Scope .isValidOpName (name ) ? name : "FunctionCall" ;
251-
252- OperationBuilder opBuilder =
253- scope
254- .env ()
255- .opBuilder (
256- isStateful () ? StatefulPartitionedCall .OP_NAME : StatelessPartitionedCall .OP_NAME ,
257- scope .makeOpName (displayName ));
258-
259- opBuilder .addInputList (inputs );
260-
261- opBuilder .setAttr ("f" , this );
262- opBuilder .setAttr ("Tin" , inputDtypes );
263- opBuilder .setAttr ("Tout" , outputDtypes );
264-
265- opBuilder = scope .apply (opBuilder );
266- Operation op = opBuilder .build ();
267-
268- int numOutputs1 = op .numOutputs ();
269- List <Operand <?>> outputList = new ArrayList <>(signature ().outputNames ().size ());
270-
271- for (i = 0 ; i < numOutputs1 ; i ++) {
272- outputList .add (op .output (i ));
273- }
243+ List <Output <?>> outputList =
244+ PartitionedCall .create (
245+ scope ,
246+ inputList ,
247+ Arrays .stream (inputDtypes )
248+ .map (x -> TensorTypeRegistry .find (x ).type ())
249+ .collect (Collectors .toList ()),
250+ this )
251+ .output ();
274252
275253 Map <String , Operand <?>> namedOutputs = new LinkedHashMap <>(signature ().outputNames ().size ());
276254
277255 List <String > outputNames = new ArrayList <>(signature ().outputNames ());
278- for (i = 0 ; i < outputNames .size (); i ++) {
256+ for (int i = 0 ; i < outputNames .size (); i ++) {
279257 String outputName = outputNames .get (i );
280258
281259 if (i > outputList .size ()) {
0 commit comments