Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 15 additions & 1 deletion src/post_processing/dataclass/data_aplose.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,21 @@ def __getitem__(self, item: int) -> Series:
return self.df.iloc[item]

def change_tz(self, tz: str | tzinfo) -> None:
"""Change the timezone of the DataFrame and of the begin and end Timestamps."""
"""Change the timezone of a DataAplose instance.

Examples
--------
>>> import pytz
>>> data = DataAplose(...)
>>> data.change_tz(pytz.timezone("Etc/GMT-2"))

>>> data = DataAplose(...)
>>> data.change_tz("UTC")

>>> data = DataAplose(...)
>>> data.change_tz("UTC+02:00")

"""
self.df["start_datetime"] = [
elem.tz_convert(tz)
for elem in self.df["start_datetime"]
Expand Down
16 changes: 12 additions & 4 deletions src/post_processing/utils/filtering_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import csv
from typing import TYPE_CHECKING

import pytz
from pandas import (
DataFrame,
Timedelta,
Expand All @@ -19,8 +20,6 @@
if TYPE_CHECKING:
from pathlib import Path

from dateutil.tz import tzoffset

from post_processing.dataclass.detection_filter import DetectionFilter


Expand Down Expand Up @@ -176,9 +175,18 @@ def get_dataset(df: DataFrame) -> list[str]:
return datasets if len(datasets) > 1 else datasets[0]


def get_timezone(df: DataFrame) -> tzoffset | list[tzoffset]:
def get_timezone(df: DataFrame):
"""Return timezone(s) from DataFrame."""
timezones = {ts.tz for ts in df["start_datetime"] if ts.tz is not None}

def get_canonical_tz(tz):
if hasattr(tz, "zone") and tz.zone:
return pytz.timezone(tz.zone)
if hasattr(tz, "key"):
return pytz.timezone(tz.key)
return pytz.UTC

timezones = {get_canonical_tz(ts.tzinfo) for ts in df["start_datetime"]}

if len(timezones) == 1:
return next(iter(timezones))
return list(timezones)
Expand Down
6 changes: 3 additions & 3 deletions tests/test_filtering_utils.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import csv
from datetime import timezone
from pathlib import Path

import pytest
import pytz
from pandas import DataFrame, Timedelta, Timestamp, date_range

from post_processing.utils.filtering_utils import (
Expand Down Expand Up @@ -170,10 +170,10 @@ def test_get_dataset(sample_df: DataFrame) -> None:

def test_get_timezone_single(sample_df: DataFrame) -> None:
tz = get_timezone(sample_df)
assert isinstance(tz, timezone)
assert tz == pytz.utc

# %% read DataFrame

# %% read DataFrame

def test_read_dataframe_comma_delimiter(tmp_path: Path) -> None:
csv_file = tmp_path / "test.csv"
Expand Down