diff --git a/.gitignore b/.gitignore index 118e75a..1279a8a 100644 --- a/.gitignore +++ b/.gitignore @@ -133,3 +133,4 @@ dmypy.json .DS_Store brainles_preprocessing/registration/atlases +brainles_preprocessing/brain_extraction/weights diff --git a/README.md b/README.md index 475e18e..3dda476 100644 --- a/README.md +++ b/README.md @@ -138,7 +138,15 @@ We provide a (WIP) documentation. Have a look [here](https://brainles-preprocess Please credit the authors by citing their work. ### Registration -We currently provide support for [ANTs](https://github.com/ANTsX/ANTs) (default), [Niftyreg](https://github.com/KCL-BMEIS/niftyreg) (Linux). We also offer basic support for [greedy](https://greedy.readthedocs.io/en/latest/reference.html#greedy-usage) and [elastix](https://pypi.org/project/itk-elastix/0.13.0/). +We currently fully support: +- [ANTs](https://github.com/ANTsX/ANTs) (default) +- [Niftyreg](https://github.com/KCL-BMEIS/niftyreg) (Linux) + +We also offer basic support for: +- [greedy](https://greedy.readthedocs.io/en/latest/reference.html#greedy-usage) (Optional dependency, install via: `pip install brainles_preprocessing[picsl_greedy]`) +- [elastix](https://pypi.org/project/itk-elastix/0.13.0/) (Optional dependency, install via: `pip install brainles_preprocessing[itk-elastix]`) + +As of now we do not offer inverse transforms for greedy and elastix. Please resort to ANTs or Niftyreg for this. ### Atlas Reference We provide a range of different atlases, namely: @@ -154,7 +162,9 @@ We also support supplying a custom atlas in NIfTI format We currently provide support for N4 Bias correction based on [SimpleITK](https://simpleitk.org/) ### Brain extraction -We currently provide support for [HD-BET](https://github.com/MIC-DKFZ/HD-BET). +We currently support: +- [HD-BET](https://github.com/MIC-DKFZ/HD-BET) +- [SynthStrip](https://surfer.nmr.mgh.harvard.edu/docs/synthstrip/) (Optional dependency, install via: `pip install brainles_preprocessing[synthstrip]`) ### Defacing We currently provide support for [Quickshear](https://github.com/nipy/quickshear). diff --git a/brainles_preprocessing/brain_extraction/brain_extractor.py b/brainles_preprocessing/brain_extraction/brain_extractor.py index cc6a1b2..03a716b 100644 --- a/brainles_preprocessing/brain_extraction/brain_extractor.py +++ b/brainles_preprocessing/brain_extraction/brain_extractor.py @@ -21,8 +21,6 @@ def extract( input_image_path: Union[str, Path], masked_image_path: Union[str, Path], brain_mask_path: Union[str, Path], - log_file_path: Optional[Union[str, Path]], - mode: Union[str, Mode], **kwargs, ) -> None: """ @@ -32,7 +30,6 @@ def extract( input_image_path (str or Path): Path to the input image. masked_image_path (str or Path): Path where the brain-extracted image will be saved. brain_mask_path (str or Path): Path where the brain mask will be saved. - log_file_path (str or Path, Optional): Path to the log file. mode (str or Mode): Extraction mode. **kwargs: Additional keyword arguments. """ @@ -86,11 +83,10 @@ def extract( input_image_path: Union[str, Path], masked_image_path: Union[str, Path], brain_mask_path: Union[str, Path], - log_file_path: Optional[Union[str, Path]] = None, - # TODO convert mode to enum mode: Union[str, Mode] = Mode.ACCURATE, device: Optional[Union[int, str]] = 0, - do_tta: Optional[bool] = True, + do_tta: bool = True, + **kwargs, ) -> None: # GPU + accurate + TTA """ @@ -100,7 +96,6 @@ def extract( input_image_path (str or Path): Path to the input image. masked_image_path (str or Path): Path where the brain-extracted image will be saved. brain_mask_path (str or Path): Path where the brain mask will be saved. - log_file_path (str or Path, Optional): Path to the log file. mode (str or Mode): Extraction mode ('fast' or 'accurate'). device (str or int): Device to use for computation (e.g., 0 for GPU 0, 'cpu' for CPU). do_tta (bool): whether to do test time data augmentation by mirroring along all axes. diff --git a/brainles_preprocessing/brain_extraction/synthstrip.py b/brainles_preprocessing/brain_extraction/synthstrip.py new file mode 100644 index 0000000..0ea3135 --- /dev/null +++ b/brainles_preprocessing/brain_extraction/synthstrip.py @@ -0,0 +1,231 @@ +# Modified from: +# https://github.com/nipreps/synthstrip/blob/main/nipreps/synthstrip/cli.py +# Original copyright (c) 2024, NiPreps developers +# Licensed under the Apache License, Version 2.0 +# Changes made by the BrainLesion Preprocessing team (2025) + +from pathlib import Path +from typing import Optional, Union, cast + +import nibabel as nib +import numpy as np +import scipy +import torch +from nibabel.nifti1 import Nifti1Image +from nipreps.synthstrip.model import StripModel +from nitransforms.linear import Affine + +from brainles_preprocessing.brain_extraction.brain_extractor import BrainExtractor +from brainles_preprocessing.utils.zenodo import fetch_synthstrip + + +class SynthStripExtractor(BrainExtractor): + + def __init__(self, border: int = 1): + """ + Brain extraction using SynthStrip with preprocessing conforming to model requirements. + + This is an optional dependency - to use this extractor, you need to install the `brainles_preprocessing` package with the `synthstrip` extra: `pip install brainles_preprocessing[synthstrip]` + + Adapted from https://github.com/nipreps/synthstrip + + Args: + border (int): Mask border threshold in mm. Defaults to 1. + """ + + super().__init__() + self.border = border + + def _setup_model(self, device: torch.device) -> StripModel: + """ + Load SynthStrip model and prepare it for inference on the specified device. + + Args: + device: Device to load the model onto. + + Returns: + A configured and ready-to-use StripModel. + """ + # necessary for speed gains (according to original nipreps authors) + torch.backends.cudnn.benchmark = True + torch.backends.cudnn.deterministic = True + + with torch.no_grad(): + model = StripModel() + model.to(device) + model.eval() + + # Load the model weights + weights_folder = fetch_synthstrip() + weights = weights_folder / "synthstrip.1.pt" + checkpoint = torch.load(weights, map_location=device) + model.load_state_dict(checkpoint["model_state_dict"]) + + return model + + def _conform(self, input_nii: Nifti1Image) -> Nifti1Image: + """ + Resample the input image to match SynthStrip's expected input space. + + Args: + input_nii (Nifti1Image): Input NIfTI image to conform. + + Raises: + ValueError: If the input NIfTI image does not have a valid affine. + + Returns: + A new NIfTI image with conformed shape and affine. + """ + + shape = np.array(input_nii.shape[:3]) + affine = input_nii.affine + + if affine is None: + raise ValueError("Input NIfTI image must have a valid affine.") + + # Get corner voxel centers in index coords + corner_centers_ijk = ( + np.array( + [ + (i, j, k) + for k in (0, shape[2] - 1) + for j in (0, shape[1] - 1) + for i in (0, shape[0] - 1) + ] + ) + + 0.5 + ) + + # Get corner voxel centers in mm + corners_xyz = ( + affine + @ np.hstack((corner_centers_ijk, np.ones((len(corner_centers_ijk), 1)))).T + ) + + # Target affine is 1mm voxels in LIA orientation + target_affine = np.diag([-1.0, 1.0, -1.0, 1.0])[:, (0, 2, 1, 3)] + + # Target shape + extent = corners_xyz.min(1)[:3], corners_xyz.max(1)[:3] + target_shape = ((extent[1] - extent[0]) / 1.0 + 0.999).astype(int) + + # SynthStrip likes dimensions be multiple of 64 (192, 256, or 320) + target_shape = np.clip( + np.ceil(np.array(target_shape) / 64).astype(int) * 64, 192, 320 + ) + + # Ensure shape ordering is LIA too + target_shape[2], target_shape[1] = target_shape[1:3] + + # Coordinates of center voxel do not change + input_c = affine @ np.hstack((0.5 * (shape - 1), 1.0)) + target_c = target_affine @ np.hstack((0.5 * (target_shape - 1), 1.0)) + + # Rebase the origin of the new, plumb affine + target_affine[:3, 3] -= target_c[:3] - input_c[:3] + + nii = Affine( + reference=Nifti1Image( + np.zeros(target_shape), + target_affine, + None, + ), + ).apply(input_nii) + return cast(Nifti1Image, nii) + + def _resample_like( + self, + image: Nifti1Image, + target: Nifti1Image, + output_dtype: Optional[np.dtype] = None, + cval: Union[int, float] = 0, + ) -> Nifti1Image: + """ + Resample the input image to match the target's grid using an identity transform. + + Args: + image: The image to be resampled. + target: The reference image. + output_dtype: Output data type. + cval: Value to use for constant padding. + + Returns: + A resampled NIfTI image. + """ + result = Affine(reference=target).apply( + image, + output_dtype=output_dtype, + cval=cval, + ) + return cast(Nifti1Image, result) + + def extract( + self, + input_image_path: Union[str, Path], + masked_image_path: Union[str, Path], + brain_mask_path: Union[str, Path], + device: Union[torch.device, str] = "cuda", + num_threads: int = 1, + **kwargs, + ) -> None: + """ + Extract the brain from an input image using SynthStrip. + + Args: + input_image_path (Union[str, Path]): Path to the input image. + masked_image_path (Union[str, Path]): Path to the output masked image. + brain_mask_path (Union[str, Path]): Path to the output brain mask. + device (Union[torch.device, str], optional): Device to use for computation. Defaults to "cuda". + num_threads (int, optional): Number of threads to use for computation in CPU mode. Defaults to 1. + + Returns: + None: The function saves the masked image and brain mask to the specified paths. + """ + + device = torch.device(device) if isinstance(device, str) else device + model = self._setup_model(device=device) + + if device.type == "cpu" and num_threads > 0: + torch.set_num_threads(num_threads) + + # normalize intensities + image = nib.load(input_image_path) + image = cast(Nifti1Image, image) + conformed = self._conform(image) + in_data = conformed.get_fdata(dtype="float32") + in_data -= in_data.min() + in_data = np.clip(in_data / np.percentile(in_data, 99), 0, 1) + in_data = in_data[np.newaxis, np.newaxis] + + # predict the surface distance transform + input_tensor = torch.from_numpy(in_data).to(device) + with torch.no_grad(): + sdt = model(input_tensor).cpu().numpy().squeeze() + + # unconform the sdt and extract mask + sdt_target = self._resample_like( + Nifti1Image(sdt, conformed.affine, None), + image, + output_dtype=np.dtype("int16"), + cval=100, + ) + sdt_data = np.asanyarray(sdt_target.dataobj).astype("int16") + + # find largest CC (just do this to be safe for now) + components = scipy.ndimage.label(sdt_data.squeeze() < self.border)[0] + bincount = np.bincount(components.flatten())[1:] + mask = components == (np.argmax(bincount) + 1) + mask = scipy.ndimage.morphology.binary_fill_holes(mask) + + # write the masked output + img_data = image.get_fdata() + bg = np.min([0, img_data.min()]) + img_data[mask == 0] = bg + Nifti1Image(img_data, image.affine, image.header).to_filename( + masked_image_path, + ) + + # write the brain mask + hdr = image.header.copy() + hdr.set_data_dtype("uint8") + Nifti1Image(mask, image.affine, hdr).to_filename(brain_mask_path) diff --git a/brainles_preprocessing/modality.py b/brainles_preprocessing/modality.py index 648e94c..152cf67 100644 --- a/brainles_preprocessing/modality.py +++ b/brainles_preprocessing/modality.py @@ -4,6 +4,7 @@ from pathlib import Path from typing import Dict, Optional, Union +import torch from auxiliary.io import read_image, write_image from loguru import logger @@ -16,7 +17,7 @@ NiftyRegRegistrator, ) from brainles_preprocessing.registration.registrator import Registrator -from brainles_preprocessing.utils.zenodo import verify_or_download_atlases +from brainles_preprocessing.utils.zenodo import fetch_atlases class Modality: @@ -598,6 +599,7 @@ def extract_brain_region( self, brain_extractor: BrainExtractor, bet_dir_path: Union[str, Path], + use_gpu: bool = True, ) -> Path: """ @@ -606,6 +608,7 @@ def extract_brain_region( Args: brain_extractor (BrainExtractor): The brain extractor object. bet_dir_path (str or Path): Directory to store brain extraction results. + use_gpu (bool): Whether to use GPU for brain extraction if available. Returns: Path: Path to the extracted brain mask. @@ -617,11 +620,16 @@ def extract_brain_region( bet = bet_dir_path / f"{self.modality_name}_bet.nii.gz" mask_path = bet_dir_path / f"{self.modality_name}_brain_mask.nii.gz" + device = torch.device( + "cuda" if use_gpu and torch.cuda.is_available() else "cpu" + ) + brain_extractor.extract( input_image_path=self.current, masked_image_path=bet, brain_mask_path=mask_path, log_file_path=bet_log, + device=device, ) # always temporarily store bet image for center modality, since e.g. quickshear defacing could require it @@ -666,7 +674,7 @@ def deface( # resolve atlas image path if isinstance(defacer.atlas_image_path, Atlas): - atlas_folder = verify_or_download_atlases() + atlas_folder = fetch_atlases() atlas_image_path = atlas_folder / defacer.atlas_image_path.value else: atlas_image_path = Path(defacer.atlas_image_path) diff --git a/brainles_preprocessing/preprocessor/atlas_centric_preprocessor.py b/brainles_preprocessing/preprocessor/atlas_centric_preprocessor.py index ccfe83b..9ed4891 100644 --- a/brainles_preprocessing/preprocessor/atlas_centric_preprocessor.py +++ b/brainles_preprocessing/preprocessor/atlas_centric_preprocessor.py @@ -11,7 +11,7 @@ from brainles_preprocessing.n4_bias_correction import N4BiasCorrector from brainles_preprocessing.preprocessor.preprocessor import BasePreprocessor from brainles_preprocessing.registration.registrator import Registrator -from brainles_preprocessing.utils.zenodo import verify_or_download_atlases +from brainles_preprocessing.utils.zenodo import fetch_atlases class AtlasCentricPreprocessor(BasePreprocessor): @@ -42,7 +42,7 @@ def __init__( atlas_image_path: Union[str, Path, Atlas] = Atlas.BRATS_SRI24, n4_bias_corrector: Optional[N4BiasCorrector] = None, temp_folder: Optional[Union[str, Path]] = None, - use_gpu: Optional[bool] = None, + use_gpu: bool = True, limit_cuda_visible_devices: Optional[str] = None, ): super().__init__( @@ -58,7 +58,7 @@ def __init__( ) if isinstance(atlas_image_path, Atlas): - atlas_folder = verify_or_download_atlases() + atlas_folder = fetch_atlases() self.atlas_image_path = atlas_folder / atlas_image_path.value else: self.atlas_image_path = Path(atlas_image_path) diff --git a/brainles_preprocessing/preprocessor/preprocessor.py b/brainles_preprocessing/preprocessor/preprocessor.py index 7c69397..2983f39 100644 --- a/brainles_preprocessing/preprocessor/preprocessor.py +++ b/brainles_preprocessing/preprocessor/preprocessor.py @@ -76,7 +76,7 @@ def __init__( defacer: Optional[Defacer] = None, n4_bias_corrector: Optional[N4BiasCorrector] = None, temp_folder: Optional[Union[str, Path]] = None, - use_gpu: Optional[bool] = None, + use_gpu: bool = True, limit_cuda_visible_devices: Optional[str] = None, ): @@ -106,6 +106,7 @@ def __init__( self.brain_extractor = brain_extractor self.defacer = defacer + self.use_gpu = use_gpu self._configure_gpu( use_gpu=use_gpu, limit_cuda_visible_devices=limit_cuda_visible_devices ) @@ -133,7 +134,9 @@ def _check_for_name_conflicts(self): raise ValueError(f"Duplicate modality names found: {', '.join(duplicates)}") def _configure_gpu( - self, use_gpu: Optional[bool], limit_cuda_visible_devices: Optional[str] = None + self, + use_gpu: bool, + limit_cuda_visible_devices: Optional[str] = None, ) -> None: """ Configures the environment for GPU usage based on the `use_gpu` parameter and CUDA availability. @@ -141,7 +144,7 @@ def _configure_gpu( Args: use_gpu (Optional[bool]): Determines the GPU usage strategy. """ - if use_gpu or (use_gpu is None and self._cuda_is_available()): + if use_gpu and self._cuda_is_available(): os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" if limit_cuda_visible_devices: os.environ["CUDA_VISIBLE_DEVICES"] = limit_cuda_visible_devices @@ -315,7 +318,9 @@ def run_brain_extraction( self.brain_extractor = HDBetExtractor() atlas_mask = self.center_modality.extract_brain_region( - brain_extractor=self.brain_extractor, bet_dir_path=bet_dir + brain_extractor=self.brain_extractor, + bet_dir_path=bet_dir, + use_gpu=self.use_gpu, ) for moving_modality in self.moving_modalities: logger.info(f"Applying brain mask to {moving_modality.modality_name}...") diff --git a/brainles_preprocessing/registration/elastix/elastix.py b/brainles_preprocessing/registration/elastix/elastix.py index d685dc3..b5780df 100644 --- a/brainles_preprocessing/registration/elastix/elastix.py +++ b/brainles_preprocessing/registration/elastix/elastix.py @@ -26,6 +26,8 @@ def register( """ Register images using elastix. + This is an optional dependency - to use this registrator, you need to install the `brainles_preprocessing` package with the `itk-elastix` extra: `pip install brainles_preprocessing[itk-elastix]`. + Args: fixed_image_path (str): Path to the fixed image. moving_image_path (str): Path to the moving image. diff --git a/brainles_preprocessing/registration/greedy/greedy.py b/brainles_preprocessing/registration/greedy/greedy.py index d3eabea..9a63b60 100644 --- a/brainles_preprocessing/registration/greedy/greedy.py +++ b/brainles_preprocessing/registration/greedy/greedy.py @@ -24,7 +24,9 @@ def register( log_file_path: Optional[str] = None, ) -> None: """ - Register images using greedy. Ref: https://pypi.org/project/picsl-greedy/ and https://greedy.readthedocs.io/en/latest/reference.html#greedy-usage + Register images using greedy. Ref: https://pypi.org/project/picsl-greedy/ and https://greedy.readthedocs.io/en/latest/reference.html#greedy-usage. + + This is an optional dependency - to use this registrator, you need to install the `brainles_preprocessing` package with the `picsl_greedy` extra: `pip install brainles_preprocessing[picsl_greedy]`. Args: fixed_image_path (str): Path to the fixed image. diff --git a/brainles_preprocessing/utils/zenodo.py b/brainles_preprocessing/utils/zenodo.py index 538787b..b43f9d7 100644 --- a/brainles_preprocessing/utils/zenodo.py +++ b/brainles_preprocessing/utils/zenodo.py @@ -1,7 +1,6 @@ from __future__ import annotations import shutil -import sys import zipfile from io import BytesIO from pathlib import Path @@ -11,174 +10,200 @@ from loguru import logger from rich.progress import Progress, SpinnerColumn, TextColumn -ZENODO_RECORD_URL = "https://zenodo.org/api/records/15236131" ATLASES_FOLDER = Path(__file__).parent.parent / "registration" / "atlases" -ATLASES_DIR_PATTERN = "atlases_v*.*.*" +ATLASES_RECORD_ID = "15236131" +SYNTHSTRIP_FOLDER = Path(__file__).parent.parent / "brain_extraction" / "weights" +SYNTHSTRIP_RECORD_ID = "16535633" -def verify_or_download_atlases() -> Path: - """Check if latest atlases are present and download them otherwise. + +def fetch_atlases() -> Path: + """ + Ensure that the required atlases are available locally, downloading them if necessary. Returns: - Path: Path to the atlases folder. + Path: The path to the folder containing the atlases. """ + record = ZenodoRecord( + record_id=ATLASES_RECORD_ID, + target_dir=ATLASES_FOLDER, + label="atlases", + ) + return record.fetch() - zenodo_data = _get_zenodo_metadata_and_archive_url() - if zenodo_data: - zenodo_metadata, archive_url = zenodo_data - matching_folders = list(ATLASES_FOLDER.glob(ATLASES_DIR_PATTERN)) - # Get the latest downloaded atlases - latest_downloaded_atlases = _get_latest_version_folder_name(matching_folders) +def fetch_synthstrip() -> Path: + """ + Ensure that the SynthStrip weights are available locally, downloading them if necessary. - if not latest_downloaded_atlases: - if not zenodo_data: - logger.error( - "Atlases not found locally and Zenodo could not be reached. Exiting..." - ) - sys.exit() - logger.info(f"Atlases not found locally") + Returns: + Path: The path to the folder containing the SynthStrip weights. + """ + record = ZenodoRecord( + record_id=SYNTHSTRIP_RECORD_ID, + target_dir=SYNTHSTRIP_FOLDER, + label="SynthStrip", + ) + return record.fetch() - return _download_atlases( - zenodo_metadata=zenodo_metadata, - archive_url=archive_url, - ) - logger.info(f"Found downloaded local atlases: {latest_downloaded_atlases}") +class ZenodoException(Exception): + """Raised when Zenodo cannot be reached or the request fails.""" - if not zenodo_metadata: - logger.warning( - "Zenodo server could not be reached. Using the latest downloaded atlases." - ) - return ATLASES_FOLDER / latest_downloaded_atlases + pass - # Compare the latest downloaded atlases with the latest Zenodo version - if zenodo_metadata["version"] == latest_downloaded_atlases.split("_v")[1]: - logger.info( - f"Latest atlases ({zenodo_metadata['version']}) are already present." - ) - return ATLASES_FOLDER / latest_downloaded_atlases - logger.info( - f"New atlases available on Zenodo ({zenodo_metadata['version']}). Deleting old and fetching new atlases..." - ) - # delete old atlases - try: - shutil.rmtree( - ATLASES_FOLDER / latest_downloaded_atlases, - ) - except OSError as e: - logger.warning(f"Failed to delete old atlases: {e}") - return _download_atlases(zenodo_metadata=zenodo_metadata, archive_url=archive_url) +class ZenodoRecord: + BASE_URL = "https://zenodo.org/api/records" + def __init__( + self, + record_id: str, + target_dir: Path, + label: str = "asset", + ): + self.record_id = record_id + self.target_dir = target_dir + self.label = label -def _get_latest_version_folder_name(folders: List[Path]) -> str | None: - """Get the latest (non empty) version folder name from the list of folders. + def fetch(self) -> Path: + """Fetch the latest version of the record from Zenodo or from local storage.""" + zenodo_response = self._get_metadata_and_archive_url() - Args: - folders (List[Path]): List of folders matching the pattern. + pattern = self._glob_pattern() + matching_folders = list(self.target_dir.glob(pattern)) + latest_local = self._get_latest_version_folder_name(matching_folders) - Returns: - str | None: Latest version folder name if one exists, else None. - """ - if not folders: - return None - latest_downloaded_folder = sorted( - folders, - reverse=True, - key=lambda x: tuple(map(int, str(x).split("_v")[1].split("."))), - )[0] - # check folder is not empty - if not list(latest_downloaded_folder.glob("*")): - return None - return latest_downloaded_folder.name - - -def _get_zenodo_metadata_and_archive_url() -> Tuple[Dict, str] | None: - """Get the metadata for the Zenodo record and the files archive url. + if not latest_local: + if not zenodo_response: + msg = f"{self.label.title()} not found locally and Zenodo could not be reached." + logger.error(msg) + raise ZenodoException(msg) - Returns: - Tuple: (dict: Metadata for the Zenodo record, str: URL to the archive file) - """ - try: - response = requests.get(f"{ZENODO_RECORD_URL}") - if response.status_code != 200: - logger.error( - f"Cant find atlases on Zenodo ({ZENODO_RECORD_URL}). Exiting..." - ) - return None - data = response.json() - return data["metadata"], data["links"]["archive"] + logger.info(f"{self.label.title()} not found locally.") + metadata, archive_url = zenodo_response + return self._download(metadata, archive_url) - except requests.exceptions.RequestException as e: - logger.warning(f"Failed to fetch Zenodo metadata: {e}") - return None + logger.info(f"Found local {self.label}: {latest_local}") + if not zenodo_response: + logger.warning(f"Zenodo unreachable. Using latest downloaded {self.label}.") + return self.target_dir / latest_local -def _download_atlases(zenodo_metadata: Dict, archive_url: str) -> Path: - """Download the latest atlases from Zenodo for the requested record and extract them to the target folder. + metadata, archive_url = zenodo_response + remote_version = metadata["version"] + local_version = latest_local.split("_v")[1] - Args: - zenodo_metadata (Dict): Metadata for the Zenodo record. - archive_url (str): URL to the archive file. + if remote_version == local_version: + logger.info(f"Latest {self.label} ({remote_version}) already present.") + return self.target_dir / latest_local - Returns: - Path: Path to the atlases folder for the requested record. - """ - record_folder = ATLASES_FOLDER / f"atlases_v{zenodo_metadata['version']}" - # ensure folder exists - record_folder.mkdir(parents=True, exist_ok=True) - - logger.info(f"Downloading atlases from Zenodo. This might take a while...") - # Make a GET request to the URL - response = requests.get(archive_url, stream=True) - # Ensure the request was successful - if response.status_code != 200: - raise RuntimeError( - f"Failed to download atlases from {archive_url}. Status code: {response.status_code}" + logger.info( + f"New version of {self.label} available on Zenodo ({remote_version}). Replacing local copy..." ) + shutil.rmtree( + self.target_dir / latest_local, + onerror=lambda func, path, excinfo: logger.warning( + f"Failed to delete {path}: {excinfo}" + ), + ) + return self._download(metadata, archive_url) + + def _glob_pattern(self) -> str: + return f"{self.record_id}_v*.*.*" + + def _build_folder_path( + self, + version: str, + ) -> Path: + return self.target_dir / f"{self.record_id}_v{version}" + + def _get_latest_version_folder_name( + self, + folders: List[Path], + ) -> str | None: + if not folders: + return None + latest = sorted( + folders, + reverse=True, + key=lambda x: tuple(map(int, str(x.name).split("_v")[1].split("."))), + )[0] + if not list(latest.glob("*")): + return None + return latest.name + + def _get_metadata_and_archive_url(self) -> Tuple[Dict, str] | None: + try: + response = requests.get(f"{self.BASE_URL}/{self.record_id}") + if response.status_code != 200: + error_msg = ( + f"Cannot find record '{self.record_id}' on Zenodo " + f"({response.status_code=})." + ) + logger.error(error_msg) + raise ZenodoException(error_msg) + + data = response.json() + return data["metadata"], data["links"]["archive"] + + except requests.exceptions.RequestException as e: + logger.warning(f"Failed to fetch metadata from Zenodo: {e}") + return None - _extract_archive(response=response, record_folder=record_folder) - - logger.info(f"Zip file extracted successfully to {record_folder}") - return record_folder - - -def _extract_archive(response: requests.Response, record_folder: Path): - # Download with progress bar - chunk_size = 1024 # 1KB - bytes_io = BytesIO() - - with Progress( - SpinnerColumn(), - TextColumn("[cyan]Downloading atlases..."), - TextColumn("[cyan]{task.completed:.2f} MB"), - transient=True, - ) as progress: - task = progress.add_task("", total=None) # Indeterminate progress - - for data in response.iter_content(chunk_size=chunk_size): - bytes_io.write(data) - progress.update( - task, advance=len(data) / (chunk_size**2) - ) # Convert bytes to MB - - # Extract the downloaded zip file to the target folder - with zipfile.ZipFile(bytes_io) as zip_ref: - zip_ref.extractall(record_folder) - - # check if the extracted file is still a zip - for f in record_folder.iterdir(): - if f.is_file() and f.suffix == ".zip": - with zipfile.ZipFile(f) as zip_ref: - files = zip_ref.namelist() - with Progress(transient=True) as progress: - task = progress.add_task( - "[cyan]Extracting files...", total=len(files) - ) - # Iterate over the files and extract them - for i, file in enumerate(files): - zip_ref.extract(file, record_folder) - # Update the progress bar - progress.update(task, completed=i + 1) - f.unlink() # remove zip after extraction + def _download( + self, + metadata: Dict, + archive_url: str, + ) -> Path: + folder = self._build_folder_path(metadata["version"]) + folder.mkdir(parents=True, exist_ok=True) + + logger.info(f"Downloading {self.label} from Zenodo. This may take a while...") + + response = requests.get(archive_url, stream=True) + if response.status_code != 200: + msg = ( + f"Failed to download {self.label}. Status code: {response.status_code}" + ) + logger.error(msg) + raise ZenodoException(msg) + + self._extract_archive(response, folder) + logger.info(f"{self.label.title()} extracted to {folder}") + return folder + + def _extract_archive( + self, + response: requests.Response, + folder: Path, + ): + chunk_size = 1024 # 1KB + buffer = BytesIO() + + with Progress( + SpinnerColumn(), + TextColumn(f"[cyan]Downloading {self.label}..."), + TextColumn("[cyan]{task.completed:.2f} MB"), + transient=True, + ) as progress: + task = progress.add_task("", total=None) + for chunk in response.iter_content(chunk_size=chunk_size): + buffer.write(chunk) + progress.update(task, advance=len(chunk) / (chunk_size**2)) + + with zipfile.ZipFile(buffer) as z: + z.extractall(folder) + + for file in folder.iterdir(): + if file.is_file() and file.suffix == ".zip": + with zipfile.ZipFile(file) as inner_zip: + files = inner_zip.namelist() + with Progress(transient=True) as progress: + task = progress.add_task( + f"[cyan]Extracting inner zip...", total=len(files) + ) + for i, f in enumerate(files): + inner_zip.extract(f, folder) + progress.update(task, completed=i + 1) + file.unlink() # Remove inner zip after extraction diff --git a/docs/source/brain-extraction.rst b/docs/source/brain-extraction.rst index cb11710..4d38086 100644 --- a/docs/source/brain-extraction.rst +++ b/docs/source/brain-extraction.rst @@ -8,3 +8,9 @@ brain_extraction.brain_extractor .. automodule:: brainles_preprocessing.brain_extraction.brain_extractor +brain_extraction.synthstrip +-------------------------------------------- + +.. automodule:: brainles_preprocessing.brain_extraction.synthstrip + + diff --git a/pyproject.toml b/pyproject.toml index f520a3f..e7aa9db 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -50,7 +50,6 @@ classifiers = [ python = "^3.10" # core ttictoc = "^0.5.6" -pathlib = "^1.0.1" nibabel = ">=3.2.1" numpy = "^1.23.0" typer = "^0.15.0" @@ -68,17 +67,23 @@ brainles_hd_bet = ">=0.0.10, <1.0.0" tqdm = "^4.64.1" auxiliary = ">=0.3.1" rich = "^13.6.0" +loguru = "^0.7.3" + # optional registration backends itk-elastix = { version = "^0.20.0", optional = true } picsl_greedy = { version = "^0.0.6", optional = true } -loguru = "^0.7.3" + +# optional BET +nipreps-synthstrip = { version = "^0.0.1", optional = true } + [tool.poetry.extras] -all = ["itk-elastix", "picsl_greedy"] +all = ["itk-elastix", "picsl_greedy", "nipreps-synthstrip"] itk-elastix = ["itk-elastix"] picsl_greedy = ["picsl_greedy"] +synthstrip = ["nipreps-synthstrip"] [tool.poetry.group.dev.dependencies] diff --git a/tests/test_zenodo.py b/tests/test_zenodo.py index 15b7a90..7815b22 100644 --- a/tests/test_zenodo.py +++ b/tests/test_zenodo.py @@ -1,162 +1,182 @@ -import pytest -import shutil -import zipfile -from unittest.mock import MagicMock, patch from pathlib import Path -from io import BytesIO +from unittest.mock import MagicMock, patch -from requests import RequestException +import pytest +import requests from brainles_preprocessing.utils.zenodo import ( - ATLASES_FOLDER, - verify_or_download_atlases, - _get_latest_version_folder_name, - _get_zenodo_metadata_and_archive_url, - _download_atlases, - _extract_archive, + ZenodoException, + ZenodoRecord, + fetch_atlases, + fetch_synthstrip, ) +# ---- Fixtures ---- + @pytest.fixture -def mock_zenodo_metadata(): - return {"version": "1.0.0"}, "https://fakeurl.com/archive.zip" +def dummy_metadata(): + return { + "version": "1.2.3", + "title": "Test Record", + } @pytest.fixture -def mock_atlases_folder(tmp_path): - atlases_path = tmp_path / "atlases" - atlases_path.mkdir() - return atlases_path - - -@patch("brainles_preprocessing.utils.zenodo.requests.get") -def test_get_zenodo_metadata_and_archive_url(mock_get, mock_zenodo_metadata): - mock_response = MagicMock() - mock_response.status_code = 200 - mock_response.json.return_value = { - "metadata": mock_zenodo_metadata[0], - "links": {"archive": mock_zenodo_metadata[1]}, - } - mock_get.return_value = mock_response +def dummy_archive_url(): + return "https://zenodo.org/record/dummy/archive.zip" - metadata, archive_url = _get_zenodo_metadata_and_archive_url() - assert metadata["version"] == "1.0.0" - assert archive_url == "https://fakeurl.com/archive.zip" +@pytest.fixture +def dummy_zenodo_response(dummy_metadata, dummy_archive_url): + return dummy_metadata, dummy_archive_url -@patch("brainles_preprocessing.utils.zenodo.requests.get") -def test_get_zenodo_metadata_and_archive_url_failure(mock_get): - mock_get.side_effect = RequestException() - assert _get_zenodo_metadata_and_archive_url() == None +# ---- Tests for _get_metadata_and_archive_url ---- -@patch( - "brainles_preprocessing.utils.zenodo._get_latest_version_folder_name", - return_value=None, -) -@patch( - "brainles_preprocessing.utils.zenodo._get_zenodo_metadata_and_archive_url", - return_value=None, -) -@patch("brainles_preprocessing.utils.zenodo.logger.error") -def test_verify_or_download_atlases_no_local_no_meta( - mock_sys_exit, mock_get_meta, mock_get_latest_version + +def test_get_metadata_and_archive_url_success( + monkeypatch, dummy_metadata, dummy_archive_url ): - with pytest.raises(SystemExit): - verify_or_download_atlases() - mock_sys_exit.assert_called_once_with( - "Atlases not found locally and Zenodo could not be reached. Exiting..." - ) + response_mock = MagicMock() + response_mock.status_code = 200 + response_mock.json.return_value = { + "metadata": dummy_metadata, + "links": {"archive": dummy_archive_url}, + } + monkeypatch.setattr("requests.get", lambda *args, **kwargs: response_mock) + record = ZenodoRecord("123", Path("/tmp"), "test") -@patch( - "brainles_preprocessing.utils.zenodo._get_latest_version_folder_name", - return_value=None, -) -@patch("brainles_preprocessing.utils.zenodo._get_zenodo_metadata_and_archive_url") -@patch("brainles_preprocessing.utils.zenodo._download_atlases") -def test_verify_or_download_atlases_no_local( - mock_download, mock_zenodo_meta, mock_atlases_folder -): - mock_zenodo_meta.return_value = ({"version": "1.0.0"}, "https://fakeurl.com") - mock_download.return_value = mock_atlases_folder / "atlases_v1.0.0" + metadata, url = record._get_metadata_and_archive_url() - atlases_path = verify_or_download_atlases() - assert atlases_path == mock_atlases_folder / "atlases_v1.0.0" + assert metadata == dummy_metadata + assert url == dummy_archive_url + + +def test_get_metadata_and_archive_url_failure(monkeypatch): + response_mock = MagicMock() + response_mock.status_code = 404 + + monkeypatch.setattr("requests.get", lambda *args, **kwargs: response_mock) + record = ZenodoRecord("invalid", Path("/tmp"), "test") + + with pytest.raises(ZenodoException): + record._get_metadata_and_archive_url() @patch( - "brainles_preprocessing.utils.zenodo._get_latest_version_folder_name", - return_value="atlases_v1.0.0", + "requests.get", side_effect=requests.exceptions.RequestException("Connection error") ) -@patch("brainles_preprocessing.utils.zenodo.logger.info") -@patch("brainles_preprocessing.utils.zenodo._get_zenodo_metadata_and_archive_url") -def test_verify_or_download_atlases_latest_local( - mock_zenodo_meta, mock_logger_info, mock_atlases_folder +def test_get_metadata_and_archive_url_connection_error(mock_get): + record = ZenodoRecord("123", Path("/tmp"), "test") + + assert record._get_metadata_and_archive_url() is None + + +# ---- Tests for _get_latest_version_folder_name ---- + + +def test_get_latest_version_folder_name(tmp_path): + folder = tmp_path / "123_v1.2.3" + folder.mkdir() + (folder / "dummy.txt").touch() + + record = ZenodoRecord("123", tmp_path, "test") + result = record._get_latest_version_folder_name(list(tmp_path.glob("*"))) + + assert result == "123_v1.2.3" + + +def test_get_latest_version_folder_name_empty(tmp_path): + record = ZenodoRecord("123", tmp_path, "test") + assert record._get_latest_version_folder_name([]) is None + + +def test_get_latest_version_folder_name_ignores_empty_folder(tmp_path): + folder = tmp_path / "123_v1.2.3" + folder.mkdir() + record = ZenodoRecord("123", tmp_path, "test") + assert record._get_latest_version_folder_name([folder]) is None + + +# ---- Tests for fetch() method ---- + + +@patch.object(ZenodoRecord, "_get_metadata_and_archive_url") +@patch.object(ZenodoRecord, "_download") +def test_fetch_downloads_new_if_no_local( + mock_download, mock_metadata, tmp_path, dummy_zenodo_response ): - mock_zenodo_meta.return_value = ({"version": "1.0.0"}, "https://fakeurl.com") + mock_metadata.return_value = dummy_zenodo_response + mock_download.return_value = tmp_path / "123_v1.2.3" - atlases_path = verify_or_download_atlases() - assert atlases_path == ATLASES_FOLDER / "atlases_v1.0.0" - mock_logger_info.assert_called_with(f"Latest atlases (1.0.0) are already present.") + record = ZenodoRecord("123", tmp_path, "test") + result = record.fetch() + assert result.name == "123_v1.2.3" + mock_download.assert_called_once() -@patch("brainles_preprocessing.utils.zenodo.shutil.rmtree") -@patch( - "brainles_preprocessing.utils.zenodo._get_latest_version_folder_name", - return_value="atlases_v1.0.0", -) -@patch("brainles_preprocessing.utils.zenodo.logger.info", return_value=None) -@patch("brainles_preprocessing.utils.zenodo._get_zenodo_metadata_and_archive_url") -@patch("brainles_preprocessing.utils.zenodo._download_atlases") -def test_verify_or_download_atlases_old_local( - mock_download, - mock_zenodo_meta, - mock_logger_info, - mock_shutil_rmtree, - mock_atlases_folder, + +@patch.object(ZenodoRecord, "_get_metadata_and_archive_url", return_value=None) +def test_fetch_zenodo_unreachable_raises(mock_metadata, tmp_path): + record = ZenodoRecord("123", tmp_path, "test") + + with pytest.raises(ZenodoException): + record.fetch() + + +@patch.object(ZenodoRecord, "_get_metadata_and_archive_url") +@patch.object(ZenodoRecord, "_download") +def test_fetch_skips_if_latest_present( + mock_download, mock_metadata, tmp_path, dummy_zenodo_response ): - mock_zenodo_meta.return_value = ({"version": "2.0.0"}, "https://fakeurl.com") - mock_download.return_value = mock_atlases_folder / "atlases_v2.0.0" + local_folder = tmp_path / "123_v1.2.3" + local_folder.mkdir() + (local_folder / "dummy.txt").touch() - atlases_path = verify_or_download_atlases() + mock_metadata.return_value = dummy_zenodo_response + record = ZenodoRecord("123", tmp_path, "test") - mock_logger_info.assert_called_with( - "New atlases available on Zenodo (2.0.0). Deleting old and fetching new atlases..." - ) + result = record.fetch() + assert result.name == "123_v1.2.3" + mock_download.assert_not_called() - mock_shutil_rmtree.assert_called_once() - assert atlases_path == mock_atlases_folder / "atlases_v2.0.0" +@patch.object(ZenodoRecord, "_get_metadata_and_archive_url") +@patch.object(ZenodoRecord, "_download") +def test_fetch_replaces_old_version( + mock_download, mock_metadata, tmp_path, dummy_zenodo_response +): + old_folder = tmp_path / "123_v1.0.0" + old_folder.mkdir() + (old_folder / "dummy.txt").touch() + mock_metadata.return_value = dummy_zenodo_response + mock_download.return_value = tmp_path / "123_v1.2.3" -@patch("brainles_preprocessing.utils.zenodo._extract_archive") -@patch("brainles_preprocessing.utils.zenodo.requests.get") -def test_download_atlases(mock_get, mock_extract_archive, mock_zenodo_metadata): - mock_response = MagicMock() - mock_response.status_code = 200 - mock_response.iter_content = lambda chunk_size: [b"data"] - mock_get.return_value = mock_response + record = ZenodoRecord("123", tmp_path, "test") - atlases_path = _download_atlases(mock_zenodo_metadata[0], mock_zenodo_metadata[1]) - assert atlases_path.exists() + result = record.fetch() + assert result.name == "123_v1.2.3" + assert not old_folder.exists() + mock_download.assert_called_once() -@patch("brainles_preprocessing.utils.zenodo.zipfile.ZipFile") -def test_extract_archive(mock_zipfile, tmp_path): - mock_response = MagicMock() - mock_response.iter_content = lambda chunk_size: [b"data"] - record_folder = tmp_path / "atlases_v1.0.0" - record_folder.mkdir() +# ---- fetch_atlases and fetch_synthstrip ---- - dummy_zip = record_folder / "archive.zip" - dummy_zip.touch() - mock_zip = MagicMock() - mock_zip.namelist.return_value = ["file1.txt", "file2.txt"] - mock_zip.__enter__.return_value = mock_zip - mock_zipfile.return_value = mock_zip +@patch("brainles_preprocessing.utils.zenodo.ZenodoRecord.fetch") +def test_fetch_atlases_calls_fetch(mock_fetch): + mock_fetch.return_value = Path("/fake/path") + result = fetch_atlases() + assert result == Path("/fake/path") + mock_fetch.assert_called_once() - _extract_archive(mock_response, record_folder) - mock_zip.extract.assert_called() +@patch("brainles_preprocessing.utils.zenodo.ZenodoRecord.fetch") +def test_fetch_synthstrip_calls_fetch(mock_fetch): + mock_fetch.return_value = Path("/fake/path") + result = fetch_synthstrip() + assert result == Path("/fake/path") + mock_fetch.assert_called_once()