Skip to content

Commit 6218b76

Browse files
committed
Improved docstrings and naive tests
1 parent 210c2ef commit 6218b76

File tree

6 files changed

+54
-31
lines changed

6 files changed

+54
-31
lines changed

stumpy/aampdist_snippets.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -179,7 +179,10 @@ def aampdist_snippets(
179179
snippets
180180
181181
snippets_regimes: ndarray
182-
The slices of indices that show the starting and ending indices of snippets
182+
The index slices corresponding to the set of regimes for each of the top `k`
183+
snippets. The first column is the (zero-based) snippet index while the second
184+
and third columns correspond to the (inclusive) regime start indices and the
185+
(exclusive) regime stop indices, respectively.
183186
184187
Notes
185188
-----
@@ -239,7 +242,7 @@ def aampdist_snippets(
239242
snippets_regimes_list.append(slices)
240243

241244
n_slices = [regime.shape[0] for regime in snippets_regimes_list]
242-
snippets_regimes = np.empty((sum(n_slices), 3), dtype=object)
245+
snippets_regimes = np.empty((sum(n_slices), 3), dtype=np.int64)
243246
snippets_regimes[:, 0] = np.repeat(np.arange(len(snippets_regimes_list)), n_slices)
244247
snippets_regimes[:, 1:] = np.vstack(snippets_regimes_list)
245248

stumpy/core.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1810,7 +1810,8 @@ def _jagged_list_to_array(a, fill_value, dtype):
18101810

18111811
def _get_mask_slices(mask):
18121812
"""
1813-
For a boolean vector mask, returns the slices of indices at which the mask is True.
1813+
For a boolean vector mask, return the (inclusive) start and (exclusive) stop
1814+
indices where the mask is `True`.
18141815
18151816
Parameters
18161817
----------

stumpy/snippets.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -188,7 +188,10 @@ def snippets(
188188
snippets
189189
190190
snippets_regimes: ndarray
191-
The slices of indices that show the starting and ending indices of snippets
191+
The index slices corresponding to the set of regimes for each of the top `k`
192+
snippets. The first column is the (zero-based) snippet index while the second
193+
and third columns correspond to the (inclusive) regime start indices and the
194+
(exclusive) regime stop indices, respectively.
192195
193196
Notes
194197
-----
@@ -248,7 +251,7 @@ def snippets(
248251
snippets_regimes_list.append(slices)
249252

250253
n_slices = [regime.shape[0] for regime in snippets_regimes_list]
251-
snippets_regimes = np.empty((sum(n_slices), 3), dtype=object)
254+
snippets_regimes = np.empty((sum(n_slices), 3), dtype=np.int64)
252255
snippets_regimes[:, 0] = np.repeat(np.arange(len(snippets_regimes_list)), n_slices)
253256
snippets_regimes[:, 1:] = np.vstack(snippets_regimes_list)
254257

tests/naive.py

Lines changed: 36 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1136,10 +1136,24 @@ def mpdist_snippets(
11361136
slices = _get_mask_slices(mask)
11371137
snippets_regimes_list.append(slices)
11381138

1139-
n_slices = [regime.shape[0] for regime in snippets_regimes_list]
1140-
snippets_regimes = np.empty((sum(n_slices), 3), dtype=object)
1141-
snippets_regimes[:, 0] = np.repeat(np.arange(len(snippets_regimes_list)), n_slices)
1142-
snippets_regimes[:, 1:] = np.vstack(snippets_regimes_list)
1139+
n_slices = []
1140+
for regime in snippets_regimes_list:
1141+
n_slices.append(regime.shape[0])
1142+
1143+
snippets_regimes = np.empty((sum(n_slices), 3), dtype=np.int64)
1144+
i = 0
1145+
j = 0
1146+
for n_slice in n_slices:
1147+
for _ in range(n_slice):
1148+
snippets_regimes[i, 0] = j
1149+
i += 1
1150+
j += 1
1151+
1152+
i = 0
1153+
for regimes in snippets_regimes_list:
1154+
for regime in regimes:
1155+
snippets_regimes[i, 1:] = regime
1156+
i += 1
11431157

11441158
return (
11451159
snippets,
@@ -1208,10 +1222,24 @@ def aampdist_snippets(
12081222
slices = _get_mask_slices(mask)
12091223
snippets_regimes_list.append(slices)
12101224

1211-
n_slices = [regime.shape[0] for regime in snippets_regimes_list]
1212-
snippets_regimes = np.empty((sum(n_slices), 3), dtype=object)
1213-
snippets_regimes[:, 0] = np.repeat(np.arange(len(snippets_regimes_list)), n_slices)
1214-
snippets_regimes[:, 1:] = np.vstack(snippets_regimes_list)
1225+
n_slices = []
1226+
for regime in snippets_regimes_list:
1227+
n_slices.append(regime.shape[0])
1228+
1229+
snippets_regimes = np.empty((sum(n_slices), 3), dtype=np.int64)
1230+
i = 0
1231+
j = 0
1232+
for n_slice in n_slices:
1233+
for _ in range(n_slice):
1234+
snippets_regimes[i, 0] = j
1235+
i += 1
1236+
j += 1
1237+
1238+
i = 0
1239+
for regimes in snippets_regimes_list:
1240+
for regime in regimes:
1241+
snippets_regimes[i, 1:] = regime
1242+
i += 1
12151243

12161244
return (
12171245
snippets,

tests/test_aampdist_snippets.py

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -49,9 +49,7 @@ def test_aampdist_snippets(T, m, k):
4949
# npt.assert_almost_equal(
5050
# ref_areas, cmp_areas, decimal=config.STUMPY_TEST_PRECISION
5151
# )
52-
npt.assert_almost_equal(
53-
ref_regimes, cmp_regimes, decimal=config.STUMPY_TEST_PRECISION
54-
)
52+
npt.assert_almost_equal(ref_regimes, cmp_regimes)
5553

5654

5755
@pytest.mark.parametrize("T", test_data)
@@ -91,9 +89,7 @@ def test_mpdist_snippets_percentage(T, m, k, percentage):
9189
# npt.assert_almost_equal(
9290
# ref_areas, cmp_areas, decimal=config.STUMPY_TEST_PRECISION
9391
# )
94-
npt.assert_almost_equal(
95-
ref_regimes, cmp_regimes, decimal=config.STUMPY_TEST_PRECISION
96-
)
92+
npt.assert_almost_equal(ref_regimes, cmp_regimes)
9793

9894

9995
@pytest.mark.parametrize("T", test_data)
@@ -133,6 +129,4 @@ def test_mpdist_snippets_s(T, m, k, s):
133129
# npt.assert_almost_equal(
134130
# ref_areas, cmp_areas, decimal=config.STUMPY_TEST_PRECISION
135131
# )
136-
npt.assert_almost_equal(
137-
ref_regimes, cmp_regimes, decimal=config.STUMPY_TEST_PRECISION
138-
)
132+
npt.assert_almost_equal(ref_regimes, cmp_regimes)

tests/test_snippets.py

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -83,9 +83,7 @@ def test_mpdist_snippets(T, m, k):
8383
ref_fractions, cmp_fractions, decimal=config.STUMPY_TEST_PRECISION
8484
)
8585
npt.assert_almost_equal(ref_areas, cmp_areas, decimal=config.STUMPY_TEST_PRECISION)
86-
npt.assert_almost_equal(
87-
ref_regimes, cmp_regimes, decimal=config.STUMPY_TEST_PRECISION
88-
)
86+
npt.assert_almost_equal(ref_regimes, cmp_regimes)
8987

9088

9189
@pytest.mark.parametrize("T", test_data)
@@ -123,9 +121,7 @@ def test_mpdist_snippets_percentage(T, m, k, percentage):
123121
ref_fractions, cmp_fractions, decimal=config.STUMPY_TEST_PRECISION
124122
)
125123
npt.assert_almost_equal(ref_areas, cmp_areas, decimal=config.STUMPY_TEST_PRECISION)
126-
npt.assert_almost_equal(
127-
ref_regimes, cmp_regimes, decimal=config.STUMPY_TEST_PRECISION
128-
)
124+
npt.assert_almost_equal(ref_regimes, cmp_regimes)
129125

130126

131127
@pytest.mark.parametrize("T", test_data)
@@ -163,6 +159,4 @@ def test_mpdist_snippets_s(T, m, k, s):
163159
ref_fractions, cmp_fractions, decimal=config.STUMPY_TEST_PRECISION
164160
)
165161
npt.assert_almost_equal(ref_areas, cmp_areas, decimal=config.STUMPY_TEST_PRECISION)
166-
npt.assert_almost_equal(
167-
ref_regimes, cmp_regimes, decimal=config.STUMPY_TEST_PRECISION
168-
)
162+
npt.assert_almost_equal(ref_regimes, cmp_regimes)

0 commit comments

Comments
 (0)