diff --git a/redisvl/index/index.py b/redisvl/index/index.py index a2837793..91c4eeaf 100644 --- a/redisvl/index/index.py +++ b/redisvl/index/index.py @@ -1,9 +1,13 @@ import asyncio import json +import os import threading import time import warnings import weakref +from urllib.parse import urlparse, urlunparse + +import yaml from math import ceil from typing import ( TYPE_CHECKING, @@ -1294,6 +1298,69 @@ def info(self, name: Optional[str] = None) -> Dict[str, Any]: index_name = name or self.schema.index.name return self._info(index_name, self._redis_client) + def _sanitize_redis_url(self, url: Optional[str]) -> Optional[str]: + """Remove password from Redis URL for safe serialization.""" + if url is None: + return None + parsed = urlparse(url) + # Replace password with **** + netloc = parsed.hostname or "" + if parsed.port: + netloc = f"{netloc}:{parsed.port}" + if parsed.username: + netloc = f"{parsed.username}:****@{netloc}" + return urlunparse((parsed.scheme, netloc, parsed.path, "", "", "")) + + def to_dict(self, include_connection: bool = False) -> Dict[str, Any]: + """Serialize the index configuration to a dictionary. + + Args: + include_connection (bool, optional): Whether to include connection + parameters. Defaults to False for security (passwords/URLs + are excluded by default). + + Returns: + Dict[str, Any]: Dictionary representation of the index configuration. + + Example: + >>> config = index.to_dict() + >>> new_index = SearchIndex.from_dict(config) + """ + config = self.schema.to_dict() + if include_connection: + # Sanitize URL to remove password + sanitized_url = self._sanitize_redis_url(self._redis_url) + if sanitized_url is not None: + config["_redis_url"] = sanitized_url + # Only include non-sensitive connection kwargs + safe_keys = {"decode_responses", "ssl", "socket_timeout", "socket_connect_timeout"} + safe_kwargs = { + k: v for k, v in self._connection_kwargs.items() + if k in safe_keys + } + if safe_kwargs: + config["_connection_kwargs"] = safe_kwargs + return config + + def to_yaml(self, path: str, include_connection: bool = False, overwrite: bool = True) -> None: + """Serialize the index configuration to a YAML file. + + Args: + path (str): Path to write the YAML file. + include_connection (bool, optional): Whether to include connection + parameters. Defaults to False for security. + overwrite (bool, optional): Whether to overwrite existing file. + Defaults to True. If False and file exists, raises FileExistsError. + + Example: + >>> index.to_yaml("schemas/my_index.yaml") + """ + if not overwrite and os.path.exists(path): + raise FileExistsError(f"File already exists: {path}") + config = self.to_dict(include_connection=include_connection) + with open(path, "w") as f: + yaml.dump(config, f, default_flow_style=False, sort_keys=False) + def __enter__(self): return self @@ -2251,6 +2318,56 @@ def disconnect_sync(self): return sync_wrapper(self.disconnect)() + def to_dict(self, include_connection: bool = False) -> Dict[str, Any]: + """Serialize the index configuration to a dictionary. + + Args: + include_connection (bool, optional): Whether to include connection + parameters. Defaults to False for security (passwords/URLs + are excluded by default). + + Returns: + Dict[str, Any]: Dictionary representation of the index configuration. + + Example: + >>> config = index.to_dict() + >>> new_index = AsyncSearchIndex.from_dict(config) + """ + config = self.schema.to_dict() + if include_connection: + # Sanitize URL to remove password + sanitized_url = self._sanitize_redis_url(self._redis_url) + if sanitized_url is not None: + config["_redis_url"] = sanitized_url + # Only include non-sensitive connection kwargs + safe_keys = {"decode_responses", "ssl", "socket_timeout", "socket_connect_timeout"} + safe_kwargs = { + k: v for k, v in self._connection_kwargs.items() + if k in safe_keys + } + if safe_kwargs: + config["_connection_kwargs"] = safe_kwargs + return config + + def to_yaml(self, path: str, include_connection: bool = False, overwrite: bool = True) -> None: + """Serialize the index configuration to a YAML file. + + Args: + path (str): Path to write the YAML file. + include_connection (bool, optional): Whether to include connection + parameters. Defaults to False for security. + overwrite (bool, optional): Whether to overwrite existing file. + Defaults to True. If False and file exists, raises FileExistsError. + + Example: + >>> index.to_yaml("schemas/my_index.yaml") + """ + if not overwrite and os.path.exists(path): + raise FileExistsError(f"File already exists: {path}") + config = self.to_dict(include_connection=include_connection) + with open(path, "w") as f: + yaml.dump(config, f, default_flow_style=False, sort_keys=False) + async def __aenter__(self): return self diff --git a/tests/unit/test_index_serialization.py b/tests/unit/test_index_serialization.py new file mode 100644 index 00000000..8e27bb90 --- /dev/null +++ b/tests/unit/test_index_serialization.py @@ -0,0 +1,166 @@ +"""Tests for SearchIndex serialization helpers.""" +import tempfile +from pathlib import Path + +import pytest + +from redisvl.index import AsyncSearchIndex, SearchIndex +from redisvl.schema import IndexSchema + + +@pytest.fixture +def sample_schema(): + """Create a sample schema for testing.""" + return IndexSchema.from_dict({ + "index": { + "name": "test_index", + "prefix": "test:", + "storage_type": "hash", + }, + "fields": [ + {"name": "text", "type": "text"}, + {"name": "vector", "type": "vector", "attrs": {"dims": 128, "algorithm": "flat"}}, + ] + }) + + +class TestSearchIndexSerialization: + """Tests for SearchIndex serialization methods.""" + + def test_to_dict_without_connection(self, sample_schema): + """Test to_dict() excludes connection info by default.""" + index = SearchIndex( + schema=sample_schema, + redis_url="redis://localhost:6379", + ) + + config = index.to_dict() + + assert "index" in config + assert config["index"]["name"] == "test_index" + assert "_redis_url" not in config + assert "_connection_kwargs" not in config + + def test_to_dict_with_connection(self, sample_schema): + """Test to_dict() includes connection info when requested.""" + index = SearchIndex( + schema=sample_schema, + redis_url="redis://localhost:6379", + connection_kwargs={"ssl": True, "socket_timeout": 30, "password": "secret"}, + ) + + config = index.to_dict(include_connection=True) + + assert "_redis_url" in config + assert config["_redis_url"] == "redis://localhost:6379" + # Password should not be included (not in safe_keys) + assert "password" not in config.get("_connection_kwargs", {}) + # Safe keys should be included + assert config["_connection_kwargs"]["ssl"] is True + assert config["_connection_kwargs"]["socket_timeout"] == 30 + + def test_to_yaml_without_connection(self, sample_schema): + """Test to_yaml() writes schema to YAML file.""" + index = SearchIndex(schema=sample_schema) + + with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml", delete=False) as f: + path = f.name + + try: + index.to_yaml(path) + + content = Path(path).read_text() + assert "test_index" in content + assert "text" in content + assert "vector" in content + finally: + Path(path).unlink() + + def test_to_yaml_with_connection(self, sample_schema): + """Test to_yaml() includes connection when requested.""" + index = SearchIndex( + schema=sample_schema, + redis_url="redis://localhost:6379", + ) + + with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml", delete=False) as f: + path = f.name + + try: + index.to_yaml(path, include_connection=True) + + content = Path(path).read_text() + assert "_redis_url" in content + finally: + Path(path).unlink() + + def test_roundtrip_dict(self, sample_schema): + """Test that to_dict() and from_dict() roundtrip correctly.""" + original_index = SearchIndex( + schema=sample_schema, + connection_kwargs={"ssl": True}, + ) + + config = original_index.to_dict() + restored_index = SearchIndex.from_dict(config) + + assert restored_index.schema.index.name == original_index.schema.index.name + assert restored_index.schema.index.prefix == original_index.schema.index.prefix + + +class TestAsyncSearchIndexSerialization: + """Tests for AsyncSearchIndex serialization methods.""" + + def test_to_dict_without_connection(self, sample_schema): + """Test to_dict() excludes connection info by default.""" + index = AsyncSearchIndex( + schema=sample_schema, + redis_url="redis://localhost:6379", + ) + + config = index.to_dict() + + assert "index" in config + assert config["index"]["name"] == "test_index" + assert "_redis_url" not in config + + def test_to_dict_with_connection(self, sample_schema): + """Test to_dict() includes connection info when requested.""" + index = AsyncSearchIndex( + schema=sample_schema, + redis_url="redis://localhost:6379", + connection_kwargs={"ssl": True, "password": "secret"}, + ) + + config = index.to_dict(include_connection=True) + + assert "_redis_url" in config + # Password should not be included (not in safe_keys) + assert "password" not in config.get("_connection_kwargs", {}) + + def test_to_yaml(self, sample_schema): + """Test to_yaml() writes schema to YAML file.""" + index = AsyncSearchIndex(schema=sample_schema) + + with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml", delete=False) as f: + path = f.name + + try: + index.to_yaml(path) + + content = Path(path).read_text() + assert "test_index" in content + finally: + Path(path).unlink() + + def test_roundtrip_dict(self, sample_schema): + """Test that to_dict() and from_dict() roundtrip correctly.""" + original_index = AsyncSearchIndex( + schema=sample_schema, + connection_kwargs={"ssl": True}, + ) + + config = original_index.to_dict() + restored_index = AsyncSearchIndex.from_dict(config) + + assert restored_index.schema.index.name == original_index.schema.index.name diff --git a/tests/unit/test_query_types.py b/tests/unit/test_query_types.py index fa7ad69a..c0f96873 100644 --- a/tests/unit/test_query_types.py +++ b/tests/unit/test_query_types.py @@ -335,7 +335,6 @@ def test_text_query_with_string_filter(): assert "AND" not in query_string_wildcard -@pytest.mark.skip("Test is flaking") def test_text_query_word_weights(): # verify word weights get added into the raw Redis query syntax query = TextQuery( @@ -344,10 +343,19 @@ def test_text_query_word_weights(): text_weights={"alpha": 2, "delta": 0.555, "gamma": 0.95}, ) - assert ( - str(query) - == "@description:(query | string | alpha=>{$weight:2} | bravo | delta=>{$weight:0.555} | tango | alpha=>{$weight:2}) SCORER BM25STD WITHSCORES DIALECT 2 LIMIT 0 10" - ) + # Check query components without relying on exact token ordering + query_str = str(query) + assert "@description:(" in query_str + assert "alpha=>{$weight:2}" in query_str + assert "delta=>{$weight:0.555}" in query_str + assert "query" in query_str + assert "string" in query_str + assert "bravo" in query_str + assert "tango" in query_str + assert "SCORER BM25STD" in query_str + assert "WITHSCORES" in query_str + assert "DIALECT 2" in query_str + assert "LIMIT 0 10" in query_str # raise an error if weights are not positive floats with pytest.raises(ValueError):