Skip to content

Commit 17ba43d

Browse files
author
Chojan Shang
committed
refactor: make gen_schema step by step
Signed-off-by: Chojan Shang <chojan.shang@vesoft.com>
1 parent 0d6141e commit 17ba43d

File tree

1 file changed

+126
-27
lines changed

1 file changed

+126
-27
lines changed

scripts/gen_schema.py

Lines changed: 126 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,13 @@
33

44
import argparse
55
import ast
6+
import json
67
import re
78
import shutil
89
import subprocess
910
import sys
1011
from collections.abc import Callable
12+
from dataclasses import dataclass
1113
from pathlib import Path
1214

1315
ROOT = Path(__file__).resolve().parents[1]
@@ -116,6 +118,14 @@
116118
)
117119

118120

121+
@dataclass(frozen=True)
122+
class _ProcessingStep:
123+
"""A named transformation applied to the generated schema content."""
124+
125+
name: str
126+
apply: Callable[[str], str]
127+
128+
119129
def parse_args() -> argparse.Namespace:
120130
parser = argparse.ArgumentParser(description="Generate src/acp/schema.py from the ACP JSON schema.")
121131
parser.add_argument(
@@ -166,69 +176,158 @@ def generate_schema(*, format_output: bool = True) -> None:
166176
]
167177

168178
subprocess.check_call(cmd) # noqa: S603
169-
warnings = rename_types(SCHEMA_OUT)
179+
warnings = postprocess_generated_schema(SCHEMA_OUT)
170180
for warning in warnings:
171181
print(f"Warning: {warning}", file=sys.stderr)
172182

173183
if format_output:
174184
format_with_ruff(SCHEMA_OUT)
175185

176186

177-
def rename_types(output_path: Path) -> list[str]:
187+
def postprocess_generated_schema(output_path: Path) -> list[str]:
178188
if not output_path.exists():
179189
raise RuntimeError(f"Generated schema not found at {output_path}") # noqa: TRY003
180190

181-
content = output_path.read_text(encoding="utf-8")
191+
raw_content = output_path.read_text(encoding="utf-8")
192+
header_block = _build_header_block()
193+
194+
content = _strip_existing_header(raw_content)
195+
content = _remove_backcompat_block(content)
196+
content, leftover_classes = _rename_numbered_models(content)
197+
198+
processing_steps: tuple[_ProcessingStep, ...] = (
199+
_ProcessingStep("apply field overrides", _apply_field_overrides),
200+
_ProcessingStep("apply default overrides", _apply_default_overrides),
201+
_ProcessingStep("normalize stdio literal", _normalize_stdio_model),
202+
_ProcessingStep("attach description comments", _add_description_comments),
203+
_ProcessingStep("ensure custom BaseModel", _ensure_custom_base_model),
204+
)
205+
206+
for step in processing_steps:
207+
content = step.apply(content)
208+
209+
missing_targets = _find_missing_targets(content)
210+
211+
content = _inject_enum_aliases(content)
212+
alias_block = _build_alias_block()
213+
final_content = header_block + content.rstrip() + "\n\n" + alias_block
214+
if not final_content.endswith("\n"):
215+
final_content += "\n"
216+
output_path.write_text(final_content, encoding="utf-8")
217+
218+
warnings: list[str] = []
219+
if leftover_classes:
220+
warnings.append(
221+
"Unrenamed schema models detected: "
222+
+ ", ".join(leftover_classes)
223+
+ ". Update RENAME_MAP in scripts/gen_schema.py."
224+
)
225+
if missing_targets:
226+
warnings.append(
227+
"Renamed schema targets not found after generation: "
228+
+ ", ".join(sorted(missing_targets))
229+
+ ". Check RENAME_MAP or upstream schema changes."
230+
)
231+
warnings.extend(_validate_schema_alignment())
232+
233+
return warnings
234+
182235

236+
def _build_header_block() -> str:
183237
header_lines = ["# Generated from schema/schema.json. Do not edit by hand."]
184238
if VERSION_FILE.exists():
185239
ref = VERSION_FILE.read_text(encoding="utf-8").strip()
186240
if ref:
187241
header_lines.append(f"# Schema ref: {ref}")
242+
return "\n".join(header_lines) + "\n\n"
188243

244+
245+
def _build_alias_block() -> str:
246+
alias_lines = [f"{old} = {new}" for old, new in sorted(RENAME_MAP.items())]
247+
return BACKCOMPAT_MARKER + "\n" + "\n".join(alias_lines) + "\n"
248+
249+
250+
def _strip_existing_header(content: str) -> str:
189251
existing_header = re.match(r"(#.*\n)+", content)
190252
if existing_header:
191-
content = content[existing_header.end() :]
192-
content = content.lstrip("\n")
253+
return content[existing_header.end() :].lstrip("\n")
254+
return content.lstrip("\n")
255+
193256

257+
def _remove_backcompat_block(content: str) -> str:
194258
marker_index = content.find(BACKCOMPAT_MARKER)
195259
if marker_index != -1:
196-
content = content[:marker_index].rstrip()
260+
return content[:marker_index].rstrip()
261+
return content
197262

263+
264+
def _rename_numbered_models(content: str) -> tuple[str, list[str]]:
265+
renamed = content
198266
for old, new in sorted(RENAME_MAP.items(), key=lambda item: len(item[0]), reverse=True):
199267
pattern = re.compile(rf"\b{re.escape(old)}\b")
200-
content = pattern.sub(new, content)
268+
renamed = pattern.sub(new, renamed)
201269

202270
leftover_class_pattern = re.compile(r"^class (\w+\d+)\(", re.MULTILINE)
203-
leftover_classes = sorted(set(leftover_class_pattern.findall(content)))
271+
leftover_classes = sorted(set(leftover_class_pattern.findall(renamed)))
272+
return renamed, leftover_classes
204273

205-
header_block = "\n".join(header_lines) + "\n\n"
206-
content = _apply_field_overrides(content)
207-
content = _apply_default_overrides(content)
208-
content = _normalize_stdio_model(content)
209-
content = _add_description_comments(content)
210-
content = _ensure_custom_base_model(content)
211274

212-
alias_lines = [f"{old} = {new}" for old, new in sorted(RENAME_MAP.items())]
213-
alias_block = BACKCOMPAT_MARKER + "\n" + "\n".join(alias_lines) + "\n"
275+
def _find_missing_targets(content: str) -> list[str]:
276+
missing: list[str] = []
277+
for new_name in RENAME_MAP.values():
278+
pattern = re.compile(rf"^class {re.escape(new_name)}\(", re.MULTILINE)
279+
if not pattern.search(content):
280+
missing.append(new_name)
281+
return missing
214282

215-
content = _inject_enum_aliases(content)
216-
content = header_block + content.rstrip() + "\n\n" + alias_block
217-
if not content.endswith("\n"):
218-
content += "\n"
219-
output_path.write_text(content, encoding="utf-8")
220283

284+
def _validate_schema_alignment() -> list[str]:
221285
warnings: list[str] = []
222-
if leftover_classes:
223-
warnings.append(
224-
"Unrenamed schema models detected: "
225-
+ ", ".join(leftover_classes)
226-
+ ". Update RENAME_MAP in scripts/gen_schema.py."
227-
)
286+
if not SCHEMA_JSON.exists():
287+
warnings.append("schema/schema.json missing; unable to validate enum aliases.")
288+
return warnings
228289

290+
try:
291+
schema_enums = _load_schema_enum_literals()
292+
except json.JSONDecodeError as exc:
293+
warnings.append(f"Failed to parse schema/schema.json: {exc}")
294+
return warnings
295+
296+
for enum_name, expected_values in ENUM_LITERAL_MAP.items():
297+
schema_values = schema_enums.get(enum_name)
298+
if schema_values is None:
299+
warnings.append(
300+
f"Enum '{enum_name}' not found in schema.json; update ENUM_LITERAL_MAP or investigate schema changes."
301+
)
302+
continue
303+
if tuple(schema_values) != expected_values:
304+
warnings.append(
305+
f"Enum mismatch for '{enum_name}': schema.json -> {schema_values}, generated aliases -> {expected_values}"
306+
)
229307
return warnings
230308

231309

310+
def _load_schema_enum_literals() -> dict[str, tuple[str, ...]]:
311+
schema_data = json.loads(SCHEMA_JSON.read_text(encoding="utf-8"))
312+
defs = schema_data.get("$defs", {})
313+
enum_literals: dict[str, tuple[str, ...]] = {}
314+
315+
for name, definition in defs.items():
316+
values: list[str] = []
317+
if "enum" in definition:
318+
values = [str(item) for item in definition["enum"]]
319+
elif "oneOf" in definition:
320+
values = [
321+
str(option["const"])
322+
for option in definition.get("oneOf", [])
323+
if isinstance(option, dict) and "const" in option
324+
]
325+
if values:
326+
enum_literals[name] = tuple(values)
327+
328+
return enum_literals
329+
330+
232331
def _ensure_custom_base_model(content: str) -> str:
233332
if "class BaseModel(_BaseModel):" in content:
234333
return content

0 commit comments

Comments
 (0)