@@ -70,6 +70,9 @@ class Diagram(nx.DiGraph):
7070 context : dict, optional
7171 Namespace for resolving table class names. If None, uses caller's
7272 frame globals/locals.
73+ direction : str, optional
74+ Layout direction: "TB" (top-to-bottom, default), "LR" (left-to-right),
75+ "BT" (bottom-to-top), or "RL" (right-to-left).
7376
7477 Examples
7578 --------
@@ -92,13 +95,15 @@ class Diagram(nx.DiGraph):
9295 Only tables loaded in the connection are displayed.
9396 """
9497
95- def __init__ (self , source , context = None ) -> None :
98+ def __init__ (self , source , context = None , direction : str = "TB" ) -> None :
9699 if isinstance (source , Diagram ):
97100 # copy constructor
98101 self .nodes_to_show = set (source .nodes_to_show )
99102 self .context = source .context
103+ self .direction = source .direction
100104 super ().__init__ (source )
101105 return
106+ self .direction = direction
102107
103108 # get the caller's context
104109 if context is None :
@@ -286,18 +291,42 @@ def _make_graph(self) -> nx.DiGraph:
286291 gaps = set (nx .algorithms .boundary .node_boundary (self , self .nodes_to_show )).intersection (
287292 nx .algorithms .boundary .node_boundary (nx .DiGraph (self ).reverse (), self .nodes_to_show )
288293 )
289- nodes = self .nodes_to_show .union (a for a in gaps if a .isdigit )
294+ nodes = self .nodes_to_show .union (a for a in gaps if a .isdigit () )
290295 # construct subgraph and rename nodes to class names
291296 graph = nx .DiGraph (nx .DiGraph (self ).subgraph (nodes ))
292297 nx .set_node_attributes (graph , name = "node_type" , values = {n : _get_tier (n ) for n in graph })
293298 # relabel nodes to class names
294299 mapping = {node : lookup_class_name (node , self .context ) or node for node in graph .nodes ()}
295- new_names = [ mapping .values ()]
300+ new_names = list ( mapping .values ())
296301 if len (new_names ) > len (set (new_names )):
297302 raise DataJointError ("Some classes have identical names. The Diagram cannot be plotted." )
298303 nx .relabel_nodes (graph , mapping , copy = False )
299304 return graph
300305
306+ def _resolve_class (self , name : str ):
307+ """
308+ Safely resolve a table class from a dotted name without eval().
309+
310+ Parameters
311+ ----------
312+ name : str
313+ Dotted class name like "MyTable" or "Module.MyTable".
314+
315+ Returns
316+ -------
317+ type or None
318+ The table class if found, otherwise None.
319+ """
320+ parts = name .split ("." )
321+ obj = self .context .get (parts [0 ])
322+ for part in parts [1 :]:
323+ if obj is None :
324+ return None
325+ obj = getattr (obj , part , None )
326+ if obj is not None and isinstance (obj , type ) and issubclass (obj , Table ):
327+ return obj
328+ return None
329+
301330 @staticmethod
302331 def _encapsulate_edge_attributes (graph : nx .DiGraph ) -> None :
303332 """
@@ -330,9 +359,34 @@ def _encapsulate_node_names(graph: nx.DiGraph) -> None:
330359 copy = False ,
331360 )
332361
333- def make_dot (self ):
362+ def make_dot (self , group_by_schema : bool = False ):
363+ """
364+ Generate a pydot graph object.
365+
366+ Parameters
367+ ----------
368+ group_by_schema : bool, optional
369+ If True, group nodes into clusters by their database schema.
370+ Default False.
371+
372+ Returns
373+ -------
374+ pydot.Dot
375+ The graph object ready for rendering.
376+ """
334377 graph = self ._make_graph ()
335- graph .nodes ()
378+
379+ # Build schema mapping if grouping is requested
380+ schema_map = {}
381+ if group_by_schema :
382+ for full_name in self .nodes_to_show :
383+ # Extract schema from full table name like `schema`.`table` or "schema"."table"
384+ parts = full_name .replace ('"' , '`' ).split ('`' )
385+ if len (parts ) >= 2 :
386+ schema_name = parts [1 ] # schema is between first pair of backticks
387+ # Find the class name for this full_name
388+ class_name = lookup_class_name (full_name , self .context ) or full_name
389+ schema_map [class_name ] = schema_name
336390
337391 scale = 1.2 # scaling factor for fonts and boxes
338392 label_props = { # http://matplotlib.org/examples/color/named_colors.html
@@ -386,7 +440,7 @@ def make_dot(self):
386440 ),
387441 Part : dict (
388442 shape = "plaintext" ,
389- color = "#0000000 " ,
443+ color = "#00000000 " ,
390444 fontcolor = "black" ,
391445 fontsize = round (scale * 8 ),
392446 size = 0.1 * scale ,
@@ -398,6 +452,7 @@ def make_dot(self):
398452 self ._encapsulate_node_names (graph )
399453 self ._encapsulate_edge_attributes (graph )
400454 dot = nx .drawing .nx_pydot .to_pydot (graph )
455+ dot .set_rankdir (self .direction )
401456 for node in dot .get_nodes ():
402457 node .set_shape ("circle" )
403458 name = node .get_name ().strip ('"' )
@@ -409,9 +464,8 @@ def make_dot(self):
409464 node .set_fixedsize ("shape" if props ["fixed" ] else False )
410465 node .set_width (props ["size" ])
411466 node .set_height (props ["size" ])
412- if name .split ("." )[0 ] in self .context :
413- cls = eval (name , self .context )
414- assert issubclass (cls , Table )
467+ cls = self ._resolve_class (name )
468+ if cls is not None :
415469 description = cls ().describe (context = self .context ).split ("\n " )
416470 description = (
417471 ("-" * 30 if q .startswith ("---" ) else (q .replace ("->" , "→" ) if "->" in q else q .split (":" )[0 ]))
@@ -437,34 +491,150 @@ def make_dot(self):
437491 edge .set_arrowhead ("none" )
438492 edge .set_penwidth (0.75 if props ["multi" ] else 2 )
439493
494+ # Group nodes into schema clusters if requested
495+ if group_by_schema and schema_map :
496+ import pydot
497+
498+ # Group nodes by schema
499+ schemas = {}
500+ for node in list (dot .get_nodes ()):
501+ name = node .get_name ().strip ('"' )
502+ schema_name = schema_map .get (name )
503+ if schema_name :
504+ if schema_name not in schemas :
505+ schemas [schema_name ] = []
506+ schemas [schema_name ].append (node )
507+
508+ # Create clusters for each schema
509+ for schema_name , nodes in schemas .items ():
510+ cluster = pydot .Cluster (
511+ f"cluster_{ schema_name } " ,
512+ label = schema_name ,
513+ style = "dashed" ,
514+ color = "gray" ,
515+ fontcolor = "gray" ,
516+ )
517+ for node in nodes :
518+ cluster .add_node (node )
519+ dot .add_subgraph (cluster )
520+
440521 return dot
441522
442- def make_svg (self ):
523+ def make_svg (self , group_by_schema : bool = False ):
443524 from IPython .display import SVG
444525
445- return SVG (self .make_dot ().create_svg ())
526+ return SVG (self .make_dot (group_by_schema = group_by_schema ).create_svg ())
446527
447- def make_png (self ):
448- return io .BytesIO (self .make_dot ().create_png ())
528+ def make_png (self , group_by_schema : bool = False ):
529+ return io .BytesIO (self .make_dot (group_by_schema = group_by_schema ).create_png ())
449530
450- def make_image (self ):
531+ def make_image (self , group_by_schema : bool = False ):
451532 if plot_active :
452- return plt .imread (self .make_png ())
533+ return plt .imread (self .make_png (group_by_schema = group_by_schema ))
453534 else :
454535 raise DataJointError ("pyplot was not imported" )
455536
537+ def make_mermaid (self , direction : str | None = None ) -> str :
538+ """
539+ Generate Mermaid diagram syntax.
540+
541+ Produces a flowchart in Mermaid syntax that can be rendered in
542+ Markdown documentation, GitHub, or https://mermaid.live.
543+
544+ Parameters
545+ ----------
546+ direction : str, optional
547+ Override layout direction for this render. Uses instance direction
548+ if not specified.
549+
550+ Returns
551+ -------
552+ str
553+ Mermaid flowchart syntax.
554+
555+ Examples
556+ --------
557+ >>> print(dj.Diagram(schema).make_mermaid())
558+ flowchart TB
559+ Mouse[Mouse]:::manual
560+ Session[Session]:::manual
561+ Neuron([Neuron]):::computed
562+ Mouse --> Session
563+ Session --> Neuron
564+ """
565+ graph = self ._make_graph ()
566+ direction = direction or self .direction
567+
568+ lines = [f"flowchart { direction } " ]
569+
570+ # Define class styles matching Graphviz colors
571+ lines .append (" classDef manual fill:#90EE90,stroke:#006400" )
572+ lines .append (" classDef lookup fill:#D3D3D3,stroke:#696969" )
573+ lines .append (" classDef computed fill:#FFB6C1,stroke:#8B0000" )
574+ lines .append (" classDef imported fill:#ADD8E6,stroke:#00008B" )
575+ lines .append (" classDef part fill:#FFFFFF,stroke:#000000" )
576+ lines .append ("" )
577+
578+ # Shape mapping: Manual=box, Computed/Imported=stadium, Lookup/Part=box
579+ shape_map = {
580+ Manual : ("[" , "]" ), # box
581+ Lookup : ("[" , "]" ), # box
582+ Computed : ("([" , "])" ), # stadium/pill
583+ Imported : ("([" , "])" ), # stadium/pill
584+ Part : ("[" , "]" ), # box
585+ _AliasNode : ("((" , "))" ), # circle
586+ None : ("((" , "))" ), # circle
587+ }
588+
589+ tier_class = {
590+ Manual : "manual" ,
591+ Lookup : "lookup" ,
592+ Computed : "computed" ,
593+ Imported : "imported" ,
594+ Part : "part" ,
595+ _AliasNode : "" ,
596+ None : "" ,
597+ }
598+
599+ # Add nodes
600+ for node , data in graph .nodes (data = True ):
601+ tier = data .get ("node_type" )
602+ left , right = shape_map .get (tier , ("[" , "]" ))
603+ cls = tier_class .get (tier , "" )
604+ # Mermaid node IDs can't have dots, replace with underscores
605+ safe_id = node .replace ("." , "_" ).replace (" " , "_" )
606+ class_suffix = f":::{ cls } " if cls else ""
607+ lines .append (f" { safe_id } { left } { node } { right } { class_suffix } " )
608+
609+ lines .append ("" )
610+
611+ # Add edges
612+ for src , dest , data in graph .edges (data = True ):
613+ safe_src = src .replace ("." , "_" ).replace (" " , "_" )
614+ safe_dest = dest .replace ("." , "_" ).replace (" " , "_" )
615+ # Solid arrow for primary FK, dotted for non-primary
616+ style = "-->" if data .get ("primary" ) else "-.->"
617+ lines .append (f" { safe_src } { style } { safe_dest } " )
618+
619+ return "\n " .join (lines )
620+
456621 def _repr_svg_ (self ):
457622 return self .make_svg ()._repr_svg_ ()
458623
459- def draw (self ):
624+ def draw (self , group_by_schema : bool = False ):
460625 if plot_active :
461- plt .imshow (self .make_image ())
626+ plt .imshow (self .make_image (group_by_schema = group_by_schema ))
462627 plt .gca ().axis ("off" )
463628 plt .show ()
464629 else :
465630 raise DataJointError ("pyplot was not imported" )
466631
467- def save (self , filename : str , format : str | None = None ) -> None :
632+ def save (
633+ self ,
634+ filename : str ,
635+ format : str | None = None ,
636+ group_by_schema : bool = False ,
637+ ) -> None :
468638 """
469639 Save diagram to file.
470640
@@ -473,7 +643,11 @@ def save(self, filename: str, format: str | None = None) -> None:
473643 filename : str
474644 Output filename.
475645 format : str, optional
476- File format (``'png'`` or ``'svg'``). Inferred from extension if None.
646+ File format (``'png'``, ``'svg'``, or ``'mermaid'``).
647+ Inferred from extension if None.
648+ group_by_schema : bool, optional
649+ If True, group nodes into clusters by their database schema.
650+ Default False. Only applies to png and svg formats.
477651
478652 Raises
479653 ------
@@ -485,12 +659,19 @@ def save(self, filename: str, format: str | None = None) -> None:
485659 format = "png"
486660 elif filename .lower ().endswith (".svg" ):
487661 format = "svg"
662+ elif filename .lower ().endswith ((".mmd" , ".mermaid" )):
663+ format = "mermaid"
664+ if format is None :
665+ raise DataJointError ("Could not infer format from filename. Specify format explicitly." )
488666 if format .lower () == "png" :
489667 with open (filename , "wb" ) as f :
490- f .write (self .make_png ().getbuffer ().tobytes ())
668+ f .write (self .make_png (group_by_schema = group_by_schema ).getbuffer ().tobytes ())
491669 elif format .lower () == "svg" :
492670 with open (filename , "w" ) as f :
493- f .write (self .make_svg ().data )
671+ f .write (self .make_svg (group_by_schema = group_by_schema ).data )
672+ elif format .lower () == "mermaid" :
673+ with open (filename , "w" ) as f :
674+ f .write (self .make_mermaid ())
494675 else :
495676 raise DataJointError ("Unsupported file format" )
496677
0 commit comments