diff --git a/src/deepgram/core/serialization.py b/src/deepgram/core/serialization.py index c36e865c..21cacdab 100644 --- a/src/deepgram/core/serialization.py +++ b/src/deepgram/core/serialization.py @@ -3,6 +3,7 @@ import collections import inspect import typing +from functools import lru_cache import pydantic import typing_extensions @@ -59,6 +60,10 @@ def convert_and_respect_annotation_metadata( inner_type = annotation clean_type = _remove_annotations(inner_type) + + # Locally cache getting origin for the cleaned type + clean_type_origin = _get_origin_cached(clean_type) + # Pydantic models if ( inspect.isclass(clean_type) @@ -67,17 +72,14 @@ def convert_and_respect_annotation_metadata( ): return _convert_mapping(object_, clean_type, direction) # TypedDicts - if typing_extensions.is_typeddict(clean_type) and isinstance(object_, typing.Mapping): + if _is_typeddict_cached(clean_type) and isinstance(object_, typing.Mapping): return _convert_mapping(object_, clean_type, direction) - if ( - typing_extensions.get_origin(clean_type) == typing.Dict - or typing_extensions.get_origin(clean_type) == dict - or clean_type == typing.Dict - ) and isinstance(object_, typing.Dict): - key_type = typing_extensions.get_args(clean_type)[0] - value_type = typing_extensions.get_args(clean_type)[1] - + # Dict + if (clean_type_origin == typing.Dict or clean_type_origin == dict or clean_type == typing.Dict) and isinstance( + object_, typing.Dict + ): + key_type, value_type = _get_args_cached(clean_type) return { key: convert_and_respect_annotation_metadata( object_=value, @@ -90,53 +92,46 @@ def convert_and_respect_annotation_metadata( # If you're iterating on a string, do not bother to coerce it to a sequence. if not isinstance(object_, str): - if ( - typing_extensions.get_origin(clean_type) == typing.Set - or typing_extensions.get_origin(clean_type) == set - or clean_type == typing.Set - ) and isinstance(object_, typing.Set): - inner_type = typing_extensions.get_args(clean_type)[0] + # Set + if (clean_type_origin == typing.Set or clean_type_origin == set or clean_type == typing.Set) and isinstance( + object_, typing.Set + ): + (inner_container_type,) = _get_args_cached(clean_type) return { convert_and_respect_annotation_metadata( object_=item, annotation=annotation, - inner_type=inner_type, + inner_type=inner_container_type, direction=direction, ) for item in object_ } + # List/Sequence elif ( - ( - typing_extensions.get_origin(clean_type) == typing.List - or typing_extensions.get_origin(clean_type) == list - or clean_type == typing.List - ) + (clean_type_origin == typing.List or clean_type_origin == list or clean_type == typing.List) and isinstance(object_, typing.List) ) or ( ( - typing_extensions.get_origin(clean_type) == typing.Sequence - or typing_extensions.get_origin(clean_type) == collections.abc.Sequence + clean_type_origin == typing.Sequence + or clean_type_origin == collections.abc.Sequence or clean_type == typing.Sequence ) and isinstance(object_, typing.Sequence) ): - inner_type = typing_extensions.get_args(clean_type)[0] + (inner_container_type,) = _get_args_cached(clean_type) return [ convert_and_respect_annotation_metadata( object_=item, annotation=annotation, - inner_type=inner_type, + inner_type=inner_container_type, direction=direction, ) for item in object_ ] - if typing_extensions.get_origin(clean_type) == typing.Union: - # We should be able to ~relatively~ safely try to convert keys against all - # member types in the union, the edge case here is if one member aliases a field - # of the same name to a different name from another member - # Or if another member aliases a field of the same name that another member does not. - for member in typing_extensions.get_args(clean_type): + # Union + if clean_type_origin == typing.Union: + for member in _get_args_cached(clean_type): object_ = convert_and_respect_annotation_metadata( object_=object_, annotation=annotation, @@ -274,3 +269,21 @@ def _alias_key( if direction == "read": return aliases_to_field_names.get(key, key) return _get_alias_from_type(type_=type_) or key + + +# Cached/getter function for get_origin +@lru_cache(maxsize=128) +def _get_origin_cached(type_: typing.Any) -> typing.Any: + return typing_extensions.get_origin(type_) + + +# Cached/getter function for get_args +@lru_cache(maxsize=128) +def _get_args_cached(type_: typing.Any) -> tuple: + return typing_extensions.get_args(type_) + + +# Cached/getter function for is_typeddict +@lru_cache(maxsize=128) +def _is_typeddict_cached(type_: typing.Any) -> bool: + return typing_extensions.is_typeddict(type_)