1919import static org .tensorflow .internal .c_api .global .tensorflow .TF_GraphToFunction ;
2020
2121import java .util .ArrayList ;
22- import java .util .Arrays ;
2322import java .util .Collection ;
2423import java .util .Collections ;
2524import java .util .HashSet ;
25+ import java .util .Iterator ;
2626import java .util .LinkedHashMap ;
2727import java .util .List ;
2828import java .util .Map ;
6666 * Map<String, Tensor> outputTensorMap = myFunction.call(inputTensorMap);
6767 * }</pre>
6868 */
69- public class ConcreteFunction implements AutoCloseable , TensorFunction {
69+ public final class ConcreteFunction implements AutoCloseable , TensorFunction {
7070
7171 /**
7272 * Creates a function by building a new graph.
@@ -220,11 +220,11 @@ public String toString() {
220220 public Map <String , Operand <?>> call (Scope scope , Map <String , Operand <?>> arguments ) {
221221 List <Operand <?>> inputList = new ArrayList <>(signature .inputNames ().size ());
222222
223- for (String inputName : signature () .inputNames ()) {
223+ for (String inputName : signature .inputNames ()) {
224224 if (!arguments .containsKey (inputName )) {
225225 throw new IllegalArgumentException (
226226 "Function "
227- + signature () .methodName ()
227+ + signature .methodName ()
228228 + " has parameter \" "
229229 + inputName
230230 + "\" , but no argument was passed for it." );
@@ -241,30 +241,30 @@ public Map<String, Operand<?>> call(Scope scope, Map<String, Operand<?>> argumen
241241 }
242242
243243 List <Output <?>> outputList =
244- PartitionedCall .create (
245- scope ,
246- inputList ,
247- Arrays .stream (outputDtypes )
248- .map (x -> TensorTypeRegistry .find (x ).type ())
249- .collect (Collectors .toList ()),
250- this )
251- .output ();
252-
253- Map <String , Operand <?>> namedOutputs = new LinkedHashMap <>(signature ().outputNames ().size ());
254-
255- List <String > outputNames = new ArrayList <>(signature ().outputNames ());
256- for (int i = 0 ; i < outputNames .size (); i ++) {
257- String outputName = outputNames .get (i );
258-
259- if (i > outputList .size ()) {
260- throw new IllegalStateException (
261- "Somehow, not all required outputs were returned from the function" );
262- }
244+ PartitionedCall .create (scope , inputList , outputTypes , this ).output ();
263245
246+ if (signature .outputNames ().size () == 0 ) {
247+ return Collections .emptyMap ();
248+ }
249+ if (signature .outputNames ().size () == 1 ) {
250+ return Collections .singletonMap (signature .outputNames ().iterator ().next (), outputList .get (0 ));
251+ }
252+ if (outputList .size () < signature .outputNames ().size ()) {
253+ throw new IllegalStateException (
254+ "Somehow, not all required outputs were returned from the function"
255+ + "(expected: "
256+ + signature .outputNames ().size ()
257+ + ", returned: "
258+ + outputList .size ()
259+ + ")" );
260+ }
261+ Map <String , Operand <?>> namedOutputs = new LinkedHashMap <>(signature .outputNames ().size ());
262+ Iterator <String > outputNames = signature .outputNames ().iterator ();
263+ for (int i = 0 ; outputNames .hasNext (); i ++) {
264+ String outputName = outputNames .next ();
264265 Operand <?> output = outputList .get (i );
265266 namedOutputs .put (outputName , output );
266267 }
267-
268268 return Collections .unmodifiableMap (namedOutputs );
269269 }
270270
@@ -291,10 +291,7 @@ public Operand<?> call(Scope scope, Operand<?> argument) {
291291 }
292292 String outputName = signatureDef .getOutputsMap ().keySet ().iterator ().next ();
293293
294- Map <String , Operand <?>> inputMap = new LinkedHashMap <>();
295- inputMap .put (inputName , argument );
296-
297- return call (scope , inputMap ).get (outputName );
294+ return call (scope , Collections .singletonMap (inputName , argument )).get (outputName );
298295 }
299296
300297 @ Override
@@ -395,8 +392,7 @@ static ConcreteFunction fromNativeHandle(
395392 private final NativeFunction nativeFunction ;
396393 private final PointerScope scope ;
397394 private final Set <TF_Function > dependencies ;
398- private final DataType [] inputDtypes ;
399- private final DataType [] outputDtypes ;
395+ private final List <Class <? extends TType >> outputTypes ;
400396
401397 /** All native functions should have deallocators registered */
402398 private ConcreteFunction (
@@ -405,7 +401,7 @@ private ConcreteFunction(
405401 this .nativeFunction = nativeFunction ;
406402 this .dependencies = Collections .unmodifiableSet (dependencies );
407403
408- if (this . signature .getInputs ().size ()
404+ if (signature .getInputs ().size ()
409405 != nativeFunction .getFunctionDef ().getSignature ().getInputArgCount ()) {
410406 throw new IllegalArgumentException (
411407 "Signature must have the same number of inputs as the native function. Expected "
@@ -414,7 +410,7 @@ private ConcreteFunction(
414410 + this .signature .getInputs ().size ());
415411 }
416412
417- if (this . signature .getOutputs ().size ()
413+ if (signature .getOutputs ().size ()
418414 != nativeFunction .getFunctionDef ().getSignature ().getOutputArgCount ()) {
419415 throw new IllegalArgumentException (
420416 "New signature must have the same number of outputs as the native function. Expected "
@@ -423,10 +419,8 @@ private ConcreteFunction(
423419 + this .signature .getOutputs ().size ());
424420 }
425421
426- inputDtypes =
427- this .signature .getInputs ().values ().stream ().map (x -> x .dataType ).toArray (DataType []::new );
428-
429- List <DataType > inputs = Arrays .asList (inputDtypes );
422+ List <DataType > inputs =
423+ signature .getInputs ().values ().stream ().map (x -> x .dataType ).collect (Collectors .toList ());
430424 List <DataType > nativeInputs =
431425 nativeFunction .getFunctionDef ().getSignature ().getInputArgList ().stream ()
432426 .map (ArgDef ::getType )
@@ -440,10 +434,8 @@ private ConcreteFunction(
440434 + inputs );
441435 }
442436
443- outputDtypes =
444- signature ().getOutputs ().values ().stream ().map (x -> x .dataType ).toArray (DataType []::new );
445-
446- List <DataType > outputs = Arrays .asList (outputDtypes );
437+ List <DataType > outputs =
438+ signature .getOutputs ().values ().stream ().map (x -> x .dataType ).collect (Collectors .toList ());
447439 List <DataType > nativeOutputs =
448440 nativeFunction .getFunctionDef ().getSignature ().getOutputArgList ().stream ()
449441 .map (ArgDef ::getType )
@@ -457,6 +449,9 @@ private ConcreteFunction(
457449 + outputs );
458450 }
459451
452+ outputTypes =
453+ outputs .stream ().map (x -> TensorTypeRegistry .find (x ).type ()).collect (Collectors .toList ());
454+
460455 try (PointerScope scope = new PointerScope ()) {
461456 this .scope = scope ;
462457 scope .extend ();
@@ -469,6 +464,8 @@ private ConcreteFunction(
469464 * FIXME: This causes native errors when I use it (Linux GPU, 6.1 CC), but I'm leaving it because
470465 * how to enable XLA JIT is extremely non-obvious.
471466 *
467+ * <p>See https://github.com/tensorflow/java/issues/347
468+ *
472469 * <p>Causes {@code OP_REQUIRES failed at xla_ops.cc:363 : Not found: could not find registered
473470 * platform with id: 0x7f75af03e6e8} (it's a warning, but the resulting TF_Status fails).
474471 */
0 commit comments