Skip to content

Commit d5bdf51

Browse files
refactor: simplify collapse logic to use single _expanded_nodes set
Replace complex _explicit_nodes + _is_collapsed with simpler design: - Fresh diagrams: all nodes expanded - collapse(): clears _expanded_nodes - + operator: union of _expanded_nodes (expanded wins) Bump version to 2.1.0a7 Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
1 parent 9e87106 commit d5bdf51

File tree

2 files changed

+29
-55
lines changed

2 files changed

+29
-55
lines changed

src/datajoint/diagram.py

Lines changed: 28 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -103,8 +103,7 @@ def __init__(self, source, context=None) -> None:
103103
if isinstance(source, Diagram):
104104
# copy constructor
105105
self.nodes_to_show = set(source.nodes_to_show)
106-
self._explicit_nodes = set(source._explicit_nodes)
107-
self._is_collapsed = source._is_collapsed
106+
self._expanded_nodes = set(source._expanded_nodes)
108107
self.context = source.context
109108
super().__init__(source)
110109
return
@@ -132,8 +131,6 @@ def __init__(self, source, context=None) -> None:
132131

133132
# Enumerate nodes from all the items in the list
134133
self.nodes_to_show = set()
135-
self._explicit_nodes = set() # nodes that should never be collapsed
136-
self._is_collapsed = False # whether this diagram's nodes should be collapsed when combined
137134
try:
138135
self.nodes_to_show.add(source.full_table_name)
139136
except AttributeError:
@@ -148,6 +145,8 @@ def __init__(self, source, context=None) -> None:
148145
# Handle both MySQL backticks and PostgreSQL double quotes
149146
if node.startswith("`%s`" % database) or node.startswith('"%s"' % database):
150147
self.nodes_to_show.add(node)
148+
# All nodes start as expanded
149+
self._expanded_nodes = set(self.nodes_to_show)
151150

152151
@classmethod
153152
def from_sequence(cls, sequence) -> "Diagram":
@@ -187,27 +186,30 @@ def is_part(part, master):
187186

188187
def collapse(self) -> "Diagram":
189188
"""
190-
Mark this diagram for collapsing when combined with other diagrams.
189+
Mark all nodes in this diagram as collapsed.
191190
192-
When a collapsed diagram is added to a non-collapsed diagram, its nodes
193-
are shown as a single collapsed node per schema, unless they also appear
194-
in the non-collapsed diagram (expanded wins).
191+
Collapsed nodes are shown as a single node per schema. When combined
192+
with other diagrams using ``+``, expanded nodes win: if a node is
193+
expanded in either operand, it remains expanded in the result.
195194
196195
Returns
197196
-------
198197
Diagram
199-
A copy of this diagram marked for collapsing.
198+
A copy of this diagram with all nodes collapsed.
200199
201200
Examples
202201
--------
203202
>>> # Show schema1 expanded, schema2 collapsed into single nodes
204203
>>> dj.Diagram(schema1) + dj.Diagram(schema2).collapse()
205204
206-
>>> # Explicitly expand one table from schema2
207-
>>> dj.Diagram(schema1) + dj.Diagram(TableFromSchema2) + dj.Diagram(schema2).collapse()
205+
>>> # Collapse all three schemas together
206+
>>> (dj.Diagram(schema1) + dj.Diagram(schema2) + dj.Diagram(schema3)).collapse()
207+
208+
>>> # Expand one table from collapsed schema
209+
>>> dj.Diagram(schema).collapse() + dj.Diagram(SingleTable)
208210
"""
209211
result = Diagram(self)
210-
result._is_collapsed = True
212+
result._expanded_nodes = set() # All nodes collapsed
211213
return result
212214

213215
def __add__(self, arg) -> "Diagram":
@@ -232,30 +234,12 @@ def __add__(self, arg) -> "Diagram":
232234
result.nodes_to_show.update(arg.nodes_to_show)
233235
# Merge contexts for class name lookups
234236
result.context = {**result.context, **arg.context}
235-
# Handle collapse: track which nodes should be explicit (expanded)
236-
# - Always preserve existing _explicit_nodes from both sides
237-
# - For a fresh (non-combined) non-collapsed diagram, add all its nodes to explicit
238-
# - A fresh diagram has empty _explicit_nodes and _is_collapsed=False
239-
# This ensures "expanded wins" and chained collapsed diagrams stay collapsed
240-
result._explicit_nodes = set()
241-
# Add self's explicit nodes
242-
result._explicit_nodes.update(self._explicit_nodes)
243-
# If self is a fresh non-collapsed diagram (not combined, not marked collapsed),
244-
# treat all its nodes as explicit
245-
if not self._is_collapsed and not self._explicit_nodes:
246-
result._explicit_nodes.update(self.nodes_to_show)
247-
# Add arg's explicit nodes
248-
result._explicit_nodes.update(arg._explicit_nodes)
249-
# If arg is a fresh non-collapsed diagram, treat all its nodes as explicit
250-
if not arg._is_collapsed and not arg._explicit_nodes:
251-
result._explicit_nodes.update(arg.nodes_to_show)
252-
# Result is "collapsed" if BOTH operands were collapsed (no explicit nodes added)
253-
# This allows chained collapsed diagrams to stay collapsed: A.collapse() + B.collapse() + C.collapse()
254-
result._is_collapsed = self._is_collapsed and arg._is_collapsed
237+
# Expanded wins: union of expanded nodes from both operands
238+
result._expanded_nodes = self._expanded_nodes | arg._expanded_nodes
255239
except AttributeError:
256240
try:
257241
result.nodes_to_show.add(arg.full_table_name)
258-
result._explicit_nodes.add(arg.full_table_name)
242+
result._expanded_nodes.add(arg.full_table_name)
259243
except AttributeError:
260244
for i in range(arg):
261245
new = nx.algorithms.boundary.node_boundary(result, result.nodes_to_show)
@@ -264,9 +248,8 @@ def __add__(self, arg) -> "Diagram":
264248
# add nodes referenced by aliased nodes
265249
new.update(nx.algorithms.boundary.node_boundary(result, (a for a in new if a.isdigit())))
266250
result.nodes_to_show.update(new)
267-
# Expanded nodes from + N expansion are explicit
268-
if not self._is_collapsed:
269-
result._explicit_nodes = result.nodes_to_show.copy()
251+
# New nodes from expansion are expanded
252+
result._expanded_nodes = result._expanded_nodes | result.nodes_to_show
270253
return result
271254

272255
def __sub__(self, arg) -> "Diagram":
@@ -369,7 +352,7 @@ def _apply_collapse(self, graph: nx.DiGraph) -> tuple[nx.DiGraph, dict[str, str]
369352
"""
370353
Apply collapse logic to the graph.
371354
372-
Nodes in nodes_to_show but not in _explicit_nodes are collapsed into
355+
Nodes in nodes_to_show but not in _expanded_nodes are collapsed into
373356
single schema nodes.
374357
375358
Parameters
@@ -384,19 +367,10 @@ def _apply_collapse(self, graph: nx.DiGraph) -> tuple[nx.DiGraph, dict[str, str]
384367
"""
385368
# Filter to valid nodes (those that exist in the underlying graph)
386369
valid_nodes = self.nodes_to_show.intersection(set(self.nodes()))
387-
valid_explicit = self._explicit_nodes.intersection(set(self.nodes()))
388-
389-
# Determine if collapse should be applied:
390-
# - If _explicit_nodes is empty AND _is_collapsed is False, this is a fresh
391-
# diagram that was never combined with collapsed diagrams → no collapse
392-
# - If _explicit_nodes is empty AND _is_collapsed is True, this is the result
393-
# of combining only collapsed diagrams → collapse all nodes
394-
# - If _explicit_nodes equals valid_nodes, all nodes are explicit → no collapse
395-
if not valid_explicit and not self._is_collapsed:
396-
# Fresh diagram, never combined with collapsed diagrams
397-
return graph, {}
398-
if valid_explicit == valid_nodes:
399-
# All nodes are explicit (expanded) - no collapse needed
370+
valid_expanded = self._expanded_nodes.intersection(set(self.nodes()))
371+
372+
# If all nodes are expanded, no collapse needed
373+
if valid_expanded >= valid_nodes:
400374
return graph, {}
401375

402376
# Map full_table_names to class_names
@@ -406,13 +380,13 @@ def _apply_collapse(self, graph: nx.DiGraph) -> tuple[nx.DiGraph, dict[str, str]
406380
}
407381
class_to_full = {v: k for k, v in full_to_class.items()}
408382

409-
# Identify explicit class names (should be expanded)
410-
explicit_class_names = {
411-
full_to_class.get(node, node) for node in valid_explicit
383+
# Identify expanded class names
384+
expanded_class_names = {
385+
full_to_class.get(node, node) for node in valid_expanded
412386
}
413387

414388
# Identify nodes to collapse (class names)
415-
nodes_to_collapse = set(graph.nodes()) - explicit_class_names
389+
nodes_to_collapse = set(graph.nodes()) - expanded_class_names
416390

417391
if not nodes_to_collapse:
418392
return graph, {}

src/datajoint/version.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
11
# version bump auto managed by Github Actions:
22
# label_prs.yaml(prep), release.yaml(bump), post_release.yaml(edit)
33
# manually set this version will be eventually overwritten by the above actions
4-
__version__ = "2.1.0a6"
4+
__version__ = "2.1.0a7"

0 commit comments

Comments
 (0)