diff --git a/integration/test_auth.py b/integration/test_auth.py index 2ddecb756..dae8f863b 100644 --- a/integration/test_auth.py +++ b/integration/test_auth.py @@ -2,6 +2,7 @@ import warnings from typing import Dict, Optional +import grpc import httpx import pytest from _pytest.fixtures import SubRequest @@ -234,6 +235,20 @@ def test_api_key() -> None: client.collections.list_all() +@pytest.mark.parametrize("creds", [None, grpc.ssl_channel_credentials()]) +def test_custom_grpc_credentials(creds: Optional[grpc.ChannelCredentials]) -> None: + assert is_auth_enabled(f"localhost:{WCS_PORT}") + with weaviate.connect_to_local( + port=WCS_PORT, + grpc_port=WCS_PORT_GRPC, + auth_credentials=wvc.init.Auth.api_key(api_key="my-secret-key"), + additional_config=wvc.init.AdditionalConfig( + grpc_config=wvc.init.GrpcConfig(credentials=creds) + ), + ) as client: + assert client.is_live() + + @pytest.mark.parametrize("header_name", ["Authorization", "authorization"]) def test_api_key_in_header(header_name: str) -> None: assert is_auth_enabled(f"localhost:{WCS_PORT}") diff --git a/mock_tests/test_grpc_config.py b/mock_tests/test_grpc_config.py new file mode 100644 index 000000000..e0626118c --- /dev/null +++ b/mock_tests/test_grpc_config.py @@ -0,0 +1,106 @@ +from typing import Any, List, Tuple +from unittest.mock import MagicMock + +import grpc +import pytest + +from weaviate.config import GrpcConfig +from weaviate.connect import base as base_module +from weaviate.connect.base import ConnectionParams, ProtocolParams + + +@pytest.fixture +def secure_params() -> ConnectionParams: + return ConnectionParams( + http=ProtocolParams(host="localhost", port=8080, secure=False), + grpc=ProtocolParams(host="localhost", port=50051, secure=True), + ) + + +@pytest.fixture +def insecure_params() -> ConnectionParams: + return ConnectionParams( + http=ProtocolParams(host="localhost", port=8080, secure=False), + grpc=ProtocolParams(host="localhost", port=50051, secure=False), + ) + + +@pytest.fixture +def mock_grpc(monkeypatch: pytest.MonkeyPatch) -> MagicMock: + mock = MagicMock() + mock.aio = MagicMock() + monkeypatch.setattr(base_module, "grpc", mock) + return mock + + +@pytest.fixture +def mock_ssl_creds(monkeypatch: pytest.MonkeyPatch) -> MagicMock: + mock = MagicMock() + monkeypatch.setattr(base_module, "ssl_channel_credentials", mock) + return mock + + +def test_grpc_config_channel_options() -> None: + opts: List[Tuple[str, Any]] = [("grpc.ssl_target_name_override", "my-host")] + config = GrpcConfig(channel_options=opts) + assert config.channel_options == opts + + +def test_secure_channel_default_credentials( + secure_params: ConnectionParams, mock_grpc: MagicMock, mock_ssl_creds: MagicMock +) -> None: + mock_channel = MagicMock() + mock_grpc.secure_channel.return_value = mock_channel + + result = secure_params._grpc_channel(proxies={}, grpc_msg_size=None, is_async=False) + + mock_ssl_creds.assert_called_once_with() + mock_grpc.secure_channel.assert_called_once() + assert result is mock_channel + + +def test_insecure_channel_no_config( + insecure_params: ConnectionParams, mock_grpc: MagicMock +) -> None: + mock_channel = MagicMock() + mock_grpc.insecure_channel.return_value = mock_channel + + result = insecure_params._grpc_channel(proxies={}, grpc_msg_size=None, is_async=False) + + mock_grpc.insecure_channel.assert_called_once() + assert result is mock_channel + + +def test_channel_options_appended_secure( + secure_params: ConnectionParams, mock_grpc: MagicMock, mock_ssl_creds: MagicMock +) -> None: + config = GrpcConfig( + channel_options=[("grpc.ssl_target_name_override", "my-gateway.example.com")] + ) + secure_params._grpc_channel(proxies={}, grpc_msg_size=None, is_async=False, grpc_config=config) + + options = mock_grpc.secure_channel.call_args.kwargs["options"] + assert ("grpc.ssl_target_name_override", "my-gateway.example.com") in options + + +def test_channel_options_appended_insecure( + insecure_params: ConnectionParams, mock_grpc: MagicMock +) -> None: + config = GrpcConfig(channel_options=[("grpc.keepalive_time_ms", 30000)]) + insecure_params._grpc_channel( + proxies={}, grpc_msg_size=None, is_async=False, grpc_config=config + ) + + options = mock_grpc.insecure_channel.call_args.kwargs["options"] + assert ("grpc.keepalive_time_ms", 30000) in options + + +def test_credentials( + secure_params: ConnectionParams, mock_grpc: MagicMock, mock_ssl_creds: MagicMock +) -> None: + creds = grpc.ssl_channel_credentials() + config = GrpcConfig(credentials=creds) + secure_params._grpc_channel(proxies={}, grpc_msg_size=None, is_async=False, grpc_config=config) + + mock_ssl_creds.assert_not_called() + assert mock_grpc.secure_channel.call_args.kwargs["credentials"] is creds diff --git a/weaviate/classes/init.py b/weaviate/classes/init.py index 7a2730adc..7dd5dace1 100644 --- a/weaviate/classes/init.py +++ b/weaviate/classes/init.py @@ -1,4 +1,4 @@ from weaviate.auth import Auth -from weaviate.config import AdditionalConfig, Proxies, Timeout +from weaviate.config import AdditionalConfig, GrpcConfig, Proxies, Timeout -__all__ = ["Auth", "AdditionalConfig", "Proxies", "Timeout"] +__all__ = ["Auth", "AdditionalConfig", "GrpcConfig", "Proxies", "Timeout"] diff --git a/weaviate/client_executor.py b/weaviate/client_executor.py index cec92f80b..3125fd9cd 100644 --- a/weaviate/client_executor.py +++ b/weaviate/client_executor.py @@ -81,6 +81,7 @@ def __init__( proxies=config.proxies, trust_env=config.trust_env, skip_init_checks=skip_init_checks, + grpc_config=config.grpc_config, ) self.integrations = _Integrations(self._connection) diff --git a/weaviate/collections/classes/config.py b/weaviate/collections/classes/config.py index 4bd95f814..6d8d5bdf1 100644 --- a/weaviate/collections/classes/config.py +++ b/weaviate/collections/classes/config.py @@ -74,7 +74,7 @@ _Vectors, _VectorsUpdate, ) -from weaviate.exceptions import WeaviateInvalidInputError, WeaviateInsertInvalidPropertyError +from weaviate.exceptions import WeaviateInsertInvalidPropertyError, WeaviateInvalidInputError from weaviate.str_enum import BaseEnum from weaviate.util import _capitalize_first_letter from weaviate.warnings import _Warnings diff --git a/weaviate/config.py b/weaviate/config.py index bc0525531..a41f335fa 100644 --- a/weaviate/config.py +++ b/weaviate/config.py @@ -1,7 +1,9 @@ from dataclasses import dataclass, field from typing import Optional, Tuple, Union -from pydantic import BaseModel, Field +from grpc import ChannelCredentials +from grpc.aio._typing import ChannelArgumentType +from pydantic import BaseModel, ConfigDict, Field @dataclass @@ -66,6 +68,36 @@ class Proxies(BaseModel): grpc: Optional[str] = Field(default=None) +class GrpcConfig(BaseModel): + """Configuration for the gRPC channel used by the Weaviate client. Use this to customize TLS/SSL settings for gRPC connections. + + To provide your own `channel_options`, supply a list of tuples where each tuple contains the name of the gRPC channel option and its corresponding value. + [Reference](https://grpc.github.io/grpc/python/glossary.html#term-channel_arguments) + + To provide your own `credentials`, use the `ssl_channel_credentials()` function from the `grpc` library to build a `ChannelCredentials` object. + [Reference](https://grpc.github.io/grpc/python/grpc.html#grpc.ssl_channel_credentials) + + Example usage: + ```python + from grpc import ssl_channel_credentials + import weaviate.classes as wvc + + conf = wvc.init.GrpcConfig( + channel_options=[ + ("grpc.keepalive_time_ms", 10000), + ("grpc.keepalive_timeout_ms", 5000), + ], + credentials=ssl_channel_credentials(...), + ) + ``` + """ + + model_config = ConfigDict(arbitrary_types_allowed=True) + + channel_options: Optional[ChannelArgumentType] = Field(default=None) + credentials: Optional[ChannelCredentials] = Field(default=None) + + class AdditionalConfig(BaseModel): """Use this class to specify the connection and proxy settings for your client when connecting to Weaviate. @@ -81,6 +113,7 @@ class AdditionalConfig(BaseModel): proxies: Union[str, Proxies, None] = Field(default=None) timeout_: Union[Tuple[int, int], Timeout] = Field(default_factory=Timeout, alias="timeout") trust_env: bool = Field(default=False) + grpc_config: Optional[GrpcConfig] = Field(default=None) @property def timeout(self) -> Timeout: diff --git a/weaviate/connect/base.py b/weaviate/connect/base.py index 5b9d8718c..ba983d85a 100644 --- a/weaviate/connect/base.py +++ b/weaviate/connect/base.py @@ -8,7 +8,7 @@ from grpc.aio import Channel as AsyncChannel # type: ignore from pydantic import BaseModel, field_validator, model_validator -from weaviate.config import Proxies +from weaviate.config import GrpcConfig, Proxies from weaviate.types import NUMBER # from grpclib.client import Channel @@ -105,7 +105,11 @@ def _grpc_target(self) -> str: return f"{self.grpc.host}:{self.grpc.port}" def _grpc_channel( - self, proxies: Dict[str, str], grpc_msg_size: Optional[int], is_async: bool + self, + proxies: Dict[str, str], + grpc_msg_size: Optional[int], + is_async: bool, + grpc_config: Optional[GrpcConfig] = None, ) -> Union[AsyncChannel, SyncChannel]: if grpc_msg_size is None: grpc_msg_size = MAX_GRPC_MESSAGE_LENGTH @@ -120,14 +124,21 @@ def _grpc_channel( else: options = opts + if grpc_config is not None and grpc_config.channel_options is not None: + options.extend(grpc_config.channel_options) + if is_async: mod = grpc.aio else: mod = grpc if self.grpc.secure: + if grpc_config is not None and grpc_config.credentials is not None: + creds = grpc_config.credentials + else: + creds = ssl_channel_credentials() return mod.secure_channel( target=self._grpc_target, - credentials=ssl_channel_credentials(), + credentials=creds, options=options, ) else: diff --git a/weaviate/connect/v4.py b/weaviate/connect/v4.py index 2b47cface..01fa0e6b3 100644 --- a/weaviate/connect/v4.py +++ b/weaviate/connect/v4.py @@ -50,7 +50,7 @@ from weaviate import __version__ as client_version from weaviate.auth import AuthApiKey, AuthClientCredentials, AuthCredentials -from weaviate.config import ConnectionConfig, Proxies +from weaviate.config import ConnectionConfig, GrpcConfig, Proxies from weaviate.config import Timeout as TimeoutConfig from weaviate.connect import executor from weaviate.connect.authentication import _Auth @@ -132,6 +132,7 @@ def __init__( connection_config: ConnectionConfig, embedded_db: Optional[EmbeddedV4] = None, skip_init_checks: bool = False, + grpc_config: Optional[GrpcConfig] = None, ): self.url = connection_params._http_url self.embedded_db = embedded_db @@ -149,6 +150,7 @@ def __init__( self._grpc_max_msg_size: Optional[int] = None self._connected = False self._skip_init_checks = skip_init_checks + self._grpc_config = grpc_config client_type = "sync" if isinstance(self, ConnectionSync) else "async" embedded_suffix = "-embedded" if self.embedded_db is not None else "" @@ -370,6 +372,7 @@ def open_connection_grpc(self, colour: executor.Colour) -> None: proxies=self._proxies, grpc_msg_size=self._grpc_max_msg_size, is_async=colour == "async", + grpc_config=self._grpc_config, ) self._grpc_channel = channel assert self._grpc_channel is not None