diff --git a/doc/changes/dev/13731.newfeature.rst b/doc/changes/dev/13731.newfeature.rst new file mode 100644 index 00000000000..ff9a97f6d77 --- /dev/null +++ b/doc/changes/dev/13731.newfeature.rst @@ -0,0 +1 @@ +Speed up :func:`mne.stats.spatio_temporal_cluster_1samp_test` and related permutation cluster functions via precomputed sum-of-squares for sign-flip t-tests and SciPy connected-components clustering (~5x), by :newcontrib:`Sharif Haason`. diff --git a/mne/stats/cluster_level.py b/mne/stats/cluster_level.py index eb887e74a7d..577447ac869 100644 --- a/mne/stats/cluster_level.py +++ b/mne/stats/cluster_level.py @@ -121,7 +121,7 @@ def _masked_sum_power(x, c, t_power): @jit() def _sum_cluster_data(data, tstep): - return np.sign(data) * np.logical_not(data == 0) * tstep + return np.sign(data) * tstep def _get_clusters_spatial(s, neighbors): @@ -258,33 +258,62 @@ def _get_clusters_st_multistep(keepers, neighbors, max_step=1): def _get_clusters_st(x_in, neighbors, max_step=1): - """Choose the most efficient version.""" + """Find spatio-temporal clusters via SciPy connected components. + + Builds a sparse adjacency graph over only the supra-threshold vertices + (spatial edges from the neighbor lists, temporal edges between the same + source at adjacent time steps) and labels clusters with + ``scipy.sparse.csgraph.connected_components``. + """ n_src = len(neighbors) - n_times = x_in.size // n_src - cl_goods = np.where(x_in)[0] - if len(cl_goods) > 0: - keepers = [np.array([], dtype=int)] * n_times - row, col = np.unravel_index(cl_goods, (n_times, n_src)) - lims = [0] - if isinstance(row, int): - row = [row] - col = [col] - else: - order = np.argsort(row) - row = row[order] - col = col[order] - lims += (np.where(np.diff(row) > 0)[0] + 1).tolist() - lims.append(len(row)) - - for start, end in zip(lims[:-1], lims[1:]): - keepers[row[start]] = np.sort(col[start:end]) - if max_step == 1: - return _get_clusters_st_1step(keepers, neighbors) - else: - return _get_clusters_st_multistep(keepers, neighbors, max_step) - else: + n_total = len(x_in) + active = np.where(x_in)[0] + if len(active) == 0: return [] + # Convert neighbor lists to CSR for vectorized expansion + lengths = np.array([len(a) for a in neighbors]) + indptr = np.zeros(n_src + 1, dtype=np.intp) + np.cumsum(lengths, out=indptr[1:]) + indices = np.concatenate(neighbors).astype(np.intp) + + active_t, active_s = np.divmod(active, n_src) + + # Spatial edges: vectorized CSR neighbor expansion + neighbor_counts = indptr[active_s + 1] - indptr[active_s] + src_flat = np.repeat(active, neighbor_counts) + src_t = np.repeat(active_t, neighbor_counts) + starts = indptr[active_s] + offsets = np.arange(int(np.sum(neighbor_counts))) - np.repeat( + np.cumsum(neighbor_counts) - neighbor_counts, neighbor_counts + ) + nb_s = indices[np.repeat(starts, neighbor_counts) + offsets] + nb_flat = src_t * n_src + nb_s + mask = x_in[nb_flat] + rows = [src_flat[mask]] + cols = [nb_flat[mask]] + + # Temporal edges: same source, adjacent time steps + for step in range(1, max_step + 1): + mask_t = active_t >= step + later = active[mask_t] + earlier = later - step * n_src + both = x_in[earlier] + rows.extend([later[both], earlier[both]]) + cols.extend([earlier[both], later[both]]) + + # Self-loops so isolated active vertices get their own component + rows.append(active) + cols.append(active) + row = np.concatenate(rows) + col = np.concatenate(cols) + adj = sparse.coo_array((np.ones(len(row)), (row, col)), shape=(n_total, n_total)) + _, labels = connected_components(adj) + + # Build cluster list directly from component labels + cluster_labels = labels[active] + return [active[cluster_labels == id_] for id_ in np.unique(cluster_labels)] + def _get_components(x_in, adjacency, return_list=True): """Get connected components from a mask and a adjacency matrix.""" @@ -745,41 +774,61 @@ def _do_1samp_permutations( # allocate space for output max_cluster_sums = np.empty(len(orders), dtype=np.double) + # For sign-flips s²=1, so sum(X²) is constant across permutations. + # Precompute once and derive t-statistics via algebra instead of + # calling stat_fun each iteration. + use_fast_ttest = stat_fun is ttest_1samp_no_p + if use_fast_ttest: + sum_sq = np.sum(X**2, axis=0) + sqrt_n_nm1 = np.sqrt(n_samp * (n_samp - 1)) + inv_n = 1.0 / n_samp + neg_n = -float(n_samp) + if buffer_size is not None: # allocate a buffer so we don't need to allocate memory in loop X_flip_buffer = np.empty((n_samp, buffer_size), dtype=X.dtype) for seed_idx, order in enumerate(orders): assert isinstance(order, np.ndarray) - # new surrogate data with specified sign flip assert order.size == n_samp # should be guaranteed by parent - signs = 2 * order[:, None].astype(int) - 1 - if not np.all(np.equal(np.abs(signs), 1)): - raise ValueError("signs from rng must be +/- 1") - if buffer_size is None: - # be careful about non-writable memmap (GH#1507) - if X.flags.writeable: - X *= signs - # Recompute statistic on randomized data - t_obs_surr = stat_fun(X) - # Set X back to previous state (trade memory eff. for CPU use) - X *= signs - else: - t_obs_surr = stat_fun(X * signs) + if use_fast_ttest: + signs = 2.0 * order - 1.0 # (n_samp,) ±1 + dot = signs @ X # (n_vars,) + mean_s = dot * inv_n + denom_sq = np.maximum(sum_sq + mean_s * mean_s * neg_n, 0.0) + t_obs_surr = np.where( + denom_sq > 0, mean_s / np.sqrt(denom_sq) * sqrt_n_nm1, 0.0 + ) else: - # only sign-flip a small data buffer, so we need less memory - t_obs_surr = np.empty(n_vars, dtype=X.dtype) + # new surrogate data with specified sign flip + signs = 2 * order[:, None].astype(int) - 1 + if not np.all(np.equal(np.abs(signs), 1)): + raise ValueError("signs from rng must be +/- 1") + + if buffer_size is None: + # be careful about non-writable memmap (GH#1507) + if X.flags.writeable: + X *= signs + # Recompute statistic on randomized data + t_obs_surr = stat_fun(X) + # Set X back to previous state (trade memory eff. for CPU use) + X *= signs + else: + t_obs_surr = stat_fun(X * signs) + else: + # only sign-flip a small data buffer, so we need less memory + t_obs_surr = np.empty(n_vars, dtype=X.dtype) - for pos in range(0, n_vars, buffer_size): - # number of variables for this loop - n_var_loop = min(pos + buffer_size, n_vars) - pos + for pos in range(0, n_vars, buffer_size): + # number of variables for this loop + n_var_loop = min(pos + buffer_size, n_vars) - pos - X_flip_buffer[:, :n_var_loop] = signs * X[:, pos : pos + n_var_loop] + X_flip_buffer[:, :n_var_loop] = signs * X[:, pos : pos + n_var_loop] - # apply stat_fun and store result - tmp = stat_fun(X_flip_buffer) - t_obs_surr[pos : pos + n_var_loop] = tmp[:n_var_loop] + # apply stat_fun and store result + tmp = stat_fun(X_flip_buffer) + t_obs_surr[pos : pos + n_var_loop] = tmp[:n_var_loop] # The stat should have the same shape as the samples for no adj. if adjacency is None: @@ -953,7 +1002,10 @@ def _permutation_cluster_test( logger.info(f"stat_fun(H1): min={np.min(t_obs)} max={np.max(t_obs)}") # test if stat_fun treats variables independently - if buffer_size is not None: + # (skip for built-in stat functions which are known to be independent) + if buffer_size is not None and ( + stat_fun is not ttest_1samp_no_p and stat_fun is not f_oneway + ): t_obs_buffer = np.zeros_like(t_obs) for pos in range(0, n_tests, buffer_size): t_obs_buffer[pos : pos + buffer_size] = stat_fun(