diff --git a/src/cmap/_colormap.py b/src/cmap/_colormap.py index 748742ddc..741c451ef 100644 --- a/src/cmap/_colormap.py +++ b/src/cmap/_colormap.py @@ -15,6 +15,7 @@ from . import _external from ._catalog import Catalog from ._color import Color, ColorLike +from ._parametrized_colormaps import get_parametrized_colormap_function if TYPE_CHECKING: from collections.abc import Iterator @@ -92,6 +93,8 @@ class Colormap: - a `str` containing a recognized string colormap name (e.g. `"viridis"`, `"magma"`), optionally suffixed with `"_r"` to reverse the colormap (e.g. `"viridis"`, `"magma_r"`). + - a `str` containing a registered colormap function name (e.g. `"cubehelix"`), + used with the `cmap_kwargs` parameter to pass function arguments. - An iterable of [ColorLike](../../colors.md#colorlike-objects) values (any object that can be cast to a [`Color`][cmap.Color]), or "color-stop-like" tuples ( `(float, ColorLike)` where the first element is a scalar value @@ -109,6 +112,12 @@ class Colormap: the matplotlib docs for more. - a `Callable` that takes an array of N values in the range [0, 1] and returns an (N, 4) array of RGBA values in the range [0, 1]. + cmap_kwargs : dict[str, Any] | None + Keyword arguments to pass to a colormap function when `value` is a string + naming a registered function (e.g. `"cubehelix"`). For example: + `Colormap("cubehelix", cmap_kwargs={"start": 1.0, "rotation": -1.0})`. + If provided when `value` is not a registered function name, a `ValueError` + will be raised. name : str | None A name for the colormap. If None, will be set to the identifier or the string `"custom colormap"`. @@ -226,10 +235,16 @@ def __init__( under: ColorLike | None = None, over: ColorLike | None = None, bad: ColorLike | None = None, + cmap_kwargs: dict[str, Any] | None = None, ) -> None: self.info: CatalogItem | None = None - if isinstance(value, str): + if isinstance(value, str) and cmap_kwargs is not None: + name = name or value + colormap_func = get_parametrized_colormap_function(name) + colormap_func = partial(colormap_func, **cmap_kwargs) + stops = _parse_colorstops(colormap_func) + elif isinstance(value, str): rev = value.endswith("_r") info = self.catalog()[value[:-2] if rev else value] name = name or f"{info.namespace}:{info.name}" diff --git a/src/cmap/_parametrized_colormaps.py b/src/cmap/_parametrized_colormaps.py new file mode 100644 index 000000000..eee825616 --- /dev/null +++ b/src/cmap/_parametrized_colormaps.py @@ -0,0 +1,42 @@ +"""Registry of parametrized colormap functions.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Callable + +from .data.cubehelix import cubehelix + +if TYPE_CHECKING: + from typing import Any + + import numpy as np + + LutCallable = Callable[[np.ndarray], np.ndarray] + +# Registry of colormap functions that accept parameters +FUNCTION_REGISTRY: dict[str, Callable[..., Any]] = { + "cubehelix": cubehelix, +} + + +def get_parametrized_colormap_function(name: str) -> Callable[..., Any]: + """ + Get a parametrized colormap function by name. + + Args: + name: Name of the colormap function + + Returns + ------- + The colormap function + + Raises + ------ + ValueError: If the function name is not in the registry + """ + if name not in FUNCTION_REGISTRY: + available = ", ".join(sorted(FUNCTION_REGISTRY.keys())) + raise ValueError( + f"Unknown colormap function: {name!r}. Available functions: {available}" + ) + return FUNCTION_REGISTRY[name] diff --git a/tests/test_colormap.py b/tests/test_colormap.py index 3a9484d71..bd63c428a 100644 --- a/tests/test_colormap.py +++ b/tests/test_colormap.py @@ -236,3 +236,18 @@ def test_shifted() -> None: assert cm.shifted(0.5).shifted(-0.5) == cm # two shifts of 0.5 should give the original array assert cm.shifted().shifted() == cm + + +def test_function_colormap_with_cmap_kwargs() -> None: + # construct cubehelix with custom parameters + ch = Colormap("cubehelix", cmap_kwargs={"start": 1.0, "rotation": -1.0}) + + # values should be different from default cubehelix + default_ch = Colormap("cubehelix") + assert ch(0.5) != default_ch(0.5) + + +def test_invalid_function_colormap_with_cmap_kwargs() -> None: + # name which doesn't map to a registered function should raise ValueError + with pytest.raises(ValueError, match="Unknown colormap function"): + Colormap("nonexistent_function", cmap_kwargs={"param": 1.0})