diff --git a/docs/about.rst b/docs/about.rst
index 8e28fa929..d758eaf0b 100644
--- a/docs/about.rst
+++ b/docs/about.rst
@@ -3,7 +3,7 @@ About us
PyHealth is developed and maintained by a diverse community of researchers and practitioners.
-Current Maintainers
+Maintainers
------------------
`Zhenbang Wu `_ (Ph.D. Student @ University of Illinois Urbana-Champaign)
@@ -12,11 +12,11 @@ Current Maintainers
`Junyi Gao `_ (M.S. @ UIUC, Ph.D. Student @ University of Edinburgh)
-Paul Landes (University of Illinois College of Medicine)
+`Paul Landes `_ (University of Illinois College of Medicine)
`Jimeng Sun `_ (Professor @ University of Illinois Urbana-Champaign)
-Major Reviewers
+Reviewers
---------------
Eric Schrock (University of Illinois Urbana-Champaign)
@@ -52,7 +52,7 @@ Muni Bondu
*...and more members as the initiative continues to expand*
-Alumni
+Past Contributors
------
`Chaoqi Yang `_ (Ph.D. Student @ University of Illinois Urbana-Champaign)
diff --git a/docs/api/datasets.rst b/docs/api/datasets.rst
index 3412e5ac5..b02439d26 100644
--- a/docs/api/datasets.rst
+++ b/docs/api/datasets.rst
@@ -10,7 +10,7 @@ New to PyHealth datasets? Start here:
This tutorial covers:
-- How to load and work with different healthcare datasets (MIMIC-III, MIMIC-IV, eICU, etc.)
+- How to load and work with any PyHealth dataset (MIMIC-III, MIMIC-IV, eICU, OMOP, and many more)
- Understanding the ``BaseDataset`` structure and patient representation
- Parsing raw EHR data into standardized PyHealth format
- Accessing patient records, visits, and clinical events
@@ -22,6 +22,198 @@ This tutorial covers:
- Using openly available demo datasets (MIMIC-III Demo, MIMIC-IV Demo)
- Working with synthetic data for testing
+How PyHealth Loads Data
+------------------------
+
+When you initialise a dataset, PyHealth reads the raw CSV or Parquet files
+using Polars, joins the tables according to a YAML schema, and writes a
+compact ``global_event_df.parquet`` cache to disk. On subsequent runs with
+the same configuration it reads from cache rather than re-parsing the source
+files, so startup is fast.
+
+The result is a :class:`~pyhealth.datasets.BaseDataset` — a structured
+patient→event tree. It is different from a PyTorch Dataset: it has no integer
+length and you cannot index into it with ``dataset[i]``. Think of it as a
+queryable dictionary of patient records. To turn it into something a model
+can train on, you call ``dataset.set_task()`` (see :doc:`tasks`), which
+returns a :class:`~pyhealth.datasets.SampleDataset` that *is* indexable and
+DataLoader-ready.
+
+From BaseDataset to SampleDataset
+-----------------------------------
+
+``BaseDataset`` and ``SampleDataset`` serve different roles and are not
+interchangeable:
+
+- **BaseDataset** is a queryable patient registry. It holds the raw
+ patient→visit→event tree loaded from disk. You cannot index into it like a
+ list — it has no integer length and is not DataLoader-ready.
+- **SampleDataset** is a PyTorch-compatible streaming dataset returned by
+ ``dataset.set_task()``. Each element is a fully processed feature
+ dictionary that a model can consume directly.
+
+The conversion happens in one call:
+
+.. code-block:: python
+
+ import torch
+ from pyhealth.datasets import MIMIC3Dataset
+ from pyhealth.tasks import MortalityPredictionMIMIC3
+
+ dataset = MIMIC3Dataset(root="...", tables=["diagnoses_icd"])
+ samples = dataset.set_task(MortalityPredictionMIMIC3())
+ # `samples` is a SampleDataset — pass it straight to a DataLoader
+ loader = torch.utils.data.DataLoader(samples, batch_size=32)
+
+Under the hood, ``set_task()`` runs a ``SampleBuilder`` that fits feature
+processors (tokenisers, label encoders, etc.) across the full dataset, then
+writes compressed, chunked sample files to disk via
+`litdata `_. A companion
+``schema.pkl`` stores the fitted processors so the dataset can be reloaded in
+future runs without re-fitting.
+
+``SampleDataset`` also exposes two convenience lookups built during fitting:
+
+- ``samples.patient_to_index`` — maps a patient ID to all sample indices for
+ that patient.
+- ``samples.record_to_index`` — maps a visit/record ID to the sample indices
+ for that visit.
+
+For testing or small cohorts you can skip the disk step entirely using
+``InMemorySampleDataset``, which holds all processed samples in RAM and is
+returned by default from ``create_sample_dataset()``.
+
+.. note::
+ Building a custom dataset or bringing your own data?
+ See :doc:`../tutorials` (Tutorial 1) for a step-by-step walkthrough, and
+ the `config.yaml for Custom Datasets`_ section below for the schema format.
+
+Native Datasets vs Custom Datasets
+------------------------------------
+
+PyHealth includes native support for many standard healthcare databases — including
+MIMIC-III, MIMIC-IV, eICU, OMOP, and many others (see the full list in `Available Datasets`_
+below). All of these come with built-in schema definitions so you
+can load them with just a root path and a list of tables:
+
+.. code-block:: python
+
+ from pyhealth.datasets import MIMIC3Dataset
+
+ if __name__ == '__main__':
+ dataset = MIMIC3Dataset(
+ root="/data/physionet.org/files/mimiciii/1.4",
+ tables=["diagnoses_icd", "procedures_icd", "prescriptions"],
+ cache_dir=".cache",
+ dev=True, # use 1 000 patients while exploring
+ )
+
+For any other data source — a custom patient registry, an institutional cohort,
+or a non-EHR dataset — you create a subclass of ``BaseDataset`` and provide a
+``config.yaml`` file that describes your table structure.
+
+Initialization Parameters
+--------------------------
+
+- **root** — path to the directory containing the raw data files. For MIMIC-IV
+ specifically, use ``ehr_root`` instead of ``root``.
+- **tables** — the table names you want to load, e.g.
+ ``["diagnoses_icd", "labevents"]``. Only these tables will be accessible in
+ patient queries downstream.
+- **config_path** — path to your ``config.yaml``; needed for custom datasets.
+ Native datasets have this built in and ignore the parameter.
+- **cache_dir** — where to store the cached Parquet and LitData files. PyHealth
+ appends a UUID derived from your configuration, so different setups never
+ overwrite each other.
+- **num_workers** — parallel processes for data loading. Increasing this can
+ speed up ``set_task()`` on large datasets.
+- **dev** — when ``True``, PyHealth caps the dataset at 1 000 patients. This
+ is very useful during development because it makes each iteration complete in
+ seconds rather than minutes. Switch to ``dev=False`` for your final training
+ run.
+
+config.yaml for Custom Datasets
+---------------------------------
+
+If you are bringing your own data, the YAML file tells PyHealth which column
+is the patient identifier, which column is the timestamp, and which other
+columns to include as event attributes:
+
+.. code-block:: yaml
+
+ tables:
+ my_table:
+ file_path: relative/path/to/file.csv
+ patient_id: subject_id
+ timestamp: charttime
+ timestamp_format: "%Y-%m-%d %H:%M:%S"
+ attributes:
+ - icd_code
+ - value
+ - itemid
+ join: [] # optional table joins
+
+All attribute column names are lowercased internally, so ``ICD_CODE`` in
+your CSV becomes ``icd_code`` in your code.
+
+Querying Patients and Events
+-----------------------------
+
+Once a dataset is loaded, you can explore it using these methods:
+
+.. code-block:: python
+
+ dataset.unique_patient_ids # all patient IDs as a list of strings
+ dataset.get_patient("p001") # retrieve one Patient object
+ dataset.iter_patients() # iterate over all patients
+ dataset.stats() # print patient and event counts
+
+Patient records are accessed through ``get_events()``, which supports
+temporal filtering and attribute-level filters:
+
+.. code-block:: python
+
+ events = patient.get_events(
+ event_type="diagnoses_icd", # table name from your config
+ start=datetime(2020, 1, 1), # optional: exclude earlier events
+ end=datetime(2020, 6, 1), # optional: exclude later events
+ filters=[("icd_code", "==", "250.00")], # optional: attribute conditions
+ )
+
+Each event in the returned list has:
+
+- ``event.timestamp`` — a Python ``datetime`` object. PyHealth normalises all
+ timestamp columns (``charttime``, ``admittime``, etc.) into this single
+ property, so this is what you should use regardless of what the original
+ column was called.
+- ``event.icd_code``, ``event["icd_code"]``, ``event.attr_dict`` — different
+ ways to access the other attributes. All attribute names are lowercase.
+
+Things to Watch Out For
+------------------------
+
+A few patterns that commonly trip up new users:
+
+**BaseDataset vs SampleDataset.** Models expect a ``SampleDataset`` (the
+output of ``set_task()``), not the raw ``BaseDataset``. Passing the wrong one
+will raise an error. If you see an ``AttributeError`` about ``input_schema``
+or ``output_schema``, this is likely the cause.
+
+**Timestamp attribute names.** Writing ``event.charttime`` will raise an
+``AttributeError`` because PyHealth remaps that column to ``event.timestamp``.
+The same applies to ``admittime``, ``starttime``, or whatever the original
+column was called.
+
+**Column name casing.** PyHealth lowercases all column names at load time.
+Even if your source CSV has ``ICD_CODE``, you access it as ``event.icd_code``.
+
+**dev=True in production.** The ``dev`` flag is great for exploring data but
+it caps the dataset at 1 000 patients. Remember to switch to ``dev=False``
+before running a full training job.
+
+**Multiprocessing guard.** Scripts that call ``set_task()`` should wrap their
+top-level code in ``if __name__ == '__main__':``. See :doc:`tasks` for details.
+
Available Datasets
------------------
diff --git a/docs/api/graph.rst b/docs/api/graph.rst
index e69de29bb..164214285 100644
--- a/docs/api/graph.rst
+++ b/docs/api/graph.rst
@@ -0,0 +1,138 @@
+Graph
+=====
+
+The ``pyhealth.graph`` module lets you bring a healthcare knowledge graph into
+your PyHealth pipeline. Graph-based models like GraphCare and GNN can use
+relational medical knowledge — drug interactions, disease hierarchies,
+symptom–diagnosis links — to enrich patient representations beyond what
+raw EHR codes alone can capture.
+
+What Is a Knowledge Graph?
+---------------------------
+
+A knowledge graph encodes medical relationships as **(head, relation, tail)**
+triples. For example:
+
+- ``("aspirin", "treats", "headache")``
+- ``("metformin", "used_for", "type_2_diabetes")``
+- ``("ICD9:250", "is_a", "ICD9:249")``
+
+PyHealth does not ship a built-in graph — you bring triples from a source of
+your choice (UMLS, DrugBank, an ICD hierarchy, a custom ontology, etc.) and
+the :class:`~pyhealth.graph.KnowledgeGraph` class handles indexing, entity
+mappings, and k-hop subgraph extraction. The typical use case is querying the
+graph at training time: given a patient's active codes, extract the local
+subgraph around those codes and feed it to a graph-aware model.
+
+Getting Started
+---------------
+
+The simplest way to create a graph is to pass a list of triples directly:
+
+.. code-block:: python
+
+ from pyhealth.graph import KnowledgeGraph
+
+ triples = [
+ ("aspirin", "treats", "headache"),
+ ("headache", "symptom_of", "migraine"),
+ ("ibuprofen", "treats", "headache"),
+ ]
+ kg = KnowledgeGraph(triples=triples)
+ kg.stat()
+ # KnowledgeGraph: 4 entities, 2 relations, 3 triples
+
+For larger graphs it is more practical to load from a CSV or TSV file. The
+file should have columns named ``head``, ``relation``, and ``tail``:
+
+.. code-block:: python
+
+ kg = KnowledgeGraph(triples="path/to/medical_kg.tsv")
+
+Exploring the Graph
+-------------------
+
+Once built, you can inspect the graph and look up neighbours for any entity:
+
+.. code-block:: python
+
+ kg.num_entities # total unique entities
+ kg.num_relations # total unique relation types
+ kg.num_triples # total edges
+
+ kg.has_entity("aspirin") # True / False
+ kg.neighbors("aspirin") # list of (relation, tail) pairs
+
+ # Integer ID mappings used internally by PyG
+ kg.entity2id["aspirin"] # → int
+ kg.id2entity[0] # → entity name string
+
+Extracting Patient Subgraphs
+------------------------------
+
+The main reason to build a knowledge graph is to extract a patient-specific
+subgraph at training time. ``subgraph()`` returns all entities reachable
+within *n* hops of a set of seed codes, as a PyTorch Geometric ``Data``
+object:
+
+.. code-block:: python
+
+ patient_codes = ["ICD9:250.00", "NDC:0069-0105"]
+ subgraph = kg.subgraph(seed_entities=patient_codes, num_hops=2)
+
+.. note::
+
+ ``subgraph()`` requires `PyTorch Geometric `_
+ (``torch_geometric``). The graph can still be constructed and explored
+ without it — only subgraph extraction needs PyG.
+
+ Install with: ``pip install torch-geometric``
+
+Using with GraphProcessor in a Task
+-------------------------------------
+
+To feed subgraphs into a model automatically during data loading, pass a
+configured :class:`~pyhealth.processors.GraphProcessor` instance in your
+task's ``input_schema``. The processor will call ``kg.subgraph()`` for each
+patient sample:
+
+.. code-block:: python
+
+ from pyhealth.graph import KnowledgeGraph
+ from pyhealth.processors import GraphProcessor
+ from pyhealth.tasks import BaseTask
+
+ kg = KnowledgeGraph(triples="medical_kg.tsv")
+
+ class MyGraphTask(BaseTask):
+ task_name = "MyGraphTask"
+ input_schema = {
+ "conditions": "sequence",
+ "kg_subgraph": GraphProcessor(kg, num_hops=2),
+ }
+ output_schema = {"label": "binary"}
+
+ def __call__(self, patient):
+ ...
+
+Pre-computed Node Embeddings
+-----------------------------
+
+If you already have entity embeddings (e.g. from TransE or an LLM), you can
+attach them to the graph at construction time. The model can then use these
+as initial node features instead of learning them from scratch:
+
+.. code-block:: python
+
+ import torch
+
+ node_features = torch.randn(kg.num_entities, 64) # (num_entities, feat_dim)
+ kg = KnowledgeGraph(triples=triples, node_features=node_features)
+
+API Reference
+-------------
+
+.. toctree::
+ :maxdepth: 3
+
+ graph/pyhealth.graph.KnowledgeGraph
diff --git a/docs/api/models.rst b/docs/api/models.rst
index 2e647898b..7368dec94 100644
--- a/docs/api/models.rst
+++ b/docs/api/models.rst
@@ -1,9 +1,171 @@
Models
===============
-We implement the following models for supporting multiple healthcare predictive tasks.
+PyHealth models sit between the :doc:`processors` (which turn raw patient data
+into tensors) and the :doc:`trainer` (which runs the training loop). Each
+model takes a ``SampleDataset`` — the result of ``dataset.set_task()`` — as
+its first constructor argument, and uses it to automatically build the right
+embedding layers and output head for your task.
+
+One thing worth knowing up front: the ``SampleDataset`` carries fitted
+processor metadata that the model needs to configure itself. If you pass the
+raw ``BaseDataset`` instead you'll get an error, because it hasn't been
+processed into samples yet.
+
+Choosing a Model
+----------------
+
+The table below covers the most commonly used models and when each one fits
+best. If your features are a mix of sequential codes and static numeric
+vectors, ``MultimodalRNN`` is usually the easiest starting point because it
+routes each feature type automatically.
+
+.. list-table::
+ :header-rows: 1
+ :widths: 20 40 40
+
+ * - Model
+ - Good fit when…
+ - Notes
+ * - :doc:`models/pyhealth.models.RNN`
+ - Your features are sequences of medical codes (diagnoses, procedures, drugs) across visits
+ - One RNN per feature, hidden states concatenated; ``rnn_type`` can be ``"GRU"`` (default) or ``"LSTM"``
+ * - :doc:`models/pyhealth.models.Transformer`
+ - You have longer code histories and want attention to capture long-range dependencies
+ - Self-attention across the sequence; tends to work well when visit order matters
+ * - :doc:`models/pyhealth.models.MLP`
+ - Features are static numeric vectors (aggregated lab values, demographics)
+ - Fully connected; no notion of sequence order
+ * - ``MultimodalRNN``
+ - Features mix sequential codes with static tensors or multi-hot encodings
+ - Auto-routes sequential features to RNN layers and non-sequential features to linear layers; good default for EHR
+ * - :doc:`models/pyhealth.models.StageNet`
+ - You have time-stamped vital signs with irregular measurement intervals
+ - Requires ``StageNetProcessor`` or ``StageNetTensorProcessor`` in the task schema
+ * - :doc:`models/pyhealth.models.GNN`
+ - Features include graph-structured data
+ - Works with ``GraphProcessor``; see :doc:`graph` for setup
+ * - :doc:`models/pyhealth.models.GraphCare`
+ - You want to augment EHR codes with a medical knowledge graph
+ - Combines code sequences with a :class:`~pyhealth.graph.KnowledgeGraph`
+
+How BaseModel Works
+--------------------
+
+All PyHealth models inherit from ``BaseModel``, which itself inherits from
+PyTorch's ``nn.Module``. When you call ``MyModel(dataset=sample_ds)``, the
+base class reads the dataset's schemas and automatically sets:
+
+- ``self.feature_keys`` — the list of input field names from ``input_schema``
+- ``self.label_keys`` — the list of output field names from ``output_schema``
+- ``self.device`` — the compute device
+
+It also provides three helper methods that take care of the boilerplate that
+varies by task type:
+
+- ``self.get_output_size()`` returns the output dimension from the fitted
+ label processor, so you don't have to hard-code it.
+- ``self.get_loss_function()`` returns the right loss for the task: BCE for
+ binary and multilabel tasks, cross-entropy for multiclass, MSE for
+ regression.
+- ``self.prepare_y_prob(logits)`` applies sigmoid, softmax, or identity to
+ logits depending on the task, producing calibrated probabilities.
+
+The ``forward()`` method is expected to return a dictionary with four keys:
+``loss``, ``y_prob``, ``y_true``, and ``logit``. The Trainer reads all four.
+
+EmbeddingModel
+--------------
+
+:class:`~pyhealth.models.EmbeddingModel` is a helper that routes each input
+feature to the appropriate embedding layer based on how its processor works.
+Features from token-based processors (``SequenceProcessor``,
+``NestedSequenceProcessor``, and similar) get a learned ``nn.Embedding``
+lookup. Features from continuous processors (``TensorProcessor``,
+``TimeseriesProcessor``, ``MultiHotProcessor``) get a linear projection
+instead. You end up with a uniform embedding shape across all features:
+
+.. code-block:: python
+
+ self.embedding_model = EmbeddingModel(dataset, embedding_dim=128)
+ embedded = self.embedding_model(inputs, masks=masks)
+ # embedded[key] has shape (batch_size, seq_len, embedding_dim)
+
+Task Mode and Loss Functions
+-----------------------------
+
+PyHealth automatically selects the loss function and output activation based
+on the label processor in your task's ``output_schema``:
+
+.. list-table::
+ :header-rows: 1
+ :widths: 20 30 30
+
+ * - Output schema value
+ - Loss function
+ - ``y_prob`` shape and activation
+ * - ``"binary"``
+ - BCE with logits
+ - sigmoid → (batch, 1)
+ * - ``"multiclass"``
+ - Cross-entropy
+ - softmax → (batch, num_classes)
+ * - ``"multilabel"``
+ - BCE with logits
+ - sigmoid → (batch, num_labels)
+ * - ``"regression"``
+ - MSE
+ - identity → (batch, 1)
+
+Building a Custom Model
+-----------------------
+
+If none of the built-in models fit your architecture, you can subclass
+``BaseModel`` directly. The skeleton below shows the typical structure: build
+an ``EmbeddingModel`` in ``__init__``, unpack processor schemas in
+``forward``, pool or aggregate the embeddings, and return the four-key dict.
+
+.. code-block:: python
+
+ from pyhealth.models import BaseModel
+ from pyhealth.models.embedding import EmbeddingModel
+ import torch
+ import torch.nn as nn
+
+ class MyModel(BaseModel):
+ def __init__(self, dataset, embedding_dim=128):
+ super().__init__(dataset=dataset)
+ self.label_key = self.label_keys[0]
+ self.embedding_model = EmbeddingModel(dataset, embedding_dim)
+ self.fc = nn.Linear(embedding_dim * len(self.feature_keys),
+ self.get_output_size())
+
+ def forward(self, **kwargs):
+ inputs, masks = {}, {}
+ for key in self.feature_keys:
+ feature = kwargs[key]
+ if isinstance(feature, torch.Tensor):
+ feature = (feature,)
+ schema = self.dataset.input_processors[key].schema()
+ inputs[key] = feature[schema.index("value")]
+ if "mask" in schema:
+ masks[key] = feature[schema.index("mask")]
+
+ embedded = self.embedding_model(inputs, masks=masks)
+ pooled = [embedded[k].mean(dim=1) for k in self.feature_keys]
+ logits = self.fc(torch.cat(pooled, dim=1))
+
+ y_true = kwargs[self.label_key].to(self.device)
+ return {
+ "loss": self.get_loss_function()(logits, y_true),
+ "y_prob": self.prepare_y_prob(logits),
+ "y_true": y_true,
+ "logit": logits,
+ }
+
+API Reference
+-------------
-
.. toctree::
:maxdepth: 3
@@ -41,4 +203,4 @@ We implement the following models for supporting multiple healthcare predictive
models/pyhealth.models.VisionEmbeddingModel
models/pyhealth.models.TextEmbedding
models/pyhealth.models.BIOT
- models/pyhealth.models.unified_multimodal_embedding_docs
\ No newline at end of file
+ models/pyhealth.models.unified_multimodal_embedding_docs
diff --git a/docs/api/overview.rst b/docs/api/overview.rst
new file mode 100644
index 000000000..532eeb48b
--- /dev/null
+++ b/docs/api/overview.rst
@@ -0,0 +1,240 @@
+PyHealth Architecture Overview
+==============================
+
+This page describes how all PyHealth components connect, from raw data files
+to a trained, evaluated model. Every stage has its own dedicated reference
+page — this overview is here to show how they fit together.
+
+Pipeline at a Glance
+---------------------
+
+.. code-block:: text
+
+ Raw CSV / Parquet files
+ │
+ ▼
+ config.yaml (table schemas, patient_id col, timestamp col, attributes)
+ │
+ ▼
+ BaseDataset subclass ──── loads tables, caches as global_event_df.parquet
+ │ .unique_patient_ids → List[str]
+ │ .get_patient(id) → Patient
+ │ .iter_patients() → Iterator[Patient]
+ │ .stats() → prints patient/event counts
+ │
+ ▼
+ BaseTask subclass (__call__(patient) → List[Dict])
+ │ .input_schema = {"feature": "processor_name", ...}
+ │ .output_schema = {"label": "binary" | "multiclass" | ...}
+ │
+ dataset.set_task(task, num_workers=N)
+ │
+ ▼
+ SampleDataset ──── len(ds), ds[i], patient_to_index, record_to_index
+ │ Backed by LitData streaming files
+ │ Processors fitted during set_task, applied at load time
+ │
+ get_dataloader(dataset, batch_size=32, shuffle=True)
+ │
+ ▼
+ Model(dataset, ...) ──── BaseModel subclass (RNN, Transformer, MLP, …)
+ │ EmbeddingModel routes features via processor.is_token()
+ │ forward(**batch) → {"loss", "y_prob", "y_true", "logit"}
+ │
+ ▼
+ Trainer(model, metrics=[...], device=...)
+ │ .train(train_dl, val_dl, test_dl, epochs=20, ...)
+ │ .evaluate(test_dl) → Dict[metric_name, value]
+ │
+ ├──▶ Calibration (pyhealth.calib)
+ │ TemperatureScaling / HistogramBinning / KCal / …
+ │ LABEL / SCRIB / FavMac / … (conformal prediction sets)
+ │
+ └──▶ Interpretability (pyhealth.interpret)
+ GradientSaliency / IntegratedGradients / DeepLift / SHAP / LIME / …
+
+
+Stage 1: Raw Data → BaseDataset
+---------------------------------
+
+See :doc:`datasets` for the full reference.
+
+PyHealth reads raw CSV or Parquet files using Polars, joins tables according
+to a ``config.yaml`` schema, and writes a compact
+``global_event_df.parquet`` cache. On subsequent runs with the same
+configuration it reads from the cache rather than re-parsing source files.
+
+**Native datasets** (MIMIC-III, MIMIC-IV, eICU, OMOP, and many others) have
+built-in schemas — just pass a ``root`` path and a list of ``tables``:
+
+.. code-block:: python
+
+ from pyhealth.datasets import MIMIC3Dataset
+
+ if __name__ == '__main__':
+ dataset = MIMIC3Dataset(
+ root="/data/mimiciii/1.4",
+ tables=["diagnoses_icd", "procedures_icd", "prescriptions"],
+ cache_dir=".cache",
+ dev=True, # cap at 1 000 patients during exploration
+ )
+
+**Custom datasets** subclass ``BaseDataset`` and provide a ``config_path``
+pointing to your own ``config.yaml``. If files need preprocessing (e.g.
+merging multiple CSVs), define ``preprocess_(self, df)`` on the
+subclass — it receives a narwhals LazyFrame and must return one.
+
+Key init params: ``root``, ``tables``, ``config_path`` (custom only),
+``cache_dir``, ``num_workers``, ``dev``.
+
+A UUID derived from ``(root, tables, dataset_name, dev)`` is appended to
+``cache_dir``, so different configurations never overwrite each other.
+
+
+Stage 2: Patient and Event Objects
+------------------------------------
+
+See :doc:`data` for the full reference.
+
+Once a dataset is loaded, ``Patient.get_events()`` is the primary query
+method:
+
+.. code-block:: python
+
+ events = patient.get_events(
+ event_type="diagnoses_icd", # must match the table name in config.yaml
+ start=datetime(2020, 1, 1),
+ end=datetime(2020, 6, 1),
+ filters=[("icd_code", "==", "250.00")],
+ )
+
+``Event`` attributes to keep in mind:
+
+- ``event.timestamp`` — always use this; PyHealth normalises ``charttime``,
+ ``admittime``, etc. into a single property.
+- ``event.attr_dict`` / ``event["col_name"]`` / ``event.col_name`` — access
+ attribute values. All column names are **lowercased** at ingest time.
+
+
+Stage 3: Task Definition → set_task
+--------------------------------------
+
+See :doc:`tasks` for the full reference.
+
+A ``BaseTask`` subclass defines three things:
+
+- ``task_name: str`` — must be assigned (not just annotated).
+- ``input_schema`` / ``output_schema`` — dicts mapping sample keys to
+ processor string aliases (e.g. ``"sequence"``, ``"binary"``).
+- ``__call__(self, patient) → List[Dict]`` — extracts features from one
+ ``Patient``; return ``[]`` to skip a patient.
+
+``dataset.set_task(task, num_workers=N)`` iterates all patients, collects
+samples, fits processors, and writes LitData ``.ld`` streaming files to disk.
+The result is a ``SampleDataset``.
+
+.. important::
+
+ All code calling ``set_task()`` must live inside
+ ``if __name__ == '__main__':``. PyHealth uses multiprocessing internally
+ and will crash without this guard.
+
+
+Stage 4: Processors → SampleDataset
+--------------------------------------
+
+See :doc:`processors` for the full reference.
+
+When ``set_task()`` runs:
+
+1. ``SampleBuilder.fit(samples)`` — calls ``processor.fit(samples, field)``
+ for every schema field.
+2. ``SampleBuilder.transform(sample)`` — calls ``processor.process(value)``
+ for every field, writing tensors to disk.
+
+The key signal for model routing is ``processor.is_token()``:
+
+- ``True`` → ``nn.Embedding`` (discrete token indices, e.g. medical codes)
+- ``False`` → ``nn.Linear`` (continuous values, e.g. time series, images)
+
+
+Stage 5: Model Initialization and Forward Pass
+------------------------------------------------
+
+See :doc:`models` for the full reference.
+
+.. code-block:: python
+
+ from pyhealth.models import RNN
+
+ model = RNN(dataset=sample_dataset, embedding_dim=128, hidden_dim=64)
+
+The model reads ``dataset.input_schema``, ``dataset.output_schema``, and
+``dataset.input_processors`` to auto-build embedding layers and the output
+head. **Always pass the ``SampleDataset`` (result of ``set_task()``), not
+the raw ``BaseDataset``.**
+
+``model(**batch)`` where ``batch`` is a dict from the DataLoader. Must return
+``{"loss", "y_prob", "y_true", "logit"}``.
+
+**Choosing a model:**
+
+- Mixed sequential + static features → ``MultimodalRNN``
+- Purely sequential codes → ``RNN`` or ``Transformer``
+- Static feature vector → ``MLP``
+- Time-stamped vitals with irregular intervals → ``StageNet``
+- Graph-structured features → ``GNN`` or ``GraphCare`` (see :doc:`graph`)
+
+
+Stage 6: Training and Evaluation
+----------------------------------
+
+See :doc:`trainer` for the full reference.
+
+.. code-block:: python
+
+ from pyhealth.trainer import Trainer
+ from pyhealth.datasets import get_dataloader
+
+ train_dl = get_dataloader(train_ds, batch_size=32, shuffle=True)
+ val_dl = get_dataloader(val_ds, batch_size=32, shuffle=False)
+ test_dl = get_dataloader(test_ds, batch_size=32, shuffle=False)
+
+ trainer = Trainer(model=model, metrics=["roc_auc_macro", "f1_macro"], device="cuda")
+ trainer.train(train_dl, val_dl, test_dl, epochs=20,
+ monitor="roc_auc_macro", monitor_criterion="max", patience=5)
+
+ scores = trainer.evaluate(test_dl)
+ # → {"roc_auc_macro": 0.85, "loss": 0.3, ...}
+
+Split by patient to avoid data leakage:
+
+.. code-block:: python
+
+ all_ids = list(sample_dataset.patient_to_index.keys())
+ # ... split all_ids into train_ids / val_ids / test_ids ...
+ train_indices = [i for pid in train_ids for i in sample_dataset.patient_to_index[pid]]
+ train_ds = sample_dataset.subset(train_indices)
+
+
+Common Pitfalls
+----------------
+
+.. list-table::
+ :header-rows: 1
+ :widths: 45 55
+
+ * - Mistake
+ - Fix
+ * - Missing ``if __name__ == '__main__':``
+ - Wrap all ``set_task()`` / dataset loading code in this guard
+ * - ``event.charttime`` instead of ``event.timestamp``
+ - Always use ``event.timestamp``
+ * - Task sample key doesn't match ``input_schema``
+ - Keys in ``__call__`` return dict must exactly match schema keys
+ * - ``dev=True`` during full training
+ - Only use ``dev=True`` during exploration; set ``dev=False`` for final runs
+ * - Passing ``BaseDataset`` to the model
+ - Pass ``SampleDataset`` (result of ``set_task()``) to the model
+ * - ``dataset.patients``
+ - Does not exist; use ``dataset.unique_patient_ids`` + ``dataset.get_patient(id)``
diff --git a/docs/api/processors.rst b/docs/api/processors.rst
index 3a9fb73de..a06e3c955 100644
--- a/docs/api/processors.rst
+++ b/docs/api/processors.rst
@@ -3,6 +3,12 @@ Processors
Processors in PyHealth handle data preprocessing and transformation for healthcare predictive tasks. They convert raw data into tensors suitable for machine learning models.
+Processors sit between :doc:`tasks` (which define *what* data to extract) and
+:doc:`models` (which consume the resulting tensors). You rarely call processors
+directly — they are configured through the ``input_schema`` and
+``output_schema`` of a task and applied automatically during
+``dataset.set_task()``.
+
Overview
--------
diff --git a/docs/api/tasks.rst b/docs/api/tasks.rst
index 3ed1e1c97..d85d04bc3 100644
--- a/docs/api/tasks.rst
+++ b/docs/api/tasks.rst
@@ -67,6 +67,138 @@ After you define a task:
- Discover how to customize processor behavior with kwargs tuples
- Understand processor types for different data modalities (text, images, signals, etc.)
+Writing a Custom Task
+----------------------
+
+When a built-in task doesn't match your cohort or prediction target, you can
+define your own by subclassing :class:`~pyhealth.tasks.BaseTask`. The class
+needs three things: a name, input and output schemas, and a ``__call__``
+method that processes one patient at a time.
+
+.. code-block:: python
+
+ from pyhealth.tasks import BaseTask
+ from pyhealth.data import Patient
+ from typing import List, Dict, Any
+
+ class MyMortalityTask(BaseTask):
+ task_name: str = "MyMortalityTask"
+
+ input_schema: Dict[str, str] = {
+ "conditions": "sequence", # maps to SequenceProcessor
+ "procedures": "sequence",
+ }
+ output_schema: Dict[str, str] = {
+ "label": "binary" # maps to BinaryLabelProcessor
+ }
+
+ def __call__(self, patient: Patient) -> List[Dict[str, Any]]:
+ samples = []
+ for adm in patient.get_events("admissions"):
+ label = 1 if adm.hospital_expire_flag == "1" else 0
+
+ # Fetch historical diagnoses up to this admission
+ conditions = patient.get_events("diagnoses_icd", end=adm.timestamp)
+ cond_codes = [e.icd_code for e in conditions]
+
+ if not cond_codes:
+ continue
+
+ samples.append({
+ "conditions": [cond_codes], # wrapped in a list for the sequence processor
+ "procedures": [[]],
+ "label": label,
+ })
+ return samples
+
+The ``__call__`` method receives one ``Patient`` and returns a list of sample
+dictionaries. Each dictionary's keys should match the schemas you declared.
+Returning an empty list is fine — PyHealth simply skips that patient. Note
+that event attribute names are always lowercase (e.g. ``e.icd_code`` rather
+than ``e.ICD_CODE``) because PyHealth lowercases all column names at ingest
+time. Timestamps are accessed through ``event.timestamp`` rather than the
+original column name like ``charttime``, since PyHealth normalises them into
+a single property.
+
+Processor String Keys
+----------------------
+
+The string values in your schemas map to specific processor classes. Here is
+a quick reference:
+
+.. list-table::
+ :header-rows: 1
+ :widths: 25 35 40
+
+ * - String key
+ - Processor
+ - Typical use
+ * - ``"sequence"``
+ - ``SequenceProcessor``
+ - Diagnosis codes, procedure codes, drug names
+ * - ``"nested_sequence"``
+ - ``NestedSequenceProcessor``
+ - Cumulative visit history (drug recommendation, readmission)
+ * - ``"tensor"``
+ - ``TensorProcessor``
+ - Aggregated numeric values (e.g. last lab value per item)
+ * - ``"timeseries"``
+ - ``TimeseriesProcessor``
+ - Irregular time-series measurements
+ * - ``"multi_hot"``
+ - ``MultiHotProcessor``
+ - Demographics, comorbidity flags
+ * - ``"text"``
+ - ``TextProcessor``
+ - Clinical notes
+ * - ``"binary"``
+ - ``BinaryLabelProcessor``
+ - Binary classification label (0 / 1)
+ * - ``"multiclass"``
+ - ``MultiClassLabelProcessor``
+ - Multi-class label
+ * - ``"multilabel"``
+ - ``MultiLabelProcessor``
+ - Multi-label classification
+ * - ``"regression"``
+ - ``RegressionLabelProcessor``
+ - Continuous regression target
+
+How set_task() Works
+---------------------
+
+Calling ``dataset.set_task(task)`` iterates over every patient in the
+dataset, runs your ``__call__`` method on each one, fits all the processors
+on the collected samples, then serialises everything to disk as LitData
+``.ld`` files. The result is a :class:`~pyhealth.datasets.SampleDataset` that
+supports ``len()`` and index access, ready for a DataLoader.
+
+.. code-block:: python
+
+ sample_ds = dataset.set_task(MyMortalityTask(), num_workers=4)
+ len(sample_ds) # total ML samples across all patients
+ sample_ds[0] # a single sample dict with tensor values
+
+If you re-run ``set_task()`` with the same task and processor configuration,
+PyHealth detects the existing cache and skips reprocessing. During
+development it is useful to set ``dev=True`` on the dataset, which limits
+processing to 1 000 patients so iterations are fast.
+
+.. note::
+
+ **A note on multiprocessing.** ``set_task()`` can spawn worker processes
+ when ``num_workers > 1``. On macOS and Linux this requires the standard
+ Python multiprocessing guard around your top-level script:
+
+ .. code-block:: python
+
+ if __name__ == '__main__':
+ sample_ds = dataset.set_task(task, num_workers=4)
+
+ Without this guard, Python may try to re-import and re-run the script in
+ each worker process, leading to infinite recursion. This is a general
+ Python multiprocessing requirement, not specific to PyHealth.
+
Available Tasks
---------------
diff --git a/docs/api/trainer.rst b/docs/api/trainer.rst
index 039f22dc4..0a52af0f7 100644
--- a/docs/api/trainer.rst
+++ b/docs/api/trainer.rst
@@ -1,7 +1,102 @@
Trainer
-===================================
+=======
+
+:class:`~pyhealth.trainer.Trainer` handles the PyTorch training loop for you.
+Rather than writing your own epoch loop, loss backward pass, optimizer step,
+and metric evaluation, you hand the Trainer your model and data loaders and
+let it manage the details — including early stopping when validation
+performance plateaus and automatic reloading of the best checkpoint at the end.
+
+A Typical Training Run
+-----------------------
+
+Here is what a full training setup looks like. The data loaders come from
+``get_dataloader()`` in :mod:`pyhealth.datasets`, which knows how to work with
+PyHealth's LitData caching format:
+
+.. code-block:: python
+
+ from pyhealth.trainer import Trainer
+ from pyhealth.datasets import get_dataloader
+
+ train_loader = get_dataloader(train_ds, batch_size=32, shuffle=True)
+ val_loader = get_dataloader(val_ds, batch_size=32, shuffle=False)
+ test_loader = get_dataloader(test_ds, batch_size=32, shuffle=False)
+
+ trainer = Trainer(
+ model=model,
+ metrics=["roc_auc_macro", "pr_auc_macro", "f1_macro"],
+ device="cuda",
+ )
+
+ trainer.train(
+ train_dataloader=train_loader,
+ val_dataloader=val_loader,
+ test_dataloader=test_loader,
+ epochs=50,
+ monitor="roc_auc_macro",
+ monitor_criterion="max",
+ patience=10,
+ )
+
+ scores = trainer.evaluate(test_loader)
+ # {'roc_auc_macro': 0.85, 'pr_auc_macro': 0.79, 'f1_macro': 0.72, 'loss': 0.31}
+
+Setting Up the Trainer
+-----------------------
+
+``Trainer(model, metrics=None, device=None, enable_logging=True, output_path=None, exp_name=None)``
+
+- **model** — your instantiated PyHealth model.
+- **metrics** — the metric names you want computed at validation and test time
+ (e.g. ``["roc_auc_macro", "f1_macro"]``). See :doc:`metrics` for the full
+ list of supported strings.
+- **device** — ``"cuda"`` or ``"cpu"``; defaults to auto-detecting a GPU.
+- **enable_logging** — when enabled, the Trainer creates a timestamped folder
+ under ``output_path`` with a ``log.txt`` and model checkpoints.
+- **output_path** / **exp_name** — where and how to name the output folder.
+
+Controlling the Training Loop
+------------------------------
+
+``trainer.train()`` accepts these key arguments beyond the data loaders:
+
+- **epochs** — the maximum number of training epochs.
+- **optimizer_class** / **optimizer_params** — which optimizer to use and how
+ to configure it. Defaults to ``Adam`` with a learning rate of ``1e-3``.
+- **weight_decay** — L2 regularisation strength. Default ``0.0``.
+- **max_grad_norm** — if set, clips gradients to this norm before each update,
+ which can help stabilise training on noisy medical data.
+- **monitor** / **monitor_criterion** — the metric to watch on the validation
+ set (e.g. ``"roc_auc_macro"``) and whether higher is better (``"max"``) or
+ lower is better (``"min"``). The Trainer saves a checkpoint whenever this
+ metric improves.
+- **patience** — how many epochs without improvement to wait before stopping
+ early.
+- **load_best_model_at_last** — when ``True`` (the default), the Trainer
+ restores the best checkpoint at the end of training rather than keeping the
+ weights from the final epoch.
+
+Getting the Test Scores
+------------------------
+
+``trainer.train()`` prints test scores to the console when a
+``test_dataloader`` is provided, but it does not return them as a Python
+object. To capture results for downstream use, call ``evaluate()`` separately:
+
+.. code-block:: python
+
+ scores = trainer.evaluate(test_loader)
+ # scores is a plain dict, e.g. {'roc_auc_macro': 0.85, 'loss': 0.31}
+
+ import json
+ with open("results.json", "w") as f:
+ json.dump(scores, f, indent=2)
+
+API Reference
+-------------
.. autoclass:: pyhealth.trainer.Trainer
:members:
:undoc-members:
- :show-inheritance:
\ No newline at end of file
+ :show-inheritance:
diff --git a/docs/index.rst b/docs/index.rst
index 543676952..37ec2e705 100644
--- a/docs/index.rst
+++ b/docs/index.rst
@@ -12,7 +12,7 @@ Build, test, and deploy healthcare machine learning models with ease. PyHealth i
**Key Features**
- **Dramatically simpler**: Build any healthcare AI model in ~7 lines of code
-- **Blazing fast**: Up to 39× faster than pandas
+- **Blazing fast**: Up to 39× faster than pandas for task processing
- **Memory efficient**: Runs on 16GB laptops
- **True multimodal**: Unified API for EHR, medical images, biosignals, clinical text, and genomics
- **Production-ready**: 25+ pre-built models, 20+ tasks, 12+ datasets with comprehensive evaluation tools
@@ -221,6 +221,7 @@ Quick Navigation
:hidden:
:caption: Documentation
+ api/overview
api/data
api/datasets
api/graph
diff --git a/pyproject.toml b/pyproject.toml
index 308e6b114..c1e0ebcc6 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -15,7 +15,7 @@ authors = [
{name = "Zhen Lin"},
{name = "Benjamin Danek"},
{name = "Junyi Gao"},
- {name = "Paul Landes", email = "landes@mailc.net"},
+ {name = "Paul Landes", email = "plande2@uic.edu"},
{name = "Jimeng Sun"},
]
description = "A Python library for healthcare AI"