diff --git a/src/recursion/models/trivial_memory_model.py b/src/recursion/models/trivial_memory_model.py index f6f651f3..0357784c 100644 --- a/src/recursion/models/trivial_memory_model.py +++ b/src/recursion/models/trivial_memory_model.py @@ -18,7 +18,7 @@ from recursion.dataset.repeat_after_k import make_sequences -class TrivialMemoryModel(nn.Module): +class ResidualMemoryModel(nn.Module): def __init__(self, memory_dim: int): super().__init__() @@ -32,26 +32,26 @@ def forward(self, input: Tensor, memory: Tensor) -> tuple[Tensor, Tensor]: x = self.relu(self.fc1(x)) x = self.fc2(x) - return x + return x + memory batch_size = 16 -k = 3 +k = 2 input_sequences, target_sequences = make_sequences(50000, k, batch_size=batch_size) memory_dim = 8 -model = TrivialMemoryModel(memory_dim) +model = ResidualMemoryModel(memory_dim) head = nn.Linear(memory_dim, 1) memory = torch.zeros(batch_size, memory_dim) criterion = nn.BCEWithLogitsLoss(reduction="none") -optimizer = SGD(model.parameters(), lr=1e-2) -head_optimizer = SGD(head.parameters(), lr=1e-2) +optimizer = SGD(model.parameters(), lr=5e-03) +head_optimizer = SGD(head.parameters(), lr=5e-03) memories = [] memories_wrt = [] param_to_jacobians = defaultdict(list) torch.set_printoptions(linewidth=200) -update_every = 4 +update_every = 3 aggregator = UPGrad()