Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
320 changes: 275 additions & 45 deletions sdk/ml/azure-ai-ml/scripts/regenerate_restclient.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,19 +6,38 @@
# license information.
# --------------------------------------------------------------------------

"""
Script to regenerate restclient from TypeSpec definitions.

Usage:
python regenerate_restclient.py -a v2025_10_01_preview --spec-repo C:\\Repos\\azure-rest-api-specs
python regenerate_restclient.py -a v2025_10_01_preview --spec-repo C:\\Repos\\azure-rest-api-specs -v

This script:
1. Compiles TypeSpec from azure-rest-api-specs repo
2. Generates Python SDK to a temp directory
3. Copies only the restclient folder to the target location
4. Fixes known TypeSpec emitter bugs (duplicate api_version)
5. Cleans up temp directory
"""

import logging
import os
import re
import shutil
import subprocess
import sys
import tempfile
import time
from argparse import ArgumentParser
from pathlib import Path
from platform import system
from urllib.request import urlopen

module_logger = logging.getLogger(__name__)

MULTI_API_TAG = "multiapi"
# Paths relative to the spec repo
TYPESPEC_PROJECT_PATH = "specification/machinelearningservices/MachineLearningServices.Management"
RESTCLIENT_RELATIVE_PATH = "sdk/ml/azure-ai-ml/azure/ai/ml/_restclient"


class Color:
Expand Down Expand Up @@ -94,64 +113,275 @@ def print_blue(message):
print(Color.BLUE + message + Color.END)


def download_file(from_url: str, to_path: Path, with_file_name: str) -> None:
print_blue(f"- Downloading {with_file_name} from {from_url} to {to_path}")
def print_green(message):
print(Color.GREEN + message + Color.END)


def print_yellow(message):
print(Color.YELLOW + message + Color.END)


def print_red(message):
print(Color.RED + message + Color.END)


def step_banner(step_num, message):
"""Print a step banner for visibility."""
print()
print(Color.CYAN + "=" * 60 + Color.END)
print(Color.CYAN + f"STEP {step_num}: {message}" + Color.END)
print(Color.CYAN + "=" * 60 + Color.END)


def fix_duplicate_api_version(file_path: Path, verbose: bool = False) -> int:
"""
Fix the TypeSpec Python emitter bug that generates duplicate api_version parameters.
See: https://github.com/microsoft/typespec/issues/9384

Returns the number of duplicates fixed.
"""
if not file_path.exists():
return 0

content = file_path.read_text(encoding="utf-8")

# Pattern to find consecutive duplicate api_version lines
pattern = r"(\s+api_version=self\._config\.api_version,\r?\n)\s+api_version=self\._config\.api_version,"

# Count matches before fixing
matches = re.findall(pattern, content)
count = len(matches)

if count > 0:
fixed_content = re.sub(pattern, r"\1", content)
file_path.write_text(fixed_content, encoding="utf-8")
if verbose:
print_yellow(f" Fixed {count} duplicate api_version occurrences in {file_path.name}")

return count


def regenerate_restclient(api_version: str, spec_repo: Path, verbose: bool = False):
"""
Regenerate restclient from TypeSpec definitions.

Args:
api_version: The API version to generate (e.g., "v2025_10_01_preview")
spec_repo: Path to the azure-rest-api-specs repository
verbose: Whether to show verbose output
"""
# Normalize api_version format (support both v2025_10_01_preview and v2025-10-01-preview)
api_version_normalized = api_version.lower().replace("-", "_")
if not api_version_normalized.startswith("v"):
api_version_normalized = "v" + api_version_normalized

# Paths
typespec_project_dir = spec_repo / TYPESPEC_PROJECT_PATH
tspconfig_path = typespec_project_dir / "tspconfig.yaml"
main_tsp_path = typespec_project_dir / "main.tsp"

# Get the script's directory to find the restclient path
script_dir = Path(__file__).parent.absolute()
sdk_package_dir = script_dir.parent # azure-ai-ml directory
restclient_base_path = sdk_package_dir / "azure" / "ai" / "ml" / "_restclient"
target_restclient_path = restclient_base_path / api_version_normalized

command_args = {"shell": system() == "Windows", "stream_stdout": verbose}

# =========================================================================
# STEP 1: Validate inputs
# =========================================================================
step_banner(1, "Validating inputs")

print_blue(f" API Version: {api_version_normalized}")
print_blue(f" Spec repo: {spec_repo}")
print_blue(f" TypeSpec project: {typespec_project_dir}")
print_blue(f" Target restclient path: {target_restclient_path}")

if not spec_repo.exists():
print_red(f"ERROR: Spec repo not found at {spec_repo}")
print_yellow("Please clone azure-rest-api-specs or specify --spec-repo path")
sys.exit(1)

if not tspconfig_path.exists():
print_red(f"ERROR: tspconfig.yaml not found at {tspconfig_path}")
sys.exit(1)

if not main_tsp_path.exists():
print_red(f"ERROR: main.tsp not found at {main_tsp_path}")
sys.exit(1)

print_green(" ✓ All inputs validated")

# =========================================================================
# STEP 2: Create temp directory for generation
# =========================================================================
step_banner(2, "Creating temp directory")

temp_dir = Path(tempfile.mkdtemp(prefix="azure_sdk_gen_"))
print_blue(f" Temp directory: {temp_dir}")
print_green(" ✓ Temp directory created")

try:
with urlopen(from_url) as response:
with open(f"{to_path}/{with_file_name}", "w", encoding="utf-8") as f:
f.write(response.read().decode("utf-8"))
except (OSError, URLError, HTTPError) as e:
sys.exit(
f"Connection error while trying to download file from {from_url}: {e}. Please try running the script again."
# =====================================================================
# STEP 3: Run TypeSpec compilation
# =====================================================================
step_banner(3, "Running TypeSpec compilation")

print_blue(f" Working directory: {typespec_project_dir}")
print_blue(f" Output directory: {temp_dir}")

# Build the tsp compile command
commands = [
"npx",
"tsp",
"compile",
"main.tsp",
"--emit",
"@azure-tools/typespec-python",
"--output-dir",
str(temp_dir),
]

print_blue(f" Command: {' '.join(commands)}")
print()

run_command(
commands,
cwd=str(typespec_project_dir),
throw_on_retcode=True,
**command_args,
)

print_green(" ✓ TypeSpec compilation completed")

def regenerate_restclient(api_tag, verbose):
readme_path = Path("./swagger/machinelearningservices/resource-manager/readme.md")
restclient_path = Path("./azure/ai/ml/_restclient/")
command_args = {"shell": system() == "Windows", "stream_stdout": verbose}
# =====================================================================
# STEP 4: Find the generated restclient folder
# =====================================================================
step_banner(4, "Finding generated restclient")

api_tag_arg = api_tag.lower() if api_tag else None
if not api_tag_arg or api_tag_arg == MULTI_API_TAG:
tag_arg = f"--{MULTI_API_TAG}"
else:
tag_arg = f"--tag={api_tag_arg}"

commands = [
"autorest",
"--python",
"--track2",
f"--python-sdks-folder={restclient_path.absolute()}",
"--package-version=0.1.0",
tag_arg,
str(readme_path.absolute()),
"--modelerfour.lenient-model-deduplication",
'--title="Azure Machine Learning Workspaces"',
]
print_blue(f"- Running autorest command: {' '.join(commands)}")
run_command(
commands,
throw_on_retcode=True,
**command_args,
)
# The TypeSpec emitter generates to: {output-dir}/{service-dir}/...
# Based on tspconfig.yaml, service-dir is "sdk/ml/azure-ai-ml"
generated_restclient_base = temp_dir / RESTCLIENT_RELATIVE_PATH

print_blue(f" Looking in: {generated_restclient_base}")

if not generated_restclient_base.exists():
print_red(f"ERROR: Generated restclient path not found: {generated_restclient_base}")
print_yellow(" Listing temp directory contents:")
for item in temp_dir.rglob("*"):
if item.is_dir():
print(f" [DIR] {item.relative_to(temp_dir)}")
sys.exit(1)

# Find the version folder
generated_version_path = generated_restclient_base / api_version_normalized

if not generated_version_path.exists():
print_yellow(f" Version folder {api_version_normalized} not found")
print_yellow(" Available folders in restclient directory:")
for item in generated_restclient_base.iterdir():
if item.is_dir():
print(f" - {item.name}")

# Try to find any version folder
version_folders = [d for d in generated_restclient_base.iterdir() if d.is_dir() and d.name.startswith("v")]
if len(version_folders) == 1:
generated_version_path = version_folders[0]
print_yellow(f" Using found version folder: {generated_version_path.name}")
else:
print_red("ERROR: Could not determine which version folder to use")
sys.exit(1)

print_green(f" ✓ Found generated restclient at: {generated_version_path}")

# =====================================================================
# STEP 5: Fix TypeSpec emitter bugs
# =====================================================================
step_banner(5, "Fixing TypeSpec emitter bugs")

total_fixes = 0

# Fix duplicate api_version in sync operations
sync_ops_path = generated_version_path / "operations" / "_operations.py"
total_fixes += fix_duplicate_api_version(sync_ops_path, verbose)

# Fix duplicate api_version in async operations
async_ops_path = generated_version_path / "aio" / "operations" / "_operations.py"
total_fixes += fix_duplicate_api_version(async_ops_path, verbose)

if total_fixes > 0:
print_green(f" ✓ Fixed {total_fixes} total duplicate api_version occurrences")
else:
print_green(" ✓ No duplicate api_version bugs found (may be fixed in newer emitter)")

# =====================================================================
# STEP 6: Copy restclient to target location
# =====================================================================
step_banner(6, "Copying restclient to target location")

print_blue(f" Source: {generated_version_path}")
print_blue(f" Target: {target_restclient_path}")

# Remove existing target if it exists
if target_restclient_path.exists():
print_yellow(f" Removing existing directory: {target_restclient_path}")
shutil.rmtree(target_restclient_path)

# Copy the generated restclient
shutil.copytree(generated_version_path, target_restclient_path)

print_green(f" ✓ Restclient copied to {target_restclient_path}")

# =====================================================================
# STEP 7: Summary
# =====================================================================
step_banner(7, "Summary")

# Count files
py_files = list(target_restclient_path.rglob("*.py"))
print_green(f" ✓ Generated {len(py_files)} Python files")
print_green(f" ✓ Restclient location: {target_restclient_path}")
print()
print_green(" ✓ Regeneration completed successfully!")
print()
print_blue(" Next steps:")
print_blue(" 1. Review the generated code")
print_blue(" 3. Commit the changes")

finally:
# =====================================================================
# Cleanup temp directory
# =====================================================================
print()
print_blue(f" Cleaning up temp directory: {temp_dir}")
shutil.rmtree(temp_dir, ignore_errors=True)
print_green(" ✓ Temp directory cleaned up")


if __name__ == "__main__":
parser = ArgumentParser()
parser = ArgumentParser(
description="Regenerate restclient from TypeSpec definitions",
epilog="Example: python regenerate_restclient.py -a v2025_10_01_preview --spec-repo C:\\Repos\\azure-rest-api-specs -v",
)

parser.add_argument(
"-a",
"--api-tag",
required=False,
"--api-version",
required=True,
help=(
"Specifies which API to generate using autorest. If not supplied, all APIs are targeted.\n"
"Must match the name of a tag in the sdk/ml/azure-ai-ml/swagger/machinelearningservices/"
"resource-manager/readme.md file."
"Specifies which API version to generate (e.g., v2025_10_01_preview).\n"
"This should match a version defined in the TypeSpec project."
),
)
parser.add_argument("-v", "--verbose", action="store_true", required=False, help="turn on verbose output")
parser.add_argument(
"--spec-repo",
type=Path,
required=True,
help="Path to the azure-rest-api-specs repository (e.g., C:\\Repos\\azure-rest-api-specs)",
)
parser.add_argument("-v", "--verbose", action="store_true", required=False, help="Turn on verbose output")

args = parser.parse_args()

regenerate_restclient(args.api_tag, args.verbose)
regenerate_restclient(args.api_version, args.spec_repo, args.verbose)