Skip to content

Commit 0c593c8

Browse files
committed
Implement semantic matching for joins
This implements semantic matching for DataJoint 2.0 joins as specified in docs/src/design/semantic-matching-spec.md. Key changes: 1. Lineage tracking: - Add `lineage` field to Attribute class (heading.py) - Create lineage.py module for ~lineage table management - Populate lineage at table declaration time - Clean up lineage entries when tables are dropped - Load lineage from database when fetching headings 2. Semantic matching in joins: - Update assert_join_compatibility() to check for non-homologous namesakes - Update join() to only match on homologous namesakes (same name AND lineage) - Lineage is preserved through projections and renames 3. API changes: - Remove @ operator (raises error directing to .join(semantic_check=False)) - dj.U * table raises deprecation error (use dj.U & table instead) - dj.U - table raises error (infinite set) - dj.U is always compatible (contains all possible lineages) 4. Tests: - Add comprehensive tests for lineage tracking - Test homologous and non-homologous namesake handling - Test deprecated operator errors - Test dj.U operations with semantic matching
1 parent c597c52 commit 0c593c8

File tree

6 files changed

+636
-37
lines changed

6 files changed

+636
-37
lines changed

src/datajoint/condition.py

Lines changed: 24 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -97,10 +97,13 @@ def __init__(self, restriction):
9797

9898
def assert_join_compatibility(expr1, expr2):
9999
"""
100-
Determine if expressions expr1 and expr2 are join-compatible. To be join-compatible,
101-
the matching attributes in the two expressions must be in the primary key of one or the
102-
other expression.
103-
Raises an exception if not compatible.
100+
Check semantic compatibility of two expressions for joining.
101+
102+
Uses semantic matching: attributes are only matched when they share both
103+
the same name AND the same lineage (origin).
104+
105+
Raises DataJointError if non-homologous namesakes are detected (same name
106+
but different lineage).
104107
105108
:param expr1: A QueryExpression object
106109
:param expr2: A QueryExpression object
@@ -110,14 +113,25 @@ def assert_join_compatibility(expr1, expr2):
110113
for rel in (expr1, expr2):
111114
if not isinstance(rel, (U, QueryExpression)):
112115
raise DataJointError("Object %r is not a QueryExpression and cannot be joined." % rel)
113-
if not isinstance(expr1, U) and not isinstance(expr2, U): # dj.U is always compatible
114-
try:
116+
117+
# dj.U is always compatible - it contains all possible lineages
118+
if isinstance(expr1, U) or isinstance(expr2, U):
119+
return
120+
121+
# Find namesake attributes (same name in both expressions)
122+
namesakes = set(expr1.heading.names) & set(expr2.heading.names)
123+
124+
for name in namesakes:
125+
lineage1 = expr1.heading[name].lineage
126+
lineage2 = expr2.heading[name].lineage
127+
128+
# Non-homologous namesakes: same name, different lineage
129+
if lineage1 != lineage2:
115130
raise DataJointError(
116-
"Cannot join query expressions on dependent attribute `%s`"
117-
% next(r for r in set(expr1.heading.secondary_attributes).intersection(expr2.heading.secondary_attributes))
131+
f"Cannot join on attribute `{name}`: different lineages "
132+
f"({lineage1} vs {lineage2}). "
133+
f"Use .proj() to rename one of the attributes."
118134
)
119-
except StopIteration:
120-
pass # all ok
121135

122136

123137
def make_condition(query_expression, condition, columns):

src/datajoint/expression.py

Lines changed: 57 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -275,30 +275,45 @@ def __mul__(self, other):
275275

276276
def __matmul__(self, other):
277277
"""
278-
Permissive join of query expressions `self` and `other` ignoring compatibility check
279-
e.g. ``q1 @ q2``.
278+
The @ operator has been removed in DataJoint 2.0.
279+
Use .join(other, semantic_check=False) for permissive joins.
280280
"""
281-
if inspect.isclass(other) and issubclass(other, QueryExpression):
282-
other = other() # instantiate
283-
return self.join(other, semantic_check=False)
281+
raise DataJointError(
282+
"The @ operator has been removed in DataJoint 2.0. "
283+
"Use .join(other, semantic_check=False) for permissive joins."
284+
)
284285

285286
def join(self, other, semantic_check=True, left=False):
286287
"""
287-
create the joined QueryExpression.
288-
a * b is short for A.join(B)
289-
a @ b is short for A.join(B, semantic_check=False)
290-
Additionally, left=True will retain the rows of self, effectively performing a left join.
288+
Create the joined QueryExpression.
289+
290+
Uses semantic matching: only attributes with the same name AND the same
291+
lineage (homologous namesakes) are used for joining.
292+
293+
:param other: QueryExpression to join with
294+
:param semantic_check: If True (default), raise error on non-homologous namesakes.
295+
If False, bypass semantic check (use for legacy compatibility).
296+
:param left: If True, perform a left join retaining all rows from self.
297+
298+
Examples:
299+
a * b is short for a.join(b)
300+
a.join(b, semantic_check=False) for permissive joins
291301
"""
292-
# trigger subqueries if joining on renamed attributes
302+
# Handle U objects: redirect to U's restriction operation
293303
if isinstance(other, U):
294-
return other * self
304+
return other & self
295305
if inspect.isclass(other) and issubclass(other, QueryExpression):
296306
other = other() # instantiate
297307
if not isinstance(other, QueryExpression):
298308
raise DataJointError("The argument of join must be a QueryExpression")
299309
if semantic_check:
300310
assert_join_compatibility(self, other)
301-
join_attributes = set(n for n in self.heading.names if n in other.heading.names)
311+
# Only join on homologous namesakes (same name AND same lineage)
312+
join_attributes = set(
313+
n
314+
for n in self.heading.names
315+
if n in other.heading.names and self.heading[n].lineage == other.heading[n].lineage
316+
)
302317
# needs subquery if self's FROM clause has common attributes with other's FROM clause
303318
need_subquery1 = need_subquery2 = bool(
304319
(set(self.original_heading.names) & set(other.original_heading.names)) - join_attributes
@@ -735,9 +750,9 @@ class U:
735750
"""
736751
dj.U objects are the universal sets representing all possible values of their attributes.
737752
dj.U objects cannot be queried on their own but are useful for forming some queries.
738-
dj.U('attr1', ..., 'attrn') represents the universal set with the primary key attributes attr1 ... attrn.
739-
The universal set is the set of all possible combinations of values of the attributes.
740-
Without any attributes, dj.U() represents the set with one element that has no attributes.
753+
dj.U() or dj.U('attr1', ..., 'attrn') represents the universal set with the primary key
754+
attributes attr1 ... attrn. Without any attributes, dj.U() represents the set with one
755+
element that has no attributes.
741756
742757
Restriction:
743758
@@ -747,11 +762,15 @@ class U:
747762
748763
>>> dj.U('contrast', 'brightness') & stimulus
749764
765+
Empty U for distinct primary keys:
766+
767+
>>> dj.U() & expr
768+
750769
Aggregation:
751770
752771
In aggregation, dj.U is used for summary calculation over an entire set:
753772
754-
The following expression yields one element with one attribute `s` containing the total number of elements in
773+
The following expression yields one element with one attribute `n` containing the total number of elements in
755774
query expression `expr`:
756775
757776
>>> dj.U().aggr(expr, n='count(*)')
@@ -760,7 +779,7 @@ class U:
760779
query expression `expr`.
761780
762781
>>> dj.U().aggr(expr, n='count(distinct attr)')
763-
>>> dj.U().aggr(dj.U('attr').aggr(expr), 'n=count(*)')
782+
>>> dj.U().aggr(dj.U('attr').aggr(expr), n='count(*)')
764783
765784
The following expression yields one element and one attribute `s` containing the sum of values of attribute `attr`
766785
over entire result set of expression `expr`:
@@ -770,16 +789,13 @@ class U:
770789
The following expression yields the set of all unique combinations of attributes `attr1`, `attr2` and the number of
771790
their occurrences in the result set of query expression `expr`.
772791
773-
>>> dj.U(attr1,attr2).aggr(expr, n='count(*)')
792+
>>> dj.U('attr1', 'attr2').aggr(expr, n='count(*)')
774793
775-
Joins:
794+
Homology:
776795
777-
If expression `expr` has attributes 'attr1' and 'attr2', then expr * dj.U('attr1','attr2') yields the same result
778-
as `expr` but `attr1` and `attr2` are promoted to the the primary key. This is useful for producing a join on
779-
non-primary key attributes.
780-
For example, if `attr` is in both expr1 and expr2 but not in their primary keys, then expr1 * expr2 will throw
781-
an error because in most cases, it does not make sense to join on non-primary key attributes and users must first
782-
rename `attr` in one of the operands. The expression dj.U('attr') * rel1 * rel2 overrides this constraint.
796+
Since dj.U conceptually contains all possible lineages, its attributes are homologous to
797+
any namesake attribute in other expressions. This makes dj.U always compatible for
798+
semantic matching in joins and restrictions.
783799
"""
784800

785801
def __init__(self, *primary_key):
@@ -826,8 +842,22 @@ def join(self, other, left=False):
826842
return result
827843

828844
def __mul__(self, other):
829-
"""shorthand for join"""
830-
return self.join(other)
845+
"""
846+
dj.U * table is deprecated in DataJoint 2.0.
847+
Use dj.U & table instead.
848+
"""
849+
raise DataJointError(
850+
"dj.U(...) * table is deprecated in DataJoint 2.0. "
851+
"Use dj.U(...) & table instead."
852+
)
853+
854+
def __sub__(self, other):
855+
"""
856+
dj.U - table produces an infinite set and is not supported.
857+
"""
858+
raise DataJointError(
859+
"dj.U(...) - table produces an infinite set and is not supported."
860+
)
831861

832862
def aggr(self, group, **named_attributes):
833863
"""

src/datajoint/heading.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
from .attribute_adapter import get_adapter
99
from .attribute_type import AttributeType
10+
from .lineage import get_all_lineages
1011
from .declare import (
1112
EXTERNAL_TYPES,
1213
NATIVE_TYPES,
@@ -73,6 +74,7 @@ def decode(self, stored, *, key=None):
7374
attribute_expression=None,
7475
database=None,
7576
dtype=object,
77+
lineage=None, # Origin of attribute: "schema.table.attribute" or None for native secondary
7678
)
7779

7880

@@ -406,6 +408,11 @@ def _init_from_database(self):
406408
# restore adapted type name
407409
attr["type"] = adapter_name
408410

411+
# Load lineage data from ~lineage table
412+
lineages = get_all_lineages(conn, database, table_name)
413+
for attr in attributes:
414+
attr["lineage"] = lineages.get(attr["name"])
415+
409416
self._attributes = dict(((q["name"], Attribute(**q)) for q in attributes))
410417

411418
# Read and tabulate secondary indexes

src/datajoint/lineage.py

Lines changed: 177 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,177 @@
1+
"""
2+
Lineage tracking for semantic matching in joins.
3+
4+
Lineage identifies the origin of an attribute - where it was first defined.
5+
It is represented as a string in the format: "schema.table.attribute"
6+
7+
Only attributes WITH lineage are stored in the ~lineage table:
8+
- Native primary key attributes: lineage is this table
9+
- FK-inherited attributes: lineage is traced to the origin
10+
- Native secondary attributes: no lineage (no entry in table)
11+
"""
12+
13+
import logging
14+
15+
logger = logging.getLogger(__name__.split(".")[0])
16+
17+
LINEAGE_TABLE_NAME = "~lineage"
18+
19+
20+
def _lineage_table_sql(database):
21+
"""Generate SQL to create the ~lineage table."""
22+
return f"""
23+
CREATE TABLE IF NOT EXISTS `{database}`.`{LINEAGE_TABLE_NAME}` (
24+
table_name VARCHAR(64) NOT NULL,
25+
attribute_name VARCHAR(64) NOT NULL,
26+
lineage VARCHAR(255) NOT NULL,
27+
PRIMARY KEY (table_name, attribute_name)
28+
) ENGINE=InnoDB
29+
"""
30+
31+
32+
def ensure_lineage_table(connection, database):
33+
"""Create the ~lineage table if it doesn't exist."""
34+
connection.query(_lineage_table_sql(database))
35+
36+
37+
def lineage_table_exists(connection, database):
38+
"""Check if the ~lineage table exists in the schema."""
39+
result = connection.query(
40+
"""
41+
SELECT COUNT(*) FROM information_schema.tables
42+
WHERE table_schema = %s AND table_name = %s
43+
""",
44+
args=(database, LINEAGE_TABLE_NAME),
45+
)
46+
return result.fetchone()[0] > 0
47+
48+
49+
def get_lineage(connection, database, table_name, attribute_name):
50+
"""
51+
Get lineage for an attribute from the ~lineage table.
52+
53+
Returns the lineage string if found, None otherwise (indicating no lineage
54+
or attribute is a native secondary).
55+
"""
56+
if not lineage_table_exists(connection, database):
57+
return None
58+
59+
result = connection.query(
60+
f"""
61+
SELECT lineage FROM `{database}`.`{LINEAGE_TABLE_NAME}`
62+
WHERE table_name = %s AND attribute_name = %s
63+
""",
64+
args=(table_name, attribute_name),
65+
)
66+
row = result.fetchone()
67+
return row[0] if row else None
68+
69+
70+
def get_all_lineages(connection, database, table_name):
71+
"""
72+
Get all lineage entries for a table.
73+
74+
Returns a dict mapping attribute_name -> lineage.
75+
Attributes not in the dict have no lineage (native secondary).
76+
"""
77+
if not lineage_table_exists(connection, database):
78+
return {}
79+
80+
result = connection.query(
81+
f"""
82+
SELECT attribute_name, lineage FROM `{database}`.`{LINEAGE_TABLE_NAME}`
83+
WHERE table_name = %s
84+
""",
85+
args=(table_name,),
86+
)
87+
return {row[0]: row[1] for row in result}
88+
89+
90+
def delete_lineage_entries(connection, database, table_name):
91+
"""Delete all lineage entries for a table."""
92+
if not lineage_table_exists(connection, database):
93+
return
94+
95+
connection.query(
96+
f"""
97+
DELETE FROM `{database}`.`{LINEAGE_TABLE_NAME}`
98+
WHERE table_name = %s
99+
""",
100+
args=(table_name,),
101+
)
102+
103+
104+
def insert_lineage_entries(connection, database, entries):
105+
"""
106+
Insert lineage entries for a table.
107+
108+
:param entries: list of (table_name, attribute_name, lineage) tuples
109+
"""
110+
if not entries:
111+
return
112+
113+
ensure_lineage_table(connection, database)
114+
115+
# Use INSERT ... ON DUPLICATE KEY UPDATE to handle re-declarations
116+
for table_name, attribute_name, lineage in entries:
117+
connection.query(
118+
f"""
119+
INSERT INTO `{database}`.`{LINEAGE_TABLE_NAME}`
120+
(table_name, attribute_name, lineage)
121+
VALUES (%s, %s, %s)
122+
ON DUPLICATE KEY UPDATE lineage = VALUES(lineage)
123+
""",
124+
args=(table_name, attribute_name, lineage),
125+
)
126+
127+
128+
def compute_lineage_from_dependencies(connection, full_table_name, attribute_name, primary_key):
129+
"""
130+
Compute lineage by traversing FK relationships.
131+
132+
Fallback method when ~lineage table doesn't exist.
133+
134+
:param connection: database connection
135+
:param full_table_name: fully qualified table name like `schema`.`table`
136+
:param attribute_name: the attribute to compute lineage for
137+
:param primary_key: list of primary key attribute names for this table
138+
:return: lineage string or None
139+
"""
140+
connection.dependencies.load(force=False)
141+
142+
# Parse database and table name
143+
parts = full_table_name.replace("`", "").split(".")
144+
database = parts[0]
145+
table_name = parts[1]
146+
147+
# Check if attribute is inherited via FK
148+
parents = connection.dependencies.parents(full_table_name)
149+
for parent_table, props in parents.items():
150+
# Skip alias nodes (numeric strings)
151+
if parent_table.isdigit():
152+
# Get the actual parent through the alias
153+
grandparents = connection.dependencies.parents(parent_table)
154+
if grandparents:
155+
parent_table, props = next(iter(grandparents.items()))
156+
157+
attr_map = props.get("attr_map", {})
158+
if attribute_name in attr_map:
159+
parent_attr = attr_map[attribute_name]
160+
parent_parts = parent_table.replace("`", "").split(".")
161+
parent_db = parent_parts[0]
162+
parent_tbl = parent_parts[1]
163+
164+
# Get parent's primary key
165+
parent_pk = connection.dependencies.nodes.get(parent_table, {}).get("primary_key", set())
166+
167+
# Recursively trace to origin
168+
return compute_lineage_from_dependencies(
169+
connection, parent_table, parent_attr, list(parent_pk)
170+
)
171+
172+
# Not inherited - check if primary key
173+
if attribute_name in primary_key:
174+
return f"{database}.{table_name}.{attribute_name}"
175+
176+
# Native secondary - no lineage
177+
return None

0 commit comments

Comments
 (0)