Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/api/models.rst
Original file line number Diff line number Diff line change
Expand Up @@ -204,5 +204,6 @@ API Reference
models/pyhealth.models.VisionEmbeddingModel
models/pyhealth.models.TextEmbedding
models/pyhealth.models.BIOT
models/pyhealth.models.CBraMod_Wrapper
models/pyhealth.models.unified_multimodal_embedding_docs
models/pyhealth.models.califorest
67 changes: 67 additions & 0 deletions docs/api/models/pyhealth.models.cbramod.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
pyhealth.models.CBraMod_Wrapper
===================================

CBraMod model for EEG signal classification.

Overview
--------

CBraMod is a criss-cross attention transformer tailored for EEG decoding. The
wrapper integrates the model into the PyHealth ``BaseModel`` pipeline so it can
be trained with the standard ``Trainer`` APIs.

Input/Output
------------

- **Input:** ``signal`` tensor shaped ``(batch, channels, timesteps)`` where
``timesteps`` is a multiple of 200 (the patch size used by CBraMod).
- **Output (classifier_head=True):** dict with ``loss``, ``y_prob``, ``y_true``,
``logit``, and ``embeddings``.
- **Output (classifier_head=False):** dict with ``logit`` and ``embeddings``.

Example Usage
-------------

.. code-block:: python

import torch
from pyhealth.datasets import create_sample_dataset, get_dataloader
from pyhealth.models import CBraMod_Wrapper

n_channels = 16
patch_size = 200
n_patches = 10
n_samples = patch_size * n_patches

samples = [
{
"patient_id": f"patient-{i}",
"visit_id": "visit-0",
"signal": torch.randn(n_channels, n_samples).numpy().tolist(),
"label": i % 6,
}
for i in range(8)
]

dataset = create_sample_dataset(
samples=samples,
input_schema={"signal": "tensor"},
output_schema={"label": "multiclass"},
dataset_name="test_cbramod",
)

model = CBraMod_Wrapper(
dataset=dataset,
seq_len=n_patches,
n_classes=6,
classifier_head=True,
)

batch = next(iter(get_dataloader(dataset, batch_size=2, shuffle=True)))
output = model(**batch)
print(output["logit"].shape)

.. autoclass:: pyhealth.models.CBraMod_Wrapper
:members:
:undoc-members:
:show-inheritance:
4 changes: 3 additions & 1 deletion docs/tutorials.rst
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,9 @@ EEG and Sleep Analysis
* - ``EEG_events_SparcNet.py``
- SparcNet for EEG event detection
* - ``EEG_isAbnormal_SparcNet.py``
- SparcNet for EEG abnormality detection
- SparcNet for EEG abnormality detection
* - ``CBraMod_tuab_eeg_abnormal_classification.py``
- CBraMod for EEG abnormality detection on TUAB
* - ``cardiology_detection_isAR_SparcNet.py``
- SparcNet for cardiology arrhythmia detection

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
from pyhealth.datasets import TUABDataset, split_by_visit, get_dataloader
from pyhealth.tasks import EEGAbnormalTUAB
from pyhealth.models import CBraMod_Wrapper
from pyhealth.trainer import Trainer

# step 1: load signal data
dataset = TUABDataset(
root="/srv/local/data/TUH/tuh_eeg_abnormal/v3.0.0/edf/",
dev=True,
refresh_cache=True,
)
print(dataset.stats())

# step 2: set task (disable STFT for CBraMod)
TUAB_ds = dataset.set_task(
EEGAbnormalTUAB(
resample_rate=200,
bandpass_filter=(0.1, 75.0),
notch_filter=50.0,
compute_stft=False,
)
)

print(f"Total task samples: {len(TUAB_ds)}")
print(f"Input schema: {TUAB_ds.input_schema}")
print(f"Output schema: {TUAB_ds.output_schema}")

# Inspect a sample to infer sequence length
sample = TUAB_ds[0]
print(f"\nSample keys: {sample.keys()}")
print(f"Signal shape: {sample['signal'].shape}")
print(f"Label: {sample['label']}")

seq_len = sample["signal"].shape[-1] // 200

# split dataset
train_dataset, val_dataset, test_dataset = split_by_visit(
TUAB_ds, [0.6, 0.2, 0.2]
)
train_dataloader = get_dataloader(train_dataset, batch_size=16, shuffle=True)
val_dataloader = get_dataloader(val_dataset, batch_size=16, shuffle=False)
test_dataloader = get_dataloader(test_dataset, batch_size=16, shuffle=False)
print(
"loader size: train/val/test",
len(train_dataset),
len(val_dataset),
len(test_dataset),
)

# step 3: define model
model = CBraMod_Wrapper(
dataset=TUAB_ds,
seq_len=seq_len,
n_classes=2,
classifier_head=True,
)

# step 4: define trainer
trainer = Trainer(model=model, device="cuda:0")
trainer.train(
train_dataloader=train_dataloader,
val_dataloader=val_dataloader,
epochs=10,
optimizer_params={"lr": 1e-4},
)

# step 5: evaluate
print(trainer.evaluate(test_dataloader))
3 changes: 2 additions & 1 deletion pyhealth/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from .base_model import BaseModel
from .transformer_deid import TransformerDeID
from .biot import BIOT
from .cbramod import CBraMod_Wrapper
from .cnn import CNN, CNNLayer
from .concare import ConCare, ConCareLayer
from .contrawr import ContraWR, ResBlock2D
Expand Down Expand Up @@ -45,4 +46,4 @@
from .sdoh import SdohClassifier
from .medlink import MedLink
from .unified_embedding import UnifiedMultimodalEmbeddingModel, SinusoidalTimeEmbedding
from .califorest import CaliForest
from .califorest import CaliForest
Loading