Skip to content

Commit 0dd5a69

Browse files
feat: always group diagram nodes by schema with module labels
- Remove group_by_schema parameter (always enabled) - Show Python module name as cluster label when available - Assign alias nodes (orange dots) to child table's schema - Add schema grouping (subgraphs) to Mermaid output Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
1 parent f98cdbf commit 0dd5a69

File tree

1 file changed

+98
-51
lines changed

1 file changed

+98
-51
lines changed

src/datajoint/diagram.py

Lines changed: 98 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -361,16 +361,10 @@ def _encapsulate_node_names(graph: nx.DiGraph) -> None:
361361
copy=False,
362362
)
363363

364-
def make_dot(self, group_by_schema: bool = False):
364+
def make_dot(self):
365365
"""
366366
Generate a pydot graph object.
367367
368-
Parameters
369-
----------
370-
group_by_schema : bool, optional
371-
If True, group nodes into clusters by their database schema.
372-
Default False.
373-
374368
Returns
375369
-------
376370
pydot.Dot
@@ -379,21 +373,39 @@ def make_dot(self, group_by_schema: bool = False):
379373
Notes
380374
-----
381375
Layout direction is controlled via ``dj.config.display.diagram_direction``.
376+
Tables are grouped by schema, with the Python module name shown as the
377+
group label when available.
382378
"""
383379
direction = config.display.diagram_direction
384380
graph = self._make_graph()
385381

386-
# Build schema mapping if grouping is requested
387-
schema_map = {}
388-
if group_by_schema:
389-
for full_name in self.nodes_to_show:
390-
# Extract schema from full table name like `schema`.`table` or "schema"."table"
391-
parts = full_name.replace('"', '`').split('`')
392-
if len(parts) >= 2:
393-
schema_name = parts[1] # schema is between first pair of backticks
394-
# Find the class name for this full_name
395-
class_name = lookup_class_name(full_name, self.context) or full_name
396-
schema_map[class_name] = schema_name
382+
# Build schema mapping: class_name -> (schema_name, module_name)
383+
# Group by database schema, but label with Python module name when available
384+
schema_map = {} # class_name -> schema_name
385+
module_map = {} # schema_name -> module_name (for cluster labels)
386+
387+
for full_name in self.nodes_to_show:
388+
# Extract schema from full table name like `schema`.`table` or "schema"."table"
389+
parts = full_name.replace('"', '`').split('`')
390+
if len(parts) >= 2:
391+
schema_name = parts[1] # schema is between first pair of backticks
392+
class_name = lookup_class_name(full_name, self.context) or full_name
393+
schema_map[class_name] = schema_name
394+
395+
# Try to get Python module name for the cluster label
396+
if schema_name not in module_map:
397+
cls = self._resolve_class(class_name)
398+
if cls is not None and hasattr(cls, "__module__"):
399+
# Use the last part of the module path (e.g., "my_pipeline" from "package.my_pipeline")
400+
module_map[schema_name] = cls.__module__.split(".")[-1]
401+
402+
# Assign alias nodes (orange dots) to the same schema as their child table
403+
for node, data in graph.nodes(data=True):
404+
if data.get("node_type") is _AliasNode:
405+
# Find the child (successor) - the table that declares the renamed FK
406+
successors = list(graph.successors(node))
407+
if successors and successors[0] in schema_map:
408+
schema_map[node] = schema_map[successors[0]]
397409

398410
scale = 1.2 # scaling factor for fonts and boxes
399411
label_props = { # http://matplotlib.org/examples/color/named_colors.html
@@ -498,8 +510,8 @@ def make_dot(self, group_by_schema: bool = False):
498510
edge.set_arrowhead("none")
499511
edge.set_penwidth(0.75 if props["multi"] else 2)
500512

501-
# Group nodes into schema clusters if requested
502-
if group_by_schema and schema_map:
513+
# Group nodes into schema clusters (always on)
514+
if schema_map:
503515
import pydot
504516

505517
# Group nodes by schema
@@ -513,10 +525,12 @@ def make_dot(self, group_by_schema: bool = False):
513525
schemas[schema_name].append(node)
514526

515527
# Create clusters for each schema
528+
# Use Python module name as label when available, otherwise database schema name
516529
for schema_name, nodes in schemas.items():
530+
label = module_map.get(schema_name, schema_name)
517531
cluster = pydot.Cluster(
518532
f"cluster_{schema_name}",
519-
label=schema_name,
533+
label=label,
520534
style="dashed",
521535
color="gray",
522536
fontcolor="gray",
@@ -527,17 +541,17 @@ def make_dot(self, group_by_schema: bool = False):
527541

528542
return dot
529543

530-
def make_svg(self, group_by_schema: bool = False):
544+
def make_svg(self):
531545
from IPython.display import SVG
532546

533-
return SVG(self.make_dot(group_by_schema=group_by_schema).create_svg())
547+
return SVG(self.make_dot().create_svg())
534548

535-
def make_png(self, group_by_schema: bool = False):
536-
return io.BytesIO(self.make_dot(group_by_schema=group_by_schema).create_png())
549+
def make_png(self):
550+
return io.BytesIO(self.make_dot().create_png())
537551

538-
def make_image(self, group_by_schema: bool = False):
552+
def make_image(self):
539553
if plot_active:
540-
return plt.imread(self.make_png(group_by_schema=group_by_schema))
554+
return plt.imread(self.make_png())
541555
else:
542556
raise DataJointError("pyplot was not imported")
543557

@@ -556,20 +570,47 @@ def make_mermaid(self) -> str:
556570
Notes
557571
-----
558572
Layout direction is controlled via ``dj.config.display.diagram_direction``.
573+
Tables are grouped by schema using Mermaid subgraphs, with the Python
574+
module name shown as the group label when available.
559575
560576
Examples
561577
--------
562578
>>> print(dj.Diagram(schema).make_mermaid())
563579
flowchart TB
564-
Mouse[Mouse]:::manual
565-
Session[Session]:::manual
566-
Neuron([Neuron]):::computed
580+
subgraph my_pipeline
581+
Mouse[Mouse]:::manual
582+
Session[Session]:::manual
583+
Neuron([Neuron]):::computed
584+
end
567585
Mouse --> Session
568586
Session --> Neuron
569587
"""
570588
graph = self._make_graph()
571589
direction = config.display.diagram_direction
572590

591+
# Build schema mapping for grouping
592+
schema_map = {} # class_name -> schema_name
593+
module_map = {} # schema_name -> module_name (for subgraph labels)
594+
595+
for full_name in self.nodes_to_show:
596+
parts = full_name.replace('"', '`').split('`')
597+
if len(parts) >= 2:
598+
schema_name = parts[1]
599+
class_name = lookup_class_name(full_name, self.context) or full_name
600+
schema_map[class_name] = schema_name
601+
602+
if schema_name not in module_map:
603+
cls = self._resolve_class(class_name)
604+
if cls is not None and hasattr(cls, "__module__"):
605+
module_map[schema_name] = cls.__module__.split(".")[-1]
606+
607+
# Assign alias nodes to the same schema as their child table
608+
for node, data in graph.nodes(data=True):
609+
if data.get("node_type") is _AliasNode:
610+
successors = list(graph.successors(node))
611+
if successors and successors[0] in schema_map:
612+
schema_map[node] = schema_map[successors[0]]
613+
573614
lines = [f"flowchart {direction}"]
574615

575616
# Define class styles matching Graphviz colors
@@ -601,15 +642,27 @@ def make_mermaid(self) -> str:
601642
None: "",
602643
}
603644

604-
# Add nodes
645+
# Group nodes by schema into subgraphs
646+
schemas = {}
605647
for node, data in graph.nodes(data=True):
606-
tier = data.get("node_type")
607-
left, right = shape_map.get(tier, ("[", "]"))
608-
cls = tier_class.get(tier, "")
609-
# Mermaid node IDs can't have dots, replace with underscores
610-
safe_id = node.replace(".", "_").replace(" ", "_")
611-
class_suffix = f":::{cls}" if cls else ""
612-
lines.append(f" {safe_id}{left}{node}{right}{class_suffix}")
648+
schema_name = schema_map.get(node)
649+
if schema_name:
650+
if schema_name not in schemas:
651+
schemas[schema_name] = []
652+
schemas[schema_name].append((node, data))
653+
654+
# Add nodes grouped by schema subgraphs
655+
for schema_name, nodes in schemas.items():
656+
label = module_map.get(schema_name, schema_name)
657+
lines.append(f" subgraph {label}")
658+
for node, data in nodes:
659+
tier = data.get("node_type")
660+
left, right = shape_map.get(tier, ("[", "]"))
661+
cls = tier_class.get(tier, "")
662+
safe_id = node.replace(".", "_").replace(" ", "_")
663+
class_suffix = f":::{cls}" if cls else ""
664+
lines.append(f" {safe_id}{left}{node}{right}{class_suffix}")
665+
lines.append(" end")
613666

614667
lines.append("")
615668

@@ -626,20 +679,15 @@ def make_mermaid(self) -> str:
626679
def _repr_svg_(self):
627680
return self.make_svg()._repr_svg_()
628681

629-
def draw(self, group_by_schema: bool = False):
682+
def draw(self):
630683
if plot_active:
631-
plt.imshow(self.make_image(group_by_schema=group_by_schema))
684+
plt.imshow(self.make_image())
632685
plt.gca().axis("off")
633686
plt.show()
634687
else:
635688
raise DataJointError("pyplot was not imported")
636689

637-
def save(
638-
self,
639-
filename: str,
640-
format: str | None = None,
641-
group_by_schema: bool = False,
642-
) -> None:
690+
def save(self, filename: str, format: str | None = None) -> None:
643691
"""
644692
Save diagram to file.
645693
@@ -650,9 +698,6 @@ def save(
650698
format : str, optional
651699
File format (``'png'``, ``'svg'``, or ``'mermaid'``).
652700
Inferred from extension if None.
653-
group_by_schema : bool, optional
654-
If True, group nodes into clusters by their database schema.
655-
Default False. Only applies to png and svg formats.
656701
657702
Raises
658703
------
@@ -662,6 +707,8 @@ def save(
662707
Notes
663708
-----
664709
Layout direction is controlled via ``dj.config.display.diagram_direction``.
710+
Tables are grouped by schema, with the Python module name shown as the
711+
group label when available.
665712
"""
666713
if format is None:
667714
if filename.lower().endswith(".png"):
@@ -674,10 +721,10 @@ def save(
674721
raise DataJointError("Could not infer format from filename. Specify format explicitly.")
675722
if format.lower() == "png":
676723
with open(filename, "wb") as f:
677-
f.write(self.make_png(group_by_schema=group_by_schema).getbuffer().tobytes())
724+
f.write(self.make_png().getbuffer().tobytes())
678725
elif format.lower() == "svg":
679726
with open(filename, "w") as f:
680-
f.write(self.make_svg(group_by_schema=group_by_schema).data)
727+
f.write(self.make_svg().data)
681728
elif format.lower() == "mermaid":
682729
with open(filename, "w") as f:
683730
f.write(self.make_mermaid())

0 commit comments

Comments
 (0)