66from urllib .parse import urlparse
77
88import pytest
9+ from redis import RedisCluster
910
1011from redis .backoff import NoBackoff , ExponentialBackoff
1112from redis .event import EventDispatcher , EventListenerInterface
2223from redis .client import Redis
2324from redis .maint_notifications import EndpointType , MaintNotificationsConfig
2425from redis .retry import Retry
25- from tests .test_scenario .fault_injector_client import FaultInjectorClient
26+ from tests .test_scenario .fault_injector_client import (
27+ ProxyServerFaultInjector ,
28+ REFaultInjector ,
29+ )
2630
2731RELAXED_TIMEOUT = 30
2832CLIENT_TIMEOUT = 5
2933
3034DEFAULT_ENDPOINT_NAME = "m-standard"
35+ DEFAULT_OSS_API_ENDPOINT_NAME = "oss-api"
3136
3237
3338class CheckActiveDatabaseChangedListener (EventListenerInterface ):
@@ -38,13 +43,24 @@ def listen(self, event: ActiveDatabaseChanged):
3843 self .is_changed_flag = True
3944
4045
46+ def use_mock_proxy ():
47+ return os .getenv ("REDIS_ENTERPRISE_TESTS" , "true" ).lower () == "false"
48+
49+
4150@pytest .fixture ()
4251def endpoint_name (request ):
4352 return request .config .getoption ("--endpoint-name" ) or os .getenv (
4453 "REDIS_ENDPOINT_NAME" , DEFAULT_ENDPOINT_NAME
4554 )
4655
4756
57+ @pytest .fixture ()
58+ def cluster_endpoint_name (request ):
59+ return request .config .getoption ("--cluster-endpoint-name" ) or os .getenv (
60+ "REDIS_CLUSTER_ENDPOINT_NAME" , DEFAULT_OSS_API_ENDPOINT_NAME
61+ )
62+
63+
4864def get_endpoints_config (endpoint_name : str ):
4965 endpoints_config = os .getenv ("REDIS_ENDPOINTS_CONFIG_PATH" , None )
5066
@@ -67,10 +83,27 @@ def endpoints_config(endpoint_name: str):
6783 return get_endpoints_config (endpoint_name )
6884
6985
86+ @pytest .fixture ()
87+ def cluster_endpoints_config (cluster_endpoint_name : str ):
88+ return get_endpoints_config (cluster_endpoint_name )
89+
90+
7091@pytest .fixture ()
7192def fault_injector_client ():
72- url = os .getenv ("FAULT_INJECTION_API_URL" , "http://127.0.0.1:20324" )
73- return FaultInjectorClient (url )
93+ if use_mock_proxy ():
94+ return ProxyServerFaultInjector (oss_cluster = False )
95+ else :
96+ url = os .getenv ("FAULT_INJECTION_API_URL" , "http://127.0.0.1:20324" )
97+ return REFaultInjector (url )
98+
99+
100+ @pytest .fixture ()
101+ def fault_injector_client_oss_api ():
102+ if use_mock_proxy ():
103+ return ProxyServerFaultInjector (oss_cluster = True )
104+ else :
105+ url = os .getenv ("FAULT_INJECTION_API_URL" , "http://127.0.0.1:20324" )
106+ return REFaultInjector (url )
74107
75108
76109@pytest .fixture ()
@@ -208,8 +241,6 @@ def _get_client_maint_notifications(
208241 endpoint_type = endpoint_type ,
209242 )
210243
211- # Create Redis client with maintenance notifications config
212- # This will automatically create the MaintNotificationsPoolHandler
213244 if disable_retries :
214245 retry = Retry (NoBackoff (), 0 )
215246 else :
@@ -218,6 +249,8 @@ def _get_client_maint_notifications(
218249 tls_enabled = True if parsed .scheme == "rediss" else False
219250 logging .info (f"TLS enabled: { tls_enabled } " )
220251
252+ # Create Redis client with maintenance notifications config
253+ # This will automatically create the MaintNotificationsPoolHandler
221254 client = Redis (
222255 host = host ,
223256 port = port ,
@@ -235,3 +268,76 @@ def _get_client_maint_notifications(
235268 logging .info (f"Client uses Protocol: { client .connection_pool .get_protocol ()} " )
236269
237270 return client
271+
272+
273+ @pytest .fixture ()
274+ def cluster_client_maint_notifications (cluster_endpoints_config ):
275+ return _get_cluster_client_maint_notifications (cluster_endpoints_config )
276+
277+
278+ def _get_cluster_client_maint_notifications (
279+ endpoints_config ,
280+ protocol : int = 3 ,
281+ enable_maintenance_notifications : bool = True ,
282+ endpoint_type : Optional [EndpointType ] = None ,
283+ enable_relaxed_timeout : bool = True ,
284+ enable_proactive_reconnect : bool = True ,
285+ disable_retries : bool = False ,
286+ socket_timeout : Optional [float ] = None ,
287+ host_config : Optional [str ] = None ,
288+ ):
289+ """Create Redis cluster client with maintenance notifications enabled."""
290+ # Get credentials from the configuration
291+ username = endpoints_config .get ("username" )
292+ password = endpoints_config .get ("password" )
293+
294+ # Parse host and port from endpoints URL
295+ endpoints = endpoints_config .get ("endpoints" , [])
296+ if not endpoints :
297+ raise ValueError ("No endpoints found in configuration" )
298+
299+ parsed = urlparse (endpoints [0 ])
300+ host = parsed .hostname
301+ port = parsed .port
302+
303+ if not host :
304+ raise ValueError (f"Could not parse host from endpoint URL: { endpoints [0 ]} " )
305+
306+ logging .info (f"Connecting to Redis Enterprise: { host } :{ port } with user: { username } " )
307+
308+ if disable_retries :
309+ retry = Retry (NoBackoff (), 0 )
310+ else :
311+ retry = Retry (backoff = ExponentialWithJitterBackoff (base = 1 , cap = 10 ), retries = 3 )
312+
313+ tls_enabled = True if parsed .scheme == "rediss" else False
314+ logging .info (f"TLS enabled: { tls_enabled } " )
315+
316+ # Configure maintenance notifications
317+ maintenance_config = MaintNotificationsConfig (
318+ enabled = enable_maintenance_notifications ,
319+ proactive_reconnect = enable_proactive_reconnect ,
320+ relaxed_timeout = RELAXED_TIMEOUT if enable_relaxed_timeout else - 1 ,
321+ endpoint_type = endpoint_type ,
322+ )
323+
324+ # Create Redis cluster client with maintenance notifications config
325+ client = RedisCluster (
326+ host = host ,
327+ port = port ,
328+ socket_timeout = CLIENT_TIMEOUT if socket_timeout is None else socket_timeout ,
329+ username = username ,
330+ password = password ,
331+ ssl = tls_enabled ,
332+ ssl_cert_reqs = "none" ,
333+ ssl_check_hostname = False ,
334+ protocol = protocol , # RESP3 required for push notifications
335+ maint_notifications_config = maintenance_config ,
336+ retry = retry ,
337+ )
338+ logging .info ("Redis cluster client created with maintenance notifications enabled" )
339+ logging .info (
340+ f"Cluster working with the following nodes: { [(node .name , node .server_type ) for node in client .get_nodes ()]} "
341+ )
342+
343+ return client
0 commit comments