Skip to content

Multi View Contrastive learning#1093

Open
coderookie1994 wants to merge 30 commits into
sunlabuiuc:masterfrom
coderookie1994:contrastive_learning
Open

Multi View Contrastive learning#1093
coderookie1994 wants to merge 30 commits into
sunlabuiuc:masterfrom
coderookie1994:contrastive_learning

Conversation

@coderookie1994
Copy link
Copy Markdown

@coderookie1994 coderookie1994 commented Apr 22, 2026

Authors: Heyan Gao (heyan3@illinois.edu), Sharthak Ghosh(sghos52@illinois.edu)
Paper: https://arxiv.org/pdf/2506.22393

Overview

This PR adds the following: -

  1. SleepEEG multi view contrastive learning task
  2. Multi view contrastive learning based on time series (paper implementation)
  3. Generic n-view multi view contrastive learning (extension)

Task

  1. The MVCLTrainingSleepEEG task returns sleep windows per patient based on the openly available SleepEDF dataset.
  2. The key properties returned are xt, xd, xf. These are the Time, Derivative and Frequency tensors respectively.

Examples

The example is here examples\mvcl_training_sleepedf.ipynb.

  1. A valid SleepEDF dataset path is necessary
  2. It runs the task associated with it and runs the generic multi-view model

Models

There are two models, MultiViewContrastiveModel which is the generic model that can work with n number of views and MultiViewContrastiveTimeSeriesModel that specifically implments the paper as mentioned above.

Associated Unit-Tests

Task tests here tests\core\test_mvcl_training_sleepedf_task.py
Model tests here tests\core\test_mvcl.py

Dataset sources

https://physionet.org/content/sleep-edfx/1.0.0/

Complete List of files added/modified

  1. docs/api/models/pyhealth.models.MVCL.rst
  2. docs/api/tasks/pyhealth.tasks.MVCLTrainingSleepEEG.rst
  3. docs/api/models.rst
  4. docs/api/tasks.rst
  5. examples/mvcl_training_sleepedf.ipynb
  6. pyhealth/datasets/sleepedf.py
  7. pyhealth/models/__init__.py
  8. pyhealth/models/multi_view_contrastive_time_series_model.py
  9. pyhealth/models/mvcl_model.py
  10. pyhealth/tasks/__init__.py
  11. pyhealth/tasks/mvcl_training_sleepedf_task.py
  12. tests/core/test_mvcl.py
  13. tests/core/test_mvcl_training_sleepedf_task.py

Bug Fix

In pyhealth/datasets/sleepedf.py the dev flag wasn't being propagated to the base class.

heyan3 and others added 26 commits April 5, 2026 16:29
…amend pooling to include residuals; 4. add MHA to finetune
small updated in task; plus unit test/exmaples/api docs
@coderookie1994 coderookie1994 changed the title Contrastive learning Multi View Contrastive learning Apr 22, 2026
@coderookie1994 coderookie1994 marked this pull request as draft April 22, 2026 09:52
@coderookie1994 coderookie1994 marked this pull request as ready for review April 22, 2026 14:16
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Adds Multi-View Contrastive Learning (MVCL) support to PyHealth by introducing a SleepEDF EEG windowing task that produces time/derivative/frequency tensor views, along with two MVCL model implementations (a generic n-view model and a time-series-specific variant), plus documentation, an example notebook, and unit tests. Also fixes SleepEDFDataset so the dev flag propagates to the base dataset.

Changes:

  • Add MVCLTrainingSleepEEG task for SleepEDF that returns xt/xd/xf per fixed-length window, plus helpers for converting/loading {samples, labels} .pt data.
  • Add MVCL models: MultiViewContrastiveModel (generic n-view) and MultiViewContrastiveTimeSeriesModel (paper-oriented time-series variant), and export them via pyhealth.models.
  • Add docs, example notebook, and unit tests; fix SleepEDF dataset dev propagation.

Reviewed changes

Copilot reviewed 13 out of 13 changed files in this pull request and generated 9 comments.

Show a summary per file
File Description
docs/api/models/pyhealth.models.MVCL.rst Adds API docs stub for MultiViewContrastiveModel.
docs/api/tasks/pyhealth.tasks.MVCLTrainingSleepEEG.rst Adds API docs stub for MVCLTrainingSleepEEG.
docs/api/models.rst Registers MVCL docs page in model API reference.
docs/api/tasks.rst Registers MVCL SleepEDF task docs page in task API reference.
examples/mvcl_training_sleepedf.ipynb Provides an end-to-end example for dataset/task setup and MVCL pretraining.
pyhealth/datasets/sleepedf.py Fix: propagate dev flag to BaseDataset.
pyhealth/models/init.py Exposes new MVCL models via pyhealth.models.
pyhealth/models/multi_view_contrastive_time_series_model.py Implements time-series MVCL model (pretrain/finetune).
pyhealth/models/mvcl_model.py Implements generic n-view MVCL model (pretrain/finetune).
pyhealth/tasks/init.py Exposes MVCLTrainingSleepEEG and .pt helpers via pyhealth.tasks.
pyhealth/tasks/mvcl_training_sleepedf_task.py Adds SleepEDF MVCL task + NumPy preprocessing + .pt conversion/loading utilities.
tests/core/test_mvcl.py Adds unit tests for generic MVCL model forward paths.
tests/core/test_mvcl_training_sleepedf_task.py Adds task test using a mocked SleepEDF-like dummy dataset and patched MNE loaders.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines +76 to +81
self.mode = ""

# Disable inference metrics during pre-training
if self.training_stage == "pretrain":
self.mode = None

):
super().__init__(dataset=dataset)
self.hidden_dim = 128
self.training_stage = training_stage
Comment on lines +442 to +446
"""Load one `.pt` file and return a PyHealth SampleDataset."""
try:
tensor_dict = torch.load(pt_path, map_location="cpu", weights_only=False)
except TypeError:
tensor_dict = torch.load(pt_path, map_location="cpu")
@@ -0,0 +1,7 @@
pyhealth.models.MultiViewContrastiveModel
===================================
Comment on lines +22 to +73
"execution_count": 1,
"id": "943a666a",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"\u001b[2mUsing Python 3.13.7 environment at: C:\\Users\\shart\\workspace\\CS-598\\PyHealth\\.venv\u001b[0m\n",
"\u001b[2mChecked \u001b[1m1 package\u001b[0m \u001b[2min 182ms\u001b[0m\u001b[0m\n"
]
}
],
"source": [
"!uv pip install ipywidgets\n",
"\n",
"import os\n",
"\n",
"from pyhealth.datasets import SleepEDFDataset\n",
"from pyhealth.tasks import MVCLTrainingSleepEEG\n",
"\n",
"# Update this absolute path to your local Sleep-EDF root.\n",
"DATA_ROOT = \"C:\\\\Users\\\\shart\\\\workspace\\\\CS-598\\\\PyHealth\\\\sleepedf\"\n",
"assert os.path.exists(DATA_ROOT), f\"Sleep-EDF root path {DATA_ROOT} does not exist. Please update the path to your local Sleep-EDF root.\"\n",
"assert os.path.isabs(DATA_ROOT), f\"Sleep-EDF root path {DATA_ROOT} is not an absolute path. Please update to an absolute path.\""
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "866a8eb3",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"No config path provided, using default config\n",
"Initializing sleepedf dataset from C:\\Users\\shart\\workspace\\CS-598\\PyHealth\\sleepedf (dev mode: True)\n",
"No cache_dir provided. Using default cache dir: C:\\Users\\shart\\AppData\\Local\\pyhealth\\pyhealth\\Cache\\c8f0e13c-fb2e-5216-8969-8e6afcc7338c\n",
"Found cached event dataframe: C:\\Users\\shart\\AppData\\Local\\pyhealth\\pyhealth\\Cache\\c8f0e13c-fb2e-5216-8969-8e6afcc7338c\\global_event_df.parquet\n",
"Dataset: sleepedf\n",
"Dev mode: True\n",
"Number of patients: 78\n",
"Number of events: 153\n",
"Found 78 unique patient IDs\n",
"Number of patients: 78\n"
]
}
],
"source": [
"dataset = SleepEDFDataset(root=DATA_ROOT, subset=\"cassette\", dev=True)\n",
Comment on lines +25 to +33
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"\u001b[2mUsing Python 3.13.7 environment at: C:\\Users\\shart\\workspace\\CS-598\\PyHealth\\.venv\u001b[0m\n",
"\u001b[2mChecked \u001b[1m1 package\u001b[0m \u001b[2min 182ms\u001b[0m\u001b[0m\n"
]
}
Comment on lines +236 to +256
"# Factory functions and helpers required to setup the model.\n",
"import math\n",
"from pathlib import Path\n",
"from typing import Any, Dict, Dict, List, List, Mapping, Union\n",
"\n",
"import torch\n",
"import torch.nn as nn\n",
"\n",
"def augment_time(x: torch.Tensor, std: float = 0.1) -> torch.Tensor:\n",
" \"\"\"Time-domain jitter augmentation\"\"\"\n",
" noise = torch.randn_like(x) * std\n",
" return x + noise\n",
" \n",
"def augment_freq(sample: torch.Tensor, pertub_ratio: float = 0.05) -> torch.Tensor:\n",
" \"\"\"Frequency-domain augmentation (remove and add frequencies)\"\"\"\n",
" aug_1 = remove_frequency(sample, pertub_ratio)\n",
" aug_2 = add_frequency(sample, pertub_ratio)\n",
" return aug_1 + aug_2\n",
"\n",
"def remove_frequency(x: torch.Tensor, pertub_ratio: float = 0.0) -> torch.Tensor:\n",
" mask = torch.rand(x.shape, device=x.device) > pertub_ratio\n",
Comment on lines +21 to +32
import torch.fft as fft

from pyhealth.datasets.sample_dataset import create_sample_dataset
from pyhealth.tasks import BaseTask



def _map_to_MVCL_five_class(pyhealth_stage: int) -> int:
"""Map PyHealth 6-class staging to 5-class AASM-style (N3+N4 → deep)."""
return (0, 1, 2, 3, 3, 4)[int(pyhealth_stage)]


Comment on lines +73 to +82
def __init__(
self,
chunk_duration: float = 30.0,
window_size: int = 200,
crop_length: Optional[int] = 178,
eeg_channel: Optional[str] = "EEG Fpz-Cz",
time_as_feature: bool = False,
dx_backend: str = "cde",
root_path: Optional[Union[str, Path]] = None,
) -> None:
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants