diff --git a/docs/api/tasks.rst b/docs/api/tasks.rst index d85d04bc3..6da8560e7 100644 --- a/docs/api/tasks.rst +++ b/docs/api/tasks.rst @@ -220,6 +220,7 @@ Available Tasks Mortality Prediction (StageNet MIMIC-IV) Patient Linkage (MIMIC-III) Readmission Prediction + Sleep-Wake Classification Sleep Staging Sleep Staging (SleepEDF) Temple University EEG Tasks diff --git a/docs/api/tasks/pyhealth.tasks.sleep_wake_classification.rst b/docs/api/tasks/pyhealth.tasks.sleep_wake_classification.rst new file mode 100644 index 000000000..ee613f09f --- /dev/null +++ b/docs/api/tasks/pyhealth.tasks.sleep_wake_classification.rst @@ -0,0 +1,7 @@ +pyhealth.tasks.sleep_wake_classification +======================================== + +.. autoclass:: pyhealth.tasks.sleep_wake_classification.SleepWakeClassification + :members: + :undoc-members: + :show-inheritance: diff --git a/examples/dreamt_sleep_wake_classification_lightgbm.py b/examples/dreamt_sleep_wake_classification_lightgbm.py new file mode 100644 index 000000000..c376e2e5b --- /dev/null +++ b/examples/dreamt_sleep_wake_classification_lightgbm.py @@ -0,0 +1,376 @@ +import io +import logging +import warnings +from collections import Counter +from contextlib import redirect_stderr, redirect_stdout +from typing import Iterable + +import lightgbm as lgb +import numpy as np +from sklearn.ensemble import RandomForestClassifier +from sklearn.exceptions import ConvergenceWarning +from sklearn.impute import SimpleImputer +from sklearn.linear_model import LogisticRegression +from sklearn.metrics import ( + accuracy_score, + average_precision_score, + f1_score, + roc_auc_score, +) + +from pyhealth.datasets import DREAMTDataset +from pyhealth.tasks.sleep_wake_classification import SleepWakeClassification + +# Configuration +DREAMT_ROOT = "REPLACE_WITH_DREAMT_ROOT" +TRAIN_PATIENT_IDS = ["S028", "S062", "S078"] +EVAL_PATIENT_IDS = ["S081", "S099"] +EPOCH_SECONDS = 30 +SAMPLING_RATE = 64 + +# Console formatting codes +RESET = "\033[0m" +BOLD = "\033[1m" +CYAN = "\033[36m" +GREEN = "\033[32m" +YELLOW = "\033[33m" + + +def build_synthetic_benchmark_data() -> tuple[np.ndarray, np.ndarray, np.ndarray]: + """Builds synthetic sleep-wake samples for a runnable ablation example. + + Returns: + Synthetic feature matrix, binary labels, and patient IDs. + """ + rng = np.random.default_rng(42) + patient_ids = TRAIN_PATIENT_IDS + EVAL_PATIENT_IDS + samples_per_patient = 24 + num_base_features = 21 + num_temporal_features = num_base_features * 3 + num_features = num_base_features + num_temporal_features + + groups = np.repeat(patient_ids, samples_per_patient) + y = rng.binomial(1, 0.35, size=len(groups)) + + X = rng.normal(0.0, 1.0, size=(len(groups), num_features)) + X[y == 1, :10] += 0.9 + X[y == 1, 10:14] += 0.4 + X[y == 1, 14:17] += 0.3 + X[y == 1, 17:21] += 0.2 + X[y == 1, 21:] += 0.25 + + return X.astype(float), y.astype(int), groups.astype(str) + + +def format_section(title: str) -> str: + """Formats a section title for console output. + + Args: + title: Section title to format. + + Returns: + A colorized section title string. + """ + return f"\n{BOLD}{CYAN}{title}{RESET}" + + +def format_patient_ids(patient_ids: Iterable[str]) -> str: + """Formats patient IDs for readable console output. + + Args: + patient_ids: Iterable of patient identifiers. + + Returns: + A comma-separated string of patient IDs. + """ + return ", ".join(sorted(str(patient_id) for patient_id in set(patient_ids))) + + +def print_metric(name: str, value: float) -> None: + """Prints a metric with consistent console formatting. + + Args: + name: Metric name. + value: Metric value. + """ + print(f" {name:<16}{value:.4f}") + + +def summarize_label_counts(labels): + """Builds a readable sleep/wake label summary. + + Args: + labels: Iterable of binary labels. + + Returns: + A formatted label count string. + """ + counts = Counter(labels) + return f"sleep (0): {counts.get(0, 0)}, " f"wake (1): {counts.get(1, 0)}" + + +def configure_clean_output() -> None: + """Suppresses noisy warnings and logs for a cleaner example run.""" + warnings.filterwarnings("ignore", category=ConvergenceWarning) + logging.getLogger("pyhealth").setLevel(logging.ERROR) + logging.getLogger("pyhealth.tasks.sleep_wake_classification").setLevel( + logging.ERROR + ) + + +def split_samples_by_patient_ids( + X: np.ndarray, + y: np.ndarray, + groups: np.ndarray, +) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray]: + """Splits samples into train and evaluation sets using patient IDs. + + Args: + X: Feature matrix. + y: Binary label vector. + groups: Patient identifier for each sample. + + Returns: + Train and evaluation features, labels, and patient groups. + """ + train_mask = np.isin(groups, TRAIN_PATIENT_IDS) + eval_mask = np.isin(groups, EVAL_PATIENT_IDS) + + if not np.any(train_mask): + raise ValueError("No samples found for TRAIN_PATIENT_IDS.") + if not np.any(eval_mask): + raise ValueError("No samples found for EVAL_PATIENT_IDS.") + + return ( + X[train_mask], + X[eval_mask], + y[train_mask], + y[eval_mask], + groups[train_mask], + groups[eval_mask], + ) + + +def run_experiment( + X: np.ndarray, + y: np.ndarray, + groups: np.ndarray, + name: str, +) -> None: + """Runs one feature ablation experiment and prints evaluation metrics. + + Args: + X: Feature matrix for the selected experiment. + y: Binary label vector. + groups: Patient identifier for each sample. + name: Name of the ablation setting. + """ + # Split samples into train and evaluation sets + X_train, X_test, y_train, y_test, g_train, g_test = split_samples_by_patient_ids( + X, + y, + groups, + ) + + # Report dataset statistics + print(format_section(f"Ablation: {name}")) + print(f"{BOLD}Train patients:{RESET} {format_patient_ids(g_train)}") + print(f"{BOLD}Eval patients:{RESET} {format_patient_ids(g_test)}") + print(f"{BOLD}Train samples:{RESET} {len(X_train)}") + print(f"{BOLD}Eval samples:{RESET} {len(X_test)}") + + # Remove features that are all NaN in the training set + non_all_nan_cols = ~np.isnan(X_train).all(axis=0) + X_train = X_train[:, non_all_nan_cols] + X_test = X_test[:, non_all_nan_cols] + + print(f"{BOLD}Feature count:{RESET} {X_train.shape[1]}") + + imputer = SimpleImputer(strategy="median") + X_train = imputer.fit_transform(X_train) + X_test = imputer.transform(X_test) + + # Train a LightGBM model on the current feature subset. + train_data = lgb.Dataset(X_train, label=y_train) + test_data = lgb.Dataset(X_test, label=y_test, reference=train_data) + + params = { + "objective": "binary", + "metric": "binary_logloss", + "boosting_type": "gbdt", + "learning_rate": 0.05, + "num_leaves": 31, + "feature_fraction": 0.9, + "bagging_fraction": 0.9, + "bagging_freq": 5, + "verbose": -1, + "seed": 42, + } + + model = lgb.train( + params, + train_data, + num_boost_round=200, + valid_sets=[test_data], + callbacks=[lgb.early_stopping(stopping_rounds=20, verbose=False)], + ) + + y_prob = model.predict(X_test) + y_pred = (y_prob >= 0.3).astype(int) + + # Report standard binary classification metrics. + print_metric("Accuracy", accuracy_score(y_test, y_pred)) + print_metric("F1", f1_score(y_test, y_pred)) + print_metric("AUROC", roc_auc_score(y_test, y_prob)) + print_metric("AUPRC", average_precision_score(y_test, y_prob)) + + +def run_model_comparison( + X: np.ndarray, + y: np.ndarray, + groups: np.ndarray, +) -> None: + """Runs a small model comparison on the full temporal feature set. + + Args: + X: Full feature matrix. + y: Binary label vector. + groups: Patient identifier for each sample. + """ + # Use the same predefined patient split to compare alternative models + X_train, X_test, y_train, y_test, g_train, g_test = split_samples_by_patient_ids( + X, + y, + groups, + ) + + print(format_section("Model Comparison: ALL modalities + temporal")) + print(f"{BOLD}Train patients:{RESET} {format_patient_ids(g_train)}") + print(f"{BOLD}Eval patients:{RESET} {format_patient_ids(g_test)}") + + non_all_nan_cols = ~np.isnan(X_train).all(axis=0) + X_train = X_train[:, non_all_nan_cols] + X_test = X_test[:, non_all_nan_cols] + + imputer = SimpleImputer(strategy="median") + X_train = imputer.fit_transform(X_train) + X_test = imputer.transform(X_test) + + # Compare logistic regression and random forest on the full feature set. + models = { + "LogisticRegression": LogisticRegression(max_iter=1000), + "RandomForest": RandomForestClassifier( + n_estimators=200, + random_state=42, + n_jobs=-1, + ), + } + + for name, model in models.items(): + model.fit(X_train, y_train) + + if hasattr(model, "predict_proba"): + y_prob = model.predict_proba(X_test)[:, 1] + else: + y_prob = model.decision_function(X_test) + + y_pred = (y_prob >= 0.3).astype(int) + + print(f"\n{YELLOW}{name}{RESET}") + print_metric("Accuracy", accuracy_score(y_test, y_pred)) + print_metric("F1", f1_score(y_test, y_pred)) + print_metric("AUROC", roc_auc_score(y_test, y_prob)) + print_metric("AUPRC", average_precision_score(y_test, y_prob)) + + +def main() -> None: + """Runs the DREAMT sleep-wake classification example workflow.""" + configure_clean_output() + + if DREAMT_ROOT == "REPLACE_WITH_DREAMT_ROOT": + print(format_section("DREAMT Sleep-Wake Classification Example")) + print("DREAMT_ROOT not set. Running the ablation workflow on synthetic data...") + print( + f"{YELLOW}Warning:{RESET} synthetic samples are randomly generated to " + "make the example runnable without DREAMT. The resulting metrics are " + "not realistic and should not be interpreted as evidence for the " + "task or paper claims\n." + ) + print(f"{BOLD}Train patients:{RESET} {', '.join(TRAIN_PATIENT_IDS)}") + print(f"{BOLD}Eval patients:{RESET} {', '.join(EVAL_PATIENT_IDS)}") + + X_all, y, groups = build_synthetic_benchmark_data() + print(f"{BOLD}Total epoch samples:{RESET} {len(X_all)}") + print(f"{BOLD}Label counts:{RESET} {summarize_label_counts(y)}") + print( + f"{BOLD}Feature matrix:{RESET} " + f"{X_all.shape[0]} samples x {X_all.shape[1]} features" + ) + else: + # Suppress verbose dataset initialization messages and print a cleaner summary. + with redirect_stdout(io.StringIO()), redirect_stderr(io.StringIO()): + dataset = DREAMTDataset(root=DREAMT_ROOT) + task = SleepWakeClassification( + epoch_seconds=EPOCH_SECONDS, + sampling_rate=SAMPLING_RATE, + ) + + print(format_section("DREAMT Sleep-Wake Classification Example")) + print(f"{BOLD}Dataset root:{RESET} {DREAMT_ROOT}") + print(f"{BOLD}Train patients:{RESET} {', '.join(TRAIN_PATIENT_IDS)}") + print(f"{BOLD}Eval patients:{RESET} {', '.join(EVAL_PATIENT_IDS)}") + + # Convert the selected DREAMT patients into epoch-level sleep/wake samples. + all_samples = [] + selected_patient_ids = TRAIN_PATIENT_IDS + EVAL_PATIENT_IDS + for patient_id in selected_patient_ids: + patient = dataset.get_patient(patient_id) + samples = task(patient) + print(f" patient {patient_id:<4} -> {len(samples)} epoch samples") + all_samples.extend(samples) + + print(f"{BOLD}Total epoch samples:{RESET} {len(all_samples)}") + print( + f"{BOLD}Label counts:{RESET} " + f"{summarize_label_counts(sample['label'] for sample in all_samples)}" + ) + + # Turn the task samples into arrays for training and evaluation. + X_all = np.array([s["features"] for s in all_samples], dtype=float) + y = np.array([s["label"] for s in all_samples], dtype=int) + groups = np.array([s["patient_id"] for s in all_samples]) + + if DREAMT_ROOT != "REPLACE_WITH_DREAMT_ROOT": + print( + f"{BOLD}Feature matrix:{RESET} " + f"{X_all.shape[0]} samples x {X_all.shape[1]} features" + ) + + # Keep only the base per-epoch features without temporal augmentation. + X_base = X_all[:, :21] + + # Keep the full feature matrix, including temporal context features. + X_temporal = X_all + + # Group feature indices by modality for the ablation experiments. + acc_idx = list(range(0, 10)) + temp_idx = list(range(10, 14)) + bvp_idx = list(range(14, 17)) + eda_idx = list(range(17, 21)) + + X_acc = X_base[:, acc_idx] + X_acc_temp = X_base[:, acc_idx + temp_idx] + X_acc_temp_bvp = X_base[:, acc_idx + temp_idx + bvp_idx] + X_all_modalities = X_base[:, acc_idx + temp_idx + bvp_idx + eda_idx] + + # Run experiments using different feature groups. + run_experiment(X_acc, y, groups, "ACC only") + run_experiment(X_acc_temp, y, groups, "ACC + TEMP") + run_experiment(X_acc_temp_bvp, y, groups, "ACC + TEMP + BVP") + run_experiment(X_all_modalities, y, groups, "ALL modalities") + run_experiment(X_temporal, y, groups, "ALL modalities + temporal") + run_model_comparison(X_temporal, y, groups) + + +if __name__ == "__main__": + main() diff --git a/pyhealth/tasks/__init__.py b/pyhealth/tasks/__init__.py index 2f4294a19..75666bc57 100644 --- a/pyhealth/tasks/__init__.py +++ b/pyhealth/tasks/__init__.py @@ -68,3 +68,5 @@ VariantClassificationClinVar, ) from .patient_linkage_mimic3 import PatientLinkageMIMIC3Task + +from .sleep_wake_classification import SleepWakeClassification \ No newline at end of file diff --git a/pyhealth/tasks/sleep_wake_classification.py b/pyhealth/tasks/sleep_wake_classification.py new file mode 100644 index 000000000..843cfcf3d --- /dev/null +++ b/pyhealth/tasks/sleep_wake_classification.py @@ -0,0 +1,886 @@ +"""Diego Farias Castro (diegof4@illinois.edu). + +Paper: Addressing Wearable Sleep Tracking Inequity: +A New Dataset and Novel Methods for a Population with Sleep Disorders +Paper link: https://proceedings.mlr.press/v248/wang24a.html + +Implements a DREAMT sleep-wake classification task using multimodal wearable +features and temporal context augmentation. +""" + +import logging +from os import PathLike +from typing import Any, Dict, List, Tuple + +import neurokit2 as nk +import numpy as np +import pandas as pd +from scipy.ndimage import gaussian_filter1d +from scipy.signal import butter, cheby2, filtfilt +from scipy.stats import trim_mean +from scipy.stats.mstats import winsorize + +from ..data import Patient +from .base_task import BaseTask + + +class SleepWakeClassification(BaseTask): + """Binary sleep-wake classification task for DREAMT wearable recordings. + + This task converts each DREAMT wearable recording into fixed-length epochs, + extracts physiological features from multiple sensor modalities, + augments them with temporal context, and assigns a binary sleep/wake + label to each epoch. + """ + + task_name = "SleepWakeClassification" + input_schema = {"features": "tensor"} + output_schema = {"label": "binary"} + logger = logging.getLogger(__name__) + + def __init__(self, epoch_seconds: int = 30, sampling_rate: int = 64): + """Initializes the sleep-wake classification task. + + Args: + epoch_seconds: Length of each epoch in seconds. + sampling_rate: Sampling rate of the wearable recording in Hz. + """ + self.epoch_seconds = epoch_seconds + self.sampling_rate = sampling_rate + super().__init__() + + def _convert_sleep_stage_to_binary_label( + self, + label: Any, + ) -> int | None: + """Maps a sleep stage label to the binary sleep-wake target. + + Args: + label: Raw sleep stage label from the DREAMT recording. + + Returns: + ``1`` for wake, ``0`` for sleep, or ``None`` if the label is + missing or unsupported. + """ + if label is None or pd.isna(label): + return None + + label = str(label).strip().upper() + + if label in {"WAKE", "W"}: + return 1 + if label in {"REM", "R", "N1", "N2", "N3"}: + return 0 + + return None + + def _split_signal_into_epochs( + self, + signal: np.ndarray, + sampling_rate_hz: float, + ) -> List[np.ndarray]: + """Splits a 1D signal into non-overlapping fixed-length epochs. + + Args: + signal: One-dimensional signal array. + sampling_rate_hz: Sampling rate of the signal in Hz. + + Returns: + A list of signal segments, one per complete epoch. + """ + samples_per_epoch = int(sampling_rate_hz * self.epoch_seconds) + num_epochs = len(signal) // samples_per_epoch + + epochs = [] + for i in range(num_epochs): + start = i * samples_per_epoch + end = start + samples_per_epoch + epochs.append(signal[start:end]) + + return epochs + + def _design_bandpass_filter_coefficients( + self, + filter_family: str, + low_hz: float, + high_hz: float, + sampling_rate_hz: float, + order: int, + stopband_attenuation_db: float = 40.0, + ) -> Tuple[np.ndarray, np.ndarray]: + """Designs band-pass filter coefficients for a supported family. + + Args: + filter_family: Filter family name, currently ``"butter"`` or + ``"cheby2"``. + low_hz: Lower cutoff frequency in Hz. + high_hz: Upper cutoff frequency in Hz. + sampling_rate_hz: Sampling rate in Hz. + order: Filter order. + stopband_attenuation_db: Stopband attenuation used by Chebyshev-II. + + Returns: + The numerator and denominator filter coefficients. + + Raises: + ValueError: If the filter family is unsupported. + """ + nyq = 0.5 * sampling_rate_hz + low = low_hz / nyq + high = high_hz / nyq + + if filter_family == "butter": + return butter(order, [low, high], btype="band") + if filter_family == "cheby2": + return cheby2( + order, + stopband_attenuation_db, + [low, high], + btype="band", + ) + + raise ValueError(f"Unsupported bandpass filter family: {filter_family}") + + def _apply_zero_phase_filter( + self, + signal: np.ndarray, + numerator_coefficients: np.ndarray, + denominator_coefficients: np.ndarray, + ) -> np.ndarray: + """Applies zero-phase filtering to a one-dimensional signal. + + Args: + signal: One-dimensional signal array. + numerator_coefficients: Numerator filter coefficients. + denominator_coefficients: Denominator filter coefficients. + + Returns: + The filtered signal. + + Raises: + ValueError: If ``signal`` is not one-dimensional. + """ + if signal.ndim != 1: + raise ValueError("Signal must be 1D.") + return filtfilt( + numerator_coefficients, + denominator_coefficients, + signal, + ) + + def _filter_signal_with_lowpass( + self, + signal: np.ndarray, + sampling_rate_hz: float, + cutoff_hz: float, + order: int = 4, + ) -> np.ndarray: + """Applies a Butterworth low-pass filter to a signal. + + Args: + signal: One-dimensional signal array. + sampling_rate_hz: Sampling rate in Hz. + cutoff_hz: Low-pass cutoff frequency in Hz. + order: Filter order. + + Returns: + The filtered signal. + """ + nyq = 0.5 * sampling_rate_hz + b, a = butter(order, cutoff_hz / nyq, btype="low") + return self._apply_zero_phase_filter(signal, b, a) + + def _filter_accelerometer_signal( + self, + signal: np.ndarray, + sampling_rate_hz: float, + ) -> np.ndarray: + """Band-pass filters an accelerometer signal for movement features. + + Args: + signal: Raw accelerometer signal from one axis. + sampling_rate_hz: Sampling rate in Hz. + + Returns: + The filtered accelerometer signal. + """ + b, a = self._design_bandpass_filter_coefficients( + filter_family="butter", + low_hz=3.0, + high_hz=11.0, + sampling_rate_hz=sampling_rate_hz, + order=5, + ) + return self._apply_zero_phase_filter(signal, b, a) + + def _filter_blood_volume_pulse_signal( + self, + signal: np.ndarray, + sampling_rate_hz: float, + ) -> np.ndarray: + """Band-pass filters a blood volume pulse signal. + + Args: + signal: Raw blood volume pulse signal. + sampling_rate_hz: Sampling rate in Hz. + + Returns: + The filtered blood volume pulse signal. + """ + b, a = self._design_bandpass_filter_coefficients( + filter_family="cheby2", + low_hz=0.5, + high_hz=20.0, + sampling_rate_hz=sampling_rate_hz, + order=4, + stopband_attenuation_db=40.0, + ) + return self._apply_zero_phase_filter(signal, b, a) + + def _detrend_signal_by_segments( + self, + signal: np.ndarray, + sampling_rate_hz: float, + segment_seconds: int, + ) -> np.ndarray: + """Detrends a signal by removing a linear trend per segment. + + Args: + signal: One-dimensional signal array. + sampling_rate_hz: Sampling rate in Hz. + segment_seconds: Segment length used for local detrending. + + Returns: + A detrended copy of the input signal. + """ + samples_per_seg = int(sampling_rate_hz * segment_seconds) + detrended = signal.copy() + + for i in range(0, len(signal), samples_per_seg): + seg = signal[i : i + samples_per_seg] + if len(seg) < 2: + continue + + x = np.arange(len(seg)) + coeffs = np.polyfit(x, seg, deg=1) + trend = np.polyval(coeffs, x) + detrended[i : i + len(seg)] = seg - trend + + return detrended + + def _extract_accelerometer_axis_epoch_features( + self, + signal: np.ndarray, + sampling_rate_hz: float, + ) -> List[Dict[str, float]]: + """Extracts per-epoch features from one accelerometer axis. + + Args: + signal: Raw accelerometer signal from one axis. + sampling_rate_hz: Sampling rate in Hz. + + Returns: + A list of feature dictionaries, one per epoch. + """ + filtered = self._filter_accelerometer_signal(signal, sampling_rate_hz) + filtered_abs = np.abs(filtered) + epochs = self._split_signal_into_epochs(filtered_abs, sampling_rate_hz) + + return [ + { + "trimmed_mean": float(trim_mean(epoch, proportiontocut=0.10)), + "max": float(np.max(epoch)), + "iqr": float(np.percentile(epoch, 75) - np.percentile(epoch, 25)), + } + for epoch in epochs + ] + + def _extract_accelerometer_magnitude_deviation_epoch_features( + self, + accelerometer_x_signal: np.ndarray, + accelerometer_y_signal: np.ndarray, + accelerometer_z_signal: np.ndarray, + sampling_rate_hz: float, + ) -> List[Dict[str, float]]: + """Extracts mean absolute deviation features from ACC magnitude. + + Args: + accelerometer_x_signal: Accelerometer X-axis signal. + accelerometer_y_signal: Accelerometer Y-axis signal. + accelerometer_z_signal: Accelerometer Z-axis signal. + sampling_rate_hz: Sampling rate in Hz. + + Returns: + A list of feature dictionaries containing magnitude deviation. + """ + magnitude = np.sqrt( + accelerometer_x_signal**2 + + accelerometer_y_signal**2 + + accelerometer_z_signal**2 + ) + epochs = self._split_signal_into_epochs(magnitude, sampling_rate_hz) + + return [ + {"mad": float(np.mean(np.abs(epoch - np.mean(epoch))))} for epoch in epochs + ] + + def _extract_temperature_epoch_features( + self, + signal: np.ndarray, + sampling_rate_hz: float, + ) -> List[Dict[str, float]]: + """Extracts per-epoch temperature summary statistics. + + Args: + signal: Raw temperature signal. + sampling_rate_hz: Sampling rate in Hz. + + Returns: + A list of feature dictionaries, one per epoch. + """ + limits = (0.05, 0.05) + wins_signal = winsorize(signal, limits=limits) + wins_signal = np.clip(wins_signal, 31.0, 40.0) + epochs = self._split_signal_into_epochs( + np.asarray(wins_signal), + sampling_rate_hz, + ) + + return [ + { + "mean": float(np.mean(epoch)), + "min": float(np.min(epoch)), + "max": float(np.max(epoch)), + "std": float(np.std(epoch)), + } + for epoch in epochs + ] + + def _extract_blood_volume_pulse_epoch_features( + self, + signal: np.ndarray, + sampling_rate_hz: float, + ) -> List[Dict[str, float]]: + """Extracts HRV-based features from blood volume pulse epochs. + + Args: + signal: Raw blood volume pulse signal. + sampling_rate_hz: Sampling rate in Hz. + + Returns: + A list of feature dictionaries, one per epoch. + """ + filtered = self._filter_blood_volume_pulse_signal(signal, sampling_rate_hz) + epochs = self._split_signal_into_epochs(filtered, sampling_rate_hz) + epoch_features = [] + for epoch in epochs: + try: + _, info = nk.ppg_process(epoch, sampling_rate=sampling_rate_hz) + hrv = nk.hrv_time( + info["PPG_Peaks"], + sampling_rate=sampling_rate_hz, + show=False, + ) + + epoch_features.append( + { + "rmssd": float(hrv["HRV_RMSSD"].values[0]), + "sdnn": float(hrv["HRV_SDNN"].values[0]), + "pnn50": float(hrv["HRV_pNN50"].values[0]), + } + ) + except (KeyError, IndexError, TypeError, ValueError) as error: + self.logger.warning( + "Skipping BVP epoch due to feature extraction error: %s", + error, + ) + epoch_features.append( + {"rmssd": np.nan, "sdnn": np.nan, "pnn50": np.nan} + ) + + return epoch_features + + def _extract_electrodermal_activity_epoch_features( + self, + signal: np.ndarray, + sampling_rate_hz: float, + ) -> List[Dict[str, float]]: + """Extracts SCR-based features from electrodermal activity epochs. + + Args: + signal: Raw electrodermal activity signal. + sampling_rate_hz: Sampling rate in Hz. + + Returns: + A list of feature dictionaries, one per epoch. + """ + detrended = self._detrend_signal_by_segments( + signal, + sampling_rate_hz, + segment_seconds=5, + ) + filtered = self._filter_signal_with_lowpass( + detrended, + sampling_rate_hz, + cutoff_hz=1.0, + ) + + eda_signals, _ = nk.eda_process(filtered, sampling_rate=sampling_rate_hz) + scr = eda_signals["EDA_Phasic"].values + epochs = self._split_signal_into_epochs(scr, sampling_rate_hz) + epoch_features = [] + for epoch in epochs: + try: + _, info = nk.eda_peaks(epoch, sampling_rate=sampling_rate_hz) + + amplitudes = info["SCR_Amplitude"] + rise_times = info["SCR_RiseTime"] + recovery_times = info["SCR_RecoveryTime"] + + epoch_features.append( + { + "scr_amp_mean": ( + float(np.mean(amplitudes)) if len(amplitudes) else 0.0 + ), + "scr_amp_max": ( + float(np.max(amplitudes)) if len(amplitudes) else 0.0 + ), + "scr_rise_mean": ( + float(np.mean(rise_times)) if len(rise_times) else 0.0 + ), + "scr_recovery_mean": ( + float(np.mean(recovery_times)) + if len(recovery_times) + else 0.0 + ), + } + ) + except (KeyError, IndexError, TypeError, ValueError) as error: + self.logger.warning( + "Skipping EDA epoch due to feature extraction error: %s", + error, + ) + epoch_features.append( + { + "scr_amp_mean": np.nan, + "scr_amp_max": np.nan, + "scr_rise_mean": np.nan, + "scr_recovery_mean": np.nan, + } + ) + + return epoch_features + + def _compute_rolling_variance( + self, + values: np.ndarray, + window: int, + ) -> np.ndarray: + """Computes a centered rolling variance over a feature series. + + Args: + values: One-dimensional array of feature values over epochs. + window: Rolling window size in number of epochs. + + Returns: + An array of rolling variance values. + """ + out = np.zeros_like(values) + half = window // 2 + + for i in range(len(values)): + start = max(0, i - half) + end = min(len(values), i + half + 1) + out[i] = np.var(values[start:end]) + + return out + + def _augment_epoch_features_with_temporal_context( + self, + epoch_features: List[List[float]], + gaussian_sigma: float = 2.0, + variance_window: int = 5, + ) -> List[List[float]]: + """Augments each epoch feature vector with temporal context features. + + For each base feature, this method appends a smoothed version, its + temporal derivative, and a rolling variance estimate. + + Args: + epoch_features: Base feature matrix as a list of epoch vectors. + gaussian_sigma: Gaussian smoothing parameter. + variance_window: Rolling variance window size in epochs. + + Returns: + The augmented epoch feature matrix. + """ + if len(epoch_features) == 0: + return [] + + feature_matrix = np.asarray(epoch_features, dtype=float) + num_epochs, num_features = feature_matrix.shape + + enhanced = feature_matrix.tolist() + + for j in range(num_features): + values = feature_matrix[:, j] + + smoothed = gaussian_filter1d(values, sigma=gaussian_sigma, mode="nearest") + deriv = np.diff(smoothed, prepend=smoothed[0]) + var = self._compute_rolling_variance(smoothed, variance_window) + + for i in range(num_epochs): + enhanced[i].append(float(smoothed[i])) + enhanced[i].append(float(deriv[i])) + enhanced[i].append(float(var[i])) + + return enhanced + + def _extract_sensor_signals_from_dataframe( + self, + record_dataframe: pd.DataFrame, + ) -> Dict[str, np.ndarray]: + """Extracts all sensor modalities from a DREAMT record. + + Args: + record_dataframe: Wearable recording loaded as a pandas DataFrame. + + Returns: + A dictionary mapping modality names to numeric NumPy arrays. + """ + return { + "accelerometer_x": pd.to_numeric(record_dataframe["ACC_X"], errors="coerce") + .fillna(0.0) + .to_numpy(), + "accelerometer_y": pd.to_numeric(record_dataframe["ACC_Y"], errors="coerce") + .fillna(0.0) + .to_numpy(), + "accelerometer_z": pd.to_numeric(record_dataframe["ACC_Z"], errors="coerce") + .fillna(0.0) + .to_numpy(), + "temperature": pd.to_numeric(record_dataframe["TEMP"], errors="coerce") + .fillna(0.0) + .to_numpy(), + "blood_volume_pulse": pd.to_numeric( + record_dataframe["BVP"], errors="coerce" + ) + .fillna(0.0) + .to_numpy(), + "electrodermal_activity": pd.to_numeric( + record_dataframe["EDA"], errors="coerce" + ) + .fillna(0.0) + .to_numpy(), + } + + def _extract_feature_sets_for_all_modalities( + self, + sensor_signals: Dict[str, np.ndarray], + sampling_rate_hz: float, + ) -> Dict[str, List[Dict[str, float]]]: + """Extracts per-epoch features for every supported sensor modality. + + Args: + sensor_signals: Dictionary of raw modality arrays. + sampling_rate_hz: Sampling rate in Hz. + + Returns: + A dictionary mapping modality names to lists of epoch features. + """ + return { + "accelerometer_x": self._extract_accelerometer_axis_epoch_features( + sensor_signals["accelerometer_x"], + sampling_rate_hz, + ), + "accelerometer_y": self._extract_accelerometer_axis_epoch_features( + sensor_signals["accelerometer_y"], + sampling_rate_hz, + ), + "accelerometer_z": self._extract_accelerometer_axis_epoch_features( + sensor_signals["accelerometer_z"], + sampling_rate_hz, + ), + "accelerometer_magnitude_deviation": self._extract_accelerometer_magnitude_deviation_epoch_features( + sensor_signals["accelerometer_x"], + sensor_signals["accelerometer_y"], + sensor_signals["accelerometer_z"], + sampling_rate_hz, + ), + "temperature": self._extract_temperature_epoch_features( + sensor_signals["temperature"], + sampling_rate_hz, + ), + "blood_volume_pulse": self._extract_blood_volume_pulse_epoch_features( + sensor_signals["blood_volume_pulse"], + sampling_rate_hz, + ), + "electrodermal_activity": self._extract_electrodermal_activity_epoch_features( + sensor_signals["electrodermal_activity"], + sampling_rate_hz, + ), + } + + def _count_complete_epochs( + self, + feature_sets: Dict[str, List[Dict[str, float]]], + ) -> int: + """Counts how many epochs are complete across all modalities. + + Args: + feature_sets: Per-modality epoch feature dictionaries. + + Returns: + The minimum number of aligned complete epochs available across all + modalities. + """ + return min(len(features) for features in feature_sets.values()) + + def _build_epoch_feature_vector( + self, + feature_sets: Dict[str, List[Dict[str, float]]], + epoch_index: int, + ) -> List[float]: + """Builds the final feature vector for one epoch. + + Args: + feature_sets: Per-modality epoch feature dictionaries. + epoch_index: Epoch index to assemble. + + Returns: + A flat feature vector for the requested epoch. + """ + accelerometer_x_features = feature_sets["accelerometer_x"][epoch_index] + accelerometer_y_features = feature_sets["accelerometer_y"][epoch_index] + accelerometer_z_features = feature_sets["accelerometer_z"][epoch_index] + accelerometer_magnitude_deviation_features = feature_sets[ + "accelerometer_magnitude_deviation" + ][epoch_index] + temperature_features = feature_sets["temperature"][epoch_index] + blood_volume_pulse_features = feature_sets["blood_volume_pulse"][epoch_index] + electrodermal_activity_features = feature_sets["electrodermal_activity"][ + epoch_index + ] + + features = [] + features.extend( + accelerometer_x_features[feature_name] + for feature_name in ["trimmed_mean", "max", "iqr"] + ) + features.extend( + accelerometer_y_features[feature_name] + for feature_name in ["trimmed_mean", "max", "iqr"] + ) + features.extend( + accelerometer_z_features[feature_name] + for feature_name in ["trimmed_mean", "max", "iqr"] + ) + features.extend( + accelerometer_magnitude_deviation_features[feature_name] + for feature_name in ["mad"] + ) + features.extend( + temperature_features[feature_name] + for feature_name in ["mean", "min", "max", "std"] + ) + features.extend( + blood_volume_pulse_features[feature_name] + for feature_name in ["rmssd", "sdnn", "pnn50"] + ) + features.extend( + electrodermal_activity_features[feature_name] + for feature_name in [ + "scr_amp_mean", + "scr_amp_max", + "scr_rise_mean", + "scr_recovery_mean", + ] + ) + return features + + def _build_record_epoch_feature_matrix( + self, + record_dataframe: pd.DataFrame, + ) -> List[List[float]]: + """Builds the full epoch feature matrix for one wearable record. + + Args: + record_dataframe: Wearable recording loaded as a pandas DataFrame. + + Returns: + A list of epoch feature vectors. Returns an empty list if required + sensor columns are missing. + """ + sampling_rate_hz = float(self.sampling_rate) + + required_columns = {"ACC_X", "ACC_Y", "ACC_Z", "TEMP", "BVP", "EDA"} + if not required_columns.issubset(record_dataframe.columns): + return [] + + sensor_signals = self._extract_sensor_signals_from_dataframe(record_dataframe) + feature_sets = self._extract_feature_sets_for_all_modalities( + sensor_signals, + sampling_rate_hz, + ) + num_epochs = self._count_complete_epochs(feature_sets) + + all_epoch_features = [] + for i in range(num_epochs): + all_epoch_features.append(self._build_epoch_feature_vector(feature_sets, i)) + + all_epoch_features = self._augment_epoch_features_with_temporal_context( + all_epoch_features, + gaussian_sigma=2.0, + variance_window=5, + ) + + return all_epoch_features + + def _load_wearable_record_dataframe( + self, + event: Any, + ) -> pd.DataFrame | None: + """Loads the wearable CSV file associated with a DREAMT event. + + Args: + event: DREAMT event containing the ``file_64hz`` attribute. + + Returns: + A pandas DataFrame if the file can be loaded, otherwise ``None``. + """ + file_path: str | PathLike[str] | None = getattr(event, "file_64hz", None) + if file_path is None: + return None + + try: + return pd.read_csv(file_path) + except ( + FileNotFoundError, + OSError, + pd.errors.EmptyDataError, + pd.errors.ParserError, + ) as error: + self.logger.warning( + "Skipping DREAMT record '%s' due to file loading error: %s", + file_path, + error, + ) + return None + + def _extract_binary_label_for_epoch( + self, + record_dataframe: pd.DataFrame, + epoch_index: int, + samples_per_epoch: int, + ) -> int | None: + """Extracts the binary sleep-wake label for one epoch. + + Args: + record_dataframe: Wearable recording loaded as a pandas DataFrame. + epoch_index: Index of the epoch to label. + samples_per_epoch: Number of samples contained in one epoch. + + Returns: + ``1`` for wake, ``0`` for sleep, or ``None`` if the epoch cannot be + labeled. + """ + start = epoch_index * samples_per_epoch + end = start + samples_per_epoch + epoch_dataframe = record_dataframe.iloc[start:end] + + if len(epoch_dataframe) < samples_per_epoch: + return None + + stage_mode = epoch_dataframe["Sleep_Stage"].mode(dropna=True) + if len(stage_mode) == 0: + return None + + return self._convert_sleep_stage_to_binary_label(stage_mode.iloc[0]) + + def _build_samples_for_sleep_event( + self, + patient: Patient, + sleep_event_index: int, + record_dataframe: pd.DataFrame, + record_epoch_feature_matrix: List[List[float]], + samples_per_epoch: int, + ) -> List[Dict[str, object]]: + """Builds epoch-level samples for one DREAMT sleep event. + + Args: + patient: Patient containing the event. + sleep_event_index: Index of the sleep event within the patient. + record_dataframe: Wearable recording loaded as a pandas DataFrame. + record_epoch_feature_matrix: Per-epoch feature matrix for the + record. + samples_per_epoch: Number of samples contained in one epoch. + + Returns: + A list of task samples for the event. + """ + samples = [] + n_labeled_epochs = len(record_dataframe) // samples_per_epoch + n_epochs = min(len(record_epoch_feature_matrix), n_labeled_epochs) + + for epoch_idx in range(n_epochs): + label = self._extract_binary_label_for_epoch( + record_dataframe, + epoch_idx, + samples_per_epoch, + ) + if label is None: + continue + + samples.append( + { + "patient_id": patient.patient_id, + "record_id": f"{patient.patient_id}-event{sleep_event_index}-epoch{epoch_idx}", + "epoch_index": epoch_idx, + "features": record_epoch_feature_matrix[epoch_idx], + "label": label, + } + ) + + return samples + + def __call__(self, patient: Patient) -> List[Dict[str, object]]: + """Processes one patient into epoch-level sleep-wake samples. + + Args: + patient: DREAMT patient to process. + + Returns: + A list of samples containing patient identifier, record identifier, + epoch index, handcrafted features, and binary sleep-wake label. + """ + samples = [] + events = patient.get_events(event_type="dreamt_sleep") + if len(events) == 0: + return samples + + samples_per_epoch = self.epoch_seconds * self.sampling_rate + + for event_idx, event in enumerate(events): + record_dataframe = self._load_wearable_record_dataframe(event) + if record_dataframe is None: + continue + + if "Sleep_Stage" not in record_dataframe.columns: + continue + + record_epoch_feature_matrix = self._build_record_epoch_feature_matrix( + record_dataframe + ) + if len(record_epoch_feature_matrix) == 0: + continue + + samples.extend( + self._build_samples_for_sleep_event( + patient=patient, + sleep_event_index=event_idx, + record_dataframe=record_dataframe, + record_epoch_feature_matrix=record_epoch_feature_matrix, + samples_per_epoch=samples_per_epoch, + ) + ) + + return samples diff --git a/tests/core/test_sleep_wake_classification.py b/tests/core/test_sleep_wake_classification.py new file mode 100644 index 000000000..3dfc8fcb0 --- /dev/null +++ b/tests/core/test_sleep_wake_classification.py @@ -0,0 +1,295 @@ +import numpy as np +import pandas as pd + +from pyhealth.tasks.sleep_wake_classification import SleepWakeClassification + + +class FakeEvent: + """Creates a minimal DREAMT-like event for task tests. + + Args: + file_64hz: Path to the wearable CSV associated with the event. + """ + + def __init__(self, file_64hz=None): + self.file_64hz = file_64hz + + +class FakePatient: + """Creates a minimal DREAMT-like patient with configurable events. + + Args: + patient_id: Identifier used by the task in generated samples. + events: Synthetic sleep events available for the patient. + """ + + def __init__(self, patient_id: str, events=None): + self.patient_id = patient_id + self._events = [] if events is None else events + + def get_events(self, event_type=None): + """Returns synthetic events for the requested DREAMT event type. + + Args: + event_type: Event type requested by the task. + + Returns: + The stored sleep events for ``"dreamt_sleep"``, otherwise an empty + list. + """ + if event_type == "dreamt_sleep": + return self._events + return [] + + +def _build_valid_record(num_rows: int = 4) -> pd.DataFrame: + """Builds a small synthetic wearable record for task tests. + + Args: + num_rows: Number of rows to include in the record. + + Returns: + A DREAMT-like wearable DataFrame with the required sensor columns. + """ + sleep_stages = ["W", "W", "N2", "N2", "REM", "REM", "W", "W"][:num_rows] + return pd.DataFrame( + { + "TIMESTAMP": list(range(num_rows)), + "BVP": np.linspace(0.1, 0.8, num_rows), + "EDA": np.linspace(0.01, 0.08, num_rows), + "TEMP": np.linspace(36.1, 36.5, num_rows), + "ACC_X": np.linspace(1.0, 2.0, num_rows), + "ACC_Y": np.linspace(0.5, 1.5, num_rows), + "ACC_Z": np.linspace(0.2, 1.2, num_rows), + "HR": np.linspace(60, 67, num_rows), + "Sleep_Stage": sleep_stages, + } + ) + + +def _build_patient_with_single_event(patient_id: str = "S001") -> FakePatient: + """Builds a synthetic patient with one sleep event. + + Args: + patient_id: Identifier to assign to the synthetic patient. + + Returns: + A synthetic patient with one DREAMT-like sleep event. + """ + return FakePatient(patient_id, events=[FakeEvent("unused.csv")]) + + +def test_convert_sleep_stage_to_binary_label(): + """Tests that raw sleep stages map to the expected binary labels. + + Returns: + None. + """ + task = SleepWakeClassification() + + assert task._convert_sleep_stage_to_binary_label("WAKE") == 1 + assert task._convert_sleep_stage_to_binary_label("W") == 1 + assert task._convert_sleep_stage_to_binary_label("N2") == 0 + assert task._convert_sleep_stage_to_binary_label("REM") == 0 + assert task._convert_sleep_stage_to_binary_label(None) is None + assert task._convert_sleep_stage_to_binary_label("UNKNOWN") is None + + +def test_split_signal_into_epochs(): + """Tests that only complete fixed-length epochs are returned. + + Returns: + None. + """ + task = SleepWakeClassification(epoch_seconds=2, sampling_rate=1) + signal = np.array([0, 1, 2, 3, 4]) + + epochs = task._split_signal_into_epochs(signal, sampling_rate_hz=1) + + assert len(epochs) == 2 + assert np.array_equal(epochs[0], np.array([0, 1])) + assert np.array_equal(epochs[1], np.array([2, 3])) + + +def test_extract_binary_label_for_epoch(): + """Tests epoch-level label extraction from sleep-stage mode. + + Returns: + None. + """ + task = SleepWakeClassification(epoch_seconds=2, sampling_rate=1) + record_dataframe = pd.DataFrame({"Sleep_Stage": ["W", "W", "N2", "N2"]}) + + assert task._extract_binary_label_for_epoch(record_dataframe, 0, 2) == 1 + assert task._extract_binary_label_for_epoch(record_dataframe, 1, 2) == 0 + + +def test_build_record_epoch_feature_matrix_returns_empty_when_columns_missing(): + """Tests that missing sensor columns yield an empty feature matrix. + + Returns: + None. + """ + task = SleepWakeClassification(epoch_seconds=2, sampling_rate=1) + record_dataframe = pd.DataFrame( + { + "ACC_X": [1, 2, 3, 4], + "ACC_Y": [1, 2, 3, 4], + "ACC_Z": [1, 2, 3, 4], + "TEMP": [36, 36, 36, 36], + "Sleep_Stage": ["W", "W", "N2", "N2"], + } + ) + + assert task._build_record_epoch_feature_matrix(record_dataframe) == [] + + +def test_load_wearable_record_dataframe_returns_none_for_missing_file(): + """Tests that missing wearable files return ``None``. + + Returns: + None. + """ + task = SleepWakeClassification() + + assert ( + task._load_wearable_record_dataframe(FakeEvent(file_64hz="missing.csv")) is None + ) + assert task._load_wearable_record_dataframe(FakeEvent(file_64hz=None)) is None + + +def test_task_returns_empty_when_patient_has_no_sleep_events(): + """Tests that a patient without sleep events produces no samples. + + Returns: + None. + """ + task = SleepWakeClassification(epoch_seconds=2, sampling_rate=1) + patient = FakePatient("S001", events=[]) + + assert task(patient) == [] + + +def test_task_returns_empty_when_sleep_stage_column_is_missing(monkeypatch): + """Tests that records missing ``Sleep_Stage`` are skipped. + + Args: + monkeypatch: Pytest fixture used to replace file loading with a + synthetic DataFrame. + + Returns: + None. + """ + task = SleepWakeClassification(epoch_seconds=2, sampling_rate=1) + record_dataframe = _build_valid_record().drop(columns=["Sleep_Stage"]) + patient = _build_patient_with_single_event() + + monkeypatch.setattr( + task, + "_load_wearable_record_dataframe", + lambda event: record_dataframe, + ) + assert task(patient) == [] + + +def test_task_skips_epochs_with_unsupported_labels(monkeypatch): + """Tests that epochs with unsupported labels are not emitted. + + Args: + monkeypatch: Pytest fixture used to replace record loading and feature + extraction with synthetic values. + + Returns: + None. + """ + task = SleepWakeClassification(epoch_seconds=2, sampling_rate=1) + record_dataframe = _build_valid_record(num_rows=4) + record_dataframe["Sleep_Stage"] = ["X", "X", "N2", "N2"] + patient = _build_patient_with_single_event() + + monkeypatch.setattr( + task, + "_load_wearable_record_dataframe", + lambda event: record_dataframe, + ) + monkeypatch.setattr( + task, + "_build_record_epoch_feature_matrix", + lambda df: [[1.0, 2.0], [3.0, 4.0]], + ) + + samples = task(patient) + + assert len(samples) == 1 + assert samples[0]["epoch_index"] == 1 + assert samples[0]["label"] == 0 + + +def test_task_runs_full_flow_with_lightweight_feature_stub(monkeypatch): + """Tests end-to-end sample generation with lightweight mocked features. + + Args: + monkeypatch: Pytest fixture used to replace record loading and avoid + expensive feature extraction. + + Returns: + None. + """ + task = SleepWakeClassification(epoch_seconds=2, sampling_rate=1) + record_dataframe = _build_valid_record(num_rows=4) + patient = _build_patient_with_single_event() + + monkeypatch.setattr( + task, + "_load_wearable_record_dataframe", + lambda event: record_dataframe, + ) + monkeypatch.setattr( + task, + "_build_record_epoch_feature_matrix", + lambda df: [ + [1.0, 10.0], + [2.0, 20.0], + ], + ) + + samples = task(patient) + + assert len(samples) == 2 + assert all("features" in sample for sample in samples) + assert all("label" in sample for sample in samples) + assert all("record_id" in sample for sample in samples) + assert samples[0]["record_id"] == "S001-event0-epoch0" + assert samples[0]["label"] == 1 + assert samples[1]["label"] == 0 + + +def test_task_uses_minimum_epoch_count_between_labels_and_features(monkeypatch): + """Tests that sample count is bounded by the shorter epoch source. + + Args: + monkeypatch: Pytest fixture used to control how many epoch features + the task returns during the test. + + Returns: + None. + """ + task = SleepWakeClassification(epoch_seconds=2, sampling_rate=1) + record_dataframe = _build_valid_record(num_rows=4) + patient = _build_patient_with_single_event() + + monkeypatch.setattr( + task, + "_load_wearable_record_dataframe", + lambda _: record_dataframe, + ) + monkeypatch.setattr( + task, + "_build_record_epoch_feature_matrix", + lambda _: [[1.0], [2.0]], + ) + + samples = task(patient) + + assert len(samples) == 2 + assert [sample["epoch_index"] for sample in samples] == [0, 1]