Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions doc/changes/dev/13731.newfeature.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Speed up :func:`mne.stats.spatio_temporal_cluster_1samp_test` and related permutation cluster functions via precomputed sum-of-squares for sign-flip t-tests and SciPy connected-components clustering (~5x), by :newcontrib:`Sharif Haason`.
150 changes: 101 additions & 49 deletions mne/stats/cluster_level.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ def _masked_sum_power(x, c, t_power):

@jit()
def _sum_cluster_data(data, tstep):
return np.sign(data) * np.logical_not(data == 0) * tstep
return np.sign(data) * tstep


def _get_clusters_spatial(s, neighbors):
Expand Down Expand Up @@ -258,33 +258,62 @@ def _get_clusters_st_multistep(keepers, neighbors, max_step=1):


def _get_clusters_st(x_in, neighbors, max_step=1):
"""Choose the most efficient version."""
"""Find spatio-temporal clusters via SciPy connected components.

Builds a sparse adjacency graph over only the supra-threshold vertices
(spatial edges from the neighbor lists, temporal edges between the same
source at adjacent time steps) and labels clusters with
``scipy.sparse.csgraph.connected_components``.
"""
n_src = len(neighbors)
n_times = x_in.size // n_src
cl_goods = np.where(x_in)[0]
if len(cl_goods) > 0:
keepers = [np.array([], dtype=int)] * n_times
row, col = np.unravel_index(cl_goods, (n_times, n_src))
lims = [0]
if isinstance(row, int):
row = [row]
col = [col]
else:
order = np.argsort(row)
row = row[order]
col = col[order]
lims += (np.where(np.diff(row) > 0)[0] + 1).tolist()
lims.append(len(row))

for start, end in zip(lims[:-1], lims[1:]):
keepers[row[start]] = np.sort(col[start:end])
if max_step == 1:
return _get_clusters_st_1step(keepers, neighbors)
else:
return _get_clusters_st_multistep(keepers, neighbors, max_step)
else:
n_total = len(x_in)
active = np.where(x_in)[0]
if len(active) == 0:
return []

# Convert neighbor lists to CSR for vectorized expansion
lengths = np.array([len(a) for a in neighbors])
indptr = np.zeros(n_src + 1, dtype=np.intp)
np.cumsum(lengths, out=indptr[1:])
indices = np.concatenate(neighbors).astype(np.intp)

active_t, active_s = np.divmod(active, n_src)

# Spatial edges: vectorized CSR neighbor expansion
neighbor_counts = indptr[active_s + 1] - indptr[active_s]
src_flat = np.repeat(active, neighbor_counts)
src_t = np.repeat(active_t, neighbor_counts)
starts = indptr[active_s]
offsets = np.arange(int(np.sum(neighbor_counts))) - np.repeat(
np.cumsum(neighbor_counts) - neighbor_counts, neighbor_counts
)
nb_s = indices[np.repeat(starts, neighbor_counts) + offsets]
nb_flat = src_t * n_src + nb_s
mask = x_in[nb_flat]
rows = [src_flat[mask]]
cols = [nb_flat[mask]]

# Temporal edges: same source, adjacent time steps
for step in range(1, max_step + 1):
mask_t = active_t >= step
later = active[mask_t]
earlier = later - step * n_src
both = x_in[earlier]
rows.extend([later[both], earlier[both]])
cols.extend([earlier[both], later[both]])

# Self-loops so isolated active vertices get their own component
rows.append(active)
cols.append(active)
row = np.concatenate(rows)
col = np.concatenate(cols)
adj = sparse.coo_array((np.ones(len(row)), (row, col)), shape=(n_total, n_total))
_, labels = connected_components(adj)

# Build cluster list directly from component labels
cluster_labels = labels[active]
return [active[cluster_labels == id_] for id_ in np.unique(cluster_labels)]


def _get_components(x_in, adjacency, return_list=True):
"""Get connected components from a mask and a adjacency matrix."""
Expand Down Expand Up @@ -745,41 +774,61 @@ def _do_1samp_permutations(
# allocate space for output
max_cluster_sums = np.empty(len(orders), dtype=np.double)

# For sign-flips s²=1, so sum(X²) is constant across permutations.
# Precompute once and derive t-statistics via algebra instead of
# calling stat_fun each iteration.
use_fast_ttest = stat_fun is ttest_1samp_no_p
if use_fast_ttest:
sum_sq = np.sum(X**2, axis=0)
sqrt_n_nm1 = np.sqrt(n_samp * (n_samp - 1))
inv_n = 1.0 / n_samp
neg_n = -float(n_samp)

if buffer_size is not None:
# allocate a buffer so we don't need to allocate memory in loop
X_flip_buffer = np.empty((n_samp, buffer_size), dtype=X.dtype)

for seed_idx, order in enumerate(orders):
assert isinstance(order, np.ndarray)
# new surrogate data with specified sign flip
assert order.size == n_samp # should be guaranteed by parent
signs = 2 * order[:, None].astype(int) - 1
if not np.all(np.equal(np.abs(signs), 1)):
raise ValueError("signs from rng must be +/- 1")

if buffer_size is None:
# be careful about non-writable memmap (GH#1507)
if X.flags.writeable:
X *= signs
# Recompute statistic on randomized data
t_obs_surr = stat_fun(X)
# Set X back to previous state (trade memory eff. for CPU use)
X *= signs
else:
t_obs_surr = stat_fun(X * signs)
if use_fast_ttest:
signs = 2.0 * order - 1.0 # (n_samp,) ±1
dot = signs @ X # (n_vars,)
mean_s = dot * inv_n
denom_sq = np.maximum(sum_sq + mean_s * mean_s * neg_n, 0.0)
t_obs_surr = np.where(
denom_sq > 0, mean_s / np.sqrt(denom_sq) * sqrt_n_nm1, 0.0
)
else:
# only sign-flip a small data buffer, so we need less memory
t_obs_surr = np.empty(n_vars, dtype=X.dtype)
# new surrogate data with specified sign flip
signs = 2 * order[:, None].astype(int) - 1
if not np.all(np.equal(np.abs(signs), 1)):
raise ValueError("signs from rng must be +/- 1")

if buffer_size is None:
# be careful about non-writable memmap (GH#1507)
if X.flags.writeable:
X *= signs
# Recompute statistic on randomized data
t_obs_surr = stat_fun(X)
# Set X back to previous state (trade memory eff. for CPU use)
X *= signs
else:
t_obs_surr = stat_fun(X * signs)
else:
# only sign-flip a small data buffer, so we need less memory
t_obs_surr = np.empty(n_vars, dtype=X.dtype)

for pos in range(0, n_vars, buffer_size):
# number of variables for this loop
n_var_loop = min(pos + buffer_size, n_vars) - pos
for pos in range(0, n_vars, buffer_size):
# number of variables for this loop
n_var_loop = min(pos + buffer_size, n_vars) - pos

X_flip_buffer[:, :n_var_loop] = signs * X[:, pos : pos + n_var_loop]
X_flip_buffer[:, :n_var_loop] = signs * X[:, pos : pos + n_var_loop]

# apply stat_fun and store result
tmp = stat_fun(X_flip_buffer)
t_obs_surr[pos : pos + n_var_loop] = tmp[:n_var_loop]
# apply stat_fun and store result
tmp = stat_fun(X_flip_buffer)
t_obs_surr[pos : pos + n_var_loop] = tmp[:n_var_loop]

# The stat should have the same shape as the samples for no adj.
if adjacency is None:
Expand Down Expand Up @@ -953,7 +1002,10 @@ def _permutation_cluster_test(
logger.info(f"stat_fun(H1): min={np.min(t_obs)} max={np.max(t_obs)}")

# test if stat_fun treats variables independently
if buffer_size is not None:
# (skip for built-in stat functions which are known to be independent)
if buffer_size is not None and (
stat_fun is not ttest_1samp_no_p and stat_fun is not f_oneway
):
t_obs_buffer = np.zeros_like(t_obs)
for pos in range(0, n_tests, buffer_size):
t_obs_buffer[pos : pos + buffer_size] = stat_fun(
Expand Down
Loading