From d1689790d6f5977cf0bf161e9b46de6a20f236be Mon Sep 17 00:00:00 2001 From: diegofariasc Date: Sun, 8 Mar 2026 22:26:57 -0600 Subject: [PATCH 01/27] feat: add SleepWakeClassification task for binary sleep/wake labeling --- pyhealth/tasks/sleep_wake_classification.py | 78 +++++++++++++++++++++ 1 file changed, 78 insertions(+) create mode 100644 pyhealth/tasks/sleep_wake_classification.py diff --git a/pyhealth/tasks/sleep_wake_classification.py b/pyhealth/tasks/sleep_wake_classification.py new file mode 100644 index 000000000..6d44ae6a5 --- /dev/null +++ b/pyhealth/tasks/sleep_wake_classification.py @@ -0,0 +1,78 @@ +from typing import Any, Dict, List + +import pandas as pd + +from pyhealth.tasks import BaseTask + + +class SleepWakeClassification(BaseTask): + task_name: str = "SleepWakeClassification" + input_schema: Dict[str, str] = {"features": "vector"} + output_schema: Dict[str, str] = {"label": "binary"} + + def __init__(self, epoch_seconds: int = 30): + """Initializes the sleep-wake classification task. + + Args: + epoch_seconds: Length of each epoch in seconds. Default is 30. + """ + self.epoch_seconds = epoch_seconds + super().__init__() + + def _map_sleep_label(self, label: str) -> int | None: + """Maps DREAMT sleep stage labels to binary sleep/wake labels. + + Args: + label: Original sleep stage label. + + Returns: + 1 for wake, 0 for sleep, or None if the label should be skipped. + """ + if label is None or pd.isna(label): + return None + + label = str(label).strip() + + if label.lower() == "wake": + return 1 + + if label.upper() in {"REM", "N1", "N2", "N3"}: + return 0 + + return None + + def __call__(self, patient: Any) -> List[Dict[str, Any]]: + samples: List[Dict[str, Any]] = [] + + events = patient.get_events() + if len(events) == 0: + return samples + + # For DREAMT, each patient should typically have one wearable file event + for event in events: + if not hasattr(event, "file_64hz") or event.file_64hz is None: + continue + + df = pd.read_csv(event.file_64hz) + + if "Sleep Stage" not in df.columns: + continue + + unique_labels = df["Sleep Stage"].dropna().unique().tolist() + + for epoch_idx, raw_label in enumerate(unique_labels): + label = self._map_sleep_label(raw_label) + if label is None: + continue + + samples.append( + { + "patient_id": patient.patient_id, + "record_id": f"{patient.patient_id}-{epoch_idx}", + "epoch_index": epoch_idx, + "features": [], + "label": label, + } + ) + + return samples \ No newline at end of file From 8f8fe6ca47f7440bca0ec56a7ac3213f7fbe88f2 Mon Sep 17 00:00:00 2001 From: diegofariasc Date: Sun, 8 Mar 2026 22:44:59 -0600 Subject: [PATCH 02/27] feat: add SleepWakeClassification task with epoch segmentation and basic wearable features --- pyhealth/tasks/sleep_wake_classification.py | 46 +++++++++++++++++---- 1 file changed, 37 insertions(+), 9 deletions(-) diff --git a/pyhealth/tasks/sleep_wake_classification.py b/pyhealth/tasks/sleep_wake_classification.py index 6d44ae6a5..0b2d4809c 100644 --- a/pyhealth/tasks/sleep_wake_classification.py +++ b/pyhealth/tasks/sleep_wake_classification.py @@ -1,3 +1,5 @@ +import numpy as np +import pandas as pd from typing import Any, Dict, List import pandas as pd @@ -10,13 +12,15 @@ class SleepWakeClassification(BaseTask): input_schema: Dict[str, str] = {"features": "vector"} output_schema: Dict[str, str] = {"label": "binary"} - def __init__(self, epoch_seconds: int = 30): + def __init__(self, epoch_seconds: int = 30, sampling_rate: int = 64): """Initializes the sleep-wake classification task. Args: epoch_seconds: Length of each epoch in seconds. Default is 30. + sampling_rate: Sampling rate of the wearable data in Hz. Default is 64. """ self.epoch_seconds = epoch_seconds + self.sampling_rate = sampling_rate super().__init__() def _map_sleep_label(self, label: str) -> int | None: @@ -24,7 +28,6 @@ def _map_sleep_label(self, label: str) -> int | None: Args: label: Original sleep stage label. - Returns: 1 for wake, 0 for sleep, or None if the label should be skipped. """ @@ -35,20 +38,36 @@ def _map_sleep_label(self, label: str) -> int | None: if label.lower() == "wake": return 1 - if label.upper() in {"REM", "N1", "N2", "N3"}: return 0 return None + + def _extract_basic_features(self, epoch_df: pd.DataFrame) -> List[float]: + """Extracts basic features (mean values) from the epoch data. + + Args: + epoch_df: DataFrame containing the data for the current epoch. + Returns: + A list of basic features (mean values) for the epoch. + """ + features = [] - def __call__(self, patient: Any) -> List[Dict[str, Any]]: - samples: List[Dict[str, Any]] = [] + for col in ["BVP", "HR", "TEMP", "EDA"]: + if col in epoch_df.columns: + values = pd.to_numeric(epoch_df[col], errors="coerce").dropna() + features.append(float(values.mean()) if len(values) > 0 else 0.0) + return features + + def __call__(self, patient: Any) -> List[Dict[str, Any]]: + samples = [] events = patient.get_events() if len(events) == 0: return samples - # For DREAMT, each patient should typically have one wearable file event + epoch_size = self.epoch_seconds * self.sampling_rate + for event in events: if not hasattr(event, "file_64hz") or event.file_64hz is None: continue @@ -58,19 +77,28 @@ def __call__(self, patient: Any) -> List[Dict[str, Any]]: if "Sleep Stage" not in df.columns: continue - unique_labels = df["Sleep Stage"].dropna().unique().tolist() + n_epochs = len(df) // epoch_size + for epoch_idx in range(n_epochs): + start = epoch_idx * epoch_size + end = start + epoch_size + epoch_df = df.iloc[start:end].copy() - for epoch_idx, raw_label in enumerate(unique_labels): + if len(epoch_df) == 0: + continue + + raw_label = epoch_df["Sleep Stage"].mode().iloc[0] label = self._map_sleep_label(raw_label) if label is None: continue + features = self._extract_basic_features(epoch_df) + samples.append( { "patient_id": patient.patient_id, "record_id": f"{patient.patient_id}-{epoch_idx}", "epoch_index": epoch_idx, - "features": [], + "features": features, "label": label, } ) From a511bcf3a3739ef5744855009edbbdc0172c45d8 Mon Sep 17 00:00:00 2001 From: diegofariasc Date: Sun, 8 Mar 2026 22:53:41 -0600 Subject: [PATCH 03/27] test: add runs unit test for SleepWakeClassification task --- tests/core/test_sleep_wake_classification.py | 53 ++++++++++++++++++++ 1 file changed, 53 insertions(+) create mode 100644 tests/core/test_sleep_wake_classification.py diff --git a/tests/core/test_sleep_wake_classification.py b/tests/core/test_sleep_wake_classification.py new file mode 100644 index 000000000..dd80cdd6e --- /dev/null +++ b/tests/core/test_sleep_wake_classification.py @@ -0,0 +1,53 @@ +import tempfile +from pathlib import Path + +import pandas as pd + +from pyhealth.tasks.sleep_wake_classification import SleepWakeClassification + + +class FakeEvent: + """A fake event class to simulate patient events.""" + def __init__(self, file_64hz): + self.file_64hz = file_64hz + +class FakePatient: + """A fake patient class to simulate patient data and events.""" + def __init__(self, patient_id, file_64hz): + self.patient_id = patient_id + self._events = [FakeEvent(file_64hz)] + + """Returns the list of events for the patient.""" + def get_events(self): + return self._events + +def test_sleep_wake_classification_runs(): + """Test that the SleepWakeClassification task runs without errors and produces expected output format.""" + tmp = tempfile.mkdtemp() + csv_path = Path(tmp) / "S001_whole_df.csv" + + df = pd.DataFrame( + { + "TIMESTAMP": [0, 1, 2, 3], + "BVP": [0.1, 0.2, 0.3, 0.4], + "EDA": [0.01, 0.02, 0.03, 0.04], + "TEMP": [36.1, 36.1, 36.2, 36.2], + "ACC_X": [1, 1, 2, 2], + "ACC_Y": [0, 0, 1, 1], + "ACC_Z": [0, 0, 1, 1], + "HR": [60, 60, 61, 61], + "Sleep Stage": ["Wake", "N2", "REM", "Wake"], + } + ) + df.to_csv(csv_path, index=False) + + task = SleepWakeClassification(epoch_seconds=2, sampling_rate=1) + patient = FakePatient("S001", str(csv_path)) + + samples = task(patient) + + assert isinstance(samples, list) + assert len(samples) == 2 + assert "features" in samples[0] + assert len(samples[0]["features"]) > 0 + assert samples[0]["label"] in [0, 1] \ No newline at end of file From 93246ab688fad294e59edaa1cd2876e2a568dc23 Mon Sep 17 00:00:00 2001 From: diegofariasc Date: Mon, 9 Mar 2026 01:44:06 -0600 Subject: [PATCH 04/27] feat: implement feature extraction pipeline for SleepWakeClassification --- pyhealth/tasks/sleep_wake_classification.py | 243 ++++++++++++++++---- 1 file changed, 195 insertions(+), 48 deletions(-) diff --git a/pyhealth/tasks/sleep_wake_classification.py b/pyhealth/tasks/sleep_wake_classification.py index 0b2d4809c..9db4c5465 100644 --- a/pyhealth/tasks/sleep_wake_classification.py +++ b/pyhealth/tasks/sleep_wake_classification.py @@ -2,103 +2,250 @@ import pandas as pd from typing import Any, Dict, List -import pandas as pd +from scipy.signal import butter, filtfilt +from scipy.stats import trim_mean +from scipy.stats.mstats import winsorize -from pyhealth.tasks import BaseTask +from .base_task import BaseTask class SleepWakeClassification(BaseTask): - task_name: str = "SleepWakeClassification" - input_schema: Dict[str, str] = {"features": "vector"} - output_schema: Dict[str, str] = {"label": "binary"} + task_name = "SleepWakeClassification" + input_schema = {"features": "vector"} + output_schema = {"label": "binary"} def __init__(self, epoch_seconds: int = 30, sampling_rate: int = 64): - """Initializes the sleep-wake classification task. - - Args: - epoch_seconds: Length of each epoch in seconds. Default is 30. - sampling_rate: Sampling rate of the wearable data in Hz. Default is 64. - """ self.epoch_seconds = epoch_seconds self.sampling_rate = sampling_rate super().__init__() - def _map_sleep_label(self, label: str) -> int | None: - """Maps DREAMT sleep stage labels to binary sleep/wake labels. - - Args: - label: Original sleep stage label. - Returns: - 1 for wake, 0 for sleep, or None if the label should be skipped. - """ + def _map_sleep_label(self, label): if label is None or pd.isna(label): return None - label = str(label).strip() + label = str(label).strip().upper() - if label.lower() == "wake": + if label in {"WAKE", "W"}: return 1 - if label.upper() in {"REM", "N1", "N2", "N3"}: + if label in {"REM", "R", "N1", "N2", "N3"}: return 0 return None - - def _extract_basic_features(self, epoch_df: pd.DataFrame) -> List[float]: - """Extracts basic features (mean values) from the epoch data. - - Args: - epoch_df: DataFrame containing the data for the current epoch. - Returns: - A list of basic features (mean values) for the epoch. - """ + + def _safe_numeric(self, series: pd.Series) -> np.ndarray: + return pd.to_numeric(series, errors="coerce").fillna(0.0).to_numpy() + + def _split_into_epochs(self, signal: np.ndarray, fs: float) -> List[np.ndarray]: + samples_per_epoch = int(fs * self.epoch_seconds) + num_epochs = len(signal) // samples_per_epoch + + epochs = [] + for i in range(num_epochs): + start = i * samples_per_epoch + end = start + samples_per_epoch + epochs.append(signal[start:end]) + + return epochs + + def _butter_bandpass( + self, + low_hz: float, + high_hz: float, + fs: float, + order: int, + ): + nyq = 0.5 * fs + low = low_hz / nyq + high = high_hz / nyq + return butter(order, [low, high], btype="band") + + def _apply_filter(self, signal: np.ndarray, b, a) -> np.ndarray: + if signal.ndim != 1: + raise ValueError("Signal must be 1D.") + return filtfilt(b, a, signal) + + def _filter_acc(self, signal: np.ndarray, fs: float) -> np.ndarray: + b, a = self._butter_bandpass( + low_hz=3.0, + high_hz=11.0, + fs=fs, + order=5, + ) + return self._apply_filter(signal, b, a) + + def _iqr(self, x: np.ndarray) -> float: + return float(np.percentile(x, 75) - np.percentile(x, 25)) + + def _extract_acc_axis_features(self, signal: np.ndarray, fs: float) -> List[Dict[str, float]]: + filtered = self._filter_acc(signal, fs) + filtered_abs = np.abs(filtered) + epochs = self._split_into_epochs(filtered_abs, fs) + features = [] + for ep in epochs: + features.append( + { + "trimmed_mean": float(trim_mean(ep, proportiontocut=0.10)), + "max": float(np.max(ep)), + "iqr": self._iqr(ep), + } + ) + return features - for col in ["BVP", "HR", "TEMP", "EDA"]: - if col in epoch_df.columns: - values = pd.to_numeric(epoch_df[col], errors="coerce").dropna() - features.append(float(values.mean()) if len(values) > 0 else 0.0) + def _extract_acc_mad_features( + self, + acc_x: np.ndarray, + acc_y: np.ndarray, + acc_z: np.ndarray, + fs: float, + ) -> List[Dict[str, float]]: + magnitude = np.sqrt(acc_x**2 + acc_y**2 + acc_z**2) + epochs = self._split_into_epochs(magnitude, fs) + features = [] + for ep in epochs: + mad = np.mean(np.abs(ep - np.mean(ep))) + features.append({"mad": float(mad)}) return features + def _extract_temp_features(self, signal: np.ndarray, fs: float) -> List[Dict[str, float]]: + limits = (0.05, 0.05) + wins_signal = winsorize(signal, limits=limits) + wins_signal = np.clip(wins_signal, 31.0, 40.0) + epochs = self._split_into_epochs(np.asarray(wins_signal), fs) + + features = [] + for ep in epochs: + features.append( + { + "mean": float(np.mean(ep)), + "min": float(np.min(ep)), + "max": float(np.max(ep)), + "std": float(np.std(ep)), + } + ) + return features + + def _build_record_epoch_features(self, df: pd.DataFrame) -> List[List[float]]: + fs = float(self.sampling_rate) + + required_acc = ["ACC_X", "ACC_Y", "ACC_Z"] + if not all(col in df.columns for col in required_acc): + return [] + + if "TEMP" not in df.columns: + return [] + + acc_x = self._safe_numeric(df["ACC_X"]) + acc_y = self._safe_numeric(df["ACC_Y"]) + acc_z = self._safe_numeric(df["ACC_Z"]) + temp = self._safe_numeric(df["TEMP"]) + + acc_x_feats = self._extract_acc_axis_features(acc_x, fs) + acc_y_feats = self._extract_acc_axis_features(acc_y, fs) + acc_z_feats = self._extract_acc_axis_features(acc_z, fs) + acc_mad_feats = self._extract_acc_mad_features(acc_x, acc_y, acc_z, fs) + temp_feats = self._extract_temp_features(temp, fs) + + num_epochs = min( + len(acc_x_feats), + len(acc_y_feats), + len(acc_z_feats), + len(acc_mad_feats), + len(temp_feats), + ) + + all_epoch_features = [] + for i in range(num_epochs): + feats = [] + + feats.extend( + [ + acc_x_feats[i]["trimmed_mean"], + acc_x_feats[i]["max"], + acc_x_feats[i]["iqr"], + ] + ) + feats.extend( + [ + acc_y_feats[i]["trimmed_mean"], + acc_y_feats[i]["max"], + acc_y_feats[i]["iqr"], + ] + ) + feats.extend( + [ + acc_z_feats[i]["trimmed_mean"], + acc_z_feats[i]["max"], + acc_z_feats[i]["iqr"], + ] + ) + feats.append(acc_mad_feats[i]["mad"]) + + feats.extend( + [ + temp_feats[i]["mean"], + temp_feats[i]["min"], + temp_feats[i]["max"], + temp_feats[i]["std"], + ] + ) + + all_epoch_features.append(feats) + + return all_epoch_features + def __call__(self, patient: Any) -> List[Dict[str, Any]]: samples = [] - events = patient.get_events() + events = patient.get_events(event_type="dreamt_sleep") if len(events) == 0: return samples epoch_size = self.epoch_seconds * self.sampling_rate - for event in events: - if not hasattr(event, "file_64hz") or event.file_64hz is None: + for event_idx, event in enumerate(events): + file_path = getattr(event, "file_64hz", None) + if file_path is None: continue - df = pd.read_csv(event.file_64hz) + try: + df = pd.read_csv(file_path) + except Exception: + continue + + if "Sleep_Stage" not in df.columns: + continue - if "Sleep Stage" not in df.columns: + record_epoch_features = self._build_record_epoch_features(df) + if len(record_epoch_features) == 0: continue - n_epochs = len(df) // epoch_size + n_label_epochs = len(df) // epoch_size + n_epochs = min(len(record_epoch_features), n_label_epochs) + for epoch_idx in range(n_epochs): start = epoch_idx * epoch_size end = start + epoch_size - epoch_df = df.iloc[start:end].copy() + epoch_df = df.iloc[start:end] + + if len(epoch_df) < epoch_size: + continue - if len(epoch_df) == 0: + stage_mode = epoch_df["Sleep_Stage"].mode(dropna=True) + if len(stage_mode) == 0: continue - raw_label = epoch_df["Sleep Stage"].mode().iloc[0] + raw_label = stage_mode.iloc[0] label = self._map_sleep_label(raw_label) if label is None: continue - features = self._extract_basic_features(epoch_df) - samples.append( { "patient_id": patient.patient_id, - "record_id": f"{patient.patient_id}-{epoch_idx}", + "record_id": f"{patient.patient_id}-event{event_idx}-epoch{epoch_idx}", "epoch_index": epoch_idx, - "features": features, + "features": record_epoch_features[epoch_idx], "label": label, } ) From 58d8366da22262fdb3c30d5d1330c47e47a3c364 Mon Sep 17 00:00:00 2001 From: diegofariasc Date: Mon, 9 Mar 2026 02:40:47 -0600 Subject: [PATCH 05/27] feat: add record-level BVP features to SleepWakeClassification --- pyhealth/tasks/sleep_wake_classification.py | 76 ++++++++++++++++++++- 1 file changed, 75 insertions(+), 1 deletion(-) diff --git a/pyhealth/tasks/sleep_wake_classification.py b/pyhealth/tasks/sleep_wake_classification.py index 9db4c5465..7c65aaa1f 100644 --- a/pyhealth/tasks/sleep_wake_classification.py +++ b/pyhealth/tasks/sleep_wake_classification.py @@ -2,7 +2,8 @@ import pandas as pd from typing import Any, Dict, List -from scipy.signal import butter, filtfilt +import neurokit2 as nk +from scipy.signal import butter, cheby2, filtfilt from scipy.stats import trim_mean from scipy.stats.mstats import winsorize @@ -72,6 +73,65 @@ def _filter_acc(self, signal: np.ndarray, fs: float) -> np.ndarray: order=5, ) return self._apply_filter(signal, b, a) + + def _cheby2_bandpass( + self, + low_hz: float, + high_hz: float, + fs: float, + order: int, + rs: float = 40.0, + ): + nyq = 0.5 * fs + low = low_hz / nyq + high = high_hz / nyq + return cheby2(order, rs, [low, high], btype="band") + + def _filter_bvp(self, signal: np.ndarray, fs: float) -> np.ndarray: + b, a = self._cheby2_bandpass( + low_hz=0.5, + high_hz=20.0, + fs=fs, + order=4, + rs=40.0, + ) + return self._apply_filter(signal, b, a) + + def _extract_bvp_features( + self, + signal: np.ndarray, + fs: float, + ) -> List[Dict[str, float]]: + filtered = self._filter_bvp(signal, fs) + epochs = self._split_into_epochs(filtered, fs) + + features = [] + for ep in epochs: + try: + _, info = nk.ppg_process(ep, sampling_rate=fs) + hrv = nk.hrv_time( + info["PPG_Peaks"], + sampling_rate=fs, + show=False, + ) + + features.append( + { + "rmssd": float(hrv["HRV_RMSSD"].values[0]), + "sdnn": float(hrv["HRV_SDNN"].values[0]), + "pnn50": float(hrv["HRV_pNN50"].values[0]), + } + ) + except Exception: + features.append( + { + "rmssd": np.nan, + "sdnn": np.nan, + "pnn50": np.nan, + } + ) + + return features def _iqr(self, x: np.ndarray) -> float: return float(np.percentile(x, 75) - np.percentile(x, 25)) @@ -136,16 +196,21 @@ def _build_record_epoch_features(self, df: pd.DataFrame) -> List[List[float]]: if "TEMP" not in df.columns: return [] + if "BVP" not in df.columns: + return [] + acc_x = self._safe_numeric(df["ACC_X"]) acc_y = self._safe_numeric(df["ACC_Y"]) acc_z = self._safe_numeric(df["ACC_Z"]) temp = self._safe_numeric(df["TEMP"]) + bvp = self._safe_numeric(df["BVP"]) acc_x_feats = self._extract_acc_axis_features(acc_x, fs) acc_y_feats = self._extract_acc_axis_features(acc_y, fs) acc_z_feats = self._extract_acc_axis_features(acc_z, fs) acc_mad_feats = self._extract_acc_mad_features(acc_x, acc_y, acc_z, fs) temp_feats = self._extract_temp_features(temp, fs) + bvp_feats = self._extract_bvp_features(bvp, fs) num_epochs = min( len(acc_x_feats), @@ -153,6 +218,7 @@ def _build_record_epoch_features(self, df: pd.DataFrame) -> List[List[float]]: len(acc_z_feats), len(acc_mad_feats), len(temp_feats), + len(bvp_feats), ) all_epoch_features = [] @@ -191,6 +257,14 @@ def _build_record_epoch_features(self, df: pd.DataFrame) -> List[List[float]]: ] ) + feats.extend( + [ + bvp_feats[i]["rmssd"], + bvp_feats[i]["sdnn"], + bvp_feats[i]["pnn50"], + ] + ) + all_epoch_features.append(feats) return all_epoch_features From a90af6c78d4b15fffdebf492d67c098f84159ec5 Mon Sep 17 00:00:00 2001 From: diegofariasc Date: Mon, 9 Mar 2026 03:14:10 -0600 Subject: [PATCH 06/27] feat: add record-level EDA features to SleepWakeClassification --- pyhealth/tasks/sleep_wake_classification.py | 87 +++++++++++++++++++++ 1 file changed, 87 insertions(+) diff --git a/pyhealth/tasks/sleep_wake_classification.py b/pyhealth/tasks/sleep_wake_classification.py index 7c65aaa1f..973eb2b56 100644 --- a/pyhealth/tasks/sleep_wake_classification.py +++ b/pyhealth/tasks/sleep_wake_classification.py @@ -133,6 +133,79 @@ def _extract_bvp_features( return features + def _lowpass_filter( + self, + signal: np.ndarray, + fs: float, + cutoff_hz: float, + order: int = 4, + ) -> np.ndarray: + nyq = 0.5 * fs + b, a = butter(order, cutoff_hz / nyq, btype="low") + return self._apply_filter(signal, b, a) + + def _detrend_segments( + self, + signal: np.ndarray, + fs: float, + segment_seconds: int, + ) -> np.ndarray: + samples_per_seg = int(fs * segment_seconds) + detrended = signal.copy() + + for i in range(0, len(signal), samples_per_seg): + seg = signal[i : i + samples_per_seg] + if len(seg) < 2: + continue + + x = np.arange(len(seg)) + coeffs = np.polyfit(x, seg, deg=1) + trend = np.polyval(coeffs, x) + detrended[i : i + len(seg)] = seg - trend + + return detrended + + def _extract_eda_features( + self, + signal: np.ndarray, + fs: float, + ) -> List[Dict[str, float]]: + detrended = self._detrend_segments(signal, fs, segment_seconds=5) + filtered = self._lowpass_filter(detrended, fs, cutoff_hz=1.0) + + eda_signals, _ = nk.eda_process(filtered, sampling_rate=fs) + scr = eda_signals["EDA_Phasic"].values + epochs = self._split_into_epochs(scr, fs) + + features = [] + for ep in epochs: + try: + _, info = nk.eda_peaks(ep, sampling_rate=fs) + + amplitudes = info["SCR_Amplitude"] + rise_times = info["SCR_RiseTime"] + recovery_times = info["SCR_RecoveryTime"] + + features.append( + { + "scr_amp_mean": float(np.mean(amplitudes)) if len(amplitudes) else 0.0, + "scr_amp_max": float(np.max(amplitudes)) if len(amplitudes) else 0.0, + "scr_rise_mean": float(np.mean(rise_times)) if len(rise_times) else 0.0, + "scr_recovery_mean": float(np.mean(recovery_times)) if len(recovery_times) else 0.0, + } + ) + except Exception: + features.append( + { + "scr_amp_mean": np.nan, + "scr_amp_max": np.nan, + "scr_rise_mean": np.nan, + "scr_recovery_mean": np.nan, + } + ) + + return features + def _iqr(self, x: np.ndarray) -> float: return float(np.percentile(x, 75) - np.percentile(x, 25)) @@ -198,12 +271,16 @@ def _build_record_epoch_features(self, df: pd.DataFrame) -> List[List[float]]: if "BVP" not in df.columns: return [] + + if "EDA" not in df.columns: + return [] acc_x = self._safe_numeric(df["ACC_X"]) acc_y = self._safe_numeric(df["ACC_Y"]) acc_z = self._safe_numeric(df["ACC_Z"]) temp = self._safe_numeric(df["TEMP"]) bvp = self._safe_numeric(df["BVP"]) + eda = self._safe_numeric(df["EDA"]) acc_x_feats = self._extract_acc_axis_features(acc_x, fs) acc_y_feats = self._extract_acc_axis_features(acc_y, fs) @@ -211,6 +288,7 @@ def _build_record_epoch_features(self, df: pd.DataFrame) -> List[List[float]]: acc_mad_feats = self._extract_acc_mad_features(acc_x, acc_y, acc_z, fs) temp_feats = self._extract_temp_features(temp, fs) bvp_feats = self._extract_bvp_features(bvp, fs) + eda_feats = self._extract_eda_features(eda, fs) num_epochs = min( len(acc_x_feats), @@ -265,6 +343,15 @@ def _build_record_epoch_features(self, df: pd.DataFrame) -> List[List[float]]: ] ) + feats.extend( + [ + eda_feats[i]["scr_amp_mean"], + eda_feats[i]["scr_amp_max"], + eda_feats[i]["scr_rise_mean"], + eda_feats[i]["scr_recovery_mean"], + ] + ) + all_epoch_features.append(feats) return all_epoch_features From 30d0f90643d615071621fa498244abbe2283f199 Mon Sep 17 00:00:00 2001 From: diegofariasc Date: Mon, 9 Mar 2026 04:14:28 -0600 Subject: [PATCH 07/27] feat: add temporal feature enhancement to SleepWakeClassification --- pyhealth/tasks/sleep_wake_classification.py | 60 +++++++++++++++++++++ 1 file changed, 60 insertions(+) diff --git a/pyhealth/tasks/sleep_wake_classification.py b/pyhealth/tasks/sleep_wake_classification.py index 973eb2b56..6d536fc86 100644 --- a/pyhealth/tasks/sleep_wake_classification.py +++ b/pyhealth/tasks/sleep_wake_classification.py @@ -6,6 +6,7 @@ from scipy.signal import butter, cheby2, filtfilt from scipy.stats import trim_mean from scipy.stats.mstats import winsorize +from scipy.ndimage import gaussian_filter1d from .base_task import BaseTask @@ -205,6 +206,59 @@ def _extract_eda_features( ) return features + + def _apply_gaussian_smoothing( + self, + values: np.ndarray, + sigma: float, + ) -> np.ndarray: + return gaussian_filter1d(values, sigma=sigma, mode="nearest") + + def _temporal_derivative(self, values: np.ndarray) -> np.ndarray: + return np.diff(values, prepend=values[0]) + + def _rolling_variance( + self, + values: np.ndarray, + window: int, + ) -> np.ndarray: + out = np.zeros_like(values) + half = window // 2 + + for i in range(len(values)): + start = max(0, i - half) + end = min(len(values), i + half + 1) + out[i] = np.var(values[start:end]) + + return out + + def _enhance_features_temporally( + self, + epoch_features: List[List[float]], + gaussian_sigma: float = 2.0, + variance_window: int = 5, + ) -> List[List[float]]: + if len(epoch_features) == 0: + return [] + + feature_matrix = np.asarray(epoch_features, dtype=float) + num_epochs, num_features = feature_matrix.shape + + enhanced = feature_matrix.tolist() + + for j in range(num_features): + values = feature_matrix[:, j] + + smoothed = self._apply_gaussian_smoothing(values, gaussian_sigma) + deriv = self._temporal_derivative(smoothed) + var = self._rolling_variance(smoothed, variance_window) + + for i in range(num_epochs): + enhanced[i].append(float(smoothed[i])) + enhanced[i].append(float(deriv[i])) + enhanced[i].append(float(var[i])) + + return enhanced def _iqr(self, x: np.ndarray) -> float: return float(np.percentile(x, 75) - np.percentile(x, 25)) @@ -354,6 +408,12 @@ def _build_record_epoch_features(self, df: pd.DataFrame) -> List[List[float]]: all_epoch_features.append(feats) + all_epoch_features = self._enhance_features_temporally( + all_epoch_features, + gaussian_sigma=2.0, + variance_window=5, + ) + return all_epoch_features def __call__(self, patient: Any) -> List[Dict[str, Any]]: From 4e91596be4249607ddd757876e857f03e1cb30aa Mon Sep 17 00:00:00 2001 From: diegofariasc Date: Mon, 9 Mar 2026 05:13:56 -0600 Subject: [PATCH 08/27] feat: add initial sleep-wake task and temporal feature ablation example --- examples/sleep_wake_classification.py | 105 ++++++++++++++++++++++++++ 1 file changed, 105 insertions(+) create mode 100644 examples/sleep_wake_classification.py diff --git a/examples/sleep_wake_classification.py b/examples/sleep_wake_classification.py new file mode 100644 index 000000000..d2cd4fa9e --- /dev/null +++ b/examples/sleep_wake_classification.py @@ -0,0 +1,105 @@ +from collections import Counter + +import numpy as np +import lightgbm as lgb +from sklearn.impute import SimpleImputer +from sklearn.metrics import accuracy_score, f1_score, roc_auc_score, average_precision_score +from sklearn.model_selection import GroupShuffleSplit + +from pyhealth.datasets import DREAMTDataset +from pyhealth.tasks.sleep_wake_classification import SleepWakeClassification + + +def run_experiment(X, y, groups, name): + splitter = GroupShuffleSplit(n_splits=1, test_size=0.4, random_state=42) + train_idx, test_idx = next(splitter.split(X, y, groups=groups)) + + X_train, X_test = X[train_idx], X[test_idx] + y_train, y_test = y[train_idx], y[test_idx] + g_train, g_test = groups[train_idx], groups[test_idx] + + print(f"\n=== {name} ===") + print("train patients:", sorted(set(g_train))) + print("test patients:", sorted(set(g_test))) + print("train size:", len(X_train)) + print("test size:", len(X_test)) + + non_all_nan_cols = ~np.isnan(X_train).all(axis=0) + X_train = X_train[:, non_all_nan_cols] + X_test = X_test[:, non_all_nan_cols] + + print("kept features:", X_train.shape[1]) + + imputer = SimpleImputer(strategy="median") + X_train = imputer.fit_transform(X_train) + X_test = imputer.transform(X_test) + + train_data = lgb.Dataset(X_train, label=y_train) + test_data = lgb.Dataset(X_test, label=y_test, reference=train_data) + + params = { + "objective": "binary", + "metric": "binary_logloss", + "boosting_type": "gbdt", + "learning_rate": 0.05, + "num_leaves": 31, + "feature_fraction": 0.9, + "bagging_fraction": 0.9, + "bagging_freq": 5, + "verbose": -1, + "seed": 42, + } + + model = lgb.train( + params, + train_data, + num_boost_round=200, + valid_sets=[test_data], + callbacks=[lgb.early_stopping(stopping_rounds=20, verbose=False)], + ) + + y_prob = model.predict(X_test) + y_pred = (y_prob >= 0.3).astype(int) + + print("Accuracy:", accuracy_score(y_test, y_pred)) + print("F1:", f1_score(y_test, y_pred)) + print("AUROC:", roc_auc_score(y_test, y_prob)) + print("AUPRC:", average_precision_score(y_test, y_prob)) + + +def main(): + root = r"C:\Users\faria\OneDrive - University of Illinois - Urbana\CS-598-DLH\dreamt-replication\data\DREAMT" + + dataset = DREAMTDataset(root=root) + task = SleepWakeClassification() + + selected_patient_ids = ["S028", "S062", "S078", "S081", "S099"] + + all_samples = [] + for patient_id in selected_patient_ids: + patient = dataset.get_patient(patient_id) + samples = task(patient) + print(patient_id, len(samples)) + all_samples.extend(samples) + + print("total samples:", len(all_samples)) + print("label counts:", Counter(s["label"] for s in all_samples)) + + X_all = np.array([s["features"] for s in all_samples], dtype=float) + y = np.array([s["label"] for s in all_samples], dtype=int) + groups = np.array([s["patient_id"] for s in all_samples]) + + print("X_all shape:", X_all.shape) + + # base only = first 21 features + X_base = X_all[:, :21] + + # base + temporal = all features + X_temporal = X_all + + run_experiment(X_base, y, groups, "Base features only") + run_experiment(X_temporal, y, groups, "Base + temporal features") + + +if __name__ == "__main__": + main() \ No newline at end of file From dbcc4964044a721260a039371f79898cfe161e2a Mon Sep 17 00:00:00 2001 From: diegofariasc Date: Mon, 9 Mar 2026 05:59:33 -0600 Subject: [PATCH 09/27] feat: add modality ablation experiments for DREAMT sleep-wake classification task --- examples/sleep_wake_classification.py | 17 +++++++++++++++-- 1 file changed, 15 insertions(+), 2 deletions(-) diff --git a/examples/sleep_wake_classification.py b/examples/sleep_wake_classification.py index d2cd4fa9e..bfe2f3574 100644 --- a/examples/sleep_wake_classification.py +++ b/examples/sleep_wake_classification.py @@ -97,8 +97,21 @@ def main(): # base + temporal = all features X_temporal = X_all - run_experiment(X_base, y, groups, "Base features only") - run_experiment(X_temporal, y, groups, "Base + temporal features") + acc_idx = list(range(0, 10)) # ACC_X, ACC_Y, ACC_Z, ACC_MAD + temp_idx = list(range(10, 14)) # TEMP + bvp_idx = list(range(14, 17)) # BVP + eda_idx = list(range(17, 21)) # EDA + + X_acc = X_base[:, acc_idx] + X_acc_temp = X_base[:, acc_idx + temp_idx] + X_acc_temp_bvp = X_base[:, acc_idx + temp_idx + bvp_idx] + X_all_modalities = X_base[:, acc_idx + temp_idx + bvp_idx + eda_idx] + + run_experiment(X_acc, y, groups, "ACC only") + run_experiment(X_acc_temp, y, groups, "ACC + TEMP") + run_experiment(X_acc_temp_bvp, y, groups, "ACC + TEMP + BVP") + run_experiment(X_all_modalities, y, groups, "ALL modalities") + run_experiment(X_temporal, y, groups, "ALL modalities + temporal") if __name__ == "__main__": From 7ef252da7d02e2a1429ee85f66ae8cb8352b6f07 Mon Sep 17 00:00:00 2001 From: diegofariasc Date: Thu, 12 Mar 2026 03:03:33 -0600 Subject: [PATCH 10/27] refactor: improve readability and reuse in sleep-wake task --- examples/sleep_wake_classification.py | 49 ++ pyhealth/tasks/sleep_wake_classification.py | 651 ++++++++++++-------- 2 files changed, 437 insertions(+), 263 deletions(-) diff --git a/examples/sleep_wake_classification.py b/examples/sleep_wake_classification.py index bfe2f3574..96d1dbd3c 100644 --- a/examples/sleep_wake_classification.py +++ b/examples/sleep_wake_classification.py @@ -5,6 +5,8 @@ from sklearn.impute import SimpleImputer from sklearn.metrics import accuracy_score, f1_score, roc_auc_score, average_precision_score from sklearn.model_selection import GroupShuffleSplit +from sklearn.linear_model import LogisticRegression +from sklearn.ensemble import RandomForestClassifier from pyhealth.datasets import DREAMTDataset from pyhealth.tasks.sleep_wake_classification import SleepWakeClassification @@ -67,6 +69,52 @@ def run_experiment(X, y, groups, name): print("AUPRC:", average_precision_score(y_test, y_prob)) +def run_model_comparison(X, y, groups): + splitter = GroupShuffleSplit(n_splits=1, test_size=0.4, random_state=42) + train_idx, test_idx = next(splitter.split(X, y, groups=groups)) + + X_train, X_test = X[train_idx], X[test_idx] + y_train, y_test = y[train_idx], y[test_idx] + g_train, g_test = groups[train_idx], groups[test_idx] + + print("\n=== Model comparison (ALL modalities + temporal) ===") + print("train patients:", sorted(set(g_train))) + print("test patients:", sorted(set(g_test))) + + non_all_nan_cols = ~np.isnan(X_train).all(axis=0) + X_train = X_train[:, non_all_nan_cols] + X_test = X_test[:, non_all_nan_cols] + + imputer = SimpleImputer(strategy="median") + X_train = imputer.fit_transform(X_train) + X_test = imputer.transform(X_test) + + models = { + "LogisticRegression": LogisticRegression(max_iter=1000), + "RandomForest": RandomForestClassifier( + n_estimators=200, + random_state=42, + n_jobs=-1 + ), + } + + for name, model in models.items(): + + model.fit(X_train, y_train) + + if hasattr(model, "predict_proba"): + y_prob = model.predict_proba(X_test)[:, 1] + else: + y_prob = model.decision_function(X_test) + + y_pred = (y_prob >= 0.3).astype(int) + + print(f"\n{name}") + print("Accuracy:", accuracy_score(y_test, y_pred)) + print("F1:", f1_score(y_test, y_pred)) + print("AUROC:", roc_auc_score(y_test, y_prob)) + print("AUPRC:", average_precision_score(y_test, y_prob)) + def main(): root = r"C:\Users\faria\OneDrive - University of Illinois - Urbana\CS-598-DLH\dreamt-replication\data\DREAMT" @@ -112,6 +160,7 @@ def main(): run_experiment(X_acc_temp_bvp, y, groups, "ACC + TEMP + BVP") run_experiment(X_all_modalities, y, groups, "ALL modalities") run_experiment(X_temporal, y, groups, "ALL modalities + temporal") + run_model_comparison(X_temporal, y, groups) if __name__ == "__main__": diff --git a/pyhealth/tasks/sleep_wake_classification.py b/pyhealth/tasks/sleep_wake_classification.py index 6d536fc86..0d7c821b4 100644 --- a/pyhealth/tasks/sleep_wake_classification.py +++ b/pyhealth/tasks/sleep_wake_classification.py @@ -1,13 +1,15 @@ +from typing import Callable, Dict, List + +import neurokit2 as nk import numpy as np import pandas as pd -from typing import Any, Dict, List -import neurokit2 as nk +from scipy.ndimage import gaussian_filter1d from scipy.signal import butter, cheby2, filtfilt from scipy.stats import trim_mean from scipy.stats.mstats import winsorize -from scipy.ndimage import gaussian_filter1d +from ..data import Patient from .base_task import BaseTask @@ -21,7 +23,7 @@ def __init__(self, epoch_seconds: int = 30, sampling_rate: int = 64): self.sampling_rate = sampling_rate super().__init__() - def _map_sleep_label(self, label): + def _convert_sleep_stage_to_binary_label(self, label): if label is None or pd.isna(label): return None @@ -34,11 +36,15 @@ def _map_sleep_label(self, label): return None - def _safe_numeric(self, series: pd.Series) -> np.ndarray: + def _convert_series_to_numeric_array(self, series: pd.Series) -> np.ndarray: return pd.to_numeric(series, errors="coerce").fillna(0.0).to_numpy() - def _split_into_epochs(self, signal: np.ndarray, fs: float) -> List[np.ndarray]: - samples_per_epoch = int(fs * self.epoch_seconds) + def _split_signal_into_epochs( + self, + signal: np.ndarray, + sampling_rate_hz: float, + ) -> List[np.ndarray]: + samples_per_epoch = int(sampling_rate_hz * self.epoch_seconds) num_epochs = len(signal) // samples_per_epoch epochs = [] @@ -49,109 +55,134 @@ def _split_into_epochs(self, signal: np.ndarray, fs: float) -> List[np.ndarray]: return epochs - def _butter_bandpass( + def _design_bandpass_filter_coefficients( self, + filter_family: str, low_hz: float, high_hz: float, - fs: float, + sampling_rate_hz: float, order: int, + stopband_attenuation_db: float = 40.0, ): - nyq = 0.5 * fs + nyq = 0.5 * sampling_rate_hz low = low_hz / nyq high = high_hz / nyq - return butter(order, [low, high], btype="band") - def _apply_filter(self, signal: np.ndarray, b, a) -> np.ndarray: + if filter_family == "butter": + return butter(order, [low, high], btype="band") + if filter_family == "cheby2": + return cheby2( + order, + stopband_attenuation_db, + [low, high], + btype="band", + ) + + raise ValueError(f"Unsupported bandpass filter family: {filter_family}") + + def _apply_zero_phase_filter(self, signal: np.ndarray, b, a) -> np.ndarray: if signal.ndim != 1: raise ValueError("Signal must be 1D.") return filtfilt(b, a, signal) - def _filter_acc(self, signal: np.ndarray, fs: float) -> np.ndarray: - b, a = self._butter_bandpass( + def _filter_accelerometer_signal( + self, + signal: np.ndarray, + sampling_rate_hz: float, + ) -> np.ndarray: + b, a = self._design_bandpass_filter_coefficients( + filter_family="butter", low_hz=3.0, high_hz=11.0, - fs=fs, + sampling_rate_hz=sampling_rate_hz, order=5, ) - return self._apply_filter(signal, b, a) - - def _cheby2_bandpass( - self, - low_hz: float, - high_hz: float, - fs: float, - order: int, - rs: float = 40.0, - ): - nyq = 0.5 * fs - low = low_hz / nyq - high = high_hz / nyq - return cheby2(order, rs, [low, high], btype="band") + return self._apply_zero_phase_filter(signal, b, a) - def _filter_bvp(self, signal: np.ndarray, fs: float) -> np.ndarray: - b, a = self._cheby2_bandpass( + def _filter_blood_volume_pulse_signal( + self, + signal: np.ndarray, + sampling_rate_hz: float, + ) -> np.ndarray: + b, a = self._design_bandpass_filter_coefficients( + filter_family="cheby2", low_hz=0.5, high_hz=20.0, - fs=fs, + sampling_rate_hz=sampling_rate_hz, order=4, - rs=40.0, + stopband_attenuation_db=40.0, ) - return self._apply_filter(signal, b, a) + return self._apply_zero_phase_filter(signal, b, a) + + def _build_feature_dictionary_from_epochs( + self, + epochs: List[np.ndarray], + feature_builder: Callable[[np.ndarray], Dict[str, float]], + ) -> List[Dict[str, float]]: + return [feature_builder(epoch) for epoch in epochs] + + def _build_missing_feature_dictionary( + self, + feature_names: List[str], + ) -> Dict[str, float]: + return {feature_name: np.nan for feature_name in feature_names} + + def _append_feature_values( + self, + feature_vector: List[float], + feature_dictionary: Dict[str, float], + feature_names: List[str], + ) -> None: + feature_vector.extend(feature_dictionary[feature_name] for feature_name in feature_names) - def _extract_bvp_features( + def _extract_blood_volume_pulse_epoch_features( self, signal: np.ndarray, - fs: float, + sampling_rate_hz: float, ) -> List[Dict[str, float]]: - filtered = self._filter_bvp(signal, fs) - epochs = self._split_into_epochs(filtered, fs) + filtered = self._filter_blood_volume_pulse_signal(signal, sampling_rate_hz) + epochs = self._split_signal_into_epochs(filtered, sampling_rate_hz) - features = [] - for ep in epochs: + def build_blood_volume_pulse_feature_dictionary(epoch: np.ndarray) -> Dict[str, float]: try: - _, info = nk.ppg_process(ep, sampling_rate=fs) + _, info = nk.ppg_process(epoch, sampling_rate=sampling_rate_hz) hrv = nk.hrv_time( info["PPG_Peaks"], - sampling_rate=fs, + sampling_rate=sampling_rate_hz, show=False, ) - features.append( - { - "rmssd": float(hrv["HRV_RMSSD"].values[0]), - "sdnn": float(hrv["HRV_SDNN"].values[0]), - "pnn50": float(hrv["HRV_pNN50"].values[0]), - } - ) + return { + "rmssd": float(hrv["HRV_RMSSD"].values[0]), + "sdnn": float(hrv["HRV_SDNN"].values[0]), + "pnn50": float(hrv["HRV_pNN50"].values[0]), + } except Exception: - features.append( - { - "rmssd": np.nan, - "sdnn": np.nan, - "pnn50": np.nan, - } - ) + return self._build_missing_feature_dictionary(["rmssd", "sdnn", "pnn50"]) - return features + return self._build_feature_dictionary_from_epochs( + epochs, + build_blood_volume_pulse_feature_dictionary, + ) - def _lowpass_filter( + def _filter_signal_with_lowpass( self, signal: np.ndarray, - fs: float, + sampling_rate_hz: float, cutoff_hz: float, order: int = 4, ) -> np.ndarray: - nyq = 0.5 * fs + nyq = 0.5 * sampling_rate_hz b, a = butter(order, cutoff_hz / nyq, btype="low") - return self._apply_filter(signal, b, a) + return self._apply_zero_phase_filter(signal, b, a) - def _detrend_segments( + def _detrend_signal_by_segments( self, signal: np.ndarray, - fs: float, + sampling_rate_hz: float, segment_seconds: int, ) -> np.ndarray: - samples_per_seg = int(fs * segment_seconds) + samples_per_seg = int(sampling_rate_hz * segment_seconds) detrended = signal.copy() for i in range(0, len(signal), samples_per_seg): @@ -166,58 +197,68 @@ def _detrend_segments( return detrended - def _extract_eda_features( + def _extract_electrodermal_activity_epoch_features( self, signal: np.ndarray, - fs: float, + sampling_rate_hz: float, ) -> List[Dict[str, float]]: - detrended = self._detrend_segments(signal, fs, segment_seconds=5) - filtered = self._lowpass_filter(detrended, fs, cutoff_hz=1.0) + detrended = self._detrend_signal_by_segments( + signal, + sampling_rate_hz, + segment_seconds=5, + ) + filtered = self._filter_signal_with_lowpass( + detrended, + sampling_rate_hz, + cutoff_hz=1.0, + ) - eda_signals, _ = nk.eda_process(filtered, sampling_rate=fs) + eda_signals, _ = nk.eda_process(filtered, sampling_rate=sampling_rate_hz) scr = eda_signals["EDA_Phasic"].values - epochs = self._split_into_epochs(scr, fs) + epochs = self._split_signal_into_epochs(scr, sampling_rate_hz) - features = [] - for ep in epochs: + def build_electrodermal_activity_feature_dictionary( + epoch: np.ndarray, + ) -> Dict[str, float]: try: - _, info = nk.eda_peaks(ep, sampling_rate=fs) + _, info = nk.eda_peaks(epoch, sampling_rate=sampling_rate_hz) amplitudes = info["SCR_Amplitude"] rise_times = info["SCR_RiseTime"] recovery_times = info["SCR_RecoveryTime"] - features.append( - { - "scr_amp_mean": float(np.mean(amplitudes)) if len(amplitudes) else 0.0, - "scr_amp_max": float(np.max(amplitudes)) if len(amplitudes) else 0.0, - "scr_rise_mean": float(np.mean(rise_times)) if len(rise_times) else 0.0, - "scr_recovery_mean": float(np.mean(recovery_times)) if len(recovery_times) else 0.0, - } - ) + return { + "scr_amp_mean": float(np.mean(amplitudes)) if len(amplitudes) else 0.0, + "scr_amp_max": float(np.max(amplitudes)) if len(amplitudes) else 0.0, + "scr_rise_mean": float(np.mean(rise_times)) if len(rise_times) else 0.0, + "scr_recovery_mean": float(np.mean(recovery_times)) if len(recovery_times) else 0.0, + } except Exception: - features.append( - { - "scr_amp_mean": np.nan, - "scr_amp_max": np.nan, - "scr_rise_mean": np.nan, - "scr_recovery_mean": np.nan, - } + return self._build_missing_feature_dictionary( + [ + "scr_amp_mean", + "scr_amp_max", + "scr_rise_mean", + "scr_recovery_mean", + ] ) - return features + return self._build_feature_dictionary_from_epochs( + epochs, + build_electrodermal_activity_feature_dictionary, + ) - def _apply_gaussian_smoothing( + def _smooth_values_with_gaussian( self, values: np.ndarray, sigma: float, ) -> np.ndarray: return gaussian_filter1d(values, sigma=sigma, mode="nearest") - def _temporal_derivative(self, values: np.ndarray) -> np.ndarray: + def _compute_temporal_derivative(self, values: np.ndarray) -> np.ndarray: return np.diff(values, prepend=values[0]) - def _rolling_variance( + def _compute_rolling_variance( self, values: np.ndarray, window: int, @@ -232,7 +273,7 @@ def _rolling_variance( return out - def _enhance_features_temporally( + def _augment_epoch_features_with_temporal_context( self, epoch_features: List[List[float]], gaussian_sigma: float = 2.0, @@ -249,9 +290,9 @@ def _enhance_features_temporally( for j in range(num_features): values = feature_matrix[:, j] - smoothed = self._apply_gaussian_smoothing(values, gaussian_sigma) - deriv = self._temporal_derivative(smoothed) - var = self._rolling_variance(smoothed, variance_window) + smoothed = self._smooth_values_with_gaussian(values, gaussian_sigma) + deriv = self._compute_temporal_derivative(smoothed) + var = self._compute_rolling_variance(smoothed, variance_window) for i in range(num_epochs): enhanced[i].append(float(smoothed[i])) @@ -260,155 +301,202 @@ def _enhance_features_temporally( return enhanced - def _iqr(self, x: np.ndarray) -> float: + def _compute_interquartile_range(self, x: np.ndarray) -> float: return float(np.percentile(x, 75) - np.percentile(x, 25)) - def _extract_acc_axis_features(self, signal: np.ndarray, fs: float) -> List[Dict[str, float]]: - filtered = self._filter_acc(signal, fs) + def _extract_accelerometer_axis_epoch_features( + self, + signal: np.ndarray, + sampling_rate_hz: float, + ) -> List[Dict[str, float]]: + filtered = self._filter_accelerometer_signal(signal, sampling_rate_hz) filtered_abs = np.abs(filtered) - epochs = self._split_into_epochs(filtered_abs, fs) - - features = [] - for ep in epochs: - features.append( - { - "trimmed_mean": float(trim_mean(ep, proportiontocut=0.10)), - "max": float(np.max(ep)), - "iqr": self._iqr(ep), - } - ) - return features + epochs = self._split_signal_into_epochs(filtered_abs, sampling_rate_hz) + + return self._build_feature_dictionary_from_epochs( + epochs, + lambda epoch: { + "trimmed_mean": float(trim_mean(epoch, proportiontocut=0.10)), + "max": float(np.max(epoch)), + "iqr": self._compute_interquartile_range(epoch), + }, + ) - def _extract_acc_mad_features( + def _extract_accelerometer_magnitude_deviation_epoch_features( self, - acc_x: np.ndarray, - acc_y: np.ndarray, - acc_z: np.ndarray, - fs: float, + accelerometer_x_signal: np.ndarray, + accelerometer_y_signal: np.ndarray, + accelerometer_z_signal: np.ndarray, + sampling_rate_hz: float, ) -> List[Dict[str, float]]: - magnitude = np.sqrt(acc_x**2 + acc_y**2 + acc_z**2) - epochs = self._split_into_epochs(magnitude, fs) + magnitude = np.sqrt( + accelerometer_x_signal**2 + + accelerometer_y_signal**2 + + accelerometer_z_signal**2 + ) + epochs = self._split_signal_into_epochs(magnitude, sampling_rate_hz) - features = [] - for ep in epochs: - mad = np.mean(np.abs(ep - np.mean(ep))) - features.append({"mad": float(mad)}) - return features + return self._build_feature_dictionary_from_epochs( + epochs, + lambda epoch: {"mad": float(np.mean(np.abs(epoch - np.mean(epoch))))}, + ) - def _extract_temp_features(self, signal: np.ndarray, fs: float) -> List[Dict[str, float]]: + def _extract_temperature_epoch_features( + self, + signal: np.ndarray, + sampling_rate_hz: float, + ) -> List[Dict[str, float]]: limits = (0.05, 0.05) wins_signal = winsorize(signal, limits=limits) wins_signal = np.clip(wins_signal, 31.0, 40.0) - epochs = self._split_into_epochs(np.asarray(wins_signal), fs) + epochs = self._split_signal_into_epochs(np.asarray(wins_signal), sampling_rate_hz) + + return self._build_feature_dictionary_from_epochs( + epochs, + lambda epoch: { + "mean": float(np.mean(epoch)), + "min": float(np.min(epoch)), + "max": float(np.max(epoch)), + "std": float(np.std(epoch)), + }, + ) - features = [] - for ep in epochs: - features.append( - { - "mean": float(np.mean(ep)), - "min": float(np.min(ep)), - "max": float(np.max(ep)), - "std": float(np.std(ep)), - } - ) - return features + def _has_required_sensor_columns( + self, + record_dataframe: pd.DataFrame, + ) -> bool: + required_columns = {"ACC_X", "ACC_Y", "ACC_Z", "TEMP", "BVP", "EDA"} + return required_columns.issubset(record_dataframe.columns) - def _build_record_epoch_features(self, df: pd.DataFrame) -> List[List[float]]: - fs = float(self.sampling_rate) + def _extract_sensor_signals_from_dataframe( + self, + record_dataframe: pd.DataFrame, + ) -> Dict[str, np.ndarray]: + return { + "accelerometer_x": self._convert_series_to_numeric_array(record_dataframe["ACC_X"]), + "accelerometer_y": self._convert_series_to_numeric_array(record_dataframe["ACC_Y"]), + "accelerometer_z": self._convert_series_to_numeric_array(record_dataframe["ACC_Z"]), + "temperature": self._convert_series_to_numeric_array(record_dataframe["TEMP"]), + "blood_volume_pulse": self._convert_series_to_numeric_array(record_dataframe["BVP"]), + "electrodermal_activity": self._convert_series_to_numeric_array(record_dataframe["EDA"]), + } + + def _extract_feature_sets_for_all_modalities( + self, + sensor_signals: Dict[str, np.ndarray], + sampling_rate_hz: float, + ) -> Dict[str, List[Dict[str, float]]]: + return { + "accelerometer_x": self._extract_accelerometer_axis_epoch_features( + sensor_signals["accelerometer_x"], + sampling_rate_hz, + ), + "accelerometer_y": self._extract_accelerometer_axis_epoch_features( + sensor_signals["accelerometer_y"], + sampling_rate_hz, + ), + "accelerometer_z": self._extract_accelerometer_axis_epoch_features( + sensor_signals["accelerometer_z"], + sampling_rate_hz, + ), + "accelerometer_magnitude_deviation": self._extract_accelerometer_magnitude_deviation_epoch_features( + sensor_signals["accelerometer_x"], + sensor_signals["accelerometer_y"], + sensor_signals["accelerometer_z"], + sampling_rate_hz, + ), + "temperature": self._extract_temperature_epoch_features( + sensor_signals["temperature"], + sampling_rate_hz, + ), + "blood_volume_pulse": self._extract_blood_volume_pulse_epoch_features( + sensor_signals["blood_volume_pulse"], + sampling_rate_hz, + ), + "electrodermal_activity": self._extract_electrodermal_activity_epoch_features( + sensor_signals["electrodermal_activity"], + sampling_rate_hz, + ), + } + + def _count_complete_epochs( + self, + feature_sets: Dict[str, List[Dict[str, float]]], + ) -> int: + return min(len(features) for features in feature_sets.values()) - required_acc = ["ACC_X", "ACC_Y", "ACC_Z"] - if not all(col in df.columns for col in required_acc): - return [] + def _build_epoch_feature_vector( + self, + feature_sets: Dict[str, List[Dict[str, float]]], + epoch_index: int, + ) -> List[float]: + accelerometer_x_features = feature_sets["accelerometer_x"][epoch_index] + accelerometer_y_features = feature_sets["accelerometer_y"][epoch_index] + accelerometer_z_features = feature_sets["accelerometer_z"][epoch_index] + accelerometer_magnitude_deviation_features = feature_sets["accelerometer_magnitude_deviation"][epoch_index] + temperature_features = feature_sets["temperature"][epoch_index] + blood_volume_pulse_features = feature_sets["blood_volume_pulse"][epoch_index] + electrodermal_activity_features = feature_sets["electrodermal_activity"][epoch_index] - if "TEMP" not in df.columns: - return [] + features = [] + self._append_feature_values( + features, + accelerometer_x_features, + ["trimmed_mean", "max", "iqr"], + ) + self._append_feature_values( + features, + accelerometer_y_features, + ["trimmed_mean", "max", "iqr"], + ) + self._append_feature_values( + features, + accelerometer_z_features, + ["trimmed_mean", "max", "iqr"], + ) + self._append_feature_values( + features, + accelerometer_magnitude_deviation_features, + ["mad"], + ) + self._append_feature_values( + features, + temperature_features, + ["mean", "min", "max", "std"], + ) + self._append_feature_values( + features, + blood_volume_pulse_features, + ["rmssd", "sdnn", "pnn50"], + ) + self._append_feature_values( + features, + electrodermal_activity_features, + ["scr_amp_mean", "scr_amp_max", "scr_rise_mean", "scr_recovery_mean"], + ) + return features - if "BVP" not in df.columns: - return [] - - if "EDA" not in df.columns: + def _build_record_epoch_feature_matrix( + self, + record_dataframe: pd.DataFrame, + ) -> List[List[float]]: + sampling_rate_hz = float(self.sampling_rate) + + if not self._has_required_sensor_columns(record_dataframe): return [] - acc_x = self._safe_numeric(df["ACC_X"]) - acc_y = self._safe_numeric(df["ACC_Y"]) - acc_z = self._safe_numeric(df["ACC_Z"]) - temp = self._safe_numeric(df["TEMP"]) - bvp = self._safe_numeric(df["BVP"]) - eda = self._safe_numeric(df["EDA"]) - - acc_x_feats = self._extract_acc_axis_features(acc_x, fs) - acc_y_feats = self._extract_acc_axis_features(acc_y, fs) - acc_z_feats = self._extract_acc_axis_features(acc_z, fs) - acc_mad_feats = self._extract_acc_mad_features(acc_x, acc_y, acc_z, fs) - temp_feats = self._extract_temp_features(temp, fs) - bvp_feats = self._extract_bvp_features(bvp, fs) - eda_feats = self._extract_eda_features(eda, fs) - - num_epochs = min( - len(acc_x_feats), - len(acc_y_feats), - len(acc_z_feats), - len(acc_mad_feats), - len(temp_feats), - len(bvp_feats), + sensor_signals = self._extract_sensor_signals_from_dataframe(record_dataframe) + feature_sets = self._extract_feature_sets_for_all_modalities( + sensor_signals, + sampling_rate_hz, ) + num_epochs = self._count_complete_epochs(feature_sets) all_epoch_features = [] for i in range(num_epochs): - feats = [] - - feats.extend( - [ - acc_x_feats[i]["trimmed_mean"], - acc_x_feats[i]["max"], - acc_x_feats[i]["iqr"], - ] - ) - feats.extend( - [ - acc_y_feats[i]["trimmed_mean"], - acc_y_feats[i]["max"], - acc_y_feats[i]["iqr"], - ] - ) - feats.extend( - [ - acc_z_feats[i]["trimmed_mean"], - acc_z_feats[i]["max"], - acc_z_feats[i]["iqr"], - ] - ) - feats.append(acc_mad_feats[i]["mad"]) - - feats.extend( - [ - temp_feats[i]["mean"], - temp_feats[i]["min"], - temp_feats[i]["max"], - temp_feats[i]["std"], - ] - ) + all_epoch_features.append(self._build_epoch_feature_vector(feature_sets, i)) - feats.extend( - [ - bvp_feats[i]["rmssd"], - bvp_feats[i]["sdnn"], - bvp_feats[i]["pnn50"], - ] - ) - - feats.extend( - [ - eda_feats[i]["scr_amp_mean"], - eda_feats[i]["scr_amp_max"], - eda_feats[i]["scr_rise_mean"], - eda_feats[i]["scr_recovery_mean"], - ] - ) - - all_epoch_features.append(feats) - - all_epoch_features = self._enhance_features_temporally( + all_epoch_features = self._augment_epoch_features_with_temporal_context( all_epoch_features, gaussian_sigma=2.0, variance_window=5, @@ -416,59 +504,96 @@ def _build_record_epoch_features(self, df: pd.DataFrame) -> List[List[float]]: return all_epoch_features - def __call__(self, patient: Any) -> List[Dict[str, Any]]: + def _load_wearable_record_dataframe(self, event) -> pd.DataFrame | None: + file_path = getattr(event, "file_64hz", None) + if file_path is None: + return None + + try: + return pd.read_csv(file_path) + except Exception: + return None + + def _extract_binary_label_for_epoch( + self, + record_dataframe: pd.DataFrame, + epoch_index: int, + samples_per_epoch: int, + ) -> int | None: + start = epoch_index * samples_per_epoch + end = start + samples_per_epoch + epoch_dataframe = record_dataframe.iloc[start:end] + + if len(epoch_dataframe) < samples_per_epoch: + return None + + stage_mode = epoch_dataframe["Sleep_Stage"].mode(dropna=True) + if len(stage_mode) == 0: + return None + + return self._convert_sleep_stage_to_binary_label(stage_mode.iloc[0]) + + def _build_samples_for_sleep_event( + self, + patient: Patient, + sleep_event_index: int, + record_dataframe: pd.DataFrame, + record_epoch_feature_matrix: List[List[float]], + samples_per_epoch: int, + ) -> List[Dict[str, object]]: + samples = [] + n_labeled_epochs = len(record_dataframe) // samples_per_epoch + n_epochs = min(len(record_epoch_feature_matrix), n_labeled_epochs) + + for epoch_idx in range(n_epochs): + label = self._extract_binary_label_for_epoch( + record_dataframe, + epoch_idx, + samples_per_epoch, + ) + if label is None: + continue + + samples.append( + { + "patient_id": patient.patient_id, + "record_id": f"{patient.patient_id}-event{sleep_event_index}-epoch{epoch_idx}", + "epoch_index": epoch_idx, + "features": record_epoch_feature_matrix[epoch_idx], + "label": label, + } + ) + + return samples + + def __call__(self, patient: Patient) -> List[Dict[str, object]]: samples = [] events = patient.get_events(event_type="dreamt_sleep") if len(events) == 0: return samples - epoch_size = self.epoch_seconds * self.sampling_rate + samples_per_epoch = self.epoch_seconds * self.sampling_rate for event_idx, event in enumerate(events): - file_path = getattr(event, "file_64hz", None) - if file_path is None: + record_dataframe = self._load_wearable_record_dataframe(event) + if record_dataframe is None: continue - try: - df = pd.read_csv(file_path) - except Exception: + if "Sleep_Stage" not in record_dataframe.columns: continue - if "Sleep_Stage" not in df.columns: + record_epoch_feature_matrix = self._build_record_epoch_feature_matrix(record_dataframe) + if len(record_epoch_feature_matrix) == 0: continue - record_epoch_features = self._build_record_epoch_features(df) - if len(record_epoch_features) == 0: - continue - - n_label_epochs = len(df) // epoch_size - n_epochs = min(len(record_epoch_features), n_label_epochs) - - for epoch_idx in range(n_epochs): - start = epoch_idx * epoch_size - end = start + epoch_size - epoch_df = df.iloc[start:end] - - if len(epoch_df) < epoch_size: - continue - - stage_mode = epoch_df["Sleep_Stage"].mode(dropna=True) - if len(stage_mode) == 0: - continue - - raw_label = stage_mode.iloc[0] - label = self._map_sleep_label(raw_label) - if label is None: - continue - - samples.append( - { - "patient_id": patient.patient_id, - "record_id": f"{patient.patient_id}-event{event_idx}-epoch{epoch_idx}", - "epoch_index": epoch_idx, - "features": record_epoch_features[epoch_idx], - "label": label, - } + samples.extend( + self._build_samples_for_sleep_event( + patient=patient, + sleep_event_index=event_idx, + record_dataframe=record_dataframe, + record_epoch_feature_matrix=record_epoch_feature_matrix, + samples_per_epoch=samples_per_epoch, ) + ) - return samples \ No newline at end of file + return samples From 20cf21f668e26116ab0036bcde3d0e815489415e Mon Sep 17 00:00:00 2001 From: diegofariasc Date: Thu, 12 Mar 2026 03:21:50 -0600 Subject: [PATCH 11/27] refactor: reorder sleep-wake task methods by responsibility --- pyhealth/tasks/sleep_wake_classification.py | 279 +++++++++++--------- 1 file changed, 153 insertions(+), 126 deletions(-) diff --git a/pyhealth/tasks/sleep_wake_classification.py b/pyhealth/tasks/sleep_wake_classification.py index 0d7c821b4..6a4be49c9 100644 --- a/pyhealth/tasks/sleep_wake_classification.py +++ b/pyhealth/tasks/sleep_wake_classification.py @@ -15,7 +15,7 @@ class SleepWakeClassification(BaseTask): task_name = "SleepWakeClassification" - input_schema = {"features": "vector"} + input_schema = {"features": "tensor"} output_schema = {"label": "binary"} def __init__(self, epoch_seconds: int = 30, sampling_rate: int = 64): @@ -39,6 +39,13 @@ def _convert_sleep_stage_to_binary_label(self, label): def _convert_series_to_numeric_array(self, series: pd.Series) -> np.ndarray: return pd.to_numeric(series, errors="coerce").fillna(0.0).to_numpy() + def _has_required_sensor_columns( + self, + record_dataframe: pd.DataFrame, + ) -> bool: + required_columns = {"ACC_X", "ACC_Y", "ACC_Z", "TEMP", "BVP", "EDA"} + return required_columns.issubset(record_dataframe.columns) + def _split_signal_into_epochs( self, signal: np.ndarray, @@ -55,6 +62,32 @@ def _split_signal_into_epochs( return epochs + def _build_feature_dictionary_from_epochs( + self, + epochs: List[np.ndarray], + feature_builder: Callable[[np.ndarray], Dict[str, float]], + ) -> List[Dict[str, float]]: + return [feature_builder(epoch) for epoch in epochs] + + def _build_missing_feature_dictionary( + self, + feature_names: List[str], + ) -> Dict[str, float]: + return {feature_name: np.nan for feature_name in feature_names} + + def _append_feature_values( + self, + feature_vector: List[float], + feature_dictionary: Dict[str, float], + feature_names: List[str], + ) -> None: + feature_vector.extend( + feature_dictionary[feature_name] for feature_name in feature_names + ) + + def _compute_interquartile_range(self, x: np.ndarray) -> float: + return float(np.percentile(x, 75) - np.percentile(x, 25)) + def _design_bandpass_filter_coefficients( self, filter_family: str, @@ -85,6 +118,17 @@ def _apply_zero_phase_filter(self, signal: np.ndarray, b, a) -> np.ndarray: raise ValueError("Signal must be 1D.") return filtfilt(b, a, signal) + def _filter_signal_with_lowpass( + self, + signal: np.ndarray, + sampling_rate_hz: float, + cutoff_hz: float, + order: int = 4, + ) -> np.ndarray: + nyq = 0.5 * sampling_rate_hz + b, a = butter(order, cutoff_hz / nyq, btype="low") + return self._apply_zero_phase_filter(signal, b, a) + def _filter_accelerometer_signal( self, signal: np.ndarray, @@ -114,26 +158,86 @@ def _filter_blood_volume_pulse_signal( ) return self._apply_zero_phase_filter(signal, b, a) - def _build_feature_dictionary_from_epochs( + def _detrend_signal_by_segments( self, - epochs: List[np.ndarray], - feature_builder: Callable[[np.ndarray], Dict[str, float]], + signal: np.ndarray, + sampling_rate_hz: float, + segment_seconds: int, + ) -> np.ndarray: + samples_per_seg = int(sampling_rate_hz * segment_seconds) + detrended = signal.copy() + + for i in range(0, len(signal), samples_per_seg): + seg = signal[i : i + samples_per_seg] + if len(seg) < 2: + continue + + x = np.arange(len(seg)) + coeffs = np.polyfit(x, seg, deg=1) + trend = np.polyval(coeffs, x) + detrended[i : i + len(seg)] = seg - trend + + return detrended + + def _extract_accelerometer_axis_epoch_features( + self, + signal: np.ndarray, + sampling_rate_hz: float, ) -> List[Dict[str, float]]: - return [feature_builder(epoch) for epoch in epochs] + filtered = self._filter_accelerometer_signal(signal, sampling_rate_hz) + filtered_abs = np.abs(filtered) + epochs = self._split_signal_into_epochs(filtered_abs, sampling_rate_hz) - def _build_missing_feature_dictionary( + return self._build_feature_dictionary_from_epochs( + epochs, + lambda epoch: { + "trimmed_mean": float(trim_mean(epoch, proportiontocut=0.10)), + "max": float(np.max(epoch)), + "iqr": self._compute_interquartile_range(epoch), + }, + ) + + def _extract_accelerometer_magnitude_deviation_epoch_features( self, - feature_names: List[str], - ) -> Dict[str, float]: - return {feature_name: np.nan for feature_name in feature_names} + accelerometer_x_signal: np.ndarray, + accelerometer_y_signal: np.ndarray, + accelerometer_z_signal: np.ndarray, + sampling_rate_hz: float, + ) -> List[Dict[str, float]]: + magnitude = np.sqrt( + accelerometer_x_signal**2 + + accelerometer_y_signal**2 + + accelerometer_z_signal**2 + ) + epochs = self._split_signal_into_epochs(magnitude, sampling_rate_hz) - def _append_feature_values( + return self._build_feature_dictionary_from_epochs( + epochs, + lambda epoch: {"mad": float(np.mean(np.abs(epoch - np.mean(epoch))))}, + ) + + def _extract_temperature_epoch_features( self, - feature_vector: List[float], - feature_dictionary: Dict[str, float], - feature_names: List[str], - ) -> None: - feature_vector.extend(feature_dictionary[feature_name] for feature_name in feature_names) + signal: np.ndarray, + sampling_rate_hz: float, + ) -> List[Dict[str, float]]: + limits = (0.05, 0.05) + wins_signal = winsorize(signal, limits=limits) + wins_signal = np.clip(wins_signal, 31.0, 40.0) + epochs = self._split_signal_into_epochs( + np.asarray(wins_signal), + sampling_rate_hz, + ) + + return self._build_feature_dictionary_from_epochs( + epochs, + lambda epoch: { + "mean": float(np.mean(epoch)), + "min": float(np.min(epoch)), + "max": float(np.max(epoch)), + "std": float(np.std(epoch)), + }, + ) def _extract_blood_volume_pulse_epoch_features( self, @@ -143,7 +247,9 @@ def _extract_blood_volume_pulse_epoch_features( filtered = self._filter_blood_volume_pulse_signal(signal, sampling_rate_hz) epochs = self._split_signal_into_epochs(filtered, sampling_rate_hz) - def build_blood_volume_pulse_feature_dictionary(epoch: np.ndarray) -> Dict[str, float]: + def build_blood_volume_pulse_feature_dictionary( + epoch: np.ndarray, + ) -> Dict[str, float]: try: _, info = nk.ppg_process(epoch, sampling_rate=sampling_rate_hz) hrv = nk.hrv_time( @@ -158,45 +264,15 @@ def build_blood_volume_pulse_feature_dictionary(epoch: np.ndarray) -> Dict[str, "pnn50": float(hrv["HRV_pNN50"].values[0]), } except Exception: - return self._build_missing_feature_dictionary(["rmssd", "sdnn", "pnn50"]) + return self._build_missing_feature_dictionary( + ["rmssd", "sdnn", "pnn50"] + ) return self._build_feature_dictionary_from_epochs( epochs, build_blood_volume_pulse_feature_dictionary, ) - def _filter_signal_with_lowpass( - self, - signal: np.ndarray, - sampling_rate_hz: float, - cutoff_hz: float, - order: int = 4, - ) -> np.ndarray: - nyq = 0.5 * sampling_rate_hz - b, a = butter(order, cutoff_hz / nyq, btype="low") - return self._apply_zero_phase_filter(signal, b, a) - - def _detrend_signal_by_segments( - self, - signal: np.ndarray, - sampling_rate_hz: float, - segment_seconds: int, - ) -> np.ndarray: - samples_per_seg = int(sampling_rate_hz * segment_seconds) - detrended = signal.copy() - - for i in range(0, len(signal), samples_per_seg): - seg = signal[i : i + samples_per_seg] - if len(seg) < 2: - continue - - x = np.arange(len(seg)) - coeffs = np.polyfit(x, seg, deg=1) - trend = np.polyval(coeffs, x) - detrended[i : i + len(seg)] = seg - trend - - return detrended - def _extract_electrodermal_activity_epoch_features( self, signal: np.ndarray, @@ -247,7 +323,7 @@ def build_electrodermal_activity_feature_dictionary( epochs, build_electrodermal_activity_feature_dictionary, ) - + def _smooth_values_with_gaussian( self, values: np.ndarray, @@ -301,84 +377,29 @@ def _augment_epoch_features_with_temporal_context( return enhanced - def _compute_interquartile_range(self, x: np.ndarray) -> float: - return float(np.percentile(x, 75) - np.percentile(x, 25)) - - def _extract_accelerometer_axis_epoch_features( - self, - signal: np.ndarray, - sampling_rate_hz: float, - ) -> List[Dict[str, float]]: - filtered = self._filter_accelerometer_signal(signal, sampling_rate_hz) - filtered_abs = np.abs(filtered) - epochs = self._split_signal_into_epochs(filtered_abs, sampling_rate_hz) - - return self._build_feature_dictionary_from_epochs( - epochs, - lambda epoch: { - "trimmed_mean": float(trim_mean(epoch, proportiontocut=0.10)), - "max": float(np.max(epoch)), - "iqr": self._compute_interquartile_range(epoch), - }, - ) - - def _extract_accelerometer_magnitude_deviation_epoch_features( - self, - accelerometer_x_signal: np.ndarray, - accelerometer_y_signal: np.ndarray, - accelerometer_z_signal: np.ndarray, - sampling_rate_hz: float, - ) -> List[Dict[str, float]]: - magnitude = np.sqrt( - accelerometer_x_signal**2 - + accelerometer_y_signal**2 - + accelerometer_z_signal**2 - ) - epochs = self._split_signal_into_epochs(magnitude, sampling_rate_hz) - - return self._build_feature_dictionary_from_epochs( - epochs, - lambda epoch: {"mad": float(np.mean(np.abs(epoch - np.mean(epoch))))}, - ) - - def _extract_temperature_epoch_features( - self, - signal: np.ndarray, - sampling_rate_hz: float, - ) -> List[Dict[str, float]]: - limits = (0.05, 0.05) - wins_signal = winsorize(signal, limits=limits) - wins_signal = np.clip(wins_signal, 31.0, 40.0) - epochs = self._split_signal_into_epochs(np.asarray(wins_signal), sampling_rate_hz) - - return self._build_feature_dictionary_from_epochs( - epochs, - lambda epoch: { - "mean": float(np.mean(epoch)), - "min": float(np.min(epoch)), - "max": float(np.max(epoch)), - "std": float(np.std(epoch)), - }, - ) - - def _has_required_sensor_columns( - self, - record_dataframe: pd.DataFrame, - ) -> bool: - required_columns = {"ACC_X", "ACC_Y", "ACC_Z", "TEMP", "BVP", "EDA"} - return required_columns.issubset(record_dataframe.columns) - def _extract_sensor_signals_from_dataframe( self, record_dataframe: pd.DataFrame, ) -> Dict[str, np.ndarray]: return { - "accelerometer_x": self._convert_series_to_numeric_array(record_dataframe["ACC_X"]), - "accelerometer_y": self._convert_series_to_numeric_array(record_dataframe["ACC_Y"]), - "accelerometer_z": self._convert_series_to_numeric_array(record_dataframe["ACC_Z"]), - "temperature": self._convert_series_to_numeric_array(record_dataframe["TEMP"]), - "blood_volume_pulse": self._convert_series_to_numeric_array(record_dataframe["BVP"]), - "electrodermal_activity": self._convert_series_to_numeric_array(record_dataframe["EDA"]), + "accelerometer_x": self._convert_series_to_numeric_array( + record_dataframe["ACC_X"] + ), + "accelerometer_y": self._convert_series_to_numeric_array( + record_dataframe["ACC_Y"] + ), + "accelerometer_z": self._convert_series_to_numeric_array( + record_dataframe["ACC_Z"] + ), + "temperature": self._convert_series_to_numeric_array( + record_dataframe["TEMP"] + ), + "blood_volume_pulse": self._convert_series_to_numeric_array( + record_dataframe["BVP"] + ), + "electrodermal_activity": self._convert_series_to_numeric_array( + record_dataframe["EDA"] + ), } def _extract_feature_sets_for_all_modalities( @@ -433,10 +454,14 @@ def _build_epoch_feature_vector( accelerometer_x_features = feature_sets["accelerometer_x"][epoch_index] accelerometer_y_features = feature_sets["accelerometer_y"][epoch_index] accelerometer_z_features = feature_sets["accelerometer_z"][epoch_index] - accelerometer_magnitude_deviation_features = feature_sets["accelerometer_magnitude_deviation"][epoch_index] + accelerometer_magnitude_deviation_features = feature_sets[ + "accelerometer_magnitude_deviation" + ][epoch_index] temperature_features = feature_sets["temperature"][epoch_index] blood_volume_pulse_features = feature_sets["blood_volume_pulse"][epoch_index] - electrodermal_activity_features = feature_sets["electrodermal_activity"][epoch_index] + electrodermal_activity_features = feature_sets["electrodermal_activity"][ + epoch_index + ] features = [] self._append_feature_values( @@ -582,7 +607,9 @@ def __call__(self, patient: Patient) -> List[Dict[str, object]]: if "Sleep_Stage" not in record_dataframe.columns: continue - record_epoch_feature_matrix = self._build_record_epoch_feature_matrix(record_dataframe) + record_epoch_feature_matrix = self._build_record_epoch_feature_matrix( + record_dataframe + ) if len(record_epoch_feature_matrix) == 0: continue From 4686a5147b7e795256d47c6e1ec2fad80a5b9dd5 Mon Sep 17 00:00:00 2001 From: diegofariasc Date: Thu, 12 Mar 2026 03:29:14 -0600 Subject: [PATCH 12/27] doc: add sleep_wake_classification.rst --- .../api/tasks/pyhealth.tasks.sleep_wake_classification.rst | 7 +++++++ 1 file changed, 7 insertions(+) create mode 100644 docs/api/tasks/pyhealth.tasks.sleep_wake_classification.rst diff --git a/docs/api/tasks/pyhealth.tasks.sleep_wake_classification.rst b/docs/api/tasks/pyhealth.tasks.sleep_wake_classification.rst new file mode 100644 index 000000000..ee613f09f --- /dev/null +++ b/docs/api/tasks/pyhealth.tasks.sleep_wake_classification.rst @@ -0,0 +1,7 @@ +pyhealth.tasks.sleep_wake_classification +======================================== + +.. autoclass:: pyhealth.tasks.sleep_wake_classification.SleepWakeClassification + :members: + :undoc-members: + :show-inheritance: From df5720f2ed11afe5cca9130bfa04e4df76cd7138 Mon Sep 17 00:00:00 2001 From: diegofariasc Date: Thu, 12 Mar 2026 03:48:51 -0600 Subject: [PATCH 13/27] doc: document all methods in SleepWakeClassification --- pyhealth/tasks/sleep_wake_classification.py | 504 ++++++++++++++------ 1 file changed, 357 insertions(+), 147 deletions(-) diff --git a/pyhealth/tasks/sleep_wake_classification.py b/pyhealth/tasks/sleep_wake_classification.py index 6a4be49c9..fbfe3c782 100644 --- a/pyhealth/tasks/sleep_wake_classification.py +++ b/pyhealth/tasks/sleep_wake_classification.py @@ -1,4 +1,4 @@ -from typing import Callable, Dict, List +from typing import Dict, List import neurokit2 as nk import numpy as np @@ -14,16 +14,39 @@ class SleepWakeClassification(BaseTask): + """Binary sleep-wake classification task for DREAMT wearable recordings. + + This task converts each DREAMT wearable recording into fixed-length epochs, + extracts physiological features from multiple sensor modalities, + augments them with temporal context, and assigns a binary sleep/wake + label to each epoch. + """ + task_name = "SleepWakeClassification" input_schema = {"features": "tensor"} output_schema = {"label": "binary"} def __init__(self, epoch_seconds: int = 30, sampling_rate: int = 64): + """Initializes the sleep-wake classification task. + + Args: + epoch_seconds: Length of each epoch in seconds. + sampling_rate: Sampling rate of the wearable recording in Hz. + """ self.epoch_seconds = epoch_seconds self.sampling_rate = sampling_rate super().__init__() def _convert_sleep_stage_to_binary_label(self, label): + """Maps a sleep stage label to the binary sleep-wake target. + + Args: + label: Raw sleep stage label from the DREAMT recording. + + Returns: + ``1`` for wake, ``0`` for sleep, or ``None`` if the label is + missing or unsupported. + """ if label is None or pd.isna(label): return None @@ -36,21 +59,20 @@ def _convert_sleep_stage_to_binary_label(self, label): return None - def _convert_series_to_numeric_array(self, series: pd.Series) -> np.ndarray: - return pd.to_numeric(series, errors="coerce").fillna(0.0).to_numpy() - - def _has_required_sensor_columns( - self, - record_dataframe: pd.DataFrame, - ) -> bool: - required_columns = {"ACC_X", "ACC_Y", "ACC_Z", "TEMP", "BVP", "EDA"} - return required_columns.issubset(record_dataframe.columns) - def _split_signal_into_epochs( self, signal: np.ndarray, sampling_rate_hz: float, ) -> List[np.ndarray]: + """Splits a 1D signal into non-overlapping fixed-length epochs. + + Args: + signal: One-dimensional signal array. + sampling_rate_hz: Sampling rate of the signal in Hz. + + Returns: + A list of signal segments, one per complete epoch. + """ samples_per_epoch = int(sampling_rate_hz * self.epoch_seconds) num_epochs = len(signal) // samples_per_epoch @@ -62,32 +84,6 @@ def _split_signal_into_epochs( return epochs - def _build_feature_dictionary_from_epochs( - self, - epochs: List[np.ndarray], - feature_builder: Callable[[np.ndarray], Dict[str, float]], - ) -> List[Dict[str, float]]: - return [feature_builder(epoch) for epoch in epochs] - - def _build_missing_feature_dictionary( - self, - feature_names: List[str], - ) -> Dict[str, float]: - return {feature_name: np.nan for feature_name in feature_names} - - def _append_feature_values( - self, - feature_vector: List[float], - feature_dictionary: Dict[str, float], - feature_names: List[str], - ) -> None: - feature_vector.extend( - feature_dictionary[feature_name] for feature_name in feature_names - ) - - def _compute_interquartile_range(self, x: np.ndarray) -> float: - return float(np.percentile(x, 75) - np.percentile(x, 25)) - def _design_bandpass_filter_coefficients( self, filter_family: str, @@ -97,6 +93,23 @@ def _design_bandpass_filter_coefficients( order: int, stopband_attenuation_db: float = 40.0, ): + """Designs band-pass filter coefficients for a supported family. + + Args: + filter_family: Filter family name, currently ``"butter"`` or + ``"cheby2"``. + low_hz: Lower cutoff frequency in Hz. + high_hz: Upper cutoff frequency in Hz. + sampling_rate_hz: Sampling rate in Hz. + order: Filter order. + stopband_attenuation_db: Stopband attenuation used by Chebyshev-II. + + Returns: + The numerator and denominator filter coefficients. + + Raises: + ValueError: If the filter family is unsupported. + """ nyq = 0.5 * sampling_rate_hz low = low_hz / nyq high = high_hz / nyq @@ -114,6 +127,19 @@ def _design_bandpass_filter_coefficients( raise ValueError(f"Unsupported bandpass filter family: {filter_family}") def _apply_zero_phase_filter(self, signal: np.ndarray, b, a) -> np.ndarray: + """Applies zero-phase filtering to a one-dimensional signal. + + Args: + signal: One-dimensional signal array. + b: Numerator filter coefficients. + a: Denominator filter coefficients. + + Returns: + The filtered signal. + + Raises: + ValueError: If ``signal`` is not one-dimensional. + """ if signal.ndim != 1: raise ValueError("Signal must be 1D.") return filtfilt(b, a, signal) @@ -125,6 +151,17 @@ def _filter_signal_with_lowpass( cutoff_hz: float, order: int = 4, ) -> np.ndarray: + """Applies a Butterworth low-pass filter to a signal. + + Args: + signal: One-dimensional signal array. + sampling_rate_hz: Sampling rate in Hz. + cutoff_hz: Low-pass cutoff frequency in Hz. + order: Filter order. + + Returns: + The filtered signal. + """ nyq = 0.5 * sampling_rate_hz b, a = butter(order, cutoff_hz / nyq, btype="low") return self._apply_zero_phase_filter(signal, b, a) @@ -134,6 +171,15 @@ def _filter_accelerometer_signal( signal: np.ndarray, sampling_rate_hz: float, ) -> np.ndarray: + """Band-pass filters an accelerometer signal for movement features. + + Args: + signal: Raw accelerometer signal from one axis. + sampling_rate_hz: Sampling rate in Hz. + + Returns: + The filtered accelerometer signal. + """ b, a = self._design_bandpass_filter_coefficients( filter_family="butter", low_hz=3.0, @@ -148,6 +194,15 @@ def _filter_blood_volume_pulse_signal( signal: np.ndarray, sampling_rate_hz: float, ) -> np.ndarray: + """Band-pass filters a blood volume pulse signal. + + Args: + signal: Raw blood volume pulse signal. + sampling_rate_hz: Sampling rate in Hz. + + Returns: + The filtered blood volume pulse signal. + """ b, a = self._design_bandpass_filter_coefficients( filter_family="cheby2", low_hz=0.5, @@ -164,6 +219,16 @@ def _detrend_signal_by_segments( sampling_rate_hz: float, segment_seconds: int, ) -> np.ndarray: + """Detrends a signal by removing a linear trend per segment. + + Args: + signal: One-dimensional signal array. + sampling_rate_hz: Sampling rate in Hz. + segment_seconds: Segment length used for local detrending. + + Returns: + A detrended copy of the input signal. + """ samples_per_seg = int(sampling_rate_hz * segment_seconds) detrended = signal.copy() @@ -184,18 +249,27 @@ def _extract_accelerometer_axis_epoch_features( signal: np.ndarray, sampling_rate_hz: float, ) -> List[Dict[str, float]]: + """Extracts per-epoch features from one accelerometer axis. + + Args: + signal: Raw accelerometer signal from one axis. + sampling_rate_hz: Sampling rate in Hz. + + Returns: + A list of feature dictionaries, one per epoch. + """ filtered = self._filter_accelerometer_signal(signal, sampling_rate_hz) filtered_abs = np.abs(filtered) epochs = self._split_signal_into_epochs(filtered_abs, sampling_rate_hz) - return self._build_feature_dictionary_from_epochs( - epochs, - lambda epoch: { + return [ + { "trimmed_mean": float(trim_mean(epoch, proportiontocut=0.10)), "max": float(np.max(epoch)), - "iqr": self._compute_interquartile_range(epoch), - }, - ) + "iqr": float(np.percentile(epoch, 75) - np.percentile(epoch, 25)), + } + for epoch in epochs + ] def _extract_accelerometer_magnitude_deviation_epoch_features( self, @@ -204,6 +278,17 @@ def _extract_accelerometer_magnitude_deviation_epoch_features( accelerometer_z_signal: np.ndarray, sampling_rate_hz: float, ) -> List[Dict[str, float]]: + """Extracts mean absolute deviation features from ACC magnitude. + + Args: + accelerometer_x_signal: Accelerometer X-axis signal. + accelerometer_y_signal: Accelerometer Y-axis signal. + accelerometer_z_signal: Accelerometer Z-axis signal. + sampling_rate_hz: Sampling rate in Hz. + + Returns: + A list of feature dictionaries containing magnitude deviation. + """ magnitude = np.sqrt( accelerometer_x_signal**2 + accelerometer_y_signal**2 @@ -211,16 +296,25 @@ def _extract_accelerometer_magnitude_deviation_epoch_features( ) epochs = self._split_signal_into_epochs(magnitude, sampling_rate_hz) - return self._build_feature_dictionary_from_epochs( - epochs, - lambda epoch: {"mad": float(np.mean(np.abs(epoch - np.mean(epoch))))}, - ) + return [ + {"mad": float(np.mean(np.abs(epoch - np.mean(epoch))))} + for epoch in epochs + ] def _extract_temperature_epoch_features( self, signal: np.ndarray, sampling_rate_hz: float, ) -> List[Dict[str, float]]: + """Extracts per-epoch temperature summary statistics. + + Args: + signal: Raw temperature signal. + sampling_rate_hz: Sampling rate in Hz. + + Returns: + A list of feature dictionaries, one per epoch. + """ limits = (0.05, 0.05) wins_signal = winsorize(signal, limits=limits) wins_signal = np.clip(wins_signal, 31.0, 40.0) @@ -229,27 +323,34 @@ def _extract_temperature_epoch_features( sampling_rate_hz, ) - return self._build_feature_dictionary_from_epochs( - epochs, - lambda epoch: { + return [ + { "mean": float(np.mean(epoch)), "min": float(np.min(epoch)), "max": float(np.max(epoch)), "std": float(np.std(epoch)), - }, - ) + } + for epoch in epochs + ] def _extract_blood_volume_pulse_epoch_features( self, signal: np.ndarray, sampling_rate_hz: float, ) -> List[Dict[str, float]]: + """Extracts HRV-based features from blood volume pulse epochs. + + Args: + signal: Raw blood volume pulse signal. + sampling_rate_hz: Sampling rate in Hz. + + Returns: + A list of feature dictionaries, one per epoch. + """ filtered = self._filter_blood_volume_pulse_signal(signal, sampling_rate_hz) epochs = self._split_signal_into_epochs(filtered, sampling_rate_hz) - - def build_blood_volume_pulse_feature_dictionary( - epoch: np.ndarray, - ) -> Dict[str, float]: + epoch_features = [] + for epoch in epochs: try: _, info = nk.ppg_process(epoch, sampling_rate=sampling_rate_hz) hrv = nk.hrv_time( @@ -258,26 +359,34 @@ def build_blood_volume_pulse_feature_dictionary( show=False, ) - return { - "rmssd": float(hrv["HRV_RMSSD"].values[0]), - "sdnn": float(hrv["HRV_SDNN"].values[0]), - "pnn50": float(hrv["HRV_pNN50"].values[0]), - } + epoch_features.append( + { + "rmssd": float(hrv["HRV_RMSSD"].values[0]), + "sdnn": float(hrv["HRV_SDNN"].values[0]), + "pnn50": float(hrv["HRV_pNN50"].values[0]), + } + ) except Exception: - return self._build_missing_feature_dictionary( - ["rmssd", "sdnn", "pnn50"] + epoch_features.append( + {"rmssd": np.nan, "sdnn": np.nan, "pnn50": np.nan} ) - return self._build_feature_dictionary_from_epochs( - epochs, - build_blood_volume_pulse_feature_dictionary, - ) + return epoch_features def _extract_electrodermal_activity_epoch_features( self, signal: np.ndarray, sampling_rate_hz: float, ) -> List[Dict[str, float]]: + """Extracts SCR-based features from electrodermal activity epochs. + + Args: + signal: Raw electrodermal activity signal. + sampling_rate_hz: Sampling rate in Hz. + + Returns: + A list of feature dictionaries, one per epoch. + """ detrended = self._detrend_signal_by_segments( signal, sampling_rate_hz, @@ -292,10 +401,8 @@ def _extract_electrodermal_activity_epoch_features( eda_signals, _ = nk.eda_process(filtered, sampling_rate=sampling_rate_hz) scr = eda_signals["EDA_Phasic"].values epochs = self._split_signal_into_epochs(scr, sampling_rate_hz) - - def build_electrodermal_activity_feature_dictionary( - epoch: np.ndarray, - ) -> Dict[str, float]: + epoch_features = [] + for epoch in epochs: try: _, info = nk.eda_peaks(epoch, sampling_rate=sampling_rate_hz) @@ -303,42 +410,48 @@ def build_electrodermal_activity_feature_dictionary( rise_times = info["SCR_RiseTime"] recovery_times = info["SCR_RecoveryTime"] - return { - "scr_amp_mean": float(np.mean(amplitudes)) if len(amplitudes) else 0.0, - "scr_amp_max": float(np.max(amplitudes)) if len(amplitudes) else 0.0, - "scr_rise_mean": float(np.mean(rise_times)) if len(rise_times) else 0.0, - "scr_recovery_mean": float(np.mean(recovery_times)) if len(recovery_times) else 0.0, - } + epoch_features.append( + { + "scr_amp_mean": float(np.mean(amplitudes)) + if len(amplitudes) + else 0.0, + "scr_amp_max": float(np.max(amplitudes)) + if len(amplitudes) + else 0.0, + "scr_rise_mean": float(np.mean(rise_times)) + if len(rise_times) + else 0.0, + "scr_recovery_mean": float(np.mean(recovery_times)) + if len(recovery_times) + else 0.0, + } + ) except Exception: - return self._build_missing_feature_dictionary( - [ - "scr_amp_mean", - "scr_amp_max", - "scr_rise_mean", - "scr_recovery_mean", - ] + epoch_features.append( + { + "scr_amp_mean": np.nan, + "scr_amp_max": np.nan, + "scr_rise_mean": np.nan, + "scr_recovery_mean": np.nan, + } ) - return self._build_feature_dictionary_from_epochs( - epochs, - build_electrodermal_activity_feature_dictionary, - ) - - def _smooth_values_with_gaussian( - self, - values: np.ndarray, - sigma: float, - ) -> np.ndarray: - return gaussian_filter1d(values, sigma=sigma, mode="nearest") - - def _compute_temporal_derivative(self, values: np.ndarray) -> np.ndarray: - return np.diff(values, prepend=values[0]) + return epoch_features def _compute_rolling_variance( self, values: np.ndarray, window: int, ) -> np.ndarray: + """Computes a centered rolling variance over a feature series. + + Args: + values: One-dimensional array of feature values over epochs. + window: Rolling window size in number of epochs. + + Returns: + An array of rolling variance values. + """ out = np.zeros_like(values) half = window // 2 @@ -355,6 +468,19 @@ def _augment_epoch_features_with_temporal_context( gaussian_sigma: float = 2.0, variance_window: int = 5, ) -> List[List[float]]: + """Augments each epoch feature vector with temporal context features. + + For each base feature, this method appends a smoothed version, its + temporal derivative, and a rolling variance estimate. + + Args: + epoch_features: Base feature matrix as a list of epoch vectors. + gaussian_sigma: Gaussian smoothing parameter. + variance_window: Rolling variance window size in epochs. + + Returns: + The augmented epoch feature matrix. + """ if len(epoch_features) == 0: return [] @@ -366,8 +492,8 @@ def _augment_epoch_features_with_temporal_context( for j in range(num_features): values = feature_matrix[:, j] - smoothed = self._smooth_values_with_gaussian(values, gaussian_sigma) - deriv = self._compute_temporal_derivative(smoothed) + smoothed = gaussian_filter1d(values, sigma=gaussian_sigma, mode="nearest") + deriv = np.diff(smoothed, prepend=smoothed[0]) var = self._compute_rolling_variance(smoothed, variance_window) for i in range(num_epochs): @@ -381,25 +507,33 @@ def _extract_sensor_signals_from_dataframe( self, record_dataframe: pd.DataFrame, ) -> Dict[str, np.ndarray]: + """Extracts all sensor modalities from a DREAMT record. + + Args: + record_dataframe: Wearable recording loaded as a pandas DataFrame. + + Returns: + A dictionary mapping modality names to numeric NumPy arrays. + """ return { - "accelerometer_x": self._convert_series_to_numeric_array( - record_dataframe["ACC_X"] - ), - "accelerometer_y": self._convert_series_to_numeric_array( - record_dataframe["ACC_Y"] - ), - "accelerometer_z": self._convert_series_to_numeric_array( - record_dataframe["ACC_Z"] - ), - "temperature": self._convert_series_to_numeric_array( - record_dataframe["TEMP"] - ), - "blood_volume_pulse": self._convert_series_to_numeric_array( - record_dataframe["BVP"] - ), - "electrodermal_activity": self._convert_series_to_numeric_array( - record_dataframe["EDA"] - ), + "accelerometer_x": pd.to_numeric( + record_dataframe["ACC_X"], errors="coerce" + ).fillna(0.0).to_numpy(), + "accelerometer_y": pd.to_numeric( + record_dataframe["ACC_Y"], errors="coerce" + ).fillna(0.0).to_numpy(), + "accelerometer_z": pd.to_numeric( + record_dataframe["ACC_Z"], errors="coerce" + ).fillna(0.0).to_numpy(), + "temperature": pd.to_numeric( + record_dataframe["TEMP"], errors="coerce" + ).fillna(0.0).to_numpy(), + "blood_volume_pulse": pd.to_numeric( + record_dataframe["BVP"], errors="coerce" + ).fillna(0.0).to_numpy(), + "electrodermal_activity": pd.to_numeric( + record_dataframe["EDA"], errors="coerce" + ).fillna(0.0).to_numpy(), } def _extract_feature_sets_for_all_modalities( @@ -407,6 +541,15 @@ def _extract_feature_sets_for_all_modalities( sensor_signals: Dict[str, np.ndarray], sampling_rate_hz: float, ) -> Dict[str, List[Dict[str, float]]]: + """Extracts per-epoch features for every supported sensor modality. + + Args: + sensor_signals: Dictionary of raw modality arrays. + sampling_rate_hz: Sampling rate in Hz. + + Returns: + A dictionary mapping modality names to lists of epoch features. + """ return { "accelerometer_x": self._extract_accelerometer_axis_epoch_features( sensor_signals["accelerometer_x"], @@ -444,6 +587,15 @@ def _count_complete_epochs( self, feature_sets: Dict[str, List[Dict[str, float]]], ) -> int: + """Counts how many epochs are complete across all modalities. + + Args: + feature_sets: Per-modality epoch feature dictionaries. + + Returns: + The minimum number of aligned complete epochs available across all + modalities. + """ return min(len(features) for features in feature_sets.values()) def _build_epoch_feature_vector( @@ -451,6 +603,15 @@ def _build_epoch_feature_vector( feature_sets: Dict[str, List[Dict[str, float]]], epoch_index: int, ) -> List[float]: + """Builds the final feature vector for one epoch. + + Args: + feature_sets: Per-modality epoch feature dictionaries. + epoch_index: Epoch index to assemble. + + Returns: + A flat feature vector for the requested epoch. + """ accelerometer_x_features = feature_sets["accelerometer_x"][epoch_index] accelerometer_y_features = feature_sets["accelerometer_y"][epoch_index] accelerometer_z_features = feature_sets["accelerometer_z"][epoch_index] @@ -464,40 +625,38 @@ def _build_epoch_feature_vector( ] features = [] - self._append_feature_values( - features, - accelerometer_x_features, - ["trimmed_mean", "max", "iqr"], + features.extend( + accelerometer_x_features[feature_name] + for feature_name in ["trimmed_mean", "max", "iqr"] ) - self._append_feature_values( - features, - accelerometer_y_features, - ["trimmed_mean", "max", "iqr"], + features.extend( + accelerometer_y_features[feature_name] + for feature_name in ["trimmed_mean", "max", "iqr"] ) - self._append_feature_values( - features, - accelerometer_z_features, - ["trimmed_mean", "max", "iqr"], + features.extend( + accelerometer_z_features[feature_name] + for feature_name in ["trimmed_mean", "max", "iqr"] ) - self._append_feature_values( - features, - accelerometer_magnitude_deviation_features, - ["mad"], + features.extend( + accelerometer_magnitude_deviation_features[feature_name] + for feature_name in ["mad"] ) - self._append_feature_values( - features, - temperature_features, - ["mean", "min", "max", "std"], + features.extend( + temperature_features[feature_name] + for feature_name in ["mean", "min", "max", "std"] ) - self._append_feature_values( - features, - blood_volume_pulse_features, - ["rmssd", "sdnn", "pnn50"], + features.extend( + blood_volume_pulse_features[feature_name] + for feature_name in ["rmssd", "sdnn", "pnn50"] ) - self._append_feature_values( - features, - electrodermal_activity_features, - ["scr_amp_mean", "scr_amp_max", "scr_rise_mean", "scr_recovery_mean"], + features.extend( + electrodermal_activity_features[feature_name] + for feature_name in [ + "scr_amp_mean", + "scr_amp_max", + "scr_rise_mean", + "scr_recovery_mean", + ] ) return features @@ -505,9 +664,19 @@ def _build_record_epoch_feature_matrix( self, record_dataframe: pd.DataFrame, ) -> List[List[float]]: + """Builds the full epoch feature matrix for one wearable record. + + Args: + record_dataframe: Wearable recording loaded as a pandas DataFrame. + + Returns: + A list of epoch feature vectors. Returns an empty list if required + sensor columns are missing. + """ sampling_rate_hz = float(self.sampling_rate) - if not self._has_required_sensor_columns(record_dataframe): + required_columns = {"ACC_X", "ACC_Y", "ACC_Z", "TEMP", "BVP", "EDA"} + if not required_columns.issubset(record_dataframe.columns): return [] sensor_signals = self._extract_sensor_signals_from_dataframe(record_dataframe) @@ -530,6 +699,14 @@ def _build_record_epoch_feature_matrix( return all_epoch_features def _load_wearable_record_dataframe(self, event) -> pd.DataFrame | None: + """Loads the wearable CSV file associated with a DREAMT event. + + Args: + event: DREAMT event containing the ``file_64hz`` attribute. + + Returns: + A pandas DataFrame if the file can be loaded, otherwise ``None``. + """ file_path = getattr(event, "file_64hz", None) if file_path is None: return None @@ -545,6 +722,17 @@ def _extract_binary_label_for_epoch( epoch_index: int, samples_per_epoch: int, ) -> int | None: + """Extracts the binary sleep-wake label for one epoch. + + Args: + record_dataframe: Wearable recording loaded as a pandas DataFrame. + epoch_index: Index of the epoch to label. + samples_per_epoch: Number of samples contained in one epoch. + + Returns: + ``1`` for wake, ``0`` for sleep, or ``None`` if the epoch cannot be + labeled. + """ start = epoch_index * samples_per_epoch end = start + samples_per_epoch epoch_dataframe = record_dataframe.iloc[start:end] @@ -566,6 +754,19 @@ def _build_samples_for_sleep_event( record_epoch_feature_matrix: List[List[float]], samples_per_epoch: int, ) -> List[Dict[str, object]]: + """Builds epoch-level samples for one DREAMT sleep event. + + Args: + patient: Patient containing the event. + sleep_event_index: Index of the sleep event within the patient. + record_dataframe: Wearable recording loaded as a pandas DataFrame. + record_epoch_feature_matrix: Per-epoch feature matrix for the + record. + samples_per_epoch: Number of samples contained in one epoch. + + Returns: + A list of task samples for the event. + """ samples = [] n_labeled_epochs = len(record_dataframe) // samples_per_epoch n_epochs = min(len(record_epoch_feature_matrix), n_labeled_epochs) @@ -592,6 +793,15 @@ def _build_samples_for_sleep_event( return samples def __call__(self, patient: Patient) -> List[Dict[str, object]]: + """Processes one patient into epoch-level sleep-wake samples. + + Args: + patient: DREAMT patient to process. + + Returns: + A list of samples containing patient identifier, record identifier, + epoch index, handcrafted features, and binary sleep-wake label. + """ samples = [] events = patient.get_events(event_type="dreamt_sleep") if len(events) == 0: From 90982927049b95c18ba51817698d9b4071001f72 Mon Sep 17 00:00:00 2001 From: diegofariasc Date: Thu, 12 Mar 2026 03:52:08 -0600 Subject: [PATCH 14/27] feat: add SleepWakeClassification to init.py --- pyhealth/tasks/__init__.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/pyhealth/tasks/__init__.py b/pyhealth/tasks/__init__.py index 2f4294a19..75666bc57 100644 --- a/pyhealth/tasks/__init__.py +++ b/pyhealth/tasks/__init__.py @@ -68,3 +68,5 @@ VariantClassificationClinVar, ) from .patient_linkage_mimic3 import PatientLinkageMIMIC3Task + +from .sleep_wake_classification import SleepWakeClassification \ No newline at end of file From 79efe276154f732d7d2b20d2645329860abc69b6 Mon Sep 17 00:00:00 2001 From: diegofariasc Date: Thu, 12 Mar 2026 03:53:34 -0600 Subject: [PATCH 15/27] feat: add Sleep-Wake Classification to tasks.rst --- docs/api/tasks.rst | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/api/tasks.rst b/docs/api/tasks.rst index 3ed1e1c97..4fc3eac23 100644 --- a/docs/api/tasks.rst +++ b/docs/api/tasks.rst @@ -88,6 +88,7 @@ Available Tasks Mortality Prediction (StageNet MIMIC-IV) Patient Linkage (MIMIC-III) Readmission Prediction + Sleep-Wake Classification Sleep Staging Sleep Staging (SleepEDF) Temple University EEG Tasks From b72ae6b8503c86f723cf941b2c7d8b461d4b6760 Mon Sep 17 00:00:00 2001 From: diegofariasc Date: Thu, 12 Mar 2026 04:04:19 -0600 Subject: [PATCH 16/27] refactor: use black+isort to autoformat task code following PEP88 --- pyhealth/tasks/sleep_wake_classification.py | 66 +++++++++++---------- 1 file changed, 35 insertions(+), 31 deletions(-) diff --git a/pyhealth/tasks/sleep_wake_classification.py b/pyhealth/tasks/sleep_wake_classification.py index fbfe3c782..92cd2c41e 100644 --- a/pyhealth/tasks/sleep_wake_classification.py +++ b/pyhealth/tasks/sleep_wake_classification.py @@ -3,7 +3,6 @@ import neurokit2 as nk import numpy as np import pandas as pd - from scipy.ndimage import gaussian_filter1d from scipy.signal import butter, cheby2, filtfilt from scipy.stats import trim_mean @@ -17,8 +16,8 @@ class SleepWakeClassification(BaseTask): """Binary sleep-wake classification task for DREAMT wearable recordings. This task converts each DREAMT wearable recording into fixed-length epochs, - extracts physiological features from multiple sensor modalities, - augments them with temporal context, and assigns a binary sleep/wake + extracts physiological features from multiple sensor modalities, + augments them with temporal context, and assigns a binary sleep/wake label to each epoch. """ @@ -297,8 +296,7 @@ def _extract_accelerometer_magnitude_deviation_epoch_features( epochs = self._split_signal_into_epochs(magnitude, sampling_rate_hz) return [ - {"mad": float(np.mean(np.abs(epoch - np.mean(epoch))))} - for epoch in epochs + {"mad": float(np.mean(np.abs(epoch - np.mean(epoch))))} for epoch in epochs ] def _extract_temperature_epoch_features( @@ -412,18 +410,20 @@ def _extract_electrodermal_activity_epoch_features( epoch_features.append( { - "scr_amp_mean": float(np.mean(amplitudes)) - if len(amplitudes) - else 0.0, - "scr_amp_max": float(np.max(amplitudes)) - if len(amplitudes) - else 0.0, - "scr_rise_mean": float(np.mean(rise_times)) - if len(rise_times) - else 0.0, - "scr_recovery_mean": float(np.mean(recovery_times)) - if len(recovery_times) - else 0.0, + "scr_amp_mean": ( + float(np.mean(amplitudes)) if len(amplitudes) else 0.0 + ), + "scr_amp_max": ( + float(np.max(amplitudes)) if len(amplitudes) else 0.0 + ), + "scr_rise_mean": ( + float(np.mean(rise_times)) if len(rise_times) else 0.0 + ), + "scr_recovery_mean": ( + float(np.mean(recovery_times)) + if len(recovery_times) + else 0.0 + ), } ) except Exception: @@ -516,24 +516,28 @@ def _extract_sensor_signals_from_dataframe( A dictionary mapping modality names to numeric NumPy arrays. """ return { - "accelerometer_x": pd.to_numeric( - record_dataframe["ACC_X"], errors="coerce" - ).fillna(0.0).to_numpy(), - "accelerometer_y": pd.to_numeric( - record_dataframe["ACC_Y"], errors="coerce" - ).fillna(0.0).to_numpy(), - "accelerometer_z": pd.to_numeric( - record_dataframe["ACC_Z"], errors="coerce" - ).fillna(0.0).to_numpy(), - "temperature": pd.to_numeric( - record_dataframe["TEMP"], errors="coerce" - ).fillna(0.0).to_numpy(), + "accelerometer_x": pd.to_numeric(record_dataframe["ACC_X"], errors="coerce") + .fillna(0.0) + .to_numpy(), + "accelerometer_y": pd.to_numeric(record_dataframe["ACC_Y"], errors="coerce") + .fillna(0.0) + .to_numpy(), + "accelerometer_z": pd.to_numeric(record_dataframe["ACC_Z"], errors="coerce") + .fillna(0.0) + .to_numpy(), + "temperature": pd.to_numeric(record_dataframe["TEMP"], errors="coerce") + .fillna(0.0) + .to_numpy(), "blood_volume_pulse": pd.to_numeric( record_dataframe["BVP"], errors="coerce" - ).fillna(0.0).to_numpy(), + ) + .fillna(0.0) + .to_numpy(), "electrodermal_activity": pd.to_numeric( record_dataframe["EDA"], errors="coerce" - ).fillna(0.0).to_numpy(), + ) + .fillna(0.0) + .to_numpy(), } def _extract_feature_sets_for_all_modalities( From df985b3ae92ae9c4961833397d634b46bd8cdbb2 Mon Sep 17 00:00:00 2001 From: diegofariasc Date: Fri, 13 Mar 2026 04:34:15 -0600 Subject: [PATCH 17/27] test: add tests covering new SleepWakeClassification task --- tests/core/test_sleep_wake_classification.py | 205 ++++++++++++++++--- 1 file changed, 173 insertions(+), 32 deletions(-) diff --git a/tests/core/test_sleep_wake_classification.py b/tests/core/test_sleep_wake_classification.py index dd80cdd6e..7fb5a61c7 100644 --- a/tests/core/test_sleep_wake_classification.py +++ b/tests/core/test_sleep_wake_classification.py @@ -1,53 +1,194 @@ -import tempfile -from pathlib import Path - +import numpy as np import pandas as pd - from pyhealth.tasks.sleep_wake_classification import SleepWakeClassification class FakeEvent: - """A fake event class to simulate patient events.""" - def __init__(self, file_64hz): + """Minimal DREAMT-like event for task tests.""" + + def __init__(self, file_64hz=None): self.file_64hz = file_64hz + class FakePatient: - """A fake patient class to simulate patient data and events.""" - def __init__(self, patient_id, file_64hz): + """Minimal DREAMT-like patient with configurable events.""" + + def __init__(self, patient_id: str, events=None): self.patient_id = patient_id - self._events = [FakeEvent(file_64hz)] + self._events = [] if events is None else events - """Returns the list of events for the patient.""" - def get_events(self): - return self._events + def get_events(self, event_type=None): + if event_type == "dreamt_sleep": + return self._events + return [] -def test_sleep_wake_classification_runs(): - """Test that the SleepWakeClassification task runs without errors and produces expected output format.""" - tmp = tempfile.mkdtemp() - csv_path = Path(tmp) / "S001_whole_df.csv" +def _build_valid_record(num_rows: int = 8) -> pd.DataFrame: + sleep_stages = ["W", "W", "N2", "N2", "REM", "REM", "W", "W"][:num_rows] + return pd.DataFrame( + { + "TIMESTAMP": list(range(num_rows)), + "BVP": np.linspace(0.1, 0.8, num_rows), + "EDA": np.linspace(0.01, 0.08, num_rows), + "TEMP": np.linspace(36.1, 36.5, num_rows), + "ACC_X": np.linspace(1.0, 2.0, num_rows), + "ACC_Y": np.linspace(0.5, 1.5, num_rows), + "ACC_Z": np.linspace(0.2, 1.2, num_rows), + "HR": np.linspace(60, 67, num_rows), + "Sleep_Stage": sleep_stages, + } + ) - df = pd.DataFrame( + +def _build_patient_with_single_event(patient_id: str = "S001") -> FakePatient: + return FakePatient(patient_id, events=[FakeEvent("unused.csv")]) + + +def test_convert_sleep_stage_to_binary_label(): + task = SleepWakeClassification() + + assert task._convert_sleep_stage_to_binary_label("WAKE") == 1 + assert task._convert_sleep_stage_to_binary_label("W") == 1 + assert task._convert_sleep_stage_to_binary_label("N2") == 0 + assert task._convert_sleep_stage_to_binary_label("REM") == 0 + assert task._convert_sleep_stage_to_binary_label(None) is None + assert task._convert_sleep_stage_to_binary_label("UNKNOWN") is None + + +def test_split_signal_into_epochs(): + task = SleepWakeClassification(epoch_seconds=2, sampling_rate=1) + signal = np.array([0, 1, 2, 3, 4]) + + epochs = task._split_signal_into_epochs(signal, sampling_rate_hz=1) + + assert len(epochs) == 2 + assert np.array_equal(epochs[0], np.array([0, 1])) + assert np.array_equal(epochs[1], np.array([2, 3])) + +def test_extract_binary_label_for_epoch(): + task = SleepWakeClassification(epoch_seconds=2, sampling_rate=1) + record_dataframe = pd.DataFrame({"Sleep_Stage": ["W", "W", "N2", "N2"]}) + + assert task._extract_binary_label_for_epoch(record_dataframe, 0, 2) == 1 + assert task._extract_binary_label_for_epoch(record_dataframe, 1, 2) == 0 + + +def test_build_record_epoch_feature_matrix_returns_empty_when_columns_missing(): + task = SleepWakeClassification(epoch_seconds=2, sampling_rate=1) + record_dataframe = pd.DataFrame( { - "TIMESTAMP": [0, 1, 2, 3], - "BVP": [0.1, 0.2, 0.3, 0.4], - "EDA": [0.01, 0.02, 0.03, 0.04], - "TEMP": [36.1, 36.1, 36.2, 36.2], - "ACC_X": [1, 1, 2, 2], - "ACC_Y": [0, 0, 1, 1], - "ACC_Z": [0, 0, 1, 1], - "HR": [60, 60, 61, 61], - "Sleep Stage": ["Wake", "N2", "REM", "Wake"], + "ACC_X": [1, 2, 3, 4], + "ACC_Y": [1, 2, 3, 4], + "ACC_Z": [1, 2, 3, 4], + "TEMP": [36, 36, 36, 36], + "Sleep_Stage": ["W", "W", "N2", "N2"], } ) - df.to_csv(csv_path, index=False) + assert task._build_record_epoch_feature_matrix(record_dataframe) == [] + + +def test_load_wearable_record_dataframe_returns_none_for_missing_file(): + task = SleepWakeClassification() + + assert task._load_wearable_record_dataframe(FakeEvent(file_64hz="missing.csv")) is None + assert task._load_wearable_record_dataframe(FakeEvent(file_64hz=None)) is None + + +def test_task_returns_empty_when_patient_has_no_sleep_events(): + task = SleepWakeClassification(epoch_seconds=2, sampling_rate=1) + patient = FakePatient("S001", events=[]) + + assert task(patient) == [] + + +def test_task_returns_empty_when_sleep_stage_column_is_missing(monkeypatch): + task = SleepWakeClassification(epoch_seconds=2, sampling_rate=1) + record_dataframe = _build_valid_record().drop(columns=["Sleep_Stage"]) + patient = _build_patient_with_single_event() + + monkeypatch.setattr( + task, + "_load_wearable_record_dataframe", + lambda event: record_dataframe, + ) + assert task(patient) == [] + + +def test_task_skips_epochs_with_unsupported_labels(monkeypatch): task = SleepWakeClassification(epoch_seconds=2, sampling_rate=1) - patient = FakePatient("S001", str(csv_path)) + record_dataframe = _build_valid_record(num_rows=4) + record_dataframe["Sleep_Stage"] = ["X", "X", "N2", "N2"] + patient = _build_patient_with_single_event() + + monkeypatch.setattr( + task, + "_load_wearable_record_dataframe", + lambda event: record_dataframe, + ) + monkeypatch.setattr( + task, + "_build_record_epoch_feature_matrix", + lambda df: [[1.0, 2.0], [3.0, 4.0]], + ) + + samples = task(patient) + + assert len(samples) == 1 + assert samples[0]["epoch_index"] == 1 + assert samples[0]["label"] == 0 + + +def test_task_runs_full_flow_with_lightweight_feature_stub(monkeypatch): + task = SleepWakeClassification(epoch_seconds=2, sampling_rate=1) + record_dataframe = _build_valid_record(num_rows=8) + patient = _build_patient_with_single_event() + + monkeypatch.setattr( + task, + "_load_wearable_record_dataframe", + lambda event: record_dataframe, + ) + monkeypatch.setattr( + task, + "_build_record_epoch_feature_matrix", + lambda df: [ + [1.0, 10.0], + [2.0, 20.0], + [3.0, 30.0], + [4.0, 40.0], + ], + ) + + samples = task(patient) + + assert len(samples) == 4 + assert all("features" in sample for sample in samples) + assert all("label" in sample for sample in samples) + assert all("record_id" in sample for sample in samples) + assert samples[0]["record_id"] == "S001-event0-epoch0" + assert samples[0]["label"] == 1 + assert samples[1]["label"] == 0 + assert samples[2]["label"] == 0 + assert samples[3]["label"] == 1 + + +def test_task_uses_minimum_epoch_count_between_labels_and_features(monkeypatch): + task = SleepWakeClassification(epoch_seconds=2, sampling_rate=1) + record_dataframe = _build_valid_record(num_rows=8) + patient = _build_patient_with_single_event() + + monkeypatch.setattr( + task, + "_load_wearable_record_dataframe", + lambda _: record_dataframe, + ) + monkeypatch.setattr( + task, + "_build_record_epoch_feature_matrix", + lambda _: [[1.0], [2.0]], + ) samples = task(patient) - assert isinstance(samples, list) assert len(samples) == 2 - assert "features" in samples[0] - assert len(samples[0]["features"]) > 0 - assert samples[0]["label"] in [0, 1] \ No newline at end of file + assert [sample["epoch_index"] for sample in samples] == [0, 1] From 9357506e988798b2f9bdbf5db6e9ea93f5087617 Mon Sep 17 00:00:00 2001 From: diegofariasc Date: Fri, 13 Mar 2026 04:40:53 -0600 Subject: [PATCH 18/27] doc: add docstrings to tests --- tests/core/test_sleep_wake_classification.py | 118 +++++++++++++++++-- 1 file changed, 108 insertions(+), 10 deletions(-) diff --git a/tests/core/test_sleep_wake_classification.py b/tests/core/test_sleep_wake_classification.py index 7fb5a61c7..a98b07048 100644 --- a/tests/core/test_sleep_wake_classification.py +++ b/tests/core/test_sleep_wake_classification.py @@ -4,25 +4,52 @@ class FakeEvent: - """Minimal DREAMT-like event for task tests.""" + """Creates a minimal DREAMT-like event for task tests. + + Args: + file_64hz: Path to the wearable CSV associated with the event. + """ def __init__(self, file_64hz=None): self.file_64hz = file_64hz class FakePatient: - """Minimal DREAMT-like patient with configurable events.""" + """Creates a minimal DREAMT-like patient with configurable events. + + Args: + patient_id: Identifier used by the task in generated samples. + events: Synthetic sleep events available for the patient. + """ def __init__(self, patient_id: str, events=None): self.patient_id = patient_id self._events = [] if events is None else events def get_events(self, event_type=None): + """Returns synthetic events for the requested DREAMT event type. + + Args: + event_type: Event type requested by the task. + + Returns: + The stored sleep events for ``"dreamt_sleep"``, otherwise an empty + list. + """ if event_type == "dreamt_sleep": return self._events return [] -def _build_valid_record(num_rows: int = 8) -> pd.DataFrame: + +def _build_valid_record(num_rows: int = 4) -> pd.DataFrame: + """Builds a small synthetic wearable record for task tests. + + Args: + num_rows: Number of rows to include in the record. + + Returns: + A DREAMT-like wearable DataFrame with the required sensor columns. + """ sleep_stages = ["W", "W", "N2", "N2", "REM", "REM", "W", "W"][:num_rows] return pd.DataFrame( { @@ -40,10 +67,23 @@ def _build_valid_record(num_rows: int = 8) -> pd.DataFrame: def _build_patient_with_single_event(patient_id: str = "S001") -> FakePatient: + """Builds a synthetic patient with one sleep event. + + Args: + patient_id: Identifier to assign to the synthetic patient. + + Returns: + A synthetic patient with one DREAMT-like sleep event. + """ return FakePatient(patient_id, events=[FakeEvent("unused.csv")]) def test_convert_sleep_stage_to_binary_label(): + """Tests that raw sleep stages map to the expected binary labels. + + Returns: + None. + """ task = SleepWakeClassification() assert task._convert_sleep_stage_to_binary_label("WAKE") == 1 @@ -55,6 +95,11 @@ def test_convert_sleep_stage_to_binary_label(): def test_split_signal_into_epochs(): + """Tests that only complete fixed-length epochs are returned. + + Returns: + None. + """ task = SleepWakeClassification(epoch_seconds=2, sampling_rate=1) signal = np.array([0, 1, 2, 3, 4]) @@ -64,7 +109,13 @@ def test_split_signal_into_epochs(): assert np.array_equal(epochs[0], np.array([0, 1])) assert np.array_equal(epochs[1], np.array([2, 3])) + def test_extract_binary_label_for_epoch(): + """Tests epoch-level label extraction from sleep-stage mode. + + Returns: + None. + """ task = SleepWakeClassification(epoch_seconds=2, sampling_rate=1) record_dataframe = pd.DataFrame({"Sleep_Stage": ["W", "W", "N2", "N2"]}) @@ -73,6 +124,11 @@ def test_extract_binary_label_for_epoch(): def test_build_record_epoch_feature_matrix_returns_empty_when_columns_missing(): + """Tests that missing sensor columns yield an empty feature matrix. + + Returns: + None. + """ task = SleepWakeClassification(epoch_seconds=2, sampling_rate=1) record_dataframe = pd.DataFrame( { @@ -88,6 +144,11 @@ def test_build_record_epoch_feature_matrix_returns_empty_when_columns_missing(): def test_load_wearable_record_dataframe_returns_none_for_missing_file(): + """Tests that missing wearable files return ``None``. + + Returns: + None. + """ task = SleepWakeClassification() assert task._load_wearable_record_dataframe(FakeEvent(file_64hz="missing.csv")) is None @@ -95,6 +156,11 @@ def test_load_wearable_record_dataframe_returns_none_for_missing_file(): def test_task_returns_empty_when_patient_has_no_sleep_events(): + """Tests that a patient without sleep events produces no samples. + + Returns: + None. + """ task = SleepWakeClassification(epoch_seconds=2, sampling_rate=1) patient = FakePatient("S001", events=[]) @@ -102,6 +168,15 @@ def test_task_returns_empty_when_patient_has_no_sleep_events(): def test_task_returns_empty_when_sleep_stage_column_is_missing(monkeypatch): + """Tests that records missing ``Sleep_Stage`` are skipped. + + Args: + monkeypatch: Pytest fixture used to replace file loading with a + synthetic DataFrame. + + Returns: + None. + """ task = SleepWakeClassification(epoch_seconds=2, sampling_rate=1) record_dataframe = _build_valid_record().drop(columns=["Sleep_Stage"]) patient = _build_patient_with_single_event() @@ -115,6 +190,15 @@ def test_task_returns_empty_when_sleep_stage_column_is_missing(monkeypatch): def test_task_skips_epochs_with_unsupported_labels(monkeypatch): + """Tests that epochs with unsupported labels are not emitted. + + Args: + monkeypatch: Pytest fixture used to replace record loading and feature + extraction with synthetic values. + + Returns: + None. + """ task = SleepWakeClassification(epoch_seconds=2, sampling_rate=1) record_dataframe = _build_valid_record(num_rows=4) record_dataframe["Sleep_Stage"] = ["X", "X", "N2", "N2"] @@ -139,8 +223,17 @@ def test_task_skips_epochs_with_unsupported_labels(monkeypatch): def test_task_runs_full_flow_with_lightweight_feature_stub(monkeypatch): + """Tests end-to-end sample generation with lightweight mocked features. + + Args: + monkeypatch: Pytest fixture used to replace record loading and avoid + expensive feature extraction. + + Returns: + None. + """ task = SleepWakeClassification(epoch_seconds=2, sampling_rate=1) - record_dataframe = _build_valid_record(num_rows=8) + record_dataframe = _build_valid_record(num_rows=4) patient = _build_patient_with_single_event() monkeypatch.setattr( @@ -154,27 +247,32 @@ def test_task_runs_full_flow_with_lightweight_feature_stub(monkeypatch): lambda df: [ [1.0, 10.0], [2.0, 20.0], - [3.0, 30.0], - [4.0, 40.0], ], ) samples = task(patient) - assert len(samples) == 4 + assert len(samples) == 2 assert all("features" in sample for sample in samples) assert all("label" in sample for sample in samples) assert all("record_id" in sample for sample in samples) assert samples[0]["record_id"] == "S001-event0-epoch0" assert samples[0]["label"] == 1 assert samples[1]["label"] == 0 - assert samples[2]["label"] == 0 - assert samples[3]["label"] == 1 def test_task_uses_minimum_epoch_count_between_labels_and_features(monkeypatch): + """Tests that sample count is bounded by the shorter epoch source. + + Args: + monkeypatch: Pytest fixture used to control how many epoch features + the task returns during the test. + + Returns: + None. + """ task = SleepWakeClassification(epoch_seconds=2, sampling_rate=1) - record_dataframe = _build_valid_record(num_rows=8) + record_dataframe = _build_valid_record(num_rows=4) patient = _build_patient_with_single_event() monkeypatch.setattr( From f04f6026a90f2c85fa01ec5c993c4dabcf171978 Mon Sep 17 00:00:00 2001 From: diegofariasc Date: Fri, 13 Mar 2026 04:43:44 -0600 Subject: [PATCH 19/27] refactor: use black+isort to autoformat test code following PEP88 --- tests/core/test_sleep_wake_classification.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/tests/core/test_sleep_wake_classification.py b/tests/core/test_sleep_wake_classification.py index a98b07048..3dfc8fcb0 100644 --- a/tests/core/test_sleep_wake_classification.py +++ b/tests/core/test_sleep_wake_classification.py @@ -1,5 +1,6 @@ import numpy as np import pandas as pd + from pyhealth.tasks.sleep_wake_classification import SleepWakeClassification @@ -151,7 +152,9 @@ def test_load_wearable_record_dataframe_returns_none_for_missing_file(): """ task = SleepWakeClassification() - assert task._load_wearable_record_dataframe(FakeEvent(file_64hz="missing.csv")) is None + assert ( + task._load_wearable_record_dataframe(FakeEvent(file_64hz="missing.csv")) is None + ) assert task._load_wearable_record_dataframe(FakeEvent(file_64hz=None)) is None From a98a1b48cc867f925732de669e9ab11a2290b7da Mon Sep 17 00:00:00 2001 From: diegofariasc Date: Fri, 13 Mar 2026 04:51:05 -0600 Subject: [PATCH 20/27] refactor: use specific Exception types instead of general Exception --- pyhealth/tasks/sleep_wake_classification.py | 26 ++++++++++++++++++--- 1 file changed, 23 insertions(+), 3 deletions(-) diff --git a/pyhealth/tasks/sleep_wake_classification.py b/pyhealth/tasks/sleep_wake_classification.py index 92cd2c41e..be5abb3fa 100644 --- a/pyhealth/tasks/sleep_wake_classification.py +++ b/pyhealth/tasks/sleep_wake_classification.py @@ -1,3 +1,4 @@ +import logging from typing import Dict, List import neurokit2 as nk @@ -24,6 +25,7 @@ class SleepWakeClassification(BaseTask): task_name = "SleepWakeClassification" input_schema = {"features": "tensor"} output_schema = {"label": "binary"} + logger = logging.getLogger(__name__) def __init__(self, epoch_seconds: int = 30, sampling_rate: int = 64): """Initializes the sleep-wake classification task. @@ -364,7 +366,11 @@ def _extract_blood_volume_pulse_epoch_features( "pnn50": float(hrv["HRV_pNN50"].values[0]), } ) - except Exception: + except (KeyError, IndexError, TypeError, ValueError) as error: + self.logger.warning( + "Skipping BVP epoch due to feature extraction error: %s", + error, + ) epoch_features.append( {"rmssd": np.nan, "sdnn": np.nan, "pnn50": np.nan} ) @@ -426,7 +432,11 @@ def _extract_electrodermal_activity_epoch_features( ), } ) - except Exception: + except (KeyError, IndexError, TypeError, ValueError) as error: + self.logger.warning( + "Skipping EDA epoch due to feature extraction error: %s", + error, + ) epoch_features.append( { "scr_amp_mean": np.nan, @@ -717,7 +727,17 @@ def _load_wearable_record_dataframe(self, event) -> pd.DataFrame | None: try: return pd.read_csv(file_path) - except Exception: + except ( + FileNotFoundError, + OSError, + pd.errors.EmptyDataError, + pd.errors.ParserError, + ) as error: + self.logger.warning( + "Skipping DREAMT record '%s' due to file loading error: %s", + file_path, + error, + ) return None def _extract_binary_label_for_epoch( From a2090a2c6931d068d6a35d4656793fd171e34f7b Mon Sep 17 00:00:00 2001 From: diegofariasc Date: Fri, 13 Mar 2026 05:46:11 -0600 Subject: [PATCH 21/27] refactor: generalize sleep-wake classification example --- examples/sleep_wake_classification.py | 127 ++++++++++++++++++-------- 1 file changed, 91 insertions(+), 36 deletions(-) diff --git a/examples/sleep_wake_classification.py b/examples/sleep_wake_classification.py index 96d1dbd3c..86c666fff 100644 --- a/examples/sleep_wake_classification.py +++ b/examples/sleep_wake_classification.py @@ -1,31 +1,73 @@ from collections import Counter -import numpy as np import lightgbm as lgb +import numpy as np +from sklearn.ensemble import RandomForestClassifier from sklearn.impute import SimpleImputer -from sklearn.metrics import accuracy_score, f1_score, roc_auc_score, average_precision_score -from sklearn.model_selection import GroupShuffleSplit from sklearn.linear_model import LogisticRegression -from sklearn.ensemble import RandomForestClassifier +from sklearn.metrics import ( + accuracy_score, + average_precision_score, + f1_score, + roc_auc_score, +) from pyhealth.datasets import DREAMTDataset from pyhealth.tasks.sleep_wake_classification import SleepWakeClassification +# Configuration +DREAMT_ROOT = "REPLACE_WITH_DREAMT_ROOT" +TRAIN_PATIENT_IDS = ["S028", "S062", "S078"] +EVAL_PATIENT_IDS = ["S081", "S099"] +EPOCH_SECONDS = 30 +SAMPLING_RATE = 64 + + +def split_samples_by_patient_ids(X, y, groups): + """Splits samples into train and evaluation sets using patient IDs. + + Args: + X: Feature matrix. + y: Binary label vector. + groups: Patient identifier for each sample. + + Returns: + Train and evaluation features, labels, and patient groups. + """ + train_mask = np.isin(groups, TRAIN_PATIENT_IDS) + eval_mask = np.isin(groups, EVAL_PATIENT_IDS) + + if not np.any(train_mask): + raise ValueError("No samples found for TRAIN_PATIENT_IDS.") + if not np.any(eval_mask): + raise ValueError("No samples found for EVAL_PATIENT_IDS.") + + return ( + X[train_mask], + X[eval_mask], + y[train_mask], + y[eval_mask], + groups[train_mask], + groups[eval_mask], + ) -def run_experiment(X, y, groups, name): - splitter = GroupShuffleSplit(n_splits=1, test_size=0.4, random_state=42) - train_idx, test_idx = next(splitter.split(X, y, groups=groups)) - X_train, X_test = X[train_idx], X[test_idx] - y_train, y_test = y[train_idx], y[test_idx] - g_train, g_test = groups[train_idx], groups[test_idx] +def run_experiment(X, y, groups, name): + # Split samples into train and evaluation sets + X_train, X_test, y_train, y_test, g_train, g_test = split_samples_by_patient_ids( + X, + y, + groups, + ) + # Report dataset statistics print(f"\n=== {name} ===") print("train patients:", sorted(set(g_train))) - print("test patients:", sorted(set(g_test))) + print("evaluation patients:", sorted(set(g_test))) print("train size:", len(X_train)) - print("test size:", len(X_test)) + print("evaluation size:", len(X_test)) + # Remove features that are all NaN in the training set non_all_nan_cols = ~np.isnan(X_train).all(axis=0) X_train = X_train[:, non_all_nan_cols] X_test = X_test[:, non_all_nan_cols] @@ -36,6 +78,7 @@ def run_experiment(X, y, groups, name): X_train = imputer.fit_transform(X_train) X_test = imputer.transform(X_test) + # Train a LightGBM model on the current feature subset. train_data = lgb.Dataset(X_train, label=y_train) test_data = lgb.Dataset(X_test, label=y_test, reference=train_data) @@ -63,6 +106,7 @@ def run_experiment(X, y, groups, name): y_prob = model.predict(X_test) y_pred = (y_prob >= 0.3).astype(int) + # Report standard binary classification metrics. print("Accuracy:", accuracy_score(y_test, y_pred)) print("F1:", f1_score(y_test, y_pred)) print("AUROC:", roc_auc_score(y_test, y_prob)) @@ -70,16 +114,16 @@ def run_experiment(X, y, groups, name): def run_model_comparison(X, y, groups): - splitter = GroupShuffleSplit(n_splits=1, test_size=0.4, random_state=42) - train_idx, test_idx = next(splitter.split(X, y, groups=groups)) - - X_train, X_test = X[train_idx], X[test_idx] - y_train, y_test = y[train_idx], y[test_idx] - g_train, g_test = groups[train_idx], groups[test_idx] + # Use the same predefined patient split to compare alternative models + X_train, X_test, y_train, y_test, g_train, g_test = split_samples_by_patient_ids( + X, + y, + groups, + ) print("\n=== Model comparison (ALL modalities + temporal) ===") print("train patients:", sorted(set(g_train))) - print("test patients:", sorted(set(g_test))) + print("evaluation patients:", sorted(set(g_test))) non_all_nan_cols = ~np.isnan(X_train).all(axis=0) X_train = X_train[:, non_all_nan_cols] @@ -89,17 +133,17 @@ def run_model_comparison(X, y, groups): X_train = imputer.fit_transform(X_train) X_test = imputer.transform(X_test) + # Compare logistic regression and random forest on the full feature set. models = { "LogisticRegression": LogisticRegression(max_iter=1000), "RandomForest": RandomForestClassifier( n_estimators=200, random_state=42, - n_jobs=-1 + n_jobs=-1, ), } for name, model in models.items(): - model.fit(X_train, y_train) if hasattr(model, "predict_proba"): @@ -115,46 +159,57 @@ def run_model_comparison(X, y, groups): print("AUROC:", roc_auc_score(y_test, y_prob)) print("AUPRC:", average_precision_score(y_test, y_prob)) -def main(): - root = r"C:\Users\faria\OneDrive - University of Illinois - Urbana\CS-598-DLH\dreamt-replication\data\DREAMT" - dataset = DREAMTDataset(root=root) - task = SleepWakeClassification() - - selected_patient_ids = ["S028", "S062", "S078", "S081", "S099"] +def main(): + if DREAMT_ROOT == "REPLACE_WITH_DREAMT_ROOT": + raise ValueError( + "Please set DREAMT_ROOT in examples/sleep_wake_classification.py " + "before running this example.", + ) + + dataset = DREAMTDataset(root=DREAMT_ROOT) + task = SleepWakeClassification( + epoch_seconds=EPOCH_SECONDS, + sampling_rate=SAMPLING_RATE, + ) + # Convert the selected DREAMT patients into epoch-level sleep/wake samples. all_samples = [] + selected_patient_ids = TRAIN_PATIENT_IDS + EVAL_PATIENT_IDS for patient_id in selected_patient_ids: patient = dataset.get_patient(patient_id) samples = task(patient) - print(patient_id, len(samples)) + print(f"patient {patient_id}: {len(samples)} epoch samples") all_samples.extend(samples) - print("total samples:", len(all_samples)) + print("total epoch samples:", len(all_samples)) print("label counts:", Counter(s["label"] for s in all_samples)) + # Turn the task samples into arrays for training and evaluation. X_all = np.array([s["features"] for s in all_samples], dtype=float) y = np.array([s["label"] for s in all_samples], dtype=int) groups = np.array([s["patient_id"] for s in all_samples]) print("X_all shape:", X_all.shape) - # base only = first 21 features + # Keep only the base per-epoch features without temporal augmentation. X_base = X_all[:, :21] - # base + temporal = all features + # Keep the full feature matrix, including temporal context features. X_temporal = X_all - acc_idx = list(range(0, 10)) # ACC_X, ACC_Y, ACC_Z, ACC_MAD - temp_idx = list(range(10, 14)) # TEMP - bvp_idx = list(range(14, 17)) # BVP - eda_idx = list(range(17, 21)) # EDA + # Group feature indices by modality for the ablation experiments. + acc_idx = list(range(0, 10)) + temp_idx = list(range(10, 14)) + bvp_idx = list(range(14, 17)) + eda_idx = list(range(17, 21)) X_acc = X_base[:, acc_idx] X_acc_temp = X_base[:, acc_idx + temp_idx] X_acc_temp_bvp = X_base[:, acc_idx + temp_idx + bvp_idx] X_all_modalities = X_base[:, acc_idx + temp_idx + bvp_idx + eda_idx] + # Run experiments using different features run_experiment(X_acc, y, groups, "ACC only") run_experiment(X_acc_temp, y, groups, "ACC + TEMP") run_experiment(X_acc_temp_bvp, y, groups, "ACC + TEMP + BVP") @@ -164,4 +219,4 @@ def main(): if __name__ == "__main__": - main() \ No newline at end of file + main() From 1ba1795541987b3f8eeed7d3b44e17db47ed5c97 Mon Sep 17 00:00:00 2001 From: diegofariasc Date: Fri, 13 Mar 2026 05:57:53 -0600 Subject: [PATCH 22/27] doc: add file header to sleep_wake_classification.py --- pyhealth/tasks/sleep_wake_classification.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/pyhealth/tasks/sleep_wake_classification.py b/pyhealth/tasks/sleep_wake_classification.py index be5abb3fa..90d8a473d 100644 --- a/pyhealth/tasks/sleep_wake_classification.py +++ b/pyhealth/tasks/sleep_wake_classification.py @@ -1,3 +1,13 @@ +"""Diego Farias Castro (diegof4@illinois.edu). + +Paper: Addressing Wearable Sleep Tracking Inequity: +A New Dataset and Novel Methods for a Population with Sleep Disorders +Paper link: https://proceedings.mlr.press/v248/wang24a.html + +Implements a DREAMT sleep-wake classification task using multimodal wearable +features and temporal context augmentation. +""" + import logging from typing import Dict, List From cbc98deefffd6c815844b521d1a0588a3c12bd5e Mon Sep 17 00:00:00 2001 From: diegofariasc Date: Sat, 14 Mar 2026 00:55:45 -0600 Subject: [PATCH 23/27] refactor: improve formatting of results in sleep_wake_classification example --- examples/sleep_wake_classification.py | 137 +++++++++++++++++++++----- 1 file changed, 113 insertions(+), 24 deletions(-) diff --git a/examples/sleep_wake_classification.py b/examples/sleep_wake_classification.py index 86c666fff..f9bfff46c 100644 --- a/examples/sleep_wake_classification.py +++ b/examples/sleep_wake_classification.py @@ -1,7 +1,13 @@ from collections import Counter +from contextlib import redirect_stderr, redirect_stdout +import io +import logging +import warnings import lightgbm as lgb import numpy as np + +from sklearn.exceptions import ConvergenceWarning from sklearn.ensemble import RandomForestClassifier from sklearn.impute import SimpleImputer from sklearn.linear_model import LogisticRegression @@ -12,6 +18,7 @@ roc_auc_score, ) + from pyhealth.datasets import DREAMTDataset from pyhealth.tasks.sleep_wake_classification import SleepWakeClassification @@ -22,6 +29,72 @@ EPOCH_SECONDS = 30 SAMPLING_RATE = 64 +# Console formatting codes +RESET = "\033[0m" +BOLD = "\033[1m" +CYAN = "\033[36m" +GREEN = "\033[32m" +YELLOW = "\033[33m" + + +def format_section(title): + """Formats a section title for console output. + + Args: + title: Section title to format. + + Returns: + A colorized section title string. + """ + return f"\n{BOLD}{CYAN}{title}{RESET}" + + +def format_patient_ids(patient_ids): + """Formats patient IDs for readable console output. + + Args: + patient_ids: Iterable of patient identifiers. + + Returns: + A comma-separated string of patient IDs. + """ + return ", ".join(sorted(str(patient_id) for patient_id in set(patient_ids))) + + +def print_metric(name, value): + """Prints a metric with consistent console formatting. + + Args: + name: Metric name. + value: Metric value. + """ + print(f" {name:<16}{value:.4f}") + + +def summarize_label_counts(labels): + """Builds a readable sleep/wake label summary. + + Args: + labels: Iterable of binary labels. + + Returns: + A formatted label count string. + """ + counts = Counter(labels) + return ( + f"sleep (0): {counts.get(0, 0)}, " + f"wake (1): {counts.get(1, 0)}" + ) + + +def configure_clean_output(): + """Suppresses noisy warnings and logs for a cleaner example run.""" + warnings.filterwarnings("ignore", category=ConvergenceWarning) + logging.getLogger("pyhealth").setLevel(logging.ERROR) + logging.getLogger("pyhealth.tasks.sleep_wake_classification").setLevel( + logging.ERROR + ) + def split_samples_by_patient_ids(X, y, groups): """Splits samples into train and evaluation sets using patient IDs. @@ -61,18 +134,18 @@ def run_experiment(X, y, groups, name): ) # Report dataset statistics - print(f"\n=== {name} ===") - print("train patients:", sorted(set(g_train))) - print("evaluation patients:", sorted(set(g_test))) - print("train size:", len(X_train)) - print("evaluation size:", len(X_test)) + print(format_section(f"Ablation: {name}")) + print(f"{BOLD}Train patients:{RESET} {format_patient_ids(g_train)}") + print(f"{BOLD}Eval patients:{RESET} {format_patient_ids(g_test)}") + print(f"{BOLD}Train samples:{RESET} {len(X_train)}") + print(f"{BOLD}Eval samples:{RESET} {len(X_test)}") # Remove features that are all NaN in the training set non_all_nan_cols = ~np.isnan(X_train).all(axis=0) X_train = X_train[:, non_all_nan_cols] X_test = X_test[:, non_all_nan_cols] - print("kept features:", X_train.shape[1]) + print(f"{BOLD}Feature count:{RESET} {X_train.shape[1]}") imputer = SimpleImputer(strategy="median") X_train = imputer.fit_transform(X_train) @@ -107,10 +180,10 @@ def run_experiment(X, y, groups, name): y_pred = (y_prob >= 0.3).astype(int) # Report standard binary classification metrics. - print("Accuracy:", accuracy_score(y_test, y_pred)) - print("F1:", f1_score(y_test, y_pred)) - print("AUROC:", roc_auc_score(y_test, y_prob)) - print("AUPRC:", average_precision_score(y_test, y_prob)) + print_metric("Accuracy", accuracy_score(y_test, y_pred)) + print_metric("F1", f1_score(y_test, y_pred)) + print_metric("AUROC", roc_auc_score(y_test, y_prob)) + print_metric("AUPRC", average_precision_score(y_test, y_prob)) def run_model_comparison(X, y, groups): @@ -121,9 +194,9 @@ def run_model_comparison(X, y, groups): groups, ) - print("\n=== Model comparison (ALL modalities + temporal) ===") - print("train patients:", sorted(set(g_train))) - print("evaluation patients:", sorted(set(g_test))) + print(format_section("Model Comparison: ALL modalities + temporal")) + print(f"{BOLD}Train patients:{RESET} {format_patient_ids(g_train)}") + print(f"{BOLD}Eval patients:{RESET} {format_patient_ids(g_test)}") non_all_nan_cols = ~np.isnan(X_train).all(axis=0) X_train = X_train[:, non_all_nan_cols] @@ -153,44 +226,60 @@ def run_model_comparison(X, y, groups): y_pred = (y_prob >= 0.3).astype(int) - print(f"\n{name}") - print("Accuracy:", accuracy_score(y_test, y_pred)) - print("F1:", f1_score(y_test, y_pred)) - print("AUROC:", roc_auc_score(y_test, y_prob)) - print("AUPRC:", average_precision_score(y_test, y_prob)) + print(f"\n{YELLOW}{name}{RESET}") + print_metric("Accuracy", accuracy_score(y_test, y_pred)) + print_metric("F1", f1_score(y_test, y_pred)) + print_metric("AUROC", roc_auc_score(y_test, y_prob)) + print_metric("AUPRC", average_precision_score(y_test, y_prob)) def main(): + configure_clean_output() + if DREAMT_ROOT == "REPLACE_WITH_DREAMT_ROOT": raise ValueError( "Please set DREAMT_ROOT in examples/sleep_wake_classification.py " "before running this example.", ) - dataset = DREAMTDataset(root=DREAMT_ROOT) + # Suppress verbose dataset initialization messages and print a cleaner summary. + with redirect_stdout(io.StringIO()), redirect_stderr(io.StringIO()): + dataset = DREAMTDataset(root=DREAMT_ROOT) task = SleepWakeClassification( epoch_seconds=EPOCH_SECONDS, sampling_rate=SAMPLING_RATE, ) + print(format_section("DREAMT Sleep-Wake Classification Example")) + print(f"{BOLD}Dataset root:{RESET} {DREAMT_ROOT}") + print( + f"{BOLD}Train patients:{RESET} {', '.join(TRAIN_PATIENT_IDS)}" + ) + print( + f"{BOLD}Eval patients:{RESET} {', '.join(EVAL_PATIENT_IDS)}" + ) + # Convert the selected DREAMT patients into epoch-level sleep/wake samples. all_samples = [] selected_patient_ids = TRAIN_PATIENT_IDS + EVAL_PATIENT_IDS for patient_id in selected_patient_ids: patient = dataset.get_patient(patient_id) samples = task(patient) - print(f"patient {patient_id}: {len(samples)} epoch samples") + print(f" patient {patient_id:<4} -> {len(samples)} epoch samples") all_samples.extend(samples) - print("total epoch samples:", len(all_samples)) - print("label counts:", Counter(s["label"] for s in all_samples)) + print(f"{BOLD}Total epoch samples:{RESET} {len(all_samples)}") + print( + f"{BOLD}Label counts:{RESET} " + f"{summarize_label_counts(sample['label'] for sample in all_samples)}" + ) # Turn the task samples into arrays for training and evaluation. X_all = np.array([s["features"] for s in all_samples], dtype=float) y = np.array([s["label"] for s in all_samples], dtype=int) groups = np.array([s["patient_id"] for s in all_samples]) - print("X_all shape:", X_all.shape) + print(f"{BOLD}Feature matrix:{RESET} {X_all.shape[0]} samples x {X_all.shape[1]} features") # Keep only the base per-epoch features without temporal augmentation. X_base = X_all[:, :21] @@ -209,7 +298,7 @@ def main(): X_acc_temp_bvp = X_base[:, acc_idx + temp_idx + bvp_idx] X_all_modalities = X_base[:, acc_idx + temp_idx + bvp_idx + eda_idx] - # Run experiments using different features + # Run experiments using different feature groups. run_experiment(X_acc, y, groups, "ACC only") run_experiment(X_acc_temp, y, groups, "ACC + TEMP") run_experiment(X_acc_temp_bvp, y, groups, "ACC + TEMP + BVP") From e8c70244ae28d9b7eef038a41c87a5edfc6a91af Mon Sep 17 00:00:00 2001 From: diegofariasc Date: Sat, 14 Mar 2026 00:57:39 -0600 Subject: [PATCH 24/27] refactor: use black+issort on example study --- examples/sleep_wake_classification.py | 25 +++++++++---------------- 1 file changed, 9 insertions(+), 16 deletions(-) diff --git a/examples/sleep_wake_classification.py b/examples/sleep_wake_classification.py index f9bfff46c..e47fa282b 100644 --- a/examples/sleep_wake_classification.py +++ b/examples/sleep_wake_classification.py @@ -1,14 +1,13 @@ -from collections import Counter -from contextlib import redirect_stderr, redirect_stdout import io import logging import warnings +from collections import Counter +from contextlib import redirect_stderr, redirect_stdout import lightgbm as lgb import numpy as np - -from sklearn.exceptions import ConvergenceWarning from sklearn.ensemble import RandomForestClassifier +from sklearn.exceptions import ConvergenceWarning from sklearn.impute import SimpleImputer from sklearn.linear_model import LogisticRegression from sklearn.metrics import ( @@ -18,7 +17,6 @@ roc_auc_score, ) - from pyhealth.datasets import DREAMTDataset from pyhealth.tasks.sleep_wake_classification import SleepWakeClassification @@ -81,10 +79,7 @@ def summarize_label_counts(labels): A formatted label count string. """ counts = Counter(labels) - return ( - f"sleep (0): {counts.get(0, 0)}, " - f"wake (1): {counts.get(1, 0)}" - ) + return f"sleep (0): {counts.get(0, 0)}, " f"wake (1): {counts.get(1, 0)}" def configure_clean_output(): @@ -252,12 +247,8 @@ def main(): print(format_section("DREAMT Sleep-Wake Classification Example")) print(f"{BOLD}Dataset root:{RESET} {DREAMT_ROOT}") - print( - f"{BOLD}Train patients:{RESET} {', '.join(TRAIN_PATIENT_IDS)}" - ) - print( - f"{BOLD}Eval patients:{RESET} {', '.join(EVAL_PATIENT_IDS)}" - ) + print(f"{BOLD}Train patients:{RESET} {', '.join(TRAIN_PATIENT_IDS)}") + print(f"{BOLD}Eval patients:{RESET} {', '.join(EVAL_PATIENT_IDS)}") # Convert the selected DREAMT patients into epoch-level sleep/wake samples. all_samples = [] @@ -279,7 +270,9 @@ def main(): y = np.array([s["label"] for s in all_samples], dtype=int) groups = np.array([s["patient_id"] for s in all_samples]) - print(f"{BOLD}Feature matrix:{RESET} {X_all.shape[0]} samples x {X_all.shape[1]} features") + print( + f"{BOLD}Feature matrix:{RESET} {X_all.shape[0]} samples x {X_all.shape[1]} features" + ) # Keep only the base per-epoch features without temporal augmentation. X_base = X_all[:, :21] From 11447e4b17c64751435caff388aac1bfa187814f Mon Sep 17 00:00:00 2001 From: diegofariasc Date: Sat, 14 Mar 2026 01:04:59 -0600 Subject: [PATCH 25/27] refactor: rename sleep_wake_classification example to dreamt_sleep_wake_classification_lightgbm.py --- ...sification.py => dreamt_sleep_wake_classification_lightgbm.py} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename examples/{sleep_wake_classification.py => dreamt_sleep_wake_classification_lightgbm.py} (100%) diff --git a/examples/sleep_wake_classification.py b/examples/dreamt_sleep_wake_classification_lightgbm.py similarity index 100% rename from examples/sleep_wake_classification.py rename to examples/dreamt_sleep_wake_classification_lightgbm.py From 489d4435fcbce72e04645f653f6c72a5898e1941 Mon Sep 17 00:00:00 2001 From: diegofariasc Date: Sat, 14 Mar 2026 01:20:01 -0600 Subject: [PATCH 26/27] refactor: improve typing in sleep_wake_classification task and example --- ...eamt_sleep_wake_classification_lightgbm.py | 49 +++++++++++++++---- pyhealth/tasks/sleep_wake_classification.py | 36 ++++++++++---- 2 files changed, 66 insertions(+), 19 deletions(-) diff --git a/examples/dreamt_sleep_wake_classification_lightgbm.py b/examples/dreamt_sleep_wake_classification_lightgbm.py index e47fa282b..134ea2f62 100644 --- a/examples/dreamt_sleep_wake_classification_lightgbm.py +++ b/examples/dreamt_sleep_wake_classification_lightgbm.py @@ -3,6 +3,7 @@ import warnings from collections import Counter from contextlib import redirect_stderr, redirect_stdout +from typing import Iterable import lightgbm as lgb import numpy as np @@ -35,7 +36,7 @@ YELLOW = "\033[33m" -def format_section(title): +def format_section(title: str) -> str: """Formats a section title for console output. Args: @@ -47,7 +48,7 @@ def format_section(title): return f"\n{BOLD}{CYAN}{title}{RESET}" -def format_patient_ids(patient_ids): +def format_patient_ids(patient_ids: Iterable[str]) -> str: """Formats patient IDs for readable console output. Args: @@ -59,7 +60,7 @@ def format_patient_ids(patient_ids): return ", ".join(sorted(str(patient_id) for patient_id in set(patient_ids))) -def print_metric(name, value): +def print_metric(name: str, value: float) -> None: """Prints a metric with consistent console formatting. Args: @@ -82,7 +83,7 @@ def summarize_label_counts(labels): return f"sleep (0): {counts.get(0, 0)}, " f"wake (1): {counts.get(1, 0)}" -def configure_clean_output(): +def configure_clean_output() -> None: """Suppresses noisy warnings and logs for a cleaner example run.""" warnings.filterwarnings("ignore", category=ConvergenceWarning) logging.getLogger("pyhealth").setLevel(logging.ERROR) @@ -91,7 +92,11 @@ def configure_clean_output(): ) -def split_samples_by_patient_ids(X, y, groups): +def split_samples_by_patient_ids( + X: np.ndarray, + y: np.ndarray, + groups: np.ndarray, +) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray]: """Splits samples into train and evaluation sets using patient IDs. Args: @@ -120,7 +125,20 @@ def split_samples_by_patient_ids(X, y, groups): ) -def run_experiment(X, y, groups, name): +def run_experiment( + X: np.ndarray, + y: np.ndarray, + groups: np.ndarray, + name: str, +) -> None: + """Runs one feature ablation experiment and prints evaluation metrics. + + Args: + X: Feature matrix for the selected experiment. + y: Binary label vector. + groups: Patient identifier for each sample. + name: Name of the ablation setting. + """ # Split samples into train and evaluation sets X_train, X_test, y_train, y_test, g_train, g_test = split_samples_by_patient_ids( X, @@ -181,7 +199,18 @@ def run_experiment(X, y, groups, name): print_metric("AUPRC", average_precision_score(y_test, y_prob)) -def run_model_comparison(X, y, groups): +def run_model_comparison( + X: np.ndarray, + y: np.ndarray, + groups: np.ndarray, +) -> None: + """Runs a small model comparison on the full temporal feature set. + + Args: + X: Full feature matrix. + y: Binary label vector. + groups: Patient identifier for each sample. + """ # Use the same predefined patient split to compare alternative models X_train, X_test, y_train, y_test, g_train, g_test = split_samples_by_patient_ids( X, @@ -228,12 +257,14 @@ def run_model_comparison(X, y, groups): print_metric("AUPRC", average_precision_score(y_test, y_prob)) -def main(): +def main() -> None: + """Runs the DREAMT sleep-wake classification example workflow.""" configure_clean_output() if DREAMT_ROOT == "REPLACE_WITH_DREAMT_ROOT": raise ValueError( - "Please set DREAMT_ROOT in examples/sleep_wake_classification.py " + "Please set DREAMT_ROOT in " + "examples/dreamt_sleep_wake_classification_lightgbm.py " "before running this example.", ) diff --git a/pyhealth/tasks/sleep_wake_classification.py b/pyhealth/tasks/sleep_wake_classification.py index 90d8a473d..843cfcf3d 100644 --- a/pyhealth/tasks/sleep_wake_classification.py +++ b/pyhealth/tasks/sleep_wake_classification.py @@ -1,6 +1,6 @@ """Diego Farias Castro (diegof4@illinois.edu). -Paper: Addressing Wearable Sleep Tracking Inequity: +Paper: Addressing Wearable Sleep Tracking Inequity: A New Dataset and Novel Methods for a Population with Sleep Disorders Paper link: https://proceedings.mlr.press/v248/wang24a.html @@ -9,7 +9,8 @@ """ import logging -from typing import Dict, List +from os import PathLike +from typing import Any, Dict, List, Tuple import neurokit2 as nk import numpy as np @@ -48,7 +49,10 @@ def __init__(self, epoch_seconds: int = 30, sampling_rate: int = 64): self.sampling_rate = sampling_rate super().__init__() - def _convert_sleep_stage_to_binary_label(self, label): + def _convert_sleep_stage_to_binary_label( + self, + label: Any, + ) -> int | None: """Maps a sleep stage label to the binary sleep-wake target. Args: @@ -103,7 +107,7 @@ def _design_bandpass_filter_coefficients( sampling_rate_hz: float, order: int, stopband_attenuation_db: float = 40.0, - ): + ) -> Tuple[np.ndarray, np.ndarray]: """Designs band-pass filter coefficients for a supported family. Args: @@ -137,13 +141,18 @@ def _design_bandpass_filter_coefficients( raise ValueError(f"Unsupported bandpass filter family: {filter_family}") - def _apply_zero_phase_filter(self, signal: np.ndarray, b, a) -> np.ndarray: + def _apply_zero_phase_filter( + self, + signal: np.ndarray, + numerator_coefficients: np.ndarray, + denominator_coefficients: np.ndarray, + ) -> np.ndarray: """Applies zero-phase filtering to a one-dimensional signal. Args: signal: One-dimensional signal array. - b: Numerator filter coefficients. - a: Denominator filter coefficients. + numerator_coefficients: Numerator filter coefficients. + denominator_coefficients: Denominator filter coefficients. Returns: The filtered signal. @@ -153,7 +162,11 @@ def _apply_zero_phase_filter(self, signal: np.ndarray, b, a) -> np.ndarray: """ if signal.ndim != 1: raise ValueError("Signal must be 1D.") - return filtfilt(b, a, signal) + return filtfilt( + numerator_coefficients, + denominator_coefficients, + signal, + ) def _filter_signal_with_lowpass( self, @@ -722,7 +735,10 @@ def _build_record_epoch_feature_matrix( return all_epoch_features - def _load_wearable_record_dataframe(self, event) -> pd.DataFrame | None: + def _load_wearable_record_dataframe( + self, + event: Any, + ) -> pd.DataFrame | None: """Loads the wearable CSV file associated with a DREAMT event. Args: @@ -731,7 +747,7 @@ def _load_wearable_record_dataframe(self, event) -> pd.DataFrame | None: Returns: A pandas DataFrame if the file can be loaded, otherwise ``None``. """ - file_path = getattr(event, "file_64hz", None) + file_path: str | PathLike[str] | None = getattr(event, "file_64hz", None) if file_path is None: return None From 80798a0336bea6c2a3e2e1a14aa475206a924a35 Mon Sep 17 00:00:00 2001 From: diegofariasc Date: Sat, 14 Mar 2026 01:56:06 -0600 Subject: [PATCH 27/27] refactor: add support for synthetic data in example --- ...eamt_sleep_wake_classification_lightgbm.py | 117 ++++++++++++------ 1 file changed, 79 insertions(+), 38 deletions(-) diff --git a/examples/dreamt_sleep_wake_classification_lightgbm.py b/examples/dreamt_sleep_wake_classification_lightgbm.py index 134ea2f62..c376e2e5b 100644 --- a/examples/dreamt_sleep_wake_classification_lightgbm.py +++ b/examples/dreamt_sleep_wake_classification_lightgbm.py @@ -36,6 +36,32 @@ YELLOW = "\033[33m" +def build_synthetic_benchmark_data() -> tuple[np.ndarray, np.ndarray, np.ndarray]: + """Builds synthetic sleep-wake samples for a runnable ablation example. + + Returns: + Synthetic feature matrix, binary labels, and patient IDs. + """ + rng = np.random.default_rng(42) + patient_ids = TRAIN_PATIENT_IDS + EVAL_PATIENT_IDS + samples_per_patient = 24 + num_base_features = 21 + num_temporal_features = num_base_features * 3 + num_features = num_base_features + num_temporal_features + + groups = np.repeat(patient_ids, samples_per_patient) + y = rng.binomial(1, 0.35, size=len(groups)) + + X = rng.normal(0.0, 1.0, size=(len(groups), num_features)) + X[y == 1, :10] += 0.9 + X[y == 1, 10:14] += 0.4 + X[y == 1, 14:17] += 0.3 + X[y == 1, 17:21] += 0.2 + X[y == 1, 21:] += 0.25 + + return X.astype(float), y.astype(int), groups.astype(str) + + def format_section(title: str) -> str: """Formats a section title for console output. @@ -262,48 +288,63 @@ def main() -> None: configure_clean_output() if DREAMT_ROOT == "REPLACE_WITH_DREAMT_ROOT": - raise ValueError( - "Please set DREAMT_ROOT in " - "examples/dreamt_sleep_wake_classification_lightgbm.py " - "before running this example.", + print(format_section("DREAMT Sleep-Wake Classification Example")) + print("DREAMT_ROOT not set. Running the ablation workflow on synthetic data...") + print( + f"{YELLOW}Warning:{RESET} synthetic samples are randomly generated to " + "make the example runnable without DREAMT. The resulting metrics are " + "not realistic and should not be interpreted as evidence for the " + "task or paper claims\n." + ) + print(f"{BOLD}Train patients:{RESET} {', '.join(TRAIN_PATIENT_IDS)}") + print(f"{BOLD}Eval patients:{RESET} {', '.join(EVAL_PATIENT_IDS)}") + + X_all, y, groups = build_synthetic_benchmark_data() + print(f"{BOLD}Total epoch samples:{RESET} {len(X_all)}") + print(f"{BOLD}Label counts:{RESET} {summarize_label_counts(y)}") + print( + f"{BOLD}Feature matrix:{RESET} " + f"{X_all.shape[0]} samples x {X_all.shape[1]} features" + ) + else: + # Suppress verbose dataset initialization messages and print a cleaner summary. + with redirect_stdout(io.StringIO()), redirect_stderr(io.StringIO()): + dataset = DREAMTDataset(root=DREAMT_ROOT) + task = SleepWakeClassification( + epoch_seconds=EPOCH_SECONDS, + sampling_rate=SAMPLING_RATE, ) - # Suppress verbose dataset initialization messages and print a cleaner summary. - with redirect_stdout(io.StringIO()), redirect_stderr(io.StringIO()): - dataset = DREAMTDataset(root=DREAMT_ROOT) - task = SleepWakeClassification( - epoch_seconds=EPOCH_SECONDS, - sampling_rate=SAMPLING_RATE, - ) - - print(format_section("DREAMT Sleep-Wake Classification Example")) - print(f"{BOLD}Dataset root:{RESET} {DREAMT_ROOT}") - print(f"{BOLD}Train patients:{RESET} {', '.join(TRAIN_PATIENT_IDS)}") - print(f"{BOLD}Eval patients:{RESET} {', '.join(EVAL_PATIENT_IDS)}") - - # Convert the selected DREAMT patients into epoch-level sleep/wake samples. - all_samples = [] - selected_patient_ids = TRAIN_PATIENT_IDS + EVAL_PATIENT_IDS - for patient_id in selected_patient_ids: - patient = dataset.get_patient(patient_id) - samples = task(patient) - print(f" patient {patient_id:<4} -> {len(samples)} epoch samples") - all_samples.extend(samples) - - print(f"{BOLD}Total epoch samples:{RESET} {len(all_samples)}") - print( - f"{BOLD}Label counts:{RESET} " - f"{summarize_label_counts(sample['label'] for sample in all_samples)}" - ) + print(format_section("DREAMT Sleep-Wake Classification Example")) + print(f"{BOLD}Dataset root:{RESET} {DREAMT_ROOT}") + print(f"{BOLD}Train patients:{RESET} {', '.join(TRAIN_PATIENT_IDS)}") + print(f"{BOLD}Eval patients:{RESET} {', '.join(EVAL_PATIENT_IDS)}") + + # Convert the selected DREAMT patients into epoch-level sleep/wake samples. + all_samples = [] + selected_patient_ids = TRAIN_PATIENT_IDS + EVAL_PATIENT_IDS + for patient_id in selected_patient_ids: + patient = dataset.get_patient(patient_id) + samples = task(patient) + print(f" patient {patient_id:<4} -> {len(samples)} epoch samples") + all_samples.extend(samples) + + print(f"{BOLD}Total epoch samples:{RESET} {len(all_samples)}") + print( + f"{BOLD}Label counts:{RESET} " + f"{summarize_label_counts(sample['label'] for sample in all_samples)}" + ) - # Turn the task samples into arrays for training and evaluation. - X_all = np.array([s["features"] for s in all_samples], dtype=float) - y = np.array([s["label"] for s in all_samples], dtype=int) - groups = np.array([s["patient_id"] for s in all_samples]) + # Turn the task samples into arrays for training and evaluation. + X_all = np.array([s["features"] for s in all_samples], dtype=float) + y = np.array([s["label"] for s in all_samples], dtype=int) + groups = np.array([s["patient_id"] for s in all_samples]) - print( - f"{BOLD}Feature matrix:{RESET} {X_all.shape[0]} samples x {X_all.shape[1]} features" - ) + if DREAMT_ROOT != "REPLACE_WITH_DREAMT_ROOT": + print( + f"{BOLD}Feature matrix:{RESET} " + f"{X_all.shape[0]} samples x {X_all.shape[1]} features" + ) # Keep only the base per-epoch features without temporal augmentation. X_base = X_all[:, :21]