diff --git a/torchlpc/__init__.py b/torchlpc/__init__.py index 620d360..6f3d1a6 100644 --- a/torchlpc/__init__.py +++ b/torchlpc/__init__.py @@ -1,5 +1,5 @@ import torch -from typing import Optional +from typing import Optional, Union, Tuple from pathlib import Path import warnings @@ -31,14 +31,18 @@ def sample_wise_lpc( - x: torch.Tensor, a: torch.Tensor, zi: Optional[torch.Tensor] = None -) -> torch.Tensor: + x: torch.Tensor, + a: torch.Tensor, + zi: Optional[torch.Tensor] = None, + return_zf: bool = False, +) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: """Compute LPC filtering sample-wise. Args: x (torch.Tensor): Input signal. a (torch.Tensor): LPC coefficients. zi (torch.Tensor): Initial conditions. + return_zf (bool): If True, return the final filter delay values. Defaults to False. Shape: - x: :math:`(B, T)` @@ -46,7 +50,8 @@ def sample_wise_lpc( - zi: :math:`(B, order)` Returns: - torch.Tensor: Filtered signal with the same shape as x. + Filtered signal with the same shape as x if `return_zf` is False. + If `return_zf` is True, returns a tuple of the filtered signal and the final delay values. """ assert x.shape[0] == a.shape[0] assert x.shape[1] == a.shape[1] @@ -62,6 +67,10 @@ def sample_wise_lpc( # if order == 1 and x.is_cuda and B * WARPSIZE < T: # return RecurrenceCUDA.apply(-a.squeeze(2), x, zi.squeeze(1)) if order == 1: - return Recurrence.apply(-a.squeeze(2), x, zi.squeeze(1)) + y = Recurrence.apply(-a.squeeze(2), x, zi.squeeze(1)) + else: + y = LPC.apply(x, a, zi) - return LPC.apply(x, a, zi) + if return_zf: + return y, y[:, -order:].flip(1) + return y