Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions docs/about.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ About us

PyHealth is developed and maintained by a diverse community of researchers and practitioners.

Current Maintainers
Maintainers
------------------

`Zhenbang Wu <https://zzachw.github.io/>`_ (Ph.D. Student @ University of Illinois Urbana-Champaign)
Expand All @@ -12,11 +12,11 @@ Current Maintainers

`Junyi Gao <http://aboutme.vixerunt.org/>`_ (M.S. @ UIUC, Ph.D. Student @ University of Edinburgh)

Paul Landes (University of Illinois College of Medicine)
`Paul Landes <https://scholar.google.com/citations?user=7xs2tnQAAAAJ&hl=en>`_ (University of Illinois College of Medicine)

`Jimeng Sun <http://sunlab.org/>`_ (Professor @ University of Illinois Urbana-Champaign)

Major Reviewers
Reviewers
---------------

Eric Schrock (University of Illinois Urbana-Champaign)
Expand Down Expand Up @@ -52,7 +52,7 @@ Muni Bondu

*...and more members as the initiative continues to expand*

Alumni
Past Contributors
------

`Chaoqi Yang <https://ycq091044.github.io//>`_ (Ph.D. Student @ University of Illinois Urbana-Champaign)
Expand Down
194 changes: 193 additions & 1 deletion docs/api/datasets.rst
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ New to PyHealth datasets? Start here:

This tutorial covers:

- How to load and work with different healthcare datasets (MIMIC-III, MIMIC-IV, eICU, etc.)
- How to load and work with any PyHealth dataset (MIMIC-III, MIMIC-IV, eICU, OMOP, and many more)
- Understanding the ``BaseDataset`` structure and patient representation
- Parsing raw EHR data into standardized PyHealth format
- Accessing patient records, visits, and clinical events
Expand All @@ -22,6 +22,198 @@ This tutorial covers:
- Using openly available demo datasets (MIMIC-III Demo, MIMIC-IV Demo)
- Working with synthetic data for testing

How PyHealth Loads Data
------------------------

When you initialise a dataset, PyHealth reads the raw CSV or Parquet files
using Polars, joins the tables according to a YAML schema, and writes a
compact ``global_event_df.parquet`` cache to disk. On subsequent runs with
the same configuration it reads from cache rather than re-parsing the source
files, so startup is fast.

The result is a :class:`~pyhealth.datasets.BaseDataset` — a structured
patient→event tree. It is different from a PyTorch Dataset: it has no integer
length and you cannot index into it with ``dataset[i]``. Think of it as a
queryable dictionary of patient records. To turn it into something a model
can train on, you call ``dataset.set_task()`` (see :doc:`tasks`), which
returns a :class:`~pyhealth.datasets.SampleDataset` that *is* indexable and
DataLoader-ready.

From BaseDataset to SampleDataset
-----------------------------------

``BaseDataset`` and ``SampleDataset`` serve different roles and are not
interchangeable:

- **BaseDataset** is a queryable patient registry. It holds the raw
patient→visit→event tree loaded from disk. You cannot index into it like a
list — it has no integer length and is not DataLoader-ready.
- **SampleDataset** is a PyTorch-compatible streaming dataset returned by
``dataset.set_task()``. Each element is a fully processed feature
dictionary that a model can consume directly.

The conversion happens in one call:

.. code-block:: python

import torch
from pyhealth.datasets import MIMIC3Dataset
from pyhealth.tasks import MortalityPredictionMIMIC3

dataset = MIMIC3Dataset(root="...", tables=["diagnoses_icd"])
samples = dataset.set_task(MortalityPredictionMIMIC3())
# `samples` is a SampleDataset — pass it straight to a DataLoader
loader = torch.utils.data.DataLoader(samples, batch_size=32)

Under the hood, ``set_task()`` runs a ``SampleBuilder`` that fits feature
processors (tokenisers, label encoders, etc.) across the full dataset, then
writes compressed, chunked sample files to disk via
`litdata <https://github.com/Lightning-AI/litdata>`_. A companion
``schema.pkl`` stores the fitted processors so the dataset can be reloaded in
future runs without re-fitting.

``SampleDataset`` also exposes two convenience lookups built during fitting:

- ``samples.patient_to_index`` — maps a patient ID to all sample indices for
that patient.
- ``samples.record_to_index`` — maps a visit/record ID to the sample indices
for that visit.

For testing or small cohorts you can skip the disk step entirely using
``InMemorySampleDataset``, which holds all processed samples in RAM and is
returned by default from ``create_sample_dataset()``.

.. note::
Building a custom dataset or bringing your own data?
See :doc:`../tutorials` (Tutorial 1) for a step-by-step walkthrough, and
the `config.yaml for Custom Datasets`_ section below for the schema format.

Native Datasets vs Custom Datasets
------------------------------------

PyHealth includes native support for many standard healthcare databases — including
MIMIC-III, MIMIC-IV, eICU, OMOP, and many others (see the full list in `Available Datasets`_
below). All of these come with built-in schema definitions so you
can load them with just a root path and a list of tables:

.. code-block:: python

from pyhealth.datasets import MIMIC3Dataset

if __name__ == '__main__':
dataset = MIMIC3Dataset(
root="/data/physionet.org/files/mimiciii/1.4",
tables=["diagnoses_icd", "procedures_icd", "prescriptions"],
cache_dir=".cache",
dev=True, # use 1 000 patients while exploring
)

For any other data source — a custom patient registry, an institutional cohort,
or a non-EHR dataset — you create a subclass of ``BaseDataset`` and provide a
``config.yaml`` file that describes your table structure.

Initialization Parameters
--------------------------

- **root** — path to the directory containing the raw data files. For MIMIC-IV
specifically, use ``ehr_root`` instead of ``root``.
- **tables** — the table names you want to load, e.g.
``["diagnoses_icd", "labevents"]``. Only these tables will be accessible in
patient queries downstream.
- **config_path** — path to your ``config.yaml``; needed for custom datasets.
Native datasets have this built in and ignore the parameter.
- **cache_dir** — where to store the cached Parquet and LitData files. PyHealth
appends a UUID derived from your configuration, so different setups never
overwrite each other.
- **num_workers** — parallel processes for data loading. Increasing this can
speed up ``set_task()`` on large datasets.
- **dev** — when ``True``, PyHealth caps the dataset at 1 000 patients. This
is very useful during development because it makes each iteration complete in
seconds rather than minutes. Switch to ``dev=False`` for your final training
run.

config.yaml for Custom Datasets
---------------------------------

If you are bringing your own data, the YAML file tells PyHealth which column
is the patient identifier, which column is the timestamp, and which other
columns to include as event attributes:

.. code-block:: yaml

tables:
my_table:
file_path: relative/path/to/file.csv
patient_id: subject_id
timestamp: charttime
timestamp_format: "%Y-%m-%d %H:%M:%S"
attributes:
- icd_code
- value
- itemid
join: [] # optional table joins

All attribute column names are lowercased internally, so ``ICD_CODE`` in
your CSV becomes ``icd_code`` in your code.

Querying Patients and Events
-----------------------------

Once a dataset is loaded, you can explore it using these methods:

.. code-block:: python

dataset.unique_patient_ids # all patient IDs as a list of strings
dataset.get_patient("p001") # retrieve one Patient object
dataset.iter_patients() # iterate over all patients
dataset.stats() # print patient and event counts

Patient records are accessed through ``get_events()``, which supports
temporal filtering and attribute-level filters:

.. code-block:: python

events = patient.get_events(
event_type="diagnoses_icd", # table name from your config
start=datetime(2020, 1, 1), # optional: exclude earlier events
end=datetime(2020, 6, 1), # optional: exclude later events
filters=[("icd_code", "==", "250.00")], # optional: attribute conditions
)

Each event in the returned list has:

- ``event.timestamp`` — a Python ``datetime`` object. PyHealth normalises all
timestamp columns (``charttime``, ``admittime``, etc.) into this single
property, so this is what you should use regardless of what the original
column was called.
- ``event.icd_code``, ``event["icd_code"]``, ``event.attr_dict`` — different
ways to access the other attributes. All attribute names are lowercase.

Things to Watch Out For
------------------------

A few patterns that commonly trip up new users:

**BaseDataset vs SampleDataset.** Models expect a ``SampleDataset`` (the
output of ``set_task()``), not the raw ``BaseDataset``. Passing the wrong one
will raise an error. If you see an ``AttributeError`` about ``input_schema``
or ``output_schema``, this is likely the cause.

**Timestamp attribute names.** Writing ``event.charttime`` will raise an
``AttributeError`` because PyHealth remaps that column to ``event.timestamp``.
The same applies to ``admittime``, ``starttime``, or whatever the original
column was called.

**Column name casing.** PyHealth lowercases all column names at load time.
Even if your source CSV has ``ICD_CODE``, you access it as ``event.icd_code``.

**dev=True in production.** The ``dev`` flag is great for exploring data but
it caps the dataset at 1 000 patients. Remember to switch to ``dev=False``
before running a full training job.

**Multiprocessing guard.** Scripts that call ``set_task()`` should wrap their
top-level code in ``if __name__ == '__main__':``. See :doc:`tasks` for details.

Available Datasets
------------------

Expand Down
138 changes: 138 additions & 0 deletions docs/api/graph.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
Graph
=====

The ``pyhealth.graph`` module lets you bring a healthcare knowledge graph into
your PyHealth pipeline. Graph-based models like GraphCare and GNN can use
relational medical knowledge — drug interactions, disease hierarchies,
symptom–diagnosis links — to enrich patient representations beyond what
raw EHR codes alone can capture.

What Is a Knowledge Graph?
---------------------------

A knowledge graph encodes medical relationships as **(head, relation, tail)**
triples. For example:

- ``("aspirin", "treats", "headache")``
- ``("metformin", "used_for", "type_2_diabetes")``
- ``("ICD9:250", "is_a", "ICD9:249")``

PyHealth does not ship a built-in graph — you bring triples from a source of
your choice (UMLS, DrugBank, an ICD hierarchy, a custom ontology, etc.) and
the :class:`~pyhealth.graph.KnowledgeGraph` class handles indexing, entity
mappings, and k-hop subgraph extraction. The typical use case is querying the
graph at training time: given a patient's active codes, extract the local
subgraph around those codes and feed it to a graph-aware model.

Getting Started
---------------

The simplest way to create a graph is to pass a list of triples directly:

.. code-block:: python

from pyhealth.graph import KnowledgeGraph

triples = [
("aspirin", "treats", "headache"),
("headache", "symptom_of", "migraine"),
("ibuprofen", "treats", "headache"),
]
kg = KnowledgeGraph(triples=triples)
kg.stat()
# KnowledgeGraph: 4 entities, 2 relations, 3 triples

For larger graphs it is more practical to load from a CSV or TSV file. The
file should have columns named ``head``, ``relation``, and ``tail``:

.. code-block:: python

kg = KnowledgeGraph(triples="path/to/medical_kg.tsv")

Exploring the Graph
-------------------

Once built, you can inspect the graph and look up neighbours for any entity:

.. code-block:: python

kg.num_entities # total unique entities
kg.num_relations # total unique relation types
kg.num_triples # total edges

kg.has_entity("aspirin") # True / False
kg.neighbors("aspirin") # list of (relation, tail) pairs

# Integer ID mappings used internally by PyG
kg.entity2id["aspirin"] # → int
kg.id2entity[0] # → entity name string

Extracting Patient Subgraphs
------------------------------

The main reason to build a knowledge graph is to extract a patient-specific
subgraph at training time. ``subgraph()`` returns all entities reachable
within *n* hops of a set of seed codes, as a PyTorch Geometric ``Data``
object:

.. code-block:: python

patient_codes = ["ICD9:250.00", "NDC:0069-0105"]
subgraph = kg.subgraph(seed_entities=patient_codes, num_hops=2)

.. note::

``subgraph()`` requires `PyTorch Geometric <https://pyg.org/>`_
(``torch_geometric``). The graph can still be constructed and explored
without it — only subgraph extraction needs PyG.

Install with: ``pip install torch-geometric``

Using with GraphProcessor in a Task
-------------------------------------

To feed subgraphs into a model automatically during data loading, pass a
configured :class:`~pyhealth.processors.GraphProcessor` instance in your
task's ``input_schema``. The processor will call ``kg.subgraph()`` for each
patient sample:

.. code-block:: python

from pyhealth.graph import KnowledgeGraph
from pyhealth.processors import GraphProcessor
from pyhealth.tasks import BaseTask

kg = KnowledgeGraph(triples="medical_kg.tsv")

class MyGraphTask(BaseTask):
task_name = "MyGraphTask"
input_schema = {
"conditions": "sequence",
"kg_subgraph": GraphProcessor(kg, num_hops=2),
}
output_schema = {"label": "binary"}

def __call__(self, patient):
...

Pre-computed Node Embeddings
-----------------------------

If you already have entity embeddings (e.g. from TransE or an LLM), you can
attach them to the graph at construction time. The model can then use these
as initial node features instead of learning them from scratch:

.. code-block:: python

import torch

node_features = torch.randn(kg.num_entities, 64) # (num_entities, feat_dim)
kg = KnowledgeGraph(triples=triples, node_features=node_features)

API Reference
-------------

.. toctree::
:maxdepth: 3

graph/pyhealth.graph.KnowledgeGraph
Loading
Loading