diff --git a/docs/api/models.rst b/docs/api/models.rst index 7c3ac7c4b..d0662009a 100644 --- a/docs/api/models.rst +++ b/docs/api/models.rst @@ -205,4 +205,5 @@ API Reference models/pyhealth.models.TextEmbedding models/pyhealth.models.BIOT models/pyhealth.models.unified_multimodal_embedding_docs + models/pyhealth.models.td_icu_mortality models/pyhealth.models.califorest diff --git a/docs/api/models/pyhealth.models.td_icu_mortality.rst b/docs/api/models/pyhealth.models.td_icu_mortality.rst new file mode 100644 index 000000000..27e67f02f --- /dev/null +++ b/docs/api/models/pyhealth.models.td_icu_mortality.rst @@ -0,0 +1,11 @@ +pyhealth.models.td_icu_mortality +================================ + +.. automodule:: pyhealth.models.td_icu_mortality + +.. autoclass:: pyhealth.models.td_icu_mortality.TDICUMortalityModel + :members: + :undoc-members: + :show-inheritance: + + diff --git a/examples/td_icu_mortality_mimic4_example.ipynb b/examples/td_icu_mortality_mimic4_example.ipynb new file mode 100644 index 000000000..3c97d7358 --- /dev/null +++ b/examples/td_icu_mortality_mimic4_example.ipynb @@ -0,0 +1,4050 @@ +{ + "nbformat": 4, + "nbformat_minor": 0, + "metadata": { + "colab": { + "provenance": [] + }, + "kernelspec": { + "name": "python3", + "display_name": "Python 3" + }, + "language_info": { + "name": "python" + }, + "widgets": { + "application/vnd.jupyter.widget-state+json": { + "6a362aae36e446a79088e11a38440615": { + "model_module": "@jupyter-widgets/controls", + "model_name": "HBoxModel", + "model_module_version": "1.5.0", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HBoxModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HBoxView", + "box_style": "", + "children": [ + "IPY_MODEL_1db925d48ebd4597bf1d1d3970ee66a5", + "IPY_MODEL_c2490a1bcd094d8481223702e5916178", + "IPY_MODEL_d4ac3b75f60243d8b90421260095b46c" + ], + "layout": "IPY_MODEL_ea6993b605dd4562b17f123f9f14b57b" + } + }, + "1db925d48ebd4597bf1d1d3970ee66a5": { + "model_module": "@jupyter-widgets/controls", + "model_name": "HTMLModel", + "model_module_version": "1.5.0", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_8d1f62afb4764bb199a2e183a5ce153b", + "placeholder": "​", + "style": "IPY_MODEL_69387d6c3c1a46bf8c11aed964eefdf7", + "value": "test: 100%" + } + }, + "c2490a1bcd094d8481223702e5916178": { + "model_module": "@jupyter-widgets/controls", + "model_name": "FloatProgressModel", + "model_module_version": "1.5.0", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "FloatProgressModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "ProgressView", + "bar_style": "success", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_3aea421084ef4be3a77a1268a28dd3ee", + "max": 2306, + "min": 0, + "orientation": "horizontal", + "style": "IPY_MODEL_c8ca009e64a94258bcf74207f5f63884", + "value": 2306 + } + }, + "d4ac3b75f60243d8b90421260095b46c": { + "model_module": "@jupyter-widgets/controls", + "model_name": "HTMLModel", + "model_module_version": "1.5.0", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_338a086755364c7dadccfe7cf1c76764", + "placeholder": "​", + "style": "IPY_MODEL_beec26dfe5c5471a81a5c51fa37c6533", + "value": " 2306/2306 [13:52<00:00,  3.06it/s]" + } + }, + "ea6993b605dd4562b17f123f9f14b57b": { + "model_module": "@jupyter-widgets/base", + "model_name": "LayoutModel", + "model_module_version": "1.2.0", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "8d1f62afb4764bb199a2e183a5ce153b": { + "model_module": "@jupyter-widgets/base", + "model_name": "LayoutModel", + "model_module_version": "1.2.0", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "69387d6c3c1a46bf8c11aed964eefdf7": { + "model_module": "@jupyter-widgets/controls", + "model_name": "DescriptionStyleModel", + "model_module_version": "1.5.0", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "3aea421084ef4be3a77a1268a28dd3ee": { + "model_module": "@jupyter-widgets/base", + "model_name": "LayoutModel", + "model_module_version": "1.2.0", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "c8ca009e64a94258bcf74207f5f63884": { + "model_module": "@jupyter-widgets/controls", + "model_name": "ProgressStyleModel", + "model_module_version": "1.5.0", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "ProgressStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "bar_color": null, + "description_width": "" + } + }, + "338a086755364c7dadccfe7cf1c76764": { + "model_module": "@jupyter-widgets/base", + "model_name": "LayoutModel", + "model_module_version": "1.2.0", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "beec26dfe5c5471a81a5c51fa37c6533": { + "model_module": "@jupyter-widgets/controls", + "model_name": "DescriptionStyleModel", + "model_module_version": "1.5.0", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + } + } + } + }, + "cells": [ + { + "cell_type": "markdown", + "source": [ + "# End-to-End TD-ICU Mortality Prediction Pipeline\n", + "\n", + "\n", + "\n", + "### **From MIMIC-IV Raw Data to Clinical Triage with Uncertainty Quantification**\n", + "\n", + "---\n", + "\n", + "### Notebook Overview\n", + "This notebook implements a complete, end-to-end pipeline for **Real-Time Intensive Care Unit (ICU) Mortality Prediction**. It is based on the Temporal-Difference (TD) learning methodology proposed by *Frost et al. (2024)*, designed to handle complex, sparse, and irregular clinical event streams.\n", + "\n", + "Beyond standard mortality risk, this notebook integrates **Monte Carlo (MC) Dropout** to provide robust per-patient confidence bounds, enabling realistic clinical triage simulations.\n", + "\n", + "---\n", + "\n", + "### Core Pipeline Stages\n", + "\n", + "#### 1. Data Preprocessing & Validation\n", + "* **CSV to Parquet:** Converts raw PhysioNet MIMIC-IV v3.x CSV shards into optimized, schema-typed Parquet files.\n", + "* **HDF5 Construction:** Normalizes events, structures time-series data, and handles feature mapping for high-performance I/O.\n", + "* **Integrity Checks:** Validates data structures and fixed-padding logic prior to training.\n", + "\n", + "#### 2. Model Architecture (`pyhealth/models`)\n", + "* **`CNNLSTMPredictor`:** A hybrid neural network combining CNN feature extraction with LSTM temporal modeling.\n", + "* **`TDICUMortalityModel`:** Employs a dual-network setup (Online + Target) trained via TD-learning to map variable-length event streams to a calibrated mortality horizon.\n", + "\n", + "#### 3. Training & Evaluation\n", + "* **Distributed Orchestration:** Streams batched records, applying continuous metric tracking (AUROC, AUPRC, Brier Score).\n", + "* **Idempotent Saves:** Safely checkpoints the model state and rolls back if execution drops.\n", + "\n", + "#### 4. Uncertainty Estimation & Clinical Triage\n", + "* **MC Dropout Inference:** Executes stochastic forward passes to capture model variance.\n", + "* **Confidence Bounds:** Generates 95% credible intervals and identifies high/low-confidence predictions.\n", + "* **Clinical Triage Simulation:** Automatically stratifies patients into actionable groups based on risk and uncertainty:\n", + " * **Auto-alert:** High confidence, High risk (> 50%)\n", + " * **Senior review:** Low confidence (System asks for human help)\n", + " * **Standard monitoring:** High confidence, Low risk (< 20%)\n", + "\n", + "---" + ], + "metadata": { + "id": "2EoQKvJlhV09" + } + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "OrwYtVM0qV29" + }, + "outputs": [], + "source": [ + "!pip install -q polars pyarrow h5py\n", + "!pip install -q tensordict torchmetrics" + ] + }, + { + "cell_type": "code", + "source": [ + "!pip install -q pyhealth" + ], + "metadata": { + "id": "81PmNc9u16Fh" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "!pip uninstall -y pyarrow -q\n", + "!pip install -q --no-cache-dir pyarrow\n", + "!pip install -q --no-cache-dir pyhealth db-dtypes statsmodels torcheval\n", + "# RESTART SESSION AFTER DOING THIS -\n", + "# Its required due to the version conflicts in colab.\n", + "# Proceed to the next cell after restarting session" + ], + "metadata": { + "id": "wcD2znez18pv" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "# MIMIC-IV Data Preprocessing\n", + "\n", + "\n", + "\n", + "\n", + "### **CSV to Normalized Parquet Pipeline**\n", + "\n", + "---\n", + "\n", + "The following cells provide the full data-prep pipeline starting from **PhysioNet's MIMIC-IV v3.x** CSV files. This workflow ensures data integrity and high-performance I/O for downstream machine learning tasks.\n", + "\n", + "### Pipeline Overview\n", + "The pipeline processes the raw data in two primary stages:\n", + "1. **Type Conversion:** Transforms raw CSV shards into schema-typed **Parquet** files.\n", + "2. **Normalization:** Builds the event, label, and split files that `build_hdf5_from_parquets` expects.\n", + "\n", + "---\n", + "\n", + "### Input Configuration\n", + "This is the layout **PhysioNet** distributes for **MIMIC-IV v3.x** (featuring CSV shards per table).\n", + "\n", + "> **Expected Input Layout**\n", + "> Ensure your `csv_root` directory follows this structure:\n", + "\n", + "```text\n", + "csv_root/\n", + "├── hosp/\n", + "│ ├── admissions/\n", + "│ | └── sharded CSVs\n", + "│ └── patients/\n", + "| └── sharded CSVs\n", + "└── icu/\n", + " ├── chartevents/\n", + " | └── sharded CSVs\n", + " └── icustays/\n", + " | └── sharded CSVs\n", + " └── inputevents/\n", + " └── sharded CSVs\n", + "```\n", + "\n", + "### Output Artifacts\n", + "The preprocessing script generates a set of optimized **Parquet** files under `--out-root`. These are structured to be consumed directly by `build_hdf5_from_parquets`:\n", + "\n", + "---\n", + "\n", + "#### Core Metadata\n", + "* `stay_demo.parquet`\n", + "\n", + "#### Normalized Events\n", + "* `chart_events_norm.parquet`\n", + "* `drug_events_norm.parquet`\n", + "* `demo_events_norm.parquet`\n", + "\n", + "#### Labels & Data Splits\n", + "* `labels.parquet`\n", + "* `splits.parquet`\n", + "* `labels_split.parquet`\n", + "* `stay_lists/{train, val, test}_stays.parquet`\n", + "\n", + "#### Feature Mapping\n", + "* `feature_map.parquet`\n", + "* `features.txt`\n", + "\n", + "---" + ], + "metadata": { + "id": "l_irOYrg_S0f" + } + }, + { + "cell_type": "code", + "source": [ + "from __future__ import annotations\n", + "\n", + "import glob\n", + "import json\n", + "import random\n", + "import os\n", + "from pathlib import Path\n", + "import argparse\n", + "import math\n", + "import time\n", + "from copy import deepcopy\n", + "from typing import Any, Dict, List, Mapping, Optional, Tuple\n", + "from collections import defaultdict\n", + "from sklearn.metrics import average_precision_score, roc_auc_score\n", + "from torch.utils.data import DataLoader, Dataset, IterableDataset, get_worker_info\n", + "\n", + "from tqdm.auto import tqdm\n", + "\n", + "import numpy as np\n", + "import polars as pl\n", + "\n", + "import torch\n", + "import torch.nn as nn\n", + "import torch.nn.functional as F\n", + "from torch.nn.utils.rnn import pack_padded_sequence\n", + "from torch.utils.data import Dataset, DataLoader\n", + "import h5py\n", + "\n", + "from pyhealth.datasets import SampleEHRDataset\n", + "from pyhealth.models import BaseModel\n", + "\n", + "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", + "print(f\"Using device: {device}\")\n" + ], + "metadata": { + "id": "nskUoVhCJQYu" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "'''\n", + "Declare the following Paths -\n", + "1) CSV Path - Path where the CSV files from Physionet are residing\n", + "2) Out directory - Path where the output of the processed data should reside\n", + "3) Checkpoint directory - Directory where the trained models which reside\n", + "'''\n", + "csv_root = Path(\"/content/mimic4\")\n", + "out_root = Path(\"/content/data/mimic\")\n", + "checkpoint_dir = Path(\"/content/checkpoints\")\n", + "data_root = out_root" + ], + "metadata": { + "id": "MNZtuBh7Jahf" + }, + "execution_count": 43, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "## 1) Declare MIMIC-IV Data Schema and Feature Vocabulary\n" + ], + "metadata": { + "id": "7P1pxqE1LZyi" + } + }, + { + "cell_type": "code", + "source": [ + "def get_mimic_dtypes() -> Dict[str, Dict[str, Any]]:\n", + " \"\"\"Return typed schemas for MIMIC-IV v3.x CSV tables.\n", + "\n", + " These match the fields used by the TD ICU mortality pipeline; unused\n", + " columns are dropped by downstream steps.\n", + "\n", + " Returns:\n", + " Mapping from table name to ``{column: polars_dtype}``.\n", + " \"\"\"\n", + " return {\n", + " \"admissions\": {\n", + " \"subject_id\": pl.Int64,\n", + " \"hadm_id\": pl.Int64,\n", + " \"admittime\": pl.Datetime,\n", + " \"dischtime\": pl.Datetime,\n", + " \"deathtime\": pl.Datetime,\n", + " \"admission_type\": pl.Utf8,\n", + " \"discharge_location\": pl.Utf8,\n", + " \"insurance\": pl.Utf8,\n", + " \"language\": pl.Utf8,\n", + " \"marital_status\": pl.Utf8,\n", + " \"race\": pl.Utf8,\n", + " \"edregtime\": pl.Datetime,\n", + " \"edouttime\": pl.Datetime,\n", + " \"hospital_expire_flag\": pl.Int64,\n", + " },\n", + " \"patients\": {\n", + " \"subject_id\": pl.Int64,\n", + " \"gender\": pl.Utf8,\n", + " \"anchor_age\": pl.Int64,\n", + " \"anchor_year\": pl.Int64,\n", + " \"anchor_year_group\": pl.Utf8,\n", + " \"dod\": pl.Datetime,\n", + " },\n", + " \"chartevents\": {\n", + " \"subject_id\": pl.Int64,\n", + " \"hadm_id\": pl.Int64,\n", + " \"stay_id\": pl.Int64,\n", + " \"caregiver_id\": pl.Int64,\n", + " \"charttime\": pl.Datetime,\n", + " \"storetime\": pl.Datetime,\n", + " \"itemid\": pl.Int64,\n", + " \"value\": pl.Utf8,\n", + " \"valuenum\": pl.Float64,\n", + " \"valueuom\": pl.Utf8,\n", + " \"warning\": pl.Int64,\n", + " },\n", + " \"icustays\": {\n", + " \"subject_id\": pl.Int64,\n", + " \"hadm_id\": pl.Int64,\n", + " \"stay_id\": pl.Int64,\n", + " \"first_careunit\": pl.Utf8,\n", + " \"last_careunit\": pl.Utf8,\n", + " \"intime\": pl.Datetime,\n", + " \"outtime\": pl.Datetime,\n", + " \"los\": pl.Float64,\n", + " },\n", + " \"inputevents\": {\n", + " \"subject_id\": pl.Int64,\n", + " \"hadm_id\": pl.Int64,\n", + " \"stay_id\": pl.Int64,\n", + " \"caregiver_id\": pl.Int64,\n", + " \"starttime\": pl.Datetime,\n", + " \"endtime\": pl.Datetime,\n", + " \"storetime\": pl.Datetime,\n", + " \"itemid\": pl.Int64,\n", + " \"amount\": pl.Float64,\n", + " \"amountuom\": pl.Utf8,\n", + " \"rate\": pl.Float64,\n", + " \"rateuom\": pl.Utf8,\n", + " \"orderid\": pl.Int64,\n", + " \"linkorderid\": pl.Int64,\n", + " \"ordercategoryname\": pl.Utf8,\n", + " \"secondaryordercategoryname\": pl.Utf8,\n", + " \"ordercomponenttypedescription\": pl.Utf8,\n", + " \"ordercategorydescription\": pl.Utf8,\n", + " \"patientweight\": pl.Float64,\n", + " \"totalamount\": pl.Float64,\n", + " \"totalamountuom\": pl.Utf8,\n", + " \"isopenbag\": pl.Int64,\n", + " \"continueinnextdept\": pl.Int64,\n", + " \"statusdescription\": pl.Utf8,\n", + " \"originalamount\": pl.Float64,\n", + " \"originalrate\": pl.Float64,\n", + " },\n", + " }\n", + "\n", + "LAB_ITEMIDS: Dict[int, str] = {\n", + " 220045: \"Heart Rate\",\n", + " 220050: \"Arterial BP Systolic\",\n", + " 220051: \"Arterial BP Diastolic\",\n", + " 220052: \"Arterial BP Mean\",\n", + " 220179: \"Non-Invasive BP Sys\",\n", + " 220180: \"Non-Invasive BP Dia\",\n", + " 220181: \"Non-Invasive BP Mean\",\n", + " 220210: \"Respiratory Rate\",\n", + " 220277: \"SpO2\",\n", + " 223761: \"Temperature\",\n", + " 220739: \"GCS Eye\",\n", + " 223900: \"GCS Verbal\",\n", + " 223901: \"GCS Motor\",\n", + " 226253: \"Albumin\",\n", + " 225624: \"Bilirubin\",\n", + " 220635: \"Calcium\",\n", + " 220615: \"Creatinine\",\n", + " 220228: \"Haemoglobin\",\n", + " 220545: \"Haematocrit\",\n", + " 225664: \"Bedside Glucose\",\n", + " 220621: \"Glucose\",\n", + " 227442: \"Potassium\",\n", + " 220734: \"pH\",\n", + " 220235: \"Blood Gas pCO2\",\n", + " 220224: \"Blood Gas pO2\",\n", + " 225690: \"Bicarbonate\",\n", + " 220602: \"Chloride\",\n", + " 220645: \"Sodium\",\n", + " 220507: \"Urea\",\n", + " 220562: \"Platelets\",\n", + " 220546: \"WBC\",\n", + " 227457: \"Lactate\",\n", + " 225668: \"Troponin - T\",\n", + " 220576: \"CRP\",\n", + " 220274: \"FiO2\",\n", + " 227013: \"PEEP\",\n", + " 224685: \"Tidal Volume\",\n", + " 224684: \"Minute Volume\",\n", + " 224695: \"Peak Inspiratory Pressure\",\n", + " 224688: \"Respiratory Rate (set)\",\n", + " 224690: \"Respiratory Rate (total)\",\n", + " 220659: \"PTT\",\n", + " 220613: \"ALT\",\n", + " 220587: \"AST\",\n", + " 226559: \"Foley\",\n", + "}\n", + "\n", + "DRUG_ITEMIDS: Dict[int, str] = {\n", + " 221906: \"Noradrenaline\",\n", + " 221289: \"Adrenaline\",\n", + " 221662: \"Dopamine\",\n", + " 221653: \"Dobutamine\",\n", + " 222315: \"Vasopressin\",\n", + " 229581: \"Phenylephrine\",\n", + " 225942: \"Propofol\",\n", + " 221744: \"Midazolam\",\n", + " 222168: \"Fentanyl\",\n", + " 225154: \"Morphine Sulfate\",\n", + " 221749: \"Lorazepam\",\n", + " 222011: \"Cisatracurium\",\n", + " 221986: \"Rocuronium\",\n", + " 225916: \"Dexmedetomidine\",\n", + " 229582: \"Ketamine\",\n", + " 220995: \"Heparin\",\n", + " 228339: \"Insulin - Regular\",\n", + " 222021: \"Furosemide\",\n", + "}\n", + "\n", + "ANTIBIOTIC_ITEMIDS: Dict[int, str] = {\n", + " 225798: \"Vancomycin\",\n", + " 225851: \"Piperacillin-Tazobactam\",\n", + " 225893: \"Meropenem\",\n", + " 225879: \"Cefepime\",\n", + " 225875: \"Ampicillin\",\n", + " 225876: \"Ampicillin-Sulbactam\",\n", + " 229584: \"Azithromycin\",\n", + " 229585: \"Ceftriaxone\",\n", + " 225899: \"Metronidazole\",\n", + " 225883: \"Levofloxacin\",\n", + " 228201: \"Ciprofloxacin\",\n", + " 225837: \"Fluconazole\",\n", + "}\n", + "\n", + "# Unit conversions to harmonize across MIMIC-IV's mixed recording units.\n", + "LAB_UNIT_CONV: Dict[str, float] = {\n", + " \"Albumin\": 10.0,\n", + " \"Bedside Glucose\": 1 / 18,\n", + " \"Glucose\": 1 / 18,\n", + " \"Bilirubin\": 17.1,\n", + " \"Blood Gas pO2\": 0.133322,\n", + " \"Blood Gas pCO2\": 0.133322,\n", + " \"Calcium\": 0.2495,\n", + " \"Creatinine\": 88.42,\n", + " \"Haemoglobin\": 10.0,\n", + " \"Urea\": 0.357,\n", + "}\n", + "\n", + "LABEL_KEYS: List[str] = [\n", + " \"1-day-died\", \"3-day-died\", \"7-day-died\", \"14-day-died\", \"28-day-died\",\n", + "]\n", + "\n", + "def build_feature_vocabulary() -> Tuple[List[str], Dict[str, int]]:\n", + " \"\"\"Construct the ordered feature list and name->id mapping.\n", + "\n", + " Returns:\n", + " ``(feature_names, name_to_id)`` where ``feature_names`` is a sorted\n", + " list of chart + drug features, followed by demographic features.\n", + " \"\"\"\n", + " all_drug_map = {**DRUG_ITEMIDS, **ANTIBIOTIC_ITEMIDS}\n", + " lab_features = list(LAB_ITEMIDS.values())\n", + " drug_features = (\n", + " [v + \"_rate\" for v in DRUG_ITEMIDS.values()]\n", + " + [v + \"_rate\" for v in ANTIBIOTIC_ITEMIDS.values()]\n", + " + [v + \"_bolus\" for v in DRUG_ITEMIDS.values()]\n", + " + [v + \"_bolus\" for v in ANTIBIOTIC_ITEMIDS.values()]\n", + " )\n", + " demo_features = [\"age\", \"gender\", \"patientweight\"]\n", + " feature_names = sorted(set(lab_features + drug_features)) + demo_features\n", + " name_to_id = {f: i for i, f in enumerate(feature_names)}\n", + " _ = all_drug_map # silence unused\n", + " return feature_names, name_to_id\n", + "\n" + ], + "metadata": { + "id": "kzQTUOqJLcq8" + }, + "execution_count": 28, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "## 2) Convert the CSV to Parquet and build the Parquet Pipeline Stages\n", + "\n" + ], + "metadata": { + "id": "uRtyvkAdL9K2" + } + }, + { + "cell_type": "code", + "source": [ + "def discover_csv_shards(csv_root: Path) -> Dict[str, List[str]]:\n", + " \"\"\"Find MIMIC-IV CSV shard files per logical table.\n", + "\n", + " Args:\n", + " csv_root: Root of the MIMIC-IV CSV tree (contains ``hosp/`` and ``icu/``).\n", + "\n", + " Returns:\n", + " Mapping from table name to sorted list of CSV paths.\n", + " \"\"\"\n", + " table_files = {\n", + " \"admissions\": sorted(\n", + " glob.glob(str(csv_root / \"hosp\" / \"admissions\" / \"*.csv\"))\n", + " ),\n", + " \"patients\": sorted(\n", + " glob.glob(str(csv_root / \"hosp\" / \"patients\" / \"*.csv\"))\n", + " ),\n", + " \"chartevents\": sorted(\n", + " glob.glob(str(csv_root / \"icu\" / \"chartevents\" / \"*.csv\"))\n", + " ),\n", + " \"icustays\": sorted(\n", + " glob.glob(str(csv_root / \"icu\" / \"icustays\" / \"*.csv\"))\n", + " ),\n", + " \"inputevents\": sorted(\n", + " glob.glob(str(csv_root / \"icu\" / \"inputevents\" / \"*.csv\"))\n", + " ),\n", + " }\n", + " for table, files in table_files.items():\n", + " print(f\" {table}: {len(files)} shard(s)\")\n", + " missing = [t for t, fs in table_files.items() if len(fs) == 0]\n", + " if missing:\n", + " raise FileNotFoundError(\n", + " f\"No CSV shards found for {missing} under {csv_root}. \"\n", + " \"Check that the MIMIC-IV download is extracted with the \"\n", + " \"expected layout (hosp/... and icu/...).\"\n", + " )\n", + " return table_files\n", + "\n", + "def csv_to_parquet(\n", + " csv_root: Path,\n", + " parquet_dir: Path,\n", + " overwrite: bool = False,\n", + ") -> None:\n", + " \"\"\"Convert sharded MIMIC-IV CSVs to typed parquet files.\n", + "\n", + " Uses polars ``scan_csv`` + ``sink_parquet`` so the conversion\n", + " streams and does not require the full CSV in memory.\n", + "\n", + " Args:\n", + " csv_root: Root of the MIMIC-IV CSV tree.\n", + " parquet_dir: Output directory for per-table parquet files.\n", + " overwrite: If ``False``, skip tables whose parquet already exists.\n", + " \"\"\"\n", + " parquet_dir.mkdir(parents=True, exist_ok=True)\n", + " table_files = discover_csv_shards(csv_root)\n", + " dtypes = get_mimic_dtypes()\n", + "\n", + " for table, files in table_files.items():\n", + " out_path = parquet_dir / f\"{table}.parquet\"\n", + " if out_path.exists() and not overwrite:\n", + " print(f\" skipping {table}: {out_path} exists\")\n", + " continue\n", + " print(f\" converting {table}: {len(files)} shard(s) -> {out_path}\")\n", + " (\n", + " pl.scan_csv(\n", + " files,\n", + " schema_overrides=dtypes[table],\n", + " infer_schema_length=10000,\n", + " null_values=[\"\", \"NULL\", \"null\", \"NaN\", \"nan\"],\n", + " ignore_errors=False,\n", + " )\n", + " .sink_parquet(str(out_path))\n", + " )\n", + "\n", + "def build_stay_demographics(\n", + " parquet_dir: Path,\n", + " out_dir: Path,\n", + ") -> Path:\n", + " \"\"\"Join admissions, patients, and icustays into a stay-level table.\n", + "\n", + " Args:\n", + " parquet_dir: Directory with per-table parquet files.\n", + " out_dir: Destination for ``stay_demo.parquet``.\n", + "\n", + " Returns:\n", + " Path to the written parquet.\n", + " \"\"\"\n", + " admissions = pl.scan_parquet(parquet_dir / \"admissions.parquet\")\n", + " patients = pl.scan_parquet(parquet_dir / \"patients.parquet\")\n", + " icustays = pl.scan_parquet(parquet_dir / \"icustays.parquet\")\n", + "\n", + " stay_demo = (\n", + " icustays\n", + " .join(\n", + " admissions.select([\n", + " \"subject_id\", \"hadm_id\", \"admittime\", \"dischtime\",\n", + " \"deathtime\", \"hospital_expire_flag\",\n", + " ]),\n", + " on=[\"subject_id\", \"hadm_id\"], how=\"left\",\n", + " )\n", + " .join(\n", + " patients.select([\"subject_id\", \"gender\", \"anchor_age\", \"dod\"]),\n", + " on=\"subject_id\", how=\"left\",\n", + " )\n", + " .with_columns([\n", + " pl.col(\"gender\").replace({\"F\": 1.0, \"M\": 0.0})\n", + " .cast(pl.Float64).alias(\"gender_code\"),\n", + " pl.col(\"anchor_age\").cast(pl.Float64).alias(\"age\"),\n", + " ])\n", + " )\n", + "\n", + " out_dir.mkdir(parents=True, exist_ok=True)\n", + " out_path = out_dir / \"stay_demo.parquet\"\n", + " stay_demo.collect(engine=\"streaming\").write_parquet(out_path)\n", + " print(f\" wrote {out_path}\")\n", + " return out_path\n", + "\n", + "def build_chart_events(\n", + " parquet_dir: Path,\n", + " out_dir: Path,\n", + ") -> Path:\n", + " \"\"\"Build normalized chart events (labs + vitals) with unit harmonization.\n", + "\n", + " Filters to the itemids in ``LAB_ITEMIDS``, applies\n", + " ``LAB_UNIT_CONV`` where required, and produces rows keyed by\n", + " ``(subject_id, hadm_id, stay_id, charttime, feature, value_num)``.\n", + "\n", + " Args:\n", + " parquet_dir: Directory with per-table parquet files.\n", + " out_dir: Destination for ``chart_events_norm.parquet``.\n", + "\n", + " Returns:\n", + " Path to the written parquet.\n", + " \"\"\"\n", + " chartevents = pl.scan_parquet(parquet_dir / \"chartevents.parquet\")\n", + " lab_ids = list(LAB_ITEMIDS.keys())\n", + "\n", + " chart_events = (\n", + " chartevents\n", + " .filter(pl.col(\"itemid\").is_in(lab_ids))\n", + " .filter(pl.col(\"valuenum\").is_not_null())\n", + " .with_columns([\n", + " pl.col(\"itemid\").replace_strict(LAB_ITEMIDS).alias(\"feature\"),\n", + " pl.col(\"valuenum\").cast(pl.Float64).alias(\"value_num\"),\n", + " ])\n", + " .with_columns([\n", + " pl.when(pl.col(\"feature\").is_in(list(LAB_UNIT_CONV.keys())))\n", + " .then(\n", + " pl.col(\"value_num\")\n", + " * pl.col(\"feature\").replace_strict(\n", + " LAB_UNIT_CONV, default=1.0\n", + " )\n", + " )\n", + " .otherwise(pl.col(\"value_num\"))\n", + " .alias(\"value_num\"),\n", + " ])\n", + " .select([\n", + " \"subject_id\", \"hadm_id\", \"stay_id\", \"charttime\",\n", + " \"feature\", \"value_num\",\n", + " ])\n", + " )\n", + "\n", + " out_path = out_dir / \"chart_events_norm.parquet\"\n", + " chart_events.collect(engine=\"streaming\").write_parquet(out_path)\n", + " print(f\" wrote {out_path}\")\n", + " return out_path\n", + "\n", + "def build_drug_events(\n", + " parquet_dir: Path,\n", + " out_dir: Path,\n", + ") -> Path:\n", + " \"\"\"Build normalized drug events (rates + boluses).\n", + "\n", + " Args:\n", + " parquet_dir: Directory with per-table parquet files.\n", + " out_dir: Destination for ``drug_events_norm.parquet``.\n", + "\n", + " Returns:\n", + " Path to the written parquet.\n", + " \"\"\"\n", + " inputevents = pl.scan_parquet(parquet_dir / \"inputevents.parquet\")\n", + " all_drug_map = {**DRUG_ITEMIDS, **ANTIBIOTIC_ITEMIDS}\n", + " all_drug_ids = list(all_drug_map.keys())\n", + "\n", + " drug_events = (\n", + " inputevents\n", + " .filter(pl.col(\"itemid\").is_in(all_drug_ids))\n", + " .with_columns([\n", + " pl.col(\"itemid\").replace_strict(all_drug_map).alias(\"drug_name\"),\n", + " pl.col(\"amount\").cast(pl.Float64),\n", + " pl.col(\"rate\").cast(pl.Float64),\n", + " ])\n", + " .with_columns([\n", + " pl.when(pl.col(\"rate\").is_not_null())\n", + " .then(pl.col(\"drug_name\") + pl.lit(\"_rate\"))\n", + " .otherwise(pl.col(\"drug_name\") + pl.lit(\"_bolus\"))\n", + " .alias(\"feature\"),\n", + " pl.when(pl.col(\"rate\").is_not_null())\n", + " .then(pl.col(\"rate\"))\n", + " .otherwise(pl.col(\"amount\"))\n", + " .alias(\"value_num\"),\n", + " pl.col(\"starttime\").alias(\"charttime\"),\n", + " ])\n", + " .filter(pl.col(\"value_num\").is_not_null())\n", + " .select([\n", + " \"subject_id\", \"hadm_id\", \"stay_id\", \"charttime\",\n", + " \"feature\", \"value_num\", \"patientweight\",\n", + " ])\n", + " )\n", + "\n", + " out_path = out_dir / \"drug_events_norm.parquet\"\n", + " drug_events.collect(engine=\"streaming\").write_parquet(out_path)\n", + " print(f\" wrote {out_path}\")\n", + " return out_path\n", + "\n", + "\n", + "def build_labels_and_splits(\n", + " stay_demo_path: Path,\n", + " out_dir: Path,\n", + " train_frac: float = 0.8,\n", + " val_frac: float = 0.1,\n", + " seed: int = 42,\n", + ") -> Tuple[Path, Path, Path]:\n", + " \"\"\"Build stay-level mortality labels and patient-level splits.\n", + "\n", + " Labels are derived from ``dod - intime`` using the five horizons\n", + " from the paper. Splits are patient-level (no subject appears in\n", + " more than one split).\n", + "\n", + " Args:\n", + " stay_demo_path: Path to ``stay_demo.parquet``.\n", + " out_dir: Directory for output parquets.\n", + " train_frac: Fraction of patients assigned to train.\n", + " val_frac: Fraction assigned to val (remainder goes to test).\n", + " seed: RNG seed for the patient shuffle.\n", + "\n", + " Returns:\n", + " Triple ``(labels_path, splits_path, labels_split_path)``.\n", + " \"\"\"\n", + " stay_demo = pl.scan_parquet(stay_demo_path)\n", + "\n", + " hour_bounds = {\n", + " \"1-day-died\": 24,\n", + " \"3-day-died\": 72,\n", + " \"7-day-died\": 168,\n", + " \"14-day-died\": 336,\n", + " \"28-day-died\": 672,\n", + " }\n", + "\n", + " labels = (\n", + " stay_demo\n", + " .with_columns([\n", + " (\n", + " (pl.col(\"dod\").is_not_null())\n", + " & (\n", + " (pl.col(\"dod\") - pl.col(\"intime\"))\n", + " .dt.total_hours() <= hrs\n", + " )\n", + " )\n", + " .cast(pl.Int8).alias(key)\n", + " for key, hrs in hour_bounds.items()\n", + " ])\n", + " .select(\n", + " [\"subject_id\", \"hadm_id\", \"stay_id\", \"intime\", \"outtime\"]\n", + " + list(hour_bounds.keys())\n", + " + [\"age\", \"gender_code\"]\n", + " )\n", + " )\n", + " labels_path = out_dir / \"labels.parquet\"\n", + " labels.collect(engine=\"streaming\").write_parquet(labels_path)\n", + " print(f\" wrote {labels_path}\")\n", + "\n", + " # Patient-level split\n", + " labels_df = pl.read_parquet(labels_path)\n", + " subject_ids = labels_df[\"subject_id\"].unique().to_numpy()\n", + " rng = np.random.default_rng(seed)\n", + " rng.shuffle(subject_ids)\n", + "\n", + " n = len(subject_ids)\n", + " n_train = int(train_frac * n)\n", + " n_val = int(val_frac * n)\n", + " train_ids = set(subject_ids[:n_train].tolist())\n", + " val_ids = set(subject_ids[n_train:n_train + n_val].tolist())\n", + "\n", + " split_df = labels_df.select([\"subject_id\", \"stay_id\"]).with_columns([\n", + " pl.when(pl.col(\"subject_id\").is_in(train_ids)).then(pl.lit(\"train\"))\n", + " .when(pl.col(\"subject_id\").is_in(val_ids)).then(pl.lit(\"val\"))\n", + " .otherwise(pl.lit(\"test\"))\n", + " .alias(\"split\"),\n", + " ])\n", + " splits_path = out_dir / \"splits.parquet\"\n", + " split_df.write_parquet(splits_path)\n", + " print(f\" wrote {splits_path}\")\n", + "\n", + " labels_split = labels_df.join(\n", + " split_df, on=[\"subject_id\", \"stay_id\"], how=\"left\",\n", + " )\n", + " labels_split_path = out_dir / \"labels_split.parquet\"\n", + " labels_split.write_parquet(labels_split_path)\n", + " print(f\" wrote {labels_split_path}\")\n", + "\n", + " return labels_path, splits_path, labels_split_path\n", + "\n", + "def build_demo_events(\n", + " stay_demo_path: Path,\n", + " labels_split_path: Path,\n", + " drug_events_path: Path,\n", + " out_dir: Path,\n", + ") -> Path:\n", + " \"\"\"Build demographic event rows (age, gender, weight) at ICU admission.\n", + "\n", + " Args:\n", + " stay_demo_path: Path to ``stay_demo.parquet``.\n", + " labels_split_path: Path to ``labels_split.parquet``.\n", + " drug_events_path: Path to ``drug_events_norm.parquet`` (used to\n", + " pull the first-recorded ``patientweight`` per stay).\n", + " out_dir: Destination for ``demo_events_norm.parquet``.\n", + "\n", + " Returns:\n", + " Path to the written parquet.\n", + " \"\"\"\n", + " stay_demo = pl.scan_parquet(stay_demo_path).select([\n", + " \"subject_id\", \"hadm_id\", \"stay_id\", \"intime\", \"age\", \"gender_code\",\n", + " ])\n", + " labels_split = pl.scan_parquet(labels_split_path).select([\n", + " \"subject_id\", \"hadm_id\", \"stay_id\", \"split\",\n", + " ])\n", + " drug_events = pl.scan_parquet(drug_events_path)\n", + "\n", + " weight_events = (\n", + " drug_events\n", + " .filter(pl.col(\"patientweight\").is_not_null())\n", + " .group_by([\"subject_id\", \"hadm_id\", \"stay_id\"])\n", + " .agg([\n", + " pl.col(\"patientweight\").drop_nulls().first().alias(\"patientweight\"),\n", + " ])\n", + " )\n", + "\n", + " demo_base = (\n", + " stay_demo\n", + " .join(\n", + " labels_split,\n", + " on=[\"subject_id\", \"hadm_id\", \"stay_id\"], how=\"left\",\n", + " )\n", + " .join(\n", + " weight_events,\n", + " on=[\"subject_id\", \"hadm_id\", \"stay_id\"], how=\"left\",\n", + " )\n", + " .with_columns([pl.col(\"intime\").alias(\"charttime\")])\n", + " .select([\n", + " \"subject_id\", \"hadm_id\", \"stay_id\", \"charttime\", \"split\",\n", + " \"age\", \"gender_code\", \"patientweight\",\n", + " ])\n", + " )\n", + "\n", + " def _make_event_rows(\n", + " col: str, feature_name: str,\n", + " ) -> pl.LazyFrame:\n", + " return (\n", + " demo_base\n", + " .filter(pl.col(col).is_not_null())\n", + " .select([\n", + " \"subject_id\", \"hadm_id\", \"stay_id\",\n", + " \"charttime\", \"split\", col,\n", + " ])\n", + " .with_columns([\n", + " pl.lit(feature_name).alias(\"feature\"),\n", + " pl.col(col).cast(pl.Float64).alias(\"value_num\"),\n", + " ])\n", + " .select([\n", + " \"subject_id\", \"hadm_id\", \"stay_id\", \"charttime\",\n", + " \"split\", \"feature\", \"value_num\",\n", + " ])\n", + " )\n", + "\n", + " age_rows = _make_event_rows(\"age\", \"age\")\n", + " gender_rows = _make_event_rows(\"gender_code\", \"gender\")\n", + " weight_rows = _make_event_rows(\"patientweight\", \"patientweight\")\n", + "\n", + " demo_event_rows = pl.concat([age_rows, gender_rows, weight_rows])\n", + " out_path = out_dir / \"demo_events_norm.parquet\"\n", + " demo_event_rows.collect(engine=\"streaming\").write_parquet(out_path)\n", + " print(f\" wrote {out_path}\")\n", + " return out_path\n", + "\n", + "def build_stay_lists(\n", + " labels_split_path: Path,\n", + " out_dir: Path,\n", + ") -> Dict[str, Path]:\n", + " \"\"\"Write per-split stay-id lists to ``{out_dir}/stay_lists/``.\n", + "\n", + " Args:\n", + " labels_split_path: Path to ``labels_split.parquet``.\n", + " out_dir: Directory that will contain ``stay_lists/``.\n", + "\n", + " Returns:\n", + " Mapping from split name to parquet path.\n", + " \"\"\"\n", + " stay_list_dir = out_dir / \"stay_lists\"\n", + " stay_list_dir.mkdir(parents=True, exist_ok=True)\n", + " labels_split = pl.read_parquet(labels_split_path)\n", + "\n", + " out: Dict[str, Path] = {}\n", + " for split_name in [\"train\", \"val\", \"test\"]:\n", + " p = stay_list_dir / f\"{split_name}_stays.parquet\"\n", + " (\n", + " labels_split\n", + " .filter(pl.col(\"split\") == split_name)\n", + " .select([\"subject_id\", \"hadm_id\", \"stay_id\"])\n", + " .unique()\n", + " .write_parquet(p)\n", + " )\n", + " print(f\" wrote {p}\")\n", + " out[split_name] = p\n", + " return out\n", + "\n", + "def build_feature_map(out_dir: Path) -> Tuple[Path, Path]:\n", + " \"\"\"Write ``feature_map.parquet`` and ``features.txt``.\n", + "\n", + " Args:\n", + " out_dir: Destination directory.\n", + "\n", + " Returns:\n", + " Pair of paths ``(feature_map_path, features_txt_path)``.\n", + " \"\"\"\n", + " feature_names, name_to_id = build_feature_vocabulary()\n", + " feat_df = pl.DataFrame({\n", + " \"feature\": list(name_to_id.keys()),\n", + " \"feature_id\": list(name_to_id.values()),\n", + " })\n", + " feature_map_path = out_dir / \"feature_map.parquet\"\n", + " feat_df.write_parquet(feature_map_path)\n", + " print(f\" wrote {feature_map_path}\")\n", + "\n", + " features_txt = out_dir / \"features.txt\"\n", + " features_txt.write_text(\"\\n\".join(feature_names) + \"\\n\")\n", + " print(f\" wrote {features_txt}\")\n", + "\n", + " return feature_map_path, features_txt" + ], + "metadata": { + "id": "QimpsKMJL4Mv" + }, + "execution_count": 29, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "## 3) Pipeline Orchestrator\n", + "### Run the above pipeline\n", + "\n", + "\n", + "\n", + "---\n", + "\n", + "### 🚨 Prerequisites\n", + "> **CRITICAL:** Ensure all raw CSV files are correctly loaded into your `csv_root` directory before continuing. The orchestrator will fail if it cannot locate the `hosp/` and `icu/` sub-directories.\n", + "\n", + "---\n", + "\n" + ], + "metadata": { + "id": "99DHDCEbQdHH" + } + }, + { + "cell_type": "code", + "source": [ + "out_root.mkdir(parents=True, exist_ok=True)\n", + "\n", + "parquet_dir = out_root / \"tables\"\n", + "overwrite_parquets = False\n", + "\n", + "print(\"\\n[1/7] Converting CSV -> typed parquet...\")\n", + "csv_to_parquet(csv_root, parquet_dir, overwrite=overwrite_parquets)\n", + "\n", + "print(\"\\n[2/7] Building stay_demo...\")\n", + "stay_demo_path = build_stay_demographics(parquet_dir, out_root)\n", + "\n", + "print(\"\\n[3/7] Building chart events (labs + vitals)...\")\n", + "build_chart_events(parquet_dir, out_root)\n", + "\n", + "print(\"\\n[4/7] Building drug events...\")\n", + "drug_events_path = build_drug_events(parquet_dir, out_root)\n", + "\n", + "print(\"\\n[5/7] Building labels + splits...\")\n", + "_, _, labels_split_path = build_labels_and_splits(\n", + " stay_demo_path, out_root,\n", + " )\n", + "\n", + "print(\"\\n[6/7] Building demographic event rows + stay lists...\")\n", + "build_demo_events(\n", + " stay_demo_path, labels_split_path, drug_events_path, out_root,\n", + " )\n", + "build_stay_lists(labels_split_path, out_root)\n", + "\n", + "print(\"\\n[7/7] Writing feature map + features.txt...\")\n", + "build_feature_map(out_root)\n", + "\n", + "print(f\"\\nAll preprocessing outputs live under {out_root}\")" + ], + "metadata": { + "id": "4Xh6sFqdQEI-" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "## 4) Compute Scaling" + ], + "metadata": { + "id": "zfbvquHqP-rY" + } + }, + { + "cell_type": "code", + "source": [ + "def compute_scaling_from_h5(\n", + " h5_path: Path,\n", + " feature_names: List[str],\n", + " sample_rows: Optional[int] = None,\n", + ") -> Dict[str, Any]:\n", + " \"\"\"Compute per-feature mean and std from a training HDF5.\n", + "\n", + " Sweeps the training split once with running sums. For each of\n", + " ``values``, ``delta_time`` and ``delta_value``, computes per-feature\n", + " stats using only non-NaN entries and feature-id >= 0 slots.\n", + "\n", + " Args:\n", + " h5_path: Path to the training HDF5 file.\n", + " feature_names: Ordered feature names (length == n_features).\n", + " sample_rows: If provided, only use the first ``sample_rows`` rows.\n", + "\n", + " Returns:\n", + " Scaling dict in the format expected by ``TDICUMortalityModel``.\n", + " \"\"\"\n", + " import h5py\n", + "\n", + " n_features = len(feature_names)\n", + " value_sum = np.zeros(n_features, dtype=np.float64)\n", + " value_sq = np.zeros(n_features, dtype=np.float64)\n", + " value_cnt = np.zeros(n_features, dtype=np.int64)\n", + " dt_sum = np.zeros(n_features, dtype=np.float64)\n", + " dt_sq = np.zeros(n_features, dtype=np.float64)\n", + " dt_cnt = np.zeros(n_features, dtype=np.int64)\n", + " dv_sum = np.zeros(n_features, dtype=np.float64)\n", + " dv_sq = np.zeros(n_features, dtype=np.float64)\n", + " dv_cnt = np.zeros(n_features, dtype=np.int64)\n", + " tp_sum, tp_sq, tp_cnt = 0.0, 0.0, 0\n", + "\n", + " with h5py.File(h5_path, \"r\") as h5f:\n", + " n = h5f[\"features\"].shape[0]\n", + " rows = n if sample_rows is None else min(sample_rows, n)\n", + " chunk = 2048\n", + " for i in range(0, rows, chunk):\n", + " j = min(i + chunk, rows)\n", + " feats = h5f[\"features\"][i:j]\n", + " vals = h5f[\"values\"][i:j]\n", + " dts = h5f[\"deltatime\"][i:j]\n", + " dvs = h5f[\"deltavalue\"][i:j]\n", + " tps = h5f[\"timepoints\"][i:j]\n", + "\n", + " feat_valid = feats >= 0\n", + " tp_valid = ~np.isnan(tps)\n", + " tp = tps[tp_valid]\n", + " tp_sum += tp.sum()\n", + " tp_sq += np.square(tp).sum()\n", + " tp_cnt += tp.size\n", + "\n", + " for arr, sum_, sq_, cnt_ in [\n", + " (vals, value_sum, value_sq, value_cnt),\n", + " (dts, dt_sum, dt_sq, dt_cnt),\n", + " (dvs, dv_sum, dv_sq, dv_cnt),\n", + " ]:\n", + " valid = feat_valid & ~np.isnan(arr)\n", + " idx = feats[valid].astype(np.int64)\n", + " x = arr[valid].astype(np.float64)\n", + " np.add.at(sum_, idx, x)\n", + " np.add.at(sq_, idx, x * x)\n", + " np.add.at(cnt_, idx, 1)\n", + "\n", + " def finalize(sum_, sq_, cnt_) -> Tuple[np.ndarray, np.ndarray]:\n", + " \"\"\"Convert running sums to (mean, std) with a small-variance floor.\"\"\"\n", + " mean = np.divide(sum_, np.maximum(cnt_, 1))\n", + " var = np.divide(sq_, np.maximum(cnt_, 1)) - mean ** 2\n", + " var = np.maximum(var, 1e-8)\n", + " std = np.sqrt(var)\n", + " return mean.astype(np.float32), std.astype(np.float32)\n", + "\n", + " v_mean, v_std = finalize(value_sum, value_sq, value_cnt)\n", + " dt_mean, dt_std = finalize(dt_sum, dt_sq, dt_cnt)\n", + " dv_mean, dv_std = finalize(dv_sum, dv_sq, dv_cnt)\n", + " tp_mean = np.float32(tp_sum / max(tp_cnt, 1))\n", + " tp_var = max(tp_sq / max(tp_cnt, 1) - tp_mean ** 2, 1e-8)\n", + " tp_std = np.float32(np.sqrt(tp_var))\n", + "\n", + " def _per_feature(arr: np.ndarray) -> Dict[str, torch.Tensor]:\n", + " \"\"\"Expand a 1D per-feature array into the name-keyed tensor dict.\"\"\"\n", + " return {\n", + " f: torch.tensor([arr[i]], dtype=torch.float32)\n", + " for i, f in enumerate(feature_names)\n", + " }\n", + "\n", + " return {\n", + " \"mean\": {\n", + " \"timepoints\": torch.tensor([tp_mean], dtype=torch.float32),\n", + " \"values\": _per_feature(v_mean),\n", + " \"delta_time\": _per_feature(dt_mean),\n", + " \"delta_value\": _per_feature(dv_mean),\n", + " },\n", + " \"std\": {\n", + " \"timepoints\": torch.tensor([tp_std], dtype=torch.float32),\n", + " \"values\": _per_feature(v_std),\n", + " \"delta_time\": _per_feature(dt_std),\n", + " \"delta_value\": _per_feature(dv_std),\n", + " },\n", + " }\n" + ], + "metadata": { + "id": "mVnLv3Y5S83C" + }, + "execution_count": 31, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "##5) This stage converts the parquet -> HDF5\n" + ], + "metadata": { + "id": "RB5Z1JauUGBY" + } + }, + { + "cell_type": "code", + "source": [ + "\n", + "NEXT_STATE_DELAY_HOURS = 24.0\n", + "NEXT_STATE_WINDOW_HOURS = 24.0\n", + "\n", + "\n", + "def choose_sample_indices(\n", + " n_events: int,\n", + " first_k: int = 8,\n", + " every_k: int = 32,\n", + " max_samples: int = 100,\n", + ") -> List[int]:\n", + " \"\"\"Pick state-marker indices within a single ICU stay.\n", + "\n", + " Args:\n", + " n_events: Number of events in the stay.\n", + " first_k: Keep every event in the first k.\n", + " every_k: After the first k, sample every k'th event.\n", + " max_samples: Cap on total samples per stay.\n", + "\n", + " Returns:\n", + " Sorted list of indices into the stay's event array.\n", + " \"\"\"\n", + " if n_events <= 1:\n", + " return []\n", + " idx_set = set()\n", + " for i in range(min(first_k, n_events - 1)):\n", + " idx_set.add(i)\n", + " i = first_k\n", + " while i < n_events - 1:\n", + " idx_set.add(i)\n", + " i += every_k\n", + " idx_set.add(n_events - 1)\n", + " idx_list = sorted(idx_set)\n", + "\n", + " if len(idx_list) > max_samples:\n", + " first_block = idx_list[:first_k]\n", + " terminal = idx_list[-1]\n", + " middle = idx_list[first_k:-1]\n", + " budget = max_samples - first_k - 1\n", + " if budget > 0 and len(middle) > budget:\n", + " step = len(middle) / budget\n", + " middle = [middle[int(i * step)] for i in range(budget)]\n", + " idx_list = first_block + middle + [terminal]\n", + " return idx_list\n", + "\n", + "\n", + "def get_next_state_idxs_batch(\n", + " times: np.ndarray,\n", + " cur_idxs: np.ndarray,\n", + " delay: float = NEXT_STATE_DELAY_HOURS,\n", + " window: float = NEXT_STATE_WINDOW_HOURS,\n", + ") -> Tuple[np.ndarray, np.ndarray]:\n", + " \"\"\"Vectorized next-state selection per the paper's SMRP rule.\n", + "\n", + " Args:\n", + " times: Sorted 1D float array of event times for one stay.\n", + " cur_idxs: Integer array of current-state indices.\n", + " delay: SMRP delay in hours.\n", + " window: Eligibility window length in hours.\n", + "\n", + " Returns:\n", + " Tuple ``(next_idxs, is_terminal)``.\n", + " \"\"\"\n", + " cur_idxs = np.asarray(cur_idxs, dtype=np.int64)\n", + " t0 = times[cur_idxs]\n", + " lo = t0 + delay\n", + " hi = t0 + delay + window\n", + " next_idxs = np.searchsorted(times, lo, side=\"left\")\n", + " in_bounds = next_idxs < len(times)\n", + " safe_next = np.clip(next_idxs, 0, len(times) - 1)\n", + " within = times[safe_next] <= hi\n", + " after = next_idxs > cur_idxs\n", + " valid = in_bounds & within & after\n", + " is_terminal = ~valid\n", + " next_idxs = np.where(is_terminal, cur_idxs, next_idxs).astype(np.int64)\n", + " return next_idxs, is_terminal\n", + "\n", + "\n", + "def precompute_dt(times: np.ndarray) -> np.ndarray:\n", + " \"\"\"Compute gap-from-previous-event for a whole stay.\n", + "\n", + " Args:\n", + " times: Sorted timepoints.\n", + "\n", + " Returns:\n", + " Same-length dt array with ``dt[0] = 0``.\n", + " \"\"\"\n", + " n = len(times)\n", + " dt = np.zeros(n, dtype=np.float32)\n", + " if n > 1:\n", + " dt[1:] = np.diff(times)\n", + " return dt\n", + "\n", + "\n", + "def precompute_dv(values: np.ndarray, features: np.ndarray) -> np.ndarray:\n", + " \"\"\"Compute value-change-since-previous-measurement for a whole stay.\n", + "\n", + " Args:\n", + " values: Event values.\n", + " features: Event feature ids.\n", + "\n", + " Returns:\n", + " Same-length dv array.\n", + " \"\"\"\n", + " n = len(values)\n", + " dv = np.zeros(n, dtype=np.float32)\n", + " if n == 0:\n", + " return dv\n", + " idx = np.argsort(features, kind=\"stable\")\n", + " f_sorted = features[idx]\n", + " v_sorted = values[idx]\n", + " is_new = np.empty(n, dtype=bool)\n", + " is_new[0] = True\n", + " is_new[1:] = f_sorted[1:] != f_sorted[:-1]\n", + " dv_sorted = np.empty(n, dtype=np.float32)\n", + " dv_sorted[0] = 0.0\n", + " dv_sorted[1:] = v_sorted[1:] - v_sorted[:-1]\n", + " dv_sorted[is_new] = 0.0\n", + " dv[idx] = dv_sorted\n", + " return dv\n", + "\n", + "\n", + "def build_window_arrays(\n", + " times: np.ndarray,\n", + " values: np.ndarray,\n", + " features: np.ndarray,\n", + " dt_full: np.ndarray,\n", + " dv_full: np.ndarray,\n", + " end_idx: int,\n", + " context_len: int,\n", + ") -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray]:\n", + " \"\"\"Build a single padded state window ending at ``end_idx``.\n", + "\n", + " Args:\n", + " times: Full-stay timepoints.\n", + " values: Full-stay values.\n", + " features: Full-stay feature ids.\n", + " dt_full: Pre-computed full-stay dt array.\n", + " dv_full: Pre-computed full-stay dv array.\n", + " end_idx: Terminal event index.\n", + " context_len: Target window size.\n", + "\n", + " Returns:\n", + " Tuple of arrays ``(timepoints, values, features, deltatime, deltavalue)``\n", + " each of length ``context_len``. Head slots are NaN / -1 padded.\n", + " \"\"\"\n", + " start_idx = max(0, end_idx - context_len + 1)\n", + " t = times[start_idx:end_idx + 1]\n", + " v = values[start_idx:end_idx + 1]\n", + " f = features[start_idx:end_idx + 1]\n", + " dt = dt_full[start_idx:end_idx + 1].copy()\n", + " dv = dv_full[start_idx:end_idx + 1]\n", + " if len(dt) > 0:\n", + " dt[0] = 0.0\n", + " pad_n = context_len - len(t)\n", + "\n", + " timepoints = np.full((context_len,), np.nan, dtype=np.float32)\n", + " values_arr = np.full((context_len,), np.nan, dtype=np.float32)\n", + " features_arr = np.full((context_len,), -1, dtype=np.int16)\n", + " deltatime = np.full((context_len,), -1.0, dtype=np.float32)\n", + " deltavalue = np.full((context_len,), np.nan, dtype=np.float32)\n", + "\n", + " timepoints[pad_n:] = t\n", + " values_arr[pad_n:] = v\n", + " features_arr[pad_n:] = f\n", + " deltatime[pad_n:] = dt\n", + " deltavalue[pad_n:] = dv\n", + " return timepoints, values_arr, features_arr, deltatime, deltavalue\n" + ], + "metadata": { + "id": "xYbzG3gQTDOV" + }, + "execution_count": 32, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "def build_hdf5_from_parquets(\n", + " parquet_dir: Path,\n", + " out_dir: Path,\n", + " split: str,\n", + " context_len: int = 400,\n", + " first_k: int = 8,\n", + " every_k: int = 32,\n", + " max_samples_per_stay: int = 100,\n", + " label_keys: Optional[List[str]] = None,\n", + " flush_every: int = 8192,\n", + " stay_batch_size: int = 2000,\n", + ") -> Path:\n", + " \"\"\"Build one split's HDF5 from normalized MIMIC-IV parquet files.\n", + "\n", + " Expected files under ``parquet_dir``:\n", + "\n", + " * ``chart_events_norm.parquet`` with columns ``subject_id``, ``hadm_id``,\n", + " ``stay_id``, ``charttime``, ``feature``, ``value_num``.\n", + " * ``drug_events_norm.parquet`` same schema.\n", + " * ``demo_events_norm.parquet`` same schema.\n", + " * ``labels_split.parquet`` with ``subject_id``, ``hadm_id``,\n", + " ``stay_id``, ``intime``, ``outtime``, ``split``, plus binary label\n", + " columns (``1-day-died`` ... ``28-day-died``).\n", + " * ``stay_lists/{split}_stays.parquet`` with ``stay_id`` column.\n", + " * ``feature_map.parquet`` with ``feature`` (str) and ``feature_id`` (int).\n", + "\n", + " Output: ``{out_dir}/h5{split}_fixed.hdf5``.\n", + "\n", + " Args:\n", + " parquet_dir: Directory containing the parquet inputs.\n", + " out_dir: Destination directory for the HDF5 file.\n", + " split: One of ``\"train\"``, ``\"val\"``, ``\"test\"``.\n", + " context_len: Sequence length per state window.\n", + " first_k: ``choose_sample_indices`` parameter.\n", + " every_k: ``choose_sample_indices`` parameter.\n", + " max_samples_per_stay: Cap on samples per ICU stay.\n", + " label_keys: Binary label columns in ``labels_split.parquet``.\n", + " Defaults to the 5 horizons from the paper.\n", + " flush_every: Rows to buffer between HDF5 writes.\n", + " stay_batch_size: Stays per polars batch load.\n", + "\n", + " Returns:\n", + " Path to the written HDF5 file.\n", + " \"\"\"\n", + " import h5py\n", + " import polars as pl\n", + "\n", + " if label_keys is None:\n", + " label_keys = [\n", + " \"1-day-died\", \"3-day-died\", \"7-day-died\",\n", + " \"14-day-died\", \"28-day-died\",\n", + " ]\n", + "\n", + " chart_path = parquet_dir / \"chart_events_norm.parquet\"\n", + " drug_path = parquet_dir / \"drug_events_norm.parquet\"\n", + " demo_path = parquet_dir / \"demo_events_norm.parquet\"\n", + " labels_path = parquet_dir / \"labels_split.parquet\"\n", + " stay_list_path = parquet_dir / \"stay_lists\" / f\"{split}_stays.parquet\"\n", + " feature_map_path = parquet_dir / \"feature_map.parquet\"\n", + "\n", + " if not feature_map_path.exists():\n", + " raise FileNotFoundError(\n", + " f\"feature_map.parquet not found at {feature_map_path}. \"\n", + " \"Build this from the training split before running build mode.\"\n", + " )\n", + " feat_df = pl.read_parquet(feature_map_path)\n", + "\n", + " out_dir.mkdir(parents=True, exist_ok=True)\n", + " out_path = out_dir / f\"h5{split}_fixed.hdf5\"\n", + " if out_path.exists():\n", + " out_path.unlink()\n", + "\n", + " h5f = h5py.File(out_path, \"w\", libver=\"latest\")\n", + " c = context_len\n", + " chunk_2d = (4096, c)\n", + " chunk_1d = (min(4096, 65536),)\n", + " kw: Dict[str, Any] = dict(compression=None, track_times=False)\n", + "\n", + " def _create_2d(name: str, dtype: str) -> Any:\n", + " \"\"\"Create a resizable 2D HDF5 dataset of shape (0, context_len).\"\"\"\n", + " return h5f.create_dataset(\n", + " name, shape=(0, c), maxshape=(None, c),\n", + " chunks=chunk_2d, dtype=dtype, **kw,\n", + " )\n", + "\n", + " def _create_1d(name: str, dtype: str) -> Any:\n", + " \"\"\"Create a resizable 1D HDF5 dataset of shape (0,).\"\"\"\n", + " return h5f.create_dataset(\n", + " name, shape=(0,), maxshape=(None,),\n", + " chunks=chunk_1d, dtype=dtype, **kw,\n", + " )\n", + "\n", + " datasets: Dict[str, Any] = {\n", + " \"timepoints\": _create_2d(\"timepoints\", \"f4\"),\n", + " \"values\": _create_2d(\"values\", \"f4\"),\n", + " \"features\": _create_2d(\"features\", \"i2\"),\n", + " \"deltatime\": _create_2d(\"deltatime\", \"f4\"),\n", + " \"deltavalue\": _create_2d(\"deltavalue\", \"f4\"),\n", + " \"nexttimepoints\": _create_2d(\"nexttimepoints\", \"f4\"),\n", + " \"nextvalues\": _create_2d(\"nextvalues\", \"f4\"),\n", + " \"nextfeatures\": _create_2d(\"nextfeatures\", \"i2\"),\n", + " \"nextdeltatime\": _create_2d(\"nextdeltatime\", \"f4\"),\n", + " \"nextdeltavalue\": _create_2d(\"nextdeltavalue\", \"f4\"),\n", + " \"isterminal\": _create_1d(\"isterminal\", \"i1\"),\n", + " }\n", + " for key in label_keys:\n", + " datasets[key] = _create_1d(key, \"i1\")\n", + "\n", + " stay_ids = (\n", + " pl.read_parquet(stay_list_path)\n", + " .select([\"stay_id\"]).unique().sort(\"stay_id\")[\"stay_id\"].to_list()\n", + " )\n", + "\n", + " def _flush(store: Dict[str, List]) -> None:\n", + " \"\"\"Append buffered rows to HDF5 datasets, then clear the buffer.\"\"\"\n", + " n = len(store[\"isterminal\"])\n", + " if n == 0:\n", + " return\n", + " for k, arr_list in store.items():\n", + " if k in label_keys or k == \"isterminal\":\n", + " arr = np.asarray(arr_list, dtype=np.int8)\n", + " elif k in (\"features\", \"nextfeatures\"):\n", + " arr = np.stack(arr_list).astype(np.int16, copy=False)\n", + " else:\n", + " arr = np.stack(arr_list).astype(np.float32, copy=False)\n", + " ds = datasets[k]\n", + " old_n = ds.shape[0]\n", + " new_n = old_n + len(arr)\n", + " if arr.ndim == 1:\n", + " ds.resize((new_n,))\n", + " ds[old_n:new_n] = arr\n", + " else:\n", + " ds.resize((new_n, arr.shape[1]))\n", + " ds[old_n:new_n, :] = arr\n", + " store.clear()\n", + "\n", + " batch_store: Dict[str, List] = defaultdict(list)\n", + " total_written = 0\n", + " n_batches = math.ceil(len(stay_ids) / stay_batch_size)\n", + " t0 = time.time()\n", + "\n", + " for b in range(n_batches):\n", + " lo = b * stay_batch_size\n", + " hi = min((b + 1) * stay_batch_size, len(stay_ids))\n", + " batch_ids = stay_ids[lo:hi]\n", + "\n", + " split_labels = (\n", + " pl.scan_parquet(labels_path)\n", + " .filter(\n", + " (pl.col(\"split\") == split)\n", + " & (pl.col(\"stay_id\").is_in(batch_ids))\n", + " )\n", + " .select(\n", + " [\"subject_id\", \"hadm_id\", \"stay_id\", \"intime\", *label_keys]\n", + " )\n", + " )\n", + " events = pl.concat([\n", + " pl.scan_parquet(p).filter(pl.col(\"stay_id\").is_in(batch_ids))\n", + " .select([\n", + " \"subject_id\", \"hadm_id\", \"stay_id\", \"charttime\",\n", + " \"feature\", \"value_num\",\n", + " ])\n", + " for p in [chart_path, drug_path, demo_path]\n", + " ])\n", + " events_df = (\n", + " events\n", + " .join(\n", + " split_labels,\n", + " on=[\"subject_id\", \"hadm_id\", \"stay_id\"],\n", + " how=\"inner\",\n", + " )\n", + " .join(feat_df.lazy(), on=\"feature\", how=\"inner\")\n", + " .with_columns([pl.col(\"feature_id\").cast(pl.Int16)])\n", + " .sort([\"stay_id\", \"charttime\"])\n", + " .collect(engine=\"streaming\")\n", + " )\n", + "\n", + " for _stay_key, sdf in events_df.group_by(\n", + " \"stay_id\", maintain_order=True,\n", + " ):\n", + " sdf = sdf.sort(\"charttime\")\n", + " intime = sdf[\"intime\"][0]\n", + " times = (\n", + " (sdf[\"charttime\"] - intime).dt.total_seconds().to_numpy()\n", + " / 3600.0\n", + " ).astype(np.float32)\n", + " values = sdf[\"value_num\"].to_numpy().astype(np.float32)\n", + " features = sdf[\"feature_id\"].to_numpy().astype(np.int16)\n", + " labels = {k: int(sdf[k][0]) for k in label_keys}\n", + "\n", + " n_events = len(times)\n", + " sample_idxs = choose_sample_indices(\n", + " n_events, first_k=first_k, every_k=every_k,\n", + " max_samples=max_samples_per_stay,\n", + " )\n", + " if not sample_idxs:\n", + " continue\n", + "\n", + " dt_full = precompute_dt(times)\n", + " dv_full = precompute_dv(values, features)\n", + " sample_arr = np.asarray(sample_idxs, dtype=np.int64)\n", + " next_arr, term_arr = get_next_state_idxs_batch(times, sample_arr)\n", + "\n", + " for k, idx in enumerate(sample_idxs):\n", + " cur = build_window_arrays(\n", + " times, values, features, dt_full, dv_full,\n", + " idx, context_len,\n", + " )\n", + " if term_arr[k]:\n", + " nxt = cur\n", + " is_terminal = 1\n", + " else:\n", + " nxt = build_window_arrays(\n", + " times, values, features, dt_full, dv_full,\n", + " int(next_arr[k]), context_len,\n", + " )\n", + " is_terminal = 0\n", + "\n", + " batch_store[\"timepoints\"].append(cur[0])\n", + " batch_store[\"values\"].append(cur[1])\n", + " batch_store[\"features\"].append(cur[2])\n", + " batch_store[\"deltatime\"].append(cur[3])\n", + " batch_store[\"deltavalue\"].append(cur[4])\n", + " batch_store[\"nexttimepoints\"].append(nxt[0])\n", + " batch_store[\"nextvalues\"].append(nxt[1])\n", + " batch_store[\"nextfeatures\"].append(nxt[2])\n", + " batch_store[\"nextdeltatime\"].append(nxt[3])\n", + " batch_store[\"nextdeltavalue\"].append(nxt[4])\n", + " batch_store[\"isterminal\"].append(is_terminal)\n", + " for lk in label_keys:\n", + " batch_store[lk].append(labels[lk])\n", + "\n", + " if len(batch_store[\"isterminal\"]) >= flush_every:\n", + " _flush(batch_store)\n", + " total_written = datasets[\"isterminal\"].shape[0]\n", + "\n", + " print(\n", + " f\" [{split}] batch {b + 1}/{n_batches} | rows={total_written} \"\n", + " f\"| {(time.time() - t0) / 60:.1f} min\"\n", + " )\n", + "\n", + " _flush(batch_store)\n", + " h5f.close()\n", + " print(f\" [{split}] wrote {out_path}\")\n", + " return out_path\n", + "\n" + ], + "metadata": { + "id": "6dXktur4Tn_M" + }, + "execution_count": 33, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "## 6) HDF5 -> DataLoader" + ], + "metadata": { + "id": "v6SFOWh8UVm9" + } + }, + { + "cell_type": "code", + "source": [ + "\n", + "class H5SequenceDataset(Dataset):\n", + " \"\"\"Dataset that serves state transitions from a pre-built HDF5 file.\n", + "\n", + " Opens the file lazily inside each worker process. Expects the HDF5\n", + " layout produced by ``build_hdf5_from_parquets``.\n", + "\n", + " Args:\n", + " h5_path: Path to the HDF5 file.\n", + " label_keys: Binary label columns stored in the file.\n", + " label_key: The single label column to expose as the primary target.\n", + " \"\"\"\n", + "\n", + " def __init__(\n", + " self,\n", + " h5_path: Path,\n", + " label_keys: List[str],\n", + " label_key: str = \"28-day-died\",\n", + " ) -> None:\n", + " import h5py\n", + "\n", + " self.h5_path = str(h5_path)\n", + " self.label_keys = label_keys\n", + " self.label_key = label_key\n", + " self._h5 = None\n", + " with h5py.File(self.h5_path, \"r\") as f:\n", + " self._length = f[\"features\"].shape[0]\n", + "\n", + " def _open(self) -> None:\n", + " \"\"\"Open the HDF5 file handle (once per worker process).\"\"\"\n", + " if self._h5 is None:\n", + " import h5py\n", + "\n", + " self._h5 = h5py.File(self.h5_path, \"r\")\n", + "\n", + " def __len__(self) -> int:\n", + " return self._length\n", + "\n", + " def __getitem__(\n", + " self, idx: int,\n", + " ) -> Tuple[Dict[str, torch.Tensor], Dict[str, torch.Tensor], Dict[str, Any]]:\n", + " \"\"\"Return one (batch dict, targets dict, meta dict) tuple.\n", + "\n", + " Args:\n", + " idx: Sample index.\n", + "\n", + " Returns:\n", + " Tuple of dicts that ``collate_td`` stacks into a batch.\n", + " \"\"\"\n", + " self._open()\n", + " f = self._h5\n", + " x = {\n", + " \"timepoints\": torch.from_numpy(f[\"timepoints\"][idx]).float(),\n", + " \"values\": torch.from_numpy(f[\"values\"][idx]).float(),\n", + " \"features\": torch.from_numpy(f[\"features\"][idx]).long(),\n", + " \"delta_time\": torch.from_numpy(f[\"deltatime\"][idx]).float(),\n", + " \"delta_value\": torch.from_numpy(f[\"deltavalue\"][idx]).float(),\n", + " \"next_timepoints\": torch.from_numpy(\n", + " f[\"nexttimepoints\"][idx]\n", + " ).float(),\n", + " \"next_values\": torch.from_numpy(f[\"nextvalues\"][idx]).float(),\n", + " \"next_features\": torch.from_numpy(f[\"nextfeatures\"][idx]).long(),\n", + " \"next_delta_time\": torch.from_numpy(\n", + " f[\"nextdeltatime\"][idx]\n", + " ).float(),\n", + " \"next_delta_value\": torch.from_numpy(\n", + " f[\"nextdeltavalue\"][idx]\n", + " ).float(),\n", + " \"isterminal\": torch.tensor(\n", + " [f[\"isterminal\"][idx]], dtype=torch.float32\n", + " ),\n", + " }\n", + " y = {\n", + " k: torch.tensor([f[k][idx]], dtype=torch.float32)\n", + " for k in self.label_keys\n", + " }\n", + " meta = {\"sample_id\": int(idx)}\n", + " return x, y, meta\n", + "\n", + "\n", + "def collate_td(batch):\n", + " \"\"\"Stack a list of samples into a batched dict.\n", + "\n", + " Args:\n", + " batch: List of ``(x_dict, y_dict, meta_dict)`` tuples.\n", + "\n", + " Returns:\n", + " Triple of stacked dicts plus the list of meta dicts.\n", + " \"\"\"\n", + " xs, ys, metas = zip(*batch)\n", + " out_x = {k: torch.stack([x[k] for x in xs], dim=0) for k in xs[0]}\n", + " out_y = {k: torch.stack([y[k] for y in ys], dim=0) for k in ys[0]}\n", + " return out_x, out_y, metas\n" + ], + "metadata": { + "id": "bgRH8qsYT7b-" + }, + "execution_count": 34, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "# Model Architecture: TD-ICU Mortality\n", + "### **PyHealth Implementation & Uncertainty Extension**\n", + "\n", + "---\n", + "\n", + "This module implements the methodology from:\n", + "> **Frost et al.**, *\"Robust Real-Time Mortality Prediction in the Intensive Care Unit using Temporal Difference Learning\"* ([arXiv:2411.04285](https://arxiv.org/abs/2411.04285)).\n", + "\n", + "Framing mortality prediction as a **value-estimation problem**. Instead of simple classification, the model predicts the probability of death within a fixed horizon (e.g., 28 days) using **Temporal-Difference (TD) targets** derived from a lagged \"target network.\"\n", + "\n", + "---\n", + "\n", + "### Core Components\n", + "\n", + "The implementation provides two specialized classes:\n", + "\n", + "#### 1. `CNNLSTMPredictor`\n", + "* **Architecture:** Hybrid CNN + LSTM backbone.\n", + "* **Function:** Maps complex event streams to a single mortality probability.\n", + "* **Usage:** Ideal for standard supervised training baselines.\n", + "\n", + "#### 2. `TDICUMortalityModel`\n", + "* **Framework:** A `PyHealth.BaseModel` wrapper.\n", + "* **Logic:** Manages dual networks (**Online** + **Target**) and implements the TD training rule.\n", + "* **Inference Plus:** Features `predict_with_confidence`, which uses **Monte Carlo (MC) Dropout** to attach a per-patient uncertainty estimate to every prediction.\n", + "\n", + "---\n", + "\n", + "The cell below is the code from pyhealth/models\n" + ], + "metadata": { + "id": "fPSEovJYElWd" + } + }, + { + "cell_type": "code", + "source": [ + "class MaxPool1D(nn.Module):\n", + " \"\"\"NaN-aware 1D max pool for irregular event streams.\n", + "\n", + " Windows that are entirely NaN remain NaN in the output. Windows that\n", + " contain at least one real value output the max of those real values.\n", + " Padded positions can therefore be tracked across pooling layers.\n", + "\n", + " Args:\n", + " kernel_size: Pooling window size.\n", + " stride: Pooling stride.\n", + " \"\"\"\n", + "\n", + " def __init__(self, kernel_size: int = 2, stride: int = 2) -> None:\n", + " super().__init__()\n", + " self.kernel_size = kernel_size\n", + " self.stride = stride\n", + "\n", + " def forward(self, x: torch.Tensor) -> torch.Tensor:\n", + " \"\"\"Run NaN-aware max pooling.\n", + "\n", + " Args:\n", + " x: Input tensor of shape ``[batch, channels, seq_len]``.\n", + "\n", + " Returns:\n", + " Pooled tensor with NaN preserved in all-NaN windows.\n", + " \"\"\"\n", + " neg_inf = torch.tensor(-np.inf, dtype=x.dtype, device=x.device)\n", + " pos_nan = torch.tensor(np.nan, dtype=x.dtype, device=x.device)\n", + " out = torch.where(torch.isnan(x), neg_inf, x)\n", + " out = F.max_pool1d(out, kernel_size=self.kernel_size, stride=self.stride)\n", + " return torch.where(torch.isinf(out), pos_nan, out)\n", + "\n", + "\n", + "class Transpose(nn.Module):\n", + " \"\"\"Swap two tensor dimensions inside an ``nn.Sequential`` pipeline.\n", + "\n", + " Args:\n", + " dim1: First dimension to swap.\n", + " dim2: Second dimension to swap.\n", + " \"\"\"\n", + "\n", + " def __init__(self, dim1: int, dim2: int) -> None:\n", + " super().__init__()\n", + " self.dim1 = dim1\n", + " self.dim2 = dim2\n", + "\n", + " def forward(self, x: torch.Tensor) -> torch.Tensor:\n", + " \"\"\"Transpose the configured dimensions.\n", + "\n", + " Args:\n", + " x: Input tensor.\n", + "\n", + " Returns:\n", + " Tensor with ``dim1`` and ``dim2`` swapped.\n", + " \"\"\"\n", + " return x.transpose(self.dim1, self.dim2)\n", + "\n", + "\n", + "# -----------------------------------------------------------------------------\n", + "# CNN + LSTM backbone\n", + "# -----------------------------------------------------------------------------\n", + "\n", + "\n", + "class CNNLSTMPredictor(nn.Module):\n", + " \"\"\"CNN + LSTM encoder producing mortality predictions.\n", + "\n", + " The predictor embeds each component of an irregular event stream\n", + " (timepoint, value, feature id, delta-time, delta-value), fuses them via\n", + " summation + batchnorm, applies a small CNN stack for sequence-length\n", + " reduction, a 2-layer LSTM for temporal modelling, and a dense head for\n", + " binary classification.\n", + "\n", + " Args:\n", + " n_features: Size of the feature vocabulary.\n", + " features: Ordered list of feature names used to assemble the\n", + " scaling buffers from ``scaling``.\n", + " output_dim: Output dimensionality (always 1 for binary mortality).\n", + " scaling: Dictionary of per-feature statistics with the structure\n", + " ``{\"mean\": {...}, \"std\": {...}}`` where each inner dict contains\n", + " per-feature 1-element tensors for ``values``, ``delta_time``,\n", + " and ``delta_value``, plus scalar tensors for ``timepoints``.\n", + " cnn_layers: Number of CNN blocks (each applies a conv + ReLU + pool).\n", + " hidden_dim: Channel/embedding dimension of the encoder.\n", + " dropout: Dropout applied between LSTM layers.\n", + " batch_first: Whether the LSTM consumes batch-first tensors.\n", + " dtype: Floating point dtype for model parameters.\n", + " device: Device string, used to choose between packed vs masked LSTM.\n", + "\n", + " Attributes:\n", + " embedding_net: ``nn.ModuleDict`` of five per-component embeddings.\n", + " cnn: CNN stack that reduces sequence length.\n", + " lstm: Two-layer LSTM with hidden size ``hidden_dim * 8``.\n", + " dense: Binary prediction head.\n", + " \"\"\"\n", + "\n", + " def __init__(\n", + " self,\n", + " n_features: int,\n", + " features: List[str],\n", + " output_dim: int,\n", + " scaling: Mapping[str, Any],\n", + " cnn_layers: int = 2,\n", + " hidden_dim: int = 32,\n", + " dropout: float = 0.5,\n", + " batch_first: bool = True,\n", + " dtype: torch.dtype = torch.float32,\n", + " device: str = \"cpu\",\n", + " ) -> None:\n", + " super().__init__()\n", + " self.dtype = dtype\n", + " self.device_name = device\n", + " self.n_features = n_features\n", + " self.features = features\n", + " self.output_dim = output_dim\n", + " self.cnn_layers = cnn_layers\n", + " self.hidden_dim = hidden_dim\n", + "\n", + " self._register_scaling_buffers(scaling)\n", + " self._build_embeddings()\n", + " self._build_cnn()\n", + " self._build_lstm(dropout=dropout, batch_first=batch_first)\n", + " self._build_head()\n", + "\n", + " self.to(device)\n", + " self.init_weights()\n", + "\n", + " # -- construction helpers -------------------------------------------------\n", + "\n", + " def _register_scaling_buffers(\n", + " self,\n", + " scaling: Mapping[str, Any],\n", + " ) -> None:\n", + " \"\"\"Register per-feature mean/std tensors as buffers.\"\"\"\n", + " for name in [\"values\", \"delta_time\", \"delta_value\"]:\n", + " mean_t = torch.cat([scaling[\"mean\"][name][f] for f in self.features])\n", + " std_t = torch.cat([scaling[\"std\"][name][f] for f in self.features])\n", + " self.register_buffer(f\"mean_{name}\", mean_t)\n", + " self.register_buffer(f\"std_{name}\", std_t)\n", + " self.register_buffer(\"mean_timepoints\", scaling[\"mean\"][\"timepoints\"])\n", + " self.register_buffer(\"std_timepoints\", scaling[\"std\"][\"timepoints\"])\n", + "\n", + " def _build_embeddings(self) -> None:\n", + " \"\"\"Build per-component embedding networks.\"\"\"\n", + " interim = int(np.sqrt(self.hidden_dim))\n", + " self.embedding_net = nn.ModuleDict()\n", + " for name in [\"time\", \"value\", \"feature\", \"delta_time\", \"delta_value\"]:\n", + " if name == \"feature\":\n", + " self.embedding_net[name] = nn.Embedding(\n", + " self.n_features, self.hidden_dim, dtype=self.dtype\n", + " )\n", + " else:\n", + " self.embedding_net[name] = nn.Sequential(\n", + " nn.Linear(1, interim, dtype=self.dtype),\n", + " nn.ReLU(),\n", + " nn.Linear(interim, self.hidden_dim, dtype=self.dtype),\n", + " )\n", + " self.embedding_norm = nn.Sequential(\n", + " Transpose(1, 2),\n", + " nn.BatchNorm1d(self.hidden_dim),\n", + " Transpose(1, 2),\n", + " )\n", + "\n", + " def _build_cnn(self) -> None:\n", + " \"\"\"Build the CNN stack used after embedding fusion.\"\"\"\n", + " layers: List[nn.Module] = [Transpose(1, 2)]\n", + " for _ in range(self.cnn_layers):\n", + " layers += [\n", + " nn.Conv1d(\n", + " self.hidden_dim,\n", + " self.hidden_dim,\n", + " kernel_size=2,\n", + " stride=1,\n", + " padding=1,\n", + " dtype=self.dtype,\n", + " ),\n", + " nn.ReLU(),\n", + " MaxPool1D(kernel_size=2, stride=2),\n", + " ]\n", + " layers.append(Transpose(1, 2))\n", + " self.cnn = nn.Sequential(*layers)\n", + "\n", + " def _build_lstm(self, dropout: float, batch_first: bool) -> None:\n", + " \"\"\"Build the stacked LSTM.\"\"\"\n", + " self.lstm = nn.LSTM(\n", + " input_size=self.hidden_dim,\n", + " hidden_size=self.hidden_dim * 8,\n", + " num_layers=2,\n", + " dropout=dropout,\n", + " batch_first=batch_first,\n", + " dtype=self.dtype,\n", + " )\n", + "\n", + " def _build_head(self) -> None:\n", + " \"\"\"Build the dense classification head.\"\"\"\n", + " self.dense = nn.Sequential(\n", + " nn.BatchNorm1d(self.hidden_dim * 8, dtype=self.dtype),\n", + " nn.Linear(self.hidden_dim * 8, self.hidden_dim, dtype=self.dtype),\n", + " nn.ReLU(),\n", + " nn.BatchNorm1d(self.hidden_dim, dtype=self.dtype),\n", + " nn.Linear(self.hidden_dim, self.output_dim, dtype=self.dtype),\n", + " )\n", + "\n", + " def init_weights(self) -> None:\n", + " \"\"\"Xavier-normal initialization for matrix-shaped parameters.\"\"\"\n", + " for p in self.parameters():\n", + " if p.dim() > 1:\n", + " nn.init.xavier_normal_(p)\n", + "\n", + " # -- inference helpers ----------------------------------------------------\n", + "\n", + " @torch.no_grad()\n", + " def soft_update(\n", + " self, new_model: \"CNNLSTMPredictor\", alpha: float = 0.99\n", + " ) -> None:\n", + " \"\"\"Exponential-moving-average update from another predictor.\n", + "\n", + " Implements ``theta_self = alpha * theta_self + (1 - alpha) * theta_new``\n", + " in-place, which is the target-network update rule from the paper\n", + " (Appendix D, Eq. 7).\n", + "\n", + " Args:\n", + " new_model: Source predictor whose weights to blend in.\n", + " alpha: EMA coefficient in ``[0, 1]``. ``alpha=1`` leaves ``self``\n", + " unchanged; ``alpha=0`` fully copies ``new_model`` into ``self``.\n", + " \"\"\"\n", + " src = new_model.state_dict()\n", + " tgt = self.state_dict()\n", + " for k in tgt:\n", + " tgt[k].copy_(alpha * tgt[k] + (1.0 - alpha) * src[k])\n", + "\n", + " def pack_sequences(\n", + " self, src: torch.Tensor, mask: torch.Tensor\n", + " ) -> torch.nn.utils.rnn.PackedSequence:\n", + " \"\"\"Pack a padded batch for efficient LSTM consumption.\n", + "\n", + " Args:\n", + " src: Embedded sequence tensor ``[batch, seq_len, hidden]``.\n", + " mask: Boolean padding mask ``[batch, seq_len]`` where ``True``\n", + " marks padded positions.\n", + "\n", + " Returns:\n", + " A ``PackedSequence`` usable by ``self.lstm``.\n", + " \"\"\"\n", + " lengths = (~mask).sum(-1)\n", + " if lengths.min() == 0:\n", + " mask = mask.clone()\n", + " mask[torch.where(lengths == 0)[0], 0] = False\n", + " lengths = torch.where(lengths == 0, torch.ones_like(lengths), lengths)\n", + " return pack_padded_sequence(\n", + " src, lengths.cpu(), batch_first=True, enforce_sorted=False\n", + " )\n", + "\n", + " def get_mask_after_conv(self, mask: torch.Tensor) -> torch.Tensor:\n", + " \"\"\"Propagate a padding mask through the CNN pooling stack.\n", + "\n", + " Args:\n", + " mask: Boolean mask ``[batch, seq_len]`` aligned with the input.\n", + "\n", + " Returns:\n", + " Boolean mask aligned with the post-CNN sequence length.\n", + " \"\"\"\n", + " pooled = mask.unsqueeze(1).float()\n", + " for _ in range(self.cnn_layers):\n", + " pooled = F.avg_pool1d(pooled, kernel_size=2, stride=2)\n", + " return pooled.squeeze(1) == 1\n", + "\n", + " def standardise_inputs(\n", + " self,\n", + " timepoints: torch.Tensor,\n", + " values: torch.Tensor,\n", + " features: torch.Tensor,\n", + " delta_time: torch.Tensor,\n", + " delta_value: torch.Tensor,\n", + " ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:\n", + " \"\"\"Apply per-feature z-score standardization.\n", + "\n", + " For each of ``values``, ``delta_time`` and ``delta_value``, scatter\n", + " the event's value into a feature-indexed long tensor, apply per-\n", + " feature mean/std, then collapse back to a dense per-event tensor\n", + " using ``torch.nansum``. ``timepoints`` use a single global scaler.\n", + "\n", + " Args:\n", + " timepoints: Event timepoints ``[batch, seq_len]``.\n", + " values: Event values ``[batch, seq_len]``.\n", + " features: Feature ids ``[batch, seq_len]`` with ``-1`` for padding.\n", + " delta_time: Time since previous measurement of the same feature.\n", + " delta_value: Value change since previous measurement of the same\n", + " feature.\n", + "\n", + " Returns:\n", + " Tuple of standardized tensors in the order\n", + " ``(timepoints, values, delta_time, delta_value, features)``\n", + " where ``features`` has had its ``-1`` sentinels remapped to ``0``.\n", + " \"\"\"\n", + " safe_time_std = torch.where(\n", + " self.std_timepoints == 0,\n", + " torch.ones_like(self.std_timepoints),\n", + " self.std_timepoints,\n", + " )\n", + " standardised = [(timepoints - self.mean_timepoints) / safe_time_std]\n", + "\n", + " features = torch.where(features == -1, 0, features)\n", + " long_shape = values.shape[:2] + (self.n_features,)\n", + "\n", + " base_nan = torch.full(\n", + " long_shape,\n", + " float(\"nan\"),\n", + " dtype=self.dtype,\n", + " device=values.device,\n", + " )\n", + "\n", + " stats = {\n", + " \"values\": (self.mean_values, self.std_values),\n", + " \"delta_time\": (self.mean_delta_time, self.std_delta_time),\n", + " \"delta_value\": (self.mean_delta_value, self.std_delta_value),\n", + " }\n", + " for name, vector in [\n", + " (\"values\", values),\n", + " (\"delta_time\", delta_time),\n", + " (\"delta_value\", delta_value),\n", + " ]:\n", + " missing = torch.isnan(vector)\n", + " scattered = base_nan.clone()\n", + " scattered.scatter_(-1, features.unsqueeze(-1), vector.unsqueeze(-1))\n", + "\n", + " mean_v, std_v = stats[name]\n", + " safe_std = torch.where(std_v == 0, torch.ones_like(std_v), std_v)\n", + " scattered = (scattered - mean_v) / safe_std\n", + " collapsed = torch.nansum(scattered, -1).unsqueeze(-1)\n", + " collapsed = torch.where(\n", + " missing.unsqueeze(-1),\n", + " torch.tensor(float(\"nan\"), dtype=self.dtype, device=values.device),\n", + " collapsed,\n", + " )\n", + " standardised.append(collapsed)\n", + "\n", + " standardised.append(features)\n", + " return tuple(standardised)\n", + "\n", + " # -- forward --------------------------------------------------------------\n", + "\n", + " def forward(\n", + " self,\n", + " timepoints: torch.Tensor,\n", + " values: torch.Tensor,\n", + " features: torch.Tensor,\n", + " delta_time: torch.Tensor,\n", + " delta_value: torch.Tensor,\n", + " normalise: bool = True,\n", + " ) -> Tuple[torch.Tensor, torch.Tensor]:\n", + " \"\"\"Run the CNN-LSTM forward pass.\n", + "\n", + " Args:\n", + " timepoints: ``[batch, seq_len]`` hours since intime, NaN for pad.\n", + " values: ``[batch, seq_len]`` measurement values, NaN for pad.\n", + " features: ``[batch, seq_len]`` long tensor of feature ids, -1 for\n", + " pad.\n", + " delta_time: ``[batch, seq_len]`` hours since previous measurement\n", + " of the same feature.\n", + " delta_value: ``[batch, seq_len]`` change in value since previous\n", + " measurement of the same feature.\n", + " normalise: If ``True``, apply ``standardise_inputs`` first.\n", + "\n", + " Returns:\n", + " Tuple ``(probs, logits)`` each of shape ``[batch, output_dim]``.\n", + " ``probs = sigmoid(logits)``.\n", + " \"\"\"\n", + " if normalise:\n", + " (\n", + " timepoints,\n", + " values,\n", + " delta_time,\n", + " delta_value,\n", + " features,\n", + " ) = self.standardise_inputs(\n", + " timepoints, values, features, delta_time, delta_value\n", + " )\n", + "\n", + " timepoints = timepoints.squeeze(-1) if timepoints.ndim == 3 else timepoints\n", + " values = values.squeeze(-1) if values.ndim == 3 else values\n", + " delta_time = delta_time.squeeze(-1) if delta_time.ndim == 3 else delta_time\n", + " delta_value = (\n", + " delta_value.squeeze(-1) if delta_value.ndim == 3 else delta_value\n", + " )\n", + "\n", + " # Sort descending by timepoint; NaNs (padding) go to the end.\n", + " argsort_idx = torch.argsort(\n", + " torch.where(torch.isnan(timepoints), -torch.inf, timepoints),\n", + " dim=1,\n", + " descending=True,\n", + " )\n", + " timepoints = torch.gather(timepoints, 1, argsort_idx)\n", + " values = torch.gather(values, 1, argsort_idx)\n", + " features = torch.gather(features, 1, argsort_idx)\n", + " delta_time = torch.gather(delta_time, 1, argsort_idx)\n", + " delta_value = torch.gather(delta_value, 1, argsort_idx)\n", + "\n", + " src_mask = torch.isnan(timepoints)\n", + " timepoints = torch.where(src_mask, 0, timepoints)\n", + " values = torch.where(src_mask, 0, values)\n", + " features = torch.where(src_mask, 0, features)\n", + "\n", + " dt_mask = torch.isnan(delta_time)\n", + " dv_mask = torch.isnan(delta_value)\n", + " delta_time = torch.where(dt_mask, 0, delta_time)\n", + " delta_value = torch.where(dv_mask, 0, delta_value)\n", + "\n", + " time_emb = self.embedding_net[\"time\"](timepoints.unsqueeze(-1))\n", + " value_emb = self.embedding_net[\"value\"](values.unsqueeze(-1))\n", + " feature_emb = self.embedding_net[\"feature\"](features)\n", + " dt_emb = self.embedding_net[\"delta_time\"](delta_time.unsqueeze(-1))\n", + " dv_emb = self.embedding_net[\"delta_value\"](delta_value.unsqueeze(-1))\n", + "\n", + " dt_emb = torch.where(dt_mask.unsqueeze(-1), 0, dt_emb)\n", + " dv_emb = torch.where(dv_mask.unsqueeze(-1), 0, dv_emb)\n", + "\n", + " embedded = time_emb + value_emb + feature_emb + dt_emb + dv_emb\n", + " embedded = self.embedding_norm(embedded)\n", + " embedded = self.cnn(embedded)\n", + "\n", + " src_mask = self.get_mask_after_conv(src_mask)\n", + "\n", + " if self.device_name == \"cuda\":\n", + " packed = self.pack_sequences(embedded, src_mask)\n", + " embedded = self.lstm(packed)[1][0][-1]\n", + " else:\n", + " embedded = embedded.clone()\n", + " embedded[src_mask] = 0\n", + " embedded = self.lstm(embedded)[1][0][-1]\n", + "\n", + " logits = self.dense(embedded)\n", + " probs = torch.sigmoid(logits)\n", + " return probs, logits\n", + "\n", + "\n", + "# -----------------------------------------------------------------------------\n", + "# Loss\n", + "# -----------------------------------------------------------------------------\n", + "\n", + "\n", + "class WeightedBCELoss(nn.Module):\n", + " \"\"\"Thin wrapper around ``BCEWithLogitsLoss`` with optional pos-weighting.\n", + "\n", + " Args:\n", + " pos_weight: Optional positive-class weight tensor. When ``None``,\n", + " a standard unweighted BCE is used. Only set this for supervised\n", + " training, never for TD training: a weighted loss produces\n", + " incorrect gradients when the target is a continuous bootstrapped\n", + " value rather than a 0/1 label.\n", + " \"\"\"\n", + "\n", + " def __init__(self, pos_weight: Optional[torch.Tensor] = None) -> None:\n", + " super().__init__()\n", + " self.loss_fn = nn.BCEWithLogitsLoss(pos_weight=pos_weight)\n", + "\n", + " def forward(\n", + " self, logits: torch.Tensor, targets: torch.Tensor\n", + " ) -> torch.Tensor:\n", + " \"\"\"Compute binary cross-entropy loss from logits.\n", + "\n", + " Args:\n", + " logits: Raw model output ``[batch, 1]``.\n", + " targets: Target values in ``[0, 1]`` (can be continuous for TD).\n", + "\n", + " Returns:\n", + " Scalar loss tensor.\n", + " \"\"\"\n", + " return self.loss_fn(logits, targets.float())\n", + "\n", + "\n", + "# -----------------------------------------------------------------------------\n", + "# PyHealth-facing TD model\n", + "# -----------------------------------------------------------------------------\n", + "\n", + "\n", + "class TDICUMortalityModel(BaseModel):\n", + " \"\"\"Temporal-difference ICU mortality prediction model.\n", + "\n", + " Implements the method of Frost et al. (2024), arXiv:2411.04285, for\n", + " real-time mortality prediction in the ICU. The model holds two\n", + " ``CNNLSTMPredictor`` instances:\n", + "\n", + " * ``online_net``: updated by gradient descent on a TD target.\n", + " * ``target_net``: an EMA-lagged copy of the online net whose predictions\n", + " form part of the TD target.\n", + "\n", + " At each training step, for a sample transition ``(s_t, s_{t+1})``, the\n", + " training target is:\n", + "\n", + " * ``target = label`` if ``s_{t+1}`` does not exist (terminal transition);\n", + " * ``target = gamma * target_net(s_{t+1})`` otherwise.\n", + "\n", + " The online net is regressed against this target with an unweighted binary\n", + " cross-entropy loss. The target net is updated once per optimizer step via\n", + " ``soft_update_target()``.\n", + "\n", + " The model can also be used in a purely supervised mode\n", + " (``train_td=False``) for baseline comparisons; in this mode the real\n", + " ``label_key`` is used as the target and ``pos_weight`` (if provided) is\n", + " applied.\n", + "\n", + " Example:\n", + " >>> from pyhealth.datasets import SampleEHRDataset\n", + " >>> from pyhealth.models.td_icu_mortality import TDICUMortalityModel\n", + " >>> dataset = SampleEHRDataset(samples=[], dataset_name=\"td_icu_demo\")\n", + " >>> scaling = build_scaling_dict(...) # see ``examples/`` directory\n", + " >>> feature_names = [f\"feat_{i}\" for i in range(128)]\n", + " >>> model = TDICUMortalityModel(\n", + " ... dataset=dataset,\n", + " ... feature_keys=[\"timepoints\", \"values\", \"features\",\n", + " ... \"delta_time\", \"delta_value\"],\n", + " ... label_key=\"28_day_died\",\n", + " ... n_features=128,\n", + " ... scaling=scaling,\n", + " ... features_vocab=feature_names,\n", + " ... )\n", + " >>> out = model(batch, targets=targets, train_td=True)\n", + " >>> out[\"loss\"].backward()\n", + " >>> model.soft_update_target()\n", + "\n", + " Args:\n", + " dataset: A ``SampleEHRDataset`` instance (can be empty; only used for\n", + " metadata such as ``input_schema`` / ``output_schema``).\n", + " feature_keys: The five keys in the batch dict that hold the event\n", + " tuple components, in the order\n", + " ``[timepoints, values, features, delta_time, delta_value]``.\n", + " label_key: Key in the ``targets`` dict holding the real outcome\n", + " (e.g. ``\"28_day_died\"``).\n", + " mode: PyHealth task mode. Only ``\"binary\"`` is currently supported.\n", + " n_features: Size of the feature vocabulary.\n", + " hidden_dim: Hidden size for the encoder and LSTM.\n", + " cnn_layers: Number of CNN pooling blocks.\n", + " dropout: Dropout between LSTM layers.\n", + " output_dim: Output dimensionality (must be 1 for binary mortality).\n", + " scaling: Per-feature mean/std dictionary (see ``CNNLSTMPredictor``).\n", + " features_vocab: Ordered list of feature names keyed by ``scaling``.\n", + " td_alpha: Target-network EMA coefficient. Paper uses 0.99.\n", + " gamma: Discount factor on the bootstrapped next-state value. Default\n", + " 1.0 matches the paper.\n", + " pos_weight: Optional positive-class weight for supervised mode only.\n", + " device: Device string passed to the underlying predictors.\n", + "\n", + " Raises:\n", + " ValueError: If ``mode`` is not ``\"binary\"``.\n", + " \"\"\"\n", + "\n", + " def __init__(\n", + " self,\n", + " dataset: SampleEHRDataset,\n", + " feature_keys: List[str],\n", + " label_key: str,\n", + " mode: str = \"binary\",\n", + " n_features: int = 128,\n", + " hidden_dim: int = 32,\n", + " cnn_layers: int = 2,\n", + " dropout: float = 0.5,\n", + " output_dim: int = 1,\n", + " scaling: Optional[Mapping[str, Any]] = None,\n", + " features_vocab: Optional[List[str]] = None,\n", + " td_alpha: float = 0.99,\n", + " gamma: float = 1.0,\n", + " pos_weight: Optional[torch.Tensor] = None,\n", + " device: str = \"cpu\",\n", + " ) -> None:\n", + " if mode != \"binary\":\n", + " raise ValueError(\n", + " f\"TDICUMortalityModel only supports mode='binary', got {mode!r}\"\n", + " )\n", + " if scaling is None or features_vocab is None:\n", + " raise ValueError(\n", + " \"scaling and features_vocab are required - see the \"\n", + " \"example script for how to build them from training data.\"\n", + " )\n", + "\n", + " # Adapt the SampleEHRDataset for BaseModel's expectations.\n", + " dataset.feature_keys = feature_keys\n", + " dataset.label_key = label_key\n", + " dataset.mode = mode\n", + " if not hasattr(dataset, \"input_schema\"):\n", + " dataset.input_schema = {k: \"sequence\" for k in feature_keys}\n", + " if not hasattr(dataset, \"output_schema\"):\n", + " dataset.output_schema = {label_key: mode}\n", + "\n", + " super().__init__(dataset)\n", + "\n", + " self.feature_keys = feature_keys\n", + " self.label_key = label_key\n", + " self.mode = mode\n", + " self.device_name = device\n", + " self.td_alpha = td_alpha\n", + " self.gamma = gamma\n", + "\n", + " self.online_net = CNNLSTMPredictor(\n", + " n_features=n_features,\n", + " features=features_vocab,\n", + " output_dim=output_dim,\n", + " scaling=scaling,\n", + " cnn_layers=cnn_layers,\n", + " hidden_dim=hidden_dim,\n", + " dropout=dropout,\n", + " device=device,\n", + " )\n", + " self.target_net = CNNLSTMPredictor(\n", + " n_features=n_features,\n", + " features=features_vocab,\n", + " output_dim=output_dim,\n", + " scaling=scaling,\n", + " cnn_layers=cnn_layers,\n", + " hidden_dim=hidden_dim,\n", + " dropout=dropout,\n", + " device=device,\n", + " )\n", + " self.target_net.load_state_dict(deepcopy(self.online_net.state_dict()))\n", + "\n", + " # pos_weight only ever applied in supervised mode. A continuous TD\n", + " # target combined with pos_weight would silently yield wrong grads.\n", + " self.supervised_loss = WeightedBCELoss(pos_weight=pos_weight)\n", + " self.td_loss = WeightedBCELoss(pos_weight=None)\n", + "\n", + " self.to(device)\n", + "\n", + " # -- PyHealth BaseModel abstract methods ---------------------------------\n", + "\n", + " def prepare_labels(self, labels: torch.Tensor) -> torch.Tensor:\n", + " \"\"\"Format labels into the BCE-compatible shape.\n", + "\n", + " Args:\n", + " labels: Tensor of shape ``[batch]`` or ``[batch, 1]``.\n", + "\n", + " Returns:\n", + " Float tensor of shape ``[batch, 1]``.\n", + " \"\"\"\n", + " return labels.float().view(-1, 1)\n", + "\n", + " def get_loss_function(self) -> nn.Module:\n", + " \"\"\"Return the default (supervised) loss function.\n", + "\n", + " The TD path uses ``self.td_loss`` internally. ``get_loss_function``\n", + " exists to satisfy PyHealth's ``BaseModel`` contract and returns the\n", + " supervised (optionally pos-weighted) loss used when ``train_td=False``.\n", + "\n", + " Returns:\n", + " The supervised ``WeightedBCELoss`` instance.\n", + " \"\"\"\n", + " return self.supervised_loss\n", + "\n", + " # -- TD-specific helpers --------------------------------------------------\n", + "\n", + " def predict_current(\n", + " self, batch: Mapping[str, torch.Tensor]\n", + " ) -> Tuple[torch.Tensor, torch.Tensor]:\n", + " \"\"\"Run the online net on the current-state window.\n", + "\n", + " Args:\n", + " batch: Mapping with the five feature keys listed in\n", + " ``self.feature_keys``.\n", + "\n", + " Returns:\n", + " Tuple ``(probs, logits)``.\n", + " \"\"\"\n", + " return self.online_net(\n", + " batch[\"timepoints\"],\n", + " batch[\"values\"],\n", + " batch[\"features\"],\n", + " batch[\"delta_time\"],\n", + " batch[\"delta_value\"],\n", + " )\n", + "\n", + " @torch.no_grad()\n", + " def predict_next_target(\n", + " self, batch: Mapping[str, torch.Tensor]\n", + " ) -> torch.Tensor:\n", + " \"\"\"Run the target net on the next-state window.\n", + "\n", + " The target net is placed in eval mode so BatchNorm stats are frozen\n", + " and LSTM dropout is inactive. The output carries no gradient.\n", + "\n", + " Args:\n", + " batch: Mapping with the five ``next_*`` keys\n", + " (``next_timepoints``, ``next_values``, ``next_features``,\n", + " ``next_delta_time``, ``next_delta_value``).\n", + "\n", + " Returns:\n", + " Detached probabilities of shape ``[batch, output_dim]``.\n", + " \"\"\"\n", + " self.target_net.eval()\n", + " probs, _ = self.target_net(\n", + " batch[\"next_timepoints\"],\n", + " batch[\"next_values\"],\n", + " batch[\"next_features\"],\n", + " batch[\"next_delta_time\"],\n", + " batch[\"next_delta_value\"],\n", + " )\n", + " return probs.detach()\n", + "\n", + " def compute_td_target(\n", + " self,\n", + " batch: Mapping[str, torch.Tensor],\n", + " targets: Mapping[str, torch.Tensor],\n", + " ) -> torch.Tensor:\n", + " \"\"\"Compute the TD target for each transition in the batch.\n", + "\n", + " At terminal transitions (``isterminal > 0.5``) the target is the real\n", + " mortality label; at non-terminal transitions it is\n", + " ``gamma * target_net(next_state)``.\n", + "\n", + " Args:\n", + " batch: Mapping containing at minimum ``isterminal`` and the five\n", + " ``next_*`` keys.\n", + " targets: Mapping containing ``self.label_key``.\n", + "\n", + " Returns:\n", + " Detached target tensor of shape ``[batch, output_dim]``.\n", + " \"\"\"\n", + " next_probs = self.predict_next_target(batch)\n", + " real_reward = targets[self.label_key].float().view_as(next_probs)\n", + " is_terminal = (batch[\"isterminal\"] > 0.5).view_as(next_probs)\n", + " td_target = torch.where(is_terminal, real_reward, self.gamma * next_probs)\n", + " return td_target.detach()\n", + "\n", + " def soft_update_target(self) -> None:\n", + " \"\"\"Update the target net with an EMA step from the online net.\n", + "\n", + " Uses ``self.td_alpha``. Call exactly once after each optimizer step.\n", + " \"\"\"\n", + " self.target_net.soft_update(self.online_net, alpha=self.td_alpha)\n", + "\n", + " # -- Forward pass ---------------------------------------------------------\n", + "\n", + " def forward(\n", + " self,\n", + " batch: Mapping[str, torch.Tensor],\n", + " targets: Optional[Mapping[str, torch.Tensor]] = None,\n", + " train_td: bool = False,\n", + " ) -> Dict[str, Optional[torch.Tensor]]:\n", + " \"\"\"Run the model.\n", + "\n", + " Args:\n", + " batch: Mapping with ``self.feature_keys`` and (for TD training)\n", + " also ``isterminal`` and the five ``next_*`` keys.\n", + " targets: Optional mapping with ``self.label_key``. If omitted,\n", + " no loss is computed.\n", + " train_td: Whether to use the temporal-difference loss. When\n", + " ``True``, the TD target is computed via the target net;\n", + " when ``False``, the real label is used directly.\n", + "\n", + " Returns:\n", + " A dict with keys:\n", + "\n", + " * ``loss``: scalar loss tensor or ``None`` if ``targets is None``.\n", + " * ``y_prob``: ``[batch, output_dim]`` probabilities from the\n", + " online net.\n", + " * ``y_true``: the supplied label tensor, or ``None``.\n", + " * ``logit``: ``[batch, output_dim]`` raw model output.\n", + " \"\"\"\n", + " probs, logits = self.predict_current(batch)\n", + "\n", + " out: Dict[str, Optional[torch.Tensor]] = {\n", + " \"loss\": None,\n", + " \"y_prob\": probs,\n", + " \"y_true\": targets[self.label_key] if targets is not None else None,\n", + " \"logit\": logits,\n", + " }\n", + "\n", + " if targets is not None:\n", + " if train_td:\n", + " td_target = self.compute_td_target(batch, targets)\n", + " out[\"loss\"] = self.td_loss(logits, td_target)\n", + " else:\n", + " supervised_target = self.prepare_labels(targets[self.label_key])\n", + " out[\"loss\"] = self.supervised_loss(logits, supervised_target)\n", + "\n", + " return out\n", + "\n", + " # -- Monte Carlo dropout for uncertainty quantification ------------------\n", + "\n", + " @staticmethod\n", + " def _enable_dropout(module: nn.Module) -> None:\n", + " \"\"\"Put only ``Dropout`` and ``LSTM`` layers into train mode.\n", + "\n", + " PyTorch's ``LSTM`` applies dropout between stacked layers only when\n", + " the module is in train mode, so we flip both ``nn.Dropout*`` and\n", + " ``nn.LSTM`` modules to train while leaving BatchNorm etc. in eval.\n", + "\n", + " Args:\n", + " module: Any ``nn.Module`` whose dropout we want to activate.\n", + " \"\"\"\n", + " for m in module.modules():\n", + " if isinstance(m, (nn.Dropout, nn.Dropout2d, nn.Dropout3d, nn.LSTM)):\n", + " m.train()\n", + "\n", + " @torch.no_grad()\n", + " def predict_with_confidence(\n", + " self,\n", + " batch: Mapping[str, torch.Tensor],\n", + " n_mc_samples: int = 30,\n", + " high_conf_threshold: float = 0.005,\n", + " low_conf_threshold: float = 0.01,\n", + " ) -> Dict[str, torch.Tensor]:\n", + " \"\"\"Return mortality predictions with per-patient confidence.\n", + "\n", + " Runs the online network ``n_mc_samples`` times with dropout active\n", + " (BatchNorm frozen in eval), then reports the MC mean, standard\n", + " deviation, and a ~95% credible interval per sample. Two boolean\n", + " flags classify each prediction as high / low confidence relative\n", + " to user-tunable thresholds on the MC standard deviation.\n", + "\n", + " This is the paper's base predictor extended with Monte Carlo\n", + " dropout, providing a per-patient uncertainty estimate alongside\n", + " the point mortality probability. Empirically, higher MC standard\n", + " deviation correlates with higher true mortality rate, so the\n", + " confidence score also serves as an auxiliary clinical-triage\n", + " signal.\n", + "\n", + " Example:\n", + " >>> out = model.predict_with_confidence(batch, n_mc_samples=30)\n", + " >>> mortality = out[\"mortality_prob\"] # shape [batch]\n", + " >>> ci_lo = out[\"ci_95_lower\"] # shape [batch]\n", + " >>> ci_hi = out[\"ci_95_upper\"] # shape [batch]\n", + " >>> uncertain = out[\"is_low_confidence\"] # bool tensor\n", + "\n", + " Args:\n", + " batch: Mapping with ``self.feature_keys`` (the five event\n", + " tuple components). ``next_*`` keys are not required since\n", + " only the online network is sampled.\n", + " n_mc_samples: Number of stochastic forward passes. More passes\n", + " tighten the uncertainty estimate at linear inference cost.\n", + " 30 is a reasonable default; 50-100 for research-grade.\n", + " high_conf_threshold: Predictions with MC standard deviation\n", + " below this value are flagged \"high confidence\".\n", + " low_conf_threshold: Predictions with MC standard deviation\n", + " above this value are flagged \"low confidence\".\n", + "\n", + " Returns:\n", + " Dict with tensor entries all shaped ``[batch]``:\n", + "\n", + " * ``mortality_prob``: MC-averaged probability.\n", + " * ``confidence_std``: standard deviation across MC samples.\n", + " * ``ci_95_lower``: ``clip(mean - 2*std, 0, 1)``.\n", + " * ``ci_95_upper``: ``clip(mean + 2*std, 0, 1)``.\n", + " * ``is_high_confidence``: bool, std < ``high_conf_threshold``.\n", + " * ``is_low_confidence``: bool, std > ``low_conf_threshold``.\n", + " \"\"\"\n", + " self.online_net.eval()\n", + " self._enable_dropout(self.online_net)\n", + "\n", + " probs_mc: List[torch.Tensor] = []\n", + " for _ in range(n_mc_samples):\n", + " probs, _ = self.online_net(\n", + " batch[\"timepoints\"],\n", + " batch[\"values\"],\n", + " batch[\"features\"],\n", + " batch[\"delta_time\"],\n", + " batch[\"delta_value\"],\n", + " )\n", + " probs_mc.append(probs.float().squeeze(-1))\n", + "\n", + " stacked = torch.stack(probs_mc, dim=1) # [batch, n_mc_samples]\n", + " mean_prob = stacked.mean(dim=1)\n", + " std_prob = stacked.std(dim=1)\n", + "\n", + " ci_lower = (mean_prob - 2.0 * std_prob).clamp(0.0, 1.0)\n", + " ci_upper = (mean_prob + 2.0 * std_prob).clamp(0.0, 1.0)\n", + "\n", + " return {\n", + " \"mortality_prob\": mean_prob,\n", + " \"confidence_std\": std_prob,\n", + " \"ci_95_lower\": ci_lower,\n", + " \"ci_95_upper\": ci_upper,\n", + " \"is_high_confidence\": std_prob < high_conf_threshold,\n", + " \"is_low_confidence\": std_prob > low_conf_threshold,\n", + " }\n" + ], + "metadata": { + "id": "f_5xjMIAfaRO" + }, + "execution_count": 35, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "## 7) Build the Model and Train" + ], + "metadata": { + "id": "okqvNnZsWlDT" + } + }, + { + "cell_type": "code", + "source": [ + "def build_model(\n", + " n_features: int,\n", + " feature_names: List[str],\n", + " scaling: Dict[str, Any],\n", + " hidden_dim: int = 32,\n", + " cnn_layers: int = 2,\n", + " dropout: float = 0.5,\n", + " td_alpha: float = 0.99,\n", + " device: str = \"cpu\",\n", + ") -> TDICUMortalityModel:\n", + " \"\"\"Build a fresh TDICUMortalityModel.\n", + "\n", + " Args:\n", + " n_features: Feature vocabulary size.\n", + " feature_names: Ordered feature names.\n", + " scaling: Per-feature mean/std dict.\n", + " hidden_dim: Encoder hidden size.\n", + " cnn_layers: Number of CNN blocks.\n", + " dropout: LSTM dropout.\n", + " td_alpha: Target-network EMA coefficient.\n", + " device: Device string.\n", + "\n", + " Returns:\n", + " Initialized model.\n", + " \"\"\"\n", + " dataset = SampleEHRDataset(samples=[], dataset_name=\"td_icu_demo\")\n", + " return TDICUMortalityModel(\n", + " dataset=dataset,\n", + " feature_keys=[\n", + " \"timepoints\", \"values\", \"features\",\n", + " \"delta_time\", \"delta_value\",\n", + " ],\n", + " label_key=\"28-day-died\",\n", + " mode=\"binary\",\n", + " n_features=n_features,\n", + " hidden_dim=hidden_dim,\n", + " cnn_layers=cnn_layers,\n", + " dropout=dropout,\n", + " scaling=scaling,\n", + " features_vocab=feature_names,\n", + " td_alpha=td_alpha,\n", + " device=device,\n", + " )\n", + "\n", + "\n", + "def _move(\n", + " batch: Dict[str, torch.Tensor],\n", + " targets: Dict[str, torch.Tensor],\n", + " device: str,\n", + ") -> Tuple[Dict[str, torch.Tensor], Dict[str, torch.Tensor]]:\n", + " \"\"\"Move batch + targets to ``device`` with non-blocking transfers.\"\"\"\n", + " batch = {k: v.to(device, non_blocking=True) for k, v in batch.items()}\n", + " targets = {\n", + " k: v.to(device, non_blocking=True) for k, v in targets.items()\n", + " }\n", + " return batch, targets\n" + ], + "metadata": { + "id": "LUjZw0YlWiJE" + }, + "execution_count": 36, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "def train_one_epoch_real(\n", + " model: TDICUMortalityModel,\n", + " loader: DataLoader,\n", + " optimizer: torch.optim.Optimizer,\n", + " device: str,\n", + " max_batches: Optional[int] = None,\n", + ") -> Dict[str, float]:\n", + " \"\"\"Train one epoch on a real DataLoader.\n", + "\n", + " Args:\n", + " model: Model to train.\n", + " loader: DataLoader yielding ``(batch_dict, targets_dict, metas)``.\n", + " optimizer: Optimizer for ``model.online_net`` parameters.\n", + " device: Device string.\n", + " max_batches: Cap on number of batches (for smoke tests).\n", + "\n", + " Returns:\n", + " Metrics dict with ``loss``, ``auroc``, ``auprc``, ``pos_rate``.\n", + " \"\"\"\n", + " model.train()\n", + " losses, y_true_all, y_prob_all = [], [], []\n", + " n_batches = (\n", + " len(loader) if max_batches is None else min(len(loader), max_batches)\n", + " )\n", + " pbar = tqdm(enumerate(loader), total=n_batches, desc=\"train\")\n", + " for step, (batch, targets, _) in pbar:\n", + " if max_batches is not None and step >= max_batches:\n", + " break\n", + " batch, targets = _move(batch, targets, device)\n", + " out = model(batch, targets=targets, train_td=True)\n", + " if not torch.isfinite(out[\"loss\"]):\n", + " continue\n", + " optimizer.zero_grad(set_to_none=True)\n", + " out[\"loss\"].backward()\n", + " torch.nn.utils.clip_grad_norm_(model.online_net.parameters(), 1.0)\n", + " optimizer.step()\n", + " model.soft_update_target()\n", + " losses.append(out[\"loss\"].item())\n", + " y_true_all.append(out[\"y_true\"].detach().float().cpu().numpy().ravel())\n", + " y_prob_all.append(out[\"y_prob\"].detach().float().cpu().numpy().ravel())\n", + "\n", + " y_true = np.concatenate(y_true_all) if y_true_all else np.empty(0)\n", + " y_prob = np.concatenate(y_prob_all) if y_prob_all else np.empty(0)\n", + " return _metrics(losses, y_true, y_prob)\n", + "\n", + "\n", + "@torch.no_grad()\n", + "def eval_real(\n", + " model: TDICUMortalityModel,\n", + " loader: DataLoader,\n", + " device: str,\n", + " max_batches: Optional[int] = None,\n", + ") -> Dict[str, float]:\n", + " \"\"\"Evaluate in supervised mode on a real DataLoader.\n", + "\n", + " Args:\n", + " model: Model to evaluate.\n", + " loader: DataLoader.\n", + " device: Device string.\n", + " max_batches: Cap on batches.\n", + "\n", + " Returns:\n", + " Metrics dict.\n", + " \"\"\"\n", + " model.eval()\n", + " losses, y_true_all, y_prob_all = [], [], []\n", + " n_batches = (\n", + " len(loader) if max_batches is None else min(len(loader), max_batches)\n", + " )\n", + " pbar = tqdm(enumerate(loader), total=n_batches, desc=\"val\")\n", + " for step, (batch, targets, _) in pbar:\n", + " if max_batches is not None and step >= max_batches:\n", + " break\n", + " batch, targets = _move(batch, targets, device)\n", + " out = model(batch, targets=targets, train_td=False)\n", + " losses.append(out[\"loss\"].item())\n", + " y_true_all.append(out[\"y_true\"].float().cpu().numpy().ravel())\n", + " y_prob_all.append(out[\"y_prob\"].float().cpu().numpy().ravel())\n", + " y_true = np.concatenate(y_true_all) if y_true_all else np.empty(0)\n", + " y_prob = np.concatenate(y_prob_all) if y_prob_all else np.empty(0)\n", + " return _metrics(losses, y_true, y_prob)\n", + "\n", + "\n", + "def _metrics(\n", + " losses: List[float],\n", + " y_true: np.ndarray,\n", + " y_prob: np.ndarray,\n", + ") -> Dict[str, float]:\n", + " \"\"\"Summarize a training / eval pass.\"\"\"\n", + " out: Dict[str, float] = {\n", + " \"loss\": float(np.mean(losses)) if losses else float(\"nan\"),\n", + " \"n_samples\": int(len(y_true)),\n", + " }\n", + " if len(y_true) > 0 and len(np.unique(y_true)) > 1:\n", + " out[\"auroc\"] = float(roc_auc_score(y_true, y_prob))\n", + " out[\"auprc\"] = float(average_precision_score(y_true, y_prob))\n", + " else:\n", + " out[\"auroc\"] = float(\"nan\")\n", + " out[\"auprc\"] = float(\"nan\")\n", + " out[\"pos_rate\"] = float(y_true.mean()) if len(y_true) else float(\"nan\")\n", + " return out\n" + ], + "metadata": { + "id": "lrIy3MvCgLSi" + }, + "execution_count": 37, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "def _load_feature_names_and_scaling(\n", + " data_root: Path,\n", + ") -> Tuple[List[str], Dict[str, Any]]:\n", + " \"\"\"Load feature names and scaling dict from standard filenames.\"\"\"\n", + " feature_file = data_root / \"features.txt\"\n", + " scaling_file = data_root / \"scaling.pt\"\n", + " train_h5 = data_root / \"h5train_fixed.hdf5\"\n", + "\n", + " if not feature_file.exists():\n", + " raise FileNotFoundError(\n", + " f\"Expected {feature_file}. Save the ordered feature list one \"\n", + " \"name per line.\"\n", + " )\n", + " feature_names = [\n", + " ln.strip()\n", + " for ln in feature_file.read_text().splitlines()\n", + " if ln.strip()\n", + " ]\n", + "\n", + " if scaling_file.exists():\n", + " scaling = torch.load(\n", + " scaling_file, map_location=\"cpu\", weights_only=False,\n", + " )\n", + " print(f\"loaded scaling from {scaling_file}\")\n", + " elif train_h5.exists():\n", + " print(f\"computing scaling from {train_h5}...\")\n", + " scaling = compute_scaling_from_h5(train_h5, feature_names)\n", + " torch.save(scaling, scaling_file)\n", + " print(f\"saved {scaling_file}\")\n", + " else:\n", + " raise FileNotFoundError(\n", + " f\"Neither {scaling_file} nor {train_h5} found. Run --mode build \"\n", + " \"first or provide both.\"\n", + " )\n", + " return feature_names, scaling\n", + "\n", + "def _save_checkpoint(\n", + " path: Path,\n", + " model: TDICUMortalityModel,\n", + " optimizer: torch.optim.Optimizer,\n", + " epoch: int,\n", + " best_val_auroc: float,\n", + ") -> None:\n", + " \"\"\"Save a checkpoint dict (model + optimizer + metadata).\n", + "\n", + " The format mirrors the training notebook shipped with the paper so\n", + " that ``run_evaluate`` can consume checkpoints from either source.\n", + "\n", + " Args:\n", + " path: Destination ``.pt`` file.\n", + " model: The TDICUMortalityModel to snapshot.\n", + " optimizer: Optimizer whose state is saved alongside weights.\n", + " epoch: Current epoch number (for resume / provenance).\n", + " best_val_auroc: Best validation AUROC seen so far.\n", + " \"\"\"\n", + " path.parent.mkdir(parents=True, exist_ok=True)\n", + " torch.save(\n", + " {\n", + " \"model_state_dict\": model.state_dict(),\n", + " \"optimizer_state_dict\": optimizer.state_dict(),\n", + " \"epoch\": epoch,\n", + " \"best_val_auroc\": best_val_auroc,\n", + " \"label_key\": model.label_key,\n", + " },\n", + " path,\n", + " )\n", + "\n", + "def _load_checkpoint_into(\n", + " path: Path,\n", + " model: TDICUMortalityModel,\n", + " optimizer: Optional[torch.optim.Optimizer] = None,\n", + " device: str = \"cpu\",\n", + ") -> Dict[str, Any]:\n", + " \"\"\"Load a checkpoint into ``model`` (and optionally ``optimizer``).\n", + "\n", + " Handles both wrapped dicts (``{\"model_state_dict\": ..., ...}``) and\n", + " raw ``state_dict`` files.\n", + "\n", + " Args:\n", + " path: Source ``.pt`` file.\n", + " model: Model to load weights into (mutated in place).\n", + " optimizer: Optional optimizer to restore alongside weights.\n", + " device: ``map_location`` target for ``torch.load``.\n", + "\n", + " Returns:\n", + " The full checkpoint dict if the file was wrapped, else\n", + " ``{\"model_state_dict\": state_dict}``.\n", + " \"\"\"\n", + " ckpt = torch.load(path, map_location=device, weights_only=False)\n", + " if isinstance(ckpt, dict) and \"model_state_dict\" in ckpt:\n", + " model.load_state_dict(ckpt[\"model_state_dict\"])\n", + " if optimizer is not None and \"optimizer_state_dict\" in ckpt:\n", + " optimizer.load_state_dict(ckpt[\"optimizer_state_dict\"])\n", + " return ckpt\n", + " # Raw state_dict - backward compatibility\n", + " model.load_state_dict(ckpt)\n", + " return {\"model_state_dict\": ckpt}\n", + "\n", + "def run_real_train(\n", + " data_root: Path,\n", + " build_first: bool,\n", + " batch_size: int = 64,\n", + " num_workers: int = 4,\n", + " max_train_batches: Optional[int] = None,\n", + " max_val_batches: Optional[int] = None,\n", + " epochs: int = 1,\n", + " device: str = \"cpu\",\n", + ") -> None:\n", + " \"\"\"Train on real MIMIC-IV data (optionally build HDF5 first).\n", + "\n", + " Args:\n", + " data_root: Directory containing parquets and/or pre-built HDF5s.\n", + " build_first: If ``True``, build HDF5 from parquets before training.\n", + " batch_size: DataLoader batch size.\n", + " num_workers: DataLoader workers.\n", + " max_train_batches: Cap on train batches per epoch.\n", + " max_val_batches: Cap on val batches.\n", + " epochs: Number of epochs.\n", + " device: Device string.\n", + " \"\"\"\n", + " print(f\"=== Mode: {'build' if build_first else 'load'} ===\")\n", + " data_root = Path(data_root)\n", + " ckpt_dir = Path(checkpoint_dir) if checkpoint_dir else data_root / \"checkpoints\"\n", + " ckpt_dir.mkdir(parents=True, exist_ok=True)\n", + " last_ckpt = ckpt_dir / \"last.pt\"\n", + " best_ckpt = ckpt_dir / \"best.pt\"\n", + " print(f\"checkpoint dir: {ckpt_dir}\")\n", + "\n", + " if build_first:\n", + " print(\"Building HDF5 files from parquets...\")\n", + " for split in [\"val\", \"train\", \"test\"]:\n", + " build_hdf5_from_parquets(\n", + " parquet_dir=data_root, out_dir=data_root, split=split,\n", + " )\n", + "\n", + " feature_names, scaling = _load_feature_names_and_scaling(data_root)\n", + " n_features = len(feature_names)\n", + " print(f\"n_features={n_features}\")\n", + "\n", + " label_keys = [\n", + " \"1-day-died\", \"3-day-died\", \"7-day-died\",\n", + " \"14-day-died\", \"28-day-died\",\n", + " ]\n", + "\n", + " train_ds = H5SequenceDataset(\n", + " data_root / \"h5train_fixed.hdf5\", label_keys=label_keys,\n", + " )\n", + " val_ds = H5SequenceDataset(\n", + " data_root / \"h5val_fixed.hdf5\", label_keys=label_keys,\n", + " )\n", + " test_ds = H5SequenceDataset(\n", + " data_root / \"h5test_fixed.hdf5\", label_keys=label_keys,\n", + " )\n", + " print(\n", + " f\"train samples: {len(train_ds):,}, \"\n", + " f\"val samples: {len(val_ds):,}, \"\n", + " f\"test samples: {len(test_ds):,}\"\n", + " )\n", + "\n", + " loader_kwargs: Dict[str, Any] = dict(\n", + " batch_size=batch_size,\n", + " num_workers=num_workers,\n", + " pin_memory=(device == \"cuda\"),\n", + " collate_fn=collate_td,\n", + " persistent_workers=(num_workers > 0),\n", + " )\n", + " train_loader = DataLoader(train_ds, shuffle=True, **loader_kwargs)\n", + " val_loader = DataLoader(val_ds, shuffle=False, **loader_kwargs)\n", + " test_loader = DataLoader(test_ds, shuffle=False, **loader_kwargs)\n", + "\n", + " model = build_model(\n", + " n_features=n_features,\n", + " feature_names=feature_names,\n", + " scaling=scaling,\n", + " hidden_dim=32,\n", + " cnn_layers=2,\n", + " dropout=0.5,\n", + " td_alpha=0.99,\n", + " device=device,\n", + " )\n", + " n_params = sum(p.numel() for p in model.parameters() if p.requires_grad)\n", + " lr = 1.0 / math.sqrt(n_params)\n", + " wd = 1 / (lr * len(train_ds))\n", + "\n", + " optimizer = torch.optim.AdamW(\n", + " model.online_net.parameters(), lr=lr, weight_decay=wd,\n", + " )\n", + " best_val_auroc = -1.0\n", + "\n", + "\n", + " for epoch in range(1, epochs + 1):\n", + " train_metrics = train_one_epoch_real(\n", + " model, train_loader, optimizer, device, max_train_batches,\n", + " )\n", + " val_metrics = eval_real(model, val_loader, device, max_val_batches)\n", + " print(f\"[epoch {epoch}] train {train_metrics} | val {val_metrics}\")\n", + "\n", + " # Always save last.pt after each epoch\n", + " _save_checkpoint(\n", + " last_ckpt, model, optimizer, epoch, best_val_auroc,\n", + " )\n", + " # Save best.pt when val AUROC improves\n", + " val_auroc = val_metrics.get(\"auroc\", float(\"nan\"))\n", + " if np.isfinite(val_auroc) and val_auroc > best_val_auroc:\n", + " best_val_auroc = val_auroc\n", + " _save_checkpoint(\n", + " best_ckpt, model, optimizer, epoch, best_val_auroc,\n", + " )\n", + " print(\n", + " f\" [best] val_auroc={best_val_auroc:.4f} \"\n", + " f\"saved to {best_ckpt}\"\n", + " )\n" + ], + "metadata": { + "id": "6LOn13hWgrWM" + }, + "execution_count": 51, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "## 8) Execute the Model Training" + ], + "metadata": { + "id": "eoLnGOp0Vqbq" + } + }, + { + "cell_type": "code", + "source": [ + "data_root = out_root\n", + "run_real_train(\n", + " data_root=data_root,\n", + " build_first=True,\n", + " batch_size=128,\n", + " num_workers=6,\n", + " max_train_batches=None,\n", + " max_val_batches=None,\n", + " epochs=1,\n", + " device=device,\n", + ")" + ], + "metadata": { + "id": "VIZAkkegiRhu" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "## 9) Evaluation functions" + ], + "metadata": { + "id": "Tu3JDL8wVxeh" + } + }, + { + "cell_type": "code", + "source": [ + "@torch.no_grad()\n", + "def evaluate_all_horizons(\n", + " model: TDICUMortalityModel,\n", + " loader: DataLoader,\n", + " label_keys: List[str],\n", + " device: str,\n", + " max_batches: Optional[int] = None,\n", + ") -> Dict[str, Dict[str, float]]:\n", + " \"\"\"Evaluate the model across all mortality horizons on a test loader.\n", + "\n", + " The TD model is trained on a single horizon (e.g. 28-day mortality)\n", + " but its probability output is used to score every horizon. Matches\n", + " the paper's Table 2 layout.\n", + "\n", + " Args:\n", + " model: Trained ``TDICUMortalityModel``.\n", + " loader: Test ``DataLoader`` (must yield all ``label_keys`` in the\n", + " targets dict).\n", + " label_keys: List of horizon names (e.g. ``[\"1-day-died\", ...]``).\n", + " device: Device string.\n", + " max_batches: Optional cap for smoke tests.\n", + "\n", + " Returns:\n", + " Mapping from horizon name to a dict with ``auroc``, ``auprc``,\n", + " ``pos_rate``, ``n``.\n", + " \"\"\"\n", + " model.eval()\n", + " y_prob_all: List[np.ndarray] = []\n", + " y_true_by_key: Dict[str, List[np.ndarray]] = {k: [] for k in label_keys}\n", + "\n", + " n_batches = (\n", + " len(loader) if max_batches is None else min(len(loader), max_batches)\n", + " )\n", + " pbar = tqdm(enumerate(loader), total=n_batches, desc=\"test\")\n", + " for step, (batch, targets, _) in pbar:\n", + " if max_batches is not None and step >= max_batches:\n", + " break\n", + " batch, targets = _move(batch, targets, device)\n", + " probs, _ = model.online_net(\n", + " batch[\"timepoints\"], batch[\"values\"], batch[\"features\"],\n", + " batch[\"delta_time\"], batch[\"delta_value\"],\n", + " )\n", + " y_prob_all.append(probs.float().cpu().numpy().ravel())\n", + " for k in label_keys:\n", + " y_true_by_key[k].append(targets[k].float().cpu().numpy().ravel())\n", + "\n", + " y_prob = np.concatenate(y_prob_all)\n", + " results: Dict[str, Dict[str, float]] = {}\n", + " for k in label_keys:\n", + " y_true = np.concatenate(y_true_by_key[k])\n", + " if len(np.unique(y_true)) > 1:\n", + " results[k] = {\n", + " \"auroc\": float(roc_auc_score(y_true, y_prob)),\n", + " \"auprc\": float(average_precision_score(y_true, y_prob)),\n", + " }\n", + " else:\n", + " results[k] = {\n", + " \"auroc\": float(\"nan\"),\n", + " \"auprc\": float(\"nan\"),\n", + " }\n", + " results[k][\"pos_rate\"] = float(y_true.mean())\n", + " results[k][\"n\"] = int(len(y_true))\n", + " return results\n", + "\n", + "\n", + "def print_horizon_table(\n", + " results: Dict[str, Dict[str, float]],\n", + " trained_on: str = \"28-day-died\",\n", + ") -> None:\n", + " \"\"\"Pretty-print the horizon results table (paper-style).\n", + "\n", + " Args:\n", + " results: Output of ``evaluate_all_horizons``.\n", + " trained_on: Label key the model was actually trained on.\n", + " \"\"\"\n", + " bar = \"=\" * 60\n", + " print(f\"\\n{bar}\")\n", + " print(f\"TEST SET RESULTS (trained on {trained_on} via TD)\")\n", + " print(bar)\n", + " print(\n", + " f\"{'horizon':<12} {'AUROC':>7} {'AUPRC':>7} \"\n", + " f\"{'pos_rate':>8} {'n':>10}\"\n", + " )\n", + " for k, r in results.items():\n", + " print(\n", + " f\"{k:<12} {r['auroc']:>7.4f} {r['auprc']:>7.4f} \"\n", + " f\"{r['pos_rate']:>8.4f} {r['n']:>10,}\"\n", + " )\n" + ], + "metadata": { + "id": "DyJCGlXptloB" + }, + "execution_count": 39, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "## 10) EXTENSION - MC DROPOUT" + ], + "metadata": { + "id": "kQjVaZSOWCuU" + } + }, + { + "cell_type": "code", + "source": [ + "@torch.no_grad()\n", + "def mc_dropout_predict_test(\n", + " model: TDICUMortalityModel,\n", + " loader: DataLoader,\n", + " device: str,\n", + " n_mc_samples: int = 30,\n", + " label_key: str = \"28-day-died\",\n", + " max_batches: Optional[int] = None,\n", + ") -> Dict[str, np.ndarray]:\n", + " \"\"\"Run MC dropout inference across a full test loader.\n", + "\n", + " For each sample, runs ``n_mc_samples`` stochastic forward passes of\n", + " the online network (with dropout active, BatchNorm frozen) and\n", + " collects the full distribution of predictions.\n", + "\n", + " Args:\n", + " model: Trained ``TDICUMortalityModel``.\n", + " loader: Test ``DataLoader``.\n", + " device: Device string.\n", + " n_mc_samples: Number of stochastic forward passes per input.\n", + " label_key: Which label to compare against.\n", + " max_batches: Optional cap for smoke tests.\n", + "\n", + " Returns:\n", + " Dict of numpy arrays, each length ``N`` (total test samples):\n", + "\n", + " * ``y_true``: ground-truth labels\n", + " * ``mean_prob``: MC-averaged probabilities\n", + " * ``std_prob``: MC standard deviation (epistemic uncertainty)\n", + " * ``ci_lower``, ``ci_upper``: 95% credible interval bounds\n", + " \"\"\"\n", + " model.online_net.eval()\n", + " # activate only dropout, keep BatchNorm frozen\n", + " for m in model.online_net.modules():\n", + " if isinstance(m, (nn.Dropout, nn.Dropout2d, nn.Dropout3d, nn.LSTM)):\n", + " m.train()\n", + "\n", + " y_true_all: List[np.ndarray] = []\n", + " mean_all: List[np.ndarray] = []\n", + " std_all: List[np.ndarray] = []\n", + "\n", + " n_batches = (\n", + " len(loader) if max_batches is None else min(len(loader), max_batches)\n", + " )\n", + " pbar = tqdm(enumerate(loader), total=n_batches, desc=\"mc-dropout\")\n", + " for step, (batch, targets, _) in pbar:\n", + " if max_batches is not None and step >= max_batches:\n", + " break\n", + " batch, targets = _move(batch, targets, device)\n", + " probs_mc = []\n", + " for _ in range(n_mc_samples):\n", + " probs, _ = model.online_net(\n", + " batch[\"timepoints\"], batch[\"values\"], batch[\"features\"],\n", + " batch[\"delta_time\"], batch[\"delta_value\"],\n", + " )\n", + " probs_mc.append(probs.float().cpu().numpy().ravel())\n", + " stacked = np.stack(probs_mc, axis=1) # [batch, n_mc]\n", + " mean_all.append(stacked.mean(axis=1))\n", + " std_all.append(stacked.std(axis=1))\n", + " y_true_all.append(targets[label_key].float().cpu().numpy().ravel())\n", + "\n", + " mean_prob = np.concatenate(mean_all)\n", + " std_prob = np.concatenate(std_all)\n", + " y_true = np.concatenate(y_true_all)\n", + " return {\n", + " \"y_true\": y_true,\n", + " \"mean_prob\": mean_prob,\n", + " \"std_prob\": std_prob,\n", + " \"ci_lower\": np.clip(mean_prob - 2.0 * std_prob, 0.0, 1.0),\n", + " \"ci_upper\": np.clip(mean_prob + 2.0 * std_prob, 0.0, 1.0),\n", + " }\n", + "\n", + "\n", + "def print_mc_dropout_analysis(mc: Dict[str, np.ndarray]) -> None:\n", + " \"\"\"Print the paper-extension's MC dropout analysis tables.\n", + "\n", + " Reports:\n", + " - Deterministic vs MC-averaged AUROC / AUPRC / Brier\n", + " - Uncertainty distribution summary\n", + " - Stratified-by-uncertainty-quintile table (the paper-extension's\n", + " key finding: higher uncertainty correlates with higher mortality)\n", + "\n", + " Args:\n", + " mc: Output of ``mc_dropout_predict_test``.\n", + " \"\"\"\n", + " y_true = mc[\"y_true\"]\n", + " mean_prob = mc[\"mean_prob\"]\n", + " std_prob = mc[\"std_prob\"]\n", + "\n", + " bar = \"=\" * 60\n", + "\n", + " # MC-averaged metrics\n", + " print(f\"\\n{bar}\")\n", + " print(\"MC-DROPOUT TEST RESULTS\")\n", + " print(bar)\n", + " if len(np.unique(y_true)) > 1:\n", + " auroc = float(roc_auc_score(y_true, mean_prob))\n", + " auprc = float(average_precision_score(y_true, mean_prob))\n", + " brier = float(np.mean((mean_prob - y_true) ** 2))\n", + " print(f\" AUROC (MC-averaged): {auroc:.4f}\")\n", + " print(f\" AUPRC (MC-averaged): {auprc:.4f}\")\n", + " print(f\" Brier (MC-averaged): {brier:.4f}\")\n", + "\n", + " # Uncertainty distribution\n", + " print(f\"\\nUncertainty distribution (std across MC samples):\")\n", + " print(f\" mean: {std_prob.mean():.4f} median: {np.median(std_prob):.4f}\")\n", + " print(\n", + " f\" p5: {np.percentile(std_prob, 5):.4f} \"\n", + " f\"p95: {np.percentile(std_prob, 95):.4f}\"\n", + " )\n", + " print(f\" max: {std_prob.max():.4f}\")\n", + "\n", + " # Stratified by uncertainty quintile\n", + " print(f\"\\n{bar}\")\n", + " print(\"STRATIFIED BY UNCERTAINTY QUINTILE\")\n", + " print(bar)\n", + " print(\n", + " f\" {'bin':<5} {'std range':<17} {'n':>6} \"\n", + " f\"{'pos_rate':>8} {'auroc':>7} {'brier':>7}\"\n", + " )\n", + " quintile_bounds = np.quantile(std_prob, np.linspace(0, 1, 6))\n", + " for i in range(5):\n", + " lo, hi = quintile_bounds[i], quintile_bounds[i + 1]\n", + " mask = (std_prob >= lo) & (std_prob <= hi) if i == 4 \\\n", + " else (std_prob >= lo) & (std_prob < hi)\n", + " if mask.sum() < 10:\n", + " continue\n", + " y_b = y_true[mask]\n", + " p_b = mean_prob[mask]\n", + " if len(np.unique(y_b)) > 1:\n", + " auroc_b = float(roc_auc_score(y_b, p_b))\n", + " else:\n", + " auroc_b = float(\"nan\")\n", + " brier_b = float(np.mean((p_b - y_b) ** 2))\n", + " print(\n", + " f\" {i + 1:<5} [{lo:.3f}, {hi:.3f}] {mask.sum():>6,} \"\n", + " f\"{y_b.mean():>8.3f} {auroc_b:>7.4f} {brier_b:>7.4f}\"\n", + " )\n", + "\n", + "\n", + "def print_clinical_triage_demo(\n", + " mc: Dict[str, np.ndarray],\n", + " high_conf_threshold: float = 0.005,\n", + " low_conf_threshold: float = 0.01,\n", + ") -> None:\n", + " \"\"\"Show the clinical triage interpretation of MC dropout confidence.\n", + "\n", + " Splits predictions into three categories and reports the actual\n", + " mortality rate in each. This is the paper-extension's key clinical\n", + " finding: the confidence score doubles as an auxiliary triage signal.\n", + "\n", + " Args:\n", + " mc: Output of ``mc_dropout_predict_test``.\n", + " high_conf_threshold: MC std below which a prediction is \"confident\".\n", + " low_conf_threshold: MC std above which a prediction is \"uncertain\".\n", + " \"\"\"\n", + " y_true = mc[\"y_true\"]\n", + " mean_prob = mc[\"mean_prob\"]\n", + " std_prob = mc[\"std_prob\"]\n", + "\n", + " is_high_conf = std_prob < high_conf_threshold\n", + " is_low_conf = std_prob > low_conf_threshold\n", + "\n", + " auto_alert = is_high_conf & (mean_prob > 0.5)\n", + " review = is_low_conf\n", + " standard = is_high_conf & (mean_prob < 0.2)\n", + "\n", + " bar = \"=\" * 60\n", + " print(f\"\\n{bar}\")\n", + " print(\"CLINICAL TRIAGE DEMO (paper-extension claim)\")\n", + " print(bar)\n", + " print(\n", + " f\" auto_alert (high conf, high risk): \"\n", + " f\"n={int(auto_alert.sum()):>6,} \"\n", + " f\"mortality={y_true[auto_alert].mean() * 100:>5.1f}%\"\n", + " if auto_alert.sum() > 0 else\n", + " f\" auto_alert (high conf, high risk): n=0\"\n", + " )\n", + " print(\n", + " f\" review (low conf): \"\n", + " f\"n={int(review.sum()):>6,} \"\n", + " f\"mortality={y_true[review].mean() * 100:>5.1f}%\"\n", + " if review.sum() > 0 else\n", + " f\" review (low conf): n=0\"\n", + " )\n", + " print(\n", + " f\" standard (high conf, low risk): \"\n", + " f\"n={int(standard.sum()):>6,} \"\n", + " f\"mortality={y_true[standard].mean() * 100:>5.1f}%\"\n", + " if standard.sum() > 0 else\n", + " f\" standard (high conf, low risk): n=0\"\n", + " )\n", + " print(\n", + " f\" overall: \"\n", + " f\"n={len(y_true):>6,} \"\n", + " f\"mortality={y_true.mean() * 100:>5.1f}%\"\n", + " )" + ], + "metadata": { + "id": "kH5fHa0atvW6" + }, + "execution_count": 40, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "def run_evaluate(\n", + " data_root: Path,\n", + " checkpoint_path: Path,\n", + " batch_size: int = 128,\n", + " num_workers: int = 0,\n", + " max_test_batches: Optional[int] = None,\n", + " run_mc_dropout: bool = False,\n", + " n_mc_samples: int = 30,\n", + " device: str = \"cpu\",\n", + ") -> None:\n", + " \"\"\"Evaluate a trained checkpoint on the test set. No training.\n", + "\n", + " Loads a saved ``state_dict`` into a fresh model, builds only the test\n", + " DataLoader (no train/val loaders), runs ``evaluate_all_horizons`` to\n", + " produce the paper's per-horizon AUROC/AUPRC table, and optionally\n", + " runs the MC dropout analysis.\n", + "\n", + " Use this when you already have a trained checkpoint and just want to\n", + " see test metrics, without paying the full training cost again.\n", + "\n", + " Args:\n", + " data_root: Directory with pre-built HDF5s (``h5test_fixed.hdf5``,\n", + " ``features.txt``, and either ``scaling.pt`` or ``h5train_fixed.hdf5``).\n", + " checkpoint_path: Path to a ``.pt`` file containing either a raw\n", + " ``state_dict`` or a dict with a ``model_state_dict`` key\n", + " (matches the format produced by the training loop).\n", + " batch_size: DataLoader batch size.\n", + " num_workers: DataLoader workers.\n", + " max_test_batches: Cap on test batches (None = full test set).\n", + " run_mc_dropout: Whether to run MC dropout analysis.\n", + " n_mc_samples: Number of stochastic forward passes per sample.\n", + " device: Device string.\n", + " \"\"\"\n", + " print(\"=== Mode: evaluate ===\")\n", + " data_root = Path(data_root)\n", + " checkpoint_path = Path(checkpoint_path)\n", + "\n", + " if not checkpoint_path.exists():\n", + " raise FileNotFoundError(\n", + " f\"Checkpoint not found at {checkpoint_path}\"\n", + " )\n", + "\n", + " feature_names, scaling = _load_feature_names_and_scaling(data_root)\n", + " n_features = len(feature_names)\n", + " print(f\"n_features={n_features}\")\n", + "\n", + " label_keys = [\n", + " \"1-day-died\", \"3-day-died\", \"7-day-died\",\n", + " \"14-day-died\", \"28-day-died\",\n", + " ]\n", + "\n", + " test_h5 = data_root / \"h5test_fixed.hdf5\"\n", + " if not test_h5.exists():\n", + " raise FileNotFoundError(\n", + " f\"Test HDF5 not found at {test_h5}. \"\n", + " \"Run --mode build first to produce it.\"\n", + " )\n", + " test_ds = H5SequenceDataset(test_h5, label_keys=label_keys)\n", + " print(f\"test samples: {len(test_ds):,}\")\n", + "\n", + " loader_kwargs: Dict[str, Any] = dict(\n", + " batch_size=batch_size,\n", + " num_workers=num_workers,\n", + " pin_memory=(device == \"cuda\"),\n", + " collate_fn=collate_td,\n", + " persistent_workers=(num_workers > 0),\n", + " )\n", + " test_loader = DataLoader(test_ds, shuffle=False, **loader_kwargs)\n", + "\n", + " # Build a fresh model with the same architecture as training\n", + " model = build_model(\n", + " n_features=n_features,\n", + " feature_names=feature_names,\n", + " scaling=scaling,\n", + " hidden_dim=32,\n", + " cnn_layers=2,\n", + " dropout=0.5,\n", + " td_alpha=0.99,\n", + " device=device,\n", + " )\n", + "\n", + " # Load checkpoint using the shared helper\n", + " print(f\"loading checkpoint from {checkpoint_path}\")\n", + " ckpt = _load_checkpoint_into(\n", + " checkpoint_path, model, optimizer=None, device=device,\n", + " )\n", + " print(\"checkpoint loaded\")\n", + " if \"best_val_auroc\" in ckpt:\n", + " print(f\" checkpoint best_val_auroc: {ckpt['best_val_auroc']:.4f}\")\n", + " if \"epoch\" in ckpt:\n", + " print(f\" checkpoint epoch: {ckpt['epoch']}\")\n", + "\n", + " # ----- Test-set evaluation across all horizons -----\n", + " print(\"\\nEvaluating on test set across all mortality horizons...\")\n", + " test_results = evaluate_all_horizons(\n", + " model=model,\n", + " loader=test_loader,\n", + " label_keys=label_keys,\n", + " device=device,\n", + " max_batches=max_test_batches,\n", + " )\n", + " print_horizon_table(test_results, trained_on=model.label_key)\n", + "\n", + " # ----- MC dropout analysis on test set -----\n", + " if run_mc_dropout:\n", + " print(\n", + " f\"\\nRunning MC dropout inference \"\n", + " f\"({n_mc_samples} samples/input)...\"\n", + " )\n", + " mc = mc_dropout_predict_test(\n", + " model=model,\n", + " loader=test_loader,\n", + " device=device,\n", + " n_mc_samples=n_mc_samples,\n", + " label_key=model.label_key,\n", + " max_batches=max_test_batches,\n", + " )\n", + " print_mc_dropout_analysis(mc)\n", + " print_clinical_triage_demo(mc)" + ], + "metadata": { + "id": "C3uVQ1YuqAj2" + }, + "execution_count": 41, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "## 11) Run Evaluation - Option to run with/without the MC Dropout Extension" + ], + "metadata": { + "id": "0BnKMWAzWYXi" + } + }, + { + "cell_type": "code", + "source": [ + "checkpoint_path = checkpoint_dir / 'best.pt'\n", + "run_evaluate(\n", + " data_root=data_root,\n", + " checkpoint_path=checkpoint_path ,\n", + " batch_size=128,\n", + " num_workers=6,\n", + " max_test_batches=None,\n", + " run_mc_dropout=False,\n", + " n_mc_samples=30,\n", + " device=device,\n", + ")" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 396, + "referenced_widgets": [ + "6a362aae36e446a79088e11a38440615", + "1db925d48ebd4597bf1d1d3970ee66a5", + "c2490a1bcd094d8481223702e5916178", + "d4ac3b75f60243d8b90421260095b46c", + "ea6993b605dd4562b17f123f9f14b57b", + "8d1f62afb4764bb199a2e183a5ce153b", + "69387d6c3c1a46bf8c11aed964eefdf7", + "3aea421084ef4be3a77a1268a28dd3ee", + "c8ca009e64a94258bcf74207f5f63884", + "338a086755364c7dadccfe7cf1c76764", + "beec26dfe5c5471a81a5c51fa37c6533" + ] + }, + "id": "gXbOcJy4qLDS", + "outputId": "cb726fe4-a4e6-4eff-8401-bf953d462cd5" + }, + "execution_count": 48, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "=== Mode: evaluate ===\n", + "loaded scaling from /content/data/mimic/scaling.pt\n", + "n_features=108\n", + "test samples: 295,085\n", + "loading checkpoint from /content/checkpoints/best.pt\n", + "checkpoint loaded\n", + " checkpoint best_val_auroc: 0.8110\n", + " checkpoint epoch: 5\n", + "\n", + "Evaluating on test set across all mortality horizons...\n" + ] + }, + { + "output_type": "display_data", + "data": { + "text/plain": [ + "test: 0%| | 0/2306 [00:00 None:\n", + " \"\"\"Sweep target-network EMA alpha.\n", + "\n", + " Args:\n", + " mode: ``\"load\"``.\n", + " data_root: Required when ``mode=\"load\"``.\n", + " alpha_values: Alphas to try.\n", + " n_train_steps: Steps per alpha (synthetic mode).\n", + " max_train_batches: Batches per alpha (load mode).\n", + " max_val_batches: Val batches per alpha (load mode).\n", + " device: Device string.\n", + " \"\"\"\n", + " print(\"=== alpha ablation ===\")\n", + "\n", + " if mode == \"load\":\n", + " if data_root is None:\n", + " raise ValueError(\"--data-root is required when mode=load\")\n", + " feature_names, scaling = _load_feature_names_and_scaling(data_root)\n", + " n_features = len(feature_names)\n", + " label_keys = [\n", + " \"1-day-died\", \"3-day-died\", \"7-day-died\",\n", + " \"14-day-died\", \"28-day-died\",\n", + " ]\n", + " train_ds = H5SequenceDataset(\n", + " data_root / \"h5train_fixed.hdf5\", label_keys=label_keys,\n", + " )\n", + " val_ds = H5SequenceDataset(\n", + " data_root / \"h5val_fixed.hdf5\", label_keys=label_keys,\n", + " )\n", + " loader_kwargs: Dict[str, Any] = dict(\n", + " batch_size=128,\n", + " num_workers=4,\n", + " pin_memory=(device == \"cuda\"),\n", + " collate_fn=collate_td,\n", + " persistent_workers=True,\n", + " )\n", + " train_loader = DataLoader(train_ds, shuffle=True, **loader_kwargs)\n", + " val_loader = DataLoader(val_ds, shuffle=False, **loader_kwargs)\n", + " else:\n", + " raise ValueError(f\"Unknown ablation mode {mode!r}\")\n", + "\n", + " print(\n", + " f\"{'alpha':>6} {'loss_mean':>10} {'loss_std':>10} \"\n", + " f\"{'val_auroc':>10} {'val_loss':>10}\"\n", + " )\n", + " all_results = []\n", + " alpha_val = []\n", + " for alpha in alpha_values:\n", + " torch.manual_seed(42)\n", + " np.random.seed(42)\n", + " model = build_model(\n", + " n_features=n_features,\n", + " feature_names=feature_names,\n", + " scaling=scaling,\n", + " hidden_dim=32,\n", + " td_alpha=alpha,\n", + " device=device,\n", + " )\n", + "\n", + " optimizer = torch.optim.AdamW(\n", + " model.online_net.parameters(), lr=1e-3,\n", + " )\n", + " result = train_one_epoch_real(\n", + " model, train_loader, optimizer, device,\n", + " max_batches=max_train_batches,\n", + " )\n", + " #all_results.append(result)\n", + " val = eval_real(\n", + " model, val_loader, device, max_batches=max_val_batches,\n", + " )\n", + " all_results.append(val)\n", + " alpha_val.append(alpha)\n", + " losses = [val[\"loss\"], val[\"loss\"]] # placeholder for stats\n", + " tail = (\n", + " np.array(losses[-50:]) if len(losses) > 1 else np.array(losses)\n", + " )\n", + " print(\n", + " f\"{alpha:>6.3f} {tail.mean():>10.4f} {tail.std():>10.4f} \"\n", + " f\"{val.get('auroc', float('nan')):>10.4f} \"\n", + " f\"{val['loss']:>10.4f}\"\n", + " )\n", + " # ----------------------------- final summary ---------------------------------\n", + " i = 0\n", + " print(\"\\n\" + \"=\" * 60)\n", + " print(\"ALPHA ABLATION SUMMARY\")\n", + " print(\"=\" * 60)\n", + " print(f\"{'alpha':>6} {'val_auroc':>10} {'val_auprc':>10} \"\n", + " f\"{'loss':>10} {'pos_rate':>10} \")\n", + " for r in all_results:\n", + " print(\n", + " f\"{alpha_val[i]:>6.3f} {r['auroc']:>10.4f} {r['auprc']:>10.4f} \"\n", + " f\"{r['loss']:>10.4f} {r['pos_rate']:>10.4f} \"\n", + " )\n", + " i=i+1\n" + ], + "metadata": { + "id": "ESkV7sEFNg5E" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "run_alpha_ablation(\n", + " mode=\"load\",\n", + " data_root=data_root,\n", + " max_train_batches=4,\n", + " max_val_batches=2,\n", + " device=device,\n", + ")" + ], + "metadata": { + "id": "RP_gl9viNlkA" + }, + "execution_count": null, + "outputs": [] + } + ] +} \ No newline at end of file diff --git a/pyhealth/models/__init__.py b/pyhealth/models/__init__.py index 4c168d3e3..bec63cf7a 100644 --- a/pyhealth/models/__init__.py +++ b/pyhealth/models/__init__.py @@ -45,4 +45,5 @@ from .sdoh import SdohClassifier from .medlink import MedLink from .unified_embedding import UnifiedMultimodalEmbeddingModel, SinusoidalTimeEmbedding -from .califorest import CaliForest \ No newline at end of file +from .td_icu_mortality import TDICUMortalityModel +from .califorest import CaliForest diff --git a/pyhealth/models/td_icu_mortality.py b/pyhealth/models/td_icu_mortality.py new file mode 100644 index 000000000..a299911dc --- /dev/null +++ b/pyhealth/models/td_icu_mortality.py @@ -0,0 +1,919 @@ +"""Temporal-difference ICU mortality prediction model. + +This module implements the model from Frost et al., "Robust Real-Time +Mortality Prediction in the Intensive Care Unit using Temporal Difference +Learning" (arXiv:2411.04285), adapted to the PyHealth framework, plus an +extension that uses Monte Carlo dropout to attach a per-patient confidence +estimate to each mortality prediction. + +The approach frames ICU mortality prediction as a value-estimation problem: +at each observation step the model predicts the probability of death within +a fixed horizon (e.g. 28 days), and is trained with a temporal-difference +(TD) target derived from a lagged copy of itself (the "target network"). +This yields predictions that are calibrated across horizons and robust to +long sparse observation streams. + +Two classes are provided: + +* ``CNNLSTMPredictor``: the underlying CNN + LSTM architecture that maps an + event stream to a mortality probability. Can be used on its own for + supervised training. +* ``TDICUMortalityModel``: a PyHealth ``BaseModel`` that wraps two + ``CNNLSTMPredictor`` instances (online + target) and implements the TD + training rule, including the terminal-state handling from the paper. + At inference time, ``predict_with_confidence`` returns MC-dropout + uncertainty alongside each mortality prediction. +""" + +from __future__ import annotations + +from copy import deepcopy +from typing import Any, Dict, List, Mapping, Optional, Tuple + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.nn.utils.rnn import pack_padded_sequence + +from pyhealth.datasets import SampleEHRDataset +from pyhealth.models import BaseModel + + +# ----------------------------------------------------------------------------- +# Helper modules +# ----------------------------------------------------------------------------- + + +class MaxPool1D(nn.Module): + """NaN-aware 1D max pool for irregular event streams. + + Windows that are entirely NaN remain NaN in the output. Windows that + contain at least one real value output the max of those real values. + Padded positions can therefore be tracked across pooling layers. + + Args: + kernel_size: Pooling window size. + stride: Pooling stride. + """ + + def __init__(self, kernel_size: int = 2, stride: int = 2) -> None: + super().__init__() + self.kernel_size = kernel_size + self.stride = stride + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Run NaN-aware max pooling. + + Args: + x: Input tensor of shape ``[batch, channels, seq_len]``. + + Returns: + Pooled tensor with NaN preserved in all-NaN windows. + """ + neg_inf = torch.tensor(-np.inf, dtype=x.dtype, device=x.device) + pos_nan = torch.tensor(np.nan, dtype=x.dtype, device=x.device) + out = torch.where(torch.isnan(x), neg_inf, x) + out = F.max_pool1d(out, kernel_size=self.kernel_size, stride=self.stride) + return torch.where(torch.isinf(out), pos_nan, out) + + +class Transpose(nn.Module): + """Swap two tensor dimensions inside an ``nn.Sequential`` pipeline. + + Args: + dim1: First dimension to swap. + dim2: Second dimension to swap. + """ + + def __init__(self, dim1: int, dim2: int) -> None: + super().__init__() + self.dim1 = dim1 + self.dim2 = dim2 + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Transpose the configured dimensions. + + Args: + x: Input tensor. + + Returns: + Tensor with ``dim1`` and ``dim2`` swapped. + """ + return x.transpose(self.dim1, self.dim2) + + +# ----------------------------------------------------------------------------- +# CNN + LSTM backbone +# ----------------------------------------------------------------------------- + + +class CNNLSTMPredictor(nn.Module): + """CNN + LSTM encoder producing mortality predictions. + + The predictor embeds each component of an irregular event stream + (timepoint, value, feature id, delta-time, delta-value), fuses them via + summation + batchnorm, applies a small CNN stack for sequence-length + reduction, a 2-layer LSTM for temporal modelling, and a dense head for + binary classification. + + Args: + n_features: Size of the feature vocabulary. + features: Ordered list of feature names used to assemble the + scaling buffers from ``scaling``. + output_dim: Output dimensionality (always 1 for binary mortality). + scaling: Dictionary of per-feature statistics with the structure + ``{"mean": {...}, "std": {...}}`` where each inner dict contains + per-feature 1-element tensors for ``values``, ``delta_time``, + and ``delta_value``, plus scalar tensors for ``timepoints``. + cnn_layers: Number of CNN blocks (each applies a conv + ReLU + pool). + hidden_dim: Channel/embedding dimension of the encoder. + dropout: Dropout applied between LSTM layers. + batch_first: Whether the LSTM consumes batch-first tensors. + dtype: Floating point dtype for model parameters. + device: Device string, used to choose between packed vs masked LSTM. + + Attributes: + embedding_net: ``nn.ModuleDict`` of five per-component embeddings. + cnn: CNN stack that reduces sequence length. + lstm: Two-layer LSTM with hidden size ``hidden_dim * 8``. + dense: Binary prediction head. + """ + + def __init__( + self, + n_features: int, + features: List[str], + output_dim: int, + scaling: Mapping[str, Any], + cnn_layers: int = 2, + hidden_dim: int = 32, + dropout: float = 0.5, + batch_first: bool = True, + dtype: torch.dtype = torch.float32, + device: str = "cpu", + ) -> None: + super().__init__() + self.dtype = dtype + self.device_name = device + self.n_features = n_features + self.features = features + self.output_dim = output_dim + self.cnn_layers = cnn_layers + self.hidden_dim = hidden_dim + + self._register_scaling_buffers(scaling) + self._build_embeddings() + self._build_cnn() + self._build_lstm(dropout=dropout, batch_first=batch_first) + self._build_head() + + self.to(device) + self.init_weights() + + # -- construction helpers ------------------------------------------------- + + def _register_scaling_buffers( + self, + scaling: Mapping[str, Any], + ) -> None: + """Register per-feature mean/std tensors as buffers.""" + for name in ["values", "delta_time", "delta_value"]: + mean_t = torch.cat([scaling["mean"][name][f] for f in self.features]) + std_t = torch.cat([scaling["std"][name][f] for f in self.features]) + self.register_buffer(f"mean_{name}", mean_t) + self.register_buffer(f"std_{name}", std_t) + self.register_buffer("mean_timepoints", scaling["mean"]["timepoints"]) + self.register_buffer("std_timepoints", scaling["std"]["timepoints"]) + + def _build_embeddings(self) -> None: + """Build per-component embedding networks.""" + interim = int(np.sqrt(self.hidden_dim)) + self.embedding_net = nn.ModuleDict() + for name in ["time", "value", "feature", "delta_time", "delta_value"]: + if name == "feature": + self.embedding_net[name] = nn.Embedding( + self.n_features, self.hidden_dim, dtype=self.dtype + ) + else: + self.embedding_net[name] = nn.Sequential( + nn.Linear(1, interim, dtype=self.dtype), + nn.ReLU(), + nn.Linear(interim, self.hidden_dim, dtype=self.dtype), + ) + self.embedding_norm = nn.Sequential( + Transpose(1, 2), + nn.BatchNorm1d(self.hidden_dim), + Transpose(1, 2), + ) + + def _build_cnn(self) -> None: + """Build the CNN stack used after embedding fusion.""" + layers: List[nn.Module] = [Transpose(1, 2)] + for _ in range(self.cnn_layers): + layers += [ + nn.Conv1d( + self.hidden_dim, + self.hidden_dim, + kernel_size=2, + stride=1, + padding=1, + dtype=self.dtype, + ), + nn.ReLU(), + MaxPool1D(kernel_size=2, stride=2), + ] + layers.append(Transpose(1, 2)) + self.cnn = nn.Sequential(*layers) + + def _build_lstm(self, dropout: float, batch_first: bool) -> None: + """Build the stacked LSTM.""" + self.lstm = nn.LSTM( + input_size=self.hidden_dim, + hidden_size=self.hidden_dim * 8, + num_layers=2, + dropout=dropout, + batch_first=batch_first, + dtype=self.dtype, + ) + + def _build_head(self) -> None: + """Build the dense classification head.""" + self.dense = nn.Sequential( + nn.BatchNorm1d(self.hidden_dim * 8, dtype=self.dtype), + nn.Linear(self.hidden_dim * 8, self.hidden_dim, dtype=self.dtype), + nn.ReLU(), + nn.BatchNorm1d(self.hidden_dim, dtype=self.dtype), + nn.Linear(self.hidden_dim, self.output_dim, dtype=self.dtype), + ) + + def init_weights(self) -> None: + """Xavier-normal initialization for matrix-shaped parameters.""" + for p in self.parameters(): + if p.dim() > 1: + nn.init.xavier_normal_(p) + + # -- inference helpers ---------------------------------------------------- + + @torch.no_grad() + def soft_update( + self, new_model: "CNNLSTMPredictor", alpha: float = 0.99 + ) -> None: + """Exponential-moving-average update from another predictor. + + Implements ``theta_self = alpha * theta_self + (1 - alpha) * theta_new`` + in-place, which is the target-network update rule from the paper + (Appendix D, Eq. 7). + + Args: + new_model: Source predictor whose weights to blend in. + alpha: EMA coefficient in ``[0, 1]``. ``alpha=1`` leaves ``self`` + unchanged; ``alpha=0`` fully copies ``new_model`` into ``self``. + """ + src = new_model.state_dict() + tgt = self.state_dict() + for k in tgt: + tgt[k].copy_(alpha * tgt[k] + (1.0 - alpha) * src[k]) + + def pack_sequences( + self, src: torch.Tensor, mask: torch.Tensor + ) -> torch.nn.utils.rnn.PackedSequence: + """Pack a padded batch for efficient LSTM consumption. + + Args: + src: Embedded sequence tensor ``[batch, seq_len, hidden]``. + mask: Boolean padding mask ``[batch, seq_len]`` where ``True`` + marks padded positions. + + Returns: + A ``PackedSequence`` usable by ``self.lstm``. + """ + lengths = (~mask).sum(-1) + if lengths.min() == 0: + mask = mask.clone() + mask[torch.where(lengths == 0)[0], 0] = False + lengths = torch.where(lengths == 0, torch.ones_like(lengths), lengths) + return pack_padded_sequence( + src, lengths.cpu(), batch_first=True, enforce_sorted=False + ) + + def get_mask_after_conv(self, mask: torch.Tensor) -> torch.Tensor: + """Propagate a padding mask through the CNN pooling stack. + + Args: + mask: Boolean mask ``[batch, seq_len]`` aligned with the input. + + Returns: + Boolean mask aligned with the post-CNN sequence length. + """ + pooled = mask.unsqueeze(1).float() + for _ in range(self.cnn_layers): + pooled = F.avg_pool1d(pooled, kernel_size=2, stride=2) + return pooled.squeeze(1) == 1 + + def standardise_inputs( + self, + timepoints: torch.Tensor, + values: torch.Tensor, + features: torch.Tensor, + delta_time: torch.Tensor, + delta_value: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """Apply per-feature z-score standardization. + + For each of ``values``, ``delta_time`` and ``delta_value``, scatter + the event's value into a feature-indexed long tensor, apply per- + feature mean/std, then collapse back to a dense per-event tensor + using ``torch.nansum``. ``timepoints`` use a single global scaler. + + Args: + timepoints: Event timepoints ``[batch, seq_len]``. + values: Event values ``[batch, seq_len]``. + features: Feature ids ``[batch, seq_len]`` with ``-1`` for padding. + delta_time: Time since previous measurement of the same feature. + delta_value: Value change since previous measurement of the same + feature. + + Returns: + Tuple of standardized tensors in the order + ``(timepoints, values, delta_time, delta_value, features)`` + where ``features`` has had its ``-1`` sentinels remapped to ``0``. + """ + safe_time_std = torch.where( + self.std_timepoints == 0, + torch.ones_like(self.std_timepoints), + self.std_timepoints, + ) + standardised = [(timepoints - self.mean_timepoints) / safe_time_std] + + features = torch.where(features == -1, 0, features) + long_shape = values.shape[:2] + (self.n_features,) + + base_nan = torch.full( + long_shape, + float("nan"), + dtype=self.dtype, + device=values.device, + ) + + stats = { + "values": (self.mean_values, self.std_values), + "delta_time": (self.mean_delta_time, self.std_delta_time), + "delta_value": (self.mean_delta_value, self.std_delta_value), + } + for name, vector in [ + ("values", values), + ("delta_time", delta_time), + ("delta_value", delta_value), + ]: + missing = torch.isnan(vector) + scattered = base_nan.clone() + scattered.scatter_(-1, features.unsqueeze(-1), vector.unsqueeze(-1)) + + mean_v, std_v = stats[name] + safe_std = torch.where(std_v == 0, torch.ones_like(std_v), std_v) + scattered = (scattered - mean_v) / safe_std + collapsed = torch.nansum(scattered, -1).unsqueeze(-1) + collapsed = torch.where( + missing.unsqueeze(-1), + torch.tensor(float("nan"), dtype=self.dtype, device=values.device), + collapsed, + ) + standardised.append(collapsed) + + standardised.append(features) + return tuple(standardised) + + # -- forward -------------------------------------------------------------- + + def forward( + self, + timepoints: torch.Tensor, + values: torch.Tensor, + features: torch.Tensor, + delta_time: torch.Tensor, + delta_value: torch.Tensor, + normalise: bool = True, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Run the CNN-LSTM forward pass. + + Args: + timepoints: ``[batch, seq_len]`` hours since intime, NaN for pad. + values: ``[batch, seq_len]`` measurement values, NaN for pad. + features: ``[batch, seq_len]`` long tensor of feature ids, -1 for + pad. + delta_time: ``[batch, seq_len]`` hours since previous measurement + of the same feature. + delta_value: ``[batch, seq_len]`` change in value since previous + measurement of the same feature. + normalise: If ``True``, apply ``standardise_inputs`` first. + + Returns: + Tuple ``(probs, logits)`` each of shape ``[batch, output_dim]``. + ``probs = sigmoid(logits)``. + """ + if normalise: + ( + timepoints, + values, + delta_time, + delta_value, + features, + ) = self.standardise_inputs( + timepoints, values, features, delta_time, delta_value + ) + + timepoints = timepoints.squeeze(-1) if timepoints.ndim == 3 else timepoints + values = values.squeeze(-1) if values.ndim == 3 else values + delta_time = delta_time.squeeze(-1) if delta_time.ndim == 3 else delta_time + delta_value = ( + delta_value.squeeze(-1) if delta_value.ndim == 3 else delta_value + ) + + # Sort descending by timepoint; NaNs (padding) go to the end. + argsort_idx = torch.argsort( + torch.where(torch.isnan(timepoints), -torch.inf, timepoints), + dim=1, + descending=True, + ) + timepoints = torch.gather(timepoints, 1, argsort_idx) + values = torch.gather(values, 1, argsort_idx) + features = torch.gather(features, 1, argsort_idx) + delta_time = torch.gather(delta_time, 1, argsort_idx) + delta_value = torch.gather(delta_value, 1, argsort_idx) + + src_mask = torch.isnan(timepoints) + timepoints = torch.where(src_mask, 0, timepoints) + values = torch.where(src_mask, 0, values) + features = torch.where(src_mask, 0, features) + + dt_mask = torch.isnan(delta_time) + dv_mask = torch.isnan(delta_value) + delta_time = torch.where(dt_mask, 0, delta_time) + delta_value = torch.where(dv_mask, 0, delta_value) + + time_emb = self.embedding_net["time"](timepoints.unsqueeze(-1)) + value_emb = self.embedding_net["value"](values.unsqueeze(-1)) + feature_emb = self.embedding_net["feature"](features) + dt_emb = self.embedding_net["delta_time"](delta_time.unsqueeze(-1)) + dv_emb = self.embedding_net["delta_value"](delta_value.unsqueeze(-1)) + + dt_emb = torch.where(dt_mask.unsqueeze(-1), 0, dt_emb) + dv_emb = torch.where(dv_mask.unsqueeze(-1), 0, dv_emb) + + embedded = time_emb + value_emb + feature_emb + dt_emb + dv_emb + embedded = self.embedding_norm(embedded) + embedded = self.cnn(embedded) + + src_mask = self.get_mask_after_conv(src_mask) + + if self.device_name == "cuda": + packed = self.pack_sequences(embedded, src_mask) + embedded = self.lstm(packed)[1][0][-1] + else: + embedded = embedded.clone() + embedded[src_mask] = 0 + embedded = self.lstm(embedded)[1][0][-1] + + logits = self.dense(embedded) + probs = torch.sigmoid(logits) + return probs, logits + + +# ----------------------------------------------------------------------------- +# Loss +# ----------------------------------------------------------------------------- + + +class WeightedBCELoss(nn.Module): + """Thin wrapper around ``BCEWithLogitsLoss`` with optional pos-weighting. + + Args: + pos_weight: Optional positive-class weight tensor. When ``None``, + a standard unweighted BCE is used. Only set this for supervised + training, never for TD training: a weighted loss produces + incorrect gradients when the target is a continuous bootstrapped + value rather than a 0/1 label. + """ + + def __init__(self, pos_weight: Optional[torch.Tensor] = None) -> None: + super().__init__() + self.loss_fn = nn.BCEWithLogitsLoss(pos_weight=pos_weight) + + def forward( + self, logits: torch.Tensor, targets: torch.Tensor + ) -> torch.Tensor: + """Compute binary cross-entropy loss from logits. + + Args: + logits: Raw model output ``[batch, 1]``. + targets: Target values in ``[0, 1]`` (can be continuous for TD). + + Returns: + Scalar loss tensor. + """ + return self.loss_fn(logits, targets.float()) + + +# ----------------------------------------------------------------------------- +# PyHealth-facing TD model +# ----------------------------------------------------------------------------- + + +class TDICUMortalityModel(BaseModel): + """Temporal-difference ICU mortality prediction model. + + Implements the method of Frost et al. (2024), arXiv:2411.04285, for + real-time mortality prediction in the ICU. The model holds two + ``CNNLSTMPredictor`` instances: + + * ``online_net``: updated by gradient descent on a TD target. + * ``target_net``: an EMA-lagged copy of the online net whose predictions + form part of the TD target. + + At each training step, for a sample transition ``(s_t, s_{t+1})``, the + training target is: + + * ``target = label`` if ``s_{t+1}`` does not exist (terminal transition); + * ``target = gamma * target_net(s_{t+1})`` otherwise. + + The online net is regressed against this target with an unweighted binary + cross-entropy loss. The target net is updated once per optimizer step via + ``soft_update_target()``. + + The model can also be used in a purely supervised mode + (``train_td=False``) for baseline comparisons; in this mode the real + ``label_key`` is used as the target and ``pos_weight`` (if provided) is + applied. + + Example: + >>> from pyhealth.datasets import SampleEHRDataset + >>> from pyhealth.models.td_icu_mortality import TDICUMortalityModel + >>> dataset = SampleEHRDataset(samples=[], dataset_name="td_icu_demo") + >>> scaling = build_scaling_dict(...) # see ``examples/`` directory + >>> feature_names = [f"feat_{i}" for i in range(128)] + >>> model = TDICUMortalityModel( + ... dataset=dataset, + ... feature_keys=["timepoints", "values", "features", + ... "delta_time", "delta_value"], + ... label_key="28_day_died", + ... n_features=128, + ... scaling=scaling, + ... features_vocab=feature_names, + ... ) + >>> out = model(batch, targets=targets, train_td=True) + >>> out["loss"].backward() + >>> model.soft_update_target() + + Args: + dataset: A ``SampleEHRDataset`` instance (can be empty; only used for + metadata such as ``input_schema`` / ``output_schema``). + feature_keys: The five keys in the batch dict that hold the event + tuple components, in the order + ``[timepoints, values, features, delta_time, delta_value]``. + label_key: Key in the ``targets`` dict holding the real outcome + (e.g. ``"28_day_died"``). + mode: PyHealth task mode. Only ``"binary"`` is currently supported. + n_features: Size of the feature vocabulary. + hidden_dim: Hidden size for the encoder and LSTM. + cnn_layers: Number of CNN pooling blocks. + dropout: Dropout between LSTM layers. + output_dim: Output dimensionality (must be 1 for binary mortality). + scaling: Per-feature mean/std dictionary (see ``CNNLSTMPredictor``). + features_vocab: Ordered list of feature names keyed by ``scaling``. + td_alpha: Target-network EMA coefficient. Paper uses 0.99. + gamma: Discount factor on the bootstrapped next-state value. Default + 1.0 matches the paper. + pos_weight: Optional positive-class weight for supervised mode only. + device: Device string passed to the underlying predictors. + + Raises: + ValueError: If ``mode`` is not ``"binary"``. + """ + + def __init__( + self, + dataset: SampleEHRDataset, + feature_keys: List[str], + label_key: str, + mode: str = "binary", + n_features: int = 128, + hidden_dim: int = 32, + cnn_layers: int = 2, + dropout: float = 0.5, + output_dim: int = 1, + scaling: Optional[Mapping[str, Any]] = None, + features_vocab: Optional[List[str]] = None, + td_alpha: float = 0.99, + gamma: float = 1.0, + pos_weight: Optional[torch.Tensor] = None, + device: str = "cpu", + ) -> None: + if mode != "binary": + raise ValueError( + f"TDICUMortalityModel only supports mode='binary', got {mode!r}" + ) + if scaling is None or features_vocab is None: + raise ValueError( + "scaling and features_vocab are required - see the " + "example script for how to build them from training data." + ) + + # Adapt the SampleEHRDataset for BaseModel's expectations. + dataset.feature_keys = feature_keys + dataset.label_key = label_key + dataset.mode = mode + if not hasattr(dataset, "input_schema"): + dataset.input_schema = {k: "sequence" for k in feature_keys} + if not hasattr(dataset, "output_schema"): + dataset.output_schema = {label_key: mode} + + super().__init__(dataset) + + self.feature_keys = feature_keys + self.label_key = label_key + self.mode = mode + self.device_name = device + self.td_alpha = td_alpha + self.gamma = gamma + + self.online_net = CNNLSTMPredictor( + n_features=n_features, + features=features_vocab, + output_dim=output_dim, + scaling=scaling, + cnn_layers=cnn_layers, + hidden_dim=hidden_dim, + dropout=dropout, + device=device, + ) + self.target_net = CNNLSTMPredictor( + n_features=n_features, + features=features_vocab, + output_dim=output_dim, + scaling=scaling, + cnn_layers=cnn_layers, + hidden_dim=hidden_dim, + dropout=dropout, + device=device, + ) + self.target_net.load_state_dict(deepcopy(self.online_net.state_dict())) + + # pos_weight only ever applied in supervised mode. A continuous TD + # target combined with pos_weight would silently yield wrong grads. + self.supervised_loss = WeightedBCELoss(pos_weight=pos_weight) + self.td_loss = WeightedBCELoss(pos_weight=None) + + self.to(device) + + # -- PyHealth BaseModel abstract methods --------------------------------- + + def prepare_labels(self, labels: torch.Tensor) -> torch.Tensor: + """Format labels into the BCE-compatible shape. + + Args: + labels: Tensor of shape ``[batch]`` or ``[batch, 1]``. + + Returns: + Float tensor of shape ``[batch, 1]``. + """ + return labels.float().view(-1, 1) + + def get_loss_function(self) -> nn.Module: + """Return the default (supervised) loss function. + + The TD path uses ``self.td_loss`` internally. ``get_loss_function`` + exists to satisfy PyHealth's ``BaseModel`` contract and returns the + supervised (optionally pos-weighted) loss used when ``train_td=False``. + + Returns: + The supervised ``WeightedBCELoss`` instance. + """ + return self.supervised_loss + + # -- TD-specific helpers -------------------------------------------------- + + def predict_current( + self, batch: Mapping[str, torch.Tensor] + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Run the online net on the current-state window. + + Args: + batch: Mapping with the five feature keys listed in + ``self.feature_keys``. + + Returns: + Tuple ``(probs, logits)``. + """ + return self.online_net( + batch["timepoints"], + batch["values"], + batch["features"], + batch["delta_time"], + batch["delta_value"], + ) + + @torch.no_grad() + def predict_next_target( + self, batch: Mapping[str, torch.Tensor] + ) -> torch.Tensor: + """Run the target net on the next-state window. + + The target net is placed in eval mode so BatchNorm stats are frozen + and LSTM dropout is inactive. The output carries no gradient. + + Args: + batch: Mapping with the five ``next_*`` keys + (``next_timepoints``, ``next_values``, ``next_features``, + ``next_delta_time``, ``next_delta_value``). + + Returns: + Detached probabilities of shape ``[batch, output_dim]``. + """ + self.target_net.eval() + probs, _ = self.target_net( + batch["next_timepoints"], + batch["next_values"], + batch["next_features"], + batch["next_delta_time"], + batch["next_delta_value"], + ) + return probs.detach() + + def compute_td_target( + self, + batch: Mapping[str, torch.Tensor], + targets: Mapping[str, torch.Tensor], + ) -> torch.Tensor: + """Compute the TD target for each transition in the batch. + + At terminal transitions (``isterminal > 0.5``) the target is the real + mortality label; at non-terminal transitions it is + ``gamma * target_net(next_state)``. + + Args: + batch: Mapping containing at minimum ``isterminal`` and the five + ``next_*`` keys. + targets: Mapping containing ``self.label_key``. + + Returns: + Detached target tensor of shape ``[batch, output_dim]``. + """ + next_probs = self.predict_next_target(batch) + real_reward = targets[self.label_key].float().view_as(next_probs) + is_terminal = (batch["isterminal"] > 0.5).view_as(next_probs) + td_target = torch.where(is_terminal, real_reward, self.gamma * next_probs) + return td_target.detach() + + def soft_update_target(self) -> None: + """Update the target net with an EMA step from the online net. + + Uses ``self.td_alpha``. Call exactly once after each optimizer step. + """ + self.target_net.soft_update(self.online_net, alpha=self.td_alpha) + + # -- Forward pass --------------------------------------------------------- + + def forward( + self, + batch: Mapping[str, torch.Tensor], + targets: Optional[Mapping[str, torch.Tensor]] = None, + train_td: bool = False, + ) -> Dict[str, Optional[torch.Tensor]]: + """Run the model. + + Args: + batch: Mapping with ``self.feature_keys`` and (for TD training) + also ``isterminal`` and the five ``next_*`` keys. + targets: Optional mapping with ``self.label_key``. If omitted, + no loss is computed. + train_td: Whether to use the temporal-difference loss. When + ``True``, the TD target is computed via the target net; + when ``False``, the real label is used directly. + + Returns: + A dict with keys: + + * ``loss``: scalar loss tensor or ``None`` if ``targets is None``. + * ``y_prob``: ``[batch, output_dim]`` probabilities from the + online net. + * ``y_true``: the supplied label tensor, or ``None``. + * ``logit``: ``[batch, output_dim]`` raw model output. + """ + probs, logits = self.predict_current(batch) + + out: Dict[str, Optional[torch.Tensor]] = { + "loss": None, + "y_prob": probs, + "y_true": targets[self.label_key] if targets is not None else None, + "logit": logits, + } + + if targets is not None: + if train_td: + td_target = self.compute_td_target(batch, targets) + out["loss"] = self.td_loss(logits, td_target) + else: + supervised_target = self.prepare_labels(targets[self.label_key]) + out["loss"] = self.supervised_loss(logits, supervised_target) + + return out + + # -- Monte Carlo dropout for uncertainty quantification ------------------ + + @staticmethod + def _enable_dropout(module: nn.Module) -> None: + """Put only ``Dropout`` and ``LSTM`` layers into train mode. + + PyTorch's ``LSTM`` applies dropout between stacked layers only when + the module is in train mode, so we flip both ``nn.Dropout*`` and + ``nn.LSTM`` modules to train while leaving BatchNorm etc. in eval. + + Args: + module: Any ``nn.Module`` whose dropout we want to activate. + """ + for m in module.modules(): + if isinstance(m, (nn.Dropout, nn.Dropout2d, nn.Dropout3d, nn.LSTM)): + m.train() + + @torch.no_grad() + def predict_with_confidence( + self, + batch: Mapping[str, torch.Tensor], + n_mc_samples: int = 30, + high_conf_threshold: float = 0.005, + low_conf_threshold: float = 0.01, + ) -> Dict[str, torch.Tensor]: + """Return mortality predictions with per-patient confidence. + + Runs the online network ``n_mc_samples`` times with dropout active + (BatchNorm frozen in eval), then reports the MC mean, standard + deviation, and a ~95% credible interval per sample. Two boolean + flags classify each prediction as high / low confidence relative + to user-tunable thresholds on the MC standard deviation. + + This is the paper's base predictor extended with Monte Carlo + dropout, providing a per-patient uncertainty estimate alongside + the point mortality probability. Empirically, higher MC standard + deviation correlates with higher true mortality rate, so the + confidence score also serves as an auxiliary clinical-triage + signal. + + Example: + >>> out = model.predict_with_confidence(batch, n_mc_samples=30) + >>> mortality = out["mortality_prob"] # shape [batch] + >>> ci_lo = out["ci_95_lower"] # shape [batch] + >>> ci_hi = out["ci_95_upper"] # shape [batch] + >>> uncertain = out["is_low_confidence"] # bool tensor + + Args: + batch: Mapping with ``self.feature_keys`` (the five event + tuple components). ``next_*`` keys are not required since + only the online network is sampled. + n_mc_samples: Number of stochastic forward passes. More passes + tighten the uncertainty estimate at linear inference cost. + 30 is a reasonable default; 50-100 for research-grade. + high_conf_threshold: Predictions with MC standard deviation + below this value are flagged "high confidence". + low_conf_threshold: Predictions with MC standard deviation + above this value are flagged "low confidence". + + Returns: + Dict with tensor entries all shaped ``[batch]``: + + * ``mortality_prob``: MC-averaged probability. + * ``confidence_std``: standard deviation across MC samples. + * ``ci_95_lower``: ``clip(mean - 2*std, 0, 1)``. + * ``ci_95_upper``: ``clip(mean + 2*std, 0, 1)``. + * ``is_high_confidence``: bool, std < ``high_conf_threshold``. + * ``is_low_confidence``: bool, std > ``low_conf_threshold``. + """ + self.online_net.eval() + self._enable_dropout(self.online_net) + + probs_mc: List[torch.Tensor] = [] + for _ in range(n_mc_samples): + probs, _ = self.online_net( + batch["timepoints"], + batch["values"], + batch["features"], + batch["delta_time"], + batch["delta_value"], + ) + probs_mc.append(probs.float().squeeze(-1)) + + stacked = torch.stack(probs_mc, dim=1) # [batch, n_mc_samples] + mean_prob = stacked.mean(dim=1) + std_prob = stacked.std(dim=1) + + ci_lower = (mean_prob - 2.0 * std_prob).clamp(0.0, 1.0) + ci_upper = (mean_prob + 2.0 * std_prob).clamp(0.0, 1.0) + + return { + "mortality_prob": mean_prob, + "confidence_std": std_prob, + "ci_95_lower": ci_lower, + "ci_95_upper": ci_upper, + "is_high_confidence": std_prob < high_conf_threshold, + "is_low_confidence": std_prob > low_conf_threshold, + } diff --git a/tests/core/test_td_icu_mortality.py b/tests/core/test_td_icu_mortality.py new file mode 100644 index 000000000..0b162cfe2 --- /dev/null +++ b/tests/core/test_td_icu_mortality.py @@ -0,0 +1,770 @@ +"""Unit tests for pyhealth.models.td_icu_mortality. + +Converted from pytest to unittest +so it runs under the standard ``python -m unittest`` runner used by PyHealth CI. + +Performance strategy: + - setUpClass reuses the same model across tests within a class. Tests that + mutate weights operate on disposable copies or restore from a snapshot. + - Tiny config: n_features=4, hidden_dim=2, cnn_layers=1, seq_len=8, + batch_size=2. A forward pass at this size is a fraction of a ms. + - Synthetic tensor batches are pre-computed once per class. + +Run with: + python -m unittest tests/core/test_td_icu_mortality.py -v +""" + +from __future__ import annotations + +import copy +import shutil +import tempfile +import unittest +import warnings +from pathlib import Path +from typing import Dict, List + +import os +os.environ["CUDA_VISIBLE_DEVICES"] = "" + +import torch +import torch.nn as nn + +from pyhealth.datasets import SampleEHRDataset +from pyhealth.models.td_icu_mortality import ( + CNNLSTMPredictor, + MaxPool1D, + TDICUMortalityModel, + Transpose, + WeightedBCELoss, +) + +# --------------------------------------------------------------------------- +# Tiny config +# --------------------------------------------------------------------------- + +N_FEATURES = 4 +HIDDEN_DIM = 2 +CNN_LAYERS = 1 +SEQ_LEN = 8 +BATCH_SIZE = 2 +LABEL_KEY = "mortality" +FEATURE_KEYS = [ + "timepoints", "values", "features", "delta_time", "delta_value", +] +FEATURE_NAMES = [f"feat_{i}" for i in range(N_FEATURES)] + + +# --------------------------------------------------------------------------- +# Data helpers +# --------------------------------------------------------------------------- + +def _build_scaling(feature_names: List[str]) -> Dict: + scaling: Dict = {"mean": {}, "std": {}} + scaling["mean"]["timepoints"] = torch.tensor([0.0]) + scaling["std"]["timepoints"] = torch.tensor([1.0]) + for key in ["values", "delta_time", "delta_value"]: + scaling["mean"][key] = {f: torch.tensor([0.0]) for f in feature_names} + scaling["std"][key] = {f: torch.tensor([1.0]) for f in feature_names} + return scaling + + +def _make_current_window( + batch_size: int = BATCH_SIZE, + seq_len: int = SEQ_LEN, + n_features: int = N_FEATURES, + pad_head: int = 1, + seed: int = 0, +) -> Dict[str, torch.Tensor]: + g = torch.Generator().manual_seed(seed) + features = torch.randint(0, n_features, (batch_size, seq_len), generator=g) + if pad_head > 0: + features[:, :pad_head] = -1 + timepoints = torch.randn(batch_size, seq_len, generator=g) + values = torch.randn(batch_size, seq_len, generator=g) + dt = torch.randn(batch_size, seq_len, generator=g) + dv = torch.randn(batch_size, seq_len, generator=g) + if pad_head > 0: + pad_mask = features == -1 + timepoints[pad_mask] = float("nan") + values[pad_mask] = float("nan") + dt[pad_mask] = -1.0 + dv[pad_mask] = float("nan") + return { + "timepoints": timepoints, + "values": values, + "features": features, + "delta_time": dt, + "delta_value": dv, + } + + +def _make_td_batch( + batch_size: int = BATCH_SIZE, + seq_len: int = SEQ_LEN, + n_features: int = N_FEATURES, + terminal_mask: torch.Tensor = None, + seed: int = 0, +) -> Dict[str, torch.Tensor]: + cur = _make_current_window(batch_size, seq_len, n_features, seed=seed) + nxt = _make_current_window(batch_size, seq_len, n_features, seed=seed + 1) + batch = dict(cur) + for key, val in nxt.items(): + batch[f"next_{key}"] = val + if terminal_mask is None: + g = torch.Generator().manual_seed(seed + 2) + terminal_mask = torch.randint(0, 2, (batch_size,), generator=g).float() + batch["isterminal"] = terminal_mask.view(-1, 1) + return batch + + +def _make_targets(batch_size: int = BATCH_SIZE, seed: int = 0) -> Dict: + g = torch.Generator().manual_seed(seed) + y = torch.randint(0, 2, (batch_size,), generator=g).float().view(-1, 1) + return {LABEL_KEY: y} + + +def _make_empty_dataset() -> SampleEHRDataset: + with warnings.catch_warnings(): + warnings.simplefilter("ignore", DeprecationWarning) + return SampleEHRDataset(samples=[], dataset_name="td_icu_unit_test") + + +def _make_predictor(scaling=None, feature_names=None) -> CNNLSTMPredictor: + if scaling is None: + scaling = _build_scaling(FEATURE_NAMES) + if feature_names is None: + feature_names = FEATURE_NAMES + torch.manual_seed(42) + return CNNLSTMPredictor( + n_features=N_FEATURES, + features=feature_names, + output_dim=1, + scaling=scaling, + cnn_layers=CNN_LAYERS, + hidden_dim=HIDDEN_DIM, + dropout=0.1, + device="cpu", + ) + + +def _make_td_model(scaling=None, feature_names=None) -> TDICUMortalityModel: + if scaling is None: + scaling = _build_scaling(FEATURE_NAMES) + if feature_names is None: + feature_names = FEATURE_NAMES + torch.manual_seed(42) + return TDICUMortalityModel( + dataset=_make_empty_dataset(), + feature_keys=FEATURE_KEYS, + label_key=LABEL_KEY, + mode="binary", + n_features=N_FEATURES, + hidden_dim=HIDDEN_DIM, + cnn_layers=CNN_LAYERS, + dropout=0.1, + output_dim=1, + scaling=scaling, + features_vocab=feature_names, + td_alpha=0.99, + pos_weight=None, + device="cpu", + ) + + +# --------------------------------------------------------------------------- +# TestHelperModules +# --------------------------------------------------------------------------- + +class TestHelperModules(unittest.TestCase): + """MaxPool1D, Transpose, WeightedBCELoss.""" + + def test_maxpool1d_preserves_all_nan_window(self): + pool = MaxPool1D(2, 2) + x = torch.tensor([[[1.0, 2.0, float("nan"), float("nan")]]]) + out = pool(x) + self.assertEqual(out.shape, (1, 1, 2)) + self.assertEqual(out[0, 0, 0].item(), 2.0) + self.assertTrue(torch.isnan(out[0, 0, 1])) + + def test_maxpool1d_mixed_window_ignores_nan(self): + pool = MaxPool1D(2, 2) + x = torch.tensor([[[3.0, float("nan")]]]) + out = pool(x) + self.assertEqual(out[0, 0, 0].item(), 3.0) + self.assertFalse(torch.isnan(out[0, 0, 0])) + + def test_transpose_swaps_dims(self): + t = Transpose(1, 2) + x = torch.randn(2, 3, 5) + self.assertTrue(torch.equal(t(x), x.transpose(1, 2))) + + def test_weighted_bce_matches_builtin(self): + loss_fn = WeightedBCELoss(pos_weight=None) + logits = torch.zeros(3, 1) + targets = torch.tensor([[1.0], [0.0], [1.0]]) + expected = nn.functional.binary_cross_entropy_with_logits(logits, targets) + self.assertTrue(torch.allclose(loss_fn(logits, targets), expected)) + + def test_weighted_bce_with_pos_weight(self): + no_w = WeightedBCELoss(pos_weight=None) + w = WeightedBCELoss(pos_weight=torch.tensor([2.0])) + logits = torch.zeros(4, 1) + targets = torch.ones(4, 1) + self.assertGreater(w(logits, targets).item(), no_w(logits, targets).item()) + + +# --------------------------------------------------------------------------- +# TestCNNLSTMPredictorInstantiation +# --------------------------------------------------------------------------- + +class TestCNNLSTMPredictorInstantiation(unittest.TestCase): + + @classmethod + def setUpClass(cls): + cls.predictor = _make_predictor() + + def test_instantiation_succeeds(self): + self.assertIsInstance(self.predictor, CNNLSTMPredictor) + self.assertEqual(self.predictor.n_features, N_FEATURES) + self.assertEqual(self.predictor.hidden_dim, HIDDEN_DIM) + self.assertEqual(self.predictor.cnn_layers, CNN_LAYERS) + + def test_has_five_embedding_heads(self): + expected = {"time", "value", "feature", "delta_time", "delta_value"} + self.assertEqual(set(self.predictor.embedding_net.keys()), expected) + + def test_feature_embedding_dims(self): + emb = self.predictor.embedding_net["feature"] + self.assertIsInstance(emb, nn.Embedding) + self.assertEqual(emb.num_embeddings, N_FEATURES) + self.assertEqual(emb.embedding_dim, HIDDEN_DIM) + + def test_scalar_embeddings_are_mlps(self): + for name in ["time", "value", "delta_time", "delta_value"]: + net = self.predictor.embedding_net[name] + self.assertIsInstance(net, nn.Sequential) + self.assertEqual(net[0].in_features, 1) + self.assertEqual(net[-1].out_features, HIDDEN_DIM) + + def test_scaling_buffers_registered(self): + buffers = dict(self.predictor.named_buffers()) + for name in [ + "mean_values", "std_values", + "mean_delta_time", "std_delta_time", + "mean_delta_value", "std_delta_value", + "mean_timepoints", "std_timepoints", + ]: + self.assertIn(name, buffers) + + def test_lstm_hidden_is_8x_base(self): + self.assertEqual(self.predictor.lstm.hidden_size, HIDDEN_DIM * 8) + self.assertEqual(self.predictor.lstm.num_layers, 2) + + def test_parameter_count_reasonable(self): + n = sum(p.numel() for p in self.predictor.parameters() if p.requires_grad) + self.assertGreater(n, 0) + self.assertLess(n, 100_000) + + +# --------------------------------------------------------------------------- +# TestCNNLSTMPredictorForward +# --------------------------------------------------------------------------- + +class TestCNNLSTMPredictorForward(unittest.TestCase): + + @classmethod + def setUpClass(cls): + cls.predictor = _make_predictor() + cls.predictor.eval() + + def test_forward_returns_tuple(self): + out = self.predictor(**_make_current_window()) + self.assertIsInstance(out, tuple) + self.assertEqual(len(out), 2) + + def test_probs_shape(self): + probs, _ = self.predictor(**_make_current_window()) + self.assertEqual(probs.shape, (BATCH_SIZE, 1)) + + def test_logits_shape(self): + _, logits = self.predictor(**_make_current_window()) + self.assertEqual(logits.shape, (BATCH_SIZE, 1)) + + def test_probs_in_unit_interval(self): + probs, _ = self.predictor(**_make_current_window()) + self.assertTrue(torch.all(probs >= 0.0)) + self.assertTrue(torch.all(probs <= 1.0)) + + def test_no_nan_in_output(self): + probs, logits = self.predictor(**_make_current_window()) + self.assertTrue(torch.isfinite(probs).all()) + self.assertTrue(torch.isfinite(logits).all()) + + def test_handles_no_padding(self): + probs, _ = self.predictor(**_make_current_window(pad_head=0)) + self.assertTrue(torch.isfinite(probs).all()) + + def test_handles_heavy_padding(self): + probs, _ = self.predictor(**_make_current_window(pad_head=SEQ_LEN - 2)) + self.assertTrue(torch.isfinite(probs).all()) + + def test_probs_match_sigmoid_logits(self): + probs, logits = self.predictor(**_make_current_window()) + self.assertTrue(torch.allclose(probs, torch.sigmoid(logits), atol=1e-6)) + + +# --------------------------------------------------------------------------- +# TestCNNLSTMPredictorGradients +# --------------------------------------------------------------------------- + +class TestCNNLSTMPredictorGradients(unittest.TestCase): + + def setUp(self): + self.predictor = _make_predictor() + self.predictor.train() + + def test_backward_produces_gradients(self): + _, logits = self.predictor(**_make_current_window()) + logits.sum().backward() + grads = [p.grad for p in self.predictor.parameters() if p.requires_grad] + self.assertTrue(any(g is not None for g in grads)) + + def test_all_trainable_params_receive_gradient(self): + _, logits = self.predictor(**_make_current_window()) + logits.sum().backward() + missing = [ + name for name, p in self.predictor.named_parameters() + if p.requires_grad and p.grad is None + ] + self.assertFalse(missing) + + def test_gradients_are_finite(self): + _, logits = self.predictor(**_make_current_window()) + logits.sum().backward() + for name, p in self.predictor.named_parameters(): + if p.grad is not None: + self.assertTrue(torch.isfinite(p.grad).all(), name) + + +# --------------------------------------------------------------------------- +# TestSoftUpdate +# --------------------------------------------------------------------------- + +class TestSoftUpdate(unittest.TestCase): + + @classmethod + def setUpClass(cls): + scaling = _build_scaling(FEATURE_NAMES) + torch.manual_seed(1) + cls.tgt = _make_predictor(scaling=scaling) + torch.manual_seed(2) + cls.src = _make_predictor(scaling=scaling) + # Snapshot BOTH — test_alpha_half_is_midpoint mutates src.dense weights + cls.tgt_initial = copy.deepcopy(cls.tgt.state_dict()) + cls.src_initial = copy.deepcopy(cls.src.state_dict()) + + def setUp(self): + # Restore both to initial state before every test + self.tgt.load_state_dict(copy.deepcopy(self.tgt_initial)) + self.src.load_state_dict(copy.deepcopy(self.src_initial)) + + def test_alpha_zero_copies_source(self): + self.tgt.soft_update(self.src, alpha=0.0) + for k, v in self.src.state_dict().items(): + if v.dtype.is_floating_point: + self.assertTrue(torch.allclose(self.tgt.state_dict()[k], v)) + + def test_alpha_one_leaves_target(self): + self.tgt.soft_update(self.src, alpha=1.0) + for k, v in self.tgt.state_dict().items(): + if v.dtype.is_floating_point: + self.assertTrue(torch.allclose(self.tgt_initial[k], v)) + + def test_alpha_half_is_midpoint(self): + with torch.no_grad(): + self.tgt.dense[1].weight.fill_(0.0) + self.src.dense[1].weight.fill_(2.0) + self.tgt.soft_update(self.src, alpha=0.5) + self.assertTrue(torch.allclose( + self.tgt.dense[1].weight, + torch.full_like(self.tgt.dense[1].weight, 1.0), + )) + + def test_soft_update_has_no_grad(self): + self.tgt.soft_update(self.src, alpha=0.9) + for p in self.tgt.parameters(): + self.assertIsNone(p.grad) + + +# --------------------------------------------------------------------------- +# TestTDICUMortalityModelInstantiation +# --------------------------------------------------------------------------- + +class TestTDICUMortalityModelInstantiation(unittest.TestCase): + + @classmethod + def setUpClass(cls): + cls.model = _make_td_model() + cls.empty_dataset = _make_empty_dataset() + cls.scaling = _build_scaling(FEATURE_NAMES) + + def test_instantiation_succeeds(self): + self.assertIsInstance(self.model, TDICUMortalityModel) + self.assertEqual(self.model.label_key, LABEL_KEY) + self.assertEqual(self.model.mode, "binary") + self.assertAlmostEqual(self.model.td_alpha, 0.99) + + def test_inherits_from_basemodel(self): + from pyhealth.models import BaseModel + self.assertIsInstance(self.model, BaseModel) + + def test_has_online_and_target_nets(self): + self.assertIsInstance(self.model.online_net, CNNLSTMPredictor) + self.assertIsInstance(self.model.target_net, CNNLSTMPredictor) + + def test_both_loss_functions_exist(self): + self.assertIsInstance(self.model.supervised_loss, WeightedBCELoss) + self.assertIsInstance(self.model.td_loss, WeightedBCELoss) + + def test_td_loss_has_no_pos_weight(self): + self.assertIsNone(self.model.td_loss.loss_fn.pos_weight) + + def test_non_binary_mode_raises(self): + with self.assertRaises(ValueError): + TDICUMortalityModel( + dataset=self.empty_dataset, + feature_keys=FEATURE_KEYS, + label_key=LABEL_KEY, + mode="multiclass", + n_features=N_FEATURES, + hidden_dim=HIDDEN_DIM, + cnn_layers=CNN_LAYERS, + scaling=self.scaling, + features_vocab=FEATURE_NAMES, + ) + + def test_missing_scaling_raises(self): + with self.assertRaises(ValueError): + TDICUMortalityModel( + dataset=self.empty_dataset, + feature_keys=FEATURE_KEYS, + label_key=LABEL_KEY, + mode="binary", + n_features=N_FEATURES, + hidden_dim=HIDDEN_DIM, + cnn_layers=CNN_LAYERS, + scaling=None, + features_vocab=FEATURE_NAMES, + ) + + +# --------------------------------------------------------------------------- +# TestBaseModelInterface +# --------------------------------------------------------------------------- + +class TestBaseModelInterface(unittest.TestCase): + + @classmethod + def setUpClass(cls): + cls.model = _make_td_model() + + def test_prepare_labels_shape(self): + labels = torch.tensor([0, 1, 1, 0]) + out = self.model.prepare_labels(labels) + self.assertEqual(out.shape, (4, 1)) + self.assertEqual(out.dtype, torch.float32) + + def test_prepare_labels_idempotent_from_batched(self): + labels = torch.tensor([[0], [1], [1], [0]]) + out = self.model.prepare_labels(labels) + self.assertEqual(out.shape, (4, 1)) + self.assertEqual(out.dtype, torch.float32) + + def test_get_loss_function_returns_module(self): + loss_fn = self.model.get_loss_function() + self.assertIsInstance(loss_fn, nn.Module) + logits = torch.zeros(2, 1) + targets = torch.tensor([[1.0], [0.0]]) + result = loss_fn(logits, targets) + self.assertTrue(torch.isfinite(result)) + + +# --------------------------------------------------------------------------- +# TestTDICUMortalityModelForward +# --------------------------------------------------------------------------- + +class TestTDICUMortalityModelForward(unittest.TestCase): + + @classmethod + def setUpClass(cls): + cls.model = _make_td_model() + cls.model.eval() + cls.batch = _make_td_batch(seed=0) + cls.targets = _make_targets(seed=0) + + def setUp(self): + # Mirrors pytest's autouse _eval_mode fixture — ensures eval() is set + # before every test, since train_td=True tests may alter model state + self.model.eval() + + def test_returns_all_expected_keys(self): + out = self.model(self.batch, targets=None, train_td=False) + self.assertEqual(set(out.keys()), {"loss", "y_prob", "y_true", "logit"}) + self.assertIsNone(out["loss"]) + self.assertIsNone(out["y_true"]) + + def test_with_targets_produces_loss(self): + out = self.model(self.batch, targets=self.targets, train_td=False) + self.assertIsNotNone(out["loss"]) + self.assertTrue(torch.isfinite(out["loss"])) + + def test_y_prob_shape(self): + out = self.model(self.batch, targets=None, train_td=False) + self.assertEqual(out["y_prob"].shape, (BATCH_SIZE, 1)) + + def test_logit_shape(self): + out = self.model(self.batch, targets=None, train_td=False) + self.assertEqual(out["logit"].shape, (BATCH_SIZE, 1)) + + def test_probs_in_unit_interval(self): + out = self.model(self.batch, targets=None, train_td=False) + self.assertTrue(torch.all(out["y_prob"] >= 0.0)) + self.assertTrue(torch.all(out["y_prob"] <= 1.0)) + + def test_supervised_and_td_mode_differ(self): + batch = _make_td_batch(terminal_mask=torch.zeros(BATCH_SIZE)) + sup_loss = self.model(batch, targets=self.targets, train_td=False)["loss"] + td_loss = self.model(batch, targets=self.targets, train_td=True)["loss"] + self.assertFalse(torch.allclose(sup_loss, td_loss)) + + def test_all_terminal_reduces_to_supervised(self): + batch = _make_td_batch(terminal_mask=torch.ones(BATCH_SIZE)) + sup = self.model(batch, targets=self.targets, train_td=False)["loss"] + td = self.model(batch, targets=self.targets, train_td=True)["loss"] + self.assertTrue(torch.allclose(sup, td, atol=1e-5)) + + +# --------------------------------------------------------------------------- +# TestTDTargetComputation +# --------------------------------------------------------------------------- + +class TestTDTargetComputation(unittest.TestCase): + + @classmethod + def setUpClass(cls): + cls.model = _make_td_model() + cls.model.eval() + cls.batch = _make_td_batch(seed=0) + cls.targets = _make_targets(seed=0) + + def test_td_target_shape(self): + td = self.model.compute_td_target(self.batch, self.targets) + self.assertEqual(td.shape, (BATCH_SIZE, 1)) + + def test_td_target_in_unit_interval(self): + td = self.model.compute_td_target(self.batch, self.targets) + self.assertTrue(torch.all(td >= 0.0) and torch.all(td <= 1.0)) + + def test_td_target_detached(self): + td = self.model.compute_td_target(self.batch, self.targets) + self.assertFalse(td.requires_grad) + + +# --------------------------------------------------------------------------- +# TestTDICUMortalityModelGradients +# --------------------------------------------------------------------------- + +class TestTDICUMortalityModelGradients(unittest.TestCase): + # Each test gets a fresh model — no shared state, no snapshot needed + # since _make_td_model() always uses torch.manual_seed(42) + + def setUp(self): + self.model = _make_td_model() + self.model.train() + self.batch = _make_td_batch(seed=0) + self.targets = _make_targets(seed=0) + + def test_supervised_backward(self): + out = self.model(self.batch, targets=self.targets, train_td=False) + out["loss"].backward() + grads = [p.grad for p in self.model.online_net.parameters() if p.requires_grad] + self.assertTrue(any(g is not None for g in grads)) + + def test_td_backward(self): + out = self.model(self.batch, targets=self.targets, train_td=True) + out["loss"].backward() + grads = [p.grad for p in self.model.online_net.parameters() if p.requires_grad] + self.assertTrue(any(g is not None for g in grads)) + + def test_target_net_no_gradient_in_td_mode(self): + out = self.model(self.batch, targets=self.targets, train_td=True) + out["loss"].backward() + for name, p in self.model.target_net.named_parameters(): + self.assertIsNone(p.grad, f"target_net.{name} got a gradient") + + +# --------------------------------------------------------------------------- +# TestSoftUpdateTarget +# --------------------------------------------------------------------------- + +class TestSoftUpdateTarget(unittest.TestCase): + + def test_target_changes_after_training_step(self): + model = _make_td_model() + model.train() + batch = _make_td_batch(seed=0) + targets = _make_targets(seed=0) + + target_before = { + k: v.detach().clone() + for k, v in model.target_net.state_dict().items() + } + optim = torch.optim.AdamW(model.online_net.parameters(), lr=1e-1) + out = model(batch, targets=targets, train_td=True) + optim.zero_grad(set_to_none=True) + out["loss"].backward() + optim.step() + model.soft_update_target() + + target_after = model.target_net.state_dict() + changed = any( + not torch.allclose(target_before[k], target_after[k], atol=1e-6) + for k in target_before + if target_before[k].dtype.is_floating_point + ) + self.assertTrue(changed) + + +# --------------------------------------------------------------------------- +# TestMCDropoutConfidence +# --------------------------------------------------------------------------- + +class TestMCDropoutConfidence(unittest.TestCase): + + @classmethod + def setUpClass(cls): + cls.model = _make_td_model() + cls.model.eval() + cls.batch = _make_td_batch(seed=0) + + def setUp(self): + # Mirrors pytest's autouse _eval_mode fixture + self.model.eval() + + def test_returns_all_expected_keys(self): + out = self.model.predict_with_confidence(self.batch, n_mc_samples=3) + expected = { + "mortality_prob", "confidence_std", + "ci_95_lower", "ci_95_upper", + "is_high_confidence", "is_low_confidence", + } + self.assertEqual(set(out.keys()), expected) + + def test_output_shapes(self): + out = self.model.predict_with_confidence(self.batch, n_mc_samples=3) + for key in ["mortality_prob", "confidence_std", "ci_95_lower", "ci_95_upper"]: + self.assertEqual(out[key].shape, (BATCH_SIZE,), key) + + def test_mortality_prob_in_unit_interval(self): + out = self.model.predict_with_confidence(self.batch, n_mc_samples=3) + self.assertTrue(torch.all(out["mortality_prob"] >= 0.0)) + self.assertTrue(torch.all(out["mortality_prob"] <= 1.0)) + + def test_std_non_negative(self): + out = self.model.predict_with_confidence(self.batch, n_mc_samples=3) + self.assertTrue(torch.all(out["confidence_std"] >= 0.0)) + + def test_ci_in_unit_interval_and_contains_mean(self): + out = self.model.predict_with_confidence(self.batch, n_mc_samples=3) + self.assertTrue(torch.all(out["ci_95_lower"] >= 0.0)) + self.assertTrue(torch.all(out["ci_95_upper"] <= 1.0)) + self.assertTrue(torch.all(out["ci_95_lower"] <= out["mortality_prob"])) + self.assertTrue(torch.all(out["mortality_prob"] <= out["ci_95_upper"])) + + def test_confidence_flags_are_bool(self): + out = self.model.predict_with_confidence(self.batch, n_mc_samples=3) + self.assertEqual(out["is_high_confidence"].dtype, torch.bool) + self.assertEqual(out["is_low_confidence"].dtype, torch.bool) + + def test_confidence_flags_mutually_exclusive(self): + out = self.model.predict_with_confidence(self.batch, n_mc_samples=3) + both = out["is_high_confidence"] & out["is_low_confidence"] + self.assertFalse(both.any()) + + def test_does_not_alter_train_eval_of_non_dropout_modules(self): + self.model.eval() + _ = self.model.predict_with_confidence(self.batch, n_mc_samples=2) + for m in self.model.online_net.modules(): + if isinstance(m, nn.BatchNorm1d): + self.assertFalse( + m.training, + "BatchNorm should stay in eval during MC dropout", + ) + + def test_no_gradients_from_mc_sampling(self): + out = self.model.predict_with_confidence(self.batch, n_mc_samples=2) + self.assertFalse(out["mortality_prob"].requires_grad) + self.assertFalse(out["confidence_std"].requires_grad) + + +# --------------------------------------------------------------------------- +# TestCheckpointIO +# --------------------------------------------------------------------------- + +class TestCheckpointIO(unittest.TestCase): + + def setUp(self): + self.tmp_dir = Path(tempfile.mkdtemp(prefix="td_icu_test_")) + self.model = _make_td_model() + + def tearDown(self): + shutil.rmtree(self.tmp_dir, ignore_errors=True) + + def test_state_dict_roundtrip(self): + ckpt_path = self.tmp_dir / "model.pt" + torch.save(self.model.state_dict(), ckpt_path) + self.assertTrue(ckpt_path.exists()) + + original = {k: v.detach().clone() for k, v in self.model.state_dict().items()} + with torch.no_grad(): + for p in self.model.online_net.parameters(): + p.add_(1.0) + + restored = torch.load(ckpt_path, weights_only=True) + self.model.load_state_dict(restored) + loaded = self.model.state_dict() + for k, v in original.items(): + self.assertTrue(torch.allclose(v, loaded[k])) + + def test_tmp_dir_is_writable(self): + self.assertTrue(self.tmp_dir.exists()) + self.assertTrue(self.tmp_dir.is_dir()) + (self.tmp_dir / "sanity.txt").write_text("ok") + self.assertEqual((self.tmp_dir / "sanity.txt").read_text(), "ok") + + def test_tmp_dir_is_unique(self): + self.assertIn("td_icu_test_", str(self.tmp_dir)) + + +# --------------------------------------------------------------------------- +# TestDeterminism +# --------------------------------------------------------------------------- + +class TestDeterminism(unittest.TestCase): + + @classmethod + def setUpClass(cls): + cls.predictor = _make_predictor() + cls.predictor.eval() + + def test_same_input_same_output(self): + cur = _make_current_window(seed=42) + with torch.no_grad(): + p1, _ = self.predictor(**cur) + p2, _ = self.predictor(**cur) + self.assertTrue(torch.allclose(p1, p2, atol=1e-6)) + + +if __name__ == "__main__": + unittest.main()