|
| 1 | +import itertools |
| 2 | + |
| 3 | +import numpy as np |
| 4 | +import xarray as xr |
| 5 | + |
| 6 | +from parcels.application_kernels.interpolation import XTriCurviLinear |
| 7 | +from parcels.field import Field |
| 8 | +from parcels.xgcm import Grid |
| 9 | +from parcels.xgrid import XGrid |
| 10 | + |
| 11 | + |
| 12 | +def get_unit_square_ds(): |
| 13 | + T, Z, Y, X = 2, 2, 2, 2 |
| 14 | + TIME = xr.date_range("2000", "2001", T) |
| 15 | + |
| 16 | + _, data_z, data_y, data_x = np.meshgrid( |
| 17 | + np.zeros(T), |
| 18 | + np.linspace(0, 1, Z), |
| 19 | + np.linspace(0, 1, Y), |
| 20 | + np.linspace(0, 1, X), |
| 21 | + indexing="ij", |
| 22 | + ) |
| 23 | + |
| 24 | + return xr.Dataset( |
| 25 | + { |
| 26 | + "0 to 1 in X": (["time", "ZG", "YG", "XG"], data_x), |
| 27 | + "0 to 1 in Y": (["time", "ZG", "YG", "XG"], data_y), |
| 28 | + "0 to 1 in Z": (["time", "ZG", "YG", "XG"], data_z), |
| 29 | + "0 to 1 in X (T-points)": (["time", "ZC", "YC", "XC"], data_x + 0.5), |
| 30 | + "0 to 1 in Y (T-points)": (["time", "ZC", "YC", "XC"], data_y + 0.5), |
| 31 | + "0 to 1 in Z (T-points)": (["time", "ZC", "YC", "XC"], data_z + 0.5), |
| 32 | + "0 to 1 in X (U velocity C-grid points)": (["time", "ZC", "YC", "XG"], data_x), |
| 33 | + "0 to 1 in Y (V velocity C-grid points)": (["time", "ZC", "YG", "XC"], data_y), |
| 34 | + }, |
| 35 | + coords={ |
| 36 | + "XG": ( |
| 37 | + ["XG"], |
| 38 | + np.arange(0, X), |
| 39 | + {"axis": "X", "c_grid_axis_shift": -0.5}, |
| 40 | + ), |
| 41 | + "XC": (["XC"], np.arange(0, X) + 0.5, {"axis": "X"}), |
| 42 | + "YG": ( |
| 43 | + ["YG"], |
| 44 | + np.arange(0, Y), |
| 45 | + {"axis": "Y", "c_grid_axis_shift": -0.5}, |
| 46 | + ), |
| 47 | + "YC": ( |
| 48 | + ["YC"], |
| 49 | + np.arange(0, Y) + 0.5, |
| 50 | + {"axis": "Y"}, |
| 51 | + ), |
| 52 | + "ZG": ( |
| 53 | + ["ZG"], |
| 54 | + np.arange(Z), |
| 55 | + {"axis": "Z", "c_grid_axis_shift": -0.5}, |
| 56 | + ), |
| 57 | + "ZC": ( |
| 58 | + ["ZC"], |
| 59 | + np.arange(Z) + 0.5, |
| 60 | + {"axis": "Z"}, |
| 61 | + ), |
| 62 | + "lon": (["XG"], np.arange(0, X)), |
| 63 | + "lat": (["YG"], np.arange(0, Y)), |
| 64 | + "depth": (["ZG"], np.arange(Z)), |
| 65 | + "time": (["time"], TIME, {"axis": "T"}), |
| 66 | + }, |
| 67 | + ) |
| 68 | + |
| 69 | + |
| 70 | +def test_XTriRectiLinear_interpolation(): |
| 71 | + ds = get_unit_square_ds() |
| 72 | + grid = XGrid(Grid(ds)) |
| 73 | + field = Field("test", ds["0 to 1 in X"], grid=grid, interp_method=XTriCurviLinear) |
| 74 | + left = field.time_interval.left |
| 75 | + |
| 76 | + epsilon = 1e-6 |
| 77 | + N = 4 |
| 78 | + |
| 79 | + # Interpolate wrt. items on f-points |
| 80 | + for x, y, z in itertools.product(np.linspace(0 + epsilon, 1 - epsilon, N), repeat=3): |
| 81 | + assert np.isclose(x, field.eval(left, z, y, x)), f"Failed for x={x}, y={y}, z={z}" |
| 82 | + |
| 83 | + field = Field("test", ds["0 to 1 in Y"], grid=grid, interp_method=XTriCurviLinear) |
| 84 | + for x, y, z in itertools.product(np.linspace(0 + epsilon, 1 - epsilon, N), repeat=3): |
| 85 | + assert np.isclose(y, field.eval(left, z, y, x)), f"Failed for x={x}, y={y}, z={z}" |
| 86 | + |
| 87 | + field = Field("test", ds["0 to 1 in Z"], grid=grid, interp_method=XTriCurviLinear) |
| 88 | + for x, y, z in itertools.product(np.linspace(0 + epsilon, 1 - epsilon, N), repeat=3): |
| 89 | + assert np.isclose(z, field.eval(left, z, y, x)), f"Failed for x={x}, y={y}, z={z}" |
| 90 | + |
| 91 | + # Interpolate wrt. items on T-points |
| 92 | + field = Field("test", ds["0 to 1 in X (T-points)"], grid=grid, interp_method=XTriCurviLinear) |
| 93 | + for x, y, z in itertools.product(np.linspace(0.5 + epsilon, 1 - epsilon, N), repeat=3): |
| 94 | + assert np.isclose(x, field.eval(left, z, y, x)), f"Failed for x={x}, y={y}, z={z}" |
| 95 | + |
| 96 | + field = Field("test", ds["0 to 1 in Y (T-points)"], grid=grid, interp_method=XTriCurviLinear) |
| 97 | + for x, y, z in itertools.product(np.linspace(0.5 + epsilon, 1 - epsilon, N), repeat=3): |
| 98 | + assert np.isclose(y, field.eval(left, z, y, x)), f"Failed for x={x}, y={y}, z={z}" |
| 99 | + |
| 100 | + field = Field("test", ds["0 to 1 in Z (T-points)"], grid=grid, interp_method=XTriCurviLinear) |
| 101 | + for x, y, z in itertools.product(np.linspace(0.5 + epsilon, 1 - epsilon, N), repeat=3): |
| 102 | + assert np.isclose(z, field.eval(left, z, y, x)), f"Failed for x={x}, y={y}, z={z}" |
0 commit comments