From d8f7308cecde745e6d368d571a7c0cc3831fc09a Mon Sep 17 00:00:00 2001 From: jhnwu3 Date: Thu, 12 Mar 2026 19:11:55 -0500 Subject: [PATCH 1/4] add new docs --- docs/api/datasets.rst | 142 +++++++++++++++++++++++++++++++++++ docs/api/graph.rst | 138 ++++++++++++++++++++++++++++++++++ docs/api/models.rst | 168 +++++++++++++++++++++++++++++++++++++++++- docs/api/tasks.rst | 132 +++++++++++++++++++++++++++++++++ docs/api/trainer.rst | 99 ++++++++++++++++++++++++- 5 files changed, 674 insertions(+), 5 deletions(-) diff --git a/docs/api/datasets.rst b/docs/api/datasets.rst index 3412e5ac5..24913c69f 100644 --- a/docs/api/datasets.rst +++ b/docs/api/datasets.rst @@ -22,6 +22,148 @@ This tutorial covers: - Using openly available demo datasets (MIMIC-III Demo, MIMIC-IV Demo) - Working with synthetic data for testing +How PyHealth Loads Data +------------------------ + +When you initialise a dataset, PyHealth reads the raw CSV or Parquet files +using Polars, joins the tables according to a YAML schema, and writes a +compact ``global_event_df.parquet`` cache to disk. On subsequent runs with +the same configuration it reads from cache rather than re-parsing the source +files, so startup is fast. + +The result is a :class:`~pyhealth.datasets.BaseDataset` — a structured +patient→event tree. It is different from a PyTorch Dataset: it has no integer +length and you cannot index into it with ``dataset[i]``. Think of it as a +queryable dictionary of patient records. To turn it into something a model +can train on, you call ``dataset.set_task()`` (see :doc:`tasks`), which +returns a :class:`~pyhealth.datasets.SampleDataset` that *is* indexable and +DataLoader-ready. + +Native Datasets vs Custom Datasets +------------------------------------ + +PyHealth includes native support for several standard EHR databases — MIMIC-III, +MIMIC-IV, eICU, and OMOP. These come with built-in schema definitions so you +can load them with just a root path and a list of tables: + +.. code-block:: python + + from pyhealth.datasets import MIMIC3Dataset + + if __name__ == '__main__': + dataset = MIMIC3Dataset( + root="/data/physionet.org/files/mimiciii/1.4", + tables=["diagnoses_icd", "procedures_icd", "prescriptions"], + cache_dir=".cache", + dev=True, # use 1 000 patients while exploring + ) + +For any other data source — a custom patient registry, an institutional cohort, +or a non-EHR dataset — you create a subclass of ``BaseDataset`` and provide a +``config.yaml`` file that describes your table structure. + +Initialization Parameters +-------------------------- + +- **root** — path to the directory containing the raw data files. For MIMIC-IV + specifically, use ``ehr_root`` instead of ``root``. +- **tables** — the table names you want to load, e.g. + ``["diagnoses_icd", "labevents"]``. Only these tables will be accessible in + patient queries downstream. +- **config_path** — path to your ``config.yaml``; needed for custom datasets. + Native datasets have this built in and ignore the parameter. +- **cache_dir** — where to store the cached Parquet and LitData files. PyHealth + appends a UUID derived from your configuration, so different setups never + overwrite each other. +- **num_workers** — parallel processes for data loading. Increasing this can + speed up ``set_task()`` on large datasets. +- **dev** — when ``True``, PyHealth caps the dataset at 1 000 patients. This + is very useful during development because it makes each iteration complete in + seconds rather than minutes. Switch to ``dev=False`` for your final training + run. + +config.yaml for Custom Datasets +--------------------------------- + +If you are bringing your own data, the YAML file tells PyHealth which column +is the patient identifier, which column is the timestamp, and which other +columns to include as event attributes: + +.. code-block:: yaml + + tables: + my_table: + file_path: relative/path/to/file.csv + patient_id: subject_id + timestamp: charttime + timestamp_format: "%Y-%m-%d %H:%M:%S" + attributes: + - icd_code + - value + - itemid + join: [] # optional table joins + +All attribute column names are lowercased internally, so ``ICD_CODE`` in +your CSV becomes ``icd_code`` in your code. + +Querying Patients and Events +----------------------------- + +Once a dataset is loaded, you can explore it using these methods: + +.. code-block:: python + + dataset.unique_patient_ids # all patient IDs as a list of strings + dataset.get_patient("p001") # retrieve one Patient object + dataset.iter_patients() # iterate over all patients + dataset.stats() # print patient and event counts + +Patient records are accessed through ``get_events()``, which supports +temporal filtering and attribute-level filters: + +.. code-block:: python + + events = patient.get_events( + event_type="diagnoses_icd", # table name from your config + start=datetime(2020, 1, 1), # optional: exclude earlier events + end=datetime(2020, 6, 1), # optional: exclude later events + filters=[("icd_code", "==", "250.00")], # optional: attribute conditions + ) + +Each event in the returned list has: + +- ``event.timestamp`` — a Python ``datetime`` object. PyHealth normalises all + timestamp columns (``charttime``, ``admittime``, etc.) into this single + property, so this is what you should use regardless of what the original + column was called. +- ``event.icd_code``, ``event["icd_code"]``, ``event.attr_dict`` — different + ways to access the other attributes. All attribute names are lowercase. + +Things to Watch Out For +------------------------ + +A few patterns that commonly trip up new users: + +**BaseDataset vs SampleDataset.** Models expect a ``SampleDataset`` (the +output of ``set_task()``), not the raw ``BaseDataset``. Passing the wrong one +will raise an error. If you see an ``AttributeError`` about ``input_schema`` +or ``output_schema``, this is likely the cause. + +**Timestamp attribute names.** Writing ``event.charttime`` will raise an +``AttributeError`` because PyHealth remaps that column to ``event.timestamp``. +The same applies to ``admittime``, ``starttime``, or whatever the original +column was called. + +**Column name casing.** PyHealth lowercases all column names at load time. +Even if your source CSV has ``ICD_CODE``, you access it as ``event.icd_code``. + +**dev=True in production.** The ``dev`` flag is great for exploring data but +it caps the dataset at 1 000 patients. Remember to switch to ``dev=False`` +before running a full training job. + +**Multiprocessing guard.** Scripts that call ``set_task()`` should wrap their +top-level code in ``if __name__ == '__main__':``. See :doc:`tasks` for details. + Available Datasets ------------------ diff --git a/docs/api/graph.rst b/docs/api/graph.rst index e69de29bb..164214285 100644 --- a/docs/api/graph.rst +++ b/docs/api/graph.rst @@ -0,0 +1,138 @@ +Graph +===== + +The ``pyhealth.graph`` module lets you bring a healthcare knowledge graph into +your PyHealth pipeline. Graph-based models like GraphCare and GNN can use +relational medical knowledge — drug interactions, disease hierarchies, +symptom–diagnosis links — to enrich patient representations beyond what +raw EHR codes alone can capture. + +What Is a Knowledge Graph? +--------------------------- + +A knowledge graph encodes medical relationships as **(head, relation, tail)** +triples. For example: + +- ``("aspirin", "treats", "headache")`` +- ``("metformin", "used_for", "type_2_diabetes")`` +- ``("ICD9:250", "is_a", "ICD9:249")`` + +PyHealth does not ship a built-in graph — you bring triples from a source of +your choice (UMLS, DrugBank, an ICD hierarchy, a custom ontology, etc.) and +the :class:`~pyhealth.graph.KnowledgeGraph` class handles indexing, entity +mappings, and k-hop subgraph extraction. The typical use case is querying the +graph at training time: given a patient's active codes, extract the local +subgraph around those codes and feed it to a graph-aware model. + +Getting Started +--------------- + +The simplest way to create a graph is to pass a list of triples directly: + +.. code-block:: python + + from pyhealth.graph import KnowledgeGraph + + triples = [ + ("aspirin", "treats", "headache"), + ("headache", "symptom_of", "migraine"), + ("ibuprofen", "treats", "headache"), + ] + kg = KnowledgeGraph(triples=triples) + kg.stat() + # KnowledgeGraph: 4 entities, 2 relations, 3 triples + +For larger graphs it is more practical to load from a CSV or TSV file. The +file should have columns named ``head``, ``relation``, and ``tail``: + +.. code-block:: python + + kg = KnowledgeGraph(triples="path/to/medical_kg.tsv") + +Exploring the Graph +------------------- + +Once built, you can inspect the graph and look up neighbours for any entity: + +.. code-block:: python + + kg.num_entities # total unique entities + kg.num_relations # total unique relation types + kg.num_triples # total edges + + kg.has_entity("aspirin") # True / False + kg.neighbors("aspirin") # list of (relation, tail) pairs + + # Integer ID mappings used internally by PyG + kg.entity2id["aspirin"] # → int + kg.id2entity[0] # → entity name string + +Extracting Patient Subgraphs +------------------------------ + +The main reason to build a knowledge graph is to extract a patient-specific +subgraph at training time. ``subgraph()`` returns all entities reachable +within *n* hops of a set of seed codes, as a PyTorch Geometric ``Data`` +object: + +.. code-block:: python + + patient_codes = ["ICD9:250.00", "NDC:0069-0105"] + subgraph = kg.subgraph(seed_entities=patient_codes, num_hops=2) + +.. note:: + + ``subgraph()`` requires `PyTorch Geometric `_ + (``torch_geometric``). The graph can still be constructed and explored + without it — only subgraph extraction needs PyG. + + Install with: ``pip install torch-geometric`` + +Using with GraphProcessor in a Task +------------------------------------- + +To feed subgraphs into a model automatically during data loading, pass a +configured :class:`~pyhealth.processors.GraphProcessor` instance in your +task's ``input_schema``. The processor will call ``kg.subgraph()`` for each +patient sample: + +.. code-block:: python + + from pyhealth.graph import KnowledgeGraph + from pyhealth.processors import GraphProcessor + from pyhealth.tasks import BaseTask + + kg = KnowledgeGraph(triples="medical_kg.tsv") + + class MyGraphTask(BaseTask): + task_name = "MyGraphTask" + input_schema = { + "conditions": "sequence", + "kg_subgraph": GraphProcessor(kg, num_hops=2), + } + output_schema = {"label": "binary"} + + def __call__(self, patient): + ... + +Pre-computed Node Embeddings +----------------------------- + +If you already have entity embeddings (e.g. from TransE or an LLM), you can +attach them to the graph at construction time. The model can then use these +as initial node features instead of learning them from scratch: + +.. code-block:: python + + import torch + + node_features = torch.randn(kg.num_entities, 64) # (num_entities, feat_dim) + kg = KnowledgeGraph(triples=triples, node_features=node_features) + +API Reference +------------- + +.. toctree:: + :maxdepth: 3 + + graph/pyhealth.graph.KnowledgeGraph diff --git a/docs/api/models.rst b/docs/api/models.rst index 2e647898b..7368dec94 100644 --- a/docs/api/models.rst +++ b/docs/api/models.rst @@ -1,9 +1,171 @@ Models =============== -We implement the following models for supporting multiple healthcare predictive tasks. +PyHealth models sit between the :doc:`processors` (which turn raw patient data +into tensors) and the :doc:`trainer` (which runs the training loop). Each +model takes a ``SampleDataset`` — the result of ``dataset.set_task()`` — as +its first constructor argument, and uses it to automatically build the right +embedding layers and output head for your task. + +One thing worth knowing up front: the ``SampleDataset`` carries fitted +processor metadata that the model needs to configure itself. If you pass the +raw ``BaseDataset`` instead you'll get an error, because it hasn't been +processed into samples yet. + +Choosing a Model +---------------- + +The table below covers the most commonly used models and when each one fits +best. If your features are a mix of sequential codes and static numeric +vectors, ``MultimodalRNN`` is usually the easiest starting point because it +routes each feature type automatically. + +.. list-table:: + :header-rows: 1 + :widths: 20 40 40 + + * - Model + - Good fit when… + - Notes + * - :doc:`models/pyhealth.models.RNN` + - Your features are sequences of medical codes (diagnoses, procedures, drugs) across visits + - One RNN per feature, hidden states concatenated; ``rnn_type`` can be ``"GRU"`` (default) or ``"LSTM"`` + * - :doc:`models/pyhealth.models.Transformer` + - You have longer code histories and want attention to capture long-range dependencies + - Self-attention across the sequence; tends to work well when visit order matters + * - :doc:`models/pyhealth.models.MLP` + - Features are static numeric vectors (aggregated lab values, demographics) + - Fully connected; no notion of sequence order + * - ``MultimodalRNN`` + - Features mix sequential codes with static tensors or multi-hot encodings + - Auto-routes sequential features to RNN layers and non-sequential features to linear layers; good default for EHR + * - :doc:`models/pyhealth.models.StageNet` + - You have time-stamped vital signs with irregular measurement intervals + - Requires ``StageNetProcessor`` or ``StageNetTensorProcessor`` in the task schema + * - :doc:`models/pyhealth.models.GNN` + - Features include graph-structured data + - Works with ``GraphProcessor``; see :doc:`graph` for setup + * - :doc:`models/pyhealth.models.GraphCare` + - You want to augment EHR codes with a medical knowledge graph + - Combines code sequences with a :class:`~pyhealth.graph.KnowledgeGraph` + +How BaseModel Works +-------------------- + +All PyHealth models inherit from ``BaseModel``, which itself inherits from +PyTorch's ``nn.Module``. When you call ``MyModel(dataset=sample_ds)``, the +base class reads the dataset's schemas and automatically sets: + +- ``self.feature_keys`` — the list of input field names from ``input_schema`` +- ``self.label_keys`` — the list of output field names from ``output_schema`` +- ``self.device`` — the compute device + +It also provides three helper methods that take care of the boilerplate that +varies by task type: + +- ``self.get_output_size()`` returns the output dimension from the fitted + label processor, so you don't have to hard-code it. +- ``self.get_loss_function()`` returns the right loss for the task: BCE for + binary and multilabel tasks, cross-entropy for multiclass, MSE for + regression. +- ``self.prepare_y_prob(logits)`` applies sigmoid, softmax, or identity to + logits depending on the task, producing calibrated probabilities. + +The ``forward()`` method is expected to return a dictionary with four keys: +``loss``, ``y_prob``, ``y_true``, and ``logit``. The Trainer reads all four. + +EmbeddingModel +-------------- + +:class:`~pyhealth.models.EmbeddingModel` is a helper that routes each input +feature to the appropriate embedding layer based on how its processor works. +Features from token-based processors (``SequenceProcessor``, +``NestedSequenceProcessor``, and similar) get a learned ``nn.Embedding`` +lookup. Features from continuous processors (``TensorProcessor``, +``TimeseriesProcessor``, ``MultiHotProcessor``) get a linear projection +instead. You end up with a uniform embedding shape across all features: + +.. code-block:: python + + self.embedding_model = EmbeddingModel(dataset, embedding_dim=128) + embedded = self.embedding_model(inputs, masks=masks) + # embedded[key] has shape (batch_size, seq_len, embedding_dim) + +Task Mode and Loss Functions +----------------------------- + +PyHealth automatically selects the loss function and output activation based +on the label processor in your task's ``output_schema``: + +.. list-table:: + :header-rows: 1 + :widths: 20 30 30 + + * - Output schema value + - Loss function + - ``y_prob`` shape and activation + * - ``"binary"`` + - BCE with logits + - sigmoid → (batch, 1) + * - ``"multiclass"`` + - Cross-entropy + - softmax → (batch, num_classes) + * - ``"multilabel"`` + - BCE with logits + - sigmoid → (batch, num_labels) + * - ``"regression"`` + - MSE + - identity → (batch, 1) + +Building a Custom Model +----------------------- + +If none of the built-in models fit your architecture, you can subclass +``BaseModel`` directly. The skeleton below shows the typical structure: build +an ``EmbeddingModel`` in ``__init__``, unpack processor schemas in +``forward``, pool or aggregate the embeddings, and return the four-key dict. + +.. code-block:: python + + from pyhealth.models import BaseModel + from pyhealth.models.embedding import EmbeddingModel + import torch + import torch.nn as nn + + class MyModel(BaseModel): + def __init__(self, dataset, embedding_dim=128): + super().__init__(dataset=dataset) + self.label_key = self.label_keys[0] + self.embedding_model = EmbeddingModel(dataset, embedding_dim) + self.fc = nn.Linear(embedding_dim * len(self.feature_keys), + self.get_output_size()) + + def forward(self, **kwargs): + inputs, masks = {}, {} + for key in self.feature_keys: + feature = kwargs[key] + if isinstance(feature, torch.Tensor): + feature = (feature,) + schema = self.dataset.input_processors[key].schema() + inputs[key] = feature[schema.index("value")] + if "mask" in schema: + masks[key] = feature[schema.index("mask")] + + embedded = self.embedding_model(inputs, masks=masks) + pooled = [embedded[k].mean(dim=1) for k in self.feature_keys] + logits = self.fc(torch.cat(pooled, dim=1)) + + y_true = kwargs[self.label_key].to(self.device) + return { + "loss": self.get_loss_function()(logits, y_true), + "y_prob": self.prepare_y_prob(logits), + "y_true": y_true, + "logit": logits, + } + +API Reference +------------- - .. toctree:: :maxdepth: 3 @@ -41,4 +203,4 @@ We implement the following models for supporting multiple healthcare predictive models/pyhealth.models.VisionEmbeddingModel models/pyhealth.models.TextEmbedding models/pyhealth.models.BIOT - models/pyhealth.models.unified_multimodal_embedding_docs \ No newline at end of file + models/pyhealth.models.unified_multimodal_embedding_docs diff --git a/docs/api/tasks.rst b/docs/api/tasks.rst index 3ed1e1c97..d85d04bc3 100644 --- a/docs/api/tasks.rst +++ b/docs/api/tasks.rst @@ -67,6 +67,138 @@ After you define a task: - Discover how to customize processor behavior with kwargs tuples - Understand processor types for different data modalities (text, images, signals, etc.) +Writing a Custom Task +---------------------- + +When a built-in task doesn't match your cohort or prediction target, you can +define your own by subclassing :class:`~pyhealth.tasks.BaseTask`. The class +needs three things: a name, input and output schemas, and a ``__call__`` +method that processes one patient at a time. + +.. code-block:: python + + from pyhealth.tasks import BaseTask + from pyhealth.data import Patient + from typing import List, Dict, Any + + class MyMortalityTask(BaseTask): + task_name: str = "MyMortalityTask" + + input_schema: Dict[str, str] = { + "conditions": "sequence", # maps to SequenceProcessor + "procedures": "sequence", + } + output_schema: Dict[str, str] = { + "label": "binary" # maps to BinaryLabelProcessor + } + + def __call__(self, patient: Patient) -> List[Dict[str, Any]]: + samples = [] + for adm in patient.get_events("admissions"): + label = 1 if adm.hospital_expire_flag == "1" else 0 + + # Fetch historical diagnoses up to this admission + conditions = patient.get_events("diagnoses_icd", end=adm.timestamp) + cond_codes = [e.icd_code for e in conditions] + + if not cond_codes: + continue + + samples.append({ + "conditions": [cond_codes], # wrapped in a list for the sequence processor + "procedures": [[]], + "label": label, + }) + return samples + +The ``__call__`` method receives one ``Patient`` and returns a list of sample +dictionaries. Each dictionary's keys should match the schemas you declared. +Returning an empty list is fine — PyHealth simply skips that patient. Note +that event attribute names are always lowercase (e.g. ``e.icd_code`` rather +than ``e.ICD_CODE``) because PyHealth lowercases all column names at ingest +time. Timestamps are accessed through ``event.timestamp`` rather than the +original column name like ``charttime``, since PyHealth normalises them into +a single property. + +Processor String Keys +---------------------- + +The string values in your schemas map to specific processor classes. Here is +a quick reference: + +.. list-table:: + :header-rows: 1 + :widths: 25 35 40 + + * - String key + - Processor + - Typical use + * - ``"sequence"`` + - ``SequenceProcessor`` + - Diagnosis codes, procedure codes, drug names + * - ``"nested_sequence"`` + - ``NestedSequenceProcessor`` + - Cumulative visit history (drug recommendation, readmission) + * - ``"tensor"`` + - ``TensorProcessor`` + - Aggregated numeric values (e.g. last lab value per item) + * - ``"timeseries"`` + - ``TimeseriesProcessor`` + - Irregular time-series measurements + * - ``"multi_hot"`` + - ``MultiHotProcessor`` + - Demographics, comorbidity flags + * - ``"text"`` + - ``TextProcessor`` + - Clinical notes + * - ``"binary"`` + - ``BinaryLabelProcessor`` + - Binary classification label (0 / 1) + * - ``"multiclass"`` + - ``MultiClassLabelProcessor`` + - Multi-class label + * - ``"multilabel"`` + - ``MultiLabelProcessor`` + - Multi-label classification + * - ``"regression"`` + - ``RegressionLabelProcessor`` + - Continuous regression target + +How set_task() Works +--------------------- + +Calling ``dataset.set_task(task)`` iterates over every patient in the +dataset, runs your ``__call__`` method on each one, fits all the processors +on the collected samples, then serialises everything to disk as LitData +``.ld`` files. The result is a :class:`~pyhealth.datasets.SampleDataset` that +supports ``len()`` and index access, ready for a DataLoader. + +.. code-block:: python + + sample_ds = dataset.set_task(MyMortalityTask(), num_workers=4) + len(sample_ds) # total ML samples across all patients + sample_ds[0] # a single sample dict with tensor values + +If you re-run ``set_task()`` with the same task and processor configuration, +PyHealth detects the existing cache and skips reprocessing. During +development it is useful to set ``dev=True`` on the dataset, which limits +processing to 1 000 patients so iterations are fast. + +.. note:: + + **A note on multiprocessing.** ``set_task()`` can spawn worker processes + when ``num_workers > 1``. On macOS and Linux this requires the standard + Python multiprocessing guard around your top-level script: + + .. code-block:: python + + if __name__ == '__main__': + sample_ds = dataset.set_task(task, num_workers=4) + + Without this guard, Python may try to re-import and re-run the script in + each worker process, leading to infinite recursion. This is a general + Python multiprocessing requirement, not specific to PyHealth. + Available Tasks --------------- diff --git a/docs/api/trainer.rst b/docs/api/trainer.rst index 039f22dc4..0a52af0f7 100644 --- a/docs/api/trainer.rst +++ b/docs/api/trainer.rst @@ -1,7 +1,102 @@ Trainer -=================================== +======= + +:class:`~pyhealth.trainer.Trainer` handles the PyTorch training loop for you. +Rather than writing your own epoch loop, loss backward pass, optimizer step, +and metric evaluation, you hand the Trainer your model and data loaders and +let it manage the details — including early stopping when validation +performance plateaus and automatic reloading of the best checkpoint at the end. + +A Typical Training Run +----------------------- + +Here is what a full training setup looks like. The data loaders come from +``get_dataloader()`` in :mod:`pyhealth.datasets`, which knows how to work with +PyHealth's LitData caching format: + +.. code-block:: python + + from pyhealth.trainer import Trainer + from pyhealth.datasets import get_dataloader + + train_loader = get_dataloader(train_ds, batch_size=32, shuffle=True) + val_loader = get_dataloader(val_ds, batch_size=32, shuffle=False) + test_loader = get_dataloader(test_ds, batch_size=32, shuffle=False) + + trainer = Trainer( + model=model, + metrics=["roc_auc_macro", "pr_auc_macro", "f1_macro"], + device="cuda", + ) + + trainer.train( + train_dataloader=train_loader, + val_dataloader=val_loader, + test_dataloader=test_loader, + epochs=50, + monitor="roc_auc_macro", + monitor_criterion="max", + patience=10, + ) + + scores = trainer.evaluate(test_loader) + # {'roc_auc_macro': 0.85, 'pr_auc_macro': 0.79, 'f1_macro': 0.72, 'loss': 0.31} + +Setting Up the Trainer +----------------------- + +``Trainer(model, metrics=None, device=None, enable_logging=True, output_path=None, exp_name=None)`` + +- **model** — your instantiated PyHealth model. +- **metrics** — the metric names you want computed at validation and test time + (e.g. ``["roc_auc_macro", "f1_macro"]``). See :doc:`metrics` for the full + list of supported strings. +- **device** — ``"cuda"`` or ``"cpu"``; defaults to auto-detecting a GPU. +- **enable_logging** — when enabled, the Trainer creates a timestamped folder + under ``output_path`` with a ``log.txt`` and model checkpoints. +- **output_path** / **exp_name** — where and how to name the output folder. + +Controlling the Training Loop +------------------------------ + +``trainer.train()`` accepts these key arguments beyond the data loaders: + +- **epochs** — the maximum number of training epochs. +- **optimizer_class** / **optimizer_params** — which optimizer to use and how + to configure it. Defaults to ``Adam`` with a learning rate of ``1e-3``. +- **weight_decay** — L2 regularisation strength. Default ``0.0``. +- **max_grad_norm** — if set, clips gradients to this norm before each update, + which can help stabilise training on noisy medical data. +- **monitor** / **monitor_criterion** — the metric to watch on the validation + set (e.g. ``"roc_auc_macro"``) and whether higher is better (``"max"``) or + lower is better (``"min"``). The Trainer saves a checkpoint whenever this + metric improves. +- **patience** — how many epochs without improvement to wait before stopping + early. +- **load_best_model_at_last** — when ``True`` (the default), the Trainer + restores the best checkpoint at the end of training rather than keeping the + weights from the final epoch. + +Getting the Test Scores +------------------------ + +``trainer.train()`` prints test scores to the console when a +``test_dataloader`` is provided, but it does not return them as a Python +object. To capture results for downstream use, call ``evaluate()`` separately: + +.. code-block:: python + + scores = trainer.evaluate(test_loader) + # scores is a plain dict, e.g. {'roc_auc_macro': 0.85, 'loss': 0.31} + + import json + with open("results.json", "w") as f: + json.dump(scores, f, indent=2) + +API Reference +------------- .. autoclass:: pyhealth.trainer.Trainer :members: :undoc-members: - :show-inheritance: \ No newline at end of file + :show-inheritance: From 19090a2177c66321fcc01ef34a8eafdfbe560682 Mon Sep 17 00:00:00 2001 From: jhnwu3 Date: Thu, 12 Mar 2026 19:17:16 -0500 Subject: [PATCH 2/4] index --- docs/index.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/index.rst b/docs/index.rst index 543676952..d2f6b69e9 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -12,7 +12,7 @@ Build, test, and deploy healthcare machine learning models with ease. PyHealth i **Key Features** - **Dramatically simpler**: Build any healthcare AI model in ~7 lines of code -- **Blazing fast**: Up to 39× faster than pandas +- **Blazing fast**: Up to 39× faster than pandas for task processing - **Memory efficient**: Runs on 16GB laptops - **True multimodal**: Unified API for EHR, medical images, biosignals, clinical text, and genomics - **Production-ready**: 25+ pre-built models, 20+ tasks, 12+ datasets with comprehensive evaluation tools From fa390b4926f61eccae232e01980e46c65d7be7fc Mon Sep 17 00:00:00 2001 From: jhnwu3 Date: Thu, 12 Mar 2026 19:41:42 -0500 Subject: [PATCH 3/4] overview page added --- docs/api/datasets.rst | 56 +++++++++- docs/api/overview.rst | 240 ++++++++++++++++++++++++++++++++++++++++ docs/api/processors.rst | 6 + docs/index.rst | 1 + 4 files changed, 300 insertions(+), 3 deletions(-) create mode 100644 docs/api/overview.rst diff --git a/docs/api/datasets.rst b/docs/api/datasets.rst index 24913c69f..b02439d26 100644 --- a/docs/api/datasets.rst +++ b/docs/api/datasets.rst @@ -10,7 +10,7 @@ New to PyHealth datasets? Start here: This tutorial covers: -- How to load and work with different healthcare datasets (MIMIC-III, MIMIC-IV, eICU, etc.) +- How to load and work with any PyHealth dataset (MIMIC-III, MIMIC-IV, eICU, OMOP, and many more) - Understanding the ``BaseDataset`` structure and patient representation - Parsing raw EHR data into standardized PyHealth format - Accessing patient records, visits, and clinical events @@ -39,11 +39,61 @@ can train on, you call ``dataset.set_task()`` (see :doc:`tasks`), which returns a :class:`~pyhealth.datasets.SampleDataset` that *is* indexable and DataLoader-ready. +From BaseDataset to SampleDataset +----------------------------------- + +``BaseDataset`` and ``SampleDataset`` serve different roles and are not +interchangeable: + +- **BaseDataset** is a queryable patient registry. It holds the raw + patient→visit→event tree loaded from disk. You cannot index into it like a + list — it has no integer length and is not DataLoader-ready. +- **SampleDataset** is a PyTorch-compatible streaming dataset returned by + ``dataset.set_task()``. Each element is a fully processed feature + dictionary that a model can consume directly. + +The conversion happens in one call: + +.. code-block:: python + + import torch + from pyhealth.datasets import MIMIC3Dataset + from pyhealth.tasks import MortalityPredictionMIMIC3 + + dataset = MIMIC3Dataset(root="...", tables=["diagnoses_icd"]) + samples = dataset.set_task(MortalityPredictionMIMIC3()) + # `samples` is a SampleDataset — pass it straight to a DataLoader + loader = torch.utils.data.DataLoader(samples, batch_size=32) + +Under the hood, ``set_task()`` runs a ``SampleBuilder`` that fits feature +processors (tokenisers, label encoders, etc.) across the full dataset, then +writes compressed, chunked sample files to disk via +`litdata `_. A companion +``schema.pkl`` stores the fitted processors so the dataset can be reloaded in +future runs without re-fitting. + +``SampleDataset`` also exposes two convenience lookups built during fitting: + +- ``samples.patient_to_index`` — maps a patient ID to all sample indices for + that patient. +- ``samples.record_to_index`` — maps a visit/record ID to the sample indices + for that visit. + +For testing or small cohorts you can skip the disk step entirely using +``InMemorySampleDataset``, which holds all processed samples in RAM and is +returned by default from ``create_sample_dataset()``. + +.. note:: + Building a custom dataset or bringing your own data? + See :doc:`../tutorials` (Tutorial 1) for a step-by-step walkthrough, and + the `config.yaml for Custom Datasets`_ section below for the schema format. + Native Datasets vs Custom Datasets ------------------------------------ -PyHealth includes native support for several standard EHR databases — MIMIC-III, -MIMIC-IV, eICU, and OMOP. These come with built-in schema definitions so you +PyHealth includes native support for many standard healthcare databases — including +MIMIC-III, MIMIC-IV, eICU, OMOP, and many others (see the full list in `Available Datasets`_ +below). All of these come with built-in schema definitions so you can load them with just a root path and a list of tables: .. code-block:: python diff --git a/docs/api/overview.rst b/docs/api/overview.rst new file mode 100644 index 000000000..532eeb48b --- /dev/null +++ b/docs/api/overview.rst @@ -0,0 +1,240 @@ +PyHealth Architecture Overview +============================== + +This page describes how all PyHealth components connect, from raw data files +to a trained, evaluated model. Every stage has its own dedicated reference +page — this overview is here to show how they fit together. + +Pipeline at a Glance +--------------------- + +.. code-block:: text + + Raw CSV / Parquet files + │ + ▼ + config.yaml (table schemas, patient_id col, timestamp col, attributes) + │ + ▼ + BaseDataset subclass ──── loads tables, caches as global_event_df.parquet + │ .unique_patient_ids → List[str] + │ .get_patient(id) → Patient + │ .iter_patients() → Iterator[Patient] + │ .stats() → prints patient/event counts + │ + ▼ + BaseTask subclass (__call__(patient) → List[Dict]) + │ .input_schema = {"feature": "processor_name", ...} + │ .output_schema = {"label": "binary" | "multiclass" | ...} + │ + dataset.set_task(task, num_workers=N) + │ + ▼ + SampleDataset ──── len(ds), ds[i], patient_to_index, record_to_index + │ Backed by LitData streaming files + │ Processors fitted during set_task, applied at load time + │ + get_dataloader(dataset, batch_size=32, shuffle=True) + │ + ▼ + Model(dataset, ...) ──── BaseModel subclass (RNN, Transformer, MLP, …) + │ EmbeddingModel routes features via processor.is_token() + │ forward(**batch) → {"loss", "y_prob", "y_true", "logit"} + │ + ▼ + Trainer(model, metrics=[...], device=...) + │ .train(train_dl, val_dl, test_dl, epochs=20, ...) + │ .evaluate(test_dl) → Dict[metric_name, value] + │ + ├──▶ Calibration (pyhealth.calib) + │ TemperatureScaling / HistogramBinning / KCal / … + │ LABEL / SCRIB / FavMac / … (conformal prediction sets) + │ + └──▶ Interpretability (pyhealth.interpret) + GradientSaliency / IntegratedGradients / DeepLift / SHAP / LIME / … + + +Stage 1: Raw Data → BaseDataset +--------------------------------- + +See :doc:`datasets` for the full reference. + +PyHealth reads raw CSV or Parquet files using Polars, joins tables according +to a ``config.yaml`` schema, and writes a compact +``global_event_df.parquet`` cache. On subsequent runs with the same +configuration it reads from the cache rather than re-parsing source files. + +**Native datasets** (MIMIC-III, MIMIC-IV, eICU, OMOP, and many others) have +built-in schemas — just pass a ``root`` path and a list of ``tables``: + +.. code-block:: python + + from pyhealth.datasets import MIMIC3Dataset + + if __name__ == '__main__': + dataset = MIMIC3Dataset( + root="/data/mimiciii/1.4", + tables=["diagnoses_icd", "procedures_icd", "prescriptions"], + cache_dir=".cache", + dev=True, # cap at 1 000 patients during exploration + ) + +**Custom datasets** subclass ``BaseDataset`` and provide a ``config_path`` +pointing to your own ``config.yaml``. If files need preprocessing (e.g. +merging multiple CSVs), define ``preprocess_(self, df)`` on the +subclass — it receives a narwhals LazyFrame and must return one. + +Key init params: ``root``, ``tables``, ``config_path`` (custom only), +``cache_dir``, ``num_workers``, ``dev``. + +A UUID derived from ``(root, tables, dataset_name, dev)`` is appended to +``cache_dir``, so different configurations never overwrite each other. + + +Stage 2: Patient and Event Objects +------------------------------------ + +See :doc:`data` for the full reference. + +Once a dataset is loaded, ``Patient.get_events()`` is the primary query +method: + +.. code-block:: python + + events = patient.get_events( + event_type="diagnoses_icd", # must match the table name in config.yaml + start=datetime(2020, 1, 1), + end=datetime(2020, 6, 1), + filters=[("icd_code", "==", "250.00")], + ) + +``Event`` attributes to keep in mind: + +- ``event.timestamp`` — always use this; PyHealth normalises ``charttime``, + ``admittime``, etc. into a single property. +- ``event.attr_dict`` / ``event["col_name"]`` / ``event.col_name`` — access + attribute values. All column names are **lowercased** at ingest time. + + +Stage 3: Task Definition → set_task +-------------------------------------- + +See :doc:`tasks` for the full reference. + +A ``BaseTask`` subclass defines three things: + +- ``task_name: str`` — must be assigned (not just annotated). +- ``input_schema`` / ``output_schema`` — dicts mapping sample keys to + processor string aliases (e.g. ``"sequence"``, ``"binary"``). +- ``__call__(self, patient) → List[Dict]`` — extracts features from one + ``Patient``; return ``[]`` to skip a patient. + +``dataset.set_task(task, num_workers=N)`` iterates all patients, collects +samples, fits processors, and writes LitData ``.ld`` streaming files to disk. +The result is a ``SampleDataset``. + +.. important:: + + All code calling ``set_task()`` must live inside + ``if __name__ == '__main__':``. PyHealth uses multiprocessing internally + and will crash without this guard. + + +Stage 4: Processors → SampleDataset +-------------------------------------- + +See :doc:`processors` for the full reference. + +When ``set_task()`` runs: + +1. ``SampleBuilder.fit(samples)`` — calls ``processor.fit(samples, field)`` + for every schema field. +2. ``SampleBuilder.transform(sample)`` — calls ``processor.process(value)`` + for every field, writing tensors to disk. + +The key signal for model routing is ``processor.is_token()``: + +- ``True`` → ``nn.Embedding`` (discrete token indices, e.g. medical codes) +- ``False`` → ``nn.Linear`` (continuous values, e.g. time series, images) + + +Stage 5: Model Initialization and Forward Pass +------------------------------------------------ + +See :doc:`models` for the full reference. + +.. code-block:: python + + from pyhealth.models import RNN + + model = RNN(dataset=sample_dataset, embedding_dim=128, hidden_dim=64) + +The model reads ``dataset.input_schema``, ``dataset.output_schema``, and +``dataset.input_processors`` to auto-build embedding layers and the output +head. **Always pass the ``SampleDataset`` (result of ``set_task()``), not +the raw ``BaseDataset``.** + +``model(**batch)`` where ``batch`` is a dict from the DataLoader. Must return +``{"loss", "y_prob", "y_true", "logit"}``. + +**Choosing a model:** + +- Mixed sequential + static features → ``MultimodalRNN`` +- Purely sequential codes → ``RNN`` or ``Transformer`` +- Static feature vector → ``MLP`` +- Time-stamped vitals with irregular intervals → ``StageNet`` +- Graph-structured features → ``GNN`` or ``GraphCare`` (see :doc:`graph`) + + +Stage 6: Training and Evaluation +---------------------------------- + +See :doc:`trainer` for the full reference. + +.. code-block:: python + + from pyhealth.trainer import Trainer + from pyhealth.datasets import get_dataloader + + train_dl = get_dataloader(train_ds, batch_size=32, shuffle=True) + val_dl = get_dataloader(val_ds, batch_size=32, shuffle=False) + test_dl = get_dataloader(test_ds, batch_size=32, shuffle=False) + + trainer = Trainer(model=model, metrics=["roc_auc_macro", "f1_macro"], device="cuda") + trainer.train(train_dl, val_dl, test_dl, epochs=20, + monitor="roc_auc_macro", monitor_criterion="max", patience=5) + + scores = trainer.evaluate(test_dl) + # → {"roc_auc_macro": 0.85, "loss": 0.3, ...} + +Split by patient to avoid data leakage: + +.. code-block:: python + + all_ids = list(sample_dataset.patient_to_index.keys()) + # ... split all_ids into train_ids / val_ids / test_ids ... + train_indices = [i for pid in train_ids for i in sample_dataset.patient_to_index[pid]] + train_ds = sample_dataset.subset(train_indices) + + +Common Pitfalls +---------------- + +.. list-table:: + :header-rows: 1 + :widths: 45 55 + + * - Mistake + - Fix + * - Missing ``if __name__ == '__main__':`` + - Wrap all ``set_task()`` / dataset loading code in this guard + * - ``event.charttime`` instead of ``event.timestamp`` + - Always use ``event.timestamp`` + * - Task sample key doesn't match ``input_schema`` + - Keys in ``__call__`` return dict must exactly match schema keys + * - ``dev=True`` during full training + - Only use ``dev=True`` during exploration; set ``dev=False`` for final runs + * - Passing ``BaseDataset`` to the model + - Pass ``SampleDataset`` (result of ``set_task()``) to the model + * - ``dataset.patients`` + - Does not exist; use ``dataset.unique_patient_ids`` + ``dataset.get_patient(id)`` diff --git a/docs/api/processors.rst b/docs/api/processors.rst index 3a9fb73de..a06e3c955 100644 --- a/docs/api/processors.rst +++ b/docs/api/processors.rst @@ -3,6 +3,12 @@ Processors Processors in PyHealth handle data preprocessing and transformation for healthcare predictive tasks. They convert raw data into tensors suitable for machine learning models. +Processors sit between :doc:`tasks` (which define *what* data to extract) and +:doc:`models` (which consume the resulting tensors). You rarely call processors +directly — they are configured through the ``input_schema`` and +``output_schema`` of a task and applied automatically during +``dataset.set_task()``. + Overview -------- diff --git a/docs/index.rst b/docs/index.rst index d2f6b69e9..37ec2e705 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -221,6 +221,7 @@ Quick Navigation :hidden: :caption: Documentation + api/overview api/data api/datasets api/graph From 80c0b38f0f548c612e72092e59a75d77e6513863 Mon Sep 17 00:00:00 2001 From: jhnwu3 Date: Fri, 13 Mar 2026 14:42:56 -0500 Subject: [PATCH 4/4] clean up and fix old details --- docs/about.rst | 8 ++++---- pyproject.toml | 2 +- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/docs/about.rst b/docs/about.rst index 8e28fa929..d758eaf0b 100644 --- a/docs/about.rst +++ b/docs/about.rst @@ -3,7 +3,7 @@ About us PyHealth is developed and maintained by a diverse community of researchers and practitioners. -Current Maintainers +Maintainers ------------------ `Zhenbang Wu `_ (Ph.D. Student @ University of Illinois Urbana-Champaign) @@ -12,11 +12,11 @@ Current Maintainers `Junyi Gao `_ (M.S. @ UIUC, Ph.D. Student @ University of Edinburgh) -Paul Landes (University of Illinois College of Medicine) +`Paul Landes `_ (University of Illinois College of Medicine) `Jimeng Sun `_ (Professor @ University of Illinois Urbana-Champaign) -Major Reviewers +Reviewers --------------- Eric Schrock (University of Illinois Urbana-Champaign) @@ -52,7 +52,7 @@ Muni Bondu *...and more members as the initiative continues to expand* -Alumni +Past Contributors ------ `Chaoqi Yang `_ (Ph.D. Student @ University of Illinois Urbana-Champaign) diff --git a/pyproject.toml b/pyproject.toml index 308e6b114..c1e0ebcc6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -15,7 +15,7 @@ authors = [ {name = "Zhen Lin"}, {name = "Benjamin Danek"}, {name = "Junyi Gao"}, - {name = "Paul Landes", email = "landes@mailc.net"}, + {name = "Paul Landes", email = "plande2@uic.edu"}, {name = "Jimeng Sun"}, ] description = "A Python library for healthcare AI"