Skip to content

Commit 67b9d69

Browse files
committed
oneOf discriminator
1 parent 994a9c7 commit 67b9d69

File tree

2 files changed

+260
-2
lines changed

2 files changed

+260
-2
lines changed

bandwidth/models/callback.py

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,8 +42,18 @@ class Callback(BaseModel):
4242
protected_namespaces=(),
4343
)
4444

45-
46-
discriminator_value_class_map: Dict[str, str] = {
45+
# Discriminator's property name (OpenAPI v3)
46+
__openapi_discriminator_name__ = 'type'
47+
48+
# Discriminator's mapping (OpenAPI v3)
49+
__discriminator_value_class_map__ = {
50+
'message-delivered': 'StatusCallback',
51+
'message-failed': 'StatusCallback',
52+
'message-read': 'StatusCallback',
53+
'message-received': 'InboundCallback',
54+
'message-sent': 'StatusCallback',
55+
'request-location-response': 'InboundCallback',
56+
'suggestion-response': 'InboundCallback'
4757
}
4858

4959
def __init__(self, *args, **kwargs) -> None:
@@ -88,6 +98,20 @@ def from_dict(cls, obj: Union[str, Dict[str, Any]]) -> Self:
8898
def from_json(cls, json_str: str) -> Self:
8999
"""Returns the object represented by the json string"""
90100
instance = cls.model_construct()
101+
102+
# Try to deserialize using the discriminator
103+
json_obj = json.loads(json_str)
104+
discriminator_value = json_obj.get(cls.__openapi_discriminator_name__)
105+
106+
if discriminator_value and discriminator_value in cls.__discriminator_value_class_map__:
107+
class_name = cls.__discriminator_value_class_map__[discriminator_value]
108+
target_class = globals()[class_name]
109+
try:
110+
instance.actual_instance = target_class.from_json(json_str)
111+
return instance
112+
except (ValidationError, ValueError) as e:
113+
raise ValueError(f"Failed to deserialize using discriminator '{discriminator_value}' -> {class_name}: {str(e)}")
114+
91115
error_messages = []
92116
match = 0
93117

Lines changed: 234 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,234 @@
1+
from __future__ import annotations
2+
import json
3+
import pprint
4+
{{#vendorExtensions.x-py-other-imports}}
5+
{{{.}}}
6+
{{/vendorExtensions.x-py-other-imports}}
7+
{{#vendorExtensions.x-py-model-imports}}
8+
{{{.}}}
9+
{{/vendorExtensions.x-py-model-imports}}
10+
from pydantic import StrictStr, Field
11+
from typing import Union, List, Set, Optional, Dict
12+
from typing_extensions import Literal, Self
13+
14+
{{#lambda.uppercase}}{{{classname}}}{{/lambda.uppercase}}_ONE_OF_SCHEMAS = [{{#oneOf}}"{{.}}"{{^-last}}, {{/-last}}{{/oneOf}}]
15+
16+
class {{classname}}({{#parent}}{{{.}}}{{/parent}}{{^parent}}BaseModel{{/parent}}):
17+
"""
18+
{{{description}}}{{^description}}{{{classname}}}{{/description}}
19+
"""
20+
{{#composedSchemas.oneOf}}
21+
# data type: {{{dataType}}}
22+
{{vendorExtensions.x-py-name}}: {{{vendorExtensions.x-py-typing}}}
23+
{{/composedSchemas.oneOf}}
24+
actual_instance: Optional[Union[{{#oneOf}}{{{.}}}{{^-last}}, {{/-last}}{{/oneOf}}]] = None
25+
one_of_schemas: Set[str] = { {{#oneOf}}"{{.}}"{{^-last}}, {{/-last}}{{/oneOf}} }
26+
27+
model_config = ConfigDict(
28+
validate_assignment=True,
29+
protected_namespaces=(),
30+
)
31+
32+
{{#discriminator}}
33+
{{#propertyName}}
34+
# Discriminator's property name (OpenAPI v3)
35+
__openapi_discriminator_name__ = '{{{.}}}'
36+
{{/propertyName}}
37+
38+
{{#mappedModels}}
39+
{{#-first}}
40+
# Discriminator's mapping (OpenAPI v3)
41+
__discriminator_value_class_map__ = {
42+
{{/-first}}
43+
'{{{mappingName}}}': '{{{modelName}}}'{{^-last}},{{/-last}}
44+
{{#-last}}
45+
}
46+
{{/-last}}
47+
{{/mappedModels}}
48+
{{/discriminator}}
49+
50+
def __init__(self, *args, **kwargs) -> None:
51+
if args:
52+
if len(args) > 1:
53+
raise ValueError("If a position argument is used, only 1 is allowed to set `actual_instance`")
54+
if kwargs:
55+
raise ValueError("If a position argument is used, keyword arguments cannot be used.")
56+
super().__init__(actual_instance=args[0])
57+
else:
58+
super().__init__(**kwargs)
59+
60+
@field_validator('actual_instance')
61+
def actual_instance_must_validate_oneof(cls, v):
62+
{{#isNullable}}
63+
if v is None:
64+
return v
65+
66+
{{/isNullable}}
67+
instance = {{{classname}}}.model_construct()
68+
error_messages = []
69+
match = 0
70+
{{#composedSchemas.oneOf}}
71+
# validate data type: {{{dataType}}}
72+
{{#isContainer}}
73+
try:
74+
instance.{{vendorExtensions.x-py-name}} = v
75+
match += 1
76+
except (ValidationError, ValueError) as e:
77+
error_messages.append(str(e))
78+
{{/isContainer}}
79+
{{^isContainer}}
80+
{{#isPrimitiveType}}
81+
try:
82+
instance.{{vendorExtensions.x-py-name}} = v
83+
match += 1
84+
except (ValidationError, ValueError) as e:
85+
error_messages.append(str(e))
86+
{{/isPrimitiveType}}
87+
{{^isPrimitiveType}}
88+
if not isinstance(v, {{{dataType}}}):
89+
error_messages.append(f"Error! Input type `{type(v)}` is not `{{{dataType}}}`")
90+
else:
91+
match += 1
92+
{{/isPrimitiveType}}
93+
{{/isContainer}}
94+
{{/composedSchemas.oneOf}}
95+
if match > 1:
96+
# more than 1 match
97+
raise ValueError("Multiple matches found when setting `actual_instance` in {{{classname}}} with oneOf schemas: {{#oneOf}}{{{.}}}{{^-last}}, {{/-last}}{{/oneOf}}. Details: " + ", ".join(error_messages))
98+
elif match == 0:
99+
# no match
100+
raise ValueError("No match found when setting `actual_instance` in {{{classname}}} with oneOf schemas: {{#oneOf}}{{{.}}}{{^-last}}, {{/-last}}{{/oneOf}}. Details: " + ", ".join(error_messages))
101+
else:
102+
return v
103+
104+
@classmethod
105+
def from_dict(cls, obj: Union[str, Dict[str, Any]]) -> Self:
106+
return cls.from_json(json.dumps(obj))
107+
108+
@classmethod
109+
{{#isNullable}}
110+
def from_json(cls, json_str: Optional[str]) -> Self:
111+
{{/isNullable}}
112+
{{^isNullable}}
113+
def from_json(cls, json_str: str) -> Self:
114+
{{/isNullable}}
115+
"""Returns the object represented by the json string"""
116+
instance = cls.model_construct()
117+
{{#isNullable}}
118+
if json_str is None:
119+
return instance
120+
121+
{{/isNullable}}
122+
{{#discriminator}}
123+
124+
# Try to deserialize using the discriminator
125+
json_obj = json.loads(json_str)
126+
discriminator_value = json_obj.get(cls.__openapi_discriminator_name__)
127+
128+
if discriminator_value and discriminator_value in cls.__discriminator_value_class_map__:
129+
class_name = cls.__discriminator_value_class_map__[discriminator_value]
130+
target_class = globals()[class_name]
131+
try:
132+
instance.actual_instance = target_class.from_json(json_str)
133+
return instance
134+
except (ValidationError, ValueError) as e:
135+
raise ValueError(f"Failed to deserialize using discriminator '{discriminator_value}' -> {class_name}: {str(e)}")
136+
137+
{{/discriminator}}
138+
error_messages = []
139+
match = 0
140+
141+
{{#useOneOfDiscriminatorLookup}}
142+
{{#discriminator}}
143+
{{#mappedModels}}
144+
{{#-first}}
145+
# use oneOf discriminator to lookup the data type
146+
_data_type = json.loads(json_str).get("{{{propertyBaseName}}}")
147+
if not _data_type:
148+
raise ValueError("Failed to lookup data type from the field `{{{propertyBaseName}}}` in the input.")
149+
150+
{{/-first}}
151+
# check if data type is `{{{modelName}}}`
152+
if _data_type == "{{{mappingName}}}":
153+
instance.actual_instance = {{{modelName}}}.from_json(json_str)
154+
return instance
155+
156+
{{/mappedModels}}
157+
{{/discriminator}}
158+
{{/useOneOfDiscriminatorLookup}}
159+
{{#composedSchemas.oneOf}}
160+
{{#isContainer}}
161+
# deserialize data into {{{dataType}}}
162+
try:
163+
# validation
164+
instance.{{vendorExtensions.x-py-name}} = json.loads(json_str)
165+
# assign value to actual_instance
166+
instance.actual_instance = instance.{{vendorExtensions.x-py-name}}
167+
match += 1
168+
except (ValidationError, ValueError) as e:
169+
error_messages.append(str(e))
170+
{{/isContainer}}
171+
{{^isContainer}}
172+
{{#isPrimitiveType}}
173+
# deserialize data into {{{dataType}}}
174+
try:
175+
# validation
176+
instance.{{vendorExtensions.x-py-name}} = json.loads(json_str)
177+
# assign value to actual_instance
178+
instance.actual_instance = instance.{{vendorExtensions.x-py-name}}
179+
match += 1
180+
except (ValidationError, ValueError) as e:
181+
error_messages.append(str(e))
182+
{{/isPrimitiveType}}
183+
{{^isPrimitiveType}}
184+
# deserialize data into {{{dataType}}}
185+
try:
186+
instance.actual_instance = {{{dataType}}}.from_json(json_str)
187+
match += 1
188+
except (ValidationError, ValueError) as e:
189+
error_messages.append(str(e))
190+
{{/isPrimitiveType}}
191+
{{/isContainer}}
192+
{{/composedSchemas.oneOf}}
193+
194+
if match > 1:
195+
# more than 1 match
196+
raise ValueError("Multiple matches found when deserializing the JSON string into {{{classname}}} with oneOf schemas: {{#oneOf}}{{{.}}}{{^-last}}, {{/-last}}{{/oneOf}}. Details: " + ", ".join(error_messages))
197+
elif match == 0:
198+
# no match
199+
raise ValueError("No match found when deserializing the JSON string into {{{classname}}} with oneOf schemas: {{#oneOf}}{{{.}}}{{^-last}}, {{/-last}}{{/oneOf}}. Details: " + ", ".join(error_messages))
200+
else:
201+
return instance
202+
203+
def to_json(self) -> str:
204+
"""Returns the JSON representation of the actual instance"""
205+
if self.actual_instance is None:
206+
return "null"
207+
208+
if hasattr(self.actual_instance, "to_json") and callable(self.actual_instance.to_json):
209+
return self.actual_instance.to_json()
210+
else:
211+
return json.dumps(self.actual_instance)
212+
213+
def to_dict(self) -> Optional[Union[Dict[str, Any], {{#oneOf}}{{{.}}}{{^-last}}, {{/-last}}{{/oneOf}}]]:
214+
"""Returns the dict representation of the actual instance"""
215+
if self.actual_instance is None:
216+
return None
217+
218+
if hasattr(self.actual_instance, "to_dict") and callable(self.actual_instance.to_dict):
219+
return self.actual_instance.to_dict()
220+
else:
221+
# primitive type
222+
return self.actual_instance
223+
224+
def to_str(self) -> str:
225+
"""Returns the string representation of the actual instance"""
226+
return pprint.pformat(self.model_dump())
227+
228+
{{#vendorExtensions.x-py-postponed-model-imports.size}}
229+
{{#vendorExtensions.x-py-postponed-model-imports}}
230+
{{{.}}}
231+
{{/vendorExtensions.x-py-postponed-model-imports}}
232+
# TODO: Rewrite to not use raise_errors
233+
{{classname}}.model_rebuild(raise_errors=False)
234+
{{/vendorExtensions.x-py-postponed-model-imports.size}}

0 commit comments

Comments
 (0)