|
9 | 9 | from llama.gpu_utils import get_fft_backend, get_scipy_module |
10 | 10 | from llama.transformations.helpers import preserve_complexity_or_realness |
11 | 11 |
|
12 | | -from llama.api.types import ArrayType, r_type |
| 12 | +from llama.api.types import ArrayType, r_type, c_type |
13 | 13 |
|
14 | 14 |
|
15 | 15 | def image_crop( |
@@ -297,13 +297,11 @@ def image_rotate_fft(images: ArrayType, theta: float) -> ArrayType: |
297 | 297 | n_x = -xp.sin(theta * xp.pi / 180) * x_grid |
298 | 298 | n_y = xp.tan(theta / 2 * xp.pi / 180) * y_grid |
299 | 299 |
|
300 | | - assert type(n_x) is r_type |
| 300 | + m_1 = xp.array(xp.exp(-2j * xp.pi * M_grid * n_y)).astype(c_type) |
| 301 | + m_2 = xp.array(xp.exp(-2j * xp.pi * xp.multiply(N_grid, n_x))).astype(c_type) |
301 | 302 |
|
302 | | - m_1 = xp.exp(-2j * xp.pi * M_grid * n_y) |
303 | | - m_2 = xp.exp(-2j * xp.pi * xp.multiply(N_grid, n_x)) |
304 | | - |
305 | | - images = scipy_module.fft.ifft(xp.multiply(scipy.fft.fft(images, axis=2), m_1), axis=2) |
306 | | - images = scipy_module.fft.ifft(xp.multiply(scipy.fft.fft(images, axis=1), m_2), axis=1) |
307 | | - images = scipy_module.fft.ifft(xp.multiply(scipy.fft.fft(images, axis=2), m_1), axis=2) |
| 303 | + images = scipy_module.fft.ifft(xp.multiply(scipy_module.fft.fft(images, axis=2), m_1), axis=2) |
| 304 | + images = scipy_module.fft.ifft(xp.multiply(scipy_module.fft.fft(images, axis=1), m_2), axis=1) |
| 305 | + images = scipy_module.fft.ifft(xp.multiply(scipy_module.fft.fft(images, axis=2), m_1), axis=2) |
308 | 306 |
|
309 | 307 | return images |
0 commit comments