Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
89 changes: 83 additions & 6 deletions src/databricks/sql/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,77 @@

NO_NATIVE_PARAMS: List = []

# All recognized **kwargs keys consumed by Connection.__init__ and the callees it
# dispatches to (build_client_context, Session.__init__, get_python_sql_connector_auth_provider).
# Any key passed by the caller that is NOT in this set will trigger a warning.
KNOWN_KWARGS: frozenset = frozenset(
{
# client.py
"access_token",
"enable_metric_view_metadata",
"_disable_pandas",
"enable_query_result_lz4_compression",
"use_cloud_fetch",
"telemetry_batch_size",
"enable_telemetry",
"force_enable_telemetry",
"use_inline_params",
"staging_allowed_local_path",
"fetch_autocommit_from_server",
"pool_maxsize",
"use_hybrid_disposition",
# session.py
"_port",
"user_agent_entry",
"_user_agent_entry",
"use_sea",
# SSL / TLS options (session.py, utils.py, auth/auth.py)
"_tls_no_verify",
"_tls_verify_hostname",
"_tls_trusted_ca_file",
"_tls_client_cert_file",
"_tls_client_cert_key_file",
"_tls_client_cert_key_password",
"_use_cert_as_auth",
"_enable_ssl",
"_skip_routing_headers",
# auth/auth.py
"auth_type",
"username",
"password",
"oauth_client_id",
"oauth_redirect_port",
"azure_client_id",
"azure_client_secret",
"azure_tenant_id",
"azure_workspace_resource_id",
"experimental_oauth_persistence",
"credentials_provider",
"identity_federation_client_id",
# utils.py / build_client_context
"_socket_timeout",
"_retry_stop_after_attempts_count",
"_retry_delay_min",
"_retry_delay_max",
"_retry_stop_after_attempts_duration",
"_retry_delay_default",
"_retry_dangerous_codes",
"_proxy_auth_method",
"_pool_connections",
"_pool_maxsize",
"telemetry_circuit_breaker_enabled",
# thrift_backend.py
"_connection_uri",
"_use_arrow_native_decimals",
"_use_arrow_native_timestamps",
"max_download_threads",
"_enable_v3_retries",
"_retry_max_redirects",
# sea/utils/http_client.py
"max_connections",
}
)

# Transaction isolation level constants (extension to PEP 249)
TRANSACTION_ISOLATION_LEVEL_REPEATABLE_READ = "REPEATABLE_READ"

Expand Down Expand Up @@ -271,6 +342,10 @@ def read(self) -> Optional[OAuthToken]:
http_path,
)

unknown = set(kwargs.keys()) - KNOWN_KWARGS
if unknown:
logger.warning("Unrecognized connection parameter(s): %s", unknown)

if access_token:
access_token_kv = {"access_token": access_token}
kwargs = {**kwargs, **access_token_kv}
Expand Down Expand Up @@ -326,9 +401,9 @@ def read(self) -> Optional[OAuthToken]:
http_path=http_path,
port=kwargs.get("_port", 443),
client_context=client_context,
user_agent=self.session.useragent_header
if hasattr(self, "session")
else None,
user_agent=(
self.session.useragent_header if hasattr(self, "session") else None
),
enable_telemetry=enable_telemetry,
)
raise e
Expand Down Expand Up @@ -375,9 +450,11 @@ def read(self) -> Optional[OAuthToken]:

driver_connection_params = DriverConnectionParameters(
http_path=http_path,
mode=DatabricksClientType.SEA
if self.session.use_sea
else DatabricksClientType.THRIFT,
mode=(
DatabricksClientType.SEA
if self.session.use_sea
else DatabricksClientType.THRIFT
),
host_info=HostDetails(host_url=server_hostname, port=self.session.port),
auth_mech=TelemetryHelper.get_auth_mechanism(self.session.auth_provider),
auth_flow=TelemetryHelper.get_auth_flow(self.session.auth_provider),
Expand Down
34 changes: 33 additions & 1 deletion tests/unit/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -509,7 +509,7 @@ def test_column_name_api(self):

expected_values = [["val1", 321, 52.32], ["val2", 2321, 252.32]]

for (row, expected) in zip(data, expected_values):
for row, expected in zip(data, expected_values):
self.assertEqual(row.first_col, expected[0])
self.assertEqual(row.second_col, expected[1])
self.assertEqual(row.third_col, expected[2])
Expand Down Expand Up @@ -633,6 +633,38 @@ def mock_close_normal():
cursors_closed, [1, 2], "Both cursors should have close called"
)

@patch("databricks.sql.session.ThriftDatabricksClient")
def test_unknown_connection_param_issues_warning(self, mock_client_class):
"""Passing an unrecognized kwarg should trigger a logger.warning call."""
with patch("databricks.sql.client.logger") as mock_logger:
databricks.sql.connect(**self.DUMMY_CONNECTION_ARGS, unknown_param_xyz=True)
mock_logger.warning.assert_called()

@patch("databricks.sql.session.ThriftDatabricksClient")
def test_unknown_connection_param_warning_names_the_param(self, mock_client_class):
"""The warning message should include the unknown parameter name."""
with patch("databricks.sql.client.logger") as mock_logger:
databricks.sql.connect(**self.DUMMY_CONNECTION_ARGS, unknown_param_xyz=True)
warning_calls = mock_logger.warning.call_args_list
warning_messages = " ".join(str(call) for call in warning_calls)
self.assertIn("unknown_param_xyz", warning_messages)

@patch("databricks.sql.session.ThriftDatabricksClient")
def test_known_connection_params_do_not_issue_warning(self, mock_client_class):
"""Passing only recognized kwargs should not trigger an unknown-param warning."""
with patch("databricks.sql.client.logger") as mock_logger:
databricks.sql.connect(
**self.DUMMY_CONNECTION_ARGS,
use_cloud_fetch=True,
_socket_timeout=30,
)
# Ensure no warning was issued about unrecognized parameters
for call in mock_logger.warning.call_args_list:
self.assertNotIn(
"Unrecognized connection parameter",
str(call),
)


class TransactionTestSuite(unittest.TestCase):
"""
Expand Down
Loading