diff --git a/src/databricks/sql/client.py b/src/databricks/sql/client.py index 2aeea175..7698ffcf 100755 --- a/src/databricks/sql/client.py +++ b/src/databricks/sql/client.py @@ -91,6 +91,77 @@ NO_NATIVE_PARAMS: List = [] +# All recognized **kwargs keys consumed by Connection.__init__ and the callees it +# dispatches to (build_client_context, Session.__init__, get_python_sql_connector_auth_provider). +# Any key passed by the caller that is NOT in this set will trigger a warning. +KNOWN_KWARGS: frozenset = frozenset( + { + # client.py + "access_token", + "enable_metric_view_metadata", + "_disable_pandas", + "enable_query_result_lz4_compression", + "use_cloud_fetch", + "telemetry_batch_size", + "enable_telemetry", + "force_enable_telemetry", + "use_inline_params", + "staging_allowed_local_path", + "fetch_autocommit_from_server", + "pool_maxsize", + "use_hybrid_disposition", + # session.py + "_port", + "user_agent_entry", + "_user_agent_entry", + "use_sea", + # SSL / TLS options (session.py, utils.py, auth/auth.py) + "_tls_no_verify", + "_tls_verify_hostname", + "_tls_trusted_ca_file", + "_tls_client_cert_file", + "_tls_client_cert_key_file", + "_tls_client_cert_key_password", + "_use_cert_as_auth", + "_enable_ssl", + "_skip_routing_headers", + # auth/auth.py + "auth_type", + "username", + "password", + "oauth_client_id", + "oauth_redirect_port", + "azure_client_id", + "azure_client_secret", + "azure_tenant_id", + "azure_workspace_resource_id", + "experimental_oauth_persistence", + "credentials_provider", + "identity_federation_client_id", + # utils.py / build_client_context + "_socket_timeout", + "_retry_stop_after_attempts_count", + "_retry_delay_min", + "_retry_delay_max", + "_retry_stop_after_attempts_duration", + "_retry_delay_default", + "_retry_dangerous_codes", + "_proxy_auth_method", + "_pool_connections", + "_pool_maxsize", + "telemetry_circuit_breaker_enabled", + # thrift_backend.py + "_connection_uri", + "_use_arrow_native_decimals", + "_use_arrow_native_timestamps", + "max_download_threads", + "_enable_v3_retries", + "_retry_max_redirects", + # sea/utils/http_client.py + "max_connections", + } +) + # Transaction isolation level constants (extension to PEP 249) TRANSACTION_ISOLATION_LEVEL_REPEATABLE_READ = "REPEATABLE_READ" @@ -271,6 +342,10 @@ def read(self) -> Optional[OAuthToken]: http_path, ) + unknown = set(kwargs.keys()) - KNOWN_KWARGS + if unknown: + logger.warning("Unrecognized connection parameter(s): %s", unknown) + if access_token: access_token_kv = {"access_token": access_token} kwargs = {**kwargs, **access_token_kv} @@ -326,9 +401,9 @@ def read(self) -> Optional[OAuthToken]: http_path=http_path, port=kwargs.get("_port", 443), client_context=client_context, - user_agent=self.session.useragent_header - if hasattr(self, "session") - else None, + user_agent=( + self.session.useragent_header if hasattr(self, "session") else None + ), enable_telemetry=enable_telemetry, ) raise e @@ -375,9 +450,11 @@ def read(self) -> Optional[OAuthToken]: driver_connection_params = DriverConnectionParameters( http_path=http_path, - mode=DatabricksClientType.SEA - if self.session.use_sea - else DatabricksClientType.THRIFT, + mode=( + DatabricksClientType.SEA + if self.session.use_sea + else DatabricksClientType.THRIFT + ), host_info=HostDetails(host_url=server_hostname, port=self.session.port), auth_mech=TelemetryHelper.get_auth_mechanism(self.session.auth_provider), auth_flow=TelemetryHelper.get_auth_flow(self.session.auth_provider), diff --git a/tests/unit/test_client.py b/tests/unit/test_client.py index 5b699193..86e3902d 100644 --- a/tests/unit/test_client.py +++ b/tests/unit/test_client.py @@ -509,7 +509,7 @@ def test_column_name_api(self): expected_values = [["val1", 321, 52.32], ["val2", 2321, 252.32]] - for (row, expected) in zip(data, expected_values): + for row, expected in zip(data, expected_values): self.assertEqual(row.first_col, expected[0]) self.assertEqual(row.second_col, expected[1]) self.assertEqual(row.third_col, expected[2]) @@ -633,6 +633,38 @@ def mock_close_normal(): cursors_closed, [1, 2], "Both cursors should have close called" ) + @patch("databricks.sql.session.ThriftDatabricksClient") + def test_unknown_connection_param_issues_warning(self, mock_client_class): + """Passing an unrecognized kwarg should trigger a logger.warning call.""" + with patch("databricks.sql.client.logger") as mock_logger: + databricks.sql.connect(**self.DUMMY_CONNECTION_ARGS, unknown_param_xyz=True) + mock_logger.warning.assert_called() + + @patch("databricks.sql.session.ThriftDatabricksClient") + def test_unknown_connection_param_warning_names_the_param(self, mock_client_class): + """The warning message should include the unknown parameter name.""" + with patch("databricks.sql.client.logger") as mock_logger: + databricks.sql.connect(**self.DUMMY_CONNECTION_ARGS, unknown_param_xyz=True) + warning_calls = mock_logger.warning.call_args_list + warning_messages = " ".join(str(call) for call in warning_calls) + self.assertIn("unknown_param_xyz", warning_messages) + + @patch("databricks.sql.session.ThriftDatabricksClient") + def test_known_connection_params_do_not_issue_warning(self, mock_client_class): + """Passing only recognized kwargs should not trigger an unknown-param warning.""" + with patch("databricks.sql.client.logger") as mock_logger: + databricks.sql.connect( + **self.DUMMY_CONNECTION_ARGS, + use_cloud_fetch=True, + _socket_timeout=30, + ) + # Ensure no warning was issued about unrecognized parameters + for call in mock_logger.warning.call_args_list: + self.assertNotIn( + "Unrecognized connection parameter", + str(call), + ) + class TransactionTestSuite(unittest.TestCase): """