Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 84 additions & 0 deletions monai/transforms/compose.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,6 +255,90 @@ def __init__(
self.set_random_state(seed=get_seed())
self.overrides = overrides

# Automatically assign group ID to child transforms for inversion tracking
self._set_transform_groups()

def _set_transform_groups(self):
"""
Automatically set group IDs on child transforms for inversion tracking.

This allows Invertd to identify which transforms belong to this
``Compose`` instance, including wrapped transforms (for example,
array transforms inside dictionary transforms).

Args:
None.

Returns:
None.
"""
from monai.transforms.inverse import TraceableTransform

group_id = str(id(self))
visited = set() # Track visited objects to avoid infinite recursion

def set_group_recursive(obj, gid, allow_compose: bool = False):
"""
Recursively set a group ID on a transform and its wrapped transforms.

Args:
obj: Transform instance to process.
gid: Group identifier to assign.
allow_compose: Whether to set group on ``Compose`` instances.
``Compose`` internals are not traversed to preserve nested
pipeline boundaries.

Returns:
None.
"""
if obj is None or isinstance(obj, (bool, int, float, str, bytes)):
return

# Avoid infinite recursion
obj_id = id(obj)
if obj_id in visited:
return
visited.add(obj_id)

if isinstance(obj, Compose):
if allow_compose:
obj._group = gid
return

if isinstance(obj, TraceableTransform):
obj._group = gid

if isinstance(obj, Mapping):
for attr in obj.values():
set_group_recursive(attr, gid)
return

if isinstance(obj, (list, tuple, set)):
for attr in obj:
set_group_recursive(attr, gid)
return

attrs: list[Any] = []
if hasattr(obj, "__dict__"):
attrs.extend(vars(obj).values())

slots = getattr(type(obj), "__slots__", ())
if isinstance(slots, str):
slots = (slots,)
for slot in slots:
if slot.startswith("__"):
continue
try:
attrs.append(getattr(obj, slot))
except AttributeError:
continue

for attr in attrs:
set_group_recursive(attr, gid)

for transform in self.transforms:
set_group_recursive(transform, group_id, allow_compose=True)

@LazyTransform.lazy.setter # type: ignore
def lazy(self, val: bool):
self._lazy = val
Expand Down
15 changes: 14 additions & 1 deletion monai/transforms/inverse.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,10 @@ def _init_trace_threadlocal(self):
if not hasattr(self._tracing, "value"):
self._tracing.value = MONAIEnvVars.trace_transform() != "0"

# Initialize group identifier (set by Compose for automatic group tracking)
if not hasattr(self, "_group"):
self._group: str | None = None

def __getstate__(self):
"""When pickling, remove the `_tracing` member from the output, if present, since it's not picklable."""
_dict = dict(getattr(self, "__dict__", {})) # this makes __dict__ always present in the unpickled object
Expand Down Expand Up @@ -119,13 +123,22 @@ def get_transform_info(self) -> dict:
"""
Return a dictionary with the relevant information pertaining to an applied transform.
"""
# Ensure _group is initialized
self._init_trace_threadlocal()

vals = (
self.__class__.__name__,
id(self),
self.tracing,
self._do_transform if hasattr(self, "_do_transform") else True,
)
return dict(zip(self.transform_info_keys(), vals))
info = dict(zip(self.transform_info_keys(), vals))

# Add group if set (automatically set by Compose)
if self._group is not None:
info[TraceKeys.GROUP] = self._group

return info

def push_transform(self, data, *args, **kwargs):
"""
Expand Down
28 changes: 27 additions & 1 deletion monai/transforms/post/dictionary.py
Original file line number Diff line number Diff line change
Expand Up @@ -859,6 +859,29 @@ def __init__(
self.post_func = ensure_tuple_rep(post_func, len(self.keys))
self._totensor = ToTensor()

def _filter_transforms_by_group(self, all_transforms: list[dict]) -> list[dict]:
"""
Filter applied_operations to only include transforms from the target Compose instance.
Uses automatic group tracking where Compose assigns its ID to child transforms.
"""
from monai.utils import TraceKeys

# Get the group ID of the transform (Compose instance)
target_group = str(id(self.transform))

# Filter transforms that match the target group
filtered = []
for xform in all_transforms:
xform_group = xform.get(TraceKeys.GROUP)
if xform_group == target_group:
filtered.append(xform)

# If no transforms match (backward compatibility), return all transforms
if not filtered:
return all_transforms

return filtered

def __call__(self, data: Mapping[Hashable, Any]) -> dict[Hashable, Any]:
d = dict(data)
for (
Expand Down Expand Up @@ -894,8 +917,11 @@ def __call__(self, data: Mapping[Hashable, Any]) -> dict[Hashable, Any]:

orig_meta_key = orig_meta_key or f"{orig_key}_{meta_key_postfix}"
if orig_key in d and isinstance(d[orig_key], MetaTensor):
transform_info = d[orig_key].applied_operations
all_transforms = d[orig_key].applied_operations
meta_info = d[orig_key].meta

# Automatically filter by Compose instance group ID
transform_info = self._filter_transforms_by_group(all_transforms)
else:
transform_info = d[InvertibleTransform.trace_key(orig_key)]
meta_info = d.get(orig_meta_key, {})
Expand Down
1 change: 1 addition & 0 deletions monai/utils/enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,6 +334,7 @@ class TraceKeys(StrEnum):
TRACING: str = "tracing"
STATUSES: str = "statuses"
LAZY: str = "lazy"
GROUP: str = "group"


class TraceStatusKeys(StrEnum):
Expand Down
37 changes: 37 additions & 0 deletions tests/transforms/compose/test_compose.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,6 +268,43 @@ def test_data_loader_2(self):
self.assertAlmostEqual(out_1.cpu().item(), 0.28602141572)
set_determinism(None)

def test_set_transform_groups_on_wrapped_transform_attributes(self):
class _IdentityInvertible(mt.InvertibleTransform):
def __call__(self, data):
return data

def inverse(self, data):
return data

class _WrapperWithTransform:
def __init__(self):
self.transform = _IdentityInvertible()

def __call__(self, data):
return self.transform(data)

class _WrapperWithTransforms:
def __init__(self):
self.transforms = [_IdentityInvertible(), {"inner": _IdentityInvertible()}]

def __call__(self, data):
for transform in self.transforms:
if isinstance(transform, dict):
for nested_transform in transform.values():
data = nested_transform(data)
else:
data = transform(data)
return data

wrapped_transform = _WrapperWithTransform()
wrapped_transforms = _WrapperWithTransforms()
composed = mt.Compose([wrapped_transform, wrapped_transforms])
expected_group = str(id(composed))

self.assertEqual(getattr(wrapped_transform.transform, "_group", None), expected_group)
self.assertEqual(getattr(wrapped_transforms.transforms[0], "_group", None), expected_group)
self.assertEqual(getattr(wrapped_transforms.transforms[1]["inner"], "_group", None), expected_group)

def test_flatten_and_len(self):
x = mt.EnsureChannelFirst(channel_dim="no_channel")
t1 = mt.Compose([x, x, x, x, mt.Compose([mt.Compose([x, x]), x, x])])
Expand Down
Loading
Loading