From e179fbd33d72a2b05d29d04fcf3e3b2188c5482a Mon Sep 17 00:00:00 2001 From: Fanis Tharropoulos Date: Mon, 4 May 2026 18:29:57 +0300 Subject: [PATCH] fix(types): type multi-search union requests and responses - add typed union/standard request schemas for multi-search - return `SearchResponse` when `union=true` in sync and async clients --- src/typesense/async_/multi_search.py | 44 +++++++++++++++++++++++----- src/typesense/sync/multi_search.py | 44 +++++++++++++++++++++++----- src/typesense/types/multi_search.py | 28 ++++++++++++++++-- 3 files changed, 97 insertions(+), 19 deletions(-) diff --git a/src/typesense/async_/multi_search.py b/src/typesense/async_/multi_search.py index 466ac51..179066e 100644 --- a/src/typesense/async_/multi_search.py +++ b/src/typesense/async_/multi_search.py @@ -20,8 +20,14 @@ from .api_call import AsyncApiCall from typesense.preprocess import stringify_search_params -from typesense.types.document import MultiSearchCommonParameters -from typesense.types.multi_search import MultiSearchRequestSchema, MultiSearchResponse +from typesense.types.document import MultiSearchCommonParameters, SearchResponse +from typesense.types.multi_search import ( + MultiSearchRequestSchema, + MultiSearchRequestSchemaMulti, + MultiSearchRequestSchemaUnion, + MultiSearchResponse, + MultiSearchResponseSchema, +) if sys.version_info >= (3, 11): import typing @@ -51,11 +57,27 @@ def __init__(self, api_call: AsyncApiCall) -> None: """ self.api_call = api_call + @typing.overload async def perform( self, - search_queries: MultiSearchRequestSchema, + search_queries: MultiSearchRequestSchemaUnion, + common_params: typing.Union[MultiSearchCommonParameters, None] = None, + ) -> SearchResponse[typing.Any]: + """Perform a union multi-search operation.""" + + @typing.overload + async def perform( # type: ignore[overload-cannot-match] + self, + search_queries: MultiSearchRequestSchemaMulti, common_params: typing.Union[MultiSearchCommonParameters, None] = None, ) -> MultiSearchResponse: + """Perform a standard multi-search operation.""" + + async def perform( + self, + search_queries: MultiSearchRequestSchema, + common_params: typing.Union[MultiSearchCommonParameters, None] = None, + ) -> MultiSearchResponseSchema: """ Perform a multi-search operation. @@ -72,9 +94,9 @@ async def perform( Common parameters to apply to all search queries. Defaults to None. Returns: - MultiSearchResponse: - The response from the multi-search operation, containing - the results of all search queries. + MultiSearchResponseSchema: + A standard multi-search response for non-union requests, + or a search response when ``union=True``. Example: >>> multi_search = AsyncMultiSearch(async_api_call) @@ -98,11 +120,17 @@ async def perform( "searches": stringified_search_params, "union": search_queries.get("union", False), } - response: MultiSearchResponse = await self.api_call.post( + entity_type: typing.Type[typing.Any] + if search_body["union"]: + entity_type = SearchResponse + else: + entity_type = MultiSearchResponse + + response: MultiSearchResponseSchema = await self.api_call.post( AsyncMultiSearch.resource_path, body=search_body, params=common_params, as_json=True, - entity_type=MultiSearchResponse, + entity_type=entity_type, ) return response diff --git a/src/typesense/sync/multi_search.py b/src/typesense/sync/multi_search.py index 2c81be6..38869e5 100644 --- a/src/typesense/sync/multi_search.py +++ b/src/typesense/sync/multi_search.py @@ -20,8 +20,14 @@ from .api_call import ApiCall from typesense.preprocess import stringify_search_params -from typesense.types.document import MultiSearchCommonParameters -from typesense.types.multi_search import MultiSearchRequestSchema, MultiSearchResponse +from typesense.types.document import MultiSearchCommonParameters, SearchResponse +from typesense.types.multi_search import ( + MultiSearchRequestSchema, + MultiSearchRequestSchemaMulti, + MultiSearchRequestSchemaUnion, + MultiSearchResponse, + MultiSearchResponseSchema, +) if sys.version_info >= (3, 11): import typing @@ -51,11 +57,27 @@ def __init__(self, api_call: ApiCall) -> None: """ self.api_call = api_call + @typing.overload def perform( self, - search_queries: MultiSearchRequestSchema, + search_queries: MultiSearchRequestSchemaUnion, + common_params: typing.Union[MultiSearchCommonParameters, None] = None, + ) -> SearchResponse[typing.Any]: + """Perform a union multi-search operation.""" + + @typing.overload + def perform( # type: ignore[overload-cannot-match] + self, + search_queries: MultiSearchRequestSchemaMulti, common_params: typing.Union[MultiSearchCommonParameters, None] = None, ) -> MultiSearchResponse: + """Perform a standard multi-search operation.""" + + def perform( + self, + search_queries: MultiSearchRequestSchema, + common_params: typing.Union[MultiSearchCommonParameters, None] = None, + ) -> MultiSearchResponseSchema: """ Perform a multi-search operation. @@ -72,9 +94,9 @@ def perform( Common parameters to apply to all search queries. Defaults to None. Returns: - MultiSearchResponse: - The response from the multi-search operation, containing - the results of all search queries. + MultiSearchResponseSchema: + A standard multi-search response for non-union requests, + or a search response when ``union=True``. Example: >>> multi_search = MultiSearch(async_api_call) @@ -98,11 +120,17 @@ def perform( "searches": stringified_search_params, "union": search_queries.get("union", False), } - response: MultiSearchResponse = self.api_call.post( + entity_type: typing.Type[typing.Any] + if search_body["union"]: + entity_type = SearchResponse + else: + entity_type = MultiSearchResponse + + response: MultiSearchResponseSchema = self.api_call.post( MultiSearch.resource_path, body=search_body, params=common_params, as_json=True, - entity_type=MultiSearchResponse, + entity_type=entity_type, ) return response diff --git a/src/typesense/types/multi_search.py b/src/typesense/types/multi_search.py index 3619c0b..d3ca860 100644 --- a/src/typesense/types/multi_search.py +++ b/src/typesense/types/multi_search.py @@ -21,13 +21,35 @@ class MultiSearchResponse(typing.TypedDict): results: typing.List[SearchResponse[typing.Any]] # noqa: WPS110 -class MultiSearchRequestSchema(typing.TypedDict): +class MultiSearchRequestSchemaUnion(typing.TypedDict): """ - Schema for multi-search request. + Schema for union multi-search request. Attributes: searches (list[MultiSearchParameters]): The search parameters. """ - union: typing.NotRequired[typing.Literal[True]] + union: typing.Literal[True] searches: typing.List[MultiSearchParameters] + + +class MultiSearchRequestSchemaMulti(typing.TypedDict): + """ + Schema for standard multi-search request. + + Attributes: + searches (list[MultiSearchParameters]): The search parameters. + """ + + union: typing.NotRequired[typing.Literal[False]] + searches: typing.List[MultiSearchParameters] + + +MultiSearchRequestSchema = typing.Union[ + MultiSearchRequestSchemaUnion, + MultiSearchRequestSchemaMulti, +] +MultiSearchResponseSchema = typing.Union[ + MultiSearchResponse, + SearchResponse[typing.Any], +]