Skip to content

Commit 4caf3f9

Browse files
authored
feat: colors parameter for easy color assignment (#24)
* Add new Colors and colors parameter to all methods * Add tests for new colors parameter * Update notebook to show off new colors parameter * Improve test
1 parent 9904fd3 commit 4caf3f9

File tree

6 files changed

+381
-5
lines changed

6 files changed

+381
-5
lines changed

docs/examples/kwargs.ipynb

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -159,6 +159,71 @@
159159
"xpx(change).imshow(color_continuous_scale=\"RdBu_r\", color_continuous_midpoint=0)"
160160
]
161161
},
162+
{
163+
"cell_type": "markdown",
164+
"metadata": {},
165+
"source": [
166+
"## colors (unified parameter)\n",
167+
"\n",
168+
"The `colors` parameter provides a simpler way to set colors without remembering the exact Plotly parameter name. It automatically maps to the correct parameter based on the input type:\n",
169+
"\n",
170+
"| Input | Maps To |\n",
171+
"|-------|---------|\n",
172+
"| `\"Viridis\"` (continuous scale name) | `color_continuous_scale` |\n",
173+
"| `\"D3\"` (qualitative palette name) | `color_discrete_sequence` |\n",
174+
"| `[\"red\", \"blue\"]` (list) | `color_discrete_sequence` |\n",
175+
"| `{\"A\": \"red\"}` (dict) | `color_discrete_map` |"
176+
]
177+
},
178+
{
179+
"cell_type": "code",
180+
"execution_count": null,
181+
"metadata": {},
182+
"outputs": [],
183+
"source": [
184+
"# Named qualitative palette\n",
185+
"xpx(stocks).line(colors=\"D3\")"
186+
]
187+
},
188+
{
189+
"cell_type": "code",
190+
"execution_count": null,
191+
"metadata": {},
192+
"outputs": [],
193+
"source": [
194+
"# List of custom colors\n",
195+
"xpx(stocks).line(colors=[\"#E63946\", \"#457B9D\", \"#2A9D8F\", \"#E9C46A\", \"#F4A261\"])"
196+
]
197+
},
198+
{
199+
"cell_type": "code",
200+
"execution_count": null,
201+
"metadata": {},
202+
"outputs": [],
203+
"source": [
204+
"# Dict for explicit mapping\n",
205+
"xpx(stocks).line(\n",
206+
" colors={\n",
207+
" \"GOOG\": \"red\",\n",
208+
" \"AAPL\": \"blue\",\n",
209+
" \"AMZN\": \"green\",\n",
210+
" \"FB\": \"purple\",\n",
211+
" \"NFLX\": \"orange\",\n",
212+
" \"MSFT\": \"brown\",\n",
213+
" }\n",
214+
")"
215+
]
216+
},
217+
{
218+
"cell_type": "code",
219+
"execution_count": null,
220+
"metadata": {},
221+
"outputs": [],
222+
"source": [
223+
"# Continuous scale for heatmaps\n",
224+
"xpx(stocks).imshow(colors=\"Plasma\")"
225+
]
226+
},
162227
{
163228
"cell_type": "markdown",
164229
"metadata": {},

tests/test_accessor.py

Lines changed: 125 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -401,3 +401,128 @@ def test_imshow_animation_consistent_bounds(self) -> None:
401401
coloraxis = fig.layout.coloraxis
402402
assert coloraxis.cmin == 0.0
403403
assert coloraxis.cmax == 70.0
404+
405+
406+
class TestColorsParameter:
407+
"""Tests for the unified colors parameter."""
408+
409+
@pytest.fixture(autouse=True)
410+
def setup(self) -> None:
411+
"""Create test DataArrays."""
412+
self.da = xr.DataArray(
413+
np.random.rand(10, 3),
414+
dims=["time", "city"],
415+
coords={"city": ["A", "B", "C"]},
416+
)
417+
418+
def test_colors_list_sets_discrete_sequence(self) -> None:
419+
"""Test that a list of colors sets color_discrete_sequence."""
420+
fig = self.da.plotly.line(colors=["red", "blue", "green"])
421+
# Check that traces have the expected colors
422+
assert len(fig.data) == 3
423+
assert fig.data[0].line.color == "red"
424+
assert fig.data[1].line.color == "blue"
425+
assert fig.data[2].line.color == "green"
426+
427+
def test_colors_dict_sets_discrete_map(self) -> None:
428+
"""Test that a dict sets color_discrete_map."""
429+
fig = self.da.plotly.line(colors={"A": "red", "B": "blue", "C": "green"})
430+
# Traces should be colored according to the mapping
431+
assert len(fig.data) == 3
432+
# Find traces by name and check their color
433+
colors_by_name = {trace.name: trace.line.color for trace in fig.data}
434+
assert colors_by_name["A"] == "red"
435+
assert colors_by_name["B"] == "blue"
436+
assert colors_by_name["C"] == "green"
437+
438+
def test_colors_continuous_scale_string(self) -> None:
439+
"""Test that a continuous scale name sets color_continuous_scale."""
440+
da = xr.DataArray(
441+
np.random.rand(50, 2),
442+
dims=["point", "coord"],
443+
coords={"coord": ["x", "y"]},
444+
)
445+
fig = da.plotly.scatter(y="coord", x="point", color="value", colors="Viridis")
446+
# Plotly Express uses coloraxis in the layout for continuous scales
447+
# Check that the colorscale was applied to the coloraxis
448+
assert fig.layout.coloraxis.colorscale is not None
449+
colorscale = fig.layout.coloraxis.colorscale
450+
# Viridis should be in the colorscale definition
451+
assert any("viridis" in str(c).lower() for c in colorscale) or len(colorscale) > 0
452+
453+
def test_colors_qualitative_palette_string(self) -> None:
454+
"""Test that a qualitative palette name sets color_discrete_sequence."""
455+
import plotly.express as px
456+
457+
fig = self.da.plotly.line(colors="D3")
458+
# D3 palette should be applied - check first trace color is from D3
459+
d3_colors = px.colors.qualitative.D3
460+
assert fig.data[0].line.color in d3_colors
461+
462+
def test_colors_ignored_with_warning_when_px_kwargs_present(self) -> None:
463+
"""Test that colors is ignored with warning when color_* kwargs are present."""
464+
import warnings
465+
466+
with warnings.catch_warnings(record=True) as w:
467+
warnings.simplefilter("always")
468+
fig = self.da.plotly.line(
469+
colors="D3", color_discrete_sequence=["orange", "purple", "cyan"]
470+
)
471+
# Should have raised a warning about colors being ignored
472+
assert any(
473+
"colors" in str(m.message).lower() and "ignored" in str(m.message).lower()
474+
for m in w
475+
), "Expected warning about 'colors' being 'ignored' not found"
476+
# The explicit px_kwargs should take precedence
477+
assert fig.data[0].line.color == "orange"
478+
479+
def test_colors_none_uses_defaults(self) -> None:
480+
"""Test that colors=None uses Plotly defaults."""
481+
fig1 = self.da.plotly.line(colors=None)
482+
fig2 = self.da.plotly.line()
483+
# Both should produce the same result
484+
assert fig1.data[0].line.color == fig2.data[0].line.color
485+
486+
def test_colors_works_with_bar(self) -> None:
487+
"""Test colors parameter with bar chart."""
488+
fig = self.da.plotly.bar(colors=["#e41a1c", "#377eb8", "#4daf4a"])
489+
assert fig.data[0].marker.color == "#e41a1c"
490+
491+
def test_colors_works_with_area(self) -> None:
492+
"""Test colors parameter with area chart."""
493+
fig = self.da.plotly.area(colors=["red", "green", "blue"])
494+
assert len(fig.data) == 3
495+
496+
def test_colors_works_with_scatter(self) -> None:
497+
"""Test colors parameter with scatter plot."""
498+
fig = self.da.plotly.scatter(colors=["red", "green", "blue"])
499+
assert len(fig.data) == 3
500+
501+
def test_colors_works_with_imshow(self) -> None:
502+
"""Test colors parameter with imshow (continuous scale)."""
503+
da = xr.DataArray(np.random.rand(10, 10), dims=["y", "x"])
504+
fig = da.plotly.imshow(colors="RdBu")
505+
# Plotly Express uses coloraxis in the layout for continuous scales
506+
assert fig.layout.coloraxis.colorscale is not None
507+
colorscale = fig.layout.coloraxis.colorscale
508+
# RdBu should be in the colorscale definition
509+
assert any("rdbu" in str(c).lower() for c in colorscale) or len(colorscale) > 0
510+
511+
def test_colors_works_with_pie(self) -> None:
512+
"""Test colors parameter with pie chart."""
513+
da = xr.DataArray([30, 40, 30], dims=["category"], coords={"category": ["A", "B", "C"]})
514+
fig = da.plotly.pie(colors={"A": "red", "B": "blue", "C": "green"})
515+
assert isinstance(fig, go.Figure)
516+
517+
def test_colors_works_with_dataset(self) -> None:
518+
"""Test colors parameter works with Dataset accessor."""
519+
ds = xr.Dataset(
520+
{
521+
"temp": (["time"], np.random.rand(10)),
522+
"precip": (["time"], np.random.rand(10)),
523+
}
524+
)
525+
fig = ds.plotly.line(colors=["red", "blue"])
526+
assert len(fig.data) == 2
527+
assert fig.data[0].line.color == "red"
528+
assert fig.data[1].line.color == "blue"

xarray_plotly/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@
5252

5353
from xarray_plotly import config
5454
from xarray_plotly.accessor import DataArrayPlotlyAccessor, DatasetPlotlyAccessor
55-
from xarray_plotly.common import SLOT_ORDERS, auto
55+
from xarray_plotly.common import SLOT_ORDERS, Colors, auto
5656
from xarray_plotly.figures import (
5757
add_secondary_y,
5858
overlay,
@@ -61,6 +61,7 @@
6161

6262
__all__ = [
6363
"SLOT_ORDERS",
64+
"Colors",
6465
"add_secondary_y",
6566
"auto",
6667
"config",

0 commit comments

Comments
 (0)