Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 8 additions & 6 deletions dandi/pynwb_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -339,19 +339,21 @@ def _get_session_duration(nwb: pynwb.NWBFile) -> float | None:
# Read only the first and last spike time from each unit
if "spike_times" in obj.colnames and len(obj["spike_times"]):
idxs = obj["spike_times"].data[:]
# Keep only boundaries where cumulative spike count increases.
# Non-spiking units repeat the prior cumulative index and are skipped.
unit_end_idxs = idxs[np.diff(np.r_[0, idxs]) > 0]

# handle bug if the first unit has no spikes
if idxs[0] == 0:
idxs = idxs[1:]
if len(unit_end_idxs) == 0:
continue

st_data = obj["spike_times"].target

if len(idxs) > 1:
start = float(np.min(np.r_[st_data[0], st_data[idxs[:-1]]]))
if len(unit_end_idxs) > 1:
start = float(np.min(np.r_[st_data[0], st_data[unit_end_idxs[:-1]]]))
else:
start = float(st_data[0])

end = float(np.max(st_data[idxs - 1]))
end = float(np.max(st_data[unit_end_idxs - 1]))
start_times.append(float(start))
end_times.append(float(end))

Expand Down
47 changes: 35 additions & 12 deletions dandi/tests/test_metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -506,9 +506,6 @@ def test_session_duration_extraction(tmp_path: Path) -> None:
with NWBHDF5IO(str(nwb_path), "w") as io:
io.write(nwbfile)

# Extract metadata
from ..metadata.nwb import get_metadata, nwb2asset

metadata = get_metadata(nwb_path)

# Check that session_end_time was calculated
Expand Down Expand Up @@ -567,9 +564,6 @@ def test_session_duration_with_trials(tmp_path: Path) -> None:
with NWBHDF5IO(str(nwb_path), "w") as io:
io.write(nwbfile)

# Extract metadata
from ..metadata.nwb import get_metadata, nwb2asset

metadata = get_metadata(nwb_path)

# Check that session_end_time was calculated
Expand Down Expand Up @@ -626,9 +620,6 @@ def test_session_duration_with_units(tmp_path: Path) -> None:
with NWBHDF5IO(str(nwb_path), "w") as io:
io.write(nwbfile)

# Extract metadata
from ..metadata.nwb import get_metadata

metadata = get_metadata(nwb_path)

# Check that session_end_time was calculated
Expand All @@ -642,6 +633,41 @@ def test_session_duration_with_units(tmp_path: Path) -> None:
assert abs(duration - 245.0) < 1.0 # Allow small floating point errors


@pytest.mark.ai_generated
def test_session_duration_with_scattered_non_spiking_units(tmp_path: Path) -> None:
"""Test session duration with multiple non-spiking units in Units table."""
nwb_path = tmp_path / "test_duration_scattered_nonspiking_units.nwb"
session_start = datetime(2020, 1, 1, 12, 0, 0, tzinfo=tzutc())

nwbfile = NWBFile(
session_description="test session with scattered non-spiking units",
identifier="test_scattered_nonspiking_units_123",
session_start_time=session_start,
)

nwbfile.add_unit(spike_times=np.array([]))
nwbfile.add_unit(spike_times=np.array([10.0, 20.0]))
nwbfile.add_unit(spike_times=np.array([]))
nwbfile.add_unit(spike_times=np.array([5.0, 250.0]))
nwbfile.add_unit(spike_times=np.array([]))
nwbfile.add_unit(spike_times=np.array([100.0]))

with NWBHDF5IO(str(nwb_path), "w") as io:
io.write(nwbfile)

metadata = get_metadata(nwb_path)
assert "session_start_time" in metadata
assert "session_end_time" in metadata

end_offset = (metadata["session_end_time"] - session_start).total_seconds()
assert abs(end_offset - 245.0) < 1.0

duration = (
metadata["session_end_time"] - metadata["session_start_time"]
).total_seconds()
assert abs(duration - 245.0) < 1.0 # max 250s and min 5s spike times


@pytest.mark.ai_generated
def test_session_duration_with_events(tmp_path: Path) -> None:
"""Test that session duration includes timestamp/duration from DynamicTable"""
Expand Down Expand Up @@ -692,9 +718,6 @@ def test_session_duration_with_events(tmp_path: Path) -> None:
with NWBHDF5IO(str(nwb_path), "w") as io:
io.write(nwbfile)

# Extract metadata
from ..metadata.nwb import get_metadata

metadata = get_metadata(nwb_path)

# Check that session_end_time was calculated
Expand Down
Loading