diff --git a/sdk/ml/azure-ai-ml/scripts/regenerate_restclient.py b/sdk/ml/azure-ai-ml/scripts/regenerate_restclient.py index ff07da3179a0..a4ee0733e56d 100644 --- a/sdk/ml/azure-ai-ml/scripts/regenerate_restclient.py +++ b/sdk/ml/azure-ai-ml/scripts/regenerate_restclient.py @@ -6,19 +6,38 @@ # license information. # -------------------------------------------------------------------------- +""" +Script to regenerate restclient from TypeSpec definitions. + +Usage: + python regenerate_restclient.py -a v2025_10_01_preview --spec-repo C:\\Repos\\azure-rest-api-specs + python regenerate_restclient.py -a v2025_10_01_preview --spec-repo C:\\Repos\\azure-rest-api-specs -v + +This script: +1. Compiles TypeSpec from azure-rest-api-specs repo +2. Generates Python SDK to a temp directory +3. Copies only the restclient folder to the target location +4. Fixes known TypeSpec emitter bugs (duplicate api_version) +5. Cleans up temp directory +""" + import logging import os +import re +import shutil import subprocess import sys +import tempfile import time from argparse import ArgumentParser from pathlib import Path from platform import system -from urllib.request import urlopen module_logger = logging.getLogger(__name__) -MULTI_API_TAG = "multiapi" +# Paths relative to the spec repo +TYPESPEC_PROJECT_PATH = "specification/machinelearningservices/MachineLearningServices.Management" +RESTCLIENT_RELATIVE_PATH = "sdk/ml/azure-ai-ml/azure/ai/ml/_restclient" class Color: @@ -94,64 +113,275 @@ def print_blue(message): print(Color.BLUE + message + Color.END) -def download_file(from_url: str, to_path: Path, with_file_name: str) -> None: - print_blue(f"- Downloading {with_file_name} from {from_url} to {to_path}") +def print_green(message): + print(Color.GREEN + message + Color.END) + + +def print_yellow(message): + print(Color.YELLOW + message + Color.END) + + +def print_red(message): + print(Color.RED + message + Color.END) + + +def step_banner(step_num, message): + """Print a step banner for visibility.""" + print() + print(Color.CYAN + "=" * 60 + Color.END) + print(Color.CYAN + f"STEP {step_num}: {message}" + Color.END) + print(Color.CYAN + "=" * 60 + Color.END) + + +def fix_duplicate_api_version(file_path: Path, verbose: bool = False) -> int: + """ + Fix the TypeSpec Python emitter bug that generates duplicate api_version parameters. + See: https://github.com/microsoft/typespec/issues/9384 + + Returns the number of duplicates fixed. + """ + if not file_path.exists(): + return 0 + + content = file_path.read_text(encoding="utf-8") + + # Pattern to find consecutive duplicate api_version lines + pattern = r"(\s+api_version=self\._config\.api_version,\r?\n)\s+api_version=self\._config\.api_version," + + # Count matches before fixing + matches = re.findall(pattern, content) + count = len(matches) + + if count > 0: + fixed_content = re.sub(pattern, r"\1", content) + file_path.write_text(fixed_content, encoding="utf-8") + if verbose: + print_yellow(f" Fixed {count} duplicate api_version occurrences in {file_path.name}") + + return count + + +def regenerate_restclient(api_version: str, spec_repo: Path, verbose: bool = False): + """ + Regenerate restclient from TypeSpec definitions. + + Args: + api_version: The API version to generate (e.g., "v2025_10_01_preview") + spec_repo: Path to the azure-rest-api-specs repository + verbose: Whether to show verbose output + """ + # Normalize api_version format (support both v2025_10_01_preview and v2025-10-01-preview) + api_version_normalized = api_version.lower().replace("-", "_") + if not api_version_normalized.startswith("v"): + api_version_normalized = "v" + api_version_normalized + + # Paths + typespec_project_dir = spec_repo / TYPESPEC_PROJECT_PATH + tspconfig_path = typespec_project_dir / "tspconfig.yaml" + main_tsp_path = typespec_project_dir / "main.tsp" + + # Get the script's directory to find the restclient path + script_dir = Path(__file__).parent.absolute() + sdk_package_dir = script_dir.parent # azure-ai-ml directory + restclient_base_path = sdk_package_dir / "azure" / "ai" / "ml" / "_restclient" + target_restclient_path = restclient_base_path / api_version_normalized + + command_args = {"shell": system() == "Windows", "stream_stdout": verbose} + + # ========================================================================= + # STEP 1: Validate inputs + # ========================================================================= + step_banner(1, "Validating inputs") + + print_blue(f" API Version: {api_version_normalized}") + print_blue(f" Spec repo: {spec_repo}") + print_blue(f" TypeSpec project: {typespec_project_dir}") + print_blue(f" Target restclient path: {target_restclient_path}") + + if not spec_repo.exists(): + print_red(f"ERROR: Spec repo not found at {spec_repo}") + print_yellow("Please clone azure-rest-api-specs or specify --spec-repo path") + sys.exit(1) + + if not tspconfig_path.exists(): + print_red(f"ERROR: tspconfig.yaml not found at {tspconfig_path}") + sys.exit(1) + + if not main_tsp_path.exists(): + print_red(f"ERROR: main.tsp not found at {main_tsp_path}") + sys.exit(1) + + print_green(" ✓ All inputs validated") + + # ========================================================================= + # STEP 2: Create temp directory for generation + # ========================================================================= + step_banner(2, "Creating temp directory") + + temp_dir = Path(tempfile.mkdtemp(prefix="azure_sdk_gen_")) + print_blue(f" Temp directory: {temp_dir}") + print_green(" ✓ Temp directory created") try: - with urlopen(from_url) as response: - with open(f"{to_path}/{with_file_name}", "w", encoding="utf-8") as f: - f.write(response.read().decode("utf-8")) - except (OSError, URLError, HTTPError) as e: - sys.exit( - f"Connection error while trying to download file from {from_url}: {e}. Please try running the script again." + # ===================================================================== + # STEP 3: Run TypeSpec compilation + # ===================================================================== + step_banner(3, "Running TypeSpec compilation") + + print_blue(f" Working directory: {typespec_project_dir}") + print_blue(f" Output directory: {temp_dir}") + + # Build the tsp compile command + commands = [ + "npx", + "tsp", + "compile", + "main.tsp", + "--emit", + "@azure-tools/typespec-python", + "--output-dir", + str(temp_dir), + ] + + print_blue(f" Command: {' '.join(commands)}") + print() + + run_command( + commands, + cwd=str(typespec_project_dir), + throw_on_retcode=True, + **command_args, ) + print_green(" ✓ TypeSpec compilation completed") -def regenerate_restclient(api_tag, verbose): - readme_path = Path("./swagger/machinelearningservices/resource-manager/readme.md") - restclient_path = Path("./azure/ai/ml/_restclient/") - command_args = {"shell": system() == "Windows", "stream_stdout": verbose} + # ===================================================================== + # STEP 4: Find the generated restclient folder + # ===================================================================== + step_banner(4, "Finding generated restclient") - api_tag_arg = api_tag.lower() if api_tag else None - if not api_tag_arg or api_tag_arg == MULTI_API_TAG: - tag_arg = f"--{MULTI_API_TAG}" - else: - tag_arg = f"--tag={api_tag_arg}" - - commands = [ - "autorest", - "--python", - "--track2", - f"--python-sdks-folder={restclient_path.absolute()}", - "--package-version=0.1.0", - tag_arg, - str(readme_path.absolute()), - "--modelerfour.lenient-model-deduplication", - '--title="Azure Machine Learning Workspaces"', - ] - print_blue(f"- Running autorest command: {' '.join(commands)}") - run_command( - commands, - throw_on_retcode=True, - **command_args, - ) + # The TypeSpec emitter generates to: {output-dir}/{service-dir}/... + # Based on tspconfig.yaml, service-dir is "sdk/ml/azure-ai-ml" + generated_restclient_base = temp_dir / RESTCLIENT_RELATIVE_PATH + + print_blue(f" Looking in: {generated_restclient_base}") + + if not generated_restclient_base.exists(): + print_red(f"ERROR: Generated restclient path not found: {generated_restclient_base}") + print_yellow(" Listing temp directory contents:") + for item in temp_dir.rglob("*"): + if item.is_dir(): + print(f" [DIR] {item.relative_to(temp_dir)}") + sys.exit(1) + + # Find the version folder + generated_version_path = generated_restclient_base / api_version_normalized + + if not generated_version_path.exists(): + print_yellow(f" Version folder {api_version_normalized} not found") + print_yellow(" Available folders in restclient directory:") + for item in generated_restclient_base.iterdir(): + if item.is_dir(): + print(f" - {item.name}") + + # Try to find any version folder + version_folders = [d for d in generated_restclient_base.iterdir() if d.is_dir() and d.name.startswith("v")] + if len(version_folders) == 1: + generated_version_path = version_folders[0] + print_yellow(f" Using found version folder: {generated_version_path.name}") + else: + print_red("ERROR: Could not determine which version folder to use") + sys.exit(1) + + print_green(f" ✓ Found generated restclient at: {generated_version_path}") + + # ===================================================================== + # STEP 5: Fix TypeSpec emitter bugs + # ===================================================================== + step_banner(5, "Fixing TypeSpec emitter bugs") + + total_fixes = 0 + + # Fix duplicate api_version in sync operations + sync_ops_path = generated_version_path / "operations" / "_operations.py" + total_fixes += fix_duplicate_api_version(sync_ops_path, verbose) + + # Fix duplicate api_version in async operations + async_ops_path = generated_version_path / "aio" / "operations" / "_operations.py" + total_fixes += fix_duplicate_api_version(async_ops_path, verbose) + + if total_fixes > 0: + print_green(f" ✓ Fixed {total_fixes} total duplicate api_version occurrences") + else: + print_green(" ✓ No duplicate api_version bugs found (may be fixed in newer emitter)") + + # ===================================================================== + # STEP 6: Copy restclient to target location + # ===================================================================== + step_banner(6, "Copying restclient to target location") + + print_blue(f" Source: {generated_version_path}") + print_blue(f" Target: {target_restclient_path}") + + # Remove existing target if it exists + if target_restclient_path.exists(): + print_yellow(f" Removing existing directory: {target_restclient_path}") + shutil.rmtree(target_restclient_path) + + # Copy the generated restclient + shutil.copytree(generated_version_path, target_restclient_path) + + print_green(f" ✓ Restclient copied to {target_restclient_path}") + + # ===================================================================== + # STEP 7: Summary + # ===================================================================== + step_banner(7, "Summary") + + # Count files + py_files = list(target_restclient_path.rglob("*.py")) + print_green(f" ✓ Generated {len(py_files)} Python files") + print_green(f" ✓ Restclient location: {target_restclient_path}") + print() + print_green(" ✓ Regeneration completed successfully!") + print() + print_blue(" Next steps:") + print_blue(" 1. Review the generated code") + print_blue(" 3. Commit the changes") + + finally: + # ===================================================================== + # Cleanup temp directory + # ===================================================================== + print() + print_blue(f" Cleaning up temp directory: {temp_dir}") + shutil.rmtree(temp_dir, ignore_errors=True) + print_green(" ✓ Temp directory cleaned up") if __name__ == "__main__": - parser = ArgumentParser() + parser = ArgumentParser( + description="Regenerate restclient from TypeSpec definitions", + epilog="Example: python regenerate_restclient.py -a v2025_10_01_preview --spec-repo C:\\Repos\\azure-rest-api-specs -v", + ) parser.add_argument( "-a", - "--api-tag", - required=False, + "--api-version", + required=True, help=( - "Specifies which API to generate using autorest. If not supplied, all APIs are targeted.\n" - "Must match the name of a tag in the sdk/ml/azure-ai-ml/swagger/machinelearningservices/" - "resource-manager/readme.md file." + "Specifies which API version to generate (e.g., v2025_10_01_preview).\n" + "This should match a version defined in the TypeSpec project." ), ) - parser.add_argument("-v", "--verbose", action="store_true", required=False, help="turn on verbose output") + parser.add_argument( + "--spec-repo", + type=Path, + required=True, + help="Path to the azure-rest-api-specs repository (e.g., C:\\Repos\\azure-rest-api-specs)", + ) + parser.add_argument("-v", "--verbose", action="store_true", required=False, help="Turn on verbose output") args = parser.parse_args() - regenerate_restclient(args.api_tag, args.verbose) + regenerate_restclient(args.api_version, args.spec_repo, args.verbose)