Skip to content

Commit 5ddd3b7

Browse files
feat: Use adapter for query expression SQL (Phase 6 Part 3)
Update expression.py to use database adapter for backend-agnostic SQL: - from_clause() subquery aliases (line 110) - from_clause() JOIN USING clause (line 123) - Aggregation.make_sql() GROUP BY clause (line 1031) - Aggregation.__len__() alias (line 1042) - Union.make_sql() alias (line 1084) - Union.__len__() alias (line 1100) - Refactor _wrap_attributes() to accept adapter parameter (line 1245) - Update sorting_clauses() to pass adapter (line 141) All query expression SQL (JOIN, FROM, SELECT, GROUP BY, ORDER BY) now uses proper identifier quoting for both MySQL and PostgreSQL. Maintains backward compatibility with MySQL backend. All existing tests pass (175 passed, 25 skipped). Part of Phase 6: Multi-backend PostgreSQL support. Related: #1338 Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>
1 parent 77e2d4c commit 5ddd3b7

File tree

1 file changed

+29
-24
lines changed

1 file changed

+29
-24
lines changed

src/datajoint/expression.py

Lines changed: 29 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -104,9 +104,10 @@ def primary_key(self):
104104
_subquery_alias_count = count() # count for alias names used in the FROM clause
105105

106106
def from_clause(self):
107+
adapter = self.connection.adapter
107108
support = (
108109
(
109-
"(" + src.make_sql() + ") as `$%x`" % next(self._subquery_alias_count)
110+
"({}) as {}".format(src.make_sql(), adapter.quote_identifier(f"${next(self._subquery_alias_count):x}"))
110111
if isinstance(src, QueryExpression)
111112
else src
112113
)
@@ -116,7 +117,8 @@ def from_clause(self):
116117
for s, (is_left, using_attrs) in zip(support, self._joins):
117118
left_kw = "LEFT " if is_left else ""
118119
if using_attrs:
119-
using = "USING ({})".format(", ".join(f"`{a}`" for a in using_attrs))
120+
quoted_attrs = ", ".join(adapter.quote_identifier(a) for a in using_attrs)
121+
using = f"USING ({quoted_attrs})"
120122
clause += f" {left_kw}JOIN {s} {using}"
121123
else:
122124
# Cross join (no common non-hidden attributes)
@@ -134,7 +136,8 @@ def sorting_clauses(self):
134136
return ""
135137
# Default to KEY ordering if order_by is None (inherit with no existing order)
136138
order_by = self._top.order_by if self._top.order_by is not None else ["KEY"]
137-
clause = ", ".join(_wrap_attributes(_flatten_attribute_list(self.primary_key, order_by)))
139+
adapter = self.connection.adapter
140+
clause = ", ".join(_wrap_attributes(_flatten_attribute_list(self.primary_key, order_by), adapter))
138141
if clause:
139142
clause = f" ORDER BY {clause}"
140143
if self._top.limit is not None:
@@ -1024,19 +1027,18 @@ def make_sql(self, fields=None):
10241027
""
10251028
if not self.primary_key
10261029
else (
1027-
" GROUP BY `%s`" % "`,`".join(self._grouping_attributes)
1030+
" GROUP BY {}".format(
1031+
", ".join(self.connection.adapter.quote_identifier(col) for col in self._grouping_attributes)
1032+
)
10281033
+ ("" if not self.restriction else " HAVING (%s)" % ")AND(".join(self.restriction))
10291034
)
10301035
),
10311036
sorting=self.sorting_clauses(),
10321037
)
10331038

10341039
def __len__(self):
1035-
return self.connection.query(
1036-
"SELECT count(1) FROM ({subquery}) `${alias:x}`".format(
1037-
subquery=self.make_sql(), alias=next(self._subquery_alias_count)
1038-
)
1039-
).fetchone()[0]
1040+
alias = self.connection.adapter.quote_identifier(f"${next(self._subquery_alias_count):x}")
1041+
return self.connection.query(f"SELECT count(1) FROM ({self.make_sql()}) {alias}").fetchone()[0]
10401042

10411043
def __bool__(self):
10421044
return bool(self.connection.query("SELECT EXISTS({sql})".format(sql=self.make_sql())).fetchone()[0])
@@ -1072,12 +1074,11 @@ def make_sql(self):
10721074
if not arg1.heading.secondary_attributes and not arg2.heading.secondary_attributes:
10731075
# no secondary attributes: use UNION DISTINCT
10741076
fields = arg1.primary_key
1075-
return "SELECT * FROM (({sql1}) UNION ({sql2})) as `_u{alias}{sorting}`".format(
1076-
sql1=(arg1.make_sql() if isinstance(arg1, Union) else arg1.make_sql(fields)),
1077-
sql2=(arg2.make_sql() if isinstance(arg2, Union) else arg2.make_sql(fields)),
1078-
alias=next(self.__count),
1079-
sorting=self.sorting_clauses(),
1080-
)
1077+
alias_name = f"_u{next(self.__count)}{self.sorting_clauses()}"
1078+
alias_quoted = self.connection.adapter.quote_identifier(alias_name)
1079+
sql1 = arg1.make_sql() if isinstance(arg1, Union) else arg1.make_sql(fields)
1080+
sql2 = arg2.make_sql() if isinstance(arg2, Union) else arg2.make_sql(fields)
1081+
return f"SELECT * FROM (({sql1}) UNION ({sql2})) as {alias_quoted}"
10811082
# with secondary attributes, use union of left join with anti-restriction
10821083
fields = self.heading.names
10831084
sql1 = arg1.join(arg2, left=True).make_sql(fields)
@@ -1093,12 +1094,8 @@ def where_clause(self):
10931094
raise NotImplementedError("Union does not use a WHERE clause")
10941095

10951096
def __len__(self):
1096-
return self.connection.query(
1097-
"SELECT count(1) FROM ({subquery}) `${alias:x}`".format(
1098-
subquery=self.make_sql(),
1099-
alias=next(QueryExpression._subquery_alias_count),
1100-
)
1101-
).fetchone()[0]
1097+
alias = self.connection.adapter.quote_identifier(f"${next(QueryExpression._subquery_alias_count):x}")
1098+
return self.connection.query(f"SELECT count(1) FROM ({self.make_sql()}) {alias}").fetchone()[0]
11021099

11031100
def __bool__(self):
11041101
return bool(self.connection.query("SELECT EXISTS({sql})".format(sql=self.make_sql())).fetchone()[0])
@@ -1242,6 +1239,14 @@ def _flatten_attribute_list(primary_key, attrs):
12421239
yield a
12431240

12441241

1245-
def _wrap_attributes(attr):
1246-
for entry in attr: # wrap attribute names in backquotes
1247-
yield re.sub(r"\b((?!asc|desc)\w+)\b", r"`\1`", entry, flags=re.IGNORECASE)
1242+
def _wrap_attributes(attr, adapter):
1243+
"""Wrap attribute names with database-specific quotes."""
1244+
for entry in attr:
1245+
# Replace word boundaries (not 'asc' or 'desc') with quoted version
1246+
def quote_match(match):
1247+
word = match.group(1)
1248+
if word.lower() not in ("asc", "desc"):
1249+
return adapter.quote_identifier(word)
1250+
return word
1251+
1252+
yield re.sub(r"\b((?!asc|desc)\w+)\b", quote_match, entry, flags=re.IGNORECASE)

0 commit comments

Comments
 (0)