|
3 | 3 |
|
4 | 4 | import argparse |
5 | 5 | import ast |
| 6 | +import json |
6 | 7 | import re |
7 | 8 | import shutil |
8 | 9 | import subprocess |
9 | 10 | import sys |
10 | 11 | from collections.abc import Callable |
| 12 | +from dataclasses import dataclass |
11 | 13 | from pathlib import Path |
12 | 14 |
|
13 | 15 | ROOT = Path(__file__).resolve().parents[1] |
|
116 | 118 | ) |
117 | 119 |
|
118 | 120 |
|
| 121 | +@dataclass(frozen=True) |
| 122 | +class _ProcessingStep: |
| 123 | + """A named transformation applied to the generated schema content.""" |
| 124 | + |
| 125 | + name: str |
| 126 | + apply: Callable[[str], str] |
| 127 | + |
| 128 | + |
119 | 129 | def parse_args() -> argparse.Namespace: |
120 | 130 | parser = argparse.ArgumentParser(description="Generate src/acp/schema.py from the ACP JSON schema.") |
121 | 131 | parser.add_argument( |
@@ -166,69 +176,158 @@ def generate_schema(*, format_output: bool = True) -> None: |
166 | 176 | ] |
167 | 177 |
|
168 | 178 | subprocess.check_call(cmd) # noqa: S603 |
169 | | - warnings = rename_types(SCHEMA_OUT) |
| 179 | + warnings = postprocess_generated_schema(SCHEMA_OUT) |
170 | 180 | for warning in warnings: |
171 | 181 | print(f"Warning: {warning}", file=sys.stderr) |
172 | 182 |
|
173 | 183 | if format_output: |
174 | 184 | format_with_ruff(SCHEMA_OUT) |
175 | 185 |
|
176 | 186 |
|
177 | | -def rename_types(output_path: Path) -> list[str]: |
| 187 | +def postprocess_generated_schema(output_path: Path) -> list[str]: |
178 | 188 | if not output_path.exists(): |
179 | 189 | raise RuntimeError(f"Generated schema not found at {output_path}") # noqa: TRY003 |
180 | 190 |
|
181 | | - content = output_path.read_text(encoding="utf-8") |
| 191 | + raw_content = output_path.read_text(encoding="utf-8") |
| 192 | + header_block = _build_header_block() |
| 193 | + |
| 194 | + content = _strip_existing_header(raw_content) |
| 195 | + content = _remove_backcompat_block(content) |
| 196 | + content, leftover_classes = _rename_numbered_models(content) |
| 197 | + |
| 198 | + processing_steps: tuple[_ProcessingStep, ...] = ( |
| 199 | + _ProcessingStep("apply field overrides", _apply_field_overrides), |
| 200 | + _ProcessingStep("apply default overrides", _apply_default_overrides), |
| 201 | + _ProcessingStep("normalize stdio literal", _normalize_stdio_model), |
| 202 | + _ProcessingStep("attach description comments", _add_description_comments), |
| 203 | + _ProcessingStep("ensure custom BaseModel", _ensure_custom_base_model), |
| 204 | + ) |
| 205 | + |
| 206 | + for step in processing_steps: |
| 207 | + content = step.apply(content) |
| 208 | + |
| 209 | + missing_targets = _find_missing_targets(content) |
| 210 | + |
| 211 | + content = _inject_enum_aliases(content) |
| 212 | + alias_block = _build_alias_block() |
| 213 | + final_content = header_block + content.rstrip() + "\n\n" + alias_block |
| 214 | + if not final_content.endswith("\n"): |
| 215 | + final_content += "\n" |
| 216 | + output_path.write_text(final_content, encoding="utf-8") |
| 217 | + |
| 218 | + warnings: list[str] = [] |
| 219 | + if leftover_classes: |
| 220 | + warnings.append( |
| 221 | + "Unrenamed schema models detected: " |
| 222 | + + ", ".join(leftover_classes) |
| 223 | + + ". Update RENAME_MAP in scripts/gen_schema.py." |
| 224 | + ) |
| 225 | + if missing_targets: |
| 226 | + warnings.append( |
| 227 | + "Renamed schema targets not found after generation: " |
| 228 | + + ", ".join(sorted(missing_targets)) |
| 229 | + + ". Check RENAME_MAP or upstream schema changes." |
| 230 | + ) |
| 231 | + warnings.extend(_validate_schema_alignment()) |
| 232 | + |
| 233 | + return warnings |
| 234 | + |
182 | 235 |
|
| 236 | +def _build_header_block() -> str: |
183 | 237 | header_lines = ["# Generated from schema/schema.json. Do not edit by hand."] |
184 | 238 | if VERSION_FILE.exists(): |
185 | 239 | ref = VERSION_FILE.read_text(encoding="utf-8").strip() |
186 | 240 | if ref: |
187 | 241 | header_lines.append(f"# Schema ref: {ref}") |
| 242 | + return "\n".join(header_lines) + "\n\n" |
188 | 243 |
|
| 244 | + |
| 245 | +def _build_alias_block() -> str: |
| 246 | + alias_lines = [f"{old} = {new}" for old, new in sorted(RENAME_MAP.items())] |
| 247 | + return BACKCOMPAT_MARKER + "\n" + "\n".join(alias_lines) + "\n" |
| 248 | + |
| 249 | + |
| 250 | +def _strip_existing_header(content: str) -> str: |
189 | 251 | existing_header = re.match(r"(#.*\n)+", content) |
190 | 252 | if existing_header: |
191 | | - content = content[existing_header.end() :] |
192 | | - content = content.lstrip("\n") |
| 253 | + return content[existing_header.end() :].lstrip("\n") |
| 254 | + return content.lstrip("\n") |
| 255 | + |
193 | 256 |
|
| 257 | +def _remove_backcompat_block(content: str) -> str: |
194 | 258 | marker_index = content.find(BACKCOMPAT_MARKER) |
195 | 259 | if marker_index != -1: |
196 | | - content = content[:marker_index].rstrip() |
| 260 | + return content[:marker_index].rstrip() |
| 261 | + return content |
197 | 262 |
|
| 263 | + |
| 264 | +def _rename_numbered_models(content: str) -> tuple[str, list[str]]: |
| 265 | + renamed = content |
198 | 266 | for old, new in sorted(RENAME_MAP.items(), key=lambda item: len(item[0]), reverse=True): |
199 | 267 | pattern = re.compile(rf"\b{re.escape(old)}\b") |
200 | | - content = pattern.sub(new, content) |
| 268 | + renamed = pattern.sub(new, renamed) |
201 | 269 |
|
202 | 270 | leftover_class_pattern = re.compile(r"^class (\w+\d+)\(", re.MULTILINE) |
203 | | - leftover_classes = sorted(set(leftover_class_pattern.findall(content))) |
| 271 | + leftover_classes = sorted(set(leftover_class_pattern.findall(renamed))) |
| 272 | + return renamed, leftover_classes |
204 | 273 |
|
205 | | - header_block = "\n".join(header_lines) + "\n\n" |
206 | | - content = _apply_field_overrides(content) |
207 | | - content = _apply_default_overrides(content) |
208 | | - content = _normalize_stdio_model(content) |
209 | | - content = _add_description_comments(content) |
210 | | - content = _ensure_custom_base_model(content) |
211 | 274 |
|
212 | | - alias_lines = [f"{old} = {new}" for old, new in sorted(RENAME_MAP.items())] |
213 | | - alias_block = BACKCOMPAT_MARKER + "\n" + "\n".join(alias_lines) + "\n" |
| 275 | +def _find_missing_targets(content: str) -> list[str]: |
| 276 | + missing: list[str] = [] |
| 277 | + for new_name in RENAME_MAP.values(): |
| 278 | + pattern = re.compile(rf"^class {re.escape(new_name)}\(", re.MULTILINE) |
| 279 | + if not pattern.search(content): |
| 280 | + missing.append(new_name) |
| 281 | + return missing |
214 | 282 |
|
215 | | - content = _inject_enum_aliases(content) |
216 | | - content = header_block + content.rstrip() + "\n\n" + alias_block |
217 | | - if not content.endswith("\n"): |
218 | | - content += "\n" |
219 | | - output_path.write_text(content, encoding="utf-8") |
220 | 283 |
|
| 284 | +def _validate_schema_alignment() -> list[str]: |
221 | 285 | warnings: list[str] = [] |
222 | | - if leftover_classes: |
223 | | - warnings.append( |
224 | | - "Unrenamed schema models detected: " |
225 | | - + ", ".join(leftover_classes) |
226 | | - + ". Update RENAME_MAP in scripts/gen_schema.py." |
227 | | - ) |
| 286 | + if not SCHEMA_JSON.exists(): |
| 287 | + warnings.append("schema/schema.json missing; unable to validate enum aliases.") |
| 288 | + return warnings |
228 | 289 |
|
| 290 | + try: |
| 291 | + schema_enums = _load_schema_enum_literals() |
| 292 | + except json.JSONDecodeError as exc: |
| 293 | + warnings.append(f"Failed to parse schema/schema.json: {exc}") |
| 294 | + return warnings |
| 295 | + |
| 296 | + for enum_name, expected_values in ENUM_LITERAL_MAP.items(): |
| 297 | + schema_values = schema_enums.get(enum_name) |
| 298 | + if schema_values is None: |
| 299 | + warnings.append( |
| 300 | + f"Enum '{enum_name}' not found in schema.json; update ENUM_LITERAL_MAP or investigate schema changes." |
| 301 | + ) |
| 302 | + continue |
| 303 | + if tuple(schema_values) != expected_values: |
| 304 | + warnings.append( |
| 305 | + f"Enum mismatch for '{enum_name}': schema.json -> {schema_values}, generated aliases -> {expected_values}" |
| 306 | + ) |
229 | 307 | return warnings |
230 | 308 |
|
231 | 309 |
|
| 310 | +def _load_schema_enum_literals() -> dict[str, tuple[str, ...]]: |
| 311 | + schema_data = json.loads(SCHEMA_JSON.read_text(encoding="utf-8")) |
| 312 | + defs = schema_data.get("$defs", {}) |
| 313 | + enum_literals: dict[str, tuple[str, ...]] = {} |
| 314 | + |
| 315 | + for name, definition in defs.items(): |
| 316 | + values: list[str] = [] |
| 317 | + if "enum" in definition: |
| 318 | + values = [str(item) for item in definition["enum"]] |
| 319 | + elif "oneOf" in definition: |
| 320 | + values = [ |
| 321 | + str(option["const"]) |
| 322 | + for option in definition.get("oneOf", []) |
| 323 | + if isinstance(option, dict) and "const" in option |
| 324 | + ] |
| 325 | + if values: |
| 326 | + enum_literals[name] = tuple(values) |
| 327 | + |
| 328 | + return enum_literals |
| 329 | + |
| 330 | + |
232 | 331 | def _ensure_custom_base_model(content: str) -> str: |
233 | 332 | if "class BaseModel(_BaseModel):" in content: |
234 | 333 | return content |
|
0 commit comments