From ba590dca2d19258fa9b5b07aeb3cd5706c668de8 Mon Sep 17 00:00:00 2001 From: Pierre Quinton Date: Fri, 19 Dec 2025 16:02:50 +0100 Subject: [PATCH 1/3] Change TrivialMemoryModel to ResidualMemoryModel --- src/recursion/models/trivial_memory_model.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/recursion/models/trivial_memory_model.py b/src/recursion/models/trivial_memory_model.py index f6f651f3..da3b14c0 100644 --- a/src/recursion/models/trivial_memory_model.py +++ b/src/recursion/models/trivial_memory_model.py @@ -18,7 +18,7 @@ from recursion.dataset.repeat_after_k import make_sequences -class TrivialMemoryModel(nn.Module): +class ResidualMemoryModel(nn.Module): def __init__(self, memory_dim: int): super().__init__() @@ -32,7 +32,7 @@ def forward(self, input: Tensor, memory: Tensor) -> tuple[Tensor, Tensor]: x = self.relu(self.fc1(x)) x = self.fc2(x) - return x + return x + memory batch_size = 16 @@ -41,7 +41,7 @@ def forward(self, input: Tensor, memory: Tensor) -> tuple[Tensor, Tensor]: memory_dim = 8 -model = TrivialMemoryModel(memory_dim) +model = ResidualMemoryModel(memory_dim) head = nn.Linear(memory_dim, 1) memory = torch.zeros(batch_size, memory_dim) criterion = nn.BCEWithLogitsLoss(reduction="none") From f454a669973e2d498a80fb9598f2bcd4c5d9285c Mon Sep 17 00:00:00 2001 From: Pierre Quinton Date: Fri, 19 Dec 2025 20:22:51 +0100 Subject: [PATCH 2/3] Interesting, this seems to explode even with an insanely low LR. --- src/recursion/models/trivial_memory_model.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/src/recursion/models/trivial_memory_model.py b/src/recursion/models/trivial_memory_model.py index da3b14c0..156bfb00 100644 --- a/src/recursion/models/trivial_memory_model.py +++ b/src/recursion/models/trivial_memory_model.py @@ -4,7 +4,7 @@ from torch import Tensor, nn from torch.nn.functional import cosine_similarity from torch.optim import SGD -from torchjd.aggregation import UPGrad +from torchjd.aggregation import Mean from torchjd.autojac._transform import ( Accumulate, Aggregate, @@ -36,24 +36,24 @@ def forward(self, input: Tensor, memory: Tensor) -> tuple[Tensor, Tensor]: batch_size = 16 -k = 3 +k = 20 input_sequences, target_sequences = make_sequences(50000, k, batch_size=batch_size) -memory_dim = 8 +memory_dim = 21 model = ResidualMemoryModel(memory_dim) head = nn.Linear(memory_dim, 1) memory = torch.zeros(batch_size, memory_dim) criterion = nn.BCEWithLogitsLoss(reduction="none") -optimizer = SGD(model.parameters(), lr=1e-2) -head_optimizer = SGD(head.parameters(), lr=1e-2) +optimizer = SGD(model.parameters(), lr=1e-60) +head_optimizer = SGD(head.parameters(), lr=1e-60) memories = [] memories_wrt = [] param_to_jacobians = defaultdict(list) torch.set_printoptions(linewidth=200) -update_every = 4 +update_every = 21 -aggregator = UPGrad() +aggregator = Mean() def hook(_, args: tuple[Tensor], __) -> None: From 48f246c937783b1e0da50fb4c34ae72a1b8117b8 Mon Sep 17 00:00:00 2001 From: Pierre Quinton Date: Sat, 20 Dec 2025 10:11:51 +0100 Subject: [PATCH 3/3] Witb k=3 it works sometimes. --- src/recursion/models/trivial_memory_model.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/src/recursion/models/trivial_memory_model.py b/src/recursion/models/trivial_memory_model.py index 156bfb00..0357784c 100644 --- a/src/recursion/models/trivial_memory_model.py +++ b/src/recursion/models/trivial_memory_model.py @@ -4,7 +4,7 @@ from torch import Tensor, nn from torch.nn.functional import cosine_similarity from torch.optim import SGD -from torchjd.aggregation import Mean +from torchjd.aggregation import UPGrad from torchjd.autojac._transform import ( Accumulate, Aggregate, @@ -36,24 +36,24 @@ def forward(self, input: Tensor, memory: Tensor) -> tuple[Tensor, Tensor]: batch_size = 16 -k = 20 +k = 2 input_sequences, target_sequences = make_sequences(50000, k, batch_size=batch_size) -memory_dim = 21 +memory_dim = 8 model = ResidualMemoryModel(memory_dim) head = nn.Linear(memory_dim, 1) memory = torch.zeros(batch_size, memory_dim) criterion = nn.BCEWithLogitsLoss(reduction="none") -optimizer = SGD(model.parameters(), lr=1e-60) -head_optimizer = SGD(head.parameters(), lr=1e-60) +optimizer = SGD(model.parameters(), lr=5e-03) +head_optimizer = SGD(head.parameters(), lr=5e-03) memories = [] memories_wrt = [] param_to_jacobians = defaultdict(list) torch.set_printoptions(linewidth=200) -update_every = 21 +update_every = 3 -aggregator = Mean() +aggregator = UPGrad() def hook(_, args: tuple[Tensor], __) -> None: