diff --git a/monai/apps/reconstruction/transforms/dictionary.py b/monai/apps/reconstruction/transforms/dictionary.py index c166740768..0185974040 100644 --- a/monai/apps/reconstruction/transforms/dictionary.py +++ b/monai/apps/reconstruction/transforms/dictionary.py @@ -37,16 +37,19 @@ class ExtractDataKeyFromMetaKeyd(MapTransform): keys: keys to be transferred from meta to data meta_key: the meta key where all the meta-data is stored allow_missing_keys: don't raise exception if key is missing + image_only: if True, only extract metadata from MetaTensor images to avoid duplication Example: When the fastMRI dataset is loaded, "kspace" is stored in the data dictionary, but the ground-truth image with the key "reconstruction_rss" is stored in the meta data. In this case, ExtractDataKeyFromMetaKeyd moves "reconstruction_rss" to data. + For MetaTensor objects, setting image_only=True prevents extracting redundant metadata. """ - def __init__(self, keys: KeysCollection, meta_key: str, allow_missing_keys: bool = False) -> None: + def __init__(self, keys: KeysCollection, meta_key: str, allow_missing_keys: bool = False, image_only: bool = False) -> None: MapTransform.__init__(self, keys, allow_missing_keys) self.meta_key = meta_key + self.image_only = image_only def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, Tensor]: """ @@ -60,7 +63,12 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, T d = dict(data) for key in self.keys: if key in d[self.meta_key]: - d[key] = d[self.meta_key][key] # type: ignore + extracted_value = d[self.meta_key][key] + # When image_only is True, skip if the extracted value is a MetaTensor + # to preserve metadata associations + if self.image_only and isinstance(extracted_value, MetaTensor): + continue + d[key] = extracted_value # type: ignore elif not self.allow_missing_keys: raise KeyError( f"Key `{key}` of transform `{self.__class__.__name__}` was missing in the meta data" diff --git a/monai/handlers/tensorboard_handlers.py b/monai/handlers/tensorboard_handlers.py index 20e2d74c8c..0a69ccfaba 100644 --- a/monai/handlers/tensorboard_handlers.py +++ b/monai/handlers/tensorboard_handlers.py @@ -50,6 +50,11 @@ class TensorBoardHandler: def __init__(self, summary_writer: SummaryWriter | SummaryWriterX | None = None, log_dir: str = "./runs"): if summary_writer is None: + if SummaryWriter is None: + raise RuntimeError( + "TensorBoardHandler requires tensorboard to be installed. " + "Please install it with: pip install tensorboard" + ) self._writer = SummaryWriter(log_dir=log_dir) self.internal_writer = True else: diff --git a/monai/transforms/croppad/functional.py b/monai/transforms/croppad/functional.py index 653db43bc5..572d01ea1d 100644 --- a/monai/transforms/croppad/functional.py +++ b/monai/transforms/croppad/functional.py @@ -96,12 +96,7 @@ def pad_nd( return _np_pad(img, pad_width=to_pad, mode=mode, **kwargs) try: _pad = _np_pad - if mode in {"constant", "reflect", "edge", "replicate", "wrap", "circular"} and img.dtype not in { - torch.int16, - torch.int64, - torch.bool, - torch.uint8, - }: + if mode in {"constant", "reflect", "edge", "replicate", "wrap", "circular"}: _pad = _pt_pad return _pad(img, pad_width=to_pad, mode=mode, **kwargs) except (ValueError, TypeError, RuntimeError) as err: