Skip to content

Commit 3eb57bd

Browse files
authored
Merge pull request #73 from Hendrik-code/speed_extract_cc_overhaul
Speed extract cc overhaul
2 parents 1339cdc + 2ed8fc3 commit 3eb57bd

22 files changed

+892
-369
lines changed

TPTBox/core/nii_wrapper.py

Lines changed: 59 additions & 98 deletions
Large diffs are not rendered by default.

TPTBox/core/np_utils.py

Lines changed: 144 additions & 129 deletions
Large diffs are not rendered by default.

TPTBox/tests/speedtests/speedtest.py

Lines changed: 43 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
from __future__ import annotations
22

3+
import random
4+
import timeit
35
from collections.abc import Callable
46
from copy import deepcopy
57
from time import perf_counter
@@ -8,14 +10,35 @@
810
from tqdm import tqdm
911

1012

11-
def speed_test_input(inp, functions: list[Callable], assert_equal_function: Callable | None = None, *args, **kwargs):
13+
def speed_test_input(
14+
inp,
15+
functions: list[Callable],
16+
assert_equal_function: Callable | None = None,
17+
print_output: bool = False,
18+
*args,
19+
**kwargs,
20+
):
1221
time_measures = {}
1322
outs = {}
23+
random.shuffle(functions)
1424
for f in functions:
15-
start = perf_counter()
1625
input_copy = deepcopy(inp)
1726
out = f(*input_copy, *args, **kwargs) if isinstance(input_copy, (tuple, list)) else f(input_copy, *args, **kwargs)
18-
time = perf_counter() - start
27+
if print_output:
28+
print(f.__name__, out)
29+
30+
# def f_test():
31+
# f(*input_copy, *args, **kwargs) if isinstance(input_copy, (tuple, list)) else f(input_copy, *args, **kwargs)
32+
33+
time = timeit.timeit(
34+
"f(*input_copy, *args, **kwargs) if isinstance(input_copy, (tuple, list)) else f(input_copy, *args, **kwargs)",
35+
# "f_test()",
36+
# setup="def f_test(): f(*input_copy, *args, **kwargs) if isinstance(input_copy, (tuple, list)) else f(input_copy, *args, **kwargs)",
37+
number=5,
38+
globals=locals(),
39+
)
40+
41+
# time = perf_counter() - start
1942
outs[f.__name__] = out
2043
time_measures[f.__name__] = time
2144

@@ -38,22 +61,27 @@ def speed_test(
3861
# print first iteration
3962
print()
4063
print("Print first speed test")
41-
start = perf_counter()
42-
first_input = get_input_func()
43-
time = perf_counter() - start
44-
time_sums["input_function"].append(time)
45-
for f in functions:
46-
input_copy = deepcopy(first_input)
47-
out = f(*input_copy, *args, **kwargs) if isinstance(input_copy, (tuple, list)) else f(input_copy, *args, **kwargs)
48-
print(f.__name__, out)
4964

50-
for _ in tqdm(range(repeats)):
65+
for repeat_idx in tqdm(range(repeats)):
66+
start = perf_counter()
5167
inp = get_input_func()
52-
time_measures = speed_test_input(inp, *args, functions=functions, assert_equal_function=assert_equal_function, **kwargs)
68+
time = perf_counter() - start
69+
time_sums["input_function"].append(time)
70+
time_measures = speed_test_input(
71+
inp,
72+
*args,
73+
functions=functions,
74+
assert_equal_function=assert_equal_function,
75+
print_output=repeat_idx == 0,
76+
**kwargs,
77+
)
5378
for k, v in time_measures.items():
5479
if k not in time_sums:
5580
time_sums[k] = []
5681
time_sums[k].append(v)
5782

58-
for k, v in time_sums.items():
59-
print(k, "\t", round(sum(v) / repeats, ndigits=6), "+-", round(np.std(v), ndigits=6))
83+
times_sorted = dict(sorted(time_sums.items(), key=lambda x: sum(x[1]) / repeats))
84+
for idx, (k, v) in enumerate(times_sorted.items()):
85+
print(idx + 1, ".\t", round(sum(v) / repeats, ndigits=6), "+-", round(np.std(v), ndigits=6), "\t", k)
86+
# for k, v in time_sums.items():
87+
# print(k, "\t", round(sum(v) / repeats, ndigits=6), "+-", round(np.std(v), ndigits=6))

TPTBox/tests/speedtests/speedtest_cc3d.py

Lines changed: 8 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010

1111
from TPTBox.core.np_utils import (
1212
_to_labels,
13-
np_approx_center_of_mass,
1413
np_bbox_binary,
1514
np_bounding_boxes,
1615
np_center_of_mass,
@@ -26,8 +25,8 @@ def get_nii_array():
2625
num_points = random.randint(1, 30)
2726
nii, points, orientation, sizes = get_nii(x=(140, 140, 150), num_point=num_points)
2827
# nii.map_labels_({1: -1}, verbose=False)
29-
arr = nii.get_seg_array().astype(int)
30-
arr[arr == 1] = -1
28+
arr = nii.get_seg_array().astype(np.uint8)
29+
# arr[arr == 1] = -1
3130
arr_r = arr.copy()
3231
return arr_r
3332

@@ -39,64 +38,30 @@ def center_of_mass_one(arr: np.ndarray):
3938
return coms
4039

4140
def center_of_mass_(arr: np.ndarray):
42-
cc_label_set = np.unique(arr)
41+
cc_label_set = np_unique(arr)
4342
coms = {}
4443
for c in cc_label_set:
4544
if c == 0:
4645
continue
47-
c_l = arr.copy()
48-
c_l[c_l != c] = 0
49-
com = center_of_mass(c_l)
46+
com = center_of_mass(arr == c)
5047
coms[c] = com
5148
return coms
5249

5350
def bbox_(arr: np.ndarray):
54-
cc_label_set = np.unique(arr)
51+
cc_label_set = np_unique(arr)
5552
coms = {}
5653
for c in cc_label_set:
5754
if c == 0:
5855
continue
59-
c_l = arr.copy()
60-
c_l[c_l != c] = 0
61-
com = np_bbox_binary(c_l)
56+
com = np_bbox_binary(arr == c)
6257
coms[c] = com
6358
return coms
6459

65-
# def cc3d_volume(arr: np.ndarray):
66-
# volumes = dict(enumerate(cc3dstats(arr)["voxel_counts"]))
67-
# volumes.pop(0)
68-
# return volumes
69-
70-
# def cc3d_countnonzero(arr: np.ndarray):
71-
# return sum(cc3d_volume(arr).values())
72-
73-
# def cc3d_bbox(arr: np.ndarray):
74-
# return dict(enumerate(cc3dstats(arr)["bounding_boxes"]))
75-
76-
# def cc3d_unique(arr: np.ndarray) -> list[int]:
77-
# return [i for i, v in cc3d_volume(arr).items() if v > 0]
78-
79-
# a = np.ones((100, 100, 50), dtype=np.uint16)
80-
# a[0, 0, 0] = 0
81-
# print(type(a[0, 0, 0]))
82-
# print(cc3d_volume(a))
83-
84-
arr = get_nii_array()
85-
# print(np_unique_withoutzero(arr))
86-
# print("npunique", np.unique(arr))
87-
# print()
88-
# print(center_of_mass_(arr))
89-
# print()
90-
# print(np_center_of_mass(arr))
91-
print()
92-
print(np_unique(arr))
93-
print(np.unique(arr))
94-
9560
speed_test(
9661
repeats=50,
9762
get_input_func=get_nii_array,
98-
functions=[np_unique, np.unique],
99-
assert_equal_function=lambda x, y: True, # np.all([x[i] == y[i] for i in range(len(x))]), # noqa: ARG005
63+
functions=[cc3d_com, center_of_mass_],
64+
assert_equal_function=lambda x, y: np.all([x[i][0] == y[i][0] for i in x.keys()]), # noqa: ARG005
10065
# np.all([x[i] == y[i] for i in range(len(x))])
10166
)
10267
# print(time_measures)

TPTBox/tests/speedtests/speedtest_cc3d_crop.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,12 +25,12 @@ def get_nii_array():
2525
return arr
2626

2727
def normal(arr):
28-
return cc3dstatistics(arr)
28+
return cc3dstatistics(arr, use_crop=False)
2929

3030
def crop(arr):
3131
crop = np_bbox_binary(arr)
3232
arr = arr[crop]
33-
return cc3dstatistics(arr)
33+
return cc3dstatistics(arr, use_crop=False)
3434

3535
speed_test(
3636
repeats=50,

TPTBox/tests/speedtests/speedtest_connected_components.py

Lines changed: 33 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222

2323
def get_nii_array():
2424
num_points = random.randint(50, 51)
25-
nii, points, orientation, sizes = get_nii(x=(300, 300, 300), num_point=num_points)
25+
nii, points, orientation, sizes = get_nii(x=(100, 100, 100), num_point=num_points)
2626
# nii.map_labels_({1: -1}, verbose=False)
2727
arr = nii.get_seg_array().astype(np.uint8)
2828
# arr[arr == 1] = -1
@@ -32,6 +32,9 @@ def get_nii_array():
3232
def np_naive_cc(arr: np.ndarray):
3333
return np_connected_components(arr)[0][1]
3434

35+
def np_naive_cc_extract(arr: np.ndarray):
36+
return np_connected_components(arr, use_extract2=True)[0][1]
37+
3538
def np_naive_cc_gcrop(arr: np.ndarray):
3639
crop = np_bbox_binary(arr)
3740
arrc = arr[crop]
@@ -72,31 +75,33 @@ def np_crop_cc(arr: np.ndarray):
7275
subreg_cc_n[subreg] = n
7376
return subreg_cc[1] # , subreg_cc_n
7477

75-
def np_cc_once(arr: np.ndarray):
76-
# call cc once, then relabel
77-
connectivity = 3
78-
connectivity = min((connectivity + 1) * 2, 8) if arr.ndim == 2 else 6 if connectivity == 1 else 18 if connectivity == 2 else 26
79-
80-
labels: list[int] = np_unique(arr)
81-
82-
subreg_cc = {}
83-
subreg_cc_n = {}
84-
crop = np_bbox_binary(arr)
85-
arrc = arr[crop]
86-
zarr = np.zeros((len(labels), *arr.shape), dtype=arr.dtype)
87-
88-
labels_out = connected_components(arrc, connectivity=connectivity, return_N=False)
89-
for sidx, subreg in enumerate(labels): # type:ignore
90-
img_subreg = np_extract_label(arrc, subreg, inplace=False)
91-
lcrop = np_bbox_binary(img_subreg)
92-
img_subregc = img_subreg[lcrop]
93-
img_subreg[lcrop] = labels_out[lcrop] * img_subregc
94-
95-
arrcc = zarr[sidx]
96-
arrcc[crop] = img_subreg
97-
subreg_cc[subreg] = arrcc
98-
subreg_cc_n[subreg] = len(np_unique_withoutzero(img_subreg[lcrop]))
99-
return subreg_cc[1] # , subreg_cc_n
78+
# def np_cc_once(arr: np.ndarray):
79+
# # call cc once, then relabel
80+
# connectivity = 3
81+
# connectivity = min((connectivity + 1) * 2, 8) if arr.ndim == 2 else 6 if connectivity == 1 else 18 if connectivity == 2 else 26
82+
#
83+
# labels: list[int] = np_unique(arr)
84+
#
85+
# subreg_cc = {}
86+
# subreg_cc_n = {}
87+
# crop = np_bbox_binary(arr)
88+
# arrc = arr[crop]
89+
# zarr = np.zeros((len(labels), *arr.shape), dtype=arr.dtype)
90+
#
91+
# labels_out = connected_components(arrc, connectivity=connectivity, return_N=False)
92+
# for sidx, subreg in enumerate(labels): # type:ignore
93+
# arrcc[crop][np.logical_and()]
94+
# # arr[s == subreg]
95+
# # img_subreg = np_extract_label(arrc, subreg, inplace=False)
96+
# # lcrop = np_bbox_binary(img_subreg)
97+
# img_subregc = img_subreg[lcrop]
98+
# img_subreg[lcrop] = labels_out[lcrop] * img_subregc
99+
#
100+
# arrcc = zarr[sidx]
101+
# arrcc[crop] = img_subreg
102+
# subreg_cc[subreg] = arrcc
103+
# subreg_cc_n[subreg] = len(np_unique_withoutzero(img_subreg[lcrop]))
104+
# return subreg_cc[1] # , subreg_cc_n
100105

101106
def np_cc_once_lcrop(arr: np.ndarray):
102107
# call cc once, then relabel
@@ -125,9 +130,9 @@ def np_cc_once_lcrop(arr: np.ndarray):
125130
return subreg_cc[1] # , subreg_cc_n
126131

127132
speed_test(
128-
repeats=10,
133+
repeats=50,
129134
get_input_func=get_nii_array,
130-
functions=[np_naive_cc, np_naive_cc_gcrop, np_cc_once],
135+
functions=[np_naive_cc, np_naive_cc_extract],
131136
assert_equal_function=lambda x, y: np.count_nonzero(x) == np.count_nonzero(y),
132137
# np.all([x[i] == y[i] for i in range(x.shape[0])]), # noqa: ARG005
133138
# np.all([x[i] == y[i] for i in range(len(x))])
Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
if __name__ == "__main__":
2+
# speed test dilation
3+
import random
4+
5+
import numpy as np
6+
from cc3d import statistics as cc3dstats
7+
from scipy.ndimage import center_of_mass
8+
9+
from TPTBox.core.nii_wrapper import NII
10+
from TPTBox.core.np_utils import (
11+
_connected_components,
12+
_to_labels,
13+
np_bbox_binary,
14+
np_calc_overlapping_labels,
15+
np_connected_components,
16+
np_connected_components_per_label,
17+
np_connected_components_per_label2,
18+
np_extract_label,
19+
np_unique,
20+
np_unique_withoutzero,
21+
)
22+
from TPTBox.tests.speedtests.speedtest import speed_test
23+
from TPTBox.tests.test_utils import get_nii
24+
25+
def get_nii_array():
26+
num_points = random.randint(10, 31)
27+
nii, points, orientation, sizes = get_nii(x=(300, 300, 300), num_point=num_points)
28+
# nii.map_labels_({1: -1}, verbose=False)
29+
arr = nii.get_seg_array().astype(np.uint8)
30+
# arr[arr == 1] = -1
31+
# arr_r = arr.copy()
32+
return arr
33+
34+
def np_cc_labelwise1(arr: np.ndarray):
35+
return np_connected_components_per_label(arr)[0][1]
36+
37+
def np_cc_labelwise2(arr: np.ndarray):
38+
return np_connected_components_per_label2(arr)[0][1]
39+
40+
def np_cc_labelwise2crop(arr: np.ndarray):
41+
return np_connected_components_per_label2(arr, use_crop=True)[0][1]
42+
43+
# def np_cc_once(arr: np.ndarray):
44+
# # call cc once, then relabel
45+
# connectivity = 3
46+
# connectivity = min((connectivity + 1) * 2, 8) if arr.ndim == 2 else 6 if connectivity == 1 else 18 if connectivity == 2 else 26
47+
#
48+
# labels: list[int] = np_unique(arr)
49+
#
50+
# subreg_cc = {}
51+
# subreg_cc_n = {}
52+
# crop = np_bbox_binary(arr)
53+
# arrc = arr[crop]
54+
# zarr = np.zeros((len(labels), *arr.shape), dtype=arr.dtype)
55+
#
56+
# labels_out = connected_components(arrc, connectivity=connectivity, return_N=False)
57+
# for sidx, subreg in enumerate(labels): # type:ignore
58+
# arrcc[crop][np.logical_and()]
59+
# # arr[s == subreg]
60+
# # img_subreg = np_extract_label(arrc, subreg, inplace=False)
61+
# # lcrop = np_bbox_binary(img_subreg)
62+
# img_subregc = img_subreg[lcrop]
63+
# img_subreg[lcrop] = labels_out[lcrop] * img_subregc
64+
#
65+
# arrcc = zarr[sidx]
66+
# arrcc[crop] = img_subreg
67+
# subreg_cc[subreg] = arrcc
68+
# subreg_cc_n[subreg] = len(np_unique_withoutzero(img_subreg[lcrop]))
69+
# return subreg_cc[1] # , subreg_cc_n
70+
71+
speed_test(
72+
repeats=50,
73+
get_input_func=get_nii_array,
74+
functions=[
75+
np_cc_labelwise1,
76+
np_cc_labelwise2,
77+
],
78+
assert_equal_function=lambda x, y: True, # np.array_equal(x, y), # noqa: ARG005
79+
# np.all([x[i] == y[i] for i in range(x.shape[0])]), # noqa: ARG005
80+
# np.all([x[i] == y[i] for i in range(len(x))])
81+
)
82+
# print(time_measures)

0 commit comments

Comments
 (0)