From 70985cab4ab349b293b911f84ba0f023f1337b3d Mon Sep 17 00:00:00 2001 From: chufangao Date: Sun, 17 May 2026 23:46:52 -0500 Subject: [PATCH 1/2] Add synthetic-EHR generative evaluation metrics Adds pyhealth/metrics/generative/, a subpackage for evaluating synthetic EHR data along privacy, utility, and statistical-fidelity axes: - privacy.py: NNAAR, membership inference attack, discriminator privacy - utility.py: machine learning efficacy (TRTR vs TSTR), code-prevalence similarity (R2, Pearson, RMSE) - utils.py: shared data prep, an LSTM classifier, and a random-forest baseline - evaluate_synthetic_ehr(): convenience orchestrator for the full suite These functions are ported from a standalone evaluation script. The MIMIC-specific data-loading/CLI glue is dropped; the metrics work on any flat EHR dataframe. Public functions are re-exported from pyhealth.metrics. Adds unit tests in tests/core/test_generative_metrics.py and Sphinx docs. Co-Authored-By: Claude Opus 4.7 (1M context) --- docs/api/metrics.rst | 3 + .../metrics/pyhealth.metrics.generative.rst | 25 + pyhealth/metrics/__init__.py | 14 + pyhealth/metrics/generative/__init__.py | 179 ++++++ pyhealth/metrics/generative/privacy.py | 335 ++++++++++ pyhealth/metrics/generative/utility.py | 245 ++++++++ pyhealth/metrics/generative/utils.py | 584 ++++++++++++++++++ tests/core/test_generative_metrics.py | 256 ++++++++ 8 files changed, 1641 insertions(+) create mode 100644 docs/api/metrics/pyhealth.metrics.generative.rst create mode 100644 pyhealth/metrics/generative/__init__.py create mode 100644 pyhealth/metrics/generative/privacy.py create mode 100644 pyhealth/metrics/generative/utility.py create mode 100644 pyhealth/metrics/generative/utils.py create mode 100644 tests/core/test_generative_metrics.py diff --git a/docs/api/metrics.rst b/docs/api/metrics.rst index 1767e0026..9e6bc160a 100644 --- a/docs/api/metrics.rst +++ b/docs/api/metrics.rst @@ -7,6 +7,8 @@ For applicable tasks, we provide the relevant metrics for model calibration, as Among these we also provide metrics related to uncertainty quantification, for model calibration, as well as metrics that measure the quality of prediction sets We also provide other metrics specically for healthcare tasks, such as drug drug interaction (DDI) rate. +For synthetic (generative) EHR data, we provide privacy, utility, and statistical +fidelity metrics. .. toctree:: @@ -19,3 +21,4 @@ tasks, such as drug drug interaction (DDI) rate. metrics/pyhealth.metrics.prediction_set metrics/pyhealth.metrics.fairness metrics/pyhealth.metrics.interpretability + metrics/pyhealth.metrics.generative diff --git a/docs/api/metrics/pyhealth.metrics.generative.rst b/docs/api/metrics/pyhealth.metrics.generative.rst new file mode 100644 index 000000000..85e448a52 --- /dev/null +++ b/docs/api/metrics/pyhealth.metrics.generative.rst @@ -0,0 +1,25 @@ +pyhealth.metrics.generative +=================================== + +Evaluation metrics for synthetic (generative) EHR data, covering privacy, +utility, and statistical fidelity. + +.. currentmodule:: pyhealth.metrics.generative + +.. autofunction:: evaluate_synthetic_ehr + +Privacy metrics +------------------------------------- + +.. autofunction:: calc_nnaar + +.. autofunction:: calc_membership_inference + +.. autofunction:: compute_discriminator_privacy + +Utility and fidelity metrics +------------------------------------- + +.. autofunction:: compute_mle + +.. autofunction:: compute_prevalence_metrics diff --git a/pyhealth/metrics/__init__.py b/pyhealth/metrics/__init__.py index da8da0f5b..f04b6ba6a 100644 --- a/pyhealth/metrics/__init__.py +++ b/pyhealth/metrics/__init__.py @@ -1,5 +1,13 @@ from .binary import binary_metrics_fn from .drug_recommendation import ddi_rate_score +from .generative import ( + calc_membership_inference, + calc_nnaar, + compute_discriminator_privacy, + compute_mle, + compute_prevalence_metrics, + evaluate_synthetic_ehr, +) from .interpretability import ( ComprehensivenessMetric, Evaluator, @@ -17,6 +25,12 @@ __all__ = [ "binary_metrics_fn", "ddi_rate_score", + "calc_nnaar", + "calc_membership_inference", + "compute_discriminator_privacy", + "compute_mle", + "compute_prevalence_metrics", + "evaluate_synthetic_ehr", "ComprehensivenessMetric", "SufficiencyMetric", "RemovalBasedMetric", diff --git a/pyhealth/metrics/generative/__init__.py b/pyhealth/metrics/generative/__init__.py new file mode 100644 index 000000000..4efad79da --- /dev/null +++ b/pyhealth/metrics/generative/__init__.py @@ -0,0 +1,179 @@ +"""Evaluation metrics for synthetic (generative) EHR data. + +This subpackage provides metrics for assessing synthetic electronic health +record (EHR) data along three axes: + + - **Privacy** (:mod:`pyhealth.metrics.generative.privacy`): NNAAR, + membership inference, and discriminator-based adversarial accuracy. + - **Utility / fidelity** (:mod:`pyhealth.metrics.generative.utility`): + machine learning efficacy (TRTR vs TSTR) and code-prevalence similarity. + +The convenience function :func:`evaluate_synthetic_ehr` runs the full suite +and returns a single merged dictionary of ``{metric_name: (mean, std)}``. +""" + +import logging +from typing import Dict, Optional, Tuple + +import pandas as pd + +from .privacy import ( + calc_membership_inference, + calc_nnaar, + compute_discriminator_privacy, +) +from .utility import compute_mle, compute_prevalence_metrics +from .utils import train_lstm_model, train_sklearn_model + +logger = logging.getLogger(__name__) + +__all__ = [ + "calc_nnaar", + "calc_membership_inference", + "compute_discriminator_privacy", + "compute_mle", + "compute_prevalence_metrics", + "evaluate_synthetic_ehr", +] + + +def evaluate_synthetic_ehr( + train_ehr: pd.DataFrame, + test_ehr: pd.DataFrame, + syn_ehr: pd.DataFrame, + subject_col: str = "id", + visit_col: str = "time", + code_col: str = "visit_codes", + label_col: str = "labels", + sample_size: int = 1000, + mode: str = "lstm", + metrics: str = "all", + lstm_params: Optional[Dict] = None, + sklearn_params: Optional[Dict] = None, + n_bootstraps: int = 100, + n_runs: int = 5, +) -> Dict[str, Tuple[float, float]]: + """Runs the full synthetic-EHR evaluation suite. + + Computes privacy and/or utility metrics comparing synthetic EHR data + against real train/test data, and returns a single merged dictionary. + + Args: + train_ehr: Real training EHR dataframe. + test_ehr: Real held-out test EHR dataframe. + syn_ehr: Synthetic EHR dataframe. + subject_col: Column name for patient/subject identifiers. + visit_col: Column name for visit/timestep identifiers. + code_col: Column name for the medical codes. + label_col: Column name for the label. + sample_size: Number of patients sampled per dataset for the + privacy metrics. + mode: Predictive backbone for the utility metrics; ``"lstm"`` uses the + built-in LSTM classifier, ``"rf"`` uses a random forest. + metrics: Which metric group to compute: ``"all"``, ``"privacy"`` or + ``"utility"``. + lstm_params: Optional overrides for the LSTM (``embed_dim``, + ``hidden_dim``, ``batch_size``, ``epochs``). + sklearn_params: Optional overrides for the sklearn model (``model``). + n_bootstraps: Number of bootstrap resamples for the utility metrics. + n_runs: Number of sampling runs for the privacy metrics. + + Returns: + Dictionary mapping each metric name to a ``(mean, std)`` tuple. + + Raises: + ValueError: If ``metrics`` or ``mode`` is not a recognized value. + """ + if metrics not in ("all", "privacy", "utility"): + raise ValueError( + f"Unknown metrics group: {metrics!r}. " + "Expected 'all', 'privacy' or 'utility'." + ) + if mode not in ("lstm", "rf"): + raise ValueError(f"Unknown mode: {mode!r}. Expected 'lstm' or 'rf'.") + + lstm_params = lstm_params or {} + sklearn_params = sklearn_params or {} + final_output: Dict[str, Tuple[float, float]] = {} + + if metrics in ("all", "privacy"): + final_output.update( + calc_nnaar( + train_ehr, + test_ehr, + syn_ehr, + subject_col=subject_col, + visit_col=visit_col, + code_col=code_col, + label_col=label_col, + sample_size=sample_size, + n_runs=n_runs, + ) + ) + final_output.update( + calc_membership_inference( + train_ehr, + test_ehr, + syn_ehr, + subject_col=subject_col, + visit_col=visit_col, + code_col=code_col, + label_col=label_col, + num_attack_samples=sample_size, + n_runs=n_runs, + ) + ) + + if metrics in ("all", "utility"): + if mode == "lstm": + train_fn = train_lstm_model + train_kwargs = { + "embed_dim": lstm_params.get("embed_dim", 32), + "hidden_dim": lstm_params.get("hidden_dim", 32), + "batch_size": lstm_params.get("batch_size", 32), + "epochs": lstm_params.get("epochs", 5), + "verbose": False, + } + else: + train_fn = train_sklearn_model + train_kwargs = {"model": sklearn_params.get("model", "rf")} + + final_output.update( + compute_mle( + train_fn=train_fn, + train_ehr=train_ehr, + test_ehr=test_ehr, + syn_ehr=syn_ehr, + subject_col=subject_col, + visit_col=visit_col, + code_col=code_col, + label_col=label_col, + n_bootstraps=n_bootstraps, + **train_kwargs, + ) + ) + final_output.update( + compute_discriminator_privacy( + train_fn=train_fn, + train_ehr=train_ehr, + test_ehr=test_ehr, + syn_ehr=syn_ehr, + subject_col=subject_col, + visit_col=visit_col, + code_col=code_col, + label_col=label_col, + n_bootstraps=n_bootstraps, + **train_kwargs, + ) + ) + final_output.update( + compute_prevalence_metrics( + train_ehr, + syn_ehr, + subject_col=subject_col, + code_col=code_col, + n_bootstraps=n_bootstraps, + ) + ) + + return final_output diff --git a/pyhealth/metrics/generative/privacy.py b/pyhealth/metrics/generative/privacy.py new file mode 100644 index 000000000..c553233ab --- /dev/null +++ b/pyhealth/metrics/generative/privacy.py @@ -0,0 +1,335 @@ +"""Privacy metrics for synthetic EHR data. + +These metrics quantify how much a synthetic EHR dataset leaks about the real +records it was trained on. They include: + + - Nearest Neighbor Adversarial Accuracy Risk (NNAAR) + - Membership Inference Attack (MIA) metrics + - A discriminator-based adversarial-accuracy privacy score + +All functions take flat EHR dataframes (one row per patient/visit/code event) +and return ``{metric_name: (mean, std)}`` summaries computed over multiple runs +or bootstrap resamples. +""" + +import copy +import logging +from typing import Callable, Dict, Tuple + +import numpy as np +import pandas as pd +from sklearn import metrics as sklearn_metrics +from sklearn.model_selection import train_test_split +from tqdm import tqdm + +from .utils import ( + convert_visits_to_sets, + find_nearest_neighbor_dist, + summarize_metric_runs, +) + +logger = logging.getLogger(__name__) + +__all__ = [ + "calc_nnaar", + "calc_membership_inference", + "compute_discriminator_privacy", +] + + +def calc_nnaar( + train_ehr: pd.DataFrame, + test_ehr: pd.DataFrame, + syn_ehr: pd.DataFrame, + subject_col: str = "id", + visit_col: str = "time", + code_col: str = "visit_codes", + label_col: str = "labels", + sample_size: int = 1000, + n_runs: int = 5, + verbose: bool = False, +) -> Dict[str, Tuple[float, float]]: + """Computes the Nearest Neighbor Adversarial Accuracy Risk (NNAAR). + + NNAAR measures whether the synthetic data sits closer to the real training + data than to held-out test data, which would indicate memorization:: + + NNAAR = AA_ES - AA_TS + + where ``AA_ES`` is the adversarial accuracy between test and synthetic data + and ``AA_TS`` is the adversarial accuracy between train and synthetic data. + Values near 0 indicate low privacy risk. + + Args: + train_ehr: Real training EHR dataframe. + test_ehr: Real held-out test EHR dataframe. + syn_ehr: Synthetic EHR dataframe. + subject_col: Column name for patient/subject identifiers. + visit_col: Column name for visit/timestep identifiers. + code_col: Column name for the medical codes. + label_col: Column name for the label (unused, kept for a uniform API). + sample_size: Number of patients to sample per dataset per run. + n_runs: Number of independent sampling runs. + verbose: Whether to show per-run progress bars. + + Returns: + Dictionary mapping ``"nnaar"``, ``"aa_es"`` and ``"aa_ts"`` to their + ``(mean, std)`` across runs. + """ + logger.info( + "Calculating NNAAR (sample_size=%d, n_runs=%d)", sample_size, n_runs + ) + train = convert_visits_to_sets(train_ehr, subject_col, visit_col, code_col) + test = convert_visits_to_sets(test_ehr, subject_col, visit_col, code_col) + synthetic = convert_visits_to_sets(syn_ehr, subject_col, visit_col, code_col) + + metrics_runs = [] + n = min(sample_size, len(train), len(test), len(synthetic)) + + for _ in range(n_runs): + if len(train) > n: + inds = np.random.choice(len(train), n, replace=False) + s_train = [train[i] for i in inds] + else: + s_train = list(train) + if len(test) > n: + inds = np.random.choice(len(test), n, replace=False) + s_test = [test[i] for i in inds] + else: + s_test = list(test) + if len(synthetic) > n: + inds = np.random.choice(len(synthetic), n, replace=False) + s_syn = [synthetic[i] for i in inds] + else: + s_syn = list(synthetic) + + # AA_ES (test vs synthetic). + val1 = sum( + 1 + for p in tqdm(s_test, desc="Test vs Syn", disable=not verbose) + if find_nearest_neighbor_dist(p, s_syn) + > find_nearest_neighbor_dist(p, s_test) + ) + val2 = sum( + 1 + for p in tqdm(s_syn, desc="Syn vs Test", disable=not verbose) + if find_nearest_neighbor_dist(p, s_test) + > find_nearest_neighbor_dist(p, s_syn) + ) + # AA_TS (train vs synthetic). + val3 = sum( + 1 + for p in tqdm(s_train, desc="Train vs Syn", disable=not verbose) + if find_nearest_neighbor_dist(p, s_syn) + > find_nearest_neighbor_dist(p, s_train) + ) + val4 = sum( + 1 + for p in tqdm(s_syn, desc="Syn vs Train", disable=not verbose) + if find_nearest_neighbor_dist(p, s_train) + > find_nearest_neighbor_dist(p, s_syn) + ) + + aa_es = 0.5 * (val1 / n + val2 / n) + aa_ts = 0.5 * (val3 / n + val4 / n) + metrics_runs.append( + {"nnaar": aa_es - aa_ts, "aa_es": aa_es, "aa_ts": aa_ts} + ) + + return summarize_metric_runs(metrics_runs) + + +def calc_membership_inference( + train_ehr: pd.DataFrame, + test_ehr: pd.DataFrame, + syn_ehr: pd.DataFrame, + subject_col: str = "id", + visit_col: str = "time", + code_col: str = "visit_codes", + label_col: str = "labels", + num_attack_samples: int = 1000, + n_runs: int = 5, + verbose: bool = False, +) -> Dict[str, Tuple[float, float]]: + """Computes Membership Inference Attack (MIA) metrics. + + An attacker tries to tell members (training patients) from non-members + (test patients) using proximity to the synthetic data: members are expected + to be closer to synthetic records. Predictions are made by thresholding the + nearest-neighbor distance at its median; F1, precision, recall and accuracy + near 0.5 indicate low membership-inference risk. + + Args: + train_ehr: Real training EHR dataframe (members). + test_ehr: Real held-out test EHR dataframe (non-members). + syn_ehr: Synthetic EHR dataframe. + subject_col: Column name for patient/subject identifiers. + visit_col: Column name for visit/timestep identifiers. + code_col: Column name for the medical codes. + label_col: Column name for the label (unused, kept for a uniform API). + num_attack_samples: Total attack-set size (half members, half not). + n_runs: Number of independent sampling runs. + verbose: Whether to show per-run progress bars. + + Returns: + Dictionary mapping ``"MIA_F1"``, ``"MIA_Precision"``, ``"MIA_Recall"`` + and ``"MIA_Accuracy"`` to their ``(mean, std)`` across runs. + """ + logger.info( + "Calculating Membership Inference (attack_size=%d, n_runs=%d)", + num_attack_samples, + n_runs, + ) + train = convert_visits_to_sets(train_ehr, subject_col, visit_col, code_col) + test = convert_visits_to_sets(test_ehr, subject_col, visit_col, code_col) + synthetic = convert_visits_to_sets(syn_ehr, subject_col, visit_col, code_col) + + metrics_runs = [] + for _ in range(n_runs): + # Build a balanced attack set: 50% members, 50% non-members. + n_half = min(len(train), len(test), num_attack_samples) // 2 + if n_half == 0: + continue + + pos_inds = np.random.choice(len(train), n_half, replace=False) + pos_samples = [train[i] for i in pos_inds] + neg_inds = np.random.choice(len(test), n_half, replace=False) + neg_samples = [test[i] for i in neg_inds] + + attack_data = pos_samples + neg_samples + attack_labels = [1] * len(pos_samples) + [0] * len(neg_samples) + + distances = [ + find_nearest_neighbor_dist(record, synthetic) + for record in tqdm( + attack_data, desc="Calculating Distances", disable=not verbose + ) + ] + if len(distances) == 0: + continue + + # Members are expected to be closer (smaller distance) to synthetic. + median_dist = np.median(distances) + predictions = [1 if d < median_dist else 0 for d in distances] + + metrics_runs.append( + { + "MIA_F1": sklearn_metrics.f1_score(attack_labels, predictions), + "MIA_Precision": sklearn_metrics.precision_score( + attack_labels, predictions, zero_division=0 + ), + "MIA_Recall": sklearn_metrics.recall_score( + attack_labels, predictions, zero_division=0 + ), + "MIA_Accuracy": sklearn_metrics.accuracy_score( + attack_labels, predictions + ), + } + ) + + summary = summarize_metric_runs(metrics_runs) + logger.info("MIA results: %s", summary) + return summary + + +def compute_discriminator_privacy( + train_fn: Callable, + train_ehr: pd.DataFrame, + test_ehr: pd.DataFrame, + syn_ehr: pd.DataFrame, + subject_col: str = "id", + visit_col: str = "time", + code_col: str = "visit_codes", + label_col: str = "labels", + n_bootstraps: int = 5, + seed: int = 4, + **kwargs, +) -> Dict[str, Tuple[float, float]]: + """Computes a discriminator-based adversarial-accuracy privacy score. + + A classifier is trained to predict whether a record is real (1) or + synthetic (0). An accuracy near 0.5 means real and synthetic data are + indistinguishable (good privacy); accuracy well above 0.5 means the + synthetic data is easy to tell apart (poor privacy). The ``Privacy_Score`` + rescales accuracy so 1.0 is perfect privacy and 0.0 is none. + + Args: + train_fn: A training function such as + :func:`pyhealth.metrics.generative.utils.train_lstm_model` or + ``train_sklearn_model``. It must accept ``train_ehr``, ``test_ehr``, + the four column-name arguments and return ``(model, y_true, + y_pred)``. + train_ehr: Real training EHR dataframe. + test_ehr: Real held-out test EHR dataframe (unused; kept for a uniform + API with the other metrics). + syn_ehr: Synthetic EHR dataframe. + subject_col: Column name for patient/subject identifiers. + visit_col: Column name for visit/timestep identifiers. + code_col: Column name for the medical codes. + label_col: Column name for the original label (unused; the + discriminator target replaces it). + n_bootstraps: Number of bootstrap resamples of the predictions. + seed: Random seed for the patient-level train/test split. + **kwargs: Extra keyword arguments forwarded to ``train_fn``. + + Returns: + Dictionary mapping ``"Privacy_Discriminator_Accuracy"`` and + ``"Privacy_Score"`` to their ``(mean, std)`` across bootstraps. + """ + logger.info("Computing discriminator privacy") + + # Label data: real = 1, synthetic = 0. + real_df = copy.deepcopy(train_ehr) + syn_df = copy.deepcopy(syn_ehr) + disc_label = "is_real" + real_df[disc_label] = 1 + syn_df[disc_label] = 0 + + # Disambiguate subject IDs so real/synthetic patients never collide. + real_df[subject_col] = real_df[subject_col].astype(str) + "_real" + syn_df[subject_col] = syn_df[subject_col].astype(str) + "_syn" + + combined_df = pd.concat([real_df, syn_df]) + unique_patients = combined_df[subject_col].unique() + train_ids, test_ids = train_test_split( + unique_patients, test_size=0.2, random_state=seed + ) + disc_train = combined_df[combined_df[subject_col].isin(train_ids)] + disc_test = combined_df[combined_df[subject_col].isin(test_ids)] + + logger.info( + "Discriminator train size=%d, test size=%d", + len(disc_train), + len(disc_test), + ) + _, y_true, y_pred = train_fn( + train_ehr=disc_train, + test_ehr=disc_test, + subject_col=subject_col, + visit_col=visit_col, + code_col=code_col, + label_col=disc_label, + **kwargs, + ) + + metrics_runs = [] + n_samples = len(y_true) + for _ in range(n_bootstraps): + if n_samples > 0: + indices = np.random.choice(n_samples, n_samples, replace=True) + y_t, y_p = y_true[indices], y_pred[indices] + else: + y_t, y_p = y_true, y_pred + + acc = sklearn_metrics.accuracy_score(y_t, y_p) if len(y_t) > 0 else 0.0 + metrics_runs.append( + { + "Privacy_Discriminator_Accuracy": acc, + # 1.0 = perfect privacy (acc 0.5); 0.0 = no privacy. + "Privacy_Score": 1.0 - 2 * abs(0.5 - acc), + } + ) + + summary = summarize_metric_runs(metrics_runs) + logger.info("Discriminator privacy results: %s", summary) + return summary diff --git a/pyhealth/metrics/generative/utility.py b/pyhealth/metrics/generative/utility.py new file mode 100644 index 000000000..cbaff4adb --- /dev/null +++ b/pyhealth/metrics/generative/utility.py @@ -0,0 +1,245 @@ +"""Utility and statistical-fidelity metrics for synthetic EHR data. + +These metrics quantify how *useful* synthetic EHR data is as a stand-in for +real data: + + - Machine Learning Efficacy (MLE): compares a model trained on real data + against one trained on synthetic data, both evaluated on real data. + - Code-prevalence similarity: compares per-code patient-level prevalence + between real and synthetic data (R-squared, Pearson correlation, RMSE). + +All functions take flat EHR dataframes (one row per patient/visit/code event) +and return ``{metric_name: (mean, std)}`` summaries over bootstrap resamples. +""" + +import copy +import logging +from typing import Callable, Dict, Tuple + +import numpy as np +import pandas as pd +from sklearn import metrics as sklearn_metrics + +from .utils import build_next_visit_prediction_dataset, summarize_metric_runs + +logger = logging.getLogger(__name__) + +__all__ = [ + "compute_mle", + "compute_prevalence_metrics", +] + + +def compute_mle( + train_fn: Callable, + train_ehr: pd.DataFrame, + test_ehr: pd.DataFrame, + syn_ehr: pd.DataFrame, + subject_col: str = "id", + visit_col: str = "time", + code_col: str = "visit_codes", + label_col: str = "labels", + n_bootstraps: int = 5, + **kwargs, +) -> Dict[str, Tuple[float, float]]: + """Computes Machine Learning Efficacy (utility) for synthetic data. + + Two classifiers are trained on a next-visit prediction task: one on real + training data (Train-Real-Test-Real, TRTR) and one on synthetic data + (Train-Synthetic-Test-Real, TSTR). Both are evaluated on the same real test + set. Synthetic accuracy/F1 close to real accuracy/F1 indicates high utility. + + Args: + train_fn: A training function such as + :func:`pyhealth.metrics.generative.utils.train_lstm_model` or + ``train_sklearn_model``, returning ``(model, y_true, y_pred)``. + train_ehr: Real training EHR dataframe. + test_ehr: Real held-out test EHR dataframe. + syn_ehr: Synthetic EHR dataframe. + subject_col: Column name for patient/subject identifiers. + visit_col: Column name for visit/timestep identifiers. + code_col: Column name for the medical codes. + label_col: Column name for the label (overwritten by the next-visit + prediction label). + n_bootstraps: Number of bootstrap resamples of the predictions. + **kwargs: Extra keyword arguments forwarded to ``train_fn``. + + Returns: + Dictionary mapping the MLE metrics (real/synthetic accuracy and F1, + their difference and ratio) to their ``(mean, std)`` across + bootstraps. + """ + logger.info("Computing MLE (utility)") + + train_task = build_next_visit_prediction_dataset( + train_ehr, subject_col, visit_col, label_col + ) + test_task = build_next_visit_prediction_dataset( + test_ehr, subject_col, visit_col, label_col + ) + syn_task = build_next_visit_prediction_dataset( + syn_ehr, subject_col, visit_col, label_col + ) + + # Train on Real, test on Real (TRTR). + _, real_y_true, real_y_pred = train_fn( + copy.deepcopy(train_task), + copy.deepcopy(test_task), + subject_col=subject_col, + visit_col=visit_col, + code_col=code_col, + label_col=label_col, + **kwargs, + ) + # Train on Synthetic, test on Real (TSTR). + _, syn_y_true, syn_y_pred = train_fn( + copy.deepcopy(syn_task), + copy.deepcopy(test_task), + subject_col=subject_col, + visit_col=visit_col, + code_col=code_col, + label_col=label_col, + **kwargs, + ) + + metrics_runs = [] + n_samples = len(real_y_true) + for _ in range(n_bootstraps): + if n_samples > 0: + indices = np.random.choice(n_samples, n_samples, replace=True) + r_true, r_pred = real_y_true[indices], real_y_pred[indices] + s_true, s_pred = syn_y_true[indices], syn_y_pred[indices] + else: + r_true, r_pred = real_y_true, real_y_pred + s_true, s_pred = syn_y_true, syn_y_pred + + real_acc = ( + sklearn_metrics.accuracy_score(r_true, r_pred) + if len(r_true) > 0 + else 0.0 + ) + syn_acc = ( + sklearn_metrics.accuracy_score(s_true, s_pred) + if len(s_true) > 0 + else 0.0 + ) + real_f1 = ( + sklearn_metrics.f1_score(r_true, r_pred, average="macro") + if len(r_true) > 0 + else 0.0 + ) + syn_f1 = ( + sklearn_metrics.f1_score(s_true, s_pred, average="macro") + if len(s_true) > 0 + else 0.0 + ) + + metrics_runs.append( + { + "MLE_Real_Accuracy": real_acc, + "MLE_Synth_Accuracy": syn_acc, + "MLE_Difference": real_acc - syn_acc, + "MLE_Ratio": syn_acc / real_acc if real_acc > 0 else 0.0, + "MLE_Real_F1": real_f1, + "MLE_Synth_F1": syn_f1, + } + ) + + summary = summarize_metric_runs(metrics_runs) + logger.info("MLE results: %s", summary) + return summary + + +def compute_prevalence_metrics( + train_ehr: pd.DataFrame, + syn_ehr: pd.DataFrame, + subject_col: str = "id", + code_col: str = "visit_codes", + n_bootstraps: int = 5, +) -> Dict[str, Tuple[float, float]]: + """Compares per-code patient-level prevalence of real vs synthetic data. + + For every code, prevalence is the fraction of unique patients who have that + code at least once. The real and synthetic prevalence vectors are compared + with R-squared, Pearson correlation and RMSE; bootstrap resampling is over + codes. + + Args: + train_ehr: Real training EHR dataframe. + syn_ehr: Synthetic EHR dataframe. + subject_col: Column name for patient/subject identifiers. + code_col: Column name for the medical codes. + n_bootstraps: Number of bootstrap resamples over codes. + + Returns: + Dictionary mapping ``"Prevalence_R2"``, ``"Prevalence_Pearson"`` and + ``"Prevalence_RMSE"`` to their ``(mean, std)`` across bootstraps. + """ + logger.info("Computing prevalence metrics") + + all_codes = set() + all_codes.update(train_ehr[code_col].unique().tolist()) + all_codes.update(syn_ehr[code_col].unique().tolist()) + + n_train = train_ehr[subject_col].nunique() + n_syn = syn_ehr[subject_col].nunique() + if n_train == 0 or n_syn == 0: + return { + "Prevalence_R2": (0.0, 0.0), + "Prevalence_Pearson": (0.0, 0.0), + "Prevalence_RMSE": (0.0, 0.0), + } + + # Count unique patients per code. + train_counts = train_ehr.groupby(code_col)[subject_col].nunique() + syn_counts = syn_ehr.groupby(code_col)[subject_col].nunique() + for code in all_codes: + if code not in train_counts.index: + train_counts.loc[code] = 0 + if code not in syn_counts.index: + syn_counts.loc[code] = 0 + + train_probs = train_counts / n_train + syn_probs = syn_counts / n_syn + df_compare = pd.DataFrame( + {"real": train_probs, "syn": syn_probs} + ).fillna(0) + + metrics_runs = [] + n_samples = len(df_compare) + for _ in range(n_bootstraps): + if n_samples > 0: + df_sampled = df_compare.sample(n=n_samples, replace=True) + real_vec = df_sampled["real"].values + syn_vec = df_sampled["syn"].values + else: + real_vec = df_compare["real"].values + syn_vec = df_compare["syn"].values + + r2 = ( + sklearn_metrics.r2_score(real_vec, syn_vec) + if n_samples > 1 + else 0.0 + ) + # Pearson correlation via numpy (avoids a hard scipy dependency). + if len(np.unique(real_vec)) > 1 and len(np.unique(syn_vec)) > 1: + rho = float(np.corrcoef(real_vec, syn_vec)[0, 1]) + else: + rho = 0.0 + rmse = ( + float(np.sqrt(sklearn_metrics.mean_squared_error(real_vec, syn_vec))) + if n_samples > 0 + else 0.0 + ) + + metrics_runs.append( + { + "Prevalence_R2": r2, + "Prevalence_Pearson": rho, + "Prevalence_RMSE": rmse, + } + ) + + summary = summarize_metric_runs(metrics_runs) + logger.info("Prevalence results: %s", summary) + return summary diff --git a/pyhealth/metrics/generative/utils.py b/pyhealth/metrics/generative/utils.py new file mode 100644 index 000000000..2e2123586 --- /dev/null +++ b/pyhealth/metrics/generative/utils.py @@ -0,0 +1,584 @@ +"""Shared utilities for synthetic-EHR generative evaluation metrics. + +This module contains the data-preparation helpers, distance functions, and +the lightweight predictive models (an LSTM classifier and a random-forest +baseline) that the privacy and utility metrics build on. It is not intended +to be used directly; see :mod:`pyhealth.metrics.generative.privacy` and +:mod:`pyhealth.metrics.generative.utility` for the public metric functions. +""" + +from typing import Dict, List, Optional, Tuple + +import numpy as np +import pandas as pd +import torch +import torch.nn as nn + +__all__ = [ + "summarize_metric_runs", + "convert_visits_to_sets", + "calculate_hamming_distance_cutoff", + "find_nearest_neighbor_dist", + "process_patient_data_for_lstm", + "collate_fn", + "EHRDataset", + "EHR_LSTM_Classifier", + "train_lstm_model", + "aggregate_patient_visits", + "train_sklearn_model", + "build_next_visit_prediction_dataset", + "convert_cols_to_multihot", +] + + +def summarize_metric_runs( + metrics_list: List[Dict[str, float]] +) -> Dict[str, Tuple[float, float]]: + """Summarizes a list of per-run metric dicts into (mean, std) tuples. + + Args: + metrics_list: List of dicts, one per run, mapping metric name to value. + + Returns: + Dictionary mapping each metric name to a ``(mean, std)`` tuple computed + across the runs. Returns an empty dict if ``metrics_list`` is empty. + """ + if not metrics_list: + return {} + summary: Dict[str, Tuple[float, float]] = {} + for key in metrics_list[0].keys(): + values = [run[key] for run in metrics_list if key in run] + summary[key] = (float(np.mean(values)), float(np.std(values))) + return summary + + +# --- Privacy distance helpers --------------------------------------------- + + +def convert_visits_to_sets( + df: pd.DataFrame, + subject_col: str = "id", + visit_col: str = "time", + code_col: str = "visit_codes", +) -> List[List[set]]: + """Converts a flat EHR dataframe into per-patient lists of code sets. + + Each patient becomes a list of visits, and each visit is a ``set`` of the + codes recorded at that timestep. + + Args: + df: Input dataframe with one row per (patient, visit, code) event. + subject_col: Column name for patient/subject identifiers. + visit_col: Column name for visit/timestep identifiers. + code_col: Column name for the medical codes. + + Returns: + List of patients, where each patient is a list of code sets. + """ + records = ( + df.groupby(subject_col)[[visit_col, code_col]] + .apply(lambda x: x.groupby(visit_col)[code_col].apply(set).tolist()) + .tolist() + ) + return records + + +def calculate_hamming_distance_cutoff( + v1: List[set], v2: List[set], cutoff: float +) -> float: + """Computes a set-based Hamming distance between two patients, with cutoff. + + The distance accumulates the symmetric-difference size of aligned visits + plus a penalty for differing sequence lengths. Computation stops early once + the running distance reaches ``cutoff``. + + Args: + v1: First patient as a list of code sets. + v2: Second patient as a list of code sets. + cutoff: Distance value at which to stop early. + + Returns: + The distance between ``v1`` and ``v2``, capped at ``cutoff``. + """ + len1, len2 = len(v1), len(v2) + dist = 0 if len1 == len2 else 1 + if dist >= cutoff: + return cutoff + + min_len = min(len1, len2) + for i in range(min_len): + dist += len(v1[i] ^ v2[i]) + if dist >= cutoff: + return cutoff + + if len1 > min_len: + dist += sum(len(v) for v in v1[min_len:]) + elif len2 > min_len: + dist += sum(len(v) for v in v2[min_len:]) + return dist + + +def find_nearest_neighbor_dist( + query: List[set], reference_dataset: List[List[set]] +) -> float: + """Finds the distance from a query patient to its nearest neighbor. + + Args: + query: Query patient as a list of code sets. + reference_dataset: Patients to search over. + + Returns: + The smallest :func:`calculate_hamming_distance_cutoff` distance between + ``query`` and any patient in ``reference_dataset``. + """ + best = float("inf") + for ref in reference_dataset: + d = calculate_hamming_distance_cutoff(query, ref, best) + if d == 0: + return 0 + if d < best: + best = d + return best + + +# --- LSTM classifier ------------------------------------------------------- + + +def process_patient_data_for_lstm( + df: pd.DataFrame, + subject_col: str = "id", + visit_col: str = "time", + code_col: str = "visit_codes", + label_col: str = "labels", + code_to_idx: Optional[Dict] = None, +) -> Tuple[List[Tuple[torch.Tensor, int]], Dict]: + """Transforms a flat EHR dataframe into multi-hot visit sequences. + + Each patient is converted into a ``(seq_len, vocab_size)`` tensor of + multi-hot visit vectors, paired with a single static label (the per-patient + max of ``label_col``). + + Args: + df: Input dataframe with one row per (patient, visit, code) event. + subject_col: Column name for patient/subject identifiers. + visit_col: Column name for visit/timestep identifiers. + code_col: Column name for the medical codes. + label_col: Column name for the binary label. + code_to_idx: Optional precomputed mapping from code to integer index. + If ``None``, one is built from ``df``. + + Returns: + A tuple ``(patients, code_to_idx)`` where ``patients`` is a list of + ``(sequence_tensor, label)`` tuples. + """ + assert label_col in df.columns, f"Label column '{label_col}' not found." + assert subject_col in df.columns, f"Subject column '{subject_col}' not found." + assert visit_col in df.columns, f"Visit column '{visit_col}' not found." + + df = df.copy() + if code_to_idx is None: + vocab_size = df[code_col].nunique() + 1 + code_to_idx = { + code: idx for idx, code in enumerate(df[code_col].unique(), start=0) + } + else: + vocab_size = len(code_to_idx) + 1 + df[code_col] = df[code_col].map(code_to_idx) + + patients = [] + for _, group in df.groupby(subject_col): + # Static per-patient label: the max over visits (e.g. "ever diagnosed"). + label = group[label_col].max() + visits = group.sort_values(visit_col).groupby(visit_col) + patient_seq = [] + for _, visit_data in visits: + multi_hot = torch.zeros(vocab_size) + codes = visit_data[code_col].values + multi_hot[codes] = 1.0 + patient_seq.append(multi_hot) + patient_seq_tensor = torch.stack(patient_seq) + patients.append((patient_seq_tensor, label)) + + return patients, code_to_idx + + +def collate_fn(batch): + """Pads variable-length visit sequences for batched LSTM training. + + Args: + batch: List of ``(sequence_tensor, label)`` tuples. + + Returns: + A tuple ``(padded_seqs, lengths, labels)``. + """ + sequences, labels = zip(*batch) + lengths = torch.tensor([len(seq) for seq in sequences]) + padded_seqs = torch.nn.utils.rnn.pad_sequence( + sequences, batch_first=True, padding_value=0 + ) + labels = torch.tensor(labels, dtype=torch.float32) + return padded_seqs, lengths, labels + + +class EHRDataset(torch.utils.data.Dataset): + """A minimal :class:`torch.utils.data.Dataset` wrapper over a list.""" + + def __init__(self, data): + self.data = data + + def __len__(self): + return len(self.data) + + def __getitem__(self, idx): + return self.data[idx] + + +class EHR_LSTM_Classifier(nn.Module): + """A simple LSTM classifier over multi-hot EHR visit sequences. + + The model embeds each multi-hot visit vector, encodes the sequence with an + LSTM, and classifies using the final hidden state. + + Args: + vocab_size: Size of the code vocabulary (input dimension per visit). + embed_dim: Dimension of the dense visit embedding. + hidden_dim: Hidden dimension of the LSTM. + num_layers: Number of stacked LSTM layers. + """ + + def __init__( + self, + vocab_size: int, + embed_dim: int, + hidden_dim: int, + num_layers: int = 1, + ): + super().__init__() + self.embedding = nn.Linear(vocab_size, embed_dim) + self.lstm = nn.LSTM( + input_size=embed_dim, + hidden_size=hidden_dim, + num_layers=num_layers, + batch_first=True, + ) + self.fc = nn.Linear(hidden_dim, 1) + self.sigmoid = nn.Sigmoid() + + def forward(self, x: torch.Tensor, lengths: torch.Tensor) -> torch.Tensor: + x = self.embedding(x) + packed_x = torch.nn.utils.rnn.pack_padded_sequence( + x, lengths.cpu(), batch_first=True, enforce_sorted=False + ) + _, (h_n, _) = self.lstm(packed_x) + final_encoding = h_n[-1] + logits = self.fc(final_encoding) + probs = self.sigmoid(logits) + return probs.squeeze(-1) + + +def train_lstm_model( + train_ehr: pd.DataFrame, + test_ehr: pd.DataFrame, + subject_col: str, + visit_col: str, + code_col: str, + label_col: str, + embed_dim: int = 32, + hidden_dim: int = 32, + batch_size: int = 32, + epochs: int = 5, + verbose: bool = True, + seed: int = 4, +) -> Tuple[nn.Module, np.ndarray, np.ndarray]: + """Trains :class:`EHR_LSTM_Classifier` and evaluates it on a test set. + + Args: + train_ehr: Training EHR dataframe. + test_ehr: Test EHR dataframe. + subject_col: Column name for patient/subject identifiers. + visit_col: Column name for visit/timestep identifiers. + code_col: Column name for the medical codes. + label_col: Column name for the binary label. + embed_dim: Visit embedding dimension. + hidden_dim: LSTM hidden dimension. + batch_size: Training/eval batch size. + epochs: Number of training epochs. + verbose: Whether to print per-epoch loss. + seed: Random seed for reproducibility. + + Returns: + A tuple ``(model, y_true, y_pred)`` where ``y_true`` and ``y_pred`` are + numpy arrays of test labels and binary predictions. + """ + torch.manual_seed(seed) + all_codes = set() + all_codes.update(train_ehr[code_col].unique().tolist()) + all_codes.update(test_ehr[code_col].unique().tolist()) + # Start indices at 1 to reserve 0 for padding. + code_to_idx = {code: idx for idx, code in enumerate(all_codes, start=1)} + + train_data, _ = process_patient_data_for_lstm( + train_ehr, subject_col, visit_col, code_col, label_col, code_to_idx + ) + test_data, _ = process_patient_data_for_lstm( + test_ehr, subject_col, visit_col, code_col, label_col, code_to_idx + ) + train_dataloader = torch.utils.data.DataLoader( + dataset=EHRDataset(train_data), + batch_size=batch_size, + collate_fn=collate_fn, + shuffle=True, + ) + test_dataloader = torch.utils.data.DataLoader( + dataset=EHRDataset(test_data), + batch_size=batch_size, + collate_fn=collate_fn, + shuffle=False, + ) + + model = EHR_LSTM_Classifier( + vocab_size=len(code_to_idx) + 1, + embed_dim=embed_dim, + hidden_dim=hidden_dim, + ) + criterion = nn.BCELoss() + optimizer = torch.optim.Adam(model.parameters(), lr=0.001) + + use_cuda = torch.cuda.is_available() + if use_cuda: + model = model.cuda() + + model.train() + for epoch in range(epochs): + total_loss = 0.0 + for batch_x, batch_lens, batch_y in train_dataloader: + optimizer.zero_grad() + if use_cuda: + batch_x, batch_y = batch_x.cuda(), batch_y.cuda() + predictions = model(batch_x, batch_lens) + loss = criterion(predictions, batch_y) + loss.backward() + optimizer.step() + total_loss += loss.item() + if verbose: + avg_loss = total_loss / max(len(train_dataloader), 1) + print(f"Epoch {epoch + 1}/{epochs}, Loss: {avg_loss:.4f}") + + model.eval() + all_preds: List[float] = [] + all_labels: List[float] = [] + with torch.no_grad(): + for batch_x, batch_lens, batch_y in test_dataloader: + if use_cuda: + batch_x, batch_y = batch_x.cuda(), batch_y.cuda() + predictions = model(batch_x, batch_lens) + all_preds.extend(predictions.cpu().numpy()) + all_labels.extend(batch_y.cpu().numpy()) + + y_true = np.array(all_labels) + y_pred = np.array([1 if p >= 0.5 else 0 for p in all_preds]) + return model, y_true, y_pred + + +# --- Random-forest baseline ------------------------------------------------ + + +def aggregate_patient_visits( + df: pd.DataFrame, + subject_col: str, + code_col: str, + label_col: str, + code_to_idx: Dict, +) -> Tuple[np.ndarray, np.ndarray]: + """Aggregates each patient's visits into a single multi-hot vector. + + Args: + df: Input dataframe with integer-encoded codes in ``code_col``. + subject_col: Column name for patient/subject identifiers. + code_col: Column name for the (integer-encoded) medical codes. + label_col: Column name for the binary label. + code_to_idx: Mapping from code to index (used to size the vector). + + Returns: + A tuple ``(patient_vectors, patient_labels)`` of numpy arrays. + """ + patient_vectors = [] + patient_labels = [] + for _, group in df.groupby(subject_col): + codes = group[code_col].unique() + multi_hot = np.zeros(len(code_to_idx) + 1) + multi_hot[codes] = 1 + patient_vectors.append(multi_hot) + patient_labels.append(group[label_col].max()) + return np.array(patient_vectors), np.array(patient_labels) + + +def train_sklearn_model( + train_ehr: pd.DataFrame, + test_ehr: pd.DataFrame, + subject_col: str, + visit_col: str, + code_col: str, + label_col: str, + model: str = "rf", + seed: int = 4, +) -> Tuple[object, np.ndarray, np.ndarray]: + """Trains an sklearn classifier on aggregated patient-level multi-hot data. + + Args: + train_ehr: Training EHR dataframe. + test_ehr: Test EHR dataframe. + subject_col: Column name for patient/subject identifiers. + visit_col: Column name for visit/timestep identifiers (unused, kept for + a uniform signature with :func:`train_lstm_model`). + code_col: Column name for the medical codes. + label_col: Column name for the binary label. + model: Which model to train. Only ``"rf"`` (random forest) is supported. + seed: Random seed for reproducibility. + + Returns: + A tuple ``(model, y_true, y_pred)``. + """ + train_ehr = train_ehr.copy() + test_ehr = test_ehr.copy() + + all_codes = set() + all_codes.update(train_ehr[code_col].unique().tolist()) + all_codes.update(test_ehr[code_col].unique().tolist()) + # Start indices at 1 to reserve 0 for padding. + code_to_idx = {code: idx for idx, code in enumerate(all_codes, start=1)} + train_ehr[code_col] = train_ehr[code_col].map(code_to_idx) + test_ehr[code_col] = test_ehr[code_col].map(code_to_idx) + + X_train, y_train = aggregate_patient_visits( + train_ehr, subject_col, code_col, label_col, code_to_idx + ) + X_test, y_test = aggregate_patient_visits( + test_ehr, subject_col, code_col, label_col, code_to_idx + ) + + if model == "rf": + from sklearn.ensemble import RandomForestClassifier + + clf = RandomForestClassifier(n_estimators=100, random_state=seed) + else: + raise NotImplementedError(f"Model '{model}' not implemented.") + clf.fit(X_train, y_train) + + y_pred = clf.predict(X_test) + return clf, y_test, y_pred + + +# --- Task / feature construction ------------------------------------------ + + +def build_next_visit_prediction_dataset( + df: pd.DataFrame, + subject_col: str, + visit_col: str, + label_col: str, + multi_visit_sample_frac: float = 0.5, + seed: int = 4, +) -> pd.DataFrame: + """Builds a next-visit prediction task from an EHR dataframe. + + For patients with multiple visits, a fraction is sampled and their last + visit is dropped; these patients are labeled 1 (has a next visit). The + remaining multi-visit patients are kept intact and labeled 0. Single-visit + patients are labeled 0 by definition. + + Args: + df: Input EHR dataframe. + subject_col: Column name for patient/subject identifiers. + visit_col: Column name for visit/timestep identifiers. + label_col: Column name to overwrite with the next-visit label. + multi_visit_sample_frac: Fraction of multi-visit patients to truncate. + seed: Random seed for reproducibility. + + Returns: + A new dataframe with ``label_col`` set to the next-visit label. + """ + assert 0.0 <= multi_visit_sample_frac <= 1.0, ( + "multi_visit_sample_frac must be in [0, 1]." + ) + + rng = np.random.default_rng(seed) + transformed_groups = [] + + for _, group in df.groupby(subject_col): + group_sorted = group.sort_values(visit_col) + unique_visits = np.sort(group_sorted[visit_col].unique()) + n_visits = len(unique_visits) + + if n_visits <= 1: + g = group_sorted.copy() + g[label_col] = 0 + transformed_groups.append(g) + continue + + should_truncate = rng.random() < multi_visit_sample_frac + if should_truncate: + last_visit = unique_visits[-1] + g = group_sorted[group_sorted[visit_col] != last_visit].copy() + if g.empty: + # Defensive fallback for unexpected edge cases. + g = group_sorted.copy() + g[label_col] = 0 + else: + g[label_col] = 1 + else: + g = group_sorted.copy() + g[label_col] = 0 + transformed_groups.append(g) + + if len(transformed_groups) == 0: + return df.copy() + return pd.concat(transformed_groups, ignore_index=True) + + +def convert_cols_to_multihot( + df: pd.DataFrame, + code_col: str, + visit_col: str, + cat_cols: List[str], + num_cols: List[str], + bins_per_num: int = 5, +) -> pd.DataFrame: + """Folds categorical and numeric columns into per-visit multi-hot codes. + + Categorical columns are prefixed with their column name; numeric columns + are quantile-binned and likewise prefixed. All values are combined with the + original code into a single comma-separated ``combined_codes`` column. + + Args: + df: Input dataframe. + code_col: Column name for the existing medical codes. + visit_col: Column name for visit/timestep identifiers (kept for a + uniform signature; not modified). + cat_cols: Categorical column names to fold in. + num_cols: Numeric column names to bin and fold in. + bins_per_num: Number of quantile bins per numeric column. + + Returns: + A copy of ``df`` with an added ``combined_codes`` column. + """ + df = df.copy() + for col in cat_cols: + df[col] = col + "_" + df[col].astype(str) + + for col in num_cols: + df[col + "_binned"] = pd.qcut( + df[col], q=bins_per_num, duplicates="drop" + ).astype(str) + df[col + "_binned"] = col + "_" + df[col + "_binned"] + + def combine_codes(row): + codes = [str(row[code_col])] + for col in cat_cols: + codes.append(str(row[col])) + for col in num_cols: + codes.append(str(row[col + "_binned"])) + return ",".join(codes) + + df["combined_codes"] = df.apply(combine_codes, axis=1) + return df diff --git a/tests/core/test_generative_metrics.py b/tests/core/test_generative_metrics.py new file mode 100644 index 000000000..4f6a38ae7 --- /dev/null +++ b/tests/core/test_generative_metrics.py @@ -0,0 +1,256 @@ +"""Unit tests for pyhealth.metrics.generative (synthetic-EHR metrics). + +Run with:: + + python -m unittest tests.core.test_generative_metrics -v +""" + +import unittest + +import numpy as np +import pandas as pd + +from pyhealth.metrics.generative import ( + calc_membership_inference, + calc_nnaar, + compute_discriminator_privacy, + compute_mle, + compute_prevalence_metrics, + evaluate_synthetic_ehr, +) +from pyhealth.metrics.generative.utils import ( + convert_cols_to_multihot, + train_lstm_model, + train_sklearn_model, +) + +SUBJECT_COL, VISIT_COL, CODE_COL, LABEL_COL = "id", "time", "visit_codes", "labels" + + +def _make_dataframes(): + """Builds small synthetic train/test/synthetic EHR dataframes.""" + train_ehr = pd.DataFrame( + { + "visit_codes": [0, 1, 3, 4, 1, 2, 0, 3, 2, 4, 1, 0, 2, 3, 4, + 1, 0, 2, 3, 4, 1, 0, 2, 3, 4], + "labels": [0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0, + 1, 0, 0, 1, 0, 1, 0, 0, 1, 0], + "time": [0, 0, 1, 1, 0, 1, 2, 2, 3, 3, 1, 2, 3, 4, 4, + 0, 1, 2, 3, 4, 1, 2, 3, 4, 5], + "id": [0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, + 3, 3, 3, 3, 3, 4, 4, 4, 4, 4], + } + ).astype({"visit_codes": str, "labels": int, "time": int, "id": str}) + + test_ehr = pd.DataFrame( + { + "visit_codes": [1, 2, 0, 3, 4, 2, 1, 0, 3, 4, 1, 2, 3, 0, 4, + 2, 1, 3, 0, 4], + "labels": [0, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 1, + 0, 0, 1, 0, 1], + "time": [0, 0, 0, 1, 1, 1, 2, 2, 2, 2, 0, 1, 1, 2, 2, + 3, 3, 3, 4, 4], + "id": [0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2], + } + ).astype({"visit_codes": str, "labels": int, "time": int, "id": str}) + + syn_ehr = pd.DataFrame( + { + "visit_codes": [2, 3, 1, 4, 0, 2, 3, 1, 0, 4, 1, 2, 3, 4, 0, + 2, 1, 3, 4, 0, 2, 1, 3, 4, 0, 1, 2, 3, 4, 0], + "labels": [0, 1, 0, 0, 1, 0, 1, 0, 0, 1, 0, 0, 1, 0, 1, + 0, 0, 1, 0, 1, 0, 0, 1, 0, 1, 0, 0, 1, 0, 1], + "time": [0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, + 3, 3, 3, 3, 3, 4, 4, 4, 4, 4, 5, 5, 5, 5, 5], + "id": [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2], + } + ).astype({"visit_codes": str, "labels": int, "time": int, "id": str}) + + return train_ehr, test_ehr, syn_ehr + + +class GenerativeMetricsTestCase(unittest.TestCase): + """Shared fixtures and assertion helpers for the generative metrics.""" + + def setUp(self): + np.random.seed(0) + self.train_ehr, self.test_ehr, self.syn_ehr = _make_dataframes() + self.cols = dict( + subject_col=SUBJECT_COL, + visit_col=VISIT_COL, + code_col=CODE_COL, + label_col=LABEL_COL, + ) + + def assertSummary(self, summary, expected_keys): + """Asserts a metrics summary has the expected (mean, std) structure.""" + self.assertIsInstance(summary, dict) + for key in expected_keys: + self.assertIn(key, summary) + value = summary[key] + self.assertIsInstance(value, tuple) + self.assertEqual(len(value), 2) + mean, std = value + self.assertTrue(np.isfinite(mean), f"{key} mean not finite") + self.assertTrue(np.isfinite(std), f"{key} std not finite") + self.assertGreaterEqual(std, 0.0) + + +class TestNNAAR(GenerativeMetricsTestCase): + def test_calc_nnaar(self): + summary = calc_nnaar( + self.train_ehr, self.test_ehr, self.syn_ehr, + **self.cols, sample_size=10, n_runs=3, + ) + self.assertSummary(summary, ["nnaar", "aa_es", "aa_ts"]) + for key in ("aa_es", "aa_ts"): + self.assertGreaterEqual(summary[key][0], 0.0) + self.assertLessEqual(summary[key][0], 1.0) + self.assertGreaterEqual(summary["nnaar"][0], -1.0) + self.assertLessEqual(summary["nnaar"][0], 1.0) + + +class TestMembershipInference(GenerativeMetricsTestCase): + def test_calc_membership_inference(self): + summary = calc_membership_inference( + self.train_ehr, self.test_ehr, self.syn_ehr, + **self.cols, num_attack_samples=10, n_runs=3, + ) + keys = ["MIA_F1", "MIA_Precision", "MIA_Recall", "MIA_Accuracy"] + self.assertSummary(summary, keys) + for key in keys: + self.assertGreaterEqual(summary[key][0], 0.0) + self.assertLessEqual(summary[key][0], 1.0) + + +class TestDiscriminatorPrivacy(GenerativeMetricsTestCase): + def test_discriminator_privacy_lstm(self): + summary = compute_discriminator_privacy( + train_fn=train_lstm_model, + train_ehr=self.train_ehr, test_ehr=self.test_ehr, + syn_ehr=self.syn_ehr, **self.cols, n_bootstraps=3, + embed_dim=8, hidden_dim=8, batch_size=8, epochs=2, verbose=False, + ) + keys = ["Privacy_Discriminator_Accuracy", "Privacy_Score"] + self.assertSummary(summary, keys) + self.assertGreaterEqual(summary["Privacy_Score"][0], 0.0) + self.assertLessEqual(summary["Privacy_Score"][0], 1.0) + + def test_discriminator_privacy_rf(self): + summary = compute_discriminator_privacy( + train_fn=train_sklearn_model, + train_ehr=self.train_ehr, test_ehr=self.test_ehr, + syn_ehr=self.syn_ehr, **self.cols, n_bootstraps=3, model="rf", + ) + self.assertSummary( + summary, ["Privacy_Discriminator_Accuracy", "Privacy_Score"] + ) + + +class TestMLE(GenerativeMetricsTestCase): + def test_compute_mle_lstm(self): + summary = compute_mle( + train_fn=train_lstm_model, + train_ehr=self.train_ehr, test_ehr=self.test_ehr, + syn_ehr=self.syn_ehr, **self.cols, n_bootstraps=3, + embed_dim=8, hidden_dim=8, batch_size=8, epochs=2, verbose=False, + ) + keys = [ + "MLE_Real_Accuracy", "MLE_Synth_Accuracy", "MLE_Difference", + "MLE_Ratio", "MLE_Real_F1", "MLE_Synth_F1", + ] + self.assertSummary(summary, keys) + for key in ("MLE_Real_Accuracy", "MLE_Synth_Accuracy"): + self.assertGreaterEqual(summary[key][0], 0.0) + self.assertLessEqual(summary[key][0], 1.0) + + def test_compute_mle_rf(self): + summary = compute_mle( + train_fn=train_sklearn_model, + train_ehr=self.train_ehr, test_ehr=self.test_ehr, + syn_ehr=self.syn_ehr, **self.cols, n_bootstraps=3, model="rf", + ) + self.assertSummary(summary, ["MLE_Real_Accuracy", "MLE_Synth_Accuracy"]) + + +class TestPrevalenceMetrics(GenerativeMetricsTestCase): + def test_compute_prevalence_metrics(self): + summary = compute_prevalence_metrics( + self.train_ehr, self.syn_ehr, + subject_col=SUBJECT_COL, code_col=CODE_COL, n_bootstraps=3, + ) + keys = ["Prevalence_R2", "Prevalence_Pearson", "Prevalence_RMSE"] + self.assertSummary(summary, keys) + self.assertGreaterEqual(summary["Prevalence_Pearson"][0], -1.0) + self.assertLessEqual(summary["Prevalence_Pearson"][0], 1.0) + self.assertGreaterEqual(summary["Prevalence_RMSE"][0], 0.0) + + +class TestConvertColsToMultihot(GenerativeMetricsTestCase): + def test_convert_cols_to_multihot(self): + df = self.train_ehr.copy() + df["gender"] = ["M", "F"] * 12 + ["M"] + df["age"] = np.arange(len(df), dtype=float) + out = convert_cols_to_multihot( + df, code_col=CODE_COL, visit_col=VISIT_COL, + cat_cols=["gender"], num_cols=["age"], bins_per_num=2, + ) + self.assertIn("combined_codes", out.columns) + self.assertEqual(len(out), len(df)) + # Each combined code should fold in the code, the category and the bin. + first = out["combined_codes"].iloc[0] + self.assertIn("gender_", first) + self.assertIn("age_", first) + # The original dataframe must not be mutated. + self.assertNotIn("combined_codes", df.columns) + + +class TestEvaluateSyntheticEHR(GenerativeMetricsTestCase): + def test_evaluate_all_lstm(self): + out = evaluate_synthetic_ehr( + self.train_ehr, self.test_ehr, self.syn_ehr, **self.cols, + sample_size=10, mode="lstm", metrics="all", + lstm_params={"embed_dim": 8, "hidden_dim": 8, + "batch_size": 8, "epochs": 2}, + n_bootstraps=3, n_runs=3, + ) + for key in ("nnaar", "MIA_F1", "MLE_Real_Accuracy", + "Privacy_Score", "Prevalence_RMSE"): + self.assertIn(key, out) + + def test_evaluate_privacy_only_rf(self): + out = evaluate_synthetic_ehr( + self.train_ehr, self.test_ehr, self.syn_ehr, **self.cols, + sample_size=10, mode="rf", metrics="privacy", + n_bootstraps=3, n_runs=3, + ) + self.assertIn("nnaar", out) + self.assertNotIn("MLE_Real_Accuracy", out) + + def test_evaluate_utility_only_rf(self): + out = evaluate_synthetic_ehr( + self.train_ehr, self.test_ehr, self.syn_ehr, **self.cols, + mode="rf", metrics="utility", n_bootstraps=3, + ) + self.assertIn("MLE_Real_Accuracy", out) + self.assertNotIn("nnaar", out) + + def test_invalid_mode_raises(self): + with self.assertRaises(ValueError): + evaluate_synthetic_ehr( + self.train_ehr, self.test_ehr, self.syn_ehr, **self.cols, + mode="bad", + ) + + def test_invalid_metrics_raises(self): + with self.assertRaises(ValueError): + evaluate_synthetic_ehr( + self.train_ehr, self.test_ehr, self.syn_ehr, **self.cols, + metrics="bad", + ) + + +if __name__ == "__main__": + unittest.main() From 20e444658682d669f087269c57902b1bab3d24be Mon Sep 17 00:00:00 2001 From: chufangao Date: Sun, 17 May 2026 23:46:52 -0500 Subject: [PATCH 2/2] Add synthetic-EHR generative evaluation metrics Adds pyhealth/metrics/generative/, a subpackage for evaluating synthetic EHR data along privacy, utility, and statistical-fidelity axes: - privacy.py: NNAAR, membership inference attack, discriminator privacy - utility.py: machine learning efficacy (TRTR vs TSTR), code-prevalence similarity (R2, Pearson, RMSE) - utils.py: shared data prep, an LSTM classifier, and a random-forest baseline - evaluate_synthetic_ehr(): convenience orchestrator for the full suite These functions are ported from a standalone evaluation script. The MIMIC-specific data-loading/CLI glue is dropped; the metrics work on any flat EHR dataframe. Public functions are re-exported from pyhealth.metrics. Adds unit tests in tests/core/test_generative_metrics.py and Sphinx docs. Co-Authored-By: Claude Opus 4.7 (1M context) --- tests/core/test_generative_metrics.py | 209 ++++++++++++++++++++++++++ 1 file changed, 209 insertions(+) diff --git a/tests/core/test_generative_metrics.py b/tests/core/test_generative_metrics.py index 4f6a38ae7..7926d2b88 100644 --- a/tests/core/test_generative_metrics.py +++ b/tests/core/test_generative_metrics.py @@ -71,6 +71,42 @@ def _make_dataframes(): return train_ehr, test_ehr, syn_ehr +def _generate_ehr( + n_patients, vocab, seed, id_offset=0, + n_visits_range=(2, 7), n_codes_range=(2, 6), +): + """Generates a random EHR dataframe with patients drawn from ``vocab``.""" + rng = np.random.default_rng(seed) + rows = [] + for i in range(n_patients): + pid = str(id_offset + i) + n_visits = int(rng.integers(*n_visits_range)) + label = int(rng.integers(0, 2)) + for t in range(n_visits): + n_codes = int(rng.integers(*n_codes_range)) + codes = rng.choice( + vocab, size=min(n_codes, len(vocab)), replace=False + ) + for code in codes: + rows.append( + {"id": pid, "time": t, + "visit_codes": str(code), "labels": label} + ) + return pd.DataFrame(rows).astype( + {"visit_codes": str, "labels": int, "time": int, "id": str} + ) + + +def _perturb_ehr(df, vocab, frac, seed): + """Returns a copy of ``df`` with a fraction of codes randomly replaced.""" + rng = np.random.default_rng(seed) + df = df.copy().reset_index(drop=True) + mask = rng.random(len(df)) < frac + new_codes = rng.choice(vocab, size=int(mask.sum())) + df.loc[mask, "visit_codes"] = [str(c) for c in new_codes] + return df + + class GenerativeMetricsTestCase(unittest.TestCase): """Shared fixtures and assertion helpers for the generative metrics.""" @@ -252,5 +288,178 @@ def test_invalid_metrics_raises(self): ) +class TestMetricsBehavior(unittest.TestCase): + """Sanity checks: metrics should respond to how close synthetic data is. + + Three synthetic datasets are compared against the same real data: + + - ``exact``: an exact copy of the real training data, + - ``similar``: the training data with ~15% of codes randomly changed, + - ``different``: independent data over a disjoint code vocabulary. + + A well-behaved metric should rank these consistently (e.g. an exact copy + is the worst case for privacy and the best case for fidelity). + """ + + VOCAB_REAL = list(range(50)) + VOCAB_DIFF = list(range(100, 150)) + + @classmethod + def setUpClass(cls): + cls.train_ehr = _generate_ehr(60, cls.VOCAB_REAL, seed=1, id_offset=0) + cls.test_ehr = _generate_ehr( + 60, cls.VOCAB_REAL, seed=2, id_offset=10000 + ) + cls.syn_exact = cls.train_ehr.copy() + cls.syn_similar = _perturb_ehr( + cls.train_ehr, cls.VOCAB_REAL, frac=0.15, seed=3 + ) + cls.syn_different = _generate_ehr( + 60, cls.VOCAB_DIFF, seed=4, id_offset=20000 + ) + cls.cols = dict( + subject_col=SUBJECT_COL, + visit_col=VISIT_COL, + code_col=CODE_COL, + label_col=LABEL_COL, + ) + + def test_prevalence_orders_by_similarity(self): + # Prevalence similarity should degrade monotonically: exact > similar + # > different. + results = {} + for name, syn in [ + ("exact", self.syn_exact), + ("similar", self.syn_similar), + ("different", self.syn_different), + ]: + np.random.seed(0) + results[name] = compute_prevalence_metrics( + self.train_ehr, syn, + subject_col=SUBJECT_COL, code_col=CODE_COL, n_bootstraps=10, + ) + + rmse = {k: v["Prevalence_RMSE"][0] for k, v in results.items()} + r2 = {k: v["Prevalence_R2"][0] for k, v in results.items()} + pearson = {k: v["Prevalence_Pearson"][0] for k, v in results.items()} + + # An exact copy has identical code prevalence. + self.assertAlmostEqual(rmse["exact"], 0.0, places=9) + self.assertAlmostEqual(r2["exact"], 1.0, places=6) + self.assertAlmostEqual(pearson["exact"], 1.0, places=6) + + # Error grows / agreement shrinks as synthetic data drifts away. + self.assertLess(rmse["exact"], rmse["similar"]) + self.assertLess(rmse["similar"], rmse["different"]) + self.assertGreater(r2["exact"], r2["similar"]) + self.assertGreater(r2["similar"], r2["different"]) + self.assertGreaterEqual(pearson["exact"], pearson["similar"]) + self.assertGreater(pearson["similar"], pearson["different"]) + + def test_nnaar_flags_exact_copies(self): + # NNAAR should be high when synthetic data memorizes the training set + # and near zero otherwise. + nnaar = {} + for name, syn in [ + ("exact", self.syn_exact), + ("similar", self.syn_similar), + ("different", self.syn_different), + ]: + np.random.seed(0) + nnaar[name] = calc_nnaar( + self.train_ehr, self.test_ehr, syn, + **self.cols, sample_size=1000, n_runs=3, + )["nnaar"][0] + + self.assertGreater(nnaar["exact"], 0.5) + self.assertGreater(nnaar["exact"], nnaar["similar"]) + self.assertGreater(nnaar["exact"], nnaar["different"]) + self.assertLess(nnaar["similar"], 0.3) + self.assertLess(nnaar["different"], 0.3) + + def test_membership_inference_detects_training_data(self): + # The attack should succeed when synthetic data is derived from the + # training set and be near chance when it is unrelated. + acc = {} + for name, syn in [ + ("exact", self.syn_exact), + ("similar", self.syn_similar), + ("different", self.syn_different), + ]: + np.random.seed(0) + acc[name] = calc_membership_inference( + self.train_ehr, self.test_ehr, syn, + **self.cols, num_attack_samples=1000, n_runs=5, + )["MIA_Accuracy"][0] + + self.assertGreater(acc["exact"], 0.8) + self.assertGreater(acc["exact"], acc["different"]) + self.assertGreater(acc["similar"], acc["different"]) + self.assertLess(acc["different"], 0.7) + + def test_discriminator_privacy_orders_by_similarity(self): + # A discriminator easily separates a disjoint-vocabulary synthetic set + # (accuracy ~1, privacy score ~0) but not data derived from the real + # data (lower accuracy, higher privacy score). + score, acc = {}, {} + for name, syn in [ + ("exact", self.syn_exact), + ("similar", self.syn_similar), + ("different", self.syn_different), + ]: + np.random.seed(0) + result = compute_discriminator_privacy( + train_fn=train_sklearn_model, + train_ehr=self.train_ehr, test_ehr=self.test_ehr, + syn_ehr=syn, **self.cols, n_bootstraps=10, model="rf", + ) + score[name] = result["Privacy_Score"][0] + acc[name] = result["Privacy_Discriminator_Accuracy"][0] + + # The disjoint-vocabulary set is trivially detected. + self.assertGreater(acc["different"], 0.8) + self.assertLess(score["different"], 0.1) + # Data derived from the real data is harder to flag. + self.assertGreater(acc["different"], acc["exact"]) + self.assertGreater(acc["different"], acc["similar"]) + self.assertGreater(score["exact"], score["different"]) + self.assertGreater(score["similar"], score["different"]) + + def test_mle_orders_by_similarity(self): + # Utility should be highest for an exact copy and degrade as the + # synthetic data drifts away from the real data. + mle = {} + for name, syn in [ + ("exact", self.syn_exact), + ("similar", self.syn_similar), + ("different", self.syn_different), + ]: + np.random.seed(0) + mle[name] = compute_mle( + train_fn=train_sklearn_model, + train_ehr=self.train_ehr, test_ehr=self.test_ehr, + syn_ehr=syn, **self.cols, n_bootstraps=10, model="rf", + ) + + # An exact copy reproduces real utility exactly. + exact = mle["exact"] + self.assertAlmostEqual(exact["MLE_Difference"][0], 0.0, places=9) + self.assertAlmostEqual(exact["MLE_Difference"][1], 0.0, places=9) + self.assertAlmostEqual(exact["MLE_Ratio"][0], 1.0, places=9) + self.assertAlmostEqual( + exact["MLE_Synth_Accuracy"][0], exact["MLE_Real_Accuracy"][0], + places=9, + ) + + # Synthetic-trained accuracy degrades monotonically. + diff = {k: abs(v["MLE_Difference"][0]) for k, v in mle.items()} + ratio = {k: v["MLE_Ratio"][0] for k, v in mle.items()} + self.assertLessEqual(diff["exact"], diff["similar"]) + self.assertLess(diff["similar"], diff["different"]) + self.assertGreaterEqual(ratio["exact"], ratio["similar"]) + self.assertGreater(ratio["similar"], ratio["different"]) + self.assertLess(ratio["different"], 1.0) + + if __name__ == "__main__": unittest.main()