Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
644 changes: 644 additions & 0 deletions examples/corgan_mimic3_colab.ipynb

Large diffs are not rendered by default.

34 changes: 34 additions & 0 deletions examples/corgan_mimic3_training.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
"""Train CorGAN on MIMIC-III diagnosis codes and save a checkpoint."""

# 1. Load MIMIC-III dataset
from pyhealth.datasets import MIMIC3Dataset
from pyhealth.datasets import split_by_patient
from pyhealth.tasks import corgan_generation_mimic3_fn
from pyhealth.models.generators.corgan import CorGAN

base_dataset = MIMIC3Dataset(
root="/path/to/mimic3",
tables=["diagnoses_icd"],
)

# 2. Apply generation task — flattens all ICD codes per patient into a bag-of-codes
sample_dataset = base_dataset.set_task(corgan_generation_mimic3_fn)
print(f"{len(sample_dataset)} patients after filtering")

# 3. Patient-level split — required for generative models to prevent data leakage across splits
train_dataset, val_dataset, _ = split_by_patient(sample_dataset, [0.8, 0.1, 0.1])

# 4. Instantiate and train — reduce epochs for testing; 50+ recommended for quality synthetic data
model = CorGAN(
dataset=sample_dataset,
latent_dim=128,
hidden_dim=128,
batch_size=128,
epochs=50,
lr=1e-4,
save_dir="./corgan_checkpoints/",
)
model.train_model(train_dataset, val_dataset)

# 5. Checkpoint is saved automatically to save_dir by train_model
print("Training complete. Checkpoint saved to ./corgan_checkpoints/")
42 changes: 42 additions & 0 deletions examples/generate_synthetic_mimic3_corgan.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
"""Generate synthetic MIMIC-III patient records using a trained CorGAN checkpoint."""
import json

from pyhealth.datasets import MIMIC3Dataset
from pyhealth.tasks import corgan_generation_mimic3_fn
from pyhealth.models.generators.corgan import CorGAN

# Update this to your local MIMIC-III path before running
MIMIC3_ROOT = "/path/to/mimic3"

# 1. Reconstruct dataset — required to initialise CorGAN's vocabulary from the processor.
# If you trained with different tables=, update this to match exactly.
base_dataset = MIMIC3Dataset(
root=MIMIC3_ROOT,
tables=["diagnoses_icd"],
)
sample_dataset = base_dataset.set_task(corgan_generation_mimic3_fn)

# 2. Instantiate model (epochs and training params are unused during generation;
# they must match your training configuration for checkpoint compatibility).
model = CorGAN(
dataset=sample_dataset,
latent_dim=128,
hidden_dim=128,
batch_size=128,
epochs=50,
save_dir="./corgan_checkpoints/",
)

# 3. Load trained checkpoint
model.load_model("./corgan_checkpoints/corgan_final.pt")

# 4. Generate synthetic patients — each patient is a flat bag-of-codes (no visit structure)
synthetic = model.synthesize_dataset(num_samples=10000)
print(f"Generated {len(synthetic)} synthetic patients")
print(f"Example record: {synthetic[0]}")

# 5. Save to JSON
output_path = "synthetic_corgan_10k.json"
with open(output_path, "w") as f:
json.dump(synthetic, f, indent=2)
print(f"Saved to {output_path}")
1 change: 1 addition & 0 deletions pyhealth/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,3 +43,4 @@
from .sdoh import SdohClassifier
from .medlink import MedLink
from .unified_embedding import UnifiedMultimodalEmbeddingModel, SinusoidalTimeEmbedding
from .generators import CorGAN
3 changes: 3 additions & 0 deletions pyhealth/models/generators/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .corgan import CorGAN

__all__ = ["CorGAN"]
Loading