diff --git a/CHANGES.md b/CHANGES.md index 7216b1a..ced8275 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -1,5 +1,81 @@ # Changelog +## 0.5.0 (2025-12-24) + +**Major Feature Release: NetworkX Plotter API** (terapyon) + +### New Features + +- **High-level Plotter API**: Direct NetworkX graph visualization without manual JSON conversion + - `Plotter.add_networkx()` method for seamless graph rendering in JupyterLab + - Support for all 4 NetworkX graph types: Graph, DiGraph, MultiGraph, MultiDiGraph + - Automatic node/edge extraction with attribute preservation + +- **Custom Styling Support**: + - Node color mapping via attribute names or callable functions + - Node label mapping with flexible attribute selection + - Edge label mapping for relationship visualization + - Automatic color scale detection (continuous vs. categorical) + +- **Layout Control**: + - 5 built-in layout algorithms: spring, kamada_kawai, spectral, circular, random + - Custom layout function support + - Existing position attribute detection + - Automatic fallback with NaN/inf validation + +- **Multi-Graph Type Support**: + - Edge direction preservation for DiGraph (via metadata) + - Edge key preservation for MultiGraph/MultiDiGraph + - Multiple edge expansion into independent Edge objects + - Automatic graph type detection and dispatch + +### API Examples + +```python +from net_vis import Plotter +import networkx as nx + +# Basic visualization +G = nx.karate_club_graph() +plotter = Plotter(title="Karate Club Network") +plotter.add_networkx(G) + +# Custom styling +plotter.add_networkx(G, + node_color="club", + node_label=lambda d: f"Node {d.get('name', '')}", + layout='kamada_kawai' +) +``` + +### Implementation Details + +- **NetworkXAdapter**: 650+ lines of conversion logic with comprehensive type hints +- **Test Coverage**: 60+ test methods covering all public APIs +- **Python 3.10+ type hints**: Full type annotation support +- **Comprehensive docstrings**: All public methods documented + +### Installation Options + +- **Basic**: `pip install net_vis` - Includes spring, circular, and random layouts +- **Full**: `pip install net_vis[full]` - Includes all layouts (adds SciPy for kamada_kawai and spectral) + +### Dependencies + +**Core:** + +- NetworkX 3.0+ +- NumPy 2.0+ (required for layout algorithms) + +**Optional (installed with [full]):** + +- SciPy 1.8+ (required for kamada_kawai and spectral layouts) + +### Compatibility + +- JupyterLab: 3.x and 4.x +- Python: 3.10+ + ## 0.4.0 (2025-11-21) **Major Release: Migration to MIME Renderer Architecture** (terapyon) diff --git a/README.md b/README.md index bee73e5..4a65e03 100644 --- a/README.md +++ b/README.md @@ -1,24 +1,77 @@ # netvis -NetVis is a package for interactive visualization of Python NetworkX graphs within JupyterLab. It leverages D3.js for dynamic rendering and supports HTML export, making network analysis effortless. +NetVis is a package for interactive visualization of Python NetworkX graphs within JupyterLab. It leverages D3.js for dynamic rendering and provides a high-level Plotter API for effortless network analysis. -**Version 0.4.0** introduces a MIME renderer architecture that simplifies installation and improves compatibility with modern JupyterLab environments. +**Version 0.5.0** introduces the NetworkX Plotter API, enabling direct visualization of NetworkX graph objects without manual JSON conversion. ## Installation +### Basic Installation + You can install using `pip`: ```bash pip install net_vis ``` -**Note for version 0.4.0+**: The nbextension is no longer required. NetVis now uses a MIME renderer that works automatically in JupyterLab 3.x and 4.x environments. +This provides core functionality with layouts: **spring**, **circular**, and **random**. + +### Full Installation (Recommended) + +For all layout algorithms including **kamada_kawai** and **spectral**: + +```bash +pip install net_vis[full] +``` + +This installs optional dependencies (scipy) required for advanced layout algorithms. + +**Note**: NetVis uses a MIME renderer that works automatically in JupyterLab 3.x and 4.x environments. No manual extension enabling is required. ## Quick Start -This section provides a simple guide to get started with the project using JupyterLab. +### NetworkX Plotter API (New in v0.5.0) + +The easiest way to visualize NetworkX graphs in JupyterLab: + +```python +from net_vis import Plotter +import networkx as nx + +# Create a NetworkX graph +G = nx.karate_club_graph() + +# Visualize with one line +plotter = Plotter(title="Karate Club Network") +plotter.add_networkx(G) +``` + +#### Custom Styling + +Control node colors, labels, and layouts: + +```python +# Color nodes by attribute, customize labels +plotter = Plotter(title="Styled Network") +plotter.add_networkx( + G, + node_color="club", # Use 'club' attribute for colors + node_label=lambda d: f"Node {d.get('name', '')}", # Custom labels + edge_label="weight", # Show edge weights + layout='kamada_kawai' # Choose layout algorithm +) +``` + +#### Supported Features + +- **Graph Types**: Graph, DiGraph, MultiGraph, MultiDiGraph +- **Layouts**: spring (default), kamada_kawai, spectral, circular, random, or custom functions +- **Styling**: Attribute-based or function-based color/label mapping +- **Automatic**: Node/edge attribute preservation in metadata + +### Low-Level API (Advanced) -### Example +For manual control over the visualization data structure: ```python import net_vis diff --git a/docs/source/examples/index.rst b/docs/source/examples/index.rst index 3f052cc..6e8c566 100644 --- a/docs/source/examples/index.rst +++ b/docs/source/examples/index.rst @@ -4,10 +4,31 @@ Examples This section contains examples of using NetVis for interactive graph visualization in JupyterLab. -Basic Usage ------------ -The most basic usage of NetVis:: +NetworkX Plotter API (Recommended) +----------------------------------- + +**New in v0.5.0**: The easiest way to visualize NetworkX graphs:: + + from net_vis import Plotter + import networkx as nx + + # Create and visualize NetworkX graph + G = nx.karate_club_graph() + plotter = Plotter(title="Karate Club Network") + plotter.add_networkx(G) + +For comprehensive examples including custom styling, layouts, and multi-graph support, see the :ref:`NetworkX Plotter API notebook ` below. + + +Low-Level API +------------- + +For advanced control with manual JSON, you can use the low-level NetVis API: + +**Basic Usage** + +The most basic usage of the low-level API:: import net_vis @@ -63,7 +84,11 @@ NetVis can handle large graphs efficiently. The force-directed layout automatica For more examples, see the `examples directory `_ in the GitHub repository. +Detailed Examples +----------------- + .. toctree:: - :glob: + :maxdepth: 2 - * + networkx_plotter + introduction diff --git a/docs/source/examples/networkx_plotter.nblink b/docs/source/examples/networkx_plotter.nblink new file mode 100644 index 0000000..6f0de5e --- /dev/null +++ b/docs/source/examples/networkx_plotter.nblink @@ -0,0 +1,3 @@ +{ + "path": "../../../examples/networkx_plotter.ipynb" +} diff --git a/docs/source/index.rst b/docs/source/index.rst index 0573c4d..c7dd332 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -4,9 +4,9 @@ net_vis Version: |release| -NetVis is a package for interactive visualization of Python NetworkX graphs within JupyterLab. It leverages D3.js for dynamic rendering and supports HTML export, making network analysis effortless. +NetVis is a package for interactive visualization of Python NetworkX graphs within JupyterLab. It leverages D3.js for dynamic rendering and provides a high-level Plotter API for effortless network analysis. -**Version 0.4.0** introduces a MIME renderer architecture that simplifies installation and improves compatibility with modern JupyterLab environments. +**Version 0.5.0** introduces the NetworkX Plotter API, enabling direct visualization of NetworkX graph objects without manual JSON conversion. Quickstart @@ -16,7 +16,17 @@ To get started with net_vis, install with pip:: pip install net_vis -**Note**: As of version 0.4.0, NetVis uses a MIME renderer that works automatically in JupyterLab 3.x and 4.x. Manual extension enabling is no longer required. +**NetworkX Plotter API (New in v0.5.0)**:: + + from net_vis import Plotter + import networkx as nx + + # Create and visualize NetworkX graph + G = nx.karate_club_graph() + plotter = Plotter(title="Karate Club Network") + plotter.add_networkx(G) + +**Note**: NetVis uses a MIME renderer that works automatically in JupyterLab 3.x and 4.x. Manual extension enabling is not required. Contents diff --git a/docs/source/installing.rst b/docs/source/installing.rst index 9daa2cc..f1974e9 100644 --- a/docs/source/installing.rst +++ b/docs/source/installing.rst @@ -4,19 +4,47 @@ Installation ============ +Basic Installation +------------------ The simplest way to install net_vis is via pip:: pip install net_vis +This provides core functionality with the following layout algorithms: + +- **spring** (force-directed) +- **circular** +- **random** + +Full Installation (Recommended) +-------------------------------- + +For all layout algorithms including advanced options:: + + pip install net_vis[full] + +This installs optional dependencies (scipy) and enables additional layouts: + +- **kamada_kawai** (stress-minimization) +- **spectral** (eigenvalue-based) + **That's it!** As of version 0.4.0, NetVis uses a MIME renderer that works automatically in JupyterLab 3.x and 4.x environments. No additional installation or configuration steps are required. Requirements ------------ +**Core Dependencies:** + - Python 3.10 or later - JupyterLab 3.x or 4.x +- NetworkX 3.0+ (automatically installed) +- NumPy 2.0+ (automatically installed, required for layout algorithms) + +**Optional Dependencies (installed with [full]):** + +- SciPy 1.8+ (required for kamada_kawai and spectral layouts) **Note**: Jupyter Notebook Classic is no longer supported as of version 0.4.0. Please use JupyterLab instead. diff --git a/docs/source/introduction.rst b/docs/source/introduction.rst index dc7c9c4..0e82513 100644 --- a/docs/source/introduction.rst +++ b/docs/source/introduction.rst @@ -8,17 +8,47 @@ NetVis is a package for interactive visualization of Python NetworkX graphs with Key Features ------------ +- **NetworkX Plotter API (v0.5.0)**: Direct visualization of NetworkX graphs without JSON conversion - **Interactive D3.js Visualization**: Force-directed graph layout with interactive node dragging, zooming, and panning -- **Simple Python API**: Works seamlessly with NetworkX graph data structures +- **Multiple Graph Types**: Support for Graph, DiGraph, MultiGraph, and MultiDiGraph +- **Layout Control**: 5 built-in algorithms (spring, kamada_kawai, spectral, circular, random) plus custom functions +- **Custom Styling**: Attribute-based or function-based color and label mapping - **MIME Renderer Architecture**: Automatic rendering in JupyterLab 3.x and 4.x without manual extension configuration -- **Customizable Appearance**: Support for custom node colors, sizes, and categories - **Modern Stack**: Built with TypeScript and modern JupyterLab extension architecture -Quick Example -------------- +Quick Example (NetworkX Plotter API) +------------------------------------- -Here's a simple example to get you started:: +The easiest way to visualize NetworkX graphs (new in v0.5.0):: + + from net_vis import Plotter + import networkx as nx + + # Create NetworkX graph + G = nx.karate_club_graph() + + # Visualize with one line + plotter = Plotter(title="Karate Club Network") + plotter.add_networkx(G) + + # Custom styling + plotter = Plotter() + plotter.add_networkx( + G, + node_color="club", # Use 'club' attribute for colors + node_label=lambda d: f"Node {d.get('name', '')}", + edge_label="weight", + layout='kamada_kawai' # Choose layout algorithm + ) + +When executed in JupyterLab, this displays an interactive force-directed graph. + + +Low-Level API Example +---------------------- + +For advanced control, you can use the low-level API with manual JSON:: import net_vis @@ -47,13 +77,43 @@ When executed in JupyterLab, this displays an interactive force-directed graph w - **Click nodes** to pin/unpin them +What's New in 0.5.0 +------------------- + +Version 0.5.0 introduces the **NetworkX Plotter API**, a high-level interface for visualizing NetworkX graphs: + +**NetworkX Plotter API** + - Direct visualization of NetworkX graph objects without manual JSON conversion + - Support for all 4 NetworkX graph types: Graph, DiGraph, MultiGraph, MultiDiGraph + - Automatic node and edge extraction with full attribute preservation + +**Layout Control** + - 5 built-in layout algorithms: spring, kamada_kawai, spectral, circular, random + - Custom layout function support + - Automatic fallback for invalid positions (NaN/inf) + +**Custom Styling** + - Node color mapping via attribute names or callable functions + - Node label mapping with flexible attribute selection + - Edge label mapping for relationship visualization + - Automatic color scale detection (continuous vs. categorical) + +**Multi-Graph Type Support** + - Edge direction preservation for DiGraph (stored in metadata) + - Edge key preservation for MultiGraph/MultiDiGraph + - Multiple edges expanded to independent Edge objects + - Automatic graph type detection and dispatch + +See the :doc:`examples/index` for complete usage examples. + + Architecture (v0.4.0) --------------------- -Version 0.4.0 introduces a major architectural change: +Version 0.4.0 introduced a major architectural change: **MIME Renderer** - NetVis now uses JupyterLab's MIME renderer system instead of ipywidgets. This means: + NetVis uses JupyterLab's MIME renderer system instead of ipywidgets. This means: - Simpler installation (no manual extension enabling) - Better performance and integration with JupyterLab @@ -62,19 +122,8 @@ Version 0.4.0 introduces a major architectural change: **JupyterLab Only** NetVis 0.4.0+ exclusively supports JupyterLab 3.x and 4.x. Jupyter Notebook Classic is no longer supported. -**Python API Unchanged** - Despite the internal changes, the Python API remains 100% compatible with previous versions. - - -What's New in 0.4.0 -------------------- - -- **MIME renderer architecture** replacing ipywidgets -- **Simplified installation** - just ``pip install net_vis`` -- **Removed nbextension support** - JupyterLab only -- **Python 3.10+ support** including 3.13 and 3.14 -- **Comprehensive test suite** with 41 TypeScript tests and 16 Python tests -- **Code quality tools** - ruff and pyright for Python linting and type checking +**Python API** + The low-level NetVis API remains compatible with previous versions, and the new Plotter API provides a higher-level interface. Migrating from 0.3.x diff --git a/examples/networkx_plotter.ipynb b/examples/networkx_plotter.ipynb new file mode 100644 index 0000000..169ddeb --- /dev/null +++ b/examples/networkx_plotter.ipynb @@ -0,0 +1,276 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# NetworkX Plotter API\n", + "\n", + "This notebook demonstrates the NetworkX Plotter API introduced in version 0.5.0.\n", + "\n", + "The Plotter API provides a high-level interface for visualizing NetworkX graphs directly in JupyterLab without manual JSON conversion." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Basic Usage\n", + "\n", + "The simplest way to visualize a NetworkX graph:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from net_vis import Plotter\n", + "import networkx as nx" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Create a simple graph\n", + "G = nx.karate_club_graph()\n", + "\n", + "# Visualize with one line\n", + "plotter = Plotter(title=\"Karate Club Network\")\n", + "plotter.add_networkx(G)\n", + "plotter" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Custom Styling with Attributes\n", + "\n", + "Map node colors and labels from graph attributes:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# The karate club graph has a 'club' attribute\n", + "plotter2 = Plotter(title=\"Karate Club - Colored by Club\")\n", + "plotter2.add_networkx(\n", + " G,\n", + " node_color=\"club\", # Use 'club' attribute for colors\n", + " layout='kamada_kawai' # Use Kamada-Kawai layout\n", + ")\n", + "plotter2" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Custom Styling with Functions\n", + "\n", + "Use callable functions for more complex styling logic:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Create a graph with custom attributes\n", + "G2 = nx.Graph()\n", + "G2.add_node(1, name=\"Alice\", value=10)\n", + "G2.add_node(2, name=\"Bob\", value=20)\n", + "G2.add_node(3, name=\"Charlie\", value=30)\n", + "G2.add_edge(1, 2, relation=\"friend\", weight=5.0)\n", + "G2.add_edge(2, 3, relation=\"colleague\", weight=3.0)\n", + "G2.add_edge(1, 3, relation=\"neighbor\", weight=2.0)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Use functions for custom styling\n", + "plotter3 = Plotter(title=\"Social Network with Custom Styling\")\n", + "plotter3.add_networkx(\n", + " G2,\n", + " node_color=lambda d: f\"value_{d.get('value', 0)}\",\n", + " node_label=lambda d: d.get('name', 'Unknown'),\n", + " edge_label=lambda d: f\"{d.get('relation', '')} ({d.get('weight', 0)})\",\n", + " layout='circular'\n", + ")\n", + "plotter3" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Layout Algorithms\n", + "\n", + "NetVis supports 5 built-in layout algorithms:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Create a test graph\n", + "G3 = nx.random_geometric_graph(20, 0.3)\n", + "\n", + "layouts = ['spring', 'kamada_kawai', 'spectral', 'circular', 'random']\n", + "\n", + "for layout_name in layouts:\n", + " plotter = Plotter(title=f\"Layout: {layout_name}\")\n", + " plotter.add_networkx(G3, layout=layout_name)\n", + " display(plotter)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Directed Graphs (DiGraph)\n", + "\n", + "DiGraph edges have direction preserved in metadata:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Create a directed graph\n", + "DG = nx.DiGraph()\n", + "DG.add_edges_from([(1, 2), (1, 3), (2, 3), (3, 4), (4, 2)])\n", + "\n", + "plotter4 = Plotter(title=\"Directed Graph\")\n", + "plotter4.add_networkx(DG, layout='spring')\n", + "plotter4" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## MultiGraph Support\n", + "\n", + "MultiGraph allows multiple edges between the same pair of nodes:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Create a multigraph\n", + "MG = nx.MultiGraph()\n", + "MG.add_edge(1, 2, relation=\"friend\", weight=1.0)\n", + "MG.add_edge(1, 2, relation=\"colleague\", weight=2.0)\n", + "MG.add_edge(2, 3, relation=\"family\", weight=5.0)\n", + "\n", + "plotter5 = Plotter(title=\"MultiGraph with Multiple Edges\")\n", + "plotter5.add_networkx(\n", + " MG,\n", + " edge_label=\"relation\",\n", + " layout='spring'\n", + ")\n", + "plotter5" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Custom Layout Functions\n", + "\n", + "You can provide your own layout function:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def grid_layout(graph):\n", + " \"\"\"Place nodes in a grid layout.\"\"\"\n", + " import math\n", + " n = len(graph.nodes())\n", + " cols = int(math.ceil(math.sqrt(n)))\n", + " \n", + " positions = {}\n", + " for i, node in enumerate(graph.nodes()):\n", + " row = i // cols\n", + " col = i % cols\n", + " positions[node] = (col * 0.2, row * 0.2)\n", + " \n", + " return positions\n", + "\n", + "G4 = nx.complete_graph(16)\n", + "plotter6 = Plotter(title=\"Custom Grid Layout\")\n", + "plotter6.add_networkx(G4, layout=grid_layout)\n", + "plotter6" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Export to JSON\n", + "\n", + "You can export the scene structure as JSON:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "plotter = Plotter()\n", + "G = nx.path_graph(5)\n", + "plotter.add_networkx(G)\n", + "\n", + "json_data = plotter.to_json()\n", + "print(json_data[:500] + \"...\") # Print first 500 characters" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.0" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/net_vis/__init__.py b/net_vis/__init__.py index 8a05c26..fd08357 100644 --- a/net_vis/__init__.py +++ b/net_vis/__init__.py @@ -1,2 +1,3 @@ from ._version import __version__, version_info from .netvis import NetVis +from .plotter import Plotter diff --git a/net_vis/_frontend.py b/net_vis/_frontend.py deleted file mode 100644 index c304a31..0000000 --- a/net_vis/_frontend.py +++ /dev/null @@ -1,6 +0,0 @@ -""" -Information about the frontend package of the widgets. -""" - -module_name = "net_vis" -module_version = "^0.4.0" diff --git a/net_vis/_version.py b/net_vis/_version.py index 2bb206e..009fe53 100644 --- a/net_vis/_version.py +++ b/net_vis/_version.py @@ -1,2 +1,2 @@ -version_info = (0, 4, 0) +version_info = (0, 5, 0) __version__ = ".".join(map(str, version_info)) diff --git a/net_vis/adapters/__init__.py b/net_vis/adapters/__init__.py new file mode 100644 index 0000000..d30348c --- /dev/null +++ b/net_vis/adapters/__init__.py @@ -0,0 +1,5 @@ +"""Adapters for converting graph formats to netvis data structures.""" + +from net_vis.adapters.networkx_adapter import NetworkXAdapter + +__all__ = ["NetworkXAdapter"] diff --git a/net_vis/adapters/networkx_adapter.py b/net_vis/adapters/networkx_adapter.py new file mode 100644 index 0000000..aa4044c --- /dev/null +++ b/net_vis/adapters/networkx_adapter.py @@ -0,0 +1,650 @@ +"""NetworkX graph adapter for converting to netvis data structures.""" + +import warnings +from collections.abc import Callable +from typing import Any + +import networkx as nx + +from ..models import Edge, GraphLayer, Node + + +class NetworkXAdapter: + """Converts NetworkX graph objects to netvis GraphLayer format. + + Handles node/edge extraction, attribute preservation, layout computation, + and visual property mapping for all NetworkX graph types (Graph, DiGraph, + MultiGraph, MultiDiGraph). + """ + + @staticmethod + def _detect_graph_type(graph: Any) -> str: + """Detect NetworkX graph type. + + Args: + graph: NetworkX graph object + + Returns: + Graph type string: 'graph', 'digraph', 'multigraph', 'multidigraph' + """ + # Check class name to determine type + class_name = type(graph).__name__.lower() + + if "multidigraph" in class_name: + return "multidigraph" + elif "multigraph" in class_name: + return "multigraph" + elif "digraph" in class_name: + return "digraph" + else: + return "graph" + + @staticmethod + def _extract_nodes( + graph: Any, + positions: dict[Any, Any], + node_color: str | Callable | None = None, + node_label: str | Callable | None = None, + ) -> list[Node]: + """Extract nodes from NetworkX graph with ID conversion to string. + + Args: + graph: NetworkX graph object + positions: Dictionary mapping node IDs to (x, y) positions + node_color: Attribute name or function for color mapping + node_label: Attribute name or function for label mapping + + Returns: + List of Node objects with positions and metadata + """ + nodes = [] + + for node_id in graph.nodes(): + # Convert node ID to string + node_id_str = str(node_id) + + # Get position from layout (default to (0, 0) if missing) + x, y = positions.get(node_id, (0.0, 0.0)) + + # Get node attributes and preserve them in metadata + node_attrs = dict(graph.nodes[node_id]) if graph.nodes[node_id] else {} + + # Apply color mapping + color = NetworkXAdapter._map_node_color(node_id, node_attrs, node_color) + + # Apply label mapping + label = NetworkXAdapter._map_node_label(node_id, node_attrs, node_label) + + # Create Node object + node = Node( + id=node_id_str, + x=float(x), + y=float(y), + color=color, + label=label, + metadata=node_attrs, + ) + + nodes.append(node) + + return nodes + + @staticmethod + def _extract_edges( + graph: Any, + edge_label: str | Callable | None = None, + ) -> list[Edge]: + """Extract edges from NetworkX graph with automatic type dispatch. + + Args: + graph: NetworkX graph object + edge_label: Attribute name or function for label mapping + + Returns: + List of Edge objects with metadata + """ + # Detect graph type and dispatch to appropriate extractor + graph_type = NetworkXAdapter._detect_graph_type(graph) + + if graph_type in ("multigraph", "multidigraph"): + return NetworkXAdapter._expand_multigraph_edges(graph, edge_label) + elif graph_type == "digraph": + return NetworkXAdapter._extract_edges_digraph(graph, edge_label) + else: + # Basic Graph type + return NetworkXAdapter._extract_edges_simple(graph, edge_label) + + @staticmethod + def _extract_edges_simple( + graph: Any, + edge_label: str | Callable | None = None, + ) -> list[Edge]: + """Extract edges from NetworkX Graph (undirected, simple). + + Args: + graph: NetworkX graph object + edge_label: Attribute name or function for label mapping + + Returns: + List of Edge objects with metadata + """ + edges = [] + + for source, target in graph.edges(): + # Convert node IDs to strings + source_str = str(source) + target_str = str(target) + + # Get edge attributes and preserve them in metadata + edge_attrs = dict(graph[source][target]) if graph[source][target] else {} + + # Apply label mapping + label = NetworkXAdapter._map_edge_label(edge_attrs, edge_label) + + # Create Edge object + edge = Edge(source=source_str, target=target_str, label=label, metadata=edge_attrs) + + edges.append(edge) + + return edges + + @staticmethod + def _extract_edges_digraph( + graph: Any, + edge_label: str | Callable | None = None, + ) -> list[Edge]: + """Extract edges from NetworkX DiGraph (directed). + + Args: + graph: NetworkX DiGraph object + edge_label: Attribute name or function for label mapping + + Returns: + List of Edge objects with direction preserved in metadata + """ + edges = [] + + for source, target in graph.edges(): + # Convert node IDs to strings + source_str = str(source) + target_str = str(target) + + # Get edge attributes and preserve them in metadata + edge_attrs = dict(graph[source][target]) if graph[source][target] else {} + + # Add direction indicator to metadata for DiGraph + edge_attrs["directed"] = True + + # Apply label mapping + label = NetworkXAdapter._map_edge_label(edge_attrs, edge_label) + + # Create Edge object + edge = Edge(source=source_str, target=target_str, label=label, metadata=edge_attrs) + + edges.append(edge) + + return edges + + @staticmethod + def _expand_multigraph_edges( + graph: Any, + edge_label: str | Callable | None = None, + ) -> list[Edge]: + """Extract and expand edges from NetworkX MultiGraph/MultiDiGraph. + + Multiple edges between the same pair of nodes are expanded into + independent Edge objects, with edge keys preserved in metadata. + + Args: + graph: NetworkX MultiGraph or MultiDiGraph object + edge_label: Attribute name or function for label mapping + + Returns: + List of Edge objects with edge keys preserved in metadata + """ + edges = [] + + # Check if this is a directed multigraph + graph_type = NetworkXAdapter._detect_graph_type(graph) + is_directed = graph_type == "multidigraph" + + # MultiGraph.edges() returns (source, target, key) tuples + for source, target, key in graph.edges(keys=True): + # Convert node IDs to strings + source_str = str(source) + target_str = str(target) + + # Get edge attributes for this specific edge key + edge_attrs = dict(graph[source][target][key]) if graph[source][target][key] else {} + + # Preserve edge key in metadata + edge_attrs["edge_key"] = key + + # Add direction indicator for MultiDiGraph + if is_directed: + edge_attrs["directed"] = True + + # Apply label mapping + label = NetworkXAdapter._map_edge_label(edge_attrs, edge_label) + + # Create Edge object + edge = Edge(source=source_str, target=target_str, label=label, metadata=edge_attrs) + + edges.append(edge) + + return edges + + @staticmethod + def _get_existing_positions(graph: Any) -> dict[Any, Any] | None: + """Extract existing 'pos' attribute from nodes. + + Args: + graph: NetworkX graph object + + Returns: + Dictionary mapping node IDs to (x, y) positions, or None if not available + """ + positions = {} + has_positions = False + + for node_id in graph.nodes(): + node_data = graph.nodes[node_id] + if "pos" in node_data: + positions[node_id] = node_data["pos"] + has_positions = True + + return positions if has_positions else None + + @staticmethod + def _apply_spring_layout(graph: Any) -> dict[Any, Any]: + """Apply spring (force-directed) layout. + + Args: + graph: NetworkX graph object + + Returns: + Dictionary mapping node IDs to (x, y) positions + """ + return nx.spring_layout(graph) + + @staticmethod + def _apply_kamada_kawai_layout(graph: Any) -> dict[Any, Any]: + """Apply Kamada-Kawai layout. + + Args: + graph: NetworkX graph object + + Returns: + Dictionary mapping node IDs to (x, y) positions + + Raises: + ImportError: If scipy is not installed + """ + try: + import scipy # type: ignore[import-not-found] # noqa: F401 + except ImportError: + raise ImportError( + "Layout 'kamada_kawai' requires scipy. Install with: pip install net_vis[full]" + ) + + return nx.kamada_kawai_layout(graph) + + @staticmethod + def _apply_spectral_layout(graph: Any) -> dict[Any, Any]: + """Apply spectral layout. + + Args: + graph: NetworkX graph object + + Returns: + Dictionary mapping node IDs to (x, y) positions + + Raises: + ImportError: If scipy is not installed + """ + try: + import scipy # type: ignore[import-not-found] # noqa: F401 + except ImportError: + raise ImportError( + "Layout 'spectral' requires scipy. Install with: pip install net_vis[full]" + ) + + return nx.spectral_layout(graph) + + @staticmethod + def _apply_circular_layout(graph: Any) -> dict[Any, Any]: + """Apply circular layout. + + Args: + graph: NetworkX graph object + + Returns: + Dictionary mapping node IDs to (x, y) positions + """ + return nx.circular_layout(graph) + + @staticmethod + def _apply_random_layout(graph: Any) -> dict[Any, Any]: + """Apply random layout. + + Args: + graph: NetworkX graph object + + Returns: + Dictionary mapping node IDs to (x, y) positions + """ + return nx.random_layout(graph) + + @staticmethod + def _apply_custom_layout(graph: Any, layout_func: Callable) -> dict[Any, Any]: + """Apply custom layout function. + + Args: + graph: NetworkX graph object + layout_func: Custom function that takes graph and returns position dict + + Returns: + Dictionary mapping node IDs to (x, y) positions + """ + return layout_func(graph) + + @staticmethod + def _validate_positions(positions: dict[Any, Any]) -> bool: + """Validate that positions don't contain NaN or inf values. + + Args: + positions: Dictionary mapping node IDs to (x, y) positions + + Returns: + True if valid, False otherwise + """ + import math + + for node_id, (x, y) in positions.items(): + if math.isnan(x) or math.isnan(y) or math.isinf(x) or math.isinf(y): + return False + return True + + @staticmethod + def _compute_layout(graph: Any, layout: str | Callable | None = None) -> dict[Any, Any]: + """Compute node positions using specified layout algorithm. + + Args: + graph: NetworkX graph object + layout: Layout algorithm name, custom function, or None + + Returns: + Dictionary mapping node IDs to (x, y) positions + """ + # Handle empty graphs + if len(graph.nodes()) == 0: + return {} + + # Determine which layout to use + positions = None + + if layout is None: + # Try to use existing 'pos' attribute, fall back to spring + positions = NetworkXAdapter._get_existing_positions(graph) + if positions is None: + try: + positions = NetworkXAdapter._apply_spring_layout(graph) + except Exception as e: + warnings.warn(f"Spring layout failed: {e}, falling back to random layout") + positions = NetworkXAdapter._apply_random_layout(graph) + elif callable(layout): + # Custom layout function + try: + positions = NetworkXAdapter._apply_custom_layout(graph, layout) + except Exception as e: + warnings.warn(f"Custom layout failed: {e}, falling back to random layout") + positions = NetworkXAdapter._apply_random_layout(graph) + else: + # Named layout algorithm + layout_str = str(layout).lower() + try: + if layout_str == "spring": + positions = NetworkXAdapter._apply_spring_layout(graph) + elif layout_str == "kamada_kawai": + positions = NetworkXAdapter._apply_kamada_kawai_layout(graph) + elif layout_str == "spectral": + positions = NetworkXAdapter._apply_spectral_layout(graph) + elif layout_str == "circular": + positions = NetworkXAdapter._apply_circular_layout(graph) + elif layout_str == "random": + positions = NetworkXAdapter._apply_random_layout(graph) + else: + warnings.warn(f"Unknown layout '{layout}', using spring layout") + positions = NetworkXAdapter._apply_spring_layout(graph) + except Exception as e: + warnings.warn(f"Layout '{layout}' failed: {e}, falling back to random layout") + positions = NetworkXAdapter._apply_random_layout(graph) + + # Validate positions + if not NetworkXAdapter._validate_positions(positions): + warnings.warn( + "Layout produced invalid positions (NaN/inf), falling back to random layout" + ) + positions = NetworkXAdapter._apply_random_layout(graph) + + return positions + + @staticmethod + def convert_graph( + graph: Any, + layout: str | Callable | None = None, + node_color: str | Callable | None = None, + node_label: str | Callable | None = None, + edge_label: str | Callable | None = None, + ) -> GraphLayer: + """Convert NetworkX graph to GraphLayer with layout and styling. + + Args: + graph: NetworkX graph object + layout: Layout algorithm name, custom function, or None + node_color: Attribute name or function for node color mapping + node_label: Attribute name or function for node label mapping + edge_label: Attribute name or function for edge label mapping + + Returns: + GraphLayer object with nodes, edges, and metadata + + Raises: + ValueError: If layout computation fails + """ + # Detect graph type + graph_type = NetworkXAdapter._detect_graph_type(graph) + + # Compute layout positions + positions = NetworkXAdapter._compute_layout(graph, layout=layout) + + # Extract nodes with positions and styling + nodes = NetworkXAdapter._extract_nodes( + graph, + positions, + node_color=node_color, + node_label=node_label, + ) + + # Extract edges with styling + edges = NetworkXAdapter._extract_edges( + graph, + edge_label=edge_label, + ) + + # Create GraphLayer with metadata + layer = GraphLayer( + layer_id="", # Will be set by Plotter + nodes=nodes, + edges=edges, + metadata={"graph_type": graph_type}, + ) + + return layer + + @staticmethod + def _map_node_color( + node_id: Any, node_data: dict, mapping: str | Callable | None + ) -> str | None: + """Map node attribute to color value. + + Args: + node_id: Node identifier + node_data: Node attributes dictionary + mapping: Attribute name (str) or function (node_data -> color_value) + + Returns: + Color value (string) or None if not mapped + """ + if mapping is None: + return None + + if callable(mapping): + # Call function with node_data + try: + result = mapping(node_data) + return str(result) if result is not None else None + except Exception: + return None + else: + # mapping is attribute name (str) + value = node_data.get(mapping) + return str(value) if value is not None else None + + @staticmethod + def _map_node_label( + node_id: Any, node_data: dict, mapping: str | Callable | None + ) -> str | None: + """Map node attribute to label value. + + Args: + node_id: Node identifier + node_data: Node attributes dictionary + mapping: Attribute name (str) or function (node_data -> label_str) + + Returns: + Label string or None if not mapped + """ + if mapping is None: + return None + + if callable(mapping): + # Call function with node_data + try: + result = mapping(node_data) + return str(result) if result is not None else None + except Exception: + return None + else: + # mapping is attribute name (str) + value = node_data.get(mapping) + return str(value) if value is not None else None + + @staticmethod + def _map_edge_label(edge_data: dict, mapping: str | Callable | None) -> str | None: + """Map edge attribute to label value. + + Args: + edge_data: Edge attributes dictionary + mapping: Attribute name (str) or function (edge_data -> label_str) + + Returns: + Label string or None if not mapped + """ + if mapping is None: + return None + + if callable(mapping): + # Call function with edge_data + try: + result = mapping(edge_data) + return str(result) if result is not None else None + except Exception: + return None + else: + # mapping is attribute name (str) + value = edge_data.get(mapping) + return str(value) if value is not None else None + + @staticmethod + def _detect_color_type(values: list) -> str: + """Detect if color values are numeric or categorical. + + Args: + values: List of color values + + Returns: + 'numeric' or 'categorical' + """ + # Check if all non-None values are numeric + numeric_count = 0 + total_count = 0 + + for val in values: + if val is not None: + total_count += 1 + if isinstance(val, (int, float)): + numeric_count += 1 + + # If majority are numeric, treat as numeric + if total_count > 0 and numeric_count / total_count > 0.5: + return "numeric" + return "categorical" + + @staticmethod + def _apply_continuous_color_scale(value: float, min_val: float, max_val: float) -> str: + """Apply continuous color scale to numeric value. + + Args: + value: Numeric value to map + min_val: Minimum value in dataset + max_val: Maximum value in dataset + + Returns: + Hex color string + """ + # Simple linear interpolation from blue to red + if max_val == min_val: + ratio = 0.5 + else: + ratio = (value - min_val) / (max_val - min_val) + + # Clamp ratio to [0, 1] + ratio = max(0.0, min(1.0, ratio)) + + # Blue (0) to Red (1) + red = int(255 * ratio) + blue = int(255 * (1 - ratio)) + green = 0 + + return f"#{red:02x}{green:02x}{blue:02x}" + + @staticmethod + def _apply_categorical_color_palette(category: str) -> str: + """Apply categorical color palette. + + Args: + category: Category value + + Returns: + Hex color string from palette + """ + # D3.js Category10 palette + palette = [ + "#1f77b4", + "#ff7f0e", + "#2ca02c", + "#d62728", + "#9467bd", + "#8c564b", + "#e377c2", + "#7f7f7f", + "#bcbd22", + "#17becf", + ] + + # Use hash of category string to select color + category_hash = hash(category) + color_index = category_hash % len(palette) + + return palette[color_index] diff --git a/net_vis/models.py b/net_vis/models.py new file mode 100644 index 0000000..98f9cc6 --- /dev/null +++ b/net_vis/models.py @@ -0,0 +1,130 @@ +"""Data models for graph visualization.""" + +from dataclasses import dataclass, field +from typing import Any + + +@dataclass +class Node: + """Represents a graph vertex with position and visual properties. + + Attributes: + id: Unique node identifier (converted to string) + label: Optional display label + x: X-coordinate position + y: Y-coordinate position + color: Optional color value (hex string or color name) + metadata: Additional node attributes from source graph + """ + + id: str + label: str | None = None + x: float = 0.0 + y: float = 0.0 + color: str | None = None + metadata: dict[str, Any] = field(default_factory=dict) + + +@dataclass +class Edge: + """Represents a graph edge with optional visual properties. + + Attributes: + source: Source node ID + target: Target node ID + label: Optional display label + weight: Optional edge weight + metadata: Additional edge attributes from source graph + """ + + source: str + target: str + label: str | None = None + weight: float | None = None + metadata: dict[str, Any] = field(default_factory=dict) + + +@dataclass +class GraphLayer: + """Represents a single network visualization layer. + + Corresponds to one NetworkX graph object in a scene. + + Attributes: + layer_id: Unique layer identifier + nodes: List of nodes in this layer + edges: List of edges in this layer + metadata: Additional layer metadata + """ + + layer_id: str + nodes: list[Node] = field(default_factory=list) + edges: list[Edge] = field(default_factory=list) + metadata: dict[str, Any] = field(default_factory=dict) + + +@dataclass +class Scene: + """Represents a complete visualization container. + + Top-level structure for JSON/HTML export, containing one or more graph layers. + + Attributes: + layers: List of graph layers to visualize + title: Optional scene title + metadata: Additional scene metadata + """ + + layers: list[GraphLayer] = field(default_factory=list) + title: str | None = None + metadata: dict[str, Any] = field(default_factory=dict) + + def to_dict(self) -> dict[str, Any]: + """Convert scene to dictionary format for MIME renderer. + + Returns: + Dictionary representation compatible with netvis MIME renderer format. + """ + # Combine all nodes and links from all layers + all_nodes = [] + all_links = [] + + for layer in self.layers: + # Convert nodes to netvis format + for node in layer.nodes: + node_dict: dict[str, Any] = { + "id": node.id, + "x": node.x, + "y": node.y, + } + if node.label is not None: + node_dict["name"] = node.label + if node.color is not None: + node_dict["category"] = node.color + # Add metadata as additional fields + node_dict.update(node.metadata) + all_nodes.append(node_dict) + + # Convert edges to netvis format (links) + for edge in layer.edges: + link_dict: dict[str, Any] = { + "source": edge.source, + "target": edge.target, + } + if edge.label is not None: + link_dict["label"] = edge.label + if edge.weight is not None: + link_dict["value"] = edge.weight + # Add metadata as additional fields + link_dict.update(edge.metadata) + all_links.append(link_dict) + + result: dict[str, Any] = { + "nodes": all_nodes, + "links": all_links, + } + + if self.title: + result["title"] = self.title + + return result diff --git a/net_vis/plotter.py b/net_vis/plotter.py new file mode 100644 index 0000000..fa058af --- /dev/null +++ b/net_vis/plotter.py @@ -0,0 +1,190 @@ +"""High-level API for plotting NetworkX graphs in JupyterLab.""" + +import json +from collections.abc import Callable +from typing import Any + +from .adapters.networkx_adapter import NetworkXAdapter +from .models import Scene + + +class Plotter: + """Main API for visualizing NetworkX graphs in JupyterLab. + + Provides a simple interface to convert NetworkX graph objects into + interactive visualizations using the netvis MIME renderer. Supports + all NetworkX graph types (Graph, DiGraph, MultiGraph, MultiDiGraph) + with automatic attribute preservation and customizable styling. + + Examples: + Basic visualization: + >>> from net_vis import Plotter + >>> import networkx as nx + >>> G = nx.karate_club_graph() + >>> plotter = Plotter(title="Karate Club") + >>> plotter.add_networkx(G) + + Custom styling with attribute mapping: + >>> plotter = Plotter() + >>> plotter.add_networkx( + ... G, + ... node_color="club", + ... node_label="name", + ... layout='kamada_kawai' + ... ) + + Custom styling with functions: + >>> plotter.add_networkx( + ... G, + ... node_color=lambda d: f"group_{d.get('club', 0)}", + ... node_label=lambda d: f"Node {d.get('id', '')}", + ... edge_label=lambda d: f"w={d.get('weight', 1.0)}" + ... ) + + Attributes: + _scene: Internal Scene object containing all visualization layers + _layer_counter: Counter for auto-generating unique layer IDs + """ + + def __init__(self, title: str | None = None) -> None: + """Initialize plotter with optional scene title. + + Args: + title: Optional title for the visualization scene + """ + self._scene = Scene(title=title) + self._layer_counter = 0 + + def _generate_layer_id(self) -> str: + """Generate unique layer ID. + + Returns: + Unique layer identifier string + """ + layer_id = f"layer_{self._layer_counter}" + self._layer_counter += 1 + return layer_id + + def add_networkx( + self, + graph: Any, + *, + layer_id: str | None = None, + layout: str | Callable | None = None, + node_color: str | Callable | None = None, + node_label: str | Callable | None = None, + edge_label: str | Callable | None = None, + ) -> str: + """Add NetworkX graph as visualization layer. + + Converts a NetworkX graph to a visualization layer with automatic + node/edge extraction, layout computation, and styling. Supports all + NetworkX graph types with automatic type detection. + + Args: + graph: NetworkX graph object (Graph/DiGraph/MultiGraph/MultiDiGraph). + All node and edge attributes are preserved in metadata. + layer_id: Custom layer ID (auto-generated if None). + layout: Layout algorithm or custom function: + - 'spring': Force-directed layout (default) + - 'kamada_kawai': Kamada-Kawai path-length cost minimization + - 'spectral': Spectral layout using graph Laplacian + - 'circular': Nodes arranged in a circle + - 'random': Random node positions + - callable: Custom function(graph) -> dict[node_id, (x, y)] + - None: Use existing 'pos' attribute or fall back to spring + node_color: Node color mapping: + - str: Attribute name to use for color values + - callable: Function(node_data) -> color_value + - None: No color mapping (default) + node_label: Node label mapping: + - str: Attribute name to use for labels + - callable: Function(node_data) -> label_string + - None: No label mapping (default) + edge_label: Edge label mapping: + - str: Attribute name to use for labels + - callable: Function(edge_data) -> label_string + - None: No label mapping (default) + + Returns: + str: ID of the added layer (auto-generated or custom) + + Raises: + TypeError: If graph is not a NetworkX graph object + ValueError: If layout computation fails + + Examples: + Basic usage: + >>> plotter = Plotter() + >>> G = nx.karate_club_graph() + >>> layer_id = plotter.add_networkx(G) + + With layout control: + >>> plotter.add_networkx(G, layout='kamada_kawai') + + With attribute-based styling: + >>> G.nodes[0]['color'] = 'red' + >>> plotter.add_networkx(G, node_color='color') + + With function-based styling: + >>> plotter.add_networkx( + ... G, + ... node_color=lambda d: 'red' if d.get('club') == 0 else 'blue', + ... node_label=lambda d: f"Node {d.get('id', '')}" + ... ) + + Notes: + - All graph types (Graph, DiGraph, MultiGraph, MultiDiGraph) are supported + - DiGraph edges include 'directed': True in metadata + - MultiGraph edges include 'edge_key' in metadata + - Multiple edges are expanded to independent Edge objects + - NaN/inf positions trigger automatic fallback to random layout + """ + # Validate input is a NetworkX graph + if not hasattr(graph, "nodes") or not hasattr(graph, "edges"): + raise TypeError(f"Expected NetworkX graph object, got {type(graph).__name__}") + + # Generate layer ID if not provided + if layer_id is None: + layer_id = self._generate_layer_id() + + # Convert NetworkX graph to GraphLayer using adapter + graph_layer = NetworkXAdapter.convert_graph( + graph, + layout=layout, + node_color=node_color, + node_label=node_label, + edge_label=edge_label, + ) + graph_layer.layer_id = layer_id + + # Add layer to scene + self._scene.layers.append(graph_layer) + + return layer_id + + def to_json(self) -> str: + """Export scene structure as JSON string. + + Returns: + JSON string representation of the scene + """ + scene_dict = self._scene.to_dict() + return json.dumps(scene_dict, indent=2) + + def _repr_mimebundle_(self, include=None, exclude=None) -> dict: + """Return MIME bundle for IPython/JupyterLab display. + + Args: + include: Optional list of MIME types to include + exclude: Optional list of MIME types to exclude + + Returns: + Dictionary mapping MIME types to content + """ + scene_dict = self._scene.to_dict() + + return { + "application/vnd.netvis+json": {"data": json.dumps(scene_dict)}, + "text/plain": f"", + } diff --git a/net_vis/tests/test_networkx_adapter.py b/net_vis/tests/test_networkx_adapter.py new file mode 100644 index 0000000..9cbd4c3 --- /dev/null +++ b/net_vis/tests/test_networkx_adapter.py @@ -0,0 +1,550 @@ +"""Tests for NetworkXAdapter conversion functionality.""" + +import pytest + +# Skip all tests if networkx is not installed +pytest.importorskip("networkx") + +import networkx as nx + +from net_vis.adapters.networkx_adapter import NetworkXAdapter + + +class TestNetworkXAdapterConversion: + """Tests for basic graph conversion.""" + + def test_convert_graph_with_simple_graph(self): + """Test NetworkXAdapter.convert_graph with simple nx.Graph.""" + G = nx.Graph() + G.add_edge(1, 2) + G.add_edge(2, 3) + + layer = NetworkXAdapter.convert_graph(G) + + assert layer is not None + assert len(layer.nodes) == 3 + assert len(layer.edges) == 2 + assert layer.metadata["graph_type"] == "graph" + + def test_convert_graph_empty_graph(self): + """Test NetworkXAdapter handles empty graph (0 nodes).""" + G = nx.Graph() + + layer = NetworkXAdapter.convert_graph(G) + + assert layer is not None + assert len(layer.nodes) == 0 + assert len(layer.edges) == 0 + + +class TestNetworkXAdapterAttributes: + """Tests for attribute preservation.""" + + def test_preserves_node_attributes_in_metadata(self): + """Test NetworkXAdapter preserves all node attributes in metadata.""" + G = nx.Graph() + G.add_node(1, name="Node 1", value=10, category="A") + G.add_node(2, name="Node 2", value=20, category="B") + + layer = NetworkXAdapter.convert_graph(G) + + node1 = next(n for n in layer.nodes if n.id == "1") + assert node1.metadata == {"name": "Node 1", "value": 10, "category": "A"} + + node2 = next(n for n in layer.nodes if n.id == "2") + assert node2.metadata == {"name": "Node 2", "value": 20, "category": "B"} + + def test_preserves_edge_attributes_in_metadata(self): + """Test NetworkXAdapter preserves all edge attributes in metadata.""" + G = nx.Graph() + G.add_edge(1, 2, weight=5.0, label="connects", type="strong") + G.add_edge(2, 3, weight=3.0, label="links") + + layer = NetworkXAdapter.convert_graph(G) + + edge1 = next(e for e in layer.edges if e.source == "1" and e.target == "2") + assert edge1.metadata == {"weight": 5.0, "label": "connects", "type": "strong"} + + edge2 = next(e for e in layer.edges if e.source == "2" and e.target == "3") + assert edge2.metadata == {"weight": 3.0, "label": "links"} + + +class TestNetworkXAdapterLayout: + """Tests for layout computation.""" + + def test_applies_spring_layout_by_default(self): + """Test NetworkXAdapter applies spring layout by default.""" + G = nx.Graph() + G.add_edge(1, 2) + G.add_edge(2, 3) + + layer = NetworkXAdapter.convert_graph(G) + + # Verify all nodes have non-zero positions (spring layout computed) + for node in layer.nodes: + # Positions should exist and be floats + assert isinstance(node.x, float) + assert isinstance(node.y, float) + # At least some nodes should have non-zero positions + # (spring layout spreads nodes out) + + # Verify at least one node has non-zero position + has_nonzero = any(node.x != 0.0 or node.y != 0.0 for node in layer.nodes) + assert has_nonzero + + +class TestNetworkXAdapterGraphTypes: + """Tests for different NetworkX graph types.""" + + def test_detect_graph_type_graph(self): + """Test _detect_graph_type identifies Graph.""" + G = nx.Graph() + graph_type = NetworkXAdapter._detect_graph_type(G) + assert graph_type == "graph" + + def test_detect_graph_type_digraph(self): + """Test _detect_graph_type identifies DiGraph.""" + G = nx.DiGraph() + graph_type = NetworkXAdapter._detect_graph_type(G) + assert graph_type == "digraph" + + def test_detect_graph_type_multigraph(self): + """Test _detect_graph_type identifies MultiGraph.""" + G = nx.MultiGraph() + graph_type = NetworkXAdapter._detect_graph_type(G) + assert graph_type == "multigraph" + + def test_detect_graph_type_multidigraph(self): + """Test _detect_graph_type identifies MultiDiGraph.""" + G = nx.MultiDiGraph() + graph_type = NetworkXAdapter._detect_graph_type(G) + assert graph_type == "multidigraph" + + +class TestNetworkXAdapterStyling: + """Tests for node and edge styling.""" + + def test_node_color_with_attribute_name(self): + """Test node_color with attribute name (string).""" + G = nx.Graph() + G.add_node(1, color="red") + G.add_node(2, color="blue") + G.add_edge(1, 2) + + layer = NetworkXAdapter.convert_graph(G, node_color="color") + + node1 = next(n for n in layer.nodes if n.id == "1") + assert node1.color == "red" + + node2 = next(n for n in layer.nodes if n.id == "2") + assert node2.color == "blue" + + def test_node_color_with_callable_function(self): + """Test node_color with callable function.""" + G = nx.Graph() + G.add_node(1, value=10) + G.add_node(2, value=20) + G.add_edge(1, 2) + + def color_fn(node_data): + return f"value_{node_data.get('value', 0)}" + + layer = NetworkXAdapter.convert_graph(G, node_color=color_fn) + + node1 = next(n for n in layer.nodes if n.id == "1") + assert node1.color == "value_10" + + node2 = next(n for n in layer.nodes if n.id == "2") + assert node2.color == "value_20" + + def test_node_label_with_attribute_name(self): + """Test node_label with attribute name (string).""" + G = nx.Graph() + G.add_node(1, name="Alice") + G.add_node(2, name="Bob") + G.add_edge(1, 2) + + layer = NetworkXAdapter.convert_graph(G, node_label="name") + + node1 = next(n for n in layer.nodes if n.id == "1") + assert node1.label == "Alice" + + node2 = next(n for n in layer.nodes if n.id == "2") + assert node2.label == "Bob" + + def test_node_label_with_callable_function(self): + """Test node_label with callable function.""" + G = nx.Graph() + G.add_node(1, value=10) + G.add_node(2, value=20) + G.add_edge(1, 2) + + def label_fn(node_data): + return f"Node {node_data.get('value', 0)}" + + layer = NetworkXAdapter.convert_graph(G, node_label=label_fn) + + node1 = next(n for n in layer.nodes if n.id == "1") + assert node1.label == "Node 10" + + def test_edge_label_with_attribute_name(self): + """Test edge_label with attribute name (string).""" + G = nx.Graph() + G.add_edge(1, 2, relation="friend") + G.add_edge(2, 3, relation="colleague") + + layer = NetworkXAdapter.convert_graph(G, edge_label="relation") + + edge1 = next(e for e in layer.edges if e.source == "1" and e.target == "2") + assert edge1.label == "friend" + + edge2 = next(e for e in layer.edges if e.source == "2" and e.target == "3") + assert edge2.label == "colleague" + + def test_edge_label_with_callable_function(self): + """Test edge_label with callable function.""" + G = nx.Graph() + G.add_edge(1, 2, weight=5.0) + G.add_edge(2, 3, weight=3.0) + + def label_fn(edge_data): + return f"w={edge_data.get('weight', 0)}" + + layer = NetworkXAdapter.convert_graph(G, edge_label=label_fn) + + edge1 = next(e for e in layer.edges if e.source == "1" and e.target == "2") + assert edge1.label == "w=5.0" + + def test_numeric_color_values_trigger_continuous_scale(self): + """Test numeric color values trigger continuous scale.""" + values = [1.0, 2.0, 3.0, 4.0] + color_type = NetworkXAdapter._detect_color_type(values) + assert color_type == "numeric" + + def test_string_color_values_trigger_categorical_palette(self): + """Test string color values trigger categorical palette.""" + values = ["red", "blue", "green"] + color_type = NetworkXAdapter._detect_color_type(values) + assert color_type == "categorical" + + def test_missing_attribute_uses_default_none(self): + """Test missing attribute uses default (None) without error.""" + G = nx.Graph() + G.add_node(1) # No color attribute + G.add_node(2, color="red") + G.add_edge(1, 2) + + layer = NetworkXAdapter.convert_graph(G, node_color="color") + + node1 = next(n for n in layer.nodes if n.id == "1") + assert node1.color is None # Missing attribute + + node2 = next(n for n in layer.nodes if n.id == "2") + assert node2.color == "red" + + +class TestNetworkXAdapterLayouts: + """Tests for layout algorithms.""" + + def test_layout_spring_applies_spring_layout(self): + """Test layout='spring' applies spring_layout.""" + G = nx.Graph() + G.add_edge(1, 2) + G.add_edge(2, 3) + + layer = NetworkXAdapter.convert_graph(G, layout="spring") + + # Verify all nodes have positions + assert len(layer.nodes) == 3 + for node in layer.nodes: + assert isinstance(node.x, float) + assert isinstance(node.y, float) + + def test_layout_kamada_kawai_applies_kamada_kawai_layout(self): + """Test layout='kamada_kawai' applies kamada_kawai_layout.""" + G = nx.Graph() + G.add_edge(1, 2) + G.add_edge(2, 3) + + layer = NetworkXAdapter.convert_graph(G, layout="kamada_kawai") + + # Verify all nodes have positions + assert len(layer.nodes) == 3 + for node in layer.nodes: + assert isinstance(node.x, float) + assert isinstance(node.y, float) + + def test_layout_spectral_applies_spectral_layout(self): + """Test layout='spectral' applies spectral_layout.""" + G = nx.Graph() + G.add_edge(1, 2) + G.add_edge(2, 3) + + layer = NetworkXAdapter.convert_graph(G, layout="spectral") + + # Verify all nodes have positions + assert len(layer.nodes) == 3 + for node in layer.nodes: + assert isinstance(node.x, float) + assert isinstance(node.y, float) + + def test_layout_circular_applies_circular_layout(self): + """Test layout='circular' applies circular_layout.""" + G = nx.Graph() + G.add_edge(1, 2) + G.add_edge(2, 3) + + layer = NetworkXAdapter.convert_graph(G, layout="circular") + + # Verify all nodes have positions + assert len(layer.nodes) == 3 + for node in layer.nodes: + assert isinstance(node.x, float) + assert isinstance(node.y, float) + + def test_layout_random_applies_random_layout(self): + """Test layout='random' applies random_layout.""" + G = nx.Graph() + G.add_edge(1, 2) + G.add_edge(2, 3) + + layer = NetworkXAdapter.convert_graph(G, layout="random") + + # Verify all nodes have positions + assert len(layer.nodes) == 3 + for node in layer.nodes: + assert isinstance(node.x, float) + assert isinstance(node.y, float) + + def test_layout_with_custom_callable_function(self): + """Test layout with custom callable function.""" + G = nx.Graph() + G.add_node(1) + G.add_node(2) + G.add_edge(1, 2) + + def custom_layout(graph): + """Custom layout placing nodes at specific positions.""" + return {1: (0.0, 0.0), 2: (1.0, 1.0)} + + layer = NetworkXAdapter.convert_graph(G, layout=custom_layout) + + node1 = next(n for n in layer.nodes if n.id == "1") + assert node1.x == 0.0 + assert node1.y == 0.0 + + node2 = next(n for n in layer.nodes if n.id == "2") + assert node2.x == 1.0 + assert node2.y == 1.0 + + def test_layout_none_uses_existing_pos_attribute(self): + """Test layout=None uses existing 'pos' attribute if present.""" + G = nx.Graph() + G.add_node(1, pos=(0.5, 0.5)) + G.add_node(2, pos=(0.7, 0.3)) + G.add_edge(1, 2) + + layer = NetworkXAdapter.convert_graph(G, layout=None) + + node1 = next(n for n in layer.nodes if n.id == "1") + assert node1.x == 0.5 + assert node1.y == 0.5 + + node2 = next(n for n in layer.nodes if n.id == "2") + assert node2.x == 0.7 + assert node2.y == 0.3 + + def test_layout_none_defaults_to_spring_when_no_pos(self): + """Test layout=None defaults to spring when no 'pos' attribute.""" + G = nx.Graph() + G.add_edge(1, 2) + G.add_edge(2, 3) + + layer = NetworkXAdapter.convert_graph(G, layout=None) + + # Verify all nodes have positions (spring layout applied) + assert len(layer.nodes) == 3 + for node in layer.nodes: + assert isinstance(node.x, float) + assert isinstance(node.y, float) + + def test_explicit_layout_overrides_existing_pos(self): + """Test explicit layout overrides existing 'pos' attribute.""" + G = nx.Graph() + G.add_node(1, pos=(0.5, 0.5)) + G.add_node(2, pos=(0.7, 0.3)) + G.add_edge(1, 2) + + layer = NetworkXAdapter.convert_graph(G, layout="circular") + + # Positions should be different from original pos attribute + # (we can't predict exact values, but they should be valid floats) + node1 = next(n for n in layer.nodes if n.id == "1") + assert isinstance(node1.x, float) + assert isinstance(node1.y, float) + + def test_layout_failure_falls_back_to_random_with_warning(self): + """Test layout failure (NaN, inf) falls back to random with warning.""" + import warnings + + G = nx.Graph() + G.add_edge(1, 2) + + def failing_layout(graph): + """Layout that returns NaN values.""" + return {1: (float("nan"), 0.0), 2: (1.0, 1.0)} + + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + layer = NetworkXAdapter.convert_graph(G, layout=failing_layout) + + # Should have warned about invalid positions + assert len(w) > 0 + assert "invalid positions" in str(w[-1].message).lower() + + # Should still have valid positions (from fallback) + for node in layer.nodes: + assert isinstance(node.x, float) + assert isinstance(node.y, float) + import math + + assert not math.isnan(node.x) + assert not math.isnan(node.y) + + +class TestNetworkXAdapterMultipleGraphTypes: + """Tests for all NetworkX graph types support.""" + + def test_convert_graph_with_undirected_graph(self): + """Test NetworkXAdapter with nx.Graph (undirected).""" + G = nx.Graph() + G.add_edge(1, 2) + G.add_edge(2, 3) + + layer = NetworkXAdapter.convert_graph(G) + + assert layer is not None + assert len(layer.nodes) == 3 + assert len(layer.edges) == 2 + assert layer.metadata["graph_type"] == "graph" + + # Verify edges don't have 'directed' flag + for edge in layer.edges: + assert "directed" not in edge.metadata or not edge.metadata["directed"] + + def test_convert_graph_with_digraph(self): + """Test NetworkXAdapter with nx.DiGraph (directed).""" + G = nx.DiGraph() + G.add_edge(1, 2) + G.add_edge(2, 3) + + layer = NetworkXAdapter.convert_graph(G) + + assert layer is not None + assert len(layer.nodes) == 3 + assert len(layer.edges) == 2 + assert layer.metadata["graph_type"] == "digraph" + + # Verify all edges have 'directed' flag set to True + for edge in layer.edges: + assert edge.metadata["directed"] is True + + def test_convert_graph_with_multigraph(self): + """Test NetworkXAdapter with nx.MultiGraph (multi-undirected).""" + G = nx.MultiGraph() + G.add_edge(1, 2, weight=1.0) + G.add_edge(1, 2, weight=2.0) # Second edge between same nodes + G.add_edge(2, 3, weight=3.0) + + layer = NetworkXAdapter.convert_graph(G) + + assert layer is not None + assert len(layer.nodes) == 3 + # Should have 3 edges (2 between nodes 1-2, 1 between nodes 2-3) + assert len(layer.edges) == 3 + assert layer.metadata["graph_type"] == "multigraph" + + # Verify edge keys are preserved + for edge in layer.edges: + assert "edge_key" in edge.metadata + + def test_convert_graph_with_multidigraph(self): + """Test NetworkXAdapter with nx.MultiDiGraph (multi-directed).""" + G = nx.MultiDiGraph() + G.add_edge(1, 2, relation="friend") + G.add_edge(1, 2, relation="colleague") # Second edge between same nodes + G.add_edge(2, 3, relation="manager") + + layer = NetworkXAdapter.convert_graph(G) + + assert layer is not None + assert len(layer.nodes) == 3 + # Should have 3 edges + assert len(layer.edges) == 3 + assert layer.metadata["graph_type"] == "multidigraph" + + # Verify both edge keys and direction are preserved + for edge in layer.edges: + assert "edge_key" in edge.metadata + assert edge.metadata["directed"] is True + + def test_digraph_edge_direction_preserved(self): + """Test DiGraph edge direction preserved in output.""" + G = nx.DiGraph() + G.add_edge(1, 2) + G.add_edge(2, 1) # Opposite direction + + layer = NetworkXAdapter.convert_graph(G) + + # Should have 2 edges (one in each direction) + assert len(layer.edges) == 2 + + # Find edges by source/target + edge_1_to_2 = next(e for e in layer.edges if e.source == "1" and e.target == "2") + edge_2_to_1 = next(e for e in layer.edges if e.source == "2" and e.target == "1") + + # Both should be marked as directed + assert edge_1_to_2.metadata["directed"] is True + assert edge_2_to_1.metadata["directed"] is True + + def test_multigraph_edge_keys_preserved(self): + """Test MultiGraph edge keys preserved in edge metadata.""" + G = nx.MultiGraph() + # Add multiple edges with custom keys + G.add_edge(1, 2, key="first", weight=1.0) + G.add_edge(1, 2, key="second", weight=2.0) + G.add_edge(1, 2, key="third", weight=3.0) + + layer = NetworkXAdapter.convert_graph(G) + + # Should have 3 edges + assert len(layer.edges) == 3 + + # Verify all edges have edge_key preserved + edge_keys = [edge.metadata["edge_key"] for edge in layer.edges] + # NetworkX may use integer keys by default, but our custom keys should be preserved + assert len(edge_keys) == 3 + assert all("edge_key" in edge.metadata for edge in layer.edges) + + def test_multigraph_expands_multiple_edges(self): + """Test MultiGraph expands multiple edges to independent Edge objects.""" + G = nx.MultiGraph() + # Add 3 edges between nodes 1 and 2 + G.add_edge(1, 2, label="edge_a") + G.add_edge(1, 2, label="edge_b") + G.add_edge(1, 2, label="edge_c") + + layer = NetworkXAdapter.convert_graph(G) + + # Should create 3 independent Edge objects + assert len(layer.edges) == 3 + + # All edges should be between nodes "1" and "2" + for edge in layer.edges: + assert (edge.source == "1" and edge.target == "2") or ( + edge.source == "2" and edge.target == "1" + ) + + # Each edge should have unique edge_key + edge_keys = [edge.metadata["edge_key"] for edge in layer.edges] + assert len(set(edge_keys)) == 3 # All keys should be unique diff --git a/net_vis/tests/test_plotter.py b/net_vis/tests/test_plotter.py new file mode 100644 index 0000000..1739b7b --- /dev/null +++ b/net_vis/tests/test_plotter.py @@ -0,0 +1,282 @@ +"""Tests for Plotter class public API.""" + +import json + +import pytest + +# Skip all tests if networkx is not installed +pytest.importorskip("networkx") + +import networkx as nx + +from net_vis import Plotter + + +def parse_mime_data(bundle: dict) -> dict: + """Parse MIME bundle data to get nodes and links.""" + mime_data = bundle["application/vnd.netvis+json"] + return json.loads(mime_data["data"]) + + +class TestPlotterInit: + """Tests for Plotter initialization.""" + + def test_init_without_title(self): + """Test Plotter.__init__ without title.""" + plotter = Plotter() + assert plotter._scene is not None + assert plotter._scene.title is None + assert plotter._layer_counter == 0 + + def test_init_with_title(self): + """Test Plotter.__init__ with optional title.""" + plotter = Plotter(title="Test Graph") + assert plotter._scene is not None + assert plotter._scene.title == "Test Graph" + assert plotter._layer_counter == 0 + + +class TestPlotterAddNetworkX: + """Tests for Plotter.add_networkx method.""" + + def test_add_networkx_accepts_graph(self): + """Test Plotter.add_networkx accepts nx.Graph.""" + plotter = Plotter() + G = nx.Graph() + G.add_edge(1, 2) + + layer_id = plotter.add_networkx(G) + assert layer_id is not None + assert isinstance(layer_id, str) + + def test_add_networkx_returns_layer_id(self): + """Test Plotter.add_networkx returns layer_id.""" + plotter = Plotter() + G = nx.Graph() + G.add_edge(1, 2) + + layer_id = plotter.add_networkx(G) + assert layer_id == "layer_0" + + # Add another graph + layer_id2 = plotter.add_networkx(G) + assert layer_id2 == "layer_1" + + def test_add_networkx_with_custom_layer_id(self): + """Test Plotter.add_networkx with custom layer_id.""" + plotter = Plotter() + G = nx.Graph() + G.add_edge(1, 2) + + layer_id = plotter.add_networkx(G, layer_id="custom_layer") + assert layer_id == "custom_layer" + + def test_add_networkx_invalid_type_raises_typeerror(self): + """Test Plotter.add_networkx raises TypeError for non-NetworkX objects.""" + plotter = Plotter() + + with pytest.raises(TypeError, match="Expected NetworkX graph object"): + plotter.add_networkx("not a graph") + + with pytest.raises(TypeError, match="Expected NetworkX graph object"): + plotter.add_networkx({"nodes": [], "edges": []}) + + +class TestPlotterReprMimeBundle: + """Tests for Plotter._repr_mimebundle_ method.""" + + def test_repr_mimebundle_returns_dict(self): + """Test Plotter._repr_mimebundle_ returns dict with application/vnd.netvis+json.""" + plotter = Plotter() + G = nx.Graph() + G.add_edge(1, 2) + plotter.add_networkx(G) + + bundle = plotter._repr_mimebundle_() + assert isinstance(bundle, dict) + assert "application/vnd.netvis+json" in bundle + assert "text/plain" in bundle + + def test_repr_mimebundle_contains_valid_data(self): + """Test _repr_mimebundle_ contains valid netvis JSON data.""" + plotter = Plotter() + G = nx.Graph() + G.add_edge(1, 2) + plotter.add_networkx(G) + + bundle = plotter._repr_mimebundle_() + mime_data = bundle["application/vnd.netvis+json"] + + # New format: data is wrapped with 'data' key as JSON string + assert "data" in mime_data + data = json.loads(mime_data["data"]) + assert "nodes" in data + assert "links" in data + assert len(data["nodes"]) == 2 # Nodes 1 and 2 + assert len(data["links"]) == 1 # Edge 1-2 + + +class TestPlotterIntegration: + """Integration tests for Plotter with real NetworkX graphs.""" + + def test_plotter_with_karate_club_graph(self): + """Test Plotter with nx.karate_club_graph integration.""" + plotter = Plotter(title="Karate Club") + G = nx.karate_club_graph() + + layer_id = plotter.add_networkx(G) + + assert layer_id == "layer_0" + assert len(plotter._scene.layers) == 1 + + bundle = plotter._repr_mimebundle_() + data = parse_mime_data(bundle) + + # Karate club has 34 nodes and 78 edges + assert len(data["nodes"]) == 34 + assert len(data["links"]) == 78 + + def test_plotter_to_json(self): + """Test Plotter.to_json returns valid JSON string.""" + plotter = Plotter() + G = nx.Graph() + G.add_edge(1, 2) + plotter.add_networkx(G) + + json_str = plotter.to_json() + assert isinstance(json_str, str) + assert "nodes" in json_str + assert "links" in json_str + + # Verify it's valid JSON + import json + + data = json.loads(json_str) + assert "nodes" in data + assert "links" in data + + +class TestPlotterStyling: + """Tests for Plotter styling parameters.""" + + def test_add_networkx_with_all_styling_parameters(self): + """Test Plotter.add_networkx with all styling parameters.""" + plotter = Plotter() + G = nx.Graph() + G.add_node(1, color="red", name="Node A") + G.add_node(2, color="blue", name="Node B") + G.add_edge(1, 2, relation="connects") + + layer_id = plotter.add_networkx( + G, + node_color="color", + node_label="name", + edge_label="relation", + ) + + assert layer_id == "layer_0" + bundle = plotter._repr_mimebundle_() + data = parse_mime_data(bundle) + + # Verify nodes have colors and labels + assert len(data["nodes"]) == 2 + node1 = next(n for n in data["nodes"] if n["id"] == "1") + assert node1["category"] == "red" + assert node1["name"] == "Node A" + + # Verify edges have labels + assert len(data["links"]) == 1 + assert data["links"][0]["label"] == "connects" + + +class TestPlotterMultipleGraphTypes: + """Tests for Plotter with all NetworkX graph types.""" + + def test_plotter_accepts_graph(self): + """Test Plotter accepts nx.Graph with same API.""" + plotter = Plotter() + G = nx.Graph() + G.add_edge(1, 2) + + layer_id = plotter.add_networkx(G) + + assert layer_id is not None + assert len(plotter._scene.layers) == 1 + bundle = plotter._repr_mimebundle_() + data = parse_mime_data(bundle) + assert len(data["nodes"]) == 2 + + def test_plotter_accepts_digraph(self): + """Test Plotter accepts nx.DiGraph with same API.""" + plotter = Plotter() + G = nx.DiGraph() + G.add_edge(1, 2) + + layer_id = plotter.add_networkx(G) + + assert layer_id is not None + assert len(plotter._scene.layers) == 1 + bundle = plotter._repr_mimebundle_() + data = parse_mime_data(bundle) + assert len(data["nodes"]) == 2 + # Verify directed edges are marked + assert len(data["links"]) == 1 + + def test_plotter_accepts_multigraph(self): + """Test Plotter accepts nx.MultiGraph with same API.""" + plotter = Plotter() + G = nx.MultiGraph() + G.add_edge(1, 2) + G.add_edge(1, 2) # Multiple edges + + layer_id = plotter.add_networkx(G) + + assert layer_id is not None + assert len(plotter._scene.layers) == 1 + bundle = plotter._repr_mimebundle_() + data = parse_mime_data(bundle) + assert len(data["nodes"]) == 2 + # Should have 2 edges (expanded) + assert len(data["links"]) == 2 + + def test_plotter_accepts_multidigraph(self): + """Test Plotter accepts nx.MultiDiGraph with same API.""" + plotter = Plotter() + G = nx.MultiDiGraph() + G.add_edge(1, 2) + G.add_edge(1, 2) # Multiple directed edges + + layer_id = plotter.add_networkx(G) + + assert layer_id is not None + assert len(plotter._scene.layers) == 1 + bundle = plotter._repr_mimebundle_() + data = parse_mime_data(bundle) + assert len(data["nodes"]) == 2 + # Should have 2 edges (expanded) + assert len(data["links"]) == 2 + + def test_plotter_accepts_all_graph_types_with_same_api(self): + """Test Plotter.add_networkx accepts all 4 graph types with same API.""" + graph_types = [ + nx.Graph(), + nx.DiGraph(), + nx.MultiGraph(), + nx.MultiDiGraph(), + ] + + for graph in graph_types: + plotter = Plotter() + graph.add_edge(1, 2) + + # All graph types should work with same API + layer_id = plotter.add_networkx( + graph, + layout="spring", + node_color=None, + node_label=None, + edge_label=None, + ) + + assert layer_id is not None + assert len(plotter._scene.layers) == 1 diff --git a/package.json b/package.json index b94ff72..85e545b 100644 --- a/package.json +++ b/package.json @@ -1,7 +1,7 @@ { "name": "net_vis", - "version": "0.4.0", - "description": "NetVis is a package for interactive visualization Python NetworkX graphs within Jupyter Lab. It leverages D3.js for dynamic rendering and supports HTML export, making network analysis effortless.", + "version": "0.5.0", + "description": "NetVis is a package for interactive visualization of Python NetworkX graphs within JupyterLab. It leverages D3.js for dynamic rendering and provides a high-level Plotter API for effortless network analysis.", "keywords": [ "jupyter", "jupyterlab", diff --git a/pyproject.toml b/pyproject.toml index 3ad164c..1c9f91f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -8,9 +8,9 @@ build-backend = "hatchling.build" [project] name = "net_vis" -version = "0.4.0" +version = "0.5.0" # dynamic = ["version"] -description = "NetVis is a package for interactive visualization Python NetworkX graphs within Jupyter Lab. It leverages D3.js for dynamic rendering and supports HTML export, making network analysis effortless." +description = "NetVis is a package for interactive visualization of Python NetworkX graphs within JupyterLab. It leverages D3.js for dynamic rendering and provides a high-level Plotter API for effortless network analysis." readme = "README.md" license = { file = "LICENSE" } requires-python = ">=3.10" @@ -36,9 +36,14 @@ classifiers = [ "Programming Language :: Python :: 3.14", ] dependencies = [ + "networkx>=3.0", + "numpy>=2.0", ] [project.optional-dependencies] +full = [ + "scipy>=1.8", # Required for kamada_kawai and spectral layouts +] docs = [ "jupyter_sphinx", "nbsphinx", @@ -56,6 +61,7 @@ test = [ "pytest>=6.0", "ruff>=0.8.0", "pyright>=1.1.0", + "scipy>=1.8", # Include full dependencies for comprehensive testing ] [project.urls] @@ -128,6 +134,7 @@ exclude = [ ".venv", "venv", "venv-docker", + "net_vis/tests", ] pythonVersion = "3.10" typeCheckingMode = "basic" diff --git a/src/__tests__/mimePlugin.test.ts b/src/__tests__/mimePlugin.test.ts index 6bf965d..e358a61 100644 --- a/src/__tests__/mimePlugin.test.ts +++ b/src/__tests__/mimePlugin.test.ts @@ -112,10 +112,10 @@ describe('validateVersion', () => { }); it('should log success for matching versions', () => { - validateVersion('0.4.0'); + validateVersion('0.5.0'); expect(consoleLogSpy).toHaveBeenCalledWith( - expect.stringContaining('Version check passed: v0.4.0'), + expect.stringContaining('Version check passed: v0.5.0'), ); expect(consoleWarnSpy).not.toHaveBeenCalled(); }); @@ -127,7 +127,7 @@ describe('validateVersion', () => { expect.stringContaining('Version mismatch'), ); expect(consoleWarnSpy).toHaveBeenCalledWith( - expect.stringContaining('Frontend v0.4.0'), + expect.stringContaining('Frontend v0.5.0'), ); expect(consoleWarnSpy).toHaveBeenCalledWith( expect.stringContaining('Backend v0.3.0'), diff --git a/src/__tests__/renderer.test.ts b/src/__tests__/renderer.test.ts index b256304..06491d2 100644 --- a/src/__tests__/renderer.test.ts +++ b/src/__tests__/renderer.test.ts @@ -1,5 +1,6 @@ import { NetVisRenderer } from '../renderer'; import { IRenderMime } from '@jupyterlab/rendermime-interfaces'; +import packageJson from '../../package.json'; // Mock MIME model type interface IMimeModel { @@ -7,6 +8,8 @@ interface IMimeModel { metadata?: { [key: string]: any }; } +const CURRENT_VERSION = packageJson.version; + describe('NetVisRenderer', () => { const MIME_TYPE = 'application/vnd.netvis+json'; @@ -29,7 +32,7 @@ describe('NetVisRenderer', () => { data: { [MIME_TYPE]: { data: graphData, - version: '0.4.0', + version: CURRENT_VERSION, }, }, }; @@ -59,7 +62,7 @@ describe('NetVisRenderer', () => { data: { [MIME_TYPE]: { data: graphData, - version: '0.4.0', + version: CURRENT_VERSION, }, }, }; @@ -84,7 +87,7 @@ describe('NetVisRenderer', () => { const model: IMimeModel = { data: { [MIME_TYPE]: { - version: '0.4.0', + version: CURRENT_VERSION, }, }, }; @@ -107,7 +110,7 @@ describe('NetVisRenderer', () => { data: { [MIME_TYPE]: { data: 'invalid json', - version: '0.4.0', + version: CURRENT_VERSION, }, }, }; @@ -132,7 +135,7 @@ describe('NetVisRenderer', () => { data: { [MIME_TYPE]: { data: graphData, - version: '0.4.0', + version: CURRENT_VERSION, }, }, }; @@ -159,7 +162,7 @@ describe('NetVisRenderer', () => { data: { [MIME_TYPE]: { data: graphData, - version: '0.4.0', + version: CURRENT_VERSION, }, }, }; @@ -202,7 +205,7 @@ describe('NetVisRenderer', () => { data: { [MIME_TYPE]: { data: graphData1, - version: '0.4.0', + version: CURRENT_VERSION, }, }, }; @@ -211,7 +214,7 @@ describe('NetVisRenderer', () => { data: { [MIME_TYPE]: { data: graphData2, - version: '0.4.0', + version: CURRENT_VERSION, }, }, }; @@ -249,7 +252,7 @@ describe('NetVisRenderer', () => { data: { [MIME_TYPE]: { data: graphData, - version: '0.4.0', + version: CURRENT_VERSION, }, }, }; @@ -276,7 +279,7 @@ describe('NetVisRenderer', () => { data: { [MIME_TYPE]: { data: graphData, - version: '0.4.0', + version: CURRENT_VERSION, }, }, }; @@ -307,7 +310,7 @@ describe('NetVisRenderer', () => { data: { [MIME_TYPE]: { data: graphData, - version: '0.4.0', + version: CURRENT_VERSION, }, }, }; @@ -350,7 +353,7 @@ describe('NetVisRenderer', () => { data: { [MIME_TYPE]: { data: graphData, - version: '0.4.0', + version: CURRENT_VERSION, }, }, }; @@ -391,7 +394,7 @@ describe('NetVisRenderer', () => { data: { [MIME_TYPE]: { data: graphData, - version: '0.4.0', + version: CURRENT_VERSION, }, }, }; diff --git a/src/mimePlugin.ts b/src/mimePlugin.ts index f0510f7..024a224 100644 --- a/src/mimePlugin.ts +++ b/src/mimePlugin.ts @@ -1,5 +1,6 @@ import { IRenderMime } from '@jupyterlab/rendermime-interfaces'; import { Widget } from '@lumino/widgets'; +import packageJson from '../package.json'; /** * MIME type for NetVis graph data @@ -7,9 +8,9 @@ import { Widget } from '@lumino/widgets'; export const MIME_TYPE = 'application/vnd.netvis+json'; /** - * Frontend package version (should match package.json) + * Frontend package version (automatically loaded from package.json) */ -const FRONTEND_VERSION = '0.4.0'; +const FRONTEND_VERSION = packageJson.version; /** * Parse graph data string and handle empty data case.