From a6a356443d4e97107d606e3b9da8d4b7bede94ab Mon Sep 17 00:00:00 2001 From: Madhavendra Rathore Date: Wed, 21 Jan 2026 16:55:18 +0530 Subject: [PATCH 1/3] Implemented the functionality of poll_pre_ping which executes a query to see if the connection is alive. If not it'll recycle the session --- src/databricks/sqlalchemy/base.py | 25 ++++++++++++++ tests/test_local/e2e/test_basic.py | 55 ++++++++++++++++++++++++++++++ tests/test_local/test_ping.py | 47 +++++++++++++++++++++++++ 3 files changed, 127 insertions(+) create mode 100644 tests/test_local/test_ping.py diff --git a/src/databricks/sqlalchemy/base.py b/src/databricks/sqlalchemy/base.py index 9148de7..a0f8598 100644 --- a/src/databricks/sqlalchemy/base.py +++ b/src/databricks/sqlalchemy/base.py @@ -336,6 +336,31 @@ def do_rollback(self, dbapi_connection): # Databricks SQL Does not support transactions pass + def do_ping(self, dbapi_connection): + """Test if a database connection is alive. + + This method is called by SQLAlchemy when pool_pre_ping=True to verify + connections are still valid before using them from the pool. + + Args: + dbapi_connection: A raw DBAPI connection (from databricks-sql-connector) + + Returns: + True if the connection is alive, False otherwise. + """ + try: + cursor = dbapi_connection.cursor() + try: + cursor.execute("SELECT VERSION()") + cursor.fetchone() + return True + finally: + cursor.close() + except Exception: + # Any exception means the connection is dead + # SQLAlchemy will discard it and create a new one + return False + @reflection.cache def has_table( self, connection, table_name, schema=None, catalog=None, **kwargs diff --git a/tests/test_local/e2e/test_basic.py b/tests/test_local/e2e/test_basic.py index ce0b5d8..0d4ab10 100644 --- a/tests/test_local/e2e/test_basic.py +++ b/tests/test_local/e2e/test_basic.py @@ -541,3 +541,58 @@ def test_table_comment_reflection(self, inspector: Inspector, table: Table): def test_column_comment(self, inspector: Inspector, table: Table): result = inspector.get_columns(table.name)[0].get("comment") assert result == "column comment" + + +def test_pool_pre_ping_with_closed_connection(connection_details): + """Test that pool_pre_ping detects closed connections and creates new ones. + + This test verifies that when a connection is manually closed (simulating + session expiration), pool_pre_ping detects it and automatically creates + a new connection without raising an error to the user. + """ + conn_string, connect_args = version_agnostic_connect_arguments(connection_details) + + # Create engine with pool_pre_ping enabled + engine = create_engine( + conn_string, + connect_args=connect_args, + pool_pre_ping=True, + pool_size=1, + max_overflow=0 + ) + + # Step 1: Create connection and get session ID + with engine.connect() as conn: + result = conn.execute(text("SELECT VERSION()")).scalar() + assert result is not None + + # Get session ID of first connection + raw_conn = conn.connection.dbapi_connection + session_id_1 = raw_conn.get_session_id_hex() + assert session_id_1 is not None + + # Step 2: Manually close the connection to simulate expiration + pooled_conn = engine.pool._pool.queue[0] + pooled_conn.driver_connection.close() + + # Verify connection is closed + assert not pooled_conn.driver_connection.open + + # Step 3: Try to use the closed connection - pool_pre_ping should detect and recycle + with engine.connect() as conn: + result = conn.execute(text("SELECT VERSION()")).scalar() + assert result is not None + + # Get session ID of new connection + raw_conn = conn.connection.dbapi_connection + session_id_2 = raw_conn.get_session_id_hex() + assert session_id_2 is not None + + # Verify a NEW connection was created (different session ID) + assert session_id_1 != session_id_2, ( + "pool_pre_ping should have detected the closed connection " + "and created a new one with a different session ID" + ) + + # Cleanup + engine.dispose() diff --git a/tests/test_local/test_ping.py b/tests/test_local/test_ping.py new file mode 100644 index 0000000..34f3d2e --- /dev/null +++ b/tests/test_local/test_ping.py @@ -0,0 +1,47 @@ +"""Unit tests for do_ping() method in DatabricksDialect.""" +import pytest +from unittest.mock import Mock +from databricks.sqlalchemy import DatabricksDialect + + +class TestDoPing: + """Test the do_ping() method for connection health checks.""" + + @pytest.fixture + def dialect(self): + """Create a DatabricksDialect instance.""" + return DatabricksDialect() + + def test_do_ping_success(self, dialect): + """Test do_ping returns True when connection is alive.""" + mock_connection = Mock() + mock_cursor = Mock() + mock_connection.cursor.return_value = mock_cursor + + result = dialect.do_ping(mock_connection) + + assert result is True + mock_cursor.execute.assert_called_once_with("SELECT VERSION()") + mock_cursor.fetchone.assert_called_once() + mock_cursor.close.assert_called_once() + + def test_do_ping_failure_cursor_creation(self, dialect): + """Test do_ping returns False when cursor creation fails.""" + mock_connection = Mock() + mock_connection.cursor.side_effect = Exception("Connection closed") + + result = dialect.do_ping(mock_connection) + + assert result is False + + def test_do_ping_failure_execute_and_cursor_closes(self, dialect): + """Test do_ping returns False on execute error and cursor is closed.""" + mock_connection = Mock() + mock_cursor = Mock() + mock_connection.cursor.return_value = mock_cursor + mock_cursor.execute.side_effect = Exception("Query failed") + + result = dialect.do_ping(mock_connection) + + assert result is False + mock_cursor.close.assert_called_once() From d2a3339ce4f7d5b50d467d89ba1032b31f6ad10e Mon Sep 17 00:00:00 2001 From: Madhavendra Rathore Date: Thu, 22 Jan 2026 10:25:52 +0530 Subject: [PATCH 2/3] Added functionality for is_disconnect() --- src/databricks/sqlalchemy/base.py | 37 ++++++++++++++ tests/test_local/e2e/test_basic.py | 70 ++++++++++++++++++++++++++ tests/test_local/test_is_disconnect.py | 45 +++++++++++++++++ 3 files changed, 152 insertions(+) create mode 100644 tests/test_local/test_is_disconnect.py diff --git a/src/databricks/sqlalchemy/base.py b/src/databricks/sqlalchemy/base.py index a0f8598..e6cca14 100644 --- a/src/databricks/sqlalchemy/base.py +++ b/src/databricks/sqlalchemy/base.py @@ -342,6 +342,10 @@ def do_ping(self, dbapi_connection): This method is called by SQLAlchemy when pool_pre_ping=True to verify connections are still valid before using them from the pool. + This implementation improves upon SQLAlchemy's default do_ping() by + wrapping the cursor creation in a try block, which properly handles + cases where the connection is closed and cursor() itself raises an exception. + Args: dbapi_connection: A raw DBAPI connection (from databricks-sql-connector) @@ -361,6 +365,39 @@ def do_ping(self, dbapi_connection): # SQLAlchemy will discard it and create a new one return False + def is_disconnect(self, e, connection, cursor): + """Determine if an exception indicates the connection was lost. + + This method is called by SQLAlchemy after exceptions occur during query + execution to determine if the error was due to a lost connection. If this + returns True, SQLAlchemy will invalidate the connection and create a new + one for the next operation. + + This is complementary to do_ping(): + - do_ping() is proactive: checks connection health BEFORE queries + - is_disconnect() is reactive: classifies errors AFTER they occur + + Args: + e: The exception that was raised + connection: The connection that raised the exception (may be None) + cursor: The cursor that raised the exception (may be None) + + Returns: + True if the error indicates a disconnect, False otherwise + """ + from databricks.sql.exc import InterfaceError, DatabaseError + + # InterfaceError: Client-side errors (e.g., connection already closed) + if isinstance(e, InterfaceError): + return True + + # DatabaseError: Server-side errors with invalid handle indicate session expired + if isinstance(e, DatabaseError): + error_msg = str(e).lower() + return "invalid" in error_msg and "handle" in error_msg + + return False + @reflection.cache def has_table( self, connection, table_name, schema=None, catalog=None, **kwargs diff --git a/tests/test_local/e2e/test_basic.py b/tests/test_local/e2e/test_basic.py index 0d4ab10..067991c 100644 --- a/tests/test_local/e2e/test_basic.py +++ b/tests/test_local/e2e/test_basic.py @@ -596,3 +596,73 @@ def test_pool_pre_ping_with_closed_connection(connection_details): # Cleanup engine.dispose() + + +def test_is_disconnect_handles_runtime_errors(db_engine): + """Test that is_disconnect() properly classifies disconnect errors during query execution. + + This tests the reactive error handling (complementary to pool_pre_ping's proactive checking). + When a connection fails DURING a query, is_disconnect() should recognize the error + and tell SQLAlchemy to invalidate the connection. + """ + from sqlalchemy import create_engine, text + from sqlalchemy.exc import DBAPIError + + engine = create_engine( + db_engine.url, + pool_pre_ping=False, # Disabled - we want to test is_disconnect, not do_ping + pool_size=1, + max_overflow=0, + ) + + # Step 1: Execute a successful query + with engine.connect() as conn: + result = conn.execute(text("SELECT VERSION()")).scalar() + assert result is not None + + # Get session ID of working connection + raw_conn = conn.connection.dbapi_connection + session_id_1 = raw_conn.get_session_id_hex() + assert session_id_1 is not None + + # Step 2: Manually close the connection to simulate server-side session expiration + pooled_conn = engine.pool._pool.queue[0] + pooled_conn.driver_connection.close() + + # Step 3: Try to execute query on closed connection + # This should: + # 1. Fail with an exception + # 2. is_disconnect() gets called by SQLAlchemy + # 3. Returns True (recognizes it as disconnect error) + # 4. SQLAlchemy invalidates the connection + # 5. Next operation gets a fresh connection + + # First query will fail because connection is closed + try: + with engine.connect() as conn: + conn.execute(text("SELECT VERSION()")).scalar() + # If we get here without exception, the connection wasn't actually closed + pytest.skip("Connection wasn't properly closed - cannot test is_disconnect") + except DBAPIError as e: + # Expected - connection was closed + # is_disconnect() should have been called and returned True + # This causes SQLAlchemy to invalidate the connection + assert "closed" in str(e).lower() or "invalid" in str(e).lower() + + # Step 4: Next query should work because is_disconnect() invalidated the bad connection + with engine.connect() as conn: + result = conn.execute(text("SELECT VERSION()")).scalar() + assert result is not None + + # Verify we got a NEW connection + raw_conn = conn.connection.dbapi_connection + session_id_2 = raw_conn.get_session_id_hex() + assert session_id_2 is not None + + # Different session ID proves connection was invalidated and recreated + assert session_id_1 != session_id_2, ( + "is_disconnect() should have invalidated the bad connection, " + "causing SQLAlchemy to create a new one with different session ID" + ) + + engine.dispose() diff --git a/tests/test_local/test_is_disconnect.py b/tests/test_local/test_is_disconnect.py new file mode 100644 index 0000000..e516745 --- /dev/null +++ b/tests/test_local/test_is_disconnect.py @@ -0,0 +1,45 @@ +"""Tests for DatabricksDialect.is_disconnect() method.""" +import pytest +from databricks.sqlalchemy import DatabricksDialect +from databricks.sql.exc import InterfaceError, DatabaseError, OperationalError + + +class TestIsDisconnect: + @pytest.fixture + def dialect(self): + return DatabricksDialect() + + def test_interface_error_is_disconnect(self, dialect): + """InterfaceError (client-side) is always a disconnect.""" + error = InterfaceError("Cannot create cursor from closed connection") + assert dialect.is_disconnect(error, None, None) is True + + def test_database_error_with_invalid_handle(self, dialect): + """DatabaseError with 'invalid handle' is a disconnect.""" + test_cases = [ + DatabaseError("Invalid SessionHandle"), + DatabaseError("[Errno INVALID_HANDLE] Session does not exist"), + DatabaseError("INVALID HANDLE"), + DatabaseError("invalid handle"), + ] + for error in test_cases: + assert dialect.is_disconnect(error, None, None) is True + + def test_database_error_without_invalid_handle(self, dialect): + """DatabaseError without 'invalid handle' is not a disconnect.""" + test_cases = [ + DatabaseError("Syntax error in SQL"), + DatabaseError("Table not found"), + DatabaseError("Permission denied"), + ] + for error in test_cases: + assert dialect.is_disconnect(error, None, None) is False + + def test_other_errors_not_disconnect(self, dialect): + """Other exception types are not disconnects.""" + test_cases = [ + OperationalError("Timeout waiting for query"), + Exception("Some random error"), + ] + for error in test_cases: + assert dialect.is_disconnect(error, None, None) is False From 77138f3cee415dd4261d48d39fee10004a8b630b Mon Sep 17 00:00:00 2001 From: Madhavendra Rathore Date: Thu, 22 Jan 2026 11:07:30 +0530 Subject: [PATCH 3/3] Implemented is_disconnect and removed the do_ping implementation keeping the default behaviour --- src/databricks/sqlalchemy/base.py | 36 +++-------------------- tests/test_local/e2e/test_basic.py | 3 +- tests/test_local/test_ping.py | 47 ------------------------------ 3 files changed, 5 insertions(+), 81 deletions(-) delete mode 100644 tests/test_local/test_ping.py diff --git a/src/databricks/sqlalchemy/base.py b/src/databricks/sqlalchemy/base.py index e6cca14..3c02d46 100644 --- a/src/databricks/sqlalchemy/base.py +++ b/src/databricks/sqlalchemy/base.py @@ -336,35 +336,6 @@ def do_rollback(self, dbapi_connection): # Databricks SQL Does not support transactions pass - def do_ping(self, dbapi_connection): - """Test if a database connection is alive. - - This method is called by SQLAlchemy when pool_pre_ping=True to verify - connections are still valid before using them from the pool. - - This implementation improves upon SQLAlchemy's default do_ping() by - wrapping the cursor creation in a try block, which properly handles - cases where the connection is closed and cursor() itself raises an exception. - - Args: - dbapi_connection: A raw DBAPI connection (from databricks-sql-connector) - - Returns: - True if the connection is alive, False otherwise. - """ - try: - cursor = dbapi_connection.cursor() - try: - cursor.execute("SELECT VERSION()") - cursor.fetchone() - return True - finally: - cursor.close() - except Exception: - # Any exception means the connection is dead - # SQLAlchemy will discard it and create a new one - return False - def is_disconnect(self, e, connection, cursor): """Determine if an exception indicates the connection was lost. @@ -373,9 +344,10 @@ def is_disconnect(self, e, connection, cursor): returns True, SQLAlchemy will invalidate the connection and create a new one for the next operation. - This is complementary to do_ping(): - - do_ping() is proactive: checks connection health BEFORE queries - - is_disconnect() is reactive: classifies errors AFTER they occur + This method is also used by SQLAlchemy's default do_ping() implementation + when pool_pre_ping=True. If do_ping() encounters an exception, it calls + is_disconnect() to classify the error and determine whether to invalidate + the connection. Args: e: The exception that was raised diff --git a/tests/test_local/e2e/test_basic.py b/tests/test_local/e2e/test_basic.py index 067991c..48b46d9 100644 --- a/tests/test_local/e2e/test_basic.py +++ b/tests/test_local/e2e/test_basic.py @@ -601,9 +601,8 @@ def test_pool_pre_ping_with_closed_connection(connection_details): def test_is_disconnect_handles_runtime_errors(db_engine): """Test that is_disconnect() properly classifies disconnect errors during query execution. - This tests the reactive error handling (complementary to pool_pre_ping's proactive checking). When a connection fails DURING a query, is_disconnect() should recognize the error - and tell SQLAlchemy to invalidate the connection. + and tell SQLAlchemy to invalidate the connection so the next query gets a fresh one. """ from sqlalchemy import create_engine, text from sqlalchemy.exc import DBAPIError diff --git a/tests/test_local/test_ping.py b/tests/test_local/test_ping.py deleted file mode 100644 index 34f3d2e..0000000 --- a/tests/test_local/test_ping.py +++ /dev/null @@ -1,47 +0,0 @@ -"""Unit tests for do_ping() method in DatabricksDialect.""" -import pytest -from unittest.mock import Mock -from databricks.sqlalchemy import DatabricksDialect - - -class TestDoPing: - """Test the do_ping() method for connection health checks.""" - - @pytest.fixture - def dialect(self): - """Create a DatabricksDialect instance.""" - return DatabricksDialect() - - def test_do_ping_success(self, dialect): - """Test do_ping returns True when connection is alive.""" - mock_connection = Mock() - mock_cursor = Mock() - mock_connection.cursor.return_value = mock_cursor - - result = dialect.do_ping(mock_connection) - - assert result is True - mock_cursor.execute.assert_called_once_with("SELECT VERSION()") - mock_cursor.fetchone.assert_called_once() - mock_cursor.close.assert_called_once() - - def test_do_ping_failure_cursor_creation(self, dialect): - """Test do_ping returns False when cursor creation fails.""" - mock_connection = Mock() - mock_connection.cursor.side_effect = Exception("Connection closed") - - result = dialect.do_ping(mock_connection) - - assert result is False - - def test_do_ping_failure_execute_and_cursor_closes(self, dialect): - """Test do_ping returns False on execute error and cursor is closed.""" - mock_connection = Mock() - mock_cursor = Mock() - mock_connection.cursor.return_value = mock_cursor - mock_cursor.execute.side_effect = Exception("Query failed") - - result = dialect.do_ping(mock_connection) - - assert result is False - mock_cursor.close.assert_called_once()