diff --git a/src/parcels/_core/particlefile.py b/src/parcels/_core/particlefile.py index 58072f6e6..56044d271 100644 --- a/src/parcels/_core/particlefile.py +++ b/src/parcels/_core/particlefile.py @@ -61,6 +61,10 @@ class ParticleFile: It is either a numpy.timedelta64, a datimetime.timedelta object or a positive float (in seconds). compression : {"zstd", "gzip", "snappy", "brotli", None}, optional Compression algorithm to use for the Parquet file. Default is "zstd". + mode : {None, "w"}, optional + Writing behaviour. + - None (default): Write dataset, and raise an error if it already exists. + - "w": Write dataset, overwriting it. Returns ------- @@ -69,7 +73,11 @@ class ParticleFile: """ def __init__( - self, path: PathLike, outputdt, compression: Literal["zstd", "gzip", "snappy", "brotli", None] = "zstd" + self, + path: PathLike, + outputdt, + compression: Literal["zstd", "gzip", "snappy", "brotli", None] = "zstd", + mode: Literal[None, "w"] = None, ): if not isinstance(outputdt, (np.timedelta64, timedelta, float)): raise ValueError( @@ -92,9 +100,15 @@ def __init__( self._path = path # TODO v4: Consider https://arrow.apache.org/docs/python/getstarted.html#working-with-large-data - though a significant question becomes how to partition, perhaps using a particle variable "partition"? self._writer: pq.ParquetWriter | None = None + + if mode not in {None, "w"}: + raise ValueError(f"Invalid mode value {mode!r}. Expected one of None or 'w'.") + if path.exists(): - # TODO: Add logic for recovering/appending to existing parquet file - raise ValueError(f"{path=!r} already exists. Either delete this file or use a path that doesn't exist.") + if mode is None: + raise ValueError(f"{path=!r} already exists. Use mode='w' or use a new path.") + if mode == "w": + path.unlink() if not path.parent.exists(): raise ValueError(f"Folder location for {path=!r} does not exist. Create the folder location first.") diff --git a/tests/test_particlefile.py b/tests/test_particlefile.py index 0aa6e8b8b..9814dfe33 100755 --- a/tests/test_particlefile.py +++ b/tests/test_particlefile.py @@ -412,6 +412,30 @@ def test_particlefile_init(tmp_parquet): ParticleFile(tmp_parquet, outputdt=np.timedelta64(1, "s")) +def test_particlefile_init_existing_path_modes(fieldset, tmp_parquet): + pset = ParticleSet(fieldset, pclass=Particle, lon=0, lat=0) + + first_file = ParticleFile(tmp_parquet, outputdt=np.timedelta64(1, "s")) + pset.execute(DoNothing, runtime=np.timedelta64(10, "s"), dt=np.timedelta64(1, "s"), output_file=first_file) + + df_first = pd.read_parquet(tmp_parquet) + + with pytest.raises(ValueError, match="already exists"): + ParticleFile(tmp_parquet, outputdt=np.timedelta64(1, "s")) + + overwrite_file = ParticleFile(tmp_parquet, outputdt=np.timedelta64(1, "s"), mode="w") + pset.execute(DoNothing, runtime=np.timedelta64(10, "s"), dt=np.timedelta64(1, "s"), output_file=overwrite_file) + + df_overwrite = pd.read_parquet(tmp_parquet) + + assert len(df_first) == len(df_overwrite) + + +def test_particlefile_init_invalid_mode(tmp_parquet): + with pytest.raises(ValueError, match="Invalid mode value"): + ParticleFile(tmp_parquet, outputdt=np.timedelta64(1, "s"), mode="something-else") + + @pytest.mark.parametrize("name", ["path", "outputdt"]) def test_particlefile_readonly_attrs(tmp_parquet, name): pfile = ParticleFile(tmp_parquet, outputdt=np.timedelta64(1, "s"))