Skip to content

Commit c513036

Browse files
authored
fixed predict with certain image resultion (#45)
fixed predict with internal image resultion matching output stride plus one
1 parent 70a1a2d commit c513036

3 files changed

Lines changed: 45 additions & 2 deletions

File tree

tests/model_test.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
import logging
2+
from unittest.mock import MagicMock
3+
4+
import numpy as np
5+
6+
from tf_bodypix.model import BodyPixModelWrapper
7+
8+
9+
LOGGER = logging.getLogger(__name__)
10+
11+
12+
ANY_INT_FACTOR_1 = 5
13+
14+
15+
class TestBodyPixModelWrapper:
16+
def test_should_be_able_to_padded_and_resized_image_matching_output_stride_plus_one(self):
17+
predict_fn = MagicMock(name='predict_fn')
18+
output_stride = 16
19+
internal_resolution = 0.5
20+
model = BodyPixModelWrapper(
21+
predict_fn=predict_fn,
22+
output_stride=output_stride,
23+
internal_resolution=internal_resolution
24+
)
25+
resolution_matching_output_stride_plus_1 = int(
26+
(output_stride * ANY_INT_FACTOR_1 + 1) / internal_resolution
27+
)
28+
LOGGER.debug(
29+
'resolution_matching_output_stride_plus_1: %s',
30+
resolution_matching_output_stride_plus_1
31+
)
32+
image = np.ones(
33+
shape=(
34+
resolution_matching_output_stride_plus_1,
35+
resolution_matching_output_stride_plus_1,
36+
3
37+
)
38+
)
39+
model.predict_single(image)

tf_bodypix/bodypix_js_utils/util.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,9 @@ def to_valid_input_resolution(
2424
input_resolution: int, output_stride: int
2525
) -> int:
2626
if is_valid_input_resolution(input_resolution, output_stride):
27-
return input_resolution
27+
return int(input_resolution)
2828

29-
return math.floor(input_resolution / output_stride) * output_stride + 1
29+
return int(math.floor(input_resolution / output_stride) * output_stride + 1)
3030

3131

3232
# see toInputResolutionHeightAndWidth

tf_bodypix/model.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -301,6 +301,10 @@ def get_bodypix_input_size(self, original_size: ImageSize) -> ImageSize:
301301
def get_padded_and_resized(
302302
self, image: np.ndarray, model_input_size: ImageSize
303303
) -> Tuple[np.ndarray, Padding]:
304+
LOGGER.debug(
305+
'pad_and_resize_to: image.shape=%s, model_input_size=%s',
306+
image.shape, model_input_size
307+
)
304308
return pad_and_resize_to(
305309
image,
306310
model_input_size.height,

0 commit comments

Comments
 (0)