diff --git a/AGENTS.md b/AGENTS.md index 03c18edb..37402fde 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -160,6 +160,26 @@ async def test_something(self, sm_runner): Do **not** manually add async no-op listeners or duplicate test classes — prefer `sm_runner`. +### TDD and coverage requirements + +Follow a **test-driven development** approach: tests are not an afterthought — they are a +first-class requirement that must be part of every implementation plan. + +- **Planning phase:** every plan must include test tasks as explicit steps, not a final + "add tests" bullet. Identify what needs to be tested (new branches, edge cases, error + paths) while designing the implementation. +- **100% branch coverage is mandatory.** The pre-commit hook enforces `--cov-fail-under=100` + with branch coverage enabled. Code that drops coverage will not pass CI. +- **Verify coverage before committing:** after writing tests, run coverage on the affected + modules and check for missing lines/branches: + ```bash + timeout 120 uv run pytest tests/.py --cov=statemachine. --cov-report=term-missing --cov-branch + ``` +- **Use pytest fixtures** (`tmp_path`, `monkeypatch`, etc.) — never hardcode paths or + use mutable global state when a fixture exists. +- **Unreachable defensive branches** (e.g., `if` guards that can never be True given the + type system) may be marked with `pragma: no cover`, but prefer writing a test first. + ## Linting and formatting ```bash diff --git a/docs/conf.py b/docs/conf.py index af845c4f..18846738 100755 --- a/docs/conf.py +++ b/docs/conf.py @@ -51,6 +51,7 @@ "sphinx.ext.autosectionlabel", "sphinx_gallery.gen_gallery", "sphinx_copybutton", + "statemachine.contrib.diagram.sphinx_ext", ] autosectionlabel_prefix_document = True diff --git a/docs/diagram.md b/docs/diagram.md index 9caeb39e..5707d794 100644 --- a/docs/diagram.md +++ b/docs/diagram.md @@ -6,6 +6,10 @@ You can generate visual diagrams from any {class}`~statemachine.statemachine.StateChart` — useful for documentation, debugging, or sharing your machine's structure with teammates. +```{statemachine-diagram} tests.examples.order_control_machine.OrderControl +:target: +``` + ## Installation Diagram generation requires [pydot](https://github.com/pydot/pydot) and @@ -26,77 +30,72 @@ For other systems, see the [Graphviz downloads page](https://graphviz.org/downlo ## Generating diagrams -Use `DotGraphMachine` to create a diagram from a class or an instance: +Every state machine instance exposes a `_graph()` method that returns a +[pydot.Dot](https://github.com/pydot/pydot) graph object: ```py ->>> from statemachine.contrib.diagram import DotGraphMachine - >>> from tests.examples.order_control_machine import OrderControl ->>> graph = DotGraphMachine(OrderControl) # also accepts instances +>>> sm = OrderControl() ->>> dot = graph() - ->>> dot.to_string() # doctest: +ELLIPSIS -'digraph OrderControl {... +>>> sm._graph() # doctest: +ELLIPSIS +>> dot.write_png("docs/images/order_control_machine_initial.png") +The diagram automatically highlights the current state of the instance. +Send events to advance the machine and see the active state change: -``` +``` py +>>> # This example will only run on automated tests if dot is present +>>> getfixture("requires_dot_installed") -![OrderControl](images/order_control_machine_initial.png) +>>> from tests.examples.order_control_machine import OrderControl -For higher resolution, set the DPI before exporting: +>>> sm = OrderControl() -```py ->>> dot.set_dpi(300) +>>> sm.receive_payment(10) +[10] ->>> dot.write_png("docs/images/order_control_machine_initial_300dpi.png") +>>> sm._graph().write_png("docs/images/order_control_machine_processing.png") ``` -![OrderControl (300 DPI)](images/order_control_machine_initial_300dpi.png) - -### Highlighting the current state +![OrderControl after receiving payment](images/order_control_machine_processing.png) -When you pass a machine **instance** (not a class), the diagram highlights -the current state: -``` py ->>> # This example will only run on automated tests if dot is present ->>> getfixture("requires_dot_installed") +### Exporting to a file ->>> from statemachine.contrib.diagram import DotGraphMachine +The `pydot.Dot` object supports writing to many formats — use +`write_png()`, `write_svg()`, `write_pdf()`, etc.: +```py >>> from tests.examples.order_control_machine import OrderControl ->>> machine = OrderControl() - ->>> graph = DotGraphMachine(machine) # also accepts instances +>>> sm = OrderControl() ->>> machine.receive_payment(10) -[10] - ->>> graph().write_png("docs/images/order_control_machine_processing.png") +>>> sm._graph().write_png("docs/images/order_control_machine_initial.png") ``` -![OrderControl](images/order_control_machine_processing.png) +![OrderControl](images/order_control_machine_initial.png) -```{tip} -Every state machine instance exposes a `_graph()` shortcut that returns -the `pydot.Dot` object directly. -``` +For higher resolution PNGs, set the DPI before exporting: ```py ->>> machine._graph() # doctest: +ELLIPSIS ->> sm._graph().set_dpi(300) + +>>> sm._graph().write_png("docs/images/order_control_machine_initial_300dpi.png") + +``` +```{note} +Supported formats include `dia`, `dot`, `fig`, `gif`, `jpg`, `pdf`, +`png`, `ps`, `svg`, and many others. See +[Graphviz output formats](https://graphviz.org/docs/outputs/) for the +complete list. ``` @@ -114,11 +113,124 @@ The output format is inferred from the file extension: python -m statemachine.contrib.diagram tests.examples.traffic_light_machine.TrafficLightMachine diagram.png ``` + +## Sphinx directive + +If you use [Sphinx](https://www.sphinx-doc.org/) to build your documentation, the +`statemachine-diagram` directive renders diagrams inline — no need to generate +image files manually. + +### Setup + +Add the extension to your `conf.py`: + +```python +extensions = [ + ... + "statemachine.contrib.diagram.sphinx_ext", +] +``` + +### Basic usage + +Reference any importable {class}`~statemachine.statemachine.StateChart` class by +its fully qualified path: + +````markdown +```{statemachine-diagram} myproject.machines.OrderControl +``` +```` + +```{statemachine-diagram} tests.examples.order_control_machine.OrderControl +:alt: OrderControl state machine +:align: center +``` + +### Highlighting a specific state + +Pass `:events:` to instantiate the machine and send events before rendering. +This highlights the current state after processing: + +````markdown +```{statemachine-diagram} myproject.machines.TrafficLight +:events: cycle +:caption: Traffic light after one cycle +``` +```` + +```{statemachine-diagram} tests.examples.traffic_light_machine.TrafficLightMachine +:events: cycle +:caption: Traffic light after one cycle +:align: center +``` + +### Enabling zoom + +For complex diagrams, add `:target:` (without a value) to make the diagram +clickable — it opens the full SVG in a new browser tab where users can +zoom and pan freely: + +````markdown +```{statemachine-diagram} myproject.machines.OrderControl +:target: +``` +```` + +```{statemachine-diagram} tests.examples.order_control_machine.OrderControl +:caption: Click to open full-size SVG +:target: +:align: center +``` + +### Directive options + +The directive supports the same layout options as the standard `image` and +`figure` directives, plus state-machine-specific ones. + +**State-machine options:** + +`:events:` *(comma-separated string)* +: Events to send in sequence. When present, the machine is instantiated and + each event is sent before rendering. + +**Image/figure options:** + +`:caption:` *(string)* +: Caption text; wraps the image in a `figure` node. + +`:alt:` *(string)* +: Alt text for the image. Defaults to the class name. + +`:width:` *(CSS length, e.g. `400px`, `80%`)* +: Explicit width for the diagram. + +`:height:` *(CSS length)* +: Explicit height for the diagram. + +`:scale:` *(integer percentage, e.g. `50%`)* +: Uniform scaling relative to the intrinsic size. + +`:align:` *(left | center | right)* +: Image alignment. Defaults to `center`. + +`:target:` *(URL or empty)* +: Makes the diagram clickable. When set without a value, the raw SVG is + saved as a file and linked so users can open it in a new tab for + full-resolution zooming — useful for large or complex diagrams. + +`:class:` *(space-separated strings)* +: Extra CSS classes for the wrapper element. + +`:figclass:` *(space-separated strings)* +: Extra CSS classes for the `figure` element (only when `:caption:` is set). + +`:name:` *(string)* +: Reference target name for cross-referencing with `{ref}`. + ```{note} -Supported formats include `dia`, `dot`, `fig`, `gif`, `jpg`, `pdf`, -`png`, `ps`, `svg`, and many others. See -[Graphviz output formats](https://graphviz.org/docs/outputs/) for the -complete list. +The directive imports the state machine class at Sphinx parse time. Machines +defined inline in doctest blocks cannot be referenced — use the +`_graph()` method for those cases. ``` @@ -139,4 +251,294 @@ using the [QuickChart](https://quickchart.io/) online service: .. autofunction:: statemachine.contrib.diagram.quickchart_write_svg ``` -![OrderControl](images/oc_machine_processing.svg) + +## Customizing the output + +The `DotGraphMachine` class gives you control over the diagram's visual +properties. Subclass it and override the class attributes to customize +fonts, colors, and layout: + +```py +>>> from statemachine.contrib.diagram import DotGraphMachine + +>>> from tests.examples.order_control_machine import OrderControl + +``` + +Available attributes: + +| Attribute | Default | Description | +|-----------|---------|-------------| +| `graph_rankdir` | `"LR"` | Graph direction (`"LR"` left-to-right, `"TB"` top-to-bottom) | +| `font_name` | `"Helvetica"` | Font face for labels | +| `state_font_size` | `"10"` | State label font size | +| `state_active_penwidth` | `2` | Border width of the active state | +| `state_active_fillcolor` | `"turquoise"` | Fill color of the active state | +| `transition_font_size` | `"9"` | Transition label font size | + +For example, to generate a top-to-bottom diagram with a custom active +state color: + +```py +>>> class CustomDiagram(DotGraphMachine): +... graph_rankdir = "TB" +... state_active_fillcolor = "lightyellow" + +>>> sm = OrderControl() + +>>> sm.receive_payment(10) +[10] + +>>> graph = CustomDiagram(sm) + +>>> dot = graph() + +>>> dot.to_string() # doctest: +ELLIPSIS +'digraph OrderControl {... + +``` + +`DotGraphMachine` also works with **classes** (not just instances) to +generate diagrams without an active state: + +```py +>>> dot = DotGraphMachine(OrderControl)() + +>>> dot.to_string() # doctest: +ELLIPSIS +'digraph OrderControl {... + +``` + + +## Visual showcase + +This section shows how each state machine feature is rendered in diagrams. +Each example includes the class definition, the **class** diagram (no +active state), and **instance** diagrams (with the current state +highlighted after sending events). + + +### Simple states + +A minimal state machine with three atomic states and linear transitions. + +```{literalinclude} ../tests/machines/showcase_simple.py +:pyobject: SimpleSC +:language: python +``` + +```{statemachine-diagram} tests.machines.showcase_simple.SimpleSC +:caption: Class +``` + +```{statemachine-diagram} tests.machines.showcase_simple.SimpleSC +:events: +:caption: Initial +``` + +```{statemachine-diagram} tests.machines.showcase_simple.SimpleSC +:events: start +:caption: Running +``` + +```{statemachine-diagram} tests.machines.showcase_simple.SimpleSC +:events: start, finish +:caption: Done (final) +``` + + +### Entry and exit actions + +States can declare `entry` / `exit` callbacks, shown in the state label. + +```{literalinclude} ../tests/machines/showcase_actions.py +:pyobject: ActionsSC +:language: python +``` + +```{statemachine-diagram} tests.machines.showcase_actions.ActionsSC +:caption: Class +``` + +```{statemachine-diagram} tests.machines.showcase_actions.ActionsSC +:events: power_on +:caption: Active: On +``` + + +### Guard conditions + +Transitions can have `cond` guards, shown in brackets on the edge label. + +```{literalinclude} ../tests/machines/showcase_guards.py +:pyobject: GuardSC +:language: python +``` + +```{statemachine-diagram} tests.machines.showcase_guards.GuardSC +:caption: Class +``` + +```{statemachine-diagram} tests.machines.showcase_guards.GuardSC +:events: +:caption: Active: Pending +``` + + +### Self-transitions + +A transition from a state back to itself. + +```{literalinclude} ../tests/machines/showcase_self_transition.py +:pyobject: SelfTransitionSC +:language: python +``` + +```{statemachine-diagram} tests.machines.showcase_self_transition.SelfTransitionSC +:caption: Class +``` + +```{statemachine-diagram} tests.machines.showcase_self_transition.SelfTransitionSC +:events: +:caption: Active: Counting +``` + + +### Internal transitions + +Internal transitions execute actions without exiting/entering the state. + +```{literalinclude} ../tests/machines/showcase_internal.py +:pyobject: InternalSC +:language: python +``` + +```{statemachine-diagram} tests.machines.showcase_internal.InternalSC +:caption: Class +``` + +```{statemachine-diagram} tests.machines.showcase_internal.InternalSC +:events: +:caption: Active: Monitoring +``` + + +### Compound states + +A compound state contains child states. Entering the compound activates +its initial child. + +```{literalinclude} ../tests/machines/showcase_compound.py +:pyobject: CompoundSC +:language: python +``` + +```{statemachine-diagram} tests.machines.showcase_compound.CompoundSC +:caption: Class +:target: +``` + +```{statemachine-diagram} tests.machines.showcase_compound.CompoundSC +:events: +:caption: Off +:target: +``` + +```{statemachine-diagram} tests.machines.showcase_compound.CompoundSC +:events: turn_on +:caption: Active/Idle +:target: +``` + +```{statemachine-diagram} tests.machines.showcase_compound.CompoundSC +:events: turn_on, begin +:caption: Active/Working +:target: +``` + + +### Parallel states + +A parallel state activates all its regions simultaneously. + +```{literalinclude} ../tests/machines/showcase_parallel.py +:pyobject: ParallelSC +:language: python +``` + +```{statemachine-diagram} tests.machines.showcase_parallel.ParallelSC +:caption: Class +:target: +``` + +```{statemachine-diagram} tests.machines.showcase_parallel.ParallelSC +:events: enter +:caption: Both active +:target: +``` + +```{statemachine-diagram} tests.machines.showcase_parallel.ParallelSC +:events: enter, go_l +:caption: Left done +:target: +``` + + +### History states (shallow) + +A history pseudo-state remembers the last active child of a compound state. + +```{literalinclude} ../tests/machines/showcase_history.py +:pyobject: HistorySC +:language: python +``` + +```{statemachine-diagram} tests.machines.showcase_history.HistorySC +:caption: Class +:target: +``` + +```{statemachine-diagram} tests.machines.showcase_history.HistorySC +:events: begin, advance +:caption: Step2 +:target: +``` + +```{statemachine-diagram} tests.machines.showcase_history.HistorySC +:events: begin, advance, pause +:caption: Paused +:target: +``` + +```{statemachine-diagram} tests.machines.showcase_history.HistorySC +:events: begin, advance, pause, resume +:caption: Resumed (→Step2) +:target: +``` + + +### Deep history + +Deep history remembers the exact leaf state across nested compounds. + +```{literalinclude} ../tests/machines/showcase_deep_history.py +:pyobject: DeepHistorySC +:language: python +``` + +```{statemachine-diagram} tests.machines.showcase_deep_history.DeepHistorySC +:caption: Class +:target: +``` + +```{statemachine-diagram} tests.machines.showcase_deep_history.DeepHistorySC +:events: dive, enter_inner, go +:caption: Inner/B +:target: +``` + +```{statemachine-diagram} tests.machines.showcase_deep_history.DeepHistorySC +:events: dive, enter_inner, go, leave, restore +:caption: Restored (→Inner/B) +:target: +``` diff --git a/docs/images/order_control_machine_initial.png b/docs/images/order_control_machine_initial.png index e843ddf0..9ba36061 100644 Binary files a/docs/images/order_control_machine_initial.png and b/docs/images/order_control_machine_initial.png differ diff --git a/docs/images/order_control_machine_initial_300dpi.png b/docs/images/order_control_machine_initial_300dpi.png index c4c3bcb3..9ba36061 100644 Binary files a/docs/images/order_control_machine_initial_300dpi.png and b/docs/images/order_control_machine_initial_300dpi.png differ diff --git a/docs/images/order_control_machine_processing.png b/docs/images/order_control_machine_processing.png index 747d5f78..5a66d93e 100644 Binary files a/docs/images/order_control_machine_processing.png and b/docs/images/order_control_machine_processing.png differ diff --git a/docs/images/readme_trafficlightmachine.png b/docs/images/readme_trafficlightmachine.png index 2defa820..add082de 100644 Binary files a/docs/images/readme_trafficlightmachine.png and b/docs/images/readme_trafficlightmachine.png differ diff --git a/docs/images/transition_compound_cancel.png b/docs/images/transition_compound_cancel.png deleted file mode 100644 index 86ca6278..00000000 Binary files a/docs/images/transition_compound_cancel.png and /dev/null differ diff --git a/docs/images/transition_from_any.png b/docs/images/transition_from_any.png deleted file mode 100644 index a5d57039..00000000 Binary files a/docs/images/transition_from_any.png and /dev/null differ diff --git a/docs/images/tutorial_coffeeorder.png b/docs/images/tutorial_coffeeorder.png deleted file mode 100644 index 52659d6d..00000000 Binary files a/docs/images/tutorial_coffeeorder.png and /dev/null differ diff --git a/docs/integrations.md b/docs/integrations.md index bec7434b..fd362cee 100644 --- a/docs/integrations.md +++ b/docs/integrations.md @@ -39,6 +39,11 @@ You can attach it to a model by inheriting from `MachineMixin` and setting `state_machine_name` to the fully qualified class name: ``` py +>>> from statemachine import registry +>>> registry.register(CampaignMachine) # register for lookup by qualname + +>>> registry._initialized = True # skip Django autodiscovery in doctest + >>> class Workflow(MachineMixin): ... state_machine_name = '__main__.CampaignMachine' ... state_machine_attr = 'sm' @@ -47,13 +52,6 @@ You can attach it to a model by inheriting from `MachineMixin` and setting ... ... workflow_step = 1 -``` - -When a `Workflow` instance is created, it automatically receives a `CampaignMachine` -instance at the `state_machine_attr` attribute. The state value is read from and -written to the `state_field_name` field: - -``` py >>> model = Workflow() >>> isinstance(model.sm, CampaignMachine) @@ -70,6 +68,7 @@ True With `bind_events_as_methods = True`, events become methods on the model itself: ``` py +>>> model = Workflow() >>> model.produce() >>> model.workflow_step 2 diff --git a/docs/releases/3.1.0.md b/docs/releases/3.1.0.md index 801170f4..69b194e7 100644 --- a/docs/releases/3.1.0.md +++ b/docs/releases/3.1.0.md @@ -4,6 +4,35 @@ ## What's new in 3.1.0 +### Sphinx directive for inline diagrams + +A new Sphinx extension renders state machine diagrams directly in your +documentation from an importable class path — no manual image generation +needed. + +Add `"statemachine.contrib.diagram.sphinx_ext"` to your `conf.py` +extensions, then use the directive in any MyST Markdown page: + +````markdown +```{statemachine-diagram} myproject.machines.OrderControl +:events: receive_payment +:caption: After payment +:target: +``` +```` + +The directive supports the same options as the standard `image`/`figure` +directives (`:width:`, `:height:`, `:scale:`, `:align:`, `:target:`, +`:class:`, `:name:`), plus `:events:` to instantiate the machine and send +events before rendering (highlighting the current state). + +Using `:target:` without a value makes the diagram clickable, opening the +full SVG in a new browser tab for zooming — useful for large statecharts. + +See {ref}`diagram:Sphinx directive` for full documentation. +[#589](https://github.com/fgmacedo/python-statemachine/pull/589). + + ### Bugfixes in 3.1.0 - Fixes silent misuse of `Event()` with multiple positional arguments. Passing more than one diff --git a/docs/transitions.md b/docs/transitions.md index d5322a70..aa1809b5 100644 --- a/docs/transitions.md +++ b/docs/transitions.md @@ -171,16 +171,14 @@ True Compare the diagrams — both model the same behavior, but the compound version makes the "cancellable" grouping explicit in the hierarchy: -```py ->>> getfixture("requires_dot_installed") ->>> OrderWorkflow()._graph().write_png("docs/images/transition_from_any.png") ->>> OrderWorkflowCompound()._graph().write_png("docs/images/transition_compound_cancel.png") - +```{statemachine-diagram} tests.machines.transition_from_any.OrderWorkflow +:caption: from_.any() ``` -| `from_.any()` | Compound | -|---|---| -| ![from_.any()](images/transition_from_any.png) | ![Compound](images/transition_compound_cancel.png) | +```{statemachine-diagram} tests.machines.transition_from_any.OrderWorkflowCompound +:caption: Compound +:target: +``` The compound approach scales better as you add more states — no need to remember to include each new state in a `from_()` list. diff --git a/docs/tutorial.md b/docs/tutorial.md index 3e150e71..e6b23d3b 100644 --- a/docs/tutorial.md +++ b/docs/tutorial.md @@ -345,46 +345,36 @@ factories, and the full list of listener callbacks. ## Generating diagrams -Visualize any state machine as a diagram. Install the `diagrams` extra -first: +Visualize any state machine as a diagram: +```{statemachine-diagram} tests.machines.tutorial_coffee_order.CoffeeOrder +:alt: CoffeeOrder diagram ``` -pip install python-statemachine[diagrams] -``` - -Then generate an image at runtime: - -```py ->>> getfixture("requires_dot_installed") - ->>> from statemachine import StateChart, State - ->>> class CoffeeOrder(StateChart): -... pending = State(initial=True) -... preparing = State() -... ready = State() -... picked_up = State(final=True) -... -... start = pending.to(preparing) -... finish = preparing.to(ready) -... pick_up = ready.to(picked_up) ->>> order = CoffeeOrder() ->>> order._graph().write_png("docs/images/tutorial_coffeeorder.png") +Generate diagrams programmatically with `_graph()`: +```python +order = CoffeeOrder() +order._graph().write_png("order.png") ``` -![CoffeeOrder](images/tutorial_coffeeorder.png) - Or from the command line: ```bash python -m statemachine.contrib.diagram my_module.CoffeeOrder order.png ``` +```{tip} +Diagram generation requires [Graphviz](https://graphviz.org/) (`dot` command) +and the `diagrams` extra: + + pip install python-statemachine[diagrams] +``` + ```{seealso} -See [](diagram.md) for Jupyter integration, SVG output, DPI settings, and -the `quickchart_write_svg` alternative that doesn't require Graphviz. +See [](diagram.md) for highlighting active states, Jupyter integration, +SVG output, DPI settings, Sphinx directive, and the `quickchart_write_svg` +alternative that doesn't require Graphviz. ``` diff --git a/pyproject.toml b/pyproject.toml index e04c9b31..50912c67 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -88,14 +88,11 @@ markers = [ ] python_files = ["tests.py", "test_*.py", "*_tests.py"] xfail_strict = true -log_cli = true log_cli_level = "DEBUG" log_cli_format = "%(relativeCreated)6.0fms %(threadName)-18s %(name)-35s %(message)s" log_cli_date_format = "%H:%M:%S" asyncio_default_fixture_loop_scope = "module" -filterwarnings = [ - "ignore::pytest_benchmark.logger.PytestBenchmarkWarning", -] +filterwarnings = ["ignore::pytest_benchmark.logger.PytestBenchmarkWarning"] [tool.coverage.run] branch = true @@ -119,6 +116,7 @@ exclude_lines = [ "raise AssertionError", "raise NotImplementedError", "if TYPE_CHECKING", + 'if __name__ == "__main__"', ] [tool.coverage.html] @@ -133,7 +131,7 @@ disable_error_code = "annotation-unchecked" mypy_path = "$MYPY_CONFIG_FILE_DIR/tests/django_project" [[tool.mypy.overrides]] -module = ['django.*', 'pytest.*', 'pydot.*', 'sphinx_gallery.*'] +module = ['django.*', 'pytest.*', 'pydot.*', 'sphinx_gallery.*', 'docutils.*', 'sphinx.*'] ignore_missing_imports = true [tool.ruff] diff --git a/statemachine/contrib/diagram.py b/statemachine/contrib/diagram.py deleted file mode 100644 index 1ec59804..00000000 --- a/statemachine/contrib/diagram.py +++ /dev/null @@ -1,370 +0,0 @@ -import importlib -import sys -from urllib.parse import quote -from urllib.request import urlopen - -import pydot - -from ..statemachine import StateChart - - -class DotGraphMachine: - graph_rankdir = "LR" - """ - Direction of the graph. Defaults to "LR" (option "TB" for top bottom) - http://www.graphviz.org/doc/info/attrs.html#d:rankdir - """ - - font_name = "Arial" - """Graph font face name""" - - state_font_size = "10pt" - """State font size""" - - state_active_penwidth = 2 - """Active state external line width""" - - state_active_fillcolor = "turquoise" - - transition_font_size = "9pt" - """Transition font size""" - - def __init__(self, machine): - self.machine = machine - - def _get_graph(self, machine): - return pydot.Dot( - machine.name, - graph_type="digraph", - label=machine.name, - fontname=self.font_name, - fontsize=self.state_font_size, - rankdir=self.graph_rankdir, - compound="true", - ) - - def _get_subgraph(self, state): - style = ", solid" - if state.parent and state.parent.parallel: - style = ", dashed" - label = state.name - if state.parallel: - label = f"<{state.name} ☷>" - subgraph = pydot.Subgraph( - label=label, - graph_name=f"cluster_{state.id}", - style=f"rounded{style}", - cluster="true", - ) - return subgraph - - def _initial_node(self, state): - node = pydot.Node( - self._state_id(state), - label="", - shape="point", - style="filled", - fontsize="1pt", - fixedsize="true", - width=0.2, - height=0.2, - ) - node.set_fillcolor("black") # type: ignore[attr-defined] - return node - - def _initial_edge(self, initial_node, state): - extra_params = {} - if state.states: - extra_params["lhead"] = f"cluster_{state.id}" - return pydot.Edge( - initial_node.get_name(), - self._state_id(state), - label="", - color="blue", - fontname=self.font_name, - fontsize=self.transition_font_size, - **extra_params, - ) - - def _actions_getter(self): - if isinstance(self.machine, StateChart): - - def getter(grouper): # pyright: ignore[reportRedeclaration] - return self.machine._callbacks.str(grouper.key) - else: - - def getter(grouper): - all_names = set(dir(self.machine)) - return ", ".join( - str(c) for c in grouper if not c.is_convention or c.func in all_names - ) - - return getter - - def _state_actions(self, state): - getter = self._actions_getter() - - entry = str(getter(state.enter)) - exit_ = str(getter(state.exit)) - internal = ", ".join( - f"{transition.event} / {str(getter(transition.on))}" - for transition in state.transitions - if transition.internal - ) - - if entry: - entry = f"entry / {entry}" - if exit_: - exit_ = f"exit / {exit_}" - - actions = "\n".join(x for x in [entry, exit_, internal] if x) - - if actions: - actions = f"\n{actions}" - - return actions - - @staticmethod - def _state_id(state): - if state.states: - return f"{state.id}_anchor" - else: - return state.id - - def _history_node(self, state): - label = "H*" if state.type.is_deep else "H" - return pydot.Node( - self._state_id(state), - label=label, - shape="circle", - style="filled", - fillcolor="white", - fontname=self.font_name, - fontsize="8pt", - fixedsize="true", - width=0.3, - height=0.3, - ) - - def _state_as_node(self, state): - actions = self._state_actions(state) - - node = pydot.Node( - self._state_id(state), - label=f"{state.name}{actions}", - shape="rectangle", - style="rounded, filled", - fontname=self.font_name, - fontsize=self.state_font_size, - peripheries=2 if state.final else 1, - ) - if ( - isinstance(self.machine, StateChart) - and state.value in self.machine.configuration_values - ): - node.set_penwidth(self.state_active_penwidth) # type: ignore[attr-defined] - node.set_fillcolor(self.state_active_fillcolor) # type: ignore[attr-defined] - else: - node.set_fillcolor("white") # type: ignore[attr-defined] - return node - - def _transition_as_edges(self, transition): - targets = transition.targets if transition.targets else [None] - cond = ", ".join([str(c) for c in transition.cond]) - if cond: - cond = f"\n[{cond}]" - - edges = [] - for i, target in enumerate(targets): - extra_params = {} - has_substates = transition.source.states or (target and target.states) - if transition.source.states: - extra_params["ltail"] = f"cluster_{transition.source.id}" - if target and target.states: - extra_params["lhead"] = f"cluster_{target.id}" - - targetless = target is None - label = f"{transition.event}{cond}" if i == 0 else "" - dst = self._state_id(target) if not targetless else self._state_id(transition.source) - edges.append( - pydot.Edge( - self._state_id(transition.source), - dst, - label=label, - color="blue", - fontname=self.font_name, - fontsize=self.transition_font_size, - minlen=2 if has_substates else 1, - **extra_params, - ) - ) - return edges - - def get_graph(self): - graph = self._get_graph(self.machine) - self._graph_states(self.machine, graph) - return graph - - def _add_transitions(self, graph, state): - for transition in state.transitions: - if transition.internal: - continue - for edge in self._transition_as_edges(transition): - graph.add_edge(edge) - - def _graph_states(self, state, graph): - initial_node = self._initial_node(state) - initial_subgraph = pydot.Subgraph( - graph_name=f"{initial_node.get_name()}_initial", - label="", - peripheries=0, - margin=0, - ) - atomic_states_subgraph = pydot.Subgraph( - graph_name=f"cluster_{initial_node.get_name()}_atomic", - label="", - peripheries=0, - cluster="true", - ) - initial_subgraph.add_node(initial_node) - graph.add_subgraph(initial_subgraph) - graph.add_subgraph(atomic_states_subgraph) - - if state.states and not getattr(state, "parallel", False): - initial = next((s for s in state.states if s.initial), None) - if initial: # pragma: no branch - graph.add_edge(self._initial_edge(initial_node, initial)) - - for substate in state.states: - if substate.states: - subgraph = self._get_subgraph(substate) - self._graph_states(substate, subgraph) - graph.add_subgraph(subgraph) - else: - atomic_states_subgraph.add_node(self._state_as_node(substate)) - self._add_transitions(graph, substate) - - for history_state in getattr(state, "history", []): - atomic_states_subgraph.add_node(self._history_node(history_state)) - self._add_transitions(graph, history_state) - - def __call__(self): - return self.get_graph() - - -def quickchart_write_svg(sm: StateChart, path: str): - """ - If the default dependency of GraphViz installed locally doesn't work for you. As an option, - you can generate the image online from the output of the `dot` language, - using one of the many services available. - - To get the **dot** representation of your state machine is as easy as follows: - - >>> from tests.examples.order_control_machine import OrderControl - >>> sm = OrderControl() - >>> print(sm._graph().to_string()) - digraph OrderControl { - compound=true; - fontname=Arial; - fontsize="10pt"; - label=OrderControl; - rankdir=LR; - ... - - To give you an example, we included this method that will serialize the dot, request the graph - to https://quickchart.io, and persist the result locally as an ``.svg`` file. - - - .. warning:: - Quickchart is an external graph service that supports many formats to generate diagrams. - - By using this method, you should trust http://quickchart.io. - - Please read https://quickchart.io/documentation/faq/ for more information. - - >>> quickchart_write_svg(sm, "docs/images/oc_machine_processing.svg") # doctest: +SKIP - - """ - dot_representation = sm._graph().to_string() - - url = f"https://quickchart.io/graphviz?graph={quote(dot_representation)}" - - response = urlopen(url) - data = response.read() - - with open(path, "wb") as f: - f.write(data) - - -def _find_sm_class(module): - """Find the first StateChart subclass defined in a module.""" - import inspect - - for _name, obj in inspect.getmembers(module, inspect.isclass): - if ( - issubclass(obj, StateChart) - and obj is not StateChart - and obj.__module__ == module.__name__ - ): - return obj - return None - - -def import_sm(qualname): - module_name, class_name = qualname.rsplit(".", 1) - module = importlib.import_module(module_name) - smclass = getattr(module, class_name, None) - if smclass is not None and isinstance(smclass, type) and issubclass(smclass, StateChart): - return smclass - - # qualname may be a module path without a class name — try importing - # the whole path as a module and find the first StateChart subclass. - try: - module = importlib.import_module(qualname) - except ImportError as err: - raise ValueError(f"{class_name} is not a subclass of StateMachine") from err - - smclass = _find_sm_class(module) - if smclass is None: - raise ValueError(f"No StateMachine subclass found in module {qualname!r}") - - return smclass - - -def write_image(qualname, out): - """ - Given a `qualname`, that is the fully qualified dotted path to a StateMachine - classes, imports the class and generates a dot graph using the `pydot` lib. - Writes the graph representation to the filename 'out' that will - open/create and truncate such file and write on it a representation of - the graph defined by the statemachine, in the format specified by - the extension contained in the out path (out.ext). - """ - smclass = import_sm(qualname) - - graph = DotGraphMachine(smclass).get_graph() - out_extension = out.rsplit(".", 1)[1] - graph.write(out, format=out_extension) - - -def main(argv=None): - import argparse - - parser = argparse.ArgumentParser( - usage="%(prog)s [OPTION] ", - description="Generate diagrams for StateMachine classes.", - ) - parser.add_argument( - "class_path", help="A fully-qualified dotted path to the StateMachine class." - ) - parser.add_argument( - "out", - help="File to generate the image using extension as the output format.", - ) - - args = parser.parse_args(argv) - write_image(qualname=args.class_path, out=args.out) - - -if __name__ == "__main__": # pragma: no cover - sys.exit(main()) diff --git a/statemachine/contrib/diagram/__init__.py b/statemachine/contrib/diagram/__init__.py new file mode 100644 index 00000000..51e7bc78 --- /dev/null +++ b/statemachine/contrib/diagram/__init__.py @@ -0,0 +1,170 @@ +import importlib +from urllib.parse import quote +from urllib.request import urlopen + +from .extract import extract +from .renderers.dot import DotRenderer +from .renderers.dot import DotRendererConfig + + +class DotGraphMachine: + """Backwards-compatible facade that uses the extract + render pipeline. + + Maintains the same public API and class-level customization attributes + as the original monolithic DotGraphMachine. + """ + + graph_rankdir = "LR" + """ + Direction of the graph. Defaults to "LR" (option "TB" for top bottom) + http://www.graphviz.org/doc/info/attrs.html#d:rankdir + """ + + font_name = "Helvetica" + """Graph font face name""" + + state_font_size = "10" + """State font size""" + + state_active_penwidth = 2 + """Active state external line width""" + + state_active_fillcolor = "turquoise" + + transition_font_size = "9" + """Transition font size""" + + def __init__(self, machine): + self.machine = machine + + def _build_config(self) -> DotRendererConfig: + return DotRendererConfig( + graph_rankdir=self.graph_rankdir, + font_name=self.font_name, + state_font_size=self.state_font_size, + state_active_penwidth=self.state_active_penwidth, + state_active_fillcolor=self.state_active_fillcolor, + transition_font_size=self.transition_font_size, + ) + + def get_graph(self): + ir = extract(self.machine) + renderer = DotRenderer(config=self._build_config()) + return renderer.render(ir) + + def __call__(self): + return self.get_graph() + + +def quickchart_write_svg(sm, path: str): + """ + If the default dependency of GraphViz installed locally doesn't work for you. As an option, + you can generate the image online from the output of the `dot` language, + using one of the many services available. + + To get the **dot** representation of your state machine is as easy as follows: + + >>> from tests.examples.order_control_machine import OrderControl + >>> sm = OrderControl() + >>> print(sm._graph().to_string()) # doctest: +ELLIPSIS + digraph OrderControl { + ... + } + + To give you an example, we included this method that will serialize the dot, request the graph + to https://quickchart.io, and persist the result locally as an ``.svg`` file. + + + .. warning:: + Quickchart is an external graph service that supports many formats to generate diagrams. + + By using this method, you should trust http://quickchart.io. + + Please read https://quickchart.io/documentation/faq/ for more information. + + >>> quickchart_write_svg(sm, "docs/images/oc_machine_processing.svg") # doctest: +SKIP + + """ + dot_representation = sm._graph().to_string() + + url = f"https://quickchart.io/graphviz?graph={quote(dot_representation)}" + + response = urlopen(url) + data = response.read() + + with open(path, "wb") as f: + f.write(data) + + +def _find_sm_class(module): + """Find the first StateChart subclass defined in a module.""" + import inspect + + from statemachine.statemachine import StateChart + + for _name, obj in inspect.getmembers(module, inspect.isclass): + if ( + issubclass(obj, StateChart) + and obj is not StateChart + and obj.__module__ == module.__name__ + ): + return obj + return None + + +def import_sm(qualname): + from statemachine.statemachine import StateChart + + module_name, class_name = qualname.rsplit(".", 1) + module = importlib.import_module(module_name) + smclass = getattr(module, class_name, None) + if smclass is not None and isinstance(smclass, type) and issubclass(smclass, StateChart): + return smclass + + # qualname may be a module path without a class name — try importing + # the whole path as a module and find the first StateChart subclass. + try: + module = importlib.import_module(qualname) + except ImportError as err: + raise ValueError(f"{class_name} is not a subclass of StateMachine") from err + + smclass = _find_sm_class(module) + if smclass is None: + raise ValueError(f"No StateMachine subclass found in module {qualname!r}") + + return smclass + + +def write_image(qualname, out): + """ + Given a `qualname`, that is the fully qualified dotted path to a StateMachine + classes, imports the class and generates a dot graph using the `pydot` lib. + Writes the graph representation to the filename 'out' that will + open/create and truncate such file and write on it a representation of + the graph defined by the statemachine, in the format specified by + the extension contained in the out path (out.ext). + """ + smclass = import_sm(qualname) + + graph = DotGraphMachine(smclass).get_graph() + out_extension = out.rsplit(".", 1)[1] + graph.write(out, format=out_extension) + + +def main(argv=None): + import argparse + + parser = argparse.ArgumentParser( + usage="%(prog)s [OPTION] ", + description="Generate diagrams for StateMachine classes.", + ) + parser.add_argument( + "class_path", help="A fully-qualified dotted path to the StateMachine class." + ) + parser.add_argument( + "out", + help="File to generate the image using extension as the output format.", + ) + + args = parser.parse_args(argv) + write_image(qualname=args.class_path, out=args.out) diff --git a/statemachine/contrib/diagram/__main__.py b/statemachine/contrib/diagram/__main__.py new file mode 100644 index 00000000..daf509ab --- /dev/null +++ b/statemachine/contrib/diagram/__main__.py @@ -0,0 +1,6 @@ +import sys + +from . import main + +if __name__ == "__main__": + sys.exit(main()) diff --git a/statemachine/contrib/diagram/extract.py b/statemachine/contrib/diagram/extract.py new file mode 100644 index 00000000..37f4fc88 --- /dev/null +++ b/statemachine/contrib/diagram/extract.py @@ -0,0 +1,258 @@ +from typing import TYPE_CHECKING +from typing import List +from typing import Set +from typing import Union + +from .model import ActionType +from .model import DiagramAction +from .model import DiagramGraph +from .model import DiagramState +from .model import DiagramTransition +from .model import StateType + +if TYPE_CHECKING: + from statemachine.state import State + from statemachine.statemachine import StateChart + + # A StateChart class or instance — both expose the same structural metadata. + MachineRef = Union["StateChart", "type[StateChart]"] + + +def _determine_state_type(state: "State") -> StateType: + from statemachine.state import HistoryState + from statemachine.state import HistoryType + + if isinstance(state, HistoryState): + if state.type == HistoryType.DEEP: + return StateType.HISTORY_DEEP + return StateType.HISTORY_SHALLOW + if getattr(state, "parallel", False): + return StateType.PARALLEL + if state.final: + return StateType.FINAL + return StateType.REGULAR + + +def _actions_getter(machine: "MachineRef"): + from statemachine.statemachine import StateChart + + if isinstance(machine, StateChart): + + def getter(grouper): # pyright: ignore[reportRedeclaration] + return machine._callbacks.str(grouper.key) + else: + + def getter(grouper): + all_names = set(dir(machine)) + return ", ".join(str(c) for c in grouper if not c.is_convention or c.func in all_names) + + return getter + + +def _extract_state_actions(state: "State", getter) -> List[DiagramAction]: + actions: List[DiagramAction] = [] + + entry = str(getter(state.enter)) + exit_ = str(getter(state.exit)) + + if entry: + actions.append(DiagramAction(type=ActionType.ENTRY, body=entry)) + if exit_: + actions.append(DiagramAction(type=ActionType.EXIT, body=exit_)) + + for transition in state.transitions: + if transition.internal: + on_text = str(getter(transition.on)) + if on_text: + actions.append( + DiagramAction(type=ActionType.INTERNAL, body=f"{transition.event} / {on_text}") + ) + + return actions + + +def _extract_state( + state: "State", + machine: "MachineRef", + getter, + active_values: set, +) -> DiagramState: + state_type = _determine_state_type(state) + is_active = state.value in active_values + is_parallel_area = bool(state.parent and getattr(state.parent, "parallel", False)) + + children: List[DiagramState] = [] + for substate in state.states: + children.append(_extract_state(substate, machine, getter, active_values)) + for history_state in getattr(state, "history", []): + children.append(_extract_state(history_state, machine, getter, active_values)) + + actions = _extract_state_actions(state, getter) + + return DiagramState( + id=state.id, + name=state.name, + type=state_type, + actions=actions, + children=children, + is_active=is_active, + is_parallel_area=is_parallel_area, + is_initial=getattr(state, "initial", False), + ) + + +def _extract_transitions_from_state(state: "State") -> List[DiagramTransition]: + """Extract transitions from a single state (non-recursive).""" + result: List[DiagramTransition] = [] + for transition in state.transitions: + targets = transition.targets if transition.targets else [] + target_ids = [t.id for t in targets] + + cond_strs = [str(c) for c in transition.cond] + + result.append( + DiagramTransition( + source=transition.source.id, + targets=target_ids, + event=transition.event, + guards=cond_strs, + is_internal=transition.internal, + ) + ) + return result + + +def _extract_all_transitions(states) -> List[DiagramTransition]: + """Recursively extract transitions from all states.""" + result: List[DiagramTransition] = [] + for state in states: + result.extend(_extract_transitions_from_state(state)) + if state.states: + result.extend(_extract_all_transitions(state.states)) + for history_state in getattr(state, "history", []): + result.extend(_extract_transitions_from_state(history_state)) + if history_state.states: # pragma: no cover + result.extend(_extract_all_transitions(history_state.states)) + return result + + +def _collect_compound_ids(states: List[DiagramState]) -> Set[str]: + """Collect IDs of states that have children (compound/parallel).""" + result: Set[str] = set() + for state in states: + if state.children: + result.add(state.id) + result.update(_collect_compound_ids(state.children)) + return result + + +def _collect_bidirectional_compound_ids( + transitions: List[DiagramTransition], + compound_ids: Set[str], +) -> Set[str]: + """Find compound states that have both outgoing and incoming explicit edges.""" + outgoing: Set[str] = set() + incoming: Set[str] = set() + for t in transitions: + if t.is_internal: + continue + # Skip implicit initial transitions + if t.source in compound_ids and not t.event and t.targets: + continue + if t.source in compound_ids: + outgoing.add(t.source) + for target_id in t.targets: + if target_id in compound_ids: + incoming.add(target_id) + return outgoing & incoming + + +def _mark_initial_transitions( + transitions: List[DiagramTransition], + compound_ids: Set[str], +) -> None: + """Mark implicit initial transitions (compound state → child, no event).""" + for t in transitions: + if t.source in compound_ids and not t.event and t.targets and not t.is_internal: + t.is_initial = True + + +def _resolve_initial_states(states: List[DiagramState]) -> None: + """Ensure exactly one state per level has is_initial=True. + + Skips parallel areas and history states. Falls back to document order + (first non-history, non-parallel-area state) when no explicit initial exists. + Recurses into children. + + Parallel areas (children of a parallel state) have their is_initial flag + cleared: all regions are auto-activated, so no initial arrow is needed. + """ + # Clear is_initial on parallel areas — all children of a parallel state + # are simultaneously active; initial arrows would be misleading. + for s in states: + if s.is_parallel_area: + s.is_initial = False + + candidates = [ + s + for s in states + if s.type not in (StateType.HISTORY_SHALLOW, StateType.HISTORY_DEEP) + and not s.is_parallel_area + ] + + has_explicit_initial = any(s.is_initial for s in candidates) + if not has_explicit_initial and candidates: + candidates[0].is_initial = True + + for state in states: + if state.children: + _resolve_initial_states(state.children) + + +def extract(machine_or_class: "MachineRef") -> DiagramGraph: + """Extract a DiagramGraph IR from a state machine instance or class. + + Accepts either a class or an instance. The class is **never** instantiated + — all structural metadata (states, transitions, name) is available on the + class itself thanks to the metaclass. Active-state highlighting is only + produced when an *instance* is passed. + + Args: + machine_or_class: A StateMachine/StateChart instance or class. + + Returns: + A DiagramGraph representing the machine's structure. + """ + from statemachine.statemachine import StateChart + + if isinstance(machine_or_class, StateChart): + machine: "MachineRef" = machine_or_class + elif isinstance(machine_or_class, type) and issubclass(machine_or_class, StateChart): + machine = machine_or_class + else: + raise TypeError(f"Expected a StateChart instance or class, got {type(machine_or_class)}") + + getter = _actions_getter(machine) + + active_values: set = set() + if isinstance(machine, StateChart) and hasattr(machine, "configuration_values"): + active_values = set(machine.configuration_values) + + states: List[DiagramState] = [] + for state in machine.states: + states.append(_extract_state(state, machine, getter, active_values)) + + transitions = _extract_all_transitions(machine.states) + + compound_ids = _collect_compound_ids(states) + bidir_ids = _collect_bidirectional_compound_ids(transitions, compound_ids) + _mark_initial_transitions(transitions, compound_ids) + _resolve_initial_states(states) + + return DiagramGraph( + name=machine.name, + states=states, + transitions=transitions, + compound_state_ids=compound_ids, + bidirectional_compound_ids=bidir_ids, + ) diff --git a/statemachine/contrib/diagram/model.py b/statemachine/contrib/diagram/model.py new file mode 100644 index 00000000..3770bba1 --- /dev/null +++ b/statemachine/contrib/diagram/model.py @@ -0,0 +1,63 @@ +from dataclasses import dataclass +from dataclasses import field +from enum import Enum +from typing import List +from typing import Set + + +class StateType(Enum): + INITIAL = "initial" + REGULAR = "regular" + FINAL = "final" + HISTORY_SHALLOW = "history_shallow" + HISTORY_DEEP = "history_deep" + CHOICE = "choice" + FORK = "fork" + JOIN = "join" + JUNCTION = "junction" + PARALLEL = "parallel" + TERMINATE = "terminate" + + +class ActionType(Enum): + ENTRY = "entry" + EXIT = "exit" + INTERNAL = "internal" + + +@dataclass +class DiagramAction: + type: ActionType + body: str + + +@dataclass +class DiagramState: + id: str + name: str + type: StateType + actions: List[DiagramAction] = field(default_factory=list) + children: List["DiagramState"] = field(default_factory=list) + is_active: bool = False + is_parallel_area: bool = False + is_initial: bool = False + + +@dataclass +class DiagramTransition: + source: str + targets: List[str] = field(default_factory=list) + event: str = "" + guards: List[str] = field(default_factory=list) + actions: List[str] = field(default_factory=list) + is_internal: bool = False + is_initial: bool = False + + +@dataclass +class DiagramGraph: + name: str + states: List[DiagramState] = field(default_factory=list) + transitions: List[DiagramTransition] = field(default_factory=list) + compound_state_ids: Set[str] = field(default_factory=set) + bidirectional_compound_ids: Set[str] = field(default_factory=set) diff --git a/statemachine/contrib/diagram/renderers/__init__.py b/statemachine/contrib/diagram/renderers/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/statemachine/contrib/diagram/renderers/dot.py b/statemachine/contrib/diagram/renderers/dot.py new file mode 100644 index 00000000..a33db791 --- /dev/null +++ b/statemachine/contrib/diagram/renderers/dot.py @@ -0,0 +1,531 @@ +from dataclasses import dataclass +from dataclasses import field +from typing import Dict +from typing import List +from typing import Optional +from typing import Set + +import pydot + +from ..model import ActionType +from ..model import DiagramAction +from ..model import DiagramGraph +from ..model import DiagramState +from ..model import DiagramTransition +from ..model import StateType + + +def _escape_html(text: str) -> str: + """Escape text for use inside HTML labels.""" + return text.replace("&", "&").replace("<", "<").replace(">", ">") + + +@dataclass +class DotRendererConfig: + """Configuration for the DOT renderer, matching DotGraphMachine's class attributes.""" + + graph_rankdir: str = "LR" + font_name: str = "Helvetica" + state_font_size: str = "12" + state_active_penwidth: int = 2 + state_active_fillcolor: str = "turquoise" + transition_font_size: str = "10" + graph_attrs: Dict[str, str] = field(default_factory=dict) + node_attrs: Dict[str, str] = field(default_factory=dict) + edge_attrs: Dict[str, str] = field(default_factory=dict) + + +class DotRenderer: + """Renders a DiagramGraph into a pydot.Dot graph with UML-inspired styling. + + Uses techniques inspired by state-machine-cat for cleaner visual output: + - HTML TABLE labels for states with UML compartments + - plaintext nodes with near-transparent fill + - Refined graph/node/edge defaults + """ + + def __init__(self, config: Optional[DotRendererConfig] = None): + self.config = config or DotRendererConfig() + self._compound_ids: Set[str] = set() + self._compound_bidir_ids: Set[str] = set() + + def render(self, graph: DiagramGraph) -> pydot.Dot: + """Render a DiagramGraph to a pydot.Dot object.""" + self._compound_ids = graph.compound_state_ids + self._compound_bidir_ids = graph.bidirectional_compound_ids + dot = self._create_graph(graph.name) + self._render_states(graph.states, graph.transitions, dot) + return dot + + def _create_graph(self, name: str) -> pydot.Dot: + cfg = self.config + graph_attrs = { + "fontname": cfg.font_name, + "fontsize": cfg.state_font_size, + "penwidth": "2.0", + "splines": "true", + "ordering": "out", + "compound": "true", + "nodesep": "0.3", + "ranksep": "0.3", + "forcelabels": "true", + } + graph_attrs.update(cfg.graph_attrs) + + dot = pydot.Dot( + name, + graph_type="digraph", + label=name, + rankdir=cfg.graph_rankdir, + **graph_attrs, + ) + + # Set default node attributes + node_defaults = { + "fontname": cfg.font_name, + "fontsize": cfg.state_font_size, + "penwidth": "2.0", + } + node_defaults.update(cfg.node_attrs) + dot.set_node_defaults(**node_defaults) + + # Set default edge attributes + edge_defaults = { + "fontname": cfg.font_name, + "fontsize": cfg.transition_font_size, + "labeldistance": "1.5", + } + edge_defaults.update(cfg.edge_attrs) + dot.set_edge_defaults(**edge_defaults) + + return dot + + def _state_node_id(self, state_id: str) -> str: + """Get the node ID to use for edges. Compound states use an anchor node.""" + if state_id in self._compound_ids: + return f"{state_id}_anchor" + return state_id + + def _compound_edge_anchor(self, state_id: str, direction: str) -> str: + """Return the appropriate anchor node ID for a compound ↔ other edge. + + Compound states that have both incoming and outgoing explicit transitions + get separate ``_anchor_out`` / ``_anchor_in`` nodes so Graphviz can route + the two directions through physically distinct points, avoiding overlap. + """ + if state_id in self._compound_bidir_ids: + return f"{state_id}_anchor_{direction}" + return f"{state_id}_anchor" + + def _render_states( + self, + states: List[DiagramState], + transitions: List[DiagramTransition], + parent_graph: "pydot.Dot | pydot.Subgraph", + extra_nodes: Optional[List[pydot.Node]] = None, + ) -> None: + """Render states and transitions into the parent graph.""" + initial_state = next((s for s in states if s.is_initial), None) + + # The atomic subgraph groups all non-compound states and the inner + # initial dot (when inside a compound cluster) so Graphviz places them + # in the same rank region, keeping the initial arrow short. + atomic_subgraph = pydot.Subgraph( + graph_name=f"cluster___atomic_{id(parent_graph)}", + label="", + peripheries=0, + margin=0, + cluster="true", + ) + has_atomic = False + + if initial_state: + has_atomic = ( + self._render_initial_arrow(initial_state, parent_graph, atomic_subgraph) + or has_atomic + ) + + for state in states: + if state.type in (StateType.HISTORY_SHALLOW, StateType.HISTORY_DEEP): + atomic_subgraph.add_node(self._create_history_node(state)) + has_atomic = True + elif state.children: + subgraph = self._create_compound_subgraph(state) + anchor_nodes = self._create_compound_anchor_nodes(state) + self._render_states( + state.children, transitions, subgraph, extra_nodes=anchor_nodes + ) + parent_graph.add_subgraph(subgraph) + # Add transitions originating from this compound state + self._add_transitions_for_state(state, transitions, parent_graph) + else: + atomic_subgraph.add_node(self._create_atomic_node(state)) + has_atomic = True + + has_atomic = self._place_extra_nodes( + extra_nodes, atomic_subgraph, parent_graph, has_atomic + ) + + if has_atomic: + parent_graph.add_subgraph(atomic_subgraph) + + # Add transitions for atomic/history states + for state in states: + if not state.children: + self._add_transitions_for_state(state, transitions, parent_graph) + + @staticmethod + def _place_extra_nodes( + extra_nodes: Optional[List[pydot.Node]], + atomic_subgraph: pydot.Subgraph, + parent_graph: "pydot.Dot | pydot.Subgraph", + has_atomic: bool, + ) -> bool: + """Place anchor nodes from the parent compound into the graph. + + Co-locates them with real states when possible. If there are no atomic + states at this level (e.g. a parallel state with only compound children), + adds them directly to the parent graph to avoid an empty cluster. + + Returns the updated ``has_atomic`` flag. + """ + if not extra_nodes: + return has_atomic + target = atomic_subgraph if has_atomic else parent_graph + for node in extra_nodes: + target.add_node(node) + return has_atomic or (target is atomic_subgraph) + + def _render_initial_arrow( + self, + initial_state: DiagramState, + parent_graph: "pydot.Dot | pydot.Subgraph", + atomic_subgraph: pydot.Subgraph, + ) -> bool: + """Render the black-dot initial arrow pointing to ``initial_state``. + + Returns True if nodes were added to ``atomic_subgraph``. + """ + initial_node_id = f"__initial_{id(parent_graph)}" + initial_node = self._create_initial_node(initial_node_id) + added_to_atomic = False + + extra = {} + if initial_state.children: + extra["lhead"] = f"cluster_{initial_state.id}" + + if initial_state.children or isinstance(parent_graph, pydot.Dot): + # Compound initial state, or top-level atomic initial state: + # keep the dot in a plain wrapper subgraph attached to parent. + wrapper = pydot.Subgraph( + graph_name=f"{initial_node_id}_sg", + label="", + peripheries=0, + margin=0, + ) + wrapper.add_node(initial_node) + parent_graph.add_subgraph(wrapper) + else: + # Inner (compound parent) with atomic initial state: add the + # dot directly into the atomic cluster so it shares the same + # rank region as the target state, avoiding a long arrow caused + # by the compound cluster's anchor nodes pushing step1 further. + atomic_subgraph.add_node(initial_node) + added_to_atomic = True + + parent_graph.add_edge( + pydot.Edge( + initial_node_id, + self._state_node_id(initial_state.id), + label="", + minlen=1, + weight=100, + **extra, + ) + ) + return added_to_atomic + + def _create_initial_node(self, node_id: str) -> pydot.Node: + return pydot.Node( + node_id, + label="", + shape="circle", + style="filled", + fillcolor="black", + color="black", + fixedsize="true", + width=0.15, + height=0.15, + penwidth="0", + ) + + def _create_atomic_node(self, state: DiagramState) -> pydot.Node: + """Create a node for an atomic state. + + All states use a native ``shape="rectangle"`` with ``style="rounded, filled"`` + so that Graphviz clips edges at the actual rounded border. States with + entry/exit actions embed an HTML TABLE (``border="0"``) inside the native + shape to render UML-style compartments (name + separator + actions). + """ + actions = [a for a in state.actions if a.type != ActionType.INTERNAL or a.body] + fillcolor = self.config.state_active_fillcolor if state.is_active else "white" + penwidth = self.config.state_active_penwidth if state.is_active else 2 + + if not actions: + # Simple state: native rounded rectangle + node = pydot.Node( + state.id, + label=state.name, + shape="rectangle", + style="rounded, filled", + fontname=self.config.font_name, + fontsize=self.config.state_font_size, + fillcolor=fillcolor, + penwidth=penwidth, + peripheries=2 if state.type == StateType.FINAL else 1, + ) + else: + # State with actions: native shape + HTML TABLE label (border=0). + # The native shape handles edge clipping; the TABLE provides + # UML compartment layout with
separator. + label = self._build_html_table_label(state, actions) + node = pydot.Node( + state.id, + label=f"<{label}>", + shape="rectangle", + style="rounded, filled", + fontname=self.config.font_name, + fontsize=self.config.state_font_size, + fillcolor=fillcolor, + penwidth=penwidth, + margin="0", + peripheries=2 if state.type == StateType.FINAL else 1, + ) + + return node + + def _build_html_table_label( + self, + state: DiagramState, + actions: List[DiagramAction], + ) -> str: + """Build an HTML TABLE label with UML compartments (name | actions). + + The TABLE has ``border="0"`` because the visible border is drawn by + the native Graphviz shape, ensuring edges are clipped correctly. + """ + name = _escape_html(state.name) + font_size = self.config.state_font_size + action_font_size = self.config.transition_font_size + + action_lines = "
".join( + f'{_escape_html(self._format_action(a))}' + for a in actions + ) + + return ( + f'' + f'" + f"
" + f'" + f"
' + f'{name}' + f"
' + f"{action_lines}" + f"
" + ) + + @staticmethod + def _format_action(action: DiagramAction) -> str: + if action.type == ActionType.INTERNAL: + return action.body + return f"{action.type.value} / {action.body}" + + def _create_history_node(self, state: DiagramState) -> pydot.Node: + label = "H*" if state.type == StateType.HISTORY_DEEP else "H" + return pydot.Node( + state.id, + label=label, + shape="circle", + style="filled", + fillcolor="white", + fontname=self.config.font_name, + fontsize="8pt", + fixedsize="true", + width=0.3, + height=0.3, + ) + + def _create_compound_anchor_nodes(self, state: DiagramState) -> List[pydot.Node]: + """Create invisible anchor nodes for edge routing inside a compound cluster. + + These nodes are injected into the children's atomic_subgraph so they + share the same layout row as the real states, avoiding blank space at + the top of the compound cluster. + """ + # For bidirectional compounds, all edges route through _anchor_in/_anchor_out; + # the generic _anchor node is never used and would become an orphan that + # Graphviz places arbitrarily, creating blank vertical space in the cluster. + if state.id not in self._compound_bidir_ids: + nodes = [ + pydot.Node( + f"{state.id}_anchor", + shape="point", + style="invis", + width=0, + height=0, + fixedsize="true", + ) + ] + else: + nodes = [] + for direction in ("in", "out"): + nodes.append( + pydot.Node( + f"{state.id}_anchor_{direction}", + shape="point", + style="invis", + width=0, + height=0, + fixedsize="true", + ) + ) + return nodes + + def _create_compound_subgraph(self, state: DiagramState) -> pydot.Subgraph: + """Create a cluster subgraph for a compound/parallel state.""" + style = "rounded, solid" + if state.is_parallel_area: + style = "rounded, dashed" + + label = self._build_compound_label(state) + + return pydot.Subgraph( + graph_name=f"cluster_{state.id}", + label=f"<{label}>", + style=style, + cluster="true", + penwidth="2.0", + fontname=self.config.font_name, + fontsize=self.config.state_font_size, + margin="4", + ) + + def _build_compound_label(self, state: DiagramState) -> str: + """Build HTML label for a compound/parallel subgraph.""" + name = _escape_html(state.name) + if state.type == StateType.PARALLEL: + return f"{name} ☷" + + actions = [a for a in state.actions if a.type != ActionType.INTERNAL or a.body] + if not actions: + return f"{name}" + + rows = [f"{name}"] + for action in actions: + action_text = _escape_html(self._format_action(action)) + rows.append( + f'{action_text}' + ) + return "
".join(rows) + + def _add_transitions_for_state( + self, + state: DiagramState, + all_transitions: List[DiagramTransition], + graph: "pydot.Dot | pydot.Subgraph", + ) -> None: + """Add edges for all non-internal transitions originating from this state.""" + for transition in all_transitions: + if transition.source != state.id or transition.is_internal: + continue + # Skip implicit initial transitions — represented by the black-dot initial node. + if transition.is_initial: + continue + for edge in self._create_edges(transition): + graph.add_edge(edge) + + def _create_edges(self, transition: DiagramTransition) -> List[pydot.Edge]: + """Create pydot.Edge objects for a transition.""" + target_ids: List[Optional[str]] = ( + list(transition.targets) if transition.targets else [None] + ) + + cond = ", ".join(transition.guards) + cond_html = f"
[{_escape_html(cond)}]" if cond else "" + + return [ + self._create_single_edge(transition, target_id, i, cond_html) + for i, target_id in enumerate(target_ids) + ] + + def _create_single_edge( + self, + transition: DiagramTransition, + target_id: Optional[str], + index: int, + cond_html: str, + ) -> pydot.Edge: + """Create a single pydot.Edge for one target of a transition.""" + src, dst, extra = self._resolve_edge_endpoints(transition, target_id) + has_substates = bool(extra) + html_label = self._build_edge_label(transition.event, cond_html, index) + + return pydot.Edge( + src, + dst, + label=html_label, + minlen=2 if has_substates else 1, + **extra, + ) + + def _resolve_edge_endpoints( + self, + transition: DiagramTransition, + target_id: Optional[str], + ) -> "tuple[str, str, Dict[str, str]]": + """Resolve source/destination node IDs and cluster attributes for an edge.""" + extra: Dict[str, str] = {} + source_is_compound = transition.source in self._compound_ids + target_is_compound = target_id is not None and target_id in self._compound_ids + + if source_is_compound: + extra["ltail"] = f"cluster_{transition.source}" + if target_is_compound: + extra["lhead"] = f"cluster_{target_id}" + + dst = ( + self._state_node_id(target_id) + if target_id is not None + else self._state_node_id(transition.source) + ) + src = self._state_node_id(transition.source) + + # For compound states in bidirectional pairs, route outgoing edges + # through _anchor_out and incoming through _anchor_in so Graphviz + # places them at different physical positions inside the cluster. + if source_is_compound and transition.source in self._compound_bidir_ids: + src = self._compound_edge_anchor(transition.source, "out") + extra["ltail"] = f"cluster_{transition.source}" + if target_is_compound and target_id in self._compound_bidir_ids: + dst = self._compound_edge_anchor(target_id, "in") + extra["lhead"] = f"cluster_{target_id}" + + return src, dst, extra + + def _build_edge_label(self, event: str, cond_html: str, index: int) -> str: + """Build the HTML label for a transition edge.""" + event_text = _escape_html(event) if index == 0 else "" + if not event_text and not (cond_html and index == 0): + return "" + + label_content = f"{event_text}{cond_html}" if index == 0 else "" + font_size = self.config.transition_font_size + return ( + f'<' + f'" + f'' + f"
' + f'{label_content}' + f"
>" + ) diff --git a/statemachine/contrib/diagram/sphinx_ext.py b/statemachine/contrib/diagram/sphinx_ext.py new file mode 100644 index 00000000..bbc9a8ac --- /dev/null +++ b/statemachine/contrib/diagram/sphinx_ext.py @@ -0,0 +1,236 @@ +"""Sphinx extension providing the ``statemachine-diagram`` directive. + +Usage in MyST Markdown:: + + ```{statemachine-diagram} mypackage.module.MyMachine + :events: start, ship + :caption: After shipping + ``` + +The directive imports the state machine class, optionally instantiates it and +sends events, then renders an SVG diagram inline in the documentation. +""" + +from __future__ import annotations + +import hashlib +import html as html_mod +import os +import re +from typing import TYPE_CHECKING +from typing import Any +from typing import ClassVar + +from docutils import nodes +from docutils.parsers.rst import directives +from sphinx.util.docutils import SphinxDirective + +if TYPE_CHECKING: + from sphinx.application import Sphinx + + +def _align_spec(argument: str) -> str: + return str(directives.choice(argument, ("left", "center", "right"))) + + +def _parse_events(value: str) -> list[str]: + """Parse a comma-separated list of event names.""" + return [e.strip() for e in value.split(",") if e.strip()] + + +# Match the outer ... element, stripping XML prologue/DOCTYPE. +_SVG_TAG_RE = re.compile(rb"()", re.DOTALL) + +# Match fixed width/height attributes (e.g. width="702pt" height="170pt"). +_SVG_WIDTH_RE = re.compile(r'\bwidth="([^"]*(?:pt|px))"') +_SVG_HEIGHT_RE = re.compile(r'\bheight="([^"]*(?:pt|px))"') + + +class StateMachineDiagram(SphinxDirective): + """Render a state machine diagram from an importable class path. + + Supports the same layout options as the standard ``image`` and ``figure`` + directives (``width``, ``height``, ``scale``, ``align``, ``target``, + ``class``, ``name``), plus state-machine-specific options (``events``, + ``caption``, ``figclass``). + """ + + has_content: ClassVar[bool] = False + required_arguments: ClassVar[int] = 1 + optional_arguments: ClassVar[int] = 0 + option_spec: ClassVar[dict[str, Any]] = { + # State-machine options + "events": directives.unchanged, + # Standard image/figure options + "caption": directives.unchanged, + "alt": directives.unchanged, + "width": directives.unchanged, + "height": directives.unchanged, + "scale": directives.unchanged, + "align": _align_spec, + "target": directives.unchanged, + "class": directives.class_option, + "name": directives.unchanged, + "figclass": directives.class_option, + } + + def run(self) -> list[nodes.Node]: + qualname = self.arguments[0] + + try: + from statemachine.contrib.diagram import DotGraphMachine + from statemachine.contrib.diagram import import_sm + + sm_class = import_sm(qualname) + except (ImportError, ValueError) as exc: + return [ + self.state_machine.reporter.warning( + f"statemachine-diagram: could not import {qualname!r}: {exc}", + line=self.lineno, + ) + ] + + if "events" in self.options: + machine = sm_class() + for event_name in _parse_events(self.options["events"]): + machine.send(event_name) + else: + machine = sm_class + + try: + graph = DotGraphMachine(machine).get_graph() + svg_bytes: bytes = graph.create_svg() # type: ignore[attr-defined] + except Exception as exc: + return [ + self.state_machine.reporter.warning( + f"statemachine-diagram: failed to generate diagram for {qualname!r}: {exc}", + line=self.lineno, + ) + ] + + svg_tag, intrinsic_width, intrinsic_height = self._prepare_svg(svg_bytes) + svg_styles = self._build_svg_styles(intrinsic_width, intrinsic_height) + svg_tag = svg_tag.replace("{svg_tag}' + if target: + img_html = f'{img_html}' + + wrapper_classes = self._build_wrapper_classes() + class_attr = f' class="{" ".join(wrapper_classes)}"' + + if "caption" in self.options: + caption = html_mod.escape(self.options["caption"]) + figclass = self.options.get("figclass", []) + if figclass: + class_attr = f' class="{" ".join(wrapper_classes + figclass)}"' + html = ( + f"\n" + f" {img_html}\n" + f"
{caption}
\n" + f"" + ) + else: + html = f"{img_html}" + + raw_node = nodes.raw("", html, format="html") + + if "name" in self.options: + self.add_name(raw_node) + + return [raw_node] + + def _prepare_svg(self, svg_bytes: bytes) -> tuple[str, str, str]: + """Extract the ```` element and its intrinsic dimensions.""" + match = _SVG_TAG_RE.search(svg_bytes) + svg_tag = match.group(1).decode("utf-8") if match else svg_bytes.decode("utf-8") + + width_match = _SVG_WIDTH_RE.search(svg_tag) + height_match = _SVG_HEIGHT_RE.search(svg_tag) + intrinsic_width = width_match.group(1) if width_match else "" + intrinsic_height = height_match.group(1) if height_match else "" + + # Remove fixed dimensions — sizing is controlled via inline styles. + svg_tag = _SVG_WIDTH_RE.sub("", svg_tag) + svg_tag = _SVG_HEIGHT_RE.sub("", svg_tag) + + return svg_tag, intrinsic_width, intrinsic_height + + def _build_svg_styles(self, intrinsic_width: str, intrinsic_height: str) -> str: + """Build an inline ``style`` attribute for the ```` element.""" + parts: list[str] = [] + + # Width: explicit > scaled intrinsic > intrinsic as max-width. + user_width = self.options.get("width", "") + scale = self.options.get("scale", "") + if user_width: + parts.append(f"width: {user_width}") + elif scale and intrinsic_width: + factor = int(scale.rstrip("%")) / 100 + value, unit = _split_length(intrinsic_width) + parts.append(f"width: {value * factor:.1f}{unit}") + elif intrinsic_width: + parts.append(f"max-width: {intrinsic_width}") + + # Height: explicit > scaled intrinsic > auto. + user_height = self.options.get("height", "") + if user_height: + parts.append(f"height: {user_height}") + elif scale and intrinsic_height: + factor = int(scale.rstrip("%")) / 100 + value, unit = _split_length(intrinsic_height) + parts.append(f"height: {value * factor:.1f}{unit}") + else: + parts.append("height: auto") + + return f'style="{"; ".join(parts)}"' + + def _resolve_target(self, svg_bytes: bytes) -> str: + """Return the href for the wrapper ```` tag, if any. + + When ``:target:`` is given without a value (or as empty string), the + raw SVG is written to ``_images/`` and linked so the user can open + the full diagram in a new browser tab for zooming. + """ + if "target" not in self.options: + return "" + target = (self.options["target"] or "").strip() + if target: + return target + + # Auto-generate a standalone SVG file for zoom. + qualname = self.arguments[0] + events_key = self.options.get("events", "") + identity = f"{qualname}:{events_key}" + digest = hashlib.sha1(identity.encode()).hexdigest()[:8] + filename = f"statemachine-{digest}.svg" + + outdir = os.path.join(self.env.app.outdir, "_images") + os.makedirs(outdir, exist_ok=True) + outpath = os.path.join(outdir, filename) + with open(outpath, "wb") as f: + f.write(svg_bytes) + + return f"/_images/{filename}" + + def _build_wrapper_classes(self) -> list[str]: + """Build CSS class list for the outer wrapper element.""" + css_classes: list[str] = self.options.get("class", []) + align = self.options.get("align", "center") + return ["statemachine-diagram", f"align-{align}"] + css_classes + + +def _split_length(value: str) -> tuple[float, str]: + """Split a CSS length like ``'702pt'`` into ``(702.0, 'pt')``.""" + match = re.match(r"([0-9.]+)(.*)", value) + if match: + return float(match.group(1)), match.group(2) + return 0.0, value + + +def setup(app: "Sphinx") -> dict[str, Any]: + app.add_directive("statemachine-diagram", StateMachineDiagram) + return {"version": "0.1", "parallel_read_safe": True, "parallel_write_safe": True} diff --git a/tests/machines/__init__.py b/tests/machines/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/machines/showcase_actions.py b/tests/machines/showcase_actions.py new file mode 100644 index 00000000..5569b838 --- /dev/null +++ b/tests/machines/showcase_actions.py @@ -0,0 +1,16 @@ +from statemachine import State +from statemachine import StateChart + + +class ActionsSC(StateChart): + off = State(initial=True) + on = State() + done = State(final=True) + + power_on = off.to(on) + shutdown = on.to(done) + + def on_exit_off(self): ... + def on_enter_on(self): ... + def on_exit_on(self): ... + def on_enter_done(self): ... diff --git a/tests/machines/showcase_compound.py b/tests/machines/showcase_compound.py new file mode 100644 index 00000000..2125c1e5 --- /dev/null +++ b/tests/machines/showcase_compound.py @@ -0,0 +1,15 @@ +from statemachine import State +from statemachine import StateChart + + +class CompoundSC(StateChart): + class active(State.Compound, name="Active"): + idle = State(initial=True) + working = State() + begin = idle.to(working) + + off = State(initial=True) + done = State(final=True) + + turn_on = off.to(active) + turn_off = active.to(done) diff --git a/tests/machines/showcase_deep_history.py b/tests/machines/showcase_deep_history.py new file mode 100644 index 00000000..1963c590 --- /dev/null +++ b/tests/machines/showcase_deep_history.py @@ -0,0 +1,21 @@ +from statemachine import HistoryState +from statemachine import State +from statemachine import StateChart + + +class DeepHistorySC(StateChart): + class outer(State.Compound, name="Outer"): + class inner(State.Compound, name="Inner"): + a = State(initial=True) + b = State() + go = a.to(b) + + start = State(initial=True) + enter_inner = start.to(inner) + h = HistoryState(type="deep") + + away = State(initial=True) + + dive = away.to(outer) + leave = outer.to(away) + restore = away.to(outer.h) diff --git a/tests/machines/showcase_guards.py b/tests/machines/showcase_guards.py new file mode 100644 index 00000000..8619e986 --- /dev/null +++ b/tests/machines/showcase_guards.py @@ -0,0 +1,16 @@ +from statemachine import State +from statemachine import StateChart + + +class GuardSC(StateChart): + pending = State(initial=True) + approved = State(final=True) + rejected = State(final=True) + + def is_valid(self): + return True + + def is_invalid(self): + return False + + review = pending.to(approved, cond="is_valid") | pending.to(rejected, cond="is_invalid") diff --git a/tests/machines/showcase_history.py b/tests/machines/showcase_history.py new file mode 100644 index 00000000..73c3f404 --- /dev/null +++ b/tests/machines/showcase_history.py @@ -0,0 +1,17 @@ +from statemachine import HistoryState +from statemachine import State +from statemachine import StateChart + + +class HistorySC(StateChart): + class process(State.Compound, name="Process"): + step1 = State(initial=True) + step2 = State() + advance = step1.to(step2) + h = HistoryState() + + paused = State(initial=True) + + pause = process.to(paused) + resume = paused.to(process.h) + begin = paused.to(process) diff --git a/tests/machines/showcase_internal.py b/tests/machines/showcase_internal.py new file mode 100644 index 00000000..530296f2 --- /dev/null +++ b/tests/machines/showcase_internal.py @@ -0,0 +1,12 @@ +from statemachine import State +from statemachine import StateChart + + +class InternalSC(StateChart): + monitoring = State(initial=True) + done = State(final=True) + + def log_status(self): ... + + check = monitoring.to.itself(internal=True, on="log_status") + stop = monitoring.to(done) diff --git a/tests/machines/showcase_parallel.py b/tests/machines/showcase_parallel.py new file mode 100644 index 00000000..7ade93ae --- /dev/null +++ b/tests/machines/showcase_parallel.py @@ -0,0 +1,21 @@ +from statemachine import State +from statemachine import StateChart + + +class ParallelSC(StateChart): + class both(State.Parallel, name="Both"): + class left(State.Compound, name="Left"): + l1 = State(initial=True) + l2 = State(final=True) + go_l = l1.to(l2) + + class right(State.Compound, name="Right"): + r1 = State(initial=True) + r2 = State(final=True) + go_r = r1.to(r2) + + start = State(initial=True) + end = State(final=True) + + enter = start.to(both) + done_state_both = both.to(end) diff --git a/tests/machines/showcase_self_transition.py b/tests/machines/showcase_self_transition.py new file mode 100644 index 00000000..59995db9 --- /dev/null +++ b/tests/machines/showcase_self_transition.py @@ -0,0 +1,10 @@ +from statemachine import State +from statemachine import StateChart + + +class SelfTransitionSC(StateChart): + counting = State(initial=True) + done = State(final=True) + + increment = counting.to.itself() + stop = counting.to(done) diff --git a/tests/machines/showcase_simple.py b/tests/machines/showcase_simple.py new file mode 100644 index 00000000..affc1ce1 --- /dev/null +++ b/tests/machines/showcase_simple.py @@ -0,0 +1,11 @@ +from statemachine import State +from statemachine import StateChart + + +class SimpleSC(StateChart): + idle = State(initial=True) + running = State() + done = State(final=True) + + start = idle.to(running) + finish = running.to(done) diff --git a/tests/machines/transition_from_any.py b/tests/machines/transition_from_any.py new file mode 100644 index 00000000..3006fc69 --- /dev/null +++ b/tests/machines/transition_from_any.py @@ -0,0 +1,30 @@ +from statemachine import State +from statemachine import StateChart + + +class OrderWorkflow(StateChart): + pending = State(initial=True) + processing = State() + done = State() + completed = State(final=True) + cancelled = State(final=True) + + process = pending.to(processing) + complete = processing.to(done) + finish = done.to(completed) + cancel = cancelled.from_.any() + + +class OrderWorkflowCompound(StateChart): + class active(State.Compound): + pending = State(initial=True) + processing = State() + done = State(final=True) + + process = pending.to(processing) + complete = processing.to(done) + + completed = State(final=True) + cancelled = State(final=True) + done_state_active = active.to(completed) + cancel = active.to(cancelled) diff --git a/tests/machines/tutorial_coffee_order.py b/tests/machines/tutorial_coffee_order.py new file mode 100644 index 00000000..f29ecd07 --- /dev/null +++ b/tests/machines/tutorial_coffee_order.py @@ -0,0 +1,13 @@ +from statemachine import State +from statemachine import StateChart + + +class CoffeeOrder(StateChart): + pending = State(initial=True) + preparing = State() + ready = State() + picked_up = State(final=True) + + start = pending.to(preparing) + finish = preparing.to(ready) + pick_up = ready.to(picked_up) diff --git a/tests/test_contrib_diagram.py b/tests/test_contrib_diagram.py index 54c94cbd..b4d7f373 100644 --- a/tests/test_contrib_diagram.py +++ b/tests/test_contrib_diagram.py @@ -1,18 +1,66 @@ +import re from contextlib import contextmanager from unittest import mock +from xml.etree import ElementTree import pytest +from docutils import nodes from statemachine.contrib.diagram import DotGraphMachine from statemachine.contrib.diagram import main from statemachine.contrib.diagram import quickchart_write_svg -from statemachine.transition import Transition +from statemachine.contrib.diagram.model import ActionType +from statemachine.contrib.diagram.model import StateType +from statemachine.contrib.diagram.renderers.dot import DotRenderer -from statemachine import HistoryState from statemachine import State from statemachine import StateChart pytestmark = pytest.mark.usefixtures("requires_dot_installed") +SVG_NS = {"svg": "http://www.w3.org/2000/svg"} + + +def _parse_svg(graph): + """Generate SVG from a pydot graph and parse it as XML.""" + svg_bytes = graph.create_svg() + return ElementTree.fromstring(svg_bytes) + + +def _find_state_node(svg_root, state_id): + """Find the SVG element for a state node by its title text.""" + for g in svg_root.iter("{http://www.w3.org/2000/svg}g"): + if g.get("class") != "node": + continue + title = g.find("{http://www.w3.org/2000/svg}title") + if title is not None and title.text == state_id: + return g + return None + + +def _has_rectangular_fill(node_g): + """Check if a node group has a with a colored fill. + + A fill inside a state node means the background is rectangular + (no rounded corners), which is a visual regression — state backgrounds + should use with curves to match the rounded border. + + Ignores white fills and arrow-related polygons (which are in edge groups). + """ + for polygon in node_g.findall("{http://www.w3.org/2000/svg}polygon"): + fill = polygon.get("fill", "none") + if fill not in ("none", "white", "black", "#ffffff"): + return True + return False + + +def _path_has_curves(d_attr): + """Check if an SVG path `d` attribute contains curve commands (C, c, Q, q, A, a). + + Rounded corners are drawn with cubic Bezier curves (C command). + A rectangular shape only has M (move) and L (line) commands. + """ + return bool(re.search(r"[CcQqAa]", d_attr)) + @pytest.fixture( params=[ @@ -186,28 +234,34 @@ class inner(State.Compound, name="Inner"): def test_subgraph_dashed_style_for_parallel_parent(): """Subgraph uses dashed border when parent state is parallel.""" - child = State("child", initial=True) - child._set_id("child") - parent = State("parent", parallel=True, states=[child]) - parent._set_id("parent") - graph_maker = DotGraphMachine.__new__(DotGraphMachine) - subgraph = graph_maker._get_subgraph(child) - assert "dashed" in subgraph.obj_dict["attributes"].get("style", "") + class SM(StateChart): + class p(State.Parallel, name="p"): + class r1(State.Compound, name="r1"): + a = State(initial=True) + + start = State(initial=True) + begin = start.to(p) + + dot = DotGraphMachine(SM)().to_string() + # The region subgraph inside a parallel state should have dashed style + assert "dashed" in dot def test_initial_edge_with_compound_state_has_lhead(): """Initial edge to a compound state sets lhead cluster attribute.""" - inner = State("inner", initial=True) - inner._set_id("inner") - compound = State("compound", states=[inner], initial=True) - compound._set_id("compound") - graph_maker = DotGraphMachine.__new__(DotGraphMachine) - initial_node = graph_maker._initial_node(compound) - edge = graph_maker._initial_edge(initial_node, compound) - attrs = edge.obj_dict["attributes"] - assert attrs.get("lhead") == f"cluster_{compound.id}" + class SM(StateChart): + class parent(State.Compound, name="Parent"): + child1 = State(initial=True) + child2 = State(final=True) + go = child1.to(child2) + + start = State(initial=True) + enter = start.to(parent) + + dot = DotGraphMachine(SM)().to_string() + assert "lhead=cluster_parent" in dot def test_initial_edge_inside_compound_subgraph(): @@ -231,18 +285,20 @@ class parent(State.Compound, name="Parent"): # The compound subgraph should contain an initial point node and an edge to child1 assert "parent_anchor" in dot assert "child1" in dot - # Verify the initial edge exists (from parent's initial node to child1) - assert "parent_anchor -> child1" in dot + # Verify the initial edge exists (from the black-dot initial node to child1) + # The implicit initial transition from the compound state itself is NOT rendered + # as an edge — it is represented only by the black-dot initial node inside the cluster. + assert "parent_anchor -> child1" not in dot + assert "-> child1" in dot def test_history_state_shallow_diagram(): """DOT output contains an 'H' circle node for shallow history state.""" - h = HistoryState(name="H") - h._set_id("h_shallow") + from statemachine.contrib.diagram.model import DiagramState - graph_maker = DotGraphMachine.__new__(DotGraphMachine) - graph_maker.font_name = "Arial" - node = graph_maker._history_node(h) + state = DiagramState(id="h_shallow", name="H", type=StateType.HISTORY_SHALLOW) + renderer = DotRenderer() + node = renderer._create_history_node(state) attrs = node.obj_dict["attributes"] assert attrs["label"] in ("H", '"H"') assert attrs["shape"] == "circle" @@ -250,13 +306,11 @@ def test_history_state_shallow_diagram(): def test_history_state_deep_diagram(): """DOT output contains an 'H*' circle node for deep history state.""" - h = HistoryState(name="H*", type="deep") - h._set_id("h_deep") + from statemachine.contrib.diagram.model import DiagramState - graph_maker = DotGraphMachine.__new__(DotGraphMachine) - graph_maker.font_name = "Arial" - node = graph_maker._history_node(h) - # Verify the node renders correctly in DOT output + state = DiagramState(id="h_deep", name="H*", type=StateType.HISTORY_DEEP) + renderer = DotRenderer() + node = renderer._create_history_node(state) dot_str = node.to_string() assert "H*" in dot_str assert "circle" in dot_str @@ -264,25 +318,12 @@ def test_history_state_deep_diagram(): def test_history_state_default_transition(): """History state's default transition appears as an edge in the diagram.""" - child1 = State("child1", initial=True) - child1._set_id("child1") - child2 = State("child2") - child2._set_id("child2") - - h = HistoryState(name="H") - h._set_id("hist") - # Add a default transition from history to child1 - t = Transition(source=h, target=child1, initial=True) - h.transitions.add_transitions(t) + from statemachine.contrib.diagram.model import DiagramTransition - parent = State("parent", states=[child1, child2], history=[h]) - parent._set_id("parent") - - graph_maker = DotGraphMachine.__new__(DotGraphMachine) - graph_maker.font_name = "Arial" - graph_maker.transition_font_size = "9pt" - - edges = graph_maker._transition_as_edges(t) + transition = DiagramTransition(source="hist", targets=["child1"], event="") + renderer = DotRenderer() + renderer._compound_ids = set() + edges = renderer._create_edges(transition) assert len(edges) == 1 edge = edges[0] assert edge.obj_dict["points"] == ("hist", "child1") @@ -320,26 +361,17 @@ def test_history_state_in_graph_states(): def test_multi_target_transition_diagram(): """Edges are created for all targets of a multi-target transition.""" - source = State("source", initial=True) - source._set_id("source") - target1 = State("target1") - target1._set_id("target1") - target2 = State("target2") - target2._set_id("target2") - - t = Transition(source=source, target=[target1, target2]) - t._events.add("go") - - graph_maker = DotGraphMachine.__new__(DotGraphMachine) - graph_maker.font_name = "Arial" - graph_maker.transition_font_size = "9pt" + from statemachine.contrib.diagram.model import DiagramTransition - edges = graph_maker._transition_as_edges(t) + transition = DiagramTransition(source="source", targets=["target1", "target2"], event="go") + renderer = DotRenderer() + renderer._compound_ids = set() + edges = renderer._create_edges(transition) assert len(edges) == 2 assert edges[0].obj_dict["points"] == ("source", "target1") assert edges[1].obj_dict["points"] == ("source", "target2") # Only the first edge gets a label - assert edges[0].obj_dict["attributes"]["label"] == "go" + assert "go" in edges[0].obj_dict["attributes"]["label"] assert edges[1].obj_dict["attributes"]["label"] == "" @@ -373,5 +405,689 @@ class region2(State.Compound, name="Region2"): assert "cluster_region2" in dot # Parallel indicator assert "☷" in dot - # Verify initial edges exist for compound states (top and regions) - assert "top_anchor -> entry" in dot + # Implicit initial transitions from compound states are NOT rendered as edges — + # they are represented by the black-dot initial node inside each cluster. + assert "top_anchor -> entry" not in dot + assert "-> entry" in dot + + +class TestSVGShapeConsistency: + """Verify that active and inactive states render with the same shape in SVG. + + These tests parse the generated SVG to catch visual regressions that are + hard to spot by inspecting DOT source alone. For example, using `bgcolor` + on a `` instead of a `
` causes Graphviz to render a rectangular + `` behind a rounded `` border — the DOT looks fine but the + visual result is broken. + """ + + def test_active_state_has_no_rectangular_fill(self): + """Active state background must use rounded , not rectangular .""" + from tests.examples.traffic_light_machine import TrafficLightMachine + + sm = TrafficLightMachine() # starts in Green + graph = DotGraphMachine(sm).get_graph() + svg = _parse_svg(graph) + + green_node = _find_state_node(svg, "green") + assert green_node is not None, "Could not find 'green' node in SVG" + assert not _has_rectangular_fill(green_node), ( + "Active state 'green' has a rectangular fill — " + "expected a rounded fill to match the border shape" + ) + + def test_active_and_inactive_states_use_same_svg_element_type(self): + """Active and inactive states must both render as rounded elements. + + With ``shape=rectangle`` + ``style="rounded, filled"``, Graphviz renders + each state as a single ```` with cubic Bezier curves (``C`` commands) + for rounded corners. Both the fill and stroke are in the same ````. + + A regression would be if the active state rendered differently — e.g., a + rectangular ```` for the fill behind a rounded ```` border. + """ + from tests.examples.traffic_light_machine import TrafficLightMachine + + sm = TrafficLightMachine() + graph = DotGraphMachine(sm).get_graph() + svg = _parse_svg(graph) + + for state_id in ("green", "yellow", "red"): + node = _find_state_node(svg, state_id) + assert node is not None, f"Could not find '{state_id}' node in SVG" + + # Each state should have at least one with rounded curves + paths = node.findall("{http://www.w3.org/2000/svg}path") + assert len(paths) >= 1, ( + f"State '{state_id}' should have at least 1 , found {len(paths)}" + ) + for p in paths: + assert _path_has_curves(p.get("d", "")), ( + f"State '{state_id}' has a without curves — not rounded" + ) + + def test_no_state_node_has_rectangular_colored_fill(self): + """No state in the diagram should have a rectangular colored fill.""" + + class SM(StateChart): + s1 = State(initial=True) + s2 = State() + s3 = State(final=True) + go = s1.to(s2) + finish = s2.to(s3) + + sm = SM() + sm.go() # move to s2 + graph = DotGraphMachine(sm).get_graph() + svg = _parse_svg(graph) + + for state_id in ("s1", "s2", "s3"): + node = _find_state_node(svg, state_id) + if node is None: + continue + assert not _has_rectangular_fill(node), ( + f"State '{state_id}' has a rectangular colored fill" + ) + + +class TestExtract: + """Tests for extract.py edge cases.""" + + def test_deep_history_state_type(self): + """Deep history state is correctly typed in the extracted graph.""" + from statemachine.contrib.diagram.extract import extract + + from tests.machines.showcase_deep_history import DeepHistorySC + + graph = extract(DeepHistorySC) + # Find the history state in the outer compound's children + outer = next(s for s in graph.states if s.id == "outer") + h_state = next(s for s in outer.children if s.type == StateType.HISTORY_DEEP) + assert h_state is not None + + def test_internal_transition_actions_extracted(self): + """Internal transitions with actions are extracted into state actions.""" + from statemachine.contrib.diagram.extract import extract + + from tests.machines.showcase_internal import InternalSC + + graph = extract(InternalSC) + monitoring = next(s for s in graph.states if s.id == "monitoring") + internal_actions = [a for a in monitoring.actions if a.type == ActionType.INTERNAL] + assert len(internal_actions) >= 1 + assert any("check" in a.body for a in internal_actions) + + def test_internal_transition_skipped_in_bidirectional(self): + """Internal transitions are skipped in _collect_bidirectional_compound_ids.""" + from statemachine.contrib.diagram.extract import extract + + class SM(StateChart): + class parent(State.Compound, name="Parent"): + child1 = State(initial=True) + child2 = State(final=True) + + def log(self): ... + + check = child1.to.itself(internal=True, on="log") + go = child1.to(child2) + + start = State(initial=True) + end = State(final=True) + + enter = start.to(parent) + finish = parent.to(end) + + graph = extract(SM) + # parent has both incoming and outgoing, so it should be bidirectional + assert "parent" in graph.bidirectional_compound_ids + + def test_internal_transition_without_action(self): + """Internal transition without on action has no internal action in diagram.""" + from statemachine.contrib.diagram.extract import extract + + class SM(StateChart): + s1 = State(initial=True) + s2 = State(final=True) + + noop = s1.to.itself(internal=True) + go = s1.to(s2) + + graph = extract(SM) + s1 = next(s for s in graph.states if s.id == "s1") + internal_actions = [a for a in s1.actions if a.type == ActionType.INTERNAL] + assert internal_actions == [] + + def test_extract_invalid_type_raises(self): + """extract() raises TypeError for invalid input.""" + from statemachine.contrib.diagram.extract import extract + + with pytest.raises(TypeError, match="Expected a StateChart"): + extract("not a machine") # type: ignore[arg-type] + + def test_resolve_initial_fallback(self): + """When no explicit initial, first candidate gets is_initial=True.""" + from statemachine.contrib.diagram.extract import _resolve_initial_states + from statemachine.contrib.diagram.model import DiagramState + + states = [ + DiagramState(id="a", name="A", type=StateType.REGULAR), + DiagramState(id="b", name="B", type=StateType.REGULAR), + ] + _resolve_initial_states(states) + assert states[0].is_initial is True + + +class TestDotRendererEdgeCases: + """Tests for dot.py edge cases.""" + + def test_compound_state_with_actions_label(self): + """Compound state with entry/exit actions renders action rows in label.""" + + class SM(StateChart): + class parent(State.Compound, name="Parent"): + child = State(initial=True) + + def on_enter_parent(self): ... + + start = State(initial=True) + enter = start.to(parent) + + dot = DotGraphMachine(SM)().to_string() + # The compound label should contain the entry action + assert "entry" in dot.lower() or "on_enter_parent" in dot + + def test_internal_action_format(self): + """Internal action uses body directly (no 'entry /' prefix).""" + renderer = DotRenderer() + from statemachine.contrib.diagram.model import DiagramAction + + action = DiagramAction(type=ActionType.INTERNAL, body="check / log_status") + result = renderer._format_action(action) + assert result == "check / log_status" + + def test_targetless_transition_self_loop(self): + """Transition with no target falls back to source as destination.""" + from statemachine.contrib.diagram.model import DiagramTransition + + transition = DiagramTransition(source="s1", targets=[], event="tick") + renderer = DotRenderer() + renderer._compound_ids = set() + edges = renderer._create_edges(transition) + assert len(edges) == 1 + # With no targets, target_ids becomes [None], and dst becomes source + assert edges[0].obj_dict["points"][1] == "s1" + + def test_compound_edge_anchor_non_bidirectional(self): + """Non-bidirectional compound state uses generic _anchor node.""" + renderer = DotRenderer() + renderer._compound_bidir_ids = {"other"} + result = renderer._compound_edge_anchor("my_state", "out") + assert result == "my_state_anchor" + + +class TestDiagramMainModule: + """Tests for __main__.py.""" + + def test_main_module_execution(self, tmp_path): + """python -m statemachine.contrib.diagram works.""" + import runpy + + out = tmp_path / "sm.svg" + with mock.patch( + "sys.argv", + [ + "statemachine.contrib.diagram", + "tests.examples.traffic_light_machine.TrafficLightMachine", + str(out), + ], + ): + with pytest.raises(SystemExit) as exc_info: + runpy.run_module( + "statemachine.contrib.diagram", run_name="__main__", alter_sys=True + ) + assert exc_info.value.code is None + assert out.exists() + + +class TestSphinxDirective: + """Unit tests for the statemachine-diagram Sphinx directive.""" + + def test_parse_events(self): + from statemachine.contrib.diagram.sphinx_ext import _parse_events + + assert _parse_events("start, ship") == ["start", "ship"] + assert _parse_events("single") == ["single"] + assert _parse_events(" a , b , c ") == ["a", "b", "c"] + assert _parse_events("") == [] + + def test_import_and_render_class(self, tmp_path): + """Directive logic: import a class and generate SVG.""" + from statemachine.contrib.diagram import DotGraphMachine + from statemachine.contrib.diagram import import_sm + + sm_class = import_sm("tests.examples.order_control_machine.OrderControl") + graph = DotGraphMachine(sm_class).get_graph() + svg_bytes = graph.create_svg() + assert svg_bytes.startswith(b"\n\n' + b'' + b"" + ) + directive = self._make_directive() + svg_tag, _, _ = directive._prepare_svg(svg_bytes) + + assert not svg_tag.startswith("" in svg_tag + + def test_extracts_intrinsic_dimensions(self): + svg_bytes = b'' + directive = self._make_directive() + _, w, h = directive._prepare_svg(svg_bytes) + + assert w == "702pt" + assert h == "170pt" + + def test_removes_fixed_dimensions(self): + svg_bytes = b'' + directive = self._make_directive() + svg_tag, _, _ = directive._prepare_svg(svg_bytes) + + assert 'width="702pt"' not in svg_tag + assert 'height="170pt"' not in svg_tag + assert "viewBox" in svg_tag + + def test_handles_no_dimensions(self): + svg_bytes = b'' + directive = self._make_directive() + _, w, h = directive._prepare_svg(svg_bytes) + + assert w == "" + assert h == "" + + def test_handles_px_dimensions(self): + svg_bytes = b'' + directive = self._make_directive() + _, w, h = directive._prepare_svg(svg_bytes) + + assert w == "200px" + assert h == "100px" + + +class TestBuildSvgStyles: + """Tests for StateMachineDiagram._build_svg_styles.""" + + def _make_directive(self, options=None): + from statemachine.contrib.diagram.sphinx_ext import StateMachineDiagram + + directive = StateMachineDiagram.__new__(StateMachineDiagram) + directive.options = options or {} + return directive + + def test_intrinsic_width_as_max_width(self): + directive = self._make_directive() + result = directive._build_svg_styles("702pt", "170pt") + assert "max-width: 702pt" in result + assert "height: auto" in result + + def test_explicit_width(self): + directive = self._make_directive({"width": "400px"}) + result = directive._build_svg_styles("702pt", "170pt") + assert "width: 400px" in result + assert "max-width" not in result + + def test_explicit_height(self): + directive = self._make_directive({"height": "200px"}) + result = directive._build_svg_styles("702pt", "170pt") + assert "height: 200px" in result + assert "height: auto" not in result + + def test_scale(self): + directive = self._make_directive({"scale": "50%"}) + result = directive._build_svg_styles("702pt", "170pt") + assert "width: 351.0pt" in result + assert "height: 85.0pt" in result + + def test_scale_without_intrinsic(self): + directive = self._make_directive({"scale": "50%"}) + result = directive._build_svg_styles("", "") + # No width/height when no intrinsic dimensions to scale + assert "max-width" not in result + assert "height: auto" in result + + def test_no_dimensions(self): + directive = self._make_directive() + result = directive._build_svg_styles("", "") + assert "height: auto" in result + + def test_explicit_width_overrides_scale(self): + directive = self._make_directive({"width": "300px", "scale": "50%"}) + result = directive._build_svg_styles("702pt", "170pt") + assert "width: 300px" in result + assert "351" not in result + + +class TestBuildWrapperClasses: + """Tests for StateMachineDiagram._build_wrapper_classes.""" + + def _make_directive(self, options=None): + from statemachine.contrib.diagram.sphinx_ext import StateMachineDiagram + + directive = StateMachineDiagram.__new__(StateMachineDiagram) + directive.options = options or {} + return directive + + def test_default_center_align(self): + directive = self._make_directive() + classes = directive._build_wrapper_classes() + assert classes == ["statemachine-diagram", "align-center"] + + def test_custom_align(self): + directive = self._make_directive({"align": "left"}) + classes = directive._build_wrapper_classes() + assert classes == ["statemachine-diagram", "align-left"] + + def test_extra_css_classes(self): + directive = self._make_directive({"class": ["my-class", "another"]}) + classes = directive._build_wrapper_classes() + assert classes == ["statemachine-diagram", "align-center", "my-class", "another"] + + +class TestResolveTarget: + """Tests for StateMachineDiagram._resolve_target.""" + + def _make_directive(self, options=None, tmp_path=None): + from statemachine.contrib.diagram.sphinx_ext import StateMachineDiagram + + directive = StateMachineDiagram.__new__(StateMachineDiagram) + directive.options = options or {} + directive.arguments = ["my.module.MyMachine"] + if tmp_path is not None: + directive.state = mock.MagicMock() + directive.state.document.settings.env.app.outdir = str(tmp_path) + return directive + + def test_no_target_option(self): + directive = self._make_directive() + assert directive._resolve_target(b"") == "" + + def test_explicit_target_url(self): + directive = self._make_directive({"target": "https://example.com/diagram.svg"}) + assert directive._resolve_target(b"") == "https://example.com/diagram.svg" + + def test_empty_target_generates_file(self, tmp_path): + directive = self._make_directive({"target": ""}, tmp_path=tmp_path) + svg_data = b"" + result = directive._resolve_target(svg_data) + + assert result.startswith("/_images/statemachine-") + assert result.endswith(".svg") + + # Verify the file was written + images_dir = tmp_path / "_images" + svg_files = list(images_dir.glob("statemachine-*.svg")) + assert len(svg_files) == 1 + assert svg_files[0].read_bytes() == svg_data + + def test_empty_target_deterministic_filename(self, tmp_path): + """Same qualname + events produces the same filename.""" + directive1 = self._make_directive({"target": "", "events": "go"}, tmp_path=tmp_path) + directive2 = self._make_directive({"target": "", "events": "go"}, tmp_path=tmp_path) + result1 = directive1._resolve_target(b"1") + result2 = directive2._resolve_target(b"2") + assert result1 == result2 + + def test_different_events_different_filename(self, tmp_path): + """Different events produce different filenames.""" + d1 = self._make_directive({"target": "", "events": "a"}, tmp_path=tmp_path) + d2 = self._make_directive({"target": "", "events": "b"}, tmp_path=tmp_path) + assert d1._resolve_target(b"") != d2._resolve_target(b"") + + +class TestDirectiveRun: + """Integration tests for StateMachineDiagram.run().""" + + _QUALNAME = "tests.examples.traffic_light_machine.TrafficLightMachine" + + def _make_directive(self, tmp_path, options=None): + from statemachine.contrib.diagram.sphinx_ext import StateMachineDiagram + + directive = StateMachineDiagram.__new__(StateMachineDiagram) + directive.options = options or {} + directive.lineno = 1 + directive.state_machine = mock.MagicMock() + directive.state = mock.MagicMock() + directive.state.document.settings.env.app.outdir = str(tmp_path) + directive.content_offset = 0 + return directive + + def _run(self, tmp_path, qualname=None, options=None): + directive = self._make_directive(tmp_path, options=options) + directive.arguments = [qualname or self._QUALNAME] + return directive, directive.run() + + def test_render_class_diagram(self, tmp_path): + """Renders a class diagram (no events) as inline SVG.""" + _, result = self._run(tmp_path) + + assert len(result) == 1 + node = result[0] + assert isinstance(node, nodes.raw) + assert node["format"] == "html" + html = node.astext() + assert " element.""" + _, result = self._run(tmp_path, options={"caption": "My caption"}) + + html = result[0].astext() + assert "My caption" in html + + def test_render_with_figclass(self, tmp_path): + """figclass adds extra CSS classes to the figure wrapper.""" + _, result = self._run(tmp_path, options={"caption": "Test", "figclass": ["extra-fig"]}) + + assert "extra-fig" in result[0].astext() + + def test_render_with_alt(self, tmp_path): + """Custom alt text appears in aria-label.""" + _, result = self._run(tmp_path, options={"alt": "Traffic light diagram"}) + + assert 'aria-label="Traffic light diagram"' in result[0].astext() + + def test_render_default_alt(self, tmp_path): + """Default alt text uses the class name from the qualname.""" + _, result = self._run(tmp_path) + + assert 'aria-label="TrafficLightMachine"' in result[0].astext() + + def test_render_with_explicit_target(self, tmp_path): + """Explicit target wraps diagram in a link.""" + _, result = self._run(tmp_path, options={"target": "https://example.com"}) + + html = result[0].astext() + assert 'href="https://example.com"' in html + assert 'target="_blank"' in html + + def test_render_with_empty_target(self, tmp_path): + """Empty target auto-generates a zoom SVG file.""" + _, result = self._run(tmp_path, options={"target": ""}) + + assert 'href="/_images/statemachine-' in result[0].astext() + images_dir = tmp_path / "_images" + assert any(images_dir.glob("statemachine-*.svg")) + + def test_render_with_align(self, tmp_path): + """Align option controls CSS class.""" + _, result = self._run(tmp_path, options={"align": "left"}) + + assert "align-left" in result[0].astext() + + def test_render_with_width(self, tmp_path): + """Width option is applied as inline style.""" + _, result = self._run(tmp_path, options={"width": "400px"}) + + assert "width: 400px" in result[0].astext() + + def test_render_with_name(self, tmp_path): + """Name option calls add_name for cross-referencing.""" + directive = self._make_directive(tmp_path, options={"name": "my-diagram"}) + directive.arguments = [self._QUALNAME] + result = directive.run() + + assert len(result) == 1 + + def test_render_with_class(self, tmp_path): + """Custom CSS classes appear in the wrapper.""" + _, result = self._run(tmp_path, options={"class": ["custom-class"]}) + + assert "custom-class" in result[0].astext() + + def test_invalid_qualname_returns_warning(self, tmp_path): + """Invalid qualname returns a warning node.""" + directive, result = self._run(tmp_path, qualname="nonexistent.module.BadMachine") + + assert len(result) == 1 + directive.state_machine.reporter.warning.assert_called_once() + call_args = directive.state_machine.reporter.warning.call_args + assert "could not import" in call_args[0][0] + + def test_render_failure_returns_warning(self, tmp_path): + """Diagram generation failure returns a warning node.""" + with mock.patch( + "statemachine.contrib.diagram.DotGraphMachine", + side_effect=RuntimeError("render failed"), + ): + directive, result = self._run(tmp_path) + + assert len(result) == 1 + directive.state_machine.reporter.warning.assert_called_once() + call_args = directive.state_machine.reporter.warning.call_args + assert "failed to generate" in call_args[0][0] + + def test_render_without_caption_uses_div(self, tmp_path): + """Without caption, the wrapper is a plain
.""" + _, result = self._run(tmp_path) + + html = result[0].astext() + assert "