Skip to content

Commit 99bc1f5

Browse files
committed
Merge lineage functionality into heading.py
- Move LineageTable class and lineage computation functions to heading.py - Simplify LineageTable to use direct SQL instead of inheriting from Table - Break circular import between heading and lineage modules - Keep lineage.py as thin re-export for backward compatibility
1 parent 02c8aa4 commit 99bc1f5

File tree

3 files changed

+380
-319
lines changed

3 files changed

+380
-319
lines changed

datajoint/heading.py

Lines changed: 375 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,380 @@
1717

1818
logger = logging.getLogger(__name__.split(".")[0])
1919

20+
21+
# =============================================================================
22+
# Lineage tracking for semantic matching in joins
23+
# =============================================================================
24+
25+
26+
def _parse_full_table_name(full_name):
27+
"""
28+
Parse a full table name like `schema`.`table` into (schema, table).
29+
30+
:param full_name: full table name in format `schema`.`table`
31+
:return: tuple (schema, table)
32+
"""
33+
match = re.match(r"`(\w+)`\.`(\w+)`", full_name)
34+
if not match:
35+
raise DataJointError(f"Invalid table name format: {full_name}")
36+
return match.group(1), match.group(2)
37+
38+
39+
def _get_primary_key_attrs(connection, schema, table_name):
40+
"""
41+
Get the primary key attributes for a table by querying the database.
42+
43+
:param connection: database connection
44+
:param schema: schema name
45+
:param table_name: table name
46+
:return: set of primary key attribute names
47+
"""
48+
result = connection.query(
49+
"""
50+
SELECT column_name
51+
FROM information_schema.key_column_usage
52+
WHERE table_schema = %s
53+
AND table_name = %s
54+
AND constraint_name = 'PRIMARY'
55+
""",
56+
args=(schema, table_name),
57+
)
58+
return {row[0] for row in result}
59+
60+
61+
def _compute_lineage_from_dependencies(connection, schema, table_name, attribute_name):
62+
"""
63+
Compute lineage by traversing the FK graph.
64+
65+
Uses connection.dependencies which loads FK info from INFORMATION_SCHEMA.
66+
This is the fallback when the ~lineage table doesn't exist.
67+
68+
:param connection: database connection
69+
:param schema: schema name
70+
:param table_name: table name
71+
:param attribute_name: attribute name
72+
:return: lineage string "schema.table.attribute" or None for native secondary attrs
73+
"""
74+
connection.dependencies.load(force=False)
75+
76+
full_table_name = f"`{schema}`.`{table_name}`"
77+
78+
# Check if the table exists in the dependency graph
79+
if full_table_name not in connection.dependencies:
80+
# Table not in graph - compute lineage based on primary key status
81+
pk_attrs = _get_primary_key_attrs(connection, schema, table_name)
82+
if attribute_name in pk_attrs:
83+
return f"{schema}.{table_name}.{attribute_name}"
84+
else:
85+
return None
86+
87+
# Check incoming edges (foreign keys TO this table's parents)
88+
for parent, props in connection.dependencies.parents(full_table_name).items():
89+
attr_map = props.get("attr_map", {})
90+
if attribute_name in attr_map:
91+
# This attribute is inherited from parent - recurse to find origin
92+
parent_attr = attr_map[attribute_name]
93+
# Handle alias nodes (numeric string nodes in the graph)
94+
if parent.isdigit():
95+
# Find the actual parent by traversing through alias
96+
for grandparent, gprops in connection.dependencies.parents(
97+
parent
98+
).items():
99+
if not grandparent.isdigit():
100+
parent = grandparent
101+
parent_attr = gprops.get("attr_map", {}).get(
102+
attribute_name, parent_attr
103+
)
104+
break
105+
parent_schema, parent_table = _parse_full_table_name(parent)
106+
return _compute_lineage_from_dependencies(
107+
connection, parent_schema, parent_table, parent_attr
108+
)
109+
110+
# Not inherited - check if it's a primary key attribute
111+
node_data = connection.dependencies.nodes.get(full_table_name, {})
112+
pk_attrs = node_data.get("primary_key", set())
113+
114+
if attribute_name in pk_attrs:
115+
# Native primary key attribute - has lineage to itself
116+
return f"{schema}.{table_name}.{attribute_name}"
117+
else:
118+
# Native secondary attribute - no lineage
119+
return None
120+
121+
122+
def _compute_all_lineage_for_table(connection, schema, table_name):
123+
"""
124+
Compute lineage for all attributes in a table.
125+
126+
:param connection: database connection
127+
:param schema: schema name
128+
:param table_name: table name
129+
:return: dict mapping attribute_name -> lineage (or None)
130+
"""
131+
# Get all attributes using Heading
132+
heading = Heading(
133+
table_info=dict(
134+
conn=connection,
135+
database=schema,
136+
table_name=table_name,
137+
context=None,
138+
)
139+
)
140+
141+
# Compute lineage for each attribute
142+
return {
143+
attr: _compute_lineage_from_dependencies(connection, schema, table_name, attr)
144+
for attr in heading.names
145+
}
146+
147+
148+
def _get_lineage_for_heading(connection, schema, table_name):
149+
"""
150+
Get lineage information for all attributes in a table.
151+
152+
First tries to load from ~lineage table, falls back to FK graph computation.
153+
154+
:param connection: database connection
155+
:param schema: schema name
156+
:param table_name: table name
157+
:return: dict mapping attribute_name -> lineage (or None)
158+
"""
159+
# Check if ~lineage table exists
160+
lineage_table_exists = (
161+
connection.query(
162+
"""
163+
SELECT COUNT(*) FROM information_schema.tables
164+
WHERE table_schema = %s AND table_name = '~lineage'
165+
""",
166+
args=(schema,),
167+
).fetchone()[0]
168+
> 0
169+
)
170+
171+
if lineage_table_exists:
172+
# Load from ~lineage table
173+
lineage_table = LineageTable(connection, schema)
174+
return lineage_table.get_table_lineage(table_name)
175+
else:
176+
# Compute from FK graph
177+
return _compute_all_lineage_for_table(connection, schema, table_name)
178+
179+
180+
class LineageTable:
181+
"""
182+
Hidden table for storing attribute lineage information.
183+
184+
Each row maps (table_name, attribute_name) -> lineage string.
185+
Only attributes with lineage are stored; absence means no lineage.
186+
"""
187+
188+
definition = """
189+
# Attribute lineage tracking for semantic matching
190+
table_name : varchar(64) # name of the table
191+
attribute_name : varchar(64) # name of the attribute
192+
---
193+
lineage : varchar(200) # "schema.table.attribute"
194+
"""
195+
196+
def __init__(self, connection, database):
197+
# Lazy import to avoid circular dependency
198+
from .table import Table
199+
200+
self._table_class = Table
201+
self.database = database
202+
self._connection = connection
203+
self._heading = Heading(
204+
table_info=dict(
205+
conn=connection,
206+
database=database,
207+
table_name=self.table_name,
208+
context=None,
209+
)
210+
)
211+
self._support = [self.full_table_name]
212+
213+
if not self.is_declared:
214+
self._declare()
215+
216+
@property
217+
def table_name(self):
218+
return "~lineage"
219+
220+
@property
221+
def full_table_name(self):
222+
return f"`{self.database}`.`{self.table_name}`"
223+
224+
@property
225+
def is_declared(self):
226+
return (
227+
self._connection.query(
228+
"""
229+
SELECT COUNT(*) FROM information_schema.tables
230+
WHERE table_schema = %s AND table_name = %s
231+
""",
232+
args=(self.database, self.table_name),
233+
).fetchone()[0]
234+
> 0
235+
)
236+
237+
def _declare(self):
238+
"""Create the ~lineage table."""
239+
self._connection.query(
240+
f"""
241+
CREATE TABLE IF NOT EXISTS {self.full_table_name} (
242+
table_name VARCHAR(64) NOT NULL,
243+
attribute_name VARCHAR(64) NOT NULL,
244+
lineage VARCHAR(200) NOT NULL,
245+
PRIMARY KEY (table_name, attribute_name)
246+
) ENGINE=InnoDB
247+
"""
248+
)
249+
250+
def insert1(self, row, replace=False):
251+
"""Insert a single row."""
252+
if replace:
253+
self._connection.query(
254+
f"""
255+
REPLACE INTO {self.full_table_name}
256+
(table_name, attribute_name, lineage)
257+
VALUES (%s, %s, %s)
258+
""",
259+
args=(row["table_name"], row["attribute_name"], row["lineage"]),
260+
)
261+
else:
262+
self._connection.query(
263+
f"""
264+
INSERT INTO {self.full_table_name}
265+
(table_name, attribute_name, lineage)
266+
VALUES (%s, %s, %s)
267+
""",
268+
args=(row["table_name"], row["attribute_name"], row["lineage"]),
269+
)
270+
271+
def delete_quick(self, table_name=None, attribute_name=None):
272+
"""Delete rows without prompts."""
273+
if table_name and attribute_name:
274+
self._connection.query(
275+
f"DELETE FROM {self.full_table_name} WHERE table_name=%s AND attribute_name=%s",
276+
args=(table_name, attribute_name),
277+
)
278+
elif table_name:
279+
self._connection.query(
280+
f"DELETE FROM {self.full_table_name} WHERE table_name=%s",
281+
args=(table_name,),
282+
)
283+
else:
284+
self._connection.query(f"DELETE FROM {self.full_table_name}")
285+
286+
def store_lineage(self, table_name, attribute_name, lineage):
287+
"""
288+
Store lineage for an attribute. Only stores if lineage is not None.
289+
290+
:param table_name: name of the table (without schema)
291+
:param attribute_name: name of the attribute
292+
:param lineage: lineage string "schema.table.attribute" or None
293+
"""
294+
if lineage is None:
295+
# No lineage - delete any existing entry
296+
self.delete_quick(table_name, attribute_name)
297+
else:
298+
self.insert1(
299+
dict(
300+
table_name=table_name,
301+
attribute_name=attribute_name,
302+
lineage=lineage,
303+
),
304+
replace=True,
305+
)
306+
307+
def get_lineage(self, table_name, attribute_name):
308+
"""
309+
Get lineage for an attribute.
310+
311+
:param table_name: name of the table (without schema)
312+
:param attribute_name: name of the attribute
313+
:return: lineage string or None if no lineage
314+
"""
315+
result = self._connection.query(
316+
f"SELECT lineage FROM {self.full_table_name} WHERE table_name=%s AND attribute_name=%s",
317+
args=(table_name, attribute_name),
318+
).fetchone()
319+
return result[0] if result else None
320+
321+
def get_table_lineage(self, table_name):
322+
"""
323+
Get lineage for all attributes in a table.
324+
325+
:param table_name: name of the table (without schema)
326+
:return: dict mapping attribute_name -> lineage (only attributes with lineage)
327+
"""
328+
result = self._connection.query(
329+
f"SELECT attribute_name, lineage FROM {self.full_table_name} WHERE table_name=%s",
330+
args=(table_name,),
331+
).fetchall()
332+
return {row[0]: row[1] for row in result}
333+
334+
def delete_table_lineage(self, table_name):
335+
"""
336+
Delete all lineage records for a table.
337+
338+
:param table_name: name of the table (without schema)
339+
"""
340+
self.delete_quick(table_name)
341+
342+
343+
def compute_schema_lineage(connection, schema):
344+
"""
345+
Compute and populate the ~lineage table for a schema.
346+
347+
Analyzes foreign key relationships to determine attribute origins.
348+
349+
:param connection: database connection
350+
:param schema: schema object or schema name
351+
"""
352+
from .schemas import Schema
353+
354+
if isinstance(schema, Schema):
355+
schema_name = schema.database
356+
else:
357+
schema_name = schema
358+
359+
# Create or get the lineage table
360+
lineage_table = LineageTable(connection, schema_name)
361+
362+
# Get all user tables in the schema (excluding hidden tables)
363+
result = connection.query(
364+
"""
365+
SELECT table_name
366+
FROM information_schema.tables
367+
WHERE table_schema = %s
368+
AND table_name NOT LIKE '~%%'
369+
AND table_type = 'BASE TABLE'
370+
""",
371+
args=(schema_name,),
372+
)
373+
tables = [row[0] for row in result]
374+
375+
# Ensure dependencies are loaded
376+
connection.dependencies.load(force=True)
377+
378+
# Compute and store lineage for each table
379+
for table_name in tables:
380+
lineage_map = _compute_all_lineage_for_table(
381+
connection, schema_name, table_name
382+
)
383+
for attr_name, lineage in lineage_map.items():
384+
if lineage is not None:
385+
lineage_table.store_lineage(table_name, attr_name, lineage)
386+
387+
logger.info(f"Computed lineage for schema `{schema_name}`: {len(tables)} tables")
388+
389+
390+
# =============================================================================
391+
# End of lineage tracking
392+
# =============================================================================
393+
20394
default_attribute_properties = dict( # these default values are set in computed attributes
21395
name=None,
22396
type="expression",
@@ -433,9 +807,7 @@ def _init_from_database(self):
433807

434808
# Load lineage information for semantic matching
435809
try:
436-
from .lineage import get_lineage_for_heading
437-
438-
lineage_map = get_lineage_for_heading(conn, database, table_name, None)
810+
lineage_map = _get_lineage_for_heading(conn, database, table_name)
439811
for attr in attributes:
440812
attr["lineage"] = lineage_map.get(attr["name"])
441813
except Exception as e:

0 commit comments

Comments
 (0)