@@ -70,6 +70,9 @@ class Diagram(nx.DiGraph):
7070 context : dict, optional
7171 Namespace for resolving table class names. If None, uses caller's
7272 frame globals/locals.
73+ direction : str, optional
74+ Default layout direction: "TB" (top-to-bottom, default) or
75+ "LR" (left-to-right). Can be overridden in output methods.
7376
7477 Examples
7578 --------
@@ -92,13 +95,43 @@ class Diagram(nx.DiGraph):
9295 Only tables loaded in the connection are displayed.
9396 """
9497
95- def __init__ (self , source , context = None ) -> None :
98+ _VALID_DIRECTIONS = ("TB" , "LR" )
99+
100+ @staticmethod
101+ def _validate_direction (direction : str ) -> str :
102+ """Validate and normalize direction parameter."""
103+ if direction not in Diagram ._VALID_DIRECTIONS :
104+ raise ValueError (
105+ f"Invalid direction '{ direction } '. Must be one of: { Diagram ._VALID_DIRECTIONS } "
106+ )
107+ return direction
108+
109+ def set_direction (self , direction : str ) -> None :
110+ """
111+ Set the default layout direction.
112+
113+ Parameters
114+ ----------
115+ direction : str
116+ Layout direction: "TB" (top-to-bottom) or "LR" (left-to-right).
117+
118+ Examples
119+ --------
120+ >>> diag = dj.Diagram(Mouse) + 1 - 1
121+ >>> diag.set_direction("LR")
122+ >>> diag.draw()
123+ """
124+ self .direction = self ._validate_direction (direction )
125+
126+ def __init__ (self , source , context = None , direction : str = "TB" ) -> None :
96127 if isinstance (source , Diagram ):
97128 # copy constructor
98129 self .nodes_to_show = set (source .nodes_to_show )
99130 self .context = source .context
131+ self .direction = source .direction
100132 super ().__init__ (source )
101133 return
134+ self .direction = self ._validate_direction (direction )
102135
103136 # get the caller's context
104137 if context is None :
@@ -286,18 +319,42 @@ def _make_graph(self) -> nx.DiGraph:
286319 gaps = set (nx .algorithms .boundary .node_boundary (self , self .nodes_to_show )).intersection (
287320 nx .algorithms .boundary .node_boundary (nx .DiGraph (self ).reverse (), self .nodes_to_show )
288321 )
289- nodes = self .nodes_to_show .union (a for a in gaps if a .isdigit )
322+ nodes = self .nodes_to_show .union (a for a in gaps if a .isdigit () )
290323 # construct subgraph and rename nodes to class names
291324 graph = nx .DiGraph (nx .DiGraph (self ).subgraph (nodes ))
292325 nx .set_node_attributes (graph , name = "node_type" , values = {n : _get_tier (n ) for n in graph })
293326 # relabel nodes to class names
294327 mapping = {node : lookup_class_name (node , self .context ) or node for node in graph .nodes ()}
295- new_names = [ mapping .values ()]
328+ new_names = list ( mapping .values ())
296329 if len (new_names ) > len (set (new_names )):
297330 raise DataJointError ("Some classes have identical names. The Diagram cannot be plotted." )
298331 nx .relabel_nodes (graph , mapping , copy = False )
299332 return graph
300333
334+ def _resolve_class (self , name : str ):
335+ """
336+ Safely resolve a table class from a dotted name without eval().
337+
338+ Parameters
339+ ----------
340+ name : str
341+ Dotted class name like "MyTable" or "Module.MyTable".
342+
343+ Returns
344+ -------
345+ type or None
346+ The table class if found, otherwise None.
347+ """
348+ parts = name .split ("." )
349+ obj = self .context .get (parts [0 ])
350+ for part in parts [1 :]:
351+ if obj is None :
352+ return None
353+ obj = getattr (obj , part , None )
354+ if obj is not None and isinstance (obj , type ) and issubclass (obj , Table ):
355+ return obj
356+ return None
357+
301358 @staticmethod
302359 def _encapsulate_edge_attributes (graph : nx .DiGraph ) -> None :
303360 """
@@ -330,9 +387,38 @@ def _encapsulate_node_names(graph: nx.DiGraph) -> None:
330387 copy = False ,
331388 )
332389
333- def make_dot (self ):
390+ def make_dot (self , direction : str | None = None , group_by_schema : bool = False ):
391+ """
392+ Generate a pydot graph object.
393+
394+ Parameters
395+ ----------
396+ direction : str, optional
397+ Layout direction: "TB" (top-to-bottom) or "LR" (left-to-right).
398+ Defaults to instance direction.
399+ group_by_schema : bool, optional
400+ If True, group nodes into clusters by their database schema.
401+ Default False.
402+
403+ Returns
404+ -------
405+ pydot.Dot
406+ The graph object ready for rendering.
407+ """
408+ direction = self ._validate_direction (direction ) if direction else self .direction
334409 graph = self ._make_graph ()
335- graph .nodes ()
410+
411+ # Build schema mapping if grouping is requested
412+ schema_map = {}
413+ if group_by_schema :
414+ for full_name in self .nodes_to_show :
415+ # Extract schema from full table name like `schema`.`table` or "schema"."table"
416+ parts = full_name .replace ('"' , '`' ).split ('`' )
417+ if len (parts ) >= 2 :
418+ schema_name = parts [1 ] # schema is between first pair of backticks
419+ # Find the class name for this full_name
420+ class_name = lookup_class_name (full_name , self .context ) or full_name
421+ schema_map [class_name ] = schema_name
336422
337423 scale = 1.2 # scaling factor for fonts and boxes
338424 label_props = { # http://matplotlib.org/examples/color/named_colors.html
@@ -386,7 +472,7 @@ def make_dot(self):
386472 ),
387473 Part : dict (
388474 shape = "plaintext" ,
389- color = "#0000000 " ,
475+ color = "#00000000 " ,
390476 fontcolor = "black" ,
391477 fontsize = round (scale * 8 ),
392478 size = 0.1 * scale ,
@@ -398,6 +484,7 @@ def make_dot(self):
398484 self ._encapsulate_node_names (graph )
399485 self ._encapsulate_edge_attributes (graph )
400486 dot = nx .drawing .nx_pydot .to_pydot (graph )
487+ dot .set_rankdir (direction )
401488 for node in dot .get_nodes ():
402489 node .set_shape ("circle" )
403490 name = node .get_name ().strip ('"' )
@@ -409,9 +496,8 @@ def make_dot(self):
409496 node .set_fixedsize ("shape" if props ["fixed" ] else False )
410497 node .set_width (props ["size" ])
411498 node .set_height (props ["size" ])
412- if name .split ("." )[0 ] in self .context :
413- cls = eval (name , self .context )
414- assert issubclass (cls , Table )
499+ cls = self ._resolve_class (name )
500+ if cls is not None :
415501 description = cls ().describe (context = self .context ).split ("\n " )
416502 description = (
417503 ("-" * 30 if q .startswith ("---" ) else (q .replace ("->" , "→" ) if "->" in q else q .split (":" )[0 ]))
@@ -437,34 +523,151 @@ def make_dot(self):
437523 edge .set_arrowhead ("none" )
438524 edge .set_penwidth (0.75 if props ["multi" ] else 2 )
439525
526+ # Group nodes into schema clusters if requested
527+ if group_by_schema and schema_map :
528+ import pydot
529+
530+ # Group nodes by schema
531+ schemas = {}
532+ for node in list (dot .get_nodes ()):
533+ name = node .get_name ().strip ('"' )
534+ schema_name = schema_map .get (name )
535+ if schema_name :
536+ if schema_name not in schemas :
537+ schemas [schema_name ] = []
538+ schemas [schema_name ].append (node )
539+
540+ # Create clusters for each schema
541+ for schema_name , nodes in schemas .items ():
542+ cluster = pydot .Cluster (
543+ f"cluster_{ schema_name } " ,
544+ label = schema_name ,
545+ style = "dashed" ,
546+ color = "gray" ,
547+ fontcolor = "gray" ,
548+ )
549+ for node in nodes :
550+ cluster .add_node (node )
551+ dot .add_subgraph (cluster )
552+
440553 return dot
441554
442- def make_svg (self ):
555+ def make_svg (self , direction : str | None = None , group_by_schema : bool = False ):
443556 from IPython .display import SVG
444557
445- return SVG (self .make_dot ().create_svg ())
558+ return SVG (self .make_dot (direction = direction , group_by_schema = group_by_schema ).create_svg ())
446559
447- def make_png (self ):
448- return io .BytesIO (self .make_dot ().create_png ())
560+ def make_png (self , direction : str | None = None , group_by_schema : bool = False ):
561+ return io .BytesIO (self .make_dot (direction = direction , group_by_schema = group_by_schema ).create_png ())
449562
450- def make_image (self ):
563+ def make_image (self , direction : str | None = None , group_by_schema : bool = False ):
451564 if plot_active :
452- return plt .imread (self .make_png ())
565+ return plt .imread (self .make_png (direction = direction , group_by_schema = group_by_schema ))
453566 else :
454567 raise DataJointError ("pyplot was not imported" )
455568
569+ def make_mermaid (self , direction : str | None = None ) -> str :
570+ """
571+ Generate Mermaid diagram syntax.
572+
573+ Produces a flowchart in Mermaid syntax that can be rendered in
574+ Markdown documentation, GitHub, or https://mermaid.live.
575+
576+ Parameters
577+ ----------
578+ direction : str, optional
579+ Layout direction: "TB" (top-to-bottom) or "LR" (left-to-right).
580+ Defaults to instance direction.
581+
582+ Returns
583+ -------
584+ str
585+ Mermaid flowchart syntax.
586+
587+ Examples
588+ --------
589+ >>> print(dj.Diagram(schema).make_mermaid())
590+ flowchart TB
591+ Mouse[Mouse]:::manual
592+ Session[Session]:::manual
593+ Neuron([Neuron]):::computed
594+ Mouse --> Session
595+ Session --> Neuron
596+ """
597+ graph = self ._make_graph ()
598+ direction = self ._validate_direction (direction ) if direction else self .direction
599+
600+ lines = [f"flowchart { direction } " ]
601+
602+ # Define class styles matching Graphviz colors
603+ lines .append (" classDef manual fill:#90EE90,stroke:#006400" )
604+ lines .append (" classDef lookup fill:#D3D3D3,stroke:#696969" )
605+ lines .append (" classDef computed fill:#FFB6C1,stroke:#8B0000" )
606+ lines .append (" classDef imported fill:#ADD8E6,stroke:#00008B" )
607+ lines .append (" classDef part fill:#FFFFFF,stroke:#000000" )
608+ lines .append ("" )
609+
610+ # Shape mapping: Manual=box, Computed/Imported=stadium, Lookup/Part=box
611+ shape_map = {
612+ Manual : ("[" , "]" ), # box
613+ Lookup : ("[" , "]" ), # box
614+ Computed : ("([" , "])" ), # stadium/pill
615+ Imported : ("([" , "])" ), # stadium/pill
616+ Part : ("[" , "]" ), # box
617+ _AliasNode : ("((" , "))" ), # circle
618+ None : ("((" , "))" ), # circle
619+ }
620+
621+ tier_class = {
622+ Manual : "manual" ,
623+ Lookup : "lookup" ,
624+ Computed : "computed" ,
625+ Imported : "imported" ,
626+ Part : "part" ,
627+ _AliasNode : "" ,
628+ None : "" ,
629+ }
630+
631+ # Add nodes
632+ for node , data in graph .nodes (data = True ):
633+ tier = data .get ("node_type" )
634+ left , right = shape_map .get (tier , ("[" , "]" ))
635+ cls = tier_class .get (tier , "" )
636+ # Mermaid node IDs can't have dots, replace with underscores
637+ safe_id = node .replace ("." , "_" ).replace (" " , "_" )
638+ class_suffix = f":::{ cls } " if cls else ""
639+ lines .append (f" { safe_id } { left } { node } { right } { class_suffix } " )
640+
641+ lines .append ("" )
642+
643+ # Add edges
644+ for src , dest , data in graph .edges (data = True ):
645+ safe_src = src .replace ("." , "_" ).replace (" " , "_" )
646+ safe_dest = dest .replace ("." , "_" ).replace (" " , "_" )
647+ # Solid arrow for primary FK, dotted for non-primary
648+ style = "-->" if data .get ("primary" ) else "-.->"
649+ lines .append (f" { safe_src } { style } { safe_dest } " )
650+
651+ return "\n " .join (lines )
652+
456653 def _repr_svg_ (self ):
457654 return self .make_svg ()._repr_svg_ ()
458655
459- def draw (self ):
656+ def draw (self , direction : str | None = None , group_by_schema : bool = False ):
460657 if plot_active :
461- plt .imshow (self .make_image ())
658+ plt .imshow (self .make_image (direction = direction , group_by_schema = group_by_schema ))
462659 plt .gca ().axis ("off" )
463660 plt .show ()
464661 else :
465662 raise DataJointError ("pyplot was not imported" )
466663
467- def save (self , filename : str , format : str | None = None ) -> None :
664+ def save (
665+ self ,
666+ filename : str ,
667+ format : str | None = None ,
668+ direction : str | None = None ,
669+ group_by_schema : bool = False ,
670+ ) -> None :
468671 """
469672 Save diagram to file.
470673
@@ -473,7 +676,14 @@ def save(self, filename: str, format: str | None = None) -> None:
473676 filename : str
474677 Output filename.
475678 format : str, optional
476- File format (``'png'`` or ``'svg'``). Inferred from extension if None.
679+ File format (``'png'``, ``'svg'``, or ``'mermaid'``).
680+ Inferred from extension if None.
681+ direction : str, optional
682+ Layout direction: "TB" (top-to-bottom) or "LR" (left-to-right).
683+ Defaults to instance direction.
684+ group_by_schema : bool, optional
685+ If True, group nodes into clusters by their database schema.
686+ Default False. Only applies to png and svg formats.
477687
478688 Raises
479689 ------
@@ -485,12 +695,19 @@ def save(self, filename: str, format: str | None = None) -> None:
485695 format = "png"
486696 elif filename .lower ().endswith (".svg" ):
487697 format = "svg"
698+ elif filename .lower ().endswith ((".mmd" , ".mermaid" )):
699+ format = "mermaid"
700+ if format is None :
701+ raise DataJointError ("Could not infer format from filename. Specify format explicitly." )
488702 if format .lower () == "png" :
489703 with open (filename , "wb" ) as f :
490- f .write (self .make_png ().getbuffer ().tobytes ())
704+ f .write (self .make_png (direction = direction , group_by_schema = group_by_schema ).getbuffer ().tobytes ())
491705 elif format .lower () == "svg" :
492706 with open (filename , "w" ) as f :
493- f .write (self .make_svg ().data )
707+ f .write (self .make_svg (direction = direction , group_by_schema = group_by_schema ).data )
708+ elif format .lower () == "mermaid" :
709+ with open (filename , "w" ) as f :
710+ f .write (self .make_mermaid (direction = direction ))
494711 else :
495712 raise DataJointError ("Unsupported file format" )
496713
0 commit comments