diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 0ae6860..c7366ae 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -44,6 +44,14 @@ jobs: exit 1 fi + # Guard against a stale lockfile: uv.lock must agree with pyproject.toml. + # `uv lock --check` fails if the lock is out of date (e.g. the version was + # bumped in pyproject.toml but uv.lock was never re-locked to match). + - name: Verify uv.lock is in sync with pyproject.toml + uses: astral-sh/setup-uv@v5 + - name: uv lock --check + run: uv lock --check + # Pull the notes for this tag out of CHANGELOG.md. We match the heading # whose version equals the tag without the leading "v" (so tag v1.2.3 maps # to a "## [1.2.3]" / "## 1.2.3" heading) and emit everything up to the diff --git a/docs/development.md b/docs/development.md index f481bf6..9b944d5 100644 --- a/docs/development.md +++ b/docs/development.md @@ -67,14 +67,31 @@ rather than from the lockfile. mid-task-deterministic/ ├── src/ │ └── mid_det/ -│ ├── __init__.py # Version -│ ├── __main__.py # Entry point; wires all modules together -│ ├── config.py # All task constants (no cross-module imports) -│ ├── display.py # PsychoPy stimuli construction and draw helpers -│ ├── recorder.py # TrialRecord, ScanPhase, CSV writers, manifest -│ ├── scanner.py # HardwareBackend, EmulatedBackend, PulseCounter -│ ├── session.py # Startup dialog, screen setup, sequence loading -│ └── trial.py # Per-phase functions and run_trial() +│ ├── __init__.py # Version +│ ├── __main__.py # Entry point; wires all modules together +│ ├── _psychopy.py # PsychoPy import shim for headless testing +│ ├── config.py # All task constants (no cross-module imports) +│ ├── task/ # The experiment run + on-screen presentation +│ │ ├── trial.py # run_trial(); ties the phases together +│ │ ├── phases.py # Fixed-duration per-phase display loops +│ │ ├── response.py # Timing-critical response window +│ │ ├── flip_timer.py # FlipTimer per-flip target-display diagnostics +│ │ ├── calibration.py # Per-cue adaptive target-window staircase +│ │ ├── instructions.py # Self-paced instruction presentation +│ │ ├── display.py # PsychoPy stimuli construction and draw helpers +│ │ ├── console.py # Rich live-view trial table +│ │ └── debug.py # F3-toggleable debug overlay HUD +│ ├── io/ # Input/output boundary +│ │ ├── bootstrap.py # SessionInfo/ScreenDiagnostics, screen setup, run dir +│ │ ├── setup_wizard.py # Interactive terminal setup wizard +│ │ ├── scanner.py # HardwareBackend, EmulatedBackend, PulseCounter +│ │ ├── sequences.py # Sequence CSV loading and validation +│ │ └── recording/ # Data recording +│ │ ├── records.py # TrialRecord/TargetTimingRecord/ScanPhase + schemas +│ │ ├── csv_writers.py # CsvWriter + behavioral/target-timing/scan-log writers +│ │ ├── legacy.py # LegacyMidCsvWriter + MATLAB-format helpers +│ │ └── manifest.py # write_manifest / write_ratings_manifest +│ └── ratings/ # Standalone cue-ratings survey (mid-ratings-det) ├── sequences/ │ ├── run_1.csv # 54-trial sequence for run 1 │ ├── run_2.csv # 54-trial sequence for run 2 @@ -92,11 +109,8 @@ mid-task-deterministic/ | Module | Responsibility | |--------|---------------| | `config.py` | Single source of truth for all timing, keyboard, scanner, and target-duration constants | -| `session.py` | Startup GUI dialog, screen/monitor setup, sequence CSV loading, instruction display | -| `display.py` | Build all PsychoPy `Visual` objects; draw helpers for each phase (circle/square cue with magnitude line) | -| `scanner.py` | Abstract scanner backend; `HardwareBackend` (MCC DAQ) and `EmulatedBackend` (software clock) | -| `trial.py` | `run_trial()` and per-phase functions (`run_cue`, `run_fixation`, `run_response`, `run_outcome`, `run_iti`) | -| `recorder.py` | `TrialRecord` and `ScanPhase` dataclasses; CSV writers; `write_manifest()` | +| `task/` | The experiment run: `trial.run_trial()`, per-phase loops (`phases.py`), the timing-critical `response.py` + `flip_timer.py`, the adaptive `calibration.py`, on-screen `display.py`/`instructions.py`, and operator UI (`console.py`, `debug.py`) | +| `io/` | The I/O boundary: session `bootstrap.py` (screen setup, run dir, `SessionInfo`/`ScreenDiagnostics`), the terminal `setup_wizard.py`, `scanner.py` hardware, `sequences.py` loading, and the `recording/` package (records, CSV writers, legacy MATLAB format, manifests) | | `__main__.py` | Orchestration: init → instructions → wait for scan → trial loop → cleanup | ## Relationship to `mid-task` diff --git a/docs/timing.md b/docs/timing.md index 6267e03..6610ed3 100644 --- a/docs/timing.md +++ b/docs/timing.md @@ -23,7 +23,7 @@ should_remove = ( `frame_dur_s` comes from `win.getActualFrameRate()`. If PsychoPy can't get a stable measurement (returns `None`, or a value outside 30–200 Hz), `__main__.run()` raises `RuntimeError` rather than guessing — a wrong frame period silently corrupts every target duration. The user can override with `--fps ` if they know the refresh rate but VSYNC measurement is broken (e.g. macOS dev rigs, where the Cocoa compositor doesn't honor `set_vsync(True)`). -`session.py:setup_screen` passes `waitBlanking=True` and calls `winHandle.set_vsync(True)` so production Windows rigs flip on VSYNC. +`io/bootstrap.py:setup_screen` passes `waitBlanking=True` and calls `winHandle.set_vsync(True)` so production Windows rigs flip on VSYNC. ### macOS caveat diff --git a/pyproject.toml b/pyproject.toml index 3f1283c..ddb99dd 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -16,6 +16,7 @@ dependencies = [ "pandas>=2.0", "rich>=14.3.3", "questionary>=2.0", + "prompt_toolkit>=3.0", "mcculw>=1.0.0", "pyobjc-framework-quartz>=10; sys_platform == 'darwin'", ] diff --git a/src/mid_det/__main__.py b/src/mid_det/__main__.py index 58a86d5..0d416da 100644 --- a/src/mid_det/__main__.py +++ b/src/mid_det/__main__.py @@ -20,10 +20,12 @@ from psychopy.hardware import keyboard from rich.console import Console -from mid_det import config, display, recorder, scanner, sequences, session, setup_wizard, trial -from mid_det.calibration import CalibrationState -from mid_det.console import TrialLiveView -from mid_det.debug import DebugOverlay, DebugState +from mid_det import config +from mid_det.task import display, instructions, trial +from mid_det.io import bootstrap, recording, scanner, sequences, setup_wizard +from mid_det.task.calibration import CalibrationState +from mid_det.task.console import TrialLiveView +from mid_det.task.debug import DebugOverlay, DebugState def _raise_process_priority() -> str | None: @@ -58,18 +60,13 @@ def run() -> None: # ── SCREEN & FRAME RATE ────────────────────────────────────────────────── # Open the window first so we have a real frame duration to pass into the # setup wizard (it uses it for RT-field defaults and frame-alignment hints). - win_res, win, screen_diag = session.setup_screen() - - # Warm-up flips before frame-rate measurement: PsychoPy's detectingFrameDrops - # doc notes drops are common during startup as the GPU/driver settle. - for _ in range(30): - win.flip() + win_res, win, screen_diag = bootstrap.setup_screen() if args.fps is not None: frame_rate: float = args.fps frame_dur_s: float = 1.0 / args.fps fps_source = "specified" - elif 1000.0 / 200.0 <= screen_diag.calib_median_ms <= 1000.0 / 30.0: + elif 1000.0 / config.MAX_REFRESH_HZ <= screen_diag.calib_median_ms <= 1000.0 / config.MIN_REFRESH_HZ: # Prefer the 120-flip VSYNC-calibration median: it's a more reliable # estimator than getActualFrameRate(), and the response loop's # `round(t / frame_dur_s)` termination is sensitive to small drift in @@ -83,7 +80,7 @@ def run() -> None: # than silently degrading — a guessed rate would corrupt every target # duration. Use --fps to override. measured_fps = win.getActualFrameRate() - if measured_fps is None or not (30.0 <= measured_fps <= 200.0): + if measured_fps is None or not (config.MIN_REFRESH_HZ <= measured_fps <= config.MAX_REFRESH_HZ): win.close() raise RuntimeError( f"Could not measure a stable refresh rate " @@ -101,7 +98,7 @@ def run() -> None: # ── LOGGING ────────────────────────────────────────────────────────────── data_dir = Path("data") - run_dir = session.make_run_dir(data_dir, session_info, session_time) + run_dir = bootstrap.make_run_dir(data_dir, session_info, session_time) logging.LogFile(str(run_dir / "experiment.log"), level=logging.EXP) logging.console.setLevel(logging.WARNING) @@ -179,10 +176,20 @@ def _flip_with_overlay(*args, **kwargs): # noqa: E306 # ── SETUP OUTPUT FILES ─────────────────────────────────────────────────── file_stem = f"{session_info.subject_id}_run{session_info.run_n}" - behavioral_writer = recorder.BehavioralCsvWriter(run_dir / f"behavioral_{file_stem}.csv") - target_timing_writer = recorder.TargetTimingCsvWriter(run_dir / f"target_timing_{file_stem}.csv") - scan_log_writer = recorder.ScanLogWriter(run_dir / f"scan_log_{file_stem}.csv") - recorder.write_manifest( + behavioral_writer = recording.BehavioralCsvWriter(run_dir / f"behavioral_{file_stem}.csv") + target_timing_writer = recording.TargetTimingCsvWriter(run_dir / f"target_timing_{file_stem}.csv") + scan_log_writer = recording.ScanLogWriter(run_dir / f"scan_log_{file_stem}.csv") + legacy_dir = data_dir / "legacy-fmt" + legacy_dir.mkdir(parents=True, exist_ok=True) + # MATLAB PartialParseData.m numbers trials continuously across blocks: block 1 + # is trials 1-42, so block 2 continues from 43. Our trial_n restarts at 1 each + # run, so shift run 2 up by block 1's length (42) to restore that numbering. + legacy_trial_offset = 42 if session_info.run_n == "2" else 0 + legacy_writer = recording.LegacyMidCsvWriter( + legacy_dir / f"{session_info.legacy_name}_b{session_info.run_n}.csv", + trial_offset=legacy_trial_offset, + ) + recording.write_manifest( run_dir=run_dir, session_info=session_info, session_time=session_time, @@ -212,7 +219,7 @@ def _flip_with_overlay(*args, **kwargs): # noqa: E306 # ── INSTRUCTIONS ───────────────────────────────────────────────────────── if session_info.show_instructions: - session.display_instructions(win, stimuli_obj, session_info, kb, rcon) + instructions.display_instructions(win, stimuli_obj, session_info, kb, rcon) # ── PULSE COUNTER ──────────────────────────────────────────────────────── backend = scanner.make_backend(session_info.fmri) @@ -314,6 +321,7 @@ def _flip_with_overlay(*args, **kwargs): # noqa: E306 behavioral_writer.append(rec) target_timing_writer.append(target_timing) + legacy_writer.append(rec) for sp in scan_phases: scan_log_writer.append(sp) @@ -346,6 +354,7 @@ def _flip_with_overlay(*args, **kwargs): # noqa: E306 behavioral_writer.close() target_timing_writer.close() scan_log_writer.close() + legacy_writer.close() logging.flush() win.close() core.quit() diff --git a/src/mid_det/_psychopy.py b/src/mid_det/_psychopy.py new file mode 100644 index 0000000..f56ef21 --- /dev/null +++ b/src/mid_det/_psychopy.py @@ -0,0 +1,23 @@ +""" +Shared PsychoPy import shim. + +Keeps the per-phase, response, and orchestration modules importable in +headless/CI environments without PsychoPy, so the pure-logic and timing code +stays testable. `core` is a namespace with the attributes those paths reference +— tests patch core.Clock; real runs always have PsychoPy. +""" +from __future__ import annotations + +try: + from psychopy import core, logging, visual + from psychopy.hardware import keyboard +except ModuleNotFoundError: + import types + + visual = keyboard = None # type: ignore[assignment] + logging = types.SimpleNamespace(exp=lambda *a, **k: None) # type: ignore[assignment] + core = types.SimpleNamespace( # type: ignore[assignment] + Clock=None, CountdownTimer=None, quit=lambda *a, **k: None + ) + +__all__ = ["core", "logging", "visual", "keyboard"] diff --git a/src/mid_det/config.py b/src/mid_det/config.py index adc2bb2..594b3b9 100644 --- a/src/mid_det/config.py +++ b/src/mid_det/config.py @@ -12,6 +12,13 @@ "iti": 2.0, } +# Nominal duration of the four fixed slides (cue+fixation+response+outcome) that +# precede the ITI. Used as the drift baseline, mirroring MATLAB main.m's hardcoded +# `- 8.0` (see trial.py timing_drift_ms). +PRE_ITI_NOMINAL_S: float = sum( + STUDY_TIMES_S[k] for k in ("cue", "fixation", "response", "outcome") +) + # Polarity → shape (and reward sign). "polarity" is the gain/loss dimension; # kept distinct from the affective "valence" rated in the cue-ratings survey. POLARITIES: list[str] = ["gain", "loss"] @@ -57,6 +64,12 @@ JITTER_MIN_S: float = 0.25 JITTER_MAX_S: float = 1.0 +# Plausible display refresh rates. Used to sanity-check measured/calibrated +# rates before we trust them for timing — anything outside this band is treated +# as a failed measurement rather than a real refresh rate. +MIN_REFRESH_HZ: float = 30.0 +MAX_REFRESH_HZ: float = 200.0 + # Scanner settings SCANNER_PULSE_RATE: int = 46 BOARD_NUM: int = 0 diff --git a/src/mid_det/io/__init__.py b/src/mid_det/io/__init__.py new file mode 100644 index 0000000..80a44a6 --- /dev/null +++ b/src/mid_det/io/__init__.py @@ -0,0 +1,2 @@ +"""The input/output boundary: session bootstrap, the terminal setup wizard, +scanner hardware, sequence loading, and data recording.""" diff --git a/src/mid_det/session.py b/src/mid_det/io/bootstrap.py similarity index 61% rename from src/mid_det/session.py rename to src/mid_det/io/bootstrap.py index 6f8c848..305b8c3 100644 --- a/src/mid_det/session.py +++ b/src/mid_det/io/bootstrap.py @@ -1,6 +1,7 @@ """ -Session initialisation: dialog, screen setup, output directory, and instruction -display. +Session bootstrap: the SessionInfo / ScreenDiagnostics dataclasses, screen + +frame-timing setup, and output-directory creation. Instruction presentation +lives in mid_det.task.instructions. """ from __future__ import annotations @@ -13,15 +14,9 @@ import pyglet from psychopy import core, monitors, visual -from psychopy.hardware import keyboard -from rich.console import Console from mid_det import config -_PACKAGE_DIR = Path(__file__).parent # src/mid_det/ -_PROJECT_ROOT = _PACKAGE_DIR.parent.parent # project root -_TEXT_DIR = _PROJECT_ROOT / "text" - @dataclass class SessionInfo: @@ -31,6 +26,7 @@ class SessionInfo: show_instructions: bool base_rt_s: float rt_change_s: float = config.RT_CHANGE_S # staircase step; set by wizard + legacy_name: str = "" # NAME for legacy-fmt/{NAME}_b{run}.csv @dataclass @@ -82,7 +78,12 @@ def setup_screen() -> tuple[list[int], visual.Window, ScreenDiagnostics]: # percentile is well above one frame period, vsync is not actually blocking # — typical on Windows under DWM composition or borderless fullscreen. intervals_ms: list[float] = [] - win.flip() # warm-up; first interval after a stale context can be misleading + # Warm-up flips before measurement: PsychoPy's detectingFrameDrops doc notes + # drops are common during startup as the GPU/driver/compositor settle. Run + # these before the calibration loop so the median feeding frame_dur_s is + # measured on a settled context, not a cold one. + for _ in range(30): + win.flip() last_t = core.getTime() for _ in range(120): win.flip() @@ -94,8 +95,8 @@ def setup_screen() -> tuple[list[int], visual.Window, ScreenDiagnostics]: p99 = intervals_ms[int(0.99 * len(intervals_ms)) - 1] mx = intervals_ms[-1] - # Enable PsychoPy's frame interval recording so trial.run_response can read - # win.nDroppedFrames and isolate on-screen drops from measurement artefacts. + # Enable PsychoPy's frame interval recording so response.run_response can read + # win.nDroppedFrames and isolate on-screen drops from measurement artifacts. win.refreshThreshold = (median / 1000.0) * 1.5 win.recordFrameIntervals = True @@ -119,57 +120,3 @@ def make_run_dir(data_dir: Path, session_info: SessionInfo, session_time: dateti run_dir = data_dir / f"{session_info.subject_id}_run{session_info.run_n}_{ts}" run_dir.mkdir(parents=True, exist_ok=True) return run_dir - - -def display_instructions( - win: visual.Window, - stimuli, # Stimuli dataclass from display.py; avoid circular import - session_info: SessionInfo, - kb: keyboard.Keyboard, - rcon: Console, -) -> None: - """Display instructions from text/instructions_MID.txt one page at a time.""" - keys_map = config.KEYS_FMRI if session_info.fmri else config.KEYS_BEHAVIORAL - forward_key = keys_map["forward"] - start_key = keys_map["start"] - end_key = keys_map["end"] - - inst_path = _TEXT_DIR / "instructions_MID.txt" - pages: list[str] = [] - with open(inst_path) as f: - for line in f: - stripped = line.rstrip() - if stripped: - pages.append(stripped) - - if not pages: - return - - kb.clearEvents() - page_idx = 0 - - while True: - stimuli.instr_prompt.text = pages[page_idx] - stimuli.instr_prompt.draw() - stimuli.instr_first.draw() - win.flip() - - pressed = kb.getKeys(keyList=[forward_key, end_key], waitRelease=False) - if not pressed: - continue - key_name = pressed[0].name - if key_name == end_key: - core.quit() - elif key_name == forward_key: - page_idx += 1 - if page_idx >= len(pages): - break - - rcon.print( - f"[bold yellow]End of instructions — press '{start_key}' to continue...[/bold yellow]" - ) - while True: - stimuli.instr_finish.draw() - win.flip() - if kb.getKeys(keyList=[start_key], waitRelease=False): - break diff --git a/src/mid_det/io/recording/__init__.py b/src/mid_det/io/recording/__init__.py new file mode 100644 index 0000000..f4c10a1 --- /dev/null +++ b/src/mid_det/io/recording/__init__.py @@ -0,0 +1,38 @@ +"""Data recording: trial-record dataclasses, the per-trial CSV writers (modern +and legacy MATLAB formats), and the run-manifest writers. The public API is +re-exported here so callers can use ``mid_det.io.recording`` directly.""" +from __future__ import annotations + +from mid_det.io.recording.csv_writers import ( + BehavioralCsvWriter, + CsvWriter, + ScanLogWriter, + TargetTimingCsvWriter, +) +from mid_det.io.recording.legacy import LEGACY_MID_COLUMNS, LegacyMidCsvWriter +from mid_det.io.recording.manifest import write_manifest, write_ratings_manifest +from mid_det.io.recording.records import ( + BEHAVIORAL_COLUMNS, + SCAN_LOG_COLUMNS, + TARGET_TIMING_COLUMNS, + ScanPhase, + TargetTimingRecord, + TrialRecord, +) + +__all__ = [ + "BEHAVIORAL_COLUMNS", + "TARGET_TIMING_COLUMNS", + "SCAN_LOG_COLUMNS", + "LEGACY_MID_COLUMNS", + "TrialRecord", + "TargetTimingRecord", + "ScanPhase", + "CsvWriter", + "BehavioralCsvWriter", + "TargetTimingCsvWriter", + "ScanLogWriter", + "LegacyMidCsvWriter", + "write_manifest", + "write_ratings_manifest", +] diff --git a/src/mid_det/io/recording/csv_writers.py b/src/mid_det/io/recording/csv_writers.py new file mode 100644 index 0000000..ec0f1c1 --- /dev/null +++ b/src/mid_det/io/recording/csv_writers.py @@ -0,0 +1,58 @@ +""" +The modern per-trial CSV writers: a generic CsvWriter plus the behavioral, +target-timing, and scan-log writers that bind it to a fixed column schema. The +MATLAB legacy-format writer lives in legacy.py. +""" +from __future__ import annotations + +import csv +from pathlib import Path + +from mid_det.io.recording.records import ( + BEHAVIORAL_COLUMNS, + SCAN_LOG_COLUMNS, + TARGET_TIMING_COLUMNS, + ScanPhase, + TargetTimingRecord, + TrialRecord, +) + + +class CsvWriter: + def __init__(self, path: Path, columns: list[str]) -> None: + self._file = open(path, "w", newline="") + self._writer = csv.DictWriter(self._file, fieldnames=columns) + self._writer.writeheader() + self._columns = columns + + def append(self, record: object) -> None: + row = {k: getattr(record, k) for k in self._columns} + self._writer.writerow(row) + self._file.flush() + + def close(self) -> None: + self._file.close() + + +class BehavioralCsvWriter(CsvWriter): + def __init__(self, path: Path) -> None: + super().__init__(path, BEHAVIORAL_COLUMNS) + + def append(self, record: TrialRecord) -> None: # type: ignore[override] + super().append(record) + + +class TargetTimingCsvWriter(CsvWriter): + def __init__(self, path: Path) -> None: + super().__init__(path, TARGET_TIMING_COLUMNS) + + def append(self, record: TargetTimingRecord) -> None: # type: ignore[override] + super().append(record) + + +class ScanLogWriter(CsvWriter): + def __init__(self, path: Path) -> None: + super().__init__(path, SCAN_LOG_COLUMNS) + + def append(self, phase: ScanPhase) -> None: # type: ignore[override] + super().append(phase) diff --git a/src/mid_det/io/recording/legacy.py b/src/mid_det/io/recording/legacy.py new file mode 100644 index 0000000..af32218 --- /dev/null +++ b/src/mid_det/io/recording/legacy.py @@ -0,0 +1,116 @@ +""" +The legacy MATLAB MID CSV path (one row per TR) for downstream-system +compatibility, plus the num2str/dollar-string formatting helpers that reproduce +the MATLAB PartialParseData.m / PresentCue.m / PresentFeedback.m output exactly. +""" +from __future__ import annotations + +import csv +import math +from pathlib import Path + +from mid_det.io.recording.records import TrialRecord + +LEGACY_MID_COLUMNS: list[str] = [ + "trial", "TR", "trialonset", "trialtype", "target_ms", "rt", "cue_value", + "hit", "trial_gain", "total", "iti", "drift", + "total_winpercent", "binned_winpercent", +] + + +def _num2str(x: float) -> str: + """Format a number like MATLAB's default ``num2str`` (used throughout + PartialParseData.m). Default precision is ``%g`` with + ``max(floor(log10(|x|)), 0) + 5`` significant figures, so values < 1 (drift, + win-percents) get 5 sig figs (``0.16667``, ``-0.00011858``) and larger + magnitudes get more (``12.013``). ``%g`` collapses whole values (``1.0`` → + ``"1"``) and only uses scientific notation for exponents < -4.""" + if x == 0: + return "0" + sig = max(math.floor(math.log10(abs(x))), 0) + 5 + return f"{x:.{sig}g}" + + +def _legacy_cue_value(polarity: str, magnitude: int) -> str: + """Legacy cue-value string (MATLAB PresentCue.m): gain → "+$X", loss → "-$X". + The sign is always shown, including for magnitude 0 (e.g. "+$0" / "-$0").""" + sign = "+" if polarity == "gain" else "-" + return f"{sign}${magnitude}" + + +def _legacy_trial_gain(polarity: str, magnitude: int, hit: int) -> str: + """Legacy realised-gain string (MATLAB PresentFeedback.m `valuestr`): + gain + hit → "+$mag" gain + miss → "$0" + loss + hit → "$0" loss + miss → "-$mag" + """ + if polarity == "gain": + return f"+${magnitude}" if hit else "$0" + return "$0" if hit else f"-${magnitude}" + + +def _legacy_total(total: int) -> str: + """Legacy running total (MATLAB PartialParseData.m: ['$' num2str(total,'%#4.2f')]). + Negatives render as "$-1.00", non-negatives as "$0.00" / "$12.00".""" + return f"${total:.2f}" + + +class LegacyMidCsvWriter: + """Writes the legacy MATLAB MID CSV (one row per TR) for downstream-system + compatibility. Tracks cumulative win-rate counters across appended trials. + + *trial_offset* is added to each trial number to mirror MATLAB + PartialParseData.m, which numbers block-2 trials starting at 43 (offset 42). + """ + + def __init__(self, path: Path, trial_offset: int = 0) -> None: + self._file = open(path, "w", newline="") + self._writer = csv.DictWriter(self._file, fieldnames=LEGACY_MID_COLUMNS) + self._writer.writeheader() + self._trial_offset = trial_offset + self._n_hits = 0 + self._n_trials = 0 + self._type_hits: dict[int, int] = {} + self._type_trials: dict[int, int] = {} + + def append(self, record: TrialRecord) -> None: + self._n_trials += 1 + self._n_hits += record.hit + self._type_trials[record.trial_type] = self._type_trials.get(record.trial_type, 0) + 1 + self._type_hits[record.trial_type] = self._type_hits.get(record.trial_type, 0) + record.hit + + total_winpercent = self._n_hits / self._n_trials + binned_winpercent = ( + self._type_hits[record.trial_type] / self._type_trials[record.trial_type] + ) + # MATLAB rt_vector: -2 = early press, -1 = miss/too-slow, else RT (seconds). + if record.early_press: + rt: float = -2 + elif isinstance(record.rt_ms, str): # "" sentinel = miss/too-slow + rt = -1 + else: + rt = record.rt_ms / 1000 + # Float columns are formatted with _num2str to match MATLAB's default + # num2str precision; integer and dollar-string columns are left as-is. + row = { + "trial": record.trial_n + self._trial_offset, + "TR": 0, # filled per row below + "trialonset": _num2str(record.time_onset), + "trialtype": record.trial_type, + "target_ms": _num2str(record.target_dur_ms / 1000), + "rt": _num2str(rt), + "cue_value": _legacy_cue_value(record.polarity, record.magnitude), + "hit": record.hit, + "trial_gain": _legacy_trial_gain(record.polarity, record.magnitude, record.hit), + "total": _legacy_total(record.total_earned), + "iti": record.n_iti_trs * 2, + "drift": _num2str(record.timing_drift_ms / 1000), + "total_winpercent": _num2str(total_winpercent), + "binned_winpercent": _num2str(binned_winpercent), + } + for tr in range(1, record.total_trs + 1): + row["TR"] = tr + self._writer.writerow(row) + self._file.flush() + + def close(self) -> None: + self._file.close() diff --git a/src/mid_det/recorder.py b/src/mid_det/io/recording/manifest.py similarity index 62% rename from src/mid_det/recorder.py rename to src/mid_det/io/recording/manifest.py index 27b2a95..0cf489c 100644 --- a/src/mid_det/recorder.py +++ b/src/mid_det/io/recording/manifest.py @@ -1,21 +1,21 @@ """ -Data recording: TrialRecord, ScanPhase, CsvWriter, ScanLogWriter, write_manifest. +Run manifest writers (manifest.json) plus the best-effort system/hardware +diagnostics they embed: write_manifest for a task run, write_ratings_manifest for +a cue-ratings survey run. """ from __future__ import annotations -import csv import json import platform import socket import subprocess import sys -from dataclasses import dataclass from datetime import datetime from pathlib import Path from typing import TYPE_CHECKING if TYPE_CHECKING: - from mid_det.session import ScreenDiagnostics, SessionInfo + from mid_det.io.bootstrap import ScreenDiagnostics, SessionInfo def _git_commit() -> str: @@ -80,126 +80,6 @@ def _system_info() -> dict: } -@dataclass -class TrialRecord: - trial_n: int - trial_type: int - polarity: str - magnitude: int - cue_label: str - time_onset: float - jitter_ms: int - jitter_ms_actual: float | str - target_dur_ms: int - target_dur_ms_actual: float | str - early_press: int - hit: int - rt_ms: float | str - reward_outcome: str - total_earned: int - time_trial_end: float - trial_dur_ms: int - time_sched_end: float - timing_drift_ms: float - n_iti_trs: int - total_trs: int - subject_id: str - run_n: str - pulse_ct: int - - -@dataclass -class TargetTimingRecord: - trial_n: int - target_frames_scheduled: int - target_frames_shown: int - target_visible_ms_scheduled: float - target_visible_ms_measured: float | str - late_flips_in_window: int - longest_frame_interval_ms: float - target_timing_ok: int - - -@dataclass -class ScanPhase: - trial_n: int - phase: str - tr_n: int - phase_onset_global_time: float - phase_onset_trial_time: float - pulse_ct: int - phase_offset_global_time: float = 0.0 - phase_offset_trial_time: float = 0.0 - trial_type: int = 0 - polarity: str = "" - magnitude: int = 0 - - -BEHAVIORAL_COLUMNS: list[str] = [ - "trial_n", "trial_type", "polarity", "magnitude", "cue_label", - "time_onset", "jitter_ms", "jitter_ms_actual", - "target_dur_ms", "target_dur_ms_actual", "early_press", "hit", "rt_ms", - "reward_outcome", "total_earned", "time_trial_end", "trial_dur_ms", - "time_sched_end", "timing_drift_ms", "n_iti_trs", "total_trs", - "subject_id", "run_n", "pulse_ct", -] - -TARGET_TIMING_COLUMNS: list[str] = [ - "trial_n", - "target_frames_scheduled", "target_frames_shown", - "target_visible_ms_scheduled", "target_visible_ms_measured", - "late_flips_in_window", "longest_frame_interval_ms", - "target_timing_ok", -] - -SCAN_LOG_COLUMNS: list[str] = [ - "trial_n", "trial_type", "polarity", "magnitude", "phase", "tr_n", - "phase_onset_global_time", "phase_offset_global_time", - "phase_onset_trial_time", "phase_offset_trial_time", - "pulse_ct", -] - - -class CsvWriter: - def __init__(self, path: Path, columns: list[str]) -> None: - self._file = open(path, "w", newline="") - self._writer = csv.DictWriter(self._file, fieldnames=columns) - self._writer.writeheader() - self._columns = columns - - def append(self, record: object) -> None: - row = {k: getattr(record, k) for k in self._columns} - self._writer.writerow(row) - self._file.flush() - - def close(self) -> None: - self._file.close() - - -class BehavioralCsvWriter(CsvWriter): - def __init__(self, path: Path) -> None: - super().__init__(path, BEHAVIORAL_COLUMNS) - - def append(self, record: TrialRecord) -> None: # type: ignore[override] - super().append(record) - - -class TargetTimingCsvWriter(CsvWriter): - def __init__(self, path: Path) -> None: - super().__init__(path, TARGET_TIMING_COLUMNS) - - def append(self, record: TargetTimingRecord) -> None: # type: ignore[override] - super().append(record) - - -class ScanLogWriter(CsvWriter): - def __init__(self, path: Path) -> None: - super().__init__(path, SCAN_LOG_COLUMNS) - - def append(self, phase: ScanPhase) -> None: # type: ignore[override] - super().append(phase) - - def write_manifest( run_dir: Path, session_info: "SessionInfo", diff --git a/src/mid_det/io/recording/records.py b/src/mid_det/io/recording/records.py new file mode 100644 index 0000000..4562d3f --- /dev/null +++ b/src/mid_det/io/recording/records.py @@ -0,0 +1,87 @@ +""" +Trial data records and their CSV column schemas: TrialRecord, TargetTimingRecord, +ScanPhase. Pure data — no behaviour, no I/O. +""" +from __future__ import annotations + +from dataclasses import dataclass + + +@dataclass +class TrialRecord: + trial_n: int + trial_type: int + polarity: str + magnitude: int + cue_label: str + time_onset: float + jitter_ms: int + jitter_ms_actual: float | str + target_dur_ms: int + target_dur_ms_actual: float | str + early_press: int + hit: int + rt_ms: float | str + reward_outcome: str + total_earned: int + time_trial_end: float + trial_dur_ms: int + time_sched_end: float + timing_drift_ms: float + n_iti_trs: int + total_trs: int + subject_id: str + run_n: str + pulse_ct: int + + +@dataclass +class TargetTimingRecord: + trial_n: int + target_frames_scheduled: int + target_frames_shown: int + target_visible_ms_scheduled: float + target_visible_ms_measured: float | str + late_flips_in_window: int + longest_frame_interval_ms: float + target_timing_ok: int + + +@dataclass +class ScanPhase: + trial_n: int + phase: str + tr_n: int + phase_onset_global_time: float + phase_onset_trial_time: float + pulse_ct: int + phase_offset_global_time: float = 0.0 + phase_offset_trial_time: float = 0.0 + trial_type: int = 0 + polarity: str = "" + magnitude: int = 0 + + +BEHAVIORAL_COLUMNS: list[str] = [ + "trial_n", "trial_type", "polarity", "magnitude", "cue_label", + "time_onset", "jitter_ms", "jitter_ms_actual", + "target_dur_ms", "target_dur_ms_actual", "early_press", "hit", "rt_ms", + "reward_outcome", "total_earned", "time_trial_end", "trial_dur_ms", + "time_sched_end", "timing_drift_ms", "n_iti_trs", "total_trs", + "subject_id", "run_n", "pulse_ct", +] + +TARGET_TIMING_COLUMNS: list[str] = [ + "trial_n", + "target_frames_scheduled", "target_frames_shown", + "target_visible_ms_scheduled", "target_visible_ms_measured", + "late_flips_in_window", "longest_frame_interval_ms", + "target_timing_ok", +] + +SCAN_LOG_COLUMNS: list[str] = [ + "trial_n", "trial_type", "polarity", "magnitude", "phase", "tr_n", + "phase_onset_global_time", "phase_offset_global_time", + "phase_onset_trial_time", "phase_offset_trial_time", + "pulse_ct", +] diff --git a/src/mid_det/scanner.py b/src/mid_det/io/scanner.py similarity index 100% rename from src/mid_det/scanner.py rename to src/mid_det/io/scanner.py diff --git a/src/mid_det/sequences.py b/src/mid_det/io/sequences.py similarity index 92% rename from src/mid_det/sequences.py rename to src/mid_det/io/sequences.py index 6e00a5c..1c98d62 100644 --- a/src/mid_det/sequences.py +++ b/src/mid_det/io/sequences.py @@ -7,8 +7,7 @@ from mid_det import config -_PACKAGE_DIR = Path(__file__).parent # src/mid_det/ -_PROJECT_ROOT = _PACKAGE_DIR.parent.parent # project root +_PROJECT_ROOT = Path(__file__).resolve().parents[3] # src/mid_det/io/ -> project root _SEQUENCES_DIR = _PROJECT_ROOT / "sequences" diff --git a/src/mid_det/setup_wizard.py b/src/mid_det/io/setup_wizard.py similarity index 75% rename from src/mid_det/setup_wizard.py rename to src/mid_det/io/setup_wizard.py index 5ba8eb2..d1a7a84 100644 --- a/src/mid_det/setup_wizard.py +++ b/src/mid_det/io/setup_wizard.py @@ -11,6 +11,8 @@ """ from __future__ import annotations +from collections.abc import Callable +from pathlib import Path from typing import NoReturn import questionary @@ -26,12 +28,16 @@ from rich.text import Text from mid_det import config -from mid_det.session import SessionInfo +from mid_det.io.bootstrap import SessionInfo _rcon = Console(stderr=True) # ── Styles ──────────────────────────────────────────────────────────────────── +# Placeholder subject ID — shown greyed-out and used as the fallback when the +# field is left empty (handy for testing). +_SUBJECT_PLACEHOLDER = "XXX000" + # Match questionary's default palette so everything looks cohesive. _QSTYLE = questionary.Style( [ @@ -43,6 +49,7 @@ ("selected", "fg:#cc5454"), ("separator", "fg:#6c6c6c"), ("instruction", "fg:#858585 italic"), + ("placeholder", "fg:#6c6c6c"), ] ) @@ -50,6 +57,7 @@ _PT_STYLE = PtStyle.from_dict( { "prompt": "#ff9d00 bold", # ❯ arrow: matches questionary answer + "placeholder": "#6c6c6c italic", # greyed-out example, not submitted "bottom-toolbar": "bg:#1e1e1e #888888", "bottom-toolbar.text": "bg:#1e1e1e", } @@ -72,6 +80,67 @@ def _nearest_frame_aligned(value_ms: float, frame_dur_ms: float) -> float: def _quit() -> NoReturn: from psychopy import core # late import — avoids circular / slow startup core.quit() + raise SystemExit(0) # unreachable; tells the type checker this never returns + + +def prompt_legacy_name(legacy_dir: Path, filename_for: Callable[[str], str]) -> str: + """Prompt for the legacy-format file's NAME and guard against clobbering an + existing file (the legacy names aren't timestamped, so same subject/run can + collide). *filename_for* maps the entered NAME to the target filename. + + Re-prompts until the resulting path is free or the operator confirms an + overwrite. Returns the chosen NAME. + """ + # Print the label above the input field (the prompt itself is just "❯"). + _rcon.print( + "[bold #5f819d]?[/bold #5f819d] [bold]Legacy filename[/bold] " + "[dim]NAME is only part of the saved file[/dim]", + highlight=False, + ) + + # Bottom toolbar: live preview of the resolved path as NAME is typed. + def _toolbar(): + typed = get_app().current_buffer.text.strip() + if not typed: + return FormattedText( + [("fg:ansired bold", " ✗ Name cannot be empty")] + ) + preview = legacy_dir / filename_for(typed) + return FormattedText([("fg:ansigreen bold", f" → saves as {preview}")]) + + while True: + raw = "" + try: + raw = _pt_prompt( + FormattedText([("class:prompt", "❯ ")]), + placeholder=HTML("e.g. 1"), + bottom_toolbar=_toolbar, + style=_PT_STYLE, + ) + except (KeyboardInterrupt, EOFError): + _quit() + name = raw.strip() + if not name: + _rcon.print("[red]Name cannot be empty.[/red]") + continue + + target = legacy_dir / filename_for(name) + if not target.exists(): + return name + + overwrite: bool | None = questionary.confirm( + f"{target.name} already exists in {legacy_dir}/ — overwrite?", + default=False, + style=_QSTYLE, + ).ask() + if overwrite is None: + _quit() + if overwrite: + return name + # else: loop and re-prompt for a different NAME + # Unreachable: the loop only exits via `return` or `_quit()`. Present so the + # type checker can see every path returns `str` (not `str | None`). + raise AssertionError("prompt_legacy_name loop exited unexpectedly") # ── Custom RT field ─────────────────────────────────────────────────────────── @@ -249,11 +318,17 @@ def run_wizard(frame_dur_s: float) -> SessionInfo: _rcon.print() # ── Session fields ──────────────────────────────────────────────────────── + # Placeholder (not a default): shown greyed-out so production users type the + # real ID without having to clear the default value. Pressing Enter on an + # empty field falls back to the placeholder value — convenient for testing. subject_id: str | None = questionary.text( - "Subject ID", default="XXX000", style=_QSTYLE + "Subject ID", + placeholder=HTML(f"{_SUBJECT_PLACEHOLDER}"), + style=_QSTYLE, ).ask() if subject_id is None: _quit() + subject_id_str: str = subject_id.strip() or _SUBJECT_PLACEHOLDER run_n: str | None = questionary.select( "Task", @@ -266,6 +341,7 @@ def run_wizard(frame_dur_s: float) -> SessionInfo: ).ask() if run_n is None: _quit() + run_n: str = run_n fmri: bool | None = questionary.confirm( "fMRI session?", default=False, style=_QSTYLE @@ -279,6 +355,12 @@ def run_wizard(frame_dur_s: float) -> SessionInfo: if show_instructions is None: _quit() + # ── Legacy-format filename ──────────────────────────────────────────────── + legacy_name = prompt_legacy_name( + Path("data") / "legacy-fmt", + lambda n: f"{n}_b{run_n}.csv", + ) + # ── Timing fields ───────────────────────────────────────────────────────── _rcon.print() _rcon.print(Rule("[dim]Timing[/dim]", style="dim")) @@ -315,10 +397,11 @@ def run_wizard(frame_dur_s: float) -> SessionInfo: _rcon.print() return SessionInfo( - subject_id=subject_id, + subject_id=subject_id_str, fmri=fmri, run_n=run_n, show_instructions=show_instructions, base_rt_s=base_rt_ms / 1000.0, rt_change_s=rt_change_ms / 1000.0, + legacy_name=legacy_name, ) diff --git a/src/mid_det/ratings/__main__.py b/src/mid_det/ratings/__main__.py index e8112a3..9ae314d 100644 --- a/src/mid_det/ratings/__main__.py +++ b/src/mid_det/ratings/__main__.py @@ -4,7 +4,7 @@ A self-paced survey (no scanner sync, no frame-timing measurement). Each of the 6 MID cues is rated on a VALENCE then an AROUSAL 7-point circle-slider scale, controlled with buttons 1 (left) / 2 (right) / 3 (select). Output is a single -CSV: data/_ratings.csv with columns polarity,magnitude,arousal,valence. +CSV: data/ratings_.csv with columns polarity,magnitude,arousal,valence. Ported from MATLAB RunRatings.m. """ @@ -26,10 +26,10 @@ from psychopy.hardware import keyboard from rich.console import Console -from mid_det import recorder, session +from mid_det.io import bootstrap, recording from mid_det.ratings import core as rcore from mid_det.ratings import display as rdisplay -from mid_det.ratings.wizard import run_ratings_wizard +from mid_det.ratings.setup_wizard import run_ratings_wizard _PACKAGE_DIR = Path(__file__).resolve().parent # src/mid_det/ratings/ _PROJECT_ROOT = _PACKAGE_DIR.parent.parent.parent # project root @@ -114,10 +114,10 @@ def _show_fixation(win: visual.Window, stim: rdisplay.RatingStimuli) -> None: def run() -> None: # ── SCREEN ─────────────────────────────────────────────────────────────── - win_res, win, screen_diag = session.setup_screen() + win_res, win, screen_diag = bootstrap.setup_screen() # ── WIZARD ─────────────────────────────────────────────────────────────── - subject_id, show_instructions = run_ratings_wizard() + subject_id, show_instructions, legacy_name = run_ratings_wizard() session_time = datetime.now() rcon = Console(stderr=True) @@ -129,7 +129,7 @@ def run() -> None: ts = session_time.strftime("%Y%m%dT%H%M%S") run_dir = _PROJECT_ROOT / "data" / f"{subject_id}_ratings_{ts}" run_dir.mkdir(parents=True, exist_ok=True) - recorder.write_ratings_manifest( + recording.write_ratings_manifest( run_dir=run_dir, subject_id=subject_id, show_instructions=show_instructions, @@ -194,10 +194,17 @@ def run() -> None: # ── WRITE CSV ──────────────────────────────────────────────────────────── # (manifest.json was already written to run_dir at startup) - out_path = run_dir / f"{subject_id}_ratings.csv" + out_path = run_dir / f"ratings_{subject_id}.csv" rcore.write_ratings_csv(out_path, results) + # Legacy-format copy (gamble,arousal,valence) for downstream systems. + legacy_dir = _PROJECT_ROOT / "data" / "legacy-fmt" + legacy_dir.mkdir(parents=True, exist_ok=True) + legacy_path = legacy_dir / f"{legacy_name}_ratings.csv" + rcore.write_legacy_ratings_csv(legacy_path, results) + rcon.print(f"[bold green]Ratings saved[/bold green] -> [cyan]{out_path}[/cyan]") + rcon.print(f"[bold green]Legacy ratings saved[/bold green] -> [cyan]{legacy_path}[/cyan]") for r in results: rcon.print( f" {r['polarity']:<4} ${r['magnitude']} valence=[cyan]{r['valence']}[/cyan] " diff --git a/src/mid_det/ratings/core.py b/src/mid_det/ratings/core.py index 4d4a45a..bf80c4d 100644 --- a/src/mid_det/ratings/core.py +++ b/src/mid_det/ratings/core.py @@ -97,3 +97,36 @@ def write_ratings_csv(path: Path, results: list[dict]) -> None: """Write the ratings CSV (polarity,magnitude,arousal,valence) to *path*.""" with open(path, "w", newline="") as f: csv.writer(f).writerows(build_csv_rows(results)) + + +# ── Legacy format (downstream-system compatibility) ─────────────────────────── +# Old MATLAB layout: a single "gamble" name column instead of polarity+magnitude, +# columns gamble,arousal,valence. magnitude 0/1/5 -> low/med/high; loss -> square, +# gain -> circle (matches config.POLARITY_SHAPE). +_MAGNITUDE_SIZE: dict[int, str] = {0: "low", 1: "med", 5: "high"} +_POLARITY_SHAPE: dict[str, str] = {"loss": "square", "gain": "circle"} + +GAMBLE_NAMES: dict[tuple[str, int], str] = { + (cue.polarity, cue.magnitude): ( + _MAGNITUDE_SIZE[cue.magnitude] + _POLARITY_SHAPE[cue.polarity] + ) + for cue in RATING_CUES +} + +LEGACY_RATINGS_HEADER: list[str] = ["gamble", "arousal", "valence"] + + +def build_legacy_csv_rows(results: list[dict]) -> list[list[str]]: + """Build legacy CSV rows (header + one per result) with a single gamble-name + column, in the order *results* are given (RATING_CUES order).""" + rows: list[list[str]] = [list(LEGACY_RATINGS_HEADER)] + for r in results: + gamble = GAMBLE_NAMES[(r["polarity"], r["magnitude"])] + rows.append([gamble, str(r["arousal"]), str(r["valence"])]) + return rows + + +def write_legacy_ratings_csv(path: Path, results: list[dict]) -> None: + """Write the legacy ratings CSV (gamble,arousal,valence) to *path*.""" + with open(path, "w", newline="") as f: + csv.writer(f).writerows(build_legacy_csv_rows(results)) diff --git a/src/mid_det/ratings/display.py b/src/mid_det/ratings/display.py index 5067829..355ca75 100644 --- a/src/mid_det/ratings/display.py +++ b/src/mid_det/ratings/display.py @@ -13,7 +13,7 @@ from psychopy import visual from mid_det import config -from mid_det.display import _LINE_Y_FRAC +from mid_det.task.display import _LINE_Y_FRAC from mid_det.ratings import core # ── Layout constants (height units; screen spans −0.5…+0.5) ────────────────── diff --git a/src/mid_det/ratings/setup_wizard.py b/src/mid_det/ratings/setup_wizard.py new file mode 100644 index 0000000..9218288 --- /dev/null +++ b/src/mid_det/ratings/setup_wizard.py @@ -0,0 +1,67 @@ +""" +Trimmed interactive setup wizard for the cue-ratings survey. + +Reuses the styling and quit helpers from mid_det.io.setup_wizard; prompts only for +Subject ID and whether to show instructions (no fmri/run/timing fields). +""" +from __future__ import annotations + +from pathlib import Path + +import questionary +from prompt_toolkit.formatted_text import HTML +from rich.panel import Panel +from rich.text import Text + +from mid_det.io.setup_wizard import ( + _QSTYLE, + _SUBJECT_PLACEHOLDER, + _quit, + _rcon, + prompt_legacy_name, +) + +# Project root: src/mid_det/ratings/wizard.py -> project root (matches __main__). +_PROJECT_ROOT = Path(__file__).resolve().parent.parent.parent.parent + + +def run_ratings_wizard() -> tuple[str, bool, str]: + """Return (subject_id, show_instructions, legacy_name).""" + _rcon.print() + _rcon.print( + Panel( + Text( + "MID Cue-Ratings Survey — Setup", + style="bold white", + justify="center", + ), + border_style="bright_blue", + padding=(0, 4), + ) + ) + _rcon.print() + + # Placeholder (not a default): shown greyed-out so production users type the + # real ID without clearing a field. Pressing Enter on an empty field falls + # back to the placeholder value — convenient for testing. + subject_id: str | None = questionary.text( + "Subject ID", + placeholder=HTML(f"{_SUBJECT_PLACEHOLDER}"), + style=_QSTYLE, + ).ask() + if subject_id is None: + _quit() + subject_id: str = subject_id.strip() or _SUBJECT_PLACEHOLDER + + show_instructions: bool | None = questionary.confirm( + "Show instructions?", default=True, style=_QSTYLE + ).ask() + if show_instructions is None: + _quit() + + legacy_name = prompt_legacy_name( + _PROJECT_ROOT / "data" / "legacy-fmt", + lambda n: f"{n}_ratings.csv", + ) + + return subject_id, show_instructions, legacy_name diff --git a/src/mid_det/ratings/wizard.py b/src/mid_det/ratings/wizard.py deleted file mode 100644 index 7817269..0000000 --- a/src/mid_det/ratings/wizard.py +++ /dev/null @@ -1,44 +0,0 @@ -""" -Trimmed interactive setup wizard for the cue-ratings survey. - -Reuses the styling and quit helpers from mid_det.setup_wizard; prompts only for -Subject ID and whether to show instructions (no fmri/run/timing fields). -""" -from __future__ import annotations - -import questionary -from rich.panel import Panel -from rich.text import Text - -from mid_det.setup_wizard import _QSTYLE, _quit, _rcon - - -def run_ratings_wizard() -> tuple[str, bool]: - """Return (subject_id, show_instructions).""" - _rcon.print() - _rcon.print( - Panel( - Text( - "MID Cue-Ratings Survey — Setup", - style="bold white", - justify="center", - ), - border_style="bright_blue", - padding=(0, 4), - ) - ) - _rcon.print() - - subject_id: str | None = questionary.text( - "Subject ID", default="XXX000", style=_QSTYLE - ).ask() - if subject_id is None: - _quit() - - show_instructions: bool | None = questionary.confirm( - "Show instructions?", default=True, style=_QSTYLE - ).ask() - if show_instructions is None: - _quit() - - return subject_id, show_instructions diff --git a/src/mid_det/task/__init__.py b/src/mid_det/task/__init__.py new file mode 100644 index 0000000..8598cdf --- /dev/null +++ b/src/mid_det/task/__init__.py @@ -0,0 +1,3 @@ +"""The experiment run: trial orchestration, per-phase loops, the timing-critical +response window, on-screen presentation (cues/feedback/instructions), and the +adaptive target-window staircase.""" diff --git a/src/mid_det/calibration.py b/src/mid_det/task/calibration.py similarity index 100% rename from src/mid_det/calibration.py rename to src/mid_det/task/calibration.py diff --git a/src/mid_det/console.py b/src/mid_det/task/console.py similarity index 100% rename from src/mid_det/console.py rename to src/mid_det/task/console.py diff --git a/src/mid_det/debug.py b/src/mid_det/task/debug.py similarity index 100% rename from src/mid_det/debug.py rename to src/mid_det/task/debug.py diff --git a/src/mid_det/display.py b/src/mid_det/task/display.py similarity index 100% rename from src/mid_det/display.py rename to src/mid_det/task/display.py diff --git a/src/mid_det/task/flip_timer.py b/src/mid_det/task/flip_timer.py new file mode 100644 index 0000000..c5fc87c --- /dev/null +++ b/src/mid_det/task/flip_timer.py @@ -0,0 +1,114 @@ +""" +FlipTimer: accumulates per-flip timing for the response window and derives the +target-display diagnostics consumed by run_trial. Fed ``win.lastFrameT`` once per +flip. Kept separate from the timing-critical response loop in response.py so the +measurement bookkeeping can be read (and tested) in isolation. +""" +from __future__ import annotations + + +class FlipTimer: + """Accumulates per-flip timing for the response window and derives the + target-display diagnostics. Fed ``win.lastFrameT`` once per flip. + + Measurement uses ``win.lastFrameT`` — the time PsychoPy stamps inside + ``flip()`` right after the GPU finishes the swap (after glFinish, before + callOnFlip callbacks fire). It is a tighter proxy for the actual swap time + than reading ``core.getTime()`` after ``flip()`` returns, which additionally + absorbs callback-dispatch overhead. + """ + + def __init__(self, win, frame_dur_s: float, n_target_frames: int) -> None: + self._win = win + self._frame_dur_s = frame_dur_s + self._n_target_frames = n_target_frames + self._response_start_flip_t: float | None = None + self._onset_flip_t: float | None = None + self._removal_flip_t: float | None = None + self._flip_iters = 0 + self._dropped_at_onset: int | None = None + self._max_intra_flip_ms = 0.0 + self._last_intra_flip_t: float | None = None + + def on_flip(self, last_frame_t: float) -> None: + """Call after every flip. Stamps the first flip of the response window so + the pre-target jitter wall time can be reported later.""" + if self._response_start_flip_t is None: + self._response_start_flip_t = last_frame_t + + def on_onset(self, last_frame_t: float) -> None: + """Call on the flip that puts the target on the glass.""" + self._onset_flip_t = last_frame_t + self._dropped_at_onset = getattr(self._win, "nDroppedFrames", 0) + self._last_intra_flip_t = last_frame_t + + def on_target_frame(self, last_frame_t: float) -> None: + """Call after each flip while the target is on screen. A stretched flip + interval here is the DWM-hiccup signature; keep the worst one.""" + self._flip_iters += 1 + self._accumulate_interval(last_frame_t) + self._last_intra_flip_t = last_frame_t + + def on_removal(self, last_frame_t: float) -> float | None: + """Call on the flip that clears the target. Returns ``target_removed_at`` + (onset→removal wall seconds), or None if onset was never stamped.""" + self._removal_flip_t = last_frame_t + # Measure the interval into the removal flip too — a stretched removal + # flip is exactly the DWM-hiccup signature and must not be invisible. + self._accumulate_interval(last_frame_t) + if self._onset_flip_t is None: + return None + return last_frame_t - self._onset_flip_t + + def _accumulate_interval(self, last_frame_t: float) -> None: + if self._last_intra_flip_t is not None: + delta_ms = (last_frame_t - self._last_intra_flip_t) * 1000 + if delta_ms > self._max_intra_flip_ms: + self._max_intra_flip_ms = delta_ms + + def summary(self) -> dict: + """Build the diagnostics dict consumed by run_trial.""" + if self._onset_flip_t is not None and self._removal_flip_t is not None: + onset_to_removal_wall_ms = round( + (self._removal_flip_t - self._onset_flip_t) * 1000, 2 + ) + else: + onset_to_removal_wall_ms = "" + + if self._response_start_flip_t is not None and self._onset_flip_t is not None: + jitter_ms_actual = round( + (self._onset_flip_t - self._response_start_flip_t) * 1000, 2 + ) + else: + jitter_ms_actual = "" + + if self._dropped_at_onset is not None: + dropped_frames = int( + getattr(self._win, "nDroppedFrames", 0) - self._dropped_at_onset + ) + else: + dropped_frames = 0 + + # Mark trial unclean if (a) PsychoPy detected any dropped frames during the + # response window, OR (b) the measured wall delta differs from the expected + # on-screen duration by more than half a frame. Either condition makes the + # exact target-display time unreliable for timing-sensitive analyses; flag + # for exclusion at analysis time. DWM-induced extra frames on Windows are + # an acknowledged unsolvable limitation — exclusion is the standard fix. + expected_dur_ms = self._n_target_frames * self._frame_dur_s * 1000 + half_frame_ms = (self._frame_dur_s * 1000) / 2 + timing_off_by_frame = ( + isinstance(onset_to_removal_wall_ms, (int, float)) + and abs(onset_to_removal_wall_ms - expected_dur_ms) > half_frame_ms + ) + trial_clean = dropped_frames == 0 and not timing_off_by_frame + + return { + "flip_iters": self._flip_iters, + "n_target_frames": self._n_target_frames, + "dropped_frames": dropped_frames, + "onset_to_removal_wall_ms": onset_to_removal_wall_ms, + "max_flip_interval_ms": round(self._max_intra_flip_ms, 2), + "trial_clean": trial_clean, + "jitter_ms_actual": jitter_ms_actual, + } diff --git a/src/mid_det/task/instructions.py b/src/mid_det/task/instructions.py new file mode 100644 index 0000000..5c79cb1 --- /dev/null +++ b/src/mid_det/task/instructions.py @@ -0,0 +1,80 @@ +""" +Instruction presentation: a self-paced, keypress-driven loop that pages through +text/instructions_MID.txt (one page per non-blank line) and waits for the start +key. Same draw → flip → poll pattern as the per-phase loops in phases.py, but +shown once before the trial loop rather than per trial. +""" +from __future__ import annotations + +from pathlib import Path +from typing import TYPE_CHECKING + +from psychopy import core, visual +from psychopy.hardware import keyboard +from rich.console import Console + +from mid_det import config + +if TYPE_CHECKING: + from mid_det.io.bootstrap import SessionInfo + +_PROJECT_ROOT = Path(__file__).resolve().parents[3] # src/mid_det/task/ -> project root +_TEXT_DIR = _PROJECT_ROOT / "text" + + +def _load_pages(path: Path) -> list[str]: + """Read an instruction text file into a list of pages (one non-blank line each).""" + pages: list[str] = [] + with open(path) as f: + for line in f: + stripped = line.rstrip() + if stripped: + pages.append(stripped) + return pages + + +def display_instructions( + win: visual.Window, + stimuli, # Stimuli dataclass from display.py; avoid circular import + session_info: "SessionInfo", + kb: keyboard.Keyboard, + rcon: Console, +) -> None: + """Display instructions from text/instructions_MID.txt one page at a time.""" + keys_map = config.KEYS_FMRI if session_info.fmri else config.KEYS_BEHAVIORAL + forward_key = keys_map["forward"] + start_key = keys_map["start"] + end_key = keys_map["end"] + + pages = _load_pages(_TEXT_DIR / "instructions_MID.txt") + if not pages: + return + + kb.clearEvents() + page_idx = 0 + + while True: + stimuli.instr_prompt.text = pages[page_idx] + stimuli.instr_prompt.draw() + stimuli.instr_first.draw() + win.flip() + + pressed = kb.getKeys(keyList=[forward_key, end_key], waitRelease=False) + if not pressed: + continue + key_name = pressed[0].name + if key_name == end_key: + core.quit() + elif key_name == forward_key: + page_idx += 1 + if page_idx >= len(pages): + break + + rcon.print( + f"[bold yellow]End of instructions — press '{start_key}' to continue...[/bold yellow]" + ) + while True: + stimuli.instr_finish.draw() + win.flip() + if kb.getKeys(keyList=[start_key], waitRelease=False): + break diff --git a/src/mid_det/task/phases.py b/src/mid_det/task/phases.py new file mode 100644 index 0000000..aa8c642 --- /dev/null +++ b/src/mid_det/task/phases.py @@ -0,0 +1,98 @@ +""" +Fixed-duration per-phase display loops: cue, fixation, outcome, ITI, plus the +shared quit/overlay poll. Each function drives win.flip() for its STUDY_TIMES_S +duration; no data is recorded here. The timing-critical response window lives in +response.py. +""" +from __future__ import annotations + +from mid_det import config +from mid_det._psychopy import core, keyboard, visual +from mid_det.task.debug import DebugOverlay +from mid_det.task.display import ( + Stimuli, + draw_cue, + draw_feedback, + draw_fixation_o, + draw_fixation_x, +) + + +def run_cue( + win: visual.Window, + stimuli: Stimuli, + polarity: str, + magnitude: int, + kb: keyboard.Keyboard, + overlay: DebugOverlay | None = None, +) -> None: + """Display cue for STUDY_TIMES_S['cue'] seconds.""" + timer = core.CountdownTimer(config.STUDY_TIMES_S["cue"]) + while timer.getTime() > 0: + draw_cue(stimuli, polarity, magnitude) + win.flip() + _poll_hotkeys(kb, overlay) + + +def run_fixation( + win: visual.Window, + stimuli: Stimuli, + kb: keyboard.Keyboard, + overlay: DebugOverlay | None = None, +) -> bool: + """Display fixation; return True if any response key was pressed (early press).""" + kb.clearEvents() + early = False + timer = core.CountdownTimer(config.STUDY_TIMES_S["fixation"]) + while timer.getTime() > 0: + draw_fixation_x(stimuli) + win.flip() + _poll_hotkeys(kb, overlay) + # Poll in the loop so a press doesn't sit in the buffer until end-of-phase, + # where a downstream kb.clearEvents() could discard it before inspection. + if not early and kb.getKeys(keyList=config.EXP_KEYS, waitRelease=False): + early = True + if not early and kb.getKeys(keyList=config.EXP_KEYS, waitRelease=False): + early = True + return early + + +def show_outcome( + win: visual.Window, + stimuli: Stimuli, + kb: keyboard.Keyboard, + hit: bool, + reward_outcome: str, + overlay: DebugOverlay | None = None, +) -> None: + """Display the outcome feedback for STUDY_TIMES_S['outcome'] seconds.""" + timer = core.CountdownTimer(config.STUDY_TIMES_S["outcome"]) + while timer.getTime() > 0: + draw_feedback(stimuli, hit, reward_outcome) + win.flip() + _poll_hotkeys(kb, overlay) + + +def run_iti( + win: visual.Window, + stimuli: Stimuli, + kb: keyboard.Keyboard, + fix_dur_s: float, + overlay: DebugOverlay | None = None, +) -> None: + """Display fixation for fix_dur_s seconds (drift-corrected by caller).""" + if fix_dur_s <= 0: + return + timer = core.CountdownTimer(fix_dur_s) + while timer.getTime() > 0: + draw_fixation_o(stimuli) + win.flip() + _poll_hotkeys(kb, overlay) + + +def _poll_hotkeys(kb: keyboard.Keyboard, overlay: DebugOverlay | None = None) -> None: + """Per-frame operator-hotkey poll: quit on escape, toggle the debug overlay on f3.""" + if kb.getKeys(keyList=["escape"], waitRelease=False): + core.quit() + if overlay is not None and kb.getKeys(keyList=["f3"], waitRelease=False): + overlay.toggle() diff --git a/src/mid_det/task/response.py b/src/mid_det/task/response.py new file mode 100644 index 0000000..424e397 --- /dev/null +++ b/src/mid_det/task/response.py @@ -0,0 +1,147 @@ +""" +The response window: target onset/offset timing and keypress capture. + +run_response is the timing-critical core of a trial. It uses +psychopy.hardware.keyboard.Keyboard for accurate RT timestamping. FlipTimer +(flip_timer.py) accumulates per-flip diagnostics; _ResponseState classifies the +keypress outcome. No rendering objects are built here; no data is written here. +""" +from __future__ import annotations + +from dataclasses import dataclass + +from mid_det import config +from mid_det._psychopy import core, keyboard, visual +from mid_det.task.debug import DebugOverlay +from mid_det.task.display import Stimuli, draw_fixation_x, draw_target +from mid_det.task.flip_timer import FlipTimer +from mid_det.task.phases import _poll_hotkeys + + +@dataclass +class _ResponseState: + """Captures the participant's keypress outcome for the response window.""" + + early_press: bool = False + hit: bool = False + rt_s: float | None = None + + def poll_pretarget(self, kb: keyboard.Keyboard) -> None: + """Before target onset any EXP_KEYS press is early. Also drains presses + queued before the loop (e.g. during wait_for_tr); a plain + kb.clearEvents() would silently discard those.""" + if not self.early_press and kb.getKeys( + keyList=config.EXP_KEYS, waitRelease=False + ): + self.early_press = True + + def poll_target( + self, kb: keyboard.Keyboard, target_removed_at: float | None + ) -> None: + """Classify the first press once the target has been shown. An rt < 0 was + pressed before the onset-flip clock reset → early, never a hit.""" + if self.hit or self.rt_s is not None or self.early_press: + return + keys = kb.getKeys(keyList=config.EXP_KEYS, waitRelease=False) + if not keys: + return + rt = keys[0].rt + if rt < 0: + self.early_press = True + else: + self.rt_s = rt + if target_removed_at is None or rt < target_removed_at: + self.hit = True + + +def run_response( + win: visual.Window, + stimuli: Stimuli, + kb: keyboard.Keyboard, + jitter_s: float, + target_dur_s: float, + frame_dur_s: float, + early_press: bool, + overlay: DebugOverlay | None = None, +) -> tuple[bool, float | None, bool, float | None, dict]: + """ + Display response phase (STUDY_TIMES_S['response'] seconds total). + Target appears after jitter_s and stays visible for target_dur_s seconds. + Returns (hit, rt_s, early_press, target_removed_at, diagnostics). + """ + phase_clock = core.Clock() + # Two distinct moments, deliberately one flip apart (draw happens before + # flip, pixels land on the glass when flip completes): + target_onset_scheduled = False # crossed jitter; onset flip is queued/decided + target_on_screen = False # the onset flip has happened — target is on the glass + target_removed_at: float | None = None + + # Derive frames-shown from kb.clock (reset on the onset flip) rather than + # counting loop iterations. Counting iterations assumes every flip() blocks + # exactly one vsync — true on macOS once VSYNC is verified, but Windows + # occasionally drops a frame, making one flip() span two vsyncs. The + # iteration counter would still tick once and the target would be visible + # for one extra frame. The clock advances with real wall time regardless, + # so round(elapsed / frame_dur) + 1 reflects actual displayed frames. + # round() rather than ceil() is intentional: frame-aligned durations (e.g. + # 17 * frame_dur) accumulate floating-point drift and evaluate to + # 17.000000000000004, which ceil() would promote to 18 — one phantom extra + # frame. round() snaps back to the correct integer. + n_target_frames = round(target_dur_s / frame_dur_s) + + timer = FlipTimer(win, frame_dur_s, n_target_frames) + response = _ResponseState(early_press=early_press) + + # Drain any presses queued between fixation end and now (e.g. during + # pulse_counter.wait_for_tr() or scheduler hiccups). Any EXP_KEYS press here + # belongs to the pre-target window and must count as early. + response.poll_pretarget(kb) + + while phase_clock.getTime() < config.STUDY_TIMES_S["response"]: + t = phase_clock.getTime() + + # Schedule kb.clock reset to fire on the next flip so t=0 aligns with onset. + if not target_onset_scheduled and t >= jitter_s: + win.callOnFlip(kb.clock.reset) + target_onset_scheduled = True + + # Decide removal BEFORE the flip: omitting draw_target clears the target + # on this flip. kb.clock (reset on the onset flip) gives the wall time the + # target has been visible; round(elapsed / frame_dur) is the whole frames + # already shown and +1 counts the frame this upcoming flip will complete. + should_remove = False + if target_on_screen and target_removed_at is None: + frames_shown_after_next_flip = round(kb.clock.getTime() / frame_dur_s) + 1 + should_remove = frames_shown_after_next_flip >= n_target_frames + + if target_onset_scheduled and target_removed_at is None and not should_remove: + draw_target(stimuli) + elif not target_onset_scheduled: + draw_fixation_x(stimuli) + win.flip() + + timer.on_flip(win.lastFrameT) + if should_remove: + target_removed_at = timer.on_removal(win.lastFrameT) + elif target_onset_scheduled and not target_on_screen: + target_on_screen = True + timer.on_onset(win.lastFrameT) + # The onset flip is both an onset and the target's first displayed frame. + if target_on_screen and target_removed_at is None: + timer.on_target_frame(win.lastFrameT) + + # Poll keys after the flip so timestamps are relative to the latest screen. + if not target_onset_scheduled: + response.poll_pretarget(kb) + else: + response.poll_target(kb, target_removed_at) + + _poll_hotkeys(kb, overlay) + + return ( + response.hit, + response.rt_s, + response.early_press, + target_removed_at, + timer.summary(), + ) diff --git a/src/mid_det/task/trial.py b/src/mid_det/task/trial.py new file mode 100644 index 0000000..3927c82 --- /dev/null +++ b/src/mid_det/task/trial.py @@ -0,0 +1,263 @@ +""" +Trial orchestration: run_trial() ties the per-phase loops together (cue → +fixation → response → outcome → ITI), applies the reward rule, and builds the +data records. The per-phase display loops live in phases.py and the +timing-critical response window in response.py. No rendering objects are built +here; no data is written here. +""" +from __future__ import annotations + +import random +from collections.abc import Callable + +import pandas as pd + +from mid_det import config +from mid_det._psychopy import core, keyboard, logging, visual +from mid_det.task.calibration import CalibrationState +from mid_det.task.debug import DebugOverlay +from mid_det.task.display import Stimuli +from mid_det.task.phases import run_cue, run_fixation, run_iti, show_outcome +from mid_det.io.recording import ScanPhase, TargetTimingRecord, TrialRecord +from mid_det.task.response import run_response +from mid_det.io.scanner import PulseCounter + + +def _compute_reward( + hit: bool, polarity: str, magnitude: int, total_earned: int +) -> tuple[str, int]: + """ + Return (reward_outcome_label, new_total_earned). + + Gain trial: hit → +$magnitude, miss → $0 + Loss trial: hit → $0, miss → -$magnitude + """ + if polarity == "gain": + if hit and magnitude > 0: + return f"+${magnitude}.00", total_earned + magnitude + if hit and magnitude == 0: + return "+$0.00", total_earned + return "$0.00", total_earned + + # loss + if hit: + return "$0.00", total_earned + if magnitude > 0: + return f"-${magnitude}.00", total_earned - magnitude + return "-$0.00", total_earned + + +def run_trial( + win: visual.Window, + stimuli: Stimuli, + kb: keyboard.Keyboard, + global_clock: core.Clock, + row: pd.Series, + trial_n: int, + n_trials: int, + n_iti_trs: int, + nominal_time: float, + total_earned: int, + subject_id: str, + run_n: str, + pulse_ct: int, + pulse_counter: PulseCounter, + calibration: CalibrationState, + frame_dur_s: float, + on_window: Callable[[int], None] | None = None, + on_response: Callable[[bool, float | None, bool, int, float | str, str, int], None] | None = None, + overlay: DebugOverlay | None = None, +) -> tuple[TrialRecord, TargetTimingRecord, list[ScanPhase], float, int]: + """ + Run one complete trial (cue → fixation → response → outcome → ITI). + + Returns (record, target_timing, scan_phases, nominal_time, total_earned). + """ + polarity = str(row["polarity"]) + magnitude = int(row["magnitude"]) + trial_type = config.TRIAL_TYPE_MAP[(polarity, magnitude)] + target_dur_s = calibration.next_target_dur_s(polarity, magnitude) + jitter_s = random.uniform( + config.JITTER_MIN_S, + config.JITTER_MAX_S, + ) + label = config.cue_label(polarity, magnitude) + + target_dur_ms = int(round(target_dur_s * 1000)) + logging.exp( + f"Trial {trial_n:3d}/{n_trials} cue={label} " + f"target_dur={target_dur_ms} ms jitter={int(jitter_s * 1000)} ms" + ) + if on_window is not None: + on_window(target_dur_ms) + + scan_phases: list[ScanPhase] = [] + tr_within = 0 + + def _update_overlay(phase: str) -> None: + if overlay is not None: + overlay.state.phase = phase + overlay.state.polarity = polarity + overlay.state.magnitude = magnitude + overlay.state.target_dur_ms = target_dur_ms + overlay.state.jitter_ms = int(jitter_s * 1000) + overlay.state.pulse_ct = pulse_ct + overlay.state.global_time = global_clock.getTime() + overlay.state.nominal_time = nominal_time + + # ── CUE ───────────────────────────────────────────────────────────────── + pulse_ct += pulse_counter.drain() + time_onset = global_clock.getTime() + tr_within += 1 + scan_phases.append(ScanPhase( + trial_n=trial_n, phase="cue", tr_n=tr_within, + phase_onset_global_time=time_onset, + phase_onset_trial_time=0.0, + pulse_ct=pulse_ct, + )) + _update_overlay("cue") + run_cue(win, stimuli, polarity, magnitude, kb, overlay) + nominal_time += config.STUDY_TIMES_S["cue"] + + # ── FIXATION ────────────────────────────────────────────────────────────── + pulse_ct += pulse_counter.wait_for_tr() + fixation_start = global_clock.getTime() + tr_within += 1 + scan_phases.append(ScanPhase( + trial_n=trial_n, phase="fixation", tr_n=tr_within, + phase_onset_global_time=fixation_start, + phase_onset_trial_time=fixation_start - time_onset, + pulse_ct=pulse_ct, + )) + _update_overlay("fixation") + early_press = run_fixation(win, stimuli, kb, overlay) + nominal_time += config.STUDY_TIMES_S["fixation"] + + # ── RESPONSE ───────────────────────────────────────────────────────────── + pulse_ct += pulse_counter.wait_for_tr() + response_start = global_clock.getTime() + tr_within += 1 + scan_phases.append(ScanPhase( + trial_n=trial_n, phase="response", tr_n=tr_within, + phase_onset_global_time=response_start, + phase_onset_trial_time=response_start - time_onset, + pulse_ct=pulse_ct, + )) + _update_overlay("response") + hit, rt_s, early_press, target_removed_at, response_diag = run_response( + win, stimuli, kb, jitter_s, target_dur_s, frame_dur_s, early_press, overlay + ) + nominal_time += config.STUDY_TIMES_S["response"] + target_dur_ms_actual = round(target_removed_at * 1000, 2) if target_removed_at is not None else "" + reward_outcome, total_earned = _compute_reward(hit, polarity, magnitude, total_earned) + if on_response is not None: + on_response( + hit, rt_s, early_press, target_dur_ms, target_dur_ms_actual, + reward_outcome, total_earned, + ) + + # ── OUTCOME ────────────────────────────────────────────────────────────── + pulse_ct += pulse_counter.wait_for_tr() + outcome_start = global_clock.getTime() + tr_within += 1 + scan_phases.append(ScanPhase( + trial_n=trial_n, phase="outcome", tr_n=tr_within, + phase_onset_global_time=outcome_start, + phase_onset_trial_time=outcome_start - time_onset, + pulse_ct=pulse_ct, + )) + _update_overlay("outcome") + show_outcome(win, stimuli, kb, hit, reward_outcome, overlay) + nominal_time += config.STUDY_TIMES_S["outcome"] + calibration.record_outcome(polarity, magnitude, bool(hit)) + # Drift is measured here — at the end of feedback, before the ITI — to match + # MATLAB main.m:329 (`GetSecs()-abs_start-onset_t(i)-8.0`). It is the per-trial + # over/under-run of the four fixed slides relative to this trial's own onset, + # i.e. the slippage the ITI is about to correct, NOT the post-correction + # residual. In scan mode the per-phase wait_for_tr() waits fall inside this + # window, so any TR-lock slack is included (same as the prior definition). + time_outcome_end = global_clock.getTime() + + # ── ITI ────────────────────────────────────────────────────────────────── + for _ in range(n_iti_trs): + pulse_ct += pulse_counter.wait_for_tr() + iti_start = global_clock.getTime() + tr_within += 1 + scan_phases.append(ScanPhase( + trial_n=trial_n, phase="post-outcome-fixation", tr_n=tr_within, + phase_onset_global_time=iti_start, + phase_onset_trial_time=iti_start - time_onset, + pulse_ct=pulse_ct, + )) + actual_time = global_clock.getTime() + iti_dur = config.STUDY_TIMES_S["iti"] - (actual_time - nominal_time) + nominal_time += config.STUDY_TIMES_S["iti"] + _update_overlay("iti") + run_iti(win, stimuli, kb, iti_dur, overlay) + + # ── BUILD RECORD ───────────────────────────────────────────────────────── + time_trial_end = global_clock.getTime() + time_sched_end = nominal_time + + # Backfill each phase's end = onset of the next phase (tiled timeline); + # the final phase ends at the trial end. + for i, sp in enumerate(scan_phases): + end_global = ( + scan_phases[i + 1].phase_onset_global_time + if i + 1 < len(scan_phases) + else time_trial_end + ) + sp.phase_offset_global_time = end_global + sp.phase_offset_trial_time = end_global - time_onset + sp.trial_type = trial_type + sp.polarity = polarity + sp.magnitude = magnitude + + record = TrialRecord( + trial_n=trial_n, + trial_type=trial_type, + polarity=polarity, + magnitude=magnitude, + cue_label=label, + time_onset=round(time_onset, 6), + jitter_ms=int(round(jitter_s * 1000)), + jitter_ms_actual=response_diag["jitter_ms_actual"], + target_dur_ms=target_dur_ms, + target_dur_ms_actual=target_dur_ms_actual, + early_press=int(early_press), + hit=int(hit), + rt_ms=round(rt_s * 1000, 2) if rt_s is not None else "", + reward_outcome=reward_outcome, + total_earned=total_earned, + time_trial_end=round(time_trial_end, 6), + trial_dur_ms=int(round((time_trial_end - time_onset) * 1000)), + time_sched_end=round(time_sched_end, 6), + timing_drift_ms=round( + ((time_outcome_end - time_onset) - config.PRE_ITI_NOMINAL_S) * 1000, 2 + ), + n_iti_trs=n_iti_trs, + total_trs=tr_within, + subject_id=subject_id, + run_n=run_n, + pulse_ct=scan_phases[0].pulse_ct, + ) + + target_timing = TargetTimingRecord( + trial_n=trial_n, + target_frames_scheduled=response_diag["n_target_frames"], + target_frames_shown=response_diag["flip_iters"], + target_visible_ms_scheduled=round( + response_diag["n_target_frames"] * frame_dur_s * 1000, 2 + ), + target_visible_ms_measured=response_diag["onset_to_removal_wall_ms"], + late_flips_in_window=response_diag["dropped_frames"], + longest_frame_interval_ms=response_diag["max_flip_interval_ms"], + target_timing_ok=int(response_diag["trial_clean"]), + ) + + if overlay is not None: + overlay.state.last_result = "HIT" if record.hit else ("early" if record.early_press else "miss") + overlay.state.last_rt_ms = f"{record.rt_ms:.0f} ms" if record.rt_ms != "" else "—" + overlay.state.last_timing_drift_ms = record.timing_drift_ms + + return record, target_timing, scan_phases, nominal_time, total_earned diff --git a/src/mid_det/trial.py b/src/mid_det/trial.py deleted file mode 100644 index 665a024..0000000 --- a/src/mid_det/trial.py +++ /dev/null @@ -1,528 +0,0 @@ -""" -Phase functions and run_trial(). -Uses psychopy.hardware.keyboard.Keyboard for accurate RT timestamping. -No rendering objects are built here; no data is written here. -""" -from __future__ import annotations - -import math -import random -from collections.abc import Callable - -import pandas as pd - -try: - from psychopy import core, logging, visual - from psychopy.hardware import keyboard -except ModuleNotFoundError: - # Headless/CI without PsychoPy: keep this module importable so the pure-logic - # (_compute_reward) and timing (run_response, driven by a fake window/clock in - # tests) code stays testable. `core` is a namespace with the attributes those - # paths reference — tests patch core.Clock; real runs always have PsychoPy. - import types - - visual = keyboard = None # type: ignore[assignment] - logging = types.SimpleNamespace(exp=lambda *a, **k: None) # type: ignore[assignment] - core = types.SimpleNamespace( # type: ignore[assignment] - Clock=None, CountdownTimer=None, quit=lambda *a, **k: None - ) - -from mid_det import config -from mid_det.calibration import CalibrationState -from mid_det.debug import DebugOverlay -from mid_det.display import ( - Stimuli, - draw_cue, - draw_feedback, - draw_fixation_o, - draw_fixation_x, - draw_target, -) -from mid_det.recorder import ScanPhase, TargetTimingRecord, TrialRecord -from mid_det.scanner import PulseCounter - - -def run_cue( - win: visual.Window, - stimuli: Stimuli, - polarity: str, - magnitude: int, - kb: keyboard.Keyboard, - overlay: DebugOverlay | None = None, -) -> None: - """Display cue for STUDY_TIMES_S['cue'] seconds.""" - timer = core.CountdownTimer(config.STUDY_TIMES_S["cue"]) - while timer.getTime() > 0: - draw_cue(stimuli, polarity, magnitude) - win.flip() - _check_quit(kb, overlay) - - -def run_fixation( - win: visual.Window, - stimuli: Stimuli, - kb: keyboard.Keyboard, - overlay: DebugOverlay | None = None, -) -> bool: - """Display fixation; return True if any response key was pressed (early press).""" - kb.clearEvents() - early = False - timer = core.CountdownTimer(config.STUDY_TIMES_S["fixation"]) - while timer.getTime() > 0: - draw_fixation_x(stimuli) - win.flip() - _check_quit(kb, overlay) - # Poll in the loop so a press doesn't sit in the buffer until end-of-phase, - # where a downstream kb.clearEvents() could discard it before inspection. - if not early and kb.getKeys(keyList=config.EXP_KEYS, waitRelease=False): - early = True - if not early and kb.getKeys(keyList=config.EXP_KEYS, waitRelease=False): - early = True - return early - - -def run_response( - win: visual.Window, - stimuli: Stimuli, - kb: keyboard.Keyboard, - jitter_s: float, - target_dur_s: float, - frame_dur_s: float, - early_press: bool, - overlay: DebugOverlay | None = None, -) -> tuple[bool, float | None, bool, float | None, dict]: - """ - Display response phase (STUDY_TIMES_S['response'] seconds total). - Target appears after jitter_s and stays visible for target_dur_s seconds. - Returns (hit, rt_s, early_press, target_removed_at, diagnostics). - """ - phase_clock = core.Clock() - target_shown = False - target_onset_flip_done = False - target_removed_at: float | None = None - clock_reset_scheduled = False - hit = False - rt_s: float | None = None - - # Measurement: use win.lastFrameT (the time PsychoPy stamps inside flip(), - # right after glFinish but before callOnFlip callbacks fire). This is a - # tighter proxy for the actual swap time than reading core.getTime() after - # flip() returns, which additionally absorbs callback-dispatch overhead. - response_start_flip_t: float | None = None - onset_flip_t: float | None = None - removal_flip_t: float | None = None - flip_iters = 0 - dropped_at_onset: int | None = None - max_intra_flip_ms = 0.0 - last_intra_flip_t: float | None = None - - # Drain any presses queued between fixation end and now (e.g. during - # pulse_counter.wait_for_tr() or scheduler hiccups). Any EXP_KEYS press - # observed here belongs to the pre-target window and must count as early. - # Plain kb.clearEvents() would silently discard these. - if kb.getKeys(keyList=config.EXP_KEYS, waitRelease=False): - early_press = True - - # Derive frames-shown from kb.clock (reset on the onset flip) rather than - # counting loop iterations. Counting iterations assumes every flip() blocks - # exactly one vsync — true on macOS once VSYNC is verified, but Windows - # occasionally drops a frame, making one flip() span two vsyncs. The - # iteration counter would still tick once and the target would be visible - # for one extra frame. The clock advances with real wall time regardless, - # so round(elapsed / frame_dur) + 1 reflects actual displayed frames. - # round() rather than ceil() is intentional: frame-aligned durations (e.g. - # 17 * frame_dur) accumulate floating-point drift and evaluate to - # 17.000000000000004, which ceil() would promote to 18 — one phantom extra - # frame. round() snaps back to the correct integer. - n_target_frames = round(target_dur_s / frame_dur_s) - - while phase_clock.getTime() < config.STUDY_TIMES_S["response"]: - t = phase_clock.getTime() - - # Schedule kb.clock reset to fire on the next flip so t=0 aligns with target onset - if not clock_reset_scheduled and t >= jitter_s: - win.callOnFlip(kb.clock.reset) - clock_reset_scheduled = True - target_shown = True - - if target_onset_flip_done and target_removed_at is None: - frames_shown = round(kb.clock.getTime() / frame_dur_s) + 1 - else: - frames_shown = 0 - should_remove = ( - target_removed_at is None - and target_onset_flip_done - and frames_shown >= n_target_frames - ) - - # Draw before flip: omitting draw_target when should_remove clears the target on this flip - if target_shown and target_removed_at is None and not should_remove: - draw_target(stimuli) - elif not target_shown: - draw_fixation_x(stimuli) - win.flip() - - # Stamp the first flip of the response phase so we can later report the - # actual pre-target jitter wall time (analog of target_dur_ms_actual). - if response_start_flip_t is None: - response_start_flip_t = win.lastFrameT - - # Timestamp using win.lastFrameT, which PsychoPy sets inside flip() - # right after the GPU finishes the swap. core.getTime() after flip() - # returns would additionally include callOnFlip callback overhead. - if should_remove: - removal_flip_t = win.lastFrameT - target_removed_at = removal_flip_t - onset_flip_t if onset_flip_t else None - # Also measure the interval into the removal flip — a stretched - # removal flip is exactly the DWM-hiccup signature and must not be - # invisible to the diagnostic. - if last_intra_flip_t is not None: - delta_ms = (removal_flip_t - last_intra_flip_t) * 1000 - if delta_ms > max_intra_flip_ms: - max_intra_flip_ms = delta_ms - elif clock_reset_scheduled and not target_onset_flip_done: - target_onset_flip_done = True - onset_flip_t = win.lastFrameT - dropped_at_onset = getattr(win, "nDroppedFrames", 0) - last_intra_flip_t = onset_flip_t - - if target_onset_flip_done and target_removed_at is None: - flip_iters += 1 - now = win.lastFrameT - if last_intra_flip_t is not None: - delta_ms = (now - last_intra_flip_t) * 1000 - if delta_ms > max_intra_flip_ms: - max_intra_flip_ms = delta_ms - last_intra_flip_t = now - - # Poll keys after flip so timestamps are relative to the most recent screen state - if not target_shown and not early_press: - if kb.getKeys(keyList=config.EXP_KEYS, waitRelease=False): - early_press = True - - if target_shown and not hit and rt_s is None and not early_press: - keys = kb.getKeys(keyList=config.EXP_KEYS, waitRelease=False) - if keys: - rt = keys[0].rt - if rt < 0: - early_press = True - else: - rt_s = rt - if target_removed_at is None or rt < target_removed_at: - hit = True - - _check_quit(kb, overlay) - - if onset_flip_t is not None and removal_flip_t is not None: - onset_to_removal_wall_ms = round( - (removal_flip_t - onset_flip_t) * 1000, 2 - ) - else: - onset_to_removal_wall_ms = "" - - if response_start_flip_t is not None and onset_flip_t is not None: - jitter_ms_actual = round((onset_flip_t - response_start_flip_t) * 1000, 2) - else: - jitter_ms_actual = "" - - if dropped_at_onset is not None: - dropped_frames = int(getattr(win, "nDroppedFrames", 0) - dropped_at_onset) - else: - dropped_frames = 0 - - # Mark trial unclean if (a) PsychoPy detected any dropped frames during the - # response window, OR (b) the measured wall delta differs from the expected - # on-screen duration by more than half a frame. Either condition makes the - # exact target-display time unreliable for timing-sensitive analyses; flag - # for exclusion at analysis time. DWM-induced extra frames on Windows are - # an acknowledged unsolvable limitation — exclusion is the standard fix. - expected_dur_ms = n_target_frames * frame_dur_s * 1000 - half_frame_ms = (frame_dur_s * 1000) / 2 - timing_off_by_frame = ( - isinstance(onset_to_removal_wall_ms, (int, float)) - and abs(onset_to_removal_wall_ms - expected_dur_ms) > half_frame_ms - ) - trial_clean = dropped_frames == 0 and not timing_off_by_frame - - diagnostics = { - "flip_iters": flip_iters, - "n_target_frames": n_target_frames, - "dropped_frames": dropped_frames, - "onset_to_removal_wall_ms": onset_to_removal_wall_ms, - "max_flip_interval_ms": round(max_intra_flip_ms, 2), - "trial_clean": trial_clean, - "jitter_ms_actual": jitter_ms_actual, - } - - return hit, rt_s, early_press, target_removed_at, diagnostics - - -def _compute_reward( - hit: bool, polarity: str, magnitude: int, total_earned: int -) -> tuple[str, int]: - """ - Return (reward_outcome_label, new_total_earned). - - Gain trial: hit → +$magnitude, miss → $0 - Loss trial: hit → $0, miss → -$magnitude - """ - if polarity == "gain": - if hit and magnitude > 0: - return f"+${magnitude}.00", total_earned + magnitude - if hit and magnitude == 0: - return "+$0.00", total_earned - return "$0.00", total_earned - - # loss - if hit: - return "$0.00", total_earned - if magnitude > 0: - return f"-${magnitude}.00", total_earned - magnitude - return "-$0.00", total_earned - - -def show_outcome( - win: visual.Window, - stimuli: Stimuli, - kb: keyboard.Keyboard, - hit: bool, - reward_outcome: str, - overlay: DebugOverlay | None = None, -) -> None: - """Display the outcome feedback for STUDY_TIMES_S['outcome'] seconds.""" - timer = core.CountdownTimer(config.STUDY_TIMES_S["outcome"]) - while timer.getTime() > 0: - draw_feedback(stimuli, hit, reward_outcome) - win.flip() - _check_quit(kb, overlay) - - -def run_iti( - win: visual.Window, - stimuli: Stimuli, - kb: keyboard.Keyboard, - fix_dur_s: float, - overlay: DebugOverlay | None = None, -) -> None: - """Display fixation for fix_dur_s seconds (drift-corrected by caller).""" - if fix_dur_s <= 0: - return - timer = core.CountdownTimer(fix_dur_s) - while timer.getTime() > 0: - draw_fixation_o(stimuli) - win.flip() - _check_quit(kb, overlay) - - -def _check_quit(kb: keyboard.Keyboard, overlay: DebugOverlay | None = None) -> None: - if kb.getKeys(keyList=["escape"], waitRelease=False): - core.quit() - if overlay is not None and kb.getKeys(keyList=["f3"], waitRelease=False): - overlay.toggle() - - -def run_trial( - win: visual.Window, - stimuli: Stimuli, - kb: keyboard.Keyboard, - global_clock: core.Clock, - row: pd.Series, - trial_n: int, - n_trials: int, - n_iti_trs: int, - nominal_time: float, - total_earned: int, - subject_id: str, - run_n: str, - pulse_ct: int, - pulse_counter: PulseCounter, - calibration: CalibrationState, - frame_dur_s: float, - on_window: Callable[[int], None] | None = None, - on_response: Callable[[bool, float | None, bool, int, float | str, str, int], None] | None = None, - overlay: DebugOverlay | None = None, -) -> tuple[TrialRecord, TargetTimingRecord, list[ScanPhase], float, int]: - """ - Run one complete trial (cue → fixation → response → outcome → ITI). - - Returns (record, target_timing, scan_phases, nominal_time, total_earned). - """ - polarity = str(row["polarity"]) - magnitude = int(row["magnitude"]) - trial_type = config.TRIAL_TYPE_MAP[(polarity, magnitude)] - target_dur_s = calibration.next_target_dur_s(polarity, magnitude) - jitter_s = random.uniform( - config.JITTER_MIN_S, - config.JITTER_MAX_S, - ) - label = config.cue_label(polarity, magnitude) - - target_dur_ms = int(round(target_dur_s * 1000)) - logging.exp( - f"Trial {trial_n:3d}/{n_trials} cue={label} " - f"target_dur={target_dur_ms} ms jitter={int(jitter_s * 1000)} ms" - ) - if on_window is not None: - on_window(target_dur_ms) - - scan_phases: list[ScanPhase] = [] - tr_within = 0 - - def _update_overlay(phase: str) -> None: - if overlay is not None: - overlay.state.phase = phase - overlay.state.polarity = polarity - overlay.state.magnitude = magnitude - overlay.state.target_dur_ms = target_dur_ms - overlay.state.jitter_ms = int(jitter_s * 1000) - overlay.state.pulse_ct = pulse_ct - overlay.state.global_time = global_clock.getTime() - overlay.state.nominal_time = nominal_time - - # ── CUE ───────────────────────────────────────────────────────────────── - pulse_ct += pulse_counter.drain() - time_onset = global_clock.getTime() - tr_within += 1 - scan_phases.append(ScanPhase( - trial_n=trial_n, phase="cue", tr_n=tr_within, - phase_onset_global_time=time_onset, - phase_onset_trial_time=0.0, - pulse_ct=pulse_ct, - )) - _update_overlay("cue") - run_cue(win, stimuli, polarity, magnitude, kb, overlay) - nominal_time += config.STUDY_TIMES_S["cue"] - - # ── FIXATION ────────────────────────────────────────────────────────────── - pulse_ct += pulse_counter.wait_for_tr() - fixation_start = global_clock.getTime() - tr_within += 1 - scan_phases.append(ScanPhase( - trial_n=trial_n, phase="fixation", tr_n=tr_within, - phase_onset_global_time=fixation_start, - phase_onset_trial_time=fixation_start - time_onset, - pulse_ct=pulse_ct, - )) - _update_overlay("fixation") - early_press = run_fixation(win, stimuli, kb, overlay) - nominal_time += config.STUDY_TIMES_S["fixation"] - - # ── RESPONSE ───────────────────────────────────────────────────────────── - pulse_ct += pulse_counter.wait_for_tr() - response_start = global_clock.getTime() - tr_within += 1 - scan_phases.append(ScanPhase( - trial_n=trial_n, phase="response", tr_n=tr_within, - phase_onset_global_time=response_start, - phase_onset_trial_time=response_start - time_onset, - pulse_ct=pulse_ct, - )) - _update_overlay("response") - hit, rt_s, early_press, target_removed_at, response_diag = run_response( - win, stimuli, kb, jitter_s, target_dur_s, frame_dur_s, early_press, overlay - ) - nominal_time += config.STUDY_TIMES_S["response"] - target_dur_ms_actual = round(target_removed_at * 1000, 2) if target_removed_at is not None else "" - reward_outcome, total_earned = _compute_reward(hit, polarity, magnitude, total_earned) - if on_response is not None: - on_response( - hit, rt_s, early_press, target_dur_ms, target_dur_ms_actual, - reward_outcome, total_earned, - ) - - # ── OUTCOME ────────────────────────────────────────────────────────────── - pulse_ct += pulse_counter.wait_for_tr() - outcome_start = global_clock.getTime() - tr_within += 1 - scan_phases.append(ScanPhase( - trial_n=trial_n, phase="outcome", tr_n=tr_within, - phase_onset_global_time=outcome_start, - phase_onset_trial_time=outcome_start - time_onset, - pulse_ct=pulse_ct, - )) - _update_overlay("outcome") - show_outcome(win, stimuli, kb, hit, reward_outcome, overlay) - nominal_time += config.STUDY_TIMES_S["outcome"] - calibration.record_outcome(polarity, magnitude, bool(hit)) - - # ── ITI ────────────────────────────────────────────────────────────────── - for _ in range(n_iti_trs): - pulse_ct += pulse_counter.wait_for_tr() - iti_start = global_clock.getTime() - tr_within += 1 - scan_phases.append(ScanPhase( - trial_n=trial_n, phase="post-outcome-fixation", tr_n=tr_within, - phase_onset_global_time=iti_start, - phase_onset_trial_time=iti_start - time_onset, - pulse_ct=pulse_ct, - )) - actual_time = global_clock.getTime() - iti_dur = config.STUDY_TIMES_S["iti"] - (actual_time - nominal_time) - nominal_time += config.STUDY_TIMES_S["iti"] - _update_overlay("iti") - run_iti(win, stimuli, kb, iti_dur, overlay) - - # ── BUILD RECORD ───────────────────────────────────────────────────────── - time_trial_end = global_clock.getTime() - time_sched_end = nominal_time - - # Backfill each phase's end = onset of the next phase (tiled timeline); - # the final phase ends at the trial end. - for i, sp in enumerate(scan_phases): - end_global = ( - scan_phases[i + 1].phase_onset_global_time - if i + 1 < len(scan_phases) - else time_trial_end - ) - sp.phase_offset_global_time = end_global - sp.phase_offset_trial_time = end_global - time_onset - sp.trial_type = trial_type - sp.polarity = polarity - sp.magnitude = magnitude - - record = TrialRecord( - trial_n=trial_n, - trial_type=trial_type, - polarity=polarity, - magnitude=magnitude, - cue_label=label, - time_onset=round(time_onset, 6), - jitter_ms=int(round(jitter_s * 1000)), - jitter_ms_actual=response_diag["jitter_ms_actual"], - target_dur_ms=target_dur_ms, - target_dur_ms_actual=target_dur_ms_actual, - early_press=int(early_press), - hit=int(hit), - rt_ms=round(rt_s * 1000, 2) if rt_s is not None else "", - reward_outcome=reward_outcome, - total_earned=total_earned, - time_trial_end=round(time_trial_end, 6), - trial_dur_ms=int(round((time_trial_end - time_onset) * 1000)), - time_sched_end=round(time_sched_end, 6), - timing_drift_ms=round((time_trial_end - time_sched_end) * 1000, 2), - n_iti_trs=n_iti_trs, - total_trs=tr_within, - subject_id=subject_id, - run_n=run_n, - pulse_ct=scan_phases[0].pulse_ct, - ) - - target_timing = TargetTimingRecord( - trial_n=trial_n, - target_frames_scheduled=response_diag["n_target_frames"], - target_frames_shown=response_diag["flip_iters"], - target_visible_ms_scheduled=round( - response_diag["n_target_frames"] * frame_dur_s * 1000, 2 - ), - target_visible_ms_measured=response_diag["onset_to_removal_wall_ms"], - late_flips_in_window=response_diag["dropped_frames"], - longest_frame_interval_ms=response_diag["max_flip_interval_ms"], - target_timing_ok=int(response_diag["trial_clean"]), - ) - - if overlay is not None: - overlay.state.last_result = "HIT" if record.hit else ("early" if record.early_press else "miss") - overlay.state.last_rt_ms = f"{record.rt_ms:.0f} ms" if record.rt_ms != "" else "—" - overlay.state.last_timing_drift_ms = record.timing_drift_ms - - return record, target_timing, scan_phases, nominal_time, total_earned diff --git a/tests/test_calibration.py b/tests/test_calibration.py index 014f163..7c3f00f 100644 --- a/tests/test_calibration.py +++ b/tests/test_calibration.py @@ -9,7 +9,7 @@ import pytest from mid_det import config -from mid_det.calibration import CalibrationState +from mid_det.task.calibration import CalibrationState # ── helpers ────────────────────────────────────────────────────────────────── diff --git a/tests/test_calibration_matlab_parity.py b/tests/test_calibration_matlab_parity.py index d9eeaee..ab156a0 100644 --- a/tests/test_calibration_matlab_parity.py +++ b/tests/test_calibration_matlab_parity.py @@ -21,7 +21,7 @@ import pytest from mid_det import config -from mid_det.calibration import CalibrationState +from mid_det.task.calibration import CalibrationState _REF_DIR = Path(__file__).parent / "matlab_ref" diff --git a/tests/test_ratings.py b/tests/test_ratings.py index 7789a9c..0a95369 100644 --- a/tests/test_ratings.py +++ b/tests/test_ratings.py @@ -82,3 +82,44 @@ def test_write_ratings_csv(tmp_path): def test_ratings_in_valid_range(): for r in range(1, core.N_ELS + 1): assert 1 <= core.clamp_slider(r, 0) <= core.N_ELS + + +def test_legacy_gamble_names(): + # magnitude 0/1/5 -> low/med/high; loss -> square, gain -> circle. + assert core.GAMBLE_NAMES == { + ("loss", 0): "lowsquare", + ("loss", 1): "medsquare", + ("loss", 5): "highsquare", + ("gain", 0): "lowcircle", + ("gain", 1): "medcircle", + ("gain", 5): "highcircle", + } + + +def test_legacy_ratings_rows_order(): + results = [ + {"polarity": c.polarity, "magnitude": c.magnitude, "arousal": 5, "valence": 4} + for c in core.RATING_CUES + ] + rows = core.build_legacy_csv_rows(results) + assert rows[0] == ["gamble", "arousal", "valence"] + assert [r[0] for r in rows[1:]] == [ + "lowsquare", "medsquare", "highsquare", + "lowcircle", "medcircle", "highcircle", + ] + # columns are gamble, arousal, valence (arousal before valence) + assert rows[1] == ["lowsquare", "5", "4"] + + +def test_write_legacy_ratings_csv(tmp_path): + results = [ + {"polarity": c.polarity, "magnitude": c.magnitude, "arousal": 4, "valence": 4} + for c in core.RATING_CUES + ] + out = tmp_path / "1_ratings.csv" + core.write_legacy_ratings_csv(out, results) + + with open(out, newline="") as f: + read = list(csv.reader(f)) + assert read[0] == ["gamble", "arousal", "valence"] + assert len(read) == 1 + len(core.RATING_CUES) diff --git a/tests/test_recorder.py b/tests/test_recorder.py index 8f6e435..62e0098 100644 --- a/tests/test_recorder.py +++ b/tests/test_recorder.py @@ -7,12 +7,36 @@ from datetime import datetime from pathlib import Path -from mid_det.recorder import ( +import pytest + +from mid_det.io.recording import ( BEHAVIORAL_COLUMNS, + LEGACY_MID_COLUMNS, BehavioralCsvWriter, + LegacyMidCsvWriter, TrialRecord, write_ratings_manifest, ) +from mid_det.io.recording.legacy import _num2str + + +@pytest.mark.parametrize( + "value, expected", + [ + (0.0, "0"), # zero + (0.16667, "0.16667"), # < 1 -> 5 sig figs + (1 / 3, "0.33333"), # < 1 -> 5 sig figs (rounded) + (1.0, "1"), # whole value collapses + (-1, "-1"), # integer miss code + (-2, "-2"), # early-press code + (0.4, "0.4"), # short decimal kept short + (12.013308, "12.0133"), # magnitude > 1 -> 6 sig figs + (-0.00011858, "-0.00011858"), # tiny drift -> 5 sig figs, no sci (exp -4) + (-7.83e-05, "-7.83e-05"), # exp < -4 -> scientific, like num2str + ], +) +def test_num2str_matches_matlab(value, expected): + assert _num2str(value) == expected @dataclass @@ -76,3 +100,111 @@ def test_behavioral_csv_roundtrip(tmp_path: Path): assert set(rows[0].keys()) == set(BEHAVIORAL_COLUMNS) assert rows[0]["cue_label"] == "+$5.00" assert rows[0]["reward_outcome"] == "+$5.00" + + +def _legacy_record(**overrides) -> TrialRecord: + base = dict( + trial_n=1, trial_type=4, polarity="gain", magnitude=0, + cue_label="+$0.00", time_onset=12.0003, + jitter_ms=10, jitter_ms_actual=10.5, + target_dur_ms=400, target_dur_ms_actual="", + early_press=0, hit=0, rt_ms="", reward_outcome="$0.00", + total_earned=0, time_trial_end=26.0, trial_dur_ms=14000, + time_sched_end=26.0, timing_drift_ms=-0.12, + n_iti_trs=3, total_trs=7, + subject_id="X", run_n="1", pulse_ct=0, + ) + base.update(overrides) + return TrialRecord(**base) + + +def test_legacy_mid_per_tr_rows_and_header(tmp_path: Path): + path = tmp_path / "1_b1.csv" + w = LegacyMidCsvWriter(path) + w.append(_legacy_record(total_trs=7)) # 7 TR rows + w.append(_legacy_record(trial_n=2, total_trs=5, n_iti_trs=1)) # 5 TR rows + w.close() + + with open(path) as f: + reader = csv.DictReader(f) + assert reader.fieldnames == LEGACY_MID_COLUMNS + rows = list(reader) + assert len(rows) == 7 + 5 + assert [r["TR"] for r in rows[:7]] == [str(i) for i in range(1, 8)] + assert all(r["trial"] == "1" for r in rows[:7]) + assert [r["TR"] for r in rows[7:]] == [str(i) for i in range(1, 6)] + + +def test_legacy_mid_formatting_and_winpercents(tmp_path: Path): + # Formats traced to MATLAB PresentCue.m / PresentFeedback.m / PartialParseData.m + # / PresentTarget.m. + path = tmp_path / "leg.csv" + w = LegacyMidCsvWriter(path) + # gain $0 miss (type 4): cue +$0, gain $0 (gain miss), total $0.00 + w.append(_legacy_record(trial_n=1, trial_type=4, polarity="gain", magnitude=0, + hit=0, early_press=0, total_earned=0, + target_dur_ms=400, rt_ms="", timing_drift_ms=-0.12, + n_iti_trs=3, total_trs=7)) + # loss $1 miss (type 2): cue -$1, gain -$1, running total -1 -> "$-1.00" + w.append(_legacy_record(trial_n=2, trial_type=2, polarity="loss", magnitude=1, + hit=0, early_press=0, total_earned=-1, + target_dur_ms=400, rt_ms="", n_iti_trs=3, total_trs=7)) + # loss $1 hit (type 2): cue -$1, gain $0 (loss avoided), total unchanged, rt set + w.append(_legacy_record(trial_n=3, trial_type=2, polarity="loss", magnitude=1, + hit=1, early_press=0, total_earned=-1, + target_dur_ms=420, rt_ms=226.54, n_iti_trs=3, total_trs=7)) + w.close() + + with open(path) as f: + rows = list(csv.DictReader(f)) + + t1, t2, t3 = rows[0], rows[7], rows[14] + # target_ms / rt / drift are in seconds, formatted via MATLAB-style num2str. + assert t1["target_ms"] == "0.4" + assert t1["rt"] == "-1" # miss + assert t1["cue_value"] == "+$0" + assert t1["trial_gain"] == "$0" # gain trial, missed + assert t1["total"] == "$0.00" + assert t1["iti"] == "6" + assert t1["drift"] == "-0.00012" # -0.12 ms, num2str 5 sig figs + assert t1["total_winpercent"] == "0" # num2str: 0.0 -> "0" + assert t1["binned_winpercent"] == "0" + + # loss $1 miss: incurs the loss + assert t2["cue_value"] == "-$1" + assert t2["trial_gain"] == "-$1" + assert t2["total"] == "$-1.00" + # loss $1 hit: loss avoided, rt recorded + assert t3["cue_value"] == "-$1" + assert t3["trial_gain"] == "$0" + assert t3["total"] == "$-1.00" + assert t3["rt"] == "0.22654" + assert t3["target_ms"] == "0.42" + # win%: after 3 trials, 1 hit overall -> 1/3; type-2 has 2 trials, 1 hit -> 1/2. + # num2str renders 1/3 to 5 sig figs ("0.33333"). + assert t3["total_winpercent"] == "0.33333" + assert t3["binned_winpercent"] == "0.5" + + +def test_legacy_mid_early_press_rt(tmp_path: Path): + # MATLAB encodes an early (front-buffer) press as rt = -2, distinct from a + # miss (-1); both score hit = 0. + path = tmp_path / "early.csv" + w = LegacyMidCsvWriter(path) + w.append(_legacy_record(hit=0, early_press=1, rt_ms="", total_trs=5)) + w.close() + with open(path) as f: + rows = list(csv.DictReader(f)) + assert rows[0]["rt"] == "-2" + assert rows[0]["hit"] == "0" + + +def test_legacy_mid_block2_trial_offset(tmp_path: Path): + # MATLAB PartialParseData.m: block-2 trials are numbered from 43 (offset 42). + path = tmp_path / "b2.csv" + w = LegacyMidCsvWriter(path, trial_offset=42) + w.append(_legacy_record(trial_n=1, total_trs=5)) + w.close() + with open(path) as f: + rows = list(csv.DictReader(f)) + assert all(r["trial"] == "43" for r in rows) diff --git a/tests/test_reward.py b/tests/test_reward.py index a9938e2..83dc689 100644 --- a/tests/test_reward.py +++ b/tests/test_reward.py @@ -15,7 +15,7 @@ import pytest -from mid_det.trial import _compute_reward +from mid_det.task.trial import _compute_reward @pytest.mark.parametrize( diff --git a/tests/test_scanner.py b/tests/test_scanner.py index 1d7a547..136fe00 100644 --- a/tests/test_scanner.py +++ b/tests/test_scanner.py @@ -2,7 +2,7 @@ from __future__ import annotations from mid_det import config -from mid_det.scanner import PulseCounter +from mid_det.io.scanner import PulseCounter class FakeBackend: diff --git a/tests/test_sequences.py b/tests/test_sequences.py index b032dda..cdc17bf 100644 --- a/tests/test_sequences.py +++ b/tests/test_sequences.py @@ -16,8 +16,9 @@ import pytest -from mid_det import config, sequences -from mid_det.sequences import load_sequence +from mid_det import config +from mid_det.io import sequences +from mid_det.io.sequences import load_sequence # ── schema ──────────────────────────────────────────────────────────────────── diff --git a/tests/test_trial_boundary.py b/tests/test_trial_boundary.py index 040c363..51a229f 100644 --- a/tests/test_trial_boundary.py +++ b/tests/test_trial_boundary.py @@ -1,5 +1,5 @@ """ -Deterministic boundary tests for trial.run_response hit/miss classification. +Deterministic boundary tests for response.run_response hit/miss classification. The concern these tests pin down: a keypress that physically lands while the target is still on screen — including the *very last* visible frame — must be @@ -24,7 +24,8 @@ import pytest -from mid_det import config, trial +from mid_det import config +from mid_det.task import response FRAME_DUR = 0.01 JITTER_S = 0.025 @@ -69,7 +70,7 @@ def callOnFlip(self, fn, *args, **kwargs) -> None: def flip(self) -> None: # Advance, then stamp lastFrameT, then fire callbacks. This mirrors # PsychoPy: lastFrameT is set right after the swap, before callOnFlip - # callbacks run (see trial.run_response comments). + # callbacks run (see response.run_response comments). self._v.t += self._frame_dur self.lastFrameT = self._v.t cbs, self._cbs = self._cbs, [] @@ -125,12 +126,12 @@ def _drive(presses, *, early_press=False, monkeypatch): # Target draws need a real window; patch them out — geometry is irrelevant # to the timing classification under test. - monkeypatch.setattr(trial, "draw_target", lambda *a, **k: None) - monkeypatch.setattr(trial, "draw_fixation_x", lambda *a, **k: None) + monkeypatch.setattr(response, "draw_target", lambda *a, **k: None) + monkeypatch.setattr(response, "draw_fixation_x", lambda *a, **k: None) # phase_clock = core.Clock() inside run_response -> bind to virtual time. - monkeypatch.setattr(trial.core, "Clock", lambda *a, **k: _FakeClock(virtual)) + monkeypatch.setattr(response.core, "Clock", lambda *a, **k: _FakeClock(virtual)) - return trial.run_response( + return response.run_response( win, object(), # stimuli: unused once draws are patched out kb, @@ -196,6 +197,21 @@ def test_no_press_is_miss(monkeypatch): assert removed_at == pytest.approx(EXPECTED_REMOVED_AT, abs=1e-9) +def test_response_diagnostics(monkeypatch): + # Pin the full diagnostics dict so the refactor is provably behavior-preserving. + # On the virtual clock the onset flip lands at t=0.04 and removal at t=0.09. + _, _, _, removed_at, diag = _drive([], monkeypatch=monkeypatch) + assert removed_at == pytest.approx(0.05, abs=1e-9) + assert diag["n_target_frames"] == 5 + assert diag["flip_iters"] == 5 + assert diag["onset_to_removal_wall_ms"] == pytest.approx(50.0) + # onset snaps to the next frame boundary after jitter (0.025 -> 0.04) + assert diag["jitter_ms_actual"] == pytest.approx(30.0) + assert diag["dropped_frames"] == 0 + assert diag["max_flip_interval_ms"] == pytest.approx(10.0) + assert diag["trial_clean"] == 1 + + def test_just_below_and_just_above_removal_boundary(monkeypatch): # Tighten the boundary to sub-frame: 1 ms below removal -> hit, # 1 ms above -> miss. Confirms the half-open [0, target_removed_at) window. diff --git a/text/instructions_MID.txt b/text/instructions_MID.txt index cb6e853..524a4df 100644 --- a/text/instructions_MID.txt +++ b/text/instructions_MID.txt @@ -1,6 +1,4 @@ -In this task, you can win or avoid losing money if you are quick enough. Your task is to press a button when you see the white triangle. -On trials where the triangle is preceded by a circle, you will win the amount shown if you press the key while the triangle is on the screen. -On trials where the triangle is preceded by a square, you will avoid losing the amount shown if you press the key while the triangle is on the screen. -The amount shown on each cue may be $0, $1, or $5. A line across the shape indicates the magnitude: near the bottom means $0, through the middle means $1, near the top means $5. -The triangle will be shown for a different duration on each trial. Some trials will give you more time to respond than others; simply do your best to press the key while the triangle is visible. -After each trial, you will see your outcome. We will add whatever you win, and subtract whatever you lose, from the amount that we pay you today. If you have questions, please ask the experimenter now. +In this experiment you will respond as quickly as possible to earn money. You will see a cue indicating how much money you can win or avoid losing. After this cue, a triangle will appear. Hit a button as fast as you can when the triangle appears to win. Press any button to continue the instructions. +Cues are either circles or squares. CIRCLE cues mean that you can EARN that amount if you hit the target. SQUARE cues mean that you can AVOID LOSING that amount if you hit the target. Press any button to continue the instructions. +If you miss a CIRCLE cue, you will NOT GAIN the amount in the cue. If you miss a SQUARE cue, you will LOSE the amount in the cue. Press any button to continue the instructions. +Please HOLD STILL while you are in the scanner. \ No newline at end of file diff --git a/uv.lock b/uv.lock index eb61554..2228336 100644 --- a/uv.lock +++ b/uv.lock @@ -985,6 +985,7 @@ dependencies = [ { name = "numpy" }, { name = "pandas", version = "2.3.3", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.14'" }, { name = "pandas", version = "3.0.2", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.14'" }, + { name = "prompt-toolkit" }, { name = "psychopy" }, { name = "pyobjc-framework-quartz", marker = "sys_platform == 'darwin'" }, { name = "questionary" }, @@ -1002,6 +1003,7 @@ requires-dist = [ { name = "mcculw", specifier = ">=1.0.0" }, { name = "numpy", specifier = ">=1.26" }, { name = "pandas", specifier = ">=2.0" }, + { name = "prompt-toolkit", specifier = ">=3.0" }, { name = "psychopy", specifier = ">=2026.1" }, { name = "pyobjc-framework-quartz", marker = "sys_platform == 'darwin'", specifier = ">=10" }, { name = "pytest", marker = "extra == 'dev'" },