From db7993e7a996d0e2a16b37e637a3ef5495f8c48f Mon Sep 17 00:00:00 2001 From: Chin-Yun Yu Date: Thu, 29 May 2025 10:20:14 +0100 Subject: [PATCH 1/4] feat: CUDA kernels for LPC and complex LPC computation --- torchlpc/csrc/cuda/lpc.cu | 177 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 177 insertions(+) create mode 100644 torchlpc/csrc/cuda/lpc.cu diff --git a/torchlpc/csrc/cuda/lpc.cu b/torchlpc/csrc/cuda/lpc.cu new file mode 100644 index 0000000..0d472be --- /dev/null +++ b/torchlpc/csrc/cuda/lpc.cu @@ -0,0 +1,177 @@ +#include +#include +#include +#include +#include +#include + +// CUDA kernel for LPC computation +template +__global__ void lpc_cuda_kernel(scalar_t* padded_y, // [B, T + order] + const scalar_t* A, // [B, T, order] + int64_t B, int64_t T, int64_t order) { + extern __shared__ char smem[]; + scalar_t* sm = reinterpret_cast(smem); + + int b = blockIdx.x; + int i = threadIdx.x; + + if (b >= B || i >= order) return; + + // Initialize shared memory with the first 'order' elements + sm[i] = padded_y[b * (T + order) + i]; + __syncthreads(); + + int circular_idx = 0; + for (int t = 0; t < T; ++t) { + circular_idx = t % order; + scalar_t a = -A[((b * T + t) * order) + i]; + + // Compute s as in the Python code + int idx_offset = circular_idx - i - 1; + if (i > circular_idx - 1) { + idx_offset += order; + } + scalar_t s = sm[(idx_offset + order) % order]; + + scalar_t v = a * s; + + if (i == order - 1) { + sm[circular_idx] = v; + v = padded_y[b * (T + order) + t + order]; + } + __syncthreads(); + + // Atomic add to shared memory + atomicAdd(&sm[circular_idx], v); + __syncthreads(); + + if (i == order - 1) { + padded_y[b * (T + order) + t + order] = sm[circular_idx]; + } + __syncthreads(); + } +} +// CUDA kernel for complex LPC computation +template +__global__ void lpc_cuda_kernel_complex( + scalar_t* padded_y_real, // [B, T + order] + scalar_t* padded_y_imag, // [B, T + order] + const scalar_t* A_real, // [B, T, order] + const scalar_t* A_imag, // [B, T, order] + int64_t B, int64_t T, int64_t order) { + extern __shared__ char smem[]; + scalar_t* sm_real = reinterpret_cast(smem); + scalar_t* sm_imag = sm_real + order; + + int b = blockIdx.x; + int i = threadIdx.x; + + if (b >= B || i >= order) return; + + // Initialize shared memory with the first 'order' elements + sm_real[i] = padded_y_real[b * (T + order) + i]; + sm_imag[i] = padded_y_imag[b * (T + order) + i]; + __syncthreads(); + + int circular_idx = 0; + for (int t = 0; t < T; ++t) { + circular_idx = t % order; + scalar_t a_real = -A_real[((b * T + t) * order) + i]; + scalar_t a_imag = -A_imag[((b * T + t) * order) + i]; + + int idx_offset = circular_idx - i - 1; + if (i > circular_idx - 1) { + idx_offset += order; + } + int s_idx = (idx_offset + order) % order; + scalar_t s_real = sm_real[s_idx]; + scalar_t s_imag = sm_imag[s_idx]; + + // Complex multiply: v = a * s + scalar_t v_real = a_real * s_real - a_imag * s_imag; + scalar_t v_imag = a_real * s_imag + a_imag * s_real; + + if (i == order - 1) { + sm_real[circular_idx] = v_real; + sm_imag[circular_idx] = v_imag; + v_real = padded_y_real[b * (T + order) + t + order]; + v_imag = padded_y_imag[b * (T + order) + t + order]; + } + __syncthreads(); + + atomicAdd(&sm_real[circular_idx], v_real); + atomicAdd(&sm_imag[circular_idx], v_imag); + __syncthreads(); + + if (i == order - 1) { + padded_y_real[b * (T + order) + t + order] = sm_real[circular_idx]; + padded_y_imag[b * (T + order) + t + order] = sm_imag[circular_idx]; + } + __syncthreads(); + } +} + +at::Tensor lpc_cuda_wrapper(const at::Tensor& x, const at::Tensor& a, + const at::Tensor& zi) { + TORCH_CHECK(x.is_floating_point() || x.is_complex(), + "Input must be floating point or complex"); + TORCH_CHECK(a.scalar_type() == x.scalar_type(), + "Coefficients must have the same scalar type as input"); + TORCH_CHECK(zi.scalar_type() == x.scalar_type(), + "Initial conditions must have the same scalar type as input"); + + TORCH_CHECK(x.dim() == 2, "Input must be 2D"); + TORCH_CHECK(zi.dim() == 2, "Initial conditions must be 2D"); + TORCH_CHECK(x.size(0) == zi.size(0), + "Batch size of input and initial conditions must match"); + + const at::cuda::OptionalCUDAGuard device_guard(device_of(x)); + + auto a_contiguous = a.contiguous(); + + at::Tensor out; + auto order = a_contiguous.size(2); + assert(order <= 1024 && "LPC order must be less than or equal to 1024"); + auto threads_per_block = order; + + if (x.is_floating_point()) { + out = at::cat({zi.flip(1), x}, 1).contiguous(); + AT_DISPATCH_FLOATING_TYPES(x.scalar_type(), "lpc_cuda", [&] { + auto padded_y = out.mutable_data_ptr(); + auto A = a_contiguous.const_data_ptr(); + auto B = x.size(0); + auto T = x.size(1); + + lpc_cuda_kernel<<>>( + padded_y, A, B, T, order); + }); + } else { + auto out_real = + at::cat({at::real(zi).flip(1), at::real(x)}, 1).contiguous(); + auto out_imag = + at::cat({at::imag(zi).flip(1), at::imag(x)}, 1).contiguous(); + auto a_real = at::real(a_contiguous).contiguous(); + auto a_imag = at::imag(a_contiguous).contiguous(); + AT_DISPATCH_FLOATING_TYPES( + out_real.scalar_type(), "lpc_cuda_complex", [&] { + auto padded_y_real = out_real.mutable_data_ptr(); + auto padded_y_imag = out_imag.mutable_data_ptr(); + auto A_real = a_real.const_data_ptr(); + auto A_imag = a_imag.const_data_ptr(); + auto B = x.size(0); + auto T = x.size(1); + + lpc_cuda_kernel_complex + <<>>( + padded_y_real, padded_y_imag, A_real, A_imag, B, T, + order); + }); + out = at::view_as_complex(at::stack({out_real, out_imag}, -1)); + } + return out.slice(1, order, out.size(1)).contiguous(); +} + +TORCH_LIBRARY_IMPL(torchlpc, CUDA, m) { m.impl("lpc", &lpc_cuda_wrapper); } \ No newline at end of file From e67cea27af7acc6cc90832658371f97db3cda161 Mon Sep 17 00:00:00 2001 From: Chin-Yun Yu Date: Thu, 29 May 2025 10:50:01 +0100 Subject: [PATCH 2/4] refactor: backend selection logic --- torchlpc/core.py | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/torchlpc/core.py b/torchlpc/core.py index 2101601..4c3a981 100644 --- a/torchlpc/core.py +++ b/torchlpc/core.py @@ -159,20 +159,21 @@ def lpc_np(x: np.ndarray, A: np.ndarray, zi: np.ndarray) -> np.ndarray: class LPC(Function): @staticmethod def forward(x: torch.Tensor, A: torch.Tensor, zi: torch.Tensor) -> torch.Tensor: - if x.is_cuda: - y = lpc_cuda(x.detach(), A.detach(), zi.detach()) - elif EXTENSION_LOADED: + if EXTENSION_LOADED: y = torch.ops.torchlpc.lpc(x, A, zi) else: warnings.warn( "Cannot find custom extension. Falling back to Numba implementation which will be deprecated in v1.0." ) - y = lpc_np( - x.detach().cpu().numpy(), - A.detach().cpu().numpy(), - zi.detach().cpu().numpy(), - ) - y = torch.from_numpy(y).to(x.device, x.dtype) + if x.is_cuda: + y = lpc_cuda(x.detach(), A.detach(), zi.detach()) + else: + y = lpc_np( + x.detach().cpu().numpy(), + A.detach().cpu().numpy(), + zi.detach().cpu().numpy(), + ) + y = torch.from_numpy(y).to(x.device, x.dtype) return y @staticmethod From eea9cbfe66fda0addcfd2de29a932a9d0d7b0a79 Mon Sep 17 00:00:00 2001 From: Chin-Yun Yu Date: Thu, 29 May 2025 10:50:50 +0100 Subject: [PATCH 3/4] refactor: streamline CUDA and CPU runner assignments in recurrence.py --- torchlpc/recurrence.py | 62 ++++++++++++++++++++++++++---------------- 1 file changed, 38 insertions(+), 24 deletions(-) diff --git a/torchlpc/recurrence.py b/torchlpc/recurrence.py index bcb5334..2f4d176 100644 --- a/torchlpc/recurrence.py +++ b/torchlpc/recurrence.py @@ -8,30 +8,48 @@ from .core import lpc_cuda, lpc_np from . import EXTENSION_LOADED +if EXTENSION_LOADED: + lpc_cuda_runner = torch.ops.torchlpc.lpc + lpc_cpu_runner = torch.ops.torchlpc.lpc + + scan_cuda_runner = torch.ops.torchlpc.scan + scan_cpu_runner = torch.ops.torchlpc.scan +else: + lpc_cuda_runner = lpc_cuda + lpc_cpu_runner = lambda x, A, zi: torch.from_numpy( + lpc_np(x.detach().numpy(), A.detach().numpy(), zi.detach().numpy()) + ) + + scan_cuda_runner = lambda impulse, decay, initial_state: ( + lambda out: ( + out, + compute_linear_recurrence( + cuda.as_cuda_array(decay.detach()), + cuda.as_cuda_array(impulse.detach()), + cuda.as_cuda_array(initial_state.detach()), + cuda.as_cuda_array(out), + decay.shape[0], + decay.shape[1], + ), + ) + )(torch.empty_like(impulse))[0] + scan_cpu_runner = lambda impulse, decay, initial_state: torch.from_numpy( + lpc_np( + impulse.detach().numpy(), + -decay.unsqueeze(2).detach().numpy(), + initial_state.unsqueeze(1).detach().numpy(), + ) + ) + def _cuda_recurrence( impulse: torch.Tensor, decay: torch.Tensor, initial_state: torch.Tensor ) -> torch.Tensor: n_dims, n_steps = decay.shape if n_dims * WARPSIZE < n_steps: - if EXTENSION_LOADED: - runner = torch.ops.torchlpc.scan - else: - - def runner(impulse, decay, initial_state): - out = torch.empty_like(impulse) - compute_linear_recurrence( - cuda.as_cuda_array(decay.detach()), - cuda.as_cuda_array(impulse.detach()), - cuda.as_cuda_array(initial_state.detach()), - cuda.as_cuda_array(out), - n_dims, - n_steps, - ) - return out - + runner = scan_cuda_runner else: - runner = lambda impulse, decay, initial_state: lpc_cuda( + runner = lambda impulse, decay, initial_state: lpc_cuda_runner( impulse, -decay.unsqueeze(2), initial_state.unsqueeze(1) ) return runner(impulse, decay, initial_state) @@ -44,14 +62,10 @@ def _cpu_recurrence( n_dims, _ = decay.shape # This is just a rough estimation of the computational cost if EXTENSION_LOADED and min(n_dims, num_threads) < num_threads / 3: - runner = torch.ops.torchlpc.scan + runner = scan_cpu_runner else: - runner = lambda impulse, decay, initial_state: torch.from_numpy( - lpc_np( - impulse.detach().numpy(), - -decay.unsqueeze(2).detach().numpy(), - initial_state.unsqueeze(1).detach().numpy(), - ) + runner = lambda impulse, decay, initial_state: lpc_cpu_runner( + impulse, -decay.unsqueeze(2), initial_state.unsqueeze(1) ) return runner(impulse, decay, initial_state) From 65a5a83374cec1940b071921264ab31e63d610ad Mon Sep 17 00:00:00 2001 From: Chin-Yun Yu Date: Thu, 29 May 2025 10:51:37 +0100 Subject: [PATCH 4/4] test: update lpc equivalence test for cuda device --- tests/test_extension.py | 26 +++++++++++++++++++------- 1 file changed, 19 insertions(+), 7 deletions(-) diff --git a/tests/test_extension.py b/tests/test_extension.py index c3fa107..5131085 100644 --- a/tests/test_extension.py +++ b/tests/test_extension.py @@ -64,20 +64,32 @@ def test_scan_equiv(samples: int, cmplx: bool, device: str): ).item() -@pytest.mark.parametrize( - "samples", - [1024], -) +@pytest.mark.parametrize("samples", [1021, 4097]) @pytest.mark.parametrize( "cmplx", [True, False], ) -def test_lpc_equiv(samples: int, cmplx: bool): +@pytest.mark.parametrize( + "device", + [ + "cpu", + pytest.param( + "cuda", + marks=pytest.mark.skipif( + not torch.cuda.is_available(), reason="CUDA not available" + ), + ), + ], +) +def test_lpc_equiv(samples: int, cmplx: bool, device: str): batch_size = 4 x, A, zi = tuple( - x.to("cpu") for x in create_test_inputs(batch_size, samples, cmplx) + x.to(device) for x in create_test_inputs(batch_size, samples, cmplx) ) - numba_y = torch.from_numpy(lpc_np(x.numpy(), A.numpy(), zi.numpy())) + if device == "cuda": + numba_y = lpc_cuda(x, A, zi) + else: + numba_y = torch.from_numpy(lpc_np(x.numpy(), A.numpy(), zi.numpy())) ext_y = torch.ops.torchlpc.lpc(x, A, zi) assert torch.allclose(numba_y, ext_y)