Skip to content

Commit 5fa0f56

Browse files
fix: Backend-agnostic fixes for cascade delete and FreeTable
- Fix FreeTable.__init__ to strip both backticks and double quotes - Fix heading.py error message to not add hardcoded backticks - Fix Attribute.original_name to accept both quote types - Fix delete_quick() to use cursor.rowcount instead of ROW_COUNT() - Update PostgreSQL FK error parser with clearer naming - Add cascade delete integration tests All 4 PostgreSQL multi-backend tests passing. Cascade delete logic working correctly.
1 parent 9800381 commit 5fa0f56

File tree

4 files changed

+226
-31
lines changed

4 files changed

+226
-31
lines changed

src/datajoint/adapters/postgres.py

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -685,13 +685,19 @@ def get_constraint_info_sql(self, constraint_name: str, schema_name: str, table_
685685
)
686686

687687
def parse_foreign_key_error(self, error_message: str) -> dict[str, str | list[str]] | None:
688-
"""Parse PostgreSQL foreign key violation error message."""
688+
"""
689+
Parse PostgreSQL foreign key violation error message.
690+
691+
PostgreSQL FK error format:
692+
'update or delete on table "X" violates foreign key constraint "Y" on table "Z"'
693+
Where:
694+
- "X" is the referenced table (being deleted/updated)
695+
- "Z" is the referencing table (has the FK, needs cascade delete)
696+
"""
689697
import re
690698

691-
# PostgreSQL FK error pattern
692-
# Example: 'update or delete on table "parent" violates foreign key constraint "child_parent_id_fkey" on table "child"'
693699
pattern = re.compile(
694-
r'.*table "(?P<parent_table>[^"]+)" violates foreign key constraint "(?P<name>[^"]+)" on table "(?P<child_table>[^"]+)"'
700+
r'.*table "(?P<referenced_table>[^"]+)" violates foreign key constraint "(?P<name>[^"]+)" on table "(?P<referencing_table>[^"]+)"'
695701
)
696702

697703
match = pattern.match(error_message)
@@ -700,16 +706,17 @@ def parse_foreign_key_error(self, error_message: str) -> dict[str, str | list[st
700706

701707
result = match.groupdict()
702708

703-
# Build child table name (assume same schema as parent for now)
709+
# The child is the referencing table (the one with the FK that needs cascade delete)
710+
# The parent is the referenced table (the one being deleted)
704711
# The error doesn't include schema, so we return unqualified names
705-
# and let the caller add schema context
706-
child = f'"{result["child_table"]}"'
712+
child = f'"{result["referencing_table"]}"'
713+
parent = f'"{result["referenced_table"]}"'
707714

708715
return {
709716
"child": child,
710717
"name": f'"{result["name"]}"',
711718
"fk_attrs": None, # Not in error message, will need constraint query
712-
"parent": f'"{result["parent_table"]}"',
719+
"parent": parent,
713720
"pk_attrs": None, # Not in error message, will need constraint query
714721
}
715722

src/datajoint/heading.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -164,8 +164,9 @@ def original_name(self) -> str:
164164
"""
165165
if self.attribute_expression is None:
166166
return self.name
167-
assert self.attribute_expression.startswith("`")
168-
return self.attribute_expression.strip("`")
167+
# Backend-agnostic quote stripping (MySQL uses `, PostgreSQL uses ")
168+
assert self.attribute_expression.startswith(("`", '"'))
169+
return self.attribute_expression.strip('`"')
169170

170171

171172
class Heading:
@@ -365,7 +366,7 @@ def _init_from_database(self) -> None:
365366
).fetchone()
366367
if info is None:
367368
raise DataJointError(
368-
"The table `{database}`.`{table_name}` is not defined.".format(table_name=table_name, database=database)
369+
f"The table {database}.{table_name} is not defined."
369370
)
370371
# Normalize table_comment to comment for backward compatibility
371372
self._table_status = {k.lower(): v for k, v in info.items()}

src/datajoint/table.py

Lines changed: 37 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -834,8 +834,9 @@ def delete_quick(self, get_count=False):
834834
If this table has populated dependent tables, this will fail.
835835
"""
836836
query = "DELETE FROM " + self.full_table_name + self.where_clause()
837-
self.connection.query(query)
838-
count = self.connection.query("SELECT ROW_COUNT()").fetchone()[0] if get_count else None
837+
cursor = self.connection.query(query)
838+
# Use cursor.rowcount (DB-API 2.0 standard, works for both MySQL and PostgreSQL)
839+
count = cursor.rowcount if get_count else None
839840
return count
840841

841842
def delete(
@@ -876,9 +877,17 @@ def cascade(table):
876877
"""service function to perform cascading deletes recursively."""
877878
max_attempts = 50
878879
for _ in range(max_attempts):
880+
# Set savepoint before delete attempt (for PostgreSQL transaction handling)
881+
savepoint_name = f"cascade_delete_{id(table)}"
882+
if transaction:
883+
table.connection.query(f"SAVEPOINT {savepoint_name}")
884+
879885
try:
880886
delete_count = table.delete_quick(get_count=True)
881887
except IntegrityError as error:
888+
# Rollback to savepoint so we can continue querying (PostgreSQL requirement)
889+
if transaction:
890+
table.connection.query(f"ROLLBACK TO SAVEPOINT {savepoint_name}")
882891
# Use adapter to parse FK error message
883892
match = table.connection.adapter.parse_foreign_key_error(error.args[0])
884893
if match is None:
@@ -895,43 +904,47 @@ def strip_quotes(s):
895904
return s.strip('`"')
896905
return s
897906

898-
# Ensure child table has schema
899-
child_table = match["child"]
900-
if "." not in strip_quotes(child_table):
907+
# Extract schema and table name from child (work with unquoted names)
908+
child_table_raw = strip_quotes(match["child"])
909+
if "." in child_table_raw:
910+
child_parts = child_table_raw.split(".")
911+
child_schema = strip_quotes(child_parts[0])
912+
child_table_name = strip_quotes(child_parts[1])
913+
else:
901914
# Add schema from current table
902-
schema = table.full_table_name.split(".")[0].strip('`"')
903-
child_unquoted = strip_quotes(child_table)
904-
child_table = f"{table.connection.adapter.quote_identifier(schema)}.{table.connection.adapter.quote_identifier(child_unquoted)}"
905-
match["child"] = child_table
915+
schema_parts = table.full_table_name.split(".")
916+
child_schema = strip_quotes(schema_parts[0])
917+
child_table_name = child_table_raw
906918

907919
# If FK/PK attributes not in error message, query information_schema
908920
if match["fk_attrs"] is None or match["pk_attrs"] is None:
909-
# Extract schema and table name from child
910-
child_parts = [strip_quotes(p) for p in child_table.split(".")]
911-
if len(child_parts) == 2:
912-
child_schema, child_table_name = child_parts
913-
else:
914-
child_schema = table.full_table_name.split(".")[0].strip('`"')
915-
child_table_name = child_parts[0]
916-
917921
constraint_query = table.connection.adapter.get_constraint_info_sql(
918922
strip_quotes(match["name"]),
919923
child_schema,
920924
child_table_name,
921925
)
922926

923-
results = table.connection.query(constraint_query).fetchall()
927+
results = table.connection.query(
928+
constraint_query,
929+
args=(strip_quotes(match["name"]), child_schema, child_table_name),
930+
).fetchall()
924931
if results:
925932
match["fk_attrs"], match["parent"], match["pk_attrs"] = list(
926933
map(list, zip(*results))
927934
)
928935
match["parent"] = match["parent"][0] # All rows have same parent
929936

937+
# Build properly quoted full table name for FreeTable
938+
child_full_name = (
939+
f"{table.connection.adapter.quote_identifier(child_schema)}."
940+
f"{table.connection.adapter.quote_identifier(child_table_name)}"
941+
)
942+
930943
# Restrict child by table if
931944
# 1. if table's restriction attributes are not in child's primary key
932945
# 2. if child renames any attributes
933946
# Otherwise restrict child by table's restriction.
934-
child = FreeTable(table.connection, match["child"])
947+
child = FreeTable(table.connection, child_full_name)
935948
if set(table.restriction_attributes) <= set(child.primary_key) and match["fk_attrs"] == match["pk_attrs"]:
936949
child._restriction = table._restriction
937950
child._restriction_attributes = table.restriction_attributes
@@ -961,6 +974,9 @@ def strip_quotes(s):
961974
else:
962975
cascade(child)
963976
else:
977+
# Successful delete - release savepoint
978+
if transaction:
979+
table.connection.query(f"RELEASE SAVEPOINT {savepoint_name}")
964980
deleted.add(table.full_table_name)
965981
logger.info("Deleting {count} rows from {table}".format(count=delete_count, table=table.full_table_name))
966982
break
@@ -1381,7 +1397,8 @@ class FreeTable(Table):
13811397
"""
13821398

13831399
def __init__(self, conn, full_table_name):
1384-
self.database, self._table_name = (s.strip("`") for s in full_table_name.split("."))
1400+
# Backend-agnostic quote stripping (MySQL uses `, PostgreSQL uses ")
1401+
self.database, self._table_name = (s.strip('`"') for s in full_table_name.split("."))
13851402
self._connection = conn
13861403
self._support = [full_table_name]
13871404
self._heading = Heading(
Lines changed: 170 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,170 @@
1+
"""
2+
Integration tests for cascade delete on multiple backends.
3+
"""
4+
5+
import os
6+
7+
import pytest
8+
9+
import datajoint as dj
10+
11+
12+
@pytest.fixture(scope="function")
13+
def schema_by_backend(connection_by_backend, db_creds_by_backend, request):
14+
"""Create a schema for cascade delete tests."""
15+
backend = db_creds_by_backend["backend"]
16+
# Use unique schema name for each test
17+
import time
18+
test_id = str(int(time.time() * 1000))[-8:] # Last 8 digits of timestamp
19+
schema_name = f"djtest_cascade_{backend}_{test_id}"[:64] # Limit length
20+
21+
# Drop schema if exists (cleanup from any previous failed runs)
22+
if connection_by_backend.is_connected:
23+
try:
24+
connection_by_backend.query(
25+
f"DROP DATABASE IF EXISTS {connection_by_backend.adapter.quote_identifier(schema_name)}"
26+
)
27+
except Exception:
28+
pass # Ignore errors during cleanup
29+
30+
# Create fresh schema
31+
schema = dj.Schema(schema_name, connection=connection_by_backend)
32+
33+
yield schema
34+
35+
# Cleanup after test
36+
if connection_by_backend.is_connected:
37+
try:
38+
connection_by_backend.query(
39+
f"DROP DATABASE IF EXISTS {connection_by_backend.adapter.quote_identifier(schema_name)}"
40+
)
41+
except Exception:
42+
pass # Ignore errors during cleanup
43+
44+
45+
def test_simple_cascade_delete(schema_by_backend):
46+
"""Test basic cascade delete with foreign keys."""
47+
48+
@schema_by_backend
49+
class Parent(dj.Manual):
50+
definition = """
51+
parent_id : int
52+
---
53+
name : varchar(255)
54+
"""
55+
56+
@schema_by_backend
57+
class Child(dj.Manual):
58+
definition = """
59+
-> Parent
60+
child_id : int
61+
---
62+
data : varchar(255)
63+
"""
64+
65+
# Insert test data
66+
Parent.insert1((1, "Parent1"))
67+
Parent.insert1((2, "Parent2"))
68+
Child.insert1((1, 1, "Child1-1"))
69+
Child.insert1((1, 2, "Child1-2"))
70+
Child.insert1((2, 1, "Child2-1"))
71+
72+
assert len(Parent()) == 2
73+
assert len(Child()) == 3
74+
75+
# Delete parent with cascade
76+
(Parent & {"parent_id": 1}).delete()
77+
78+
# Check cascade worked
79+
assert len(Parent()) == 1
80+
assert len(Child()) == 1
81+
assert (Child & {"parent_id": 2, "child_id": 1}).fetch1("data") == "Child2-1"
82+
83+
84+
def test_multi_level_cascade_delete(schema_by_backend):
85+
"""Test cascade delete through multiple levels of foreign keys."""
86+
87+
@schema_by_backend
88+
class GrandParent(dj.Manual):
89+
definition = """
90+
gp_id : int
91+
---
92+
name : varchar(255)
93+
"""
94+
95+
@schema_by_backend
96+
class Parent(dj.Manual):
97+
definition = """
98+
-> GrandParent
99+
parent_id : int
100+
---
101+
name : varchar(255)
102+
"""
103+
104+
@schema_by_backend
105+
class Child(dj.Manual):
106+
definition = """
107+
-> Parent
108+
child_id : int
109+
---
110+
data : varchar(255)
111+
"""
112+
113+
# Insert test data
114+
GrandParent.insert1((1, "GP1"))
115+
Parent.insert1((1, 1, "P1"))
116+
Parent.insert1((1, 2, "P2"))
117+
Child.insert1((1, 1, 1, "C1"))
118+
Child.insert1((1, 1, 2, "C2"))
119+
Child.insert1((1, 2, 1, "C3"))
120+
121+
assert len(GrandParent()) == 1
122+
assert len(Parent()) == 2
123+
assert len(Child()) == 3
124+
125+
# Delete grandparent - should cascade through parent to child
126+
(GrandParent & {"gp_id": 1}).delete()
127+
128+
# Check everything is deleted
129+
assert len(GrandParent()) == 0
130+
assert len(Parent()) == 0
131+
assert len(Child()) == 0
132+
133+
134+
def test_cascade_delete_with_renamed_attrs(schema_by_backend):
135+
"""Test cascade delete when foreign key renames attributes."""
136+
137+
@schema_by_backend
138+
class Animal(dj.Manual):
139+
definition = """
140+
animal_id : int
141+
---
142+
species : varchar(255)
143+
"""
144+
145+
@schema_by_backend
146+
class Observation(dj.Manual):
147+
definition = """
148+
obs_id : int
149+
---
150+
-> Animal.proj(subject_id='animal_id')
151+
measurement : float
152+
"""
153+
154+
# Insert test data
155+
Animal.insert1((1, "Mouse"))
156+
Animal.insert1((2, "Rat"))
157+
Observation.insert1((1, 1, 10.5))
158+
Observation.insert1((2, 1, 11.2))
159+
Observation.insert1((3, 2, 15.3))
160+
161+
assert len(Animal()) == 2
162+
assert len(Observation()) == 3
163+
164+
# Delete animal - should cascade to observations
165+
(Animal & {"animal_id": 1}).delete()
166+
167+
# Check cascade worked
168+
assert len(Animal()) == 1
169+
assert len(Observation()) == 1
170+
assert (Observation & {"obs_id": 3}).fetch1("measurement") == 15.3

0 commit comments

Comments
 (0)