Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
d168979
feat: add SleepWakeClassification task for binary sleep/wake labeling
diegofariasc Mar 9, 2026
8f8fe6c
feat: add SleepWakeClassification task with epoch segmentation and ba…
diegofariasc Mar 9, 2026
a511bcf
test: add runs unit test for SleepWakeClassification task
diegofariasc Mar 9, 2026
93246ab
feat: implement feature extraction pipeline for SleepWakeClassification
diegofariasc Mar 9, 2026
58d8366
feat: add record-level BVP features to SleepWakeClassification
diegofariasc Mar 9, 2026
a90af6c
feat: add record-level EDA features to SleepWakeClassification
diegofariasc Mar 9, 2026
30d0f90
feat: add temporal feature enhancement to SleepWakeClassification
diegofariasc Mar 9, 2026
4e91596
feat: add initial sleep-wake task and temporal feature ablation example
diegofariasc Mar 9, 2026
dbcc496
feat: add modality ablation experiments for DREAMT sleep-wake classif…
diegofariasc Mar 9, 2026
7ef252d
refactor: improve readability and reuse in sleep-wake task
diegofariasc Mar 12, 2026
20cf21f
refactor: reorder sleep-wake task methods by responsibility
diegofariasc Mar 12, 2026
4686a51
doc: add sleep_wake_classification.rst
diegofariasc Mar 12, 2026
df5720f
doc: document all methods in SleepWakeClassification
diegofariasc Mar 12, 2026
9098292
feat: add SleepWakeClassification to init.py
diegofariasc Mar 12, 2026
79efe27
feat: add Sleep-Wake Classification to tasks.rst
diegofariasc Mar 12, 2026
b72ae6b
refactor: use black+isort to autoformat task code following PEP88
diegofariasc Mar 12, 2026
df985b3
test: add tests covering new SleepWakeClassification task
diegofariasc Mar 13, 2026
9357506
doc: add docstrings to tests
diegofariasc Mar 13, 2026
f04f602
refactor: use black+isort to autoformat test code following PEP88
diegofariasc Mar 13, 2026
a98a1b4
refactor: use specific Exception types instead of general Exception
diegofariasc Mar 13, 2026
a2090a2
refactor: generalize sleep-wake classification example
diegofariasc Mar 13, 2026
1ba1795
doc: add file header to sleep_wake_classification.py
diegofariasc Mar 13, 2026
cbc98de
refactor: improve formatting of results in sleep_wake_classification …
diegofariasc Mar 14, 2026
e8c7024
refactor: use black+issort on example study
diegofariasc Mar 14, 2026
11447e4
refactor: rename sleep_wake_classification example to dreamt_sleep_wa…
diegofariasc Mar 14, 2026
489d443
refactor: improve typing in sleep_wake_classification task and example
diegofariasc Mar 14, 2026
80798a0
refactor: add support for synthetic data in example
diegofariasc Mar 14, 2026
18d2f80
Merge branch 'sunlabuiuc:master' into diegof4/dreamt_sleep_tracking
diegofariasc Mar 14, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/api/tasks.rst
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,7 @@ Available Tasks
Mortality Prediction (StageNet MIMIC-IV) <tasks/pyhealth.tasks.mortality_prediction_stagenet_mimic4>
Patient Linkage (MIMIC-III) <tasks/pyhealth.tasks.patient_linkage_mimic3_fn>
Readmission Prediction <tasks/pyhealth.tasks.readmission_prediction>
Sleep-Wake Classification <tasks/pyhealth.tasks.sleep_wake_classification>
Sleep Staging <tasks/pyhealth.tasks.sleep_staging>
Sleep Staging (SleepEDF) <tasks/pyhealth.tasks.SleepStagingSleepEDF>
Temple University EEG Tasks <tasks/pyhealth.tasks.temple_university_EEG_tasks>
Expand Down
7 changes: 7 additions & 0 deletions docs/api/tasks/pyhealth.tasks.sleep_wake_classification.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
pyhealth.tasks.sleep_wake_classification
========================================

.. autoclass:: pyhealth.tasks.sleep_wake_classification.SleepWakeClassification
:members:
:undoc-members:
:show-inheritance:
376 changes: 376 additions & 0 deletions examples/dreamt_sleep_wake_classification_lightgbm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,376 @@
import io
import logging
import warnings
from collections import Counter
from contextlib import redirect_stderr, redirect_stdout
from typing import Iterable

import lightgbm as lgb
import numpy as np
from sklearn.ensemble import RandomForestClassifier
from sklearn.exceptions import ConvergenceWarning
from sklearn.impute import SimpleImputer
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import (
accuracy_score,
average_precision_score,
f1_score,
roc_auc_score,
)

from pyhealth.datasets import DREAMTDataset
from pyhealth.tasks.sleep_wake_classification import SleepWakeClassification

# Configuration
DREAMT_ROOT = "REPLACE_WITH_DREAMT_ROOT"
TRAIN_PATIENT_IDS = ["S028", "S062", "S078"]
EVAL_PATIENT_IDS = ["S081", "S099"]
EPOCH_SECONDS = 30
SAMPLING_RATE = 64

# Console formatting codes
RESET = "\033[0m"
BOLD = "\033[1m"
CYAN = "\033[36m"
GREEN = "\033[32m"
YELLOW = "\033[33m"


def build_synthetic_benchmark_data() -> tuple[np.ndarray, np.ndarray, np.ndarray]:
"""Builds synthetic sleep-wake samples for a runnable ablation example.

Returns:
Synthetic feature matrix, binary labels, and patient IDs.
"""
rng = np.random.default_rng(42)
patient_ids = TRAIN_PATIENT_IDS + EVAL_PATIENT_IDS
samples_per_patient = 24
num_base_features = 21
num_temporal_features = num_base_features * 3
num_features = num_base_features + num_temporal_features

groups = np.repeat(patient_ids, samples_per_patient)
y = rng.binomial(1, 0.35, size=len(groups))

X = rng.normal(0.0, 1.0, size=(len(groups), num_features))
X[y == 1, :10] += 0.9
X[y == 1, 10:14] += 0.4
X[y == 1, 14:17] += 0.3
X[y == 1, 17:21] += 0.2
X[y == 1, 21:] += 0.25

return X.astype(float), y.astype(int), groups.astype(str)


def format_section(title: str) -> str:
"""Formats a section title for console output.

Args:
title: Section title to format.

Returns:
A colorized section title string.
"""
return f"\n{BOLD}{CYAN}{title}{RESET}"


def format_patient_ids(patient_ids: Iterable[str]) -> str:
"""Formats patient IDs for readable console output.

Args:
patient_ids: Iterable of patient identifiers.

Returns:
A comma-separated string of patient IDs.
"""
return ", ".join(sorted(str(patient_id) for patient_id in set(patient_ids)))


def print_metric(name: str, value: float) -> None:
"""Prints a metric with consistent console formatting.

Args:
name: Metric name.
value: Metric value.
"""
print(f" {name:<16}{value:.4f}")


def summarize_label_counts(labels):
"""Builds a readable sleep/wake label summary.

Args:
labels: Iterable of binary labels.

Returns:
A formatted label count string.
"""
counts = Counter(labels)
return f"sleep (0): {counts.get(0, 0)}, " f"wake (1): {counts.get(1, 0)}"


def configure_clean_output() -> None:
"""Suppresses noisy warnings and logs for a cleaner example run."""
warnings.filterwarnings("ignore", category=ConvergenceWarning)
logging.getLogger("pyhealth").setLevel(logging.ERROR)
logging.getLogger("pyhealth.tasks.sleep_wake_classification").setLevel(
logging.ERROR
)


def split_samples_by_patient_ids(
X: np.ndarray,
y: np.ndarray,
groups: np.ndarray,
) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
"""Splits samples into train and evaluation sets using patient IDs.

Args:
X: Feature matrix.
y: Binary label vector.
groups: Patient identifier for each sample.

Returns:
Train and evaluation features, labels, and patient groups.
"""
train_mask = np.isin(groups, TRAIN_PATIENT_IDS)
eval_mask = np.isin(groups, EVAL_PATIENT_IDS)

if not np.any(train_mask):
raise ValueError("No samples found for TRAIN_PATIENT_IDS.")
if not np.any(eval_mask):
raise ValueError("No samples found for EVAL_PATIENT_IDS.")

return (
X[train_mask],
X[eval_mask],
y[train_mask],
y[eval_mask],
groups[train_mask],
groups[eval_mask],
)


def run_experiment(
X: np.ndarray,
y: np.ndarray,
groups: np.ndarray,
name: str,
) -> None:
"""Runs one feature ablation experiment and prints evaluation metrics.

Args:
X: Feature matrix for the selected experiment.
y: Binary label vector.
groups: Patient identifier for each sample.
name: Name of the ablation setting.
"""
# Split samples into train and evaluation sets
X_train, X_test, y_train, y_test, g_train, g_test = split_samples_by_patient_ids(
X,
y,
groups,
)

# Report dataset statistics
print(format_section(f"Ablation: {name}"))
print(f"{BOLD}Train patients:{RESET} {format_patient_ids(g_train)}")
print(f"{BOLD}Eval patients:{RESET} {format_patient_ids(g_test)}")
print(f"{BOLD}Train samples:{RESET} {len(X_train)}")
print(f"{BOLD}Eval samples:{RESET} {len(X_test)}")

# Remove features that are all NaN in the training set
non_all_nan_cols = ~np.isnan(X_train).all(axis=0)
X_train = X_train[:, non_all_nan_cols]
X_test = X_test[:, non_all_nan_cols]

print(f"{BOLD}Feature count:{RESET} {X_train.shape[1]}")

imputer = SimpleImputer(strategy="median")
X_train = imputer.fit_transform(X_train)
X_test = imputer.transform(X_test)

# Train a LightGBM model on the current feature subset.
train_data = lgb.Dataset(X_train, label=y_train)
test_data = lgb.Dataset(X_test, label=y_test, reference=train_data)

params = {
"objective": "binary",
"metric": "binary_logloss",
"boosting_type": "gbdt",
"learning_rate": 0.05,
"num_leaves": 31,
"feature_fraction": 0.9,
"bagging_fraction": 0.9,
"bagging_freq": 5,
"verbose": -1,
"seed": 42,
}

model = lgb.train(
params,
train_data,
num_boost_round=200,
valid_sets=[test_data],
callbacks=[lgb.early_stopping(stopping_rounds=20, verbose=False)],
)

y_prob = model.predict(X_test)
y_pred = (y_prob >= 0.3).astype(int)

# Report standard binary classification metrics.
print_metric("Accuracy", accuracy_score(y_test, y_pred))
print_metric("F1", f1_score(y_test, y_pred))
print_metric("AUROC", roc_auc_score(y_test, y_prob))
print_metric("AUPRC", average_precision_score(y_test, y_prob))


def run_model_comparison(
X: np.ndarray,
y: np.ndarray,
groups: np.ndarray,
) -> None:
"""Runs a small model comparison on the full temporal feature set.

Args:
X: Full feature matrix.
y: Binary label vector.
groups: Patient identifier for each sample.
"""
# Use the same predefined patient split to compare alternative models
X_train, X_test, y_train, y_test, g_train, g_test = split_samples_by_patient_ids(
X,
y,
groups,
)

print(format_section("Model Comparison: ALL modalities + temporal"))
print(f"{BOLD}Train patients:{RESET} {format_patient_ids(g_train)}")
print(f"{BOLD}Eval patients:{RESET} {format_patient_ids(g_test)}")

non_all_nan_cols = ~np.isnan(X_train).all(axis=0)
X_train = X_train[:, non_all_nan_cols]
X_test = X_test[:, non_all_nan_cols]

imputer = SimpleImputer(strategy="median")
X_train = imputer.fit_transform(X_train)
X_test = imputer.transform(X_test)

# Compare logistic regression and random forest on the full feature set.
models = {
"LogisticRegression": LogisticRegression(max_iter=1000),
"RandomForest": RandomForestClassifier(
n_estimators=200,
random_state=42,
n_jobs=-1,
),
}

for name, model in models.items():
model.fit(X_train, y_train)

if hasattr(model, "predict_proba"):
y_prob = model.predict_proba(X_test)[:, 1]
else:
y_prob = model.decision_function(X_test)

y_pred = (y_prob >= 0.3).astype(int)

print(f"\n{YELLOW}{name}{RESET}")
print_metric("Accuracy", accuracy_score(y_test, y_pred))
print_metric("F1", f1_score(y_test, y_pred))
print_metric("AUROC", roc_auc_score(y_test, y_prob))
print_metric("AUPRC", average_precision_score(y_test, y_prob))


def main() -> None:
"""Runs the DREAMT sleep-wake classification example workflow."""
configure_clean_output()

if DREAMT_ROOT == "REPLACE_WITH_DREAMT_ROOT":
print(format_section("DREAMT Sleep-Wake Classification Example"))
print("DREAMT_ROOT not set. Running the ablation workflow on synthetic data...")
print(
f"{YELLOW}Warning:{RESET} synthetic samples are randomly generated to "
"make the example runnable without DREAMT. The resulting metrics are "
"not realistic and should not be interpreted as evidence for the "
"task or paper claims\n."
)
print(f"{BOLD}Train patients:{RESET} {', '.join(TRAIN_PATIENT_IDS)}")
print(f"{BOLD}Eval patients:{RESET} {', '.join(EVAL_PATIENT_IDS)}")

X_all, y, groups = build_synthetic_benchmark_data()
print(f"{BOLD}Total epoch samples:{RESET} {len(X_all)}")
print(f"{BOLD}Label counts:{RESET} {summarize_label_counts(y)}")
print(
f"{BOLD}Feature matrix:{RESET} "
f"{X_all.shape[0]} samples x {X_all.shape[1]} features"
)
else:
# Suppress verbose dataset initialization messages and print a cleaner summary.
with redirect_stdout(io.StringIO()), redirect_stderr(io.StringIO()):
dataset = DREAMTDataset(root=DREAMT_ROOT)
task = SleepWakeClassification(
epoch_seconds=EPOCH_SECONDS,
sampling_rate=SAMPLING_RATE,
)

print(format_section("DREAMT Sleep-Wake Classification Example"))
print(f"{BOLD}Dataset root:{RESET} {DREAMT_ROOT}")
print(f"{BOLD}Train patients:{RESET} {', '.join(TRAIN_PATIENT_IDS)}")
print(f"{BOLD}Eval patients:{RESET} {', '.join(EVAL_PATIENT_IDS)}")

# Convert the selected DREAMT patients into epoch-level sleep/wake samples.
all_samples = []
selected_patient_ids = TRAIN_PATIENT_IDS + EVAL_PATIENT_IDS
for patient_id in selected_patient_ids:
patient = dataset.get_patient(patient_id)
samples = task(patient)
print(f" patient {patient_id:<4} -> {len(samples)} epoch samples")
all_samples.extend(samples)

print(f"{BOLD}Total epoch samples:{RESET} {len(all_samples)}")
print(
f"{BOLD}Label counts:{RESET} "
f"{summarize_label_counts(sample['label'] for sample in all_samples)}"
)

# Turn the task samples into arrays for training and evaluation.
X_all = np.array([s["features"] for s in all_samples], dtype=float)
y = np.array([s["label"] for s in all_samples], dtype=int)
groups = np.array([s["patient_id"] for s in all_samples])

if DREAMT_ROOT != "REPLACE_WITH_DREAMT_ROOT":
print(
f"{BOLD}Feature matrix:{RESET} "
f"{X_all.shape[0]} samples x {X_all.shape[1]} features"
)

# Keep only the base per-epoch features without temporal augmentation.
X_base = X_all[:, :21]

# Keep the full feature matrix, including temporal context features.
X_temporal = X_all

# Group feature indices by modality for the ablation experiments.
acc_idx = list(range(0, 10))
temp_idx = list(range(10, 14))
bvp_idx = list(range(14, 17))
eda_idx = list(range(17, 21))

X_acc = X_base[:, acc_idx]
X_acc_temp = X_base[:, acc_idx + temp_idx]
X_acc_temp_bvp = X_base[:, acc_idx + temp_idx + bvp_idx]
X_all_modalities = X_base[:, acc_idx + temp_idx + bvp_idx + eda_idx]

# Run experiments using different feature groups.
run_experiment(X_acc, y, groups, "ACC only")
run_experiment(X_acc_temp, y, groups, "ACC + TEMP")
run_experiment(X_acc_temp_bvp, y, groups, "ACC + TEMP + BVP")
run_experiment(X_all_modalities, y, groups, "ALL modalities")
run_experiment(X_temporal, y, groups, "ALL modalities + temporal")
run_model_comparison(X_temporal, y, groups)


if __name__ == "__main__":
main()
2 changes: 2 additions & 0 deletions pyhealth/tasks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,3 +68,5 @@
VariantClassificationClinVar,
)
from .patient_linkage_mimic3 import PatientLinkageMIMIC3Task

from .sleep_wake_classification import SleepWakeClassification
Loading