1616
1717from .dependencies import topo_sort
1818from .errors import DataJointError
19+ from .settings import config
1920from .table import Table , lookup_class_name
2021from .user_tables import Computed , Imported , Lookup , Manual , Part , _AliasNode , _get_tier
2122
@@ -90,6 +91,12 @@ class Diagram(nx.DiGraph):
9091 -----
9192 ``diagram + 1 - 1`` may differ from ``diagram - 1 + 1``.
9293 Only tables loaded in the connection are displayed.
94+
95+ Layout direction is controlled via ``dj.config.display.diagram_direction``
96+ (default ``"TB"``). Use ``dj.config.override()`` to change temporarily::
97+
98+ with dj.config.override(display_diagram_direction="LR"):
99+ dj.Diagram(schema).draw()
93100 """
94101
95102 def __init__ (self , source , context = None ) -> None :
@@ -286,18 +293,42 @@ def _make_graph(self) -> nx.DiGraph:
286293 gaps = set (nx .algorithms .boundary .node_boundary (self , self .nodes_to_show )).intersection (
287294 nx .algorithms .boundary .node_boundary (nx .DiGraph (self ).reverse (), self .nodes_to_show )
288295 )
289- nodes = self .nodes_to_show .union (a for a in gaps if a .isdigit )
296+ nodes = self .nodes_to_show .union (a for a in gaps if a .isdigit () )
290297 # construct subgraph and rename nodes to class names
291298 graph = nx .DiGraph (nx .DiGraph (self ).subgraph (nodes ))
292299 nx .set_node_attributes (graph , name = "node_type" , values = {n : _get_tier (n ) for n in graph })
293300 # relabel nodes to class names
294301 mapping = {node : lookup_class_name (node , self .context ) or node for node in graph .nodes ()}
295- new_names = [ mapping .values ()]
302+ new_names = list ( mapping .values ())
296303 if len (new_names ) > len (set (new_names )):
297304 raise DataJointError ("Some classes have identical names. The Diagram cannot be plotted." )
298305 nx .relabel_nodes (graph , mapping , copy = False )
299306 return graph
300307
308+ def _resolve_class (self , name : str ):
309+ """
310+ Safely resolve a table class from a dotted name without eval().
311+
312+ Parameters
313+ ----------
314+ name : str
315+ Dotted class name like "MyTable" or "Module.MyTable".
316+
317+ Returns
318+ -------
319+ type or None
320+ The table class if found, otherwise None.
321+ """
322+ parts = name .split ("." )
323+ obj = self .context .get (parts [0 ])
324+ for part in parts [1 :]:
325+ if obj is None :
326+ return None
327+ obj = getattr (obj , part , None )
328+ if obj is not None and isinstance (obj , type ) and issubclass (obj , Table ):
329+ return obj
330+ return None
331+
301332 @staticmethod
302333 def _encapsulate_edge_attributes (graph : nx .DiGraph ) -> None :
303334 """
@@ -330,9 +361,39 @@ def _encapsulate_node_names(graph: nx.DiGraph) -> None:
330361 copy = False ,
331362 )
332363
333- def make_dot (self ):
364+ def make_dot (self , group_by_schema : bool = False ):
365+ """
366+ Generate a pydot graph object.
367+
368+ Parameters
369+ ----------
370+ group_by_schema : bool, optional
371+ If True, group nodes into clusters by their database schema.
372+ Default False.
373+
374+ Returns
375+ -------
376+ pydot.Dot
377+ The graph object ready for rendering.
378+
379+ Notes
380+ -----
381+ Layout direction is controlled via ``dj.config.display.diagram_direction``.
382+ """
383+ direction = config .display .diagram_direction
334384 graph = self ._make_graph ()
335- graph .nodes ()
385+
386+ # Build schema mapping if grouping is requested
387+ schema_map = {}
388+ if group_by_schema :
389+ for full_name in self .nodes_to_show :
390+ # Extract schema from full table name like `schema`.`table` or "schema"."table"
391+ parts = full_name .replace ('"' , '`' ).split ('`' )
392+ if len (parts ) >= 2 :
393+ schema_name = parts [1 ] # schema is between first pair of backticks
394+ # Find the class name for this full_name
395+ class_name = lookup_class_name (full_name , self .context ) or full_name
396+ schema_map [class_name ] = schema_name
336397
337398 scale = 1.2 # scaling factor for fonts and boxes
338399 label_props = { # http://matplotlib.org/examples/color/named_colors.html
@@ -386,7 +447,7 @@ def make_dot(self):
386447 ),
387448 Part : dict (
388449 shape = "plaintext" ,
389- color = "#0000000 " ,
450+ color = "#00000000 " ,
390451 fontcolor = "black" ,
391452 fontsize = round (scale * 8 ),
392453 size = 0.1 * scale ,
@@ -398,6 +459,7 @@ def make_dot(self):
398459 self ._encapsulate_node_names (graph )
399460 self ._encapsulate_edge_attributes (graph )
400461 dot = nx .drawing .nx_pydot .to_pydot (graph )
462+ dot .set_rankdir (direction )
401463 for node in dot .get_nodes ():
402464 node .set_shape ("circle" )
403465 name = node .get_name ().strip ('"' )
@@ -409,9 +471,8 @@ def make_dot(self):
409471 node .set_fixedsize ("shape" if props ["fixed" ] else False )
410472 node .set_width (props ["size" ])
411473 node .set_height (props ["size" ])
412- if name .split ("." )[0 ] in self .context :
413- cls = eval (name , self .context )
414- assert issubclass (cls , Table )
474+ cls = self ._resolve_class (name )
475+ if cls is not None :
415476 description = cls ().describe (context = self .context ).split ("\n " )
416477 description = (
417478 ("-" * 30 if q .startswith ("---" ) else (q .replace ("->" , "→" ) if "->" in q else q .split (":" )[0 ]))
@@ -437,34 +498,148 @@ def make_dot(self):
437498 edge .set_arrowhead ("none" )
438499 edge .set_penwidth (0.75 if props ["multi" ] else 2 )
439500
501+ # Group nodes into schema clusters if requested
502+ if group_by_schema and schema_map :
503+ import pydot
504+
505+ # Group nodes by schema
506+ schemas = {}
507+ for node in list (dot .get_nodes ()):
508+ name = node .get_name ().strip ('"' )
509+ schema_name = schema_map .get (name )
510+ if schema_name :
511+ if schema_name not in schemas :
512+ schemas [schema_name ] = []
513+ schemas [schema_name ].append (node )
514+
515+ # Create clusters for each schema
516+ for schema_name , nodes in schemas .items ():
517+ cluster = pydot .Cluster (
518+ f"cluster_{ schema_name } " ,
519+ label = schema_name ,
520+ style = "dashed" ,
521+ color = "gray" ,
522+ fontcolor = "gray" ,
523+ )
524+ for node in nodes :
525+ cluster .add_node (node )
526+ dot .add_subgraph (cluster )
527+
440528 return dot
441529
442- def make_svg (self ):
530+ def make_svg (self , group_by_schema : bool = False ):
443531 from IPython .display import SVG
444532
445- return SVG (self .make_dot ().create_svg ())
533+ return SVG (self .make_dot (group_by_schema = group_by_schema ).create_svg ())
446534
447- def make_png (self ):
448- return io .BytesIO (self .make_dot ().create_png ())
535+ def make_png (self , group_by_schema : bool = False ):
536+ return io .BytesIO (self .make_dot (group_by_schema = group_by_schema ).create_png ())
449537
450- def make_image (self ):
538+ def make_image (self , group_by_schema : bool = False ):
451539 if plot_active :
452- return plt .imread (self .make_png ())
540+ return plt .imread (self .make_png (group_by_schema = group_by_schema ))
453541 else :
454542 raise DataJointError ("pyplot was not imported" )
455543
544+ def make_mermaid (self ) -> str :
545+ """
546+ Generate Mermaid diagram syntax.
547+
548+ Produces a flowchart in Mermaid syntax that can be rendered in
549+ Markdown documentation, GitHub, or https://mermaid.live.
550+
551+ Returns
552+ -------
553+ str
554+ Mermaid flowchart syntax.
555+
556+ Notes
557+ -----
558+ Layout direction is controlled via ``dj.config.display.diagram_direction``.
559+
560+ Examples
561+ --------
562+ >>> print(dj.Diagram(schema).make_mermaid())
563+ flowchart TB
564+ Mouse[Mouse]:::manual
565+ Session[Session]:::manual
566+ Neuron([Neuron]):::computed
567+ Mouse --> Session
568+ Session --> Neuron
569+ """
570+ graph = self ._make_graph ()
571+ direction = config .display .diagram_direction
572+
573+ lines = [f"flowchart { direction } " ]
574+
575+ # Define class styles matching Graphviz colors
576+ lines .append (" classDef manual fill:#90EE90,stroke:#006400" )
577+ lines .append (" classDef lookup fill:#D3D3D3,stroke:#696969" )
578+ lines .append (" classDef computed fill:#FFB6C1,stroke:#8B0000" )
579+ lines .append (" classDef imported fill:#ADD8E6,stroke:#00008B" )
580+ lines .append (" classDef part fill:#FFFFFF,stroke:#000000" )
581+ lines .append ("" )
582+
583+ # Shape mapping: Manual=box, Computed/Imported=stadium, Lookup/Part=box
584+ shape_map = {
585+ Manual : ("[" , "]" ), # box
586+ Lookup : ("[" , "]" ), # box
587+ Computed : ("([" , "])" ), # stadium/pill
588+ Imported : ("([" , "])" ), # stadium/pill
589+ Part : ("[" , "]" ), # box
590+ _AliasNode : ("((" , "))" ), # circle
591+ None : ("((" , "))" ), # circle
592+ }
593+
594+ tier_class = {
595+ Manual : "manual" ,
596+ Lookup : "lookup" ,
597+ Computed : "computed" ,
598+ Imported : "imported" ,
599+ Part : "part" ,
600+ _AliasNode : "" ,
601+ None : "" ,
602+ }
603+
604+ # Add nodes
605+ for node , data in graph .nodes (data = True ):
606+ tier = data .get ("node_type" )
607+ left , right = shape_map .get (tier , ("[" , "]" ))
608+ cls = tier_class .get (tier , "" )
609+ # Mermaid node IDs can't have dots, replace with underscores
610+ safe_id = node .replace ("." , "_" ).replace (" " , "_" )
611+ class_suffix = f":::{ cls } " if cls else ""
612+ lines .append (f" { safe_id } { left } { node } { right } { class_suffix } " )
613+
614+ lines .append ("" )
615+
616+ # Add edges
617+ for src , dest , data in graph .edges (data = True ):
618+ safe_src = src .replace ("." , "_" ).replace (" " , "_" )
619+ safe_dest = dest .replace ("." , "_" ).replace (" " , "_" )
620+ # Solid arrow for primary FK, dotted for non-primary
621+ style = "-->" if data .get ("primary" ) else "-.->"
622+ lines .append (f" { safe_src } { style } { safe_dest } " )
623+
624+ return "\n " .join (lines )
625+
456626 def _repr_svg_ (self ):
457627 return self .make_svg ()._repr_svg_ ()
458628
459- def draw (self ):
629+ def draw (self , group_by_schema : bool = False ):
460630 if plot_active :
461- plt .imshow (self .make_image ())
631+ plt .imshow (self .make_image (group_by_schema = group_by_schema ))
462632 plt .gca ().axis ("off" )
463633 plt .show ()
464634 else :
465635 raise DataJointError ("pyplot was not imported" )
466636
467- def save (self , filename : str , format : str | None = None ) -> None :
637+ def save (
638+ self ,
639+ filename : str ,
640+ format : str | None = None ,
641+ group_by_schema : bool = False ,
642+ ) -> None :
468643 """
469644 Save diagram to file.
470645
@@ -473,24 +648,39 @@ def save(self, filename: str, format: str | None = None) -> None:
473648 filename : str
474649 Output filename.
475650 format : str, optional
476- File format (``'png'`` or ``'svg'``). Inferred from extension if None.
651+ File format (``'png'``, ``'svg'``, or ``'mermaid'``).
652+ Inferred from extension if None.
653+ group_by_schema : bool, optional
654+ If True, group nodes into clusters by their database schema.
655+ Default False. Only applies to png and svg formats.
477656
478657 Raises
479658 ------
480659 DataJointError
481660 If format is unsupported.
661+
662+ Notes
663+ -----
664+ Layout direction is controlled via ``dj.config.display.diagram_direction``.
482665 """
483666 if format is None :
484667 if filename .lower ().endswith (".png" ):
485668 format = "png"
486669 elif filename .lower ().endswith (".svg" ):
487670 format = "svg"
671+ elif filename .lower ().endswith ((".mmd" , ".mermaid" )):
672+ format = "mermaid"
673+ if format is None :
674+ raise DataJointError ("Could not infer format from filename. Specify format explicitly." )
488675 if format .lower () == "png" :
489676 with open (filename , "wb" ) as f :
490- f .write (self .make_png ().getbuffer ().tobytes ())
677+ f .write (self .make_png (group_by_schema = group_by_schema ).getbuffer ().tobytes ())
491678 elif format .lower () == "svg" :
492679 with open (filename , "w" ) as f :
493- f .write (self .make_svg ().data )
680+ f .write (self .make_svg (group_by_schema = group_by_schema ).data )
681+ elif format .lower () == "mermaid" :
682+ with open (filename , "w" ) as f :
683+ f .write (self .make_mermaid ())
494684 else :
495685 raise DataJointError ("Unsupported file format" )
496686
0 commit comments