Skip to content

Commit 37336ad

Browse files
committed
refactor: move get_reference_types in reference.py
1 parent e73b6a1 commit 37336ad

File tree

2 files changed

+31
-19
lines changed

2 files changed

+31
-19
lines changed

scim2_models/reference.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,15 @@
22
from typing import Any
33
from typing import Generic
44
from typing import TypeVar
5+
from typing import get_args
6+
from typing import get_origin
57

68
from pydantic import GetCoreSchemaHandler
79
from pydantic_core import core_schema
810
from typing_extensions import NewType
911

12+
from .utils import UNION_TYPES
13+
1014
ReferenceTypes = TypeVar("ReferenceTypes")
1115

1216
URIReference = NewType("URIReference", str)
@@ -48,3 +52,28 @@ def __get_pydantic_core_schema__(
4852
@classmethod
4953
def _validate(cls, input_value: str, /) -> str:
5054
return input_value
55+
56+
@classmethod
57+
def get_types(cls, type_annotation: Any) -> list[str]:
58+
"""Get reference types from a type annotation.
59+
60+
:param type_annotation: Type annotation to extract reference types from
61+
:type type_annotation: Any
62+
:return: List of reference type strings
63+
:rtype: list[str]
64+
"""
65+
first_arg = get_args(type_annotation)[0]
66+
types = (
67+
get_args(first_arg) if get_origin(first_arg) in UNION_TYPES else [first_arg]
68+
)
69+
70+
def serialize_ref_type(ref_type: Any) -> str:
71+
if ref_type == URIReference:
72+
return "uri"
73+
74+
elif ref_type == ExternalReference:
75+
return "external"
76+
77+
return get_args(ref_type)[0]
78+
79+
return list(map(serialize_ref_type, types))

scim2_models/rfc7643/resource.py

Lines changed: 2 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,7 @@
2929
from ..attributes import is_complex_attribute
3030
from ..base import BaseModel
3131
from ..base import BaseModelType
32-
from ..reference import ExternalReference
33-
from ..reference import URIReference
32+
from ..reference import Reference
3433
from ..utils import UNION_TYPES
3534
from ..utils import normalize_attribute_name
3635

@@ -317,22 +316,6 @@ def model_to_schema(model: type[BaseModel]) -> "Schema":
317316
return schema
318317

319318

320-
def get_reference_types(type) -> list[str]:
321-
first_arg = get_args(type)[0]
322-
types = get_args(first_arg) if get_origin(first_arg) in UNION_TYPES else [first_arg]
323-
324-
def serialize_ref_type(ref_type: type) -> str:
325-
if ref_type == URIReference:
326-
return "uri"
327-
328-
elif ref_type == ExternalReference:
329-
return "external"
330-
331-
return get_args(ref_type)[0]
332-
333-
return list(map(serialize_ref_type, types))
334-
335-
336319
def model_attribute_to_scim_attribute(
337320
model: type[BaseModel], attribute_name: str
338321
) -> "Attribute":
@@ -369,7 +352,7 @@ def model_attribute_to_scim_attribute(
369352
returned=model.get_field_annotation(attribute_name, Returned),
370353
uniqueness=model.get_field_annotation(attribute_name, Uniqueness),
371354
sub_attributes=sub_attributes,
372-
reference_types=get_reference_types(root_type)
355+
reference_types=Reference.get_types(root_type)
373356
if attribute_type == Attribute.Type.reference
374357
else None,
375358
)

0 commit comments

Comments
 (0)