Skip to content

Commit 3c1b9d9

Browse files
authored
ci: Add mypy to ci and update types in code (#25)
1 parent 4caf3f9 commit 3c1b9d9

File tree

6 files changed

+22
-16
lines changed

6 files changed

+22
-16
lines changed

.github/workflows/ci.yml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,9 @@ jobs:
3131
- name: Format check
3232
run: uv run ruff format --check .
3333

34+
- name: Type check
35+
run: uv run mypy xarray_plotly
36+
3437
- name: Test
3538
run: uv run pytest --cov=xarray_plotly --cov-report=xml
3639

xarray_plotly/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -111,5 +111,5 @@ def xpx(data: DataArray | Dataset) -> DataArrayPlotlyAccessor | DatasetPlotlyAcc
111111
__version__ = version("xarray_plotly")
112112

113113
# Register the accessors
114-
register_dataarray_accessor("plotly")(DataArrayPlotlyAccessor)
115-
register_dataset_accessor("plotly")(DatasetPlotlyAccessor)
114+
register_dataarray_accessor("plotly")(DataArrayPlotlyAccessor) # type: ignore[no-untyped-call]
115+
register_dataset_accessor("plotly")(DatasetPlotlyAccessor) # type: ignore[no-untyped-call]

xarray_plotly/common.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,7 @@ def to_dataframe(darray: DataArray) -> pd.DataFrame:
159159
return df
160160

161161

162-
def _get_label_from_attrs(attrs: dict, fallback: str) -> str:
162+
def _get_label_from_attrs(attrs: dict[str, object], fallback: str) -> str:
163163
"""Extract a label from xarray attributes based on current config.
164164
165165
Args:

xarray_plotly/config.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
from contextlib import contextmanager
1010
from dataclasses import dataclass, field
11-
from typing import TYPE_CHECKING, Any
11+
from typing import TYPE_CHECKING, Any, cast
1212

1313
if TYPE_CHECKING:
1414
from collections.abc import Generator
@@ -166,12 +166,12 @@ def set_options(
166166
yield
167167
finally:
168168
# Restore old values (modify in place)
169-
_options.label_use_long_name = old_values["label_use_long_name"]
170-
_options.label_use_standard_name = old_values["label_use_standard_name"]
171-
_options.label_include_units = old_values["label_include_units"]
172-
_options.label_unit_format = old_values["label_unit_format"]
173-
_options.slot_orders = old_values["slot_orders"]
174-
_options.dataset_variable_position = old_values["dataset_variable_position"]
169+
_options.label_use_long_name = cast("bool", old_values["label_use_long_name"])
170+
_options.label_use_standard_name = cast("bool", old_values["label_use_standard_name"])
171+
_options.label_include_units = cast("bool", old_values["label_include_units"])
172+
_options.label_unit_format = cast("str", old_values["label_unit_format"])
173+
_options.slot_orders = cast("dict[str, tuple[str, ...]]", old_values["slot_orders"])
174+
_options.dataset_variable_position = cast("int", old_values["dataset_variable_position"])
175175

176176

177177
def notebook(renderer: str = "notebook") -> None:

xarray_plotly/figures.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,15 +5,15 @@
55
from __future__ import annotations
66

77
import copy
8-
from typing import TYPE_CHECKING
8+
from typing import TYPE_CHECKING, Any
99

1010
if TYPE_CHECKING:
1111
from collections.abc import Iterator
1212

1313
import plotly.graph_objects as go
1414

1515

16-
def _iter_all_traces(fig: go.Figure) -> Iterator:
16+
def _iter_all_traces(fig: go.Figure) -> Iterator[Any]:
1717
"""Iterate over all traces in a figure, including animation frames.
1818
1919
Yields traces from fig.data first, then from each frame in fig.frames.
@@ -107,7 +107,7 @@ def _merge_frames(
107107
overlays: list[go.Figure],
108108
base_trace_count: int,
109109
overlay_trace_counts: list[int],
110-
) -> list:
110+
) -> list[go.Frame]:
111111
"""Merge animation frames from base and overlay figures.
112112
113113
Args:
@@ -360,7 +360,7 @@ def _merge_secondary_y_frames(
360360
base: go.Figure,
361361
secondary: go.Figure,
362362
y_mapping: dict[str, str],
363-
) -> list:
363+
) -> list[go.Frame]:
364364
"""Merge animation frames for secondary y-axis combination.
365365
366366
Args:
@@ -411,7 +411,9 @@ def _merge_secondary_y_frames(
411411
return merged_frames
412412

413413

414-
def update_traces(fig: go.Figure, selector: dict | None = None, **kwargs) -> go.Figure:
414+
def update_traces(
415+
fig: go.Figure, selector: dict[str, Any] | None = None, **kwargs: Any
416+
) -> go.Figure:
415417
"""Update traces in both base figure and all animation frames.
416418
417419
Plotly's `update_traces()` only updates the base figure, not animation frames.

xarray_plotly/plotting.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from typing import TYPE_CHECKING, Any
99

1010
import numpy as np
11+
import numpy.typing as npt
1112
import plotly.express as px
1213

1314
from xarray_plotly.common import (
@@ -191,7 +192,7 @@ def bar(
191192
)
192193

193194

194-
def _classify_trace_sign(y_values: np.ndarray) -> str:
195+
def _classify_trace_sign(y_values: npt.ArrayLike) -> str:
195196
"""Classify a trace as 'positive', 'negative', or 'mixed' based on its values."""
196197
y_arr = np.asarray(y_values)
197198
y_clean = y_arr[np.isfinite(y_arr) & (np.abs(y_arr) > 1e-9)]

0 commit comments

Comments
 (0)