From 8f5e9820e6af14100a1db0b5081e6460749c1af1 Mon Sep 17 00:00:00 2001 From: yousefkotp Date: Mon, 6 Apr 2026 14:58:59 -0400 Subject: [PATCH 1/8] chore: prepare v1.1.0 release plumbing --- .github/workflows/publish.yml | 44 +++++++++++ README.md | 22 ++++-- atlas_patch/__init__.py | 21 ++++- atlas_patch/cli.py | 73 ++++++++++-------- atlas_patch/models/patch/__init__.py | 81 ++++++++------------ atlas_patch/models/patch/chief_ctranspath.py | 12 ++- atlas_patch/models/patch/quilt.py | 3 +- atlas_patch/models/patch/registry.py | 7 ++ atlas_patch/services/__init__.py | 31 ++++++-- pyproject.toml | 31 +++++--- 10 files changed, 215 insertions(+), 110 deletions(-) create mode 100644 .github/workflows/publish.yml diff --git a/.github/workflows/publish.yml b/.github/workflows/publish.yml new file mode 100644 index 0000000..36f1158 --- /dev/null +++ b/.github/workflows/publish.yml @@ -0,0 +1,44 @@ +name: publish + +on: + release: + types: [published] + +jobs: + build: + runs-on: ubuntu-latest + permissions: + contents: read + steps: + - uses: actions/checkout@v4 + - uses: actions/setup-python@v5 + with: + python-version: "3.12" + - name: Install build tooling + run: python -m pip install --upgrade build twine + - name: Build distributions + run: python -m build + - name: Validate distributions + run: python -m twine check dist/* + - name: Upload distributions + uses: actions/upload-artifact@v4 + with: + name: python-dist + path: dist/* + + publish: + needs: build + runs-on: ubuntu-latest + environment: pypi + permissions: + id-token: write + steps: + - name: Download distributions + uses: actions/download-artifact@v4 + with: + name: python-dist + path: dist + - name: Publish to PyPI + uses: pypa/gh-action-pypi-publish@release/v1 + with: + packages-dir: dist diff --git a/README.md b/README.md index 2c72570..b49e166 100644 --- a/README.md +++ b/README.md @@ -64,7 +64,7 @@ ### Quick Install (Recommended) ```bash -# Install AtlasPatch +# Install base AtlasPatch pip install atlas-patch # Install SAM2 (required for tissue segmentation) @@ -96,7 +96,15 @@ Before installing AtlasPatch, you need the OpenSlide system library: ### Optional Encoder Dependencies -Some feature extractors require additional dependencies that must be installed separately: +AtlasPatch keeps model-specific dependencies out of the base install. + +Use the optional extra below if you want the broader built-in patch encoder registry: + +```bash +pip install "atlas-patch[patch-encoders]" +``` + +Some encoders also require upstream project packages that must still be installed separately: ```bash # For CONCH encoder (conch_v1, conch_v15) @@ -106,7 +114,7 @@ pip install git+https://github.com/Mahmoodlab/CONCH.git pip install git+https://github.com/lilab-stanford/MUSK.git ``` -These are only needed if you plan to use those specific encoders. +These installs are only needed if you plan to use those specific encoders. ### Alternative Installation Methods @@ -307,8 +315,8 @@ All visualization outputs are saved under `/visualization/`. | `--save-images` | Off | Export each patch as a PNG file under `/images//`. | | `--recursive` | Off | Walk subdirectories when `WSI_PATH` is a directory. | | `--mpp-csv` | None | Path to a CSV file with `wsi,mpp` columns to override microns-per-pixel when slide metadata is missing or incorrect. | -| `--skip-existing` | Off | Skip slides that already have an output H5 file. | -| `--force` | Off | Overwrite existing output files. | +| `--skip-existing` | On | Skip slides that already have an output H5 file. This is the default behavior. | +| `--force` | Off | Reprocess slides even when an output H5 file already exists. | | `--verbose`, `-v` | Off | Enable debug logging and disable the progress bar. | | `--write-batch` | `8192` | Number of coordinate rows to buffer before flushing to H5. Tune for RAM vs. I/O trade-off. | @@ -641,10 +649,10 @@ If your format isn't supported, consider converting it to a supported format or
How do I skip already processed slides? -Use the `--skip-existing` flag to skip slides that already have an output H5 file: +`process` and `segment-and-get-coords` already skip existing per-slide H5 outputs by default. Use `--force` when you want to overwrite them: ```bash -atlaspatch process /path/to/slides --output ./output --skip-existing +atlaspatch process /path/to/slides --output ./output --force ```
diff --git a/atlas_patch/__init__.py b/atlas_patch/__init__.py index 2923b7d..65a1fe8 100644 --- a/atlas_patch/__init__.py +++ b/atlas_patch/__init__.py @@ -1,6 +1,21 @@ -"""AtlasPatch module.""" +"""AtlasPatch package.""" -from . import core, services +from __future__ import annotations -__version__ = "1.0.0.post5" +from importlib import import_module +from typing import Any + +__version__ = "1.1.0" __all__ = ["core", "services", "__version__"] + + +def __getattr__(name: str) -> Any: + if name in {"core", "services"}: + module = import_module(f"{__name__}.{name}") + globals()[name] = module + return module + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") + + +def __dir__() -> list[str]: + return sorted(__all__) diff --git a/atlas_patch/cli.py b/atlas_patch/cli.py index 89c2a9f..f2b3556 100644 --- a/atlas_patch/cli.py +++ b/atlas_patch/cli.py @@ -5,32 +5,11 @@ from pathlib import Path import click -import torch -from tqdm import tqdm - -from atlas_patch.core.config import ( - AppConfig, - ExtractionConfig, - FeatureExtractionConfig, - OutputConfig, - ProcessingConfig, - SegmentationConfig, - VisualizationConfig, -) -from atlas_patch.core.models import Slide -from atlas_patch.models.patch import PatchFeatureExtractorRegistry, build_default_registry -from atlas_patch.models.patch.custom import register_feature_extractors_from_module -from atlas_patch.orchestration.runner import ProcessingRunner -from atlas_patch.services.extraction import PatchExtractionService -from atlas_patch.services.feature_embedding import PatchFeatureEmbeddingService, resolve_feature_dtype -from atlas_patch.services.mpp import CSVMPPResolver -from atlas_patch.services.segmentation import SAM2SegmentationService -from atlas_patch.services.visualization import DefaultVisualizationService -from atlas_patch.services.wsi_loader import DefaultWSILoader + +from atlas_patch import __version__ from atlas_patch.utils import ( configure_logging, install_embedding_log_filter, - parse_feature_list, ) from atlas_patch.utils.params import get_wsi_files from atlas_patch.utils.visualization import visualize_mask_on_thumbnail @@ -46,10 +25,6 @@ def _default_config_path() -> Path: return Path(__file__).resolve().parent / "configs" / "sam2.1_hiera_t.yaml" - -FEATURE_EXTRACTOR_CHOICES = build_default_registry(device="cpu").available() - - # Shared option sets ----------------------------------------------------------- _COMMON_OPTIONS: list = [ click.argument("wsi_path", type=click.Path(exists=True)), @@ -154,9 +129,11 @@ def _default_config_path() -> Path: "--feature-extractors", required=True, type=str, - help="Space/comma separated feature extractors to run (available: " - + ", ".join(FEATURE_EXTRACTOR_CHOICES) - + "; add more via --feature-plugin).", + help=( + "Space/comma separated feature extractors to run. " + "The registry is resolved lazily at runtime; install `atlas-patch[patch-encoders]` " + "for optional timm/transformers/open-clip based models and add more via --feature-plugin." + ), ), click.option( "--feature-batch-size", @@ -233,6 +210,24 @@ def _run_pipeline( feature_cfg: FeatureExtractionConfig | None = None, registry: PatchFeatureExtractorRegistry | None = None, ) -> tuple[list, list]: + from tqdm import tqdm + + from atlas_patch.core.config import ( + AppConfig, + ExtractionConfig, + OutputConfig, + ProcessingConfig, + SegmentationConfig, + VisualizationConfig, + ) + from atlas_patch.orchestration.runner import ProcessingRunner + from atlas_patch.services.extraction import PatchExtractionService + from atlas_patch.services.feature_embedding import PatchFeatureEmbeddingService + from atlas_patch.services.mpp import CSVMPPResolver + from atlas_patch.services.segmentation import SAM2SegmentationService + from atlas_patch.services.visualization import DefaultVisualizationService + from atlas_patch.services.wsi_loader import DefaultWSILoader + configure_logging(verbose) processing_cfg = ProcessingConfig( @@ -336,6 +331,14 @@ def _run_tissue_visualization( mpp_csv: str | None, verbose: bool, ) -> tuple[list[tuple[Slide, Path]], list[tuple[Slide, Exception | str]]]: + from tqdm import tqdm + + from atlas_patch.core.config import ProcessingConfig, SegmentationConfig, VisualizationConfig + from atlas_patch.core.models import Slide + from atlas_patch.services.mpp import CSVMPPResolver + from atlas_patch.services.segmentation import SAM2SegmentationService + from atlas_patch.services.wsi_loader import DefaultWSILoader + configure_logging(verbose) processing_cfg = ProcessingConfig( @@ -464,7 +467,7 @@ def _echo_mask_results( @click.group() -@click.version_option(version="0.2.0") +@click.version_option(version=__version__) def cli(): """AtlasPatch CLI. @@ -612,6 +615,14 @@ def process( feature_plugins: tuple[str, ...], ): """Run segmentation, patch extraction, and feature embedding into a single H5.""" + import torch + + from atlas_patch.core.config import FeatureExtractionConfig + from atlas_patch.models.patch import build_default_registry + from atlas_patch.models.patch.custom import register_feature_extractors_from_module + from atlas_patch.services.feature_embedding import resolve_feature_dtype + from atlas_patch.utils import parse_feature_list + feat_device = feature_device.lower() if feature_device else device.lower() torch_device = torch.device(feat_device) dtype = resolve_feature_dtype(torch_device, feature_precision.lower()) diff --git a/atlas_patch/models/patch/__init__.py b/atlas_patch/models/patch/__init__.py index 8004ef5..0e36c12 100644 --- a/atlas_patch/models/patch/__init__.py +++ b/atlas_patch/models/patch/__init__.py @@ -1,32 +1,10 @@ from __future__ import annotations +from importlib import import_module + import torch -from atlas_patch.models.patch.biomedclip import register_biomedclip_model -from atlas_patch.models.patch.clip import register_openai_clip_models -from atlas_patch.models.patch.conch import register_conch_models -from atlas_patch.models.patch.convnext import register_convnexts -from atlas_patch.models.patch.chief_ctranspath import register_chief_ctranspath_model -from atlas_patch.models.patch.dinov2 import register_dinov2_models -from atlas_patch.models.patch.dinov3 import register_dinov3_models -from atlas_patch.models.patch.gigapath import register_prov_gigapath_model -from atlas_patch.models.patch.hibou import register_hibou_models -from atlas_patch.models.patch.hoptimus import register_hoptimus_models -from atlas_patch.models.patch.lunit import register_lunit_models -from atlas_patch.models.patch.medsiglip import register_medsiglip_model -from atlas_patch.models.patch.midnight import register_midnight_model -from atlas_patch.models.patch.musk import register_musk_model -from atlas_patch.models.patch.omiclip import register_omiclip_model -from atlas_patch.models.patch.openmidnight import register_openmidnight_model -from atlas_patch.models.patch.pathorchestra import register_pathorchestra_model -from atlas_patch.models.patch.phikon import register_phikon_models -from atlas_patch.models.patch.plip import register_plip_model -from atlas_patch.models.patch.quilt import register_quilt_models from atlas_patch.models.patch.registry import PatchFeatureExtractorRegistry -from atlas_patch.models.patch.resnet import register_resnets -from atlas_patch.models.patch.uni import register_uni_models -from atlas_patch.models.patch.virchow import register_virchow_models -from atlas_patch.models.patch.vit import register_vits from atlas_patch.models.patch.custom import ( CustomEncoderComponents, CustomEncoderLoader, @@ -43,6 +21,33 @@ "register_feature_extractors_from_module", ] +_REGISTRARS: tuple[tuple[str, str], ...] = ( + ("atlas_patch.models.patch.resnet", "register_resnets"), + ("atlas_patch.models.patch.convnext", "register_convnexts"), + ("atlas_patch.models.patch.vit", "register_vits"), + ("atlas_patch.models.patch.dinov2", "register_dinov2_models"), + ("atlas_patch.models.patch.dinov3", "register_dinov3_models"), + ("atlas_patch.models.patch.clip", "register_openai_clip_models"), + ("atlas_patch.models.patch.conch", "register_conch_models"), + ("atlas_patch.models.patch.omiclip", "register_omiclip_model"), + ("atlas_patch.models.patch.quilt", "register_quilt_models"), + ("atlas_patch.models.patch.uni", "register_uni_models"), + ("atlas_patch.models.patch.lunit", "register_lunit_models"), + ("atlas_patch.models.patch.plip", "register_plip_model"), + ("atlas_patch.models.patch.medsiglip", "register_medsiglip_model"), + ("atlas_patch.models.patch.musk", "register_musk_model"), + ("atlas_patch.models.patch.openmidnight", "register_openmidnight_model"), + ("atlas_patch.models.patch.pathorchestra", "register_pathorchestra_model"), + ("atlas_patch.models.patch.hoptimus", "register_hoptimus_models"), + ("atlas_patch.models.patch.hibou", "register_hibou_models"), + ("atlas_patch.models.patch.biomedclip", "register_biomedclip_model"), + ("atlas_patch.models.patch.phikon", "register_phikon_models"), + ("atlas_patch.models.patch.virchow", "register_virchow_models"), + ("atlas_patch.models.patch.gigapath", "register_prov_gigapath_model"), + ("atlas_patch.models.patch.midnight", "register_midnight_model"), + ("atlas_patch.models.patch.chief_ctranspath", "register_chief_ctranspath_model"), +) + def build_default_registry( *, @@ -53,28 +58,8 @@ def build_default_registry( """Factory that registers the built-in extractors.""" dev = torch.device(device) registry = PatchFeatureExtractorRegistry() - register_resnets(registry, device=dev, num_workers=num_workers, dtype=dtype) - register_convnexts(registry, device=dev, num_workers=num_workers, dtype=dtype) - register_vits(registry, device=dev, num_workers=num_workers, dtype=dtype) - register_dinov2_models(registry, device=dev, num_workers=num_workers, dtype=dtype) - register_dinov3_models(registry, device=dev, num_workers=num_workers, dtype=dtype) - register_openai_clip_models(registry, device=dev, num_workers=num_workers, dtype=dtype) - register_conch_models(registry, device=dev, num_workers=num_workers, dtype=dtype) - register_omiclip_model(registry, device=dev, num_workers=num_workers, dtype=dtype) - register_quilt_models(registry, device=dev, num_workers=num_workers, dtype=dtype) - register_uni_models(registry, device=dev, num_workers=num_workers, dtype=dtype) - register_lunit_models(registry, device=dev, num_workers=num_workers, dtype=dtype) - register_plip_model(registry, device=dev, num_workers=num_workers, dtype=dtype) - register_medsiglip_model(registry, device=dev, num_workers=num_workers, dtype=dtype) - register_musk_model(registry, device=dev, num_workers=num_workers, dtype=dtype) - register_openmidnight_model(registry, device=dev, num_workers=num_workers, dtype=dtype) - register_pathorchestra_model(registry, device=dev, num_workers=num_workers, dtype=dtype) - register_hoptimus_models(registry, device=dev, num_workers=num_workers, dtype=dtype) - register_hibou_models(registry, device=dev, num_workers=num_workers, dtype=dtype) - register_biomedclip_model(registry, device=dev, num_workers=num_workers, dtype=dtype) - register_phikon_models(registry, device=dev, num_workers=num_workers, dtype=dtype) - register_virchow_models(registry, device=dev, num_workers=num_workers, dtype=dtype) - register_prov_gigapath_model(registry, device=dev, num_workers=num_workers, dtype=dtype) - register_midnight_model(registry, device=dev, num_workers=num_workers, dtype=dtype) - register_chief_ctranspath_model(registry, device=dev, num_workers=num_workers, dtype=dtype) + for module_name, registrar_name in _REGISTRARS: + module = import_module(module_name) + registrar = getattr(module, registrar_name) + registrar(registry, device=dev, num_workers=num_workers, dtype=dtype) return registry diff --git a/atlas_patch/models/patch/chief_ctranspath.py b/atlas_patch/models/patch/chief_ctranspath.py index b1ff4f1..674739d 100644 --- a/atlas_patch/models/patch/chief_ctranspath.py +++ b/atlas_patch/models/patch/chief_ctranspath.py @@ -9,8 +9,6 @@ from atlas_patch.models.patch.base import PatchFeatureExtractor from atlas_patch.models.patch.registry import PatchFeatureExtractorRegistry -from timm.layers import to_2tuple - logger = logging.getLogger(__name__) _CHECKPOINT_ID = "1_vgRF1QXa8sPCOpJ1S9BihwZhXQMOVJc" @@ -18,6 +16,12 @@ _EMB_DIM = 768 +def _to_2tuple(value): + from timm.layers import to_2tuple + + return to_2tuple(value) + + def _build_preprocess(): from torchvision import transforms @@ -81,8 +85,8 @@ def __init__( assert patch_size == 4 assert embed_dim % 8 == 0 - img_size = to_2tuple(img_size) - patch_size = to_2tuple(patch_size) + img_size = _to_2tuple(img_size) + patch_size = _to_2tuple(patch_size) self.img_size = img_size self.patch_size = patch_size self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1]) diff --git a/atlas_patch/models/patch/quilt.py b/atlas_patch/models/patch/quilt.py index 291e0ca..d82f26c 100644 --- a/atlas_patch/models/patch/quilt.py +++ b/atlas_patch/models/patch/quilt.py @@ -3,7 +3,6 @@ import logging import torch -from transformers import CLIPModel, CLIPProcessor from atlas_patch.models.patch.base import PatchFeatureExtractor from atlas_patch.models.patch.registry import PatchFeatureExtractorRegistry @@ -53,6 +52,8 @@ def _forward(x, m=model): forward_fn = _forward else: + from transformers import CLIPModel, CLIPProcessor + model = CLIPModel.from_pretrained(model_id) processor = CLIPProcessor.from_pretrained(model_id) diff --git a/atlas_patch/models/patch/registry.py b/atlas_patch/models/patch/registry.py index 4bb6ea0..91958b5 100644 --- a/atlas_patch/models/patch/registry.py +++ b/atlas_patch/models/patch/registry.py @@ -30,6 +30,13 @@ def create(self, name: str) -> FeatureExtractor: builder = self._builders[key] try: return builder() + except ModuleNotFoundError as exc: + logger.exception("Missing optional dependency for feature extractor '%s'", name) + dependency = exc.name or "unknown dependency" + raise RuntimeError( + f"Feature extractor '{name}' requires optional dependency '{dependency}'. " + "Install `atlas-patch[patch-encoders]` and any model-specific packages documented in the README." + ) from exc except Exception: logger.exception("Failed to create feature extractor '%s'", name) raise diff --git a/atlas_patch/services/__init__.py b/atlas_patch/services/__init__.py index 26dfaba..7b87a0c 100644 --- a/atlas_patch/services/__init__.py +++ b/atlas_patch/services/__init__.py @@ -1,10 +1,17 @@ """Service implementations for segmentation, extraction, visualization, and WSI access.""" -from .extraction import PatchExtractionService -from .mpp import CSVMPPResolver -from .segmentation import SAM2SegmentationService -from .visualization import DefaultVisualizationService -from .wsi_loader import DefaultWSILoader +from __future__ import annotations + +from importlib import import_module +from typing import Any + +_LAZY_EXPORTS = { + "PatchExtractionService": "atlas_patch.services.extraction", + "CSVMPPResolver": "atlas_patch.services.mpp", + "SAM2SegmentationService": "atlas_patch.services.segmentation", + "DefaultVisualizationService": "atlas_patch.services.visualization", + "DefaultWSILoader": "atlas_patch.services.wsi_loader", +} __all__ = [ "PatchExtractionService", @@ -13,3 +20,17 @@ "DefaultVisualizationService", "DefaultWSILoader", ] + + +def __getattr__(name: str) -> Any: + module_name = _LAZY_EXPORTS.get(name) + if module_name is None: + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") + module = import_module(module_name) + value = getattr(module, name) + globals()[name] = value + return value + + +def __dir__() -> list[str]: + return sorted(__all__) diff --git a/pyproject.toml b/pyproject.toml index adbe14e..5f73ac0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "atlas-patch" -version = "1.0.0.post5" +dynamic = ["version"] description = "A Python package for processing and handling whole slide images" readme = "README.md" requires-python = ">=3.10" @@ -40,24 +40,30 @@ dependencies = [ "matplotlib>=3.5.0", "tqdm>=4.66.0", "torchvision>=0.15.0", - "timm>=0.9.0", "huggingface-hub>=0.23.0", - "gdown>=5.2.0", - "transformers>=4.41.0", - "sentencepiece>=0.2.0", - "open-clip-torch>=2.24.0", - "fairscale>=0.4.0", - "einops>=0.8.0", - "einops-exts>=0.0.4", ] [project.optional-dependencies] +patch-encoders = [ + "einops>=0.8.0", + "einops-exts>=0.0.4", + "fairscale>=0.4.0", + "gdown>=5.2.0", + "open-clip-torch>=2.24.0", + "sentencepiece>=0.2.0", + "timm>=0.9.0", + "transformers>=4.41.0", +] +release = [ + "build>=1.2.2", + "twine>=5.1.1", +] dev = [ - "ruff>=0.1.0", + "mypy>=1.0.0", "pre-commit>=3.0.0", "pytest>=7.0.0", "pytest-cov>=4.0.0", - "mypy>=1.0.0", + "ruff>=0.1.0", ] # See README for installation instructions: @@ -81,6 +87,9 @@ include = ["atlas_patch*"] [tool.setuptools.package-data] "atlas_patch" = ["configs/*.yaml"] +[tool.setuptools.dynamic] +version = {attr = "atlas_patch.__version__"} + [tool.ruff] line-length = 100 target-version = "py310" From cf39c6afffcdf50174cc0623f4acee1ef22f9131 Mon Sep 17 00:00:00 2001 From: yousefkotp Date: Mon, 6 Apr 2026 15:57:58 -0400 Subject: [PATCH 2/8] feat: add slide and patient encoder foundations --- atlas_patch/core/config.py | 50 +++++++ atlas_patch/core/models.py | 25 ++++ atlas_patch/core/paths.py | 66 +++++++++ atlas_patch/models/patient/__init__.py | 14 ++ atlas_patch/models/patient/base.py | 74 ++++++++++ atlas_patch/models/patient/registry.py | 59 ++++++++ atlas_patch/models/slide/__init__.py | 17 +++ atlas_patch/models/slide/base.py | 69 ++++++++++ atlas_patch/models/slide/registry.py | 59 ++++++++ atlas_patch/utils/__init__.py | 10 ++ atlas_patch/utils/feature_h5.py | 180 +++++++++++++++++++++++++ 11 files changed, 623 insertions(+) create mode 100644 atlas_patch/models/patient/__init__.py create mode 100644 atlas_patch/models/patient/base.py create mode 100644 atlas_patch/models/patient/registry.py create mode 100644 atlas_patch/models/slide/__init__.py create mode 100644 atlas_patch/models/slide/base.py create mode 100644 atlas_patch/models/slide/registry.py create mode 100644 atlas_patch/utils/feature_h5.py diff --git a/atlas_patch/core/config.py b/atlas_patch/core/config.py index a7c306d..d4f26df 100644 --- a/atlas_patch/core/config.py +++ b/atlas_patch/core/config.py @@ -37,6 +37,22 @@ def _validate_device(device: str) -> str: raise ValueError(f"device must be 'cpu', 'cuda', or 'cuda:', got {device}") +def _normalize_names(values: list[str], name: str) -> list[str]: + normalized: list[str] = [] + seen: set[str] = set() + for raw in values: + item = str(raw).strip().lower() + if not item: + raise ValueError(f"{name} entries must be non-empty strings.") + if item in seen: + continue + normalized.append(item) + seen.add(item) + if not normalized: + raise ValueError(f"At least one {name} entry must be provided.") + return normalized + + @dataclass class SegmentationConfig: checkpoint_path: Path | None @@ -149,6 +165,40 @@ def validated(self) -> ProcessingConfig: return self +@dataclass +class SlideEncodingConfig: + input_path: Path + encoders: list[str] + recursive: bool = False + device: str = "cuda" + skip_existing: bool = True + mpp_csv: Path | None = None + + def validated(self) -> SlideEncodingConfig: + if not self.input_path.exists(): + raise FileNotFoundError(f"Input path not found: {self.input_path}") + self.encoders = _normalize_names(self.encoders, "slide encoder") + self.device = _validate_device(str(self.device)) + if self.mpp_csv is not None and not self.mpp_csv.exists(): + raise FileNotFoundError(f"MPP CSV not found: {self.mpp_csv}") + return self + + +@dataclass +class PatientEncodingConfig: + manifest_path: Path + encoders: list[str] + device: str = "cuda" + skip_existing: bool = True + + def validated(self) -> PatientEncodingConfig: + if not self.manifest_path.exists(): + raise FileNotFoundError(f"Patient manifest not found: {self.manifest_path}") + self.encoders = _normalize_names(self.encoders, "patient encoder") + self.device = _validate_device(str(self.device)) + return self + + @dataclass class VisualizationConfig: thumbnail_size: int = 1024 diff --git a/atlas_patch/core/models.py b/atlas_patch/core/models.py index 6662f09..1aef841 100644 --- a/atlas_patch/core/models.py +++ b/atlas_patch/core/models.py @@ -34,3 +34,28 @@ class ExtractionResult: metadata: dict[str, Any] = field(default_factory=dict) coords: np.ndarray | None = None # Optional in-memory coords for visualization patch_size_level0: int | None = None + + +@dataclass +class SlideEmbeddingResult: + slide: Slide + h5_path: Path + encoder_name: str + dataset_key: str + embedding_dim: int + num_patches: int + source_patch_encoder: str + metadata: dict[str, Any] = field(default_factory=dict) + embedding: np.ndarray | None = None + + +@dataclass +class PatientEmbeddingResult: + case_id: str + h5_path: Path + encoder_name: str + embedding_dim: int + num_slides: int + source_patch_encoder: str | None = None + metadata: dict[str, Any] = field(default_factory=dict) + embedding: np.ndarray | None = None diff --git a/atlas_patch/core/paths.py b/atlas_patch/core/paths.py index 8d6c3d6..b6ec097 100644 --- a/atlas_patch/core/paths.py +++ b/atlas_patch/core/paths.py @@ -1,10 +1,43 @@ from __future__ import annotations +import os from pathlib import Path from atlas_patch.core.config import ExtractionConfig, OutputConfig from atlas_patch.core.models import Slide +PATCH_FEATURE_GROUP = "features" +SLIDE_FEATURE_GROUP = "slide_features" +PATIENT_FEATURES_DIRNAME = "patient_features" + + +def normalize_encoder_name(name: str) -> str: + value = str(name).strip().lower() + if not value: + raise ValueError("encoder name must be a non-empty string") + return value + + +def _normalize_dataset_component(name: str, *, prefix: str) -> str: + value = str(name).strip() + if value.lower().startswith(f"{prefix}/"): + value = value.split("/", 1)[1] + return normalize_encoder_name(value) + + +def _validate_output_stem(value: str, *, field_name: str) -> str: + cleaned = str(value).strip() + if not cleaned: + raise ValueError(f"{field_name} must be a non-empty string") + if cleaned in {".", ".."}: + raise ValueError(f"{field_name} cannot be '.' or '..'") + separators = [os.sep] + if os.altsep is not None: + separators.append(os.altsep) + if any(sep in cleaned for sep in separators): + raise ValueError(f"{field_name} cannot contain path separators") + return cleaned + def build_run_root(output_cfg: OutputConfig, extraction_cfg: ExtractionConfig) -> Path: """Return the root output directory for a run. @@ -17,6 +50,14 @@ def patch_h5_path(slide: Slide, output_cfg: OutputConfig, extraction_cfg: Extrac return run_root / "patches" / f"{slide.stem}.h5" +def patch_feature_dataset_key(encoder_name: str) -> str: + return f"{PATCH_FEATURE_GROUP}/{_normalize_dataset_component(encoder_name, prefix=PATCH_FEATURE_GROUP)}" + + +def slide_feature_dataset_key(encoder_name: str) -> str: + return f"{SLIDE_FEATURE_GROUP}/{_normalize_dataset_component(encoder_name, prefix=SLIDE_FEATURE_GROUP)}" + + def find_existing_patch( slide: Slide, output_cfg: OutputConfig, extraction_cfg: ExtractionConfig ) -> Path | None: @@ -40,3 +81,28 @@ def patch_lock_path( ) -> Path: run_root = build_run_root(output_cfg, extraction_cfg) return run_root / "patches" / f"{slide.stem}.lock" + + +def slide_append_lock_path( + slide: Slide, output_cfg: OutputConfig, extraction_cfg: ExtractionConfig +) -> Path: + """Serialize all writes against the canonical per-slide patch H5.""" + return patch_lock_path(slide, output_cfg, extraction_cfg) + + +def patient_features_dir(output_cfg: OutputConfig) -> Path: + return output_cfg.output_root / PATIENT_FEATURES_DIRNAME + + +def patient_encoder_dir(output_cfg: OutputConfig, encoder_name: str) -> Path: + return patient_features_dir(output_cfg) / normalize_encoder_name(encoder_name) + + +def patient_embedding_path(output_cfg: OutputConfig, encoder_name: str, case_id: str) -> Path: + case_stem = _validate_output_stem(case_id, field_name="case_id") + return patient_encoder_dir(output_cfg, encoder_name) / f"{case_stem}.h5" + + +def patient_lock_path(output_cfg: OutputConfig, encoder_name: str, case_id: str) -> Path: + case_stem = _validate_output_stem(case_id, field_name="case_id") + return patient_encoder_dir(output_cfg, encoder_name) / f"{case_stem}.lock" diff --git a/atlas_patch/models/patient/__init__.py b/atlas_patch/models/patient/__init__.py new file mode 100644 index 0000000..f2b1400 --- /dev/null +++ b/atlas_patch/models/patient/__init__.py @@ -0,0 +1,14 @@ +from atlas_patch.models.patient.base import PatientEncoder, PatientEncoderSpec +from atlas_patch.models.patient.registry import PatientEncoderRegistry + +__all__ = [ + "PatientEncoder", + "PatientEncoderRegistry", + "PatientEncoderSpec", + "build_default_registry", +] + + +def build_default_registry() -> PatientEncoderRegistry: + """Return an empty patient-encoder registry for phased population.""" + return PatientEncoderRegistry() diff --git a/atlas_patch/models/patient/base.py b/atlas_patch/models/patient/base.py new file mode 100644 index 0000000..6b12cf2 --- /dev/null +++ b/atlas_patch/models/patient/base.py @@ -0,0 +1,74 @@ +from __future__ import annotations + +from abc import ABC, abstractmethod +from dataclasses import dataclass +from pathlib import Path +from typing import Mapping, Sequence + +import numpy as np + + +def _normalize_encoder_name(name: str) -> str: + value = str(name).strip().lower() + if not value: + raise ValueError("encoder name must be a non-empty string") + return value + + +def _normalize_patch_size(patch_size: int) -> int: + value = int(patch_size) + if value <= 0: + raise ValueError("patch_size must be > 0") + return value + + +@dataclass(frozen=True) +class PatientEncoderSpec: + name: str + embedding_dim: int + patch_encoder_name: str + patch_size: int + + def __post_init__(self) -> None: + object.__setattr__(self, "name", _normalize_encoder_name(self.name)) + embedding_dim = int(self.embedding_dim) + if embedding_dim <= 0: + raise ValueError("embedding_dim must be > 0") + object.__setattr__(self, "embedding_dim", embedding_dim) + object.__setattr__( + self, + "patch_encoder_name", + _normalize_encoder_name(self.patch_encoder_name), + ) + object.__setattr__(self, "patch_size", _normalize_patch_size(self.patch_size)) + + +class PatientEncoder(ABC): + """Base interface for case-level encoders operating on grouped slide patch H5s.""" + + spec: PatientEncoderSpec + + @property + def name(self) -> str: + return self.spec.name + + @property + def embedding_dim(self) -> int: + return self.spec.embedding_dim + + @property + def required_patch_encoder(self) -> str: + return self.spec.patch_encoder_name + + @property + def required_patch_size(self) -> int: + return self.spec.patch_size + + @abstractmethod + def encode_case( + self, + slide_h5_paths: Sequence[Path], + metadata: Mapping[str, object] | None = None, + ) -> np.ndarray: + """Return a 1-D embedding for the provided case's per-slide patch H5 files.""" + raise NotImplementedError diff --git a/atlas_patch/models/patient/registry.py b/atlas_patch/models/patient/registry.py new file mode 100644 index 0000000..55b3b5a --- /dev/null +++ b/atlas_patch/models/patient/registry.py @@ -0,0 +1,59 @@ +from __future__ import annotations + +import logging +from typing import Callable, Mapping + +from atlas_patch.core.paths import normalize_encoder_name +from atlas_patch.models.patient.base import PatientEncoder, PatientEncoderSpec + +logger = logging.getLogger(__name__) + + +class PatientEncoderRegistry: + """Registry of patient encoder builders plus lightweight metadata specs.""" + + def __init__(self) -> None: + self._builders: dict[str, Callable[[], PatientEncoder]] = {} + self._specs: dict[str, PatientEncoderSpec] = {} + + def register(self, spec: PatientEncoderSpec, builder: Callable[[], PatientEncoder]) -> None: + key = normalize_encoder_name(spec.name) + if key in self._builders: + raise ValueError(f"Patient encoder '{spec.name}' already registered.") + self._builders[key] = builder + self._specs[key] = spec + + def available(self) -> list[str]: + return sorted(self._builders.keys()) + + def get_spec(self, name: str) -> PatientEncoderSpec: + key = normalize_encoder_name(name) + if key not in self._specs: + raise KeyError(f"Unknown patient encoder '{name}'. Available: {self.available()}") + return self._specs[key] + + def specs(self) -> Mapping[str, PatientEncoderSpec]: + return dict(self._specs) + + def create(self, name: str) -> PatientEncoder: + key = normalize_encoder_name(name) + if key not in self._builders: + raise KeyError(f"Unknown patient encoder '{name}'. Available: {self.available()}") + builder = self._builders[key] + try: + encoder = builder() + except ModuleNotFoundError as exc: + logger.exception("Missing optional dependency for patient encoder '%s'", name) + dependency = exc.name or "unknown dependency" + raise RuntimeError( + f"Patient encoder '{name}' requires optional dependency '{dependency}'. " + "Install the relevant optional extras documented in the README." + ) from exc + except Exception: + logger.exception("Failed to create patient encoder '%s'", name) + raise + if normalize_encoder_name(encoder.name) != key: + raise RuntimeError( + f"Patient encoder builder for '{name}' returned '{encoder.name}', which does not match." + ) + return encoder diff --git a/atlas_patch/models/slide/__init__.py b/atlas_patch/models/slide/__init__.py new file mode 100644 index 0000000..85f2b00 --- /dev/null +++ b/atlas_patch/models/slide/__init__.py @@ -0,0 +1,17 @@ +from atlas_patch.models.slide.base import ( + SlideEncoder, + SlideEncoderSpec, +) +from atlas_patch.models.slide.registry import SlideEncoderRegistry + +__all__ = [ + "SlideEncoder", + "SlideEncoderRegistry", + "SlideEncoderSpec", + "build_default_registry", +] + + +def build_default_registry() -> SlideEncoderRegistry: + """Return an empty slide-encoder registry for phased population.""" + return SlideEncoderRegistry() diff --git a/atlas_patch/models/slide/base.py b/atlas_patch/models/slide/base.py new file mode 100644 index 0000000..7c578bb --- /dev/null +++ b/atlas_patch/models/slide/base.py @@ -0,0 +1,69 @@ +from __future__ import annotations + +from abc import ABC, abstractmethod +from dataclasses import dataclass +from pathlib import Path + +import numpy as np + + +def _normalize_encoder_name(name: str) -> str: + value = str(name).strip().lower() + if not value: + raise ValueError("encoder name must be a non-empty string") + return value + + +def _normalize_patch_size(patch_size: int) -> int: + value = int(patch_size) + if value <= 0: + raise ValueError("patch_size must be > 0") + return value + + +@dataclass(frozen=True) +class SlideEncoderSpec: + name: str + embedding_dim: int + patch_encoder_name: str + patch_size: int + + def __post_init__(self) -> None: + object.__setattr__(self, "name", _normalize_encoder_name(self.name)) + embedding_dim = int(self.embedding_dim) + if embedding_dim <= 0: + raise ValueError("embedding_dim must be > 0") + object.__setattr__(self, "embedding_dim", embedding_dim) + object.__setattr__( + self, + "patch_encoder_name", + _normalize_encoder_name(self.patch_encoder_name), + ) + object.__setattr__(self, "patch_size", _normalize_patch_size(self.patch_size)) + + +class SlideEncoder(ABC): + """Base interface for slide-level encoders operating on AtlasPatch H5 artifacts.""" + + spec: SlideEncoderSpec + + @property + def name(self) -> str: + return self.spec.name + + @property + def embedding_dim(self) -> int: + return self.spec.embedding_dim + + @property + def required_patch_encoder(self) -> str: + return self.spec.patch_encoder_name + + @property + def required_patch_size(self) -> int: + return self.spec.patch_size + + @abstractmethod + def encode_slide(self, patch_h5_path: Path) -> np.ndarray: + """Return a 1-D embedding for the provided per-slide patch H5.""" + raise NotImplementedError diff --git a/atlas_patch/models/slide/registry.py b/atlas_patch/models/slide/registry.py new file mode 100644 index 0000000..967318f --- /dev/null +++ b/atlas_patch/models/slide/registry.py @@ -0,0 +1,59 @@ +from __future__ import annotations + +import logging +from typing import Callable, Mapping + +from atlas_patch.core.paths import normalize_encoder_name +from atlas_patch.models.slide.base import SlideEncoder, SlideEncoderSpec + +logger = logging.getLogger(__name__) + + +class SlideEncoderRegistry: + """Registry of slide encoder builders plus lightweight metadata specs.""" + + def __init__(self) -> None: + self._builders: dict[str, Callable[[], SlideEncoder]] = {} + self._specs: dict[str, SlideEncoderSpec] = {} + + def register(self, spec: SlideEncoderSpec, builder: Callable[[], SlideEncoder]) -> None: + key = normalize_encoder_name(spec.name) + if key in self._builders: + raise ValueError(f"Slide encoder '{spec.name}' already registered.") + self._builders[key] = builder + self._specs[key] = spec + + def available(self) -> list[str]: + return sorted(self._builders.keys()) + + def get_spec(self, name: str) -> SlideEncoderSpec: + key = normalize_encoder_name(name) + if key not in self._specs: + raise KeyError(f"Unknown slide encoder '{name}'. Available: {self.available()}") + return self._specs[key] + + def specs(self) -> Mapping[str, SlideEncoderSpec]: + return dict(self._specs) + + def create(self, name: str) -> SlideEncoder: + key = normalize_encoder_name(name) + if key not in self._builders: + raise KeyError(f"Unknown slide encoder '{name}'. Available: {self.available()}") + builder = self._builders[key] + try: + encoder = builder() + except ModuleNotFoundError as exc: + logger.exception("Missing optional dependency for slide encoder '%s'", name) + dependency = exc.name or "unknown dependency" + raise RuntimeError( + f"Slide encoder '{name}' requires optional dependency '{dependency}'. " + "Install the relevant optional extras documented in the README." + ) from exc + except Exception: + logger.exception("Failed to create slide encoder '%s'", name) + raise + if normalize_encoder_name(encoder.name) != key: + raise RuntimeError( + f"Slide encoder builder for '{name}' returned '{encoder.name}', which does not match." + ) + return encoder diff --git a/atlas_patch/utils/__init__.py b/atlas_patch/utils/__init__.py index b7f64d3..954277d 100644 --- a/atlas_patch/utils/__init__.py +++ b/atlas_patch/utils/__init__.py @@ -9,6 +9,12 @@ missing_features, parse_feature_list, ) +from .feature_h5 import ( + PatchFeatureData, + append_slide_embedding, + load_patch_feature_data, + write_patient_embedding_h5, +) from .h5 import H5AppendWriter from .hf import import_module_from_hf from .image import is_black_patch, is_white_patch @@ -29,5 +35,9 @@ "parse_feature_list", "get_existing_features", "missing_features", + "PatchFeatureData", + "load_patch_feature_data", + "append_slide_embedding", + "write_patient_embedding_h5", "import_module_from_hf", ] diff --git a/atlas_patch/utils/feature_h5.py b/atlas_patch/utils/feature_h5.py new file mode 100644 index 0000000..cd37990 --- /dev/null +++ b/atlas_patch/utils/feature_h5.py @@ -0,0 +1,180 @@ +from __future__ import annotations + +import json +import os +import uuid +from contextlib import suppress +from dataclasses import dataclass +from pathlib import Path +from typing import Any, Mapping + +import h5py +import numpy as np + +from atlas_patch.core.paths import patch_feature_dataset_key, slide_feature_dataset_key + +REQUIRED_PATCH_FILE_ATTRS = ( + "patch_size_level0", + "patch_size", + "target_magnification", +) + + +@dataclass(frozen=True) +class PatchFeatureData: + h5_path: Path + feature_name: str + dataset_key: str + features: np.ndarray + coords: np.ndarray + patch_size_level0: int + patch_size: int + target_magnification: int + + @property + def num_patches(self) -> int: + return int(self.features.shape[0]) + + +def _encode_attr_value(value: Any) -> Any: + if isinstance(value, Path): + return str(value) + if isinstance(value, (dict, list, tuple)): + return json.dumps(value) + if value is None: + return "None" + return value + + +def _write_attrs(target: Any, attrs: Mapping[str, Any] | None) -> None: + if not attrs: + return + for key, value in attrs.items(): + target.attrs[key] = _encode_attr_value(value) + + +def _read_required_int_attrs(h5_path: Path, *sources: Mapping[str, Any]) -> dict[str, int]: + values: dict[str, int] = {} + for key in REQUIRED_PATCH_FILE_ATTRS: + for source in sources: + if key not in source: + continue + values[key] = int(source[key]) + break + else: + raise ValueError(f"{h5_path} is missing required metadata '{key}'.") + return values + + +def _coerce_embedding(embedding: np.ndarray, *, label: str) -> np.ndarray: + vector = np.asarray(embedding) + if vector.ndim != 1: + raise ValueError(f"{label} must be 1-D, got shape {vector.shape}") + return vector + + +def load_patch_feature_data( + h5_path: str | Path, + feature_name: str, + *, + validate_shapes: bool = True, +) -> PatchFeatureData: + """Load one AtlasPatch feature matrix and its aligned coordinate metadata.""" + path = Path(h5_path) + dataset_key = patch_feature_dataset_key(feature_name) + feature_key = dataset_key.split("/", 1)[1] + + with h5py.File(path, "r") as handle: + coords_ds = handle.get("coords") + if not isinstance(coords_ds, h5py.Dataset): + raise ValueError(f"{path} is missing required dataset 'coords'.") + + features_ds = handle.get(dataset_key) + if not isinstance(features_ds, h5py.Dataset): + raise ValueError(f"{path} is missing required dataset '{dataset_key}'.") + + attrs = _read_required_int_attrs(path, handle.attrs, coords_ds.attrs) + coords = np.asarray(coords_ds[()]) + features = np.asarray(features_ds[()]) + + if validate_shapes: + if features.ndim != 2: + raise ValueError(f"{path} has invalid feature shape {features.shape}; expected 2-D.") + if coords.ndim != 2 or coords.shape[1] < 2: + raise ValueError( + f"{path} has invalid coords shape {coords.shape}; expected at least 2 columns." + ) + if coords.shape[0] != features.shape[0]: + raise ValueError( + f"{path} has mismatched features/coords lengths: " + f"features {features.shape[0]} vs coords {coords.shape[0]}" + ) + + return PatchFeatureData( + h5_path=path, + feature_name=feature_key, + dataset_key=dataset_key, + features=np.asarray(features, dtype=np.float32), + coords=np.asarray(coords, dtype=np.int64), + patch_size_level0=attrs["patch_size_level0"], + patch_size=attrs["patch_size"], + target_magnification=attrs["target_magnification"], + ) + + +def append_slide_embedding( + h5_path: str | Path, + encoder_name: str, + embedding: np.ndarray, + *, + attrs: Mapping[str, Any] | None = None, + overwrite: bool = False, +) -> str: + """Append or replace a slide-level embedding inside an existing AtlasPatch H5.""" + path = Path(h5_path) + dataset_key = slide_feature_dataset_key(encoder_name) + dataset_name = dataset_key.split("/", 1)[1] + vector = _coerce_embedding(embedding, label=dataset_key) + + with h5py.File(path, "r+") as handle: + group = handle.require_group("slide_features") + if dataset_name in group: + if not overwrite: + raise ValueError(f"{path} already contains '{dataset_key}'.") + del group[dataset_name] + dset = group.create_dataset(dataset_name, data=vector, dtype=vector.dtype) + merged_attrs = {"encoder_name": dataset_name} + if attrs: + merged_attrs.update(attrs) + _write_attrs(dset, merged_attrs) + handle.flush() + + return dataset_key + + +def write_patient_embedding_h5( + output_path: str | Path, + embedding: np.ndarray, + *, + attrs: Mapping[str, Any] | None = None, + overwrite: bool = False, +) -> Path: + """Write a compact patient-level embedding H5 atomically.""" + path = Path(output_path) + vector = _coerce_embedding(embedding, label=str(path)) + path.parent.mkdir(parents=True, exist_ok=True) + if path.exists() and not overwrite: + raise FileExistsError(f"Patient embedding already exists: {path}") + + tmp_path = path.parent / f".{path.name}.tmp.{uuid.uuid4().hex}" + try: + with h5py.File(tmp_path, "w") as handle: + handle.create_dataset("features", data=vector, dtype=vector.dtype) + _write_attrs(handle, attrs) + os.replace(tmp_path, path) + finally: + if tmp_path.exists(): + with suppress(OSError): + tmp_path.unlink() + + return path From baba65b7ca681f5b68a28c674149a61e156e6584 Mon Sep 17 00:00:00 2001 From: yousefkotp Date: Mon, 6 Apr 2026 17:49:00 -0400 Subject: [PATCH 3/8] feat: add slide encoder support --- atlas_patch/cli.py | 573 +++++++++++++++++++----- atlas_patch/core/config.py | 21 +- atlas_patch/core/models.py | 2 - atlas_patch/models/slide/__init__.py | 26 +- atlas_patch/models/slide/common.py | 44 ++ atlas_patch/models/slide/moozy.py | 92 ++++ atlas_patch/models/slide/prism.py | 98 ++++ atlas_patch/models/slide/titan.py | 91 ++++ atlas_patch/orchestration/runner.py | 8 +- atlas_patch/services/slide_embedding.py | 260 +++++++++++ atlas_patch/utils/__init__.py | 14 +- atlas_patch/utils/feature_h5.py | 93 +++- atlas_patch/utils/features.py | 16 +- pyproject.toml | 32 ++ 14 files changed, 1227 insertions(+), 143 deletions(-) create mode 100644 atlas_patch/models/slide/common.py create mode 100644 atlas_patch/models/slide/moozy.py create mode 100644 atlas_patch/models/slide/prism.py create mode 100644 atlas_patch/models/slide/titan.py create mode 100644 atlas_patch/services/slide_embedding.py diff --git a/atlas_patch/cli.py b/atlas_patch/cli.py index f2b3556..1247c45 100644 --- a/atlas_patch/cli.py +++ b/atlas_patch/cli.py @@ -7,10 +7,26 @@ import click from atlas_patch import __version__ +from atlas_patch.core.config import ( + DEFAULT_FEATURE_BATCH_SIZE, + DEFAULT_FEATURE_NUM_WORKERS, + DEFAULT_FEATURE_PRECISION, + AppConfig, + FeatureExtractionConfig, + SlideEncodingConfig, + default_sam2_config_path, +) +from atlas_patch.core.models import Slide +from atlas_patch.core.paths import patch_h5_path +from atlas_patch.models.patch.registry import PatchFeatureExtractorRegistry +from atlas_patch.models.slide.registry import SlideEncoderRegistry from atlas_patch.utils import ( configure_logging, install_embedding_log_filter, + missing_features, + parse_named_list, ) +from atlas_patch.utils.feature_h5 import read_patch_artifact_summary from atlas_patch.utils.params import get_wsi_files from atlas_patch.utils.visualization import visualize_mask_on_thumbnail @@ -18,13 +34,8 @@ level=logging.WARNING, format="%(asctime)s | %(levelname)s | %(name)s | %(message)s", ) -logger = logging.getLogger("atlas_patch.cli") install_embedding_log_filter() - -def _default_config_path() -> Path: - return Path(__file__).resolve().parent / "configs" / "sam2.1_hiera_t.yaml" - # Shared option sets ----------------------------------------------------------- _COMMON_OPTIONS: list = [ click.argument("wsi_path", type=click.Path(exists=True)), @@ -118,41 +129,31 @@ def _default_config_path() -> Path: click.option("--verbose", "-v", is_flag=True, help="Enable debug logging."), ] -_FEATURE_OPTIONS: list = [ +_FEATURE_RUNTIME_OPTIONS: list = [ click.option( "--feature-device", type=str, default=None, help="Device for feature extraction; e.g. cuda, cuda:0, cpu. Defaults to --device.", ), - click.option( - "--feature-extractors", - required=True, - type=str, - help=( - "Space/comma separated feature extractors to run. " - "The registry is resolved lazily at runtime; install `atlas-patch[patch-encoders]` " - "for optional timm/transformers/open-clip based models and add more via --feature-plugin." - ), - ), click.option( "--feature-batch-size", type=int, - default=32, + default=DEFAULT_FEATURE_BATCH_SIZE, show_default=True, help="Batch size used when embedding patches.", ), click.option( "--feature-num-workers", type=int, - default=4, + default=DEFAULT_FEATURE_NUM_WORKERS, show_default=True, help="DataLoader worker count for feature extraction.", ), click.option( "--feature-precision", type=click.Choice(["float32", "float16", "bfloat16"], case_sensitive=False), - default="float16", + default=DEFAULT_FEATURE_PRECISION, show_default=True, help="Computation precision for feature extraction.", ), @@ -167,6 +168,20 @@ def _default_config_path() -> Path: ), ), ] +_FEATURE_EXTRACTOR_OPTION = click.option( + "--feature-extractors", + required=True, + type=str, + help=( + "Space/comma separated feature extractors to run. " + "The registry is resolved lazily at runtime; install `atlas-patch[patch-encoders]` " + "for optional timm/transformers/open-clip based models and add more via --feature-plugin." + ), +) +_FEATURE_OPTIONS: list = [ + _FEATURE_EXTRACTOR_OPTION, + *_FEATURE_RUNTIME_OPTIONS, +] def _apply_options(func, options: list): @@ -183,7 +198,137 @@ def feature_options(func): return _apply_options(func, _FEATURE_OPTIONS) -def _run_pipeline( +def feature_runtime_options(func): + return _apply_options(func, _FEATURE_RUNTIME_OPTIONS) + + +def _build_patch_feature_registry( + *, + device: str, + feature_device: str | None, + feature_num_workers: int, + feature_precision: str, + feature_plugins: tuple[str, ...], +): + import torch + + from atlas_patch.models.patch import build_default_registry + from atlas_patch.models.patch.custom import register_feature_extractors_from_module + from atlas_patch.services.feature_embedding import resolve_feature_dtype + + feat_device = feature_device.lower() if feature_device else device.lower() + precision = feature_precision.lower() + torch_device = torch.device(feat_device) + dtype = resolve_feature_dtype(torch_device, precision) + registry = build_default_registry( + device=torch_device, + num_workers=feature_num_workers, + dtype=dtype, + ) + for plugin in feature_plugins: + register_feature_extractors_from_module( + plugin, + registry=registry, + device=torch_device, + dtype=dtype, + num_workers=feature_num_workers, + ) + return registry, feat_device, precision + + +def _resolve_slide_encoder_requirements( + slide_registry: SlideEncoderRegistry, + encoders: list[str], +) -> tuple[int, list[str]]: + specs = [slide_registry.get_spec(name) for name in encoders] + patch_sizes = {spec.patch_size for spec in specs} + if len(patch_sizes) != 1: + parts = ", ".join(f"{spec.name}:{spec.patch_size}" for spec in specs) + raise click.ClickException( + "Requested slide encoders require incompatible patch geometries for one canonical H5: " + f"{parts}" + ) + required_patch_encoders = sorted({spec.patch_encoder_name for spec in specs}) + return next(iter(patch_sizes)), required_patch_encoders + + +def _validate_existing_slide_outputs(app_cfg: AppConfig, slides: list[Slide]) -> None: + if not app_cfg.output.skip_existing: + return + + for slide in slides: + h5_path = patch_h5_path(slide, app_cfg.output, app_cfg.extraction) + if not h5_path.exists(): + continue + + try: + summary = read_patch_artifact_summary(h5_path) + except Exception: + continue + if summary.patch_size != app_cfg.extraction.patch_size: + raise click.ClickException( + f"Existing patch artifact {h5_path} has patch_size={summary.patch_size}, " + f"but encode-slide requested --patch-size {app_cfg.extraction.patch_size}. " + "Re-run with --force to rebuild the canonical H5." + ) + if summary.target_magnification != app_cfg.extraction.target_magnification: + raise click.ClickException( + f"Existing patch artifact {h5_path} has target_mag={summary.target_magnification}, " + f"but encode-slide requested --target-mag {app_cfg.extraction.target_magnification}. " + "Re-run with --force to rebuild the canonical H5." + ) + + +def _collect_ready_slide_artifacts( + app_cfg: AppConfig, + slides: list[Slide], + *, + required_patch_encoders: list[str], + failed_slide_paths: set[Path], +) -> tuple[list[tuple[Slide, Path]], list[tuple[Slide, Exception | str]]]: + artifacts: list[tuple[Slide, Path]] = [] + failures: list[tuple[Slide, Exception | str]] = [] + + for slide in slides: + if slide.path.resolve() in failed_slide_paths: + continue + + h5_path = patch_h5_path(slide, app_cfg.output, app_cfg.extraction) + if not h5_path.exists(): + failures.append( + ( + slide, + RuntimeError( + f"No canonical patch artifact was produced for {slide.path.name} at {h5_path}." + ), + ) + ) + continue + + try: + summary = read_patch_artifact_summary(h5_path) + if summary.patch_size != app_cfg.extraction.patch_size: + raise ValueError( + f"{h5_path} has patch_size={summary.patch_size}, " + f"but encode-slide expects {app_cfg.extraction.patch_size}." + ) + missing = missing_features( + h5_path, + required_patch_encoders, + expected_total=summary.num_patches, + ) + if missing: + raise ValueError( + f"{h5_path} is missing required patch feature sets: {', '.join(missing)}" + ) + artifacts.append((slide, h5_path)) + except Exception as exc: # noqa: BLE001 + failures.append((slide, exc)) + + return artifacts, failures + + +def _build_wsi_app_config( *, wsi_path: str, output: str, @@ -206,20 +351,66 @@ def _run_pipeline( recursive: bool, mpp_csv: str | None, skip_existing: bool, - verbose: bool, - feature_cfg: FeatureExtractionConfig | None = None, - registry: PatchFeatureExtractorRegistry | None = None, -) -> tuple[list, list]: - from tqdm import tqdm - + feature_cfg=None, +): from atlas_patch.core.config import ( - AppConfig, ExtractionConfig, OutputConfig, ProcessingConfig, SegmentationConfig, VisualizationConfig, ) + + return AppConfig( + processing=ProcessingConfig( + input_path=Path(wsi_path), + recursive=recursive, + mpp_csv=Path(mpp_csv) if mpp_csv else None, + ), + segmentation=SegmentationConfig( + checkpoint_path=None, + config_path=default_sam2_config_path(), + device=device.lower(), + batch_size=seg_batch_size, + ), + extraction=ExtractionConfig( + patch_size=patch_size, + step_size=step_size, + target_magnification=target_mag, + tissue_threshold=tissue_thresh, + white_threshold=white_thresh, + black_threshold=black_thresh, + fast_mode=fast_mode, + write_batch=write_batch, + workers=patch_workers, + max_open_slides=max_open_slides, + ), + output=OutputConfig( + output_root=Path(output), + save_images=save_images, + visualize_grids=visualize_grids, + visualize_mask=visualize_mask, + visualize_contours=visualize_contours, + skip_existing=skip_existing, + ), + visualization=VisualizationConfig(), + features=feature_cfg, + device=device.lower(), + ).validated() + + +def _run_pipeline_from_config( + app_cfg: AppConfig, + *, + slides: list[Slide] | None = None, + verbose: bool, + registry: PatchFeatureExtractorRegistry | None = None, + announce_completion: bool = True, +) -> tuple[list, list]: + from tqdm import tqdm + + configure_logging(verbose) + from atlas_patch.orchestration.runner import ProcessingRunner from atlas_patch.services.extraction import PatchExtractionService from atlas_patch.services.feature_embedding import PatchFeatureEmbeddingService @@ -228,60 +419,22 @@ def _run_pipeline( from atlas_patch.services.visualization import DefaultVisualizationService from atlas_patch.services.wsi_loader import DefaultWSILoader - configure_logging(verbose) - - processing_cfg = ProcessingConfig( - input_path=Path(wsi_path), - recursive=recursive, - mpp_csv=Path(mpp_csv) if mpp_csv else None, - ) - segmentation_cfg = SegmentationConfig( - checkpoint_path=None, - config_path=_default_config_path(), - device=device.lower(), - batch_size=seg_batch_size, - ) - extraction_cfg = ExtractionConfig( - patch_size=patch_size, - step_size=step_size, - target_magnification=target_mag, - tissue_threshold=tissue_thresh, - white_threshold=white_thresh, - black_threshold=black_thresh, - fast_mode=fast_mode, - write_batch=write_batch, - workers=patch_workers, - max_open_slides=max_open_slides, - ) - output_cfg = OutputConfig( - output_root=Path(output), - save_images=save_images, - visualize_grids=visualize_grids, - visualize_mask=visualize_mask, - visualize_contours=visualize_contours, - skip_existing=skip_existing, - ) - app_cfg = AppConfig( - processing=processing_cfg, - segmentation=segmentation_cfg, - extraction=extraction_cfg, - output=output_cfg, - visualization=VisualizationConfig(), - features=feature_cfg, - device=device.lower(), - ).validated() - segmentation_service = SAM2SegmentationService(app_cfg.segmentation) extractor_service = PatchExtractionService(app_cfg.extraction, app_cfg.output) visualizer_service = None - if visualize_grids or visualize_mask or visualize_contours: + if ( + app_cfg.output.visualize_grids + or app_cfg.output.visualize_mask + or app_cfg.output.visualize_contours + ): visualizer_service = DefaultVisualizationService( - app_cfg.output, app_cfg.extraction, app_cfg.visualization + app_cfg.output, + app_cfg.extraction, + app_cfg.visualization, ) mpp_resolver = CSVMPPResolver(app_cfg.processing.mpp_csv) wsi_loader = DefaultWSILoader() - runner = ProcessingRunner( config=app_cfg, segmentation=segmentation_service, @@ -292,18 +445,17 @@ def _run_pipeline( show_progress=not verbose, ) - results: list - failures: list try: - results, failures = runner.run() + results, failures = runner.run(slides=slides) finally: segmentation_service.close() - click.echo("Segmentation and patch coordinate extraction complete.") - if app_cfg.features is not None: feature_service = PatchFeatureEmbeddingService( - app_cfg.extraction, app_cfg.output, app_cfg.features, registry=registry + app_cfg.extraction, + app_cfg.output, + app_cfg.features, + registry=registry, ) total_units = len(results) * len(app_cfg.features.extractors) feature_progress = tqdm( @@ -318,9 +470,71 @@ def _run_pipeline( finally: feature_progress.close() + if announce_completion: + click.echo("Segmentation and patch coordinate extraction complete.") return results, failures +def _run_pipeline( + *, + wsi_path: str, + output: str, + patch_size: int, + step_size: int | None, + target_mag: int, + device: str, + tissue_thresh: float, + white_thresh: int, + black_thresh: int, + seg_batch_size: int, + write_batch: int, + patch_workers: int | None, + max_open_slides: int | None, + fast_mode: bool, + save_images: bool, + visualize_grids: bool, + visualize_mask: bool, + visualize_contours: bool, + recursive: bool, + mpp_csv: str | None, + skip_existing: bool, + verbose: bool, + feature_cfg: FeatureExtractionConfig | None = None, + registry: PatchFeatureExtractorRegistry | None = None, + announce_completion: bool = True, +) -> tuple[list, list]: + app_cfg = _build_wsi_app_config( + wsi_path=wsi_path, + output=output, + patch_size=patch_size, + step_size=step_size, + target_mag=target_mag, + device=device, + tissue_thresh=tissue_thresh, + white_thresh=white_thresh, + black_thresh=black_thresh, + seg_batch_size=seg_batch_size, + write_batch=write_batch, + patch_workers=patch_workers, + max_open_slides=max_open_slides, + fast_mode=fast_mode, + save_images=save_images, + visualize_grids=visualize_grids, + visualize_mask=visualize_mask, + visualize_contours=visualize_contours, + recursive=recursive, + mpp_csv=mpp_csv, + skip_existing=skip_existing, + feature_cfg=feature_cfg, + ) + return _run_pipeline_from_config( + app_cfg, + verbose=verbose, + registry=registry, + announce_completion=announce_completion, + ) + + def _run_tissue_visualization( *, wsi_path: str, @@ -349,7 +563,7 @@ def _run_tissue_visualization( segmentation_cfg = ( SegmentationConfig( checkpoint_path=None, - config_path=_default_config_path(), + config_path=default_sam2_config_path(), device=device.lower(), batch_size=seg_batch_size, ) @@ -466,6 +680,20 @@ def _echo_mask_results( click.echo(f"[FAIL] {slide.path.name}: {err}", err=True) +def _echo_slide_results( + results: list, failures: list, verbose: bool +) -> None: + click.echo(f"Completed {len(results)} slide embedding(s), failures: {len(failures)}") + if verbose: + for res in results: + status = "reused" if res.metadata.get("reused") else "written" + click.echo( + f"[OK] {res.slide.path.name} -> {res.h5_path} ({res.dataset_key}, {status})" + ) + for slide, err in failures: + click.echo(f"[FAIL] {slide.path.name}: {err}", err=True) + + @click.group() @click.version_option(version=__version__) def cli(): @@ -615,37 +843,26 @@ def process( feature_plugins: tuple[str, ...], ): """Run segmentation, patch extraction, and feature embedding into a single H5.""" - import torch - - from atlas_patch.core.config import FeatureExtractionConfig - from atlas_patch.models.patch import build_default_registry - from atlas_patch.models.patch.custom import register_feature_extractors_from_module - from atlas_patch.services.feature_embedding import resolve_feature_dtype - from atlas_patch.utils import parse_feature_list - - feat_device = feature_device.lower() if feature_device else device.lower() - torch_device = torch.device(feat_device) - dtype = resolve_feature_dtype(torch_device, feature_precision.lower()) - registry = build_default_registry( - device=torch_device, num_workers=feature_num_workers, dtype=dtype + registry, feat_device, precision = _build_patch_feature_registry( + device=device, + feature_device=feature_device, + feature_num_workers=feature_num_workers, + feature_precision=feature_precision, + feature_plugins=feature_plugins, ) - for plugin in feature_plugins: - register_feature_extractors_from_module( - plugin, - registry=registry, - device=torch_device, - dtype=dtype, - num_workers=feature_num_workers, - ) available_extractors = registry.available() - feats = parse_feature_list(feature_extractors, choices=available_extractors) + feats = parse_named_list( + feature_extractors, + choices=available_extractors, + item_label="feature extractor", + ) feature_cfg = FeatureExtractionConfig( extractors=feats, batch_size=feature_batch_size, device=feat_device, num_workers=feature_num_workers, - precision=feature_precision.lower(), + precision=precision, plugins=[Path(p) for p in feature_plugins], ) results, failures = _run_pipeline( @@ -677,6 +894,155 @@ def process( _echo_results(results, failures, verbose, feature_cfg) +@cli.command() +@feature_runtime_options +@common_options +@click.option( + "--slide-encoders", + required=True, + type=str, + help="Space/comma separated slide encoders to run.", +) +def encode_slide( + wsi_path: str, + output: str, + patch_size: int, + step_size: int | None, + target_mag: int, + device: str, + feature_device: str | None, + feature_batch_size: int, + feature_num_workers: int, + feature_precision: str, + tissue_thresh: float, + white_thresh: int, + black_thresh: int, + seg_batch_size: int, + write_batch: int, + patch_workers: int | None, + max_open_slides: int | None, + fast_mode: bool, + save_images: bool, + visualize_grids: bool, + visualize_mask: bool, + visualize_contours: bool, + recursive: bool, + mpp_csv: str | None, + skip_existing: bool, + verbose: bool, + feature_plugins: tuple[str, ...], + slide_encoders: str, +): + """Run the WSI pipeline and append slide embeddings into canonical AtlasPatch H5 files.""" + from atlas_patch.models.slide import build_default_registry + from atlas_patch.services.slide_embedding import SlideEmbeddingService + + slide_registry = build_default_registry(device=device.lower()) + available_encoders = slide_registry.available() + encoders = parse_named_list( + slide_encoders, + choices=available_encoders, + item_label="slide encoder", + ) + required_patch_size, required_patch_encoders = _resolve_slide_encoder_requirements( + slide_registry, + encoders, + ) + if patch_size != required_patch_size: + raise click.ClickException( + f"Requested slide encoders require patch_size={required_patch_size}, " + f"but encode-slide was given --patch-size {patch_size}." + ) + patch_registry, feat_device, precision = _build_patch_feature_registry( + device=device, + feature_device=feature_device, + feature_num_workers=feature_num_workers, + feature_precision=feature_precision, + feature_plugins=feature_plugins, + ) + available_extractors = patch_registry.available() + missing_extractors = [name for name in required_patch_encoders if name not in available_extractors] + if missing_extractors: + joined = ", ".join(missing_extractors) + raise click.ClickException( + f"Required patch extractor(s) for the selected slide encoders are unavailable: {joined}" + ) + + feature_cfg = FeatureExtractionConfig( + extractors=required_patch_encoders, + batch_size=feature_batch_size, + device=feat_device, + num_workers=feature_num_workers, + precision=precision, + plugins=[Path(p) for p in feature_plugins], + ).validated() + slide_cfg = SlideEncodingConfig( + encoders=encoders, + device=device.lower(), + skip_existing=skip_existing, + ).validated() + app_cfg = _build_wsi_app_config( + wsi_path=wsi_path, + output=output, + patch_size=patch_size, + step_size=step_size, + target_mag=target_mag, + device=device, + tissue_thresh=tissue_thresh, + white_thresh=white_thresh, + black_thresh=black_thresh, + seg_batch_size=seg_batch_size, + write_batch=write_batch, + patch_workers=patch_workers, + max_open_slides=max_open_slides, + fast_mode=fast_mode, + save_images=save_images, + visualize_grids=visualize_grids, + visualize_mask=visualize_mask, + visualize_contours=visualize_contours, + recursive=recursive, + mpp_csv=mpp_csv, + skip_existing=skip_existing, + feature_cfg=feature_cfg, + ) + slides = [ + Slide(path=Path(path)) + for path in get_wsi_files( + str(app_cfg.processing.input_path), + recursive=app_cfg.processing.recursive, + ) + ] + _validate_existing_slide_outputs(app_cfg, slides) + + try: + _, failures = _run_pipeline_from_config( + app_cfg, + slides=slides, + verbose=verbose, + registry=patch_registry, + announce_completion=False, + ) + except Exception as exc: # noqa: BLE001 + raise click.ClickException(str(exc)) from exc + + failed_slide_paths = {slide.path.resolve() for slide, _ in failures} + artifacts, artifact_failures = _collect_ready_slide_artifacts( + app_cfg, + slides, + required_patch_encoders=required_patch_encoders, + failed_slide_paths=failed_slide_paths, + ) + failures.extend(artifact_failures) + results, embedding_failures = SlideEmbeddingService( + app_cfg.extraction, + app_cfg.output, + slide_cfg, + registry=slide_registry, + ).embed_all(artifacts) + failures.extend(embedding_failures) + _echo_slide_results(results, failures, verbose) + + @cli.command() def info(): """Display supported formats and output structure.""" @@ -687,6 +1053,7 @@ def info(): click.echo( "Outputs: HDF5 per slide under patches/.h5; optional PNGs under images/; visualizations under visualization/." ) + click.echo("Slide embeddings append into slide_features/ inside patches/.h5.") def main(): diff --git a/atlas_patch/core/config.py b/atlas_patch/core/config.py index d4f26df..b75d636 100644 --- a/atlas_patch/core/config.py +++ b/atlas_patch/core/config.py @@ -3,6 +3,14 @@ from dataclasses import dataclass, field from pathlib import Path +DEFAULT_FEATURE_BATCH_SIZE = 32 +DEFAULT_FEATURE_NUM_WORKERS = 4 +DEFAULT_FEATURE_PRECISION = "float16" + + +def default_sam2_config_path() -> Path: + return Path(__file__).resolve().parent.parent / "configs" / "sam2.1_hiera_t.yaml" + def _ensure_positive(value: int, name: str) -> int: if value <= 0: @@ -108,10 +116,10 @@ def validated(self) -> ExtractionConfig: @dataclass class FeatureExtractionConfig: extractors: list[str] - batch_size: int = 32 + batch_size: int = DEFAULT_FEATURE_BATCH_SIZE device: str = "cuda" - num_workers: int = 4 - precision: str = "float32" + num_workers: int = DEFAULT_FEATURE_NUM_WORKERS + precision: str = DEFAULT_FEATURE_PRECISION plugins: list[Path] = field(default_factory=list) def validated(self) -> FeatureExtractionConfig: @@ -167,20 +175,13 @@ def validated(self) -> ProcessingConfig: @dataclass class SlideEncodingConfig: - input_path: Path encoders: list[str] - recursive: bool = False device: str = "cuda" skip_existing: bool = True - mpp_csv: Path | None = None def validated(self) -> SlideEncodingConfig: - if not self.input_path.exists(): - raise FileNotFoundError(f"Input path not found: {self.input_path}") self.encoders = _normalize_names(self.encoders, "slide encoder") self.device = _validate_device(str(self.device)) - if self.mpp_csv is not None and not self.mpp_csv.exists(): - raise FileNotFoundError(f"MPP CSV not found: {self.mpp_csv}") return self diff --git a/atlas_patch/core/models.py b/atlas_patch/core/models.py index 1aef841..a6c9d1e 100644 --- a/atlas_patch/core/models.py +++ b/atlas_patch/core/models.py @@ -46,7 +46,6 @@ class SlideEmbeddingResult: num_patches: int source_patch_encoder: str metadata: dict[str, Any] = field(default_factory=dict) - embedding: np.ndarray | None = None @dataclass @@ -58,4 +57,3 @@ class PatientEmbeddingResult: num_slides: int source_patch_encoder: str | None = None metadata: dict[str, Any] = field(default_factory=dict) - embedding: np.ndarray | None = None diff --git a/atlas_patch/models/slide/__init__.py b/atlas_patch/models/slide/__init__.py index 85f2b00..33c23eb 100644 --- a/atlas_patch/models/slide/__init__.py +++ b/atlas_patch/models/slide/__init__.py @@ -1,3 +1,9 @@ +from __future__ import annotations + +from importlib import import_module + +import torch + from atlas_patch.models.slide.base import ( SlideEncoder, SlideEncoderSpec, @@ -11,7 +17,21 @@ "build_default_registry", ] +_REGISTRARS: tuple[tuple[str, str], ...] = ( + ("atlas_patch.models.slide.titan", "register_titan_slide_encoder"), + ("atlas_patch.models.slide.prism", "register_prism_slide_encoder"), + ("atlas_patch.models.slide.moozy", "register_moozy_slide_encoder"), +) + -def build_default_registry() -> SlideEncoderRegistry: - """Return an empty slide-encoder registry for phased population.""" - return SlideEncoderRegistry() +def build_default_registry( + *, + device: str | torch.device = "cuda", +) -> SlideEncoderRegistry: + """Factory that registers the built-in slide encoders.""" + registry = SlideEncoderRegistry() + for module_name, registrar_name in _REGISTRARS: + module = import_module(module_name) + registrar = getattr(module, registrar_name) + registrar(registry, device=device) + return registry diff --git a/atlas_patch/models/slide/common.py b/atlas_patch/models/slide/common.py new file mode 100644 index 0000000..20b14a1 --- /dev/null +++ b/atlas_patch/models/slide/common.py @@ -0,0 +1,44 @@ +from __future__ import annotations + +import logging + +import numpy as np +import torch + +from atlas_patch.utils.feature_h5 import coerce_1d_embedding + +logger = logging.getLogger("atlas_patch.slide_models") + + +def resolve_slide_device(device: str | torch.device) -> torch.device: + resolved = torch.device(device) + if resolved.type == "cuda" and not torch.cuda.is_available(): + logger.warning("Slide encoding requested on CUDA but unavailable; using CPU instead.") + return torch.device("cpu") + return resolved + + +def coerce_slide_embedding( + output: torch.Tensor | np.ndarray, + *, + expected_dim: int, + encoder_name: str, +) -> np.ndarray: + tensor = torch.as_tensor(output) + if tensor.ndim == 2: + if tensor.shape[0] != 1: + raise ValueError( + f"Expected a single {encoder_name} embedding, got shape {tuple(tensor.shape)}" + ) + tensor = tensor.squeeze(0) + if tensor.ndim != 1: + raise ValueError(f"Expected a 1-D {encoder_name} embedding, got shape {tuple(tensor.shape)}") + vector = coerce_1d_embedding( + tensor.detach().to(device="cpu", dtype=torch.float32).numpy(), + label=encoder_name, + ) + if vector.shape[0] != expected_dim: + raise ValueError( + f"{encoder_name} embedding dim mismatch: expected {expected_dim}, got {vector.shape[0]}" + ) + return vector diff --git a/atlas_patch/models/slide/moozy.py b/atlas_patch/models/slide/moozy.py new file mode 100644 index 0000000..220d482 --- /dev/null +++ b/atlas_patch/models/slide/moozy.py @@ -0,0 +1,92 @@ +from __future__ import annotations + +from contextlib import nullcontext +import tempfile +from pathlib import Path + +import h5py +import numpy as np +import torch + +from atlas_patch.models.slide.base import SlideEncoder, SlideEncoderSpec +from atlas_patch.models.slide.common import coerce_slide_embedding, resolve_slide_device +from atlas_patch.models.slide.registry import SlideEncoderRegistry + +_EMBEDDING_DATASET = "features" + + +def _run_moozy_public_api( + slide_paths: list[str], + output_path: str, + *, + device: torch.device, + mixed_precision: bool = False, +) -> None: + try: + from moozy.encoding import run_encoding + except ModuleNotFoundError as exc: + raise RuntimeError( + "MOOZY slide encoding requires the optional `moozy` package. " + "Install `atlas-patch[moozy]`, `atlas-patch[slide-encoders]`, or `pip install moozy`." + ) from exc + + if device.type == "cpu" and torch.cuda.is_available(): + raise RuntimeError( + "MOOZY's public `run_encoding` API does not support forcing CPU when CUDA is " + "available. Use `--device cuda` or run in a CPU-only environment." + ) + + cuda_ctx = torch.cuda.device(device) if device.type == "cuda" else nullcontext() + with cuda_ctx: + run_encoding( + slide_paths=slide_paths, + output_path=output_path, + mixed_precision=mixed_precision, + ) + + +def _load_moozy_embedding(output_path: Path) -> np.ndarray: + with h5py.File(output_path, "r") as handle: + dataset = handle.get(_EMBEDDING_DATASET) + if not isinstance(dataset, h5py.Dataset): + raise ValueError(f"{output_path} is missing required dataset '{_EMBEDDING_DATASET}'.") + return np.asarray(dataset[()], dtype=np.float32) + + +class MOOZYSlideEncoder(SlideEncoder): + spec = SlideEncoderSpec( + name="moozy", + embedding_dim=768, + patch_encoder_name="lunit_vit_small_patch8_dino", + patch_size=224, + ) + + def __init__(self, *, device: str | torch.device = "cuda") -> None: + self.device = resolve_slide_device(device) + + def encode_slide(self, patch_h5_path: Path) -> np.ndarray: + with tempfile.TemporaryDirectory(prefix="atlaspatch_moozy_") as tmp_dir: + output_path = Path(tmp_dir) / "case_embedding.h5" + _run_moozy_public_api( + [str(patch_h5_path)], + str(output_path), + device=self.device, + mixed_precision=False, + ) + embedding = _load_moozy_embedding(output_path) + return coerce_slide_embedding( + embedding, + expected_dim=self.embedding_dim, + encoder_name="MOOZY", + ) + + +def register_moozy_slide_encoder( + registry: SlideEncoderRegistry, + *, + device: str | torch.device = "cuda", +) -> None: + registry.register( + MOOZYSlideEncoder.spec, + lambda: MOOZYSlideEncoder(device=device), + ) diff --git a/atlas_patch/models/slide/prism.py b/atlas_patch/models/slide/prism.py new file mode 100644 index 0000000..b8ef453 --- /dev/null +++ b/atlas_patch/models/slide/prism.py @@ -0,0 +1,98 @@ +from __future__ import annotations + +import sys +from pathlib import Path + +import numpy as np +import torch + +from atlas_patch.models.slide.base import SlideEncoder, SlideEncoderSpec +from atlas_patch.models.slide.common import coerce_slide_embedding, resolve_slide_device +from atlas_patch.models.slide.registry import SlideEncoderRegistry +from atlas_patch.utils.feature_h5 import load_patch_feature_data + +_MODEL_ID = "paige-ai/Prism" +_PATCH_FEATURE_DIM = 2560 + + +def _load_prism_model(*, device: torch.device, dtype: torch.dtype): + if sys.version_info < (3, 10): + raise RuntimeError("PRISM requires Python 3.10 or newer.") + + try: + import environs # noqa: F401 + import sacremoses # noqa: F401 + from transformers import AutoModel + except ModuleNotFoundError as exc: + raise RuntimeError( + "PRISM requires optional slide-encoder dependencies. " + "Install `atlas-patch[prism]` or `atlas-patch[slide-encoders]`." + ) from exc + + model = AutoModel.from_pretrained(_MODEL_ID, trust_remote_code=True) + if hasattr(model, "text_decoder"): + model.text_decoder = None + return model.to(device=device, dtype=dtype).eval() + + +class PrismSlideEncoder(SlideEncoder): + spec = SlideEncoderSpec( + name="prism", + embedding_dim=1280, + patch_encoder_name="virchow_v1", + patch_size=224, + ) + + def __init__(self, *, device: str | torch.device = "cuda") -> None: + self.device = resolve_slide_device(device) + self.dtype = torch.float16 if self.device.type == "cuda" else torch.float32 + self.model = _load_prism_model(device=self.device, dtype=self.dtype) + + def encode_slide(self, patch_h5_path: Path) -> np.ndarray: + patch_data = load_patch_feature_data(patch_h5_path, self.required_patch_encoder) + if patch_data.patch_size != self.required_patch_size: + raise ValueError( + f"{patch_h5_path} has patch_size={patch_data.patch_size}, " + f"but PRISM requires {self.required_patch_size}." + ) + if patch_data.features.shape[1] != _PATCH_FEATURE_DIM: + raise ValueError( + f"{patch_h5_path} has feature dim {patch_data.features.shape[1]}, " + f"but PRISM expects {_PATCH_FEATURE_DIM} from '{self.required_patch_encoder}'." + ) + + features = torch.from_numpy(patch_data.features).unsqueeze(0).to( + device=self.device, + dtype=self.dtype, + ) + tile_mask = torch.ones( + (1, patch_data.num_patches), + device=self.device, + dtype=torch.long, + ) + with torch.inference_mode(): + output = self.model.slide_representations(features, tile_mask=tile_mask) + if not isinstance(output, dict) or "image_embedding" not in output: + raise ValueError("PRISM did not return an 'image_embedding' entry.") + return coerce_slide_embedding( + output["image_embedding"], + expected_dim=self.embedding_dim, + encoder_name="PRISM", + ) + + def cleanup(self) -> None: + try: + self.model.cpu() + except Exception: + pass + + +def register_prism_slide_encoder( + registry: SlideEncoderRegistry, + *, + device: str | torch.device = "cuda", +) -> None: + registry.register( + PrismSlideEncoder.spec, + lambda: PrismSlideEncoder(device=device), + ) diff --git a/atlas_patch/models/slide/titan.py b/atlas_patch/models/slide/titan.py new file mode 100644 index 0000000..9edc427 --- /dev/null +++ b/atlas_patch/models/slide/titan.py @@ -0,0 +1,91 @@ +from __future__ import annotations + +from pathlib import Path + +import numpy as np +import torch + +from atlas_patch.models.slide.base import SlideEncoder, SlideEncoderSpec +from atlas_patch.models.slide.common import coerce_slide_embedding, resolve_slide_device +from atlas_patch.models.slide.registry import SlideEncoderRegistry +from atlas_patch.utils.feature_h5 import load_patch_feature_data + +_MODEL_ID = "MahmoodLab/TITAN" +_PATCH_FEATURE_DIM = 768 + + +def _load_titan_model(*, device: torch.device, dtype: torch.dtype): + try: + from transformers import AutoModel + except ModuleNotFoundError as exc: + raise RuntimeError( + "TITAN requires optional slide-encoder dependencies. " + "Install `atlas-patch[titan]` or `atlas-patch[slide-encoders]`." + ) from exc + + model = AutoModel.from_pretrained(_MODEL_ID, trust_remote_code=True) + return model.to(device=device, dtype=dtype).eval() + + +class TitanSlideEncoder(SlideEncoder): + spec = SlideEncoderSpec( + name="titan", + embedding_dim=768, + patch_encoder_name="conch_v15", + patch_size=512, + ) + + def __init__(self, *, device: str | torch.device = "cuda") -> None: + self.device = resolve_slide_device(device) + self.dtype = torch.float16 if self.device.type == "cuda" else torch.float32 + self.model = _load_titan_model(device=self.device, dtype=self.dtype) + + def encode_slide(self, patch_h5_path: Path) -> np.ndarray: + patch_data = load_patch_feature_data(patch_h5_path, self.required_patch_encoder) + if patch_data.patch_size != self.required_patch_size: + raise ValueError( + f"{patch_h5_path} has patch_size={patch_data.patch_size}, " + f"but TITAN requires {self.required_patch_size}." + ) + if patch_data.features.shape[1] != _PATCH_FEATURE_DIM: + raise ValueError( + f"{patch_h5_path} has feature dim {patch_data.features.shape[1]}, " + f"but TITAN expects {_PATCH_FEATURE_DIM} from '{self.required_patch_encoder}'." + ) + + features = torch.from_numpy(patch_data.features).unsqueeze(0).to( + device=self.device, + dtype=self.dtype, + ) + coords = torch.from_numpy(patch_data.coords[:, :2]).unsqueeze(0).to( + device=self.device, + dtype=torch.int64, + ) + with torch.inference_mode(): + embedding = self.model.encode_slide_from_patch_features( + features, + coords, + int(patch_data.patch_size_level0), + ) + return coerce_slide_embedding( + embedding, + expected_dim=self.embedding_dim, + encoder_name="TITAN", + ) + + def cleanup(self) -> None: + try: + self.model.cpu() + except Exception: + pass + + +def register_titan_slide_encoder( + registry: SlideEncoderRegistry, + *, + device: str | torch.device = "cuda", +) -> None: + registry.register( + TitanSlideEncoder.spec, + lambda: TitanSlideEncoder(device=device), + ) diff --git a/atlas_patch/orchestration/runner.py b/atlas_patch/orchestration/runner.py index 4a44c20..6893b86 100644 --- a/atlas_patch/orchestration/runner.py +++ b/atlas_patch/orchestration/runner.py @@ -199,8 +199,12 @@ def _resolve_max_open_slides(self, patch_workers: int, batch_size: int) -> int: raise ValueError("max_open_slides must be defined") return max(1, int(cfg_val)) - def run(self) -> tuple[list[ExtractionResult], list[tuple[Slide, Exception | str]]]: - slides = self.discover_slides() + def run( + self, + slides: list[Slide] | None = None, + ) -> tuple[list[ExtractionResult], list[tuple[Slide, Exception | str]]]: + if slides is None: + slides = self.discover_slides() slides = self._attach_mpp(slides) if not slides: diff --git a/atlas_patch/services/slide_embedding.py b/atlas_patch/services/slide_embedding.py new file mode 100644 index 0000000..e3cce2f --- /dev/null +++ b/atlas_patch/services/slide_embedding.py @@ -0,0 +1,260 @@ +from __future__ import annotations + +import os +import time +from pathlib import Path +from typing import Iterable + +from atlas_patch.core.config import ExtractionConfig, OutputConfig, SlideEncodingConfig +from atlas_patch.core.models import Slide, SlideEmbeddingResult +from atlas_patch.core.paths import slide_append_lock_path, slide_feature_dataset_key +from atlas_patch.models.slide import SlideEncoderRegistry, build_default_registry +from atlas_patch.utils.feature_h5 import ( + PatchArtifactSummary, + SlideEmbeddingSummary, + append_slide_embedding, + read_patch_artifact_summary, + read_slide_embedding_summary, +) + + +class SlideEmbeddingService: + """Append slide-level embeddings into canonical AtlasPatch patch H5 files.""" + + def __init__( + self, + extraction_cfg: ExtractionConfig, + output_cfg: OutputConfig, + slide_cfg: SlideEncodingConfig, + *, + registry: SlideEncoderRegistry | None = None, + ) -> None: + self.extraction_cfg = extraction_cfg.validated() + self.output_cfg = output_cfg.validated() + self.slide_cfg = slide_cfg.validated() + self.registry = registry or build_default_registry(device=self.slide_cfg.device) + + def _acquire_lock(self, slide: Slide) -> tuple[int | None, Path]: + lock_path = slide_append_lock_path(slide, self.output_cfg, self.extraction_cfg) + lock_path.parent.mkdir(parents=True, exist_ok=True) + payload = f"pid={os.getpid()},time={int(time.time())},slide={slide.path},phase=slide-embedding" + try: + fd = os.open(lock_path, os.O_CREAT | os.O_EXCL | os.O_WRONLY) + os.write(fd, payload.encode()) + os.fsync(fd) + return fd, lock_path + except FileExistsError: + return None, lock_path + except Exception as exc: # noqa: BLE001 + raise RuntimeError(f"Failed to create slide lock {lock_path}: {exc}") from exc + + @staticmethod + def _release_lock(fd: int | None, lock_path: Path) -> None: + if fd is not None: + try: + os.close(fd) + except Exception: + pass + try: + lock_path.unlink() + except FileNotFoundError: + pass + except Exception: + pass + + @staticmethod + def _cleanup_encoder(encoder) -> None: + cleanup = getattr(encoder, "cleanup", None) + if callable(cleanup): + try: + cleanup() + except Exception: + pass + + @staticmethod + def _build_result( + *, + slide: Slide, + h5_path: Path, + encoder_name: str, + embedding_dim: int, + source_patch_encoder: str, + num_patches: int, + patch_size: int, + patch_size_level0: int, + target_magnification: int, + reused: bool, + ) -> SlideEmbeddingResult: + return SlideEmbeddingResult( + slide=slide, + h5_path=h5_path, + encoder_name=encoder_name, + dataset_key=slide_feature_dataset_key(encoder_name), + embedding_dim=int(embedding_dim), + num_patches=int(num_patches), + source_patch_encoder=source_patch_encoder, + metadata={ + "patch_size": int(patch_size), + "patch_size_level0": int(patch_size_level0), + "target_magnification": int(target_magnification), + "reused": bool(reused), + }, + ) + + @staticmethod + def _validate_reusable_embedding( + *, + encoder, + h5_path: Path, + summary: SlideEmbeddingSummary, + patch_summary: PatchArtifactSummary, + ) -> None: + if summary.embedding_dim != encoder.embedding_dim: + raise ValueError( + f"{h5_path} has slide embedding dim {summary.embedding_dim} for '{summary.dataset_key}', " + f"but '{encoder.name}' expects {encoder.embedding_dim}." + ) + if summary.source_patch_encoder != encoder.required_patch_encoder: + raise ValueError( + f"{h5_path} was encoded from '{summary.source_patch_encoder}', " + f"but '{encoder.name}' requires '{encoder.required_patch_encoder}'." + ) + if summary.patch_size != encoder.required_patch_size: + raise ValueError( + f"{h5_path} records patch_size={summary.patch_size} for '{summary.dataset_key}', " + f"but '{encoder.name}' requires {encoder.required_patch_size}." + ) + if summary.num_patches != patch_summary.num_patches: + raise ValueError( + f"{h5_path} records {summary.num_patches} patches for '{summary.dataset_key}', " + f"but the current patch artifact has {patch_summary.num_patches}." + ) + if summary.patch_size_level0 != patch_summary.patch_size_level0: + raise ValueError( + f"{h5_path} records patch_size_level0={summary.patch_size_level0} for " + f"'{summary.dataset_key}', but the current patch artifact has " + f"{patch_summary.patch_size_level0}." + ) + if summary.target_magnification != patch_summary.target_magnification: + raise ValueError( + f"{h5_path} records target_mag={summary.target_magnification} for " + f"'{summary.dataset_key}', but the current patch artifact has " + f"{patch_summary.target_magnification}." + ) + + def embed_all( + self, + artifacts: Iterable[tuple[Slide, Path]], + ) -> tuple[list[SlideEmbeddingResult], list[tuple[Slide, Exception | str]]]: + items = list(artifacts) + results: list[SlideEmbeddingResult] = [] + failures: list[tuple[Slide, Exception | str]] = [] + if not items: + return results, failures + + for encoder_name in self.slide_cfg.encoders: + try: + encoder = self.registry.create(encoder_name) + except Exception as exc: # noqa: BLE001 + for slide, _ in items: + failures.append((slide, RuntimeError(f"{encoder_name}: {exc}"))) + continue + + try: + for slide, h5_path in items: + try: + patch_summary = read_patch_artifact_summary(h5_path) + except Exception as exc: # noqa: BLE001 + failures.append((slide, RuntimeError(f"{encoder_name}: {exc}"))) + continue + + overwrite_existing = not self.slide_cfg.skip_existing + if self.slide_cfg.skip_existing: + try: + existing = read_slide_embedding_summary(h5_path, encoder_name) + except Exception: + existing = None + overwrite_existing = True + + if existing is not None: + try: + self._validate_reusable_embedding( + encoder=encoder, + h5_path=h5_path, + summary=existing, + patch_summary=patch_summary, + ) + except Exception: + overwrite_existing = True + else: + results.append( + self._build_result( + slide=slide, + h5_path=h5_path, + encoder_name=encoder_name, + embedding_dim=existing.embedding_dim, + source_patch_encoder=existing.source_patch_encoder, + num_patches=existing.num_patches, + patch_size=existing.patch_size, + patch_size_level0=existing.patch_size_level0, + target_magnification=existing.target_magnification, + reused=True, + ) + ) + continue + + lock_fd, lock_path = self._acquire_lock(slide) + if lock_fd is None: + failures.append( + ( + slide, + RuntimeError( + f"{encoder_name}: {h5_path} is locked by another process." + ), + ) + ) + continue + + try: + if patch_summary.patch_size != encoder.required_patch_size: + raise ValueError( + f"{h5_path} has patch_size={patch_summary.patch_size}, " + f"but '{encoder_name}' requires {encoder.required_patch_size}." + ) + + embedding = encoder.encode_slide(h5_path) + append_slide_embedding( + h5_path, + encoder_name, + embedding, + attrs={ + "source_patch_encoder": encoder.required_patch_encoder, + "num_patches": patch_summary.num_patches, + "patch_size": patch_summary.patch_size, + "patch_size_level0": patch_summary.patch_size_level0, + "target_magnification": patch_summary.target_magnification, + }, + overwrite=overwrite_existing, + ) + results.append( + self._build_result( + slide=slide, + h5_path=h5_path, + encoder_name=encoder_name, + embedding_dim=int(embedding.shape[0]), + source_patch_encoder=encoder.required_patch_encoder, + num_patches=patch_summary.num_patches, + patch_size=patch_summary.patch_size, + patch_size_level0=patch_summary.patch_size_level0, + target_magnification=patch_summary.target_magnification, + reused=False, + ) + ) + except Exception as exc: # noqa: BLE001 + failures.append((slide, RuntimeError(f"{encoder_name}: {exc}"))) + finally: + self._release_lock(lock_fd, lock_path) + finally: + self._cleanup_encoder(encoder) + + return results, failures diff --git a/atlas_patch/utils/__init__.py b/atlas_patch/utils/__init__.py index 954277d..bb9486d 100644 --- a/atlas_patch/utils/__init__.py +++ b/atlas_patch/utils/__init__.py @@ -7,13 +7,7 @@ from .features import ( get_existing_features, missing_features, - parse_feature_list, -) -from .feature_h5 import ( - PatchFeatureData, - append_slide_embedding, - load_patch_feature_data, - write_patient_embedding_h5, + parse_named_list, ) from .h5 import H5AppendWriter from .hf import import_module_from_hf @@ -32,12 +26,8 @@ "SuppressEmbeddingLogs", "configure_logging", "install_embedding_log_filter", - "parse_feature_list", + "parse_named_list", "get_existing_features", "missing_features", - "PatchFeatureData", - "load_patch_feature_data", - "append_slide_embedding", - "write_patient_embedding_h5", "import_module_from_hf", ] diff --git a/atlas_patch/utils/feature_h5.py b/atlas_patch/utils/feature_h5.py index cd37990..d30db2e 100644 --- a/atlas_patch/utils/feature_h5.py +++ b/atlas_patch/utils/feature_h5.py @@ -18,6 +18,13 @@ "patch_size", "target_magnification", ) +REQUIRED_SLIDE_EMBEDDING_ATTRS = ( + "source_patch_encoder", + "num_patches", + "patch_size", + "patch_size_level0", + "target_magnification", +) @dataclass(frozen=True) @@ -36,6 +43,27 @@ def num_patches(self) -> int: return int(self.features.shape[0]) +@dataclass(frozen=True) +class PatchArtifactSummary: + h5_path: Path + num_patches: int + patch_size_level0: int + patch_size: int + target_magnification: int + wsi_path: str | None = None + + +@dataclass(frozen=True) +class SlideEmbeddingSummary: + dataset_key: str + embedding_dim: int + source_patch_encoder: str + num_patches: int + patch_size: int + patch_size_level0: int + target_magnification: int + + def _encode_attr_value(value: Any) -> Any: if isinstance(value, Path): return str(value) @@ -66,13 +94,72 @@ def _read_required_int_attrs(h5_path: Path, *sources: Mapping[str, Any]) -> dict return values -def _coerce_embedding(embedding: np.ndarray, *, label: str) -> np.ndarray: +def coerce_1d_embedding(embedding: np.ndarray, *, label: str) -> np.ndarray: vector = np.asarray(embedding) if vector.ndim != 1: raise ValueError(f"{label} must be 1-D, got shape {vector.shape}") return vector +def read_patch_artifact_summary(h5_path: str | Path) -> PatchArtifactSummary: + path = Path(h5_path) + with h5py.File(path, "r") as handle: + coords_ds = handle.get("coords") + if not isinstance(coords_ds, h5py.Dataset): + raise ValueError(f"{path} is missing required dataset 'coords'.") + + num_patches_attr = handle.attrs.get("num_patches") + if num_patches_attr is not None: + num_patches = int(num_patches_attr) + else: + num_patches = int(coords_ds.shape[0]) + + attrs = _read_required_int_attrs(path, handle.attrs, coords_ds.attrs) + wsi_path = handle.attrs.get("wsi_path") + + return PatchArtifactSummary( + h5_path=path, + num_patches=num_patches, + patch_size_level0=attrs["patch_size_level0"], + patch_size=attrs["patch_size"], + target_magnification=attrs["target_magnification"], + wsi_path=str(wsi_path) if wsi_path is not None else None, + ) + + +def read_slide_embedding_summary( + h5_path: str | Path, + encoder_name: str, +) -> SlideEmbeddingSummary | None: + path = Path(h5_path) + dataset_key = slide_feature_dataset_key(encoder_name) + + with h5py.File(path, "r") as handle: + dataset = handle.get(dataset_key) + if dataset is None: + return None + if not isinstance(dataset, h5py.Dataset): + raise ValueError(f"{path} contains '{dataset_key}', but it is not an HDF5 dataset.") + if dataset.ndim != 1: + raise ValueError(f"{path} has invalid slide embedding shape {dataset.shape} for '{dataset_key}'.") + + attrs = _read_required_int_attrs(path, dataset.attrs) + missing = [key for key in REQUIRED_SLIDE_EMBEDDING_ATTRS if key not in dataset.attrs] + if missing: + joined = ", ".join(missing) + raise ValueError(f"{path} is missing required slide embedding metadata for '{dataset_key}': {joined}") + + return SlideEmbeddingSummary( + dataset_key=dataset_key, + embedding_dim=int(dataset.shape[0]), + source_patch_encoder=str(dataset.attrs["source_patch_encoder"]).strip().lower(), + num_patches=int(dataset.attrs["num_patches"]), + patch_size=attrs["patch_size"], + patch_size_level0=attrs["patch_size_level0"], + target_magnification=attrs["target_magnification"], + ) + + def load_patch_feature_data( h5_path: str | Path, feature_name: str, @@ -134,7 +221,7 @@ def append_slide_embedding( path = Path(h5_path) dataset_key = slide_feature_dataset_key(encoder_name) dataset_name = dataset_key.split("/", 1)[1] - vector = _coerce_embedding(embedding, label=dataset_key) + vector = coerce_1d_embedding(embedding, label=dataset_key) with h5py.File(path, "r+") as handle: group = handle.require_group("slide_features") @@ -161,7 +248,7 @@ def write_patient_embedding_h5( ) -> Path: """Write a compact patient-level embedding H5 atomically.""" path = Path(output_path) - vector = _coerce_embedding(embedding, label=str(path)) + vector = coerce_1d_embedding(embedding, label=str(path)) path.parent.mkdir(parents=True, exist_ok=True) if path.exists() and not overwrite: raise FileExistsError(f"Patient embedding already exists: {path}") diff --git a/atlas_patch/utils/features.py b/atlas_patch/utils/features.py index ac4ab05..e628421 100644 --- a/atlas_patch/utils/features.py +++ b/atlas_patch/utils/features.py @@ -7,18 +7,18 @@ import h5py -def parse_feature_list(raw: str, *, choices: list[str]) -> list[str]: - """Normalize, validate, and deduplicate a feature extractor list.""" +def parse_named_list(raw: str, *, choices: list[str], item_label: str) -> list[str]: + """Normalize, validate, and deduplicate a list of named registry entries.""" parts = [p.strip().lower() for p in raw.replace(",", " ").split() if p.strip()] if not parts: - raise click.BadParameter("At least one feature extractor name is required.") + raise click.BadParameter(f"At least one {item_label} name is required.") unknown = [p for p in parts if p not in choices] if unknown: raise click.BadParameter( - f"Unknown extractor(s): {', '.join(unknown)}. Available: {', '.join(choices)}" + f"Unknown {item_label}(s): {', '.join(unknown)}. Available: {', '.join(choices)}" ) - seen = set() - dupes = [] + seen: set[str] = set() + dupes: list[str] = [] unique_parts: list[str] = [] for p in parts: if p in seen: @@ -28,8 +28,8 @@ def parse_feature_list(raw: str, *, choices: list[str]) -> list[str]: unique_parts.append(p) if dupes: raise click.BadParameter( - f"Duplicate extractor(s) specified: {', '.join(sorted(set(dupes)))}. " - "Provide each extractor at most once." + f"Duplicate {item_label}(s) specified: {', '.join(sorted(set(dupes)))}. " + f"Provide each {item_label} once." ) return unique_parts diff --git a/pyproject.toml b/pyproject.toml index 5f73ac0..4cffea0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -54,6 +54,38 @@ patch-encoders = [ "timm>=0.9.0", "transformers>=4.41.0", ] +slide-encoders = [ + "einops>=0.8.0", + "einops-exts>=0.0.4", + "fairscale>=0.4.0", + "gdown>=5.2.0", + "open-clip-torch>=2.24.0", + "sentencepiece>=0.2.0", + "timm>=0.9.0", + "transformers>=4.41.0", + "environs>=11.0.0", + "sacremoses>=0.1.1", + "moozy>=0.1.0", +] +titan = [ + "einops>=0.8.0", + "einops-exts>=0.0.4", + "fairscale>=0.4.0", + "gdown>=5.2.0", + "open-clip-torch>=2.24.0", + "sentencepiece>=0.2.0", + "transformers>=4.41.0", +] +prism = [ + "timm>=0.9.0", + "transformers>=4.41.0", + "environs>=11.0.0", + "sacremoses>=0.1.1", +] +moozy = [ + "timm>=0.9.0", + "moozy>=0.1.0", +] release = [ "build>=1.2.2", "twine>=5.1.1", From eda669fe0ee3a658cf95335900f0c9e2c561a5f7 Mon Sep 17 00:00:00 2001 From: yousefkotp Date: Mon, 6 Apr 2026 18:22:41 -0400 Subject: [PATCH 4/8] feat: add patient encoder support --- atlas_patch/cli.py | 273 +++++++++++++++++--- atlas_patch/core/__init__.py | 3 +- atlas_patch/core/models.py | 10 + atlas_patch/core/paths.py | 6 +- atlas_patch/models/{slide => }/common.py | 22 +- atlas_patch/models/moozy_api.py | 49 ++++ atlas_patch/models/patient/__init__.py | 24 +- atlas_patch/models/patient/moozy.py | 62 +++++ atlas_patch/models/slide/moozy.py | 55 +--- atlas_patch/models/slide/prism.py | 8 +- atlas_patch/models/slide/titan.py | 8 +- atlas_patch/services/_embedding_runtime.py | 45 ++++ atlas_patch/services/mpp.py | 2 + atlas_patch/services/patient_embedding.py | 282 +++++++++++++++++++++ atlas_patch/services/slide_embedding.py | 55 ++-- atlas_patch/utils/feature_h5.py | 100 ++++++++ 16 files changed, 858 insertions(+), 146 deletions(-) rename atlas_patch/models/{slide => }/common.py (54%) create mode 100644 atlas_patch/models/moozy_api.py create mode 100644 atlas_patch/models/patient/moozy.py create mode 100644 atlas_patch/services/_embedding_runtime.py create mode 100644 atlas_patch/services/patient_embedding.py diff --git a/atlas_patch/cli.py b/atlas_patch/cli.py index 1247c45..89cde73 100644 --- a/atlas_patch/cli.py +++ b/atlas_patch/cli.py @@ -13,13 +13,13 @@ DEFAULT_FEATURE_PRECISION, AppConfig, FeatureExtractionConfig, + PatientEncodingConfig, SlideEncodingConfig, default_sam2_config_path, ) -from atlas_patch.core.models import Slide +from atlas_patch.core.models import PatientCase, Slide from atlas_patch.core.paths import patch_h5_path from atlas_patch.models.patch.registry import PatchFeatureExtractorRegistry -from atlas_patch.models.slide.registry import SlideEncoderRegistry from atlas_patch.utils import ( configure_logging, install_embedding_log_filter, @@ -37,8 +37,9 @@ install_embedding_log_filter() # Shared option sets ----------------------------------------------------------- -_COMMON_OPTIONS: list = [ - click.argument("wsi_path", type=click.Path(exists=True)), +_INPUT_ARGUMENT = click.argument("wsi_path", type=click.Path(exists=True)) + +_PIPELINE_OPTIONS: list = [ click.option( "--output", "-o", @@ -119,16 +120,21 @@ click.option("--visualize-grids", is_flag=True, help="Render patch grid overlay."), click.option("--visualize-mask", is_flag=True, help="Render predicted mask overlay."), click.option("--visualize-contours", is_flag=True, help="Render contour overlay."), - click.option("--recursive", is_flag=True, help="Recursively search directories for WSIs."), - click.option( - "--mpp-csv", type=click.Path(exists=True), default=None, help="CSV with custom MPP." - ), click.option( "--skip-existing/--force", default=True, show_default=True, help="Skip existing H5." ), click.option("--verbose", "-v", is_flag=True, help="Enable debug logging."), ] +_DISCOVERY_OPTIONS: list = [ + click.option("--recursive", is_flag=True, help="Recursively search directories for WSIs."), + click.option( + "--mpp-csv", type=click.Path(exists=True), default=None, help="CSV with custom MPP." + ), +] + +_COMMON_OPTIONS: list = [_INPUT_ARGUMENT, *_PIPELINE_OPTIONS, *_DISCOVERY_OPTIONS] + _FEATURE_RUNTIME_OPTIONS: list = [ click.option( "--feature-device", @@ -194,6 +200,10 @@ def common_options(func): return _apply_options(func, _COMMON_OPTIONS) +def pipeline_options(func): + return _apply_options(func, _PIPELINE_OPTIONS) + + def feature_options(func): return _apply_options(func, _FEATURE_OPTIONS) @@ -236,23 +246,23 @@ def _build_patch_feature_registry( return registry, feat_device, precision -def _resolve_slide_encoder_requirements( - slide_registry: SlideEncoderRegistry, +def _resolve_encoder_requirements( + registry, encoders: list[str], ) -> tuple[int, list[str]]: - specs = [slide_registry.get_spec(name) for name in encoders] + specs = [registry.get_spec(name) for name in encoders] patch_sizes = {spec.patch_size for spec in specs} if len(patch_sizes) != 1: parts = ", ".join(f"{spec.name}:{spec.patch_size}" for spec in specs) raise click.ClickException( - "Requested slide encoders require incompatible patch geometries for one canonical H5: " + "Requested encoders require incompatible patch geometries for one canonical H5: " f"{parts}" ) required_patch_encoders = sorted({spec.patch_encoder_name for spec in specs}) return next(iter(patch_sizes)), required_patch_encoders -def _validate_existing_slide_outputs(app_cfg: AppConfig, slides: list[Slide]) -> None: +def _validate_existing_patch_artifacts(app_cfg: AppConfig, slides: list[Slide]) -> None: if not app_cfg.output.skip_existing: return @@ -268,18 +278,18 @@ def _validate_existing_slide_outputs(app_cfg: AppConfig, slides: list[Slide]) -> if summary.patch_size != app_cfg.extraction.patch_size: raise click.ClickException( f"Existing patch artifact {h5_path} has patch_size={summary.patch_size}, " - f"but encode-slide requested --patch-size {app_cfg.extraction.patch_size}. " + f"but the current run expects --patch-size {app_cfg.extraction.patch_size}. " "Re-run with --force to rebuild the canonical H5." ) if summary.target_magnification != app_cfg.extraction.target_magnification: raise click.ClickException( f"Existing patch artifact {h5_path} has target_mag={summary.target_magnification}, " - f"but encode-slide requested --target-mag {app_cfg.extraction.target_magnification}. " + f"but the current run expects --target-mag {app_cfg.extraction.target_magnification}. " "Re-run with --force to rebuild the canonical H5." ) -def _collect_ready_slide_artifacts( +def _collect_ready_patch_artifacts( app_cfg: AppConfig, slides: list[Slide], *, @@ -310,7 +320,7 @@ def _collect_ready_slide_artifacts( if summary.patch_size != app_cfg.extraction.patch_size: raise ValueError( f"{h5_path} has patch_size={summary.patch_size}, " - f"but encode-slide expects {app_cfg.extraction.patch_size}." + f"but the current run expects {app_cfg.extraction.patch_size}." ) missing = missing_features( h5_path, @@ -328,9 +338,37 @@ def _collect_ready_slide_artifacts( return artifacts, failures -def _build_wsi_app_config( +def _collect_ready_patient_cases( + cases: list[PatientCase], *, - wsi_path: str, + slide_h5_by_path: dict[Path, Path], + slide_errors: dict[Path, Exception | str], +) -> tuple[list[PatientCase], list[tuple[PatientCase, Exception | str]]]: + ready_cases: list[PatientCase] = [] + failures: list[tuple[PatientCase, Exception | str]] = [] + + for case in cases: + missing: list[str] = [] + for slide in case.slides: + resolved_path = slide.path.resolve() + if resolved_path in slide_h5_by_path: + continue + err = slide_errors.get(resolved_path) + if err is None: + missing.append(f"{slide.path.name}: no canonical patch artifact was produced") + else: + missing.append(f"{slide.path.name}: {err}") + if missing: + failures.append((case, RuntimeError("; ".join(missing)))) + continue + ready_cases.append(case) + + return ready_cases, failures + + +def _build_pipeline_app_config( + *, + input_path: str, output: str, patch_size: int, step_size: int | None, @@ -363,7 +401,7 @@ def _build_wsi_app_config( return AppConfig( processing=ProcessingConfig( - input_path=Path(wsi_path), + input_path=Path(input_path), recursive=recursive, mpp_csv=Path(mpp_csv) if mpp_csv else None, ), @@ -477,7 +515,7 @@ def _run_pipeline_from_config( def _run_pipeline( *, - wsi_path: str, + input_path: str, output: str, patch_size: int, step_size: int | None, @@ -503,8 +541,8 @@ def _run_pipeline( registry: PatchFeatureExtractorRegistry | None = None, announce_completion: bool = True, ) -> tuple[list, list]: - app_cfg = _build_wsi_app_config( - wsi_path=wsi_path, + app_cfg = _build_pipeline_app_config( + input_path=input_path, output=output, patch_size=patch_size, step_size=step_size, @@ -694,6 +732,18 @@ def _echo_slide_results( click.echo(f"[FAIL] {slide.path.name}: {err}", err=True) +def _echo_patient_results( + results: list, failures: list, verbose: bool +) -> None: + click.echo(f"Completed {len(results)} patient embedding(s), failures: {len(failures)}") + if verbose: + for res in results: + status = "reused" if res.metadata.get("reused") else "written" + click.echo(f"[OK] {res.case_id} -> {res.h5_path} ({status})") + for case, err in failures: + click.echo(f"[FAIL] {case.case_id}: {err}", err=True) + + @click.group() @click.version_option(version=__version__) def cli(): @@ -732,7 +782,7 @@ def segment_and_get_coords( ): """Segment, patchify, and optionally visualize WSI files.""" results, failures = _run_pipeline( - wsi_path=wsi_path, + input_path=wsi_path, output=output, patch_size=patch_size, step_size=step_size, @@ -866,7 +916,7 @@ def process( plugins=[Path(p) for p in feature_plugins], ) results, failures = _run_pipeline( - wsi_path=wsi_path, + input_path=wsi_path, output=output, patch_size=patch_size, step_size=step_size, @@ -944,7 +994,7 @@ def encode_slide( choices=available_encoders, item_label="slide encoder", ) - required_patch_size, required_patch_encoders = _resolve_slide_encoder_requirements( + required_patch_size, required_patch_encoders = _resolve_encoder_requirements( slide_registry, encoders, ) @@ -981,8 +1031,8 @@ def encode_slide( device=device.lower(), skip_existing=skip_existing, ).validated() - app_cfg = _build_wsi_app_config( - wsi_path=wsi_path, + app_cfg = _build_pipeline_app_config( + input_path=wsi_path, output=output, patch_size=patch_size, step_size=step_size, @@ -1012,7 +1062,7 @@ def encode_slide( recursive=app_cfg.processing.recursive, ) ] - _validate_existing_slide_outputs(app_cfg, slides) + _validate_existing_patch_artifacts(app_cfg, slides) try: _, failures = _run_pipeline_from_config( @@ -1026,7 +1076,7 @@ def encode_slide( raise click.ClickException(str(exc)) from exc failed_slide_paths = {slide.path.resolve() for slide, _ in failures} - artifacts, artifact_failures = _collect_ready_slide_artifacts( + artifacts, artifact_failures = _collect_ready_patch_artifacts( app_cfg, slides, required_patch_encoders=required_patch_encoders, @@ -1043,6 +1093,168 @@ def encode_slide( _echo_slide_results(results, failures, verbose) +@cli.command() +@click.argument("manifest_path", type=click.Path(exists=True)) +@feature_runtime_options +@pipeline_options +@click.option( + "--patient-encoders", + required=True, + type=str, + help="Space/comma separated patient encoders to run.", +) +def encode_patient( + manifest_path: str, + output: str, + patch_size: int, + step_size: int | None, + target_mag: int, + device: str, + feature_device: str | None, + feature_batch_size: int, + feature_num_workers: int, + feature_precision: str, + tissue_thresh: float, + white_thresh: int, + black_thresh: int, + seg_batch_size: int, + write_batch: int, + patch_workers: int | None, + max_open_slides: int | None, + fast_mode: bool, + save_images: bool, + visualize_grids: bool, + visualize_mask: bool, + visualize_contours: bool, + skip_existing: bool, + verbose: bool, + feature_plugins: tuple[str, ...], + patient_encoders: str, +): + """Run the WSI pipeline from a case manifest and write patient embeddings.""" + from atlas_patch.models.patient import build_default_registry + from atlas_patch.services.patient_embedding import ( + PatientEmbeddingService, + load_patient_cases, + ) + + patient_registry = build_default_registry(device=device.lower()) + available_encoders = patient_registry.available() + encoders = parse_named_list( + patient_encoders, + choices=available_encoders, + item_label="patient encoder", + ) + required_patch_size, required_patch_encoders = _resolve_encoder_requirements( + patient_registry, + encoders, + ) + if patch_size != required_patch_size: + raise click.ClickException( + f"Requested patient encoders require patch_size={required_patch_size}, " + f"but encode-patient was given --patch-size {patch_size}." + ) + + cases, slides = load_patient_cases(manifest_path) + if not cases: + raise click.ClickException(f"{manifest_path} does not contain any cases.") + + patch_registry, feat_device, precision = _build_patch_feature_registry( + device=device, + feature_device=feature_device, + feature_num_workers=feature_num_workers, + feature_precision=feature_precision, + feature_plugins=feature_plugins, + ) + available_extractors = patch_registry.available() + missing_extractors = [ + name for name in required_patch_encoders if name not in available_extractors + ] + if missing_extractors: + joined = ", ".join(missing_extractors) + raise click.ClickException( + f"Required patch extractor(s) for the selected patient encoders are unavailable: {joined}" + ) + + feature_cfg = FeatureExtractionConfig( + extractors=required_patch_encoders, + batch_size=feature_batch_size, + device=feat_device, + num_workers=feature_num_workers, + precision=precision, + plugins=[Path(p) for p in feature_plugins], + ).validated() + patient_cfg = PatientEncodingConfig( + manifest_path=Path(manifest_path), + encoders=encoders, + device=device.lower(), + skip_existing=skip_existing, + ).validated() + app_cfg = _build_pipeline_app_config( + input_path=manifest_path, + output=output, + patch_size=patch_size, + step_size=step_size, + target_mag=target_mag, + device=device, + tissue_thresh=tissue_thresh, + white_thresh=white_thresh, + black_thresh=black_thresh, + seg_batch_size=seg_batch_size, + write_batch=write_batch, + patch_workers=patch_workers, + max_open_slides=max_open_slides, + fast_mode=fast_mode, + save_images=save_images, + visualize_grids=visualize_grids, + visualize_mask=visualize_mask, + visualize_contours=visualize_contours, + recursive=False, + mpp_csv=None, + skip_existing=skip_existing, + feature_cfg=feature_cfg, + ) + _validate_existing_patch_artifacts(app_cfg, slides) + + try: + _, pipeline_failures = _run_pipeline_from_config( + app_cfg, + slides=slides, + verbose=verbose, + registry=patch_registry, + announce_completion=False, + ) + except Exception as exc: # noqa: BLE001 + raise click.ClickException(str(exc)) from exc + + failed_slide_paths = {slide.path.resolve() for slide, _ in pipeline_failures} + artifacts, artifact_failures = _collect_ready_patch_artifacts( + app_cfg, + slides, + required_patch_encoders=required_patch_encoders, + failed_slide_paths=failed_slide_paths, + ) + slide_errors: dict[Path, Exception | str] = { + slide.path.resolve(): err for slide, err in [*pipeline_failures, *artifact_failures] + } + slide_h5_by_path = {slide.path.resolve(): h5_path for slide, h5_path in artifacts} + ready_cases, case_failures = _collect_ready_patient_cases( + cases, + slide_h5_by_path=slide_h5_by_path, + slide_errors=slide_errors, + ) + results, embedding_failures = PatientEmbeddingService( + app_cfg.output, + patient_cfg, + registry=patient_registry, + ).embed_all( + ready_cases, + slide_h5_by_path=slide_h5_by_path, + ) + failures = [*case_failures, *embedding_failures] + _echo_patient_results(results, failures, verbose) + + @cli.command() def info(): """Display supported formats and output structure.""" @@ -1054,6 +1266,7 @@ def info(): "Outputs: HDF5 per slide under patches/.h5; optional PNGs under images/; visualizations under visualization/." ) click.echo("Slide embeddings append into slide_features/ inside patches/.h5.") + click.echo("Patient embeddings are written separately under patient_features//.h5.") def main(): diff --git a/atlas_patch/core/__init__.py b/atlas_patch/core/__init__.py index 72bfd62..b93ed69 100644 --- a/atlas_patch/core/__init__.py +++ b/atlas_patch/core/__init__.py @@ -9,7 +9,7 @@ SegmentationConfig, VisualizationConfig, ) -from .models import ExtractionResult, Mask, Slide +from .models import ExtractionResult, Mask, PatientCase, Slide __all__ = [ "AppConfig", @@ -21,5 +21,6 @@ "VisualizationConfig", "ExtractionResult", "Mask", + "PatientCase", "Slide", ] diff --git a/atlas_patch/core/models.py b/atlas_patch/core/models.py index a6c9d1e..683d0f3 100644 --- a/atlas_patch/core/models.py +++ b/atlas_patch/core/models.py @@ -57,3 +57,13 @@ class PatientEmbeddingResult: num_slides: int source_patch_encoder: str | None = None metadata: dict[str, Any] = field(default_factory=dict) + + +@dataclass(frozen=True) +class PatientCase: + case_id: str + slides: tuple[Slide, ...] + + @property + def num_slides(self) -> int: + return len(self.slides) diff --git a/atlas_patch/core/paths.py b/atlas_patch/core/paths.py index b6ec097..8052a18 100644 --- a/atlas_patch/core/paths.py +++ b/atlas_patch/core/paths.py @@ -25,7 +25,7 @@ def _normalize_dataset_component(name: str, *, prefix: str) -> str: return normalize_encoder_name(value) -def _validate_output_stem(value: str, *, field_name: str) -> str: +def validate_output_stem(value: str, *, field_name: str) -> str: cleaned = str(value).strip() if not cleaned: raise ValueError(f"{field_name} must be a non-empty string") @@ -99,10 +99,10 @@ def patient_encoder_dir(output_cfg: OutputConfig, encoder_name: str) -> Path: def patient_embedding_path(output_cfg: OutputConfig, encoder_name: str, case_id: str) -> Path: - case_stem = _validate_output_stem(case_id, field_name="case_id") + case_stem = validate_output_stem(case_id, field_name="case_id") return patient_encoder_dir(output_cfg, encoder_name) / f"{case_stem}.h5" def patient_lock_path(output_cfg: OutputConfig, encoder_name: str, case_id: str) -> Path: - case_stem = _validate_output_stem(case_id, field_name="case_id") + case_stem = validate_output_stem(case_id, field_name="case_id") return patient_encoder_dir(output_cfg, encoder_name) / f"{case_stem}.lock" diff --git a/atlas_patch/models/slide/common.py b/atlas_patch/models/common.py similarity index 54% rename from atlas_patch/models/slide/common.py rename to atlas_patch/models/common.py index 20b14a1..0e977ad 100644 --- a/atlas_patch/models/slide/common.py +++ b/atlas_patch/models/common.py @@ -7,38 +7,34 @@ from atlas_patch.utils.feature_h5 import coerce_1d_embedding -logger = logging.getLogger("atlas_patch.slide_models") +logger = logging.getLogger("atlas_patch.models") -def resolve_slide_device(device: str | torch.device) -> torch.device: +def resolve_model_device(device: str | torch.device) -> torch.device: resolved = torch.device(device) if resolved.type == "cuda" and not torch.cuda.is_available(): - logger.warning("Slide encoding requested on CUDA but unavailable; using CPU instead.") + logger.warning("Model inference requested on CUDA but unavailable; using CPU instead.") return torch.device("cpu") return resolved -def coerce_slide_embedding( +def coerce_model_embedding( output: torch.Tensor | np.ndarray, *, expected_dim: int, - encoder_name: str, + label: str, ) -> np.ndarray: tensor = torch.as_tensor(output) if tensor.ndim == 2: if tensor.shape[0] != 1: - raise ValueError( - f"Expected a single {encoder_name} embedding, got shape {tuple(tensor.shape)}" - ) + raise ValueError(f"Expected a single {label} embedding, got shape {tuple(tensor.shape)}") tensor = tensor.squeeze(0) if tensor.ndim != 1: - raise ValueError(f"Expected a 1-D {encoder_name} embedding, got shape {tuple(tensor.shape)}") + raise ValueError(f"Expected a 1-D {label} embedding, got shape {tuple(tensor.shape)}") vector = coerce_1d_embedding( tensor.detach().to(device="cpu", dtype=torch.float32).numpy(), - label=encoder_name, + label=label, ) if vector.shape[0] != expected_dim: - raise ValueError( - f"{encoder_name} embedding dim mismatch: expected {expected_dim}, got {vector.shape[0]}" - ) + raise ValueError(f"{label} embedding dim mismatch: expected {expected_dim}, got {vector.shape[0]}") return vector diff --git a/atlas_patch/models/moozy_api.py b/atlas_patch/models/moozy_api.py new file mode 100644 index 0000000..0e168cb --- /dev/null +++ b/atlas_patch/models/moozy_api.py @@ -0,0 +1,49 @@ +from __future__ import annotations + +from contextlib import nullcontext +from pathlib import Path + +import h5py +import numpy as np +import torch + +_EMBEDDING_DATASET = "features" + + +def run_moozy_public_api( + slide_paths: list[str], + output_path: str, + *, + device: torch.device, + mixed_precision: bool = False, +) -> None: + try: + from moozy.encoding import run_encoding + except ModuleNotFoundError as exc: + raise RuntimeError( + "MOOZY encoding requires the optional `moozy` package. " + "Install `atlas-patch[moozy]`, `atlas-patch[slide-encoders]`, or `pip install moozy`." + ) from exc + + if device.type == "cpu" and torch.cuda.is_available(): + raise RuntimeError( + "MOOZY's public `run_encoding` API does not support forcing CPU when CUDA is " + "available. Use `--device cuda` or run in a CPU-only environment." + ) + + cuda_ctx = torch.cuda.device(device) if device.type == "cuda" else nullcontext() + with cuda_ctx: + run_encoding( + slide_paths=slide_paths, + output_path=output_path, + mixed_precision=mixed_precision, + ) + + +def load_moozy_embedding(output_path: str | Path) -> np.ndarray: + path = Path(output_path) + with h5py.File(path, "r") as handle: + dataset = handle.get(_EMBEDDING_DATASET) + if not isinstance(dataset, h5py.Dataset): + raise ValueError(f"{path} is missing required dataset '{_EMBEDDING_DATASET}'.") + return np.asarray(dataset[()], dtype=np.float32) diff --git a/atlas_patch/models/patient/__init__.py b/atlas_patch/models/patient/__init__.py index f2b1400..26de41d 100644 --- a/atlas_patch/models/patient/__init__.py +++ b/atlas_patch/models/patient/__init__.py @@ -1,3 +1,9 @@ +from __future__ import annotations + +from importlib import import_module + +import torch + from atlas_patch.models.patient.base import PatientEncoder, PatientEncoderSpec from atlas_patch.models.patient.registry import PatientEncoderRegistry @@ -8,7 +14,19 @@ "build_default_registry", ] +_REGISTRARS: tuple[tuple[str, str], ...] = ( + ("atlas_patch.models.patient.moozy", "register_moozy_patient_encoder"), +) + -def build_default_registry() -> PatientEncoderRegistry: - """Return an empty patient-encoder registry for phased population.""" - return PatientEncoderRegistry() +def build_default_registry( + *, + device: str | torch.device = "cuda", +) -> PatientEncoderRegistry: + """Factory that registers the built-in patient encoders.""" + registry = PatientEncoderRegistry() + for module_name, registrar_name in _REGISTRARS: + module = import_module(module_name) + registrar = getattr(module, registrar_name) + registrar(registry, device=device) + return registry diff --git a/atlas_patch/models/patient/moozy.py b/atlas_patch/models/patient/moozy.py new file mode 100644 index 0000000..8077652 --- /dev/null +++ b/atlas_patch/models/patient/moozy.py @@ -0,0 +1,62 @@ +from __future__ import annotations + +import tempfile +from pathlib import Path +from typing import Mapping, Sequence + +import numpy as np +import torch + +from atlas_patch.models.common import resolve_model_device +from atlas_patch.models.moozy_api import load_moozy_embedding, run_moozy_public_api +from atlas_patch.models.patient.base import PatientEncoder, PatientEncoderSpec +from atlas_patch.models.patient.registry import PatientEncoderRegistry +from atlas_patch.utils.feature_h5 import coerce_1d_embedding + + +class MOOZYPatientEncoder(PatientEncoder): + spec = PatientEncoderSpec( + name="moozy", + embedding_dim=768, + patch_encoder_name="lunit_vit_small_patch8_dino", + patch_size=224, + ) + + def __init__(self, *, device: str | torch.device = "cuda") -> None: + self.device = resolve_model_device(device) + + def encode_case( + self, + slide_h5_paths: Sequence[Path], + metadata: Mapping[str, object] | None = None, + ) -> np.ndarray: + if not slide_h5_paths: + raise ValueError("MOOZY patient encoding requires at least one slide H5.") + + with tempfile.TemporaryDirectory(prefix="atlaspatch_moozy_case_") as tmp_dir: + output_path = Path(tmp_dir) / "case_embedding.h5" + run_moozy_public_api( + [str(path) for path in slide_h5_paths], + str(output_path), + device=self.device, + mixed_precision=False, + ) + embedding = load_moozy_embedding(output_path) + + vector = coerce_1d_embedding(embedding, label="MOOZY patient embedding") + if vector.shape[0] != self.embedding_dim: + raise ValueError( + f"MOOZY patient embedding dim mismatch: expected {self.embedding_dim}, got {vector.shape[0]}" + ) + return np.asarray(vector, dtype=np.float32) + + +def register_moozy_patient_encoder( + registry: PatientEncoderRegistry, + *, + device: str | torch.device = "cuda", +) -> None: + registry.register( + MOOZYPatientEncoder.spec, + lambda: MOOZYPatientEncoder(device=device), + ) diff --git a/atlas_patch/models/slide/moozy.py b/atlas_patch/models/slide/moozy.py index 220d482..3762b0c 100644 --- a/atlas_patch/models/slide/moozy.py +++ b/atlas_patch/models/slide/moozy.py @@ -1,57 +1,16 @@ from __future__ import annotations -from contextlib import nullcontext import tempfile from pathlib import Path -import h5py import numpy as np import torch +from atlas_patch.models.common import coerce_model_embedding, resolve_model_device +from atlas_patch.models.moozy_api import load_moozy_embedding, run_moozy_public_api from atlas_patch.models.slide.base import SlideEncoder, SlideEncoderSpec -from atlas_patch.models.slide.common import coerce_slide_embedding, resolve_slide_device from atlas_patch.models.slide.registry import SlideEncoderRegistry -_EMBEDDING_DATASET = "features" - - -def _run_moozy_public_api( - slide_paths: list[str], - output_path: str, - *, - device: torch.device, - mixed_precision: bool = False, -) -> None: - try: - from moozy.encoding import run_encoding - except ModuleNotFoundError as exc: - raise RuntimeError( - "MOOZY slide encoding requires the optional `moozy` package. " - "Install `atlas-patch[moozy]`, `atlas-patch[slide-encoders]`, or `pip install moozy`." - ) from exc - - if device.type == "cpu" and torch.cuda.is_available(): - raise RuntimeError( - "MOOZY's public `run_encoding` API does not support forcing CPU when CUDA is " - "available. Use `--device cuda` or run in a CPU-only environment." - ) - - cuda_ctx = torch.cuda.device(device) if device.type == "cuda" else nullcontext() - with cuda_ctx: - run_encoding( - slide_paths=slide_paths, - output_path=output_path, - mixed_precision=mixed_precision, - ) - - -def _load_moozy_embedding(output_path: Path) -> np.ndarray: - with h5py.File(output_path, "r") as handle: - dataset = handle.get(_EMBEDDING_DATASET) - if not isinstance(dataset, h5py.Dataset): - raise ValueError(f"{output_path} is missing required dataset '{_EMBEDDING_DATASET}'.") - return np.asarray(dataset[()], dtype=np.float32) - class MOOZYSlideEncoder(SlideEncoder): spec = SlideEncoderSpec( @@ -62,22 +21,22 @@ class MOOZYSlideEncoder(SlideEncoder): ) def __init__(self, *, device: str | torch.device = "cuda") -> None: - self.device = resolve_slide_device(device) + self.device = resolve_model_device(device) def encode_slide(self, patch_h5_path: Path) -> np.ndarray: with tempfile.TemporaryDirectory(prefix="atlaspatch_moozy_") as tmp_dir: output_path = Path(tmp_dir) / "case_embedding.h5" - _run_moozy_public_api( + run_moozy_public_api( [str(patch_h5_path)], str(output_path), device=self.device, mixed_precision=False, ) - embedding = _load_moozy_embedding(output_path) - return coerce_slide_embedding( + embedding = load_moozy_embedding(output_path) + return coerce_model_embedding( embedding, expected_dim=self.embedding_dim, - encoder_name="MOOZY", + label="MOOZY", ) diff --git a/atlas_patch/models/slide/prism.py b/atlas_patch/models/slide/prism.py index b8ef453..4d50818 100644 --- a/atlas_patch/models/slide/prism.py +++ b/atlas_patch/models/slide/prism.py @@ -6,8 +6,8 @@ import numpy as np import torch +from atlas_patch.models.common import coerce_model_embedding, resolve_model_device from atlas_patch.models.slide.base import SlideEncoder, SlideEncoderSpec -from atlas_patch.models.slide.common import coerce_slide_embedding, resolve_slide_device from atlas_patch.models.slide.registry import SlideEncoderRegistry from atlas_patch.utils.feature_h5 import load_patch_feature_data @@ -44,7 +44,7 @@ class PrismSlideEncoder(SlideEncoder): ) def __init__(self, *, device: str | torch.device = "cuda") -> None: - self.device = resolve_slide_device(device) + self.device = resolve_model_device(device) self.dtype = torch.float16 if self.device.type == "cuda" else torch.float32 self.model = _load_prism_model(device=self.device, dtype=self.dtype) @@ -74,10 +74,10 @@ def encode_slide(self, patch_h5_path: Path) -> np.ndarray: output = self.model.slide_representations(features, tile_mask=tile_mask) if not isinstance(output, dict) or "image_embedding" not in output: raise ValueError("PRISM did not return an 'image_embedding' entry.") - return coerce_slide_embedding( + return coerce_model_embedding( output["image_embedding"], expected_dim=self.embedding_dim, - encoder_name="PRISM", + label="PRISM", ) def cleanup(self) -> None: diff --git a/atlas_patch/models/slide/titan.py b/atlas_patch/models/slide/titan.py index 9edc427..61f5d97 100644 --- a/atlas_patch/models/slide/titan.py +++ b/atlas_patch/models/slide/titan.py @@ -5,8 +5,8 @@ import numpy as np import torch +from atlas_patch.models.common import coerce_model_embedding, resolve_model_device from atlas_patch.models.slide.base import SlideEncoder, SlideEncoderSpec -from atlas_patch.models.slide.common import coerce_slide_embedding, resolve_slide_device from atlas_patch.models.slide.registry import SlideEncoderRegistry from atlas_patch.utils.feature_h5 import load_patch_feature_data @@ -36,7 +36,7 @@ class TitanSlideEncoder(SlideEncoder): ) def __init__(self, *, device: str | torch.device = "cuda") -> None: - self.device = resolve_slide_device(device) + self.device = resolve_model_device(device) self.dtype = torch.float16 if self.device.type == "cuda" else torch.float32 self.model = _load_titan_model(device=self.device, dtype=self.dtype) @@ -67,10 +67,10 @@ def encode_slide(self, patch_h5_path: Path) -> np.ndarray: coords, int(patch_data.patch_size_level0), ) - return coerce_slide_embedding( + return coerce_model_embedding( embedding, expected_dim=self.embedding_dim, - encoder_name="TITAN", + label="TITAN", ) def cleanup(self) -> None: diff --git a/atlas_patch/services/_embedding_runtime.py b/atlas_patch/services/_embedding_runtime.py new file mode 100644 index 0000000..363e2c7 --- /dev/null +++ b/atlas_patch/services/_embedding_runtime.py @@ -0,0 +1,45 @@ +from __future__ import annotations + +import os +from pathlib import Path + + +def acquire_exclusive_lock( + lock_path: Path, + *, + payload: str, + label: str, +) -> tuple[int | None, Path]: + lock_path.parent.mkdir(parents=True, exist_ok=True) + try: + fd = os.open(lock_path, os.O_CREAT | os.O_EXCL | os.O_WRONLY) + os.write(fd, payload.encode()) + os.fsync(fd) + return fd, lock_path + except FileExistsError: + return None, lock_path + except Exception as exc: # noqa: BLE001 + raise RuntimeError(f"Failed to create {label} lock {lock_path}: {exc}") from exc + + +def release_lock(fd: int | None, lock_path: Path) -> None: + if fd is not None: + try: + os.close(fd) + except Exception: + pass + try: + lock_path.unlink() + except FileNotFoundError: + pass + except Exception: + pass + + +def cleanup_encoder(encoder: object) -> None: + cleanup = getattr(encoder, "cleanup", None) + if callable(cleanup): + try: + cleanup() + except Exception: + pass diff --git a/atlas_patch/services/mpp.py b/atlas_patch/services/mpp.py index a2382bf..f1945bc 100644 --- a/atlas_patch/services/mpp.py +++ b/atlas_patch/services/mpp.py @@ -15,4 +15,6 @@ def __init__(self, csv_path: Path | None) -> None: self._mpp_map = load_mpp_csv(str(csv_path)) def resolve(self, slide: Slide) -> float | None: + if slide.mpp is not None: + return float(slide.mpp) return get_mpp_for_wsi(str(slide.path), self._mpp_map) diff --git a/atlas_patch/services/patient_embedding.py b/atlas_patch/services/patient_embedding.py new file mode 100644 index 0000000..4587273 --- /dev/null +++ b/atlas_patch/services/patient_embedding.py @@ -0,0 +1,282 @@ +from __future__ import annotations + +import csv +import os +import time +from pathlib import Path +from typing import Iterable, Mapping + +from atlas_patch.core.config import OutputConfig, PatientEncodingConfig +from atlas_patch.core.models import PatientCase, PatientEmbeddingResult, Slide +from atlas_patch.core.paths import patient_embedding_path, patient_lock_path, validate_output_stem +from atlas_patch.models.patient import PatientEncoderRegistry, build_default_registry +from atlas_patch.services._embedding_runtime import ( + acquire_exclusive_lock, + cleanup_encoder, + release_lock, +) +from atlas_patch.utils.feature_h5 import ( + PatchArtifactSummary, + read_patch_artifact_summary, + read_patient_embedding_summary, + write_patient_embedding_h5, +) + + +def _parse_manifest_mpp(raw: str | None, *, manifest_path: Path, row_number: int) -> float | None: + value = str(raw or "").strip() + if not value: + return None + try: + return float(value) + except ValueError as exc: + raise ValueError( + f"{manifest_path}:{row_number} has invalid mpp value {value!r}; expected a number." + ) from exc + + +def load_patient_cases(manifest_path: str | Path) -> tuple[list[PatientCase], list[Slide]]: + path = Path(manifest_path) + manifest_dir = path.resolve().parent + with path.open("r", encoding="utf-8", newline="") as handle: + reader = csv.DictReader(handle) + fieldnames = reader.fieldnames or [] + required = {"case_id", "slide_path"} + missing = [name for name in required if name not in fieldnames] + if missing: + joined = ", ".join(sorted(missing)) + raise ValueError(f"{path} is missing required column(s): {joined}") + + slides_by_path: dict[Path, Slide] = {} + slide_stems: dict[str, Path] = {} + cases_by_id: dict[str, list[Slide]] = {} + seen_within_case: dict[str, set[Path]] = {} + + for row_number, row in enumerate(reader, start=2): + case_id = str(row.get("case_id") or "").strip() + if not case_id: + raise ValueError(f"{path}:{row_number} has empty case_id.") + validate_output_stem(case_id, field_name="case_id") + + slide_raw = str(row.get("slide_path") or "").strip() + if not slide_raw: + raise ValueError(f"{path}:{row_number} has empty slide_path.") + slide_path = Path(slide_raw).expanduser() + if not slide_path.is_absolute(): + slide_path = manifest_dir / slide_path + if not slide_path.exists(): + raise FileNotFoundError(f"{path}:{row_number} references missing slide {slide_path}") + if not slide_path.is_file(): + raise ValueError(f"{path}:{row_number} references non-file slide {slide_path}") + resolved_path = slide_path.resolve() + mpp = _parse_manifest_mpp(row.get("mpp"), manifest_path=path, row_number=row_number) + + existing = slides_by_path.get(resolved_path) + if existing is None: + existing_stem_path = slide_stems.get(slide_path.stem) + if existing_stem_path is not None and existing_stem_path != resolved_path: + raise ValueError( + f"{path} includes multiple slides named '{slide_path.stem}', which would " + "collide in patches/.h5. Rename the slide files or split the run." + ) + existing = Slide(path=slide_path, mpp=mpp) + slides_by_path[resolved_path] = existing + slide_stems[slide_path.stem] = resolved_path + elif existing.mpp != mpp: + raise ValueError( + f"{path}:{row_number} assigns conflicting mpp overrides to {slide_path}." + ) + + case_slides = cases_by_id.setdefault(case_id, []) + seen_paths = seen_within_case.setdefault(case_id, set()) + if resolved_path in seen_paths: + raise ValueError(f"{path}:{row_number} repeats slide {slide_path} within case {case_id}.") + seen_paths.add(resolved_path) + case_slides.append(existing) + + cases = [PatientCase(case_id=case_id, slides=tuple(slides)) for case_id, slides in cases_by_id.items()] + slides = list(slides_by_path.values()) + return cases, slides + + +class PatientEmbeddingService: + """Write patient-level embeddings into separate case-scoped H5 files.""" + + def __init__( + self, + output_cfg: OutputConfig, + patient_cfg: PatientEncodingConfig, + *, + registry: PatientEncoderRegistry | None = None, + ) -> None: + self.output_cfg = output_cfg.validated() + self.patient_cfg = patient_cfg.validated() + self.registry = registry or build_default_registry(device=self.patient_cfg.device) + + @staticmethod + def _build_source_slide_summaries(slide_h5_paths: list[Path]) -> tuple[PatchArtifactSummary, ...]: + return tuple(read_patch_artifact_summary(path.resolve()) for path in slide_h5_paths) + + @staticmethod + def _validate_reusable_embedding( + *, + output_path: Path, + encoder, + case: PatientCase, + slide_h5_paths: list[Path], + ) -> PatientEmbeddingResult | None: + try: + current_slide_summaries = PatientEmbeddingService._build_source_slide_summaries( + slide_h5_paths + ) + summary = read_patient_embedding_summary(output_path) + except Exception: + return None + if summary.encoder_name != encoder.name: + return None + if summary.embedding_dim != encoder.embedding_dim: + return None + if summary.num_slides != case.num_slides: + return None + if summary.source_patch_encoder != encoder.required_patch_encoder: + return None + current_slide_h5s = tuple(str(path.resolve()) for path in slide_h5_paths) + if summary.source_slide_h5_paths != current_slide_h5s: + return None + if summary.source_slide_patch_summaries != current_slide_summaries: + return None + return PatientEmbeddingResult( + case_id=case.case_id, + h5_path=output_path, + encoder_name=encoder.name, + embedding_dim=summary.embedding_dim, + num_slides=summary.num_slides, + source_patch_encoder=summary.source_patch_encoder, + metadata={"reused": True}, + ) + + def embed_all( + self, + cases: Iterable[PatientCase], + *, + slide_h5_by_path: Mapping[Path, Path], + ) -> tuple[list[PatientEmbeddingResult], list[tuple[PatientCase, Exception | str]]]: + ordered_cases = list(cases) + results: list[PatientEmbeddingResult] = [] + failures: list[tuple[PatientCase, Exception | str]] = [] + if not ordered_cases: + return results, failures + + for encoder_name in self.patient_cfg.encoders: + try: + encoder = self.registry.create(encoder_name) + except Exception as exc: # noqa: BLE001 + for case in ordered_cases: + failures.append((case, RuntimeError(f"{encoder_name}: {exc}"))) + continue + + try: + for case in ordered_cases: + try: + slide_h5_paths = [ + slide_h5_by_path[slide.path.resolve()].resolve() + for slide in case.slides + ] + except KeyError as exc: + failures.append( + ( + case, + RuntimeError( + f"{encoder_name}: missing canonical patch artifact for {exc.args[0]}" + ), + ) + ) + continue + + output_path = patient_embedding_path(self.output_cfg, encoder_name, case.case_id) + if self.patient_cfg.skip_existing and output_path.exists(): + reused = self._validate_reusable_embedding( + output_path=output_path, + encoder=encoder, + case=case, + slide_h5_paths=slide_h5_paths, + ) + if reused is not None: + results.append(reused) + continue + overwrite = True + else: + overwrite = not self.patient_cfg.skip_existing + + lock_fd, lock_path = acquire_exclusive_lock( + patient_lock_path(self.output_cfg, encoder_name, case.case_id), + payload=( + f"pid={os.getpid()},time={int(time.time())}," + f"case={case.case_id},phase=patient-embedding" + ), + label="patient", + ) + if lock_fd is None: + failures.append( + ( + case, + RuntimeError( + f"{encoder_name}: {output_path} is locked by another process." + ), + ) + ) + continue + + try: + embedding = encoder.encode_case( + slide_h5_paths, + metadata={ + "case_id": case.case_id, + "manifest_path": self.patient_cfg.manifest_path.resolve(), + }, + ) + source_slide_summaries = self._build_source_slide_summaries(slide_h5_paths) + write_patient_embedding_h5( + output_path, + embedding, + attrs={ + "encoder_name": encoder.name, + "num_slides": case.num_slides, + "source_patch_encoder": encoder.required_patch_encoder, + "source_manifest": str(self.patient_cfg.manifest_path.resolve()), + "source_slide_h5_paths": [ + str(path.resolve()) for path in slide_h5_paths + ], + "source_slide_patch_summaries": [ + { + "h5_path": str(summary.h5_path.resolve()), + "num_patches": summary.num_patches, + "patch_size_level0": summary.patch_size_level0, + "patch_size": summary.patch_size, + "target_magnification": summary.target_magnification, + "wsi_path": summary.wsi_path, + } + for summary in source_slide_summaries + ], + }, + overwrite=overwrite, + ) + results.append( + PatientEmbeddingResult( + case_id=case.case_id, + h5_path=output_path, + encoder_name=encoder.name, + embedding_dim=int(embedding.shape[0]), + num_slides=case.num_slides, + source_patch_encoder=encoder.required_patch_encoder, + metadata={"reused": False}, + ) + ) + except Exception as exc: # noqa: BLE001 + failures.append((case, RuntimeError(f"{encoder_name}: {exc}"))) + finally: + release_lock(lock_fd, lock_path) + finally: + cleanup_encoder(encoder) + + return results, failures diff --git a/atlas_patch/services/slide_embedding.py b/atlas_patch/services/slide_embedding.py index e3cce2f..3060469 100644 --- a/atlas_patch/services/slide_embedding.py +++ b/atlas_patch/services/slide_embedding.py @@ -9,6 +9,11 @@ from atlas_patch.core.models import Slide, SlideEmbeddingResult from atlas_patch.core.paths import slide_append_lock_path, slide_feature_dataset_key from atlas_patch.models.slide import SlideEncoderRegistry, build_default_registry +from atlas_patch.services._embedding_runtime import ( + acquire_exclusive_lock, + cleanup_encoder, + release_lock, +) from atlas_patch.utils.feature_h5 import ( PatchArtifactSummary, SlideEmbeddingSummary, @@ -34,43 +39,6 @@ def __init__( self.slide_cfg = slide_cfg.validated() self.registry = registry or build_default_registry(device=self.slide_cfg.device) - def _acquire_lock(self, slide: Slide) -> tuple[int | None, Path]: - lock_path = slide_append_lock_path(slide, self.output_cfg, self.extraction_cfg) - lock_path.parent.mkdir(parents=True, exist_ok=True) - payload = f"pid={os.getpid()},time={int(time.time())},slide={slide.path},phase=slide-embedding" - try: - fd = os.open(lock_path, os.O_CREAT | os.O_EXCL | os.O_WRONLY) - os.write(fd, payload.encode()) - os.fsync(fd) - return fd, lock_path - except FileExistsError: - return None, lock_path - except Exception as exc: # noqa: BLE001 - raise RuntimeError(f"Failed to create slide lock {lock_path}: {exc}") from exc - - @staticmethod - def _release_lock(fd: int | None, lock_path: Path) -> None: - if fd is not None: - try: - os.close(fd) - except Exception: - pass - try: - lock_path.unlink() - except FileNotFoundError: - pass - except Exception: - pass - - @staticmethod - def _cleanup_encoder(encoder) -> None: - cleanup = getattr(encoder, "cleanup", None) - if callable(cleanup): - try: - cleanup() - except Exception: - pass - @staticmethod def _build_result( *, @@ -203,7 +171,14 @@ def embed_all( ) continue - lock_fd, lock_path = self._acquire_lock(slide) + lock_fd, lock_path = acquire_exclusive_lock( + slide_append_lock_path(slide, self.output_cfg, self.extraction_cfg), + payload=( + f"pid={os.getpid()},time={int(time.time())}," + f"slide={slide.path},phase=slide-embedding" + ), + label="slide", + ) if lock_fd is None: failures.append( ( @@ -253,8 +228,8 @@ def embed_all( except Exception as exc: # noqa: BLE001 failures.append((slide, RuntimeError(f"{encoder_name}: {exc}"))) finally: - self._release_lock(lock_fd, lock_path) + release_lock(lock_fd, lock_path) finally: - self._cleanup_encoder(encoder) + cleanup_encoder(encoder) return results, failures diff --git a/atlas_patch/utils/feature_h5.py b/atlas_patch/utils/feature_h5.py index d30db2e..c792eab 100644 --- a/atlas_patch/utils/feature_h5.py +++ b/atlas_patch/utils/feature_h5.py @@ -25,6 +25,10 @@ "patch_size_level0", "target_magnification", ) +REQUIRED_PATIENT_EMBEDDING_ATTRS = ( + "encoder_name", + "num_slides", +) @dataclass(frozen=True) @@ -64,6 +68,17 @@ class SlideEmbeddingSummary: target_magnification: int +@dataclass(frozen=True) +class PatientEmbeddingSummary: + embedding_dim: int + encoder_name: str + num_slides: int + source_patch_encoder: str | None = None + source_manifest: str | None = None + source_slide_h5_paths: tuple[str, ...] = () + source_slide_patch_summaries: tuple[PatchArtifactSummary, ...] = () + + def _encode_attr_value(value: Any) -> Any: if isinstance(value, Path): return str(value) @@ -160,6 +175,91 @@ def read_slide_embedding_summary( ) +def _decode_json_list_attr(value: Any, *, label: str) -> tuple[str, ...]: + if value is None: + return () + try: + parsed = json.loads(str(value)) + except json.JSONDecodeError as exc: + raise ValueError(f"{label} must be valid JSON.") from exc + if not isinstance(parsed, list): + raise ValueError(f"{label} must decode to a JSON list.") + return tuple(str(item) for item in parsed) + + +def _decode_patient_source_slide_summaries( + value: Any, + *, + label: str, +) -> tuple[PatchArtifactSummary, ...]: + if value is None: + return () + try: + parsed = json.loads(str(value)) + except json.JSONDecodeError as exc: + raise ValueError(f"{label} must be valid JSON.") from exc + if not isinstance(parsed, list): + raise ValueError(f"{label} must decode to a JSON list.") + + summaries: list[PatchArtifactSummary] = [] + for item in parsed: + if not isinstance(item, dict): + raise ValueError(f"{label} entries must be JSON objects.") + summaries.append( + PatchArtifactSummary( + h5_path=Path(str(item["h5_path"])).expanduser().resolve(), + num_patches=int(item["num_patches"]), + patch_size_level0=int(item["patch_size_level0"]), + patch_size=int(item["patch_size"]), + target_magnification=int(item["target_magnification"]), + wsi_path=(str(item["wsi_path"]) if item.get("wsi_path") is not None else None), + ) + ) + return tuple(summaries) + + +def read_patient_embedding_summary(h5_path: str | Path) -> PatientEmbeddingSummary: + path = Path(h5_path) + with h5py.File(path, "r") as handle: + dataset = handle.get("features") + if not isinstance(dataset, h5py.Dataset): + raise ValueError(f"{path} is missing required dataset 'features'.") + if dataset.ndim != 1: + raise ValueError(f"{path} has invalid patient embedding shape {dataset.shape}.") + + missing = [key for key in REQUIRED_PATIENT_EMBEDDING_ATTRS if key not in handle.attrs] + if missing: + joined = ", ".join(missing) + raise ValueError(f"{path} is missing required patient embedding metadata: {joined}") + + return PatientEmbeddingSummary( + embedding_dim=int(dataset.shape[0]), + encoder_name=str(handle.attrs["encoder_name"]).strip().lower(), + num_slides=int(handle.attrs["num_slides"]), + source_patch_encoder=( + str(handle.attrs["source_patch_encoder"]).strip().lower() + if "source_patch_encoder" in handle.attrs + else None + ), + source_manifest=( + str(handle.attrs["source_manifest"]) + if "source_manifest" in handle.attrs + else None + ), + source_slide_h5_paths=tuple( + str(Path(item).expanduser().resolve()) + for item in _decode_json_list_attr( + handle.attrs.get("source_slide_h5_paths"), + label=f"{path} source_slide_h5_paths", + ) + ), + source_slide_patch_summaries=_decode_patient_source_slide_summaries( + handle.attrs.get("source_slide_patch_summaries"), + label=f"{path} source_slide_patch_summaries", + ), + ) + + def load_patch_feature_data( h5_path: str | Path, feature_name: str, From da6a382d3f48742f4aa1cea21368b806b50e51e5 Mon Sep 17 00:00:00 2001 From: yousefkotp Date: Mon, 6 Apr 2026 19:35:33 -0400 Subject: [PATCH 5/8] docs: refresh v1.1.0 documentation --- CHANGELOG.md | 15 + README.md | 585 +++++++++++++----------- docs/commands/detect-tissue.md | 52 +++ docs/commands/encode-patient.md | 114 +++++ docs/commands/encode-slide.md | 108 +++++ docs/commands/process.md | 106 +++++ docs/commands/segment-and-get-coords.md | 94 ++++ docs/release-notes/v1.1.0.md | 19 + pyproject.toml | 4 + 9 files changed, 823 insertions(+), 274 deletions(-) create mode 100644 docs/commands/detect-tissue.md create mode 100644 docs/commands/encode-patient.md create mode 100644 docs/commands/encode-slide.md create mode 100644 docs/commands/process.md create mode 100644 docs/commands/segment-and-get-coords.md create mode 100644 docs/release-notes/v1.1.0.md diff --git a/CHANGELOG.md b/CHANGELOG.md index 3b2b9d0..a98b44b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,21 @@ All notable changes to AtlasPatch will be documented in this file. The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). +## [1.1.0] - Unreleased + +### Added + +#### Slide and Patient Encoding +- `encode-slide` command for slide embeddings appended into `slide_features/` inside the canonical per-slide H5 +- `encode-patient` command for manifest-driven patient embeddings written under `patient_features//.h5` +- Slide encoder registry and built-in slide encoders: `titan`, `prism`, `moozy` +- Patient encoder registry and built-in patient encoder: `moozy` + +#### Packaging and Release +- Optional extras for slide and patient encoder stacks: `titan`, `prism`, `moozy`, `slide-encoders`, and `patient-encoders` +- GitHub Release to PyPI trusted publishing workflow in `publish.yml` + + ## [1.0.0] - 2025-02-03 ### Added diff --git a/README.md b/README.md index b49e166..d8b10dd 100644 --- a/README.md +++ b/README.md @@ -7,10 +7,10 @@

PyPI Python + HuggingFace License

-

Project Page | Paper | @@ -20,102 +20,102 @@ ## Table of Contents - [Installation](#installation) - - [Quick Install (Recommended)](#quick-install-recommended) - - [OpenSlide Prerequisites](#openslide-prerequisites) - - [Optional Encoder Dependencies](#optional-encoder-dependencies) + - [1. Install OpenSlide](#1-install-openslide) + - [2. Install AtlasPatch](#2-install-atlaspatch) + - [3. Install SAM2](#3-install-sam2) + - [Optional extras](#optional-extras) - [Alternative Installation Methods](#alternative-installation-methods) - [Usage Guide](#usage-guide) - - [Pipeline Checkpoints](#pipeline-checkpoints) - - [A - Tissue Detection](#a-tissue-detection) - - [B - Patch Coordinate Extraction](#b-patch-coordinate-extraction) - - [C - Patch Embedding](#c-patch-embedding) - - [D - Patch Writing](#d-patch-writing) - - [Visualization Samples](#visualization-samples) - - [Process Command Arguments](#process-command-arguments) - - [Required](#required) - - [Optional](#optional) - - [Patch Layout](#patch-layout) - - [Segmentation & Extraction Performance](#segmentation--extraction-performance) - - [Feature Extraction](#feature-extraction) - - [Filtering & Quality](#filtering--quality) - - [Visualization](#visualization) - - [Run Control](#run-control) -- [Supported Formats](#supported-formats) -- [Using Extracted Data](#using-extracted-data) - - [Patch Coordinates](#patch-coordinates) - - [Feature Matrices](#feature-matrices) -- [Available Feature Extractors](#available-feature-extractors) - - [Core vision backbones on Natural Images](#core-vision-backbones-on-natural-images) - - [Medical- and Pathology-Specific Vision Encoders](#medical--and-pathology-specific-vision-encoders) - - [CLIP-like models](#clip-like-models) - - [Natural Images](#natural-images) - - [Medical- and Pathology-Specific CLIP](#medical--and-pathology-specific-clip) -- [Bring Your Own Encoder](#bring-your-own-encoder) + - [Choose a Command](#choose-a-command) + - [Quick Start](#quick-start) + - [Full patch pipeline](#full-patch-pipeline) + - [Slide encoding](#slide-encoding) + - [Patient encoding](#patient-encoding) + - [Supported Inputs](#supported-inputs) +- [Encoders](#encoders) + - [Available Patch Feature Extractors](#available-patch-feature-extractors) + - [Core vision backbones on Natural Images](#core-vision-backbones-on-natural-images) + - [Medical- and Pathology-Specific Vision Encoders](#medical--and-pathology-specific-vision-encoders) + - [CLIP-like models](#clip-like-models) + - [Natural Images](#natural-images) + - [Medical- and Pathology-Specific CLIP](#medical--and-pathology-specific-clip) + - [Bring Your Own Encoder](#bring-your-own-encoder) + - [Available Slide Encoders](#available-slide-encoders) + - [Available Patient Encoders](#available-patient-encoders) +- [Output Files](#output-files) + - [What AtlasPatch writes](#what-atlaspatch-writes) + - [Per-Slide H5 files](#per-slide-h5-files) + - [Patient embedding files](#patient-embedding-files) + - [Optional image outputs](#optional-image-outputs) + - [Reading the files](#reading-the-files) + - [Patch coordinates](#patch-coordinates) + - [Patch feature matrices](#patch-feature-matrices) + - [Slide embeddings](#slide-embeddings) + - [Patient embeddings](#patient-embeddings) - [SLURM job scripts](#slurm-job-scripts) - [Frequently Asked Questions (FAQ)](#frequently-asked-questions-faq) - [Feedback](#feedback) - [Citation](#citation) - [License](#license) -- [Future Updates](#future-updates) - - [Slide Encoders](#slide-encoders) ## Installation -### Quick Install (Recommended) +AtlasPatch targets Python 3.10+. -```bash -# Install base AtlasPatch -pip install atlas-patch + -# Install SAM2 (required for tissue segmentation) -pip install git+https://github.com/facebookresearch/sam2.git -``` +### 1. Install OpenSlide -> **Note:** AtlasPatch requires the OpenSlide system library for WSI processing. See [OpenSlide Prerequisites](#openslide-prerequisites) below. +AtlasPatch needs the OpenSlide system library before you install the Python package. -### OpenSlide Prerequisites - -Before installing AtlasPatch, you need the OpenSlide system library: - -- **Using Conda (Recommended)**: - ```bash - conda install -c conda-forge openslide - ``` +```bash +# conda +conda install -c conda-forge openslide -- **Ubuntu/Debian**: - ```bash - sudo apt-get install openslide-tools - ``` +# Ubuntu / Debian +sudo apt-get install openslide-tools -- **macOS**: - ```bash - brew install openslide - ``` +# macOS +brew install openslide +``` -- **Other systems**: Visit [OpenSlide Documentation](https://openslide.org/) +### 2. Install AtlasPatch -### Optional Encoder Dependencies +```bash +pip install atlas-patch +``` -AtlasPatch keeps model-specific dependencies out of the base install. +### 3. Install SAM2 -Use the optional extra below if you want the broader built-in patch encoder registry: +All WSI-facing AtlasPatch commands use SAM2 for tissue segmentation. ```bash -pip install "atlas-patch[patch-encoders]" +pip install git+https://github.com/facebookresearch/sam2.git ``` -Some encoders also require upstream project packages that must still be installed separately: +### Optional extras + +The default install stays lean. Install only the model stacks you need. + +| Use case | Install | +| --- | --- | +| Broader built-in patch encoder registry | `pip install "atlas-patch[patch-encoders]"` | +| TITAN slide encoding | `pip install "atlas-patch[titan]"` | +| PRISM slide encoding | `pip install "atlas-patch[prism]"` | +| MOOZY slide or patient encoding | `pip install "atlas-patch[moozy]"` | +| All bundled slide encoder extras | `pip install "atlas-patch[slide-encoders]"` | +| All bundled patient encoder extras | `pip install "atlas-patch[patient-encoders]"` | + +Some patch encoders still require upstream project packages in addition to `atlas-patch[patch-encoders]`: ```bash -# For CONCH encoder (conch_v1, conch_v15) -pip install git+https://github.com/Mahmoodlab/CONCH.git +# Optional: CONCH patch encoders +pip install git+https://github.com/MahmoodLab/CONCH.git -# For MUSK encoder +# Optional: MUSK patch encoder pip install git+https://github.com/lilab-stanford/MUSK.git ``` -These installs are only needed if you plan to use those specific encoders. - ### Alternative Installation Methods

@@ -133,255 +133,113 @@ conda install -c conda-forge openslide pip install atlas-patch pip install git+https://github.com/facebookresearch/sam2.git ``` +
-Using uv (faster installs) +Using uv ```bash -# Install uv (see https://docs.astral.sh/uv/getting-started/) +# Install uv curl -LsSf https://astral.sh/uv/install.sh | sh # Create and activate environment uv venv -source .venv/bin/activate # On Windows: .venv\Scripts\activate +source .venv/bin/activate # On Windows: .venv\\Scripts\\activate # Install AtlasPatch and SAM2 uv pip install atlas-patch uv pip install git+https://github.com/facebookresearch/sam2.git ``` -
+ ## Usage Guide -AtlasPatch provides a flexible pipeline with **4 checkpoints** that you can use independently or combine based on your needs. - -### Pipeline Checkpoints - -

- AtlasPatch Pipeline Checkpoints -

- -Quick overview of the checkpoint commands: -- `detect-tissue`: runs SAM2 segmentation and writes mask overlays under `/visualization/`. -- `segment-and-get-coords`: runs segmentation + patch coordinate extraction into `/patches/.h5`. -- `process`: full pipeline (segmentation + coords + feature embeddings) in the same H5. -- `segment-and-get-coords --save-images`: same as `segment-and-get-coords`, plus patch PNGs under `/images//`. - ---- - -#### [A] Tissue Detection - -Detect and visualize tissue regions in your WSI using SAM2 segmentation. - -```bash -atlaspatch detect-tissue /path/to/slide.svs \ - --output ./output \ - --device cuda -``` - ---- - -#### [B] Patch Coordinate Extraction - -Detect tissue and extract patch coordinates without feature embedding. - -```bash -atlaspatch segment-and-get-coords /path/to/slide.svs \ - --output ./output \ - --patch-size 256 \ - --target-mag 20 \ - --device cuda -``` +### Choose a Command ---- +| Command | Use it when | Main outputs | Docs | +| --- | --- | --- | --- | +| `detect-tissue` | You only need tissue masks and overlays | `visualization/` | [docs/commands/detect-tissue.md](docs/commands/detect-tissue.md) | +| `segment-and-get-coords` | You want patch locations now and feature extraction later | `patches/.h5` | [docs/commands/segment-and-get-coords.md](docs/commands/segment-and-get-coords.md) | +| `process` | You want the full patch pipeline, including patch features | `patches/.h5`, optional `images/`, optional `visualization/` | [docs/commands/process.md](docs/commands/process.md) | +| `encode-slide` | You want slide embeddings added to each slide H5 file | `patches/.h5` with `slide_features/` | [docs/commands/encode-slide.md](docs/commands/encode-slide.md) | +| `encode-patient` | You have a CSV file that lists which slides belong to each patient | `patient_features//.h5` | [docs/commands/encode-patient.md](docs/commands/encode-patient.md) | -#### [C] Patch Embedding +### Quick Start -Run the full pipeline: Tissue detection, coordinate extraction, and feature embedding. +#### Full patch pipeline ```bash atlaspatch process /path/to/slide.svs \ --output ./output \ --patch-size 256 \ --target-mag 20 \ - --feature-extractors resnet50 \ + --feature-extractors uni_v2 \ --device cuda ``` ---- +This writes a per-slide H5 file at `./output/patches/slide.h5` with: -#### [D] Patch Writing +- `coords` +- `features/uni_v2` +- slide metadata in file attributes -Full pipeline with optional patch image export for visualization or downstream tasks. +#### Slide encoding ```bash -atlaspatch segment-and-get-coords /path/to/slide.svs \ +atlaspatch encode-slide /path/to/slides \ --output ./output \ - --patch-size 256 \ + --slide-encoders titan \ + --patch-size 512 \ --target-mag 20 \ - --device cuda \ - --save-images + --device cuda ``` ---- - -Pass a directory instead of a single file to process multiple WSIs; outputs land in `/patches/.h5` based on the path you provide to `--output`. - -### Visualization Samples +`encode-slide` is WSI-only. It creates or reuses the per-slide patch H5, backfills the required upstream patch features automatically, and appends slide embeddings into the same file. -Below are some examples for the output masks and overlays (original image, predicted mask, overlay, contours, grid). - -

- AtlasPatch visualization samples -

+See [Available Slide Encoders](#available-slide-encoders) for the built-in encoder list, requirements, and install extras. -Quantitative and qualitative analysis of AtlasPatch tissue detection against existing slide-preprocessing tools. +If you want multiple slide encoders in one run, they must be compatible with one patch geometry. For example, `titan` and `prism` need different patch sizes, so they should be run separately. -

- AtlasPatch method comparison -

+#### Patient encoding -Representative WSI thumbnails are shown from diverse tissue features and artifact conditions, with tissue masks predicted by thresholding methods (TIAToolbox, CLAM) and deep learning methods (pretrained "non-finetuned" SAM2 model, Trident-QC, Trident-Hest and AtlasPatch), highlighting differences in boundary fidelity, artifact suppression and handling of fragmented tissue. Tissue detection performance is also shown on the held-out test set for AtlasPatch and baseline pipelines, highlighting that AtlasPatch matches or exceeds their segmentation quality. The segmentation complexity–performance trade-off, plotting F1-score against segmentation runtime (on a random set of 100 WSIs), shows AtlasPatch achieves high performance with substantially lower wall-clock time than tile-wise detectors and heuristic pipelines, underscoring its suitability for large-scale WSI preprocessing. +Create a CSV file: -### Process Command Arguments +```csv +case_id,slide_path,mpp +case-001,/data/case-001-slide-a.svs,0.25 +case-001,/data/case-001-slide-b.svs, +case-002,/data/case-002-slide-a.svs, +``` -The `process` command is the primary entry point for most workflows. It runs the full pipeline: tissue segmentation, patch coordinate extraction, and feature embedding. You can process a single slide or an entire directory of WSIs in one command. +Then run: ```bash -atlaspatch process --output --patch-size --target-mag --feature-extractors [OPTIONS] +atlaspatch encode-patient cases.csv \ + --output ./output \ + --patient-encoders moozy \ + --patch-size 224 \ + --target-mag 20 \ + --device cuda ``` -#### Required - -| Argument | Description | -| --- | --- | -| `WSI_PATH` | Path to a single slide file or a directory containing slides. When a directory is provided, all supported formats are processed. | -| `--output`, `-o` | Root directory for results. Outputs are organized as `/patches/.h5` for coordinates and features, and `/visualization/` for overlays. | -| `--patch-size` | Final patch size in pixels at the target magnification (e.g., `256` for 256×256 patches). | -| `--target-mag` | Magnification level to extract patches at. Common values: `5`, `10`, `20`, `40`. The pipeline reads from the closest available pyramid level and resizes if needed. | -| `--feature-extractors` | Comma or space-separated list of encoder names from [Available Feature Extractors](#available-feature-extractors). Multiple encoders can be specified to extract several feature sets in one pass (e.g., `resnet50,uni_v2`). | - -#### Optional - -##### Patch Layout - -| Argument | Default | Description | -| --- | --- | --- | -| `--step-size` | Same as `--patch-size` | Stride between patches. Omit for non-overlapping grids. Use smaller values (e.g., `128` with `--patch-size 256`) to create 50% overlap. | - -##### Segmentation & Extraction Performance - -| Argument | Default | Description | -| --- | --- | --- | -| `--device` | `cuda` | Device for SAM2 tissue segmentation. Options: `cuda`, `cuda:0`, `cuda:1`, or `cpu`. | -| `--seg-batch-size` | `1` | Batch size for SAM2 thumbnail segmentation. Increase for faster processing if GPU memory allows. | -| `--patch-workers` | CPU count | Number of threads for patch extraction and H5 writes. | -| `--max-open-slides` | `200` | Maximum number of WSI files open simultaneously. Lower this if you hit file descriptor limits. | - -##### Feature Extraction - -| Argument | Default | Description | -| --- | --- | --- | -| `--feature-device` | Same as `--device` | Device for feature extraction. Set separately to use a different GPU than segmentation. | -| `--feature-batch-size` | `32` | Batch size for the feature extractor forward pass. Increase for faster throughput; decrease if running out of GPU memory. | -| `--feature-num-workers` | `4` | Number of DataLoader workers for loading patches during feature extraction. | -| `--feature-precision` | `float16` | Precision for feature extraction. Options: `float32`, `float16`, `bfloat16`. Lower precision reduces memory and can improve throughput on compatible GPUs. | - -##### Filtering & Quality - -| Argument | Default | Description | -| --- | --- | --- | -| `--fast-mode` | Enabled | Skips per-patch black/white content filtering for faster processing. Use `--no-fast-mode` to enable filtering. | -| `--tissue-thresh` | `0.0` | Minimum tissue area fraction to keep a region. Filters out tiny tissue fragments. | -| `--white-thresh` | `15` | Saturation threshold for white patch filtering (only with `--no-fast-mode`). Lower values are stricter. | -| `--black-thresh` | `50` | RGB threshold for black/dark patch filtering (only with `--no-fast-mode`). Higher values are stricter. | - -##### Visualization - -| Argument | Default | Description | -| --- | --- | --- | -| `--visualize-grids` | Off | Render patch grid overlay on slide thumbnails. | -| `--visualize-mask` | Off | Render tissue segmentation mask overlay. | -| `--visualize-contours` | Off | Render tissue contour overlay. | - -All visualization outputs are saved under `/visualization/`. - -##### Run Control - -| Argument | Default | Description | -| --- | --- | --- | -| `--save-images` | Off | Export each patch as a PNG file under `/images//`. | -| `--recursive` | Off | Walk subdirectories when `WSI_PATH` is a directory. | -| `--mpp-csv` | None | Path to a CSV file with `wsi,mpp` columns to override microns-per-pixel when slide metadata is missing or incorrect. | -| `--skip-existing` | On | Skip slides that already have an output H5 file. This is the default behavior. | -| `--force` | Off | Reprocess slides even when an output H5 file already exists. | -| `--verbose`, `-v` | Off | Enable debug logging and disable the progress bar. | -| `--write-batch` | `8192` | Number of coordinate rows to buffer before flushing to H5. Tune for RAM vs. I/O trade-off. | - -## Supported Formats - -AtlasPatch uses OpenSlide for WSIs and Pillow for standard images: - -- WSIs: `.svs`, `.tif`, `.tiff`, `.ndpi`, `.vms`, `.vmu`, `.scn`, `.mrxs`, `.bif`, `.dcm` -- Images: `.png`, `.jpg`, `.jpeg`, `.bmp`, `.webp`, `.gif` - -## Using Extracted Data - -`atlaspatch process` writes one HDF5 per slide under `/patches/.h5` containing coordinates and feature matrices. Coordinates and features share row order. - -### Patch Coordinates +`encode-patient` groups slides by `case_id`, creates or reuses per-slide H5 files, ensures the required patch features exist, and writes one patient embedding per case. -- Dataset: `coords` (int32, shape `(N, 5)`) with columns `(x, y, read_w, read_h, level)`. -- `x` and `y` are level-0 pixel coordinates. `read_w`, `read_h`, and `level` describe how the patch was read from the WSI. -- The level-0 footprint of each patch is stored as the `patch_size_level0` file attribute; some slide encoders use it for positional encoding (e.g., ALiBi in TITAN). +See [Available Patient Encoders](#available-patient-encoders) for the built-in encoder list, requirements, and install extras. -Example: - -```python -import h5py -import numpy as np -import openslide -from PIL import Image - -h5_path = "output/patches/sample.h5" -wsi_path = "/path/to/slide.svs" - -with h5py.File(h5_path, "r") as f: - coords = f["coords"][...] # (N, 5) int32: [x, y, read_w, read_h, level] - patch_size = int(f.attrs["patch_size"]) +### Supported Inputs -with openslide.OpenSlide(wsi_path) as wsi: - for x, y, read_w, read_h, level in coords: - img = wsi.read_region( - (int(x), int(y)), - int(level), - (int(read_w), int(read_h)), - ).convert("RGB") - if img.size != (patch_size, patch_size): - # Some slides don't have a pyramid level that matches target magnification exactly, so they have to be resized. - img = img.resize((patch_size, patch_size), resample=Image.BILINEAR) - patch = np.array(img) # (H, W, 3) uint8 -``` +From `atlaspatch info`: -### Feature Matrices - -- Group: `features/` inside the same HDF5. -- Each extractor is stored as `features/` (float32, shape `(N, D)`), aligned row-for-row with `coords`. -- List available feature sets with `list(f['features'].keys())`. - -```python -import h5py +- WSI formats: `.svs`, `.tif`, `.tiff`, `.ndpi`, `.vms`, `.vmu`, `.scn`, `.mrxs`, `.bif`, `.dcm` +- image formats: `.png`, `.jpg`, `.jpeg`, `.bmp`, `.webp`, `.gif` -with h5py.File("output/patches/sample.h5", "r") as f: - feat_names = list(f["features"].keys()) - resnet50_feats = f["features/resnet50"][...] # (N, 2048) float32 -``` +## Encoders -## Available Feature Extractors +### Available Patch Feature Extractors -### Core vision backbones on Natural Images +#### Core vision backbones on Natural Images | Name | Output Dim | | --- | --- | @@ -412,7 +270,7 @@ with h5py.File("output/patches/sample.h5", "r") as f: | [`dinov3_vit7b16`](https://huggingface.co/facebook/dinov3-vit7b16-pretrain-lvd1689m) ([DINOv3](https://arxiv.org/abs/2508.10104)) | 4096 | | [`dinov3_vit7b16_sat`](https://huggingface.co/facebook/dinov3-vit7b16-pretrain-sat493m) ([DINOv3](https://arxiv.org/abs/2508.10104)) | 4096 | -### Medical- and Pathology-Specific Vision Encoders +#### Medical- and Pathology-Specific Vision Encoders | Name | Output Dim | | --- | --- | @@ -449,9 +307,9 @@ with h5py.File("output/patches/sample.h5", "r") as f: > ``` > 3. Then you can use the encoder in your commands -### CLIP-like models +#### CLIP-like models -#### Natural Images +##### Natural Images | Name | Output Dim | | --- | --- | @@ -465,7 +323,7 @@ with h5py.File("output/patches/sample.h5", "r") as f: | `clip_vit_l_14` ([Learning Transferable Visual Models From Natural Language Supervision](https://arxiv.org/abs/2103.00020)) | 768 | | `clip_vit_l_14_336` ([Learning Transferable Visual Models From Natural Language Supervision](https://arxiv.org/abs/2103.00020)) | 768 | -#### Medical- and Pathology-Specific CLIP +##### Medical- and Pathology-Specific CLIP | Name | Output Dim | | --- | --- | @@ -477,7 +335,7 @@ with h5py.File("output/patches/sample.h5", "r") as f: | [`biomedclip`](https://huggingface.co/microsoft/BiomedCLIP-PubMedBERT_256-vit_base_patch16_224) ([BiomedCLIP: a multimodal biomedical foundation model pretrained from fifteen million scientific image-text pairs](https://aka.ms/biomedclip-paper)) | 512 | | [`omiclip`](https://huggingface.co/WangGuangyuLab/Loki) ([A visual-omics foundation model to bridge histopathology with spatial transcriptomics](https://www.nature.com/articles/s41592-025-02707-1)) | 768 | -## Bring Your Own Encoder +#### Bring Your Own Encoder Add a custom encoder without touching AtlasPatch by writing a small plugin and pointing the CLI at it with `--feature-plugin /path/to/plugin.py`. The plugin must expose a `register_feature_extractors(registry, device, dtype, num_workers)` function; inside that hook call `register_custom_encoder` with a loader that knows how to load the model and run a forward pass. @@ -516,7 +374,138 @@ def register_feature_extractors(registry, device, dtype, num_workers): ) ``` -Run AtlasPatch with `--feature-plugin /path/to/plugin.py --feature-extractors my_encoder` to benchmark your encoder alongside the built-ins, multiple plugins and extractors can be added at once. Outputs keep the same HDF5 layout—your custom embeddings live under `features/my_encoder` (row-aligned with `coords`) next to other extractors. +Run AtlasPatch with `--feature-plugin /path/to/plugin.py --feature-extractors my_encoder` to benchmark your encoder alongside the built-ins. Multiple plugins and extractors can be added at once. Your custom embeddings will be written under `features/my_encoder`, row-aligned with `coords`, next to the built-in extractors. + +### Available Slide Encoders + +| Encoder | Embedding dim | Required patch encoder | Patch size | Model | Paper | Install | +| --- | --- | --- | --- | --- | --- | --- | +| `titan` | 768 | `conch_v15` | 512 | [MahmoodLab/TITAN](https://huggingface.co/MahmoodLab/TITAN) | [A multimodal whole-slide foundation model for pathology](https://www.nature.com/articles/s41591-025-03982-3) | `atlas-patch[titan]` or `atlas-patch[slide-encoders]` | +| `prism` | 1280 | `virchow_v1` | 224 | [paige-ai/Prism](https://huggingface.co/paige-ai/Prism) | [PRISM: A Multi-Modal Generative Foundation Model for Slide-Level Histopathology](https://arxiv.org/abs/2405.10254) | `atlas-patch[prism]` or `atlas-patch[slide-encoders]` | +| `moozy` | 768 | `lunit_vit_small_patch8_dino` | 224 | [AtlasAnalyticsLab/MOOZY](https://huggingface.co/AtlasAnalyticsLab/MOOZY) | [MOOZY: A Patient-First Foundation Model for Computational Pathology](https://arxiv.org/abs/2603.27048) | `atlas-patch[moozy]` or `atlas-patch[slide-encoders]` | + +### Available Patient Encoders + +| Encoder | Embedding dim | Required patch encoder | Patch size | Model | Paper | Install | +| --- | --- | --- | --- | --- | --- | --- | +| `moozy` | 768 | `lunit_vit_small_patch8_dino` | 224 | [AtlasAnalyticsLab/MOOZY](https://huggingface.co/AtlasAnalyticsLab/MOOZY) | [MOOZY: A Patient-First Foundation Model for Computational Pathology](https://arxiv.org/abs/2603.27048) | `atlas-patch[moozy]` or `atlas-patch[patient-encoders]` | + +## Output Files + +Everything AtlasPatch writes lives under the directory you pass to `--output`. + +### What AtlasPatch writes + +#### Per-Slide H5 files + +Each processed slide gets one H5 file: + +```text +/patches/.h5 +``` + +That file may contain: + +- `coords` +- `features/` +- `slide_features/` + +Rows in `features/` are aligned with `coords`. + +#### Patient embedding files + +Patient embeddings are written separately: + +```text +/patient_features//.h5 +``` + +Each patient H5 file stores the case embedding in `features`. + +#### Optional image outputs + +- patch PNGs: `/images//` +- overlays and masks: `/visualization/` + +### Reading the files + +Per-slide H5 files keep patch coordinates, patch features, and slide embeddings together in one place. + +#### Patch coordinates + +- dataset: `coords` +- shape: `(N, 5)` +- columns: `(x, y, read_w, read_h, level)` +- `x` and `y` are level-0 pixel coordinates. +- `read_w`, `read_h`, and `level` describe how the patch was read from the WSI. +- the level-0 footprint of each patch is stored as the `patch_size_level0` file attribute + +Example: + +```python +import h5py +import numpy as np +import openslide +from PIL import Image + +h5_path = "output/patches/sample.h5" +wsi_path = "/path/to/slide.svs" + +with h5py.File(h5_path, "r") as f: + coords = f["coords"][...] # (N, 5) int32: [x, y, read_w, read_h, level] + patch_size = int(f.attrs["patch_size"]) + +with openslide.OpenSlide(wsi_path) as wsi: + for x, y, read_w, read_h, level in coords: + img = wsi.read_region( + (int(x), int(y)), + int(level), + (int(read_w), int(read_h)), + ).convert("RGB") + if img.size != (patch_size, patch_size): + img = img.resize((patch_size, patch_size), resample=Image.BILINEAR) + patch = np.array(img) # (H, W, 3) uint8 +``` + +#### Patch feature matrices + +- group: `features/` +- dataset: `features/` +- shape: `(N, D)` + +Rows in every feature matrix are aligned with `coords`. + +```python +import h5py + +with h5py.File("output/patches/sample.h5", "r") as f: + feature_names = list(f["features"].keys()) + resnet50_features = f["features/resnet50"][...] +``` + +#### Slide embeddings + +- group: `slide_features/` +- dataset: `slide_features/` +- shape: `(D,)` + +```python +import h5py + +with h5py.File("output/patches/sample.h5", "r") as f: + titan_embedding = f["slide_features/titan"][...] +``` + +#### Patient embeddings + +Patient embeddings are stored in separate H5 files under `patient_features//`. + +```python +import h5py + +with h5py.File("output/patient_features/moozy/case-001.h5", "r") as f: + case_embedding = f["features"][...] +``` ## SLURM job scripts @@ -532,7 +521,7 @@ We prepared ready-to-run SLURM templates under `jobs/`: - Configure `FEATURES` (comma/space list, multiple extractors are supported), `FEATURE_DEVICE`, `FEATURE_BATCH`, `FEATURE_WORKERS`, and `FEATURE_PRECISION`. - This script is intended for feature extraction; use the patch script when you need segmentation + coordinates, and run the feature script to embed one or more models into those H5 files. - Submit with `sbatch jobs/atlaspatch_features.slurm.sh`. -- Running multiple jobs: you can submit several jobs in a loop (e.g., 50 job using `for i in {1..50}; do sbatch jobs/atlaspatch_features.slurm.sh; done`). AtlasPatch uses per-slide lock files to avoid overlapping work on the same slide. +- Running multiple jobs: you can submit several jobs in a loop (for example, `for i in {1..50}; do sbatch jobs/atlaspatch_features.slurm.sh; done`). AtlasPatch uses per-slide lock files to avoid overlapping work on the same slide. ## Frequently Asked Questions (FAQ) @@ -656,6 +645,59 @@ atlaspatch process /path/to/slides --output ./output --force ``` +
+Can I run multiple slide encoders in one command? + +Yes, but only when they agree on the required patch geometry. `encode-slide` runs one patch pipeline and then appends the requested slide embeddings into the same per-slide H5 file, so all requested slide encoders in that run must agree on the patch size they need. + +For example, `titan` and `prism` should be run separately because they require different patch sizes. +
+ +
+What does encode-slide or encode-patient reuse? + +Both commands reuse existing per-slide H5 files by default. If the required patch features for the requested encoder are already present, AtlasPatch uses them directly. If the H5 file exists but the required patch feature dataset is missing, AtlasPatch runs the missing patch feature extraction step and then continues with slide or patient encoding. + +Use `--force` if you want to rebuild instead of reuse. +
+ +
+What should the encode-patient CSV file look like? + +The CSV file must contain: + +- `case_id` +- `slide_path` + +It may also contain: + +- `mpp` + +Each row links one slide to one patient. AtlasPatch groups rows by `case_id`, runs or reuses the per-slide H5 pipeline for each referenced slide, and then writes one patient embedding per patient. +
+ +
+Does encode-patient use slide embeddings? + +No. In `v1.1.0`, patient encoding uses the patch features stored in each slide H5 file. It does not read `slide_features/`. +
+ +
+Where are slide and patient embeddings written? + +Slide embeddings are written into the per-slide H5 file under: + +```text +slide_features/ +``` + +Patient embeddings are written to separate files under: + +```text +patient_features//.h5 +``` +
+ --- Have a question not covered here? Feel free to [open an issue](https://github.com/AtlasAnalyticsLab/AtlasPatch/issues/new) and ask! @@ -681,9 +723,4 @@ If you use AtlasPatch in your research, please cite our paper: ## License -AtlasPatch is released under CC-BY-NC-SA-4.0, which strictly disallows commercial use of the model weights or any derivative works. Commercialization includes selling the model, offering it as a paid service, using it inside commercial products, or distributing modified versions for commercial gain. Non-commercial research, experimentation, educational use, and use by academic or non-profit organizations is permitted under the license terms. If you need commercial rights, please contact the authors to obtain a separate commercial license. See the LICENSE file in this repository for full terms. For the complete license text and detailed terms, see the [LICENSE](./LICENSE) file in this repository. - -## Future Updates - -### Slide Encoders -- We plan to add slide-level encoders (open for extension): TITAN, PRISM, GigaPath, Madeleine. +AtlasPatch is released under [CC BY-NC-SA 4.0](LICENSE). diff --git a/docs/commands/detect-tissue.md b/docs/commands/detect-tissue.md new file mode 100644 index 0000000..8aca5d5 --- /dev/null +++ b/docs/commands/detect-tissue.md @@ -0,0 +1,52 @@ +# detect-tissue + +- [Usage Guide](#usage-guide) + - [One slide](#one-slide) + - [Directory of slides](#directory-of-slides) +- [Arguments](#arguments) +- [Outputs](#outputs) + +## Usage Guide + +`atlaspatch detect-tissue` runs tissue segmentation only. It does not write patch coordinates, patch feature datasets, slide embeddings, or patient embeddings. + +### One slide + +Use this when you only want to inspect the segmentation result for a single WSI. + +```bash +atlaspatch detect-tissue /path/to/slide.svs \ + --output ./output \ + --device cuda +``` + +### Directory of slides + +Point `WSI_PATH` at a directory to segment many slides in one run. Add `--recursive` if slides are nested in subdirectories. + +```bash +atlaspatch detect-tissue /path/to/slides \ + --output ./output \ + --device cuda \ + --recursive +``` + +## Arguments + +| Argument | Type | Required | Default | Description | +|----------|------|----------|---------|-------------| +| `WSI_PATH` | path | yes | - | Path to one slide file or a directory of slides. When a directory is provided, AtlasPatch scans for supported WSI extensions and uses `--recursive` to control whether subdirectories are included. | +| `--output`, `-o` | path | yes | - | Output root for the visualization outputs generated by segmentation. | +| `--device` | text | no | `cuda` | Device used for tissue segmentation. AtlasPatch accepts values such as `cuda`, `cuda:0`, and `cpu`. | +| `--seg-batch-size` | int | no | `1` | Batch size for thumbnail-level tissue segmentation. Increase this only if the segmentation device has enough memory to handle larger thumbnail batches. | +| `--recursive` | flag | no | off | Recurse into subdirectories when `WSI_PATH` is a directory. Ignored when `WSI_PATH` is a single slide file. | +| `--mpp-csv` | path | no | - | CSV file with columns `wsi,mpp` that overrides the slide microns-per-pixel metadata for selected slides. Slides are matched by stem. | +| `--verbose`, `-v` | flag | no | off | Enable debug logging. | + +## Outputs + +`atlaspatch detect-tissue` writes visualization outputs under: + +- `/visualization/` + +Use this command when you want to inspect tissue masks before running patch extraction or when you only need mask overlays and no H5 outputs. diff --git a/docs/commands/encode-patient.md b/docs/commands/encode-patient.md new file mode 100644 index 0000000..426de8b --- /dev/null +++ b/docs/commands/encode-patient.md @@ -0,0 +1,114 @@ +# encode-patient + +- [Usage Guide](#usage-guide) + - [CSV format](#csv-format) + - [One case with multiple slides](#one-case-with-multiple-slides) + - [Many cases in one run](#many-cases-in-one-run) +- [Arguments](#arguments) +- [Outputs](#outputs) + +## Usage Guide + +`atlaspatch encode-patient` reads a CSV file listing slides for each patient, runs or reuses the per-slide patch pipeline for every referenced slide, ensures the required upstream patch features exist, and then writes one patient embedding per case. + +Patient encoders operate on the patch features stored in each slide H5. They do not consume `slide_features/`. + +### CSV format + +The input is a CSV with one row per slide. All rows with the same `case_id` are grouped into one patient case. + +| Column | Required | Description | +|--------|----------|-------------| +| `case_id` | yes | Case identifier. AtlasPatch uses this as the output filename stem under `patient_features//`. | +| `slide_path` | yes | Path to one slide belonging to that case. Relative paths are resolved relative to the CSV file directory. | +| `mpp` | no | Optional per-slide MPP override. If omitted, AtlasPatch falls back to the slide metadata. | + +Additional validation: + +- duplicate slide paths within the same case are rejected +- invalid `case_id` values are rejected before pipeline work starts +- duplicate slide stems that would collide in `patches/.h5` are rejected + +Example CSV: + +```csv +case_id,slide_path,mpp +case_001,/data/case_001_slide_a.svs,0.25 +case_001,/data/case_001_slide_b.svs,0.25 +case_002,/data/case_002_slide_a.svs, +``` + +### One case with multiple slides + +In `v1.1.0`, AtlasPatch ships one built-in patient encoder: `moozy`. If a case has multiple slides, AtlasPatch builds or reuses one H5 file per slide, then aggregates all of that case's slide-level patch-feature inputs into one patient embedding per encoder. + +```bash +atlaspatch encode-patient cases.csv \ + --output ./output \ + --patient-encoders moozy \ + --patch-size 224 \ + --target-mag 20 \ + --device cuda +``` + +### Many cases in one run + +One CSV file can contain many cases. AtlasPatch groups rows by `case_id` and writes one output H5 per case under `patient_features//`. + +```bash +atlaspatch encode-patient cases.csv \ + --output ./output \ + --patient-encoders moozy \ + --patch-size 224 \ + --target-mag 20 +``` + +## Arguments + +| Argument | Type | Required | Default | Description | +|----------|------|----------|---------|-------------| +| `MANIFEST_PATH` | path | yes | - | Path to the CSV file that maps slides to patient cases. Each row names one slide, and rows are grouped by `case_id` during patient encoding. | +| `--output`, `-o` | path | yes | - | Output root for the per-slide H5 files, optional overlays or patch images, and final patient embedding files. | +| `--patient-encoders` | text | yes | - | One or more patient encoders, separated by spaces or commas. In `v1.1.0`, the built-in choice is `moozy`. Each encoder writes one file under `patient_features//`. | +| `--patch-size` | int | yes | - | Patch size, in pixels, at the requested target magnification. This must match the geometry required by the selected patient encoder set. | +| `--step-size` | int | no | same as `--patch-size` | Stride, in pixels, between adjacent patches at the target magnification when AtlasPatch needs to build or refresh per-slide H5 files. | +| `--target-mag` | int | yes | - | Target magnification used when extracting or validating the per-slide H5 files referenced by the patient cases. | +| `--feature-device` | text | no | same as `--device` | Device used for any upstream patch feature extraction required by the selected patient encoders. | +| `--feature-batch-size` | int | no | `32` | Batch size used while computing any missing upstream patch features. | +| `--feature-num-workers` | int | no | `4` | DataLoader worker count for upstream patch feature extraction. | +| `--feature-precision` | choice | no | `float16` | Computation precision for any missing upstream patch feature extraction. Supported values are `float32`, `float16`, and `bfloat16`. | +| `--feature-plugin` | path | no | - | Path to a Python module that registers custom patch feature extractors. This matters only if a selected patient encoder depends on a custom upstream patch encoder. | +| `--device` | text | no | `cuda` | Device used for tissue segmentation and patient encoder inference. AtlasPatch accepts values such as `cuda`, `cuda:0`, and `cpu`. | +| `--tissue-thresh` | float | no | `0.0` | Minimum tissue area fraction required for a patch to be kept while building or refreshing per-slide H5 files. | +| `--white-thresh` | int | no | `15` | Saturation threshold used by the optional white-filtering stage in `--no-fast-mode`. | +| `--black-thresh` | int | no | `50` | RGB threshold used by the optional black-filtering stage in `--no-fast-mode`. | +| `--seg-batch-size` | int | no | `1` | Batch size for thumbnail-level tissue segmentation. | +| `--write-batch` | int | no | `8192` | Number of coordinate rows buffered before writing to H5 while building or refreshing per-slide H5 files. | +| `--patch-workers` | int | no | CPU count | Number of worker threads used during patch extraction and optional patch PNG export. | +| `--max-open-slides` | int | no | `200` | Upper bound on how many slides AtlasPatch keeps open across segmentation and extraction. | +| `--fast-mode / --no-fast-mode` | flag | no | `--fast-mode` | `--fast-mode` skips per-patch black and white filtering after segmentation. Use `--no-fast-mode` if you want that extra filtering pass. | +| `--save-images` | flag | no | off | Save extracted patches as PNGs under `images//` while building or refreshing per-slide H5 files. | +| `--visualize-grids` | flag | no | off | Save patch-grid overlays under `visualization/`. | +| `--visualize-mask` | flag | no | off | Save tissue-mask overlays under `visualization/`. | +| `--visualize-contours` | flag | no | off | Save contour overlays under `visualization/`. | +| `--skip-existing / --force` | flag | no | `--skip-existing` | Reuse existing per-slide H5 files and existing patient embedding files when their saved metadata still matches the current source H5 files. Use `--force` to rebuild and overwrite them. | +| `--verbose`, `-v` | flag | no | off | Enable debug logging. | + +## Outputs + +`atlaspatch encode-patient` writes or reuses per-slide H5 files under: + +- `/patches/.h5` + +Patient embeddings are written as separate files under: + +- `/patient_features//.h5` + +Important constraints: + +- Patient encoders consume patch features from the per-slide H5 files, not slide embeddings. +- AtlasPatch resolves required upstream patch encoders automatically. You do not pass `--feature-extractors` directly. +- The built-in MOOZY path uses the upstream public Python API. +- MOOZY's public API cannot force CPU when CUDA is visible. On a GPU-visible host, use `--device cuda` or run in a CPU-only environment if you need CPU inference. + +More detail: [../../README.md#available-patient-encoders](../../README.md#available-patient-encoders) diff --git a/docs/commands/encode-slide.md b/docs/commands/encode-slide.md new file mode 100644 index 0000000..339f457 --- /dev/null +++ b/docs/commands/encode-slide.md @@ -0,0 +1,108 @@ +# encode-slide + +- [Usage Guide](#usage-guide) + - [One slide](#one-slide) + - [Directory of slides](#directory-of-slides) + - [Multiple slide encoders in one run](#multiple-slide-encoders-in-one-run) +- [Arguments](#arguments) +- [Outputs](#outputs) + +## Usage Guide + +`atlaspatch encode-slide` runs the full WSI pipeline for each requested slide, reuses existing per-slide H5 files when possible, ensures the required upstream patch features exist, and then appends slide embeddings into the same H5. + +This command is WSI-only. You pass slide paths, not precomputed H5 files. + +### One slide + +Use this when you want a slide embedding for one WSI and AtlasPatch should manage the upstream patch pipeline for you. + +```bash +atlaspatch encode-slide /path/to/slide.svs \ + --output ./output \ + --slide-encoders titan \ + --patch-size 512 \ + --target-mag 20 \ + --device cuda +``` + +### Directory of slides + +Point `WSI_PATH` at a directory to encode many slides in one run. Add `--recursive` if slides are nested in subdirectories. + +```bash +atlaspatch encode-slide /path/to/slides \ + --output ./output \ + --slide-encoders prism \ + --patch-size 224 \ + --target-mag 20 \ + --recursive +``` + +### Multiple slide encoders in one run + +AtlasPatch can run multiple slide encoders in one pass as long as they depend on the same upstream patch geometry. For example, `prism` and `moozy` can share a run because both use 224-pixel patches, while `titan` must run separately because it expects 512-pixel patches. + +```bash +atlaspatch encode-slide /path/to/slides \ + --output ./output \ + --slide-encoders prism,moozy \ + --patch-size 224 \ + --target-mag 20 +``` + +## Arguments + +| Argument | Type | Required | Default | Description | +|----------|------|----------|---------|-------------| +| `WSI_PATH` | path | yes | - | Path to one slide file or a directory of slides. When a directory is provided, AtlasPatch scans for supported WSI extensions and uses `--recursive` to control whether subdirectories are included. | +| `--output`, `-o` | path | yes | - | Output root for the per-slide H5 files and any optional overlays or patch images generated while building or refreshing those H5 files. | +| `--slide-encoders` | text | yes | - | One or more slide encoders, separated by spaces or commas. Each encoder writes one dataset under `slide_features/` inside the per-slide H5. | +| `--patch-size` | int | yes | - | Patch size, in pixels, at the requested target magnification. This must match the geometry required by the selected slide encoder set. | +| `--step-size` | int | no | same as `--patch-size` | Stride, in pixels, between adjacent patches at the target magnification. Use a smaller value than `--patch-size` if you want overlapping patches in the per-slide H5. | +| `--target-mag` | int | yes | - | Target magnification used when extracting or validating the per-slide H5. AtlasPatch records this in the H5 metadata and uses it to determine whether existing H5 files are reusable. | +| `--feature-device` | text | no | same as `--device` | Device used for any upstream patch feature extraction required by the selected slide encoders. | +| `--feature-batch-size` | int | no | `32` | Batch size used while computing any missing upstream patch features. | +| `--feature-num-workers` | int | no | `4` | DataLoader worker count for upstream patch feature extraction. | +| `--feature-precision` | choice | no | `float16` | Computation precision for any missing upstream patch feature extraction. Supported values are `float32`, `float16`, and `bfloat16`. | +| `--feature-plugin` | path | no | - | Path to a Python module that registers custom patch feature extractors. This matters only if a selected slide encoder depends on a custom upstream patch encoder. | +| `--device` | text | no | `cuda` | Device used for tissue segmentation and slide encoder inference. AtlasPatch accepts values such as `cuda`, `cuda:0`, and `cpu`. | +| `--tissue-thresh` | float | no | `0.0` | Minimum tissue area fraction required for a patch to be kept while building or refreshing the per-slide H5. | +| `--white-thresh` | int | no | `15` | Saturation threshold used by the optional white-filtering stage in `--no-fast-mode`. | +| `--black-thresh` | int | no | `50` | RGB threshold used by the optional black-filtering stage in `--no-fast-mode`. | +| `--seg-batch-size` | int | no | `1` | Batch size for thumbnail-level tissue segmentation. | +| `--write-batch` | int | no | `8192` | Number of coordinate rows buffered before writing to H5 while building or refreshing the per-slide H5. | +| `--patch-workers` | int | no | CPU count | Number of worker threads used during patch extraction and optional patch PNG export. | +| `--max-open-slides` | int | no | `200` | Upper bound on how many slides AtlasPatch keeps open across segmentation and extraction. | +| `--fast-mode / --no-fast-mode` | flag | no | `--fast-mode` | `--fast-mode` skips per-patch black and white filtering after segmentation. Use `--no-fast-mode` if you want that extra filtering pass. | +| `--save-images` | flag | no | off | Save extracted patches as PNGs under `images//` while building or refreshing the per-slide H5. | +| `--visualize-grids` | flag | no | off | Save patch-grid overlays under `visualization/`. | +| `--visualize-mask` | flag | no | off | Save tissue-mask overlays under `visualization/`. | +| `--visualize-contours` | flag | no | off | Save contour overlays under `visualization/`. | +| `--skip-existing / --force` | flag | no | `--skip-existing` | Reuse existing H5 files and existing slide embeddings when their saved metadata still matches the current H5 file. Use `--force` to rebuild and overwrite them. | +| `--recursive` | flag | no | off | Recurse into subdirectories when `WSI_PATH` is a directory. Ignored when `WSI_PATH` is a single slide file. | +| `--mpp-csv` | path | no | - | CSV file with columns `wsi,mpp` that overrides the slide microns-per-pixel metadata for selected slides. Slides are matched by stem. | +| `--verbose`, `-v` | flag | no | off | Enable debug logging. | + +## Outputs + +`atlaspatch encode-slide` writes or reuses the per-slide H5: + +- `/patches/.h5` + +Slide embeddings are appended inside that H5 under: + +- `slide_features/` + +Optional outputs: + +- patch PNGs under `/images//` +- overlays under `/visualization/` + +Important constraints: + +- `encode-slide` resolves required upstream patch encoders automatically. You do not pass `--feature-extractors` directly. +- Existing slide embeddings are reused only if their saved metadata still matches the current H5 file. +- Slide encoders depend on patch features in the per-slide H5, not on raw patch pixels after extraction time. + +More detail: [../../README.md#available-slide-encoders](../../README.md#available-slide-encoders) diff --git a/docs/commands/process.md b/docs/commands/process.md new file mode 100644 index 0000000..ec67835 --- /dev/null +++ b/docs/commands/process.md @@ -0,0 +1,106 @@ +# process + +- [Usage Guide](#usage-guide) + - [One slide](#one-slide) + - [Directory of slides](#directory-of-slides) + - [Separate segmentation and feature devices](#separate-segmentation-and-feature-devices) +- [Arguments](#arguments) +- [Outputs](#outputs) + +## Usage Guide + +`atlaspatch process` runs the full per-slide patch pipeline: tissue segmentation, patch coordinate extraction, patch feature extraction, and any optional image or overlay exports you enable. + +### One slide + +Use this when you want one H5 file for one slide. + +```bash +atlaspatch process /path/to/slide.svs \ + --output ./output \ + --patch-size 256 \ + --target-mag 20 \ + --feature-extractors uni_v2 \ + --device cuda +``` + +### Directory of slides + +Point `WSI_PATH` at a directory to process many slides in one run. Add `--recursive` if slides are nested in subdirectories. + +```bash +atlaspatch process /path/to/slides \ + --output ./output \ + --patch-size 256 \ + --target-mag 20 \ + --feature-extractors resnet50,uni_v2 \ + --recursive +``` + +### Separate segmentation and feature devices + +Segmentation and feature extraction can run on different devices. This is useful when one GPU is reserved for SAM2 and another for patch encoders. + +```bash +atlaspatch process /path/to/slides \ + --output ./output \ + --patch-size 256 \ + --target-mag 20 \ + --feature-extractors virchow_v1 \ + --device cuda:0 \ + --feature-device cuda:1 +``` + +## Arguments + +| Argument | Type | Required | Default | Description | +|----------|------|----------|---------|-------------| +| `WSI_PATH` | path | yes | - | Path to one slide file or a directory of slides. When a directory is provided, AtlasPatch scans that directory for supported WSI extensions and uses `--recursive` to control whether subdirectories are included. | +| `--output`, `-o` | path | yes | - | Output root for all generated outputs. H5 files are written under `patches/`, optional patch PNGs under `images/`, and optional overlays under `visualization/`. | +| `--feature-extractors` | text | yes | - | One or more patch feature extractors, separated by spaces or commas. Each extractor writes one dataset under `features/` inside the per-slide H5. | +| `--patch-size` | int | yes | - | Patch size, in pixels, at the requested target magnification. This controls both the patch grid and the expected geometry for any downstream encoders that reuse the generated H5. | +| `--step-size` | int | no | same as `--patch-size` | Stride, in pixels, between adjacent patch origins at the target magnification. Use a smaller value than `--patch-size` if you want overlapping patches. | +| `--target-mag` | int | yes | - | Target magnification used when reading patches from the slide pyramid. AtlasPatch records this in the H5 metadata and uses it later to validate reuse of the artifact. | +| `--feature-device` | text | no | same as `--device` | Device used for patch feature extraction. This can differ from `--device` when segmentation and embedding should run on different devices. | +| `--feature-batch-size` | int | no | `32` | Batch size used when embedding extracted patches. Larger values improve throughput but increase memory use. | +| `--feature-num-workers` | int | no | `4` | DataLoader worker count for patch feature extraction. Increase this if patch embedding becomes input-bound rather than model-bound. | +| `--feature-precision` | choice | no | `float16` | Computation precision for patch feature extraction. Supported values are `float32`, `float16`, and `bfloat16`. | +| `--feature-plugin` | path | no | - | Path to a Python module that registers custom patch feature extractors. Use this when extending AtlasPatch beyond the built-in registry. | +| `--device` | text | no | `cuda` | Device used for tissue segmentation. AtlasPatch accepts values such as `cuda`, `cuda:0`, and `cpu`. | +| `--tissue-thresh` | float | no | `0.0` | Minimum tissue area fraction required for a patch to be kept after segmentation. Increase this to drop patches with too little tissue. | +| `--white-thresh` | int | no | `15` | Saturation threshold used by the optional white-filtering stage in `--no-fast-mode`. Higher values mark more bright background as discardable. | +| `--black-thresh` | int | no | `50` | RGB threshold used by the optional black-filtering stage in `--no-fast-mode`. Lower values mark darker regions as discardable. | +| `--seg-batch-size` | int | no | `1` | Batch size for thumbnail-level tissue segmentation. This affects the SAM2 inference stage, not patch embedding. | +| `--write-batch` | int | no | `8192` | Number of coordinate rows buffered before writing to H5. Larger values reduce write frequency but increase transient memory use. | +| `--patch-workers` | int | no | CPU count | Number of worker threads used while extracting patch coordinates and, if enabled, saving patch PNGs. | +| `--max-open-slides` | int | no | `200` | Upper bound on how many slides AtlasPatch keeps open across segmentation and extraction. Reduce this if the host has strict file-handle limits. | +| `--fast-mode / --no-fast-mode` | flag | no | `--fast-mode` | `--fast-mode` skips per-patch black and white filtering after segmentation. Use `--no-fast-mode` if you want the additional filtering pass. | +| `--save-images` | flag | no | off | Save each extracted patch as a PNG under `images//`. This is optional and is not required for H5-based downstream workflows. | +| `--visualize-grids` | flag | no | off | Save a patch-grid overlay for each processed slide under `visualization/`. | +| `--visualize-mask` | flag | no | off | Save the predicted tissue mask overlay for each processed slide under `visualization/`. | +| `--visualize-contours` | flag | no | off | Save the contour overlay used during patch extraction under `visualization/`. | +| `--skip-existing / --force` | flag | no | `--skip-existing` | Reuse existing H5 outputs by default. Use `--force` to rebuild them even when the output files already exist. | +| `--recursive` | flag | no | off | Recurse into subdirectories when `WSI_PATH` is a directory. Ignored when `WSI_PATH` is a single slide file. | +| `--mpp-csv` | path | no | - | CSV file with columns `wsi,mpp` that overrides the slide microns-per-pixel metadata for selected slides. Slides are matched by stem. | +| `--verbose`, `-v` | flag | no | off | Enable debug logging. | + +## Outputs + +`atlaspatch process` writes one H5 file per slide: + +- `/patches/.h5` + +That H5 contains: + +- `coords` +- `features/` + +Optional outputs: + +- patch PNGs under `/images//` +- overlays under `/visualization/` + +More detail: + +- [../../README.md#using-extracted-data](../../README.md#using-extracted-data) +- [../../README.md#available-patch-feature-extractors](../../README.md#available-patch-feature-extractors) diff --git a/docs/commands/segment-and-get-coords.md b/docs/commands/segment-and-get-coords.md new file mode 100644 index 0000000..7f1c44f --- /dev/null +++ b/docs/commands/segment-and-get-coords.md @@ -0,0 +1,94 @@ +# segment-and-get-coords + +- [Usage Guide](#usage-guide) + - [One slide](#one-slide) + - [Directory of slides](#directory-of-slides) + - [Coordinates plus overlays](#coordinates-plus-overlays) +- [Arguments](#arguments) +- [Outputs](#outputs) + +## Usage Guide + +`atlaspatch segment-and-get-coords` runs tissue segmentation and patch coordinate extraction, then writes the per-slide H5 without patch feature matrices. + +Use this command when you want AtlasPatch's coordinates and metadata now, but you want to defer feature extraction to a later `process`, `encode-slide`, or `encode-patient` run. + +### One slide + +```bash +atlaspatch segment-and-get-coords /path/to/slide.svs \ + --output ./output \ + --patch-size 256 \ + --target-mag 20 \ + --device cuda +``` + +### Directory of slides + +Point `WSI_PATH` at a directory to patchify many slides in one run. Add `--recursive` if slides are nested in subdirectories. + +```bash +atlaspatch segment-and-get-coords /path/to/slides \ + --output ./output \ + --patch-size 256 \ + --target-mag 20 \ + --recursive +``` + +### Coordinates plus overlays + +Add visualization flags when you want to inspect the extracted grid and segmentation outputs alongside the H5 coordinates. + +```bash +atlaspatch segment-and-get-coords /path/to/slide.svs \ + --output ./output \ + --patch-size 256 \ + --target-mag 20 \ + --visualize-grids \ + --visualize-mask \ + --visualize-contours +``` + +## Arguments + +| Argument | Type | Required | Default | Description | +|----------|------|----------|---------|-------------| +| `WSI_PATH` | path | yes | - | Path to one slide file or a directory of slides. When a directory is provided, AtlasPatch scans for supported WSI extensions and uses `--recursive` to control whether subdirectories are included. | +| `--output`, `-o` | path | yes | - | Output root for the H5 files and any optional overlays or patch images generated during patch extraction. | +| `--patch-size` | int | yes | - | Patch size, in pixels, at the requested target magnification. This controls the extracted grid written into the H5 file. | +| `--step-size` | int | no | same as `--patch-size` | Stride, in pixels, between adjacent patches at the target magnification. Use a smaller value than `--patch-size` if you want overlapping coordinates. | +| `--target-mag` | int | yes | - | Target magnification used when reading patches from the slide pyramid. AtlasPatch records this in the H5 metadata. | +| `--device` | text | no | `cuda` | Device used for tissue segmentation. AtlasPatch accepts values such as `cuda`, `cuda:0`, and `cpu`. | +| `--tissue-thresh` | float | no | `0.0` | Minimum tissue area fraction required for a patch to be kept after segmentation. | +| `--white-thresh` | int | no | `15` | Saturation threshold used by the optional white-filtering stage in `--no-fast-mode`. | +| `--black-thresh` | int | no | `50` | RGB threshold used by the optional black-filtering stage in `--no-fast-mode`. | +| `--seg-batch-size` | int | no | `1` | Batch size for thumbnail-level tissue segmentation. | +| `--write-batch` | int | no | `8192` | Number of coordinate rows buffered before writing to H5. Larger values reduce write frequency but increase transient memory use. | +| `--patch-workers` | int | no | CPU count | Number of worker threads used during patch extraction and optional patch PNG export. | +| `--max-open-slides` | int | no | `200` | Upper bound on how many slides AtlasPatch keeps open across segmentation and extraction. | +| `--fast-mode / --no-fast-mode` | flag | no | `--fast-mode` | `--fast-mode` skips per-patch black and white filtering after segmentation. Use `--no-fast-mode` if you want that extra filtering pass. | +| `--save-images` | flag | no | off | Save extracted patches as PNGs under `images//`. This is optional and is not required for later H5-based feature extraction. | +| `--visualize-grids` | flag | no | off | Save patch-grid overlays under `visualization/`. | +| `--visualize-mask` | flag | no | off | Save tissue-mask overlays under `visualization/`. | +| `--visualize-contours` | flag | no | off | Save contour overlays under `visualization/`. | +| `--skip-existing / --force` | flag | no | `--skip-existing` | Reuse existing H5 outputs by default. Use `--force` to rebuild them even when the output files already exist. | +| `--recursive` | flag | no | off | Recurse into subdirectories when `WSI_PATH` is a directory. Ignored when `WSI_PATH` is a single slide file. | +| `--mpp-csv` | path | no | - | CSV file with columns `wsi,mpp` that overrides the slide microns-per-pixel metadata for selected slides. Slides are matched by stem. | +| `--verbose`, `-v` | flag | no | off | Enable debug logging. | + +## Outputs + +`atlaspatch segment-and-get-coords` writes one H5 file per slide: + +- `/patches/.h5` + +That H5 contains slide metadata and: + +- `coords` + +It does not contain `features/` datasets unless you later run `process`, `encode-slide`, or `encode-patient`. + +Optional outputs: + +- patch PNGs under `/images//` +- overlays under `/visualization/` diff --git a/docs/release-notes/v1.1.0.md b/docs/release-notes/v1.1.0.md new file mode 100644 index 0000000..9e1d346 --- /dev/null +++ b/docs/release-notes/v1.1.0.md @@ -0,0 +1,19 @@ +# AtlasPatch v1.1.0 Release Notes + +- Added `encode-slide`. +- Added `encode-patient`. +- `process` remains patch-only. +- Added built-in slide encoders: `titan`, `prism`, `moozy`. +- Added built-in patient encoder: `moozy`. +- Patch features remain in `patches/.h5` under `features/`. +- Slide embeddings are written to `slide_features/` inside the per-slide H5. +- Patient embeddings are written to `patient_features//.h5`. +- `encode-slide` reuses existing slide H5 files and computes missing required patch features when needed. +- `encode-patient` reuses existing slide H5 files and computes missing required patch features when needed. +- Added optional install extras: + - `atlas-patch[patch-encoders]` + - `atlas-patch[titan]` + - `atlas-patch[prism]` + - `atlas-patch[moozy]` + - `atlas-patch[slide-encoders]` + - `atlas-patch[patient-encoders]` diff --git a/pyproject.toml b/pyproject.toml index 4cffea0..87242d1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -67,6 +67,10 @@ slide-encoders = [ "sacremoses>=0.1.1", "moozy>=0.1.0", ] +patient-encoders = [ + "timm>=0.9.0", + "moozy>=0.1.0", +] titan = [ "einops>=0.8.0", "einops-exts>=0.0.4", From 47c1d0342e361807337a5124672ccdba81c38953 Mon Sep 17 00:00:00 2001 From: yousefkotp Date: Mon, 6 Apr 2026 20:43:08 -0400 Subject: [PATCH 6/8] fix: restore titan and prism slide encoding --- atlas_patch/models/common.py | 7 +++++++ atlas_patch/models/slide/prism.py | 32 ++++++++++++++++++------------- atlas_patch/models/slide/titan.py | 16 +++++++++++----- atlas_patch/utils/feature_h5.py | 2 +- atlas_patch/utils/hf.py | 13 +++++++++++++ pyproject.toml | 3 +++ 6 files changed, 54 insertions(+), 19 deletions(-) diff --git a/atlas_patch/models/common.py b/atlas_patch/models/common.py index 0e977ad..8e43c74 100644 --- a/atlas_patch/models/common.py +++ b/atlas_patch/models/common.py @@ -1,6 +1,7 @@ from __future__ import annotations import logging +from contextlib import nullcontext import numpy as np import torch @@ -18,6 +19,12 @@ def resolve_model_device(device: str | torch.device) -> torch.device: return resolved +def model_autocast(device: torch.device, dtype: torch.dtype): + if device.type != "cuda" or dtype == torch.float32: + return nullcontext() + return torch.autocast(device_type=device.type, dtype=dtype) + + def coerce_model_embedding( output: torch.Tensor | np.ndarray, *, diff --git a/atlas_patch/models/slide/prism.py b/atlas_patch/models/slide/prism.py index 4d50818..4a51371 100644 --- a/atlas_patch/models/slide/prism.py +++ b/atlas_patch/models/slide/prism.py @@ -6,9 +6,10 @@ import numpy as np import torch -from atlas_patch.models.common import coerce_model_embedding, resolve_model_device +from atlas_patch.models.common import coerce_model_embedding, model_autocast, resolve_model_device from atlas_patch.models.slide.base import SlideEncoder, SlideEncoderSpec from atlas_patch.models.slide.registry import SlideEncoderRegistry +from atlas_patch.utils.hf import download_hf_file, load_remote_class from atlas_patch.utils.feature_h5 import load_patch_feature_data _MODEL_ID = "paige-ai/Prism" @@ -22,16 +23,27 @@ def _load_prism_model(*, device: torch.device, dtype: torch.dtype): try: import environs # noqa: F401 import sacremoses # noqa: F401 - from transformers import AutoModel + from safetensors.torch import load_file except ModuleNotFoundError as exc: raise RuntimeError( "PRISM requires optional slide-encoder dependencies. " "Install `atlas-patch[prism]` or `atlas-patch[slide-encoders]`." ) from exc - model = AutoModel.from_pretrained(_MODEL_ID, trust_remote_code=True) - if hasattr(model, "text_decoder"): - model.text_decoder = None + config_class = load_remote_class(_MODEL_ID, "configuring_prism.PrismConfig") + model_class = load_remote_class(_MODEL_ID, "modeling_prism.Prism") + config = config_class.from_json_file(str(download_hf_file(_MODEL_ID, "config.json"))) + config.tie_word_embeddings = False + config.biogpt_config.tie_word_embeddings = False + + model = model_class(config) + state_dict = load_file(str(download_hf_file(_MODEL_ID, "model.safetensors"))) + # PRISM shares repeated layer weights, so the checkpoint only stores one copy per shared block. + _, unexpected = model.load_state_dict(state_dict, strict=False) + if unexpected: + joined = ", ".join(sorted(unexpected)) + raise RuntimeError(f"PRISM checkpoint contains unexpected parameter keys: {joined}") + model.text_decoder = None return model.to(device=device, dtype=dtype).eval() @@ -63,15 +75,9 @@ def encode_slide(self, patch_h5_path: Path) -> np.ndarray: features = torch.from_numpy(patch_data.features).unsqueeze(0).to( device=self.device, - dtype=self.dtype, ) - tile_mask = torch.ones( - (1, patch_data.num_patches), - device=self.device, - dtype=torch.long, - ) - with torch.inference_mode(): - output = self.model.slide_representations(features, tile_mask=tile_mask) + with torch.inference_mode(), model_autocast(self.device, self.dtype): + output = self.model.slide_representations(features) if not isinstance(output, dict) or "image_embedding" not in output: raise ValueError("PRISM did not return an 'image_embedding' entry.") return coerce_model_embedding( diff --git a/atlas_patch/models/slide/titan.py b/atlas_patch/models/slide/titan.py index 61f5d97..171cfc4 100644 --- a/atlas_patch/models/slide/titan.py +++ b/atlas_patch/models/slide/titan.py @@ -5,9 +5,10 @@ import numpy as np import torch -from atlas_patch.models.common import coerce_model_embedding, resolve_model_device +from atlas_patch.models.common import coerce_model_embedding, model_autocast, resolve_model_device from atlas_patch.models.slide.base import SlideEncoder, SlideEncoderSpec from atlas_patch.models.slide.registry import SlideEncoderRegistry +from atlas_patch.utils.hf import download_hf_file, load_remote_class from atlas_patch.utils.feature_h5 import load_patch_feature_data _MODEL_ID = "MahmoodLab/TITAN" @@ -16,14 +17,20 @@ def _load_titan_model(*, device: torch.device, dtype: torch.dtype): try: - from transformers import AutoModel + from safetensors.torch import load_file + + config_class = load_remote_class(_MODEL_ID, "configuration_titan.TitanConfig") + model_class = load_remote_class(_MODEL_ID, "modeling_titan.Titan") except ModuleNotFoundError as exc: raise RuntimeError( "TITAN requires optional slide-encoder dependencies. " "Install `atlas-patch[titan]` or `atlas-patch[slide-encoders]`." ) from exc - model = AutoModel.from_pretrained(_MODEL_ID, trust_remote_code=True) + config = config_class.from_json_file(str(download_hf_file(_MODEL_ID, "config.json"))) + model = model_class(config) + state_dict = load_file(str(download_hf_file(_MODEL_ID, "model.safetensors"))) + model.load_state_dict(state_dict, strict=True) return model.to(device=device, dtype=dtype).eval() @@ -55,13 +62,12 @@ def encode_slide(self, patch_h5_path: Path) -> np.ndarray: features = torch.from_numpy(patch_data.features).unsqueeze(0).to( device=self.device, - dtype=self.dtype, ) coords = torch.from_numpy(patch_data.coords[:, :2]).unsqueeze(0).to( device=self.device, dtype=torch.int64, ) - with torch.inference_mode(): + with torch.inference_mode(), model_autocast(self.device, self.dtype): embedding = self.model.encode_slide_from_patch_features( features, coords, diff --git a/atlas_patch/utils/feature_h5.py b/atlas_patch/utils/feature_h5.py index c792eab..ac855ef 100644 --- a/atlas_patch/utils/feature_h5.py +++ b/atlas_patch/utils/feature_h5.py @@ -302,7 +302,7 @@ def load_patch_feature_data( feature_name=feature_key, dataset_key=dataset_key, features=np.asarray(features, dtype=np.float32), - coords=np.asarray(coords, dtype=np.int64), + coords=np.asarray(coords), patch_size_level0=attrs["patch_size_level0"], patch_size=attrs["patch_size"], target_magnification=attrs["target_magnification"], diff --git a/atlas_patch/utils/hf.py b/atlas_patch/utils/hf.py index d136282..9060380 100644 --- a/atlas_patch/utils/hf.py +++ b/atlas_patch/utils/hf.py @@ -1,6 +1,7 @@ from __future__ import annotations import importlib.util +from pathlib import Path from types import ModuleType from huggingface_hub import hf_hub_download @@ -15,3 +16,15 @@ def import_module_from_hf(repo_id: str, filename: str) -> ModuleType: module = importlib.util.module_from_spec(spec) spec.loader.exec_module(module) return module + + +def download_hf_file(repo_id: str, filename: str) -> Path: + """Download a file from HuggingFace Hub and return its local path.""" + return Path(hf_hub_download(repo_id, filename=filename)) + + +def load_remote_class(repo_id: str, class_reference: str): + """Load a trusted remote-code class from a HuggingFace model repo.""" + from transformers.dynamic_module_utils import get_class_from_dynamic_module + + return get_class_from_dynamic_module(class_reference, repo_id) diff --git a/pyproject.toml b/pyproject.toml index 87242d1..c1f0c0a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -60,6 +60,7 @@ slide-encoders = [ "fairscale>=0.4.0", "gdown>=5.2.0", "open-clip-torch>=2.24.0", + "safetensors>=0.4.0", "sentencepiece>=0.2.0", "timm>=0.9.0", "transformers>=4.41.0", @@ -77,10 +78,12 @@ titan = [ "fairscale>=0.4.0", "gdown>=5.2.0", "open-clip-torch>=2.24.0", + "safetensors>=0.4.0", "sentencepiece>=0.2.0", "transformers>=4.41.0", ] prism = [ + "safetensors>=0.4.0", "timm>=0.9.0", "transformers>=4.41.0", "environs>=11.0.0", From 3ff1aba870bde606ba0caa2f75f43abff307c260 Mon Sep 17 00:00:00 2001 From: yousefkotp Date: Mon, 6 Apr 2026 21:35:09 -0400 Subject: [PATCH 7/8] perf: short-circuit skip-existing reuse --- atlas_patch/cli.py | 108 +++++++++++------ atlas_patch/orchestration/runner.py | 134 +++++++++++++--------- atlas_patch/services/feature_embedding.py | 2 + atlas_patch/services/patient_embedding.py | 99 +++++++++------- atlas_patch/services/slide_embedding.py | 113 +++++++++--------- 5 files changed, 271 insertions(+), 185 deletions(-) diff --git a/atlas_patch/cli.py b/atlas_patch/cli.py index 89cde73..efa2613 100644 --- a/atlas_patch/cli.py +++ b/atlas_patch/cli.py @@ -449,46 +449,84 @@ def _run_pipeline_from_config( configure_logging(verbose) - from atlas_patch.orchestration.runner import ProcessingRunner - from atlas_patch.services.extraction import PatchExtractionService - from atlas_patch.services.feature_embedding import PatchFeatureEmbeddingService - from atlas_patch.services.mpp import CSVMPPResolver - from atlas_patch.services.segmentation import SAM2SegmentationService - from atlas_patch.services.visualization import DefaultVisualizationService - from atlas_patch.services.wsi_loader import DefaultWSILoader + from atlas_patch.orchestration.runner import ProcessingRunner, classify_existing_slide_output + + if slides is None: + slides = [ + Slide(path=Path(path)) + for path in get_wsi_files( + str(app_cfg.processing.input_path), + recursive=app_cfg.processing.recursive, + ) + ] + else: + slides = list(slides) - segmentation_service = SAM2SegmentationService(app_cfg.segmentation) - extractor_service = PatchExtractionService(app_cfg.extraction, app_cfg.output) - visualizer_service = None - if ( - app_cfg.output.visualize_grids - or app_cfg.output.visualize_mask - or app_cfg.output.visualize_contours - ): - visualizer_service = DefaultVisualizationService( - app_cfg.output, - app_cfg.extraction, - app_cfg.visualization, + if not slides: + logging.getLogger("atlas_patch.runner").warning("No slides found to process.") + return [], [] + + preflight_results: list = [] + slides_to_process: list[Slide] = [] + for slide in slides: + decision, existing_result = classify_existing_slide_output(app_cfg, slide) + if decision == "skip": + continue + if decision == "reuse" and existing_result is not None: + preflight_results.append(existing_result) + continue + slides_to_process.append(slide) + + results = list(preflight_results) + failures: list = [] + wsi_loader = None + + if slides_to_process: + from atlas_patch.services.extraction import PatchExtractionService + from atlas_patch.services.mpp import CSVMPPResolver + from atlas_patch.services.segmentation import SAM2SegmentationService + from atlas_patch.services.visualization import DefaultVisualizationService + from atlas_patch.services.wsi_loader import DefaultWSILoader + + wsi_loader = DefaultWSILoader() + segmentation_service = SAM2SegmentationService(app_cfg.segmentation) + extractor_service = PatchExtractionService(app_cfg.extraction, app_cfg.output) + visualizer_service = None + if ( + app_cfg.output.visualize_grids + or app_cfg.output.visualize_mask + or app_cfg.output.visualize_contours + ): + visualizer_service = DefaultVisualizationService( + app_cfg.output, + app_cfg.extraction, + app_cfg.visualization, + ) + + mpp_resolver = CSVMPPResolver(app_cfg.processing.mpp_csv) + runner = ProcessingRunner( + config=app_cfg, + segmentation=segmentation_service, + extractor=extractor_service, + visualizer=visualizer_service, + mpp_resolver=mpp_resolver, + wsi_loader=wsi_loader, + show_progress=not verbose, ) - mpp_resolver = CSVMPPResolver(app_cfg.processing.mpp_csv) - wsi_loader = DefaultWSILoader() - runner = ProcessingRunner( - config=app_cfg, - segmentation=segmentation_service, - extractor=extractor_service, - visualizer=visualizer_service, - mpp_resolver=mpp_resolver, - wsi_loader=wsi_loader, - show_progress=not verbose, - ) + try: + run_results, run_failures = runner.run(slides=slides_to_process) + results.extend(run_results) + failures.extend(run_failures) + finally: + segmentation_service.close() - try: - results, failures = runner.run(slides=slides) - finally: - segmentation_service.close() + if app_cfg.features is not None and results: + from atlas_patch.services.feature_embedding import PatchFeatureEmbeddingService + from atlas_patch.services.wsi_loader import DefaultWSILoader - if app_cfg.features is not None: + if wsi_loader is None: + wsi_loader = DefaultWSILoader() feature_service = PatchFeatureEmbeddingService( app_cfg.extraction, app_cfg.output, diff --git a/atlas_patch/orchestration/runner.py b/atlas_patch/orchestration/runner.py index 6893b86..8e01223 100644 --- a/atlas_patch/orchestration/runner.py +++ b/atlas_patch/orchestration/runner.py @@ -36,6 +36,78 @@ def _chunked(items: Sequence[Slide], size: int) -> Iterable[Sequence[Slide]]: yield items[i : i + size] +def build_existing_extraction_result(slide: Slide, h5_path: Path) -> ExtractionResult | None: + """Create a lightweight ExtractionResult from an existing H5.""" + metadata: dict[str, Any] = {} + num_patches: int | None = None + patch_size_level0: int | None = None + try: + with h5py.File(h5_path, "r") as f: + num_attr = f.attrs.get("num_patches") + if num_attr is not None: + num_patches = int(num_attr) + elif "coords" in f: + num_patches = int(f["coords"].shape[0]) + + ps_level0_attr = f.attrs.get("patch_size_level0") + if ps_level0_attr is not None: + patch_size_level0 = int(ps_level0_attr) + except Exception as e: # noqa: BLE001 + logger.warning( + "Failed to read existing output for %s; will reprocess. Error: %s", + slide.path.name, + e, + ) + return None + + if num_patches is None or num_patches <= 0: + return None + + return ExtractionResult( + slide=slide, + h5_path=h5_path, + num_patches=int(num_patches), + patch_size_level0=patch_size_level0, + metadata=metadata, + ) + + +def classify_existing_slide_output( + config: AppConfig, + slide: Slide, +) -> tuple[str | None, ExtractionResult | None]: + """Return how an existing patch H5 should be treated for this run. + + Returns one of: + - (None, None): no reusable output; the slide needs full processing + - ("skip", None): output is fully complete for this run + - ("reuse", ExtractionResult): reuse patches/H5 and continue with downstream features + """ + if not config.output.skip_existing: + return None, None + + existing_path = find_existing_patch(slide, config.output, config.extraction) + if existing_path is None: + return None, None + + feat_cfg = config.features + if feat_cfg is None or not feat_cfg.extractors: + return "skip", None + + existing_result = build_existing_extraction_result(slide, existing_path) + if existing_result is None: + return None, None + + missing = missing_features( + existing_path, + feat_cfg.extractors, + expected_total=existing_result.num_patches, + ) + if not missing: + return "skip", None + return "reuse", existing_result + + class ProcessingRunner: """High-level orchestration of WSI segmentation, patch extraction, and visualization.""" @@ -68,41 +140,6 @@ def discover_slides(self) -> list[Slide]: slides.append(slide) return slides - def _build_existing_result(self, slide: Slide, h5_path: Path) -> ExtractionResult | None: - """Create a lightweight ExtractionResult from an existing H5 (no re-segmentation).""" - metadata: dict[str, Any] = {} - num_patches: int | None = None - patch_size_level0: int | None = None - try: - with h5py.File(h5_path, "r") as f: - num_attr = f.attrs.get("num_patches") - if num_attr is not None: - num_patches = int(num_attr) - elif "coords" in f: - num_patches = int(f["coords"].shape[0]) - - ps_level0_attr = f.attrs.get("patch_size_level0") - if ps_level0_attr is not None: - patch_size_level0 = int(ps_level0_attr) - except Exception as e: # noqa: BLE001 - logger.warning( - "Failed to read existing output for %s; will reprocess. Error: %s", - slide.path.name, - e, - ) - return None - - if num_patches is None or num_patches <= 0: - return None - - return ExtractionResult( - slide=slide, - h5_path=h5_path, - num_patches=int(num_patches), - patch_size_level0=patch_size_level0, - metadata=metadata, - ) - def _handle_existing_slide( self, slide: Slide, @@ -113,35 +150,22 @@ def _handle_existing_slide( Returns True when the slide is fully handled (skip or reuse), False to continue processing. """ - if not self.config.output.skip_existing: - return False - - existing_path = find_existing_patch(slide, self.config.output, self.config.extraction) - if existing_path is None: + decision, existing_result = classify_existing_slide_output(self.config, slide) + if decision is None: return False - - feat_cfg = self.config.features - if feat_cfg is None or not feat_cfg.extractors: + if decision == "skip": logger.info("Skipping %s (already processed).", slide.path.name) if progress: progress.update(1) return True - - existing_result = self._build_existing_result(slide, existing_path) if existing_result is None: - logger.info("Existing output invalid for %s; reprocessing.", slide.path.name) return False - + results.append(existing_result) missing = missing_features( - existing_path, feat_cfg.extractors, expected_total=existing_result.num_patches + existing_result.h5_path, + self.config.features.extractors, + expected_total=existing_result.num_patches, ) - if not missing: - logger.info("Skipping %s (features complete).", slide.path.name) - if progress: - progress.update(1) - return True - - results.append(existing_result) logger.info( "Reusing existing patches for %s; missing features: %s", slide.path.name, diff --git a/atlas_patch/services/feature_embedding.py b/atlas_patch/services/feature_embedding.py index 938bdbc..0e1ff29 100644 --- a/atlas_patch/services/feature_embedding.py +++ b/atlas_patch/services/feature_embedding.py @@ -275,6 +275,8 @@ def embed_all( progress.update(completed_units) for name in self.extractor_names: + if not any(name in missing_for_slide for missing_for_slide in pending.values()): + continue try: extractor = self.registry.create(name) except Exception as e: # noqa: BLE001 diff --git a/atlas_patch/services/patient_embedding.py b/atlas_patch/services/patient_embedding.py index 4587273..7f01729 100644 --- a/atlas_patch/services/patient_embedding.py +++ b/atlas_patch/services/patient_embedding.py @@ -9,7 +9,11 @@ from atlas_patch.core.config import OutputConfig, PatientEncodingConfig from atlas_patch.core.models import PatientCase, PatientEmbeddingResult, Slide from atlas_patch.core.paths import patient_embedding_path, patient_lock_path, validate_output_stem -from atlas_patch.models.patient import PatientEncoderRegistry, build_default_registry +from atlas_patch.models.patient import ( + PatientEncoderRegistry, + PatientEncoderSpec, + build_default_registry, +) from atlas_patch.services._embedding_runtime import ( acquire_exclusive_lock, cleanup_encoder, @@ -121,7 +125,7 @@ def _build_source_slide_summaries(slide_h5_paths: list[Path]) -> tuple[PatchArti def _validate_reusable_embedding( *, output_path: Path, - encoder, + spec: PatientEncoderSpec, case: PatientCase, slide_h5_paths: list[Path], ) -> PatientEmbeddingResult | None: @@ -132,13 +136,13 @@ def _validate_reusable_embedding( summary = read_patient_embedding_summary(output_path) except Exception: return None - if summary.encoder_name != encoder.name: + if summary.encoder_name != spec.name: return None - if summary.embedding_dim != encoder.embedding_dim: + if summary.embedding_dim != spec.embedding_dim: return None if summary.num_slides != case.num_slides: return None - if summary.source_patch_encoder != encoder.required_patch_encoder: + if summary.source_patch_encoder != spec.patch_encoder_name: return None current_slide_h5s = tuple(str(path.resolve()) for path in slide_h5_paths) if summary.source_slide_h5_paths != current_slide_h5s: @@ -148,7 +152,7 @@ def _validate_reusable_embedding( return PatientEmbeddingResult( case_id=case.case_id, h5_path=output_path, - encoder_name=encoder.name, + encoder_name=spec.name, embedding_dim=summary.embedding_dim, num_slides=summary.num_slides, source_patch_encoder=summary.source_patch_encoder, @@ -168,46 +172,55 @@ def embed_all( return results, failures for encoder_name in self.patient_cfg.encoders: + spec = self.registry.get_spec(encoder_name) + pending_cases: list[tuple[PatientCase, Path, list[Path], bool]] = [] + + for case in ordered_cases: + try: + slide_h5_paths = [ + slide_h5_by_path[slide.path.resolve()].resolve() + for slide in case.slides + ] + except KeyError as exc: + failures.append( + ( + case, + RuntimeError( + f"{encoder_name}: missing canonical patch artifact for {exc.args[0]}" + ), + ) + ) + continue + + output_path = patient_embedding_path(self.output_cfg, encoder_name, case.case_id) + if self.patient_cfg.skip_existing and output_path.exists(): + reused = self._validate_reusable_embedding( + output_path=output_path, + spec=spec, + case=case, + slide_h5_paths=slide_h5_paths, + ) + if reused is not None: + results.append(reused) + continue + overwrite = True + else: + overwrite = not self.patient_cfg.skip_existing + + pending_cases.append((case, output_path, slide_h5_paths, overwrite)) + + if not pending_cases: + continue + try: encoder = self.registry.create(encoder_name) except Exception as exc: # noqa: BLE001 - for case in ordered_cases: + for case, _, _, _ in pending_cases: failures.append((case, RuntimeError(f"{encoder_name}: {exc}"))) continue try: - for case in ordered_cases: - try: - slide_h5_paths = [ - slide_h5_by_path[slide.path.resolve()].resolve() - for slide in case.slides - ] - except KeyError as exc: - failures.append( - ( - case, - RuntimeError( - f"{encoder_name}: missing canonical patch artifact for {exc.args[0]}" - ), - ) - ) - continue - - output_path = patient_embedding_path(self.output_cfg, encoder_name, case.case_id) - if self.patient_cfg.skip_existing and output_path.exists(): - reused = self._validate_reusable_embedding( - output_path=output_path, - encoder=encoder, - case=case, - slide_h5_paths=slide_h5_paths, - ) - if reused is not None: - results.append(reused) - continue - overwrite = True - else: - overwrite = not self.patient_cfg.skip_existing - + for case, output_path, slide_h5_paths, overwrite in pending_cases: lock_fd, lock_path = acquire_exclusive_lock( patient_lock_path(self.output_cfg, encoder_name, case.case_id), payload=( @@ -240,9 +253,9 @@ def embed_all( output_path, embedding, attrs={ - "encoder_name": encoder.name, + "encoder_name": spec.name, "num_slides": case.num_slides, - "source_patch_encoder": encoder.required_patch_encoder, + "source_patch_encoder": spec.patch_encoder_name, "source_manifest": str(self.patient_cfg.manifest_path.resolve()), "source_slide_h5_paths": [ str(path.resolve()) for path in slide_h5_paths @@ -265,10 +278,10 @@ def embed_all( PatientEmbeddingResult( case_id=case.case_id, h5_path=output_path, - encoder_name=encoder.name, + encoder_name=spec.name, embedding_dim=int(embedding.shape[0]), num_slides=case.num_slides, - source_patch_encoder=encoder.required_patch_encoder, + source_patch_encoder=spec.patch_encoder_name, metadata={"reused": False}, ) ) diff --git a/atlas_patch/services/slide_embedding.py b/atlas_patch/services/slide_embedding.py index 3060469..48f9966 100644 --- a/atlas_patch/services/slide_embedding.py +++ b/atlas_patch/services/slide_embedding.py @@ -8,7 +8,7 @@ from atlas_patch.core.config import ExtractionConfig, OutputConfig, SlideEncodingConfig from atlas_patch.core.models import Slide, SlideEmbeddingResult from atlas_patch.core.paths import slide_append_lock_path, slide_feature_dataset_key -from atlas_patch.models.slide import SlideEncoderRegistry, build_default_registry +from atlas_patch.models.slide import SlideEncoderRegistry, SlideEncoderSpec, build_default_registry from atlas_patch.services._embedding_runtime import ( acquire_exclusive_lock, cleanup_encoder, @@ -72,25 +72,25 @@ def _build_result( @staticmethod def _validate_reusable_embedding( *, - encoder, + spec: SlideEncoderSpec, h5_path: Path, summary: SlideEmbeddingSummary, patch_summary: PatchArtifactSummary, ) -> None: - if summary.embedding_dim != encoder.embedding_dim: + if summary.embedding_dim != spec.embedding_dim: raise ValueError( f"{h5_path} has slide embedding dim {summary.embedding_dim} for '{summary.dataset_key}', " - f"but '{encoder.name}' expects {encoder.embedding_dim}." + f"but '{spec.name}' expects {spec.embedding_dim}." ) - if summary.source_patch_encoder != encoder.required_patch_encoder: + if summary.source_patch_encoder != spec.patch_encoder_name: raise ValueError( f"{h5_path} was encoded from '{summary.source_patch_encoder}', " - f"but '{encoder.name}' requires '{encoder.required_patch_encoder}'." + f"but '{spec.name}' requires '{spec.patch_encoder_name}'." ) - if summary.patch_size != encoder.required_patch_size: + if summary.patch_size != spec.patch_size: raise ValueError( f"{h5_path} records patch_size={summary.patch_size} for '{summary.dataset_key}', " - f"but '{encoder.name}' requires {encoder.required_patch_size}." + f"but '{spec.name}' requires {spec.patch_size}." ) if summary.num_patches != patch_summary.num_patches: raise ValueError( @@ -121,56 +121,65 @@ def embed_all( return results, failures for encoder_name in self.slide_cfg.encoders: - try: - encoder = self.registry.create(encoder_name) - except Exception as exc: # noqa: BLE001 - for slide, _ in items: + spec = self.registry.get_spec(encoder_name) + pending_items: list[tuple[Slide, Path, PatchArtifactSummary, bool]] = [] + + for slide, h5_path in items: + try: + patch_summary = read_patch_artifact_summary(h5_path) + except Exception as exc: # noqa: BLE001 failures.append((slide, RuntimeError(f"{encoder_name}: {exc}"))) - continue + continue - try: - for slide, h5_path in items: + overwrite_existing = not self.slide_cfg.skip_existing + if self.slide_cfg.skip_existing: try: - patch_summary = read_patch_artifact_summary(h5_path) - except Exception as exc: # noqa: BLE001 - failures.append((slide, RuntimeError(f"{encoder_name}: {exc}"))) - continue + existing = read_slide_embedding_summary(h5_path, encoder_name) + except Exception: + existing = None + overwrite_existing = True - overwrite_existing = not self.slide_cfg.skip_existing - if self.slide_cfg.skip_existing: + if existing is not None: try: - existing = read_slide_embedding_summary(h5_path, encoder_name) + self._validate_reusable_embedding( + spec=spec, + h5_path=h5_path, + summary=existing, + patch_summary=patch_summary, + ) except Exception: - existing = None overwrite_existing = True - - if existing is not None: - try: - self._validate_reusable_embedding( - encoder=encoder, + else: + results.append( + self._build_result( + slide=slide, h5_path=h5_path, - summary=existing, - patch_summary=patch_summary, + encoder_name=encoder_name, + embedding_dim=existing.embedding_dim, + source_patch_encoder=existing.source_patch_encoder, + num_patches=existing.num_patches, + patch_size=existing.patch_size, + patch_size_level0=existing.patch_size_level0, + target_magnification=existing.target_magnification, + reused=True, ) - except Exception: - overwrite_existing = True - else: - results.append( - self._build_result( - slide=slide, - h5_path=h5_path, - encoder_name=encoder_name, - embedding_dim=existing.embedding_dim, - source_patch_encoder=existing.source_patch_encoder, - num_patches=existing.num_patches, - patch_size=existing.patch_size, - patch_size_level0=existing.patch_size_level0, - target_magnification=existing.target_magnification, - reused=True, - ) - ) - continue + ) + continue + + pending_items.append((slide, h5_path, patch_summary, overwrite_existing)) + + if not pending_items: + continue + + try: + encoder = self.registry.create(encoder_name) + except Exception as exc: # noqa: BLE001 + for slide, _, _, _ in pending_items: + failures.append((slide, RuntimeError(f"{encoder_name}: {exc}"))) + continue + try: + for slide, h5_path, patch_summary, overwrite_existing in pending_items: lock_fd, lock_path = acquire_exclusive_lock( slide_append_lock_path(slide, self.output_cfg, self.extraction_cfg), payload=( @@ -191,10 +200,10 @@ def embed_all( continue try: - if patch_summary.patch_size != encoder.required_patch_size: + if patch_summary.patch_size != spec.patch_size: raise ValueError( f"{h5_path} has patch_size={patch_summary.patch_size}, " - f"but '{encoder_name}' requires {encoder.required_patch_size}." + f"but '{encoder_name}' requires {spec.patch_size}." ) embedding = encoder.encode_slide(h5_path) @@ -203,7 +212,7 @@ def embed_all( encoder_name, embedding, attrs={ - "source_patch_encoder": encoder.required_patch_encoder, + "source_patch_encoder": spec.patch_encoder_name, "num_patches": patch_summary.num_patches, "patch_size": patch_summary.patch_size, "patch_size_level0": patch_summary.patch_size_level0, @@ -217,7 +226,7 @@ def embed_all( h5_path=h5_path, encoder_name=encoder_name, embedding_dim=int(embedding.shape[0]), - source_patch_encoder=encoder.required_patch_encoder, + source_patch_encoder=spec.patch_encoder_name, num_patches=patch_summary.num_patches, patch_size=patch_summary.patch_size, patch_size_level0=patch_summary.patch_size_level0, From 569720c4e94d39de47c696401889fa5d0db2a438 Mon Sep 17 00:00:00 2001 From: yousefkotp Date: Mon, 6 Apr 2026 21:43:54 -0400 Subject: [PATCH 8/8] fix: align medsiglip preprocessing with upstream usage --- atlas_patch/models/patch/medsiglip.py | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/atlas_patch/models/patch/medsiglip.py b/atlas_patch/models/patch/medsiglip.py index fc73538..3a9a923 100644 --- a/atlas_patch/models/patch/medsiglip.py +++ b/atlas_patch/models/patch/medsiglip.py @@ -3,6 +3,7 @@ import logging import torch +from PIL import Image from atlas_patch.models.patch.base import PatchFeatureExtractor from atlas_patch.models.patch.registry import PatchFeatureExtractorRegistry @@ -11,6 +12,14 @@ _MODEL_ID = "google/medsiglip-448" _EMB_DIM = 1152 +_IMAGE_SIZE = (448, 448) + + +def _resize_patch(pil_img: Image.Image) -> Image.Image: + image = pil_img if pil_img.mode == "RGB" else pil_img.convert("RGB") + if image.size == _IMAGE_SIZE: + return image + return image.resize(_IMAGE_SIZE, resample=Image.BILINEAR) class MedSigLip(PatchFeatureExtractor): @@ -44,11 +53,11 @@ def __init__( raise RuntimeError( f"Loaded model '{_MODEL_ID}' does not expose get_image_features; cannot be used for patch embeddings." ) - model = model.to(device=self.device, dtype=self.dtype).eval() def _preprocess(pil_img): - inputs = processor(images=pil_img, padding="max_length", return_tensors="pt") + resized = _resize_patch(pil_img) + inputs = processor(images=resized, padding="max_length", return_tensors="pt") return inputs["pixel_values"].squeeze(0) super().__init__(