Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions integration/test_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import warnings
from typing import Dict, Optional

import grpc
import httpx
import pytest
from _pytest.fixtures import SubRequest
Expand Down Expand Up @@ -234,6 +235,20 @@ def test_api_key() -> None:
client.collections.list_all()


@pytest.mark.parametrize("creds", [None, grpc.ssl_channel_credentials()])
def test_custom_grpc_credentials(creds: Optional[grpc.ChannelCredentials]) -> None:
assert is_auth_enabled(f"localhost:{WCS_PORT}")
with weaviate.connect_to_local(
port=WCS_PORT,
grpc_port=WCS_PORT_GRPC,
auth_credentials=wvc.init.Auth.api_key(api_key="my-secret-key"),
additional_config=wvc.init.AdditionalConfig(
grpc_config=wvc.init.GrpcConfig(credentials=creds)
),
) as client:
assert client.is_live()


@pytest.mark.parametrize("header_name", ["Authorization", "authorization"])
def test_api_key_in_header(header_name: str) -> None:
assert is_auth_enabled(f"localhost:{WCS_PORT}")
Expand Down
106 changes: 106 additions & 0 deletions mock_tests/test_grpc_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
from typing import Any, List, Tuple
from unittest.mock import MagicMock

import grpc
import pytest

from weaviate.config import GrpcConfig
from weaviate.connect import base as base_module
from weaviate.connect.base import ConnectionParams, ProtocolParams


@pytest.fixture
def secure_params() -> ConnectionParams:
return ConnectionParams(
http=ProtocolParams(host="localhost", port=8080, secure=False),
grpc=ProtocolParams(host="localhost", port=50051, secure=True),
)


@pytest.fixture
def insecure_params() -> ConnectionParams:
return ConnectionParams(
http=ProtocolParams(host="localhost", port=8080, secure=False),
grpc=ProtocolParams(host="localhost", port=50051, secure=False),
)


@pytest.fixture
def mock_grpc(monkeypatch: pytest.MonkeyPatch) -> MagicMock:
mock = MagicMock()
mock.aio = MagicMock()
monkeypatch.setattr(base_module, "grpc", mock)
return mock


@pytest.fixture
def mock_ssl_creds(monkeypatch: pytest.MonkeyPatch) -> MagicMock:
mock = MagicMock()
monkeypatch.setattr(base_module, "ssl_channel_credentials", mock)
return mock


def test_grpc_config_channel_options() -> None:
opts: List[Tuple[str, Any]] = [("grpc.ssl_target_name_override", "my-host")]
config = GrpcConfig(channel_options=opts)
assert config.channel_options == opts


def test_secure_channel_default_credentials(
secure_params: ConnectionParams, mock_grpc: MagicMock, mock_ssl_creds: MagicMock
) -> None:
mock_channel = MagicMock()
mock_grpc.secure_channel.return_value = mock_channel

result = secure_params._grpc_channel(proxies={}, grpc_msg_size=None, is_async=False)

mock_ssl_creds.assert_called_once_with()
mock_grpc.secure_channel.assert_called_once()
assert result is mock_channel


def test_insecure_channel_no_config(
insecure_params: ConnectionParams, mock_grpc: MagicMock
) -> None:
mock_channel = MagicMock()
mock_grpc.insecure_channel.return_value = mock_channel

result = insecure_params._grpc_channel(proxies={}, grpc_msg_size=None, is_async=False)

mock_grpc.insecure_channel.assert_called_once()
assert result is mock_channel


def test_channel_options_appended_secure(
secure_params: ConnectionParams, mock_grpc: MagicMock, mock_ssl_creds: MagicMock
) -> None:
config = GrpcConfig(
channel_options=[("grpc.ssl_target_name_override", "my-gateway.example.com")]
)
secure_params._grpc_channel(proxies={}, grpc_msg_size=None, is_async=False, grpc_config=config)

options = mock_grpc.secure_channel.call_args.kwargs["options"]
assert ("grpc.ssl_target_name_override", "my-gateway.example.com") in options


def test_channel_options_appended_insecure(
insecure_params: ConnectionParams, mock_grpc: MagicMock
) -> None:
config = GrpcConfig(channel_options=[("grpc.keepalive_time_ms", 30000)])
insecure_params._grpc_channel(
proxies={}, grpc_msg_size=None, is_async=False, grpc_config=config
)

options = mock_grpc.insecure_channel.call_args.kwargs["options"]
assert ("grpc.keepalive_time_ms", 30000) in options


def test_credentials(
secure_params: ConnectionParams, mock_grpc: MagicMock, mock_ssl_creds: MagicMock
) -> None:
creds = grpc.ssl_channel_credentials()
config = GrpcConfig(credentials=creds)
secure_params._grpc_channel(proxies={}, grpc_msg_size=None, is_async=False, grpc_config=config)

mock_ssl_creds.assert_not_called()
assert mock_grpc.secure_channel.call_args.kwargs["credentials"] is creds
4 changes: 2 additions & 2 deletions weaviate/classes/init.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from weaviate.auth import Auth
from weaviate.config import AdditionalConfig, Proxies, Timeout
from weaviate.config import AdditionalConfig, GrpcConfig, Proxies, Timeout

__all__ = ["Auth", "AdditionalConfig", "Proxies", "Timeout"]
__all__ = ["Auth", "AdditionalConfig", "GrpcConfig", "Proxies", "Timeout"]
1 change: 1 addition & 0 deletions weaviate/client_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ def __init__(
proxies=config.proxies,
trust_env=config.trust_env,
skip_init_checks=skip_init_checks,
grpc_config=config.grpc_config,
)

self.integrations = _Integrations(self._connection)
Expand Down
2 changes: 1 addition & 1 deletion weaviate/collections/classes/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@
_Vectors,
_VectorsUpdate,
)
from weaviate.exceptions import WeaviateInvalidInputError, WeaviateInsertInvalidPropertyError
from weaviate.exceptions import WeaviateInsertInvalidPropertyError, WeaviateInvalidInputError
from weaviate.str_enum import BaseEnum
from weaviate.util import _capitalize_first_letter
from weaviate.warnings import _Warnings
Expand Down
35 changes: 34 additions & 1 deletion weaviate/config.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
from dataclasses import dataclass, field
from typing import Optional, Tuple, Union

from pydantic import BaseModel, Field
from grpc import ChannelCredentials
from grpc.aio._typing import ChannelArgumentType
from pydantic import BaseModel, ConfigDict, Field


@dataclass
Expand Down Expand Up @@ -66,6 +68,36 @@ class Proxies(BaseModel):
grpc: Optional[str] = Field(default=None)


class GrpcConfig(BaseModel):
"""Configuration for the gRPC channel used by the Weaviate client. Use this to customize TLS/SSL settings for gRPC connections.

To provide your own `channel_options`, supply a list of tuples where each tuple contains the name of the gRPC channel option and its corresponding value.
[Reference](https://grpc.github.io/grpc/python/glossary.html#term-channel_arguments)

To provide your own `credentials`, use the `ssl_channel_credentials()` function from the `grpc` library to build a `ChannelCredentials` object.
[Reference](https://grpc.github.io/grpc/python/grpc.html#grpc.ssl_channel_credentials)

Example usage:
```python
from grpc import ssl_channel_credentials
import weaviate.classes as wvc

conf = wvc.init.GrpcConfig(
channel_options=[
("grpc.keepalive_time_ms", 10000),
("grpc.keepalive_timeout_ms", 5000),
],
credentials=ssl_channel_credentials(...),
)
```
"""

model_config = ConfigDict(arbitrary_types_allowed=True)

channel_options: Optional[ChannelArgumentType] = Field(default=None)
credentials: Optional[ChannelCredentials] = Field(default=None)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could we add a footgun-check that makes sure that ssl is not deactivated when this is checked?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Have added a test for this in integration/test_auth.py



class AdditionalConfig(BaseModel):
"""Use this class to specify the connection and proxy settings for your client when connecting to Weaviate.

Expand All @@ -81,6 +113,7 @@ class AdditionalConfig(BaseModel):
proxies: Union[str, Proxies, None] = Field(default=None)
timeout_: Union[Tuple[int, int], Timeout] = Field(default_factory=Timeout, alias="timeout")
trust_env: bool = Field(default=False)
grpc_config: Optional[GrpcConfig] = Field(default=None)

@property
def timeout(self) -> Timeout:
Expand Down
17 changes: 14 additions & 3 deletions weaviate/connect/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from grpc.aio import Channel as AsyncChannel # type: ignore
from pydantic import BaseModel, field_validator, model_validator

from weaviate.config import Proxies
from weaviate.config import GrpcConfig, Proxies
from weaviate.types import NUMBER

# from grpclib.client import Channel
Expand Down Expand Up @@ -105,7 +105,11 @@ def _grpc_target(self) -> str:
return f"{self.grpc.host}:{self.grpc.port}"

def _grpc_channel(
self, proxies: Dict[str, str], grpc_msg_size: Optional[int], is_async: bool
self,
proxies: Dict[str, str],
grpc_msg_size: Optional[int],
is_async: bool,
grpc_config: Optional[GrpcConfig] = None,
) -> Union[AsyncChannel, SyncChannel]:
if grpc_msg_size is None:
grpc_msg_size = MAX_GRPC_MESSAGE_LENGTH
Expand All @@ -120,14 +124,21 @@ def _grpc_channel(
else:
options = opts

if grpc_config is not None and grpc_config.channel_options is not None:
options.extend(grpc_config.channel_options)

if is_async:
mod = grpc.aio
else:
mod = grpc
if self.grpc.secure:
if grpc_config is not None and grpc_config.credentials is not None:
creds = grpc_config.credentials
else:
creds = ssl_channel_credentials()
return mod.secure_channel(
target=self._grpc_target,
credentials=ssl_channel_credentials(),
credentials=creds,
options=options,
)
else:
Expand Down
5 changes: 4 additions & 1 deletion weaviate/connect/v4.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@

from weaviate import __version__ as client_version
from weaviate.auth import AuthApiKey, AuthClientCredentials, AuthCredentials
from weaviate.config import ConnectionConfig, Proxies
from weaviate.config import ConnectionConfig, GrpcConfig, Proxies
from weaviate.config import Timeout as TimeoutConfig
from weaviate.connect import executor
from weaviate.connect.authentication import _Auth
Expand Down Expand Up @@ -132,6 +132,7 @@ def __init__(
connection_config: ConnectionConfig,
embedded_db: Optional[EmbeddedV4] = None,
skip_init_checks: bool = False,
grpc_config: Optional[GrpcConfig] = None,
):
self.url = connection_params._http_url
self.embedded_db = embedded_db
Expand All @@ -149,6 +150,7 @@ def __init__(
self._grpc_max_msg_size: Optional[int] = None
self._connected = False
self._skip_init_checks = skip_init_checks
self._grpc_config = grpc_config

client_type = "sync" if isinstance(self, ConnectionSync) else "async"
embedded_suffix = "-embedded" if self.embedded_db is not None else ""
Expand Down Expand Up @@ -370,6 +372,7 @@ def open_connection_grpc(self, colour: executor.Colour) -> None:
proxies=self._proxies,
grpc_msg_size=self._grpc_max_msg_size,
is_async=colour == "async",
grpc_config=self._grpc_config,
)
self._grpc_channel = channel
assert self._grpc_channel is not None
Expand Down
Loading