From be6271e47919040bed7bec0f4e7de47210ff0e9b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Thu, 18 Dec 2025 00:42:20 +0100 Subject: [PATCH 1/3] feat: [WIP] add RNN training --- pyproject.toml | 2 + src/recursion/dataset/repeat_after_k.py | 14 +++ src/recursion/models/trivial_memory_model.py | 103 +++++++++++++++++++ 3 files changed, 119 insertions(+) create mode 100644 src/recursion/dataset/repeat_after_k.py create mode 100644 src/recursion/models/trivial_memory_model.py diff --git a/pyproject.toml b/pyproject.toml index 2d20c234..02455643 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -26,4 +26,6 @@ dependencies = [ "numba", "triton", "pre-commit", + "torchjd", + "torchviz" ] diff --git a/src/recursion/dataset/repeat_after_k.py b/src/recursion/dataset/repeat_after_k.py new file mode 100644 index 00000000..67a21cc2 --- /dev/null +++ b/src/recursion/dataset/repeat_after_k.py @@ -0,0 +1,14 @@ +import torch +from torch import Tensor + + +def make_sequence(length: int, k: int) -> tuple[Tensor, Tensor]: + seq = torch.randint(low=0, high=2, size=[length + k]) + input = seq[k:] + + if k == 0: + target = seq + else: + target = seq[:-k] + + return input, target diff --git a/src/recursion/models/trivial_memory_model.py b/src/recursion/models/trivial_memory_model.py new file mode 100644 index 00000000..33281f96 --- /dev/null +++ b/src/recursion/models/trivial_memory_model.py @@ -0,0 +1,103 @@ +from collections import defaultdict + +import torch +from torch import Tensor, nn +from torch.optim import SGD + +from recursion.dataset.repeat_after_k import make_sequence + + +class TrivialMemoryModel(nn.Module): + def __init__(self, memory_dim: int): + super().__init__() + + hidden_size = 2 * (1 + memory_dim) + self.fc1 = nn.Linear(1 + memory_dim, hidden_size) + self.fc2 = nn.Linear(hidden_size, memory_dim) + # self.fc3 = nn.Linear(memory_dim, 1) + self.relu = nn.ReLU() + + def forward(self, input: Tensor, memory: Tensor) -> tuple[Tensor, Tensor]: + x = torch.cat([input, memory], dim=-1) + x = self.relu(self.fc1(x)) + x = self.fc2(x) + + return x + + +input_sequence, target_sequence = make_sequence(7, 3) + +memory_dim = 8 + +model = TrivialMemoryModel(memory_dim) +head = nn.Linear(memory_dim, 1) +memory = torch.randn(memory_dim) +criterion = nn.BCEWithLogitsLoss() +optimizer = SGD(model.parameters(), lr=1e-2) +memories = [] +memories_wrt = [] + +param_to_gradients = defaultdict(list) +torch.set_printoptions(linewidth=200) +update_every = 6 + +from torchjd.aggregation import UPGradWeighting + +weighting = UPGradWeighting() + +for i, (input, target) in enumerate(zip(input_sequence, target_sequence, strict=True)): + memories_wrt.append(memory.detach().requires_grad_(True)) + memory = model(input.unsqueeze(0).to(dtype=torch.float32), memories_wrt[-1]) + output = head(memory) + loss = criterion(output, target.unsqueeze(0).to(dtype=torch.float32)) + memories.append(memory) + + print(f"{loss.item():.1e}") + + if (i + 1) % update_every == 0: + optimizer.zero_grad() + + grad_output = torch.autograd.grad(loss, [memories[-1]]) + + for j in range(update_every): + print(j) + grads = torch.autograd.grad( + memories[-j - 1], + list(model.parameters()) + [memories_wrt[-j - 1]], + grad_outputs=grad_output, + ) + grads_wrt_params = grads[:-1] + grad_output = grads[-1] + + for param, grad in zip(model.parameters(), grads_wrt_params, strict=True): + param_to_gradients[param].append(grad) + + param_to_jacobian_matrix = { + param: torch.stack([g.flatten() for g in gradients], dim=0) + for param, gradients in param_to_gradients.items() + } + jacobian_matrix = torch.cat([mat for mat in param_to_jacobian_matrix.values()], dim=1) + + gramian = jacobian_matrix @ jacobian_matrix.T + weights = weighting(gramian) + # print(jacobian_matrix.shape) + print(gramian) + print(weights) + + # graph = make_dot(loss, params=dict(model.named_parameters()), show_attrs=True, show_saved=True) + # graph.view() + + # graph = make_dot(attached_memories[-1], params=dict(model.named_parameters()), show_attrs=True, + # show_saved=True) + # graph.view() + + # loss.backward() + + # print("fc1 weights: ", model.fc1.weight.grad) + # print("fc1 biases: ", model.fc1.bias.grad) + # + # print("fc2 weights: ", model.fc2.weight.grad) + # print("fc2 biases: ", model.fc2.bias.grad) + + optimizer.step() + memory = memory.detach() From 25a32ed628464e5cdcbf003e3ccb105c2eb9d972 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Thu, 18 Dec 2025 15:02:32 +0100 Subject: [PATCH 2/3] Improvements: * Reset memories, memories_wrt and param_to_gradients * Use transform to aggregate and accumulate into .grad * Train head too * Change some values * At this point, it seems hard to train with Mean() and doable with UPGrad() --- src/recursion/models/trivial_memory_model.py | 55 +++++++------------- 1 file changed, 19 insertions(+), 36 deletions(-) diff --git a/src/recursion/models/trivial_memory_model.py b/src/recursion/models/trivial_memory_model.py index 33281f96..5d2f66f3 100644 --- a/src/recursion/models/trivial_memory_model.py +++ b/src/recursion/models/trivial_memory_model.py @@ -3,6 +3,8 @@ import torch from torch import Tensor, nn from torch.optim import SGD +from torchjd.aggregation import UPGrad +from torchjd.autojac._transform import Accumulate, Aggregate, OrderedSet from recursion.dataset.repeat_after_k import make_sequence @@ -14,7 +16,6 @@ def __init__(self, memory_dim: int): hidden_size = 2 * (1 + memory_dim) self.fc1 = nn.Linear(1 + memory_dim, hidden_size) self.fc2 = nn.Linear(hidden_size, memory_dim) - # self.fc3 = nn.Linear(memory_dim, 1) self.relu = nn.ReLU() def forward(self, input: Tensor, memory: Tensor) -> tuple[Tensor, Tensor]: @@ -25,7 +26,7 @@ def forward(self, input: Tensor, memory: Tensor) -> tuple[Tensor, Tensor]: return x -input_sequence, target_sequence = make_sequence(7, 3) +input_sequence, target_sequence = make_sequence(50000, 3) memory_dim = 8 @@ -34,16 +35,14 @@ def forward(self, input: Tensor, memory: Tensor) -> tuple[Tensor, Tensor]: memory = torch.randn(memory_dim) criterion = nn.BCEWithLogitsLoss() optimizer = SGD(model.parameters(), lr=1e-2) +head_optimizer = SGD(head.parameters(), lr=1e-2) memories = [] memories_wrt = [] - param_to_gradients = defaultdict(list) torch.set_printoptions(linewidth=200) -update_every = 6 - -from torchjd.aggregation import UPGradWeighting +update_every = 4 -weighting = UPGradWeighting() +aggregator = UPGrad() for i, (input, target) in enumerate(zip(input_sequence, target_sequence, strict=True)): memories_wrt.append(memory.detach().requires_grad_(True)) @@ -51,16 +50,14 @@ def forward(self, input: Tensor, memory: Tensor) -> tuple[Tensor, Tensor]: output = head(memory) loss = criterion(output, target.unsqueeze(0).to(dtype=torch.float32)) memories.append(memory) + transform = Accumulate() << Aggregate(aggregator, OrderedSet(list(model.parameters()))) print(f"{loss.item():.1e}") if (i + 1) % update_every == 0: - optimizer.zero_grad() - - grad_output = torch.autograd.grad(loss, [memories[-1]]) + grad_output = torch.autograd.grad(loss, [memories[-1]], retain_graph=True) for j in range(update_every): - print(j) grads = torch.autograd.grad( memories[-j - 1], list(model.parameters()) + [memories_wrt[-j - 1]], @@ -72,32 +69,18 @@ def forward(self, input: Tensor, memory: Tensor) -> tuple[Tensor, Tensor]: for param, grad in zip(model.parameters(), grads_wrt_params, strict=True): param_to_gradients[param].append(grad) - param_to_jacobian_matrix = { - param: torch.stack([g.flatten() for g in gradients], dim=0) - for param, gradients in param_to_gradients.items() + param_to_jacobian = { + param: torch.stack(gradients, dim=0) for param, gradients in param_to_gradients.items() } - jacobian_matrix = torch.cat([mat for mat in param_to_jacobian_matrix.values()], dim=1) - - gramian = jacobian_matrix @ jacobian_matrix.T - weights = weighting(gramian) - # print(jacobian_matrix.shape) - print(gramian) - print(weights) - # graph = make_dot(loss, params=dict(model.named_parameters()), show_attrs=True, show_saved=True) - # graph.view() - - # graph = make_dot(attached_memories[-1], params=dict(model.named_parameters()), show_attrs=True, - # show_saved=True) - # graph.view() - - # loss.backward() + optimizer.zero_grad() + transform(param_to_jacobian) # This stores the aggregated Jacobian in the .grad fields + optimizer.step() - # print("fc1 weights: ", model.fc1.weight.grad) - # print("fc1 biases: ", model.fc1.bias.grad) - # - # print("fc2 weights: ", model.fc2.weight.grad) - # print("fc2 biases: ", model.fc2.bias.grad) + memories = [] + memories_wrt = [] + param_to_gradients = defaultdict(list) - optimizer.step() - memory = memory.detach() + head_optimizer.zero_grad() + torch.autograd.backward(loss, inputs=list(head.parameters())) + head_optimizer.step() From 927b36eb891bbc16208d4609e94db368e64b7aba Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Thu, 18 Dec 2025 17:19:56 +0100 Subject: [PATCH 3/3] Add batched JD training --- src/recursion/dataset/repeat_after_k.py | 8 +- src/recursion/models/trivial_memory_model.py | 89 +++++++++++++++----- 2 files changed, 73 insertions(+), 24 deletions(-) diff --git a/src/recursion/dataset/repeat_after_k.py b/src/recursion/dataset/repeat_after_k.py index 67a21cc2..549ad8b7 100644 --- a/src/recursion/dataset/repeat_after_k.py +++ b/src/recursion/dataset/repeat_after_k.py @@ -2,13 +2,13 @@ from torch import Tensor -def make_sequence(length: int, k: int) -> tuple[Tensor, Tensor]: - seq = torch.randint(low=0, high=2, size=[length + k]) - input = seq[k:] +def make_sequences(length: int, k: int, batch_size: int) -> tuple[Tensor, Tensor]: + seq = torch.randint(low=0, high=2, size=[batch_size, length + k]) + input = seq[:, k:] if k == 0: target = seq else: - target = seq[:-k] + target = seq[:, :-k] return input, target diff --git a/src/recursion/models/trivial_memory_model.py b/src/recursion/models/trivial_memory_model.py index 5d2f66f3..f6f651f3 100644 --- a/src/recursion/models/trivial_memory_model.py +++ b/src/recursion/models/trivial_memory_model.py @@ -2,11 +2,20 @@ import torch from torch import Tensor, nn +from torch.nn.functional import cosine_similarity from torch.optim import SGD from torchjd.aggregation import UPGrad -from torchjd.autojac._transform import Accumulate, Aggregate, OrderedSet +from torchjd.autojac._transform import ( + Accumulate, + Aggregate, + Diagonalize, + Init, + Jac, + OrderedSet, + Select, +) -from recursion.dataset.repeat_after_k import make_sequence +from recursion.dataset.repeat_after_k import make_sequences class TrivialMemoryModel(nn.Module): @@ -26,51 +35,91 @@ def forward(self, input: Tensor, memory: Tensor) -> tuple[Tensor, Tensor]: return x -input_sequence, target_sequence = make_sequence(50000, 3) +batch_size = 16 +k = 3 +input_sequences, target_sequences = make_sequences(50000, k, batch_size=batch_size) memory_dim = 8 model = TrivialMemoryModel(memory_dim) head = nn.Linear(memory_dim, 1) -memory = torch.randn(memory_dim) -criterion = nn.BCEWithLogitsLoss() +memory = torch.zeros(batch_size, memory_dim) +criterion = nn.BCEWithLogitsLoss(reduction="none") optimizer = SGD(model.parameters(), lr=1e-2) head_optimizer = SGD(head.parameters(), lr=1e-2) memories = [] memories_wrt = [] -param_to_gradients = defaultdict(list) +param_to_jacobians = defaultdict(list) torch.set_printoptions(linewidth=200) update_every = 4 aggregator = UPGrad() -for i, (input, target) in enumerate(zip(input_sequence, target_sequence, strict=True)): + +def hook(_, args: tuple[Tensor], __) -> None: + jacobian = args[0] + gramian = jacobian @ jacobian.T + print(gramian[0, 0] / gramian[k * batch_size, k * batch_size]) + + +def print_gd_similarity(_, inputs: tuple[torch.Tensor, ...], aggregation: torch.Tensor) -> None: + """Prints the cosine similarity between the aggregation and the average gradient.""" + matrix = inputs[0] + gd_output = matrix.mean(dim=0) + similarity = cosine_similarity(aggregation, gd_output, dim=0) + print(f"Cosine similarity: {similarity.item():.4f}") + + +aggregator.register_forward_hook(hook) +aggregator.register_forward_hook(print_gd_similarity) + +for i, (input, target) in enumerate(zip(input_sequences.T, target_sequences.T, strict=True)): memories_wrt.append(memory.detach().requires_grad_(True)) - memory = model(input.unsqueeze(0).to(dtype=torch.float32), memories_wrt[-1]) + + memory = model(input.unsqueeze(1).to(dtype=torch.float32), memories_wrt[-1]) output = head(memory) - loss = criterion(output, target.unsqueeze(0).to(dtype=torch.float32)) + losses = criterion(output, target.unsqueeze(1).to(dtype=torch.float32)) + loss = losses.mean() memories.append(memory) transform = Accumulate() << Aggregate(aggregator, OrderedSet(list(model.parameters()))) print(f"{loss.item():.1e}") if (i + 1) % update_every == 0: - grad_output = torch.autograd.grad(loss, [memories[-1]], retain_graph=True) + # grad_output = torch.autograd.grad(loss, [memories[-1]], retain_graph=True) + + ordered_set = OrderedSet(losses) + init = Init(ordered_set) + diag = Diagonalize(ordered_set) + jac = Jac(ordered_set, OrderedSet([memories[-1]]), chunk_size=None, retain_graph=True) + trans = jac << diag << init + + trans.check_keys(set()) + + jac_output = trans({}) for j in range(update_every): - grads = torch.autograd.grad( - memories[-j - 1], - list(model.parameters()) + [memories_wrt[-j - 1]], - grad_outputs=grad_output, + new_jac = Jac( + OrderedSet([memories[-j - 1]]), + OrderedSet(list(model.parameters()) + [memories_wrt[-j - 1]]), + chunk_size=None, ) - grads_wrt_params = grads[:-1] - grad_output = grads[-1] + select_jac_wrt_model = Select(OrderedSet(list(model.parameters()))) + select_jac_wrt_memory = Select(OrderedSet([memories_wrt[-j - 1]])) + + jacobians = new_jac(jac_output) + jac_output = select_jac_wrt_memory(jacobians) + + if j < update_every - 1: + jac_output = {memories[-j - 2]: jac_output[memories_wrt[-j - 1]]} + + jac_wrt_params = select_jac_wrt_model(jacobians) - for param, grad in zip(model.parameters(), grads_wrt_params, strict=True): - param_to_gradients[param].append(grad) + for param, jacob in jac_wrt_params.items(): + param_to_jacobians[param].append(jacob) param_to_jacobian = { - param: torch.stack(gradients, dim=0) for param, gradients in param_to_gradients.items() + param: torch.cat(jacobs, dim=0) for param, jacobs in param_to_jacobians.items() } optimizer.zero_grad() @@ -79,7 +128,7 @@ def forward(self, input: Tensor, memory: Tensor) -> tuple[Tensor, Tensor]: memories = [] memories_wrt = [] - param_to_gradients = defaultdict(list) + param_to_jacobians = defaultdict(list) head_optimizer.zero_grad() torch.autograd.backward(loss, inputs=list(head.parameters()))