diff --git a/pyproject.toml b/pyproject.toml index 2d20c234..02455643 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -26,4 +26,6 @@ dependencies = [ "numba", "triton", "pre-commit", + "torchjd", + "torchviz" ] diff --git a/src/recursion/dataset/repeat_after_k.py b/src/recursion/dataset/repeat_after_k.py new file mode 100644 index 00000000..549ad8b7 --- /dev/null +++ b/src/recursion/dataset/repeat_after_k.py @@ -0,0 +1,14 @@ +import torch +from torch import Tensor + + +def make_sequences(length: int, k: int, batch_size: int) -> tuple[Tensor, Tensor]: + seq = torch.randint(low=0, high=2, size=[batch_size, length + k]) + input = seq[:, k:] + + if k == 0: + target = seq + else: + target = seq[:, :-k] + + return input, target diff --git a/src/recursion/models/trivial_memory_model.py b/src/recursion/models/trivial_memory_model.py new file mode 100644 index 00000000..f6f651f3 --- /dev/null +++ b/src/recursion/models/trivial_memory_model.py @@ -0,0 +1,135 @@ +from collections import defaultdict + +import torch +from torch import Tensor, nn +from torch.nn.functional import cosine_similarity +from torch.optim import SGD +from torchjd.aggregation import UPGrad +from torchjd.autojac._transform import ( + Accumulate, + Aggregate, + Diagonalize, + Init, + Jac, + OrderedSet, + Select, +) + +from recursion.dataset.repeat_after_k import make_sequences + + +class TrivialMemoryModel(nn.Module): + def __init__(self, memory_dim: int): + super().__init__() + + hidden_size = 2 * (1 + memory_dim) + self.fc1 = nn.Linear(1 + memory_dim, hidden_size) + self.fc2 = nn.Linear(hidden_size, memory_dim) + self.relu = nn.ReLU() + + def forward(self, input: Tensor, memory: Tensor) -> tuple[Tensor, Tensor]: + x = torch.cat([input, memory], dim=-1) + x = self.relu(self.fc1(x)) + x = self.fc2(x) + + return x + + +batch_size = 16 +k = 3 +input_sequences, target_sequences = make_sequences(50000, k, batch_size=batch_size) + +memory_dim = 8 + +model = TrivialMemoryModel(memory_dim) +head = nn.Linear(memory_dim, 1) +memory = torch.zeros(batch_size, memory_dim) +criterion = nn.BCEWithLogitsLoss(reduction="none") +optimizer = SGD(model.parameters(), lr=1e-2) +head_optimizer = SGD(head.parameters(), lr=1e-2) +memories = [] +memories_wrt = [] +param_to_jacobians = defaultdict(list) +torch.set_printoptions(linewidth=200) +update_every = 4 + +aggregator = UPGrad() + + +def hook(_, args: tuple[Tensor], __) -> None: + jacobian = args[0] + gramian = jacobian @ jacobian.T + print(gramian[0, 0] / gramian[k * batch_size, k * batch_size]) + + +def print_gd_similarity(_, inputs: tuple[torch.Tensor, ...], aggregation: torch.Tensor) -> None: + """Prints the cosine similarity between the aggregation and the average gradient.""" + matrix = inputs[0] + gd_output = matrix.mean(dim=0) + similarity = cosine_similarity(aggregation, gd_output, dim=0) + print(f"Cosine similarity: {similarity.item():.4f}") + + +aggregator.register_forward_hook(hook) +aggregator.register_forward_hook(print_gd_similarity) + +for i, (input, target) in enumerate(zip(input_sequences.T, target_sequences.T, strict=True)): + memories_wrt.append(memory.detach().requires_grad_(True)) + + memory = model(input.unsqueeze(1).to(dtype=torch.float32), memories_wrt[-1]) + output = head(memory) + losses = criterion(output, target.unsqueeze(1).to(dtype=torch.float32)) + loss = losses.mean() + memories.append(memory) + transform = Accumulate() << Aggregate(aggregator, OrderedSet(list(model.parameters()))) + + print(f"{loss.item():.1e}") + + if (i + 1) % update_every == 0: + # grad_output = torch.autograd.grad(loss, [memories[-1]], retain_graph=True) + + ordered_set = OrderedSet(losses) + init = Init(ordered_set) + diag = Diagonalize(ordered_set) + jac = Jac(ordered_set, OrderedSet([memories[-1]]), chunk_size=None, retain_graph=True) + trans = jac << diag << init + + trans.check_keys(set()) + + jac_output = trans({}) + + for j in range(update_every): + new_jac = Jac( + OrderedSet([memories[-j - 1]]), + OrderedSet(list(model.parameters()) + [memories_wrt[-j - 1]]), + chunk_size=None, + ) + select_jac_wrt_model = Select(OrderedSet(list(model.parameters()))) + select_jac_wrt_memory = Select(OrderedSet([memories_wrt[-j - 1]])) + + jacobians = new_jac(jac_output) + jac_output = select_jac_wrt_memory(jacobians) + + if j < update_every - 1: + jac_output = {memories[-j - 2]: jac_output[memories_wrt[-j - 1]]} + + jac_wrt_params = select_jac_wrt_model(jacobians) + + for param, jacob in jac_wrt_params.items(): + param_to_jacobians[param].append(jacob) + + param_to_jacobian = { + param: torch.cat(jacobs, dim=0) for param, jacobs in param_to_jacobians.items() + } + + optimizer.zero_grad() + transform(param_to_jacobian) # This stores the aggregated Jacobian in the .grad fields + optimizer.step() + + memories = [] + memories_wrt = [] + param_to_jacobians = defaultdict(list) + + head_optimizer.zero_grad() + torch.autograd.backward(loss, inputs=list(head.parameters())) + head_optimizer.step()