Skip to content

Commit 828f70e

Browse files
feat(diagram): add direction, Mermaid output, and schema grouping
Bug fixes: - Fix isdigit() missing parentheses in _make_graph - Fix nested list creation in _make_graph - Remove dead code in make_dot - Fix invalid color code for Part tier - Replace eval() with safe _resolve_class() method New features: - Add direction parameter ("TB", "LR", "BT", "RL") for layout control - Add make_mermaid() method for web-friendly diagram output - Add group_by_schema parameter to cluster nodes by database schema - Update save() to support .mmd/.mermaid file extensions Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
1 parent 648bd1a commit 828f70e

File tree

2 files changed

+214
-20
lines changed

2 files changed

+214
-20
lines changed

src/datajoint/diagram.py

Lines changed: 210 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
from .dependencies import topo_sort
1818
from .errors import DataJointError
19+
from .settings import config
1920
from .table import Table, lookup_class_name
2021
from .user_tables import Computed, Imported, Lookup, Manual, Part, _AliasNode, _get_tier
2122

@@ -90,6 +91,12 @@ class Diagram(nx.DiGraph):
9091
-----
9192
``diagram + 1 - 1`` may differ from ``diagram - 1 + 1``.
9293
Only tables loaded in the connection are displayed.
94+
95+
Layout direction is controlled via ``dj.config.display.diagram_direction``
96+
(default ``"TB"``). Use ``dj.config.override()`` to change temporarily::
97+
98+
with dj.config.override(display_diagram_direction="LR"):
99+
dj.Diagram(schema).draw()
93100
"""
94101

95102
def __init__(self, source, context=None) -> None:
@@ -286,18 +293,42 @@ def _make_graph(self) -> nx.DiGraph:
286293
gaps = set(nx.algorithms.boundary.node_boundary(self, self.nodes_to_show)).intersection(
287294
nx.algorithms.boundary.node_boundary(nx.DiGraph(self).reverse(), self.nodes_to_show)
288295
)
289-
nodes = self.nodes_to_show.union(a for a in gaps if a.isdigit)
296+
nodes = self.nodes_to_show.union(a for a in gaps if a.isdigit())
290297
# construct subgraph and rename nodes to class names
291298
graph = nx.DiGraph(nx.DiGraph(self).subgraph(nodes))
292299
nx.set_node_attributes(graph, name="node_type", values={n: _get_tier(n) for n in graph})
293300
# relabel nodes to class names
294301
mapping = {node: lookup_class_name(node, self.context) or node for node in graph.nodes()}
295-
new_names = [mapping.values()]
302+
new_names = list(mapping.values())
296303
if len(new_names) > len(set(new_names)):
297304
raise DataJointError("Some classes have identical names. The Diagram cannot be plotted.")
298305
nx.relabel_nodes(graph, mapping, copy=False)
299306
return graph
300307

308+
def _resolve_class(self, name: str):
309+
"""
310+
Safely resolve a table class from a dotted name without eval().
311+
312+
Parameters
313+
----------
314+
name : str
315+
Dotted class name like "MyTable" or "Module.MyTable".
316+
317+
Returns
318+
-------
319+
type or None
320+
The table class if found, otherwise None.
321+
"""
322+
parts = name.split(".")
323+
obj = self.context.get(parts[0])
324+
for part in parts[1:]:
325+
if obj is None:
326+
return None
327+
obj = getattr(obj, part, None)
328+
if obj is not None and isinstance(obj, type) and issubclass(obj, Table):
329+
return obj
330+
return None
331+
301332
@staticmethod
302333
def _encapsulate_edge_attributes(graph: nx.DiGraph) -> None:
303334
"""
@@ -330,9 +361,39 @@ def _encapsulate_node_names(graph: nx.DiGraph) -> None:
330361
copy=False,
331362
)
332363

333-
def make_dot(self):
364+
def make_dot(self, group_by_schema: bool = False):
365+
"""
366+
Generate a pydot graph object.
367+
368+
Parameters
369+
----------
370+
group_by_schema : bool, optional
371+
If True, group nodes into clusters by their database schema.
372+
Default False.
373+
374+
Returns
375+
-------
376+
pydot.Dot
377+
The graph object ready for rendering.
378+
379+
Notes
380+
-----
381+
Layout direction is controlled via ``dj.config.display.diagram_direction``.
382+
"""
383+
direction = config.display.diagram_direction
334384
graph = self._make_graph()
335-
graph.nodes()
385+
386+
# Build schema mapping if grouping is requested
387+
schema_map = {}
388+
if group_by_schema:
389+
for full_name in self.nodes_to_show:
390+
# Extract schema from full table name like `schema`.`table` or "schema"."table"
391+
parts = full_name.replace('"', '`').split('`')
392+
if len(parts) >= 2:
393+
schema_name = parts[1] # schema is between first pair of backticks
394+
# Find the class name for this full_name
395+
class_name = lookup_class_name(full_name, self.context) or full_name
396+
schema_map[class_name] = schema_name
336397

337398
scale = 1.2 # scaling factor for fonts and boxes
338399
label_props = { # http://matplotlib.org/examples/color/named_colors.html
@@ -386,7 +447,7 @@ def make_dot(self):
386447
),
387448
Part: dict(
388449
shape="plaintext",
389-
color="#0000000",
450+
color="#00000000",
390451
fontcolor="black",
391452
fontsize=round(scale * 8),
392453
size=0.1 * scale,
@@ -398,6 +459,7 @@ def make_dot(self):
398459
self._encapsulate_node_names(graph)
399460
self._encapsulate_edge_attributes(graph)
400461
dot = nx.drawing.nx_pydot.to_pydot(graph)
462+
dot.set_rankdir(direction)
401463
for node in dot.get_nodes():
402464
node.set_shape("circle")
403465
name = node.get_name().strip('"')
@@ -409,9 +471,8 @@ def make_dot(self):
409471
node.set_fixedsize("shape" if props["fixed"] else False)
410472
node.set_width(props["size"])
411473
node.set_height(props["size"])
412-
if name.split(".")[0] in self.context:
413-
cls = eval(name, self.context)
414-
assert issubclass(cls, Table)
474+
cls = self._resolve_class(name)
475+
if cls is not None:
415476
description = cls().describe(context=self.context).split("\n")
416477
description = (
417478
("-" * 30 if q.startswith("---") else (q.replace("->", "&#8594;") if "->" in q else q.split(":")[0]))
@@ -437,34 +498,148 @@ def make_dot(self):
437498
edge.set_arrowhead("none")
438499
edge.set_penwidth(0.75 if props["multi"] else 2)
439500

501+
# Group nodes into schema clusters if requested
502+
if group_by_schema and schema_map:
503+
import pydot
504+
505+
# Group nodes by schema
506+
schemas = {}
507+
for node in list(dot.get_nodes()):
508+
name = node.get_name().strip('"')
509+
schema_name = schema_map.get(name)
510+
if schema_name:
511+
if schema_name not in schemas:
512+
schemas[schema_name] = []
513+
schemas[schema_name].append(node)
514+
515+
# Create clusters for each schema
516+
for schema_name, nodes in schemas.items():
517+
cluster = pydot.Cluster(
518+
f"cluster_{schema_name}",
519+
label=schema_name,
520+
style="dashed",
521+
color="gray",
522+
fontcolor="gray",
523+
)
524+
for node in nodes:
525+
cluster.add_node(node)
526+
dot.add_subgraph(cluster)
527+
440528
return dot
441529

442-
def make_svg(self):
530+
def make_svg(self, group_by_schema: bool = False):
443531
from IPython.display import SVG
444532

445-
return SVG(self.make_dot().create_svg())
533+
return SVG(self.make_dot(group_by_schema=group_by_schema).create_svg())
446534

447-
def make_png(self):
448-
return io.BytesIO(self.make_dot().create_png())
535+
def make_png(self, group_by_schema: bool = False):
536+
return io.BytesIO(self.make_dot(group_by_schema=group_by_schema).create_png())
449537

450-
def make_image(self):
538+
def make_image(self, group_by_schema: bool = False):
451539
if plot_active:
452-
return plt.imread(self.make_png())
540+
return plt.imread(self.make_png(group_by_schema=group_by_schema))
453541
else:
454542
raise DataJointError("pyplot was not imported")
455543

544+
def make_mermaid(self) -> str:
545+
"""
546+
Generate Mermaid diagram syntax.
547+
548+
Produces a flowchart in Mermaid syntax that can be rendered in
549+
Markdown documentation, GitHub, or https://mermaid.live.
550+
551+
Returns
552+
-------
553+
str
554+
Mermaid flowchart syntax.
555+
556+
Notes
557+
-----
558+
Layout direction is controlled via ``dj.config.display.diagram_direction``.
559+
560+
Examples
561+
--------
562+
>>> print(dj.Diagram(schema).make_mermaid())
563+
flowchart TB
564+
Mouse[Mouse]:::manual
565+
Session[Session]:::manual
566+
Neuron([Neuron]):::computed
567+
Mouse --> Session
568+
Session --> Neuron
569+
"""
570+
graph = self._make_graph()
571+
direction = config.display.diagram_direction
572+
573+
lines = [f"flowchart {direction}"]
574+
575+
# Define class styles matching Graphviz colors
576+
lines.append(" classDef manual fill:#90EE90,stroke:#006400")
577+
lines.append(" classDef lookup fill:#D3D3D3,stroke:#696969")
578+
lines.append(" classDef computed fill:#FFB6C1,stroke:#8B0000")
579+
lines.append(" classDef imported fill:#ADD8E6,stroke:#00008B")
580+
lines.append(" classDef part fill:#FFFFFF,stroke:#000000")
581+
lines.append("")
582+
583+
# Shape mapping: Manual=box, Computed/Imported=stadium, Lookup/Part=box
584+
shape_map = {
585+
Manual: ("[", "]"), # box
586+
Lookup: ("[", "]"), # box
587+
Computed: ("([", "])"), # stadium/pill
588+
Imported: ("([", "])"), # stadium/pill
589+
Part: ("[", "]"), # box
590+
_AliasNode: ("((", "))"), # circle
591+
None: ("((", "))"), # circle
592+
}
593+
594+
tier_class = {
595+
Manual: "manual",
596+
Lookup: "lookup",
597+
Computed: "computed",
598+
Imported: "imported",
599+
Part: "part",
600+
_AliasNode: "",
601+
None: "",
602+
}
603+
604+
# Add nodes
605+
for node, data in graph.nodes(data=True):
606+
tier = data.get("node_type")
607+
left, right = shape_map.get(tier, ("[", "]"))
608+
cls = tier_class.get(tier, "")
609+
# Mermaid node IDs can't have dots, replace with underscores
610+
safe_id = node.replace(".", "_").replace(" ", "_")
611+
class_suffix = f":::{cls}" if cls else ""
612+
lines.append(f" {safe_id}{left}{node}{right}{class_suffix}")
613+
614+
lines.append("")
615+
616+
# Add edges
617+
for src, dest, data in graph.edges(data=True):
618+
safe_src = src.replace(".", "_").replace(" ", "_")
619+
safe_dest = dest.replace(".", "_").replace(" ", "_")
620+
# Solid arrow for primary FK, dotted for non-primary
621+
style = "-->" if data.get("primary") else "-.->"
622+
lines.append(f" {safe_src} {style} {safe_dest}")
623+
624+
return "\n".join(lines)
625+
456626
def _repr_svg_(self):
457627
return self.make_svg()._repr_svg_()
458628

459-
def draw(self):
629+
def draw(self, group_by_schema: bool = False):
460630
if plot_active:
461-
plt.imshow(self.make_image())
631+
plt.imshow(self.make_image(group_by_schema=group_by_schema))
462632
plt.gca().axis("off")
463633
plt.show()
464634
else:
465635
raise DataJointError("pyplot was not imported")
466636

467-
def save(self, filename: str, format: str | None = None) -> None:
637+
def save(
638+
self,
639+
filename: str,
640+
format: str | None = None,
641+
group_by_schema: bool = False,
642+
) -> None:
468643
"""
469644
Save diagram to file.
470645
@@ -473,24 +648,39 @@ def save(self, filename: str, format: str | None = None) -> None:
473648
filename : str
474649
Output filename.
475650
format : str, optional
476-
File format (``'png'`` or ``'svg'``). Inferred from extension if None.
651+
File format (``'png'``, ``'svg'``, or ``'mermaid'``).
652+
Inferred from extension if None.
653+
group_by_schema : bool, optional
654+
If True, group nodes into clusters by their database schema.
655+
Default False. Only applies to png and svg formats.
477656
478657
Raises
479658
------
480659
DataJointError
481660
If format is unsupported.
661+
662+
Notes
663+
-----
664+
Layout direction is controlled via ``dj.config.display.diagram_direction``.
482665
"""
483666
if format is None:
484667
if filename.lower().endswith(".png"):
485668
format = "png"
486669
elif filename.lower().endswith(".svg"):
487670
format = "svg"
671+
elif filename.lower().endswith((".mmd", ".mermaid")):
672+
format = "mermaid"
673+
if format is None:
674+
raise DataJointError("Could not infer format from filename. Specify format explicitly.")
488675
if format.lower() == "png":
489676
with open(filename, "wb") as f:
490-
f.write(self.make_png().getbuffer().tobytes())
677+
f.write(self.make_png(group_by_schema=group_by_schema).getbuffer().tobytes())
491678
elif format.lower() == "svg":
492679
with open(filename, "w") as f:
493-
f.write(self.make_svg().data)
680+
f.write(self.make_svg(group_by_schema=group_by_schema).data)
681+
elif format.lower() == "mermaid":
682+
with open(filename, "w") as f:
683+
f.write(self.make_mermaid())
494684
else:
495685
raise DataJointError("Unsupported file format")
496686

src/datajoint/settings.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -221,6 +221,10 @@ class DisplaySettings(BaseSettings):
221221
limit: int = 12
222222
width: int = 14
223223
show_tuple_count: bool = True
224+
diagram_direction: Literal["TB", "LR"] = Field(
225+
default="TB",
226+
description="Default diagram layout direction: 'TB' (top-to-bottom) or 'LR' (left-to-right)",
227+
)
224228

225229

226230
class StoresSettings(BaseSettings):

0 commit comments

Comments
 (0)