Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 32 additions & 18 deletions parcels/particleset.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,35 +135,49 @@ def __init__(
lon.size == kwargs[kwvar].size
), f"{kwvar} and positions (lon, lat, depth) don't have the same lengths."

self._data = {
"lon": lon.astype(lonlatdepth_dtype),
"lat": lat.astype(lonlatdepth_dtype),
"depth": depth.astype(lonlatdepth_dtype),
"time": time,
"dt": np.timedelta64(1, "ns") * np.ones(len(trajectory_ids)),
# "ei": (["trajectory", "ngrid"], np.zeros((len(trajectory_ids), len(fieldset.gridset)), dtype=np.int32)),
"state": np.zeros((len(trajectory_ids)), dtype=np.int32),
"lon_nextloop": lon.astype(lonlatdepth_dtype),
"lat_nextloop": lat.astype(lonlatdepth_dtype),
"depth_nextloop": depth.astype(lonlatdepth_dtype),
"time_nextloop": time,
"trajectory": trajectory_ids,
}
self._ptype = pclass.getPType()
self._ds = xr.Dataset(
{
"lon": (["trajectory"], lon.astype(lonlatdepth_dtype)),
"lat": (["trajectory"], lat.astype(lonlatdepth_dtype)),
"depth": (["trajectory"], depth.astype(lonlatdepth_dtype)),
"time": (["trajectory"], time),
"dt": (["trajectory"], np.timedelta64(1, "ns") * np.ones(len(trajectory_ids))),
"ei": (["trajectory", "ngrid"], np.zeros((len(trajectory_ids), len(fieldset.gridset)), dtype=np.int32)),
"state": (["trajectory"], np.zeros((len(trajectory_ids)), dtype=np.int32)),
"lon_nextloop": (["trajectory"], lon.astype(lonlatdepth_dtype)),
"lat_nextloop": (["trajectory"], lat.astype(lonlatdepth_dtype)),
"depth_nextloop": (["trajectory"], depth.astype(lonlatdepth_dtype)),
"time_nextloop": (["trajectory"], time),
},
coords={
"trajectory": ("trajectory", trajectory_ids),
},
attrs={
"ngrid": len(fieldset.gridset),
"ptype": pclass.getPType(),
},
)
# add extra fields from the custom Particle class
for v in pclass.__dict__.values():
if isinstance(v, Variable):
if isinstance(v.initial, attrgetter):
initial = v.initial(self)
initial = v.initial(self._ds).values
else:
initial = v.initial * np.ones(len(trajectory_ids), dtype=v.dtype)
self._data[v.name] = initial
self._ds[v.name] = (["trajectory"], initial)

# update initial values provided on ParticleSet creation
for kwvar, kwval in kwargs.items():
if not hasattr(pclass, kwvar):
raise RuntimeError(f"Particle class does not have Variable {kwvar}")
self._data[kwvar][:] = kwval
self._ds[kwvar][:] = kwval

# also keep a struct of numpy arrays for faster access (see parcels-benchmarks/pull/1)
self._data = {}
for v in self._ds.keys():
self._data[v] = self._ds[v].data
self._data["trajectory"] = self._ds["trajectory"].data
self._ptype = self._ds.attrs["ptype"]

self._kernel = None

Expand Down
23 changes: 23 additions & 0 deletions tests/v4/test_particleset_execute.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,29 @@ def PythonFail(particle, fieldset, time): # pragma: no cover
assert all([time == fieldset.time_interval.left + np.timedelta64(0, "s") for time in pset.time[1:]])


def test_pset_update_particles_in_dataset_and_dict(fieldset, npart=10):
lon_start = np.linspace(0, 1, npart)
lat_start = np.linspace(1, 0, npart)
pset = ParticleSet(fieldset, lon=np.linspace(0, 1, npart), lat=np.linspace(1, 0, npart))

def UpdateParticle(particle, fieldset, time): # pragma: no cover
particle.lon += 0.1
particle.lat -= 0.1

pset.execute(pset.Kernel(UpdateParticle), runtime=np.timedelta64(10, "s"), dt=np.timedelta64(1, "s"))
assert np.allclose(pset.lon, lon_start + 1, atol=1e-5)
assert np.allclose(pset.lat, lat_start - 1, atol=1e-5)
assert all(pset._data["lon"] == pset._ds["lon"].data)


def test_pset_remove_indices_in_dataset_and_dict(fieldset, npart=10):
pset = ParticleSet(fieldset, lon=np.linspace(0, 1, npart), lat=np.linspace(1, 0, npart))
assert len(pset._ds.lon) == len(pset._data["lon"]) == npart

pset.remove_indices([0])
assert len(pset._ds.lon) == len(pset._data["lon"]) == npart - 1


@pytest.mark.parametrize("verbose_progress", [True, False])
def test_uxstommelgyre_pset_execute(verbose_progress):
ds = datasets_unstructured["stommel_gyre_delaunay"]
Expand Down
Loading