|
| 1 | +""" |
| 2 | + fourier.py test |
| 3 | +""" |
| 4 | + |
1 | 5 | import unittest |
2 | 6 | import numpy as np |
3 | 7 |
|
4 | 8 | from context import fourier |
5 | 9 |
|
6 | | -def fft1(x): |
7 | | - L = len(x) |
| 10 | +def fft1(test_input): |
| 11 | + """ |
| 12 | + A simple fft implementation. |
| 13 | + This is the same function with which numpy tests their FFT funciton. |
| 14 | + """ |
| 15 | + L = len(test_input) |
8 | 16 | phase = -2j*np.pi*(np.arange(L)/float(L)) |
9 | 17 | phase = np.arange(L).reshape(-1, 1) * phase |
10 | | - return np.sum(x*np.exp(phase), axis=1) |
| 18 | + return np.sum(test_input*np.exp(phase), axis=1) |
11 | 19 |
|
12 | 20 | class TestFft(unittest.TestCase): |
13 | 21 | """ |
14 | 22 | Testclass for testing fourier transforms |
15 | 23 | """ |
16 | 24 |
|
17 | 25 | def test_dft(self): |
18 | | - x = np.random.random(30) + 1j*np.random.random(30) |
19 | | - self.assertAlmostEqual(fft1(x), fourier.DFT_slow(x)) |
| 26 | + """ |
| 27 | + Test for DFT function |
| 28 | + """ |
| 29 | + test_input = np.random.random(32) + 1j*np.random.random(32) |
| 30 | + test_input = np.asarray(test_input) |
| 31 | + self.assertAlmostEqual(fft1(test_input).all(), fourier.DFT_slow(test_input).all()) |
| 32 | + |
20 | 33 |
|
21 | 34 | def test_fft(self): |
22 | | - x = np.random.random(30) + 1j*np.random.random(30) |
23 | | - self.assertAlmostEqual(fft1(x), fourier.FFT_vectorized(x)) |
| 35 | + """ |
| 36 | + Test for FFT_vectorized function |
| 37 | + """ |
| 38 | + test_input = np.random.random(32) + 1j*np.random.random(32) |
| 39 | + test_input = np.asarray(test_input) |
| 40 | + self.assertAlmostEqual(fft1(test_input).all(), fourier.FFT_vectorized(test_input).all()) |
| 41 | + |
| 42 | + |
| 43 | + def test_wrong_array_size(self): |
| 44 | + """ |
| 45 | + Tests for correct input for the FFT_vectorized function. |
| 46 | + Fails if the input is not a power of 2. |
| 47 | + """ |
| 48 | + test_input = np.random.random(30) + 1j*np.random.random(30) |
| 49 | + with self.assertRaises(ValueError): |
| 50 | + fourier.FFT_vectorized(test_input) |
| 51 | + |
| 52 | + |
| 53 | + def test_fftfreq(self): |
| 54 | + """ |
| 55 | + Test for fftfreq function |
| 56 | + """ |
| 57 | + test_input = [0, 1, 2, 3, 4, -4, -3, -2, -1] |
| 58 | + test_input = np.asarray(test_input) |
| 59 | + self.assertAlmostEqual(9*fourier.fftfreq(9).all(), test_input.all()) |
| 60 | + self.assertAlmostEqual(9*np.pi*fourier.fftfreq(9, np.pi).all(), test_input.all()) |
| 61 | + test_input = [0, 1, 2, 3, 4, -5, -4, -3, -2, -1] |
| 62 | + test_input = np.asarray(test_input) |
| 63 | + self.assertAlmostEqual(10*fourier.fftfreq(10).all(), test_input.all()) |
| 64 | + self.assertAlmostEqual(10*np.pi*fourier.fftfreq(10, np.pi).all(), test_input.all()) |
| 65 | + |
24 | 66 |
|
25 | 67 | if __name__ == '__main__': |
26 | 68 | unittest.main() |
0 commit comments