Skip to content

Commit 420b1d0

Browse files
authored
Provide options allowing users to specify custom grpc SSL credentials (#1946)
* Provide options allowing users to specify custom grpc SSL credentials * Simplify `GrpcConfig` object * Update docstring of new config class with better descr and example usage * Fix tests * Add integration test validating that custom creds don't override default
1 parent 6dab8cf commit 420b1d0

8 files changed

Lines changed: 177 additions & 8 deletions

File tree

integration/test_auth.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import warnings
33
from typing import Dict, Optional
44

5+
import grpc
56
import httpx
67
import pytest
78
from _pytest.fixtures import SubRequest
@@ -234,6 +235,20 @@ def test_api_key() -> None:
234235
client.collections.list_all()
235236

236237

238+
@pytest.mark.parametrize("creds", [None, grpc.ssl_channel_credentials()])
239+
def test_custom_grpc_credentials(creds: Optional[grpc.ChannelCredentials]) -> None:
240+
assert is_auth_enabled(f"localhost:{WCS_PORT}")
241+
with weaviate.connect_to_local(
242+
port=WCS_PORT,
243+
grpc_port=WCS_PORT_GRPC,
244+
auth_credentials=wvc.init.Auth.api_key(api_key="my-secret-key"),
245+
additional_config=wvc.init.AdditionalConfig(
246+
grpc_config=wvc.init.GrpcConfig(credentials=creds)
247+
),
248+
) as client:
249+
assert client.is_live()
250+
251+
237252
@pytest.mark.parametrize("header_name", ["Authorization", "authorization"])
238253
def test_api_key_in_header(header_name: str) -> None:
239254
assert is_auth_enabled(f"localhost:{WCS_PORT}")

mock_tests/test_grpc_config.py

Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
from typing import Any, List, Tuple
2+
from unittest.mock import MagicMock
3+
4+
import grpc
5+
import pytest
6+
7+
from weaviate.config import GrpcConfig
8+
from weaviate.connect import base as base_module
9+
from weaviate.connect.base import ConnectionParams, ProtocolParams
10+
11+
12+
@pytest.fixture
13+
def secure_params() -> ConnectionParams:
14+
return ConnectionParams(
15+
http=ProtocolParams(host="localhost", port=8080, secure=False),
16+
grpc=ProtocolParams(host="localhost", port=50051, secure=True),
17+
)
18+
19+
20+
@pytest.fixture
21+
def insecure_params() -> ConnectionParams:
22+
return ConnectionParams(
23+
http=ProtocolParams(host="localhost", port=8080, secure=False),
24+
grpc=ProtocolParams(host="localhost", port=50051, secure=False),
25+
)
26+
27+
28+
@pytest.fixture
29+
def mock_grpc(monkeypatch: pytest.MonkeyPatch) -> MagicMock:
30+
mock = MagicMock()
31+
mock.aio = MagicMock()
32+
monkeypatch.setattr(base_module, "grpc", mock)
33+
return mock
34+
35+
36+
@pytest.fixture
37+
def mock_ssl_creds(monkeypatch: pytest.MonkeyPatch) -> MagicMock:
38+
mock = MagicMock()
39+
monkeypatch.setattr(base_module, "ssl_channel_credentials", mock)
40+
return mock
41+
42+
43+
def test_grpc_config_channel_options() -> None:
44+
opts: List[Tuple[str, Any]] = [("grpc.ssl_target_name_override", "my-host")]
45+
config = GrpcConfig(channel_options=opts)
46+
assert config.channel_options == opts
47+
48+
49+
def test_secure_channel_default_credentials(
50+
secure_params: ConnectionParams, mock_grpc: MagicMock, mock_ssl_creds: MagicMock
51+
) -> None:
52+
mock_channel = MagicMock()
53+
mock_grpc.secure_channel.return_value = mock_channel
54+
55+
result = secure_params._grpc_channel(proxies={}, grpc_msg_size=None, is_async=False)
56+
57+
mock_ssl_creds.assert_called_once_with()
58+
mock_grpc.secure_channel.assert_called_once()
59+
assert result is mock_channel
60+
61+
62+
def test_insecure_channel_no_config(
63+
insecure_params: ConnectionParams, mock_grpc: MagicMock
64+
) -> None:
65+
mock_channel = MagicMock()
66+
mock_grpc.insecure_channel.return_value = mock_channel
67+
68+
result = insecure_params._grpc_channel(proxies={}, grpc_msg_size=None, is_async=False)
69+
70+
mock_grpc.insecure_channel.assert_called_once()
71+
assert result is mock_channel
72+
73+
74+
def test_channel_options_appended_secure(
75+
secure_params: ConnectionParams, mock_grpc: MagicMock, mock_ssl_creds: MagicMock
76+
) -> None:
77+
config = GrpcConfig(
78+
channel_options=[("grpc.ssl_target_name_override", "my-gateway.example.com")]
79+
)
80+
secure_params._grpc_channel(proxies={}, grpc_msg_size=None, is_async=False, grpc_config=config)
81+
82+
options = mock_grpc.secure_channel.call_args.kwargs["options"]
83+
assert ("grpc.ssl_target_name_override", "my-gateway.example.com") in options
84+
85+
86+
def test_channel_options_appended_insecure(
87+
insecure_params: ConnectionParams, mock_grpc: MagicMock
88+
) -> None:
89+
config = GrpcConfig(channel_options=[("grpc.keepalive_time_ms", 30000)])
90+
insecure_params._grpc_channel(
91+
proxies={}, grpc_msg_size=None, is_async=False, grpc_config=config
92+
)
93+
94+
options = mock_grpc.insecure_channel.call_args.kwargs["options"]
95+
assert ("grpc.keepalive_time_ms", 30000) in options
96+
97+
98+
def test_credentials(
99+
secure_params: ConnectionParams, mock_grpc: MagicMock, mock_ssl_creds: MagicMock
100+
) -> None:
101+
creds = grpc.ssl_channel_credentials()
102+
config = GrpcConfig(credentials=creds)
103+
secure_params._grpc_channel(proxies={}, grpc_msg_size=None, is_async=False, grpc_config=config)
104+
105+
mock_ssl_creds.assert_not_called()
106+
assert mock_grpc.secure_channel.call_args.kwargs["credentials"] is creds

weaviate/classes/init.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
11
from weaviate.auth import Auth
2-
from weaviate.config import AdditionalConfig, Proxies, Timeout
2+
from weaviate.config import AdditionalConfig, GrpcConfig, Proxies, Timeout
33

4-
__all__ = ["Auth", "AdditionalConfig", "Proxies", "Timeout"]
4+
__all__ = ["Auth", "AdditionalConfig", "GrpcConfig", "Proxies", "Timeout"]

weaviate/client_executor.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,7 @@ def __init__(
8181
proxies=config.proxies,
8282
trust_env=config.trust_env,
8383
skip_init_checks=skip_init_checks,
84+
grpc_config=config.grpc_config,
8485
)
8586

8687
self.integrations = _Integrations(self._connection)

weaviate/collections/classes/config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@
7474
_Vectors,
7575
_VectorsUpdate,
7676
)
77-
from weaviate.exceptions import WeaviateInvalidInputError, WeaviateInsertInvalidPropertyError
77+
from weaviate.exceptions import WeaviateInsertInvalidPropertyError, WeaviateInvalidInputError
7878
from weaviate.str_enum import BaseEnum
7979
from weaviate.util import _capitalize_first_letter
8080
from weaviate.warnings import _Warnings

weaviate/config.py

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
from dataclasses import dataclass, field
22
from typing import Optional, Tuple, Union
33

4-
from pydantic import BaseModel, Field
4+
from grpc import ChannelCredentials
5+
from grpc.aio._typing import ChannelArgumentType
6+
from pydantic import BaseModel, ConfigDict, Field
57

68

79
@dataclass
@@ -66,6 +68,36 @@ class Proxies(BaseModel):
6668
grpc: Optional[str] = Field(default=None)
6769

6870

71+
class GrpcConfig(BaseModel):
72+
"""Configuration for the gRPC channel used by the Weaviate client. Use this to customize TLS/SSL settings for gRPC connections.
73+
74+
To provide your own `channel_options`, supply a list of tuples where each tuple contains the name of the gRPC channel option and its corresponding value.
75+
[Reference](https://grpc.github.io/grpc/python/glossary.html#term-channel_arguments)
76+
77+
To provide your own `credentials`, use the `ssl_channel_credentials()` function from the `grpc` library to build a `ChannelCredentials` object.
78+
[Reference](https://grpc.github.io/grpc/python/grpc.html#grpc.ssl_channel_credentials)
79+
80+
Example usage:
81+
```python
82+
from grpc import ssl_channel_credentials
83+
import weaviate.classes as wvc
84+
85+
conf = wvc.init.GrpcConfig(
86+
channel_options=[
87+
("grpc.keepalive_time_ms", 10000),
88+
("grpc.keepalive_timeout_ms", 5000),
89+
],
90+
credentials=ssl_channel_credentials(...),
91+
)
92+
```
93+
"""
94+
95+
model_config = ConfigDict(arbitrary_types_allowed=True)
96+
97+
channel_options: Optional[ChannelArgumentType] = Field(default=None)
98+
credentials: Optional[ChannelCredentials] = Field(default=None)
99+
100+
69101
class AdditionalConfig(BaseModel):
70102
"""Use this class to specify the connection and proxy settings for your client when connecting to Weaviate.
71103
@@ -81,6 +113,7 @@ class AdditionalConfig(BaseModel):
81113
proxies: Union[str, Proxies, None] = Field(default=None)
82114
timeout_: Union[Tuple[int, int], Timeout] = Field(default_factory=Timeout, alias="timeout")
83115
trust_env: bool = Field(default=False)
116+
grpc_config: Optional[GrpcConfig] = Field(default=None)
84117

85118
@property
86119
def timeout(self) -> Timeout:

weaviate/connect/base.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from grpc.aio import Channel as AsyncChannel # type: ignore
99
from pydantic import BaseModel, field_validator, model_validator
1010

11-
from weaviate.config import Proxies
11+
from weaviate.config import GrpcConfig, Proxies
1212
from weaviate.types import NUMBER
1313

1414
# from grpclib.client import Channel
@@ -105,7 +105,11 @@ def _grpc_target(self) -> str:
105105
return f"{self.grpc.host}:{self.grpc.port}"
106106

107107
def _grpc_channel(
108-
self, proxies: Dict[str, str], grpc_msg_size: Optional[int], is_async: bool
108+
self,
109+
proxies: Dict[str, str],
110+
grpc_msg_size: Optional[int],
111+
is_async: bool,
112+
grpc_config: Optional[GrpcConfig] = None,
109113
) -> Union[AsyncChannel, SyncChannel]:
110114
if grpc_msg_size is None:
111115
grpc_msg_size = MAX_GRPC_MESSAGE_LENGTH
@@ -120,14 +124,21 @@ def _grpc_channel(
120124
else:
121125
options = opts
122126

127+
if grpc_config is not None and grpc_config.channel_options is not None:
128+
options.extend(grpc_config.channel_options)
129+
123130
if is_async:
124131
mod = grpc.aio
125132
else:
126133
mod = grpc
127134
if self.grpc.secure:
135+
if grpc_config is not None and grpc_config.credentials is not None:
136+
creds = grpc_config.credentials
137+
else:
138+
creds = ssl_channel_credentials()
128139
return mod.secure_channel(
129140
target=self._grpc_target,
130-
credentials=ssl_channel_credentials(),
141+
credentials=creds,
131142
options=options,
132143
)
133144
else:

weaviate/connect/v4.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@
5050

5151
from weaviate import __version__ as client_version
5252
from weaviate.auth import AuthApiKey, AuthClientCredentials, AuthCredentials
53-
from weaviate.config import ConnectionConfig, Proxies
53+
from weaviate.config import ConnectionConfig, GrpcConfig, Proxies
5454
from weaviate.config import Timeout as TimeoutConfig
5555
from weaviate.connect import executor
5656
from weaviate.connect.authentication import _Auth
@@ -132,6 +132,7 @@ def __init__(
132132
connection_config: ConnectionConfig,
133133
embedded_db: Optional[EmbeddedV4] = None,
134134
skip_init_checks: bool = False,
135+
grpc_config: Optional[GrpcConfig] = None,
135136
):
136137
self.url = connection_params._http_url
137138
self.embedded_db = embedded_db
@@ -149,6 +150,7 @@ def __init__(
149150
self._grpc_max_msg_size: Optional[int] = None
150151
self._connected = False
151152
self._skip_init_checks = skip_init_checks
153+
self._grpc_config = grpc_config
152154

153155
client_type = "sync" if isinstance(self, ConnectionSync) else "async"
154156
embedded_suffix = "-embedded" if self.embedded_db is not None else ""
@@ -370,6 +372,7 @@ def open_connection_grpc(self, colour: executor.Colour) -> None:
370372
proxies=self._proxies,
371373
grpc_msg_size=self._grpc_max_msg_size,
372374
is_async=colour == "async",
375+
grpc_config=self._grpc_config,
373376
)
374377
self._grpc_channel = channel
375378
assert self._grpc_channel is not None

0 commit comments

Comments
 (0)