Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions scim2_models/rfc7643/resource.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
from typing import Generic
from typing import Optional
from typing import TypeVar
from typing import Union
from typing import get_args
from typing import get_origin

Expand All @@ -24,6 +23,7 @@
from ..base import Uniqueness
from ..base import URIReference
from ..base import is_complex_attribute
from ..utils import UNION_TYPES
from ..utils import normalize_attribute_name


Expand Down Expand Up @@ -117,7 +117,7 @@ def __new__(cls, name, bases, attrs, **kwargs):
extensions = kwargs["__pydantic_generic_metadata__"]["args"][0]
extensions = (
get_args(extensions)
if get_origin(extensions) == Union
if get_origin(extensions) in UNION_TYPES
else [extensions]
)
for extension in extensions:
Expand Down Expand Up @@ -183,7 +183,8 @@ def get_extension_models(cls) -> dict[str, type[Extension]]:
extension_models = cls.__pydantic_generic_metadata__.get("args", [])
extension_models = (
get_args(extension_models[0])
if len(extension_models) == 1 and get_origin(extension_models[0]) == Union
if len(extension_models) == 1
and get_origin(extension_models[0]) in UNION_TYPES
else extension_models
)

Expand Down Expand Up @@ -301,7 +302,7 @@ def model_to_schema(model: type[BaseModel]):

def get_reference_types(type) -> list[str]:
first_arg = get_args(type)[0]
types = get_args(first_arg) if get_origin(first_arg) == Union else [first_arg]
types = get_args(first_arg) if get_origin(first_arg) in UNION_TYPES else [first_arg]

def serialize_ref_type(ref_type):
if ref_type == URIReference:
Expand Down
14 changes: 13 additions & 1 deletion scim2_models/rfc7643/resource_type.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from typing import Annotated
from typing import Optional
from typing import get_args
from typing import get_origin

from pydantic import Field
from typing_extensions import Self
Expand All @@ -11,6 +13,7 @@
from ..base import Required
from ..base import Returned
from ..base import URIReference
from ..utils import UNION_TYPES
from .resource import Resource


Expand Down Expand Up @@ -82,7 +85,16 @@ def from_resource(cls, resource_model: type[Resource]) -> Self:
"""Build a naive ResourceType from a resource model."""
schema = resource_model.model_fields["schemas"].default[0]
name = schema.split(":")[-1]
extensions = resource_model.__pydantic_generic_metadata__["args"]
if resource_model.__pydantic_generic_metadata__["args"]:
extensions = resource_model.__pydantic_generic_metadata__["args"][0]
extensions = (
get_args(extensions)
if get_origin(extensions) in UNION_TYPES
else [extensions]
)
else:
extensions = []

return ResourceType(
id=name,
name=name,
Expand Down
3 changes: 2 additions & 1 deletion scim2_models/rfc7644/list_response.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from ..base import Context
from ..base import Required
from ..rfc7643.resource import AnyResource
from ..utils import UNION_TYPES
from .message import Message


Expand All @@ -29,7 +30,7 @@ def tagged_resource_union(resource_union):

https://docs.pydantic.dev/latest/concepts/unions/#discriminated-unions
"""
if not get_origin(resource_union) == Union:
if get_origin(resource_union) not in UNION_TYPES:
return resource_union

resource_types = get_args(resource_union)
Expand Down
34 changes: 34 additions & 0 deletions tests/test_resource_type.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
from typing import Annotated
from typing import Union

from scim2_models import EnterpriseUser
from scim2_models import Extension
from scim2_models import Reference
from scim2_models import Required
from scim2_models import ResourceType
from scim2_models import User

Expand Down Expand Up @@ -61,3 +66,32 @@ def test_from_resource_with_extensions():
== "urn:ietf:params:scim:schemas:extension:enterprise:2.0:User"
)
assert not enterprise_user_rt.schema_extensions[0].required


def test_from_resource_with_mulitple_extensions():
class TestExtension(Extension):
schemas: Annotated[list[str], Required.true] = [
"urn:ietf:params:scim:schemas:extension:Test:1.0:User"
]

test: Union[str, None] = None
test2: Union[list[str], None] = None

enterprise_user_rt = ResourceType.from_resource(
User[Union[EnterpriseUser, TestExtension]]
)
assert enterprise_user_rt.id == "User"
assert enterprise_user_rt.name == "User"
assert enterprise_user_rt.description == "User"
assert enterprise_user_rt.endpoint == "/Users"
assert enterprise_user_rt.schema_ == "urn:ietf:params:scim:schemas:core:2.0:User"
assert (
enterprise_user_rt.schema_extensions[0].schema_
== "urn:ietf:params:scim:schemas:extension:enterprise:2.0:User"
)
assert not enterprise_user_rt.schema_extensions[0].required
assert (
enterprise_user_rt.schema_extensions[1].schema_
== "urn:ietf:params:scim:schemas:extension:Test:1.0:User"
)
assert not enterprise_user_rt.schema_extensions[1].required
Loading