Skip to content

Commit 2fbefbb

Browse files
authored
added tflite support (experimental) (#40)
* tflite support (wip) * linting * just check for "resnet" in model path * fallback to single tensor input * adjust image_data shape to expected input shape also optional image_size input * removed commented out code * only resize tensor if input image doesn't match * added command to convert to tflite * added --model-architecture argument * added quantization arguments * updated readme * linting * typo in readme * improved wording in readme
1 parent ebc5a4a commit 2fbefbb

File tree

6 files changed

+196
-4
lines changed

6 files changed

+196
-4
lines changed

.pylintrc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,4 +10,5 @@ disable=
1010
too-few-public-methods,
1111
too-many-arguments,
1212
too-many-instance-attributes,
13+
duplicate-code,
1314
invalid-name

README.md

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -176,6 +176,31 @@ python -m tf_bodypix \
176176
--threshold=0.75
177177
```
178178

179+
## TensorFlow Lite support (experimental)
180+
181+
The model path may also point to a TensorFlow Lite model (`.tflite` extension). Whether that actually improves performance may depend on the platform and available hardware.
182+
183+
You could convert one of the available TensorFlow JS models to TensorFlow Lite using the following command:
184+
185+
```bash
186+
python -m tf_bodypix \
187+
convert-to-tflite \
188+
--model-path \
189+
"https://storage.googleapis.com/tfjs-models/savedmodel/bodypix/mobilenet/float/075/model-stride16.json" \
190+
--optimize \
191+
--quantization-type=float16 \
192+
--output-model-file "./mobilenet-float16-stride16.tflite"
193+
```
194+
195+
The above command is provided for convenience.
196+
You may use alternative methods depending on your preference and requirements.
197+
198+
Relevant links:
199+
200+
* [TensorFlow Lite converter](https://www.tensorflow.org/lite/convert/)
201+
* [TF Lite post_training_quantization](https://www.tensorflow.org/lite/performance/post_training_quantization)
202+
* [TF GitHub #40183](https://github.com/tensorflow/tensorflow/issues/40183).
203+
179204
## Acknowledgements
180205

181206
* [Original TensorFlow JS Implementation of BodyPix](https://github.com/tensorflow/tfjs-models/tree/body-pix-v2.0.4/body-pix)

tests/cli_test.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from pathlib import Path
33

44
from tf_bodypix.download import BodyPixModelPaths
5+
from tf_bodypix.model import ModelArchitectureNames
56
from tf_bodypix.cli import main
67

78

@@ -87,3 +88,22 @@ def test_should_list_all_default_model_urls(self, capsys):
8788
LOGGER.debug('output_urls: %s', output_urls)
8889
missing_urls = set(expected_urls) - set(output_urls)
8990
assert not missing_urls
91+
92+
def test_should_be_able_to_convert_to_tflite_and_use_model(self, temp_dir: Path):
93+
output_model_file = temp_dir / 'model.tflite'
94+
main([
95+
'convert-to-tflite',
96+
'--model-path=%s' % BodyPixModelPaths.MOBILENET_FLOAT_75_STRIDE_16,
97+
'--optimize',
98+
'--quantization-type=int8',
99+
'--output-model-file=%s' % output_model_file
100+
])
101+
output_image_path = temp_dir / 'mask.jpg'
102+
main([
103+
'draw-mask',
104+
'--model-path=%s' % output_model_file,
105+
'--model-architecture=%s' % ModelArchitectureNames.MOBILENET_V1,
106+
'--output-stride=16',
107+
'--source=%s' % EXAMPLE_IMAGE_URL,
108+
'--output=%s' % output_image_path
109+
])

tf_bodypix/cli.py

Lines changed: 60 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from abc import ABC, abstractmethod
66
from contextlib import ExitStack
77
from itertools import cycle
8+
from pathlib import Path
89
from time import time
910
from typing import Dict, List
1011

@@ -25,8 +26,10 @@
2526
)
2627
from tf_bodypix.utils.s3 import iter_s3_file_urls
2728
from tf_bodypix.download import download_model
29+
from tf_bodypix.tflite import get_tflite_converter_for_model_path
2830
from tf_bodypix.model import (
2931
load_model,
32+
VALID_MODEL_ARCHITECTURE_NAMES,
3033
PART_CHANNELS,
3134
DEFAULT_RESIZE_METHOD,
3235
BodyPixModelWrapper,
@@ -77,6 +80,14 @@ def add_model_arguments(parser: argparse.ArgumentParser):
7780
default=DEFAULT_MODEL_PATH,
7881
help="The path or URL to the bodypix model."
7982
)
83+
parser.add_argument(
84+
"--model-architecture",
85+
choices=VALID_MODEL_ARCHITECTURE_NAMES,
86+
help=(
87+
"The model architecture."
88+
" It will be guessed from the model path if not specified."
89+
)
90+
)
8091
parser.add_argument(
8192
"--output-stride",
8293
type=int,
@@ -219,7 +230,8 @@ def load_bodypix_model(args: argparse.Namespace) -> BodyPixModelWrapper:
219230
return load_model(
220231
local_model_path,
221232
internal_resolution=args.internal_resolution,
222-
output_stride=args.output_stride
233+
output_stride=args.output_stride,
234+
architecture_name=args.model_architecture
223235
)
224236

225237

@@ -266,6 +278,52 @@ def run(self, args: argparse.Namespace): # pylint: disable=unused-argument
266278
print('\n'.join(bodypix_model_json_files))
267279

268280

281+
class ConvertToTFLiteSubCommand(SubCommand):
282+
def __init__(self):
283+
super().__init__("convert-to-tflite", "Converts the model to a tflite model")
284+
285+
def add_arguments(self, parser: argparse.ArgumentParser):
286+
add_common_arguments(parser)
287+
parser.add_argument(
288+
"--model-path",
289+
default=DEFAULT_MODEL_PATH,
290+
help="The path or URL to the bodypix model."
291+
)
292+
parser.add_argument(
293+
"--output-model-file",
294+
required=True,
295+
help="The path to the output file (tflite model)."
296+
)
297+
parser.add_argument(
298+
"--optimize",
299+
action='store_true',
300+
help="Enable optimization (quantization)."
301+
)
302+
parser.add_argument(
303+
"--quantization-type",
304+
choices=['float16', 'float32', 'int8'],
305+
help="The quantization type to use."
306+
)
307+
308+
def run(self, args: argparse.Namespace): # pylint: disable=unused-argument
309+
LOGGER.info('converting model: %s', args.model_path)
310+
converter = get_tflite_converter_for_model_path(download_model(
311+
args.model_path
312+
))
313+
tflite_model = converter.convert()
314+
if args.optimize:
315+
LOGGER.info('enabled optimization')
316+
converter.optimizations = [tf.lite.Optimize.DEFAULT]
317+
if args.quantization_type:
318+
LOGGER.info('quanization type: %s', args.quantization_type)
319+
quantization_type = getattr(tf, args.quantization_type)
320+
converter.target_spec.supported_types = [quantization_type]
321+
converter.inference_input_type = quantization_type
322+
converter.inference_output_type = quantization_type
323+
LOGGER.info('saving tflite model to: %s', args.output_model_file)
324+
Path(args.output_model_file).write_bytes(tflite_model)
325+
326+
269327
class AbstractWebcamFilterApp(ABC):
270328
def __init__(self, args: argparse.Namespace):
271329
self.args = args
@@ -497,6 +555,7 @@ def get_app(self, args: argparse.Namespace) -> AbstractWebcamFilterApp:
497555

498556
SUB_COMMANDS: List[SubCommand] = [
499557
ListModelsSubCommand(),
558+
ConvertToTFLiteSubCommand(),
500559
DrawMaskSubCommand(),
501560
BlurBackgroundSubCommand(),
502561
ReplaceBackgroundSubCommand()

tf_bodypix/model.py

Lines changed: 65 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -341,6 +341,66 @@ def get_structured_output_names(structured_outputs: List[tf.Tensor]) -> List[str
341341
]
342342

343343

344+
def to_number_of_dimensions(data: np.ndarray, dimension_count: int) -> np.ndarray:
345+
while len(data.shape) > dimension_count:
346+
data = data[0]
347+
while len(data.shape) < dimension_count:
348+
data = np.expand_dims(data, axis=0)
349+
return data
350+
351+
352+
def load_tflite_model(model_path: str):
353+
# Load TFLite model and allocate tensors.
354+
interpreter = tf.lite.Interpreter(model_path=model_path)
355+
interpreter.allocate_tensors()
356+
357+
input_details = interpreter.get_input_details()
358+
LOGGER.debug('input_details: %s', input_details)
359+
input_names = [item['name'] for item in input_details]
360+
LOGGER.debug('input_names: %s', input_names)
361+
input_details_map = dict(zip(input_names, input_details))
362+
363+
output_details = interpreter.get_output_details()
364+
LOGGER.debug('output_details: %s', output_details)
365+
output_names = [item['name'] for item in output_details]
366+
LOGGER.debug('output_names: %s', output_names)
367+
368+
try:
369+
image_input = input_details_map['image']
370+
except KeyError:
371+
assert len(input_details_map) == 1
372+
image_input = list(input_details_map.values())[0]
373+
input_shape = image_input['shape']
374+
LOGGER.debug('input_shape: %s', input_shape)
375+
376+
def predict(image_data: np.ndarray):
377+
nonlocal input_shape
378+
image_data = to_number_of_dimensions(image_data, len(input_shape))
379+
LOGGER.debug('tflite predict, image_data.shape=%s (%s)', image_data.shape, image_data.dtype)
380+
height, width, *_ = image_data.shape
381+
if tuple(image_data.shape) != tuple(input_shape):
382+
LOGGER.info('resizing input tensor: %s -> %s', tuple(input_shape), image_data.shape)
383+
interpreter.resize_tensor_input(image_input['index'], list(image_data.shape))
384+
interpreter.allocate_tensors()
385+
input_shape = image_data.shape
386+
interpreter.set_tensor(image_input['index'], image_data)
387+
if 'image_size' in input_details_map:
388+
interpreter.set_tensor(
389+
input_details_map['image_size']['index'],
390+
np.array([height, width], dtype=np.float)
391+
)
392+
393+
interpreter.invoke()
394+
395+
# The function `get_tensor()` returns a copy of the tensor data.
396+
# Use `tensor()` in order to get a pointer to the tensor.
397+
return {
398+
item['name']: interpreter.get_tensor(item['index'])
399+
for item in output_details
400+
}
401+
return predict
402+
403+
344404
def load_using_saved_model_and_get_predict_function(model_path):
345405
loaded = tf.saved_model.load(model_path)
346406
LOGGER.debug('loaded: %s', loaded)
@@ -366,24 +426,26 @@ def load_using_tfjs_graph_converter_and_get_predict_function(
366426
def load_model_and_get_predict_function(
367427
model_path: str
368428
) -> Callable[[np.ndarray], dict]:
429+
if model_path.endswith('.tflite'):
430+
return load_tflite_model(model_path)
369431
try:
370432
return load_using_saved_model_and_get_predict_function(model_path)
371433
except OSError:
372434
return load_using_tfjs_graph_converter_and_get_predict_function(model_path)
373435

374436

375437
def get_output_stride_from_model_path(model_path: str) -> int:
376-
match = re.search(r'stride(\d+)', model_path)
438+
match = re.search(r'stride(\d+)|_(\d+)_quant', model_path)
377439
if not match:
378440
raise ValueError('cannot extract output stride from model path: %r' % model_path)
379-
return int(match.group(1))
441+
return int(match.group(1) or match.group(2))
380442

381443

382444
def get_architecture_from_model_path(model_path: str) -> int:
383445
model_path_lower = model_path.lower()
384446
if 'mobilenet' in model_path_lower:
385447
return ModelArchitectureNames.MOBILENET_V1
386-
if 'resnet50' in model_path_lower:
448+
if 'resnet' in model_path_lower:
387449
return ModelArchitectureNames.RESNET_50
388450
raise ValueError('cannot extract model architecture from model path: %r' % model_path)
389451

tf_bodypix/tflite.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
import logging
2+
3+
import tensorflow as tf
4+
5+
try:
6+
import tfjs_graph_converter
7+
except ImportError:
8+
tfjs_graph_converter = None
9+
10+
11+
LOGGER = logging.getLogger(__name__)
12+
13+
14+
def get_tflite_converter_for_tfjs_model_path(model_path: str) -> tf.lite.TFLiteConverter:
15+
if tfjs_graph_converter is None:
16+
raise ImportError('tfjs_graph_converter required')
17+
graph = tfjs_graph_converter.api.load_graph_model(model_path)
18+
tf_fn = tfjs_graph_converter.api.graph_to_function_v2(graph)
19+
return tf.lite.TFLiteConverter.from_concrete_functions([tf_fn])
20+
21+
22+
def get_tflite_converter_for_model_path(model_path: str) -> tf.lite.TFLiteConverter:
23+
LOGGER.debug('converting model_path: %s', model_path)
24+
# if model_path.endswith('.json'):
25+
return get_tflite_converter_for_tfjs_model_path(model_path)

0 commit comments

Comments
 (0)