|
| 1 | +# type: ignore |
| 2 | +import importlib.util |
| 3 | +import inspect |
| 4 | +import json |
| 5 | +import re |
| 6 | +import sys |
| 7 | +import tomllib |
| 8 | +from pathlib import Path |
| 9 | +import yaml |
| 10 | + |
| 11 | +from pydantic import BaseModel |
| 12 | + |
| 13 | +try: |
| 14 | + from pydantic.errors import PydanticInvalidForJsonSchema |
| 15 | +except ImportError: |
| 16 | + PydanticInvalidForJsonSchema = Exception |
| 17 | + |
| 18 | + |
| 19 | +def is_pydantic_model(obj): |
| 20 | + """Checks if an object is a Pydantic model class, excluding BaseModel itself.""" |
| 21 | + return inspect.isclass(obj) and issubclass(obj, BaseModel) and obj is not BaseModel |
| 22 | + |
| 23 | + |
| 24 | +def patch_token_region_request_schema(schema_file_path): |
| 25 | + """Specifically patches the TokenRegionRequest.json schema for the body.anyOf[1] issue.""" |
| 26 | + try: |
| 27 | + with open(schema_file_path, "r") as f: |
| 28 | + schema_data = json.load(f) |
| 29 | + |
| 30 | + body_prop = schema_data.get("properties", {}).get("body", {}) |
| 31 | + any_of_list = body_prop.get("anyOf") |
| 32 | + |
| 33 | + if isinstance(any_of_list, list) and len(any_of_list) > 1: |
| 34 | + if any_of_list[1] == {}: # Check if the problematic empty object is at index 1 |
| 35 | + print( |
| 36 | + f" Patching {schema_file_path}: changing properties.body.anyOf[1] from {{}} to {{'type': 'object'}}" |
| 37 | + ) |
| 38 | + any_of_list[1] = {"type": "object"} |
| 39 | + |
| 40 | + with open(schema_file_path, "w") as f: |
| 41 | + json.dump(schema_data, f, indent=2) |
| 42 | + print(f" Successfully patched {schema_file_path}") |
| 43 | + except Exception as e: |
| 44 | + print(f" Error patching {schema_file_path}: {e}") |
| 45 | + |
| 46 | + |
| 47 | +def generate_dto_schemas(source_dir: Path, output_dir: Path, project_root: Path): |
| 48 | + """Generates JSON schemas for Pydantic DTOs and returns a dict of successful ones.""" |
| 49 | + print(f"Searching for DTOs in: {source_dir}") |
| 50 | + sys.path.insert(0, str(project_root)) |
| 51 | + output_dir.mkdir(parents=True, exist_ok=True) |
| 52 | + |
| 53 | + discovered_models = [] |
| 54 | + processed_dto_files = set() |
| 55 | + successfully_generated_schemas = {} |
| 56 | + |
| 57 | + for dto_file_path in source_dir.rglob("**/dtos.py"): |
| 58 | + if dto_file_path in processed_dto_files: |
| 59 | + continue |
| 60 | + processed_dto_files.add(dto_file_path) |
| 61 | + |
| 62 | + print(f" Processing DTO file: {dto_file_path}") |
| 63 | + relative_path = dto_file_path.relative_to(project_root) |
| 64 | + module_name_parts = list(relative_path.parts) |
| 65 | + if module_name_parts[-1] == "dtos.py": |
| 66 | + module_name_parts[-1] = "dtos" |
| 67 | + module_name = ".".join(part for part in module_name_parts if part != "__pycache__") |
| 68 | + |
| 69 | + try: |
| 70 | + spec = importlib.util.spec_from_file_location(module_name, dto_file_path) |
| 71 | + if spec and spec.loader: |
| 72 | + module = importlib.util.module_from_spec(spec) |
| 73 | + sys.modules[module_name] = module |
| 74 | + spec.loader.exec_module(module) |
| 75 | + else: |
| 76 | + print(f"\tCould not create module spec for {dto_file_path}") |
| 77 | + continue |
| 78 | + except Exception as e: |
| 79 | + print(f"\tError importing module {module_name} from {dto_file_path}: {e}") |
| 80 | + if module_name in sys.modules: |
| 81 | + del sys.modules[module_name] |
| 82 | + continue |
| 83 | + |
| 84 | + for name, obj in inspect.getmembers(module): |
| 85 | + if is_pydantic_model(obj): |
| 86 | + if hasattr(obj, "__module__") and obj.__module__ == module_name: |
| 87 | + discovered_models.append((obj, name, module_name)) |
| 88 | + |
| 89 | + print(f"\nFound {len(discovered_models)} Pydantic models from {len(processed_dto_files)} DTO file(s).") |
| 90 | + |
| 91 | + print("\nPhase 2: Rebuilding all discovered models...") |
| 92 | + rebuilt_models_count = 0 |
| 93 | + models_for_schema_gen = [] |
| 94 | + for model_class, model_name, module_name in discovered_models: |
| 95 | + try: |
| 96 | + if hasattr(model_class, "model_rebuild"): |
| 97 | + model_class.model_rebuild(force=True) |
| 98 | + elif hasattr(model_class, "update_forward_refs"): |
| 99 | + model_class.update_forward_refs() |
| 100 | + rebuilt_models_count += 1 |
| 101 | + models_for_schema_gen.append((model_class, model_name, module_name)) |
| 102 | + except Exception as e: |
| 103 | + print(f" Error rebuilding model {module_name}.{model_name}: {e}") |
| 104 | + models_for_schema_gen.append((model_class, model_name, module_name)) |
| 105 | + |
| 106 | + print(f"Attempted to rebuild {rebuilt_models_count} models.") |
| 107 | + |
| 108 | + print("\nPhase 3: Generating JSON schemas for DTOs...") |
| 109 | + for model_class, model_name, module_name in models_for_schema_gen: |
| 110 | + print(f" Generating schema for: {module_name}.{model_name}") |
| 111 | + try: |
| 112 | + if hasattr(model_class, "model_json_schema"): |
| 113 | + schema = model_class.model_json_schema() |
| 114 | + elif hasattr(model_class, "schema_json"): |
| 115 | + schema = json.loads(model_class.schema_json()) |
| 116 | + else: |
| 117 | + print(f"\tCould not find schema generation method for model {model_name}") |
| 118 | + continue |
| 119 | + |
| 120 | + schema_file_name = f"{model_name}.json" |
| 121 | + schema_file_path = output_dir / schema_file_name |
| 122 | + with open(schema_file_path, "w") as f: |
| 123 | + json.dump(schema, f, indent=2) |
| 124 | + print(f"\tSchema saved to: {schema_file_path}") |
| 125 | + |
| 126 | + if model_name == "TokenRegionRequest": |
| 127 | + patch_token_region_request_schema(schema_file_path) |
| 128 | + |
| 129 | + successfully_generated_schemas[model_name] = schema_file_name |
| 130 | + except PydanticInvalidForJsonSchema as e: |
| 131 | + print(f"\tError: Cannot generate JSON schema for {module_name}.{model_name}. Details: {e}") |
| 132 | + except Exception as e: |
| 133 | + print(f"\tError generating/saving schema for model {module_name}.{model_name}: {e}") |
| 134 | + |
| 135 | + print(f"\nSuccessfully generated {len(successfully_generated_schemas)} DTO JSON schema file(s).") |
| 136 | + return successfully_generated_schemas |
| 137 | + |
| 138 | + |
| 139 | +def parse_author_string(author_str): |
| 140 | + """Parses an author string into name and email.""" |
| 141 | + match = re.match(r"^(.*?)\s*<([^>]+)>$", author_str) |
| 142 | + if match: |
| 143 | + return match.group(1).strip(), match.group(2).strip() |
| 144 | + return author_str.strip(), None |
| 145 | + |
| 146 | + |
| 147 | +def load_project_meta(project_root: Path): |
| 148 | + """Loads project metadata from pyproject.toml.""" |
| 149 | + pyproject_file = project_root / "pyproject.toml" |
| 150 | + print(f"\nLoading project metadata from {pyproject_file}...") |
| 151 | + meta = { |
| 152 | + "title": "My API", |
| 153 | + "version": "0.1.0", |
| 154 | + "description": "API documentation", |
| 155 | + "contact_name": None, |
| 156 | + "contact_email": None, |
| 157 | + } |
| 158 | + try: |
| 159 | + with open(pyproject_file, "rb") as f: |
| 160 | + data = tomllib.load(f) |
| 161 | + poetry_data = data.get("tool", {}).get("poetry", {}) |
| 162 | + name = poetry_data.get("name", "my-api") |
| 163 | + meta["title"] = name.replace("_", " ").replace("-", " ").title() + " API" |
| 164 | + meta["version"] = poetry_data.get("version", "0.1.0") |
| 165 | + meta["description"] = poetry_data.get("description", "API documentation") |
| 166 | + authors = poetry_data.get("authors", []) |
| 167 | + if authors and isinstance(authors, list) and authors[0]: |
| 168 | + meta["contact_name"], meta["contact_email"] = parse_author_string(authors[0]) |
| 169 | + print(f" API Title: {meta['title']}, Version: {meta['version']}, Description: {meta['description']}") |
| 170 | + if meta["contact_name"]: |
| 171 | + print(f" Contact Name: {meta['contact_name']}, Email: {meta['contact_email']}") |
| 172 | + except FileNotFoundError: |
| 173 | + print(f" Error: {pyproject_file} not found. Using default API info.") |
| 174 | + except Exception as e: |
| 175 | + print(f" Error reading {pyproject_file}: {e}. Using default API info.") |
| 176 | + return meta |
| 177 | + |
| 178 | + |
| 179 | +def generate_serverless_config(successfully_generated_schemas, project_meta, project_root: Path): |
| 180 | + """Generates a serverless configuration in memory.""" |
| 181 | + print("\nGenerating Serverless config for OpenAPI in memory...") |
| 182 | + python_runtime = "python3.12" |
| 183 | + try: |
| 184 | + main_sls_file = project_root / "serverless-wo-cross-accounts.yml" |
| 185 | + if main_sls_file.exists(): |
| 186 | + with open(main_sls_file, "r") as f: |
| 187 | + main_config = yaml.safe_load(f) |
| 188 | + if main_config and "provider" in main_config and "runtime" in main_config["provider"]: |
| 189 | + python_runtime = main_config["provider"]["runtime"] |
| 190 | + print(f" Using runtime '{python_runtime}' from {main_sls_file}") |
| 191 | + except Exception as e: |
| 192 | + print(f" Could not determine runtime, defaulting to {python_runtime}. Error: {e}") |
| 193 | + |
| 194 | + model_entries = [] |
| 195 | + if successfully_generated_schemas: |
| 196 | + for model_name, schema_file_name in sorted(successfully_generated_schemas.items()): |
| 197 | + description = f"Schema for {model_name}" |
| 198 | + try: |
| 199 | + with open(project_root / "openapi_models" / schema_file_name, "r") as sf: |
| 200 | + schema_content = json.load(sf) |
| 201 | + if "description" in schema_content and schema_content["description"]: |
| 202 | + description = schema_content["description"] |
| 203 | + except Exception: # nosec |
| 204 | + pass |
| 205 | + model_entries.append( |
| 206 | + { |
| 207 | + "name": model_name, |
| 208 | + "description": description, |
| 209 | + "contentType": "application/json", |
| 210 | + "schema": "${file(openapi_models/" + schema_file_name + ")}", |
| 211 | + } |
| 212 | + ) |
| 213 | + |
| 214 | + documentation_block = { |
| 215 | + "version": project_meta["version"], |
| 216 | + "title": project_meta["title"], |
| 217 | + "description": project_meta["description"], |
| 218 | + "models": model_entries, |
| 219 | + } |
| 220 | + if project_meta["contact_name"]: |
| 221 | + documentation_block["contact"] = {"name": project_meta["contact_name"]} |
| 222 | + if project_meta["contact_email"]: |
| 223 | + documentation_block["contact"]["email"] = project_meta["contact_email"] |
| 224 | + |
| 225 | + functions_file = project_root / "serverless" / "functions.yml" |
| 226 | + functions_ref = f"${{file(./serverless/functions.yml)}}" if functions_file.exists() else {} |
| 227 | + |
| 228 | + |
| 229 | + config_content = { |
| 230 | + "service": "identity-oauth-docs-builder", |
| 231 | + "frameworkVersion": "^4.0", |
| 232 | + "provider": {"name": "aws", "runtime": python_runtime, "stage": "integration"}, |
| 233 | + "plugins": ["serverless-openapi-documenter"], |
| 234 | + "custom": { |
| 235 | + "documentation": documentation_block, |
| 236 | + "variables": {"lambda_warm_instances": 1, "lambda_memory_size": 256}, |
| 237 | + }, |
| 238 | + "functions": functions_ref, |
| 239 | + } |
| 240 | + |
| 241 | + return config_content |
0 commit comments