Skip to content

Commit 8ac6a8c

Browse files
Fixed #434 Refactor test files (#440)
* refactor test files to move naive implementations to the top of the files * reformat with black code formatter
1 parent 6218b76 commit 8ac6a8c

File tree

4 files changed

+83
-81
lines changed

4 files changed

+83
-81
lines changed

tests/test_core.py

Lines changed: 36 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -18,42 +18,6 @@ def naive_rolling_window_dot_product(Q, T):
1818
return result
1919

2020

21-
def test_check_dtype_float32():
22-
assert core.check_dtype(np.random.rand(10).astype(np.float32))
23-
24-
25-
def test_check_dtype_float64():
26-
assert core.check_dtype(np.random.rand(10))
27-
28-
29-
def test_get_max_window_size():
30-
for n in range(3, 10):
31-
ref_max_m = (
32-
int(
33-
n
34-
- math.floor(
35-
(n + (config.STUMPY_EXCL_ZONE_DENOM - 1))
36-
// (config.STUMPY_EXCL_ZONE_DENOM + 1)
37-
)
38-
)
39-
- 1
40-
)
41-
cmp_max_m = core.get_max_window_size(n)
42-
assert ref_max_m == cmp_max_m
43-
44-
45-
def test_check_window_size():
46-
for m in range(-1, 3):
47-
with pytest.raises(ValueError):
48-
core.check_window_size(m)
49-
50-
51-
def test_check_max_window_size():
52-
for m in range(4, 7):
53-
with pytest.raises(ValueError):
54-
core.check_window_size(m, max_size=3)
55-
56-
5721
def naive_compute_mean_std(T, m):
5822
n = T.shape[0]
5923

@@ -103,6 +67,42 @@ def naive_compute_mean_std_multidimensional(T, m):
10367
]
10468

10569

70+
def test_check_dtype_float32():
71+
assert core.check_dtype(np.random.rand(10).astype(np.float32))
72+
73+
74+
def test_check_dtype_float64():
75+
assert core.check_dtype(np.random.rand(10))
76+
77+
78+
def test_get_max_window_size():
79+
for n in range(3, 10):
80+
ref_max_m = (
81+
int(
82+
n
83+
- math.floor(
84+
(n + (config.STUMPY_EXCL_ZONE_DENOM - 1))
85+
// (config.STUMPY_EXCL_ZONE_DENOM + 1)
86+
)
87+
)
88+
- 1
89+
)
90+
cmp_max_m = core.get_max_window_size(n)
91+
assert ref_max_m == cmp_max_m
92+
93+
94+
def test_check_window_size():
95+
for m in range(-1, 3):
96+
with pytest.raises(ValueError):
97+
core.check_window_size(m)
98+
99+
100+
def test_check_max_window_size():
101+
for m in range(4, 7):
102+
with pytest.raises(ValueError):
103+
core.check_window_size(m, max_size=3)
104+
105+
106106
@pytest.mark.parametrize("Q, T", test_data)
107107
def test_sliding_dot_product(Q, T):
108108
ref_mp = naive_rolling_window_dot_product(Q, T)

tests/test_motifs.py

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -6,18 +6,6 @@
66

77
import naive
88

9-
test_data = [
10-
(
11-
np.array([0.0, 1.0, 0.0]),
12-
np.array([0.0, 1.0, 0.0, -1.0, -1.0, 0.0, 1.0, 0.0, -0.5]),
13-
),
14-
(
15-
np.array([0.0, 1.0, 2.0]),
16-
np.array([0.1, 1.0, 2.0, 3.0, -1.0, 0.1, 1.0, 2.0, -0.5]),
17-
),
18-
(np.random.uniform(-1000, 1000, [8]), np.random.uniform(-1000, 1000, [64])),
19-
]
20-
219

2210
def naive_match(Q, T, excl_zone, max_distance):
2311
m = Q.shape[0]
@@ -41,6 +29,19 @@ def naive_match(Q, T, excl_zone, max_distance):
4129
return np.array(result, dtype=object)
4230

4331

32+
test_data = [
33+
(
34+
np.array([0.0, 1.0, 0.0]),
35+
np.array([0.0, 1.0, 0.0, -1.0, -1.0, 0.0, 1.0, 0.0, -0.5]),
36+
),
37+
(
38+
np.array([0.0, 1.0, 2.0]),
39+
np.array([0.1, 1.0, 2.0, 3.0, -1.0, 0.1, 1.0, 2.0, -0.5]),
40+
),
41+
(np.random.uniform(-1000, 1000, [8]), np.random.uniform(-1000, 1000, [64])),
42+
]
43+
44+
4445
def test_motifs_one_motif():
4546
# The top motif for m=3 is a [0 1 0] at indices 0, 5 and 9
4647
T = np.array([0.0, 1.0, 0.0, -1.0, -1.0, 0.0, 1.0, 0.0, -0.5, 2.0, 3.0, 2.0])

tests/test_scraamp.py

Lines changed: 18 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -6,23 +6,6 @@
66
import naive
77

88

9-
test_data = [
10-
(
11-
np.array([9, 8100, -60, 7], dtype=np.float64),
12-
np.array([584, -11, 23, 79, 1001, 0, -19], dtype=np.float64),
13-
),
14-
(
15-
np.random.uniform(-1000, 1000, [8]).astype(np.float64),
16-
np.random.uniform(-1000, 1000, [64]).astype(np.float64),
17-
),
18-
]
19-
20-
window_size = [8, 16, 32]
21-
substitution_locations = [(slice(0, 0), 0, -1, slice(1, 3), [0, 3])]
22-
substitution_values = [np.nan, np.inf]
23-
percentages = [(0.01, 0.1, 1.0)]
24-
25-
269
def naive_prescraamp(T_A, m, T_B, s, exclusion_zone=None):
2710
distance_matrix = naive.aamp_distance_matrix(T_A, T_B, m)
2811

@@ -126,6 +109,24 @@ def naive_scraamp(T_A, m, T_B, percentage, exclusion_zone, pre_scraamp, s):
126109
return out
127110

128111

112+
test_data = [
113+
(
114+
np.array([9, 8100, -60, 7], dtype=np.float64),
115+
np.array([584, -11, 23, 79, 1001, 0, -19], dtype=np.float64),
116+
),
117+
(
118+
np.random.uniform(-1000, 1000, [8]).astype(np.float64),
119+
np.random.uniform(-1000, 1000, [64]).astype(np.float64),
120+
),
121+
]
122+
123+
124+
window_size = [8, 16, 32]
125+
substitution_locations = [(slice(0, 0), 0, -1, slice(1, 3), [0, 3])]
126+
substitution_values = [np.nan, np.inf]
127+
percentages = [(0.01, 0.1, 1.0)]
128+
129+
129130
@pytest.mark.parametrize("T_A, T_B", test_data)
130131
def test_prescraamp_self_join(T_A, T_B):
131132
m = 3

tests/test_stimp.py

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,22 @@
88
import naive
99

1010

11+
def naive_bsf_indices(n):
12+
a = np.arange(n)
13+
nodes = [a.tolist()]
14+
out = []
15+
16+
while nodes:
17+
tmp = []
18+
for node in nodes:
19+
for n in split(node, out):
20+
if n:
21+
tmp.append(n)
22+
nodes = tmp
23+
24+
return np.array(out)
25+
26+
1127
T = [
1228
np.array([584, -11, 23, 79, 1001, 0, -19], dtype=np.float64),
1329
np.random.uniform(-1000, 1000, [64]).astype(np.float64),
@@ -29,22 +45,6 @@ def split(node, out):
2945
return node[:mid], node[mid + 1 :]
3046

3147

32-
def naive_bsf_indices(n):
33-
a = np.arange(n)
34-
nodes = [a.tolist()]
35-
out = []
36-
37-
while nodes:
38-
tmp = []
39-
for node in nodes:
40-
for n in split(node, out):
41-
if n:
42-
tmp.append(n)
43-
nodes = tmp
44-
45-
return np.array(out)
46-
47-
4848
@pytest.mark.parametrize("n", n)
4949
def test_bsf_indices(n):
5050
ref_bsf_indices = naive_bsf_indices(n)

0 commit comments

Comments
 (0)