diff --git a/mne/epochs.py b/mne/epochs.py index 2d317caa63e..184a0f8a80b 100644 --- a/mne/epochs.py +++ b/mne/epochs.py @@ -1595,6 +1595,34 @@ def _handle_empty(self, on_empty, meth): ) _on_missing(on_empty, msg, error_klass=RuntimeError) + def _handle_tmin_tmax(self, tmin, tmax): + """Convert seconds to index into data.""" + _validate_type( + tmin, + types=("numeric", None), + item_name="tmin", + type_name="int, float, None", + ) + _validate_type( + tmax, + types=("numeric", None), + item_name="tmax", + type_name="int, float, None", + ) + + # handle tmin/tmax as start and stop indices into data array + n_times = self.times.size + start = 0 if tmin is None else self.time_as_index(tmin, use_rounding=True)[0] + stop = ( + n_times if tmax is None else self.time_as_index(tmax, use_rounding=True)[0] + ) + + # truncate start/stop to the open interval [0, n_times] + start = min(max(0, start), n_times) + stop = min(max(0, stop), n_times) + + return start, stop + @verbose def _get_data( self, diff --git a/mne/tests/test_epochs.py b/mne/tests/test_epochs.py index 91c5f902ac8..6513e76b167 100644 --- a/mne/tests/test_epochs.py +++ b/mne/tests/test_epochs.py @@ -4855,6 +4855,40 @@ def fun(data): assert_array_equal(out.get_data(non_picks), epochs.get_data(non_picks)) +def test_get_data_rounding(): + """Test that get_data respects rounding for tmin/tmax (gh-13634).""" + # Data setup mirroring the issue report + data = np.linspace(-3.5, 1, 451).reshape((1, 1, 451)) + info = create_info(["test"], 100.0, "eeg") + epochs = EpochsArray(data, info, tmin=-3.5, verbose=False) + + t = 0.77 + + # compare crop() vs get_data() + # crop() uses proper rounding internally, get_data() should match it. + val_crop = epochs.copy().crop(tmin=t).get_data()[0, 0, 0] + val_get_data = epochs.get_data(tmin=t)[0, 0, 0] + + assert_allclose( + val_get_data, + val_crop, + atol=1e-12, + err_msg="get_data(tmin) does not match crop(tmin)", + ) + + # verification on time consistency + # ensure we are getting the sample corresponding exactly to time 't' + idx = np.where(epochs.times == t)[0][0] + val_direct = epochs.get_data()[0, 0, idx] + + assert_allclose( + val_get_data, + val_direct, + atol=1e-12, + err_msg="get_data(tmin) does not match direct indexing", + ) + + def test_apply_function_epo_ch_access(): """Test ch-access within apply function to epoch objects."""