Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 13 additions & 1 deletion codeflash/verification/comparator.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import math
import re
import types
import weakref
from collections import ChainMap, OrderedDict, deque
from importlib.util import find_spec
from typing import Any, Optional
Expand Down Expand Up @@ -93,7 +94,7 @@ def _get_wrapped_exception(exc: BaseException) -> Optional[BaseException]: # no
return _extract_exception_from_message(str(exc))


def comparator(orig: Any, new: Any, superset_obj=False) -> bool:
def comparator(orig: Any, new: Any, superset_obj: bool = False) -> bool:
"""Compare two objects for equality recursively. If superset_obj is True, the new object is allowed to have more keys than the original object. However, the existing keys/values must be equivalent."""
try:
# Handle exceptions specially - before type check to allow wrapper comparison
Expand Down Expand Up @@ -171,6 +172,17 @@ def comparator(orig: Any, new: Any, superset_obj=False) -> bool:
return True
return math.isclose(orig, new)

# Handle weak references (e.g., found in torch.nn.LSTM/GRU modules)
if isinstance(orig, weakref.ref):
orig_referent = orig()
new_referent = new()
# Both dead refs are equal, otherwise compare referents
if orig_referent is None and new_referent is None:
return True
if orig_referent is None or new_referent is None:
return False
return comparator(orig_referent, new_referent, superset_obj)

if HAS_JAX:
import jax # type: ignore # noqa: PGH003
import jax.numpy as jnp # type: ignore # noqa: PGH003
Expand Down
Loading
Loading