Skip to content

Commit 0959a86

Browse files
author
Hanna Ruth
committed
FIX: fixed bugs in fft rotate function
1 parent a06a945 commit 0959a86

1 file changed

Lines changed: 6 additions & 8 deletions

File tree

src/llama/transformations/functions.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from llama.gpu_utils import get_fft_backend, get_scipy_module
1010
from llama.transformations.helpers import preserve_complexity_or_realness
1111

12-
from llama.api.types import ArrayType, r_type
12+
from llama.api.types import ArrayType, r_type, c_type
1313

1414

1515
def image_crop(
@@ -297,13 +297,11 @@ def image_rotate_fft(images: ArrayType, theta: float) -> ArrayType:
297297
n_x = -xp.sin(theta * xp.pi / 180) * x_grid
298298
n_y = xp.tan(theta / 2 * xp.pi / 180) * y_grid
299299

300-
assert type(n_x) is r_type
300+
m_1 = xp.array(xp.exp(-2j * xp.pi * M_grid * n_y)).astype(c_type)
301+
m_2 = xp.array(xp.exp(-2j * xp.pi * xp.multiply(N_grid, n_x))).astype(c_type)
301302

302-
m_1 = xp.exp(-2j * xp.pi * M_grid * n_y)
303-
m_2 = xp.exp(-2j * xp.pi * xp.multiply(N_grid, n_x))
304-
305-
images = scipy_module.fft.ifft(xp.multiply(scipy.fft.fft(images, axis=2), m_1), axis=2)
306-
images = scipy_module.fft.ifft(xp.multiply(scipy.fft.fft(images, axis=1), m_2), axis=1)
307-
images = scipy_module.fft.ifft(xp.multiply(scipy.fft.fft(images, axis=2), m_1), axis=2)
303+
images = scipy_module.fft.ifft(xp.multiply(scipy_module.fft.fft(images, axis=2), m_1), axis=2)
304+
images = scipy_module.fft.ifft(xp.multiply(scipy_module.fft.fft(images, axis=1), m_2), axis=1)
305+
images = scipy_module.fft.ifft(xp.multiply(scipy_module.fft.fft(images, axis=2), m_1), axis=2)
308306

309307
return images

0 commit comments

Comments
 (0)