From 090caa4b85e15af6a2fb55258c19e8607afeeaf0 Mon Sep 17 00:00:00 2001 From: jalengg Date: Wed, 25 Feb 2026 02:12:43 -0600 Subject: [PATCH 01/13] Add CorGAN baseline from corgan-medgan-port (pre-integration) --- examples/generate_synthetic_mimic3_corgan.py | 161 ++++ pyhealth/models/__init__.py | 1 + pyhealth/models/generators/__init__.py | 4 + pyhealth/models/generators/corgan.py | 851 +++++++++++++++++++ 4 files changed, 1017 insertions(+) create mode 100644 examples/generate_synthetic_mimic3_corgan.py create mode 100644 pyhealth/models/generators/__init__.py create mode 100644 pyhealth/models/generators/corgan.py diff --git a/examples/generate_synthetic_mimic3_corgan.py b/examples/generate_synthetic_mimic3_corgan.py new file mode 100644 index 000000000..a3d8538fa --- /dev/null +++ b/examples/generate_synthetic_mimic3_corgan.py @@ -0,0 +1,161 @@ +#!/usr/bin/env python3 +""" +Generate synthetic MIMIC-III patients using a trained CorGAN checkpoint. +Uses Variable top-K sampling to maintain natural variation in code counts. +""" + +import os +import sys +sys.path.insert(0, '/u/jalenj4/PyHealth-Medgan-Corgan-Port') +import argparse +import torch +import numpy as np +import pandas as pd +from pyhealth.models.generators.corgan import CorGANAutoencoder, CorGAN8LayerAutoencoder, CorGANGenerator, CorGANDiscriminator + + +def main(): + parser = argparse.ArgumentParser(description="Generate synthetic patients using trained CorGAN") + parser.add_argument("--checkpoint", required=True, help="Path to trained CorGAN checkpoint (.pth)") + parser.add_argument("--vocab", required=True, help="Path to ICD-9 vocabulary file (.txt)") + parser.add_argument("--binary_matrix", required=True, help="Path to training binary matrix (.npy)") + parser.add_argument("--output", required=True, help="Path to output CSV file") + parser.add_argument("--n_samples", type=int, default=10000, help="Number of synthetic patients to generate") + parser.add_argument("--mean_k", type=float, default=13, help="Mean K for Variable top-K sampling") + parser.add_argument("--std_k", type=float, default=5, help="Std dev K for Variable top-K sampling") + args = parser.parse_args() + + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + print(f"Using device: {device}") + + # Load vocabulary + print(f"Loading vocabulary from {args.vocab}") + with open(args.vocab, 'r') as f: + code_vocab = [line.strip() for line in f] + print(f"Loaded {len(code_vocab)} ICD-9 codes") + + # Load binary matrix to get architecture dimensions + print(f"Loading binary matrix from {args.binary_matrix}") + binary_matrix = np.load(args.binary_matrix) + n_codes = binary_matrix.shape[1] + print(f"Binary matrix shape: {binary_matrix.shape}") + print(f"Real data avg codes/patient: {binary_matrix.sum(axis=1).mean():.2f}") + + # Load checkpoint + print(f"\nLoading checkpoint from {args.checkpoint}") + checkpoint = torch.load(args.checkpoint, map_location=device, weights_only=False) + + # Detect architecture from checkpoint + # Check if this is 8-layer by looking at state dict keys + state_keys = checkpoint['autoencoder_state_dict'].keys() + is_8layer = any('encoder.18' in k or 'encoder.21' in k for k in state_keys) # 8-layer has more layers + + # Initialize CorGAN components with correct architecture + print("Initializing CorGAN model components...") + if n_codes == 6955: + if is_8layer: + autoencoder = CorGAN8LayerAutoencoder(feature_size=n_codes).to(device) + print("Detected 8-layer architecture") + else: + # Assume adaptive pooling + autoencoder = CorGANAutoencoder( + feature_size=n_codes, + use_adaptive_pooling=True + ).to(device) + print("Detected 6-layer + adaptive pooling architecture") + else: + autoencoder = CorGANAutoencoder( + feature_size=n_codes, + use_adaptive_pooling=False + ).to(device) + print(f"Using standard 6-layer architecture for {n_codes} codes") + + generator = CorGANGenerator(latent_dim=128, hidden_dim=128).to(device) + discriminator = CorGANDiscriminator(input_dim=n_codes, hidden_dim=256).to(device) + + # Load trained weights + autoencoder.load_state_dict(checkpoint['autoencoder_state_dict']) + generator.load_state_dict(checkpoint['generator_state_dict']) + discriminator.load_state_dict(checkpoint['discriminator_state_dict']) + + autoencoder.eval() + generator.eval() + discriminator.eval() + print("Model loaded successfully") + + # Generate synthetic patients + print(f"\nGenerating {args.n_samples} synthetic patients...") + + with torch.no_grad(): + # Generate random noise + z = torch.randn(args.n_samples, 128, device=device) + # Generate latent codes + generated_latent = generator(z) + # Decode to probabilities + synthetic_probs = autoencoder.decode(generated_latent) + + # Trim or pad if needed + if synthetic_probs.shape[1] > n_codes: + synthetic_probs = synthetic_probs[:, :n_codes] + elif synthetic_probs.shape[1] < n_codes: + padding = torch.zeros(synthetic_probs.shape[0], n_codes - synthetic_probs.shape[1], device=device) + synthetic_probs = torch.cat([synthetic_probs, padding], dim=1) + + probs = synthetic_probs.cpu().numpy() + + # Apply Variable top-K sampling + print(f"Applying Variable top-K sampling (μ={args.mean_k}, σ={args.std_k})...") + binary_matrix_synthetic = np.zeros_like(probs) + + for i in range(args.n_samples): + # Sample K from normal distribution, clip to reasonable range + k = int(np.clip(np.random.normal(args.mean_k, args.std_k), 1, 50)) + # Get indices of top-K probabilities + top_k_indices = np.argsort(probs[i])[-k:] + binary_matrix_synthetic[i, top_k_indices] = 1 + + # Calculate statistics + avg_codes = binary_matrix_synthetic.sum(axis=1).mean() + std_codes = binary_matrix_synthetic.sum(axis=1).std() + min_codes = binary_matrix_synthetic.sum(axis=1).min() + max_codes = binary_matrix_synthetic.sum(axis=1).max() + sparsity = (binary_matrix_synthetic == 0).mean() + + print(f"\nSynthetic data statistics:") + print(f" Avg codes per patient: {avg_codes:.2f} ± {std_codes:.2f}") + print(f" Range: [{min_codes:.0f}, {max_codes:.0f}]") + print(f" Sparsity: {sparsity:.4f}") + + # Check heterogeneity + unique_profiles = len(set(tuple(row) for row in binary_matrix_synthetic)) + print(f" Unique patient profiles: {unique_profiles}/{args.n_samples} ({unique_profiles/args.n_samples*100:.1f}%)") + + # Convert to CSV format (SUBJECT_ID, ICD9_CODE) + print(f"\nConverting to CSV format...") + records = [] + for patient_idx in range(args.n_samples): + patient_id = f"SYNTHETIC_{patient_idx+1:06d}" + code_indices = np.where(binary_matrix_synthetic[patient_idx] == 1)[0] + + for code_idx in code_indices: + records.append({ + 'SUBJECT_ID': patient_id, + 'ICD9_CODE': code_vocab[code_idx] + }) + + df = pd.DataFrame(records) + print(f"Created {len(df)} diagnosis records for {args.n_samples} patients") + + # Save to CSV + print(f"\nSaving to {args.output}") + df.to_csv(args.output, index=False) + + file_size_mb = os.path.getsize(args.output) / (1024 * 1024) + print(f"Saved {file_size_mb:.1f} MB") + + print("\n✓ Generation complete!") + print(f"Output: {args.output}") + + +if __name__ == '__main__': + main() diff --git a/pyhealth/models/__init__.py b/pyhealth/models/__init__.py index a13b18a51..834159181 100644 --- a/pyhealth/models/__init__.py +++ b/pyhealth/models/__init__.py @@ -43,3 +43,4 @@ from .sdoh import SdohClassifier from .medlink import MedLink from .unified_embedding import UnifiedMultimodalEmbeddingModel, SinusoidalTimeEmbedding +from .generators import CorGAN diff --git a/pyhealth/models/generators/__init__.py b/pyhealth/models/generators/__init__.py new file mode 100644 index 000000000..e27bb8993 --- /dev/null +++ b/pyhealth/models/generators/__init__.py @@ -0,0 +1,4 @@ +from .corgan import CorGAN +from .plasmode import PlasMode + +__all__ = ["CorGAN", "PlasMode"] diff --git a/pyhealth/models/generators/corgan.py b/pyhealth/models/generators/corgan.py new file mode 100644 index 000000000..8acfe5297 --- /dev/null +++ b/pyhealth/models/generators/corgan.py @@ -0,0 +1,851 @@ +import functools +from typing import Dict, List, Optional, Tuple, Union +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.utils.data import DataLoader, Dataset +import random +import time + +from pyhealth.datasets import BaseDataset +from pyhealth.models import BaseModel +from pyhealth.tokenizer import Tokenizer + + +class CorGANDataset(Dataset): + """Dataset wrapper for CorGAN training""" + + def __init__(self, data, transform=None): + self.transform = transform + self.data = data.astype(np.float32) + self.sampleSize = data.shape[0] + self.featureSize = data.shape[1] + + def return_data(self): + return self.data + + def __len__(self): + return len(self.data) + + def __getitem__(self, idx): + if torch.is_tensor(idx): + idx = idx.tolist() + + sample = self.data[idx] + sample = np.clip(sample, 0, 1) + + if self.transform: + pass + + return torch.from_numpy(sample) + + +class CorGANAutoencoder(nn.Module): + """Autoencoder for CorGAN - uses 1D convolutions to capture correlations""" + + def __init__(self, feature_size: int, latent_dim: int = 128, use_adaptive_pooling: bool = False): + super(CorGANAutoencoder, self).__init__() + self.feature_size = feature_size + self.latent_dim = latent_dim + self.use_adaptive_pooling = use_adaptive_pooling + n_channels_base = 4 + + # calculate the size after convolutions + # input: (batch, 1, feature_size) + # conv1: kernel=5, stride=2 -> (batch, 4, (feature_size-4)//2) + # conv2: kernel=5, stride=2 -> (batch, 8, ((feature_size-4)//2-4)//2) + # conv3: kernel=5, stride=3 -> (batch, 16, (((feature_size-4)//2-4)//2-4)//3) + # conv4: kernel=5, stride=3 -> (batch, 32, ((((feature_size-4)//2-4)//2-4)//3-4)//3) + # conv5: kernel=5, stride=3 -> (batch, 64, (((((feature_size-4)//2-4)//2-4)//3-4)//3-4)//3) + # conv6: kernel=8, stride=1 -> (batch, 128, ((((((feature_size-4)//2-4)//2-4)//3-4)//3-4)//3-7)) + + # rough estimate for latent size + latent_size = max(1, feature_size // 100) # ensure at least 1 + + self.encoder = nn.Sequential( + nn.Conv1d(in_channels=1, out_channels=n_channels_base, kernel_size=5, stride=2, padding=0, dilation=1, + groups=1, bias=True, padding_mode='zeros'), + nn.LeakyReLU(0.2, inplace=True), + nn.Conv1d(in_channels=n_channels_base, out_channels=2 * n_channels_base, kernel_size=5, stride=2, padding=0, + dilation=1, groups=1, bias=True, padding_mode='zeros'), + nn.BatchNorm1d(2 * n_channels_base), + nn.LeakyReLU(0.2, inplace=True), + nn.Conv1d(in_channels=2 * n_channels_base, out_channels=4 * n_channels_base, kernel_size=5, stride=3, + padding=0, dilation=1, groups=1, bias=True, padding_mode='zeros'), + nn.BatchNorm1d(4 * n_channels_base), + nn.LeakyReLU(0.2, inplace=True), + nn.Conv1d(in_channels=4 * n_channels_base, out_channels=8 * n_channels_base, kernel_size=5, stride=3, + padding=0, dilation=1, groups=1, bias=True, padding_mode='zeros'), + nn.BatchNorm1d(8 * n_channels_base), + nn.LeakyReLU(0.2, inplace=True), + nn.Conv1d(in_channels=8 * n_channels_base, out_channels=16 * n_channels_base, kernel_size=5, stride=3, + padding=0, dilation=1, groups=1, bias=True, padding_mode='zeros'), + nn.BatchNorm1d(16 * n_channels_base), + nn.LeakyReLU(0.2, inplace=True), + nn.Conv1d(in_channels=16 * n_channels_base, out_channels=32 * n_channels_base, kernel_size=8, stride=1, + padding=0, dilation=1, groups=1, bias=True, padding_mode='zeros'), + nn.Tanh(), + ) + + # decoder - exact match to synthEHRella (wgancnnmimic.py lines 200-228) + # Kernel sizes: [5, 5, 7, 7, 7, 3] + # Strides: [1, 4, 4, 3, 2, 2] + # Activations: ReLU (not LeakyReLU) + # Note: First layer has NO BatchNorm + decoder_layers = [ + nn.ConvTranspose1d(in_channels=32 * n_channels_base, out_channels=16 * n_channels_base, kernel_size=5, stride=1, + padding=0, dilation=1, groups=1, bias=True, padding_mode='zeros'), + nn.ReLU(), + nn.ConvTranspose1d(in_channels=16 * n_channels_base, out_channels=8 * n_channels_base, kernel_size=5, stride=4, + padding=0, dilation=1, groups=1, bias=True, padding_mode='zeros'), + nn.BatchNorm1d(8 * n_channels_base), + nn.ReLU(), + nn.ConvTranspose1d(in_channels=8 * n_channels_base, out_channels=4 * n_channels_base, kernel_size=7, stride=4, + padding=0, dilation=1, groups=1, bias=True, padding_mode='zeros'), + nn.BatchNorm1d(4 * n_channels_base), + nn.ReLU(), + nn.ConvTranspose1d(in_channels=4 * n_channels_base, out_channels=2 * n_channels_base, kernel_size=7, stride=3, + padding=0, dilation=1, groups=1, bias=True, padding_mode='zeros'), + nn.BatchNorm1d(2 * n_channels_base), + nn.ReLU(), + nn.ConvTranspose1d(in_channels=2 * n_channels_base, out_channels=n_channels_base, kernel_size=7, stride=2, + padding=0, dilation=1, groups=1, bias=True, padding_mode='zeros'), + nn.BatchNorm1d(n_channels_base), + nn.ReLU(), + nn.ConvTranspose1d(in_channels=n_channels_base, out_channels=1, kernel_size=3, stride=2, + padding=0, dilation=1, groups=1, bias=True, padding_mode='zeros'), + ] + + # Add adaptive pooling if enabled (for variable vocabulary sizes) + if self.use_adaptive_pooling: + decoder_layers.append(nn.AdaptiveAvgPool1d(output_size=feature_size)) + + decoder_layers.append(nn.Sigmoid()) + + self.decoder = nn.Sequential(*decoder_layers) + + def forward(self, x): + # add channel dimension if needed + if len(x.shape) == 2: + x = x.unsqueeze(1) # (batch, 1, features) + encoded = self.encoder(x) + decoded = self.decoder(encoded) + # Squeeze only the channel dimension (dim=1), not the batch dimension + if decoded.dim() == 3 and decoded.shape[1] == 1: + decoded = decoded.squeeze(1) + return decoded + + def decode(self, x): + # x shape: (batch, 128) from generator - unsqueeze for CNN decoder + if x.dim() == 2: + x = x.unsqueeze(2) # (batch, 128, 1) + decoded = self.decoder(x) # (batch, 1, output_len) + if decoded.dim() == 3 and decoded.shape[1] == 1: + decoded = decoded.squeeze(1) # (batch, output_len) + return decoded + + +class CorGAN8LayerAutoencoder(nn.Module): + """ + 8-Layer CNN Autoencoder for CorGAN - designed for 6,955 codes. + + Extends the original 6-layer architecture to support larger vocabularies + without adaptive pooling. The encoder compresses 6,955 codes down to a + latent space of size (128, 1), then the decoder reconstructs exactly 6,955. + + This is an experimental architecture designed to test whether native + dimension matching (no adaptive pooling) produces better synthetic data + quality compared to the 6-layer + adaptive pooling approach. + + Args: + feature_size: Must be 6955 (architecture is hardcoded for this size) + latent_dim: Latent dimension (default: 128) + """ + + def __init__(self, feature_size: int = 6955, latent_dim: int = 128): + super(CorGAN8LayerAutoencoder, self).__init__() + assert feature_size == 6955, "8-layer architecture only supports 6955 codes" + + self.feature_size = feature_size + self.latent_dim = latent_dim + + # Encoder: 6955 → 1 (8 layers) + self.encoder = nn.Sequential( + # Layer 1: 6955 → 3476 + nn.Conv1d(1, 4, kernel_size=5, stride=2, padding=0), + nn.LeakyReLU(0.2, inplace=True), + + # Layer 2: 3476 → 1736 + nn.Conv1d(4, 8, kernel_size=5, stride=2, padding=0), + nn.BatchNorm1d(8), + nn.LeakyReLU(0.2, inplace=True), + + # Layer 3: 1736 → 578 + nn.Conv1d(8, 16, kernel_size=5, stride=3, padding=0), + nn.BatchNorm1d(16), + nn.LeakyReLU(0.2, inplace=True), + + # Layer 4: 578 → 192 + nn.Conv1d(16, 32, kernel_size=5, stride=3, padding=0), + nn.BatchNorm1d(32), + nn.LeakyReLU(0.2, inplace=True), + + # Layer 5: 192 → 63 + nn.Conv1d(32, 64, kernel_size=5, stride=3, padding=0), + nn.BatchNorm1d(64), + nn.LeakyReLU(0.2, inplace=True), + + # Layer 6: 63 → 20 [NEW] + nn.Conv1d(64, 96, kernel_size=5, stride=3, padding=0), + nn.BatchNorm1d(96), + nn.LeakyReLU(0.2, inplace=True), + + # Layer 7: 20 → 4 [NEW] + nn.Conv1d(96, 112, kernel_size=5, stride=4, padding=0), + nn.BatchNorm1d(112), + nn.LeakyReLU(0.2, inplace=True), + + # Layer 8: 4 → 1 [NEW] + nn.Conv1d(112, 128, kernel_size=4, stride=1, padding=0), + nn.Tanh(), + ) + + # Decoder: 1 → 6955 (8 layers) + self.decoder = nn.Sequential( + # Layer 1: 1 → 4 (NO BatchNorm on first layer) + nn.ConvTranspose1d(128, 112, kernel_size=4, stride=1, padding=0), + nn.ReLU(), + + # Layer 2: 4 → 20 + nn.ConvTranspose1d(112, 96, kernel_size=8, stride=4, padding=0), + nn.BatchNorm1d(96), + nn.ReLU(), + + # Layer 3: 20 → 63 + nn.ConvTranspose1d(96, 64, kernel_size=6, stride=3, padding=0), + nn.BatchNorm1d(64), + nn.ReLU(), + + # Layer 4: 63 → 192 + nn.ConvTranspose1d(64, 32, kernel_size=6, stride=3, padding=0), + nn.BatchNorm1d(32), + nn.ReLU(), + + # Layer 5: 192 → 578 + nn.ConvTranspose1d(32, 16, kernel_size=5, stride=3, padding=0), + nn.BatchNorm1d(16), + nn.ReLU(), + + # Layer 6: 578 → 1736 + nn.ConvTranspose1d(16, 8, kernel_size=5, stride=3, padding=0), + nn.BatchNorm1d(8), + nn.ReLU(), + + # Layer 7: 1736 → 3476 + nn.ConvTranspose1d(8, 4, kernel_size=6, stride=2, padding=0), + nn.BatchNorm1d(4), + nn.ReLU(), + + # Layer 8: 3476 → 6955 + nn.ConvTranspose1d(4, 1, kernel_size=5, stride=2, padding=0), + nn.Sigmoid(), + ) + + def forward(self, x): + if len(x.shape) == 2: + x = x.unsqueeze(1) # (batch, 1, features) + encoded = self.encoder(x) + decoded = self.decoder(encoded) + # Squeeze only channel dimension + if decoded.dim() == 3 and decoded.shape[1] == 1: + decoded = decoded.squeeze(1) + return decoded + + def decode(self, x): + """Decode latent representation from generator.""" + if x.dim() == 2: + x = x.unsqueeze(2) # (batch, 128, 1) + decoded = self.decoder(x) + if decoded.dim() == 3 and decoded.shape[1] == 1: + decoded = decoded.squeeze(1) + return decoded + + +class CorGANLinearAutoencoder(nn.Module): + """ + Linear autoencoder for CorGAN - simpler than CNN, appropriate for unordered codes. + + This variant replaces the CNN autoencoder with a simple linear architecture, + which is more appropriate for unordered medical codes (ICD-9) where spatial + locality doesn't exist. Based on: + - SynthEHRella's commented linear decoder alternative (line 229 in wgancnnmimic.py) + - MedGAN's proven linear architecture (achieves 10.66 codes/patient) + - Simpler gradient flow to address mode collapse + + The core CorGAN components are preserved: + - WGAN training with Wasserstein loss + - Generator with residual connections + - Discriminator with minibatch averaging + + This architecture is referred to as "CorGAN-Linear" to distinguish it from + the original CNN-based CorGAN while maintaining the core WGAN design. + """ + + def __init__(self, feature_size: int, latent_dim: int = 128): + super(CorGANLinearAutoencoder, self).__init__() + self.feature_size = feature_size + self.latent_dim = latent_dim + + # Encoder: feature_size → latent_dim + # Use ReLU+BatchNorm (V11 achieved 4.49 codes, best linear result) + self.encoder = nn.Sequential( + nn.Linear(feature_size, latent_dim), + nn.ReLU(), + nn.BatchNorm1d(latent_dim) + ) + + # Decoder: latent_dim → feature_size + self.decoder = nn.Sequential( + nn.Linear(latent_dim, feature_size), + nn.Sigmoid() + ) + + def forward(self, x): + """ + Forward pass for autoencoder training. + + Args: + x: Input tensor of shape (batch, feature_size) + + Returns: + Decoded tensor of shape (batch, feature_size) + """ + encoded = self.encoder(x) + decoded = self.decoder(encoded) + return decoded + + def decode(self, x): + """ + Decode latent representation from generator. + + Args: + x: Latent tensor from generator of shape (batch, latent_dim) + + Returns: + Decoded tensor of shape (batch, feature_size) + """ + return self.decoder(x) + + +class CorGANGenerator(nn.Module): + """ + Generator for CorGAN - MLP with residual connections + + Architecture matches synthEHRella exactly (wgancnnmimic.py lines 242-263) + """ + + def __init__(self, latent_dim: int = 128, hidden_dim: int = 128): + super(CorGANGenerator, self).__init__() + self.latent_dim = latent_dim + self.hidden_dim = hidden_dim + + # Layer 1 + self.linear1 = nn.Linear(latent_dim, hidden_dim) + self.bn1 = nn.BatchNorm1d(hidden_dim, eps=0.001, momentum=0.01) + self.activation1 = nn.ReLU() + + # Layer 2 + self.linear2 = nn.Linear(hidden_dim, hidden_dim) + self.bn2 = nn.BatchNorm1d(hidden_dim, eps=0.001, momentum=0.01) + self.activation2 = nn.Tanh() + + def forward(self, x): + # Layer 1 with residual connection + residual = x + temp = self.activation1(self.bn1(self.linear1(x))) + out1 = temp + residual + + # Layer 2 with residual connection + residual = out1 + temp = self.activation2(self.bn2(self.linear2(out1))) + out2 = temp + residual + + return out2 + + +class CorGANDiscriminator(nn.Module): + """ + Discriminator for CorGAN - MLP with minibatch averaging + + Architecture matches synthEHRella exactly (wgancnnmimic.py lines 265-296): + - 4 linear layers: input → 256 → 256 → 256 → 1 + - ReLU activations + - No sigmoid (WGAN uses unbounded critic outputs) + """ + + def __init__(self, input_dim: int, hidden_dim: int = 256, minibatch_averaging: bool = True): + super(CorGANDiscriminator, self).__init__() + self.input_dim = input_dim + self.hidden_dim = hidden_dim + self.minibatch_averaging = minibatch_averaging + + # adjust input dimension for minibatch averaging + ma_coef = 1 + if minibatch_averaging: + ma_coef = ma_coef * 2 + model_input_dim = ma_coef * input_dim + + # 4-layer architecture matching synthEHRella exactly + self.model = nn.Sequential( + nn.Linear(model_input_dim, self.hidden_dim), + nn.ReLU(True), + nn.Linear(self.hidden_dim, int(self.hidden_dim)), + nn.ReLU(True), + nn.Linear(self.hidden_dim, int(self.hidden_dim)), + nn.ReLU(True), + nn.Linear(int(self.hidden_dim), 1) + # No sigmoid - WGAN uses unbounded critic outputs + ) + + def forward(self, x): + if self.minibatch_averaging: + # minibatch averaging: concatenate batch mean to each sample + x_mean = torch.mean(x, dim=0).repeat(x.shape[0], 1) + x = torch.cat((x, x_mean), dim=1) + + output = self.model(x) + return output + + +def weights_init(m): + """ + Custom weight initialization (synthEHRella implementation) + + Reference: synthEHRella wgancnnmimic.py lines 363-377 + """ + classname = m.__class__.__name__ + if classname.find('Conv') != -1: + nn.init.normal_(m.weight.data, 0.0, 0.02) + elif classname.find('BatchNorm') != -1: + nn.init.normal_(m.weight.data, 1.0, 0.02) + nn.init.constant_(m.bias.data, 0) + if type(m) == nn.Linear: + torch.nn.init.xavier_uniform_(m.weight) + m.bias.data.fill_(0.01) + + +def autoencoder_loss(x_output, y_target): + """ + Autoencoder reconstruction loss (synthEHRella implementation) + + This implementation is equivalent to torch.nn.BCELoss(reduction='sum') / batch_size + As our matrix is too sparse, first we will take a sum over the features and then + do the mean over the batch. + + WARNING: This is NOT equivalent to torch.nn.BCELoss(reduction='mean') as the latter + means over both features and batches. + + Reference: synthEHRella wgancnnmimic.py lines 312-323 + """ + epsilon = 1e-12 + term = y_target * torch.log(x_output + epsilon) + (1. - y_target) * torch.log(1. - x_output + epsilon) + loss = torch.mean(-torch.sum(term, 1), 0) + return loss + + +def discriminator_accuracy(predicted, y_true): + """Calculate discriminator accuracy""" + predicted = (predicted >= 0.5).float() + accuracy = (predicted == y_true).float().mean() + return accuracy.item() + + +class CorGAN(BaseModel): + """ + CorGAN: Correlation-capturing Generative Adversarial Network + + Uses CNNs to capture correlations between adjacent medical features by combining + Convolutional GANs with Convolutional Autoencoders. + + Args: + dataset: PyHealth dataset object + feature_keys: List of feature keys to use + label_key: Label key (not used in unsupervised generation) + mode: Training mode (not used in GAN context) + latent_dim: Dimensionality of latent space + hidden_dim: Hidden dimension for networks + batch_size: Training batch size + n_epochs: Number of training epochs + n_epochs_pretrain: Number of autoencoder pretraining epochs + lr: Learning rate + weight_decay: Weight decay for optimization + b1: Beta1 for Adam optimizer + b2: Beta2 for Adam optimizer + n_iter_D: Number of discriminator iterations per generator iteration + clamp_lower: Lower bound for weight clipping + clamp_upper: Upper bound for weight clipping + minibatch_averaging: Whether to use minibatch averaging in discriminator + **kwargs: Additional arguments + + Examples: + >>> from pyhealth.datasets import MIMIC3Dataset + >>> dataset = MIMIC3Dataset(...) + >>> model = CorGAN(dataset=dataset, feature_keys=["conditions"]) + >>> model.fit() + >>> synthetic_data = model.generate(n_samples=50000) + """ + + def __init__( + self, + dataset: BaseDataset, + feature_keys: List[str], + label_key: str, + mode: str = "generation", + latent_dim: int = 128, + hidden_dim: int = 128, + batch_size: int = 512, + n_epochs: int = 1000, + n_epochs_pretrain: int = 1, + lr: float = 0.001, + weight_decay: float = 0.0001, + b1: float = 0.9, + b2: float = 0.999, + n_iter_D: int = 5, + clamp_lower: float = -0.01, + clamp_upper: float = 0.01, + minibatch_averaging: bool = True, + **kwargs + ): + super(CorGAN, self).__init__( + dataset=dataset, + feature_keys=feature_keys, + label_key=label_key, + mode=mode, + **kwargs + ) + + self.latent_dim = latent_dim + self.hidden_dim = hidden_dim + self.batch_size = batch_size + self.n_epochs = n_epochs + self.n_epochs_pretrain = n_epochs_pretrain + self.lr = lr + self.weight_decay = weight_decay + self.b1 = b1 + self.b2 = b2 + self.n_iter_D = n_iter_D + self.clamp_lower = clamp_lower + self.clamp_upper = clamp_upper + self.minibatch_averaging = minibatch_averaging + + # build unified vocabulary for all feature keys + self.global_vocab = self._build_global_vocab(dataset, feature_keys) + self.input_dim = len(self.global_vocab) + self.tokenizer = Tokenizer(tokens=self.global_vocab, special_tokens=[]) + + # initialize components + # Determine if adaptive pooling is needed (only for non-standard vocabulary sizes) + use_adaptive_pooling = (self.input_dim != 1071) + self.autoencoder = CorGANAutoencoder( + feature_size=self.input_dim, + latent_dim=latent_dim, + use_adaptive_pooling=use_adaptive_pooling + ) + self.autoencoder_decoder = self.autoencoder.decoder # separate decoder for generator + + self.generator = CorGANGenerator( + latent_dim=latent_dim, + hidden_dim=hidden_dim + ) + + self.discriminator = CorGANDiscriminator( + input_dim=self.input_dim, + hidden_dim=256, # Match synthEHRella exactly (not hidden_dim * 2) + minibatch_averaging=minibatch_averaging + ) + + # apply custom weight initialization + self._init_weights() + + # setup device + self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + self.to(self.device) + + # setup optimizers + g_params = [ + {'params': self.generator.parameters()}, + {'params': self.autoencoder_decoder.parameters(), 'lr': 1e-4} + ] + self.optimizer_G = torch.optim.Adam(g_params, lr=lr, betas=(b1, b2), weight_decay=weight_decay) + self.optimizer_D = torch.optim.Adam(self.discriminator.parameters(), lr=lr, betas=(b1, b2), weight_decay=weight_decay) + self.optimizer_A = torch.optim.Adam(self.autoencoder.parameters(), lr=lr, betas=(b1, b2), weight_decay=weight_decay) + + # setup tensors for training + self.one = torch.tensor(1.0, device=self.device) + self.mone = torch.tensor(-1.0, device=self.device) + + def _build_global_vocab(self, dataset: BaseDataset, feature_keys: List[str]) -> List[str]: + """Build unified vocabulary across all feature keys""" + global_vocab = set() + + # collect all unique codes from all patients and feature keys + for patient_id in dataset.patients: + patient = dataset.patients[patient_id] + for feature_key in feature_keys: + if feature_key in patient: + for visit in patient[feature_key]: + if isinstance(visit, list): + global_vocab.update(visit) + else: + global_vocab.add(visit) + + return sorted(list(global_vocab)) + + def _encode_patient_record(self, record: Dict) -> torch.Tensor: + """Encode a patient record to binary vector""" + # create binary vector + binary_vector = np.zeros(self.input_dim, dtype=np.float32) + + for feature_key in self.feature_keys: + if feature_key in record: + for visit in record[feature_key]: + if isinstance(visit, list): + for code in visit: + if code in self.global_vocab: + idx = self.global_vocab.index(code) + binary_vector[idx] = 1.0 + else: + if visit in self.global_vocab: + idx = self.global_vocab.index(visit) + binary_vector[idx] = 1.0 + + return torch.from_numpy(binary_vector) + + def _init_weights(self): + """Initialize network weights""" + self.generator.apply(weights_init) + self.discriminator.apply(weights_init) + self.autoencoder.apply(weights_init) + + def _extract_features_from_batch(self, batch_data, device: torch.device) -> torch.Tensor: + """Extract features from batch data""" + features = [] + for patient_id in batch_data: + patient = self.dataset.patients[patient_id] + feature_vector = self._encode_patient_record(patient) + features.append(feature_vector) + + return torch.stack(features).to(device) + + def forward(self, **kwargs) -> Dict[str, torch.Tensor]: + """Forward pass - not used in GAN context""" + raise NotImplementedError("Forward pass not implemented for GAN models") + + def fit(self, train_dataloader: Optional[DataLoader] = None): + """Train the CorGAN model""" + print("Starting CorGAN training...") + + # create dataset and dataloader + if train_dataloader is None: + # create binary matrix from dataset + data_matrix = [] + for patient_id in self.dataset.patients: + patient = self.dataset.patients[patient_id] + feature_vector = self._encode_patient_record(patient) + data_matrix.append(feature_vector.numpy()) + + data_matrix = np.array(data_matrix) + dataset = CorGANDataset(data=data_matrix) + + sampler = torch.utils.data.sampler.RandomSampler( + data_source=dataset, replacement=True + ) + train_dataloader = DataLoader( + dataset, + batch_size=self.batch_size, + shuffle=False, + num_workers=0, + drop_last=True, + sampler=sampler + ) + + # pretrain autoencoder + print(f"Pretraining autoencoder for {self.n_epochs_pretrain} epochs...") + for epoch_pre in range(self.n_epochs_pretrain): + for i, samples in enumerate(train_dataloader): + # configure input + real_samples = samples.to(self.device) + + # generate a batch of images + recons_samples = self.autoencoder(real_samples) + + # loss measures autoencoder's ability to reconstruct + a_loss = autoencoder_loss(recons_samples, real_samples) + + # reset gradients + self.optimizer_A.zero_grad() + a_loss.backward() + self.optimizer_A.step() + + if i % 100 == 0: + print(f"[Epoch {epoch_pre + 1}/{self.n_epochs_pretrain}] [Batch {i}/{len(train_dataloader)}] [A loss: {a_loss.item():.3f}]") + + # adversarial training + print(f"Starting adversarial training for {self.n_epochs} epochs...") + gen_iterations = 0 + + for epoch in range(self.n_epochs): + epoch_start = time.time() + + for i, samples in enumerate(train_dataloader): + # adversarial ground truths + valid = torch.ones(samples.shape[0], device=self.device) + fake = torch.zeros(samples.shape[0], device=self.device) + + # configure input + real_samples = samples.to(self.device) + + # sample noise as generator input + z = torch.randn(samples.shape[0], self.latent_dim, device=self.device) + + # --------------------- + # Train Discriminator + # --------------------- + + for p in self.discriminator.parameters(): + p.requires_grad = True + + # train the discriminator n_iter_D times + if gen_iterations < 25 or gen_iterations % 500 == 0: + n_iter_D = 100 + else: + n_iter_D = self.n_iter_D + + j = 0 + while j < n_iter_D: + j += 1 + + # clamp parameters to a cube + for p in self.discriminator.parameters(): + p.data.clamp_(self.clamp_lower, self.clamp_upper) + + # reset gradients of discriminator + self.optimizer_D.zero_grad() + + errD_real = torch.mean(self.discriminator(real_samples), dim=0) + errD_real.backward(self.one) + + # sample noise as generator input + z = torch.randn(samples.shape[0], self.latent_dim, device=self.device) + + # generate a batch of images + fake_samples = self.generator(z) + fake_samples = torch.squeeze(self.autoencoder_decoder(fake_samples.unsqueeze(dim=2))) + + errD_fake = torch.mean(self.discriminator(fake_samples.detach()), dim=0) + errD_fake.backward(self.mone) + errD = errD_real - errD_fake + + # optimizer step + self.optimizer_D.step() + + # ----------------- + # Train Generator + # ----------------- + + for p in self.discriminator.parameters(): + p.requires_grad = False + + # zero grads + self.optimizer_G.zero_grad() + + # sample noise as generator input + z = torch.randn(samples.shape[0], self.latent_dim, device=self.device) + + # generate a batch of images + fake_samples = self.generator(z) + fake_samples = torch.squeeze(self.autoencoder_decoder(fake_samples.unsqueeze(dim=2))) + + # loss measures generator's ability to fool the discriminator + errG = torch.mean(self.discriminator(fake_samples), dim=0) + errG.backward(self.one) + + # optimizer step + self.optimizer_G.step() + gen_iterations += 1 + + # end of epoch + epoch_end = time.time() + print(f"[Epoch {epoch + 1}/{self.n_epochs}] [Batch {i}/{len(train_dataloader)}] " + f"Loss_D: {errD.item():.3f} Loss_G: {errG.item():.3f} " + f"Loss_D_real: {errD_real.item():.3f} Loss_D_fake: {errD_fake.item():.3f}") + print(f"Epoch time: {epoch_end - epoch_start:.2f} seconds") + + print("Training completed!") + + def generate(self, n_samples: int, device: torch.device = None) -> torch.Tensor: + """Generate synthetic data""" + if device is None: + device = self.device + + # set models to eval mode + self.generator.eval() + self.autoencoder_decoder.eval() + + # generate samples + gen_samples = np.zeros((n_samples, self.input_dim), dtype=np.float32) + n_batches = int(n_samples / self.batch_size) + + with torch.no_grad(): + for i in range(n_batches): + # sample noise as generator input + z = torch.randn(self.batch_size, self.latent_dim, device=device) + gen_samples_tensor = self.generator(z) + gen_samples_decoded = torch.squeeze(self.autoencoder_decoder(gen_samples_tensor.unsqueeze(dim=2))) + gen_samples[i * self.batch_size:(i + 1) * self.batch_size, :] = gen_samples_decoded.cpu().data.numpy() + + # handle remaining samples + remaining = n_samples % self.batch_size + if remaining > 0: + z = torch.randn(remaining, self.latent_dim, device=device) + gen_samples_tensor = self.generator(z) + gen_samples_decoded = torch.squeeze(self.autoencoder_decoder(gen_samples_tensor.unsqueeze(dim=2))) + gen_samples[n_batches * self.batch_size:, :] = gen_samples_decoded.cpu().data.numpy() + + # binarize output + gen_samples[gen_samples >= 0.5] = 1.0 + gen_samples[gen_samples < 0.5] = 0.0 + + return torch.from_numpy(gen_samples) + + def save_model(self, path: str): + """Save model checkpoint""" + torch.save({ + 'generator_state_dict': self.generator.state_dict(), + 'discriminator_state_dict': self.discriminator.state_dict(), + 'autoencoder_state_dict': self.autoencoder.state_dict(), + 'autoencoder_decoder_state_dict': self.autoencoder_decoder.state_dict(), + 'optimizer_G_state_dict': self.optimizer_G.state_dict(), + 'optimizer_D_state_dict': self.optimizer_D.state_dict(), + 'optimizer_A_state_dict': self.optimizer_A.state_dict(), + 'global_vocab': self.global_vocab, + 'input_dim': self.input_dim, + 'latent_dim': self.latent_dim, + }, path) + + def load_model(self, path: str): + """Load model checkpoint""" + checkpoint = torch.load(path, map_location=self.device) + + self.generator.load_state_dict(checkpoint['generator_state_dict']) + self.discriminator.load_state_dict(checkpoint['discriminator_state_dict']) + self.autoencoder.load_state_dict(checkpoint['autoencoder_state_dict']) + self.autoencoder_decoder.load_state_dict(checkpoint['autoencoder_decoder_state_dict']) + self.optimizer_G.load_state_dict(checkpoint['optimizer_G_state_dict']) + self.optimizer_D.load_state_dict(checkpoint['optimizer_D_state_dict']) + self.optimizer_A.load_state_dict(checkpoint['optimizer_A_state_dict']) + + self.global_vocab = checkpoint['global_vocab'] + self.input_dim = checkpoint['input_dim'] + self.latent_dim = checkpoint['latent_dim'] \ No newline at end of file From 5280d46d36ce9ab0ad365122de48b808e807c8e0 Mon Sep 17 00:00:00 2001 From: jalengg Date: Wed, 25 Feb 2026 02:15:01 -0600 Subject: [PATCH 02/13] Add CorGANGenerationMIMIC3 task function --- pyhealth/tasks/__init__.py | 6 ++ pyhealth/tasks/corgan_generation.py | 91 +++++++++++++++++++++++++++++ 2 files changed, 97 insertions(+) create mode 100644 pyhealth/tasks/corgan_generation.py diff --git a/pyhealth/tasks/__init__.py b/pyhealth/tasks/__init__.py index 2f4294a19..aef106ead 100644 --- a/pyhealth/tasks/__init__.py +++ b/pyhealth/tasks/__init__.py @@ -9,6 +9,12 @@ cardiology_isCD_fn, cardiology_isWA_fn, ) +from .corgan_generation import ( + CorGANGenerationMIMIC3, + CorGANGenerationMIMIC4, + corgan_generation_mimic3_fn, + corgan_generation_mimic4_fn, +) from .chestxray14_binary_classification import ChestXray14BinaryClassification from .chestxray14_multilabel_classification import ChestXray14MultilabelClassification from .covid19_cxr_classification import COVID19CXRClassification diff --git a/pyhealth/tasks/corgan_generation.py b/pyhealth/tasks/corgan_generation.py new file mode 100644 index 000000000..0c25fd066 --- /dev/null +++ b/pyhealth/tasks/corgan_generation.py @@ -0,0 +1,91 @@ +import polars as pl +from typing import Dict, List + +from pyhealth.tasks.base_task import BaseTask + + +class CorGANGenerationMIMIC3(BaseTask): + """Task function for CorGAN synthetic EHR generation using MIMIC-III. + + Extracts ICD-9 diagnosis codes from MIMIC-III admission records into a + nested visit structure suitable for training the CorGAN model. + + Each sample contains the full visit history for a single patient, where + each visit is a list of ICD-9 codes recorded during that admission. + Patients with fewer than 2 visits are excluded. + + Attributes: + task_name (str): Unique task identifier. + input_schema (dict): Schema descriptor for the visits field. + output_schema (dict): Empty — generative task, no conditioning label. + _icd_col (str): Polars column path for ICD codes in MIMIC-III. + + Examples: + >>> fn = CorGANGenerationMIMIC3() + >>> fn.task_name + 'CorGANGenerationMIMIC3' + """ + + task_name = "CorGANGenerationMIMIC3" + input_schema = {"visits": "nested_sequence"} + output_schema = {} + _icd_col = "diagnoses_icd/icd9_code" + + def __call__(self, patient) -> List[Dict]: + """Extract structured visit data for a single patient. + + Args: + patient: A PyHealth patient object with admission and diagnosis + event data. + + Returns: + list of dict: A single-element list containing the patient record, + or an empty list if the patient has fewer than 2 visits with + diagnosis codes. Each dict has: + ``"patient_id"`` (str): the patient identifier. + ``"visits"`` (list of list of str): per-visit ICD code lists. + """ + admissions = list(patient.get_events(event_type="admissions")) + visits = [] + for adm in admissions: + codes = ( + patient.get_events( + event_type="diagnoses_icd", + filters=[("hadm_id", "==", adm.hadm_id)], + return_df=True, + ) + .select(pl.col(self._icd_col)) + .to_series() + .drop_nulls() + .to_list() + ) + if codes: + visits.append(codes) + if len(visits) < 2: + return [] + return [{"patient_id": patient.patient_id, "visits": visits}] + + +class CorGANGenerationMIMIC4(CorGANGenerationMIMIC3): + """Task function for CorGAN synthetic EHR generation using MIMIC-IV. + + Inherits all logic from :class:`CorGANGenerationMIMIC3`. Overrides only + the task name and the ICD code column to match the MIMIC-IV schema, where + the column is ``icd_code`` (unversioned) rather than ``icd9_code``. + + Attributes: + task_name (str): Unique task identifier. + _icd_col (str): Polars column path for ICD codes in MIMIC-IV. + + Examples: + >>> fn = CorGANGenerationMIMIC4() + >>> fn.task_name + 'CorGANGenerationMIMIC4' + """ + + task_name = "CorGANGenerationMIMIC4" + _icd_col = "diagnoses_icd/icd_code" + + +corgan_generation_mimic3_fn = CorGANGenerationMIMIC3() +corgan_generation_mimic4_fn = CorGANGenerationMIMIC4() From 544945d1482caf8f552a6d3c161985ac3b71cd70 Mon Sep 17 00:00:00 2001 From: jalengg Date: Wed, 25 Feb 2026 02:19:56 -0600 Subject: [PATCH 03/13] Fix code quality issues in CorGAN task function --- pyhealth/tasks/__init__.py | 4 ++-- pyhealth/tasks/corgan_generation.py | 7 ++++--- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/pyhealth/tasks/__init__.py b/pyhealth/tasks/__init__.py index aef106ead..c982b385b 100644 --- a/pyhealth/tasks/__init__.py +++ b/pyhealth/tasks/__init__.py @@ -9,14 +9,14 @@ cardiology_isCD_fn, cardiology_isWA_fn, ) +from .chestxray14_binary_classification import ChestXray14BinaryClassification +from .chestxray14_multilabel_classification import ChestXray14MultilabelClassification from .corgan_generation import ( CorGANGenerationMIMIC3, CorGANGenerationMIMIC4, corgan_generation_mimic3_fn, corgan_generation_mimic4_fn, ) -from .chestxray14_binary_classification import ChestXray14BinaryClassification -from .chestxray14_multilabel_classification import ChestXray14MultilabelClassification from .covid19_cxr_classification import COVID19CXRClassification from .dka import DKAPredictionMIMIC4, T1DDKAPredictionMIMIC4 from .drug_recommendation import ( diff --git a/pyhealth/tasks/corgan_generation.py b/pyhealth/tasks/corgan_generation.py index 0c25fd066..d3ce692fe 100644 --- a/pyhealth/tasks/corgan_generation.py +++ b/pyhealth/tasks/corgan_generation.py @@ -1,6 +1,7 @@ -import polars as pl from typing import Dict, List +import polars as pl + from pyhealth.tasks.base_task import BaseTask @@ -42,8 +43,8 @@ def __call__(self, patient) -> List[Dict]: list of dict: A single-element list containing the patient record, or an empty list if the patient has fewer than 2 visits with diagnosis codes. Each dict has: - ``"patient_id"`` (str): the patient identifier. - ``"visits"`` (list of list of str): per-visit ICD code lists. + ``"patient_id"`` (str): the patient identifier. + ``"visits"`` (list of list of str): per-visit ICD code lists. """ admissions = list(patient.get_events(event_type="admissions")) visits = [] From 41e02a45ed68c43f41d68ae88194855ee3e82564 Mon Sep 17 00:00:00 2001 From: jalengg Date: Wed, 25 Feb 2026 03:24:30 -0600 Subject: [PATCH 04/13] Refactor CorGAN to PyHealth 2.0 BaseModel compliance (train_model, synthesize_dataset) --- pyhealth/models/generators/corgan.py | 526 ++++++++++++++------------- 1 file changed, 283 insertions(+), 243 deletions(-) diff --git a/pyhealth/models/generators/corgan.py b/pyhealth/models/generators/corgan.py index 8acfe5297..3c216d1c2 100644 --- a/pyhealth/models/generators/corgan.py +++ b/pyhealth/models/generators/corgan.py @@ -1,21 +1,17 @@ -import functools -from typing import Dict, List, Optional, Tuple, Union +import time +from typing import Dict, List, Optional + import numpy as np import torch import torch.nn as nn -import torch.nn.functional as F from torch.utils.data import DataLoader, Dataset -import random -import time -from pyhealth.datasets import BaseDataset from pyhealth.models import BaseModel -from pyhealth.tokenizer import Tokenizer class CorGANDataset(Dataset): """Dataset wrapper for CorGAN training""" - + def __init__(self, data, transform=None): self.transform = transform self.data = data.astype(np.float32) @@ -50,7 +46,7 @@ def __init__(self, feature_size: int, latent_dim: int = 128, use_adaptive_poolin self.latent_dim = latent_dim self.use_adaptive_pooling = use_adaptive_pooling n_channels_base = 4 - + # calculate the size after convolutions # input: (batch, 1, feature_size) # conv1: kernel=5, stride=2 -> (batch, 4, (feature_size-4)//2) @@ -59,10 +55,10 @@ def __init__(self, feature_size: int, latent_dim: int = 128, use_adaptive_poolin # conv4: kernel=5, stride=3 -> (batch, 32, ((((feature_size-4)//2-4)//2-4)//3-4)//3) # conv5: kernel=5, stride=3 -> (batch, 64, (((((feature_size-4)//2-4)//2-4)//3-4)//3-4)//3) # conv6: kernel=8, stride=1 -> (batch, 128, ((((((feature_size-4)//2-4)//2-4)//3-4)//3-4)//3-7)) - + # rough estimate for latent size latent_size = max(1, feature_size // 100) # ensure at least 1 - + self.encoder = nn.Sequential( nn.Conv1d(in_channels=1, out_channels=n_channels_base, kernel_size=5, stride=2, padding=0, dilation=1, groups=1, bias=True, padding_mode='zeros'), @@ -170,84 +166,84 @@ def __init__(self, feature_size: int = 6955, latent_dim: int = 128): self.feature_size = feature_size self.latent_dim = latent_dim - # Encoder: 6955 → 1 (8 layers) + # Encoder: 6955 -> 1 (8 layers) self.encoder = nn.Sequential( - # Layer 1: 6955 → 3476 + # Layer 1: 6955 -> 3476 nn.Conv1d(1, 4, kernel_size=5, stride=2, padding=0), nn.LeakyReLU(0.2, inplace=True), - # Layer 2: 3476 → 1736 + # Layer 2: 3476 -> 1736 nn.Conv1d(4, 8, kernel_size=5, stride=2, padding=0), nn.BatchNorm1d(8), nn.LeakyReLU(0.2, inplace=True), - # Layer 3: 1736 → 578 + # Layer 3: 1736 -> 578 nn.Conv1d(8, 16, kernel_size=5, stride=3, padding=0), nn.BatchNorm1d(16), nn.LeakyReLU(0.2, inplace=True), - # Layer 4: 578 → 192 + # Layer 4: 578 -> 192 nn.Conv1d(16, 32, kernel_size=5, stride=3, padding=0), nn.BatchNorm1d(32), nn.LeakyReLU(0.2, inplace=True), - # Layer 5: 192 → 63 + # Layer 5: 192 -> 63 nn.Conv1d(32, 64, kernel_size=5, stride=3, padding=0), nn.BatchNorm1d(64), nn.LeakyReLU(0.2, inplace=True), - # Layer 6: 63 → 20 [NEW] + # Layer 6: 63 -> 20 [NEW] nn.Conv1d(64, 96, kernel_size=5, stride=3, padding=0), nn.BatchNorm1d(96), nn.LeakyReLU(0.2, inplace=True), - # Layer 7: 20 → 4 [NEW] + # Layer 7: 20 -> 4 [NEW] nn.Conv1d(96, 112, kernel_size=5, stride=4, padding=0), nn.BatchNorm1d(112), nn.LeakyReLU(0.2, inplace=True), - # Layer 8: 4 → 1 [NEW] + # Layer 8: 4 -> 1 [NEW] nn.Conv1d(112, 128, kernel_size=4, stride=1, padding=0), nn.Tanh(), ) - # Decoder: 1 → 6955 (8 layers) + # Decoder: 1 -> 6955 (8 layers) self.decoder = nn.Sequential( - # Layer 1: 1 → 4 (NO BatchNorm on first layer) + # Layer 1: 1 -> 4 (NO BatchNorm on first layer) nn.ConvTranspose1d(128, 112, kernel_size=4, stride=1, padding=0), nn.ReLU(), - # Layer 2: 4 → 20 + # Layer 2: 4 -> 20 nn.ConvTranspose1d(112, 96, kernel_size=8, stride=4, padding=0), nn.BatchNorm1d(96), nn.ReLU(), - # Layer 3: 20 → 63 + # Layer 3: 20 -> 63 nn.ConvTranspose1d(96, 64, kernel_size=6, stride=3, padding=0), nn.BatchNorm1d(64), nn.ReLU(), - # Layer 4: 63 → 192 + # Layer 4: 63 -> 192 nn.ConvTranspose1d(64, 32, kernel_size=6, stride=3, padding=0), nn.BatchNorm1d(32), nn.ReLU(), - # Layer 5: 192 → 578 + # Layer 5: 192 -> 578 nn.ConvTranspose1d(32, 16, kernel_size=5, stride=3, padding=0), nn.BatchNorm1d(16), nn.ReLU(), - # Layer 6: 578 → 1736 + # Layer 6: 578 -> 1736 nn.ConvTranspose1d(16, 8, kernel_size=5, stride=3, padding=0), nn.BatchNorm1d(8), nn.ReLU(), - # Layer 7: 1736 → 3476 + # Layer 7: 1736 -> 3476 nn.ConvTranspose1d(8, 4, kernel_size=6, stride=2, padding=0), nn.BatchNorm1d(4), nn.ReLU(), - # Layer 8: 3476 → 6955 + # Layer 8: 3476 -> 6955 nn.ConvTranspose1d(4, 1, kernel_size=5, stride=2, padding=0), nn.Sigmoid(), ) @@ -297,7 +293,7 @@ def __init__(self, feature_size: int, latent_dim: int = 128): self.feature_size = feature_size self.latent_dim = latent_dim - # Encoder: feature_size → latent_dim + # Encoder: feature_size -> latent_dim # Use ReLU+BatchNorm (V11 achieved 4.49 codes, best linear result) self.encoder = nn.Sequential( nn.Linear(feature_size, latent_dim), @@ -305,7 +301,7 @@ def __init__(self, feature_size: int, latent_dim: int = 128): nn.BatchNorm1d(latent_dim) ) - # Decoder: latent_dim → feature_size + # Decoder: latent_dim -> feature_size self.decoder = nn.Sequential( nn.Linear(latent_dim, feature_size), nn.Sigmoid() @@ -379,7 +375,7 @@ class CorGANDiscriminator(nn.Module): Discriminator for CorGAN - MLP with minibatch averaging Architecture matches synthEHRella exactly (wgancnnmimic.py lines 265-296): - - 4 linear layers: input → 256 → 256 → 256 → 1 + - 4 linear layers: input -> 256 -> 256 -> 256 -> 1 - ReLU activations - No sigmoid (WGAN uses unbounded critic outputs) """ @@ -463,49 +459,58 @@ def discriminator_accuracy(predicted, y_true): class CorGAN(BaseModel): """ - CorGAN: Correlation-capturing Generative Adversarial Network - + CorGAN: Correlation-capturing Generative Adversarial Network for synthetic EHR generation. + Uses CNNs to capture correlations between adjacent medical features by combining Convolutional GANs with Convolutional Autoencoders. - + + Reference: + Baowaly et al., "Synthesizing Electronic Health Records Using Improved + Generative Adversarial Networks", JAMIA 2019. + Args: - dataset: PyHealth dataset object - feature_keys: List of feature keys to use - label_key: Label key (not used in unsupervised generation) - mode: Training mode (not used in GAN context) - latent_dim: Dimensionality of latent space - hidden_dim: Hidden dimension for networks - batch_size: Training batch size - n_epochs: Number of training epochs - n_epochs_pretrain: Number of autoencoder pretraining epochs - lr: Learning rate - weight_decay: Weight decay for optimization - b1: Beta1 for Adam optimizer - b2: Beta2 for Adam optimizer - n_iter_D: Number of discriminator iterations per generator iteration - clamp_lower: Lower bound for weight clipping - clamp_upper: Upper bound for weight clipping - minibatch_averaging: Whether to use minibatch averaging in discriminator - **kwargs: Additional arguments - + dataset: A fitted SampleDataset with ``input_schema = {"visits": "nested_sequence"}``. + latent_dim: Dimensionality of the generator latent space. Default: 128. + hidden_dim: Hidden dimension for the generator MLP. Default: 128. + batch_size: Training batch size. Default: 512. + epochs: Total GAN training epochs. Default: 1000. + n_epochs_pretrain: Autoencoder pre-training epochs. Default: 1. + lr: Learning rate for all optimizers. Default: 0.001. + weight_decay: Weight decay for Adam optimizers. Default: 0.0001. + b1: Beta1 for Adam optimizers. Default: 0.9. + b2: Beta2 for Adam optimizers. Default: 0.999. + n_iter_D: Discriminator update steps per generator step. Default: 5. + clamp_lower: Lower weight-clipping bound for WGAN critic. Default: -0.01. + clamp_upper: Upper weight-clipping bound for WGAN critic. Default: 0.01. + autoencoder_type: One of ``"cnn"`` (default), ``"cnn8layer"``, or ``"linear"``. + use_adaptive_pooling: If True, add adaptive average pooling to the CNN + autoencoder decoder so it matches any vocabulary size. Ignored when + ``autoencoder_type`` is not ``"cnn"``. Default: True. + minibatch_averaging: Whether to use minibatch averaging in the discriminator. + Default: True. + save_dir: Directory for saving checkpoints. Default: ``"./corgan_checkpoints"``. + **kwargs: Additional arguments passed to ``BaseModel``. + Examples: - >>> from pyhealth.datasets import MIMIC3Dataset - >>> dataset = MIMIC3Dataset(...) - >>> model = CorGAN(dataset=dataset, feature_keys=["conditions"]) - >>> model.fit() - >>> synthetic_data = model.generate(n_samples=50000) + >>> from pyhealth.datasets.sample_dataset import InMemorySampleDataset + >>> samples = [{"patient_id": "p1", "visits": [["A", "B"], ["C"]]}] + >>> dataset = InMemorySampleDataset( + ... samples=samples, + ... input_schema={"visits": "nested_sequence"}, + ... output_schema={}, + ... ) + >>> model = CorGAN(dataset, latent_dim=32, hidden_dim=32, epochs=1) + >>> model.train_model(dataset) + >>> records = model.synthesize_dataset(num_samples=10) """ - + def __init__( self, - dataset: BaseDataset, - feature_keys: List[str], - label_key: str, - mode: str = "generation", + dataset, latent_dim: int = 128, hidden_dim: int = 128, batch_size: int = 512, - n_epochs: int = 1000, + epochs: int = 1000, n_epochs_pretrain: int = 1, lr: float = 0.001, weight_decay: float = 0.0001, @@ -514,21 +519,18 @@ def __init__( n_iter_D: int = 5, clamp_lower: float = -0.01, clamp_upper: float = 0.01, + autoencoder_type: str = "cnn", + use_adaptive_pooling: bool = True, minibatch_averaging: bool = True, + save_dir: str = "./corgan_checkpoints", **kwargs ): - super(CorGAN, self).__init__( - dataset=dataset, - feature_keys=feature_keys, - label_key=label_key, - mode=mode, - **kwargs - ) - + super(CorGAN, self).__init__(dataset=dataset) + self.latent_dim = latent_dim self.hidden_dim = hidden_dim self.batch_size = batch_size - self.n_epochs = n_epochs + self.n_epochs = epochs self.n_epochs_pretrain = n_epochs_pretrain self.lr = lr self.weight_decay = weight_decay @@ -538,40 +540,53 @@ def __init__( self.clamp_lower = clamp_lower self.clamp_upper = clamp_upper self.minibatch_averaging = minibatch_averaging - - # build unified vocabulary for all feature keys - self.global_vocab = self._build_global_vocab(dataset, feature_keys) - self.input_dim = len(self.global_vocab) - self.tokenizer = Tokenizer(tokens=self.global_vocab, special_tokens=[]) - + self.save_dir = save_dir + + # vocabulary from the dataset's fitted processor + processor = dataset.input_processors["visits"] + self.input_dim = processor.vocab_size() + # build reverse-lookup: integer index -> code string + self._idx_to_code: List[Optional[str]] = [None] * self.input_dim + for code, idx in processor.code_vocab.items(): + self._idx_to_code[idx] = code + # initialize components - # Determine if adaptive pooling is needed (only for non-standard vocabulary sizes) - use_adaptive_pooling = (self.input_dim != 1071) - self.autoencoder = CorGANAutoencoder( - feature_size=self.input_dim, - latent_dim=latent_dim, - use_adaptive_pooling=use_adaptive_pooling - ) + if autoencoder_type == "cnn8layer": + self.autoencoder = CorGAN8LayerAutoencoder( + feature_size=self.input_dim, + latent_dim=latent_dim, + ) + elif autoencoder_type == "linear": + self.autoencoder = CorGANLinearAutoencoder( + feature_size=self.input_dim, + latent_dim=latent_dim, + ) + else: + self.autoencoder = CorGANAutoencoder( + feature_size=self.input_dim, + latent_dim=latent_dim, + use_adaptive_pooling=use_adaptive_pooling, + ) + self.autoencoder_decoder = self.autoencoder.decoder # separate decoder for generator - + self.generator = CorGANGenerator( latent_dim=latent_dim, - hidden_dim=hidden_dim + hidden_dim=hidden_dim, ) - + self.discriminator = CorGANDiscriminator( input_dim=self.input_dim, hidden_dim=256, # Match synthEHRella exactly (not hidden_dim * 2) - minibatch_averaging=minibatch_averaging + minibatch_averaging=minibatch_averaging, ) - + # apply custom weight initialization self._init_weights() - - # setup device - self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + # move to device (uses BaseModel's device property) self.to(self.device) - + # setup optimizers g_params = [ {'params': self.generator.parameters()}, @@ -580,247 +595,268 @@ def __init__( self.optimizer_G = torch.optim.Adam(g_params, lr=lr, betas=(b1, b2), weight_decay=weight_decay) self.optimizer_D = torch.optim.Adam(self.discriminator.parameters(), lr=lr, betas=(b1, b2), weight_decay=weight_decay) self.optimizer_A = torch.optim.Adam(self.autoencoder.parameters(), lr=lr, betas=(b1, b2), weight_decay=weight_decay) - + # setup tensors for training self.one = torch.tensor(1.0, device=self.device) self.mone = torch.tensor(-1.0, device=self.device) - - def _build_global_vocab(self, dataset: BaseDataset, feature_keys: List[str]) -> List[str]: - """Build unified vocabulary across all feature keys""" - global_vocab = set() - - # collect all unique codes from all patients and feature keys - for patient_id in dataset.patients: - patient = dataset.patients[patient_id] - for feature_key in feature_keys: - if feature_key in patient: - for visit in patient[feature_key]: - if isinstance(visit, list): - global_vocab.update(visit) - else: - global_vocab.add(visit) - - return sorted(list(global_vocab)) - - def _encode_patient_record(self, record: Dict) -> torch.Tensor: - """Encode a patient record to binary vector""" - # create binary vector - binary_vector = np.zeros(self.input_dim, dtype=np.float32) - - for feature_key in self.feature_keys: - if feature_key in record: - for visit in record[feature_key]: - if isinstance(visit, list): - for code in visit: - if code in self.global_vocab: - idx = self.global_vocab.index(code) - binary_vector[idx] = 1.0 - else: - if visit in self.global_vocab: - idx = self.global_vocab.index(visit) - binary_vector[idx] = 1.0 - - return torch.from_numpy(binary_vector) - + def _init_weights(self): """Initialize network weights""" self.generator.apply(weights_init) self.discriminator.apply(weights_init) self.autoencoder.apply(weights_init) - - def _extract_features_from_batch(self, batch_data, device: torch.device) -> torch.Tensor: - """Extract features from batch data""" - features = [] - for patient_id in batch_data: - patient = self.dataset.patients[patient_id] - feature_vector = self._encode_patient_record(patient) - features.append(feature_vector) - - return torch.stack(features).to(device) - + + def _encode_samples_to_multihot(self, dataset) -> np.ndarray: + """Build a multi-hot binary matrix from a SampleDataset. + + Each row corresponds to one patient sample. All visits are aggregated + into a single flat set of codes per patient and encoded as a binary + vector of length ``vocab_size``. + + Args: + dataset: A fitted SampleDataset whose raw samples contain + ``sample["visits"]`` as a list of lists of code strings. + + Returns: + np.ndarray of shape ``(n_patients, vocab_size)`` with dtype float32. + """ + processor = self.dataset.input_processors["visits"] + code_vocab = processor.code_vocab + + n = len(dataset) + matrix = np.zeros((n, self.input_dim), dtype=np.float32) + + for i, sample in enumerate(dataset): + # sample["visits"] is the raw nested list from the original dict, + # but after SampleDataset processing it may be a tensor. + # We need the raw string codes, so access via dataset.samples if + # available, otherwise decode from the processed tensor. + visits = dataset.samples[i].get("visits", []) + for visit in visits: + if isinstance(visit, list): + for code in visit: + if code is not None and code in code_vocab: + matrix[i, code_vocab[code]] = 1.0 + + return matrix + def forward(self, **kwargs) -> Dict[str, torch.Tensor]: - """Forward pass - not used in GAN context""" - raise NotImplementedError("Forward pass not implemented for GAN models") - - def fit(self, train_dataloader: Optional[DataLoader] = None): - """Train the CorGAN model""" + """Not used in GAN context.""" + raise NotImplementedError("Forward pass not implemented for GAN models.") + + def train_model(self, train_dataset, val_dataset=None): + """Train the CorGAN model on a SampleDataset. + + Builds multi-hot encodings from ``train_dataset``, pre-trains the + autoencoder, then runs WGAN adversarial training. + + Args: + train_dataset: A fitted SampleDataset with + ``input_schema = {"visits": "nested_sequence"}``. + val_dataset: Unused. Accepted for API compatibility. + """ print("Starting CorGAN training...") - - # create dataset and dataloader - if train_dataloader is None: - # create binary matrix from dataset - data_matrix = [] - for patient_id in self.dataset.patients: - patient = self.dataset.patients[patient_id] - feature_vector = self._encode_patient_record(patient) - data_matrix.append(feature_vector.numpy()) - - data_matrix = np.array(data_matrix) - dataset = CorGANDataset(data=data_matrix) - - sampler = torch.utils.data.sampler.RandomSampler( - data_source=dataset, replacement=True - ) - train_dataloader = DataLoader( - dataset, - batch_size=self.batch_size, - shuffle=False, - num_workers=0, - drop_last=True, - sampler=sampler - ) - + + # build multi-hot matrix from SampleDataset + data_matrix = self._encode_samples_to_multihot(train_dataset) + + corgan_ds = CorGANDataset(data=data_matrix) + sampler = torch.utils.data.sampler.RandomSampler( + data_source=corgan_ds, replacement=True + ) + train_dataloader = DataLoader( + corgan_ds, + batch_size=self.batch_size, + shuffle=False, + num_workers=0, + drop_last=True, + sampler=sampler, + ) + # pretrain autoencoder print(f"Pretraining autoencoder for {self.n_epochs_pretrain} epochs...") for epoch_pre in range(self.n_epochs_pretrain): for i, samples in enumerate(train_dataloader): # configure input real_samples = samples.to(self.device) - + # generate a batch of images recons_samples = self.autoencoder(real_samples) - + # loss measures autoencoder's ability to reconstruct a_loss = autoencoder_loss(recons_samples, real_samples) - + # reset gradients self.optimizer_A.zero_grad() a_loss.backward() self.optimizer_A.step() - + if i % 100 == 0: print(f"[Epoch {epoch_pre + 1}/{self.n_epochs_pretrain}] [Batch {i}/{len(train_dataloader)}] [A loss: {a_loss.item():.3f}]") - + # adversarial training print(f"Starting adversarial training for {self.n_epochs} epochs...") gen_iterations = 0 - + for epoch in range(self.n_epochs): epoch_start = time.time() - + for i, samples in enumerate(train_dataloader): - # adversarial ground truths - valid = torch.ones(samples.shape[0], device=self.device) - fake = torch.zeros(samples.shape[0], device=self.device) - # configure input real_samples = samples.to(self.device) - + # sample noise as generator input z = torch.randn(samples.shape[0], self.latent_dim, device=self.device) - + # --------------------- # Train Discriminator # --------------------- - + for p in self.discriminator.parameters(): p.requires_grad = True - + # train the discriminator n_iter_D times if gen_iterations < 25 or gen_iterations % 500 == 0: n_iter_D = 100 else: n_iter_D = self.n_iter_D - + j = 0 while j < n_iter_D: j += 1 - + # clamp parameters to a cube for p in self.discriminator.parameters(): p.data.clamp_(self.clamp_lower, self.clamp_upper) - + # reset gradients of discriminator self.optimizer_D.zero_grad() - + errD_real = torch.mean(self.discriminator(real_samples), dim=0) errD_real.backward(self.one) - + # sample noise as generator input z = torch.randn(samples.shape[0], self.latent_dim, device=self.device) - + # generate a batch of images fake_samples = self.generator(z) fake_samples = torch.squeeze(self.autoencoder_decoder(fake_samples.unsqueeze(dim=2))) - + errD_fake = torch.mean(self.discriminator(fake_samples.detach()), dim=0) errD_fake.backward(self.mone) errD = errD_real - errD_fake - + # optimizer step self.optimizer_D.step() - + # ----------------- # Train Generator # ----------------- - + for p in self.discriminator.parameters(): p.requires_grad = False - + # zero grads self.optimizer_G.zero_grad() - + # sample noise as generator input z = torch.randn(samples.shape[0], self.latent_dim, device=self.device) - + # generate a batch of images fake_samples = self.generator(z) fake_samples = torch.squeeze(self.autoencoder_decoder(fake_samples.unsqueeze(dim=2))) - + # loss measures generator's ability to fool the discriminator errG = torch.mean(self.discriminator(fake_samples), dim=0) errG.backward(self.one) - + # optimizer step self.optimizer_G.step() gen_iterations += 1 - + # end of epoch epoch_end = time.time() print(f"[Epoch {epoch + 1}/{self.n_epochs}] [Batch {i}/{len(train_dataloader)}] " f"Loss_D: {errD.item():.3f} Loss_G: {errG.item():.3f} " f"Loss_D_real: {errD_real.item():.3f} Loss_D_fake: {errD_fake.item():.3f}") print(f"Epoch time: {epoch_end - epoch_start:.2f} seconds") - + print("Training completed!") - - def generate(self, n_samples: int, device: torch.device = None) -> torch.Tensor: - """Generate synthetic data""" - if device is None: - device = self.device - + + def synthesize_dataset(self, num_samples: int, random_sampling: bool = True) -> List[Dict]: + """Generate synthetic patient records. + + Each synthetic patient is represented as a single visit containing all + generated codes. This is an honest representation of what CorGAN produces — + a flat multi-hot vector aggregated across all visits. + + Args: + num_samples: Number of synthetic patients to generate. + random_sampling: Unused; accepted for API compatibility. + + Returns: + List of dicts, each with: + ``"patient_id"`` (str): e.g. ``"synthetic_0"``. + ``"visits"`` (list of list of str): one visit per patient + containing the decoded ICD codes. + """ # set models to eval mode self.generator.eval() self.autoencoder_decoder.eval() - - # generate samples - gen_samples = np.zeros((n_samples, self.input_dim), dtype=np.float32) - n_batches = int(n_samples / self.batch_size) - + + device = self.device + gen_samples = np.zeros((num_samples, self.input_dim), dtype=np.float32) + n_batches = num_samples // self.batch_size + with torch.no_grad(): for i in range(n_batches): - # sample noise as generator input z = torch.randn(self.batch_size, self.latent_dim, device=device) gen_samples_tensor = self.generator(z) - gen_samples_decoded = torch.squeeze(self.autoencoder_decoder(gen_samples_tensor.unsqueeze(dim=2))) - gen_samples[i * self.batch_size:(i + 1) * self.batch_size, :] = gen_samples_decoded.cpu().data.numpy() - - # handle remaining samples - remaining = n_samples % self.batch_size - if remaining > 0: - z = torch.randn(remaining, self.latent_dim, device=device) - gen_samples_tensor = self.generator(z) - gen_samples_decoded = torch.squeeze(self.autoencoder_decoder(gen_samples_tensor.unsqueeze(dim=2))) - gen_samples[n_batches * self.batch_size:, :] = gen_samples_decoded.cpu().data.numpy() - - # binarize output + gen_samples_decoded = torch.squeeze( + self.autoencoder_decoder(gen_samples_tensor.unsqueeze(dim=2)) + ) + gen_samples[i * self.batch_size:(i + 1) * self.batch_size, :] = ( + gen_samples_decoded.cpu().data.numpy() + ) + + # handle remaining samples + remaining = num_samples % self.batch_size + if remaining > 0: + z = torch.randn(remaining, self.latent_dim, device=device) + gen_samples_tensor = self.generator(z) + gen_samples_decoded = torch.squeeze( + self.autoencoder_decoder(gen_samples_tensor.unsqueeze(dim=2)) + ) + gen_samples[n_batches * self.batch_size:, :] = ( + gen_samples_decoded.cpu().data.numpy() + ) + + # binarize at threshold 0.5 gen_samples[gen_samples >= 0.5] = 1.0 gen_samples[gen_samples < 0.5] = 0.0 - - return torch.from_numpy(gen_samples) - + + # decode binary vectors to code strings + results: List[Dict] = [] + for i in range(num_samples): + row = gen_samples[i] + codes = [ + self._idx_to_code[idx] + for idx in np.where(row == 1.0)[0] + if self._idx_to_code[idx] not in (None, "", "") + ] + results.append({ + "patient_id": f"synthetic_{i}", + "visits": [codes], + }) + + return results + def save_model(self, path: str): - """Save model checkpoint""" + """Save model checkpoint. + + Args: + path: File path to write the checkpoint to. + """ torch.save({ 'generator_state_dict': self.generator.state_dict(), 'discriminator_state_dict': self.discriminator.state_dict(), @@ -829,15 +865,19 @@ def save_model(self, path: str): 'optimizer_G_state_dict': self.optimizer_G.state_dict(), 'optimizer_D_state_dict': self.optimizer_D.state_dict(), 'optimizer_A_state_dict': self.optimizer_A.state_dict(), - 'global_vocab': self.global_vocab, + 'idx_to_code': self._idx_to_code, 'input_dim': self.input_dim, 'latent_dim': self.latent_dim, }, path) - + def load_model(self, path: str): - """Load model checkpoint""" + """Load model checkpoint. + + Args: + path: File path to read the checkpoint from. + """ checkpoint = torch.load(path, map_location=self.device) - + self.generator.load_state_dict(checkpoint['generator_state_dict']) self.discriminator.load_state_dict(checkpoint['discriminator_state_dict']) self.autoencoder.load_state_dict(checkpoint['autoencoder_state_dict']) @@ -845,7 +885,7 @@ def load_model(self, path: str): self.optimizer_G.load_state_dict(checkpoint['optimizer_G_state_dict']) self.optimizer_D.load_state_dict(checkpoint['optimizer_D_state_dict']) self.optimizer_A.load_state_dict(checkpoint['optimizer_A_state_dict']) - - self.global_vocab = checkpoint['global_vocab'] + + self._idx_to_code = checkpoint['idx_to_code'] self.input_dim = checkpoint['input_dim'] - self.latent_dim = checkpoint['latent_dim'] \ No newline at end of file + self.latent_dim = checkpoint['latent_dim'] From 0e2d29034b7f343dc8eb2dcf5f7d1546433afcca Mon Sep 17 00:00:00 2001 From: jalengg Date: Wed, 25 Feb 2026 03:56:45 -0600 Subject: [PATCH 05/13] Fix generators/__init__.py: remove untracked PlasMode reference --- pyhealth/models/generators/__init__.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/pyhealth/models/generators/__init__.py b/pyhealth/models/generators/__init__.py index e27bb8993..a5ff099af 100644 --- a/pyhealth/models/generators/__init__.py +++ b/pyhealth/models/generators/__init__.py @@ -1,4 +1,3 @@ from .corgan import CorGAN -from .plasmode import PlasMode -__all__ = ["CorGAN", "PlasMode"] +__all__ = ["CorGAN"] From c02f15843021bd9868af7320e15ddca5c2fc2c3a Mon Sep 17 00:00:00 2001 From: jalengg Date: Wed, 25 Feb 2026 10:18:44 -0600 Subject: [PATCH 06/13] Switch CorGAN to multi_hot schema: flat codes, no nested visits --- pyhealth/models/generators/corgan.py | 99 +++++++++++----------------- pyhealth/tasks/corgan_generation.py | 36 +++++----- 2 files changed, 59 insertions(+), 76 deletions(-) diff --git a/pyhealth/models/generators/corgan.py b/pyhealth/models/generators/corgan.py index 3c216d1c2..93cbb9029 100644 --- a/pyhealth/models/generators/corgan.py +++ b/pyhealth/models/generators/corgan.py @@ -469,7 +469,7 @@ class CorGAN(BaseModel): Generative Adversarial Networks", JAMIA 2019. Args: - dataset: A fitted SampleDataset with ``input_schema = {"visits": "nested_sequence"}``. + dataset: A fitted SampleDataset with ``input_schema = {"visits": "multi_hot"}``. latent_dim: Dimensionality of the generator latent space. Default: 128. hidden_dim: Hidden dimension for the generator MLP. Default: 128. batch_size: Training batch size. Default: 512. @@ -493,10 +493,10 @@ class CorGAN(BaseModel): Examples: >>> from pyhealth.datasets.sample_dataset import InMemorySampleDataset - >>> samples = [{"patient_id": "p1", "visits": [["A", "B"], ["C"]]}] + >>> samples = [{"patient_id": "p1", "visits": ["A", "B", "C"]}] >>> dataset = InMemorySampleDataset( ... samples=samples, - ... input_schema={"visits": "nested_sequence"}, + ... input_schema={"visits": "multi_hot"}, ... output_schema={}, ... ) >>> model = CorGAN(dataset, latent_dim=32, hidden_dim=32, epochs=1) @@ -544,19 +544,26 @@ def __init__( # vocabulary from the dataset's fitted processor processor = dataset.input_processors["visits"] - self.input_dim = processor.vocab_size() + self.input_dim = processor.size() # build reverse-lookup: integer index -> code string self._idx_to_code: List[Optional[str]] = [None] * self.input_dim - for code, idx in processor.code_vocab.items(): + for code, idx in processor.label_vocab.items(): self._idx_to_code[idx] = code # initialize components - if autoencoder_type == "cnn8layer": + # CNN autoencoder requires a minimum input size to survive its convolution chain + # (6 layers with kernels 5,5,5,5,5,8 and strides 2,2,3,3,3,1 need at least ~500 + # features). Fall back to the linear autoencoder for small vocabularies. + _effective_type = autoencoder_type + if autoencoder_type not in ("linear", "cnn8layer") and self.input_dim < 500: + _effective_type = "linear" + + if _effective_type == "cnn8layer": self.autoencoder = CorGAN8LayerAutoencoder( feature_size=self.input_dim, latent_dim=latent_dim, ) - elif autoencoder_type == "linear": + elif _effective_type == "linear": self.autoencoder = CorGANLinearAutoencoder( feature_size=self.input_dim, latent_dim=latent_dim, @@ -606,40 +613,6 @@ def _init_weights(self): self.discriminator.apply(weights_init) self.autoencoder.apply(weights_init) - def _encode_samples_to_multihot(self, dataset) -> np.ndarray: - """Build a multi-hot binary matrix from a SampleDataset. - - Each row corresponds to one patient sample. All visits are aggregated - into a single flat set of codes per patient and encoded as a binary - vector of length ``vocab_size``. - - Args: - dataset: A fitted SampleDataset whose raw samples contain - ``sample["visits"]`` as a list of lists of code strings. - - Returns: - np.ndarray of shape ``(n_patients, vocab_size)`` with dtype float32. - """ - processor = self.dataset.input_processors["visits"] - code_vocab = processor.code_vocab - - n = len(dataset) - matrix = np.zeros((n, self.input_dim), dtype=np.float32) - - for i, sample in enumerate(dataset): - # sample["visits"] is the raw nested list from the original dict, - # but after SampleDataset processing it may be a tensor. - # We need the raw string codes, so access via dataset.samples if - # available, otherwise decode from the processed tensor. - visits = dataset.samples[i].get("visits", []) - for visit in visits: - if isinstance(visit, list): - for code in visit: - if code is not None and code in code_vocab: - matrix[i, code_vocab[code]] = 1.0 - - return matrix - def forward(self, **kwargs) -> Dict[str, torch.Tensor]: """Not used in GAN context.""" raise NotImplementedError("Forward pass not implemented for GAN models.") @@ -652,13 +625,14 @@ def train_model(self, train_dataset, val_dataset=None): Args: train_dataset: A fitted SampleDataset with - ``input_schema = {"visits": "nested_sequence"}``. + ``input_schema = {"visits": "multi_hot"}``. val_dataset: Unused. Accepted for API compatibility. """ print("Starting CorGAN training...") - # build multi-hot matrix from SampleDataset - data_matrix = self._encode_samples_to_multihot(train_dataset) + # build multi-hot matrix by stacking the pre-encoded tensors from MultiHotProcessor + tensors = [train_dataset[i]["visits"] for i in range(len(train_dataset))] + data_matrix = torch.stack(tensors).numpy() # shape (n_patients, vocab_size) corgan_ds = CorGANDataset(data=data_matrix) sampler = torch.utils.data.sampler.RandomSampler( @@ -732,7 +706,7 @@ def train_model(self, train_dataset, val_dataset=None): # reset gradients of discriminator self.optimizer_D.zero_grad() - errD_real = torch.mean(self.discriminator(real_samples), dim=0) + errD_real = torch.mean(self.discriminator(real_samples)).squeeze() errD_real.backward(self.one) # sample noise as generator input @@ -740,9 +714,9 @@ def train_model(self, train_dataset, val_dataset=None): # generate a batch of images fake_samples = self.generator(z) - fake_samples = torch.squeeze(self.autoencoder_decoder(fake_samples.unsqueeze(dim=2))) + fake_samples = self.autoencoder.decode(fake_samples) - errD_fake = torch.mean(self.discriminator(fake_samples.detach()), dim=0) + errD_fake = torch.mean(self.discriminator(fake_samples.detach())).squeeze() errD_fake.backward(self.mone) errD = errD_real - errD_fake @@ -764,10 +738,10 @@ def train_model(self, train_dataset, val_dataset=None): # generate a batch of images fake_samples = self.generator(z) - fake_samples = torch.squeeze(self.autoencoder_decoder(fake_samples.unsqueeze(dim=2))) + fake_samples = self.autoencoder.decode(fake_samples) # loss measures generator's ability to fool the discriminator - errG = torch.mean(self.discriminator(fake_samples), dim=0) + errG = torch.mean(self.discriminator(fake_samples)).squeeze() errG.backward(self.one) # optimizer step @@ -783,12 +757,20 @@ def train_model(self, train_dataset, val_dataset=None): print("Training completed!") + # save final checkpoint if save_dir is configured + if self.save_dir: + import os + os.makedirs(self.save_dir, exist_ok=True) + checkpoint_path = os.path.join(self.save_dir, "corgan_final.pt") + self.save_model(checkpoint_path) + print(f"Checkpoint saved to {checkpoint_path}") + def synthesize_dataset(self, num_samples: int, random_sampling: bool = True) -> List[Dict]: """Generate synthetic patient records. - Each synthetic patient is represented as a single visit containing all - generated codes. This is an honest representation of what CorGAN produces — - a flat multi-hot vector aggregated across all visits. + Each synthetic patient is represented as a flat list of codes decoded + from the generated binary vector. This mirrors the ``multi_hot`` input + schema used during training. Args: num_samples: Number of synthetic patients to generate. @@ -797,8 +779,7 @@ def synthesize_dataset(self, num_samples: int, random_sampling: bool = True) -> Returns: List of dicts, each with: ``"patient_id"`` (str): e.g. ``"synthetic_0"``. - ``"visits"`` (list of list of str): one visit per patient - containing the decoded ICD codes. + ``"visits"`` (list of str): flat list of decoded ICD code strings. """ # set models to eval mode self.generator.eval() @@ -812,9 +793,7 @@ def synthesize_dataset(self, num_samples: int, random_sampling: bool = True) -> for i in range(n_batches): z = torch.randn(self.batch_size, self.latent_dim, device=device) gen_samples_tensor = self.generator(z) - gen_samples_decoded = torch.squeeze( - self.autoencoder_decoder(gen_samples_tensor.unsqueeze(dim=2)) - ) + gen_samples_decoded = self.autoencoder.decode(gen_samples_tensor) gen_samples[i * self.batch_size:(i + 1) * self.batch_size, :] = ( gen_samples_decoded.cpu().data.numpy() ) @@ -824,9 +803,7 @@ def synthesize_dataset(self, num_samples: int, random_sampling: bool = True) -> if remaining > 0: z = torch.randn(remaining, self.latent_dim, device=device) gen_samples_tensor = self.generator(z) - gen_samples_decoded = torch.squeeze( - self.autoencoder_decoder(gen_samples_tensor.unsqueeze(dim=2)) - ) + gen_samples_decoded = self.autoencoder.decode(gen_samples_tensor) gen_samples[n_batches * self.batch_size:, :] = ( gen_samples_decoded.cpu().data.numpy() ) @@ -846,7 +823,7 @@ def synthesize_dataset(self, num_samples: int, random_sampling: bool = True) -> ] results.append({ "patient_id": f"synthetic_{i}", - "visits": [codes], + "visits": codes, }) return results diff --git a/pyhealth/tasks/corgan_generation.py b/pyhealth/tasks/corgan_generation.py index d3ce692fe..d2e948741 100644 --- a/pyhealth/tasks/corgan_generation.py +++ b/pyhealth/tasks/corgan_generation.py @@ -9,15 +9,17 @@ class CorGANGenerationMIMIC3(BaseTask): """Task function for CorGAN synthetic EHR generation using MIMIC-III. Extracts ICD-9 diagnosis codes from MIMIC-III admission records into a - nested visit structure suitable for training the CorGAN model. + flat list of codes suitable for training the CorGAN model. - Each sample contains the full visit history for a single patient, where - each visit is a list of ICD-9 codes recorded during that admission. - Patients with fewer than 2 visits are excluded. + CorGAN is a bag-of-codes model: it collapses all visit codes for a patient + into a single binary vector, so visit structure is irrelevant. All codes + from all admissions are pooled into one flat list per patient. + Patients with no codes are excluded. Attributes: task_name (str): Unique task identifier. - input_schema (dict): Schema descriptor for the visits field. + input_schema (dict): Schema descriptor — ``"visits"`` field uses + ``"multi_hot"`` encoding (flat list of code strings). output_schema (dict): Empty — generative task, no conditioning label. _icd_col (str): Polars column path for ICD codes in MIMIC-III. @@ -28,12 +30,16 @@ class CorGANGenerationMIMIC3(BaseTask): """ task_name = "CorGANGenerationMIMIC3" - input_schema = {"visits": "nested_sequence"} + input_schema = {"visits": "multi_hot"} output_schema = {} _icd_col = "diagnoses_icd/icd9_code" def __call__(self, patient) -> List[Dict]: - """Extract structured visit data for a single patient. + """Extract flat code list for a single patient. + + All ICD codes from all admissions are pooled into a single flat list. + Visit temporal structure is discarded because CorGAN operates on a + single multi-hot binary vector per patient. Args: patient: A PyHealth patient object with admission and diagnosis @@ -41,13 +47,14 @@ def __call__(self, patient) -> List[Dict]: Returns: list of dict: A single-element list containing the patient record, - or an empty list if the patient has fewer than 2 visits with - diagnosis codes. Each dict has: + or an empty list if the patient has no diagnosis codes. Each + dict has: ``"patient_id"`` (str): the patient identifier. - ``"visits"`` (list of list of str): per-visit ICD code lists. + ``"visits"`` (list of str): flat list of all ICD codes across all + admissions. """ admissions = list(patient.get_events(event_type="admissions")) - visits = [] + all_codes = [] for adm in admissions: codes = ( patient.get_events( @@ -60,11 +67,10 @@ def __call__(self, patient) -> List[Dict]: .drop_nulls() .to_list() ) - if codes: - visits.append(codes) - if len(visits) < 2: + all_codes.extend(codes) + if not all_codes: return [] - return [{"patient_id": patient.patient_id, "visits": visits}] + return [{"patient_id": patient.patient_id, "visits": all_codes}] class CorGANGenerationMIMIC4(CorGANGenerationMIMIC3): From 77bfd406baff583a292dd46e73a22c4fd57b343e Mon Sep 17 00:00:00 2001 From: jalengg Date: Wed, 25 Feb 2026 10:23:19 -0600 Subject: [PATCH 07/13] Fix docstring quality issues in T2+T3 multi_hot schema change --- pyhealth/models/generators/corgan.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/pyhealth/models/generators/corgan.py b/pyhealth/models/generators/corgan.py index 93cbb9029..6aa01282b 100644 --- a/pyhealth/models/generators/corgan.py +++ b/pyhealth/models/generators/corgan.py @@ -469,7 +469,7 @@ class CorGAN(BaseModel): Generative Adversarial Networks", JAMIA 2019. Args: - dataset: A fitted SampleDataset with ``input_schema = {"visits": "multi_hot"}``. + dataset (SampleDataset): A fitted SampleDataset with ``input_schema = {"visits": "multi_hot"}``. latent_dim: Dimensionality of the generator latent space. Default: 128. hidden_dim: Hidden dimension for the generator MLP. Default: 128. batch_size: Training batch size. Default: 512. @@ -627,6 +627,9 @@ def train_model(self, train_dataset, val_dataset=None): train_dataset: A fitted SampleDataset with ``input_schema = {"visits": "multi_hot"}``. val_dataset: Unused. Accepted for API compatibility. + + Returns: + None """ print("Starting CorGAN training...") @@ -777,9 +780,10 @@ def synthesize_dataset(self, num_samples: int, random_sampling: bool = True) -> random_sampling: Unused; accepted for API compatibility. Returns: - List of dicts, each with: + list of dict: Synthetic patient records. Each dict has: ``"patient_id"`` (str): e.g. ``"synthetic_0"``. ``"visits"`` (list of str): flat list of decoded ICD code strings. + May be empty if the generated vector has all values below the 0.5 threshold. """ # set models to eval mode self.generator.eval() From f3b79951fcc59f488c763c39866ffece21162028 Mon Sep 17 00:00:00 2001 From: jalengg Date: Wed, 25 Feb 2026 10:26:39 -0600 Subject: [PATCH 08/13] T7: fix deferred import and complete CorGAN docstrings --- pyhealth/models/generators/corgan.py | 28 ++++++++++++++++++---------- 1 file changed, 18 insertions(+), 10 deletions(-) diff --git a/pyhealth/models/generators/corgan.py b/pyhealth/models/generators/corgan.py index 6aa01282b..df09bc6b8 100644 --- a/pyhealth/models/generators/corgan.py +++ b/pyhealth/models/generators/corgan.py @@ -1,3 +1,4 @@ +import os import time from typing import Dict, List, Optional @@ -493,15 +494,17 @@ class CorGAN(BaseModel): Examples: >>> from pyhealth.datasets.sample_dataset import InMemorySampleDataset - >>> samples = [{"patient_id": "p1", "visits": ["A", "B", "C"]}] >>> dataset = InMemorySampleDataset( - ... samples=samples, + ... samples=[ + ... {"patient_id": "p1", "visits": ["A", "B", "C"]}, + ... {"patient_id": "p2", "visits": ["A", "C", "D"]}, + ... ], ... input_schema={"visits": "multi_hot"}, ... output_schema={}, ... ) - >>> model = CorGAN(dataset, latent_dim=32, hidden_dim=32, epochs=1) - >>> model.train_model(dataset) - >>> records = model.synthesize_dataset(num_samples=10) + >>> model = CorGAN(dataset, latent_dim=32, hidden_dim=32) + >>> isinstance(model, CorGAN) + True """ def __init__( @@ -762,7 +765,6 @@ def train_model(self, train_dataset, val_dataset=None): # save final checkpoint if save_dir is configured if self.save_dir: - import os os.makedirs(self.save_dir, exist_ok=True) checkpoint_path = os.path.join(self.save_dir, "corgan_final.pt") self.save_model(checkpoint_path) @@ -833,10 +835,13 @@ def synthesize_dataset(self, num_samples: int, random_sampling: bool = True) -> return results def save_model(self, path: str): - """Save model checkpoint. + """Save model weights and vocabulary to a checkpoint file. Args: - path: File path to write the checkpoint to. + path (str): File path to write the checkpoint (.pt file). + + Returns: + None """ torch.save({ 'generator_state_dict': self.generator.state_dict(), @@ -852,10 +857,13 @@ def save_model(self, path: str): }, path) def load_model(self, path: str): - """Load model checkpoint. + """Load model weights and vocabulary from a checkpoint file. Args: - path: File path to read the checkpoint from. + path (str): File path to read the checkpoint (.pt file). + + Returns: + None """ checkpoint = torch.load(path, map_location=self.device) From cddcffeff9ff06e4cfcd204478d6e16c3343dca8 Mon Sep 17 00:00:00 2001 From: jalengg Date: Wed, 25 Feb 2026 10:39:13 -0600 Subject: [PATCH 09/13] Rewrite CorGAN generation example to PyHealth 2.0 API --- examples/generate_synthetic_mimic3_corgan.py | 200 ++++--------------- 1 file changed, 39 insertions(+), 161 deletions(-) diff --git a/examples/generate_synthetic_mimic3_corgan.py b/examples/generate_synthetic_mimic3_corgan.py index a3d8538fa..d88ac0534 100644 --- a/examples/generate_synthetic_mimic3_corgan.py +++ b/examples/generate_synthetic_mimic3_corgan.py @@ -1,161 +1,39 @@ -#!/usr/bin/env python3 -""" -Generate synthetic MIMIC-III patients using a trained CorGAN checkpoint. -Uses Variable top-K sampling to maintain natural variation in code counts. -""" - -import os -import sys -sys.path.insert(0, '/u/jalenj4/PyHealth-Medgan-Corgan-Port') -import argparse -import torch -import numpy as np -import pandas as pd -from pyhealth.models.generators.corgan import CorGANAutoencoder, CorGAN8LayerAutoencoder, CorGANGenerator, CorGANDiscriminator - - -def main(): - parser = argparse.ArgumentParser(description="Generate synthetic patients using trained CorGAN") - parser.add_argument("--checkpoint", required=True, help="Path to trained CorGAN checkpoint (.pth)") - parser.add_argument("--vocab", required=True, help="Path to ICD-9 vocabulary file (.txt)") - parser.add_argument("--binary_matrix", required=True, help="Path to training binary matrix (.npy)") - parser.add_argument("--output", required=True, help="Path to output CSV file") - parser.add_argument("--n_samples", type=int, default=10000, help="Number of synthetic patients to generate") - parser.add_argument("--mean_k", type=float, default=13, help="Mean K for Variable top-K sampling") - parser.add_argument("--std_k", type=float, default=5, help="Std dev K for Variable top-K sampling") - args = parser.parse_args() - - device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - print(f"Using device: {device}") - - # Load vocabulary - print(f"Loading vocabulary from {args.vocab}") - with open(args.vocab, 'r') as f: - code_vocab = [line.strip() for line in f] - print(f"Loaded {len(code_vocab)} ICD-9 codes") - - # Load binary matrix to get architecture dimensions - print(f"Loading binary matrix from {args.binary_matrix}") - binary_matrix = np.load(args.binary_matrix) - n_codes = binary_matrix.shape[1] - print(f"Binary matrix shape: {binary_matrix.shape}") - print(f"Real data avg codes/patient: {binary_matrix.sum(axis=1).mean():.2f}") - - # Load checkpoint - print(f"\nLoading checkpoint from {args.checkpoint}") - checkpoint = torch.load(args.checkpoint, map_location=device, weights_only=False) - - # Detect architecture from checkpoint - # Check if this is 8-layer by looking at state dict keys - state_keys = checkpoint['autoencoder_state_dict'].keys() - is_8layer = any('encoder.18' in k or 'encoder.21' in k for k in state_keys) # 8-layer has more layers - - # Initialize CorGAN components with correct architecture - print("Initializing CorGAN model components...") - if n_codes == 6955: - if is_8layer: - autoencoder = CorGAN8LayerAutoencoder(feature_size=n_codes).to(device) - print("Detected 8-layer architecture") - else: - # Assume adaptive pooling - autoencoder = CorGANAutoencoder( - feature_size=n_codes, - use_adaptive_pooling=True - ).to(device) - print("Detected 6-layer + adaptive pooling architecture") - else: - autoencoder = CorGANAutoencoder( - feature_size=n_codes, - use_adaptive_pooling=False - ).to(device) - print(f"Using standard 6-layer architecture for {n_codes} codes") - - generator = CorGANGenerator(latent_dim=128, hidden_dim=128).to(device) - discriminator = CorGANDiscriminator(input_dim=n_codes, hidden_dim=256).to(device) - - # Load trained weights - autoencoder.load_state_dict(checkpoint['autoencoder_state_dict']) - generator.load_state_dict(checkpoint['generator_state_dict']) - discriminator.load_state_dict(checkpoint['discriminator_state_dict']) - - autoencoder.eval() - generator.eval() - discriminator.eval() - print("Model loaded successfully") - - # Generate synthetic patients - print(f"\nGenerating {args.n_samples} synthetic patients...") - - with torch.no_grad(): - # Generate random noise - z = torch.randn(args.n_samples, 128, device=device) - # Generate latent codes - generated_latent = generator(z) - # Decode to probabilities - synthetic_probs = autoencoder.decode(generated_latent) - - # Trim or pad if needed - if synthetic_probs.shape[1] > n_codes: - synthetic_probs = synthetic_probs[:, :n_codes] - elif synthetic_probs.shape[1] < n_codes: - padding = torch.zeros(synthetic_probs.shape[0], n_codes - synthetic_probs.shape[1], device=device) - synthetic_probs = torch.cat([synthetic_probs, padding], dim=1) - - probs = synthetic_probs.cpu().numpy() - - # Apply Variable top-K sampling - print(f"Applying Variable top-K sampling (μ={args.mean_k}, σ={args.std_k})...") - binary_matrix_synthetic = np.zeros_like(probs) - - for i in range(args.n_samples): - # Sample K from normal distribution, clip to reasonable range - k = int(np.clip(np.random.normal(args.mean_k, args.std_k), 1, 50)) - # Get indices of top-K probabilities - top_k_indices = np.argsort(probs[i])[-k:] - binary_matrix_synthetic[i, top_k_indices] = 1 - - # Calculate statistics - avg_codes = binary_matrix_synthetic.sum(axis=1).mean() - std_codes = binary_matrix_synthetic.sum(axis=1).std() - min_codes = binary_matrix_synthetic.sum(axis=1).min() - max_codes = binary_matrix_synthetic.sum(axis=1).max() - sparsity = (binary_matrix_synthetic == 0).mean() - - print(f"\nSynthetic data statistics:") - print(f" Avg codes per patient: {avg_codes:.2f} ± {std_codes:.2f}") - print(f" Range: [{min_codes:.0f}, {max_codes:.0f}]") - print(f" Sparsity: {sparsity:.4f}") - - # Check heterogeneity - unique_profiles = len(set(tuple(row) for row in binary_matrix_synthetic)) - print(f" Unique patient profiles: {unique_profiles}/{args.n_samples} ({unique_profiles/args.n_samples*100:.1f}%)") - - # Convert to CSV format (SUBJECT_ID, ICD9_CODE) - print(f"\nConverting to CSV format...") - records = [] - for patient_idx in range(args.n_samples): - patient_id = f"SYNTHETIC_{patient_idx+1:06d}" - code_indices = np.where(binary_matrix_synthetic[patient_idx] == 1)[0] - - for code_idx in code_indices: - records.append({ - 'SUBJECT_ID': patient_id, - 'ICD9_CODE': code_vocab[code_idx] - }) - - df = pd.DataFrame(records) - print(f"Created {len(df)} diagnosis records for {args.n_samples} patients") - - # Save to CSV - print(f"\nSaving to {args.output}") - df.to_csv(args.output, index=False) - - file_size_mb = os.path.getsize(args.output) / (1024 * 1024) - print(f"Saved {file_size_mb:.1f} MB") - - print("\n✓ Generation complete!") - print(f"Output: {args.output}") - - -if __name__ == '__main__': - main() +"""Generate synthetic MIMIC-III patient records using a trained CorGAN checkpoint.""" +import json + +from pyhealth.datasets import MIMIC3Dataset +from pyhealth.tasks import corgan_generation_mimic3_fn +from pyhealth.models.generators.corgan import CorGAN + +# 1. Reconstruct dataset — required to initialise CorGAN's vocabulary from the processor. +# If you trained with different tables=, update this to match exactly. +base_dataset = MIMIC3Dataset( + root="/path/to/mimic3", + tables=["diagnoses_icd"], +) +sample_dataset = base_dataset.set_task(corgan_generation_mimic3_fn) + +# 2. Instantiate model (epochs and training params are unused during generation; +# they must match your training configuration for checkpoint compatibility). +model = CorGAN( + dataset=sample_dataset, + latent_dim=128, + hidden_dim=128, + batch_size=128, + epochs=50, + save_dir="./corgan_checkpoints/", +) + +# 3. Load trained checkpoint +model.load_model("./corgan_checkpoints/corgan_final.pt") + +# 4. Generate synthetic patients — each patient is a flat bag-of-codes (no visit structure) +synthetic = model.synthesize_dataset(num_samples=10000) +print(f"Generated {len(synthetic)} synthetic patients") +print(f"Example record: {synthetic[0]}") + +# 5. Save to JSON +output_path = "synthetic_corgan_10k.json" +with open(output_path, "w") as f: + json.dump(synthetic, f, indent=2) +print(f"Saved to {output_path}") From 2bd1bca6f77f53dd3605780ae17fb92ef076d5e9 Mon Sep 17 00:00:00 2001 From: jalengg Date: Wed, 25 Feb 2026 10:43:24 -0600 Subject: [PATCH 10/13] Add CorGAN PyHealth 2.0 training example --- examples/corgan_mimic3_training.py | 34 ++++++++++++++++++++++++++++++ 1 file changed, 34 insertions(+) create mode 100644 examples/corgan_mimic3_training.py diff --git a/examples/corgan_mimic3_training.py b/examples/corgan_mimic3_training.py new file mode 100644 index 000000000..c68d700e5 --- /dev/null +++ b/examples/corgan_mimic3_training.py @@ -0,0 +1,34 @@ +"""Train CorGAN on MIMIC-III diagnosis codes and save a checkpoint.""" + +# 1. Load MIMIC-III dataset +from pyhealth.datasets import MIMIC3Dataset +from pyhealth.datasets import split_by_patient +from pyhealth.tasks import corgan_generation_mimic3_fn +from pyhealth.models.generators.corgan import CorGAN + +base_dataset = MIMIC3Dataset( + root="/path/to/mimic3", + tables=["diagnoses_icd"], +) + +# 2. Apply generation task — flattens all ICD codes per patient into a bag-of-codes +sample_dataset = base_dataset.set_task(corgan_generation_mimic3_fn) +print(f"{len(sample_dataset)} patients after filtering") + +# 3. Patient-level split — required for generative models to prevent data leakage across splits +train_dataset, val_dataset, _ = split_by_patient(sample_dataset, [0.8, 0.1, 0.1]) + +# 4. Instantiate and train — reduce epochs for testing; 50+ recommended for quality synthetic data +model = CorGAN( + dataset=sample_dataset, + latent_dim=128, + hidden_dim=128, + batch_size=128, + epochs=50, + lr=1e-4, + save_dir="./corgan_checkpoints/", +) +model.train_model(train_dataset, val_dataset) + +# 5. Checkpoint is saved automatically to save_dir by train_model +print("Training complete. Checkpoint saved to ./corgan_checkpoints/") From 5c6b699f9dce947e7b7e7d431cf1b585f06bbcd1 Mon Sep 17 00:00:00 2001 From: jalengg Date: Wed, 25 Feb 2026 10:45:03 -0600 Subject: [PATCH 11/13] T5: use named variable for MIMIC3 root path --- examples/generate_synthetic_mimic3_corgan.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/examples/generate_synthetic_mimic3_corgan.py b/examples/generate_synthetic_mimic3_corgan.py index d88ac0534..f01fc40a3 100644 --- a/examples/generate_synthetic_mimic3_corgan.py +++ b/examples/generate_synthetic_mimic3_corgan.py @@ -5,10 +5,13 @@ from pyhealth.tasks import corgan_generation_mimic3_fn from pyhealth.models.generators.corgan import CorGAN +# Update this to your local MIMIC-III path before running +MIMIC3_ROOT = "/path/to/mimic3" + # 1. Reconstruct dataset — required to initialise CorGAN's vocabulary from the processor. # If you trained with different tables=, update this to match exactly. base_dataset = MIMIC3Dataset( - root="/path/to/mimic3", + root=MIMIC3_ROOT, tables=["diagnoses_icd"], ) sample_dataset = base_dataset.set_task(corgan_generation_mimic3_fn) From 5411b3eee4a6e5c2cfb9a17934cd7ee9f1ddfbeb Mon Sep 17 00:00:00 2001 From: jalengg Date: Wed, 25 Feb 2026 11:08:32 -0600 Subject: [PATCH 12/13] T8: Add CorGAN end-to-end integration tests --- tests/integration/__init__.py | 0 tests/integration/test_corgan_end_to_end.py | 383 ++++++++++++++++++++ 2 files changed, 383 insertions(+) create mode 100644 tests/integration/__init__.py create mode 100644 tests/integration/test_corgan_end_to_end.py diff --git a/tests/integration/__init__.py b/tests/integration/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/integration/test_corgan_end_to_end.py b/tests/integration/test_corgan_end_to_end.py new file mode 100644 index 000000000..aff81883a --- /dev/null +++ b/tests/integration/test_corgan_end_to_end.py @@ -0,0 +1,383 @@ +"""End-to-end integration tests for the CorGAN synthetic EHR generation pipeline. + +Category A tests use InMemorySampleDataset with synthetic data — no external +data required and must always pass. + +Category B tests require actual MIMIC-III data and are skipped gracefully when +the data is unavailable. + +The bootstrap pattern mirrors test_halo_end_to_end.py: load CorGAN and +InMemorySampleDataset via importlib while stubbing out heavy optional +dependencies (einops, litdata, etc.) that are not yet in the venv. +""" + +import importlib.util +import os +import sys +import tempfile +import unittest +from unittest.mock import MagicMock + + +# --------------------------------------------------------------------------- +# Bootstrap: load CorGAN, BaseModel, and InMemorySampleDataset without +# triggering pyhealth.models.__init__ (requires einops, litdata, etc.) or +# pyhealth.datasets.__init__ (requires litdata, pyarrow, pandas, dask, ...). +# --------------------------------------------------------------------------- + + +def _bootstrap(): + """Load CorGAN, BaseModel, and InMemorySampleDataset via importlib. + + Returns: + (BaseModel, CorGAN, InMemorySampleDataset) + """ + import pyhealth # noqa: F401 — top-level __init__ has no heavy deps + + # Stub pyhealth.datasets so that base_model.py's + # "from ..datasets import SampleDataset" resolves cleanly. + if "pyhealth.datasets" not in sys.modules: + ds_stub = MagicMock() + + class _FakeSampleDataset: # noqa: N801 + pass + + ds_stub.SampleDataset = _FakeSampleDataset + sys.modules["pyhealth.datasets"] = ds_stub + + # Stub pyhealth.models so we can control loading without the real __init__. + if "pyhealth.models" not in sys.modules or isinstance( + sys.modules["pyhealth.models"], MagicMock + ): + models_stub = MagicMock() + sys.modules["pyhealth.models"] = models_stub + else: + models_stub = sys.modules["pyhealth.models"] + + # Processors are safe to import normally. + from pyhealth.processors import PROCESSOR_REGISTRY # noqa: F401 + + def _load_file(mod_name, filepath): + spec = importlib.util.spec_from_file_location(mod_name, filepath) + mod = importlib.util.module_from_spec(spec) + sys.modules[mod_name] = mod + spec.loader.exec_module(mod) + return mod + + root = os.path.dirname( + os.path.dirname(os.path.dirname(os.path.abspath(__file__))) + ) + base = os.path.join(root, "pyhealth", "models") + + # Load base_model and expose via stub. + bm_mod = _load_file( + "pyhealth.models.base_model", os.path.join(base, "base_model.py") + ) + BaseModel = bm_mod.BaseModel + models_stub.BaseModel = BaseModel + + gen_stub = MagicMock() + sys.modules.setdefault("pyhealth.models.generators", gen_stub) + + # Load CorGAN directly — generators/corgan.py only imports torch + BaseModel. + corgan_mod = _load_file( + "pyhealth.models.generators.corgan", + os.path.join(base, "generators", "corgan.py"), + ) + CorGAN = corgan_mod.CorGAN + + # Stub litdata so sample_dataset.py can be loaded without the full package. + # sample_dataset.py imports litdata.StreamingDataset and + # litdata.utilities.train_test_split.deepcopy_dataset. + if "litdata" not in sys.modules: + litdata_pkg = MagicMock() + litdata_pkg.StreamingDataset = type( + "StreamingDataset", (), {"__init__": lambda self, *a, **kw: None} + ) + litdata_utilities = MagicMock() + litdata_utilities_train_test = MagicMock() + litdata_utilities_train_test.deepcopy_dataset = lambda x: x + litdata_utilities.train_test_split = litdata_utilities_train_test + litdata_pkg.utilities = litdata_utilities + sys.modules["litdata"] = litdata_pkg + sys.modules["litdata.utilities"] = litdata_utilities + sys.modules["litdata.utilities.train_test_split"] = ( + litdata_utilities_train_test + ) + + # Load sample_dataset.py directly (bypasses datasets/__init__.py). + ds_file_mod = _load_file( + "pyhealth.datasets.sample_dataset", + os.path.join(root, "pyhealth", "datasets", "sample_dataset.py"), + ) + InMemorySampleDataset = ds_file_mod.InMemorySampleDataset + + return BaseModel, CorGAN, InMemorySampleDataset + + +BaseModel, CorGAN, InMemorySampleDataset = _bootstrap() + +import torch # noqa: E402 + + +# --------------------------------------------------------------------------- +# Shared helpers +# --------------------------------------------------------------------------- + +# Flat lists of code strings — CorGAN collapses all visit codes into a single +# multi-hot binary vector, so the task function returns one flat list per patient. +# Vocab = {A, B, C, D, E} (5 codes) → input_dim = 5 < 500 → linear autoencoder. +_SMALL_SAMPLES = [ + {"patient_id": "p1", "visits": ["A", "B", "C"]}, + {"patient_id": "p2", "visits": ["A", "C", "D"]}, + {"patient_id": "p3", "visits": ["B", "D", "E"]}, + {"patient_id": "p4", "visits": ["A", "B", "C", "D"]}, + {"patient_id": "p5", "visits": ["C", "E"]}, + {"patient_id": "p6", "visits": ["A", "D", "E"]}, + {"patient_id": "p7", "visits": ["B", "C", "D"]}, + {"patient_id": "p8", "visits": ["A", "E"]}, +] + +# Minimal dimensions to keep tests fast. 5-code vocab < 500 triggers the linear +# autoencoder fallback automatically (no architectural changes needed in tests). +_SMALL_MODEL_KWARGS = dict( + latent_dim=4, + hidden_dim=4, + batch_size=4, + epochs=1, + n_epochs_pretrain=1, + save_dir=None, +) + + +def _make_dataset(samples=None): + if samples is None: + samples = _SMALL_SAMPLES + return InMemorySampleDataset( + samples=samples, + input_schema={"visits": "multi_hot"}, + output_schema={}, + ) + + +# --------------------------------------------------------------------------- +# Category A: In-Memory Integration Tests (must always pass) +# --------------------------------------------------------------------------- + + +class TestCorGANIsBaseModelInstance(unittest.TestCase): + """CorGAN model is an instance of BaseModel.""" + + def test_model_is_basemodel_instance(self): + dataset = _make_dataset() + model = CorGAN(dataset, **_SMALL_MODEL_KWARGS) + self.assertIsInstance(model, BaseModel) + + +class TestCorGANFeatureKeys(unittest.TestCase): + """model.feature_keys equals ['visits'].""" + + def test_feature_keys(self): + dataset = _make_dataset() + model = CorGAN(dataset, **_SMALL_MODEL_KWARGS) + self.assertEqual(model.feature_keys, ["visits"]) + + +class TestCorGANVocabSize(unittest.TestCase): + """CorGAN.input_dim matches processor.size().""" + + def test_vocab_size_matches_processor(self): + dataset = _make_dataset() + expected = dataset.input_processors["visits"].size() + model = CorGAN(dataset, **_SMALL_MODEL_KWARGS) + self.assertEqual(model.input_dim, expected) + + +class TestCorGANForwardRaisesNotImplementedError(unittest.TestCase): + """Calling forward() raises NotImplementedError. + + CorGAN uses GAN-style training (train_model / synthesize_dataset), not + the supervised forward pass used by discriminative BaseModel subclasses. + """ + + def test_forward_not_implemented(self): + dataset = _make_dataset() + model = CorGAN(dataset, **_SMALL_MODEL_KWARGS) + with self.assertRaises(NotImplementedError): + model.forward() + + +class TestCorGANTrainModelRuns(unittest.TestCase): + """train_model completes one epoch without error.""" + + def test_train_model_runs_one_epoch(self): + dataset = _make_dataset() + model = CorGAN(dataset, **_SMALL_MODEL_KWARGS) + try: + model.train_model(dataset, val_dataset=None) + except Exception as exc: # noqa: BLE001 + self.fail(f"train_model raised an unexpected exception: {exc}") + + +class TestCorGANSynthesizeCount(unittest.TestCase): + """synthesize_dataset(num_samples=5) returns exactly 5 dicts.""" + + def setUp(self): + dataset = _make_dataset() + self.model = CorGAN(dataset, **_SMALL_MODEL_KWARGS) + + def test_synthesize_returns_correct_count(self): + result = self.model.synthesize_dataset(num_samples=5) + self.assertIsInstance(result, list) + self.assertEqual(len(result), 5) + + +class TestCorGANSynthesizeOutputStructure(unittest.TestCase): + """Each synthesized dict has patient_id (str) and visits (flat list of str). + + CorGAN outputs a flat list of code strings per patient — not nested visit + lists. This reflects the multi_hot input schema where all codes are pooled + into a single binary vector. + """ + + def setUp(self): + dataset = _make_dataset() + self.model = CorGAN(dataset, **_SMALL_MODEL_KWARGS) + + def test_synthesize_output_structure(self): + result = self.model.synthesize_dataset(num_samples=3) + for i, item in enumerate(result): + self.assertIsInstance(item, dict, f"Item {i} is not a dict") + self.assertIn("patient_id", item, f"Item {i} missing 'patient_id'") + self.assertIn("visits", item, f"Item {i} missing 'visits'") + self.assertIsInstance( + item["patient_id"], str, f"patient_id in item {i} is not a str" + ) + self.assertIsInstance( + item["visits"], list, f"visits in item {i} is not a list" + ) + # visits is a flat list of strings — NOT nested visit lists + for code in item["visits"]: + self.assertIsInstance( + code, str, f"code '{code}' in item {i} is not a str" + ) + + +class TestCorGANSaveLoadRoundtrip(unittest.TestCase): + """save_model then load_model; synthesize_dataset still returns correct count.""" + + def test_save_load_roundtrip(self): + dataset = _make_dataset() + model = CorGAN(dataset, **_SMALL_MODEL_KWARGS) + with tempfile.TemporaryDirectory() as tmpdir: + ckpt_path = os.path.join(tmpdir, "corgan_test.pt") + model.save_model(ckpt_path) + self.assertTrue( + os.path.exists(ckpt_path), + f"Expected checkpoint at {ckpt_path}", + ) + model.load_model(ckpt_path) + result = model.synthesize_dataset(num_samples=3) + self.assertEqual(len(result), 3) + + +# --------------------------------------------------------------------------- +# Category B: MIMIC-III Integration Tests (skipped if data unavailable) +# --------------------------------------------------------------------------- + +_MIMIC3_PATH = os.environ.get( + "PYHEALTH_MIMIC3_PATH", + "/srv/local/data/physionet.org/files/mimiciii/1.4", +) + + +class TestCorGANMIMIC3Integration(unittest.TestCase): + """End-to-end pipeline test with actual MIMIC-III data. + + Skipped automatically when MIMIC-III is not present on this machine. + """ + + @classmethod + def setUpClass(cls): + cls.skip_integration = False + cls.skip_reason = "" + try: + # Remove the bootstrap stub for pyhealth.datasets so we can attempt + # a real import (which will raise ImportError if litdata is absent). + _saved_stub = sys.modules.pop("pyhealth.datasets", None) + try: + import importlib as _il + _il.invalidate_caches() + from pyhealth.datasets import MIMIC3Dataset as _MIMIC3Dataset + from pyhealth.tasks.corgan_generation import CorGANGenerationMIMIC3 + except (ImportError, ModuleNotFoundError) as exc: + # Restore stub so the rest of the test session is unaffected. + if _saved_stub is not None: + sys.modules["pyhealth.datasets"] = _saved_stub + raise ImportError(str(exc)) from exc + + cls.dataset = _MIMIC3Dataset( + root=_MIMIC3_PATH, + tables=["diagnoses_icd"], + ) + task = CorGANGenerationMIMIC3() + cls.sample_dataset = cls.dataset.set_task(task) + except (FileNotFoundError, OSError, ImportError, ValueError) as exc: + cls.skip_integration = True + cls.skip_reason = str(exc) + + def setUp(self): + if self.skip_integration: + self.skipTest(f"MIMIC-III integration test skipped: {self.skip_reason}") + + def test_mimic3_set_task_returns_nonempty_dataset(self): + """set_task produces at least one sample from MIMIC-III.""" + self.assertGreater(len(self.sample_dataset), 0) + + def test_mimic3_sample_keys(self): + """Every sample must contain patient_id and visits keys.""" + for sample in self.sample_dataset: + self.assertIn("patient_id", sample) + self.assertIn("visits", sample) + + def test_mimic3_visits_are_flat_multihot_tensors(self): + """visits must be a float32 tensor of shape (vocab_size,) with values in {0, 1}. + + MultiHotProcessor outputs a binary float32 tensor of shape (vocab_size,). + This verifies that the multi_hot schema round-trips through set_task correctly. + """ + processor = self.sample_dataset.input_processors["visits"] + vocab_size = processor.size() + for sample in self.sample_dataset: + visits = sample["visits"] + self.assertIsInstance(visits, torch.Tensor) + self.assertEqual(visits.shape, (vocab_size,)) + self.assertEqual(visits.dtype, torch.float32) + self.assertTrue( + torch.all((visits == 0.0) | (visits == 1.0)), + "visits tensor contains values outside {0, 1}", + ) + + def test_mimic3_full_pipeline_train_and_synthesize(self): + """Train one epoch on MIMIC-III data and synthesize a small batch.""" + with tempfile.TemporaryDirectory() as tmpdir: + model = CorGAN( + self.sample_dataset, + latent_dim=64, + hidden_dim=64, + batch_size=32, + epochs=1, + n_epochs_pretrain=1, + save_dir=tmpdir, + ) + model.train_model(self.sample_dataset, val_dataset=None) + synthetic = model.synthesize_dataset(num_samples=10) + self.assertEqual(len(synthetic), 10) + for item in synthetic: + self.assertIn("patient_id", item) + self.assertIn("visits", item) + self.assertIsInstance(item["visits"], list) + + +if __name__ == "__main__": + unittest.main() From 057b9346d4c7225492270319766d453c1818048c Mon Sep 17 00:00:00 2001 From: Stevie Xie <66903483+shiitavie@users.noreply.github.com> Date: Mon, 2 Mar 2026 22:42:45 -0700 Subject: [PATCH 13/13] corgan notebook --- examples/corgan_mimic3_colab.ipynb | 644 ++++++++++++++++++++ pyhealth/models/generators/corgan.py | 16 +- tests/integration/test_corgan_end_to_end.py | 29 + 3 files changed, 688 insertions(+), 1 deletion(-) create mode 100644 examples/corgan_mimic3_colab.ipynb diff --git a/examples/corgan_mimic3_colab.ipynb b/examples/corgan_mimic3_colab.ipynb new file mode 100644 index 000000000..2d3cd9657 --- /dev/null +++ b/examples/corgan_mimic3_colab.ipynb @@ -0,0 +1,644 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# CorGAN Synthetic Data Generation for MIMIC-III\n", + "\n", + "**Last updated:** 2026-02-28\n", + "\n", + "This notebook trains a [CorGAN](https://doi.org/10.1093/jamia/ocz120) model on MIMIC-III diagnosis codes and generates synthetic electronic health records (EHRs).\n", + "\n", + "### What You'll Need\n", + "- **MIMIC-III access** via [PhysioNet](https://physionet.org/content/mimiciii/) — or use **demo mode** (no data required)\n", + "- **Google Colab** with GPU runtime (Runtime > Change runtime type > T4 GPU), or a local machine with PyTorch\n", + "- **Time:** Demo ~20–30 min on T4 | Production ~2–4 hrs on T4\n", + "\n", + "### What You'll Get\n", + "- A trained CorGAN model checkpoint\n", + "- Synthetic MIMIC-III patients (default: 1,000 in demo, 10,000 in production)\n", + "- A quality report (`quality_report.json`) with evaluation metrics\n", + "\n", + "### How It Works (6 Steps)\n", + "1. **Setup** — Install PyHealth, detect GPU\n", + "2. **Configure** — Set training parameters (demo vs production)\n", + "3. **Upload Data** — Upload MIMIC-III CSVs or use demo data\n", + "4. **Train** — Pre-train autoencoder, then adversarial WGAN training\n", + "5. **Generate** — Synthesize patient records\n", + "6. **Evaluate** — Compare real vs synthetic data quality\n", + "\n", + "### What Makes CorGAN Different\n", + "\n", + "CorGAN uses a **CNN autoencoder** combined with a **Wasserstein GAN (WGAN)** to generate synthetic patient records as flat bags-of-codes (binary vectors of ICD-9 diagnoses).\n", + "\n", + "The key insight is that 1D convolutions in the autoencoder capture **inter-code correlations** — for example, that diabetes and hypertension frequently co-occur — which plain linear autoencoders miss. The three components work together:\n", + "\n", + "- **Autoencoder**: Learns a compressed latent representation of multi-hot ICD-9 code vectors. The CNN layers capture which codes tend to appear together.\n", + "- **Generator**: Produces synthetic latent codes from random noise.\n", + "- **Discriminator (Critic)**: Distinguishes real latent codes from generated ones, providing the adversarial training signal.\n", + "\n", + "Training proceeds in two phases: (1) the autoencoder is pre-trained to learn good latent representations, then (2) the WGAN trains the generator to produce realistic latent codes that the decoder maps back to binary code vectors.\n", + "\n", + "Unlike HALO (which generates sequential visits via a Transformer), CorGAN aggregates all of a patient's diagnoses across admissions into a single flat set.\n", + "\n", + "### Reference\n", + "Baowaly et al., \"Synthesizing Electronic Health Records Using Improved Generative Adversarial Networks\", *JAMIA*, 2019. [DOI: 10.1093/jamia/ocz120](https://doi.org/10.1093/jamia/ocz120)\n", + "\n", + "---\n", + "\n", + "> **Colab timeout warning:** Free Colab sessions disconnect after ~90 min of inactivity. For production training, consider [Colab Pro](https://colab.research.google.com/signup) or running on a local GPU/SLURM cluster." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 1. Setup & Installation" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": "import subprocess, sys\n\nFORK = \"jalengg\"\nBRANCH = \"corgan-pr-integration\"\ninstall_url = f\"git+https://github.com/{FORK}/PyHealth.git@{BRANCH}\"\n\n# Only install from GitHub in Colab; locally, use the editable install\ntry:\n import google.colab\n result = subprocess.run(\n [sys.executable, \"-m\", \"pip\", \"install\", install_url, \"--quiet\", \"--no-cache-dir\"],\n capture_output=True, text=True\n )\n if result.returncode != 0:\n print(result.stderr)\n raise RuntimeError(\"PyHealth installation failed\")\n print(f\"PyHealth installed from {FORK}/{BRANCH}\")\nexcept ImportError:\n print(\"Not in Colab — using local PyHealth installation\")" + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import os, json, random, time\n", + "from datetime import datetime\n", + "\n", + "import numpy as np\n", + "import pandas as pd\n", + "import torch\n", + "import matplotlib.pyplot as plt\n", + "\n", + "# Environment detection\n", + "try:\n", + " import google.colab\n", + " IN_COLAB = True\n", + "except ImportError:\n", + " IN_COLAB = False\n", + "\n", + "# GPU detection\n", + "if torch.cuda.is_available():\n", + " print(f\"GPU: {torch.cuda.get_device_name(0)}\")\n", + " DEVICE = \"cuda\"\n", + "else:\n", + " print(\"WARNING: No GPU detected. Training will be slow.\")\n", + " DEVICE = \"cpu\"\n", + "\n", + "print(f\"PyTorch {torch.__version__} | Environment: {'Colab' if IN_COLAB else 'Local'}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 2. Configuration\n", + "\n", + "All hyperparameters are centralized here. Change `PRESET` to switch between a quick demo and full production training. You should not need to modify any other cell." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# ============================================================\n", + "# CONFIGURATION — Modify these parameters\n", + "# ============================================================\n", + "\n", + "PRESET = \"demo\" # \"demo\" or \"production\"\n", + "\n", + "# Training parameters\n", + "if PRESET == \"demo\":\n", + " EPOCHS = 5 # Quick smoke test (~20-30 min on T4)\n", + " BATCH_SIZE = 64\n", + " N_SYNTHETIC_SAMPLES = 1000\n", + " N_EPOCHS_PRETRAIN = 1 # Autoencoder pre-training epochs\n", + "elif PRESET == \"production\":\n", + " EPOCHS = 50 # Full training (~2-4 hrs on T4)\n", + " BATCH_SIZE = 512\n", + " N_SYNTHETIC_SAMPLES = 10000\n", + " N_EPOCHS_PRETRAIN = 3\n", + "\n", + "# Model architecture\n", + "LATENT_DIM = 128 # Generator + decoder latent space dimension\n", + "HIDDEN_DIM = 128 # Generator MLP hidden layer width\n", + "AUTOENCODER_TYPE = \"cnn\" # \"cnn\", \"cnn8layer\", or \"linear\"\n", + "\n", + "# WGAN parameters\n", + "# N_ITER_D: Discriminator updates per generator step. Higher = stronger critic\n", + "# but slower training. 5 is the WGAN default from Arjovsky et al.\n", + "N_ITER_D = 5\n", + "# Weight clipping enforces the Lipschitz constraint on the discriminator.\n", + "# Smaller range = more stable but slower convergence.\n", + "CLAMP_LOWER = -0.01\n", + "CLAMP_UPPER = 0.01\n", + "LR = 0.001 # Learning rate for all optimizers\n", + "\n", + "# Reproducibility\n", + "SEED = 42\n", + "\n", + "# Paths\n", + "BASE_DIR = \"/content/drive/MyDrive/CorGAN_Training\" if IN_COLAB else \"./corgan_training\"\n", + "DATA_DIR = os.path.join(BASE_DIR, \"data\")\n", + "CHECKPOINT_DIR = os.path.join(BASE_DIR, \"checkpoints\")\n", + "OUTPUT_DIR = os.path.join(BASE_DIR, \"output\")\n", + "\n", + "print(f\"Preset: {PRESET}\")\n", + "print(f\"Training: {EPOCHS} epochs, batch size {BATCH_SIZE}\")\n", + "print(f\"Architecture: {AUTOENCODER_TYPE} autoencoder, latent_dim={LATENT_DIM}\")\n", + "print(f\"Output: {N_SYNTHETIC_SAMPLES} synthetic patients\")\n", + "print(f\"Base directory: {BASE_DIR}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 3. Data Upload\n", + "\n", + "**Two paths:**\n", + "- **With MIMIC-III:** Upload `DIAGNOSES_ICD.csv` (or `.csv.gz`) from your PhysioNet download, plus `ADMISSIONS.csv`\n", + "- **Demo mode (no MIMIC-III):** A synthetic dataset is generated automatically so you can run the full pipeline\n", + "\n", + "If you've already uploaded files to Google Drive in a previous session, they'll be detected automatically." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Mount Google Drive (Colab only) and create directories\n", + "if IN_COLAB:\n", + " from google.colab import drive, files\n", + " drive.mount(\"/content/drive\")\n", + "\n", + "for d in [DATA_DIR, CHECKPOINT_DIR, OUTPUT_DIR]:\n", + " os.makedirs(d, exist_ok=True)\n", + " print(f\"Directory ready: {d}\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from pyhealth.datasets.sample_dataset import InMemorySampleDataset\n", + "\n", + "MIMIC3_ROOT = None # Set to your MIMIC-III directory path, or leave None for demo\n", + "\n", + "# Try to detect MIMIC-III data\n", + "if MIMIC3_ROOT and os.path.isdir(MIMIC3_ROOT):\n", + " USE_DEMO = False\n", + " print(f\"Using MIMIC-III data from: {MIMIC3_ROOT}\")\n", + "elif IN_COLAB:\n", + " # Try uploading files\n", + " print(\"Upload DIAGNOSES_ICD.csv (or .csv.gz) and ADMISSIONS.csv from MIMIC-III:\")\n", + " try:\n", + " uploaded = files.upload()\n", + " if uploaded:\n", + " for fname, content in uploaded.items():\n", + " dest = os.path.join(DATA_DIR, fname)\n", + " with open(dest, \"wb\") as f:\n", + " f.write(content)\n", + " print(f\"Saved: {dest} ({len(content):,} bytes)\")\n", + " MIMIC3_ROOT = DATA_DIR\n", + " USE_DEMO = False\n", + " else:\n", + " USE_DEMO = True\n", + " except Exception:\n", + " USE_DEMO = True\n", + "else:\n", + " USE_DEMO = True\n", + "\n", + "if USE_DEMO:\n", + " print(\"MIMIC-III not found \\u2014 using synthetic demo data\")\n", + " print(\"Demo mode uses a small random vocabulary to demonstrate the full pipeline.\")\n", + "\n", + " # Generate demo data: 200 patients with random ICD-9-like codes\n", + " demo_codes = [f\"D{i:03d}\" for i in range(50)] # 50-code vocabulary\n", + " rng = np.random.RandomState(SEED)\n", + " demo_samples = []\n", + " for pid in range(200):\n", + " n_codes = rng.randint(3, 20)\n", + " codes = rng.choice(demo_codes, size=n_codes, replace=False).tolist()\n", + " demo_samples.append({\"patient_id\": f\"demo_{pid}\", \"visits\": codes})\n", + "\n", + " dataset = InMemorySampleDataset(\n", + " samples=demo_samples,\n", + " input_schema={\"visits\": \"multi_hot\"},\n", + " output_schema={},\n", + " dataset_name=\"demo\",\n", + " task_name=\"CorGANGeneration\",\n", + " )\n", + " print(f\"Demo dataset: {len(dataset)} patients, {len(demo_codes)} unique codes\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "if not USE_DEMO:\n", + " from pyhealth.datasets import MIMIC3Dataset\n", + " from pyhealth.tasks.corgan_generation import CorGANGenerationMIMIC3\n", + "\n", + " mimic3_dataset = MIMIC3Dataset(\n", + " root=MIMIC3_ROOT,\n", + " tables=[\"diagnoses_icd\"],\n", + " code_mapping={},\n", + " refresh_cache=False,\n", + " )\n", + "\n", + " task = CorGANGenerationMIMIC3()\n", + " dataset = mimic3_dataset.set_task(task)\n", + " print(f\"MIMIC-III dataset: {len(dataset)} patients\")\n", + "\n", + "# Display sample\n", + "print(f\"\\nSample patient: {dataset[0]}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 4. Training\n", + "\n", + "CorGAN training has **two phases**:\n", + "\n", + "1. **Autoencoder pre-training**: The CNN autoencoder learns to compress multi-hot ICD-9 code vectors into a low-dimensional latent space and reconstruct them. This teaches the model which codes tend to co-occur.\n", + "\n", + "2. **WGAN adversarial training**: The generator learns to produce latent codes that the decoder maps back to realistic binary code vectors. The discriminator (critic) is trained to distinguish real from generated latent codes, providing the Wasserstein distance as a training signal.\n", + "\n", + "After training, three loss curves are plotted: autoencoder reconstruction loss, discriminator Wasserstein distance, and generator loss. Stabilization of the Wasserstein distance indicates convergence." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Set random seeds for reproducibility\n", + "torch.manual_seed(SEED)\n", + "np.random.seed(SEED)\n", + "random.seed(SEED)\n", + "if torch.cuda.is_available():\n", + " torch.cuda.manual_seed_all(SEED)\n", + "torch.backends.cudnn.deterministic = True\n", + "print(f\"Random seed set to {SEED}\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from pyhealth.datasets import split_by_patient\n", + "\n", + "# Split dataset\n", + "train_dataset, val_dataset, test_dataset = split_by_patient(\n", + " dataset, ratios=[0.85, 0.1, 0.05], seed=SEED\n", + ")\n", + "print(f\"Split: train={len(train_dataset)}, val={len(val_dataset)}, test={len(test_dataset)}\")\n", + "\n", + "# Save hyperparameter config\n", + "config_record = {\n", + " \"preset\": PRESET,\n", + " \"epochs\": EPOCHS,\n", + " \"batch_size\": BATCH_SIZE,\n", + " \"n_epochs_pretrain\": N_EPOCHS_PRETRAIN,\n", + " \"latent_dim\": LATENT_DIM,\n", + " \"hidden_dim\": HIDDEN_DIM,\n", + " \"autoencoder_type\": AUTOENCODER_TYPE,\n", + " \"n_iter_d\": N_ITER_D,\n", + " \"clamp_lower\": CLAMP_LOWER,\n", + " \"clamp_upper\": CLAMP_UPPER,\n", + " \"lr\": LR,\n", + " \"seed\": SEED,\n", + " \"n_synthetic_samples\": N_SYNTHETIC_SAMPLES,\n", + " \"use_demo\": USE_DEMO,\n", + " \"timestamp\": datetime.now().isoformat(),\n", + "}\n", + "config_path = os.path.join(CHECKPOINT_DIR, \"config.json\")\n", + "with open(config_path, \"w\") as f:\n", + " json.dump(config_record, f, indent=2)\n", + "print(f\"Config saved to {config_path}\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": "from pyhealth.models.generators.corgan import CorGAN\n\nmodel = CorGAN(\n dataset=train_dataset,\n latent_dim=LATENT_DIM,\n hidden_dim=HIDDEN_DIM,\n batch_size=BATCH_SIZE,\n epochs=EPOCHS,\n n_epochs_pretrain=N_EPOCHS_PRETRAIN,\n n_iter_D=N_ITER_D,\n clamp_lower=CLAMP_LOWER,\n clamp_upper=CLAMP_UPPER,\n lr=LR,\n autoencoder_type=AUTOENCODER_TYPE,\n save_dir=CHECKPOINT_DIR,\n)\n\nprint(f\"Model initialized: input_dim={model.input_dim}, autoencoder={AUTOENCODER_TYPE}\")\nprint(f\"Training: {EPOCHS} adversarial epochs + {N_EPOCHS_PRETRAIN} pretrain epochs\\n\")\n\nhistory = model.train_model(train_dataset)" + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "fig, axes = plt.subplots(1, 3, figsize=(15, 4))\n", + "\n", + "# Autoencoder loss\n", + "axes[0].plot(history[\"autoencoder_loss\"], marker=\"o\", color=\"tab:blue\")\n", + "axes[0].set_title(\"Autoencoder Loss (Pre-training)\")\n", + "axes[0].set_xlabel(\"Epoch\")\n", + "axes[0].set_ylabel(\"Reconstruction Loss\")\n", + "axes[0].grid(True, alpha=0.3)\n", + "\n", + "# Discriminator loss (Wasserstein distance)\n", + "axes[1].plot(history[\"discriminator_loss\"], marker=\"o\", color=\"tab:orange\")\n", + "axes[1].set_title(\"Discriminator Loss (Wasserstein Distance)\")\n", + "axes[1].set_xlabel(\"Epoch\")\n", + "axes[1].set_ylabel(\"W-Distance\")\n", + "axes[1].grid(True, alpha=0.3)\n", + "\n", + "# Generator loss\n", + "axes[2].plot(history[\"generator_loss\"], marker=\"o\", color=\"tab:green\")\n", + "axes[2].set_title(\"Generator Loss\")\n", + "axes[2].set_xlabel(\"Epoch\")\n", + "axes[2].set_ylabel(\"Loss\")\n", + "axes[2].grid(True, alpha=0.3)\n", + "\n", + "plt.suptitle(\"CorGAN Training Loss Curves\", fontsize=14)\n", + "plt.tight_layout()\n", + "plt.savefig(os.path.join(OUTPUT_DIR, \"loss_curves.png\"), dpi=150, bbox_inches=\"tight\")\n", + "plt.show()\n", + "print(\"Loss curves saved to output/loss_curves.png\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 5. Generation\n", + "\n", + "CorGAN generates a **flat bag-of-codes** per patient — a single set of ICD-9 diagnosis codes representing all their diagnoses across admissions. Unlike HALO (which generates sequential visits), there is no `VISIT_NUM` column in the output." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "print(f\"Generating {N_SYNTHETIC_SAMPLES} synthetic patients...\")\n", + "synthetic = model.synthesize_dataset(num_samples=N_SYNTHETIC_SAMPLES)\n", + "print(f\"Generated {len(synthetic)} patients\")\n", + "\n", + "# Display sample\n", + "sample_df = pd.DataFrame([\n", + " {\n", + " \"patient_id\": p[\"patient_id\"],\n", + " \"n_codes\": len(p[\"visits\"]),\n", + " \"sample_codes\": \", \".join(p[\"visits\"][:5]) + (\"...\" if len(p[\"visits\"]) > 5 else \"\"),\n", + " }\n", + " for p in synthetic[:10]\n", + "])\n", + "display(sample_df)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Save as JSON\n", + "json_path = os.path.join(OUTPUT_DIR, \"synthetic_patients.json\")\n", + "with open(json_path, \"w\") as f:\n", + " json.dump(synthetic, f, indent=2)\n", + "print(f\"JSON saved: {json_path} ({os.path.getsize(json_path):,} bytes)\")\n", + "\n", + "# Save as CSV (flat: SUBJECT_ID, ICD9_CODE)\n", + "rows = []\n", + "for p in synthetic:\n", + " for code in p[\"visits\"]:\n", + " rows.append({\"SUBJECT_ID\": p[\"patient_id\"], \"ICD9_CODE\": code})\n", + "csv_df = pd.DataFrame(rows)\n", + "csv_path = os.path.join(OUTPUT_DIR, \"synthetic_patients.csv\")\n", + "csv_df.to_csv(csv_path, index=False)\n", + "print(f\"CSV saved: {csv_path} ({len(csv_df)} rows, {os.path.getsize(csv_path):,} bytes)\")\n", + "\n", + "display(csv_df.head(10))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 6. Results & Evaluation\n", + "\n", + "This section validates synthetic data quality by comparing it against the real training data. CorGAN's core contribution is capturing inter-code correlations, so we measure code frequency fidelity (Pearson correlation) as the central metric." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# 6a. Vocabulary Coverage\n", + "all_generated_codes = set(code for p in synthetic for code in p[\"visits\"])\n", + "vocab_size = train_dataset.input_processors[\"visits\"].size()\n", + "coverage = len(all_generated_codes) / vocab_size * 100\n", + "print(f\"Unique codes generated: {len(all_generated_codes)}\")\n", + "print(f\"Vocabulary size: {vocab_size}\")\n", + "print(f\"Vocabulary coverage: {coverage:.1f}%\")\n", + "if coverage < 30:\n", + " print(\"WARNING: Low coverage may indicate mode collapse.\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# 6b. Code Count Statistics\n", + "synth_code_counts = [len(p[\"visits\"]) for p in synthetic]\n", + "\n", + "# Real training data code counts\n", + "real_code_counts = []\n", + "for i in range(len(train_dataset)):\n", + " sample = train_dataset[i]\n", + " n_codes = int(sample[\"visits\"].sum().item())\n", + " real_code_counts.append(n_codes)\n", + "\n", + "print(\"=== Codes per Patient ===\")\n", + "print(f\"{'':>20s} {'Real':>10s} {'Synthetic':>10s}\")\n", + "print(f\"{'Mean':>20s} {np.mean(real_code_counts):>10.1f} {np.mean(synth_code_counts):>10.1f}\")\n", + "print(f\"{'Std':>20s} {np.std(real_code_counts):>10.1f} {np.std(synth_code_counts):>10.1f}\")\n", + "print(f\"{'Min':>20s} {np.min(real_code_counts):>10d} {np.min(synth_code_counts):>10d}\")\n", + "print(f\"{'Max':>20s} {np.max(real_code_counts):>10d} {np.max(synth_code_counts):>10d}\")\n", + "print(f\"{'Median':>20s} {np.median(real_code_counts):>10.1f} {np.median(synth_code_counts):>10.1f}\")\n", + "\n", + "# Histogram comparison\n", + "fig, ax = plt.subplots(figsize=(10, 5))\n", + "bins = np.linspace(0, max(max(real_code_counts), max(synth_code_counts)), 30)\n", + "ax.hist(real_code_counts, bins=bins, alpha=0.6, label=\"Real\", density=True, color=\"tab:blue\")\n", + "ax.hist(synth_code_counts, bins=bins, alpha=0.6, label=\"Synthetic\", density=True, color=\"tab:orange\")\n", + "ax.set_xlabel(\"Number of Codes per Patient\")\n", + "ax.set_ylabel(\"Density\")\n", + "ax.set_title(\"Code Count Distribution: Real vs Synthetic\")\n", + "ax.legend()\n", + "ax.grid(True, alpha=0.3)\n", + "plt.tight_layout()\n", + "plt.savefig(os.path.join(OUTPUT_DIR, \"code_count_histogram.png\"), dpi=150, bbox_inches=\"tight\")\n", + "plt.show()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": "# 6c. Code Frequency Comparison\nfrom collections import Counter\n\n# Build index-to-code mapping from the processor's label_vocab\nprocessor = train_dataset.input_processors[\"visits\"]\nidx_to_code = {idx: code for code, idx in processor.label_vocab.items()}\n\n# Count code frequencies in real data\nreal_freq = Counter()\nfor i in range(len(train_dataset)):\n vec = train_dataset[i][\"visits\"]\n active = torch.where(vec > 0)[0].tolist()\n for idx in active:\n code = idx_to_code.get(idx)\n if code and code not in (\"\", \"\"):\n real_freq[code] += 1\n\n# Count code frequencies in synthetic data\nsynth_freq = Counter()\nfor p in synthetic:\n for code in p[\"visits\"]:\n synth_freq[code] += 1\n\n# Align frequencies on common codes\nall_codes = sorted(set(real_freq.keys()) | set(synth_freq.keys()))\nreal_vals = np.array([real_freq.get(c, 0) for c in all_codes], dtype=float)\nsynth_vals = np.array([synth_freq.get(c, 0) for c in all_codes], dtype=float)\n\n# Normalize to frequencies\nif real_vals.sum() > 0:\n real_vals /= real_vals.sum()\nif synth_vals.sum() > 0:\n synth_vals /= synth_vals.sum()\n\n# Pearson correlation\npearson_r = np.corrcoef(real_vals, synth_vals)[0, 1] if len(all_codes) > 1 else 0.0\nprint(f\"Code frequency Pearson r = {pearson_r:.4f}\")\n\n# Plot top 20 codes\nn_top = min(20, len(all_codes))\ntop_codes_idx = np.argsort(real_vals)[-n_top:][::-1]\ntop_codes = [all_codes[i] for i in top_codes_idx]\ntop_real = [real_vals[i] for i in top_codes_idx]\ntop_synth = [synth_vals[i] for i in top_codes_idx]\n\nfig, ax = plt.subplots(figsize=(12, 5))\nx = np.arange(len(top_codes))\nwidth = 0.35\nax.bar(x - width / 2, top_real, width, label=\"Real\", color=\"tab:blue\", alpha=0.8)\nax.bar(x + width / 2, top_synth, width, label=\"Synthetic\", color=\"tab:orange\", alpha=0.8)\nax.set_xlabel(\"ICD-9 Code\")\nax.set_ylabel(\"Frequency (normalized)\")\nax.set_title(f\"Top {n_top} Code Frequencies: Real vs Synthetic (Pearson r = {pearson_r:.3f})\")\nax.set_xticks(x)\nax.set_xticklabels(top_codes, rotation=45, ha=\"right\", fontsize=8)\nax.legend()\nax.grid(True, alpha=0.3, axis=\"y\")\nplt.tight_layout()\nplt.savefig(os.path.join(OUTPUT_DIR, \"code_frequency_comparison.png\"), dpi=150, bbox_inches=\"tight\")\nplt.show()" + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# 6d. All-Zeros Detection\n", + "empty_patients = sum(1 for p in synthetic if len(p[\"visits\"]) == 0)\n", + "print(f\"Empty patients (all-zeros): {empty_patients} / {len(synthetic)}\")\n", + "if empty_patients > 0:\n", + " print(\"WARNING: Some patients have no codes (all-zeros generation).\")\n", + " print(\"Consider: adjusting binarization threshold, retraining with more epochs,\")\n", + " print(\"or reducing learning rate.\")\n", + "else:\n", + " print(\"No all-zeros patients detected.\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# 6e. Quality Report\n", + "quality_report = {\n", + " \"total_synthetic_patients\": len(synthetic),\n", + " \"mean_codes_per_patient\": float(np.mean(synth_code_counts)),\n", + " \"std_codes_per_patient\": float(np.std(synth_code_counts)),\n", + " \"min_codes\": int(np.min(synth_code_counts)),\n", + " \"max_codes\": int(np.max(synth_code_counts)),\n", + " \"unique_codes_generated\": len(all_generated_codes),\n", + " \"vocabulary_size\": vocab_size,\n", + " \"vocabulary_coverage_percent\": round(coverage, 2),\n", + " \"empty_patients_count\": empty_patients,\n", + " \"code_frequency_pearson_r\": round(float(pearson_r), 4),\n", + " \"seed\": SEED,\n", + " \"preset\": PRESET,\n", + " \"epochs\": EPOCHS,\n", + " \"timestamp\": datetime.now().isoformat(),\n", + "}\n", + "\n", + "report_path = os.path.join(OUTPUT_DIR, \"quality_report.json\")\n", + "with open(report_path, \"w\") as f:\n", + " json.dump(quality_report, f, indent=2)\n", + "\n", + "print(f\"Quality report saved to {report_path}\")\n", + "print(json.dumps(quality_report, indent=2))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 7. Download & Next Steps" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Download outputs\n", + "print(\"=== Output Files ===\")\n", + "for fname in [\"synthetic_patients.json\", \"synthetic_patients.csv\", \"quality_report.json\"]:\n", + " fpath = os.path.join(OUTPUT_DIR, fname)\n", + " if os.path.exists(fpath):\n", + " size = os.path.getsize(fpath)\n", + " print(f\" {fname}: {size:,} bytes\")\n", + "\n", + "if IN_COLAB:\n", + " for fname in [\"synthetic_patients.csv\", \"quality_report.json\", \"synthetic_patients.json\"]:\n", + " fpath = os.path.join(OUTPUT_DIR, fname)\n", + " if os.path.exists(fpath):\n", + " files.download(fpath)\n", + " print(\"\\nFiles downloaded! Also backed up in Google Drive.\")\n", + "else:\n", + " print(f\"\\nFiles saved in: {OUTPUT_DIR}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Congratulations!\n", + "\n", + "You've trained a CorGAN model and generated synthetic MIMIC-III patient records.\n", + "\n", + "### Next Steps\n", + "- **Use your synthetic data**: Load `synthetic_patients.json` into downstream PyHealth pipelines\n", + "- **Generate more patients**: Re-run Section 5 with a larger `N_SYNTHETIC_SAMPLES`\n", + "- **Production training**: Change `PRESET = \"production\"` and re-run from Section 2\n", + "\n", + "### Troubleshooting\n", + "\n", + "**Out of memory (OOM)**: Reduce `BATCH_SIZE` (try 32 or 16)\n", + "\n", + "**Mode collapse** (many empty patients or low vocabulary coverage):\n", + "- Check `quality_report.json` for `empty_patients_count` and `vocabulary_coverage_percent`\n", + "- Try reducing `LR` (e.g., 0.0001)\n", + "- Increase `N_ITER_D` (e.g., 10) to strengthen the critic\n", + "- Train for more epochs\n", + "\n", + "**Slow training**: Use GPU runtime (Runtime > Change runtime type > T4 GPU)\n", + "\n", + "### References\n", + "- [CorGAN Paper (Baowaly et al., JAMIA 2019)](https://doi.org/10.1093/jamia/ocz120)\n", + "- [PyHealth Documentation](https://pyhealth.readthedocs.io/)\n", + "- [MIMIC-III on PhysioNet](https://physionet.org/content/mimiciii/)" + ] + } + ], + "metadata": { + "accelerator": "GPU", + "colab": { + "gpuType": "T4", + "provenance": [] + }, + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "name": "python", + "version": "3.10.0" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} \ No newline at end of file diff --git a/pyhealth/models/generators/corgan.py b/pyhealth/models/generators/corgan.py index df09bc6b8..57278ce54 100644 --- a/pyhealth/models/generators/corgan.py +++ b/pyhealth/models/generators/corgan.py @@ -632,10 +632,19 @@ def train_model(self, train_dataset, val_dataset=None): val_dataset: Unused. Accepted for API compatibility. Returns: - None + dict: Loss history with keys: + - ``"autoencoder_loss"``: list of float, one per pretrain epoch. + - ``"discriminator_loss"``: list of float, one per adversarial epoch. + - ``"generator_loss"``: list of float, one per adversarial epoch. """ print("Starting CorGAN training...") + history = { + "autoencoder_loss": [], + "discriminator_loss": [], + "generator_loss": [], + } + # build multi-hot matrix by stacking the pre-encoded tensors from MultiHotProcessor tensors = [train_dataset[i]["visits"] for i in range(len(train_dataset))] data_matrix = torch.stack(tensors).numpy() # shape (n_patients, vocab_size) @@ -673,6 +682,7 @@ def train_model(self, train_dataset, val_dataset=None): if i % 100 == 0: print(f"[Epoch {epoch_pre + 1}/{self.n_epochs_pretrain}] [Batch {i}/{len(train_dataloader)}] [A loss: {a_loss.item():.3f}]") + history["autoencoder_loss"].append(a_loss.item()) # adversarial training print(f"Starting adversarial training for {self.n_epochs} epochs...") @@ -760,6 +770,8 @@ def train_model(self, train_dataset, val_dataset=None): f"Loss_D: {errD.item():.3f} Loss_G: {errG.item():.3f} " f"Loss_D_real: {errD_real.item():.3f} Loss_D_fake: {errD_fake.item():.3f}") print(f"Epoch time: {epoch_end - epoch_start:.2f} seconds") + history["discriminator_loss"].append(errD.item()) + history["generator_loss"].append(errG.item()) print("Training completed!") @@ -770,6 +782,8 @@ def train_model(self, train_dataset, val_dataset=None): self.save_model(checkpoint_path) print(f"Checkpoint saved to {checkpoint_path}") + return history + def synthesize_dataset(self, num_samples: int, random_sampling: bool = True) -> List[Dict]: """Generate synthetic patient records. diff --git a/tests/integration/test_corgan_end_to_end.py b/tests/integration/test_corgan_end_to_end.py index aff81883a..667f0e094 100644 --- a/tests/integration/test_corgan_end_to_end.py +++ b/tests/integration/test_corgan_end_to_end.py @@ -281,6 +281,35 @@ def test_save_load_roundtrip(self): self.assertEqual(len(result), 3) +class TestCorGANTrainModelReturnsLossHistory(unittest.TestCase): + """train_model() returns a dict with three loss lists.""" + + def test_train_model_returns_loss_history(self): + ds = _make_dataset() + model = CorGAN(dataset=ds, **_SMALL_MODEL_KWARGS) + history = model.train_model(ds) + + self.assertIsInstance(history, dict) + self.assertIn("autoencoder_loss", history) + self.assertIn("discriminator_loss", history) + self.assertIn("generator_loss", history) + self.assertEqual( + len(history["autoencoder_loss"]), + _SMALL_MODEL_KWARGS["n_epochs_pretrain"], + ) + self.assertEqual( + len(history["discriminator_loss"]), + _SMALL_MODEL_KWARGS["epochs"], + ) + self.assertEqual( + len(history["generator_loss"]), + _SMALL_MODEL_KWARGS["epochs"], + ) + self.assertTrue( + all(isinstance(v, float) for v in history["autoencoder_loss"]) + ) + + # --------------------------------------------------------------------------- # Category B: MIMIC-III Integration Tests (skipped if data unavailable) # ---------------------------------------------------------------------------