From 961cca544a1894bef43a92c4ce7fa76890323fb6 Mon Sep 17 00:00:00 2001 From: aseembits93 Date: Fri, 6 Feb 2026 13:25:30 -0800 Subject: [PATCH 1/3] fix for weakref --- codeflash/verification/comparator.py | 12 + tests/test_comparator.py | 951 +++++++++++++++++++++++++++ 2 files changed, 963 insertions(+) diff --git a/codeflash/verification/comparator.py b/codeflash/verification/comparator.py index ad7c59ede..eb86c790a 100644 --- a/codeflash/verification/comparator.py +++ b/codeflash/verification/comparator.py @@ -6,6 +6,7 @@ import math import re import types +import weakref from collections import ChainMap, OrderedDict, deque from importlib.util import find_spec from typing import Any, Optional @@ -171,6 +172,17 @@ def comparator(orig: Any, new: Any, superset_obj=False) -> bool: return True return math.isclose(orig, new) + # Handle weak references (e.g., found in torch.nn.LSTM/GRU modules) + if isinstance(orig, weakref.ref): + orig_referent = orig() + new_referent = new() + # Both dead refs are equal, otherwise compare referents + if orig_referent is None and new_referent is None: + return True + if orig_referent is None or new_referent is None: + return False + return comparator(orig_referent, new_referent, superset_obj) + if HAS_JAX: import jax # type: ignore # noqa: PGH003 import jax.numpy as jnp # type: ignore # noqa: PGH003 diff --git a/tests/test_comparator.py b/tests/test_comparator.py index 753929843..30bbe8700 100644 --- a/tests/test_comparator.py +++ b/tests/test_comparator.py @@ -877,6 +877,957 @@ def test_torch_device(): assert not comparator(l, n) +def test_torch_nn_linear(): + """Test comparator for torch.nn.Linear modules.""" + try: + import torch + from torch import nn + except ImportError: + pytest.skip() + + # Test identical Linear layers + torch.manual_seed(42) + a = nn.Linear(10, 5) + torch.manual_seed(42) + b = nn.Linear(10, 5) + assert comparator(a, b) + + # Test Linear layers with different weights (different seeds) + torch.manual_seed(42) + c = nn.Linear(10, 5) + torch.manual_seed(123) + d = nn.Linear(10, 5) + assert not comparator(c, d) + + # Test Linear layers with different in_features + torch.manual_seed(42) + e = nn.Linear(10, 5) + torch.manual_seed(42) + f = nn.Linear(20, 5) + assert not comparator(e, f) + + # Test Linear layers with different out_features + torch.manual_seed(42) + g = nn.Linear(10, 5) + torch.manual_seed(42) + h = nn.Linear(10, 10) + assert not comparator(g, h) + + # Test Linear with and without bias + torch.manual_seed(42) + i = nn.Linear(10, 5, bias=True) + torch.manual_seed(42) + j = nn.Linear(10, 5, bias=False) + assert not comparator(i, j) + + # Test Linear layers in train vs eval mode + torch.manual_seed(42) + k = nn.Linear(10, 5) + k.train() + torch.manual_seed(42) + l = nn.Linear(10, 5) + l.eval() + assert not comparator(k, l) + + +def test_torch_nn_conv2d(): + """Test comparator for torch.nn.Conv2d modules.""" + try: + import torch + from torch import nn + except ImportError: + pytest.skip() + + # Test identical Conv2d layers + torch.manual_seed(42) + a = nn.Conv2d(3, 16, kernel_size=3) + torch.manual_seed(42) + b = nn.Conv2d(3, 16, kernel_size=3) + assert comparator(a, b) + + # Test Conv2d with different weights + torch.manual_seed(42) + c = nn.Conv2d(3, 16, kernel_size=3) + torch.manual_seed(123) + d = nn.Conv2d(3, 16, kernel_size=3) + assert not comparator(c, d) + + # Test Conv2d with different in_channels + torch.manual_seed(42) + e = nn.Conv2d(3, 16, kernel_size=3) + torch.manual_seed(42) + f = nn.Conv2d(1, 16, kernel_size=3) + assert not comparator(e, f) + + # Test Conv2d with different out_channels + torch.manual_seed(42) + g = nn.Conv2d(3, 16, kernel_size=3) + torch.manual_seed(42) + h = nn.Conv2d(3, 32, kernel_size=3) + assert not comparator(g, h) + + # Test Conv2d with different kernel_size + torch.manual_seed(42) + i = nn.Conv2d(3, 16, kernel_size=3) + torch.manual_seed(42) + j = nn.Conv2d(3, 16, kernel_size=5) + assert not comparator(i, j) + + # Test Conv2d with different stride + torch.manual_seed(42) + k = nn.Conv2d(3, 16, kernel_size=3, stride=1) + torch.manual_seed(42) + l = nn.Conv2d(3, 16, kernel_size=3, stride=2) + assert not comparator(k, l) + + # Test Conv2d with different padding + torch.manual_seed(42) + m = nn.Conv2d(3, 16, kernel_size=3, padding=0) + torch.manual_seed(42) + n = nn.Conv2d(3, 16, kernel_size=3, padding=1) + assert not comparator(m, n) + + +def test_torch_nn_batchnorm(): + """Test comparator for torch.nn.BatchNorm modules.""" + try: + import torch + from torch import nn + except ImportError: + pytest.skip() + + # Test identical BatchNorm2d layers + torch.manual_seed(42) + a = nn.BatchNorm2d(16) + torch.manual_seed(42) + b = nn.BatchNorm2d(16) + assert comparator(a, b) + + # Test BatchNorm2d with different num_features + torch.manual_seed(42) + c = nn.BatchNorm2d(16) + torch.manual_seed(42) + d = nn.BatchNorm2d(32) + assert not comparator(c, d) + + # Test BatchNorm2d with different eps + torch.manual_seed(42) + e = nn.BatchNorm2d(16, eps=1e-5) + torch.manual_seed(42) + f = nn.BatchNorm2d(16, eps=1e-3) + assert not comparator(e, f) + + # Test BatchNorm2d with different momentum + torch.manual_seed(42) + g = nn.BatchNorm2d(16, momentum=0.1) + torch.manual_seed(42) + h = nn.BatchNorm2d(16, momentum=0.01) + assert not comparator(g, h) + + # Test BatchNorm2d with and without affine + torch.manual_seed(42) + i = nn.BatchNorm2d(16, affine=True) + torch.manual_seed(42) + j = nn.BatchNorm2d(16, affine=False) + assert not comparator(i, j) + + # Test BatchNorm2d running stats after forward passes + torch.manual_seed(42) + k = nn.BatchNorm2d(16) + k.train() + input_k = torch.randn(4, 16, 8, 8) + _ = k(input_k) + torch.manual_seed(42) + l = nn.BatchNorm2d(16) + l.train() + input_l = torch.randn(4, 16, 8, 8) + _ = l(input_l) + # Same seed means same running stats + assert comparator(k, l) + + # Test BatchNorm2d with different running stats + torch.manual_seed(42) + m = nn.BatchNorm2d(16) + m.train() + torch.manual_seed(42) + _ = m(torch.randn(4, 16, 8, 8)) + torch.manual_seed(42) + n = nn.BatchNorm2d(16) + n.train() + torch.manual_seed(123) + _ = n(torch.randn(4, 16, 8, 8)) + assert not comparator(m, n) + + # Test BatchNorm1d + torch.manual_seed(42) + o = nn.BatchNorm1d(16) + torch.manual_seed(42) + p = nn.BatchNorm1d(16) + assert comparator(o, p) + + +def test_torch_nn_dropout(): + """Test comparator for torch.nn.Dropout modules.""" + try: + import torch + from torch import nn + except ImportError: + pytest.skip() + + # Test identical Dropout layers + a = nn.Dropout(p=0.5) + b = nn.Dropout(p=0.5) + assert comparator(a, b) + + # Test Dropout with different p values + c = nn.Dropout(p=0.5) + d = nn.Dropout(p=0.3) + assert not comparator(c, d) + + # Test Dropout with different inplace values + e = nn.Dropout(p=0.5, inplace=False) + f = nn.Dropout(p=0.5, inplace=True) + assert not comparator(e, f) + + # Test Dropout2d + g = nn.Dropout2d(p=0.5) + h = nn.Dropout2d(p=0.5) + assert comparator(g, h) + + # Test Dropout vs Dropout2d (different types) + i = nn.Dropout(p=0.5) + j = nn.Dropout2d(p=0.5) + assert not comparator(i, j) + + +def test_torch_nn_activation(): + """Test comparator for torch.nn activation modules.""" + try: + import torch + from torch import nn + except ImportError: + pytest.skip() + + # Test ReLU + a = nn.ReLU() + b = nn.ReLU() + assert comparator(a, b) + + # Test ReLU with different inplace + c = nn.ReLU(inplace=False) + d = nn.ReLU(inplace=True) + assert not comparator(c, d) + + # Test LeakyReLU + e = nn.LeakyReLU(negative_slope=0.01) + f = nn.LeakyReLU(negative_slope=0.01) + assert comparator(e, f) + + # Test LeakyReLU with different negative_slope + g = nn.LeakyReLU(negative_slope=0.01) + h = nn.LeakyReLU(negative_slope=0.1) + assert not comparator(g, h) + + # Test Sigmoid vs ReLU (different types) + i = nn.Sigmoid() + j = nn.ReLU() + assert not comparator(i, j) + + # Test GELU + k = nn.GELU() + l = nn.GELU() + assert comparator(k, l) + + # Test Softmax + m = nn.Softmax(dim=1) + n = nn.Softmax(dim=1) + assert comparator(m, n) + + # Test Softmax with different dim + o = nn.Softmax(dim=1) + p = nn.Softmax(dim=0) + assert not comparator(o, p) + + +def test_torch_nn_pooling(): + """Test comparator for torch.nn pooling modules.""" + try: + import torch + from torch import nn + except ImportError: + pytest.skip() + + # Test MaxPool2d + a = nn.MaxPool2d(kernel_size=2) + b = nn.MaxPool2d(kernel_size=2) + assert comparator(a, b) + + # Test MaxPool2d with different kernel_size + c = nn.MaxPool2d(kernel_size=2) + d = nn.MaxPool2d(kernel_size=3) + assert not comparator(c, d) + + # Test MaxPool2d with different stride + e = nn.MaxPool2d(kernel_size=2, stride=2) + f = nn.MaxPool2d(kernel_size=2, stride=1) + assert not comparator(e, f) + + # Test AvgPool2d + g = nn.AvgPool2d(kernel_size=2) + h = nn.AvgPool2d(kernel_size=2) + assert comparator(g, h) + + # Test MaxPool2d vs AvgPool2d (different types) + i = nn.MaxPool2d(kernel_size=2) + j = nn.AvgPool2d(kernel_size=2) + assert not comparator(i, j) + + # Test AdaptiveAvgPool2d + k = nn.AdaptiveAvgPool2d(output_size=(1, 1)) + l = nn.AdaptiveAvgPool2d(output_size=(1, 1)) + assert comparator(k, l) + + # Test AdaptiveAvgPool2d with different output_size + m = nn.AdaptiveAvgPool2d(output_size=(1, 1)) + n = nn.AdaptiveAvgPool2d(output_size=(2, 2)) + assert not comparator(m, n) + + +def test_torch_nn_embedding(): + """Test comparator for torch.nn.Embedding modules.""" + try: + import torch + from torch import nn + except ImportError: + pytest.skip() + + # Test identical Embedding layers + torch.manual_seed(42) + a = nn.Embedding(1000, 128) + torch.manual_seed(42) + b = nn.Embedding(1000, 128) + assert comparator(a, b) + + # Test Embedding with different weights + torch.manual_seed(42) + c = nn.Embedding(1000, 128) + torch.manual_seed(123) + d = nn.Embedding(1000, 128) + assert not comparator(c, d) + + # Test Embedding with different num_embeddings + torch.manual_seed(42) + e = nn.Embedding(1000, 128) + torch.manual_seed(42) + f = nn.Embedding(2000, 128) + assert not comparator(e, f) + + # Test Embedding with different embedding_dim + torch.manual_seed(42) + g = nn.Embedding(1000, 128) + torch.manual_seed(42) + h = nn.Embedding(1000, 256) + assert not comparator(g, h) + + # Test Embedding with different padding_idx + torch.manual_seed(42) + i = nn.Embedding(1000, 128, padding_idx=0) + torch.manual_seed(42) + j = nn.Embedding(1000, 128, padding_idx=1) + assert not comparator(i, j) + + # Test Embedding with and without padding_idx + torch.manual_seed(42) + k = nn.Embedding(1000, 128) + torch.manual_seed(42) + l = nn.Embedding(1000, 128, padding_idx=0) + assert not comparator(k, l) + + +def test_torch_nn_lstm(): + """Test comparator for torch.nn.LSTM modules.""" + try: + import torch + from torch import nn + except ImportError: + pytest.skip() + + # Test identical LSTM layers + torch.manual_seed(42) + a = nn.LSTM(input_size=10, hidden_size=20, num_layers=2) + torch.manual_seed(42) + b = nn.LSTM(input_size=10, hidden_size=20, num_layers=2) + assert comparator(a, b) + + # Test LSTM with different weights + torch.manual_seed(42) + c = nn.LSTM(input_size=10, hidden_size=20, num_layers=2) + torch.manual_seed(123) + d = nn.LSTM(input_size=10, hidden_size=20, num_layers=2) + assert not comparator(c, d) + + # Test LSTM with different input_size + torch.manual_seed(42) + e = nn.LSTM(input_size=10, hidden_size=20) + torch.manual_seed(42) + f = nn.LSTM(input_size=20, hidden_size=20) + assert not comparator(e, f) + + # Test LSTM with different hidden_size + torch.manual_seed(42) + g = nn.LSTM(input_size=10, hidden_size=20) + torch.manual_seed(42) + h = nn.LSTM(input_size=10, hidden_size=40) + assert not comparator(g, h) + + # Test LSTM with different num_layers + torch.manual_seed(42) + i = nn.LSTM(input_size=10, hidden_size=20, num_layers=1) + torch.manual_seed(42) + j = nn.LSTM(input_size=10, hidden_size=20, num_layers=2) + assert not comparator(i, j) + + # Test LSTM with different bidirectional + torch.manual_seed(42) + k = nn.LSTM(input_size=10, hidden_size=20, bidirectional=False) + torch.manual_seed(42) + l = nn.LSTM(input_size=10, hidden_size=20, bidirectional=True) + assert not comparator(k, l) + + # Test LSTM with different batch_first + torch.manual_seed(42) + m = nn.LSTM(input_size=10, hidden_size=20, batch_first=False) + torch.manual_seed(42) + n = nn.LSTM(input_size=10, hidden_size=20, batch_first=True) + assert not comparator(m, n) + + +def test_torch_nn_gru(): + """Test comparator for torch.nn.GRU modules.""" + try: + import torch + from torch import nn + except ImportError: + pytest.skip() + + # Test identical GRU layers + torch.manual_seed(42) + a = nn.GRU(input_size=10, hidden_size=20, num_layers=2) + torch.manual_seed(42) + b = nn.GRU(input_size=10, hidden_size=20, num_layers=2) + assert comparator(a, b) + + # Test GRU with different hidden_size + torch.manual_seed(42) + c = nn.GRU(input_size=10, hidden_size=20) + torch.manual_seed(42) + d = nn.GRU(input_size=10, hidden_size=40) + assert not comparator(c, d) + + # Test GRU vs LSTM (different types) + torch.manual_seed(42) + e = nn.GRU(input_size=10, hidden_size=20) + torch.manual_seed(42) + f = nn.LSTM(input_size=10, hidden_size=20) + assert not comparator(e, f) + + +def test_torch_nn_layernorm(): + """Test comparator for torch.nn.LayerNorm modules.""" + try: + import torch + from torch import nn + except ImportError: + pytest.skip() + + # Test identical LayerNorm layers + torch.manual_seed(42) + a = nn.LayerNorm(normalized_shape=[10]) + torch.manual_seed(42) + b = nn.LayerNorm(normalized_shape=[10]) + assert comparator(a, b) + + # Test LayerNorm with different normalized_shape + torch.manual_seed(42) + c = nn.LayerNorm(normalized_shape=[10]) + torch.manual_seed(42) + d = nn.LayerNorm(normalized_shape=[20]) + assert not comparator(c, d) + + # Test LayerNorm with different eps + torch.manual_seed(42) + e = nn.LayerNorm(normalized_shape=[10], eps=1e-5) + torch.manual_seed(42) + f = nn.LayerNorm(normalized_shape=[10], eps=1e-3) + assert not comparator(e, f) + + # Test LayerNorm with and without elementwise_affine + torch.manual_seed(42) + g = nn.LayerNorm(normalized_shape=[10], elementwise_affine=True) + torch.manual_seed(42) + h = nn.LayerNorm(normalized_shape=[10], elementwise_affine=False) + assert not comparator(g, h) + + +def test_torch_nn_multihead_attention(): + """Test comparator for torch.nn.MultiheadAttention modules.""" + try: + import torch + from torch import nn + except ImportError: + pytest.skip() + + # Test identical MultiheadAttention layers + torch.manual_seed(42) + a = nn.MultiheadAttention(embed_dim=64, num_heads=8) + torch.manual_seed(42) + b = nn.MultiheadAttention(embed_dim=64, num_heads=8) + assert comparator(a, b) + + # Test MultiheadAttention with different weights + torch.manual_seed(42) + c = nn.MultiheadAttention(embed_dim=64, num_heads=8) + torch.manual_seed(123) + d = nn.MultiheadAttention(embed_dim=64, num_heads=8) + assert not comparator(c, d) + + # Test MultiheadAttention with different embed_dim + torch.manual_seed(42) + e = nn.MultiheadAttention(embed_dim=64, num_heads=8) + torch.manual_seed(42) + f = nn.MultiheadAttention(embed_dim=128, num_heads=8) + assert not comparator(e, f) + + # Test MultiheadAttention with different num_heads + torch.manual_seed(42) + g = nn.MultiheadAttention(embed_dim=64, num_heads=8) + torch.manual_seed(42) + h = nn.MultiheadAttention(embed_dim=64, num_heads=4) + assert not comparator(g, h) + + # Test MultiheadAttention with different dropout + torch.manual_seed(42) + i = nn.MultiheadAttention(embed_dim=64, num_heads=8, dropout=0.0) + torch.manual_seed(42) + j = nn.MultiheadAttention(embed_dim=64, num_heads=8, dropout=0.1) + assert not comparator(i, j) + + +def test_torch_nn_sequential(): + """Test comparator for torch.nn.Sequential modules.""" + try: + import torch + from torch import nn + except ImportError: + pytest.skip() + + # Test identical Sequential modules + torch.manual_seed(42) + a = nn.Sequential( + nn.Linear(10, 20), + nn.ReLU(), + nn.Linear(20, 5) + ) + torch.manual_seed(42) + b = nn.Sequential( + nn.Linear(10, 20), + nn.ReLU(), + nn.Linear(20, 5) + ) + assert comparator(a, b) + + # Test Sequential with different weights + torch.manual_seed(42) + c = nn.Sequential( + nn.Linear(10, 20), + nn.ReLU(), + nn.Linear(20, 5) + ) + torch.manual_seed(123) + d = nn.Sequential( + nn.Linear(10, 20), + nn.ReLU(), + nn.Linear(20, 5) + ) + assert not comparator(c, d) + + # Test Sequential with different number of layers + torch.manual_seed(42) + e = nn.Sequential( + nn.Linear(10, 20), + nn.ReLU() + ) + torch.manual_seed(42) + f = nn.Sequential( + nn.Linear(10, 20), + nn.ReLU(), + nn.Linear(20, 5) + ) + assert not comparator(e, f) + + # Test Sequential with different layer types + torch.manual_seed(42) + g = nn.Sequential( + nn.Linear(10, 20), + nn.ReLU() + ) + torch.manual_seed(42) + h = nn.Sequential( + nn.Linear(10, 20), + nn.Sigmoid() + ) + assert not comparator(g, h) + + +def test_torch_nn_modulelist(): + """Test comparator for torch.nn.ModuleList modules.""" + try: + import torch + from torch import nn + except ImportError: + pytest.skip() + + # Test identical ModuleList + torch.manual_seed(42) + a = nn.ModuleList([nn.Linear(10, 10) for _ in range(3)]) + torch.manual_seed(42) + b = nn.ModuleList([nn.Linear(10, 10) for _ in range(3)]) + assert comparator(a, b) + + # Test ModuleList with different number of modules + torch.manual_seed(42) + c = nn.ModuleList([nn.Linear(10, 10) for _ in range(3)]) + torch.manual_seed(42) + d = nn.ModuleList([nn.Linear(10, 10) for _ in range(4)]) + assert not comparator(c, d) + + +def test_torch_nn_moduledict(): + """Test comparator for torch.nn.ModuleDict modules.""" + try: + import torch + from torch import nn + except ImportError: + pytest.skip() + + # Test identical ModuleDict + torch.manual_seed(42) + a = nn.ModuleDict({ + "fc1": nn.Linear(10, 20), + "fc2": nn.Linear(20, 5) + }) + torch.manual_seed(42) + b = nn.ModuleDict({ + "fc1": nn.Linear(10, 20), + "fc2": nn.Linear(20, 5) + }) + assert comparator(a, b) + + # Test ModuleDict with different keys + torch.manual_seed(42) + c = nn.ModuleDict({ + "fc1": nn.Linear(10, 20), + "fc2": nn.Linear(20, 5) + }) + torch.manual_seed(42) + d = nn.ModuleDict({ + "layer1": nn.Linear(10, 20), + "layer2": nn.Linear(20, 5) + }) + assert not comparator(c, d) + + +def test_torch_nn_custom_module(): + """Test comparator for custom torch.nn.Module subclasses.""" + try: + import torch + from torch import nn + except ImportError: + pytest.skip() + + class SimpleNet(nn.Module): + def __init__(self, hidden_size): + super().__init__() + self.fc1 = nn.Linear(10, hidden_size) + self.relu = nn.ReLU() + self.fc2 = nn.Linear(hidden_size, 5) + + def forward(self, x): + x = self.fc1(x) + x = self.relu(x) + x = self.fc2(x) + return x + + # Test identical custom modules + torch.manual_seed(42) + a = SimpleNet(hidden_size=20) + torch.manual_seed(42) + b = SimpleNet(hidden_size=20) + assert comparator(a, b) + + # Test custom modules with different weights + torch.manual_seed(42) + c = SimpleNet(hidden_size=20) + torch.manual_seed(123) + d = SimpleNet(hidden_size=20) + assert not comparator(c, d) + + # Test custom modules with different architecture + torch.manual_seed(42) + e = SimpleNet(hidden_size=20) + torch.manual_seed(42) + f = SimpleNet(hidden_size=40) + assert not comparator(e, f) + + +def test_torch_nn_nested_modules(): + """Test comparator for nested torch.nn.Module structures.""" + try: + import torch + from torch import nn + except ImportError: + pytest.skip() + + class EncoderBlock(nn.Module): + def __init__(self, in_channels, out_channels): + super().__init__() + self.conv = nn.Conv2d(in_channels, out_channels, 3, padding=1) + self.bn = nn.BatchNorm2d(out_channels) + self.relu = nn.ReLU() + + def forward(self, x): + return self.relu(self.bn(self.conv(x))) + + class Encoder(nn.Module): + def __init__(self): + super().__init__() + self.block1 = EncoderBlock(3, 16) + self.block2 = EncoderBlock(16, 32) + self.pool = nn.MaxPool2d(2) + + def forward(self, x): + x = self.block1(x) + x = self.pool(x) + x = self.block2(x) + x = self.pool(x) + return x + + # Test identical nested modules + torch.manual_seed(42) + a = Encoder() + torch.manual_seed(42) + b = Encoder() + assert comparator(a, b) + + # Test nested modules with different weights + torch.manual_seed(42) + c = Encoder() + torch.manual_seed(123) + d = Encoder() + assert not comparator(c, d) + + +def test_torch_nn_transformer(): + """Test comparator for torch.nn.Transformer modules.""" + try: + import torch + from torch import nn + except ImportError: + pytest.skip() + + # Test identical Transformer + torch.manual_seed(42) + a = nn.Transformer(d_model=64, nhead=4, num_encoder_layers=2, num_decoder_layers=2) + torch.manual_seed(42) + b = nn.Transformer(d_model=64, nhead=4, num_encoder_layers=2, num_decoder_layers=2) + assert comparator(a, b) + + # Test Transformer with different d_model + torch.manual_seed(42) + c = nn.Transformer(d_model=64, nhead=4) + torch.manual_seed(42) + d = nn.Transformer(d_model=128, nhead=4) + assert not comparator(c, d) + + # Test Transformer with different nhead + torch.manual_seed(42) + e = nn.Transformer(d_model=64, nhead=4) + torch.manual_seed(42) + f = nn.Transformer(d_model=64, nhead=8) + assert not comparator(e, f) + + # Test TransformerEncoder + torch.manual_seed(42) + encoder_layer_a = nn.TransformerEncoderLayer(d_model=64, nhead=4) + g = nn.TransformerEncoder(encoder_layer_a, num_layers=2) + torch.manual_seed(42) + encoder_layer_b = nn.TransformerEncoderLayer(d_model=64, nhead=4) + h = nn.TransformerEncoder(encoder_layer_b, num_layers=2) + assert comparator(g, h) + + +def test_torch_nn_parameter_buffer_modification(): + """Test comparator detects parameter and buffer modifications.""" + try: + import torch + from torch import nn + except ImportError: + pytest.skip() + + # Test that modifying a parameter is detected + torch.manual_seed(42) + a = nn.Linear(10, 5) + torch.manual_seed(42) + b = nn.Linear(10, 5) + assert comparator(a, b) + + # Modify a parameter + with torch.no_grad(): + b.weight[0, 0] = 999.0 + assert not comparator(a, b) + + # Test that modifying a buffer is detected (BatchNorm running_mean) + torch.manual_seed(42) + c = nn.BatchNorm2d(16) + torch.manual_seed(42) + d = nn.BatchNorm2d(16) + assert comparator(c, d) + + # Modify a buffer + d.running_mean[0] = 999.0 + assert not comparator(c, d) + + +def test_torch_nn_device_placement(): + """Test comparator handles modules on different devices.""" + try: + import torch + from torch import nn + except ImportError: + pytest.skip() + + # Create modules on CPU + torch.manual_seed(42) + cpu_module = nn.Linear(10, 5) + torch.manual_seed(42) + cpu_module2 = nn.Linear(10, 5) + assert comparator(cpu_module, cpu_module2) + + # If CUDA is available, test device mismatch + if torch.cuda.is_available(): + torch.manual_seed(42) + cpu_mod = nn.Linear(10, 5) + torch.manual_seed(42) + cuda_mod = nn.Linear(10, 5).cuda() + assert not comparator(cpu_mod, cuda_mod) + + +def test_torch_nn_conv1d_conv3d(): + """Test comparator for Conv1d and Conv3d modules.""" + try: + import torch + from torch import nn + except ImportError: + pytest.skip() + + # Test Conv1d + torch.manual_seed(42) + a = nn.Conv1d(3, 16, kernel_size=3) + torch.manual_seed(42) + b = nn.Conv1d(3, 16, kernel_size=3) + assert comparator(a, b) + + # Test Conv1d with different out_channels + torch.manual_seed(42) + c = nn.Conv1d(3, 16, kernel_size=3) + torch.manual_seed(42) + d = nn.Conv1d(3, 32, kernel_size=3) + assert not comparator(c, d) + + # Test Conv3d + torch.manual_seed(42) + e = nn.Conv3d(3, 16, kernel_size=3) + torch.manual_seed(42) + f = nn.Conv3d(3, 16, kernel_size=3) + assert comparator(e, f) + + # Test Conv1d vs Conv2d (different types) + torch.manual_seed(42) + g = nn.Conv1d(3, 16, kernel_size=3) + torch.manual_seed(42) + h = nn.Conv2d(3, 16, kernel_size=3) + assert not comparator(g, h) + + +def test_torch_nn_flatten_unflatten(): + """Test comparator for Flatten and Unflatten modules.""" + try: + import torch + from torch import nn + except ImportError: + pytest.skip() + + # Test Flatten + a = nn.Flatten() + b = nn.Flatten() + assert comparator(a, b) + + # Test Flatten with different start_dim + c = nn.Flatten(start_dim=1) + d = nn.Flatten(start_dim=0) + assert not comparator(c, d) + + # Test Unflatten + e = nn.Unflatten(dim=1, unflattened_size=(2, 5)) + f = nn.Unflatten(dim=1, unflattened_size=(2, 5)) + assert comparator(e, f) + + +def test_torch_nn_identity(): + """Test comparator for Identity module.""" + try: + import torch + from torch import nn + except ImportError: + pytest.skip() + + # Test Identity + a = nn.Identity() + b = nn.Identity() + assert comparator(a, b) + + # Test Identity vs Linear (different types) + torch.manual_seed(42) + c = nn.Identity() + d = nn.Linear(10, 10) + assert not comparator(c, d) + + +def test_torch_nn_with_superset(): + """Test comparator superset_obj mode with nn.Module.""" + try: + import torch + from torch import nn + except ImportError: + pytest.skip() + + # For nn.Module, superset_obj should still work + torch.manual_seed(42) + a = nn.Linear(10, 5) + torch.manual_seed(42) + b = nn.Linear(10, 5) + + # superset_obj=True should pass for identical modules + assert comparator(a, b, superset_obj=True) + + # Different modules should still fail + torch.manual_seed(42) + c = nn.Linear(10, 5) + torch.manual_seed(123) + d = nn.Linear(10, 5) + assert not comparator(c, d, superset_obj=True) + + def test_jax(): try: import jax.numpy as jnp From 648c4473e5e87a4c883122236d87a76742bb1e46 Mon Sep 17 00:00:00 2001 From: "claude[bot]" <41898282+claude[bot]@users.noreply.github.com> Date: Fri, 6 Feb 2026 21:57:23 +0000 Subject: [PATCH 2/3] style: add type annotation for superset_obj parameter --- codeflash/verification/comparator.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/codeflash/verification/comparator.py b/codeflash/verification/comparator.py index eb86c790a..6d3618a3f 100644 --- a/codeflash/verification/comparator.py +++ b/codeflash/verification/comparator.py @@ -94,7 +94,7 @@ def _get_wrapped_exception(exc: BaseException) -> Optional[BaseException]: # no return _extract_exception_from_message(str(exc)) -def comparator(orig: Any, new: Any, superset_obj=False) -> bool: +def comparator(orig: Any, new: Any, superset_obj: bool = False) -> bool: """Compare two objects for equality recursively. If superset_obj is True, the new object is allowed to have more keys than the original object. However, the existing keys/values must be equivalent.""" try: # Handle exceptions specially - before type check to allow wrapper comparison From 555a2f92c597026e97799c7d0ba9a4d16ef9478a Mon Sep 17 00:00:00 2001 From: aseembits93 Date: Fri, 6 Feb 2026 14:04:33 -0800 Subject: [PATCH 3/3] standalone tests --- tests/test_comparator.py | 157 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 157 insertions(+) diff --git a/tests/test_comparator.py b/tests/test_comparator.py index 30bbe8700..4635acc54 100644 --- a/tests/test_comparator.py +++ b/tests/test_comparator.py @@ -7,6 +7,7 @@ import re import sys import uuid +import weakref from collections import ChainMap, Counter, OrderedDict, UserDict, UserList, UserString, defaultdict, deque, namedtuple from enum import Enum, Flag, IntFlag, auto from pathlib import Path @@ -136,6 +137,162 @@ def test_basic_python_objects() -> None: assert not comparator(a, c) +def test_weakref() -> None: + """Test comparator for weakref.ref objects.""" + + # Helper class that supports weak references and has comparable __dict__ + class Holder: + def __init__(self, value): + self.value = value + + # Test weak references to the same object + obj = Holder([1, 2, 3]) + ref1 = weakref.ref(obj) + ref2 = weakref.ref(obj) + assert comparator(ref1, ref2) + + # Test weak references to equivalent but different objects + obj1 = Holder({"key": "value"}) + obj2 = Holder({"key": "value"}) + ref1 = weakref.ref(obj1) + ref2 = weakref.ref(obj2) + assert comparator(ref1, ref2) + + # Test weak references to different objects + obj1 = Holder([1, 2, 3]) + obj2 = Holder([1, 2, 4]) + ref1 = weakref.ref(obj1) + ref2 = weakref.ref(obj2) + assert not comparator(ref1, ref2) + + # Test weak references with different data + obj1 = Holder([1, 2, 3]) + obj2 = Holder([1, 2, 3, 4]) + ref1 = weakref.ref(obj1) + ref2 = weakref.ref(obj2) + assert not comparator(ref1, ref2) + + # Test dead weak references (both dead) + obj1 = Holder([1, 2, 3]) + obj2 = Holder([1, 2, 3]) + ref1 = weakref.ref(obj1) + ref2 = weakref.ref(obj2) + del obj1 + del obj2 + # Both refs are now dead, should be equal + assert comparator(ref1, ref2) + + # Test one dead, one alive weak reference + obj1 = Holder([1, 2, 3]) + obj2 = Holder([1, 2, 3]) + ref1 = weakref.ref(obj1) + ref2 = weakref.ref(obj2) + del obj1 + # ref1 is dead, ref2 is alive, should not be equal + assert not comparator(ref1, ref2) + assert not comparator(ref2, ref1) + + # Test weak references to nested structures + obj1 = Holder({"nested": [1, 2, {"inner": "value"}]}) + obj2 = Holder({"nested": [1, 2, {"inner": "value"}]}) + ref1 = weakref.ref(obj1) + ref2 = weakref.ref(obj2) + assert comparator(ref1, ref2) + + # Test weak references to nested structures with differences + obj1 = Holder({"nested": [1, 2, {"inner": "value1"}]}) + obj2 = Holder({"nested": [1, 2, {"inner": "value2"}]}) + ref1 = weakref.ref(obj1) + ref2 = weakref.ref(obj2) + assert not comparator(ref1, ref2) + + # Test weak references in a dictionary (simulating __dict__ with weakrefs) + obj1 = Holder([1, 2, 3]) + obj2 = Holder([1, 2, 3]) + dict1 = {"data": 42, "ref": weakref.ref(obj1)} + dict2 = {"data": 42, "ref": weakref.ref(obj2)} + assert comparator(dict1, dict2) + + # Test weak references in a dictionary with different referents + obj1 = Holder([1, 2, 3]) + obj2 = Holder([4, 5, 6]) + dict1 = {"data": 42, "ref": weakref.ref(obj1)} + dict2 = {"data": 42, "ref": weakref.ref(obj2)} + assert not comparator(dict1, dict2) + + # Test weak references in a list + obj1 = Holder({"a": 1}) + obj2 = Holder({"a": 1}) + list1 = [weakref.ref(obj1), "other"] + list2 = [weakref.ref(obj2), "other"] + assert comparator(list1, list2) + + +def test_weakref_to_custom_objects() -> None: + """Test comparator for weakref.ref to custom class instances.""" + + class MyClass: + def __init__(self, value): + self.value = value + + # Test weak references to equivalent custom objects + obj1 = MyClass(42) + obj2 = MyClass(42) + ref1 = weakref.ref(obj1) + ref2 = weakref.ref(obj2) + assert comparator(ref1, ref2) + + # Test weak references to different custom objects + obj1 = MyClass(42) + obj2 = MyClass(99) + ref1 = weakref.ref(obj1) + ref2 = weakref.ref(obj2) + assert not comparator(ref1, ref2) + + # Test weak references to custom objects with nested data + class Container: + def __init__(self, items): + self.items = items + + obj1 = Container([1, 2, 3]) + obj2 = Container([1, 2, 3]) + ref1 = weakref.ref(obj1) + ref2 = weakref.ref(obj2) + assert comparator(ref1, ref2) + + obj1 = Container([1, 2, 3]) + obj2 = Container([1, 2, 4]) + ref1 = weakref.ref(obj1) + ref2 = weakref.ref(obj2) + assert not comparator(ref1, ref2) + + +def test_weakref_with_callbacks() -> None: + """Test that weakrefs with callbacks are compared correctly.""" + + class Holder: + def __init__(self, value): + self.value = value + + callback_called = [] + + def callback(ref): + callback_called.append(ref) + + obj1 = Holder([1, 2, 3]) + obj2 = Holder([1, 2, 3]) + # Weakrefs with callbacks should still compare based on referents + ref1 = weakref.ref(obj1, callback) + ref2 = weakref.ref(obj2, callback) + assert comparator(ref1, ref2) + + obj1 = Holder([1, 2, 3]) + obj2 = Holder([4, 5, 6]) + ref1 = weakref.ref(obj1, callback) + ref2 = weakref.ref(obj2, callback) + assert not comparator(ref1, ref2) + + @pytest.mark.parametrize( "r1, r2, expected", [