diff --git a/scim2_models/rfc7643/resource.py b/scim2_models/rfc7643/resource.py index ed89b4f..b8a68c8 100644 --- a/scim2_models/rfc7643/resource.py +++ b/scim2_models/rfc7643/resource.py @@ -4,7 +4,6 @@ from typing import Generic from typing import Optional from typing import TypeVar -from typing import Union from typing import get_args from typing import get_origin @@ -24,6 +23,7 @@ from ..base import Uniqueness from ..base import URIReference from ..base import is_complex_attribute +from ..utils import UNION_TYPES from ..utils import normalize_attribute_name @@ -117,7 +117,7 @@ def __new__(cls, name, bases, attrs, **kwargs): extensions = kwargs["__pydantic_generic_metadata__"]["args"][0] extensions = ( get_args(extensions) - if get_origin(extensions) == Union + if get_origin(extensions) in UNION_TYPES else [extensions] ) for extension in extensions: @@ -183,7 +183,8 @@ def get_extension_models(cls) -> dict[str, type[Extension]]: extension_models = cls.__pydantic_generic_metadata__.get("args", []) extension_models = ( get_args(extension_models[0]) - if len(extension_models) == 1 and get_origin(extension_models[0]) == Union + if len(extension_models) == 1 + and get_origin(extension_models[0]) in UNION_TYPES else extension_models ) @@ -301,7 +302,7 @@ def model_to_schema(model: type[BaseModel]): def get_reference_types(type) -> list[str]: first_arg = get_args(type)[0] - types = get_args(first_arg) if get_origin(first_arg) == Union else [first_arg] + types = get_args(first_arg) if get_origin(first_arg) in UNION_TYPES else [first_arg] def serialize_ref_type(ref_type): if ref_type == URIReference: diff --git a/scim2_models/rfc7643/resource_type.py b/scim2_models/rfc7643/resource_type.py index 17befba..e7b60e3 100644 --- a/scim2_models/rfc7643/resource_type.py +++ b/scim2_models/rfc7643/resource_type.py @@ -1,5 +1,7 @@ from typing import Annotated from typing import Optional +from typing import get_args +from typing import get_origin from pydantic import Field from typing_extensions import Self @@ -11,6 +13,7 @@ from ..base import Required from ..base import Returned from ..base import URIReference +from ..utils import UNION_TYPES from .resource import Resource @@ -82,7 +85,16 @@ def from_resource(cls, resource_model: type[Resource]) -> Self: """Build a naive ResourceType from a resource model.""" schema = resource_model.model_fields["schemas"].default[0] name = schema.split(":")[-1] - extensions = resource_model.__pydantic_generic_metadata__["args"] + if resource_model.__pydantic_generic_metadata__["args"]: + extensions = resource_model.__pydantic_generic_metadata__["args"][0] + extensions = ( + get_args(extensions) + if get_origin(extensions) in UNION_TYPES + else [extensions] + ) + else: + extensions = [] + return ResourceType( id=name, name=name, diff --git a/scim2_models/rfc7644/list_response.py b/scim2_models/rfc7644/list_response.py index af55ec6..1616890 100644 --- a/scim2_models/rfc7644/list_response.py +++ b/scim2_models/rfc7644/list_response.py @@ -20,6 +20,7 @@ from ..base import Context from ..base import Required from ..rfc7643.resource import AnyResource +from ..utils import UNION_TYPES from .message import Message @@ -29,7 +30,7 @@ def tagged_resource_union(resource_union): https://docs.pydantic.dev/latest/concepts/unions/#discriminated-unions """ - if not get_origin(resource_union) == Union: + if get_origin(resource_union) not in UNION_TYPES: return resource_union resource_types = get_args(resource_union) diff --git a/tests/test_resource_type.py b/tests/test_resource_type.py index 2c60fa7..1332bf0 100644 --- a/tests/test_resource_type.py +++ b/tests/test_resource_type.py @@ -1,5 +1,10 @@ +from typing import Annotated +from typing import Union + from scim2_models import EnterpriseUser +from scim2_models import Extension from scim2_models import Reference +from scim2_models import Required from scim2_models import ResourceType from scim2_models import User @@ -61,3 +66,32 @@ def test_from_resource_with_extensions(): == "urn:ietf:params:scim:schemas:extension:enterprise:2.0:User" ) assert not enterprise_user_rt.schema_extensions[0].required + + +def test_from_resource_with_mulitple_extensions(): + class TestExtension(Extension): + schemas: Annotated[list[str], Required.true] = [ + "urn:ietf:params:scim:schemas:extension:Test:1.0:User" + ] + + test: Union[str, None] = None + test2: Union[list[str], None] = None + + enterprise_user_rt = ResourceType.from_resource( + User[Union[EnterpriseUser, TestExtension]] + ) + assert enterprise_user_rt.id == "User" + assert enterprise_user_rt.name == "User" + assert enterprise_user_rt.description == "User" + assert enterprise_user_rt.endpoint == "/Users" + assert enterprise_user_rt.schema_ == "urn:ietf:params:scim:schemas:core:2.0:User" + assert ( + enterprise_user_rt.schema_extensions[0].schema_ + == "urn:ietf:params:scim:schemas:extension:enterprise:2.0:User" + ) + assert not enterprise_user_rt.schema_extensions[0].required + assert ( + enterprise_user_rt.schema_extensions[1].schema_ + == "urn:ietf:params:scim:schemas:extension:Test:1.0:User" + ) + assert not enterprise_user_rt.schema_extensions[1].required