@@ -361,16 +361,10 @@ def _encapsulate_node_names(graph: nx.DiGraph) -> None:
361361 copy = False ,
362362 )
363363
364- def make_dot (self , group_by_schema : bool = False ):
364+ def make_dot (self ):
365365 """
366366 Generate a pydot graph object.
367367
368- Parameters
369- ----------
370- group_by_schema : bool, optional
371- If True, group nodes into clusters by their database schema.
372- Default False.
373-
374368 Returns
375369 -------
376370 pydot.Dot
@@ -379,21 +373,39 @@ def make_dot(self, group_by_schema: bool = False):
379373 Notes
380374 -----
381375 Layout direction is controlled via ``dj.config.display.diagram_direction``.
376+ Tables are grouped by schema, with the Python module name shown as the
377+ group label when available.
382378 """
383379 direction = config .display .diagram_direction
384380 graph = self ._make_graph ()
385381
386- # Build schema mapping if grouping is requested
387- schema_map = {}
388- if group_by_schema :
389- for full_name in self .nodes_to_show :
390- # Extract schema from full table name like `schema`.`table` or "schema"."table"
391- parts = full_name .replace ('"' , '`' ).split ('`' )
392- if len (parts ) >= 2 :
393- schema_name = parts [1 ] # schema is between first pair of backticks
394- # Find the class name for this full_name
395- class_name = lookup_class_name (full_name , self .context ) or full_name
396- schema_map [class_name ] = schema_name
382+ # Build schema mapping: class_name -> (schema_name, module_name)
383+ # Group by database schema, but label with Python module name when available
384+ schema_map = {} # class_name -> schema_name
385+ module_map = {} # schema_name -> module_name (for cluster labels)
386+
387+ for full_name in self .nodes_to_show :
388+ # Extract schema from full table name like `schema`.`table` or "schema"."table"
389+ parts = full_name .replace ('"' , '`' ).split ('`' )
390+ if len (parts ) >= 2 :
391+ schema_name = parts [1 ] # schema is between first pair of backticks
392+ class_name = lookup_class_name (full_name , self .context ) or full_name
393+ schema_map [class_name ] = schema_name
394+
395+ # Try to get Python module name for the cluster label
396+ if schema_name not in module_map :
397+ cls = self ._resolve_class (class_name )
398+ if cls is not None and hasattr (cls , "__module__" ):
399+ # Use the last part of the module path (e.g., "my_pipeline" from "package.my_pipeline")
400+ module_map [schema_name ] = cls .__module__ .split ("." )[- 1 ]
401+
402+ # Assign alias nodes (orange dots) to the same schema as their child table
403+ for node , data in graph .nodes (data = True ):
404+ if data .get ("node_type" ) is _AliasNode :
405+ # Find the child (successor) - the table that declares the renamed FK
406+ successors = list (graph .successors (node ))
407+ if successors and successors [0 ] in schema_map :
408+ schema_map [node ] = schema_map [successors [0 ]]
397409
398410 scale = 1.2 # scaling factor for fonts and boxes
399411 label_props = { # http://matplotlib.org/examples/color/named_colors.html
@@ -498,8 +510,8 @@ def make_dot(self, group_by_schema: bool = False):
498510 edge .set_arrowhead ("none" )
499511 edge .set_penwidth (0.75 if props ["multi" ] else 2 )
500512
501- # Group nodes into schema clusters if requested
502- if group_by_schema and schema_map :
513+ # Group nodes into schema clusters (always on)
514+ if schema_map :
503515 import pydot
504516
505517 # Group nodes by schema
@@ -513,10 +525,12 @@ def make_dot(self, group_by_schema: bool = False):
513525 schemas [schema_name ].append (node )
514526
515527 # Create clusters for each schema
528+ # Use Python module name as label when available, otherwise database schema name
516529 for schema_name , nodes in schemas .items ():
530+ label = module_map .get (schema_name , schema_name )
517531 cluster = pydot .Cluster (
518532 f"cluster_{ schema_name } " ,
519- label = schema_name ,
533+ label = label ,
520534 style = "dashed" ,
521535 color = "gray" ,
522536 fontcolor = "gray" ,
@@ -527,17 +541,17 @@ def make_dot(self, group_by_schema: bool = False):
527541
528542 return dot
529543
530- def make_svg (self , group_by_schema : bool = False ):
544+ def make_svg (self ):
531545 from IPython .display import SVG
532546
533- return SVG (self .make_dot (group_by_schema = group_by_schema ).create_svg ())
547+ return SVG (self .make_dot ().create_svg ())
534548
535- def make_png (self , group_by_schema : bool = False ):
536- return io .BytesIO (self .make_dot (group_by_schema = group_by_schema ).create_png ())
549+ def make_png (self ):
550+ return io .BytesIO (self .make_dot ().create_png ())
537551
538- def make_image (self , group_by_schema : bool = False ):
552+ def make_image (self ):
539553 if plot_active :
540- return plt .imread (self .make_png (group_by_schema = group_by_schema ))
554+ return plt .imread (self .make_png ())
541555 else :
542556 raise DataJointError ("pyplot was not imported" )
543557
@@ -556,20 +570,47 @@ def make_mermaid(self) -> str:
556570 Notes
557571 -----
558572 Layout direction is controlled via ``dj.config.display.diagram_direction``.
573+ Tables are grouped by schema using Mermaid subgraphs, with the Python
574+ module name shown as the group label when available.
559575
560576 Examples
561577 --------
562578 >>> print(dj.Diagram(schema).make_mermaid())
563579 flowchart TB
564- Mouse[Mouse]:::manual
565- Session[Session]:::manual
566- Neuron([Neuron]):::computed
580+ subgraph my_pipeline
581+ Mouse[Mouse]:::manual
582+ Session[Session]:::manual
583+ Neuron([Neuron]):::computed
584+ end
567585 Mouse --> Session
568586 Session --> Neuron
569587 """
570588 graph = self ._make_graph ()
571589 direction = config .display .diagram_direction
572590
591+ # Build schema mapping for grouping
592+ schema_map = {} # class_name -> schema_name
593+ module_map = {} # schema_name -> module_name (for subgraph labels)
594+
595+ for full_name in self .nodes_to_show :
596+ parts = full_name .replace ('"' , '`' ).split ('`' )
597+ if len (parts ) >= 2 :
598+ schema_name = parts [1 ]
599+ class_name = lookup_class_name (full_name , self .context ) or full_name
600+ schema_map [class_name ] = schema_name
601+
602+ if schema_name not in module_map :
603+ cls = self ._resolve_class (class_name )
604+ if cls is not None and hasattr (cls , "__module__" ):
605+ module_map [schema_name ] = cls .__module__ .split ("." )[- 1 ]
606+
607+ # Assign alias nodes to the same schema as their child table
608+ for node , data in graph .nodes (data = True ):
609+ if data .get ("node_type" ) is _AliasNode :
610+ successors = list (graph .successors (node ))
611+ if successors and successors [0 ] in schema_map :
612+ schema_map [node ] = schema_map [successors [0 ]]
613+
573614 lines = [f"flowchart { direction } " ]
574615
575616 # Define class styles matching Graphviz colors
@@ -601,15 +642,27 @@ def make_mermaid(self) -> str:
601642 None : "" ,
602643 }
603644
604- # Add nodes
645+ # Group nodes by schema into subgraphs
646+ schemas = {}
605647 for node , data in graph .nodes (data = True ):
606- tier = data .get ("node_type" )
607- left , right = shape_map .get (tier , ("[" , "]" ))
608- cls = tier_class .get (tier , "" )
609- # Mermaid node IDs can't have dots, replace with underscores
610- safe_id = node .replace ("." , "_" ).replace (" " , "_" )
611- class_suffix = f":::{ cls } " if cls else ""
612- lines .append (f" { safe_id } { left } { node } { right } { class_suffix } " )
648+ schema_name = schema_map .get (node )
649+ if schema_name :
650+ if schema_name not in schemas :
651+ schemas [schema_name ] = []
652+ schemas [schema_name ].append ((node , data ))
653+
654+ # Add nodes grouped by schema subgraphs
655+ for schema_name , nodes in schemas .items ():
656+ label = module_map .get (schema_name , schema_name )
657+ lines .append (f" subgraph { label } " )
658+ for node , data in nodes :
659+ tier = data .get ("node_type" )
660+ left , right = shape_map .get (tier , ("[" , "]" ))
661+ cls = tier_class .get (tier , "" )
662+ safe_id = node .replace ("." , "_" ).replace (" " , "_" )
663+ class_suffix = f":::{ cls } " if cls else ""
664+ lines .append (f" { safe_id } { left } { node } { right } { class_suffix } " )
665+ lines .append (" end" )
613666
614667 lines .append ("" )
615668
@@ -626,20 +679,15 @@ def make_mermaid(self) -> str:
626679 def _repr_svg_ (self ):
627680 return self .make_svg ()._repr_svg_ ()
628681
629- def draw (self , group_by_schema : bool = False ):
682+ def draw (self ):
630683 if plot_active :
631- plt .imshow (self .make_image (group_by_schema = group_by_schema ))
684+ plt .imshow (self .make_image ())
632685 plt .gca ().axis ("off" )
633686 plt .show ()
634687 else :
635688 raise DataJointError ("pyplot was not imported" )
636689
637- def save (
638- self ,
639- filename : str ,
640- format : str | None = None ,
641- group_by_schema : bool = False ,
642- ) -> None :
690+ def save (self , filename : str , format : str | None = None ) -> None :
643691 """
644692 Save diagram to file.
645693
@@ -650,9 +698,6 @@ def save(
650698 format : str, optional
651699 File format (``'png'``, ``'svg'``, or ``'mermaid'``).
652700 Inferred from extension if None.
653- group_by_schema : bool, optional
654- If True, group nodes into clusters by their database schema.
655- Default False. Only applies to png and svg formats.
656701
657702 Raises
658703 ------
@@ -662,6 +707,8 @@ def save(
662707 Notes
663708 -----
664709 Layout direction is controlled via ``dj.config.display.diagram_direction``.
710+ Tables are grouped by schema, with the Python module name shown as the
711+ group label when available.
665712 """
666713 if format is None :
667714 if filename .lower ().endswith (".png" ):
@@ -674,10 +721,10 @@ def save(
674721 raise DataJointError ("Could not infer format from filename. Specify format explicitly." )
675722 if format .lower () == "png" :
676723 with open (filename , "wb" ) as f :
677- f .write (self .make_png (group_by_schema = group_by_schema ).getbuffer ().tobytes ())
724+ f .write (self .make_png ().getbuffer ().tobytes ())
678725 elif format .lower () == "svg" :
679726 with open (filename , "w" ) as f :
680- f .write (self .make_svg (group_by_schema = group_by_schema ).data )
727+ f .write (self .make_svg ().data )
681728 elif format .lower () == "mermaid" :
682729 with open (filename , "w" ) as f :
683730 f .write (self .make_mermaid ())
0 commit comments