diff --git a/dandi/pynwb_utils.py b/dandi/pynwb_utils.py index 2b21335ea..75a4ee4a2 100644 --- a/dandi/pynwb_utils.py +++ b/dandi/pynwb_utils.py @@ -339,19 +339,21 @@ def _get_session_duration(nwb: pynwb.NWBFile) -> float | None: # Read only the first and last spike time from each unit if "spike_times" in obj.colnames and len(obj["spike_times"]): idxs = obj["spike_times"].data[:] + # Keep only boundaries where cumulative spike count increases. + # Non-spiking units repeat the prior cumulative index and are skipped. + unit_end_idxs = idxs[np.diff(np.r_[0, idxs]) > 0] - # handle bug if the first unit has no spikes - if idxs[0] == 0: - idxs = idxs[1:] + if len(unit_end_idxs) == 0: + continue st_data = obj["spike_times"].target - if len(idxs) > 1: - start = float(np.min(np.r_[st_data[0], st_data[idxs[:-1]]])) + if len(unit_end_idxs) > 1: + start = float(np.min(np.r_[st_data[0], st_data[unit_end_idxs[:-1]]])) else: start = float(st_data[0]) - end = float(np.max(st_data[idxs - 1])) + end = float(np.max(st_data[unit_end_idxs - 1])) start_times.append(float(start)) end_times.append(float(end)) diff --git a/dandi/tests/test_metadata.py b/dandi/tests/test_metadata.py index 8614a6eb1..00520d08f 100644 --- a/dandi/tests/test_metadata.py +++ b/dandi/tests/test_metadata.py @@ -506,9 +506,6 @@ def test_session_duration_extraction(tmp_path: Path) -> None: with NWBHDF5IO(str(nwb_path), "w") as io: io.write(nwbfile) - # Extract metadata - from ..metadata.nwb import get_metadata, nwb2asset - metadata = get_metadata(nwb_path) # Check that session_end_time was calculated @@ -567,9 +564,6 @@ def test_session_duration_with_trials(tmp_path: Path) -> None: with NWBHDF5IO(str(nwb_path), "w") as io: io.write(nwbfile) - # Extract metadata - from ..metadata.nwb import get_metadata, nwb2asset - metadata = get_metadata(nwb_path) # Check that session_end_time was calculated @@ -626,9 +620,6 @@ def test_session_duration_with_units(tmp_path: Path) -> None: with NWBHDF5IO(str(nwb_path), "w") as io: io.write(nwbfile) - # Extract metadata - from ..metadata.nwb import get_metadata - metadata = get_metadata(nwb_path) # Check that session_end_time was calculated @@ -642,6 +633,41 @@ def test_session_duration_with_units(tmp_path: Path) -> None: assert abs(duration - 245.0) < 1.0 # Allow small floating point errors +@pytest.mark.ai_generated +def test_session_duration_with_scattered_non_spiking_units(tmp_path: Path) -> None: + """Test session duration with multiple non-spiking units in Units table.""" + nwb_path = tmp_path / "test_duration_scattered_nonspiking_units.nwb" + session_start = datetime(2020, 1, 1, 12, 0, 0, tzinfo=tzutc()) + + nwbfile = NWBFile( + session_description="test session with scattered non-spiking units", + identifier="test_scattered_nonspiking_units_123", + session_start_time=session_start, + ) + + nwbfile.add_unit(spike_times=np.array([])) + nwbfile.add_unit(spike_times=np.array([10.0, 20.0])) + nwbfile.add_unit(spike_times=np.array([])) + nwbfile.add_unit(spike_times=np.array([5.0, 250.0])) + nwbfile.add_unit(spike_times=np.array([])) + nwbfile.add_unit(spike_times=np.array([100.0])) + + with NWBHDF5IO(str(nwb_path), "w") as io: + io.write(nwbfile) + + metadata = get_metadata(nwb_path) + assert "session_start_time" in metadata + assert "session_end_time" in metadata + + end_offset = (metadata["session_end_time"] - session_start).total_seconds() + assert abs(end_offset - 245.0) < 1.0 + + duration = ( + metadata["session_end_time"] - metadata["session_start_time"] + ).total_seconds() + assert abs(duration - 245.0) < 1.0 # max 250s and min 5s spike times + + @pytest.mark.ai_generated def test_session_duration_with_events(tmp_path: Path) -> None: """Test that session duration includes timestamp/duration from DynamicTable""" @@ -692,9 +718,6 @@ def test_session_duration_with_events(tmp_path: Path) -> None: with NWBHDF5IO(str(nwb_path), "w") as io: io.write(nwbfile) - # Extract metadata - from ..metadata.nwb import get_metadata - metadata = get_metadata(nwb_path) # Check that session_end_time was calculated