From 86e0b54237fe7e772740c6f9fc88c1fc364eb216 Mon Sep 17 00:00:00 2001 From: Erik van Sebille Date: Mon, 21 Jul 2025 09:01:53 +0200 Subject: [PATCH 1/4] Adding both a datstruct and a dict for the particledata Following @VeckoTheGecko's suggestion at https://github.com/OceanParcels/parcels-benchmarks/pull/1#issuecomment-3089184625 --- parcels/particleset.py | 50 ++++++++++++++++++---------- tests/v4/test_particleset_execute.py | 14 ++++++++ 2 files changed, 46 insertions(+), 18 deletions(-) diff --git a/parcels/particleset.py b/parcels/particleset.py index f018a356d..62d22c7e6 100644 --- a/parcels/particleset.py +++ b/parcels/particleset.py @@ -135,35 +135,49 @@ def __init__( lon.size == kwargs[kwvar].size ), f"{kwvar} and positions (lon, lat, depth) don't have the same lengths." - self._data = { - "lon": lon.astype(lonlatdepth_dtype), - "lat": lat.astype(lonlatdepth_dtype), - "depth": depth.astype(lonlatdepth_dtype), - "time": time, - "dt": np.timedelta64(1, "ns") * np.ones(len(trajectory_ids)), - # "ei": (["trajectory", "ngrid"], np.zeros((len(trajectory_ids), len(fieldset.gridset)), dtype=np.int32)), - "state": np.zeros((len(trajectory_ids)), dtype=np.int32), - "lon_nextloop": lon.astype(lonlatdepth_dtype), - "lat_nextloop": lat.astype(lonlatdepth_dtype), - "depth_nextloop": depth.astype(lonlatdepth_dtype), - "time_nextloop": time, - "trajectory": trajectory_ids, - } - self._ptype = pclass.getPType() + self._ds = xr.Dataset( + { + "lon": (["trajectory"], lon.astype(lonlatdepth_dtype)), + "lat": (["trajectory"], lat.astype(lonlatdepth_dtype)), + "depth": (["trajectory"], depth.astype(lonlatdepth_dtype)), + "time": (["trajectory"], time), + "dt": (["trajectory"], np.timedelta64(1, "ns") * np.ones(len(trajectory_ids))), + "ei": (["trajectory", "ngrid"], np.zeros((len(trajectory_ids), len(fieldset.gridset)), dtype=np.int32)), + "state": (["trajectory"], np.zeros((len(trajectory_ids)), dtype=np.int32)), + "lon_nextloop": (["trajectory"], lon.astype(lonlatdepth_dtype)), + "lat_nextloop": (["trajectory"], lat.astype(lonlatdepth_dtype)), + "depth_nextloop": (["trajectory"], depth.astype(lonlatdepth_dtype)), + "time_nextloop": (["trajectory"], time), + }, + coords={ + "trajectory": ("trajectory", trajectory_ids), + }, + attrs={ + "ngrid": len(fieldset.gridset), + "ptype": pclass.getPType(), + }, + ) # add extra fields from the custom Particle class for v in pclass.__dict__.values(): if isinstance(v, Variable): if isinstance(v.initial, attrgetter): - initial = v.initial(self) + initial = v.initial(self).values else: initial = v.initial * np.ones(len(trajectory_ids), dtype=v.dtype) - self._data[v.name] = initial + self._ds[v.name] = (["trajectory"], initial) # update initial values provided on ParticleSet creation for kwvar, kwval in kwargs.items(): if not hasattr(pclass, kwvar): raise RuntimeError(f"Particle class does not have Variable {kwvar}") - self._data[kwvar][:] = kwval + self._ds[kwvar][:] = kwval + + # also keep a struct of numpy arrays for faster access (see parcels-benchmarks/pull/1) + self._data = {} + for v in self._ds.keys(): + self._data[v] = self._ds[v].data + self._data["trajectory"] = self._ds["trajectory"].data + self._ptype = self._ds.attrs["ptype"] self._kernel = None diff --git a/tests/v4/test_particleset_execute.py b/tests/v4/test_particleset_execute.py index f63210add..4b751e1ff 100644 --- a/tests/v4/test_particleset_execute.py +++ b/tests/v4/test_particleset_execute.py @@ -116,6 +116,20 @@ def PythonFail(particle, fieldset, time): # pragma: no cover assert all([time == fieldset.time_interval.left + np.timedelta64(0, "s") for time in pset.time[1:]]) +def test_pset_update_particle(fieldset, npart=10): + lon_start = np.linspace(0, 1, npart) + lat_start = np.linspace(1, 0, npart) + pset = ParticleSet(fieldset, lon=np.linspace(0, 1, npart), lat=np.linspace(1, 0, npart)) + + def UpdateParticle(particle, fieldset, time): # pragma: no cover + particle.lon += 0.1 + particle.lat -= 0.1 + + pset.execute(pset.Kernel(UpdateParticle), runtime=np.timedelta64(10, "s"), dt=np.timedelta64(1, "s")) + assert np.allclose(pset.lon, lon_start + 1, atol=1e-5) + assert np.allclose(pset.lat, lat_start - 1, atol=1e-5) + + @pytest.mark.parametrize("verbose_progress", [True, False]) def test_uxstommelgyre_pset_execute(verbose_progress): ds = datasets_unstructured["stommel_gyre_delaunay"] From 486765d6c74920635cca3ff41e99468622aac01f Mon Sep 17 00:00:00 2001 From: Erik van Sebille Date: Mon, 21 Jul 2025 09:10:28 +0200 Subject: [PATCH 2/4] Adding one extra assert to check dataset and dict equivalence --- tests/v4/test_particleset_execute.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/v4/test_particleset_execute.py b/tests/v4/test_particleset_execute.py index 4b751e1ff..8c9cb8b00 100644 --- a/tests/v4/test_particleset_execute.py +++ b/tests/v4/test_particleset_execute.py @@ -116,7 +116,7 @@ def PythonFail(particle, fieldset, time): # pragma: no cover assert all([time == fieldset.time_interval.left + np.timedelta64(0, "s") for time in pset.time[1:]]) -def test_pset_update_particle(fieldset, npart=10): +def test_pset_update_particles_in_dataset_and_dict(fieldset, npart=10): lon_start = np.linspace(0, 1, npart) lat_start = np.linspace(1, 0, npart) pset = ParticleSet(fieldset, lon=np.linspace(0, 1, npart), lat=np.linspace(1, 0, npart)) @@ -128,6 +128,7 @@ def UpdateParticle(particle, fieldset, time): # pragma: no cover pset.execute(pset.Kernel(UpdateParticle), runtime=np.timedelta64(10, "s"), dt=np.timedelta64(1, "s")) assert np.allclose(pset.lon, lon_start + 1, atol=1e-5) assert np.allclose(pset.lat, lat_start - 1, atol=1e-5) + assert all(pset._data["lon"] == pset._ds["lon"].data) @pytest.mark.parametrize("verbose_progress", [True, False]) From 80d9749a7838e3a9ba14fc37e33878624b7d2200 Mon Sep 17 00:00:00 2001 From: Erik van Sebille Date: Mon, 21 Jul 2025 09:41:46 +0200 Subject: [PATCH 3/4] Fixing bug in attrgetr for dataset --- parcels/particleset.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/parcels/particleset.py b/parcels/particleset.py index 62d22c7e6..917f52531 100644 --- a/parcels/particleset.py +++ b/parcels/particleset.py @@ -161,7 +161,7 @@ def __init__( for v in pclass.__dict__.values(): if isinstance(v, Variable): if isinstance(v.initial, attrgetter): - initial = v.initial(self).values + initial = v.initial(self._ds).values else: initial = v.initial * np.ones(len(trajectory_ids), dtype=v.dtype) self._ds[v.name] = (["trajectory"], initial) From e70b9f920a4afcbd24e9f39657e5244bdf1de02a Mon Sep 17 00:00:00 2001 From: Erik van Sebille Date: Mon, 21 Jul 2025 14:11:51 +0200 Subject: [PATCH 4/4] Adding failing deletion test Adding @VeckoTheGecko's failing test showing that the dict-of-numpys does not track the xarray dataset anymore after a deletion --- tests/v4/test_particleset_execute.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/tests/v4/test_particleset_execute.py b/tests/v4/test_particleset_execute.py index 8c9cb8b00..b7368e841 100644 --- a/tests/v4/test_particleset_execute.py +++ b/tests/v4/test_particleset_execute.py @@ -131,6 +131,14 @@ def UpdateParticle(particle, fieldset, time): # pragma: no cover assert all(pset._data["lon"] == pset._ds["lon"].data) +def test_pset_remove_indices_in_dataset_and_dict(fieldset, npart=10): + pset = ParticleSet(fieldset, lon=np.linspace(0, 1, npart), lat=np.linspace(1, 0, npart)) + assert len(pset._ds.lon) == len(pset._data["lon"]) == npart + + pset.remove_indices([0]) + assert len(pset._ds.lon) == len(pset._data["lon"]) == npart - 1 + + @pytest.mark.parametrize("verbose_progress", [True, False]) def test_uxstommelgyre_pset_execute(verbose_progress): ds = datasets_unstructured["stommel_gyre_delaunay"]