Skip to content

Commit be7b54a

Browse files
committed
schema hadling changes, move tests to pytest
1 parent 6bbfa3e commit be7b54a

File tree

7 files changed

+804
-337
lines changed

7 files changed

+804
-337
lines changed

pyproject.toml

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,13 @@ dependencies = [
2727
[tool.setuptools.packages.find]
2828
where = ["src"]
2929

30+
[dependency-groups]
31+
dev = [
32+
"pytest>=8.3.5",
33+
"pytest-mock>=3.14.1",
34+
"requests-mock>=1.12.1",
35+
]
36+
3037
[project.scripts]
3138
openapi-gen = "serverless_openapi_generator.openapi_generator:main"
3239
openapi-validate = "openapi_spec_validator.__main__:main"

src/serverless_openapi_generator/openapi_generator.py

Lines changed: 34 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -188,11 +188,30 @@ def create_operation_object(self, documentation, func, http_event):
188188
obj['requestBody'] = self.create_request_body(documentation.get('requestBody'))
189189
elif http_event.get('request', {}).get('schemas'):
190190
# Handle request schemas defined directly on the event
191-
request_body_doc = {
192-
'description': 'Request body inferred from event schema',
193-
'requestModels': http_event['request']['schemas']
194-
}
195-
obj['requestBody'] = self.create_request_body(request_body_doc)
191+
schemas = http_event['request']['schemas']
192+
content = {}
193+
for media_type, schema_info in schemas.items():
194+
if isinstance(schema_info, str):
195+
schema_name = schema_info
196+
# Find the model name from the standardized models
197+
model = next((model for model in self.schema_handler.models if model.get('key') == schema_name), None)
198+
if model:
199+
content[media_type] = {
200+
'schema': {
201+
'$ref': f"#/components/schemas/{model['name']}"
202+
}
203+
}
204+
elif isinstance(schema_info, dict):
205+
schema_name = ''.join(word.capitalize() for word in re.split(r'[/_-]', media_type))
206+
schema_ref = self.schema_handler.create_schema(schema_name, schema_info)
207+
content[media_type] = {'schema': {'$ref': schema_ref}}
208+
209+
if content:
210+
obj['requestBody'] = {
211+
'description': 'Request body inferred from event schema',
212+
'content': content,
213+
'required': True
214+
}
196215

197216
# Handle private endpoints
198217
if http_event.get('private') is True:
@@ -264,24 +283,16 @@ def create_responses(self, documentation):
264283
return responses
265284

266285
def create_media_type_object(self, models):
267-
media_type_obj = {}
268-
for media_type, model_info in models.items():
269-
# Handle both simple model names and complex objects with schema
270-
if isinstance(model_info, str):
271-
model_name = model_info
272-
elif isinstance(model_info, dict) and 'name' in model_info:
273-
model_name = model_info['name']
274-
else:
275-
# Fallback for inline schemas, though we'll focus on named models
276-
continue
277-
278-
if model_name in self.open_api['components']['schemas']:
279-
media_type_obj[media_type] = {
280-
'schema': {
281-
'$ref': f"#/components/schemas/{model_name}"
282-
}
283-
}
284-
return media_type_obj
286+
content = {}
287+
if models:
288+
for media_type, schema_name in models.items():
289+
print(f"Schema name: {schema_name}")
290+
print(f"Models: {self.schema_handler.models}")
291+
model_info = next((model for model in self.schema_handler.models if model['name'] == schema_name), None)
292+
if model_info:
293+
schema_ref = self.schema_handler.create_schema(schema_name, model_info.get('schema'))
294+
content[media_type] = {'schema': {'$ref': schema_ref}}
295+
return content
285296

286297
def create_request_body(self, request_body_doc):
287298
content = self.create_media_type_object(request_body_doc.get('requestModels', {}))

src/serverless_openapi_generator/schema_handler.py

Lines changed: 17 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -37,8 +37,7 @@ def _resolve_file_references(self, value):
3737

3838
def _standardize_models(self):
3939
documentation = self.serverless_config.get('custom', {}).get('documentation', {})
40-
41-
standardized_models = []
40+
4241
model_sources = []
4342
if 'models' in documentation:
4443
models_config = self._resolve_file_references(documentation['models'])
@@ -51,6 +50,15 @@ def _standardize_models(self):
5150
if 'modelsList' in documentation:
5251
model_sources.extend(self._resolve_file_references(documentation['modelsList']))
5352

53+
api_gateway_models = self.serverless_config.get('provider', {}).get('apiGateway', {}).get('request', {}).get('schemas', {})
54+
if api_gateway_models:
55+
for model_key, model_definition in api_gateway_models.items():
56+
if 'name' not in model_definition:
57+
model_definition['name'] = model_key
58+
model_definition['key'] = model_key
59+
model_sources.append(model_definition)
60+
61+
standardized_models = []
5462
for model in model_sources:
5563
if not isinstance(model, dict) or 'name' not in model:
5664
continue
@@ -61,6 +69,8 @@ def _standardize_models(self):
6169
'name': model.get('name'),
6270
'description': model.get('description', ''),
6371
}
72+
if 'key' in model:
73+
std_model['key'] = model['key']
6474
if 'contentType' in model and 'schema' in model:
6575
std_model['contentType'] = model['contentType']
6676
std_model['schema'] = model['schema']
@@ -163,27 +173,23 @@ def _resolve_schema_references(self, schema):
163173

164174
schema = self._clean_schema(schema)
165175

166-
registry = Registry()
167-
if "definitions" in schema:
168-
for name, sub_schema in schema["definitions"].items():
169-
resource = Resource.from_contents(self._clean_schema(sub_schema), default_specification=DRAFT4)
170-
registry = registry.with_resource(f"#/definitions/{name}", resource)
176+
resource = Resource.from_contents(schema, default_specification=DRAFT4)
177+
registry = Registry().with_resource("root", resource).crawl()
178+
resolver = registry.resolver(base_uri="root")
171179

172-
main_resource = Resource.from_contents(schema, default_specification=DRAFT4)
173-
registry = registry.with_resource("root", main_resource)
174-
resolver = registry.resolver("root")
175180
dereferenced_schema = self._recursive_dereference(schema, resolver)
176181

177182
if isinstance(dereferenced_schema, dict):
178183
dereferenced_schema.pop("definitions", None)
184+
179185
return dereferenced_schema
180186

181187
def _recursive_dereference(self, node, resolver):
182188
if isinstance(node, dict):
183189
if "$ref" in node:
184190
try:
185191
resolved = resolver.lookup(node["$ref"])
186-
return self._recursive_dereference(resolved.contents, resolver)
192+
return self._recursive_dereference(resolved.contents, resolved.resolver)
187193
except Exception:
188194
return node
189195
return {k: self._recursive_dereference(v, resolver) for k, v in node.items()}

0 commit comments

Comments
 (0)