Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
9f4de74
Remove pathlib dependency (include in stdlib as of python 3.10)
MarcelRosier Jul 28, 2025
40ed1ec
Add synthstrip base package as optional dependency
MarcelRosier Jul 28, 2025
f922179
rm outdated todo comment
MarcelRosier Jul 28, 2025
35ef47f
WIP: synthstrip extractor
MarcelRosier Jul 28, 2025
86f34ba
Pass down use_gpu to bet method
MarcelRosier Jul 28, 2025
80debfb
Remove unused params and fix pylance errors
MarcelRosier Jul 28, 2025
b4de51d
Handle device param and ad docstrings
MarcelRosier Jul 28, 2025
3ea459e
Add synthstrip to docs
MarcelRosier Jul 28, 2025
bb1bb45
Generalize zenodo fetching
MarcelRosier Jul 28, 2025
3e0e4a9
Integrate updated zenodo fetching, rename synth class, type fixes
MarcelRosier Jul 28, 2025
a129aff
Add bet weights
MarcelRosier Jul 28, 2025
df31d41
Add synthstrip to all extras option
MarcelRosier Jul 28, 2025
eab3e1e
Remove default import of optional synthstrip and add info to docstring
MarcelRosier Jul 28, 2025
74323be
Basic test suite for zenodo integration
MarcelRosier Jul 28, 2025
cb0f0fa
Fix latest folder logic bug
MarcelRosier Jul 28, 2025
ed020e8
Update registration and bet section
MarcelRosier Jul 28, 2025
2d2b513
Fix label name
MarcelRosier Jul 28, 2025
cacde44
Add note that registrator is optional and provide install instructions
MarcelRosier Jul 28, 2025
8052fcb
Update brainles_preprocessing/brain_extraction/synthstrip.py
neuronflow Jul 28, 2025
5312860
Autoformat with black
brainless-bot[bot] Jul 28, 2025
5aa514c
Simplify exception mock
MarcelRosier Jul 28, 2025
a796971
Integrate comments
MarcelRosier Jul 28, 2025
23dd8f6
Merge branch '132-feature-consider-synthstrip-brain-extraction' of gi…
MarcelRosier Jul 28, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -133,3 +133,4 @@ dmypy.json
.DS_Store

brainles_preprocessing/registration/atlases
brainles_preprocessing/brain_extraction/weights
14 changes: 12 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,15 @@ We provide a (WIP) documentation. Have a look [here](https://brainles-preprocess
Please credit the authors by citing their work.

### Registration
We currently provide support for [ANTs](https://github.com/ANTsX/ANTs) (default), [Niftyreg](https://github.com/KCL-BMEIS/niftyreg) (Linux). We also offer basic support for [greedy](https://greedy.readthedocs.io/en/latest/reference.html#greedy-usage) and [elastix](https://pypi.org/project/itk-elastix/0.13.0/).
We currently fully support:
- [ANTs](https://github.com/ANTsX/ANTs) (default)
- [Niftyreg](https://github.com/KCL-BMEIS/niftyreg) (Linux)

We also offer basic support for:
- [greedy](https://greedy.readthedocs.io/en/latest/reference.html#greedy-usage) (Optional dependency, install via: `pip install brainles_preprocessing[picsl_greedy]`)
- [elastix](https://pypi.org/project/itk-elastix/0.13.0/) (Optional dependency, install via: `pip install brainles_preprocessing[itk-elastix]`)

As of now we do not offer inverse transforms for greedy and elastix. Please resort to ANTs or Niftyreg for this.

### Atlas Reference
We provide a range of different atlases, namely:
Expand All @@ -154,7 +162,9 @@ We also support supplying a custom atlas in NIfTI format
We currently provide support for N4 Bias correction based on [SimpleITK](https://simpleitk.org/)

### Brain extraction
We currently provide support for [HD-BET](https://github.com/MIC-DKFZ/HD-BET).
We currently support:
- [HD-BET](https://github.com/MIC-DKFZ/HD-BET)
- [SynthStrip](https://surfer.nmr.mgh.harvard.edu/docs/synthstrip/) (Optional dependency, install via: `pip install brainles_preprocessing[synthstrip]`)

### Defacing
We currently provide support for [Quickshear](https://github.com/nipy/quickshear).
9 changes: 2 additions & 7 deletions brainles_preprocessing/brain_extraction/brain_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,6 @@ def extract(
input_image_path: Union[str, Path],
masked_image_path: Union[str, Path],
brain_mask_path: Union[str, Path],
log_file_path: Optional[Union[str, Path]],
mode: Union[str, Mode],
**kwargs,
) -> None:
"""
Expand All @@ -32,7 +30,6 @@ def extract(
input_image_path (str or Path): Path to the input image.
masked_image_path (str or Path): Path where the brain-extracted image will be saved.
brain_mask_path (str or Path): Path where the brain mask will be saved.
log_file_path (str or Path, Optional): Path to the log file.
mode (str or Mode): Extraction mode.
**kwargs: Additional keyword arguments.
"""
Expand Down Expand Up @@ -86,11 +83,10 @@ def extract(
input_image_path: Union[str, Path],
masked_image_path: Union[str, Path],
brain_mask_path: Union[str, Path],
log_file_path: Optional[Union[str, Path]] = None,
# TODO convert mode to enum
mode: Union[str, Mode] = Mode.ACCURATE,
device: Optional[Union[int, str]] = 0,
do_tta: Optional[bool] = True,
do_tta: bool = True,
**kwargs,
) -> None:
# GPU + accurate + TTA
"""
Expand All @@ -100,7 +96,6 @@ def extract(
input_image_path (str or Path): Path to the input image.
masked_image_path (str or Path): Path where the brain-extracted image will be saved.
brain_mask_path (str or Path): Path where the brain mask will be saved.
log_file_path (str or Path, Optional): Path to the log file.
mode (str or Mode): Extraction mode ('fast' or 'accurate').
device (str or int): Device to use for computation (e.g., 0 for GPU 0, 'cpu' for CPU).
do_tta (bool): whether to do test time data augmentation by mirroring along all axes.
Expand Down
231 changes: 231 additions & 0 deletions brainles_preprocessing/brain_extraction/synthstrip.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,231 @@
# Modified from:
# https://github.com/nipreps/synthstrip/blob/main/nipreps/synthstrip/cli.py
# Original copyright (c) 2024, NiPreps developers
# Licensed under the Apache License, Version 2.0
# Changes made by the BrainLesion Preprocessing team (2025)

from pathlib import Path
from typing import Optional, Union, cast

import nibabel as nib
import numpy as np
import scipy
import torch
from nibabel.nifti1 import Nifti1Image
from nipreps.synthstrip.model import StripModel
from nitransforms.linear import Affine

from brainles_preprocessing.brain_extraction.brain_extractor import BrainExtractor
from brainles_preprocessing.utils.zenodo import fetch_synthstrip


class SynthStripExtractor(BrainExtractor):

def __init__(self, border: int = 1):
"""
Brain extraction using SynthStrip with preprocessing conforming to model requirements.

This is an optional dependency - to use this extractor, you need to install the `brainles_preprocessing` package with the `synthstrip` extra: `pip install brainles_preprocessing[synthstrip]`

Adapted from https://github.com/nipreps/synthstrip

Args:
border (int): Mask border threshold in mm. Defaults to 1.
"""

super().__init__()
self.border = border

def _setup_model(self, device: torch.device) -> StripModel:
"""
Load SynthStrip model and prepare it for inference on the specified device.

Args:
device: Device to load the model onto.

Returns:
A configured and ready-to-use StripModel.
"""
# necessary for speed gains (according to original nipreps authors)
torch.backends.cudnn.benchmark = True
torch.backends.cudnn.deterministic = True

with torch.no_grad():
model = StripModel()
model.to(device)
model.eval()

# Load the model weights
weights_folder = fetch_synthstrip()
weights = weights_folder / "synthstrip.1.pt"
checkpoint = torch.load(weights, map_location=device)
model.load_state_dict(checkpoint["model_state_dict"])

return model

def _conform(self, input_nii: Nifti1Image) -> Nifti1Image:
"""
Resample the input image to match SynthStrip's expected input space.

Args:
input_nii (Nifti1Image): Input NIfTI image to conform.

Raises:
ValueError: If the input NIfTI image does not have a valid affine.

Returns:
A new NIfTI image with conformed shape and affine.
"""

shape = np.array(input_nii.shape[:3])
affine = input_nii.affine

if affine is None:
raise ValueError("Input NIfTI image must have a valid affine.")

# Get corner voxel centers in index coords
corner_centers_ijk = (
np.array(
[
(i, j, k)
for k in (0, shape[2] - 1)
for j in (0, shape[1] - 1)
for i in (0, shape[0] - 1)
]
)
+ 0.5
)

# Get corner voxel centers in mm
corners_xyz = (
affine
@ np.hstack((corner_centers_ijk, np.ones((len(corner_centers_ijk), 1)))).T
)

# Target affine is 1mm voxels in LIA orientation
target_affine = np.diag([-1.0, 1.0, -1.0, 1.0])[:, (0, 2, 1, 3)]

# Target shape
extent = corners_xyz.min(1)[:3], corners_xyz.max(1)[:3]
target_shape = ((extent[1] - extent[0]) / 1.0 + 0.999).astype(int)

# SynthStrip likes dimensions be multiple of 64 (192, 256, or 320)
target_shape = np.clip(
np.ceil(np.array(target_shape) / 64).astype(int) * 64, 192, 320
)

# Ensure shape ordering is LIA too
target_shape[2], target_shape[1] = target_shape[1:3]

# Coordinates of center voxel do not change
input_c = affine @ np.hstack((0.5 * (shape - 1), 1.0))
target_c = target_affine @ np.hstack((0.5 * (target_shape - 1), 1.0))

# Rebase the origin of the new, plumb affine
target_affine[:3, 3] -= target_c[:3] - input_c[:3]

nii = Affine(
reference=Nifti1Image(
np.zeros(target_shape),
target_affine,
None,
),
).apply(input_nii)
return cast(Nifti1Image, nii)

def _resample_like(
self,
image: Nifti1Image,
target: Nifti1Image,
output_dtype: Optional[np.dtype] = None,
cval: Union[int, float] = 0,
) -> Nifti1Image:
"""
Resample the input image to match the target's grid using an identity transform.

Args:
image: The image to be resampled.
target: The reference image.
output_dtype: Output data type.
cval: Value to use for constant padding.

Returns:
A resampled NIfTI image.
"""
result = Affine(reference=target).apply(
image,
output_dtype=output_dtype,
cval=cval,
)
return cast(Nifti1Image, result)

def extract(
self,
input_image_path: Union[str, Path],
masked_image_path: Union[str, Path],
brain_mask_path: Union[str, Path],
device: Union[torch.device, str] = "cuda",
num_threads: int = 1,
**kwargs,
) -> None:
"""
Extract the brain from an input image using SynthStrip.

Args:
input_image_path (Union[str, Path]): Path to the input image.
masked_image_path (Union[str, Path]): Path to the output masked image.
brain_mask_path (Union[str, Path]): Path to the output brain mask.
device (Union[torch.device, str], optional): Device to use for computation. Defaults to "cuda".
num_threads (int, optional): Number of threads to use for computation in CPU mode. Defaults to 1.

Returns:
None: The function saves the masked image and brain mask to the specified paths.
"""

device = torch.device(device) if isinstance(device, str) else device
model = self._setup_model(device=device)

if device.type == "cpu" and num_threads > 0:
torch.set_num_threads(num_threads)

# normalize intensities
image = nib.load(input_image_path)
image = cast(Nifti1Image, image)
conformed = self._conform(image)
in_data = conformed.get_fdata(dtype="float32")
in_data -= in_data.min()
in_data = np.clip(in_data / np.percentile(in_data, 99), 0, 1)
in_data = in_data[np.newaxis, np.newaxis]

# predict the surface distance transform
input_tensor = torch.from_numpy(in_data).to(device)
with torch.no_grad():
sdt = model(input_tensor).cpu().numpy().squeeze()

# unconform the sdt and extract mask
sdt_target = self._resample_like(
Nifti1Image(sdt, conformed.affine, None),
image,
output_dtype=np.dtype("int16"),
cval=100,
)
sdt_data = np.asanyarray(sdt_target.dataobj).astype("int16")

# find largest CC (just do this to be safe for now)
components = scipy.ndimage.label(sdt_data.squeeze() < self.border)[0]
bincount = np.bincount(components.flatten())[1:]
mask = components == (np.argmax(bincount) + 1)
mask = scipy.ndimage.morphology.binary_fill_holes(mask)

# write the masked output
img_data = image.get_fdata()
bg = np.min([0, img_data.min()])
img_data[mask == 0] = bg
Nifti1Image(img_data, image.affine, image.header).to_filename(
masked_image_path,
)

# write the brain mask
hdr = image.header.copy()
hdr.set_data_dtype("uint8")
Nifti1Image(mask, image.affine, hdr).to_filename(brain_mask_path)
12 changes: 10 additions & 2 deletions brainles_preprocessing/modality.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from pathlib import Path
from typing import Dict, Optional, Union

import torch
from auxiliary.io import read_image, write_image
from loguru import logger

Expand All @@ -16,7 +17,7 @@
NiftyRegRegistrator,
)
from brainles_preprocessing.registration.registrator import Registrator
from brainles_preprocessing.utils.zenodo import verify_or_download_atlases
from brainles_preprocessing.utils.zenodo import fetch_atlases


class Modality:
Expand Down Expand Up @@ -598,6 +599,7 @@ def extract_brain_region(
self,
brain_extractor: BrainExtractor,
bet_dir_path: Union[str, Path],
use_gpu: bool = True,
) -> Path:
"""

Expand All @@ -606,6 +608,7 @@ def extract_brain_region(
Args:
brain_extractor (BrainExtractor): The brain extractor object.
bet_dir_path (str or Path): Directory to store brain extraction results.
use_gpu (bool): Whether to use GPU for brain extraction if available.

Returns:
Path: Path to the extracted brain mask.
Expand All @@ -617,11 +620,16 @@ def extract_brain_region(
bet = bet_dir_path / f"{self.modality_name}_bet.nii.gz"
mask_path = bet_dir_path / f"{self.modality_name}_brain_mask.nii.gz"

device = torch.device(
"cuda" if use_gpu and torch.cuda.is_available() else "cpu"
)

brain_extractor.extract(
input_image_path=self.current,
masked_image_path=bet,
brain_mask_path=mask_path,
log_file_path=bet_log,
device=device,
)

# always temporarily store bet image for center modality, since e.g. quickshear defacing could require it
Expand Down Expand Up @@ -666,7 +674,7 @@ def deface(

# resolve atlas image path
if isinstance(defacer.atlas_image_path, Atlas):
atlas_folder = verify_or_download_atlases()
atlas_folder = fetch_atlases()
atlas_image_path = atlas_folder / defacer.atlas_image_path.value
else:
atlas_image_path = Path(defacer.atlas_image_path)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from brainles_preprocessing.n4_bias_correction import N4BiasCorrector
from brainles_preprocessing.preprocessor.preprocessor import BasePreprocessor
from brainles_preprocessing.registration.registrator import Registrator
from brainles_preprocessing.utils.zenodo import verify_or_download_atlases
from brainles_preprocessing.utils.zenodo import fetch_atlases


class AtlasCentricPreprocessor(BasePreprocessor):
Expand Down Expand Up @@ -42,7 +42,7 @@ def __init__(
atlas_image_path: Union[str, Path, Atlas] = Atlas.BRATS_SRI24,
n4_bias_corrector: Optional[N4BiasCorrector] = None,
temp_folder: Optional[Union[str, Path]] = None,
use_gpu: Optional[bool] = None,
use_gpu: bool = True,
limit_cuda_visible_devices: Optional[str] = None,
):
super().__init__(
Expand All @@ -58,7 +58,7 @@ def __init__(
)

if isinstance(atlas_image_path, Atlas):
atlas_folder = verify_or_download_atlases()
atlas_folder = fetch_atlases()
self.atlas_image_path = atlas_folder / atlas_image_path.value
else:
self.atlas_image_path = Path(atlas_image_path)
Expand Down
Loading