diff --git a/src/databricks/sqlalchemy/base.py b/src/databricks/sqlalchemy/base.py index 9148de7..3c02d46 100644 --- a/src/databricks/sqlalchemy/base.py +++ b/src/databricks/sqlalchemy/base.py @@ -336,6 +336,40 @@ def do_rollback(self, dbapi_connection): # Databricks SQL Does not support transactions pass + def is_disconnect(self, e, connection, cursor): + """Determine if an exception indicates the connection was lost. + + This method is called by SQLAlchemy after exceptions occur during query + execution to determine if the error was due to a lost connection. If this + returns True, SQLAlchemy will invalidate the connection and create a new + one for the next operation. + + This method is also used by SQLAlchemy's default do_ping() implementation + when pool_pre_ping=True. If do_ping() encounters an exception, it calls + is_disconnect() to classify the error and determine whether to invalidate + the connection. + + Args: + e: The exception that was raised + connection: The connection that raised the exception (may be None) + cursor: The cursor that raised the exception (may be None) + + Returns: + True if the error indicates a disconnect, False otherwise + """ + from databricks.sql.exc import InterfaceError, DatabaseError + + # InterfaceError: Client-side errors (e.g., connection already closed) + if isinstance(e, InterfaceError): + return True + + # DatabaseError: Server-side errors with invalid handle indicate session expired + if isinstance(e, DatabaseError): + error_msg = str(e).lower() + return "invalid" in error_msg and "handle" in error_msg + + return False + @reflection.cache def has_table( self, connection, table_name, schema=None, catalog=None, **kwargs diff --git a/tests/test_local/e2e/test_basic.py b/tests/test_local/e2e/test_basic.py index ce0b5d8..48b46d9 100644 --- a/tests/test_local/e2e/test_basic.py +++ b/tests/test_local/e2e/test_basic.py @@ -541,3 +541,127 @@ def test_table_comment_reflection(self, inspector: Inspector, table: Table): def test_column_comment(self, inspector: Inspector, table: Table): result = inspector.get_columns(table.name)[0].get("comment") assert result == "column comment" + + +def test_pool_pre_ping_with_closed_connection(connection_details): + """Test that pool_pre_ping detects closed connections and creates new ones. + + This test verifies that when a connection is manually closed (simulating + session expiration), pool_pre_ping detects it and automatically creates + a new connection without raising an error to the user. + """ + conn_string, connect_args = version_agnostic_connect_arguments(connection_details) + + # Create engine with pool_pre_ping enabled + engine = create_engine( + conn_string, + connect_args=connect_args, + pool_pre_ping=True, + pool_size=1, + max_overflow=0 + ) + + # Step 1: Create connection and get session ID + with engine.connect() as conn: + result = conn.execute(text("SELECT VERSION()")).scalar() + assert result is not None + + # Get session ID of first connection + raw_conn = conn.connection.dbapi_connection + session_id_1 = raw_conn.get_session_id_hex() + assert session_id_1 is not None + + # Step 2: Manually close the connection to simulate expiration + pooled_conn = engine.pool._pool.queue[0] + pooled_conn.driver_connection.close() + + # Verify connection is closed + assert not pooled_conn.driver_connection.open + + # Step 3: Try to use the closed connection - pool_pre_ping should detect and recycle + with engine.connect() as conn: + result = conn.execute(text("SELECT VERSION()")).scalar() + assert result is not None + + # Get session ID of new connection + raw_conn = conn.connection.dbapi_connection + session_id_2 = raw_conn.get_session_id_hex() + assert session_id_2 is not None + + # Verify a NEW connection was created (different session ID) + assert session_id_1 != session_id_2, ( + "pool_pre_ping should have detected the closed connection " + "and created a new one with a different session ID" + ) + + # Cleanup + engine.dispose() + + +def test_is_disconnect_handles_runtime_errors(db_engine): + """Test that is_disconnect() properly classifies disconnect errors during query execution. + + When a connection fails DURING a query, is_disconnect() should recognize the error + and tell SQLAlchemy to invalidate the connection so the next query gets a fresh one. + """ + from sqlalchemy import create_engine, text + from sqlalchemy.exc import DBAPIError + + engine = create_engine( + db_engine.url, + pool_pre_ping=False, # Disabled - we want to test is_disconnect, not do_ping + pool_size=1, + max_overflow=0, + ) + + # Step 1: Execute a successful query + with engine.connect() as conn: + result = conn.execute(text("SELECT VERSION()")).scalar() + assert result is not None + + # Get session ID of working connection + raw_conn = conn.connection.dbapi_connection + session_id_1 = raw_conn.get_session_id_hex() + assert session_id_1 is not None + + # Step 2: Manually close the connection to simulate server-side session expiration + pooled_conn = engine.pool._pool.queue[0] + pooled_conn.driver_connection.close() + + # Step 3: Try to execute query on closed connection + # This should: + # 1. Fail with an exception + # 2. is_disconnect() gets called by SQLAlchemy + # 3. Returns True (recognizes it as disconnect error) + # 4. SQLAlchemy invalidates the connection + # 5. Next operation gets a fresh connection + + # First query will fail because connection is closed + try: + with engine.connect() as conn: + conn.execute(text("SELECT VERSION()")).scalar() + # If we get here without exception, the connection wasn't actually closed + pytest.skip("Connection wasn't properly closed - cannot test is_disconnect") + except DBAPIError as e: + # Expected - connection was closed + # is_disconnect() should have been called and returned True + # This causes SQLAlchemy to invalidate the connection + assert "closed" in str(e).lower() or "invalid" in str(e).lower() + + # Step 4: Next query should work because is_disconnect() invalidated the bad connection + with engine.connect() as conn: + result = conn.execute(text("SELECT VERSION()")).scalar() + assert result is not None + + # Verify we got a NEW connection + raw_conn = conn.connection.dbapi_connection + session_id_2 = raw_conn.get_session_id_hex() + assert session_id_2 is not None + + # Different session ID proves connection was invalidated and recreated + assert session_id_1 != session_id_2, ( + "is_disconnect() should have invalidated the bad connection, " + "causing SQLAlchemy to create a new one with different session ID" + ) + + engine.dispose() diff --git a/tests/test_local/test_is_disconnect.py b/tests/test_local/test_is_disconnect.py new file mode 100644 index 0000000..e516745 --- /dev/null +++ b/tests/test_local/test_is_disconnect.py @@ -0,0 +1,45 @@ +"""Tests for DatabricksDialect.is_disconnect() method.""" +import pytest +from databricks.sqlalchemy import DatabricksDialect +from databricks.sql.exc import InterfaceError, DatabaseError, OperationalError + + +class TestIsDisconnect: + @pytest.fixture + def dialect(self): + return DatabricksDialect() + + def test_interface_error_is_disconnect(self, dialect): + """InterfaceError (client-side) is always a disconnect.""" + error = InterfaceError("Cannot create cursor from closed connection") + assert dialect.is_disconnect(error, None, None) is True + + def test_database_error_with_invalid_handle(self, dialect): + """DatabaseError with 'invalid handle' is a disconnect.""" + test_cases = [ + DatabaseError("Invalid SessionHandle"), + DatabaseError("[Errno INVALID_HANDLE] Session does not exist"), + DatabaseError("INVALID HANDLE"), + DatabaseError("invalid handle"), + ] + for error in test_cases: + assert dialect.is_disconnect(error, None, None) is True + + def test_database_error_without_invalid_handle(self, dialect): + """DatabaseError without 'invalid handle' is not a disconnect.""" + test_cases = [ + DatabaseError("Syntax error in SQL"), + DatabaseError("Table not found"), + DatabaseError("Permission denied"), + ] + for error in test_cases: + assert dialect.is_disconnect(error, None, None) is False + + def test_other_errors_not_disconnect(self, dialect): + """Other exception types are not disconnects.""" + test_cases = [ + OperationalError("Timeout waiting for query"), + Exception("Some random error"), + ] + for error in test_cases: + assert dialect.is_disconnect(error, None, None) is False