From 2d611d74241ce97c122bc647f7149c85f195f508 Mon Sep 17 00:00:00 2001 From: jl33ai Date: Fri, 29 May 2026 09:27:34 -0700 Subject: [PATCH 1/2] add synthetic data source for running without trodes lets you run the pipeline end to end without an acquisition rig. useful for smoke testing a fresh install, working on a laptop, and CI. - realtime_decoder/synthetic.py: SyntheticDataReceiver and SyntheticClient that match the surface of TrodesDataReceiver and TrodesClient. drop in replacements. non-blocking __next__ semantics so the polling loops in encoder/decoder/ripple work unchanged. - runscript.py picks the data source via config['datasource']. defaults to 'trodes' so every existing config keeps working. - config/demo_synthetic.yml is a 5 rank demo wired to the synthetic source. - README adds a 'running without acquisition hardware' section. spikes are poisson, marks gaussian, position walks a triangle wave along a single segment. not biologically realistic, the point is to exercise the data path and message plumbing. to try it: mpiexec -np 5 python -u runscript.py config/demo_synthetic.yml --- README.md | 17 ++ config/demo_synthetic.yml | 175 +++++++++++++++++++ realtime_decoder/synthetic.py | 305 ++++++++++++++++++++++++++++++++++ runscript.py | 42 +++-- 4 files changed, 528 insertions(+), 11 deletions(-) create mode 100644 config/demo_synthetic.yml create mode 100644 realtime_decoder/synthetic.py diff --git a/README.md b/README.md index db0fc9a..803ee71 100644 --- a/README.md +++ b/README.md @@ -24,6 +24,23 @@ mpiexec -np -bind-to hwthread python -u runscript.py 0 + and elapsed > self._p['run_duration_s'] + and not self._stopped + ): + # one-time log; the supervisor's termination is wired + # through SyntheticClient.receive() below. + self._stopped = True + + if self.datatype == Datatypes.LFP: + return self._next_lfp(elapsed) + elif self.datatype == Datatypes.SPIKES: + return self._next_spike(elapsed) + else: + return self._next_position(elapsed) + + # ------------------------------------------------------------------ + # Per-datatype generators + # ------------------------------------------------------------------ + + def _next_lfp(self, elapsed): + target_idx = int(elapsed * self._fs) + if target_idx < self._next_sample_idx: + return None + idx = self._next_sample_idx + self._next_sample_idx += 1 + # white-ish noise sized to (num_channels,), scaled the same way + # TrodesDataReceiver does (raw * voltage_scaling_factor) + n = max(1, len(self.ntrode_ids)) + raw = self._rng.standard_normal(n) * 200.0 # ~uV range pre-scale + data = raw * self._p['voltage_scaling_factor'] + local_ts = idx # LFP uses spike-clock timestamps in real Trodes; + # at fs_lfp=1500, fs_spike=30000 the ratio is 20, but downstream + # only cares about monotonicity within a stream, so use idx. + system_ts = time.time_ns() + return LFPPoint( + local_ts, + list(self.ntrode_ids), + data, + system_ts, + time.time_ns(), + ) + + def _next_spike(self, elapsed): + if not self.ntrode_ids: + return None + spike_sample_now = int(elapsed * self._spike_clock) + # Find any ntrode whose next-spike sample has arrived. + for ntid in self.ntrode_ids: + if self._next_spike_sample[ntid] <= spike_sample_now: + ts = self._next_spike_sample[ntid] + self._schedule_next_spike(ntid, sample_now=spike_sample_now) + # mark vector: gaussian around _amp, all channels positive + samples = ( + self._rng.standard_normal(self._mark_dim) * 8.0 + self._amp + ) / self._p['voltage_scaling_factor'] + # SpikePoint.data is later multiplied by voltage_scaling_factor + # in real Trodes; the encoder reads `max(mark_vec)` so we just + # need the post-scaling magnitudes to clear `encoder.spk_amp`. + return SpikePoint( + ts, + ntid, + samples * self._p['voltage_scaling_factor'], + time.time_ns(), + time.time_ns(), + ) + return None + + def _next_position(self, elapsed): + target_idx = int(elapsed * self._fs) + if target_idx < self._next_sample_idx: + return None + idx = self._next_sample_idx + self._next_sample_idx += 1 + + # triangle-wave walk along a single linear segment between 0 and + # track_length_cm + L = self._p['track_length_cm'] + v = self._p['walk_speed_cm_s'] + t = elapsed + period = 2.0 * L / max(v, 1e-6) + phase = (t % period) / period # 0..1 + pos_cm = L * (1.0 - abs(2.0 * phase - 1.0)) + # x/y/x2/y2 in "pixel" units — kinematics.scale_factor converts back + sf = self.config['kinematics']['scale_factor'] + x = pos_cm / sf + y = 100.0 # constant + x2 = x + 5.0 + y2 = y + return CameraModulePoint( + idx, + segment=0, + position=pos_cm, + x=x, + y=y, + x2=x2, + y2=y2, + t_recv_data=time.time_ns(), + ) + + # ------------------------------------------------------------------ + # internals + # ------------------------------------------------------------------ + + def _schedule_next_spike(self, ntid, *, sample_now): + rate = max(self._p['spike_rate_hz'], 1e-6) + # exponential inter-arrival in seconds → samples + gap_s = self._rng.exponential(1.0 / rate) + gap_samples = max(1, int(gap_s * self._spike_clock)) + self._next_spike_sample[ntid] = sample_now + gap_samples + + +class SyntheticClient(object): + """Drop-in synthetic replacement for ``trodesnet.TrodesClient``. + + Exposes the same surface used by the supervisor and stim decider: + * ``set_startup_callback`` / ``set_termination_callback`` + * ``receive`` (called from the supervisor main loop) + * ``send_statescript_shortcut_message`` (called from stim_decider) + + ``receive`` fires the startup callback once after ``startup_delay_s`` + of wall clock has elapsed, and fires termination once ``run_duration_s`` + has elapsed. + """ + + def __init__(self, config): + self._startup_callback = utils.nop + self._termination_callback = utils.nop + self._p = _params(config) + self._t0_wall = time.time() + self._started = False + self._terminated = False + # log-only buffer of "shortcut messages" the stim decider would + # have sent to ECU; useful for asserting in tests later. + self.sent_shortcuts = [] + + def receive(self): + elapsed = time.time() - self._t0_wall + if not self._started and elapsed >= self._p['startup_delay_s']: + self._started = True + self._startup_callback() + if ( + self._started + and not self._terminated + and self._p['run_duration_s'] > 0 + and elapsed >= self._p['run_duration_s'] + self._p['startup_delay_s'] + ): + self._terminated = True + self._termination_callback() + + def send_statescript_shortcut_message(self, val): + self.sent_shortcuts.append((time.time_ns(), int(val))) + + def set_startup_callback(self, callback): + self._startup_callback = callback + + def set_termination_callback(self, callback): + self._termination_callback = callback diff --git a/runscript.py b/runscript.py index 3e920c8..5b93548 100644 --- a/runscript.py +++ b/runscript.py @@ -11,12 +11,30 @@ from mpi4py import MPI from realtime_decoder import ( - datatypes, position, trodesnet, stimulation, + datatypes, position, trodesnet, synthetic, stimulation, main_process, ripple_process, encoder_process, decoder_process, gui_process, base, messages, merge_rec ) + +def _data_source_factory(config): + """Pick the (receiver_class, client_class) pair for the configured + data source. + + `datasource: trodes` (default) uses the live Trodes streams. + `datasource: synthetic` uses the in-process generator from + `realtime_decoder.synthetic` — install-and-run with no hardware. + """ + ds = config.get('datasource', 'trodes') + if ds == 'trodes': + return trodesnet.TrodesDataReceiver, trodesnet.TrodesClient + if ds == 'synthetic': + return synthetic.SyntheticDataReceiver, synthetic.SyntheticClient + raise ValueError( + f"Unknown datasource {ds!r}; expected 'trodes' or 'synthetic'" + ) + # from line_profiler import LineProfiler class GuiProcessStub(base.RealtimeProcess, base.MessageHandler): @@ -169,21 +187,23 @@ def setup(config_path, numprocs): regloop = True ################################################# + DataReceiver, Client = _data_source_factory(config) + if rank in config['rank']['supervisor']: - trodes_client = trodesnet.TrodesClient(config) + net_client = Client(config) stim_decider = stimulation.TwoArmTrodesStimDecider( - comm, rank, config, trodes_client + comm, rank, config, net_client ) process = main_process.MainProcess( - comm, rank, config, stim_decider, trodes_client + comm, rank, config, stim_decider, net_client ) - trodes_client.set_startup_callback(process.startup) - trodes_client.set_termination_callback(process.trigger_termination) + net_client.set_startup_callback(process.startup) + net_client.set_termination_callback(process.trigger_termination) elif rank in config['rank']['ripples']: - lfp_interface = trodesnet.TrodesDataReceiver( + lfp_interface = DataReceiver( comm, rank, config, datatypes.Datatypes.LFP ) - pos_interface = trodesnet.TrodesDataReceiver( + pos_interface = DataReceiver( comm, rank, config, datatypes.Datatypes.LINEAR_POSITION ) process = ripple_process.RippleProcess( @@ -196,10 +216,10 @@ def setup(config_path, numprocs): # prof.print_stats() # regloop = False elif rank in config['rank']['encoders']: - spikes_interface = trodesnet.TrodesDataReceiver( + spikes_interface = DataReceiver( comm, rank, config, datatypes.Datatypes.SPIKES ) - pos_interface = trodesnet.TrodesDataReceiver( + pos_interface = DataReceiver( comm, rank, config, datatypes.Datatypes.LINEAR_POSITION ) pos_mapper = position.TrodesPositionMapper( @@ -211,7 +231,7 @@ def setup(config_path, numprocs): pos_mapper ) elif rank in config['rank']['decoders']: - pos_interface = trodesnet.TrodesDataReceiver( + pos_interface = DataReceiver( comm, rank, config, datatypes.Datatypes.LINEAR_POSITION ) pos_mapper = position.TrodesPositionMapper( From 2bc4351225923b73d4f950e527d097acaced9efc Mon Sep 17 00:00:00 2001 From: jl33ai Date: Fri, 29 May 2026 09:27:43 -0700 Subject: [PATCH 2/2] add config defaults and startup validation most of the per-animal yamls duplicate the same sampling rates, ripple filter, gui colors, kinematics smoothing filter, mua, etc. drift is real and 'what is the canonical value of X' is currently unanswerable. - configs can declare `_extends: defaults.yml` and only override what is actually different. nested dicts deep merge key by key, lists and scalars replace. - config/defaults.yml holds the values stable across the SC*/fred/ginny configs. - runscript routes config loading through a new config_loader module that runs a small validator on the resolved dict. catches missing rank.supervisor, unknown algorithm, missing encoder.mark_dim, decoder ranks with no assignment, encoder.mark_dim != synthetic.mark_dim, etc. errors print one readable message and exit 2 before MPI workers spawn, instead of surfacing as IndexError deep in a rank. - config/demo_synthetic.yml now uses _extends, drops about 80 lines. - loader uses stdlib pyyaml instead of oyaml since python 3.7+ dicts preserve order. one fewer dep. backward compatible. verified all 16 existing per-animal configs still load and validate unchanged. follow up not in this PR: the SC*/fred/ginny configs can each be rewritten as _extends + overrides which would shrink them by roughly half. left as a separate diff so per-config behavior changes are auditable on their own. --- README.md | 34 +++++ config/defaults.yml | 101 +++++++++++++ config/demo_synthetic.yml | 98 ++++--------- realtime_decoder/config_loader.py | 229 ++++++++++++++++++++++++++++++ runscript.py | 16 ++- 5 files changed, 403 insertions(+), 75 deletions(-) create mode 100644 config/defaults.yml create mode 100644 realtime_decoder/config_loader.py diff --git a/README.md b/README.md index 803ee71..654473d 100644 --- a/README.md +++ b/README.md @@ -45,6 +45,40 @@ auto-terminates after `synthetic.run_duration_s` seconds. See Please see the example configuration file in the `example_config` folder. Options are described in more detail below. +## Defaults and inheritance + +Configs can extend a shared base file by declaring `_extends:` at the top: + +```yaml +_extends: defaults.yml + +# only the keys that actually differ from defaults.yml go here +rank: + supervisor: [0] + ... +``` + +`config/defaults.yml` ships with values shared across the SC* / fred / ginny +configs (sampling rates, ripple filter, GUI, MUA, kinematics smoothing +filter, display intervals, process monitor). Per-animal configs only need +to specify what's actually different — typically `rank`, `trode_selection`, +`decoder_assignment`, `files`, `encoder.position`, `kinematics.scale_factor`, +and `stimulation`. Existing configs without `_extends` continue to work +unchanged. + +Relative paths in `_extends` resolve next to the file that declares them. +You may pass a single path or a list (parents merged in order). + +## Validation + +At startup the loader checks the resolved config against a minimal schema +(required ranks, known algorithm, known datasource, encoder dimensions +match the synthetic source, every decoder rank has an assignment, etc.). +Missing or malformed keys produce a single readable error before the MPI +processes spawn workers, instead of an `IndexError` deep inside a rank. + + + ## `rank` Describes which MPI rank should be assigned to each process type. diff --git a/config/defaults.yml b/config/defaults.yml new file mode 100644 index 0000000..7b04ce0 --- /dev/null +++ b/config/defaults.yml @@ -0,0 +1,101 @@ +--- +# Shared defaults for realtime_decoder configs. +# +# Per-animal / per-experiment YAMLs should declare: +# _extends: defaults.yml +# at the top, then override only what's actually different. Anything not +# overridden inherits from here. Lists and scalars are replaced as a +# whole; nested dicts merge key-by-key. +# +# This file is intentionally conservative: it only contains values that +# have been stable across the SC* / fred / ginny configs in this repo, +# plus a few sane defaults for new fields. Per-animal kinematics, +# stimulation, trode selection, and file paths must still live in the +# child config. + +algorithm: "clusterless_decoder" +datasource: "trodes" +num_setup_messages: 100 +preloaded_model: false +frozen_model: false + +sampling_rate: + spikes: 30000 + lfp: 1500 + position: 30 + +ripples: + max_ripple_samples: 450 + vel_thresh: 10 + freeze_stats: false + timings_bufsize: 1000000 + filter: + type: 'iir' + order: 2 + crit_freqs: [150, 250] + kwargs: + btype: 'bandpass' + ftype: 'butter' + smoothing_filter: + num_taps: 15 + band_edges: [50, 55] + desired: [1, 0] + threshold: + standard: 3.5 + conditioning: 3.75 + content: 4 + end: 0 + +decoder: + cred_int_bufsize: 10 + time_bin: + samples: 180 # 6 ms at 30 kHz spike clock + delay_samples: 180 + +clusterless_decoder: + state_labels: ['state'] + transmat_bias: 1 + +gui: + colormap: 'rocket' + send_interval: 0 + refresh_rate: 25 + trace_length: 2 + state_colors: ['#4c72b0', '#dd8452', '#55a868'] + num_xticks: 5 + +mua: + threshold: + trigger: 4 + end: 0 + freeze_stats: false + moving_avg_window: 5 + +cred_interval: + val: 0.5 + max_num: 5 + +kinematics: + smooth_x: true + smooth_y: true + smooth_speed: false + smoothing_filter: [0.31, 0.29, 0.25, 0.15] + +display: + stim_decider: + position: 150 + decoding_bins: 2000 + ripples: + lfp: 100000 + encoder: + encoding_spikes: 5000 + total_spikes: 50000 + occupancy: 5000 + position: 5000 + decoder: + total_spikes: 50000 + occupancy: 100 + +process_monitor: + interval: 15 + timeout: 3 diff --git a/config/demo_synthetic.yml b/config/demo_synthetic.yml index 666a7a3..fe62728 100644 --- a/config/demo_synthetic.yml +++ b/config/demo_synthetic.yml @@ -1,12 +1,16 @@ --- # Demo config: runs the full MPI pipeline against the in-process -# synthetic data source (realtime_decoder.synthetic). No Trodes, no -# acquisition hardware, no driver setup required. +# synthetic data source. No Trodes / acquisition hardware required. # # Run: # mpiexec -np 5 python -u runscript.py config/demo_synthetic.yml # -# Ranks: 0=supervisor, 1=decoder, 2=gui, 3=ripples, 4=encoder +# Everything not set here inherits from defaults.yml via `_extends`. + +_extends: defaults.yml + +datasource: "synthetic" + rank: supervisor: [0] ripples: [3] @@ -14,57 +18,35 @@ rank: encoders: [4] gui: [2] rank_settings: - enable_rec: [0,1,3,4] + enable_rec: [0, 1, 3, 4] trode_selection: ripples: [1] decoding: [1] decoder_assignment: 1: [1] -algorithm: "clusterless_decoder" -datasource: "synthetic" -num_setup_messages: 100 -preloaded_model: false -frozen_model: false + files: output_dir: '/tmp/realtime_decoder_demo' prefix: 'demo' rec_postfix: 'bin_rec' timing_postfix: 'timing' -# --- synthetic-source parameters (all optional, defaults shown) ----------- + +# --- synthetic-source parameters (all optional, defaults documented in +# realtime_decoder/synthetic.py) synthetic: spike_rate_hz: 30 - mark_dim: 4 - mark_amplitude_uv: 120 # well above encoder.spk_amp below - track_length_cm: 40 # fits into the 0..41 bins + mark_dim: 4 # must equal encoder.mark_dim below + mark_amplitude_uv: 120 + track_length_cm: 40 walk_speed_cm_s: 20 - startup_delay_s: 1.0 # supervisor waits this long, then fires play - run_duration_s: 30 # auto-terminate after this many seconds + startup_delay_s: 1.0 + run_duration_s: 30 voltage_scaling_factor: 0.195 -sampling_rate: - spikes: 30000 - lfp: 1500 - position: 30 + +# Lower-volume timings so the demo doesn't waste memory ripples: - max_ripple_samples: 450 - vel_thresh: 10 - freeze_stats: false timings_bufsize: 100000 - filter: - type: 'iir' - order: 2 - crit_freqs: [150, 250] - kwargs: - btype: 'bandpass' - ftype: 'butter' - smoothing_filter: - num_taps: 15 - band_edges: [50, 55] - desired: [1, 0] - threshold: - standard: 3.5 - conditioning: 3.75 - content: 4 - end: 0 + encoder: spk_amp: 60 use_channel_dist_from_max_amp: 2 @@ -78,43 +60,25 @@ encoder: upper: 41 num_bins: 41 arm_ids: [0] - arm_coords: [[0,40]] + arm_coords: [[0, 40]] mark_kernel: mean: 0 std: 20 use_filter: false n_std: 1 n_marks_min: 10 + decoder: decoder_to_message: 1 bufsize: 2000 timings_bufsize: 10000 - cred_int_bufsize: 10 starting_arm1_bin: 10 starting_arm2_bin: 30 num_pos_points: 30 - time_bin: - samples: 180 - delay_samples: 180 -clusterless_decoder: - state_labels: ['state'] - transmat_bias: 1 -gui: - colormap: 'rocket' - send_interval: 0 - refresh_rate: 25 - trace_length: 2 - state_colors: ['#4c72b0','#dd8452', '#55a868'] - num_xticks: 5 -mua: - threshold: - trigger: 4 - end: 0 - freeze_stats: false - moving_avg_window: 5 + stimulation: instructive: false - shortcut_msg_on: false # no ECU in demo mode + shortcut_msg_on: false automatic_threshold_update: false num_each_arm_per_minute: 1.1 num_pos_points: 30 @@ -147,18 +111,12 @@ stimulation: well_angle_range: 6 within_angle_range: 6 well_loc: [[100, 100], [200, 200]] + kinematics: - smooth_x: true - smooth_y: true - smooth_speed: false - smoothing_filter: [0.31, 0.29, 0.25, 0.15] scale_factor: 0.2644 -cred_interval: - val: 0.5 - max_num: 5 + display: stim_decider: - position: 150 decoding_bins: 200 ripples: lfp: 10000 @@ -169,7 +127,3 @@ display: position: 500 decoder: total_spikes: 5000 - occupancy: 100 -process_monitor: - interval: 15 - timeout: 3 diff --git a/realtime_decoder/config_loader.py b/realtime_decoder/config_loader.py new file mode 100644 index 0000000..9f26aff --- /dev/null +++ b/realtime_decoder/config_loader.py @@ -0,0 +1,229 @@ +"""Config loader: YAML defaults inheritance + startup validation. + +The historical pattern in this repo is one ~200-line YAML per animal, +near-duplicated across the colony. That breeds drift: a parameter +correctly tuned in `SC79_nTrode16.yml` quietly differs from the same +parameter in `SC80_nTrode16.yml`, and there is no single source of +truth for "what's the standard value of X." + +This module provides two small affordances: + +1. Optional ``_extends`` key that loads a parent YAML and deep-merges it + under the current file. Per-animal files become *overrides* on top + of a shared ``defaults.yml`` instead of full standalone configs. + +2. Startup validation. Today, common operator mistakes (missing + ``rank.supervisor``, unknown ``algorithm``, missing ``encoder.mark_dim``) + surface as ``IndexError``/``KeyError``/``NotImplementedError`` deep + inside a worker process — easy to lose in MPI log noise. ``validate`` + raises a single clear ``ConfigError`` *before* the MPI run starts. + +The loader is backward compatible: existing configs without ``_extends`` +load identically to ``yaml.safe_load`` (modulo validation). +""" + +from __future__ import annotations + +import os +from typing import Any, Dict, List, Tuple + +# PyYAML — stdlib-only dependency. (Python 3.7+ preserves insertion order +# in plain dicts, so no need for oyaml just to read configs.) +import yaml + + +class ConfigError(ValueError): + """Raised when a config fails validation or cannot be loaded.""" + + +# --------------------------------------------------------------------------- +# loading +# --------------------------------------------------------------------------- + + +def load_config(path: str) -> Dict[str, Any]: + """Load a YAML config, resolving ``_extends`` chains and validating. + + ``_extends`` may be a single path or a list of paths. Each parent is + loaded recursively (parents may themselves ``_extends``) and merged + in order, with the current file's keys taking precedence. + + Relative paths in ``_extends`` resolve relative to the file that + declares them. + """ + cfg = _load_with_inheritance(path, _seen=set()) + validate(cfg) + return cfg + + +def _load_with_inheritance(path: str, *, _seen: set) -> Dict[str, Any]: + abspath = os.path.abspath(path) + if abspath in _seen: + raise ConfigError( + f"Circular `_extends` chain detected involving {abspath}" + ) + _seen = _seen | {abspath} + + with open(abspath, 'r') as f: + raw = yaml.safe_load(f) or {} + + extends = raw.pop('_extends', None) + if extends is None: + return raw + + if isinstance(extends, str): + parents: List[str] = [extends] + elif isinstance(extends, list): + parents = list(extends) + else: + raise ConfigError( + f"`_extends` in {abspath} must be a string or list, got {type(extends).__name__}" + ) + + here = os.path.dirname(abspath) + merged: Dict[str, Any] = {} + for parent in parents: + parent_path = parent if os.path.isabs(parent) else os.path.join(here, parent) + merged = deep_merge(merged, _load_with_inheritance(parent_path, _seen=_seen)) + return deep_merge(merged, raw) + + +def deep_merge(base: Dict[str, Any], override: Dict[str, Any]) -> Dict[str, Any]: + """Recursively merge ``override`` onto ``base``, preferring ``override``. + + Nested dicts merge key-by-key. Lists and scalars are replaced, not + appended — this matches operator intuition ("override X" means + "replace X," not "extend X"). + """ + out = dict(base) + for k, v in override.items(): + if k in out and isinstance(out[k], dict) and isinstance(v, dict): + out[k] = deep_merge(out[k], v) + else: + out[k] = v + return out + + +# --------------------------------------------------------------------------- +# validation +# --------------------------------------------------------------------------- + + +# Required top-level keys and the rough shape we expect. Kept as plain +# code rather than a third-party schema lib so this module has no new +# install-time dependencies; the checks below are cheap and the error +# messages are deliberately operator-friendly. +_REQUIRED_TOP: Tuple[str, ...] = ( + 'rank', + 'algorithm', + 'sampling_rate', + 'files', + 'encoder', + 'decoder', + 'ripples', + 'kinematics', +) +_KNOWN_ALGORITHMS = ('clusterless_decoder', 'clusterless_classifier') +_KNOWN_DATASOURCES = ('trodes', 'synthetic') +_REQUIRED_RANK_ROLES = ('supervisor', 'decoders', 'encoders', 'ripples', 'gui') +_REQUIRED_SAMPLING = ('spikes', 'lfp', 'position') +_REQUIRED_FILES = ('output_dir', 'prefix') +_REQUIRED_ENCODER = ('mark_dim', 'bufsize', 'spk_amp', 'position') +_REQUIRED_ENCODER_POSITION = ('lower', 'upper', 'num_bins', 'arm_ids', 'arm_coords') +_REQUIRED_DECODER = ('bufsize', 'time_bin', 'cred_int_bufsize') + + +def validate(cfg: Dict[str, Any]) -> None: + """Raise ``ConfigError`` with a clear message if ``cfg`` is malformed.""" + errors: List[str] = [] + + for k in _REQUIRED_TOP: + if k not in cfg: + errors.append(f"missing required top-level key '{k}'") + + if cfg.get('algorithm') and cfg['algorithm'] not in _KNOWN_ALGORITHMS: + errors.append( + f"algorithm={cfg['algorithm']!r} is not one of {_KNOWN_ALGORITHMS}" + ) + + ds = cfg.get('datasource', 'trodes') + if ds not in _KNOWN_DATASOURCES: + errors.append( + f"datasource={ds!r} is not one of {_KNOWN_DATASOURCES}" + ) + + rank = cfg.get('rank', {}) + if isinstance(rank, dict): + for role in _REQUIRED_RANK_ROLES: + v = rank.get(role) + if v is None: + errors.append(f"rank.{role} is missing") + elif not isinstance(v, list) or not v: + errors.append(f"rank.{role} must be a non-empty list of ints, got {v!r}") + for role in ('supervisor', 'gui'): + if isinstance(rank.get(role), list) and len(rank[role]) != 1: + errors.append( + f"rank.{role} must contain exactly one rank, got {rank[role]!r}" + ) + else: + errors.append(f"rank must be a mapping, got {type(rank).__name__}") + + sr = cfg.get('sampling_rate', {}) + if isinstance(sr, dict): + for k in _REQUIRED_SAMPLING: + if k not in sr: + errors.append(f"sampling_rate.{k} is missing") + elif not isinstance(sr[k], (int, float)) or sr[k] <= 0: + errors.append(f"sampling_rate.{k} must be a positive number, got {sr[k]!r}") + + files = cfg.get('files', {}) + if isinstance(files, dict): + for k in _REQUIRED_FILES: + if not files.get(k): + errors.append(f"files.{k} is missing or empty") + + enc = cfg.get('encoder', {}) + if isinstance(enc, dict): + for k in _REQUIRED_ENCODER: + if k not in enc: + errors.append(f"encoder.{k} is missing") + pos = enc.get('position', {}) + if isinstance(pos, dict): + for k in _REQUIRED_ENCODER_POSITION: + if k not in pos: + errors.append(f"encoder.position.{k} is missing") + + dec = cfg.get('decoder', {}) + if isinstance(dec, dict): + for k in _REQUIRED_DECODER: + if k not in dec: + errors.append(f"decoder.{k} is missing") + tb = dec.get('time_bin', {}) + if isinstance(tb, dict): + for k in ('samples', 'delay_samples'): + if k not in tb: + errors.append(f"decoder.time_bin.{k} is missing") + + # Cross-field: each decoder rank must be a key in decoder_assignment. + dec_ranks = (cfg.get('rank') or {}).get('decoders') or [] + assignment = cfg.get('decoder_assignment') or {} + if isinstance(assignment, dict): + for r in dec_ranks: + if r not in assignment: + errors.append( + f"decoder_assignment is missing an entry for rank {r}" + ) + + # Cross-field: encoder.mark_dim must match across encoder and any + # synthetic-source override. + syn = cfg.get('synthetic') or {} + if isinstance(enc, dict) and isinstance(syn, dict): + if 'mark_dim' in syn and 'mark_dim' in enc and syn['mark_dim'] != enc['mark_dim']: + errors.append( + f"synthetic.mark_dim ({syn['mark_dim']}) " + f"!= encoder.mark_dim ({enc['mark_dim']})" + ) + + if errors: + bullet = '\n - '.join(errors) + raise ConfigError(f"config validation failed:\n - {bullet}") diff --git a/runscript.py b/runscript.py index 5b93548..2f374f5 100644 --- a/runscript.py +++ b/runscript.py @@ -1,5 +1,6 @@ import os import argparse +import sys import time import datetime import logging @@ -14,7 +15,7 @@ datatypes, position, trodesnet, synthetic, stimulation, main_process, ripple_process, encoder_process, decoder_process, gui_process, base, messages, - merge_rec + merge_rec, config_loader ) @@ -119,8 +120,17 @@ def setup(config_path, numprocs): num_digits = len(str(comm.Get_size())) - with open(config_path, 'r') as f: - config = yaml.safe_load(f) + # Load via the resolver: handles `_extends` inheritance and runs + # validation up front so missing required keys produce one clear + # error before the MPI run starts, instead of an IndexError / + # KeyError deep inside a worker. + try: + config = config_loader.load_config(config_path) + except config_loader.ConfigError as exc: + if rank == 0: + print(f"[config] {exc}", file=sys.stderr, flush=True) + comm.Barrier() + sys.exit(2) os.makedirs(os.path.dirname(config['files']['output_dir']), exist_ok=True) prefix = config['files']['prefix']