Skip to content

Commit ea443f4

Browse files
feat(diagram): add direction, Mermaid output, and schema grouping
Bug fixes: - Fix isdigit() missing parentheses in _make_graph - Fix nested list creation in _make_graph - Remove dead code in make_dot - Fix invalid color code for Part tier - Replace eval() with safe _resolve_class() method New features: - Add direction parameter ("TB", "LR", "BT", "RL") for layout control - Add make_mermaid() method for web-friendly diagram output - Add group_by_schema parameter to cluster nodes by database schema - Update save() to support .mmd/.mermaid file extensions Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
1 parent 648bd1a commit ea443f4

File tree

1 file changed

+238
-21
lines changed

1 file changed

+238
-21
lines changed

src/datajoint/diagram.py

Lines changed: 238 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,9 @@ class Diagram(nx.DiGraph):
7070
context : dict, optional
7171
Namespace for resolving table class names. If None, uses caller's
7272
frame globals/locals.
73+
direction : str, optional
74+
Default layout direction: "TB" (top-to-bottom, default) or
75+
"LR" (left-to-right). Can be overridden in output methods.
7376
7477
Examples
7578
--------
@@ -92,13 +95,43 @@ class Diagram(nx.DiGraph):
9295
Only tables loaded in the connection are displayed.
9396
"""
9497

95-
def __init__(self, source, context=None) -> None:
98+
_VALID_DIRECTIONS = ("TB", "LR")
99+
100+
@staticmethod
101+
def _validate_direction(direction: str) -> str:
102+
"""Validate and normalize direction parameter."""
103+
if direction not in Diagram._VALID_DIRECTIONS:
104+
raise ValueError(
105+
f"Invalid direction '{direction}'. Must be one of: {Diagram._VALID_DIRECTIONS}"
106+
)
107+
return direction
108+
109+
def set_direction(self, direction: str) -> None:
110+
"""
111+
Set the default layout direction.
112+
113+
Parameters
114+
----------
115+
direction : str
116+
Layout direction: "TB" (top-to-bottom) or "LR" (left-to-right).
117+
118+
Examples
119+
--------
120+
>>> diag = dj.Diagram(Mouse) + 1 - 1
121+
>>> diag.set_direction("LR")
122+
>>> diag.draw()
123+
"""
124+
self.direction = self._validate_direction(direction)
125+
126+
def __init__(self, source, context=None, direction: str = "TB") -> None:
96127
if isinstance(source, Diagram):
97128
# copy constructor
98129
self.nodes_to_show = set(source.nodes_to_show)
99130
self.context = source.context
131+
self.direction = source.direction
100132
super().__init__(source)
101133
return
134+
self.direction = self._validate_direction(direction)
102135

103136
# get the caller's context
104137
if context is None:
@@ -286,18 +319,42 @@ def _make_graph(self) -> nx.DiGraph:
286319
gaps = set(nx.algorithms.boundary.node_boundary(self, self.nodes_to_show)).intersection(
287320
nx.algorithms.boundary.node_boundary(nx.DiGraph(self).reverse(), self.nodes_to_show)
288321
)
289-
nodes = self.nodes_to_show.union(a for a in gaps if a.isdigit)
322+
nodes = self.nodes_to_show.union(a for a in gaps if a.isdigit())
290323
# construct subgraph and rename nodes to class names
291324
graph = nx.DiGraph(nx.DiGraph(self).subgraph(nodes))
292325
nx.set_node_attributes(graph, name="node_type", values={n: _get_tier(n) for n in graph})
293326
# relabel nodes to class names
294327
mapping = {node: lookup_class_name(node, self.context) or node for node in graph.nodes()}
295-
new_names = [mapping.values()]
328+
new_names = list(mapping.values())
296329
if len(new_names) > len(set(new_names)):
297330
raise DataJointError("Some classes have identical names. The Diagram cannot be plotted.")
298331
nx.relabel_nodes(graph, mapping, copy=False)
299332
return graph
300333

334+
def _resolve_class(self, name: str):
335+
"""
336+
Safely resolve a table class from a dotted name without eval().
337+
338+
Parameters
339+
----------
340+
name : str
341+
Dotted class name like "MyTable" or "Module.MyTable".
342+
343+
Returns
344+
-------
345+
type or None
346+
The table class if found, otherwise None.
347+
"""
348+
parts = name.split(".")
349+
obj = self.context.get(parts[0])
350+
for part in parts[1:]:
351+
if obj is None:
352+
return None
353+
obj = getattr(obj, part, None)
354+
if obj is not None and isinstance(obj, type) and issubclass(obj, Table):
355+
return obj
356+
return None
357+
301358
@staticmethod
302359
def _encapsulate_edge_attributes(graph: nx.DiGraph) -> None:
303360
"""
@@ -330,9 +387,38 @@ def _encapsulate_node_names(graph: nx.DiGraph) -> None:
330387
copy=False,
331388
)
332389

333-
def make_dot(self):
390+
def make_dot(self, direction: str | None = None, group_by_schema: bool = False):
391+
"""
392+
Generate a pydot graph object.
393+
394+
Parameters
395+
----------
396+
direction : str, optional
397+
Layout direction: "TB" (top-to-bottom) or "LR" (left-to-right).
398+
Defaults to instance direction.
399+
group_by_schema : bool, optional
400+
If True, group nodes into clusters by their database schema.
401+
Default False.
402+
403+
Returns
404+
-------
405+
pydot.Dot
406+
The graph object ready for rendering.
407+
"""
408+
direction = self._validate_direction(direction) if direction else self.direction
334409
graph = self._make_graph()
335-
graph.nodes()
410+
411+
# Build schema mapping if grouping is requested
412+
schema_map = {}
413+
if group_by_schema:
414+
for full_name in self.nodes_to_show:
415+
# Extract schema from full table name like `schema`.`table` or "schema"."table"
416+
parts = full_name.replace('"', '`').split('`')
417+
if len(parts) >= 2:
418+
schema_name = parts[1] # schema is between first pair of backticks
419+
# Find the class name for this full_name
420+
class_name = lookup_class_name(full_name, self.context) or full_name
421+
schema_map[class_name] = schema_name
336422

337423
scale = 1.2 # scaling factor for fonts and boxes
338424
label_props = { # http://matplotlib.org/examples/color/named_colors.html
@@ -386,7 +472,7 @@ def make_dot(self):
386472
),
387473
Part: dict(
388474
shape="plaintext",
389-
color="#0000000",
475+
color="#00000000",
390476
fontcolor="black",
391477
fontsize=round(scale * 8),
392478
size=0.1 * scale,
@@ -398,6 +484,7 @@ def make_dot(self):
398484
self._encapsulate_node_names(graph)
399485
self._encapsulate_edge_attributes(graph)
400486
dot = nx.drawing.nx_pydot.to_pydot(graph)
487+
dot.set_rankdir(direction)
401488
for node in dot.get_nodes():
402489
node.set_shape("circle")
403490
name = node.get_name().strip('"')
@@ -409,9 +496,8 @@ def make_dot(self):
409496
node.set_fixedsize("shape" if props["fixed"] else False)
410497
node.set_width(props["size"])
411498
node.set_height(props["size"])
412-
if name.split(".")[0] in self.context:
413-
cls = eval(name, self.context)
414-
assert issubclass(cls, Table)
499+
cls = self._resolve_class(name)
500+
if cls is not None:
415501
description = cls().describe(context=self.context).split("\n")
416502
description = (
417503
("-" * 30 if q.startswith("---") else (q.replace("->", "&#8594;") if "->" in q else q.split(":")[0]))
@@ -437,34 +523,151 @@ def make_dot(self):
437523
edge.set_arrowhead("none")
438524
edge.set_penwidth(0.75 if props["multi"] else 2)
439525

526+
# Group nodes into schema clusters if requested
527+
if group_by_schema and schema_map:
528+
import pydot
529+
530+
# Group nodes by schema
531+
schemas = {}
532+
for node in list(dot.get_nodes()):
533+
name = node.get_name().strip('"')
534+
schema_name = schema_map.get(name)
535+
if schema_name:
536+
if schema_name not in schemas:
537+
schemas[schema_name] = []
538+
schemas[schema_name].append(node)
539+
540+
# Create clusters for each schema
541+
for schema_name, nodes in schemas.items():
542+
cluster = pydot.Cluster(
543+
f"cluster_{schema_name}",
544+
label=schema_name,
545+
style="dashed",
546+
color="gray",
547+
fontcolor="gray",
548+
)
549+
for node in nodes:
550+
cluster.add_node(node)
551+
dot.add_subgraph(cluster)
552+
440553
return dot
441554

442-
def make_svg(self):
555+
def make_svg(self, direction: str | None = None, group_by_schema: bool = False):
443556
from IPython.display import SVG
444557

445-
return SVG(self.make_dot().create_svg())
558+
return SVG(self.make_dot(direction=direction, group_by_schema=group_by_schema).create_svg())
446559

447-
def make_png(self):
448-
return io.BytesIO(self.make_dot().create_png())
560+
def make_png(self, direction: str | None = None, group_by_schema: bool = False):
561+
return io.BytesIO(self.make_dot(direction=direction, group_by_schema=group_by_schema).create_png())
449562

450-
def make_image(self):
563+
def make_image(self, direction: str | None = None, group_by_schema: bool = False):
451564
if plot_active:
452-
return plt.imread(self.make_png())
565+
return plt.imread(self.make_png(direction=direction, group_by_schema=group_by_schema))
453566
else:
454567
raise DataJointError("pyplot was not imported")
455568

569+
def make_mermaid(self, direction: str | None = None) -> str:
570+
"""
571+
Generate Mermaid diagram syntax.
572+
573+
Produces a flowchart in Mermaid syntax that can be rendered in
574+
Markdown documentation, GitHub, or https://mermaid.live.
575+
576+
Parameters
577+
----------
578+
direction : str, optional
579+
Layout direction: "TB" (top-to-bottom) or "LR" (left-to-right).
580+
Defaults to instance direction.
581+
582+
Returns
583+
-------
584+
str
585+
Mermaid flowchart syntax.
586+
587+
Examples
588+
--------
589+
>>> print(dj.Diagram(schema).make_mermaid())
590+
flowchart TB
591+
Mouse[Mouse]:::manual
592+
Session[Session]:::manual
593+
Neuron([Neuron]):::computed
594+
Mouse --> Session
595+
Session --> Neuron
596+
"""
597+
graph = self._make_graph()
598+
direction = self._validate_direction(direction) if direction else self.direction
599+
600+
lines = [f"flowchart {direction}"]
601+
602+
# Define class styles matching Graphviz colors
603+
lines.append(" classDef manual fill:#90EE90,stroke:#006400")
604+
lines.append(" classDef lookup fill:#D3D3D3,stroke:#696969")
605+
lines.append(" classDef computed fill:#FFB6C1,stroke:#8B0000")
606+
lines.append(" classDef imported fill:#ADD8E6,stroke:#00008B")
607+
lines.append(" classDef part fill:#FFFFFF,stroke:#000000")
608+
lines.append("")
609+
610+
# Shape mapping: Manual=box, Computed/Imported=stadium, Lookup/Part=box
611+
shape_map = {
612+
Manual: ("[", "]"), # box
613+
Lookup: ("[", "]"), # box
614+
Computed: ("([", "])"), # stadium/pill
615+
Imported: ("([", "])"), # stadium/pill
616+
Part: ("[", "]"), # box
617+
_AliasNode: ("((", "))"), # circle
618+
None: ("((", "))"), # circle
619+
}
620+
621+
tier_class = {
622+
Manual: "manual",
623+
Lookup: "lookup",
624+
Computed: "computed",
625+
Imported: "imported",
626+
Part: "part",
627+
_AliasNode: "",
628+
None: "",
629+
}
630+
631+
# Add nodes
632+
for node, data in graph.nodes(data=True):
633+
tier = data.get("node_type")
634+
left, right = shape_map.get(tier, ("[", "]"))
635+
cls = tier_class.get(tier, "")
636+
# Mermaid node IDs can't have dots, replace with underscores
637+
safe_id = node.replace(".", "_").replace(" ", "_")
638+
class_suffix = f":::{cls}" if cls else ""
639+
lines.append(f" {safe_id}{left}{node}{right}{class_suffix}")
640+
641+
lines.append("")
642+
643+
# Add edges
644+
for src, dest, data in graph.edges(data=True):
645+
safe_src = src.replace(".", "_").replace(" ", "_")
646+
safe_dest = dest.replace(".", "_").replace(" ", "_")
647+
# Solid arrow for primary FK, dotted for non-primary
648+
style = "-->" if data.get("primary") else "-.->"
649+
lines.append(f" {safe_src} {style} {safe_dest}")
650+
651+
return "\n".join(lines)
652+
456653
def _repr_svg_(self):
457654
return self.make_svg()._repr_svg_()
458655

459-
def draw(self):
656+
def draw(self, direction: str | None = None, group_by_schema: bool = False):
460657
if plot_active:
461-
plt.imshow(self.make_image())
658+
plt.imshow(self.make_image(direction=direction, group_by_schema=group_by_schema))
462659
plt.gca().axis("off")
463660
plt.show()
464661
else:
465662
raise DataJointError("pyplot was not imported")
466663

467-
def save(self, filename: str, format: str | None = None) -> None:
664+
def save(
665+
self,
666+
filename: str,
667+
format: str | None = None,
668+
direction: str | None = None,
669+
group_by_schema: bool = False,
670+
) -> None:
468671
"""
469672
Save diagram to file.
470673
@@ -473,7 +676,14 @@ def save(self, filename: str, format: str | None = None) -> None:
473676
filename : str
474677
Output filename.
475678
format : str, optional
476-
File format (``'png'`` or ``'svg'``). Inferred from extension if None.
679+
File format (``'png'``, ``'svg'``, or ``'mermaid'``).
680+
Inferred from extension if None.
681+
direction : str, optional
682+
Layout direction: "TB" (top-to-bottom) or "LR" (left-to-right).
683+
Defaults to instance direction.
684+
group_by_schema : bool, optional
685+
If True, group nodes into clusters by their database schema.
686+
Default False. Only applies to png and svg formats.
477687
478688
Raises
479689
------
@@ -485,12 +695,19 @@ def save(self, filename: str, format: str | None = None) -> None:
485695
format = "png"
486696
elif filename.lower().endswith(".svg"):
487697
format = "svg"
698+
elif filename.lower().endswith((".mmd", ".mermaid")):
699+
format = "mermaid"
700+
if format is None:
701+
raise DataJointError("Could not infer format from filename. Specify format explicitly.")
488702
if format.lower() == "png":
489703
with open(filename, "wb") as f:
490-
f.write(self.make_png().getbuffer().tobytes())
704+
f.write(self.make_png(direction=direction, group_by_schema=group_by_schema).getbuffer().tobytes())
491705
elif format.lower() == "svg":
492706
with open(filename, "w") as f:
493-
f.write(self.make_svg().data)
707+
f.write(self.make_svg(direction=direction, group_by_schema=group_by_schema).data)
708+
elif format.lower() == "mermaid":
709+
with open(filename, "w") as f:
710+
f.write(self.make_mermaid(direction=direction))
494711
else:
495712
raise DataJointError("Unsupported file format")
496713

0 commit comments

Comments
 (0)