@@ -630,6 +630,7 @@ class BranchTestCase(InstanceTestCase):
630630 BASE_TEST_CLASS = True
631631 TEARDOWN_RETRY_DROP_DB = 1
632632
633+ CLIENT_TYPE : ClassVar [type [TestClient | TestAsyncIOClient ] | None ]
633634 client : ClassVar [TestClient | TestAsyncIOClient ]
634635
635636 @classmethod
@@ -678,7 +679,9 @@ def setUp(self) -> None:
678679 if self .ISOLATED_TEST_BRANCHES :
679680 cls = type (self )
680681 testdb = cls .loop .run_until_complete (self .setup_branch_copy ())
681- client = cls .make_test_client (database = testdb )._with_debug (
682+ client = cls .make_test_client (
683+ database = testdb , client_class = self .CLIENT_TYPE
684+ )._with_debug (
682685 save_postcheck = True ,
683686 )
684687 self .client = client # type: ignore[misc]
@@ -717,6 +720,7 @@ def tearDown(self) -> None:
717720 def make_test_client (
718721 cls ,
719722 * ,
723+ client_class : type [TestClient | TestAsyncIOClient ] | None = None ,
720724 connection_class : type [
721725 asyncio_client .AsyncIOConnection
722726 | blocking_client .BlockingIOConnection
@@ -758,14 +762,17 @@ def make_blocking_test_client(
758762 cls ,
759763 * ,
760764 instance : _server .BaseInstance ,
765+ client_class : type [TestClient ] | None = None ,
761766 connection_class : type [blocking_client .BlockingIOConnection ]
762767 | None = None ,
763768 ** kwargs : str ,
764769 ) -> TestClient :
770+ if client_class is None :
771+ client_class = TestClient
765772 if connection_class is None :
766773 connection_class = blocking_client .BlockingIOConnection
767774 client = instance .create_blocking_client (
768- client_class = TestClient ,
775+ client_class = client_class ,
769776 connection_class = connection_class ,
770777 ** cls .get_connect_args (instance , ** kwargs ),
771778 )
@@ -799,13 +806,16 @@ def make_async_test_client(
799806 cls ,
800807 * ,
801808 instance : _server .BaseInstance ,
809+ client_class : type [TestAsyncIOClient ] | None = None ,
802810 connection_class : type [asyncio_client .AsyncIOConnection ] | None = None ,
803811 ** kwargs : str ,
804812 ) -> TestAsyncIOClient :
813+ if client_class is None :
814+ client_class = TestAsyncIOClient
805815 if connection_class is None :
806816 connection_class = asyncio_client .AsyncIOConnection
807817 client = instance .create_async_client (
808- client_class = TestAsyncIOClient ,
818+ client_class = client_class ,
809819 connection_class = connection_class ,
810820 ** cls .get_connect_args (instance , ** kwargs ),
811821 )
@@ -881,7 +891,9 @@ async def setup_and_connect(cls) -> None:
881891 await cls ._create_empty_branch (dbname )
882892
883893 if not cls .ISOLATED_TEST_BRANCHES :
884- cls .client = cls .make_test_client (database = dbname )
894+ cls .client = cls .make_test_client (
895+ database = dbname , client_class = cls .CLIENT_TYPE
896+ )
885897 if isinstance (cls .client , gel .AsyncIOClient ):
886898 await cls .client .ensure_connected ()
887899 else :
@@ -1021,11 +1033,13 @@ class AsyncQueryTestCase(BranchTestCase):
10211033 def make_test_client ( # pyright: ignore [reportIncompatibleMethodOverride]
10221034 cls ,
10231035 * ,
1036+ client_class : type [TestAsyncIOClient ] | None = None ,
10241037 connection_class : type [asyncio_client .AsyncIOConnection ] | None = None , # type: ignore [override]
10251038 ** kwargs : str ,
10261039 ) -> TestAsyncIOClient :
10271040 return cls .make_async_test_client (
10281041 instance = cls .instance ,
1042+ client_class = client_class ,
10291043 connection_class = connection_class ,
10301044 ** kwargs ,
10311045 )
@@ -1062,12 +1076,14 @@ def adapt_call(cls, coro: Any) -> Any:
10621076 def make_test_client ( # pyright: ignore [reportIncompatibleMethodOverride]
10631077 cls ,
10641078 * ,
1079+ client_class : type [TestClient ] | None = None ,
10651080 connection_class : type [blocking_client .BlockingIOConnection ] # type: ignore [override]
10661081 | None = None ,
10671082 ** kwargs : str ,
10681083 ) -> TestClient :
10691084 return cls .make_blocking_test_client (
10701085 instance = cls .instance ,
1086+ client_class = client_class ,
10711087 connection_class = connection_class ,
10721088 ** kwargs ,
10731089 )
0 commit comments