Skip to content

Commit bc245d3

Browse files
fix: PostgreSQL compatibility improvements for DataJoint 2.1
Multiple fixes for PostgreSQL backend compatibility: 1. Fix composite FK column mapping in dependencies.py - Use pg_constraint with unnest() to correctly map FK columns - Previous information_schema query created Cartesian product - Fixes "Attribute already exists" errors during key_source 2. Fix Part table full_table_name quoting - PartMeta.full_table_name now uses adapter.quote_identifier() - Previously hardcoded MySQL backticks - Fixes "syntax error at or near `" errors with Part tables 3. Fix char type length preservation in postgres.py - Reconstruct parametrized types from PostgreSQL info schema - Fixes char(n) being truncated to char(1) for FK columns 4. Implement HAVING clause subquery wrapping for PostgreSQL - PostgreSQL doesn't allow column aliases in HAVING - Aggregation.make_sql() wraps as subquery with WHERE on PostgreSQL - MySQL continues to use HAVING directly (more efficient) 5. Implement GROUP_CONCAT/STRING_AGG translation - Base adapter has translate_expression() method - PostgreSQL: GROUP_CONCAT → STRING_AGG - MySQL: STRING_AGG → GROUP_CONCAT - heading.py calls translate_expression() in as_sql() 6. Register numpy type adapters for PostgreSQL - numpy.bool_, int*, float* types now work with psycopg2 - Prevents "can't adapt type 'numpy.bool_'" errors Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
1 parent d2e89ba commit bc245d3

File tree

7 files changed

+242
-44
lines changed

7 files changed

+242
-44
lines changed

src/datajoint/adapters/base.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -938,6 +938,34 @@ def json_path_expr(self, column: str, path: str, return_type: str | None = None)
938938
"""
939939
...
940940

941+
def translate_expression(self, expr: str) -> str:
942+
"""
943+
Translate SQL expression for backend compatibility.
944+
945+
Converts database-specific function calls to the equivalent syntax
946+
for the current backend. This enables portable DataJoint code that
947+
uses common aggregate functions.
948+
949+
Translations performed:
950+
- GROUP_CONCAT(col) ↔ STRING_AGG(col, ',')
951+
952+
Parameters
953+
----------
954+
expr : str
955+
SQL expression that may contain function calls.
956+
957+
Returns
958+
-------
959+
str
960+
Translated expression for the current backend.
961+
962+
Notes
963+
-----
964+
The base implementation returns the expression unchanged.
965+
Subclasses override to provide backend-specific translations.
966+
"""
967+
return expr
968+
941969
# =========================================================================
942970
# DDL Generation
943971
# =========================================================================

src/datajoint/adapters/mysql.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -827,6 +827,50 @@ def json_path_expr(self, column: str, path: str, return_type: str | None = None)
827827
return_clause = f" returning {return_type}" if return_type else ""
828828
return f"json_value({quoted_col}, _utf8mb4'$.{path}'{return_clause})"
829829

830+
def translate_expression(self, expr: str) -> str:
831+
"""
832+
Translate SQL expression for MySQL compatibility.
833+
834+
Converts PostgreSQL-specific functions to MySQL equivalents:
835+
- STRING_AGG(col, 'sep') → GROUP_CONCAT(col SEPARATOR 'sep')
836+
- STRING_AGG(col, ',') → GROUP_CONCAT(col)
837+
838+
Parameters
839+
----------
840+
expr : str
841+
SQL expression that may contain function calls.
842+
843+
Returns
844+
-------
845+
str
846+
Translated expression for MySQL.
847+
"""
848+
import re
849+
850+
# STRING_AGG(col, 'sep') → GROUP_CONCAT(col SEPARATOR 'sep')
851+
def replace_string_agg(match):
852+
inner = match.group(1).strip()
853+
# Parse arguments: col, 'separator'
854+
# Handle both single and double quoted separators
855+
arg_match = re.match(r"(.+?)\s*,\s*(['\"])(.+?)\2", inner)
856+
if arg_match:
857+
col = arg_match.group(1).strip()
858+
sep = arg_match.group(3)
859+
# Remove ::text cast if present (PostgreSQL-specific)
860+
col = re.sub(r"::text$", "", col)
861+
if sep == ",":
862+
return f"GROUP_CONCAT({col})"
863+
else:
864+
return f"GROUP_CONCAT({col} SEPARATOR '{sep}')"
865+
else:
866+
# No separator found, just use the expression
867+
col = re.sub(r"::text$", "", inner)
868+
return f"GROUP_CONCAT({col})"
869+
870+
expr = re.sub(r"STRING_AGG\s*\((.+?)\)", replace_string_agg, expr, flags=re.IGNORECASE)
871+
872+
return expr
873+
830874
# =========================================================================
831875
# DDL Generation
832876
# =========================================================================

src/datajoint/adapters/postgres.py

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -130,8 +130,38 @@ def connect(
130130
# DataJoint manages transactions explicitly via start_transaction()
131131
# Set autocommit=True to avoid implicit transactions
132132
conn.autocommit = True
133+
134+
# Register numpy type adapters so numpy types can be used directly in queries
135+
self._register_numpy_adapters()
136+
133137
return conn
134138

139+
def _register_numpy_adapters(self) -> None:
140+
"""
141+
Register psycopg2 adapters for numpy types.
142+
143+
This allows numpy scalar types (bool_, int64, float64, etc.) to be used
144+
directly in queries without explicit conversion to Python native types.
145+
"""
146+
try:
147+
import numpy as np
148+
from psycopg2.extensions import register_adapter, AsIs
149+
150+
# Numpy bool type
151+
register_adapter(np.bool_, lambda x: AsIs(str(bool(x)).upper()))
152+
153+
# Numpy integer types
154+
for np_type in (np.int8, np.int16, np.int32, np.int64,
155+
np.uint8, np.uint16, np.uint32, np.uint64):
156+
register_adapter(np_type, lambda x: AsIs(int(x)))
157+
158+
# Numpy float types
159+
for np_type in (np.float16, np.float32, np.float64):
160+
register_adapter(np_type, lambda x: AsIs(repr(float(x))))
161+
162+
except ImportError:
163+
pass # numpy not available
164+
135165
def close(self, connection: Any) -> None:
136166
"""Close the PostgreSQL connection."""
137167
connection.close()
@@ -853,6 +883,25 @@ def parse_column_info(self, row: dict[str, Any]) -> dict[str, Any]:
853883
data_type = row["data_type"]
854884
if data_type == "USER-DEFINED":
855885
data_type = row["udt_name"]
886+
887+
# Reconstruct parametrized types that PostgreSQL splits into separate fields
888+
char_max_len = row.get("character_maximum_length")
889+
num_precision = row.get("numeric_precision")
890+
num_scale = row.get("numeric_scale")
891+
892+
if data_type == "character" and char_max_len is not None:
893+
# char(n) - PostgreSQL reports as "character" with length in separate field
894+
data_type = f"char({char_max_len})"
895+
elif data_type == "character varying" and char_max_len is not None:
896+
# varchar(n)
897+
data_type = f"varchar({char_max_len})"
898+
elif data_type == "numeric" and num_precision is not None:
899+
# numeric(p,s) - reconstruct decimal type
900+
if num_scale is not None and num_scale > 0:
901+
data_type = f"decimal({num_precision},{num_scale})"
902+
else:
903+
data_type = f"decimal({num_precision})"
904+
856905
return {
857906
"name": row["column_name"],
858907
"type": data_type,
@@ -959,6 +1008,43 @@ def json_path_expr(self, column: str, path: str, return_type: str | None = None)
9591008
# Note: PostgreSQL jsonb_extract_path_text doesn't use return type parameter
9601009
return f"jsonb_extract_path_text({quoted_col}, {path_args})"
9611010

1011+
def translate_expression(self, expr: str) -> str:
1012+
"""
1013+
Translate SQL expression for PostgreSQL compatibility.
1014+
1015+
Converts MySQL-specific functions to PostgreSQL equivalents:
1016+
- GROUP_CONCAT(col) → STRING_AGG(col::text, ',')
1017+
- GROUP_CONCAT(col SEPARATOR 'sep') → STRING_AGG(col::text, 'sep')
1018+
1019+
Parameters
1020+
----------
1021+
expr : str
1022+
SQL expression that may contain function calls.
1023+
1024+
Returns
1025+
-------
1026+
str
1027+
Translated expression for PostgreSQL.
1028+
"""
1029+
import re
1030+
1031+
# GROUP_CONCAT(col) → STRING_AGG(col::text, ',')
1032+
# GROUP_CONCAT(col SEPARATOR 'sep') → STRING_AGG(col::text, 'sep')
1033+
def replace_group_concat(match):
1034+
inner = match.group(1).strip()
1035+
# Check for SEPARATOR clause
1036+
sep_match = re.match(r"(.+?)\s+SEPARATOR\s+(['\"])(.+?)\2", inner, re.IGNORECASE)
1037+
if sep_match:
1038+
col = sep_match.group(1).strip()
1039+
sep = sep_match.group(3)
1040+
return f"STRING_AGG({col}::text, '{sep}')"
1041+
else:
1042+
return f"STRING_AGG({inner}::text, ',')"
1043+
1044+
expr = re.sub(r"GROUP_CONCAT\s*\((.+?)\)", replace_group_concat, expr, flags=re.IGNORECASE)
1045+
1046+
return expr
1047+
9621048
# =========================================================================
9631049
# DDL Generation
9641050
# =========================================================================

src/datajoint/dependencies.py

Lines changed: 23 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -221,25 +221,31 @@ def load(self, force: bool = True) -> None:
221221
for key in keys:
222222
pks[key[0]].add(key[1])
223223

224-
# load foreign keys (PostgreSQL requires joining multiple tables)
225-
ref_tab_expr = "'\"' || ccu.table_schema || '\".\"' || ccu.table_name || '\"'"
224+
# load foreign keys using pg_constraint system catalogs
225+
# The information_schema approach creates a Cartesian product for composite FKs
226+
# because constraint_column_usage doesn't have ordinal_position.
227+
# Using pg_constraint with unnest(conkey, confkey) WITH ORDINALITY gives correct mapping.
226228
fk_keys = self._conn.query(
227229
f"""
228-
SELECT kcu.constraint_name,
229-
{tab_expr} as referencing_table,
230-
{ref_tab_expr} as referenced_table,
231-
kcu.column_name, ccu.column_name as referenced_column_name
232-
FROM information_schema.key_column_usage kcu
233-
JOIN information_schema.referential_constraints rc
234-
ON kcu.constraint_name = rc.constraint_name
235-
AND kcu.constraint_schema = rc.constraint_schema
236-
JOIN information_schema.constraint_column_usage ccu
237-
ON rc.unique_constraint_name = ccu.constraint_name
238-
AND rc.unique_constraint_schema = ccu.constraint_schema
239-
WHERE kcu.table_name NOT LIKE {like_pattern}
240-
AND (ccu.table_schema in ({schemas_list})
241-
OR kcu.table_schema in ({schemas_list}))
242-
ORDER BY kcu.constraint_name, kcu.ordinal_position
230+
SELECT
231+
c.conname as constraint_name,
232+
'"' || ns1.nspname || '"."' || cl1.relname || '"' as referencing_table,
233+
'"' || ns2.nspname || '"."' || cl2.relname || '"' as referenced_table,
234+
a1.attname as column_name,
235+
a2.attname as referenced_column_name
236+
FROM pg_constraint c
237+
JOIN pg_class cl1 ON c.conrelid = cl1.oid
238+
JOIN pg_namespace ns1 ON cl1.relnamespace = ns1.oid
239+
JOIN pg_class cl2 ON c.confrelid = cl2.oid
240+
JOIN pg_namespace ns2 ON cl2.relnamespace = ns2.oid
241+
CROSS JOIN LATERAL unnest(c.conkey, c.confkey) WITH ORDINALITY AS cols(conkey, confkey, ord)
242+
JOIN pg_attribute a1 ON a1.attrelid = cl1.oid AND a1.attnum = cols.conkey
243+
JOIN pg_attribute a2 ON a2.attrelid = cl2.oid AND a2.attnum = cols.confkey
244+
WHERE c.contype = 'f'
245+
AND cl1.relname NOT LIKE {like_pattern}
246+
AND (ns2.nspname in ({schemas_list})
247+
OR ns1.nspname in ({schemas_list}))
248+
ORDER BY c.conname, cols.ord
243249
""",
244250
as_dict=True,
245251
)

src/datajoint/expression.py

Lines changed: 44 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1019,27 +1019,54 @@ def where_clause(self):
10191019
return "" if not self._left_restrict else " WHERE (%s)" % ")AND(".join(str(s) for s in self._left_restrict)
10201020

10211021
def make_sql(self, fields=None):
1022-
fields = self.heading.as_sql(fields or self.heading.names, adapter=self.connection.adapter)
1022+
adapter = self.connection.adapter
1023+
fields = self.heading.as_sql(fields or self.heading.names, adapter=adapter)
10231024
assert self._grouping_attributes or not self.restriction
10241025
distinct = set(self.heading.names) == set(self.primary_key)
1025-
return "SELECT {distinct}{fields} FROM {from_}{where}{group_by}{sorting}".format(
1026-
distinct="DISTINCT " if distinct else "",
1027-
fields=fields,
1028-
from_=self.from_clause(),
1029-
where=self.where_clause(),
1030-
group_by=(
1031-
""
1032-
if not self.primary_key
1033-
else (
1034-
" GROUP BY {}".format(
1035-
", ".join(self.connection.adapter.quote_identifier(col) for col in self._grouping_attributes)
1036-
)
1037-
+ ("" if not self.restriction else " HAVING (%s)" % ")AND(".join(self.restriction))
1038-
)
1039-
),
1040-
sorting=self.sorting_clauses(),
1026+
1027+
# PostgreSQL doesn't allow column aliases in HAVING clause (SQL standard).
1028+
# For PostgreSQL with restrictions, wrap aggregation in subquery and use WHERE.
1029+
use_subquery_for_having = (
1030+
adapter.backend == "postgresql"
1031+
and self.restriction
1032+
and self._grouping_attributes
10411033
)
10421034

1035+
if use_subquery_for_having:
1036+
# Generate inner query without HAVING
1037+
inner_sql = "SELECT {distinct}{fields} FROM {from_}{where}{group_by}".format(
1038+
distinct="DISTINCT " if distinct else "",
1039+
fields=fields,
1040+
from_=self.from_clause(),
1041+
where=self.where_clause(),
1042+
group_by=" GROUP BY {}".format(
1043+
", ".join(adapter.quote_identifier(col) for col in self._grouping_attributes)
1044+
),
1045+
)
1046+
# Wrap in subquery with WHERE for the HAVING conditions
1047+
subquery_alias = adapter.quote_identifier(f"_aggr{next(self._subquery_alias_count)}")
1048+
outer_where = " WHERE (%s)" % ")AND(".join(self.restriction)
1049+
return f"SELECT * FROM ({inner_sql}) AS {subquery_alias}{outer_where}{self.sorting_clauses()}"
1050+
else:
1051+
# MySQL path: use HAVING directly
1052+
return "SELECT {distinct}{fields} FROM {from_}{where}{group_by}{sorting}".format(
1053+
distinct="DISTINCT " if distinct else "",
1054+
fields=fields,
1055+
from_=self.from_clause(),
1056+
where=self.where_clause(),
1057+
group_by=(
1058+
""
1059+
if not self.primary_key
1060+
else (
1061+
" GROUP BY {}".format(
1062+
", ".join(adapter.quote_identifier(col) for col in self._grouping_attributes)
1063+
)
1064+
+ ("" if not self.restriction else " HAVING (%s)" % ")AND(".join(self.restriction))
1065+
)
1066+
),
1067+
sorting=self.sorting_clauses(),
1068+
)
1069+
10431070
def __len__(self):
10441071
alias = self.connection.adapter.quote_identifier(f"${next(self._subquery_alias_count):x}")
10451072
return self.connection.query(f"SELECT count(1) FROM ({self.make_sql()}) {alias}").fetchone()[0]

src/datajoint/heading.py

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -349,14 +349,20 @@ def quote(name):
349349
# Use adapter if available, otherwise use ANSI SQL double quotes (not backticks)
350350
return adapter.quote_identifier(name) if adapter else f'"{name}"'
351351

352-
return ",".join(
353-
(
354-
quote(name)
355-
if self.attributes[name].attribute_expression is None
356-
else self.attributes[name].attribute_expression + (f" as {quote(name)}" if include_aliases else "")
357-
)
358-
for name in fields
359-
)
352+
def render_field(name):
353+
attr = self.attributes[name]
354+
if attr.attribute_expression is None:
355+
return quote(name)
356+
else:
357+
# Translate expression for backend compatibility (e.g., GROUP_CONCAT ↔ STRING_AGG)
358+
expr = attr.attribute_expression
359+
if adapter:
360+
expr = adapter.translate_expression(expr)
361+
if include_aliases:
362+
return f"{expr} as {quote(name)}"
363+
return expr
364+
365+
return ",".join(render_field(name) for name in fields)
360366

361367
def __iter__(self):
362368
return iter(self.attributes)

src/datajoint/user_tables.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -182,10 +182,11 @@ def table_name(cls):
182182

183183
@property
184184
def full_table_name(cls):
185-
"""The fully qualified table name (`database`.`table`)."""
185+
"""The fully qualified table name (quoted per backend)."""
186186
if cls.database is None or cls.table_name is None:
187187
return None
188-
return r"`{0:s}`.`{1:s}`".format(cls.database, cls.table_name)
188+
adapter = cls._connection.adapter
189+
return f"{adapter.quote_identifier(cls.database)}.{adapter.quote_identifier(cls.table_name)}"
189190

190191
@property
191192
def master(cls):

0 commit comments

Comments
 (0)