Skip to content

Commit 74c05ca

Browse files
committed
fix: read/write splitting switch to writer
1 parent 2f07ddd commit 74c05ca

4 files changed

Lines changed: 153 additions & 41 deletions

File tree

aws_advanced_python_wrapper/read_write_splitting_plugin.py

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -185,7 +185,6 @@ def _set_reader_connection(
185185

186186
def _initialize_writer_connection(self):
187187
conn, writer_host = self._connection_handler.open_new_writer_connection(lambda x: self._plugin_service.connect(x, self._properties, self))
188-
189188
if conn is None:
190189
self.log_and_raise_exception(
191190
"ReadWriteSplittingPlugin.FailedToConnectToWriter"
@@ -280,13 +279,18 @@ def _switch_to_writer_connection(self):
280279
# Already connected to the intended writer.
281280
return
282281

282+
self._writer_host_info = self._connection_handler.get_writer()
283283
self._in_read_write_split = True
284284
if not self._is_connection_usable(self._writer_connection, driver_dialect):
285285
self._initialize_writer_connection()
286286
elif self._writer_connection is not None and self._writer_host_info is not None:
287-
self._switch_current_connection_to(
288-
self._writer_connection, self._writer_host_info
289-
)
287+
if self._connection_handler.can_host_be_used(self._writer_host_info):
288+
self._switch_current_connection_to(
289+
self._writer_connection, self._writer_host_info
290+
)
291+
else:
292+
ReadWriteSplittingConnectionManager.log_and_raise_exception(
293+
"ReadWriteSplittingPlugin.NoWriterFound")
290294

291295
if self._is_reader_conn_from_internal_pool:
292296
self._close_connection_if_idle(self._reader_connection)
@@ -508,6 +512,10 @@ def refresh_and_store_host_list(
508512
"""Refreshes the host list and then stores it."""
509513
...
510514

515+
def get_writer(self) -> Optional[HostInfo]:
516+
"""Get the current writer host info."""
517+
...
518+
511519

512520
class TopologyBasedConnectionHandler(ReadWriteConnectionHandler):
513521
"""Topology based implementation of connection handling logic."""
@@ -538,7 +546,7 @@ def open_new_writer_connection(
538546
self,
539547
plugin_service_connect_func: Callable[[HostInfo], Connection],
540548
) -> tuple[Optional[Connection], Optional[HostInfo]]:
541-
writer_host = self._get_writer()
549+
writer_host = self.get_writer()
542550
if writer_host is None:
543551
return None, None
544552

@@ -621,7 +629,7 @@ def can_host_be_used(self, host_info: HostInfo) -> bool:
621629

622630
def has_no_readers(self) -> bool:
623631
if len(self._hosts) == 1:
624-
return self._get_writer() is not None
632+
return self.get_writer() is not None
625633
return False
626634

627635
def refresh_and_store_host_list(
@@ -657,14 +665,11 @@ def is_writer_host(self, current_host: HostInfo) -> bool:
657665
def is_reader_host(self, current_host) -> bool:
658666
return current_host.role == HostRole.READER
659667

660-
def _get_writer(self) -> Optional[HostInfo]:
668+
def get_writer(self) -> Optional[HostInfo]:
661669
for host in self._hosts:
662670
if host.role == HostRole.WRITER:
663671
return host
664672

665-
ReadWriteSplittingConnectionManager.log_and_raise_exception(
666-
"ReadWriteSplittingPlugin.NoWriterFound"
667-
)
668673
return None
669674

670675

aws_advanced_python_wrapper/simple_read_write_splitting_plugin.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -228,6 +228,9 @@ def is_reader_host(self, current_host: HostInfo) -> bool:
228228
or current_host.url.casefold() == self._read_endpoint
229229
)
230230

231+
def get_writer(self) -> Optional[HostInfo]:
232+
return self._write_endpoint_host_info
233+
231234
def _create_host_info(self, endpoint: str, role: HostRole) -> HostInfo:
232235
endpoint = endpoint.strip()
233236
host = endpoint

tests/integration/container/test_custom_endpoint.py

Lines changed: 115 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
from aws_advanced_python_wrapper import AwsWrapperConnection
3030
from aws_advanced_python_wrapper.errors import (FailoverSuccessError,
3131
ReadWriteSplittingError)
32+
from aws_advanced_python_wrapper.hostinfo import HostRole
3233
from aws_advanced_python_wrapper.utils.log import Logger
3334
from aws_advanced_python_wrapper.utils.properties import (Properties,
3435
WrapperProperties)
@@ -195,7 +196,42 @@ def test_custom_endpoint_failover(self, test_driver: TestDriver, conn_utils, pro
195196

196197
conn.close()
197198

198-
def test_custom_endpoint_read_write_splitting__with_custom_endpoint_changes(
199+
def _setup_custom_endpoint_role(self, target_driver_connect, conn_kwargs, rds_utils, host_role: HostRole):
200+
self.logger.debug("Setting up custom endpoint instance with role: " + host_role.name)
201+
props = {'plugins': ''}
202+
original_writer = rds_utils.get_cluster_writer_instance_id()
203+
failover_target = None
204+
with AwsWrapperConnection.connect(target_driver_connect, **conn_kwargs, **props) as conn:
205+
endpoint_members = self.endpoint_info["StaticMembers"]
206+
original_instance_id = rds_utils.query_instance_id(conn)
207+
self.logger.debug("Original instance id: " + original_instance_id)
208+
assert original_instance_id in endpoint_members
209+
210+
if host_role == HostRole.WRITER:
211+
if original_instance_id == original_writer:
212+
self.logger.debug("Role is already " + host_role.name + ", no failover needed.")
213+
return # Do nothing, no need to failover.
214+
self.logger.debug("Failing over to get writer role...")
215+
elif host_role == HostRole.READER:
216+
if original_instance_id != original_writer:
217+
self.logger.debug("Role is already " + host_role.name + ", no failover needed.")
218+
return # Do nothing, no need to failover.
219+
self.logger.debug("Failing over to get reader role...")
220+
221+
rds_utils.failover_cluster_and_wait_until_writer_changed(target_id=failover_target)
222+
223+
self.logger.debug("Verifying that new connection has role: " + host_role.name)
224+
# Verify that new connection is now the correct role
225+
with AwsWrapperConnection.connect(target_driver_connect, **conn_kwargs, **props) as conn:
226+
endpoint_members = self.endpoint_info["StaticMembers"]
227+
original_instance_id = rds_utils.query_instance_id(conn)
228+
assert original_instance_id in endpoint_members
229+
230+
new_role = rds_utils.query_host_role(conn, TestEnvironment.get_current().get_engine())
231+
assert new_role == host_role
232+
self.logger.debug("Custom endpoint instance successfully set to role: " + host_role.name)
233+
234+
def test_custom_endpoint_read_write_splitting__with_custom_endpoint_changes__with_reader_as_init_conn(
199235
self, test_driver: TestDriver, conn_utils, props, rds_utils):
200236
target_driver_connect = DriverHelper.get_connect_func(test_driver)
201237
kwargs = conn_utils.get_connect_params()
@@ -204,34 +240,87 @@ def test_custom_endpoint_read_write_splitting__with_custom_endpoint_changes(
204240
# it takes more than 30 seconds to modify the cluster endpoint (usually around 140s).
205241
props["custom_endpoint_idle_monitor_expiration_ms"] = 30_000
206242
props["wait_for_custom_endpoint_info_timeout_ms"] = 30_000
207-
conn = AwsWrapperConnection.connect(target_driver_connect, **kwargs, **props)
208243

244+
# Ensure that we are starting with a reader connection
245+
self._setup_custom_endpoint_role(target_driver_connect, kwargs, rds_utils, HostRole.READER)
246+
247+
conn = AwsWrapperConnection.connect(target_driver_connect, **kwargs, **props)
209248
endpoint_members = self.endpoint_info["StaticMembers"]
210249
original_instance_id = rds_utils.query_instance_id(conn)
211250
assert original_instance_id in endpoint_members
212251

213252
# Attempt to switch to an instance of the opposite role. This should fail since the custom endpoint consists
214253
# only of the current host.
215-
new_read_only_value = original_instance_id == rds_utils.get_cluster_writer_instance_id()
216-
if new_read_only_value:
217-
# We are connected to the writer. Attempting to switch to the reader will not work but will intentionally
218-
# not throw an exception. In this scenario we log a warning and purposefully stick with the writer.
219-
self.logger.debug("Initial connection is to the writer. Attempting to switch to reader...")
220-
conn.read_only = new_read_only_value
254+
self.logger.debug("Initial connection is to a reader. Attempting to switch to writer...")
255+
with pytest.raises(ReadWriteSplittingError):
256+
conn.read_only = False
257+
258+
writer_id = rds_utils.get_cluster_writer_instance_id()
259+
new_member = writer_id
260+
261+
rds_client = client('rds', region_name=TestEnvironment.get_current().get_aurora_region())
262+
rds_client.modify_db_cluster_endpoint(
263+
DBClusterEndpointIdentifier=self.endpoint_id,
264+
StaticMembers=[original_instance_id, new_member]
265+
)
266+
267+
try:
268+
self.wait_until_endpoint_has_members(rds_client, {original_instance_id, new_member})
269+
270+
# We should now be able to switch to reader.
271+
conn.read_only = False
221272
new_instance_id = rds_utils.query_instance_id(conn)
222-
assert new_instance_id == original_instance_id
223-
else:
224-
# We are connected to the reader. Attempting to switch to the writer will throw an exception.
225-
self.logger.debug("Initial connection is to a reader. Attempting to switch to writer...")
226-
with pytest.raises(ReadWriteSplittingError):
227-
conn.read_only = new_read_only_value
273+
assert new_instance_id == new_member
274+
275+
# Switch back to original instance
276+
conn.read_only = True
277+
finally:
278+
rds_client.modify_db_cluster_endpoint(
279+
DBClusterEndpointIdentifier=self.endpoint_id,
280+
StaticMembers=[original_instance_id])
281+
self.wait_until_endpoint_has_members(rds_client, {original_instance_id})
282+
283+
# We should not be able to switch again because new_member was removed from the custom endpoint.
284+
# We are connected to the reader. Attempting to switch to the writer will throw an exception.
285+
with pytest.raises(ReadWriteSplittingError):
286+
conn.read_only = False
287+
288+
conn.close()
289+
290+
def test_custom_endpoint_read_write_splitting__with_custom_endpoint_changes__with_writer_as_init_conn(
291+
self, test_driver: TestDriver, conn_utils, props, rds_utils):
292+
target_driver_connect = DriverHelper.get_connect_func(test_driver)
293+
kwargs = conn_utils.get_connect_params()
294+
kwargs["host"] = self.endpoint_info["Endpoint"]
295+
# This setting is not required for the test, but it allows us to also test re-creation of expired monitors since
296+
# it takes more than 30 seconds to modify the cluster endpoint (usually around 140s).
297+
props["custom_endpoint_idle_monitor_expiration_ms"] = 30_000
298+
props["wait_for_custom_endpoint_info_timeout_ms"] = 30_000
299+
300+
# Ensure that we are starting with a reader connection
301+
self._setup_custom_endpoint_role(target_driver_connect, kwargs, rds_utils, HostRole.WRITER)
302+
conn = AwsWrapperConnection.connect(target_driver_connect, **kwargs, **props)
303+
304+
endpoint_members = self.endpoint_info["StaticMembers"]
305+
original_instance_id = str(rds_utils.query_instance_id(conn))
306+
assert original_instance_id in endpoint_members
307+
308+
# We are connected to the writer. Attempting to switch to the reader will not work but will intentionally
309+
# not throw an exception. In this scenario we log a warning and purposefully stick with the writer.
310+
self.logger.debug("Initial connection is to the writer. Attempting to switch to reader...")
311+
conn.read_only = True
312+
new_instance_id = rds_utils.query_instance_id(conn)
313+
assert new_instance_id == original_instance_id
228314

229315
instances = TestEnvironment.get_current().get_instances()
230-
writer_id = rds_utils.get_cluster_writer_instance_id()
231-
if original_instance_id == writer_id:
232-
new_member = instances[1].get_instance_id()
233-
else:
234-
new_member = writer_id
316+
writer_id = str(rds_utils.get_cluster_writer_instance_id())
317+
318+
new_member = ""
319+
# Get any reader id
320+
for instance in instances:
321+
if instance.get_instance_id() != writer_id:
322+
new_member = instance.get_instance_id()
323+
break
235324

236325
rds_client = client('rds', region_name=TestEnvironment.get_current().get_aurora_region())
237326
rds_client.modify_db_cluster_endpoint(
@@ -243,28 +332,23 @@ def test_custom_endpoint_read_write_splitting__with_custom_endpoint_changes(
243332
self.wait_until_endpoint_has_members(rds_client, {original_instance_id, new_member})
244333

245334
# We should now be able to switch to new_member.
246-
conn.read_only = new_read_only_value
335+
conn.read_only = True
247336
new_instance_id = rds_utils.query_instance_id(conn)
248337
assert new_instance_id == new_member
249338

250339
# Switch back to original instance
251-
conn.read_only = not new_read_only_value
340+
conn.read_only = False
252341
finally:
253342
rds_client.modify_db_cluster_endpoint(
254343
DBClusterEndpointIdentifier=self.endpoint_id,
255344
StaticMembers=[original_instance_id])
256345
self.wait_until_endpoint_has_members(rds_client, {original_instance_id})
257346

258347
# We should not be able to switch again because new_member was removed from the custom endpoint.
259-
if new_read_only_value:
260-
# We are connected to the writer. Attempting to switch to the reader will not work but will intentionally
261-
# not throw an exception. In this scenario we log a warning and purposefully stick with the writer.
262-
conn.read_only = new_read_only_value
263-
new_instance_id = rds_utils.query_instance_id(conn)
264-
assert new_instance_id == original_instance_id
265-
else:
266-
# We are connected to the reader. Attempting to switch to the writer will throw an exception.
267-
with pytest.raises(ReadWriteSplittingError):
268-
conn.read_only = new_read_only_value
348+
# We are connected to the writer. Attempting to switch to the reader will not work but will intentionally
349+
# not throw an exception. In this scenario we log a warning and purposefully stick with the writer.
350+
conn.read_only = True
351+
new_instance_id = rds_utils.query_instance_id(conn)
352+
assert new_instance_id == original_instance_id
269353

270354
conn.close()

tests/integration/container/utils/rds_test_utility.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636

3737
from aws_advanced_python_wrapper.driver_info import DriverInfo
3838
from aws_advanced_python_wrapper.errors import UnsupportedOperationError
39+
from aws_advanced_python_wrapper.hostinfo import HostRole
3940
from aws_advanced_python_wrapper.utils.log import Logger
4041
from aws_advanced_python_wrapper.utils.messages import Messages
4142
from .database_engine import DatabaseEngine
@@ -255,6 +256,25 @@ def query_instance_id(
255256
raise RuntimeError(Messages.get_formatted(
256257
"RdsTestUtility.MethodNotSupportedForDeployment", "query_instance_id", database_deployment))
257258

259+
def query_host_role(
260+
self,
261+
conn,
262+
database_engine: DatabaseEngine) -> HostRole:
263+
if database_engine == DatabaseEngine.MYSQL:
264+
is_reader_query = "SELECT @@innodb_read_only"
265+
elif database_engine == DatabaseEngine.PG:
266+
is_reader_query = "SELECT pg_catalog.pg_is_in_recovery()"
267+
268+
with closing(conn.cursor()) as cursor:
269+
cursor.execute(is_reader_query)
270+
record = cursor.fetchone()
271+
is_reader = record[0]
272+
273+
if is_reader in (1, True):
274+
return HostRole.READER
275+
else:
276+
return HostRole.WRITER
277+
258278
def _query_aurora_instance_id(self, conn: Connection, engine: DatabaseEngine) -> str:
259279
if engine == DatabaseEngine.MYSQL:
260280
sql = "SELECT @@aurora_server_id"

0 commit comments

Comments
 (0)