Skip to content

Commit d8b15c5

Browse files
fix: PostgreSQL adapter bugs for multi-backend support
Fixes multiple issues discovered during notebook execution against PostgreSQL: 1. Column comment retrieval (blob codec association) - PostgreSQL stores comments separately via COMMENT ON COLUMN - Updated get_columns_sql to use col_description() function - Updated parse_column_info to extract column_comment - This enables proper codec association (e.g., <blob>) from comments 2. ENUM type DDL generation - PostgreSQL requires CREATE TYPE for enums (not inline like MySQL) - Added enum type name generation based on value hash - Added get_pending_enum_ddl() method for pre-CREATE TABLE DDL - Updated declare() to return pre_ddl and post_ddl separately 3. Upsert/skip_duplicates syntax - MySQL: ON DUPLICATE KEY UPDATE pk=table.pk - PostgreSQL: ON CONFLICT (pk_cols) DO NOTHING - Added skip_duplicates_clause() method to both adapters - Updated table.py to use adapter method 4. String quoting in information_schema queries - Dependencies.py had hardcoded MySQL double quotes and concat() - Made queries backend-agnostic using adapter.quote_string() - Added backend property to adapters for conditional SQL generation 5. Index DDL syntax - MySQL: inline INDEX in CREATE TABLE - PostgreSQL: separate CREATE INDEX statements - Added supports_inline_indexes property - Added create_index_ddl() method to base adapter - Updated declare() to generate CREATE INDEX for PostgreSQL post_ddl Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
1 parent 4797151 commit d8b15c5

File tree

6 files changed

+302
-38
lines changed

6 files changed

+302
-38
lines changed

src/datajoint/adapters/base.py

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,19 @@ def default_port(self) -> int:
114114
"""
115115
...
116116

117+
@property
118+
@abstractmethod
119+
def backend(self) -> str:
120+
"""
121+
Backend identifier string.
122+
123+
Returns
124+
-------
125+
str
126+
Backend name: 'mysql' or 'postgresql'.
127+
"""
128+
...
129+
117130
@abstractmethod
118131
def get_cursor(self, connection: Any, as_dict: bool = False) -> Any:
119132
"""
@@ -491,6 +504,83 @@ def upsert_on_duplicate_sql(
491504
"""
492505
...
493506

507+
@abstractmethod
508+
def skip_duplicates_clause(
509+
self,
510+
full_table_name: str,
511+
primary_key: list[str],
512+
) -> str:
513+
"""
514+
Generate clause to skip duplicate key insertions.
515+
516+
For MySQL: ON DUPLICATE KEY UPDATE pk=table.pk (no-op update)
517+
For PostgreSQL: ON CONFLICT (pk_cols) DO NOTHING
518+
519+
Parameters
520+
----------
521+
full_table_name : str
522+
Fully qualified table name (with quotes).
523+
primary_key : list[str]
524+
Primary key column names (unquoted).
525+
526+
Returns
527+
-------
528+
str
529+
SQL clause to append to INSERT statement.
530+
"""
531+
...
532+
533+
@property
534+
def supports_inline_indexes(self) -> bool:
535+
"""
536+
Whether this backend supports inline INDEX in CREATE TABLE.
537+
538+
MySQL supports inline index definitions in CREATE TABLE.
539+
PostgreSQL requires separate CREATE INDEX statements.
540+
541+
Returns
542+
-------
543+
bool
544+
True for MySQL, False for PostgreSQL.
545+
"""
546+
return True # Default for MySQL, override in PostgreSQL
547+
548+
def create_index_ddl(
549+
self,
550+
full_table_name: str,
551+
columns: list[str],
552+
unique: bool = False,
553+
index_name: str | None = None,
554+
) -> str:
555+
"""
556+
Generate CREATE INDEX statement.
557+
558+
Parameters
559+
----------
560+
full_table_name : str
561+
Fully qualified table name (with quotes).
562+
columns : list[str]
563+
Column names to index (unquoted).
564+
unique : bool, optional
565+
If True, create a unique index.
566+
index_name : str, optional
567+
Custom index name. If None, auto-generate from table/columns.
568+
569+
Returns
570+
-------
571+
str
572+
CREATE INDEX SQL statement.
573+
"""
574+
quoted_cols = ", ".join(self.quote_identifier(col) for col in columns)
575+
# Generate index name from table and columns if not provided
576+
if index_name is None:
577+
# Extract table name from full_table_name for index naming
578+
table_part = full_table_name.split(".")[-1].strip('`"')
579+
col_part = "_".join(columns)[:30] # Truncate for long column lists
580+
index_name = f"idx_{table_part}_{col_part}"
581+
unique_clause = "UNIQUE " if unique else ""
582+
return f"CREATE {unique_clause}INDEX {self.quote_identifier(index_name)} ON {full_table_name} ({quoted_cols})"
583+
494584
# =========================================================================
495585
# Introspection
496586
# =========================================================================

src/datajoint/adapters/mysql.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,11 @@ def default_port(self) -> int:
155155
"""MySQL default port 3306."""
156156
return 3306
157157

158+
@property
159+
def backend(self) -> str:
160+
"""Backend identifier: 'mysql'."""
161+
return "mysql"
162+
158163
def get_cursor(self, connection: Any, as_dict: bool = False) -> Any:
159164
"""
160165
Get a cursor from MySQL connection.
@@ -567,6 +572,32 @@ def upsert_on_duplicate_sql(
567572
ON DUPLICATE KEY UPDATE {update_clauses}
568573
"""
569574

575+
def skip_duplicates_clause(
576+
self,
577+
full_table_name: str,
578+
primary_key: list[str],
579+
) -> str:
580+
"""
581+
Generate clause to skip duplicate key insertions for MySQL.
582+
583+
Uses ON DUPLICATE KEY UPDATE with a no-op update (pk=pk) to effectively
584+
skip duplicates without raising an error.
585+
586+
Parameters
587+
----------
588+
full_table_name : str
589+
Fully qualified table name (with quotes).
590+
primary_key : list[str]
591+
Primary key column names (unquoted).
592+
593+
Returns
594+
-------
595+
str
596+
MySQL ON DUPLICATE KEY UPDATE clause.
597+
"""
598+
quoted_pk = self.quote_identifier(primary_key[0])
599+
return f" ON DUPLICATE KEY UPDATE {quoted_pk}={full_table_name}.{quoted_pk}"
600+
570601
# =========================================================================
571602
# Introspection
572603
# =========================================================================

src/datajoint/adapters/postgres.py

Lines changed: 104 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
from __future__ import annotations
99

10+
import re
1011
from typing import Any
1112

1213
try:
@@ -170,6 +171,11 @@ def default_port(self) -> int:
170171
"""PostgreSQL default port 5432."""
171172
return 5432
172173

174+
@property
175+
def backend(self) -> str:
176+
"""Backend identifier: 'postgresql'."""
177+
return "postgresql"
178+
173179
def get_cursor(self, connection: Any, as_dict: bool = False) -> Any:
174180
"""
175181
Get a cursor from PostgreSQL connection.
@@ -292,9 +298,27 @@ def core_type_to_sql(self, core_type: str) -> str:
292298
return f"numeric{params}"
293299

294300
if core_type.startswith("enum("):
295-
# Enum requires special handling - caller must use CREATE TYPE
296-
# Return the type name pattern (will be replaced by caller)
297-
return "{{enum_type_name}}" # Placeholder for CREATE TYPE
301+
# PostgreSQL requires CREATE TYPE for enums
302+
# Extract enum values and generate a deterministic type name
303+
enum_match = re.match(r"enum\s*\((.+)\)", core_type, re.I)
304+
if enum_match:
305+
# Parse enum values: enum('M','F') -> ['M', 'F']
306+
values_str = enum_match.group(1)
307+
# Split by comma, handling quoted values
308+
values = [v.strip().strip("'\"") for v in values_str.split(",")]
309+
# Generate a deterministic type name based on values
310+
# Use a hash to keep name reasonable length
311+
import hashlib
312+
value_hash = hashlib.md5("_".join(sorted(values)).encode()).hexdigest()[:8]
313+
type_name = f"enum_{value_hash}"
314+
# Track this enum type for CREATE TYPE DDL
315+
if not hasattr(self, "_pending_enum_types"):
316+
self._pending_enum_types = {}
317+
self._pending_enum_types[type_name] = values
318+
# Return schema-qualified type reference using placeholder
319+
# {database} will be replaced with actual schema name in table.py
320+
return '"{database}".' + self.quote_identifier(type_name)
321+
return "text" # Fallback if parsing fails
298322

299323
raise ValueError(f"Unknown core type: {core_type}")
300324

@@ -611,6 +635,43 @@ def upsert_on_duplicate_sql(
611635
ON CONFLICT ({conflict_cols}) DO UPDATE SET {update_clauses}
612636
"""
613637

638+
def skip_duplicates_clause(
639+
self,
640+
full_table_name: str,
641+
primary_key: list[str],
642+
) -> str:
643+
"""
644+
Generate clause to skip duplicate key insertions for PostgreSQL.
645+
646+
Uses ON CONFLICT (pk_cols) DO NOTHING to skip duplicates without
647+
raising an error.
648+
649+
Parameters
650+
----------
651+
full_table_name : str
652+
Fully qualified table name (with quotes). Unused but kept for
653+
API compatibility with MySQL adapter.
654+
primary_key : list[str]
655+
Primary key column names (unquoted).
656+
657+
Returns
658+
-------
659+
str
660+
PostgreSQL ON CONFLICT DO NOTHING clause.
661+
"""
662+
pk_cols = ", ".join(self.quote_identifier(pk) for pk in primary_key)
663+
return f" ON CONFLICT ({pk_cols}) DO NOTHING"
664+
665+
@property
666+
def supports_inline_indexes(self) -> bool:
667+
"""
668+
PostgreSQL does not support inline INDEX in CREATE TABLE.
669+
670+
Returns False to indicate indexes must be created separately
671+
with CREATE INDEX statements.
672+
"""
673+
return False
674+
614675
# =========================================================================
615676
# Introspection
616677
# =========================================================================
@@ -639,14 +700,17 @@ def get_table_info_sql(self, schema_name: str, table_name: str) -> str:
639700
)
640701

641702
def get_columns_sql(self, schema_name: str, table_name: str) -> str:
642-
"""Query to get column definitions."""
703+
"""Query to get column definitions including comments."""
704+
# Use col_description() to retrieve column comments stored via COMMENT ON COLUMN
705+
# The regclass cast allows using schema.table notation to get the OID
643706
return (
644-
f"SELECT column_name, data_type, is_nullable, column_default, "
645-
f"character_maximum_length, numeric_precision, numeric_scale "
646-
f"FROM information_schema.columns "
647-
f"WHERE table_schema = {self.quote_string(schema_name)} "
648-
f"AND table_name = {self.quote_string(table_name)} "
649-
f"ORDER BY ordinal_position"
707+
f"SELECT c.column_name, c.data_type, c.is_nullable, c.column_default, "
708+
f"c.character_maximum_length, c.numeric_precision, c.numeric_scale, "
709+
f"col_description(({self.quote_string(schema_name)} || '.' || {self.quote_string(table_name)})::regclass, c.ordinal_position) as column_comment "
710+
f"FROM information_schema.columns c "
711+
f"WHERE c.table_schema = {self.quote_string(schema_name)} "
712+
f"AND c.table_name = {self.quote_string(table_name)} "
713+
f"ORDER BY c.ordinal_position"
650714
)
651715

652716
def get_primary_key_sql(self, schema_name: str, table_name: str) -> str:
@@ -761,7 +825,7 @@ def parse_column_info(self, row: dict[str, Any]) -> dict[str, Any]:
761825
Parameters
762826
----------
763827
row : dict
764-
Row from information_schema.columns query.
828+
Row from information_schema.columns query with col_description() join.
765829
766830
Returns
767831
-------
@@ -774,7 +838,7 @@ def parse_column_info(self, row: dict[str, Any]) -> dict[str, Any]:
774838
"type": row["data_type"],
775839
"nullable": row["is_nullable"] == "YES",
776840
"default": row["column_default"],
777-
"comment": None, # PostgreSQL stores comments separately
841+
"comment": row.get("column_comment"), # Retrieved via col_description()
778842
"key": "", # PostgreSQL key info retrieved separately
779843
"extra": "", # PostgreSQL doesn't have auto_increment in same way
780844
}
@@ -947,6 +1011,34 @@ def enum_type_ddl(self, type_name: str, values: list[str]) -> str | None:
9471011
quoted_values = ", ".join(f"'{v}'" for v in values)
9481012
return f"CREATE TYPE {self.quote_identifier(type_name)} AS ENUM ({quoted_values})"
9491013

1014+
def get_pending_enum_ddl(self, schema_name: str) -> list[str]:
1015+
"""
1016+
Get DDL statements for pending enum types and clear the pending list.
1017+
1018+
PostgreSQL requires CREATE TYPE statements before using enum types in
1019+
column definitions. This method returns DDL for enum types accumulated
1020+
during type conversion and clears the pending list.
1021+
1022+
Parameters
1023+
----------
1024+
schema_name : str
1025+
Schema name to qualify enum type names.
1026+
1027+
Returns
1028+
-------
1029+
list[str]
1030+
List of CREATE TYPE statements (if any pending).
1031+
"""
1032+
ddl_statements = []
1033+
if hasattr(self, "_pending_enum_types") and self._pending_enum_types:
1034+
for type_name, values in self._pending_enum_types.items():
1035+
# Generate CREATE TYPE with schema qualification
1036+
quoted_type = f"{self.quote_identifier(schema_name)}.{self.quote_identifier(type_name)}"
1037+
quoted_values = ", ".join(f"'{v}'" for v in values)
1038+
ddl_statements.append(f"CREATE TYPE {quoted_type} AS ENUM ({quoted_values})")
1039+
self._pending_enum_types = {}
1040+
return ddl_statements
1041+
9501042
def job_metadata_columns(self) -> list[str]:
9511043
"""
9521044
Return PostgreSQL-specific job metadata column definitions.

0 commit comments

Comments
 (0)