diff --git a/src/post_processing/dataclass/data_aplose.py b/src/post_processing/dataclass/data_aplose.py index 0939966..59cb01b 100644 --- a/src/post_processing/dataclass/data_aplose.py +++ b/src/post_processing/dataclass/data_aplose.py @@ -167,7 +167,21 @@ def __getitem__(self, item: int) -> Series: return self.df.iloc[item] def change_tz(self, tz: str | tzinfo) -> None: - """Change the timezone of the DataFrame and of the begin and end Timestamps.""" + """Change the timezone of a DataAplose instance. + + Examples + -------- + >>> import pytz + >>> data = DataAplose(...) + >>> data.change_tz(pytz.timezone("Etc/GMT-2")) + + >>> data = DataAplose(...) + >>> data.change_tz("UTC") + + >>> data = DataAplose(...) + >>> data.change_tz("UTC+02:00") + + """ self.df["start_datetime"] = [ elem.tz_convert(tz) for elem in self.df["start_datetime"] diff --git a/src/post_processing/utils/filtering_utils.py b/src/post_processing/utils/filtering_utils.py index db1484f..5083c7d 100644 --- a/src/post_processing/utils/filtering_utils.py +++ b/src/post_processing/utils/filtering_utils.py @@ -6,6 +6,7 @@ import csv from typing import TYPE_CHECKING +import pytz from pandas import ( DataFrame, Timedelta, @@ -19,8 +20,6 @@ if TYPE_CHECKING: from pathlib import Path - from dateutil.tz import tzoffset - from post_processing.dataclass.detection_filter import DetectionFilter @@ -176,9 +175,18 @@ def get_dataset(df: DataFrame) -> list[str]: return datasets if len(datasets) > 1 else datasets[0] -def get_timezone(df: DataFrame) -> tzoffset | list[tzoffset]: +def get_timezone(df: DataFrame): """Return timezone(s) from DataFrame.""" - timezones = {ts.tz for ts in df["start_datetime"] if ts.tz is not None} + + def get_canonical_tz(tz): + if hasattr(tz, "zone") and tz.zone: + return pytz.timezone(tz.zone) + if hasattr(tz, "key"): + return pytz.timezone(tz.key) + return pytz.UTC + + timezones = {get_canonical_tz(ts.tzinfo) for ts in df["start_datetime"]} + if len(timezones) == 1: return next(iter(timezones)) return list(timezones) diff --git a/tests/test_filtering_utils.py b/tests/test_filtering_utils.py index f8a3745..622ad2a 100644 --- a/tests/test_filtering_utils.py +++ b/tests/test_filtering_utils.py @@ -1,8 +1,8 @@ import csv -from datetime import timezone from pathlib import Path import pytest +import pytz from pandas import DataFrame, Timedelta, Timestamp, date_range from post_processing.utils.filtering_utils import ( @@ -170,10 +170,10 @@ def test_get_dataset(sample_df: DataFrame) -> None: def test_get_timezone_single(sample_df: DataFrame) -> None: tz = get_timezone(sample_df) - assert isinstance(tz, timezone) + assert tz == pytz.utc -# %% read DataFrame +# %% read DataFrame def test_read_dataframe_comma_delimiter(tmp_path: Path) -> None: csv_file = tmp_path / "test.csv"