From 8c48ac859f5e7b2fec95924eed0583ae70d74f12 Mon Sep 17 00:00:00 2001 From: aaprasad Date: Mon, 24 Feb 2025 14:34:32 -0500 Subject: [PATCH 1/2] add '.to' function to PPSeqmodel --- ppseq/model.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/ppseq/model.py b/ppseq/model.py index 73f83bf..317dfed 100644 --- a/ppseq/model.py +++ b/ppseq/model.py @@ -47,6 +47,7 @@ def __init__(self, self.base_rates = torch.ones(num_neurons, device=device) + self.template_scales = torch.ones(num_templates, num_neurons, device=device) / num_neurons self.template_offsets = template_duration * torch.rand(num_templates, num_neurons, device=device) self.template_widths = torch.ones(self.num_templates, self.num_neurons, device=device) @@ -59,6 +60,17 @@ def __init__(self, self.alpha_t0 = alpha_t0 self.beta_t0 = beta_t0 + def to(self, map_location: str | torch.DeviceObjType | torch.dtype): + + self.base_rates = self.base_rates.to(map_location) + self.template_scales = self.template_scales.to(map_location) + self.template_offsets = self.template_offsets.to(map_location) + self.template_widths = self.template_widths.to(map_location) + + if not isinstance(map_location, torch.dtype): + self.device=map_location + + @property def templates(self) -> Float[Tensor, "num_templates num_neurons duration"]: """Compute the templates from the mean, std, and amplitude of the Gaussian kernel. @@ -68,6 +80,7 @@ def templates(self) -> Float[Tensor, "num_templates num_neurons duration"]: ds = torch.arange(D, device=self.device)[:, None, None] p = dist.Normal(mu, sigma) W = p.log_prob(ds).exp().permute(1,2,0) + return W / W.sum(dim=2, keepdim=True) * amp[:, :, None] def reconstruct(self, From 4f98f4aab479c2ba41e87925cb17b7dd6a2c2542 Mon Sep 17 00:00:00 2001 From: aaprasad Date: Mon, 24 Feb 2025 14:47:30 -0500 Subject: [PATCH 2/2] fix potential shape error (B=1,N,T) (eg for torch dataset) --- ppseq/batch_model.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/ppseq/batch_model.py b/ppseq/batch_model.py index ae06bfa..3975c0c 100644 --- a/ppseq/batch_model.py +++ b/ppseq/batch_model.py @@ -18,6 +18,8 @@ def __init__(self, beta_a0: float=0., alpha_b0: float=0., beta_b0: float=0., + alpha_t0: float=0., + beta_t0:float=0., device=None): super().__init__(num_templates, num_neurons, @@ -26,6 +28,8 @@ def __init__(self, beta_a0, alpha_b0, beta_b0, + alpha_t0, + beta_t0, device) def fit(self, @@ -40,7 +44,7 @@ def fit(self, init_method = dict(random=self.initialize_random)[initialization.lower()] - amplitude_batches =[init_method(data) for data in data_batches] + amplitude_batches =[init_method(data.squeeze()) for data in data_batches] # TODO: Initialize amplitudes more intelligently? # amplitudes = torch.rand(K, T, device=self.device) + 1e-4 @@ -50,12 +54,13 @@ def fit(self, for _ in progress_bar(range(num_iter)): ll = 0 for i, data in enumerate(data_batches): + data = data.squeeze() # prevents indexing error when data_shape = (1, N, T) (e.g in a torch dataloader) amplitude_batches[i] = self._update_amplitudes(data, amplitude_batches[i]) self._update_base_rates(data, amplitude_batches[i]) self._update_templates(data, amplitude_batches[i]) ll += self.log_likelihood(data, amplitude_batches[i]) - lps.append(ll) + lps.append(ll) #return the sum or avg log likelihood? lps = torch.stack(lps) if num_iter > 0 else torch.tensor([]) return lps, amplitude_batches