diff --git a/duckdb_engine/__init__.py b/duckdb_engine/__init__.py index e6b1680e2..2649d9c80 100644 --- a/duckdb_engine/__init__.py +++ b/duckdb_engine/__init__.py @@ -301,7 +301,22 @@ def connect(self, *cargs: Any, **cparams: Any) -> "Connection": conn = duckdb.connect(*cargs, **cparams) for extension in preload_extensions: - conn.execute(f"LOAD {extension}") + # skip if already loaded in this connection + row = conn.execute( + "SELECT loaded FROM duckdb_extensions() WHERE extension_name = ?", + [extension], + ).fetchone() + if row and row[0]: # True == already loaded + continue + + try: + conn.execute(f"LOAD {extension}") + except Exception as e: + # tolerate idempotent re-load race if it ever happens + if "already exists" in str(e) or "already registered" in str(e): + pass + else: + raise for filesystem in filesystems: conn.register_filesystem(filesystem) diff --git a/duckdb_engine/tests/test_basic.py b/duckdb_engine/tests/test_basic.py index 15fda1fea..d2d6dfb9e 100644 --- a/duckdb_engine/tests/test_basic.py +++ b/duckdb_engine/tests/test_basic.py @@ -9,6 +9,7 @@ import duckdb import fsspec import sqlalchemy +from _pytest.monkeypatch import MonkeyPatch from hypothesis import assume, given, settings from hypothesis.strategies import text as text_strat from packaging.version import Version @@ -112,6 +113,14 @@ class IntervalModel(Base): field = Column(Interval) +class _Row: + def __init__(self, val: Any) -> None: + self._val = val + + def fetchone(self) -> Any: + return self._val + + @fixture def session(engine: Engine) -> Session: return sessionmaker(bind=engine)() @@ -706,3 +715,75 @@ def test_register_filesystem() -> None: with engine.connect() as conn: duckdb_conn = getattr(conn.connection.dbapi_connection, "_ConnectionWrapper__c") assert duckdb.list_filesystems(connection=duckdb_conn) == ["memory", "file"] + + +def test_skip_load_extension_if_already_loaded(monkeypatch: MonkeyPatch) -> None: + """ + First engine LOADs successfully. + Second engine attempts LOAD again; DB says 'already registered'. + New connect() swallows that benign error; old connect() bubbles it. + """ + real_connect = duckdb.connect + globally_registered: bool = False # simulate process-global registration + first_loads: int = 0 + second_load_attempts: int = 0 + + def fake_connect(*args: Any, **kwargs: Any) -> Any: + inner = real_connect(*args, **kwargs) + loaded_here: bool = False # per-connection view + + class Proxy: + def __init__(self, inner_conn: Any) -> None: + self._inner = inner_conn + + def execute( + self, query: Any, params: Optional[Sequence[Any]] = None + ) -> Any: + nonlocal \ + globally_registered, \ + loaded_here, \ + first_loads, \ + second_load_attempts + q = str(query).strip().lower() + + # Dialect probes per connection + if q.startswith("select loaded from duckdb_extensions"): + return _Row((loaded_here,)) + + # LOAD path + if q.startswith("load "): + if not globally_registered: + globally_registered = True + loaded_here = True + first_loads += 1 + return _Row(None) + # simulate tolerated idempotent error + second_load_attempts += 1 + raise Exception("extension already registered") + + return self._inner.execute(query, params) + + def register_filesystem(self, fs: Any) -> Any: + return self._inner.register_filesystem(fs) + + def __getattr__(self, name: str) -> Any: + return getattr(self._inner, name) + + return Proxy(inner) + + monkeypatch.setattr(duckdb, "connect", fake_connect) + + eng1 = create_engine( + "duckdb:///:memory:", connect_args={"preload_extensions": ["httpfs"]} + ) + with eng1.connect(): + pass + + eng2 = create_engine( + "duckdb:///:memory:", connect_args={"preload_extensions": ["httpfs"]} + ) + with eng2.connect(): + pass + + assert first_loads == 1 + assert second_load_attempts == 1