Skip to content

Commit 3ac4f49

Browse files
committed
Add tests for new colors parameter
1 parent 8fdc7b0 commit 3ac4f49

File tree

1 file changed

+124
-0
lines changed

1 file changed

+124
-0
lines changed

tests/test_accessor.py

Lines changed: 124 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -401,3 +401,127 @@ def test_imshow_animation_consistent_bounds(self) -> None:
401401
coloraxis = fig.layout.coloraxis
402402
assert coloraxis.cmin == 0.0
403403
assert coloraxis.cmax == 70.0
404+
405+
406+
class TestColorsParameter:
407+
"""Tests for the unified colors parameter."""
408+
409+
@pytest.fixture(autouse=True)
410+
def setup(self) -> None:
411+
"""Create test DataArrays."""
412+
self.da = xr.DataArray(
413+
np.random.rand(10, 3),
414+
dims=["time", "city"],
415+
coords={"city": ["A", "B", "C"]},
416+
)
417+
418+
def test_colors_list_sets_discrete_sequence(self) -> None:
419+
"""Test that a list of colors sets color_discrete_sequence."""
420+
fig = self.da.plotly.line(colors=["red", "blue", "green"])
421+
# Check that traces have the expected colors
422+
assert len(fig.data) == 3
423+
assert fig.data[0].line.color == "red"
424+
assert fig.data[1].line.color == "blue"
425+
assert fig.data[2].line.color == "green"
426+
427+
def test_colors_dict_sets_discrete_map(self) -> None:
428+
"""Test that a dict sets color_discrete_map."""
429+
fig = self.da.plotly.line(colors={"A": "red", "B": "blue", "C": "green"})
430+
# Traces should be colored according to the mapping
431+
assert len(fig.data) == 3
432+
# Find traces by name and check their color
433+
colors_by_name = {trace.name: trace.line.color for trace in fig.data}
434+
assert colors_by_name["A"] == "red"
435+
assert colors_by_name["B"] == "blue"
436+
assert colors_by_name["C"] == "green"
437+
438+
def test_colors_continuous_scale_string(self) -> None:
439+
"""Test that a continuous scale name sets color_continuous_scale."""
440+
da = xr.DataArray(
441+
np.random.rand(50, 2),
442+
dims=["point", "coord"],
443+
coords={"coord": ["x", "y"]},
444+
)
445+
fig = da.plotly.scatter(y="coord", x="point", color="value", colors="Viridis")
446+
# Plotly Express uses coloraxis in the layout for continuous scales
447+
# Check that the colorscale was applied to the coloraxis
448+
assert fig.layout.coloraxis.colorscale is not None
449+
colorscale = fig.layout.coloraxis.colorscale
450+
# Viridis should be in the colorscale definition
451+
assert any("viridis" in str(c).lower() for c in colorscale) or len(colorscale) > 0
452+
453+
def test_colors_qualitative_palette_string(self) -> None:
454+
"""Test that a qualitative palette name sets color_discrete_sequence."""
455+
import plotly.express as px
456+
457+
fig = self.da.plotly.line(colors="D3")
458+
# D3 palette should be applied - check first trace color is from D3
459+
d3_colors = px.colors.qualitative.D3
460+
assert fig.data[0].line.color in d3_colors
461+
462+
def test_colors_ignored_with_warning_when_px_kwargs_present(self) -> None:
463+
"""Test that colors is ignored with warning when color_* kwargs are present."""
464+
import warnings
465+
466+
with warnings.catch_warnings(record=True) as w:
467+
warnings.simplefilter("always")
468+
fig = self.da.plotly.line(
469+
colors="D3", color_discrete_sequence=["orange", "purple", "cyan"]
470+
)
471+
# Should have raised a warning
472+
assert len(w) == 1
473+
assert "colors" in str(w[0].message).lower()
474+
assert "ignored" in str(w[0].message).lower()
475+
# The explicit px_kwargs should take precedence
476+
assert fig.data[0].line.color == "orange"
477+
478+
def test_colors_none_uses_defaults(self) -> None:
479+
"""Test that colors=None uses Plotly defaults."""
480+
fig1 = self.da.plotly.line(colors=None)
481+
fig2 = self.da.plotly.line()
482+
# Both should produce the same result
483+
assert fig1.data[0].line.color == fig2.data[0].line.color
484+
485+
def test_colors_works_with_bar(self) -> None:
486+
"""Test colors parameter with bar chart."""
487+
fig = self.da.plotly.bar(colors=["#e41a1c", "#377eb8", "#4daf4a"])
488+
assert fig.data[0].marker.color == "#e41a1c"
489+
490+
def test_colors_works_with_area(self) -> None:
491+
"""Test colors parameter with area chart."""
492+
fig = self.da.plotly.area(colors=["red", "green", "blue"])
493+
assert len(fig.data) == 3
494+
495+
def test_colors_works_with_scatter(self) -> None:
496+
"""Test colors parameter with scatter plot."""
497+
fig = self.da.plotly.scatter(colors=["red", "green", "blue"])
498+
assert len(fig.data) == 3
499+
500+
def test_colors_works_with_imshow(self) -> None:
501+
"""Test colors parameter with imshow (continuous scale)."""
502+
da = xr.DataArray(np.random.rand(10, 10), dims=["y", "x"])
503+
fig = da.plotly.imshow(colors="RdBu")
504+
# Plotly Express uses coloraxis in the layout for continuous scales
505+
assert fig.layout.coloraxis.colorscale is not None
506+
colorscale = fig.layout.coloraxis.colorscale
507+
# RdBu should be in the colorscale definition
508+
assert any("rdbu" in str(c).lower() for c in colorscale) or len(colorscale) > 0
509+
510+
def test_colors_works_with_pie(self) -> None:
511+
"""Test colors parameter with pie chart."""
512+
da = xr.DataArray([30, 40, 30], dims=["category"], coords={"category": ["A", "B", "C"]})
513+
fig = da.plotly.pie(colors={"A": "red", "B": "blue", "C": "green"})
514+
assert isinstance(fig, go.Figure)
515+
516+
def test_colors_works_with_dataset(self) -> None:
517+
"""Test colors parameter works with Dataset accessor."""
518+
ds = xr.Dataset(
519+
{
520+
"temp": (["time"], np.random.rand(10)),
521+
"precip": (["time"], np.random.rand(10)),
522+
}
523+
)
524+
fig = ds.plotly.line(colors=["red", "blue"])
525+
assert len(fig.data) == 2
526+
assert fig.data[0].line.color == "red"
527+
assert fig.data[1].line.color == "blue"

0 commit comments

Comments
 (0)