Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 23 additions & 0 deletions cicd_utils/cicd/test_helpers.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,18 @@
from __future__ import annotations

import contextlib
import copy
import pickle
import sys
from importlib.util import module_from_spec, spec_from_file_location
from pathlib import Path
from typing import TYPE_CHECKING, Any, TypeVar, cast
from unittest.mock import MagicMock, patch

import plotly.io
import pytest_socket
from plotly import graph_objects as go

if TYPE_CHECKING:
from collections.abc import Iterator
from importlib.abc import Loader
Expand All @@ -17,6 +22,24 @@
from plotly.graph_objs import Figure


_PLOTLY_SHOW_DEEPCOPY = copy.deepcopy(plotly.io.show)


def plotly_show_browser(fig: go.Figure, renderer: str = "browser", **kwargs: Any) -> None:
"""Display a Plotly figure in a new browser tab.

This temporarily enables network connections (if disabled by pytest-socket)
and ensures the real (unpatched) `plotly.io.show()` is used. Useful for
debugging test failures by viewing the actual rendered figure in a browser
window.
"""
try:
pytest_socket.enable_socket()
_PLOTLY_SHOW_DEEPCOPY(fig=fig.to_dict(), renderer=renderer, **kwargs)
finally:
pytest_socket.disable_socket()


@contextlib.contextmanager
def patch_plotly_show() -> Iterator[None]:
"""Patch the :func:`plotly.io.show()` function to skip any rendering steps
Expand Down
15 changes: 5 additions & 10 deletions cicd_utils/ridgeplot_examples/_lincoln_weather.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,13 @@
from __future__ import annotations

from typing import TYPE_CHECKING, Literal
from typing import TYPE_CHECKING

if TYPE_CHECKING:
from collections.abc import Collection

import plotly.graph_objects as go

from ridgeplot._color.interpolation import SolidColormode
from ridgeplot._types import Color, ColorScale


def main(
colorscale: ColorScale | Collection[Color] | str | None = "Inferno",
colormode: SolidColormode | Literal["fillgradient"] = "fillgradient",
color_discrete_map: dict[str, str] | None = None,
) -> go.Figure:
import numpy as np

Expand All @@ -33,9 +27,10 @@ def main(

fig = ridgeplot(
samples=samples,
labels=[["Min Temperature [F]", "Max Temperature [F]"]] * len(months),
row_labels=months,
colorscale=colorscale,
colormode=colormode,
colorscale="Inferno",
color_discrete_map=color_discrete_map,
bandwidth=4,
kde_points=np.linspace(-40, 110, 400),
spacing=0.3,
Expand Down
6 changes: 4 additions & 2 deletions cicd_utils/ridgeplot_examples/_lincoln_weather_red_blue.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,10 @@

def main() -> go.Figure:
fig = lincoln_weather(
colorscale=["orangered", "deepskyblue"],
colormode="trace-index-row-wise",
color_discrete_map={
"Min Temperature [F]": "deepskyblue",
"Max Temperature [F]": "orangered",
}
)
return fig

Expand Down
25 changes: 18 additions & 7 deletions docs/getting_started/getting_started.md
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,7 @@ samples = [
# And finish by styling it up to your liking!
fig = ridgeplot(
samples=samples,
labels=[["Min Temperature [F]", "Max Temperature [F]"]] * len(months),
row_labels=months,
colorscale="Inferno",
bandwidth=4,
Expand Down Expand Up @@ -228,6 +229,7 @@ Finally, we can pass the {py:paramref}`~ridgeplot.ridgeplot.samples` list to the
```python
fig = ridgeplot(
samples=samples,
labels=[["Min Temperature [F]", "Max Temperature [F]"]] * len(months),
row_labels=months,
colorscale="Inferno",
bandwidth=4,
Expand All @@ -237,8 +239,8 @@ fig = ridgeplot(

fig.update_layout(
title="Minimum and maximum daily temperatures in Lincoln, NE (2016)",
height=650,
width=950,
height=600,
width=800,
font_size=14,
plot_bgcolor="rgb(245, 245, 245)",
xaxis_gridcolor="white",
Expand All @@ -261,15 +263,24 @@ fig.show()
We are currently investigating the best way to support all color options available in Plotly Express. If you have any suggestions or requests, or just want to track the progress, please check out {gh-issue}`226`.
:::

The {py:func}`~ridgeplot.ridgeplot()` function offers flexible customisation options that help you control the automatic coloring of ridgeline traces. Take a look at {py:paramref}`~ridgeplot.ridgeplot.colorscale`, {py:paramref}`~ridgeplot.ridgeplot.colormode`, and {py:paramref}`~ridgeplot.ridgeplot.opacity` for more information.
The {py:func}`~ridgeplot.ridgeplot()` function offers flexible customisation options that help you control the exact coloring of ridgeline traces. Take a look at {py:paramref}`~ridgeplot.ridgeplot.colorscale`, {py:paramref}`~ridgeplot.ridgeplot.colormode`, {py:paramref}`~ridgeplot.ridgeplot.color_discrete_map`, {py:paramref}`~ridgeplot.ridgeplot.opacity`, and {py:paramref}`~ridgeplot.ridgeplot.line_color` for a detailed description of the available options.

As a simple (but quite common) example, we'll try to adjust the output of the previous example to use different discrete colors for the minimum and maximum temperature traces. Specifically, we'll set all minimum temperature traces to a shade of blue and all maximum temperature traces to a shade of red. To achieve this, we just need to set the {py:paramref}`~ridgeplot.ridgeplot.color_discrete_map` parameter to a dictionary that maps the trace labels to the desired colors.

To demonstrate how these options can be used, we can try to adjust the output from the previous example to use different colors for the minimum and maximum temperature traces. For instance, setting all minimum temperature traces to a shade of blue and all maximum temperature traces to a shade of red. To achieve this, we just need to adjust the {py:paramref}`~ridgeplot.ridgeplot.colorscale` and {py:paramref}`~ridgeplot.ridgeplot.colormode` parameters in the call to the {py:func}`~ridgeplot.ridgeplot()` function. _i.e._,
:::{note}
Because the {py:paramref}`~ridgeplot.ridgeplot.color_discrete_map` parameter takes precedence over the {py:paramref}`~ridgeplot.ridgeplot.colorscale` and {py:paramref}`~ridgeplot.ridgeplot.colormode` parameters, we can keep them as they are in the previous example. However, since their behavior will be overridden by {py:paramref}`~ridgeplot.ridgeplot.color_discrete_map`, it is a good practice to remove them from the function call to avoid any confusion.
:::

```python
fig = ridgeplot(
# Same options as before, with only the following changes:
colorscale=["orangered", "deepskyblue"],
colormode="trace-index-row-wise",
# Same options as before, with the
# addition of `color_discrete_map`
# ...
color_discrete_map={
"Min Temperature [F]": "deepskyblue",
"Max Temperature [F]": "orangered",
}
# ...
)
```

Expand Down
4 changes: 4 additions & 0 deletions docs/reference/changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,10 @@ Unreleased changes
- Dropped support for Python 3.9, in accordance with the official Python support policy[^1] ({gh-pr}`345`)
- Bump project classification from Pre-Alpha to Alpha ({gh-pr}`336`)

### Features

- Implement a new `color_discrete_map` parameter to allow users to specify custom colors for each trace ({gh-pr}`348`)

### CI/CD

- Bump actions/github-script from 7 to 8 ({gh-pr}`338`)
Expand Down
47 changes: 30 additions & 17 deletions src/ridgeplot/_figure_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
ShallowTraceTypesArray,
TraceType,
TraceTypesArray,
is_flat_str_collection,
is_shallow_trace_types_array,
is_trace_type,
is_trace_types_array,
Expand All @@ -50,9 +49,9 @@ def normalise_trace_types(
trace_types = cast("TraceTypesArray", [[trace_types] * len(row) for row in densities])
elif is_shallow_trace_types_array(trace_types):
trace_types = nest_shallow_collection(trace_types)
trace_types = normalise_row_attrs(trace_types, l2_target=densities)
trace_types = normalise_row_attrs(attrs=trace_types, l2_target=densities)
elif is_trace_types_array(trace_types):
trace_types = normalise_row_attrs(trace_types, l2_target=densities)
trace_types = normalise_row_attrs(attrs=trace_types, l2_target=densities)
else:
raise TypeError(f"Invalid trace_type: {trace_types}")
return trace_types
Expand All @@ -67,17 +66,15 @@ def normalise_trace_labels(
ids = iter(range(1, n_traces + 1))
trace_labels = [[f"Trace {next(ids)}" for _ in row] for row in densities]
else:
if is_flat_str_collection(trace_labels):
trace_labels = nest_shallow_collection(trace_labels)
trace_labels = normalise_row_attrs(trace_labels, l2_target=densities)
trace_labels = normalise_row_attrs(attrs=trace_labels, l2_target=densities)
return trace_labels


def normalise_row_labels(trace_labels: LabelsArray) -> Collection[str]:
return [",".join(ordered_dedup(row)) for row in trace_labels]


def update_layout(
def _update_layout(
fig: go.Figure,
row_labels: Collection[str] | Literal[False],
tickvals: list[float],
Expand Down Expand Up @@ -123,12 +120,13 @@ def update_layout(

def create_ridgeplot(
densities: Densities,
trace_labels: LabelsArray | ShallowLabelsArray | None,
trace_types: TraceTypesArray | ShallowTraceTypesArray | TraceType,
row_labels: Collection[str] | None | Literal[False],
colorscale: ColorScale | Collection[Color] | str | None,
opacity: float | None,
colormode: Literal["fillgradient"] | SolidColormode,
trace_labels: LabelsArray | ShallowLabelsArray | None,
color_discrete_map: dict[str, str] | None,
opacity: float | None,
line_color: Color | Literal["fill-color"],
line_width: float | None,
spacing: float,
Expand Down Expand Up @@ -159,6 +157,15 @@ def create_ridgeplot(
elif row_labels is not False and len(row_labels) != n_rows:
raise ValueError(f"Expected {n_rows} row_labels, got {len(row_labels)} instead.")

if color_discrete_map:
missing_labels = {
label for row in trace_labels for label in row if label not in color_discrete_map
}
if missing_labels:
raise ValueError(
f"The following labels are missing from 'color_discrete_map': {missing_labels}",
)

# Force cast certain arguments to the expected types
line_width = float(line_width) if line_width is not None else None
spacing = float(spacing)
Expand All @@ -176,12 +183,18 @@ def create_ridgeplot(
x_min=x_min,
x_max=x_max,
)
solid_colors = compute_solid_colors(
colorscale=colorscale,
colormode=colormode if colormode != "fillgradient" else "mean-minmax",
opacity=opacity,
interpolation_ctx=interpolation_ctx,
)
if color_discrete_map:
solid_colors = (
(color_discrete_map[label] for label in row_trace_labels)
for row_trace_labels in trace_labels
)
else:
solid_colors = compute_solid_colors(
colorscale=colorscale,
colormode=colormode if colormode != "fillgradient" else "mean-minmax",
opacity=opacity,
interpolation_ctx=interpolation_ctx,
)

tickvals: list[float] = []
fig = go.Figure()
Expand All @@ -207,14 +220,14 @@ def create_ridgeplot(
fig=fig,
coloring_ctx=ColoringContext(
colorscale=colorscale,
colormode=colormode,
fillgradient=colormode == "fillgradient" and not color_discrete_map,
opacity=opacity,
interpolation_ctx=interpolation_ctx,
),
)
ith_trace += 1

fig = update_layout(
fig = _update_layout(
fig,
row_labels=row_labels,
tickvals=tickvals,
Expand Down
2 changes: 1 addition & 1 deletion src/ridgeplot/_kde.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def normalize_sample_weights(
return [[sample_weights] * len(row) for row in samples]
if _is_shallow_sample_weights(sample_weights):
sample_weights = nest_shallow_collection(sample_weights)
sample_weights = normalise_row_attrs(sample_weights, l2_target=samples)
sample_weights = normalise_row_attrs(attrs=sample_weights, l2_target=samples)
return sample_weights


Expand Down
2 changes: 1 addition & 1 deletion src/ridgeplot/_obj/traces/area.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ class AreaTrace(RidgeplotTrace):
_DEFAULT_LINE_WIDTH: ClassVar[float] = 1.5

def _get_coloring_kwargs(self, ctx: ColoringContext) -> dict[str, Any]:
if ctx.colormode == "fillgradient":
if ctx.fillgradient:
if ctx.opacity is not None:
# HACK: Plotly doesn't yet support setting the fill opacity
# for traces with `fillgradient`. As a workaround, we
Expand Down
2 changes: 1 addition & 1 deletion src/ridgeplot/_obj/traces/bar.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ class BarTrace(RidgeplotTrace):
_DEFAULT_LINE_WIDTH: ClassVar[float] = 0.5

def _get_coloring_kwargs(self, ctx: ColoringContext) -> dict[str, Any]:
if ctx.colormode == "fillgradient":
if ctx.fillgradient:
color_kwargs = dict(
marker_line_color=self.line_color,
marker_color=[
Expand Down
4 changes: 2 additions & 2 deletions src/ridgeplot/_obj/traces/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
if TYPE_CHECKING:
from plotly import graph_objects as go

from ridgeplot._color.interpolation import InterpolationContext, SolidColormode
from ridgeplot._color.interpolation import InterpolationContext
from ridgeplot._types import Color, ColorScale, DensityTrace


Expand Down Expand Up @@ -44,7 +44,7 @@
@dataclass
class ColoringContext:
colorscale: ColorScale
colormode: Literal["fillgradient"] | SolidColormode
fillgradient: bool
opacity: float | None
interpolation_ctx: InterpolationContext

Expand Down
22 changes: 21 additions & 1 deletion src/ridgeplot/_ridgeplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@ def ridgeplot(
# Coloring and styling parameters
colorscale: ColorScale | Collection[Color] | str | None = None,
colormode: Literal["fillgradient"] | SolidColormode = "fillgradient",
color_discrete_map: dict[str, str] | None = None,
opacity: float | None = None,
line_color: Color | Literal["fill-color"] = "black",
line_width: float | None = None,
Expand Down Expand Up @@ -337,6 +338,24 @@ def ridgeplot(
The default value changed from ``"mean-minmax"`` to
``"fillgradient"``.

color_discrete_map: dict or None
A mapping from trace labels to specific colors.

This parameter is useful when you want to have full manual control over
the colors assigned to each trace. If specified, the assigned colors
are determined by looking up the trace's label as a key in this
dictionary. All labels must be present as keys in the dictionary.

Note that this parameter overrides any value specified for
:paramref:`.colorscale` and :paramref:`.colormode`. In this case, the
color assigned to each trace will be a solid color, as specified in
this dictionary.

If not specified (default), the colors will be determined using the
:paramref:`.colorscale` and :paramref:`.colormode` parameters.

.. versionadded:: 0.5.0

opacity : float or None
If None (default), this parameter will be ignored and the transparency
values of the specified color-scale will remain untouched. Otherwise,
Expand Down Expand Up @@ -485,8 +504,9 @@ def ridgeplot(
trace_types=trace_type,
row_labels=row_labels,
colorscale=colorscale,
opacity=opacity,
colormode=colormode,
color_discrete_map=color_discrete_map,
opacity=opacity,
line_color=line_color,
line_width=line_width,
spacing=spacing,
Expand Down
Loading