Skip to content

Commit 88f6041

Browse files
committed
[PECOBLR-1968] Implement fix
1 parent e916f71 commit 88f6041

File tree

2 files changed

+116
-7
lines changed

2 files changed

+116
-7
lines changed

src/databricks/sql/client.py

Lines changed: 83 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,77 @@
9191

9292
NO_NATIVE_PARAMS: List = []
9393

94+
# All recognized **kwargs keys consumed by Connection.__init__ and the callees it
95+
# dispatches to (build_client_context, Session.__init__, get_python_sql_connector_auth_provider).
96+
# Any key passed by the caller that is NOT in this set will trigger a warning.
97+
KNOWN_KWARGS: frozenset = frozenset(
98+
{
99+
# client.py
100+
"access_token",
101+
"enable_metric_view_metadata",
102+
"_disable_pandas",
103+
"enable_query_result_lz4_compression",
104+
"use_cloud_fetch",
105+
"telemetry_batch_size",
106+
"enable_telemetry",
107+
"force_enable_telemetry",
108+
"use_inline_params",
109+
"staging_allowed_local_path",
110+
"fetch_autocommit_from_server",
111+
"pool_maxsize",
112+
"use_hybrid_disposition",
113+
# session.py
114+
"_port",
115+
"user_agent_entry",
116+
"_user_agent_entry",
117+
"use_sea",
118+
# SSL / TLS options (session.py, utils.py, auth/auth.py)
119+
"_tls_no_verify",
120+
"_tls_verify_hostname",
121+
"_tls_trusted_ca_file",
122+
"_tls_client_cert_file",
123+
"_tls_client_cert_key_file",
124+
"_tls_client_cert_key_password",
125+
"_use_cert_as_auth",
126+
"_enable_ssl",
127+
"_skip_routing_headers",
128+
# auth/auth.py
129+
"auth_type",
130+
"username",
131+
"password",
132+
"oauth_client_id",
133+
"oauth_redirect_port",
134+
"azure_client_id",
135+
"azure_client_secret",
136+
"azure_tenant_id",
137+
"azure_workspace_resource_id",
138+
"experimental_oauth_persistence",
139+
"credentials_provider",
140+
"identity_federation_client_id",
141+
# utils.py / build_client_context
142+
"_socket_timeout",
143+
"_retry_stop_after_attempts_count",
144+
"_retry_delay_min",
145+
"_retry_delay_max",
146+
"_retry_stop_after_attempts_duration",
147+
"_retry_delay_default",
148+
"_retry_dangerous_codes",
149+
"_proxy_auth_method",
150+
"_pool_connections",
151+
"_pool_maxsize",
152+
"telemetry_circuit_breaker_enabled",
153+
# thrift_backend.py
154+
"_connection_uri",
155+
"_use_arrow_native_decimals",
156+
"_use_arrow_native_timestamps",
157+
"max_download_threads",
158+
"_enable_v3_retries",
159+
"_retry_max_redirects",
160+
# sea/utils/http_client.py
161+
"max_connections",
162+
}
163+
)
164+
94165
# Transaction isolation level constants (extension to PEP 249)
95166
TRANSACTION_ISOLATION_LEVEL_REPEATABLE_READ = "REPEATABLE_READ"
96167

@@ -271,6 +342,10 @@ def read(self) -> Optional[OAuthToken]:
271342
http_path,
272343
)
273344

345+
unknown = set(kwargs.keys()) - KNOWN_KWARGS
346+
if unknown:
347+
logger.warning("Unrecognized connection parameter(s): %s", unknown)
348+
274349
if access_token:
275350
access_token_kv = {"access_token": access_token}
276351
kwargs = {**kwargs, **access_token_kv}
@@ -326,9 +401,9 @@ def read(self) -> Optional[OAuthToken]:
326401
http_path=http_path,
327402
port=kwargs.get("_port", 443),
328403
client_context=client_context,
329-
user_agent=self.session.useragent_header
330-
if hasattr(self, "session")
331-
else None,
404+
user_agent=(
405+
self.session.useragent_header if hasattr(self, "session") else None
406+
),
332407
enable_telemetry=enable_telemetry,
333408
)
334409
raise e
@@ -375,9 +450,11 @@ def read(self) -> Optional[OAuthToken]:
375450

376451
driver_connection_params = DriverConnectionParameters(
377452
http_path=http_path,
378-
mode=DatabricksClientType.SEA
379-
if self.session.use_sea
380-
else DatabricksClientType.THRIFT,
453+
mode=(
454+
DatabricksClientType.SEA
455+
if self.session.use_sea
456+
else DatabricksClientType.THRIFT
457+
),
381458
host_info=HostDetails(host_url=server_hostname, port=self.session.port),
382459
auth_mech=TelemetryHelper.get_auth_mechanism(self.session.auth_provider),
383460
auth_flow=TelemetryHelper.get_auth_flow(self.session.auth_provider),

tests/unit/test_client.py

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -509,7 +509,7 @@ def test_column_name_api(self):
509509

510510
expected_values = [["val1", 321, 52.32], ["val2", 2321, 252.32]]
511511

512-
for (row, expected) in zip(data, expected_values):
512+
for row, expected in zip(data, expected_values):
513513
self.assertEqual(row.first_col, expected[0])
514514
self.assertEqual(row.second_col, expected[1])
515515
self.assertEqual(row.third_col, expected[2])
@@ -633,6 +633,38 @@ def mock_close_normal():
633633
cursors_closed, [1, 2], "Both cursors should have close called"
634634
)
635635

636+
@patch("databricks.sql.session.ThriftDatabricksClient")
637+
def test_unknown_connection_param_issues_warning(self, mock_client_class):
638+
"""Passing an unrecognized kwarg should trigger a logger.warning call."""
639+
with patch("databricks.sql.client.logger") as mock_logger:
640+
databricks.sql.connect(**self.DUMMY_CONNECTION_ARGS, unknown_param_xyz=True)
641+
mock_logger.warning.assert_called()
642+
643+
@patch("databricks.sql.session.ThriftDatabricksClient")
644+
def test_unknown_connection_param_warning_names_the_param(self, mock_client_class):
645+
"""The warning message should include the unknown parameter name."""
646+
with patch("databricks.sql.client.logger") as mock_logger:
647+
databricks.sql.connect(**self.DUMMY_CONNECTION_ARGS, unknown_param_xyz=True)
648+
warning_calls = mock_logger.warning.call_args_list
649+
warning_messages = " ".join(str(call) for call in warning_calls)
650+
self.assertIn("unknown_param_xyz", warning_messages)
651+
652+
@patch("databricks.sql.session.ThriftDatabricksClient")
653+
def test_known_connection_params_do_not_issue_warning(self, mock_client_class):
654+
"""Passing only recognized kwargs should not trigger an unknown-param warning."""
655+
with patch("databricks.sql.client.logger") as mock_logger:
656+
databricks.sql.connect(
657+
**self.DUMMY_CONNECTION_ARGS,
658+
use_cloud_fetch=True,
659+
_socket_timeout=30,
660+
)
661+
# Ensure no warning was issued about unrecognized parameters
662+
for call in mock_logger.warning.call_args_list:
663+
self.assertNotIn(
664+
"Unrecognized connection parameter",
665+
str(call),
666+
)
667+
636668

637669
class TransactionTestSuite(unittest.TestCase):
638670
"""

0 commit comments

Comments
 (0)