Skip to content

Commit ddcf592

Browse files
feat(diagram): add direction, Mermaid output, and schema grouping
Bug fixes: - Fix isdigit() missing parentheses in _make_graph - Fix nested list creation in _make_graph - Remove dead code in make_dot - Fix invalid color code for Part tier - Replace eval() with safe _resolve_class() method New features: - Add direction parameter ("TB", "LR", "BT", "RL") for layout control - Add make_mermaid() method for web-friendly diagram output - Add group_by_schema parameter to cluster nodes by database schema - Update save() to support .mmd/.mermaid file extensions Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
1 parent 648bd1a commit ddcf592

File tree

1 file changed

+202
-21
lines changed

1 file changed

+202
-21
lines changed

src/datajoint/diagram.py

Lines changed: 202 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,9 @@ class Diagram(nx.DiGraph):
7070
context : dict, optional
7171
Namespace for resolving table class names. If None, uses caller's
7272
frame globals/locals.
73+
direction : str, optional
74+
Layout direction: "TB" (top-to-bottom, default), "LR" (left-to-right),
75+
"BT" (bottom-to-top), or "RL" (right-to-left).
7376
7477
Examples
7578
--------
@@ -92,13 +95,15 @@ class Diagram(nx.DiGraph):
9295
Only tables loaded in the connection are displayed.
9396
"""
9497

95-
def __init__(self, source, context=None) -> None:
98+
def __init__(self, source, context=None, direction: str = "TB") -> None:
9699
if isinstance(source, Diagram):
97100
# copy constructor
98101
self.nodes_to_show = set(source.nodes_to_show)
99102
self.context = source.context
103+
self.direction = source.direction
100104
super().__init__(source)
101105
return
106+
self.direction = direction
102107

103108
# get the caller's context
104109
if context is None:
@@ -286,18 +291,42 @@ def _make_graph(self) -> nx.DiGraph:
286291
gaps = set(nx.algorithms.boundary.node_boundary(self, self.nodes_to_show)).intersection(
287292
nx.algorithms.boundary.node_boundary(nx.DiGraph(self).reverse(), self.nodes_to_show)
288293
)
289-
nodes = self.nodes_to_show.union(a for a in gaps if a.isdigit)
294+
nodes = self.nodes_to_show.union(a for a in gaps if a.isdigit())
290295
# construct subgraph and rename nodes to class names
291296
graph = nx.DiGraph(nx.DiGraph(self).subgraph(nodes))
292297
nx.set_node_attributes(graph, name="node_type", values={n: _get_tier(n) for n in graph})
293298
# relabel nodes to class names
294299
mapping = {node: lookup_class_name(node, self.context) or node for node in graph.nodes()}
295-
new_names = [mapping.values()]
300+
new_names = list(mapping.values())
296301
if len(new_names) > len(set(new_names)):
297302
raise DataJointError("Some classes have identical names. The Diagram cannot be plotted.")
298303
nx.relabel_nodes(graph, mapping, copy=False)
299304
return graph
300305

306+
def _resolve_class(self, name: str):
307+
"""
308+
Safely resolve a table class from a dotted name without eval().
309+
310+
Parameters
311+
----------
312+
name : str
313+
Dotted class name like "MyTable" or "Module.MyTable".
314+
315+
Returns
316+
-------
317+
type or None
318+
The table class if found, otherwise None.
319+
"""
320+
parts = name.split(".")
321+
obj = self.context.get(parts[0])
322+
for part in parts[1:]:
323+
if obj is None:
324+
return None
325+
obj = getattr(obj, part, None)
326+
if obj is not None and isinstance(obj, type) and issubclass(obj, Table):
327+
return obj
328+
return None
329+
301330
@staticmethod
302331
def _encapsulate_edge_attributes(graph: nx.DiGraph) -> None:
303332
"""
@@ -330,9 +359,34 @@ def _encapsulate_node_names(graph: nx.DiGraph) -> None:
330359
copy=False,
331360
)
332361

333-
def make_dot(self):
362+
def make_dot(self, group_by_schema: bool = False):
363+
"""
364+
Generate a pydot graph object.
365+
366+
Parameters
367+
----------
368+
group_by_schema : bool, optional
369+
If True, group nodes into clusters by their database schema.
370+
Default False.
371+
372+
Returns
373+
-------
374+
pydot.Dot
375+
The graph object ready for rendering.
376+
"""
334377
graph = self._make_graph()
335-
graph.nodes()
378+
379+
# Build schema mapping if grouping is requested
380+
schema_map = {}
381+
if group_by_schema:
382+
for full_name in self.nodes_to_show:
383+
# Extract schema from full table name like `schema`.`table` or "schema"."table"
384+
parts = full_name.replace('"', '`').split('`')
385+
if len(parts) >= 2:
386+
schema_name = parts[1] # schema is between first pair of backticks
387+
# Find the class name for this full_name
388+
class_name = lookup_class_name(full_name, self.context) or full_name
389+
schema_map[class_name] = schema_name
336390

337391
scale = 1.2 # scaling factor for fonts and boxes
338392
label_props = { # http://matplotlib.org/examples/color/named_colors.html
@@ -386,7 +440,7 @@ def make_dot(self):
386440
),
387441
Part: dict(
388442
shape="plaintext",
389-
color="#0000000",
443+
color="#00000000",
390444
fontcolor="black",
391445
fontsize=round(scale * 8),
392446
size=0.1 * scale,
@@ -398,6 +452,7 @@ def make_dot(self):
398452
self._encapsulate_node_names(graph)
399453
self._encapsulate_edge_attributes(graph)
400454
dot = nx.drawing.nx_pydot.to_pydot(graph)
455+
dot.set_rankdir(self.direction)
401456
for node in dot.get_nodes():
402457
node.set_shape("circle")
403458
name = node.get_name().strip('"')
@@ -409,9 +464,8 @@ def make_dot(self):
409464
node.set_fixedsize("shape" if props["fixed"] else False)
410465
node.set_width(props["size"])
411466
node.set_height(props["size"])
412-
if name.split(".")[0] in self.context:
413-
cls = eval(name, self.context)
414-
assert issubclass(cls, Table)
467+
cls = self._resolve_class(name)
468+
if cls is not None:
415469
description = cls().describe(context=self.context).split("\n")
416470
description = (
417471
("-" * 30 if q.startswith("---") else (q.replace("->", "&#8594;") if "->" in q else q.split(":")[0]))
@@ -437,34 +491,150 @@ def make_dot(self):
437491
edge.set_arrowhead("none")
438492
edge.set_penwidth(0.75 if props["multi"] else 2)
439493

494+
# Group nodes into schema clusters if requested
495+
if group_by_schema and schema_map:
496+
import pydot
497+
498+
# Group nodes by schema
499+
schemas = {}
500+
for node in list(dot.get_nodes()):
501+
name = node.get_name().strip('"')
502+
schema_name = schema_map.get(name)
503+
if schema_name:
504+
if schema_name not in schemas:
505+
schemas[schema_name] = []
506+
schemas[schema_name].append(node)
507+
508+
# Create clusters for each schema
509+
for schema_name, nodes in schemas.items():
510+
cluster = pydot.Cluster(
511+
f"cluster_{schema_name}",
512+
label=schema_name,
513+
style="dashed",
514+
color="gray",
515+
fontcolor="gray",
516+
)
517+
for node in nodes:
518+
cluster.add_node(node)
519+
dot.add_subgraph(cluster)
520+
440521
return dot
441522

442-
def make_svg(self):
523+
def make_svg(self, group_by_schema: bool = False):
443524
from IPython.display import SVG
444525

445-
return SVG(self.make_dot().create_svg())
526+
return SVG(self.make_dot(group_by_schema=group_by_schema).create_svg())
446527

447-
def make_png(self):
448-
return io.BytesIO(self.make_dot().create_png())
528+
def make_png(self, group_by_schema: bool = False):
529+
return io.BytesIO(self.make_dot(group_by_schema=group_by_schema).create_png())
449530

450-
def make_image(self):
531+
def make_image(self, group_by_schema: bool = False):
451532
if plot_active:
452-
return plt.imread(self.make_png())
533+
return plt.imread(self.make_png(group_by_schema=group_by_schema))
453534
else:
454535
raise DataJointError("pyplot was not imported")
455536

537+
def make_mermaid(self, direction: str | None = None) -> str:
538+
"""
539+
Generate Mermaid diagram syntax.
540+
541+
Produces a flowchart in Mermaid syntax that can be rendered in
542+
Markdown documentation, GitHub, or https://mermaid.live.
543+
544+
Parameters
545+
----------
546+
direction : str, optional
547+
Override layout direction for this render. Uses instance direction
548+
if not specified.
549+
550+
Returns
551+
-------
552+
str
553+
Mermaid flowchart syntax.
554+
555+
Examples
556+
--------
557+
>>> print(dj.Diagram(schema).make_mermaid())
558+
flowchart TB
559+
Mouse[Mouse]:::manual
560+
Session[Session]:::manual
561+
Neuron([Neuron]):::computed
562+
Mouse --> Session
563+
Session --> Neuron
564+
"""
565+
graph = self._make_graph()
566+
direction = direction or self.direction
567+
568+
lines = [f"flowchart {direction}"]
569+
570+
# Define class styles matching Graphviz colors
571+
lines.append(" classDef manual fill:#90EE90,stroke:#006400")
572+
lines.append(" classDef lookup fill:#D3D3D3,stroke:#696969")
573+
lines.append(" classDef computed fill:#FFB6C1,stroke:#8B0000")
574+
lines.append(" classDef imported fill:#ADD8E6,stroke:#00008B")
575+
lines.append(" classDef part fill:#FFFFFF,stroke:#000000")
576+
lines.append("")
577+
578+
# Shape mapping: Manual=box, Computed/Imported=stadium, Lookup/Part=box
579+
shape_map = {
580+
Manual: ("[", "]"), # box
581+
Lookup: ("[", "]"), # box
582+
Computed: ("([", "])"), # stadium/pill
583+
Imported: ("([", "])"), # stadium/pill
584+
Part: ("[", "]"), # box
585+
_AliasNode: ("((", "))"), # circle
586+
None: ("((", "))"), # circle
587+
}
588+
589+
tier_class = {
590+
Manual: "manual",
591+
Lookup: "lookup",
592+
Computed: "computed",
593+
Imported: "imported",
594+
Part: "part",
595+
_AliasNode: "",
596+
None: "",
597+
}
598+
599+
# Add nodes
600+
for node, data in graph.nodes(data=True):
601+
tier = data.get("node_type")
602+
left, right = shape_map.get(tier, ("[", "]"))
603+
cls = tier_class.get(tier, "")
604+
# Mermaid node IDs can't have dots, replace with underscores
605+
safe_id = node.replace(".", "_").replace(" ", "_")
606+
class_suffix = f":::{cls}" if cls else ""
607+
lines.append(f" {safe_id}{left}{node}{right}{class_suffix}")
608+
609+
lines.append("")
610+
611+
# Add edges
612+
for src, dest, data in graph.edges(data=True):
613+
safe_src = src.replace(".", "_").replace(" ", "_")
614+
safe_dest = dest.replace(".", "_").replace(" ", "_")
615+
# Solid arrow for primary FK, dotted for non-primary
616+
style = "-->" if data.get("primary") else "-.->"
617+
lines.append(f" {safe_src} {style} {safe_dest}")
618+
619+
return "\n".join(lines)
620+
456621
def _repr_svg_(self):
457622
return self.make_svg()._repr_svg_()
458623

459-
def draw(self):
624+
def draw(self, group_by_schema: bool = False):
460625
if plot_active:
461-
plt.imshow(self.make_image())
626+
plt.imshow(self.make_image(group_by_schema=group_by_schema))
462627
plt.gca().axis("off")
463628
plt.show()
464629
else:
465630
raise DataJointError("pyplot was not imported")
466631

467-
def save(self, filename: str, format: str | None = None) -> None:
632+
def save(
633+
self,
634+
filename: str,
635+
format: str | None = None,
636+
group_by_schema: bool = False,
637+
) -> None:
468638
"""
469639
Save diagram to file.
470640
@@ -473,7 +643,11 @@ def save(self, filename: str, format: str | None = None) -> None:
473643
filename : str
474644
Output filename.
475645
format : str, optional
476-
File format (``'png'`` or ``'svg'``). Inferred from extension if None.
646+
File format (``'png'``, ``'svg'``, or ``'mermaid'``).
647+
Inferred from extension if None.
648+
group_by_schema : bool, optional
649+
If True, group nodes into clusters by their database schema.
650+
Default False. Only applies to png and svg formats.
477651
478652
Raises
479653
------
@@ -485,12 +659,19 @@ def save(self, filename: str, format: str | None = None) -> None:
485659
format = "png"
486660
elif filename.lower().endswith(".svg"):
487661
format = "svg"
662+
elif filename.lower().endswith((".mmd", ".mermaid")):
663+
format = "mermaid"
664+
if format is None:
665+
raise DataJointError("Could not infer format from filename. Specify format explicitly.")
488666
if format.lower() == "png":
489667
with open(filename, "wb") as f:
490-
f.write(self.make_png().getbuffer().tobytes())
668+
f.write(self.make_png(group_by_schema=group_by_schema).getbuffer().tobytes())
491669
elif format.lower() == "svg":
492670
with open(filename, "w") as f:
493-
f.write(self.make_svg().data)
671+
f.write(self.make_svg(group_by_schema=group_by_schema).data)
672+
elif format.lower() == "mermaid":
673+
with open(filename, "w") as f:
674+
f.write(self.make_mermaid())
494675
else:
495676
raise DataJointError("Unsupported file format")
496677

0 commit comments

Comments
 (0)