|
1 | 1 | /* Copyright 2019-2021 The TensorFlow Authors. All Rights Reserved. |
2 | 2 |
|
3 | | - Licensed under the Apache License, Version 2.0 (the "License"); |
4 | | - you may not use this file except in compliance with the License. |
5 | | - You may obtain a copy of the License at |
6 | | -
|
7 | | - http://www.apache.org/licenses/LICENSE-2.0 |
8 | | -
|
9 | | - Unless required by applicable law or agreed to in writing, software |
10 | | - distributed under the License is distributed on an "AS IS" BASIS, |
11 | | - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | | - See the License for the specific language governing permissions and |
13 | | - limitations under the License. |
14 | | - ======================================================================= |
15 | | - */ |
| 3 | +Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +you may not use this file except in compliance with the License. |
| 5 | +You may obtain a copy of the License at |
| 6 | +
|
| 7 | + http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +
|
| 9 | +Unless required by applicable law or agreed to in writing, software |
| 10 | +distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +See the License for the specific language governing permissions and |
| 13 | +limitations under the License. |
| 14 | +======================================================================= |
| 15 | +*/ |
16 | 16 | package org.tensorflow; |
17 | 17 |
|
18 | | -import static org.tensorflow.internal.c_api.global.tensorflow.TF_LoadSessionFromSavedModel; |
19 | 18 | import static org.tensorflow.internal.c_api.global.tensorflow.TF_NewGraph; |
20 | 19 | import static org.tensorflow.internal.c_api.global.tensorflow.TF_SetConfig; |
21 | 20 |
|
|
34 | 33 | import java.util.Map.Entry; |
35 | 34 | import java.util.stream.Collectors; |
36 | 35 | import org.bytedeco.javacpp.BytePointer; |
37 | | -import org.bytedeco.javacpp.PointerPointer; |
38 | 36 | import org.bytedeco.javacpp.PointerScope; |
39 | 37 | import org.tensorflow.exceptions.TensorFlowException; |
40 | 38 | import org.tensorflow.internal.c_api.TF_Buffer; |
@@ -510,21 +508,18 @@ private static SavedModelBundle load( |
510 | 508 | TF_Graph graph = TF_NewGraph(); |
511 | 509 | TF_Buffer metagraphDef = TF_Buffer.newBuffer(); |
512 | 510 | TF_Session session = |
513 | | - TF_LoadSessionFromSavedModel( |
514 | | - opts, |
515 | | - runOpts, |
516 | | - new BytePointer(exportDir), |
517 | | - new PointerPointer(tags), |
518 | | - tags.length, |
519 | | - graph, |
520 | | - metagraphDef, |
521 | | - status); |
| 511 | + TF_Session.loadSessionFromSavedModel( |
| 512 | + opts, runOpts, exportDir, tags, graph, metagraphDef, status); |
522 | 513 | status.throwExceptionIfNotOK(); |
523 | 514 |
|
524 | 515 | // handle the result |
525 | 516 | try { |
526 | 517 | bundle = |
527 | 518 | fromHandle(graph, session, MetaGraphDef.parseFrom(metagraphDef.dataAsByteBuffer())); |
| 519 | + // Only retain the references if the metagraphdef parses correctly, |
| 520 | + // otherwise allow the pointer scope to clean them up |
| 521 | + graph.retainReference(); |
| 522 | + session.retainReference(); |
528 | 523 | } catch (InvalidProtocolBufferException e) { |
529 | 524 | throw new TensorFlowException("Cannot parse MetaGraphDef protocol buffer", e); |
530 | 525 | } |
|
0 commit comments