diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml
index e723e92f..d9c30e19 100644
--- a/.github/workflows/ci.yml
+++ b/.github/workflows/ci.yml
@@ -37,9 +37,9 @@ jobs:
# env:
# SKIP: "no-commit-to-branch"
# run: tox -e pre-commit-all
- - name: Run type checks with mypy
+ - name: Run type checks
shell: bash
- run: tox -e mypy-safe
+ run: tox -e typing
# Run test suits for all supported platforms and Python versions
software-tests:
diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml
index 9f093834..d5e40d79 100644
--- a/.pre-commit-config.yaml
+++ b/.pre-commit-config.yaml
@@ -100,19 +100,9 @@ repos:
args: [ --no-build-isolation ]
additional_dependencies: [setuptools-scm]
- - repo: https://github.com/adamchainz/blacken-docs
- rev: 1.19.1
- hooks:
- - id: blacken-docs
-
- - repo: https://github.com/psf/black-pre-commit-mirror
- rev: 24.10.0
- hooks:
- - id: black-jupyter
-
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.8.1
hooks:
+ - id: ruff-format
- id: ruff
args: [ --show-fixes, --exit-non-zero-on-fix ]
- types_or: [ python, pyi, jupyter ]
diff --git a/MANIFEST.in b/MANIFEST.in
index ca38a4f4..b115ba8b 100644
--- a/MANIFEST.in
+++ b/MANIFEST.in
@@ -8,6 +8,7 @@ include *.md
include *.toml
include *.yml
include *.yaml
+include *.json
# Stubs
recursive-include src py.typed *.pyi
diff --git a/cicd_utils/README.md b/cicd_utils/README.md
index 7f82cb51..2c2b5481 100644
--- a/cicd_utils/README.md
+++ b/cicd_utils/README.md
@@ -16,6 +16,7 @@ For this reason, the `cicd_utils` directory needs to be made explicitly discover
Static analysis tools will also we need to be made aware of this package:
- For mypy, we can add it to the `files` option in `mypy.ini` to help with import discovery.
+- For pyright, we can add it to the `extraPaths` option in `pyrightconfig.json`.
- For ruff's _isort_-implementation, we also added it to the `known-first-party` list (see `ruff.toml`)
### The `cicd/scripts` directory
diff --git a/cicd_utils/cicd/compile_plotly_charts.py b/cicd_utils/cicd/compile_plotly_charts.py
index 5a95d493..f74910b5 100755
--- a/cicd_utils/cicd/compile_plotly_charts.py
+++ b/cicd_utils/cicd/compile_plotly_charts.py
@@ -5,6 +5,7 @@
used in the docs. It saves the HTML and WebP artefacts to the
`docs/_static/charts` directory.
"""
+
from __future__ import annotations
from copy import deepcopy
@@ -116,9 +117,9 @@ def _write_plotlyjs_bundle() -> None:
bundle_path.write_text(plotlyjs, encoding="utf-8")
-def compile_plotly_charts() -> None:
- # Setup logic ---
- # _write_plotlyjs_bundle()
+def compile_plotly_charts(update_plotlyjs_bundle: bool = False) -> None:
+ if update_plotlyjs_bundle:
+ _write_plotlyjs_bundle()
# Compile all charts ---
if not PATH_STATIC_CHARTS.exists():
diff --git a/cicd_utils/cicd/scripts/extract_latest_release_notes.py b/cicd_utils/cicd/scripts/extract_latest_release_notes.py
index 760e0792..e024e0e7 100755
--- a/cicd_utils/cicd/scripts/extract_latest_release_notes.py
+++ b/cicd_utils/cicd/scripts/extract_latest_release_notes.py
@@ -6,6 +6,7 @@
- The output is written to the `LATEST_RELEASE_NOTES.md` file.
- The body of this file is then used as the body of the GitHub release.
"""
+
from __future__ import annotations
from pathlib import Path
diff --git a/cicd_utils/ridgeplot_examples/__init__.py b/cicd_utils/ridgeplot_examples/__init__.py
index b46bb7ba..b0faf677 100644
--- a/cicd_utils/ridgeplot_examples/__init__.py
+++ b/cicd_utils/ridgeplot_examples/__init__.py
@@ -24,6 +24,12 @@ def load_basic() -> go.Figure:
return main()
+def load_basic_hist() -> go.Figure:
+ from ._basic_hist import main
+
+ return main()
+
+
def load_lincoln_weather() -> go.Figure:
from ._lincoln_weather import main
@@ -44,6 +50,7 @@ def load_probly() -> go.Figure:
ALL_EXAMPLES: list[tuple[str, Callable[[], go.Figure]]] = [
("basic", load_basic),
+ ("basic_hist", load_basic_hist),
("lincoln_weather", load_lincoln_weather),
("lincoln_weather_red_blue", load_lincoln_weather_red_blue),
("probly", load_probly),
diff --git a/cicd_utils/ridgeplot_examples/_basic.py b/cicd_utils/ridgeplot_examples/_basic.py
index 885ad5a5..8192f1f6 100644
--- a/cicd_utils/ridgeplot_examples/_basic.py
+++ b/cicd_utils/ridgeplot_examples/_basic.py
@@ -12,9 +12,9 @@ def main() -> go.Figure:
from ridgeplot import ridgeplot
rng = np.random.default_rng(42)
- my_samples = [rng.normal(n / 1.2, size=600) for n in range(7, 0, -1)]
+ my_samples = [rng.normal(n / 1.2, size=600) for n in range(6, 0, -1)]
fig = ridgeplot(samples=my_samples)
- fig.update_layout(height=400, width=800)
+ fig.update_layout(height=350, width=800)
return fig
diff --git a/cicd_utils/ridgeplot_examples/_basic_hist.py b/cicd_utils/ridgeplot_examples/_basic_hist.py
new file mode 100644
index 00000000..ee2ae666
--- /dev/null
+++ b/cicd_utils/ridgeplot_examples/_basic_hist.py
@@ -0,0 +1,24 @@
+from __future__ import annotations
+
+from typing import TYPE_CHECKING
+
+if TYPE_CHECKING:
+ import plotly.graph_objects as go
+
+
+def main() -> go.Figure:
+ import numpy as np
+
+ from ridgeplot import ridgeplot
+
+ rng = np.random.default_rng(42)
+ my_samples = [rng.normal(n / 1.2, size=600) for n in range(7, 0, -1)]
+ fig = ridgeplot(samples=my_samples, nbins=20)
+ fig.update_layout(height=350, width=800)
+
+ return fig
+
+
+if __name__ == "__main__":
+ fig = main()
+ fig.show()
diff --git a/cicd_utils/ridgeplot_examples/_lincoln_weather.py b/cicd_utils/ridgeplot_examples/_lincoln_weather.py
index 970b0572..1c7c1456 100644
--- a/cicd_utils/ridgeplot_examples/_lincoln_weather.py
+++ b/cicd_utils/ridgeplot_examples/_lincoln_weather.py
@@ -22,11 +22,11 @@ def main(
df = load_lincoln_weather()
- months = df.index.month_name().unique() # type: ignore[attr-defined]
+ months = df.index.month_name().unique() # pyright: ignore[reportAttributeAccessIssue]
samples = [
[
- df[df.index.month_name() == month]["Min Temperature [F]"], # type: ignore[attr-defined]
- df[df.index.month_name() == month]["Max Temperature [F]"], # type: ignore[attr-defined]
+ df[df.index.month_name() == month]["Min Temperature [F]"], # pyright: ignore[reportAttributeAccessIssue]
+ df[df.index.month_name() == month]["Max Temperature [F]"], # pyright: ignore[reportAttributeAccessIssue]
]
for month in months
]
diff --git a/docs/_static/charts/basic.html b/docs/_static/charts/basic.html
index 8e43d4db..a3a124c7 100644
--- a/docs/_static/charts/basic.html
+++ b/docs/_static/charts/basic.html
@@ -1 +1 @@
-
+
diff --git a/docs/_static/charts/basic.jpeg b/docs/_static/charts/basic.jpeg
index 29ffc761..42b90ebf 100644
Binary files a/docs/_static/charts/basic.jpeg and b/docs/_static/charts/basic.jpeg differ
diff --git a/docs/_static/charts/basic.webp b/docs/_static/charts/basic.webp
index 7401abf1..9581732e 100644
Binary files a/docs/_static/charts/basic.webp and b/docs/_static/charts/basic.webp differ
diff --git a/docs/_static/charts/basic_hist.html b/docs/_static/charts/basic_hist.html
new file mode 100644
index 00000000..1eab9140
--- /dev/null
+++ b/docs/_static/charts/basic_hist.html
@@ -0,0 +1 @@
+
diff --git a/docs/_static/charts/basic_hist.jpeg b/docs/_static/charts/basic_hist.jpeg
new file mode 100644
index 00000000..98230427
Binary files /dev/null and b/docs/_static/charts/basic_hist.jpeg differ
diff --git a/docs/_static/charts/basic_hist.webp b/docs/_static/charts/basic_hist.webp
new file mode 100644
index 00000000..c90de7c8
Binary files /dev/null and b/docs/_static/charts/basic_hist.webp differ
diff --git a/docs/_static/charts/lincoln_weather.html b/docs/_static/charts/lincoln_weather.html
index ab02afca..118c9328 100644
--- a/docs/_static/charts/lincoln_weather.html
+++ b/docs/_static/charts/lincoln_weather.html
@@ -1 +1 @@
-
+
diff --git a/docs/_static/charts/lincoln_weather_red_blue.html b/docs/_static/charts/lincoln_weather_red_blue.html
index 8949d609..1db6aee5 100644
--- a/docs/_static/charts/lincoln_weather_red_blue.html
+++ b/docs/_static/charts/lincoln_weather_red_blue.html
@@ -1 +1 @@
-
+
diff --git a/docs/_static/charts/probly.html b/docs/_static/charts/probly.html
index 0381eed4..5aa18a96 100644
--- a/docs/_static/charts/probly.html
+++ b/docs/_static/charts/probly.html
@@ -1 +1 @@
-
+
diff --git a/docs/api/index.rst b/docs/api/index.rst
index 1708e41f..0421db38 100644
--- a/docs/api/index.rst
+++ b/docs/api/index.rst
@@ -9,16 +9,6 @@ API Reference
ridgeplot.ridgeplot
-Color utilities
-===============
-
-.. autosummary::
- :toctree: public/
- :nosignatures:
-
- ridgeplot.list_all_colorscale_names
-
-
Data loading utilities
======================
diff --git a/docs/api/internal/_obj/traces.rst b/docs/api/internal/_obj/traces.rst
new file mode 100644
index 00000000..04959322
--- /dev/null
+++ b/docs/api/internal/_obj/traces.rst
@@ -0,0 +1,10 @@
+ridgeplot._obj.traces
+================
+
+Object-oriented trace interfaces.
+
+.. toctree::
+ :maxdepth: 1
+ :glob:
+
+ traces/*
diff --git a/docs/api/internal/_obj/traces/area.rst b/docs/api/internal/_obj/traces/area.rst
new file mode 100644
index 00000000..83c6e3f3
--- /dev/null
+++ b/docs/api/internal/_obj/traces/area.rst
@@ -0,0 +1,7 @@
+ridgeplot._obj.traces.area
+===========================
+
+Area trace object.
+
+.. automodule:: ridgeplot._obj.traces.area
+ :private-members:
diff --git a/docs/api/internal/_obj/traces/bar.rst b/docs/api/internal/_obj/traces/bar.rst
new file mode 100644
index 00000000..2f7e7075
--- /dev/null
+++ b/docs/api/internal/_obj/traces/bar.rst
@@ -0,0 +1,7 @@
+ridgeplot._obj.traces.bar
+===========================
+
+Bar trace object.
+
+.. automodule:: ridgeplot._obj.traces.bar
+ :private-members:
diff --git a/docs/api/internal/_obj/traces/base.rst b/docs/api/internal/_obj/traces/base.rst
new file mode 100644
index 00000000..0f5b55c3
--- /dev/null
+++ b/docs/api/internal/_obj/traces/base.rst
@@ -0,0 +1,7 @@
+ridgeplot._obj.traces.base
+===========================
+
+Base trace object and utilities.
+
+.. automodule:: ridgeplot._obj.traces.base
+ :private-members:
diff --git a/docs/api/internal/hist.rst b/docs/api/internal/hist.rst
new file mode 100644
index 00000000..9b8bf449
--- /dev/null
+++ b/docs/api/internal/hist.rst
@@ -0,0 +1,7 @@
+ridgeplot._hist
+==============
+
+Histogram utilities.
+
+.. automodule:: ridgeplot._hist
+ :private-members:
diff --git a/docs/api/internal/obj.rst b/docs/api/internal/obj.rst
new file mode 100644
index 00000000..1b33486f
--- /dev/null
+++ b/docs/api/internal/obj.rst
@@ -0,0 +1,10 @@
+ridgeplot._obj
+================
+
+Object-oriented interfaces.
+
+.. toctree::
+ :maxdepth: 1
+ :glob:
+
+ _obj/*
diff --git a/docs/conf.py b/docs/conf.py
index 88abd329..d6b8e880 100644
--- a/docs/conf.py
+++ b/docs/conf.py
@@ -3,13 +3,14 @@
import sys
from contextlib import contextmanager
from datetime import datetime
+from importlib import import_module
from pathlib import Path
from typing import TYPE_CHECKING
try:
import importlib.metadata as importlib_metadata
except ImportError:
- import importlib_metadata # type: ignore[no-redef]
+ import importlib_metadata # pyright: ignore[no-redef]
try:
from cicd.compile_plotly_charts import compile_plotly_charts
@@ -204,9 +205,9 @@
intersphinx_mapping = {
"python": ("https://docs.python.org/3", None),
"packaging": ("https://packaging.pypa.io/en/latest", None),
- "numpy": ("https://docs.scipy.org/doc/numpy/", None),
+ "numpy": ("https://numpy.org/doc/stable/", None),
"pandas": ("https://pandas.pydata.org/pandas-docs/stable/", None),
- "scipy": ("https://docs.scipy.org/doc/scipy/reference/", None),
+ "scipy": ("https://docs.scipy.org/doc/scipy/", None),
"statsmodels": ("https://www.statsmodels.org/stable/", None),
"plotly": ("https://plotly.com/python-api-reference/", None),
}
@@ -263,18 +264,9 @@
# ------- ._color.interpolation ----------------
"ridgeplot._color.interpolation.ColorscaleInterpolants",
"ridgeplot._color.interpolation.SolidColormode",
- # ------- ._figure_factory ---------------------
- "ridgeplot._figure_factory.TraceType",
- "ridgeplot._figure_factory.TraceTypesArray",
- "ridgeplot._figure_factory.ShallowTraceTypesArray",
- "ridgeplot._figure_factory.LabelsArray",
- "ridgeplot._figure_factory.ShallowLabelsArray",
# ------- ._kde --------------------------------
"ridgeplot._kde.KDEPoints",
"ridgeplot._kde.KDEBandwidth",
- "ridgeplot._kde.SampleWeights",
- "ridgeplot._kde.SampleWeightsArray",
- "ridgeplot._kde.ShallowSampleWeightsArray",
# ------- ._missing ----------------------------
"ridgeplot._missing.MISSING",
"ridgeplot._missing.MissingType",
@@ -298,13 +290,32 @@
"ridgeplot._types.SamplesRow",
"ridgeplot._types.Samples",
"ridgeplot._types.ShallowSamples",
+ "ridgeplot._types.TraceType",
+ "ridgeplot._types.TraceTypesArray",
+ "ridgeplot._types.ShallowTraceTypesArray",
+ "ridgeplot._types.LabelsArray",
+ "ridgeplot._types.ShallowLabelsArray",
+ "ridgeplot._types.SampleWeights",
+ "ridgeplot._types.SampleWeightsArray",
+ "ridgeplot._types.ShallowSampleWeightsArray",
}
+for fq in _TYPE_ALIASES_FULLY_QUALIFIED:
+ module_name, _, type_name = fq.rpartition(".")
+ try:
+ import_module(module_name)
+ except ImportError as e:
+ raise AssertionError(f"Type alias {fq!r} is not importable: {e}") from e
+
_TYPE_ALIASES = {fq.split(".")[-1]: fq for fq in _TYPE_ALIASES_FULLY_QUALIFIED}
autodoc_type_aliases = {
**{a: a for a in _TYPE_ALIASES.values()},
**{fq: fq for fq in _TYPE_ALIASES.values()},
}
napoleon_type_aliases = {a: f":data:`~{fq}`" for a, fq in _TYPE_ALIASES.items()}
+EXTRA_NAPOLEON_ALIASES = {
+ "Collection[Color]": r":data:`~collections.abc.Collection`\[:data:`~ridgeplot._types.Color`\]",
+}
+napoleon_type_aliases.update(EXTRA_NAPOLEON_ALIASES)
# -- sphinx_remove_toctrees ------------------------------------------------------------------------
@@ -380,5 +391,5 @@ def setup(app: Sphinx) -> None:
compile_plotly_charts()
# app.connect("html-page-context", register_jinja_functions)
- app.connect("build-finished", lambda *_: _fix_generated_public_api_rst())
- app.connect("build-finished", lambda *_: _fix_html_charts())
+ app.connect("build-finished", lambda *_: _fix_generated_public_api_rst()) # pyright: ignore[reportUnknownLambdaType]
+ app.connect("build-finished", lambda *_: _fix_html_charts()) # pyright: ignore[reportUnknownLambdaType]
diff --git a/docs/development/contributing.md b/docs/development/contributing.md
index b9c0e76f..d0bf31aa 100644
--- a/docs/development/contributing.md
+++ b/docs/development/contributing.md
@@ -158,13 +158,13 @@ pre-commit run --all-files
For more information on all the checks being run here, take a look inside the {repo-file}`.pre-commit-config.yaml` configuration file.
-The only static check that is not run by pre-commit is [mypy](https://github.com/python/mypy), which is too expensive to run on every commit. To run mypy against all files, run:
+The only static check that is not run by pre-commit is [pyright](https://github.com/microsoft/pyright), which is too expensive to run on every commit. To run pyright against all files, run:
```shell
-tox -e mypy-incremental
+tox -e typing
```
-Just like with pytest, you can also pass extra positional arguments to mypy by running `tox -e mypy-incremental -- `.
+Just like with pytest, you can also pass extra positional arguments to pyright by running `tox -e typing -- `.
To trigger all static checks, run:
@@ -184,23 +184,22 @@ Finally, we have a small workflow (see {repo-file}`.github/workflows/check-relea
Here is a quick overview of ~~all~~ most of the CI tools and software used in this project, along with their respective configuration files. If you have any questions or need help with any of these tools, feel free to ask for help from the community by commenting on your issue or pull request.
-| Tool | Category | config files | Details |
-|---------------------------------------------------------|------------------|-------------------------------------------------------------------------------------------|--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
-| [Tox](https://github.com/tox-dev/tox) | 🔧 Orchestration | {repo-file}`tox.ini` | We use Tox to reliably run all integration approval steps in reproducible isolated virtual environments. |
-| [GitHub Actions](https://github.com/features/actions) | 🔧 Orchestration | {repo-file}`.github/workflows/ci.yml` | Workflow automation for GitHub. We use it to automatically run all integration approval steps on every push or pull request event. |
-| [Make](https://www.gnu.org/software/make/) | 🔧 Orchestration | {repo-file}`Makefile` | A build automation tool that we (mis)use to abstract away some bootstrapping and development environment setup steps. |
-| [git](https://git-scm.com/) | 🕰 VCS | {repo-file}`.gitignore` | The project's version control system. |
-| [pytest](https://github.com/pytest-dev/pytest) | 🧪 Testing | {repo-file}`pytest.ini` | Testing framework for python code. |
-| [Coverage.py](https://github.com/nedbat/coveragepy) | 📊 Coverage | {repo-file}`.coveragerc` | The code coverage tool for Python |
-| [Codecov](https://about.codecov.io/) | 📊 Coverage | {repo-file}`.github/workflows/ci.yml` | An external services for tracking, monitoring, and alerting on code coverage metrics. |
-| [pre-commit](https://pre-commit.com/) | 💅 Linting | {repo-file}`.pre-commit-config.yaml` | Used to to automatically check and fix any formatting rules on every commit. |
-| [mypy](https://github.com/python/mypy) | 💅 Linting | {repo-file}`mypy.ini` | A static type checker for Python. We use quite a strict configuration here, which can be tricky at times. Feel free to ask for help from the community by commenting on your issue or pull request. |
-| [black](https://github.com/psf/black) | 💅 Linting | {repo-file}`pyproject.toml` | "The uncompromising Python code formatter". We use `black` to automatically format Python code in a deterministic manner. Maybe we'll replace this with `ruff` in the future. |
-| [ruff](https://github.com/astral-sh/ruff) | 💅 Linting | {repo-file}`ruff.toml` | "An extremely fast Python linter and code formatter, written in Rust." For this project, ruff replaced Flake8 (+plugins), isort, pydocstyle, pyupgrade, and autoflake with a single (and faster) tool. |
-| [EditorConfig](https://editorconfig.org/) | 💅 Linting | {repo-file}`.editorconfig` | This repository uses the `.editorconfig` standard configuration file, which aims to ensure consistent style across multiple programming environments. |
-| [bumpversion](https://github.com/c4urself/bump2version) | 📦 Packaging | {repo-file}`.bumpversion.cfg` | A small command line tool to simplify releasing software by updating all version strings in your source code by the correct increment. |
-| [setuptools](https://setuptools.pypa.io/en/latest/) | 📦 Packaging | {repo-file}`pyproject.toml` and {repo-file}`MANIFEST.in` | `MANIFEST.in` tells `setuptools` which files to include in the distribution. `pyproject.toml` is the new standard for defining static package metadata. |
-| [readthedocs](https://readthedocs.org/) | 📚 Documentation | {repo-file}`.readthedocs.yaml` and {repo-file}`.github/workflows/readthedocs-preview.yml` | An open-source documentation hosting platform. We use it to automatically build and deploy the documentation for this project. |
+| Tool | Category | config files | Details |
+|---------------------------------------------------------|------------------|-------------------------------------------------------------------------------------------|---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
+| [Tox](https://github.com/tox-dev/tox) | 🔧 Orchestration | {repo-file}`tox.ini` | We use Tox to reliably run all integration approval steps in reproducible isolated virtual environments. |
+| [GitHub Actions](https://github.com/features/actions) | 🔧 Orchestration | {repo-file}`.github/workflows/ci.yml` | Workflow automation for GitHub. We use it to automatically run all integration approval steps on every push or pull request event. |
+| [Make](https://www.gnu.org/software/make/) | 🔧 Orchestration | {repo-file}`Makefile` | A build automation tool that we (mis)use to abstract away some bootstrapping and development environment setup steps. |
+| [git](https://git-scm.com/) | 🕰 VCS | {repo-file}`.gitignore` | The project's version control system. |
+| [pytest](https://github.com/pytest-dev/pytest) | 🧪 Testing | {repo-file}`pytest.ini` | Testing framework for python code. |
+| [Coverage.py](https://github.com/nedbat/coveragepy) | 📊 Coverage | {repo-file}`.coveragerc` | The code coverage tool for Python |
+| [Codecov](https://about.codecov.io/) | 📊 Coverage | {repo-file}`.github/workflows/ci.yml` | An external services for tracking, monitoring, and alerting on code coverage metrics. |
+| [pre-commit](https://pre-commit.com/) | 💅 Linting | {repo-file}`.pre-commit-config.yaml` | Used to to automatically check and fix any formatting rules on every commit. |
+| [pyright](https://github.com/microsoft/pyright) | 💅 Linting | {repo-file}`pyrightconfig.json` | A static type checker for Python. We use quite a strict configuration here, which can be tricky at times. Feel free to ask for help from the community by commenting on your issue or pull request. |
+| [ruff](https://github.com/astral-sh/ruff) | 💅 Linting | {repo-file}`ruff.toml` | "An extremely fast Python linter and code formatter, written in Rust." For this project, ruff replaced black, Flake8 (+plugins), isort, pydocstyle, pyupgrade, and autoflake with a single (and faster) tool. |
+| [EditorConfig](https://editorconfig.org/) | 💅 Linting | {repo-file}`.editorconfig` | This repository uses the `.editorconfig` standard configuration file, which aims to ensure consistent style across multiple programming environments. |
+| [bumpversion](https://github.com/c4urself/bump2version) | 📦 Packaging | {repo-file}`.bumpversion.cfg` | A small command line tool to simplify releasing software by updating all version strings in your source code by the correct increment. |
+| [setuptools](https://setuptools.pypa.io/en/latest/) | 📦 Packaging | {repo-file}`pyproject.toml` and {repo-file}`MANIFEST.in` | `MANIFEST.in` tells `setuptools` which files to include in the distribution. `pyproject.toml` is the new standard for defining static package metadata. |
+| [readthedocs](https://readthedocs.org/) | 📚 Documentation | {repo-file}`.readthedocs.yaml` and {repo-file}`.github/workflows/readthedocs-preview.yml` | An open-source documentation hosting platform. We use it to automatically build and deploy the documentation for this project. |
## Code of Conduct
diff --git a/docs/getting_started/getting_started.md b/docs/getting_started/getting_started.md
index 899b9479..a1e6f999 100644
--- a/docs/getting_started/getting_started.md
+++ b/docs/getting_started/getting_started.md
@@ -10,7 +10,7 @@ This basic example shows how you can quickly get started with a simple call to t
import numpy as np
from ridgeplot import ridgeplot
-my_samples = [np.random.normal(n / 1.2, size=600) for n in range(7, 0, -1)]
+my_samples = [np.random.normal(n / 1.2, size=600) for n in range(6, 0, -1)]
fig = ridgeplot(samples=my_samples)
fig.show()
```
@@ -19,6 +19,18 @@ fig.show()
:file: ../_static/charts/basic.html
```
+By default, the {py:func}`~ridgeplot.ridgeplot()` function will estimate the samples' probability density functions (PDFs) using kernel density estimation (KDE) and plot them as ridgeline area traces ({py:paramref}`trace_type="area" `). If you want to plot histograms instead, you can set the {py:paramref}`~ridgeplot.ridgeplot.nbins` parameter to an integer, which will automatically switch the trace type to `"bar"`.
+
+```python
+fig = ridgeplot(samples=my_samples, nbins=20)
+fig.show()
+```
+
+```{raw} html
+:file: ../_static/charts/basic_hist.html
+```
+
+
## Flexible configuration
In this example, we will try to replicate the first ridgeline plot in this [_from Data to Viz_ post](https://www.data-to-viz.com/graph/ridgeline.html). The example in the post was created using the _"Perception of Probability Words"_ dataset (see {py:func}`~ridgeplot.datasets.load_probly()`) and the popular [ggridges](https://wilkelab.org/ggridges/) R package. In the end, we will see how the `ridgeplot` Python library can be used to create a (nearly) identical plot, thanks to its extensive configuration options.
@@ -208,10 +220,10 @@ samples = [
```
:::{note}
-For other use cases (like in the two previous examples), you could use a numpy ndarray to represent the samples. However, since different months have different number of days, we need to use a data container that can hold arrays of different lengths along the same dimension. Irregular arrays like this one are called [ragged arrays](https://en.wikipedia.org/wiki/Jagged_array). There are many different ways you can represent irregular arrays in Python. In this specific example, we used a list of lists of pandas Series. However,`ridgeplot` is designed to handle any object that implements the {py:class}`~typing.Collection`\[{py:class}`~typing.Collection`\[{py:class}`~typing.Collection`\[{py:data}`~ridgeplot._types.Numeric`\]]] protocol (_i.e.,_ any numeric 3D ragged array).
+For other use cases (like in the two previous examples), you could use a numpy ndarray to represent the samples. However, since different months have different number of days, we need to use a data container that can hold arrays of different lengths along the same dimension. Irregular arrays like this one are called [ragged arrays](https://en.wikipedia.org/wiki/Jagged_array). There are many different ways you can represent irregular arrays in Python. In this specific example, we used a list of lists of pandas Series. However, {py:func}`~ridgeplot.ridgeplot()` is designed to handle any object that implements the {py:class}`~typing.Collection`\[{py:class}`~typing.Collection`\[{py:class}`~typing.Collection`\[{py:data}`~ridgeplot._types.Numeric`\]]] protocol (_i.e.,_ any numeric 3D ragged array).
:::
-Finally, we can pass the `samples` list to the {py:func}`~ridgeplot.ridgeplot()` function and specify any other arguments we want to customize the plot, like adjusting the KDE's bandwidth, the vertical spacing between rows, etc.
+Finally, we can pass the {py:paramref}`~ridgeplot.ridgeplot.samples` list to the {py:func}`~ridgeplot.ridgeplot()` function and specify any other arguments we want to customize the plot, like adjusting the KDE's bandwidth, the vertical spacing between rows, etc.
```python
fig = ridgeplot(
@@ -251,7 +263,7 @@ We are currently investigating the best way to support all color options availab
The {py:func}`~ridgeplot.ridgeplot()` function offers flexible customisation options that help you control the automatic coloring of ridgeline traces. Take a look at {py:paramref}`~ridgeplot.ridgeplot.colorscale`, {py:paramref}`~ridgeplot.ridgeplot.colormode`, and {py:paramref}`~ridgeplot.ridgeplot.opacity` for more information.
-To demonstrate how these options can be used, we can try to adjust the output from the previous example to use different colors for the minimum and maximum temperature traces. For instance, setting all minimum temperature traces to a shade of blue and all maximum temperature traces to a shade of red. To achieve this, we just need to adjust the `colorscale` and `colormode` parameters in the call to the {py:func}`~ridgeplot.ridgeplot()` function. _i.e._,
+To demonstrate how these options can be used, we can try to adjust the output from the previous example to use different colors for the minimum and maximum temperature traces. For instance, setting all minimum temperature traces to a shade of blue and all maximum temperature traces to a shade of red. To achieve this, we just need to adjust the {py:paramref}`~ridgeplot.ridgeplot.colorscale` and {py:paramref}`~ridgeplot.ridgeplot.colormode` parameters in the call to the {py:func}`~ridgeplot.ridgeplot()` function. _i.e._,
```python
fig = ridgeplot(
diff --git a/docs/index.md b/docs/index.md
index 13857220..6a4bb3ea 100644
--- a/docs/index.md
+++ b/docs/index.md
@@ -47,7 +47,7 @@ For those in a hurry, here's a very basic example on how to quickly get started
import numpy as np
from ridgeplot import ridgeplot
-my_samples = [np.random.normal(n / 1.2, size=600) for n in range(7, 0, -1)]
+my_samples = [np.random.normal(n / 1.2, size=600) for n in range(6, 0, -1)]
fig = ridgeplot(samples=my_samples)
fig.show()
```
diff --git a/docs/reference/changelog.md b/docs/reference/changelog.md
index bc9a812b..99380c5f 100644
--- a/docs/reference/changelog.md
+++ b/docs/reference/changelog.md
@@ -5,13 +5,24 @@ This document outlines the list of changes to ridgeplot between each release. Fo
Unreleased changes
------------------
+### Features
+
+- Add support for histogram and bar traces ({gh-pr}`287`)
+
### Documentation
- Small improvements to `ridgeplot()`'s docstring ({gh-pr}`284`)
+- Misc improvements to the API docs and the getting-started and contributing guides ({gh-pr}`287`)
### Internal
- Small improvements to type hints and annotations ({gh-pr}`284`)
+- Introduce an internal `ridgeplot._obj` package to hold object-oriented interfaces ({gh-pr}`287`)
+
+### CI/CD
+
+- Improve type annotations and switch from mypy to pyright with stricter settings ({gh-pr}`287`)
+- Switch from `black` to the new `ruff` formatter ({gh-pr}`287`)
---
diff --git a/pyproject.toml b/pyproject.toml
index 7217266a..c3b81bed 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -67,7 +67,3 @@ namespaces = false
# and we want to push X.devM to TestPyPi
# on every merge to the `main` branch
local_scheme = "no-local-version"
-
-[tool.black]
-line-length = 100
-include = '\.pyi?$'
diff --git a/pyrightconfig.json b/pyrightconfig.json
new file mode 100644
index 00000000..52090272
--- /dev/null
+++ b/pyrightconfig.json
@@ -0,0 +1,20 @@
+{
+ "include": [
+ "src",
+ "tests",
+ "docs",
+ "cicd_utils"
+ ],
+ "exclude": [
+ "docs/build",
+ "**/__pycache__"
+ ],
+ "extraPaths": [
+ "cicd_utils"
+ ],
+ "typeCheckingMode": "strict",
+ "reportMissingTypeStubs": "none",
+ "reportUnknownMemberType": "none",
+ "reportUnknownArgumentType": "none",
+ "reportUnknownVariableType": "none"
+}
diff --git a/requirements/local-dev.txt b/requirements/local-dev.txt
index 5b355dd6..5f7bce80 100644
--- a/requirements/local-dev.txt
+++ b/requirements/local-dev.txt
@@ -23,5 +23,5 @@ ptpython
# And everything else...
-r cicd_utils.txt
-r docs.txt
--r mypy.txt
+-r typing.txt
-r tests.txt
diff --git a/requirements/mypy.txt b/requirements/typing.txt
similarity index 70%
rename from requirements/mypy.txt
rename to requirements/typing.txt
index d3c7eb4e..e07fd046 100644
--- a/requirements/mypy.txt
+++ b/requirements/typing.txt
@@ -1,5 +1,5 @@
-# mypy dependencies
-mypy
+# pyright
+pyright
# Third-party stubs
types-python-dateutil
@@ -9,7 +9,7 @@ types-requests
types-tqdm
pandas-stubs
-# mypy also needs to inherit other environment dependencies in
+# pyright also needs to inherit other environment dependencies in
# order to correctly infer types for code in tests, docs, etc.
-r cicd_utils.txt
-r docs.txt
diff --git a/ruff.toml b/ruff.toml
index 59da2fc0..59cc58b6 100644
--- a/ruff.toml
+++ b/ruff.toml
@@ -6,11 +6,11 @@ fix = true
line-length = 100
# ================================================
-# Formatting settings (currently not used)
+# Formatting settings
# ================================================
[format]
line-ending = "lf"
-docstring-code-format = true
+docstring-code-format = false
# ================================================
# Linting settings
@@ -74,8 +74,6 @@ ignore = [
"TD003", # Missing issue link on the line following this TODO
# flake8-annotations (ANN)
- "ANN101", # Missing type annotation for `self` in method
- "ANN102", # Missing type annotation for `cls` in classmethod
"ANN401", # Dynamically typed expressions (typing.Any) are disallowed in {name}
]
diff --git a/src/ridgeplot/_color/colorscale.py b/src/ridgeplot/_color/colorscale.py
index 9d9fa68a..d63b4ce2 100644
--- a/src/ridgeplot/_color/colorscale.py
+++ b/src/ridgeplot/_color/colorscale.py
@@ -13,7 +13,7 @@
from collections.abc import Collection
-class ColorscaleValidator(_ColorscaleValidator): # type: ignore[misc]
+class ColorscaleValidator(_ColorscaleValidator):
def __init__(self) -> None:
super().__init__("colorscale", "ridgeplot")
@@ -29,6 +29,7 @@ def validate_coerce(self, v: Any) -> ColorScale:
coerced = super().validate_coerce(v)
if coerced is None: # pragma: no cover
self.raise_invalid_val(coerced)
+ coerced = cast(ColorScale, coerced)
# This helps us avoid floating point errors when making
# comparisons in our test suite. The user should not
# be able to notice *any* difference in the output
@@ -37,12 +38,12 @@ def validate_coerce(self, v: Any) -> ColorScale:
def infer_default_colorscale() -> ColorScale | Collection[Color] | str:
- return validate_and_coerce_colorscale(
+ return validate_coerce_colorscale(
default_plotly_template().layout.colorscale.sequential or px.colors.sequential.Viridis
)
-def validate_and_coerce_colorscale(
+def validate_coerce_colorscale(
colorscale: ColorScale | Collection[Color] | str | None,
) -> ColorScale:
"""Convert mixed colorscale representations to the canonical
diff --git a/src/ridgeplot/_color/interpolation.py b/src/ridgeplot/_color/interpolation.py
index f3241711..f9376903 100644
--- a/src/ridgeplot/_color/interpolation.py
+++ b/src/ridgeplot/_color/interpolation.py
@@ -1,28 +1,25 @@
from __future__ import annotations
from dataclasses import dataclass
-from typing import TYPE_CHECKING, Literal, Protocol, TypedDict, overload
+from typing import TYPE_CHECKING, Literal, Protocol
-import plotly.graph_objs as go
-
-from ridgeplot._color.colorscale import validate_and_coerce_colorscale
from ridgeplot._color.utils import apply_alpha, round_color, to_rgb, unpack_rgb
-from ridgeplot._types import CollectionL2, Color, ColorScale
+from ridgeplot._types import CollectionL2, ColorScale
from ridgeplot._utils import get_xy_extrema, normalise_min_max
from ridgeplot._vendor.more_itertools import zip_strict
if TYPE_CHECKING:
- from collections.abc import Collection, Generator
+ from collections.abc import Generator
from ridgeplot._types import Densities, Numeric
# ==============================================================
-# --- Common interpolation utilities
+# --- Interpolation utilities
# ==============================================================
-def _interpolate_color(colorscale: ColorScale, p: float) -> str:
+def interpolate_color(colorscale: ColorScale, p: float) -> str:
"""Get a color from a colorscale at a given interpolation point ``p``.
This function always returns a color in the RGB format, even if the input
@@ -59,6 +56,51 @@ def _interpolate_color(colorscale: ColorScale, p: float) -> str:
return round_color(rgb, 5)
+def slice_colorscale(
+ colorscale: ColorScale,
+ p_lower: float,
+ p_upper: float,
+) -> ColorScale:
+ """Slice a continuous colorscale between two intermediate points.
+
+ Parameters
+ ----------
+ colorscale
+ The continuous colorscale to slice.
+ p_lower
+ The lower bound of the slicing interval. Must be >= 0 and < p_upper.
+ p_upper
+ The upper bound of the slicing interval. Must be <= 1 and > p_lower.
+
+ Returns
+ -------
+ ColorScale
+ The sliced colorscale.
+
+ Raises
+ ------
+ ValueError
+ If ``p_lower`` is >= ``p_upper``, or if either ``p_lower`` or ``p_upper``
+ are outside the range [0, 1].
+ """
+ if p_lower >= p_upper:
+ raise ValueError("p_lower should be less than p_upper.")
+ if p_lower < 0 or p_upper > 1:
+ raise ValueError("p_lower should be >= 0 and p_upper should be <= 1.")
+ if p_lower == 0 and p_upper == 1:
+ return colorscale
+
+ return (
+ (0.0, interpolate_color(colorscale, p=p_lower)),
+ *[
+ (normalise_min_max(v, min_=p_lower, max_=p_upper), c)
+ for v, c in colorscale
+ if p_lower < v < p_upper
+ ],
+ (1.0, interpolate_color(colorscale, p=p_upper)),
+ )
+
+
# ==============================================================
# --- Solid color modes
# ==============================================================
@@ -79,6 +121,8 @@ def _interpolate_color(colorscale: ColorScale, p: float) -> str:
@dataclass
class InterpolationContext:
+ """Context information needed by the interpolation functions."""
+
densities: Densities
n_rows: int
n_traces: int
@@ -87,7 +131,7 @@ class InterpolationContext:
@classmethod
def from_densities(cls, densities: Densities) -> InterpolationContext:
- x_min, x_max, _, _ = map(float, get_xy_extrema(densities=densities))
+ x_min, x_max, _, _ = get_xy_extrema(densities=densities)
return cls(
densities=densities,
n_rows=len(densities),
@@ -160,8 +204,8 @@ def _interpolate_mean_means(ctx: InterpolationContext) -> ColorscaleInterpolants
x, y = zip(*trace)
means_row.append(sum(_mul(x, y)) / sum(y))
means.append(means_row)
- min_mean = min([min(row) for row in means])
- max_mean = max([max(row) for row in means])
+ min_mean = min(min(row) for row in means)
+ max_mean = max(max(row) for row in means)
return [
[normalise_min_max(mean, min_=min_mean, max_=max_mean) for mean in row] for row in means
]
@@ -185,14 +229,16 @@ def _interpolate_mean_means(ctx: InterpolationContext) -> ColorscaleInterpolants
}
-def _compute_solid_colors(
+def compute_solid_colors(
colorscale: ColorScale,
colormode: SolidColormode,
opacity: float | None,
interpolation_ctx: InterpolationContext,
) -> Generator[Generator[str]]:
- def _get_fill_color(p: float) -> str:
- fill_color = _interpolate_color(colorscale, p=p)
+ """Compute the solid colors for all traces in the plot."""
+
+ def get_fill_color(p: float) -> str:
+ fill_color = interpolate_color(colorscale, p=p)
if opacity is not None:
# Sometimes the interpolation logic can drop the alpha channel
fill_color = apply_alpha(fill_color, alpha=float(opacity))
@@ -200,195 +246,4 @@ def _get_fill_color(p: float) -> str:
interpolate_func = SOLID_COLORMODE_MAPS[colormode]
interpolants = interpolate_func(ctx=interpolation_ctx)
- return ((_get_fill_color(p) for p in row) for row in interpolants)
-
-
-class SolidColorsDict(TypedDict):
- line_color: Color
- fillcolor: str
-
-
-def _compute_solid_trace_colors(
- colorscale: ColorScale,
- colormode: SolidColormode,
- line_color: Color | Literal["fill-color"],
- opacity: float | None,
- interpolation_ctx: InterpolationContext,
-) -> Generator[Generator[SolidColorsDict]]:
- return (
- (
- dict(
- line_color=fill_color if line_color == "fill-color" else line_color,
- fillcolor=fill_color,
- )
- for fill_color in row
- )
- for row in _compute_solid_colors(
- colorscale=colorscale,
- colormode=colormode,
- opacity=opacity,
- interpolation_ctx=interpolation_ctx,
- )
- )
-
-
-# ==============================================================
-# --- `fillgradient` color mode
-# ==============================================================
-
-
-def _slice_colorscale(
- colorscale: ColorScale,
- p_lower: float,
- p_upper: float,
-) -> ColorScale:
- """Slice a continuous colorscale between two intermediate points.
-
- Parameters
- ----------
- colorscale
- The continuous colorscale to slice.
- p_lower
- The lower bound of the slicing interval. Must be >= 0 and < p_upper.
- p_upper
- The upper bound of the slicing interval. Must be <= 1 and > p_lower.
-
- Returns
- -------
- ColorScale
- The sliced colorscale.
-
- Raises
- ------
- ValueError
- If ``p_lower`` is >= ``p_upper``, or if either ``p_lower`` or ``p_upper``
- are outside the range [0, 1].
- """
- if p_lower >= p_upper:
- raise ValueError("p_lower should be less than p_upper.")
- if p_lower < 0 or p_upper > 1:
- raise ValueError("p_lower should be >= 0 and p_upper should be <= 1.")
- if p_lower == 0 and p_upper == 1:
- return colorscale
-
- return (
- (0.0, _interpolate_color(colorscale, p=p_lower)),
- *[
- (normalise_min_max(v, min_=p_lower, max_=p_upper), c)
- for v, c in colorscale
- if p_lower < v < p_upper
- ],
- (1.0, _interpolate_color(colorscale, p=p_upper)),
- )
-
-
-class FillgradientColorsDict(TypedDict):
- line_color: str
- fillgradient: go.scatter.Fillgradient
-
-
-def _compute_fillgradient_trace_colors(
- colorscale: ColorScale,
- line_color: Color | Literal["fill-color"],
- opacity: float | None,
- interpolation_ctx: InterpolationContext,
-) -> Generator[Generator[FillgradientColorsDict]]:
- solid_line_colors: Generator[Generator[Color]]
- if line_color == "fill-color":
- solid_line_colors = _compute_solid_colors(
- colorscale=colorscale,
- colormode="mean-minmax",
- opacity=opacity,
- interpolation_ctx=interpolation_ctx,
- )
- else:
- solid_line_colors = ((line_color for _ in row) for row in interpolation_ctx.densities)
- if opacity is not None:
- # HACK: Plotly doesn't yet support setting the fill opacity
- # for traces with `fillgradient`. As a workaround, we
- # can override the color-scale's color values and add
- # the corresponding alpha channel to all colors.
- colorscale = [(v, apply_alpha(c, float(opacity))) for v, c in colorscale]
- return (
- (
- dict(
- line_color=line_color,
- fillgradient=go.scatter.Fillgradient(
- colorscale=_slice_colorscale(
- colorscale=colorscale,
- p_lower=normalise_min_max(
- min(next(zip(*trace))),
- min_=interpolation_ctx.x_min,
- max_=interpolation_ctx.x_max,
- ),
- p_upper=normalise_min_max(
- max(next(zip(*trace))),
- min_=interpolation_ctx.x_min,
- max_=interpolation_ctx.x_max,
- ),
- ),
- type="horizontal",
- ),
- )
- for line_color, trace in zip_strict(line_colors_row, densities_row)
- )
- for line_colors_row, densities_row in zip_strict(
- solid_line_colors, interpolation_ctx.densities
- )
- )
-
-
-# ==============================================================
-# --- Main public function
-# ==============================================================
-
-
-@overload
-def compute_trace_colors(
- colorscale: ColorScale | Collection[Color] | str | None,
- colormode: Literal["fillgradient"],
- line_color: Color | Literal["fill-color"],
- opacity: float | None,
- interpolation_ctx: InterpolationContext,
-) -> Generator[Generator[FillgradientColorsDict]]: ...
-
-
-@overload
-def compute_trace_colors(
- colorscale: ColorScale | Collection[Color] | str | None,
- colormode: SolidColormode,
- line_color: Color | Literal["fill-color"],
- opacity: float | None,
- interpolation_ctx: InterpolationContext,
-) -> Generator[Generator[SolidColorsDict]]: ...
-
-
-def compute_trace_colors(
- colorscale: ColorScale | Collection[Color] | str | None,
- colormode: Literal["fillgradient"] | SolidColormode,
- line_color: Color | Literal["fill-color"],
- opacity: float | None,
- interpolation_ctx: InterpolationContext,
-) -> Generator[Generator[FillgradientColorsDict | SolidColorsDict]]:
- colorscale = validate_and_coerce_colorscale(colorscale)
-
- valid_colormodes = ("fillgradient", *SOLID_COLORMODE_MAPS)
- if colormode not in valid_colormodes:
- raise ValueError(
- f"The colormode argument should be one of {valid_colormodes}, got {colormode} instead."
- )
-
- if colormode == "fillgradient":
- return _compute_fillgradient_trace_colors(
- colorscale=colorscale,
- line_color=line_color,
- opacity=opacity,
- interpolation_ctx=interpolation_ctx,
- )
- return _compute_solid_trace_colors(
- colorscale=colorscale,
- colormode=colormode,
- line_color=line_color,
- opacity=opacity,
- interpolation_ctx=interpolation_ctx,
- )
+ return ((get_fill_color(p) for p in row) for row in interpolants)
diff --git a/src/ridgeplot/_color/utils.py b/src/ridgeplot/_color/utils.py
index 5dccb56b..5122ceef 100644
--- a/src/ridgeplot/_color/utils.py
+++ b/src/ridgeplot/_color/utils.py
@@ -1,17 +1,14 @@
from __future__ import annotations
-from typing import TYPE_CHECKING, cast
+from collections.abc import Collection
+from typing import Union, cast
import plotly.express as px
import plotly.graph_objects as go
import plotly.io as pio
-from ridgeplot._color.css_colors import CSS_NAMED_COLORS, CssNamedColor
-
-if TYPE_CHECKING:
- from collections.abc import Collection
-
- from ridgeplot._types import Color
+from ridgeplot._color.css_colors import CSS_NAMED_COLORS
+from ridgeplot._types import Color
def default_plotly_template() -> go.layout.Template:
@@ -21,11 +18,13 @@ def default_plotly_template() -> go.layout.Template:
# TODO: Move this in the future to a separate module
# once we add support for color sequences.
def infer_default_color_sequence() -> Collection[Color]: # pragma: no cover
- return default_plotly_template().layout.colorway or px.colors.qualitative.D3 # type: ignore[no-any-return]
+ return cast(
+ Collection[Color], default_plotly_template().layout.colorway or px.colors.qualitative.D3
+ )
def to_rgb(color: Color) -> str:
- if not isinstance(color, (str, tuple)):
+ if not isinstance(color, (str, tuple)): # type: ignore[reportUnnecessaryIsInstance]
raise TypeError(f"Expected str or tuple for color, got {type(color)} instead.")
if isinstance(color, tuple):
r, g, b = color
@@ -35,7 +34,6 @@ def to_rgb(color: Color) -> str:
elif color.startswith(("rgb(", "rgba(")):
rgb = color
elif color in CSS_NAMED_COLORS:
- color = cast(CssNamedColor, color)
return to_rgb(CSS_NAMED_COLORS[color])
else:
raise ValueError(
@@ -50,7 +48,7 @@ def unpack_rgb(rgb: str) -> tuple[float, float, float, float] | tuple[float, flo
prefix = rgb.split("(")[0] + "("
values_str = map(str.strip, rgb.removeprefix(prefix).removesuffix(")").split(","))
values_num = tuple(int(v) if v.isdecimal() else float(v) for v in values_str)
- return values_num # type: ignore[return-value]
+ return cast(Union[tuple[float, float, float, float], tuple[float, float, float]], values_num)
def apply_alpha(color: Color, alpha: float) -> str:
diff --git a/src/ridgeplot/_figure_factory.py b/src/ridgeplot/_figure_factory.py
index 0793d2a1..dac0822d 100644
--- a/src/ridgeplot/_figure_factory.py
+++ b/src/ridgeplot/_figure_factory.py
@@ -1,22 +1,29 @@
from __future__ import annotations
-from dataclasses import dataclass
-from typing import TYPE_CHECKING, Any, Literal
+from typing import TYPE_CHECKING, Literal, cast
from plotly import graph_objects as go
+from ridgeplot._color.colorscale import validate_coerce_colorscale
from ridgeplot._color.interpolation import (
InterpolationContext,
SolidColormode,
- compute_trace_colors,
+ compute_solid_colors,
)
+from ridgeplot._obj.traces import get_trace_cls
+from ridgeplot._obj.traces.base import ColoringContext
from ridgeplot._types import (
- CollectionL1,
- CollectionL2,
Color,
ColorScale,
- DensityTrace,
+ LabelsArray,
+ ShallowLabelsArray,
+ ShallowTraceTypesArray,
+ TraceType,
+ TraceTypesArray,
is_flat_str_collection,
+ is_shallow_trace_types_array,
+ is_trace_type,
+ is_trace_types_array,
nest_shallow_collection,
)
from ridgeplot._utils import (
@@ -30,48 +37,23 @@
if TYPE_CHECKING:
from collections.abc import Collection
- from ridgeplot._types import Densities, Numeric
+ from ridgeplot._types import Densities
-LabelsArray = CollectionL2[str]
-"""A :data:`LabelsArray` represents the labels of traces in a ridgeplot.
-
-Example
--------
-
->>> labels_array: LabelsArray = [
-... ["trace 1", "trace 2", "trace 3"],
-... ["trace 4", "trace 5"],
-... ]
-"""
-
-ShallowLabelsArray = CollectionL1[str]
-"""Shallow type for :data:`LabelsArray`.
-
-Example
--------
-
->>> labels_array: ShallowLabelsArray = ["trace 1", "trace 2", "trace 3"]
-"""
-
-_D3HF = ".7"
-"""Default (d3-format) format for floats in hover labels.
-
-After trying to read through the plotly.py source code, I couldn't find a
-simple way to replicate the default hover format using the d3-format syntax
-in Plotly's 'hovertemplate' parameter. The closest I got was by using the
-string below, but it's not quite the same... (see '.7~r' as well)
-"""
-
-_DEFAULT_HOVERTEMPLATE = (
- f"(%{{x:{_D3HF}}}, %{{customdata[0]:{_D3HF}}})"
- "
"
- "%{fullData.name}"
-) # fmt: skip
-"""Default ``hovertemplate`` for density traces.
-
-See :func:`draw_density_trace`.
-"""
+def normalise_trace_types(
+ densities: Densities,
+ trace_types: TraceTypesArray | ShallowTraceTypesArray | TraceType,
+) -> TraceTypesArray:
+ if is_trace_type(trace_types):
+ trace_types = cast(TraceTypesArray, [[trace_types] * len(row) for row in densities])
+ elif is_shallow_trace_types_array(trace_types):
+ trace_types = nest_shallow_collection(trace_types)
+ trace_types = normalise_row_attrs(trace_types, l2_target=densities)
+ elif is_trace_types_array(trace_types):
+ trace_types = normalise_row_attrs(trace_types, l2_target=densities)
+ else:
+ raise TypeError(f"Invalid trace_type: {trace_types}")
+ return trace_types
def normalise_trace_labels(
@@ -93,76 +75,6 @@ def normalise_y_labels(trace_labels: LabelsArray) -> LabelsArray:
return [ordered_dedup(row) for row in trace_labels]
-@dataclass
-class RidgeplotTrace:
- trace: DensityTrace
- label: str
- color: dict[str, Any]
-
-
-@dataclass
-class RidgeplotRow:
- traces: list[RidgeplotTrace]
- y_shifted: float
-
-
-def draw_base(
- fig: go.Figure,
- x: Collection[Numeric],
- y_shifted: float,
-) -> go.Figure:
- """Draw the base for a density trace.
-
- Adds an invisible trace at constant y that will serve as the fill-limit
- for the corresponding density trace.
- """
- fig.add_trace(
- go.Scatter(
- x=x,
- y=[y_shifted] * len(x),
- # make trace 'invisible'
- # Note: visible=False does not work with fill="tonexty"
- line=dict(color="rgba(0,0,0,0)", width=0),
- showlegend=False,
- hoverinfo="skip",
- )
- )
- return fig
-
-
-def draw_density_trace(
- fig: go.Figure,
- x: Collection[Numeric],
- y: Collection[Numeric],
- y_shifted: float,
- label: str,
- color: dict[str, Any],
- line_width: float,
-) -> go.Figure:
- """Draw a density trace.
-
- Adds a density 'trace' to the Figure. The ``fill="tonexty"`` option
- fills the trace until the previously drawn trace (see
- :meth:`draw_base`). This is why the base trace must be drawn first.
- """
- fig = draw_base(fig, x=x, y_shifted=y_shifted)
- fig.add_trace(
- go.Scatter(
- x=x,
- y=[y_i + y_shifted for y_i in y],
- **color,
- name=label,
- fill="tonexty",
- mode="lines",
- line=dict(width=line_width),
- # Hover information
- customdata=[[y_i] for y_i in y],
- hovertemplate=_DEFAULT_HOVERTEMPLATE,
- ),
- )
- return fig
-
-
def update_layout(
fig: go.Figure,
y_labels: LabelsArray,
@@ -192,17 +104,27 @@ def update_layout(
showticklabels=True,
**axes_common,
)
+ # Settings for bar/histogram traces:
+ fig.update_layout(
+ # barmode can be either 'stack' or 'relative'
+ barmode="stack",
+ # bargap and bargroupgap should be set
+ # to 0 to avoid gaps between bars
+ bargap=0,
+ bargroupgap=0,
+ )
return fig
def create_ridgeplot(
densities: Densities,
+ trace_types: TraceTypesArray | ShallowTraceTypesArray | TraceType,
colorscale: ColorScale | Collection[Color] | str | None,
opacity: float | None,
colormode: Literal["fillgradient"] | SolidColormode,
trace_labels: LabelsArray | ShallowLabelsArray | None,
line_color: Color | Literal["fill-color"],
- line_width: float,
+ line_width: float | None,
spacing: float,
show_yticklabels: bool,
xpad: float,
@@ -218,6 +140,10 @@ def create_ridgeplot(
n_traces = sum(len(row) for row in densities)
x_min, x_max, _, y_max = map(float, get_xy_extrema(densities=densities))
+ trace_types = normalise_trace_types(
+ densities=densities,
+ trace_types=trace_types,
+ )
trace_labels = normalise_trace_labels(
densities=densities,
trace_labels=trace_labels,
@@ -226,58 +152,65 @@ def create_ridgeplot(
y_labels = normalise_y_labels(trace_labels)
# Force cast certain arguments to the expected types
- line_width = float(line_width)
+ line_width = float(line_width) if line_width is not None else None
spacing = float(spacing)
show_yticklabels = bool(show_yticklabels)
xpad = float(xpad)
+ colorscale = validate_coerce_colorscale(colorscale)
# ==============================================================
# --- Build the figure
# ==============================================================
- colors = compute_trace_colors(
+ interpolation_ctx = InterpolationContext(
+ densities=densities,
+ n_rows=n_rows,
+ n_traces=n_traces,
+ x_min=x_min,
+ x_max=x_max,
+ )
+ solid_colors = compute_solid_colors(
colorscale=colorscale,
- colormode=colormode,
- line_color=line_color,
+ colormode=colormode if colormode != "fillgradient" else "mean-minmax",
opacity=opacity,
- interpolation_ctx=InterpolationContext(
- densities=densities,
- n_rows=n_rows,
- n_traces=n_traces,
- x_min=x_min,
- x_max=x_max,
- ),
+ interpolation_ctx=interpolation_ctx,
)
- rows: list[RidgeplotRow] = [
- RidgeplotRow(
- traces=[
- RidgeplotTrace(trace=trace, label=label, color=color)
- for trace, label, color in zip_strict(traces, labels, colors)
- ],
- y_shifted=float(-ith_row * y_max * spacing),
- )
- for ith_row, (traces, labels, colors) in enumerate(
- zip_strict(densities, trace_labels, colors)
- )
- ]
+ tickvals: list[float] = []
fig = go.Figure()
- for row in rows:
- for trace in row.traces:
- x, y = zip(*trace.trace)
- fig = draw_density_trace(
- fig,
- x=x,
- y=y,
- y_shifted=row.y_shifted,
- label=trace.label,
- color=trace.color,
+ ith_trace = 0
+ for ith_row, (row_traces, row_trace_types, row_labels, row_colors) in enumerate(
+ zip_strict(densities, trace_types, trace_labels, solid_colors)
+ ):
+ y_base = float(-ith_row * y_max * spacing)
+ tickvals.append(y_base)
+ for trace, trace_type, label, color in zip_strict(
+ row_traces, row_trace_types, row_labels, row_colors
+ ):
+ trace_drawer = get_trace_cls(trace_type)(
+ trace=trace,
+ label=label,
+ solid_color=color,
+ zorder=ith_trace,
+ y_base=y_base,
+ line_color=line_color,
line_width=line_width,
)
+ fig = trace_drawer.draw(
+ fig=fig,
+ coloring_ctx=ColoringContext(
+ colorscale=colorscale,
+ colormode=colormode,
+ opacity=opacity,
+ interpolation_ctx=interpolation_ctx,
+ ),
+ )
+ ith_trace += 1
+
fig = update_layout(
fig,
y_labels=y_labels,
- tickvals=[row.y_shifted for row in rows],
+ tickvals=tickvals,
show_yticklabels=show_yticklabels,
xpad=xpad,
x_max=x_max,
diff --git a/src/ridgeplot/_hist.py b/src/ridgeplot/_hist.py
new file mode 100644
index 00000000..a91d02f8
--- /dev/null
+++ b/src/ridgeplot/_hist.py
@@ -0,0 +1,52 @@
+from __future__ import annotations
+
+from typing import TYPE_CHECKING
+
+import numpy as np
+
+from ridgeplot._kde import normalize_sample_weights
+from ridgeplot._vendor.more_itertools import zip_strict
+
+if TYPE_CHECKING:
+ from ridgeplot._types import (
+ Densities,
+ DensityTrace,
+ Samples,
+ SamplesTrace,
+ SampleWeights,
+ SampleWeightsArray,
+ ShallowSampleWeightsArray,
+ )
+
+
+def bin_trace_samples(
+ trace_samples: SamplesTrace,
+ nbins: int,
+ weights: SampleWeights = None,
+) -> DensityTrace:
+ trace_samples = np.asarray(trace_samples, dtype=float)
+ if not np.isfinite(trace_samples).all():
+ raise ValueError("The samples array should not contain any infs or NaNs.")
+ if weights is not None:
+ weights = np.asarray(weights, dtype=float)
+ if len(weights) != len(trace_samples):
+ raise ValueError("The weights array should have the same length as the samples array.")
+ if not np.isfinite(weights).all():
+ raise ValueError("The weights array should not contain any infs or NaNs.")
+ hist, bins = np.histogram(trace_samples, bins=nbins, weights=weights)
+ return [(float(x), float(y)) for x, y in zip(bins, hist)]
+
+
+def bin_samples(
+ samples: Samples,
+ nbins: int,
+ sample_weights: SampleWeightsArray | ShallowSampleWeightsArray | SampleWeights = None,
+) -> Densities:
+ normalised_weights = normalize_sample_weights(sample_weights=sample_weights, samples=samples)
+ return [
+ [
+ bin_trace_samples(trace_samples, nbins=nbins, weights=weights)
+ for trace_samples, weights in zip_strict(samples_row, weights_row)
+ ]
+ for samples_row, weights_row in zip_strict(samples, normalised_weights)
+ ]
diff --git a/src/ridgeplot/_kde.py b/src/ridgeplot/_kde.py
index ceda645b..1fafdb21 100644
--- a/src/ridgeplot/_kde.py
+++ b/src/ridgeplot/_kde.py
@@ -3,9 +3,10 @@
import sys
from collections.abc import Collection
from functools import partial
-from typing import TYPE_CHECKING, Any, Callable, Optional, Union, cast
+from typing import TYPE_CHECKING, Any, Callable, Union, cast
import numpy as np
+import numpy.typing as npt
import statsmodels.api as sm
from statsmodels.sandbox.nonparametric.kernels import CustomKernel as StatsmodelsKernel
@@ -16,9 +17,11 @@
from ridgeplot._types import (
CollectionL1,
- CollectionL2,
- Float,
+ DensityTrace,
Numeric,
+ SampleWeights,
+ SampleWeightsArray,
+ ShallowSampleWeightsArray,
is_flat_numeric_collection,
nest_shallow_collection,
)
@@ -26,9 +29,7 @@
from ridgeplot._vendor.more_itertools import zip_strict
if TYPE_CHECKING:
- import numpy.typing as npt
-
- from ridgeplot._types import Densities, Samples, SamplesTrace, XYCoordinate
+ from ridgeplot._types import Densities, Samples, SamplesTrace
KDEPoints = Union[int, CollectionL1[Numeric]]
@@ -37,17 +38,6 @@
KDEBandwidth = Union[str, float, Callable[[CollectionL1[Numeric], StatsmodelsKernel], float]]
"""The :paramref:`ridgeplot.ridgeplot.bandwidth` parameter."""
-SampleWeights = Optional[CollectionL1[Numeric]]
-"""An array of KDE weights corresponding to each sample."""
-
-SampleWeightsArray = CollectionL2[SampleWeights]
-"""A :data:`SampleWeightsArray` represents the weights of the datapoints in a
-:data:`Samples` array. The shape of the :data:`SampleWeightsArray` array should
-match the shape of the corresponding :data:`Samples` array."""
-
-ShallowSampleWeightsArray = CollectionL1[SampleWeights]
-"""Shallow type for :data:`SampleWeightsArray`."""
-
def _is_sample_weights(obj: Any) -> TypeIs[SampleWeights]:
"""Type guard for :data:`SampleWeights`.
@@ -106,11 +96,6 @@ def normalize_sample_weights(
"""
if _is_sample_weights(sample_weights):
return [[sample_weights] * len(row) for row in samples]
- # TODO: Investigate this issue with mypy's type narrowing...
- sample_weights = cast( # type: ignore[unreachable]
- Union[SampleWeightsArray, ShallowSampleWeightsArray],
- sample_weights,
- )
if _is_shallow_sample_weights(sample_weights):
sample_weights = nest_shallow_collection(sample_weights)
sample_weights = normalise_row_attrs(sample_weights, l2_target=samples)
@@ -123,7 +108,7 @@ def estimate_density_trace(
kernel: str,
bandwidth: KDEBandwidth,
weights: SampleWeights = None,
-) -> list[XYCoordinate[Float]]:
+) -> DensityTrace:
"""Estimates a density trace from a set of samples.
For a given set of sample values, computes the kernel densities (KDE) at
@@ -170,18 +155,20 @@ def estimate_density_trace(
dens.fit(
kernel=kernel,
fft=kernel == "gau" and weights is None,
- bw=bandwidth,
+ bw=bandwidth, # pyright: ignore[reportArgumentType]
weights=weights,
)
density_y = dens.evaluate(density_x)
- _validate_densities(x=density_x, y=density_y, kernel=kernel)
+ density_y = _validate_densities(x=density_x, y=density_y, kernel=kernel)
return list(zip(density_x, density_y))
def _validate_densities(
- x: npt.NDArray[np.floating[Any]], y: npt.NDArray[np.floating[Any]], kernel: str
-) -> None:
+ x: npt.NDArray[np.floating[Any]],
+ y: Any,
+ kernel: str,
+) -> npt.NDArray[np.floating[Any]]:
# I haven't investigated the root of this issue yet
# but statsmodels' KDEUnivariate implementation
# can return a float('NaN') if something goes
@@ -199,10 +186,12 @@ def _validate_densities(
# Fail early if the return type is incorrect
# Otherwise, the remaining checks will fail
raise RuntimeError(msg) # noqa: TRY004
+ y = cast(npt.NDArray[np.floating[Any]], y)
wrong_shape = y.shape != x.shape
not_finite = ~np.isfinite(y).all()
if wrong_shape or not_finite:
raise RuntimeError(msg)
+ return y
def estimate_densities(
diff --git a/src/ridgeplot/_obj/__init__.py b/src/ridgeplot/_obj/__init__.py
new file mode 100644
index 00000000..e69de29b
diff --git a/src/ridgeplot/_obj/traces/__init__.py b/src/ridgeplot/_obj/traces/__init__.py
new file mode 100644
index 00000000..8d2f221b
--- /dev/null
+++ b/src/ridgeplot/_obj/traces/__init__.py
@@ -0,0 +1,32 @@
+from __future__ import annotations
+
+from typing import TYPE_CHECKING
+
+from ridgeplot._obj.traces.area import AreaTrace
+from ridgeplot._obj.traces.bar import BarTrace
+from ridgeplot._obj.traces.base import RidgeplotTrace
+
+if TYPE_CHECKING:
+ from ridgeplot._types import TraceType
+
+__all__ = [
+ "AreaTrace",
+ "BarTrace",
+ "RidgeplotTrace",
+ "get_trace_cls",
+]
+
+_TRACE_TYPES: dict[TraceType, type[RidgeplotTrace]] = {
+ "area": AreaTrace,
+ "bar": BarTrace,
+}
+"""Mapping of trace types to trace classes."""
+
+
+def get_trace_cls(trace_type: TraceType) -> type[RidgeplotTrace]:
+ """Get a trace class by its type."""
+ try:
+ return _TRACE_TYPES[trace_type]
+ except KeyError as err:
+ types = ", ".join(repr(t) for t in _TRACE_TYPES)
+ raise ValueError(f"Unknown trace type {trace_type!r}. Available types: {types}.") from err
diff --git a/src/ridgeplot/_obj/traces/area.py b/src/ridgeplot/_obj/traces/area.py
new file mode 100644
index 00000000..54e53683
--- /dev/null
+++ b/src/ridgeplot/_obj/traces/area.py
@@ -0,0 +1,85 @@
+from __future__ import annotations
+
+from typing import Any, ClassVar
+
+from plotly import graph_objects as go
+
+from ridgeplot._color.interpolation import slice_colorscale
+from ridgeplot._color.utils import apply_alpha
+from ridgeplot._obj.traces.base import DEFAULT_HOVERTEMPLATE, ColoringContext, RidgeplotTrace
+from ridgeplot._utils import normalise_min_max
+
+
+class AreaTrace(RidgeplotTrace):
+ _DEFAULT_LINE_WIDTH: ClassVar[float] = 1.5
+
+ def _get_coloring_kwargs(self, ctx: ColoringContext) -> dict[str, Any]:
+ if ctx.colormode == "fillgradient":
+ if ctx.opacity is not None:
+ # HACK: Plotly doesn't yet support setting the fill opacity
+ # for traces with `fillgradient`. As a workaround, we
+ # can override the color-scale's color values and add
+ # the corresponding alpha channel to all colors.
+ ctx.colorscale = [
+ (v, apply_alpha(c, float(ctx.opacity))) for v, c in ctx.colorscale
+ ]
+ color_kwargs = dict(
+ line_color=self.line_color,
+ fillgradient=go.scatter.Fillgradient(
+ colorscale=slice_colorscale(
+ colorscale=ctx.colorscale,
+ p_lower=normalise_min_max(
+ min(self.x),
+ min_=ctx.interpolation_ctx.x_min,
+ max_=ctx.interpolation_ctx.x_max,
+ ),
+ p_upper=normalise_min_max(
+ max(self.x),
+ min_=ctx.interpolation_ctx.x_min,
+ max_=ctx.interpolation_ctx.x_max,
+ ),
+ ),
+ type="horizontal",
+ ),
+ )
+ else:
+ color_kwargs = dict(
+ line_color=self.line_color,
+ fillcolor=self.solid_color,
+ )
+ return color_kwargs
+
+ def draw(self, fig: go.Figure, coloring_ctx: ColoringContext) -> go.Figure:
+ # Draw an invisible trace at constance y=y_base so that we
+ # can set fill="tonexty" below and get a filled area plot
+ fig.add_trace(
+ go.Scatter(
+ x=self.x,
+ y=[self.y_base] * len(self.x),
+ # make trace 'invisible'
+ # Note: visible=False does not work with fill="tonexty"
+ line=dict(color="rgba(0,0,0,0)", width=0),
+ # Hide this invisible helper trace from the legend and hoverinfo
+ showlegend=False,
+ hoverinfo="skip",
+ # z-order (higher z-order means the trace is drawn on top)
+ zorder=self.zorder,
+ )
+ )
+ fig.add_trace(
+ go.Scatter(
+ x=self.x,
+ y=[y_i + self.y_base for y_i in self.y],
+ name=self.label,
+ fill="tonexty",
+ mode="lines",
+ line_width=self.line_width,
+ **self._get_coloring_kwargs(ctx=coloring_ctx),
+ # Hover information
+ customdata=[[y_i] for y_i in self.y],
+ hovertemplate=DEFAULT_HOVERTEMPLATE,
+ # z-order (higher z-order means the trace is drawn on top)
+ zorder=self.zorder,
+ ),
+ )
+ return fig
diff --git a/src/ridgeplot/_obj/traces/bar.py b/src/ridgeplot/_obj/traces/bar.py
new file mode 100644
index 00000000..ba621e03
--- /dev/null
+++ b/src/ridgeplot/_obj/traces/bar.py
@@ -0,0 +1,53 @@
+from __future__ import annotations
+
+from typing import Any, ClassVar
+
+from plotly import graph_objects as go
+
+from ridgeplot._color.interpolation import interpolate_color
+from ridgeplot._obj.traces.base import DEFAULT_HOVERTEMPLATE, ColoringContext, RidgeplotTrace
+from ridgeplot._utils import normalise_min_max
+
+
+class BarTrace(RidgeplotTrace):
+ _DEFAULT_LINE_WIDTH: ClassVar[float] = 0.5
+
+ def _get_coloring_kwargs(self, ctx: ColoringContext) -> dict[str, Any]:
+ if ctx.colormode == "fillgradient":
+ color_kwargs = dict(
+ marker_line_color=self.line_color,
+ marker_color=[
+ interpolate_color(
+ colorscale=ctx.colorscale,
+ p=normalise_min_max(
+ x_i, min_=ctx.interpolation_ctx.x_min, max_=ctx.interpolation_ctx.x_max
+ ),
+ )
+ for x_i in self.x
+ ],
+ )
+ else:
+ color_kwargs = dict(
+ marker_line_color=self.line_color,
+ marker_color=self.solid_color,
+ )
+ return color_kwargs
+
+ def draw(self, fig: go.Figure, coloring_ctx: ColoringContext) -> go.Figure:
+ fig.add_trace(
+ go.Bar(
+ x=self.x,
+ y=self.y,
+ name=self.label,
+ base=self.y_base,
+ marker_line_width=self.line_width,
+ width=None, # Plotly automatically picks the right width
+ **self._get_coloring_kwargs(ctx=coloring_ctx),
+ # Hover information
+ customdata=[[y_i] for y_i in self.y],
+ hovertemplate=DEFAULT_HOVERTEMPLATE,
+ # z-order (higher z-order means the trace is drawn on top)
+ zorder=self.zorder,
+ ),
+ )
+ return fig
diff --git a/src/ridgeplot/_obj/traces/base.py b/src/ridgeplot/_obj/traces/base.py
new file mode 100644
index 00000000..b75e26ab
--- /dev/null
+++ b/src/ridgeplot/_obj/traces/base.py
@@ -0,0 +1,76 @@
+from __future__ import annotations
+
+from abc import ABC, abstractmethod
+from dataclasses import dataclass
+from typing import TYPE_CHECKING, ClassVar, Literal
+
+from ridgeplot._vendor.more_itertools import zip_strict
+
+if TYPE_CHECKING:
+ from plotly import graph_objects as go
+
+ from ridgeplot._color.interpolation import InterpolationContext, SolidColormode
+ from ridgeplot._types import Color, ColorScale, DensityTrace
+
+
+_D3HF = ".7"
+"""Default (d3-format) format for floats in hover labels.
+
+After trying to read through the plotly.py source code, I couldn't find a
+simple way to replicate the default hover format using the d3-format syntax
+in Plotly's 'hovertemplate' parameter. The closest I got was by using the
+string below, but it's not quite the same... (see '.7~r' as well)
+"""
+
+DEFAULT_HOVERTEMPLATE = (
+ f"(%{{x:{_D3HF}}}, %{{customdata[0]:{_D3HF}}})"
+ "
"
+ "%{fullData.name}"
+) # fmt: skip
+"""Default ``hovertemplate`` for density traces.
+
+The default hover template that should be used for all density traces. It
+displays the x and y values of the hovered point, as well as the trace's name.
+When using this as ``hovertemplate=DEFAULT_HOVERTEMPLATE``, it is expected that
+the trace's ``customdata`` is set to a list of lists, where each inner list
+contains a single element that is the y-value of the corresponding x-value
+(e.g. ``customdata=[[y_i] for y_i in y]``). The ``name`` attribute of the trace
+should also be set to the desired label for the trace (e.g. ``name=self.label``).
+"""
+
+
+@dataclass
+class ColoringContext:
+ colorscale: ColorScale
+ colormode: Literal["fillgradient"] | SolidColormode
+ opacity: float | None
+ interpolation_ctx: InterpolationContext
+
+
+class RidgeplotTrace(ABC):
+ _DEFAULT_LINE_WIDTH: ClassVar[float] = 2.0
+
+ def __init__(
+ self,
+ *, # kw only
+ trace: DensityTrace,
+ label: str,
+ solid_color: str,
+ zorder: int,
+ # Constant over the trace's row
+ y_base: float,
+ # Constant over the entire plot
+ line_color: Color | Literal["fill-color"],
+ line_width: float | None,
+ ):
+ self.x, self.y = zip_strict(*trace)
+ self.label = label
+ self.solid_color = solid_color
+ self.zorder = zorder
+ self.y_base = y_base
+ self.line_color: Color = self.solid_color if line_color == "fill-color" else line_color
+ self.line_width: float = line_width if line_width is not None else self._DEFAULT_LINE_WIDTH
+
+ @abstractmethod
+ def draw(self, fig: go.Figure, coloring_ctx: ColoringContext) -> go.Figure:
+ raise NotImplementedError
diff --git a/src/ridgeplot/_ridgeplot.py b/src/ridgeplot/_ridgeplot.py
index de7b00c6..02fbf97c 100644
--- a/src/ridgeplot/_ridgeplot.py
+++ b/src/ridgeplot/_ridgeplot.py
@@ -1,22 +1,13 @@
from __future__ import annotations
import warnings
-from typing import TYPE_CHECKING, Literal, cast
+from typing import TYPE_CHECKING, cast
-from ridgeplot._figure_factory import (
- LabelsArray,
- ShallowLabelsArray,
- create_ridgeplot,
-)
-from ridgeplot._missing import MISSING, MissingType
+from ridgeplot._figure_factory import create_ridgeplot
+from ridgeplot._missing import MISSING
from ridgeplot._types import (
- Color,
- ColorScale,
Densities,
- NormalisationOption,
Samples,
- ShallowDensities,
- ShallowSamples,
is_shallow_densities,
is_shallow_samples,
nest_shallow_collection,
@@ -25,6 +16,7 @@
if TYPE_CHECKING:
from collections.abc import Collection
+ from typing import Literal
import plotly.graph_objects as go
@@ -32,39 +24,68 @@
from ridgeplot._kde import (
KDEBandwidth,
KDEPoints,
+ )
+ from ridgeplot._missing import MissingType
+ from ridgeplot._types import (
+ Color,
+ ColorScale,
+ LabelsArray,
+ NormalisationOption,
SampleWeights,
SampleWeightsArray,
+ ShallowDensities,
+ ShallowLabelsArray,
+ ShallowSamples,
ShallowSampleWeightsArray,
+ ShallowTraceTypesArray,
+ TraceType,
+ TraceTypesArray,
)
def _coerce_to_densities(
samples: Samples | ShallowSamples | None,
densities: Densities | ShallowDensities | None,
+ # KDE parameters
kernel: str,
bandwidth: KDEBandwidth,
kde_points: KDEPoints,
+ # Histogram parameters
+ nbins: int | None,
+ # Common parameters for density estimation
sample_weights: SampleWeightsArray | ShallowSampleWeightsArray | SampleWeights,
) -> Densities:
# Importing statsmodels, scipy, and numpy can be slow,
# so we're hiding the kde import here to only incur
# this cost if the user actually needs this it...
+ from ridgeplot._hist import bin_samples
from ridgeplot._kde import estimate_densities
+ # Input validation
has_samples = samples is not None
has_densities = densities is not None
if has_samples and has_densities:
raise ValueError("You may not specify both `samples` and `densities` arguments!")
if not has_samples and not has_densities:
raise ValueError("You must specify either `samples` or `densities`")
+
+ # Exit early if densities are already provided
if has_densities:
if is_shallow_densities(densities):
densities = nest_shallow_collection(densities)
- densities = cast(Densities, densities)
+ return densities
+
+ # Transform samples into densities via KDE or histogram binning
+ if is_shallow_samples(samples):
+ samples = nest_shallow_collection(samples)
+ samples = cast(Samples, samples)
+ if nbins is not None:
+ densities = bin_samples(
+ samples=samples,
+ nbins=nbins,
+ sample_weights=sample_weights,
+ )
else:
- if is_shallow_samples(samples):
- samples = nest_shallow_collection(samples)
- samples = cast(Samples, samples)
densities = estimate_densities(
samples=samples,
points=kde_points,
@@ -78,29 +99,40 @@ def _coerce_to_densities(
def ridgeplot(
samples: Samples | ShallowSamples | None = None,
densities: Densities | ShallowDensities | None = None,
+ trace_type: TraceTypesArray | ShallowTraceTypesArray | TraceType | None = None,
+ labels: LabelsArray | ShallowLabelsArray | None = None,
+ # KDE parameters
kernel: str = "gau",
bandwidth: KDEBandwidth = "normal_reference",
kde_points: KDEPoints = 500,
+ # Histogram parameters
+ nbins: int | None = None,
+ # Common parameters for density estimation
sample_weights: SampleWeightsArray | ShallowSampleWeightsArray | SampleWeights = None,
+ norm: NormalisationOption | None = None,
+ # Coloring and styling parameters
colorscale: ColorScale | Collection[Color] | str | None = None,
colormode: Literal["fillgradient"] | SolidColormode = "fillgradient",
opacity: float | None = None,
- labels: LabelsArray | ShallowLabelsArray | None = None,
- norm: NormalisationOption | None = None,
line_color: Color | Literal["fill-color"] = "black",
- line_width: float = 1.5,
+ line_width: float | None = None,
spacing: float = 0.5,
show_yticklabels: bool = True,
xpad: float = 0.05,
- # Deprecated arguments
+ # Deprecated parameters
coloralpha: float | None | MissingType = MISSING,
linewidth: float | MissingType = MISSING,
) -> go.Figure:
r"""Return an interactive ridgeline (Plotly) |~go.Figure|.
.. note::
- You must pass either :paramref:`samples` or :paramref:`densities` to
- this function, but not both. See descriptions below for more details.
+ You must specify either :paramref:`samples` or :paramref:`densities` to
+ this function, but not both. When specifying :paramref:`samples`, the
+ function will estimate the densities using either Kernel Density
+ Estimation (KDE) or histogram binning. When specifying
+ :paramref:`densities`, the function will skip the density estimation
+ step and use the provided densities directly. See the parameter
+ descriptions below for more details.
.. _bandwidths.py:
https://www.statsmodels.org/stable/_modules/statsmodels/nonparametric/bandwidths.html
@@ -112,12 +144,19 @@ def ridgeplot(
Parameters
----------
samples : Samples or ShallowSamples
- If ``samples`` data is specified, Kernel Density Estimation (KDE) will
- be computed. See :paramref:`kernel`, :paramref:`bandwidth`,
- :paramref:`kde_points`, and :paramref:`sample_weights` for more details
- and KDE configuration options. The ``samples`` argument should be an
- array of shape :math:`(R, T_r, S_t)`. Note that we support irregular
- (`ragged`_) arrays, where:
+ If ``samples`` data is specified, either Kernel Density Estimation (KDE)
+ or histogram binning will be performed to estimate the underlying
+ densities.
+
+ See :paramref:`kernel`, :paramref:`bandwidth`, and
+ :paramref:`kde_points` for more details on the different KDE parameters.
+ See :paramref:`nbins` for more details on histogram binning. The
+ :paramref:`sample_weights` parameter can be used for both KDE and
+ histogram binning.
+
+ The ``samples`` argument should be an array of shape
+ :math:`(R, T_r, S_t)`. Note that we support irregular (`ragged`_)
+ arrays, where:
- :math:`R` is the number of rows in the plot
- :math:`T_r` is the number of traces per row, where each row
@@ -125,17 +164,17 @@ def ridgeplot(
- :math:`S_t` is the number of samples per trace, where each trace
:math:`t \in T_r` can also have a different number of samples.
- The KDE will be performed over the sample values (:math:`S_t`) for all
- traces. After the KDE, the resulting array will be a (4D)
- :paramref:`densities` array with shape :math:`(R, T_r, P_t, 2)`
- (see below for more details).
+ The density estimation step will be performed over the sample values
+ (:math:`S_t`) for all traces. The resulting array will be a (4D)
+ :paramref:`densities` array of shape :math:`(R, T_r, P_t, 2)`
+ (see :paramref:`densities` below for more details).
densities : Densities or ShallowDensities
- If a ``densities`` array is specified, the KDE step will be skipped and
- all associated arguments ignored. Each density array should have shape
- :math:`(R, T_r, P_t, 2)` (4D). Just like the :paramref:`samples`
- argument, we also support irregular (`ragged`_) ``densities`` arrays,
- where:
+ If a ``densities`` array is specified, the density estimation step will
+ be skipped and all associated arguments ignored. Each density array
+ should have shape :math:`(R, T_r, P_t, 2)` (4D). Just like the
+ :paramref:`samples` argument, we also support irregular (`ragged`_)
+ ``densities`` arrays, where:
- :math:`R` is the number of rows in the plot
- :math:`T_r` is the number of traces per row, where each row
@@ -144,6 +183,24 @@ def ridgeplot(
:math:`t \in T_r` can also have a different number of points.
- :math:`2` is the number of coordinates per point (x and y)
+ See :paramref:`samples` above for more details.
+
+ trace_type : TraceTypesArray or ShallowTraceTypesArray or TraceType or None
+ The type of trace to display. Choices are ``'area'`` or ``'bar'``. If a
+ single value is passed, it will be used for all traces. If a list of
+ values is passed, it should have the same shape as the samples array.
+ If not specified (default), the traces will be displayed as area plots
+ (``trace_type='area'``) unless histogram binning is used, in which case
+ the traces will be displayed as bar plots (``trace_type='bar'``).
+
+ .. versionadded:: 0.3.0
+
+ labels : LabelsArray or ShallowLabelsArray or None
+ A list of string labels for each trace. If not specified (default), the
+ labels will be automatically generated as ``"Trace {n}"``, where ``n``
+ is the trace's index. If instead a list of labels is specified, it
+ should have the same shape as the samples array.
+
kernel : str
The Kernel to be used during Kernel Density Estimation. The default is
a Gaussian Kernel (``"gau"``). Choices are:
@@ -182,11 +239,29 @@ def ridgeplot(
set of samples. Optionally, you can also pass a custom 1D numerical
array, which will be used for all traces.
+ nbins : int or None
+ The number of bins to use when applying histogram binning. If not
+ specified (default), KDE will be used instead of histogram binning.
+
+ .. versionadded:: 0.3.0
+
sample_weights : SampleWeightsArray or ShallowSampleWeightsArray or SampleWeights or None
An (optional) array of KDE weights corresponding to each sample. The
weights should have the same shape as the samples array. If not
specified (default), all samples will be weighted equally.
+ norm : NormalisationOption or None
+ The normalisation option to use when normalising the densities. If not
+ specified (default), no normalisation will be applied and the densities
+ will be used *as is*. The following normalisation options are available:
+
+ - ``"probability"`` - normalise the densities by dividing each trace by
+ its sum.
+ - ``"percent"`` - same as ``"probability"``, but the normalised values
+ are multiplied by 100.
+
+ .. versionadded:: 0.2.0
+
colorscale : ColorScale or Collection[Color] or str or None
A continuous color scale used to color the different traces in the
ridgeline plot. It can be represented by a string name (e.g.,
@@ -256,24 +331,6 @@ def ridgeplot(
.. versionadded:: 0.2.0
Replaces the deprecated :paramref:`coloralpha` argument.
- labels : LabelsArray or ShallowLabelsArray or None
- A list of string labels for each trace. If not specified (default), the
- labels will be automatically generated as ``"Trace {n}"``, where ``n``
- is the trace's index. If instead a list of labels is specified, it
- should have the same shape as the samples array.
-
- norm : NormalisationOption or None
- The normalisation option to use when normalising the densities. If not
- specified (default), no normalisation will be applied and the densities
- will be used *as is*. The following normalisation options are available:
-
- - ``"probability"`` - normalise the densities by dividing each trace by
- its sum.
- - ``"percent"`` - same as ``"probability"``, but the normalised values
- are multiplied by 100.
-
- .. versionadded:: 0.2.0
-
line_color : Color or "fill-color"
The color of the traces' lines. Any valid CSS color is allowed
(default: ``"black"``). If the value is set to "fill-color", the line
@@ -284,8 +341,10 @@ def ridgeplot(
.. versionadded:: 0.2.0
- line_width : float
- The traces' line width (in px).
+ line_width : float or None
+ The traces' line width (in px). If not specified (default), area plots
+ will have a line width of 1.5 px, and bar plots will have a line width
+ of 0.5 px.
.. versionadded:: 0.2.0
Replaces the deprecated :paramref:`linewidth` argument.
@@ -332,15 +391,19 @@ def ridgeplot(
if neither of them is specified. i.e. you may only specify one of them.
"""
+ if trace_type is None:
+ trace_type = "area" if nbins is None else "bar"
+
densities = _coerce_to_densities(
samples=samples,
densities=densities,
kernel=kernel,
bandwidth=bandwidth,
kde_points=kde_points,
+ nbins=nbins,
sample_weights=sample_weights,
)
- del samples, kernel, bandwidth, kde_points
+ del samples, kernel, bandwidth, kde_points, nbins, sample_weights
if norm:
densities = normalise_densities(densities, norm=norm)
@@ -360,7 +423,7 @@ def ridgeplot(
opacity = coloralpha
if linewidth is not MISSING:
- if line_width != 1.5:
+ if line_width is not None:
raise ValueError(
"You may not specify both the 'linewidth' and 'line_width' arguments! "
"HINT: Use the new 'line_width' argument instead of the deprecated 'linewidth'."
@@ -389,6 +452,7 @@ def ridgeplot(
fig = create_ridgeplot(
densities=densities,
trace_labels=labels,
+ trace_types=trace_type,
colorscale=colorscale,
opacity=opacity,
colormode=colormode,
diff --git a/src/ridgeplot/_types.py b/src/ridgeplot/_types.py
index 7ed7de36..e4758b61 100644
--- a/src/ridgeplot/_types.py
+++ b/src/ridgeplot/_types.py
@@ -2,7 +2,7 @@
import sys
from collections.abc import Collection
-from typing import Any, Literal, TypeVar, Union
+from typing import Any, Literal, Optional, TypeVar, Union
import numpy as np
@@ -502,6 +502,122 @@ def is_shallow_samples(obj: Any) -> TypeIs[ShallowSamples]:
return isinstance(obj, Collection) and all(map(is_trace_samples, obj))
+# ========================================================
+# --- Other array types
+# ========================================================
+
+
+# Trace types ---
+
+TraceType = Literal["area", "bar"]
+"""The type of trace to draw in a ridgeplot. See
+:paramref:`ridgeplot.ridgeplot.trace_type` for more information."""
+
+TraceTypesArray = CollectionL2[TraceType]
+"""A :data:`TraceTypesArray` represents the types of traces in a ridgeplot.
+
+Example
+-------
+>>> trace_types_array: TraceTypesArray = [
+... ["area", "bar", "area"],
+... ["bar", "area"],
+... ]
+"""
+
+ShallowTraceTypesArray = CollectionL1[TraceType]
+"""Shallow type for :data:`TraceTypesArray`.
+
+Example
+-------
+>>> trace_types_array: ShallowTraceTypesArray = ["area", "bar", "area"]
+"""
+
+
+def is_trace_type(obj: Any) -> TypeIs[TraceType]:
+ """Type guard for :data:`TraceType`.
+
+ Examples
+ --------
+ >>> is_trace_type("area")
+ True
+ >>> is_trace_type("bar")
+ True
+ >>> is_trace_type("foo")
+ False
+ >>> is_trace_type(42)
+ False
+ """
+ from typing import get_args
+
+ return isinstance(obj, str) and obj in get_args(TraceType)
+
+
+def is_shallow_trace_types_array(obj: Any) -> TypeIs[ShallowTraceTypesArray]:
+ """Type guard for :data:`ShallowTraceTypesArray`.
+
+ Examples
+ --------
+ >>> is_shallow_trace_types_array(["area", "bar", "area"])
+ True
+ >>> is_shallow_trace_types_array(["area", "bar", "foo"])
+ False
+ >>> is_shallow_trace_types_array([1, 2, 3])
+ False
+ """
+ return isinstance(obj, Collection) and all(map(is_trace_type, obj))
+
+
+def is_trace_types_array(obj: Any) -> TypeIs[TraceTypesArray]:
+ """Type guard for :data:`TraceTypesArray`.
+
+ Examples
+ --------
+ >>> is_trace_types_array([["area", "bar"], ["area", "bar"]])
+ True
+ >>> is_trace_types_array([["area", "bar"], ["area", "foo"]])
+ False
+ >>> is_trace_types_array([["area", "bar"], ["area", 42]])
+ False
+ """
+ return isinstance(obj, Collection) and all(map(is_shallow_trace_types_array, obj))
+
+
+# Labels ---
+
+LabelsArray = CollectionL2[str]
+"""A :data:`LabelsArray` represents the labels of traces in a ridgeplot.
+
+Example
+-------
+
+>>> labels_array: LabelsArray = [
+... ["trace 1", "trace 2", "trace 3"],
+... ["trace 4", "trace 5"],
+... ]
+"""
+
+ShallowLabelsArray = CollectionL1[str]
+"""Shallow type for :data:`LabelsArray`.
+
+Example
+-------
+
+>>> labels_array: ShallowLabelsArray = ["trace 1", "trace 2", "trace 3"]
+"""
+
+# Sample weights ---
+
+SampleWeights = Optional[CollectionL1[Numeric]]
+"""An array of KDE weights corresponding to each sample."""
+
+SampleWeightsArray = CollectionL2[SampleWeights]
+"""A :data:`SampleWeightsArray` represents the weights of the datapoints in a
+:data:`Samples` array. The shape of the :data:`SampleWeightsArray` array should
+match the shape of the corresponding :data:`Samples` array."""
+
+ShallowSampleWeightsArray = CollectionL1[SampleWeights]
+"""Shallow type for :data:`SampleWeightsArray`."""
+
# ========================================================
# --- More type guards and other utilities
# ========================================================
diff --git a/src/ridgeplot/_utils.py b/src/ridgeplot/_utils.py
index c2a816e7..10d116c7 100644
--- a/src/ridgeplot/_utils.py
+++ b/src/ridgeplot/_utils.py
@@ -161,7 +161,7 @@ def _get_dim_length(obj: Any) -> int:
return len(obj)
shape: list[int | set[int]] = [_get_dim_length(arr)]
- while isinstance(arr, Collection) and len(arr) > 0:
+ while len(arr) > 0:
try:
dim_lengths = set(map(_get_dim_length, arr))
except TypeError:
diff --git a/src/ridgeplot/datasets/__init__.py b/src/ridgeplot/datasets/__init__.py
index f6dc73dd..aef71832 100644
--- a/src/ridgeplot/datasets/__init__.py
+++ b/src/ridgeplot/datasets/__init__.py
@@ -4,9 +4,9 @@
from typing import TYPE_CHECKING
if sys.version_info >= (3, 10):
- from importlib.resources import files
+ from importlib.resources import as_file, files
else:
- from importlib_resources import files
+ from importlib_resources import as_file, files
if TYPE_CHECKING:
from typing import Literal
@@ -124,7 +124,8 @@ def load_probly(
f"Unknown version {version!r} for the probly dataset. "
f"Valid versions are {list(versions.keys())}."
)
- return pd.read_csv(_DATA_DIR / versions[version])
+ with as_file(_DATA_DIR / versions[version]) as data_file:
+ return pd.read_csv(data_file)
def load_lincoln_weather() -> pd.DataFrame:
@@ -162,6 +163,7 @@ def load_lincoln_weather() -> pd.DataFrame:
https://austinwehrwein.com/data-visualization/plot-inspiration-via-fivethirtyeight/
"""
- data = pd.read_csv(_DATA_DIR / "lincoln-weather.csv", index_col="CST")
- data.index = pd.to_datetime(data.index.to_list())
+ with as_file(_DATA_DIR / "lincoln-weather.csv") as data_file:
+ data = pd.read_csv(data_file, index_col="CST")
+ data.index = pd.to_datetime(data.index)
return data
diff --git a/tests/conftest.py b/tests/conftest.py
index 4116b80e..290cd7ba 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -11,6 +11,6 @@
@pytest.fixture(autouse=True, scope="session")
-def _patch_plotly_show() -> Generator[None]:
+def _patch_plotly_show() -> Generator[None]: # pyright: ignore[reportUnusedFunction]
with patch_plotly_show():
yield
diff --git a/tests/unit/color/test_colorscale.py b/tests/unit/color/test_colorscale.py
index ae5883eb..b4512420 100644
--- a/tests/unit/color/test_colorscale.py
+++ b/tests/unit/color/test_colorscale.py
@@ -8,7 +8,7 @@
from ridgeplot._color.colorscale import (
infer_default_colorscale,
list_all_colorscale_names,
- validate_and_coerce_colorscale,
+ validate_coerce_colorscale,
)
if TYPE_CHECKING:
@@ -23,32 +23,32 @@
def test_infer_default_colorscale() -> None:
- assert infer_default_colorscale() == validate_and_coerce_colorscale("plasma")
+ assert infer_default_colorscale() == validate_coerce_colorscale("plasma")
# ==============================================================
-# --- validate_and_coerce_colorscale()
+# --- validate_coerce_colorscale()
# ==============================================================
-def test_validate_and_coerce_colorscale(
- valid_colorscale: tuple[ColorScale | Collection[Color] | str, ColorScale]
+def test_validate_coerce_colorscale(
+ valid_colorscale: tuple[ColorScale | Collection[Color] | str, ColorScale],
) -> None:
colorscale, expected = valid_colorscale
- coerced = validate_and_coerce_colorscale(colorscale=colorscale)
+ coerced = validate_coerce_colorscale(colorscale=colorscale)
values, colors = zip(*coerced)
values_expected, colors_expected = zip(*expected)
assert values == pytest.approx(values_expected)
assert colors == colors_expected
-def test_validate_and_coerce_colorscale_fails(
+def test_validate_coerce_colorscale_fails(
invalid_colorscale: ColorScale | Collection[Color] | str,
) -> None:
with pytest.raises(
ValueError, match=r"Invalid value .* received for the 'colorscale' property"
):
- validate_and_coerce_colorscale(invalid_colorscale)
+ validate_coerce_colorscale(invalid_colorscale)
# ==============================================================
@@ -65,4 +65,4 @@ def test_list_all_colorscale_names() -> None:
assert "viridis" in all_colorscale_names
assert "default" in all_colorscale_names
for name in all_colorscale_names:
- validate_and_coerce_colorscale(name)
+ validate_coerce_colorscale(name)
diff --git a/tests/unit/color/test_interpolation.py b/tests/unit/color/test_interpolation.py
index 3bddf200..8f9031aa 100644
--- a/tests/unit/color/test_interpolation.py
+++ b/tests/unit/color/test_interpolation.py
@@ -10,11 +10,10 @@
ColorscaleInterpolants,
InterpolationContext,
SolidColormode,
- _interpolate_color,
- _interpolate_mean_means,
- _interpolate_mean_minmax,
- _slice_colorscale,
- compute_trace_colors,
+ _interpolate_mean_means, # pyright: ignore[reportPrivateUsage]
+ _interpolate_mean_minmax, # pyright: ignore[reportPrivateUsage]
+ interpolate_color,
+ slice_colorscale,
)
from ridgeplot._color.utils import to_rgb
@@ -23,36 +22,106 @@
# ==============================================================
-# --- _interpolate_color()
+# --- interpolate_color()
# ==============================================================
def test_interpolate_color_p_in_scale(viridis_colorscale: ColorScale) -> None:
viridis_colorscale = list(viridis_colorscale)
- assert _interpolate_color(colorscale=viridis_colorscale, p=0) == to_rgb(
- viridis_colorscale[0][1]
- )
- assert _interpolate_color(colorscale=viridis_colorscale, p=1) == to_rgb(
+ assert interpolate_color(colorscale=viridis_colorscale, p=0) == to_rgb(viridis_colorscale[0][1])
+ assert interpolate_color(colorscale=viridis_colorscale, p=1) == to_rgb(
viridis_colorscale[-1][1]
)
# Test that the alpha channels are also properly handled here
cs = ((0, "rgba(0, 0, 0, 0)"), (1, "rgba(255, 255, 255, 1)"))
- assert _interpolate_color(colorscale=cs, p=0) == cs[0][1]
- assert _interpolate_color(colorscale=cs, p=1) == cs[-1][1]
+ assert interpolate_color(colorscale=cs, p=0) == cs[0][1]
+ assert interpolate_color(colorscale=cs, p=1) == cs[-1][1]
def test_interpolate_color_p_not_in_scale(viridis_colorscale: ColorScale) -> None:
# Hard-coded test case for the Viridis colorscale
- assert _interpolate_color(colorscale=viridis_colorscale, p=0.5) == "rgb(34.5, 144.0, 139.5)"
+ assert interpolate_color(colorscale=viridis_colorscale, p=0.5) == "rgb(34.5, 144.0, 139.5)"
# Test that the alpha channels are also properly handled here
cs = ((0, "rgba(0, 0, 0, 0)"), (1, "rgba(255, 255, 255, 1)"))
- assert _interpolate_color(colorscale=cs, p=0.5) == "rgba(127.5, 127.5, 127.5, 0.5)"
+ assert interpolate_color(colorscale=cs, p=0.5) == "rgba(127.5, 127.5, 127.5, 0.5)"
@pytest.mark.parametrize("p", [-10.0, -1.3, 1.9, 100.0])
def test_interpolate_color_fails_for_p_out_of_bounds(p: float) -> None:
with pytest.raises(ValueError, match="should be a float value between 0 and 1"):
- _interpolate_color(colorscale=..., p=p) # type: ignore[arg-type]
+ interpolate_color(colorscale=..., p=p) # pyright: ignore[reportArgumentType]
+
+
+# ==============================================================
+# --- slice_colorscale()
+# ==============================================================
+
+
+def test_slice_colorscale_lower_less_than_upper() -> None:
+ with pytest.raises(ValueError, match="p_lower should be less than p_upper"):
+ slice_colorscale(colorscale=[(0, "...")], p_lower=1, p_upper=0)
+
+
+def test_slice_colorscale_lower_than_0() -> None:
+ with pytest.raises(ValueError, match="p_lower should be >= 0"):
+ slice_colorscale(colorscale=[(0, "...")], p_lower=-1, p_upper=0)
+
+
+def test_slice_colorscale_upper_than_1() -> None:
+ with pytest.raises(ValueError, match="p_upper should be <= 1"):
+ slice_colorscale(colorscale=[(0, "...")], p_lower=0, p_upper=1.1)
+
+
+def test_slice_colorscale_unchanged() -> None:
+ cs = ((0, "rgb(0, 0, 0)"), (1, "rgb(255, 255, 255)"))
+ assert slice_colorscale(colorscale=cs, p_lower=0, p_upper=1) == cs
+
+
+def test_slice_colorscale() -> None:
+ cs = (
+ (0, "rgb(0, 0, 0)"),
+ (0.5, "rgb(127.5, 127.5, 127.5)"),
+ (1, "rgb(255, 255, 255)"),
+ )
+ assert slice_colorscale(colorscale=cs, p_lower=0.25, p_upper=0.75) == (
+ (0.0, "rgb(63.75, 63.75, 63.75)"),
+ (0.5, "rgb(127.5, 127.5, 127.5)"),
+ (1.0, "rgb(191.25, 191.25, 191.25)"),
+ )
+
+
+def test_slice_colorscale_no_intermediate_values() -> None:
+ cs = ((0, "rgb(0, 0, 0)"), (1, "rgb(255, 255, 255)"))
+ assert slice_colorscale(colorscale=cs, p_lower=0.25, p_upper=0.75) == (
+ (0.0, "rgb(63.75, 63.75, 63.75)"),
+ (1.0, "rgb(191.25, 191.25, 191.25)"),
+ )
+
+
+def test_slice_colorscale_alpha() -> None:
+ cs = (
+ (0, "rgba(0, 0, 0, 0)"),
+ (0.5, "rgba(127.5, 127.5, 127.5, 0.5)"),
+ (1, "rgba(255, 255, 255, 1)"),
+ )
+ assert slice_colorscale(colorscale=cs, p_lower=0.25, p_upper=0.75) == (
+ (0.0, "rgba(63.75, 63.75, 63.75, 0.25)"),
+ (0.5, "rgba(127.5, 127.5, 127.5, 0.5)"),
+ (1.0, "rgba(191.25, 191.25, 191.25, 0.75)"),
+ )
+
+
+def test_slice_colorscale_mixed_alpha_channels() -> None:
+ cs = (
+ (0, "rgba(0, 0, 0, 0)"),
+ (0.5, "rgba(127.5, 127.5, 127.5, 1)"),
+ (1, "rgba(255, 255, 255, 0)"),
+ )
+ assert slice_colorscale(colorscale=cs, p_lower=0.25, p_upper=0.75) == (
+ (0.0, "rgba(63.75, 63.75, 63.75, 0.5)"),
+ (0.5, "rgba(127.5, 127.5, 127.5, 1)"),
+ (1.0, "rgba(191.25, 191.25, 191.25, 0.5)"),
+ )
# ==============================================================
@@ -130,93 +199,3 @@ def test_index_based_colormodes(
interpolate_func = SOLID_COLORMODE_MAPS[colormode]
interpolants = interpolate_func(ctx=InterpolationContext.from_densities(densities))
assert interpolants == expected
-
-
-# ==============================================================
-# --- _slice_colorscale()
-# ==============================================================
-
-
-def test_slice_colorscale_lower_less_than_upper() -> None:
- with pytest.raises(ValueError, match="p_lower should be less than p_upper"):
- _slice_colorscale(colorscale=[(0, "...")], p_lower=1, p_upper=0)
-
-
-def test_slice_colorscale_lower_than_0() -> None:
- with pytest.raises(ValueError, match="p_lower should be >= 0"):
- _slice_colorscale(colorscale=[(0, "...")], p_lower=-1, p_upper=0)
-
-
-def test_slice_colorscale_upper_than_1() -> None:
- with pytest.raises(ValueError, match="p_upper should be <= 1"):
- _slice_colorscale(colorscale=[(0, "...")], p_lower=0, p_upper=1.1)
-
-
-def test_slice_colorscale_unchanged() -> None:
- cs = ((0, "rgb(0, 0, 0)"), (1, "rgb(255, 255, 255)"))
- assert _slice_colorscale(colorscale=cs, p_lower=0, p_upper=1) == cs
-
-
-def test_slice_colorscale() -> None:
- cs = (
- (0, "rgb(0, 0, 0)"),
- (0.5, "rgb(127.5, 127.5, 127.5)"),
- (1, "rgb(255, 255, 255)"),
- )
- assert _slice_colorscale(colorscale=cs, p_lower=0.25, p_upper=0.75) == (
- (0.0, "rgb(63.75, 63.75, 63.75)"),
- (0.5, "rgb(127.5, 127.5, 127.5)"),
- (1.0, "rgb(191.25, 191.25, 191.25)"),
- )
-
-
-def test_slice_colorscale_no_intermediate_values() -> None:
- cs = ((0, "rgb(0, 0, 0)"), (1, "rgb(255, 255, 255)"))
- assert _slice_colorscale(colorscale=cs, p_lower=0.25, p_upper=0.75) == (
- (0.0, "rgb(63.75, 63.75, 63.75)"),
- (1.0, "rgb(191.25, 191.25, 191.25)"),
- )
-
-
-def test_slice_colorscale_alpha() -> None:
- cs = (
- (0, "rgba(0, 0, 0, 0)"),
- (0.5, "rgba(127.5, 127.5, 127.5, 0.5)"),
- (1, "rgba(255, 255, 255, 1)"),
- )
- assert _slice_colorscale(colorscale=cs, p_lower=0.25, p_upper=0.75) == (
- (0.0, "rgba(63.75, 63.75, 63.75, 0.25)"),
- (0.5, "rgba(127.5, 127.5, 127.5, 0.5)"),
- (1.0, "rgba(191.25, 191.25, 191.25, 0.75)"),
- )
-
-
-def test_slice_colorscale_mixed_alpha_channels() -> None:
- cs = (
- (0, "rgba(0, 0, 0, 0)"),
- (0.5, "rgba(127.5, 127.5, 127.5, 1)"),
- (1, "rgba(255, 255, 255, 0)"),
- )
- assert _slice_colorscale(colorscale=cs, p_lower=0.25, p_upper=0.75) == (
- (0.0, "rgba(63.75, 63.75, 63.75, 0.5)"),
- (0.5, "rgba(127.5, 127.5, 127.5, 1)"),
- (1.0, "rgba(191.25, 191.25, 191.25, 0.5)"),
- )
-
-
-# ==============================================================
-# --- compute_trace_colors()
-# ==============================================================
-
-
-def test_colormode_invalid() -> None:
- with pytest.raises(
- ValueError, match="The colormode argument should be one of .* got INVALID instead"
- ):
- compute_trace_colors(
- colorscale="Viridis",
- colormode="INVALID", # type: ignore[call-overload]
- line_color="black",
- opacity=None,
- interpolation_ctx=InterpolationContext.from_densities([[[(0, 0)]]]),
- )
diff --git a/tests/unit/color/test_utils.py b/tests/unit/color/test_utils.py
index 96a18d15..b8b87b95 100644
--- a/tests/unit/color/test_utils.py
+++ b/tests/unit/color/test_utils.py
@@ -8,7 +8,6 @@
from ridgeplot._color.utils import apply_alpha, default_plotly_template, round_color, to_rgb
if TYPE_CHECKING:
-
from ridgeplot._types import Color
diff --git a/tests/unit/conftest.py b/tests/unit/conftest.py
index 3b1ea98b..55363cfa 100644
--- a/tests/unit/conftest.py
+++ b/tests/unit/conftest.py
@@ -1,14 +1,12 @@
from __future__ import annotations
-from typing import TYPE_CHECKING
+from collections.abc import Collection
+from typing import Union, cast
import plotly.express as px
import pytest
-if TYPE_CHECKING:
- from collections.abc import Collection
-
- from ridgeplot._types import Color, ColorScale
+from ridgeplot._types import Color, ColorScale
VIRIDIS = (
(0.0, "#440154"),
@@ -50,7 +48,7 @@ def viridis_colorscale() -> ColorScale:
def valid_colorscale(
request: pytest.FixtureRequest,
) -> tuple[ColorScale | Collection[Color] | str, ColorScale]:
- return request.param # type: ignore[no-any-return]
+ return cast(tuple[Union[ColorScale, Collection[Color], str], ColorScale], request.param)
INVALID_COLOR_SCALES = [
@@ -68,4 +66,4 @@ def valid_colorscale(
@pytest.fixture(scope="session", params=INVALID_COLOR_SCALES)
def invalid_colorscale(request: pytest.FixtureRequest) -> ColorScale | Collection[Color] | str:
- return request.param # type: ignore[no-any-return]
+ return cast(Union[ColorScale, Collection[Color], str], request.param)
diff --git a/tests/unit/obj/traces/test_bar.py b/tests/unit/obj/traces/test_bar.py
new file mode 100644
index 00000000..a9114c31
--- /dev/null
+++ b/tests/unit/obj/traces/test_bar.py
@@ -0,0 +1,66 @@
+from __future__ import annotations
+
+import pytest
+
+from ridgeplot._color.interpolation import InterpolationContext
+from ridgeplot._obj.traces.bar import BarTrace
+from ridgeplot._obj.traces.base import ColoringContext
+
+
+@pytest.fixture
+def bar_trace() -> BarTrace:
+ return BarTrace(
+ trace=[(0, 0), (1, 1), (2, 0)],
+ label="Trace 1",
+ solid_color="red",
+ zorder=1,
+ y_base=0,
+ line_color="black",
+ line_width=0.5,
+ )
+
+
+@pytest.fixture
+def interpolation_ctx() -> InterpolationContext:
+ return InterpolationContext(
+ densities=[
+ [[(0, 0), (1, 1), (2, 0)]],
+ [[(1, 0), (2, 1), (3, 0)]],
+ ],
+ n_rows=2,
+ n_traces=2,
+ x_min=0,
+ x_max=3,
+ )
+
+
+class TestBarTrace:
+ def test_coloring_kwargs_fillgradient(
+ self, bar_trace: BarTrace, interpolation_ctx: InterpolationContext
+ ) -> None:
+ coloring_ctx = ColoringContext(
+ colorscale=[(0.0, "red"), (1.0, "blue")],
+ colormode="fillgradient",
+ opacity=None,
+ interpolation_ctx=interpolation_ctx,
+ )
+ color_kwargs = bar_trace._get_coloring_kwargs(ctx=coloring_ctx) # pyright: ignore[reportPrivateUsage]
+ assert color_kwargs == {
+ "marker_line_color": "black",
+ "marker_color": ["rgb(255, 0, 0)", "rgb(170.0, 0.0, 85.0)", "rgb(85.0, 0.0, 170.0)"],
+ }
+
+ def test_coloring_kwargs_fillcolor(
+ self, bar_trace: BarTrace, interpolation_ctx: InterpolationContext
+ ) -> None:
+ coloring_ctx = ColoringContext(
+ colorscale=[(0.0, "red"), (1.0, "blue")],
+ colormode="trace-index",
+ opacity=None,
+ interpolation_ctx=interpolation_ctx,
+ )
+ color_kwargs = bar_trace._get_coloring_kwargs(ctx=coloring_ctx) # pyright: ignore[reportPrivateUsage]
+ assert color_kwargs == {
+ "marker_line_color": "black",
+ "marker_color": "red",
+ }
diff --git a/tests/unit/obj/traces/test_init.py b/tests/unit/obj/traces/test_init.py
new file mode 100644
index 00000000..83a4ba73
--- /dev/null
+++ b/tests/unit/obj/traces/test_init.py
@@ -0,0 +1,28 @@
+from __future__ import annotations
+
+from typing import TYPE_CHECKING
+
+import pytest
+
+from ridgeplot._obj.traces import AreaTrace, BarTrace, RidgeplotTrace, get_trace_cls
+
+if TYPE_CHECKING:
+ from ridgeplot._types import TraceType
+
+
+@pytest.mark.parametrize(
+ ("name", "cls"),
+ [
+ ("area", AreaTrace),
+ ("bar", BarTrace),
+ ],
+)
+def test_get_trace_cls(name: TraceType, cls: type[RidgeplotTrace]) -> None:
+ assert get_trace_cls(name) is cls
+
+
+def test_get_trace_cls_unknown() -> None:
+ with pytest.raises(
+ ValueError, match="Unknown trace type 'foo'. Available types: 'area', 'bar'."
+ ):
+ get_trace_cls("foo") # pyright: ignore[reportArgumentType]
diff --git a/tests/unit/test_datasets.py b/tests/unit/test_datasets.py
index 9a5be3e0..6ac2d499 100644
--- a/tests/unit/test_datasets.py
+++ b/tests/unit/test_datasets.py
@@ -3,7 +3,11 @@
import pandas as pd
import pytest
-from ridgeplot.datasets import _DATA_DIR, load_lincoln_weather, load_probly
+from ridgeplot.datasets import (
+ _DATA_DIR, # pyright: ignore[reportPrivateUsage]
+ load_lincoln_weather,
+ load_probly,
+)
def test_data_dir_contains_data_files() -> None:
@@ -29,7 +33,7 @@ def test_load_probly() -> None:
df_illinois = load_probly(version="illinois")
assert df_illinois.shape == (75, 17)
with pytest.raises(ValueError, match="Unknown version"):
- load_probly(version="nonexistent") # type: ignore[arg-type]
+ load_probly(version="nonexistent") # pyright: ignore[reportArgumentType]
def test_load_lincoln_weather() -> None:
diff --git a/tests/unit/test_figure_factory.py b/tests/unit/test_figure_factory.py
index 008ccf3f..43420550 100644
--- a/tests/unit/test_figure_factory.py
+++ b/tests/unit/test_figure_factory.py
@@ -11,7 +11,6 @@
class TestCreateRidgeplot:
-
@pytest.mark.parametrize(
"densities",
[
@@ -26,13 +25,14 @@ def test_densities_must_be_4d(self, densities: Densities) -> None:
with pytest.raises(ValueError, match="Expected a 4D array of densities"):
create_ridgeplot(
densities=densities,
- colorscale=..., # type: ignore[arg-type]
- opacity=..., # type: ignore[arg-type]
- colormode=..., # type: ignore[arg-type]
- trace_labels=..., # type: ignore[arg-type]
- line_color=..., # type: ignore[arg-type]
- line_width=..., # type: ignore[arg-type]
- spacing=..., # type: ignore[arg-type]
- show_yticklabels=..., # type: ignore[arg-type]
- xpad=..., # type: ignore[arg-type]
+ trace_types=..., # pyright: ignore[reportArgumentType]
+ colorscale=..., # pyright: ignore[reportArgumentType]
+ opacity=..., # pyright: ignore[reportArgumentType]
+ colormode=..., # pyright: ignore[reportArgumentType]
+ trace_labels=..., # pyright: ignore[reportArgumentType]
+ line_color=..., # pyright: ignore[reportArgumentType]
+ line_width=..., # pyright: ignore[reportArgumentType]
+ spacing=..., # pyright: ignore[reportArgumentType]
+ show_yticklabels=..., # pyright: ignore[reportArgumentType]
+ xpad=..., # pyright: ignore[reportArgumentType]
)
diff --git a/tests/unit/test_hist.py b/tests/unit/test_hist.py
new file mode 100644
index 00000000..aaf9029e
--- /dev/null
+++ b/tests/unit/test_hist.py
@@ -0,0 +1,92 @@
+from __future__ import annotations
+
+import numpy as np
+import pytest
+
+from ridgeplot._hist import (
+ bin_samples,
+ bin_trace_samples,
+)
+
+# Example data
+
+SAMPLES_IN = [1, 2, 2, 3, 4]
+NBINS = 4
+DENSITIES_OUT = [(1.0, 1.0), (1.75, 2.0), (2.5, 1.0), (3.25, 1.0)]
+X_OUT, Y_OUT = zip(*DENSITIES_OUT)
+
+WEIGHTS = [1, 1, 1, 1, 9]
+
+# ==============================================================
+# --- estimate_density_trace()
+# ==============================================================
+
+
+def test_bin_trace_samples_simple() -> None:
+ density_trace = bin_trace_samples(trace_samples=SAMPLES_IN, nbins=NBINS)
+ x, y = zip(*density_trace)
+ assert x == X_OUT
+ assert y == Y_OUT
+
+
+@pytest.mark.parametrize("nbins", [2, 5, 8, 11])
+def test_bin_trace_samples_nbins(nbins: int) -> None:
+ density_trace = bin_trace_samples(trace_samples=SAMPLES_IN, nbins=nbins)
+ assert len(density_trace) == nbins
+
+
+@pytest.mark.parametrize("non_finite_value", [np.inf, np.nan, float("inf"), float("nan")])
+def test_bin_trace_samples_fails_for_non_finite_values(non_finite_value: float) -> None:
+ err_msg = "The samples array should not contain any infs or NaNs."
+ with pytest.raises(ValueError, match=err_msg):
+ bin_trace_samples(trace_samples=[*SAMPLES_IN[:-1], non_finite_value], nbins=NBINS)
+
+
+def test_bin_trace_samples_weights() -> None:
+ density_trace = bin_trace_samples(
+ trace_samples=SAMPLES_IN,
+ nbins=NBINS,
+ weights=WEIGHTS,
+ )
+ x, y = zip(*density_trace)
+ assert x == X_OUT
+ assert np.argmax(y) == len(y) - 1
+
+
+def test_bin_trace_samples_weights_not_same_length() -> None:
+ with pytest.raises(
+ ValueError, match="The weights array should have the same length as the samples array"
+ ):
+ bin_trace_samples(trace_samples=SAMPLES_IN, nbins=NBINS, weights=[1, 1, 1])
+
+
+@pytest.mark.parametrize("non_finite_value", [np.inf, np.nan, float("inf"), float("nan")])
+def test_bin_trace_samples_weights_fails_for_non_finite_values(
+ non_finite_value: float,
+) -> None:
+ err_msg = "The weights array should not contain any infs or NaNs."
+ with pytest.raises(ValueError, match=err_msg):
+ bin_trace_samples(
+ trace_samples=SAMPLES_IN,
+ nbins=NBINS,
+ weights=[*WEIGHTS[:-1], non_finite_value],
+ )
+
+
+# ==============================================================
+# --- estimate_densities()
+# ==============================================================
+
+
+def test_bin_samples() -> None:
+ densities = bin_samples(
+ samples=[[SAMPLES_IN], [SAMPLES_IN]],
+ nbins=NBINS,
+ )
+ assert len(densities) == 2
+ for densities_row in densities:
+ assert len(densities_row) == 1
+ density_trace = next(iter(densities_row))
+ x, y = zip(*density_trace)
+ assert x == X_OUT
+ assert y == Y_OUT
diff --git a/tests/unit/test_init.py b/tests/unit/test_init.py
index aeb72412..a4693833 100644
--- a/tests/unit/test_init.py
+++ b/tests/unit/test_init.py
@@ -10,8 +10,9 @@ def test_packaged_installed() -> None:
# By definition, if a module has a __path__ attribute, it is a package.
assert hasattr(ridgeplot, "__path__")
- assert len(ridgeplot.__path__) == 1
- package_path = Path(ridgeplot.__path__[0])
+ pkg_path = list(ridgeplot.__path__)
+ assert len(pkg_path) == 1
+ package_path = Path(pkg_path[0])
assert package_path.exists()
assert package_path.is_dir()
assert package_path.name == "ridgeplot"
diff --git a/tests/unit/test_kde.py b/tests/unit/test_kde.py
index 4860ce9d..140ff897 100644
--- a/tests/unit/test_kde.py
+++ b/tests/unit/test_kde.py
@@ -8,7 +8,7 @@
from ridgeplot._kde import (
KDEPoints,
- _validate_densities,
+ _validate_densities, # pyright: ignore[reportPrivateUsage]
estimate_densities,
estimate_density_trace,
)
@@ -126,7 +126,8 @@ def test__validate_densities() -> None:
inputs."""
x = np.array([0, 1, 2, 3, 4, 5, 6])
y = np.array([0.1, 0.2, 0.3, 0.4, 0.3, 0.2, 0.1])
- _validate_densities(x=x, y=y, kernel="doesn't matter")
+ y_valid = _validate_densities(x=x, y=y, kernel="doesn't matter")
+ np.testing.assert_array_equal(y_valid, y)
@pytest.mark.parametrize(
diff --git a/tests/unit/test_missing.py b/tests/unit/test_missing.py
index 87966d28..ea90a1b6 100644
--- a/tests/unit/test_missing.py
+++ b/tests/unit/test_missing.py
@@ -30,7 +30,7 @@ def assert_all_are(*args: Any) -> None:
b = args[i + 1]
if a is not b:
raise AssertionError(
- f"{a!r} and {b!r} (i={i}) are not the same object (id: {id(a)} != {id(b)})"
+ f"{a!r} and {b!r} ({i=}) are not the same object ({id(a)=} != {id(b)=})"
)
@@ -43,7 +43,7 @@ def test_reloading() -> None:
import ridgeplot._missing as types_module
from ridgeplot._missing import MISSING
- missing1 = ridgeplot._missing.MISSING
+ missing1 = ridgeplot._missing.MISSING # pyright: ignore[reportAttributeAccessIssue]
missing2 = types_module.MISSING
missing3 = MISSING
@@ -54,17 +54,17 @@ def test_reloading() -> None:
missing1,
missing2,
missing3,
- ridgeplot._missing.MISSING,
+ ridgeplot._missing.MISSING, # pyright: ignore[reportAttributeAccessIssue]
types_module.MISSING,
MISSING,
)
reload(types_module)
- assert_all_are(
+ assert_all_are( # pragma: no cover
missing1,
missing2,
missing3,
- ridgeplot._missing.MISSING,
+ ridgeplot._missing.MISSING, # pyright: ignore[reportAttributeAccessIssue]
types_module.MISSING,
MISSING,
)
diff --git a/tests/unit/test_ridgeplot.py b/tests/unit/test_ridgeplot.py
index aa89c5d3..5f452c4c 100644
--- a/tests/unit/test_ridgeplot.py
+++ b/tests/unit/test_ridgeplot.py
@@ -46,6 +46,11 @@ def test_shallow_samples() -> None:
) # fmt: skip
+# ==============================================================
+# --- param: labels
+# ==============================================================
+
+
def test_shallow_labels() -> None:
shallow_labels = ["trace 1", "trace 2"]
assert (
@@ -61,13 +66,42 @@ def test_y_labels_dedup() -> None:
) # fmt: skip
+# ==============================================================
+# --- param: trace_type
+# ==============================================================
+
+
+def test_shallow_trace_type() -> None:
+ assert (
+ ridgeplot(samples=[[1, 2, 3], [1, 2, 3]], trace_type="bar") ==
+ ridgeplot(samples=[[1, 2, 3], [1, 2, 3]], trace_type=["bar", "bar"]) ==
+ ridgeplot(samples=[[1, 2, 3], [1, 2, 3]], trace_type=[["bar"], ["bar"]])
+ ) # fmt: skip
+
+
+def test_unknown_trace_type() -> None:
+ with pytest.raises(TypeError, match="Invalid trace_type: foo"):
+ ridgeplot(samples=[[1, 2, 3], [1, 2, 3]], trace_type="foo") # pyright: ignore[reportArgumentType]
+
+
+# ==============================================================
+# --- param: nbins
+# ==============================================================
+
+
+def test_nbins() -> None:
+ fig = ridgeplot(samples=[[[1, 2, 3], [4, 5, 6]]], nbins=3)
+ assert len(fig.data) == 2
+ assert fig.data[0]._plotly_name == "bar"
+
+
# ==============================================================
# --- param: colorscale
# ==============================================================
def test_colorscale_coercion(
- valid_colorscale: tuple[ColorScale | Collection[Color] | str, ColorScale]
+ valid_colorscale: tuple[ColorScale | Collection[Color] | str, ColorScale],
) -> None:
colorscale, coerced = valid_colorscale
assert ridgeplot(samples=[[[1, 2, 3], [4, 5, 6]]], colorscale=colorscale) == ridgeplot(
diff --git a/tests/unit/test_utils.py b/tests/unit/test_utils.py
index 7ec1eb7d..6d519fdb 100644
--- a/tests/unit/test_utils.py
+++ b/tests/unit/test_utils.py
@@ -9,7 +9,6 @@
from ridgeplot._utils import get_xy_extrema, normalise_min_max
if TYPE_CHECKING:
-
from ridgeplot._types import Densities, DensitiesRow
_X = TypeVar("_X")
@@ -36,7 +35,7 @@ def test_raise_for_non_2d_array(self) -> None:
# valid 2D trace
[[(0, 0), (1, 1), (2, 2)]],
# invalid 3D trace
- [[(3, 3, 3), (4, 4, 4)]], # type: ignore[list-item]
+ [[(3, 3, 3), (4, 4, 4)]], # pyright: ignore[reportArgumentType]
]
)
diff --git a/tox.ini b/tox.ini
index 76f11217..25972364 100644
--- a/tox.ini
+++ b/tox.ini
@@ -1,7 +1,7 @@
[tox]
labels =
- static = pre-commit-all, mypy-safe
- static-quick = pre-commit-quick, mypy-incremental
+ static = pre-commit-all, typing
+ static-quick = pre-commit-quick, typing
tests = tests-unit, tests-e2e, tests-cicd_utils
upgrade-requirements = pre-commit-autoupgrade
isolated_build = true
@@ -64,19 +64,15 @@ skip_install = true
deps = pre-commit
commands =
all: pre-commit run --all-files --show-diff-on-failure {posargs:}
- quick: pre-commit run black-jupyter --all-files
+ quick: pre-commit run ruff-format --all-files
pre-commit run ruff --all-files
autoupgrade: pre-commit autoupdate {posargs:}
-[testenv:mypy-{safe,incremental}]
-description = run type checks with mypy
-deps = -r requirements/mypy.txt
-setenv =
- {[testenv]setenv}
- _MYPY_DFLT_ARGS=--config-file=mypy.ini --strict --enable-incomplete-feature=NewGenericSyntax
+[testenv:typing]
+description = run type checks
+deps = -r requirements/typing.txt
commands =
- safe: mypy {env:_MYPY_DFLT_ARGS} --no-incremental --cache-dir=/dev/null {posargs:}
- incremental: mypy {env:_MYPY_DFLT_ARGS} --incremental {posargs:}
+ pyright --skipunannotated
[testenv:docs-{live,static}]
description = generate Sphinx (live/static) HTML documentation