|
| 1 | +import asyncio |
| 2 | +import importlib |
| 3 | +import json |
| 4 | +import os |
| 5 | +import shlex |
| 6 | +import sys |
| 7 | +import tempfile |
| 8 | +import time |
| 9 | +from importlib.metadata import entry_points |
| 10 | +from importlib.util import find_spec |
| 11 | +from typing import Any |
| 12 | + |
| 13 | +import click |
| 14 | +from aiohttp import ClientSession, UnixConnector, web |
| 15 | + |
| 16 | +from ._utils._console import ConsoleLogger |
| 17 | +from .cli_debug import debug |
| 18 | +from .cli_eval import eval |
| 19 | +from .cli_run import run |
| 20 | + |
| 21 | +console = ConsoleLogger() |
| 22 | + |
| 23 | +SOCKET_ENV_VAR = "UIPATH_SERVER_SOCKET" |
| 24 | +DEFAULT_SOCKET_PATH = "/tmp/uipath-server.sock" |
| 25 | +DEFAULT_PORT = 8765 |
| 26 | + |
| 27 | +IS_WINDOWS = sys.platform == "win32" |
| 28 | + |
| 29 | +COMMANDS = { |
| 30 | + "run": run, |
| 31 | + "debug": debug, |
| 32 | + "eval": eval, |
| 33 | +} |
| 34 | + |
| 35 | +DEFAULT_PRELOAD_MODULES = [ |
| 36 | + # Network/async - slowest to load |
| 37 | + "pysignalr.client", |
| 38 | + "socketio", |
| 39 | + "httpx", |
| 40 | + # Validation/serialization |
| 41 | + "pydantic", |
| 42 | + "pydantic_function_models", |
| 43 | + # CLI/UI |
| 44 | + "click", |
| 45 | + "rich", |
| 46 | +] |
| 47 | + |
| 48 | + |
| 49 | +def preload_modules() -> None: |
| 50 | + """Pre-load modules registered by all uipath packages.""" |
| 51 | + console.info("Pre-loading modules...") |
| 52 | + start = time.perf_counter() |
| 53 | + |
| 54 | + modules_to_load: set[str] = set(DEFAULT_PRELOAD_MODULES) |
| 55 | + |
| 56 | + for ep in entry_points(group="uipath.preload"): |
| 57 | + try: |
| 58 | + get_modules = ep.load() |
| 59 | + modules_to_load.update(get_modules()) |
| 60 | + except Exception as e: |
| 61 | + console.warning(f"Failed to load entry point {ep.name}: {e}") |
| 62 | + |
| 63 | + for module_name in modules_to_load: |
| 64 | + if module_name in sys.modules: |
| 65 | + continue |
| 66 | + if find_spec(module_name) is None: |
| 67 | + continue |
| 68 | + try: |
| 69 | + importlib.import_module(module_name) |
| 70 | + console.success(f"Pre-loaded module: {module_name}") |
| 71 | + except ImportError as e: |
| 72 | + console.warning(f"Failed to load {module_name}: {e}") |
| 73 | + |
| 74 | + elapsed = time.perf_counter() - start |
| 75 | + console.success(f"Modules pre-loaded in {elapsed:.2f}s") |
| 76 | + |
| 77 | + |
| 78 | +def generate_socket_path() -> str: |
| 79 | + """Generate a unique socket path for the server to listen on.""" |
| 80 | + return os.path.join(tempfile.gettempdir(), f"uipath-server-{os.getpid()}.sock") |
| 81 | + |
| 82 | + |
| 83 | +def get_field(message: dict[str, Any], *keys: str) -> Any: |
| 84 | + """Get a field from message, trying multiple key variations.""" |
| 85 | + for key in keys: |
| 86 | + if key in message: |
| 87 | + return message[key] |
| 88 | + return None |
| 89 | + |
| 90 | + |
| 91 | +def parse_args(args: str | list[str] | None) -> list[str]: |
| 92 | + """Parse args into a list of strings.""" |
| 93 | + if args is None: |
| 94 | + return [] |
| 95 | + if isinstance(args, list): |
| 96 | + return args |
| 97 | + if isinstance(args, str): |
| 98 | + return shlex.split(args) |
| 99 | + return [] |
| 100 | + |
| 101 | + |
| 102 | +async def send_ack(ack_socket_path: str, server_socket_path: str) -> None: |
| 103 | + """Send acknowledgment via HTTP POST to the ack socket.""" |
| 104 | + ack_message: dict[str, str] = { |
| 105 | + "status": "ready", |
| 106 | + "socket": server_socket_path, |
| 107 | + } |
| 108 | + |
| 109 | + conn = UnixConnector(path=ack_socket_path) |
| 110 | + try: |
| 111 | + async with ClientSession(connector=conn) as session: |
| 112 | + async with session.post( |
| 113 | + "http://localhost/ack", # placeholder URL for Unix socket |
| 114 | + json=ack_message, |
| 115 | + ) as response: |
| 116 | + if response.status == 200: |
| 117 | + console.success(f"Sent ack to {ack_socket_path}") |
| 118 | + else: |
| 119 | + console.error(f"Ack failed with status {response.status}") |
| 120 | + raise RuntimeError(f"Ack failed: {response.status}") |
| 121 | + except Exception as e: |
| 122 | + console.error(f"Failed to send ack to {ack_socket_path}: {e}") |
| 123 | + raise |
| 124 | + |
| 125 | + |
| 126 | +async def handle_health(request: web.Request) -> web.Response: |
| 127 | + """Handle GET /health endpoint.""" |
| 128 | + return web.Response(text="OK", status=200) |
| 129 | + |
| 130 | + |
| 131 | +async def handle_start(request: web.Request) -> web.Response: |
| 132 | + """Handle POST /jobs/{job_key}/start endpoint.""" |
| 133 | + job_key = request.match_info.get("job_key") |
| 134 | + if not job_key: |
| 135 | + return web.json_response( |
| 136 | + {"success": False, "error": "Missing job_key"}, |
| 137 | + status=400, |
| 138 | + ) |
| 139 | + |
| 140 | + try: |
| 141 | + message: dict[str, Any] = await request.json() |
| 142 | + except json.JSONDecodeError: |
| 143 | + return web.json_response( |
| 144 | + {"success": False, "error": "Invalid JSON"}, |
| 145 | + status=400, |
| 146 | + ) |
| 147 | + |
| 148 | + command_name = get_field(message, "command", "Command") |
| 149 | + if not isinstance(command_name, str): |
| 150 | + return web.json_response( |
| 151 | + {"success": False, "error": "Missing or invalid field: 'command'"}, |
| 152 | + status=400, |
| 153 | + ) |
| 154 | + |
| 155 | + args_raw = get_field(message, "args", "Args") |
| 156 | + args = parse_args(args_raw) |
| 157 | + |
| 158 | + env_vars = get_field(message, "environmentVariables", "EnvironmentVariables") or {} |
| 159 | + working_dir = get_field(message, "workingDirectory", "WorkingDirectory") |
| 160 | + |
| 161 | + console.info(f"Starting job {job_key}: {command_name} {args}") |
| 162 | + |
| 163 | + cmd = COMMANDS.get(command_name) |
| 164 | + if cmd is None: |
| 165 | + return web.json_response( |
| 166 | + {"success": False, "error": f"Unknown command: {command_name}"}, |
| 167 | + status=400, |
| 168 | + ) |
| 169 | + |
| 170 | + # Save original state |
| 171 | + original_cwd = os.getcwd() |
| 172 | + original_env = os.environ.copy() |
| 173 | + |
| 174 | + console.info(f"Original cwd: {original_cwd}") |
| 175 | + console.info(f"Requested working_dir: {working_dir}") |
| 176 | + |
| 177 | + try: |
| 178 | + if isinstance(env_vars, dict): |
| 179 | + os.environ.update(env_vars) |
| 180 | + |
| 181 | + if working_dir and isinstance(working_dir, str): |
| 182 | + os.chdir(working_dir) |
| 183 | + |
| 184 | + result = await asyncio.to_thread(cmd.main, args, standalone_mode=False) |
| 185 | + |
| 186 | + return web.json_response( |
| 187 | + { |
| 188 | + "success": True, |
| 189 | + "job_key": job_key, |
| 190 | + "result": result, |
| 191 | + } |
| 192 | + ) |
| 193 | + except SystemExit as e: |
| 194 | + exit_code = e.code if isinstance(e.code, int) else 1 |
| 195 | + return web.json_response( |
| 196 | + { |
| 197 | + "success": exit_code == 0, |
| 198 | + "job_key": job_key, |
| 199 | + "error": None if exit_code == 0 else f"Exit code: {exit_code}", |
| 200 | + } |
| 201 | + ) |
| 202 | + except Exception as e: |
| 203 | + return web.json_response( |
| 204 | + {"success": False, "job_key": job_key, "error": str(e)}, |
| 205 | + status=500, |
| 206 | + ) |
| 207 | + finally: |
| 208 | + # Restore original state |
| 209 | + os.chdir(original_cwd) |
| 210 | + os.environ.clear() |
| 211 | + os.environ.update(original_env) |
| 212 | + |
| 213 | + |
| 214 | +def create_app() -> web.Application: |
| 215 | + """Create the aiohttp application.""" |
| 216 | + app = web.Application() |
| 217 | + app.router.add_get("/health", handle_health) |
| 218 | + app.router.add_post("/jobs/{job_key}/start", handle_start) |
| 219 | + return app |
| 220 | + |
| 221 | + |
| 222 | +async def start_unix_server(ack_socket_path: str) -> None: |
| 223 | + """Start Unix domain socket HTTP server.""" |
| 224 | + server_socket_path = generate_socket_path() |
| 225 | + |
| 226 | + if os.path.exists(server_socket_path): |
| 227 | + os.unlink(server_socket_path) |
| 228 | + |
| 229 | + app = create_app() |
| 230 | + runner = web.AppRunner(app) |
| 231 | + await runner.setup() |
| 232 | + |
| 233 | + try: |
| 234 | + site = web.UnixSite(runner, server_socket_path) |
| 235 | + await site.start() |
| 236 | + |
| 237 | + console.success(f"Server listening on unix://{server_socket_path}") |
| 238 | + |
| 239 | + await send_ack(ack_socket_path, server_socket_path) |
| 240 | + |
| 241 | + while True: |
| 242 | + await asyncio.sleep(3600) |
| 243 | + finally: |
| 244 | + await runner.cleanup() |
| 245 | + if os.path.exists(server_socket_path): |
| 246 | + os.unlink(server_socket_path) |
| 247 | + |
| 248 | + |
| 249 | +async def start_tcp_server(host: str, port: int) -> None: |
| 250 | + """Start TCP HTTP server (Windows fallback).""" |
| 251 | + app = create_app() |
| 252 | + runner = web.AppRunner(app) |
| 253 | + await runner.setup() |
| 254 | + |
| 255 | + try: |
| 256 | + site = web.TCPSite(runner, host, port) |
| 257 | + await site.start() |
| 258 | + |
| 259 | + console.success(f"Server listening on http://{host}:{port}") |
| 260 | + |
| 261 | + while True: |
| 262 | + await asyncio.sleep(3600) |
| 263 | + finally: |
| 264 | + await runner.cleanup() |
| 265 | + |
| 266 | + |
| 267 | +@click.command() |
| 268 | +@click.option( |
| 269 | + "--socket", |
| 270 | + type=str, |
| 271 | + default=None, |
| 272 | + help=f"Unix socket path to send ready ack to (default: ${SOCKET_ENV_VAR} or {DEFAULT_SOCKET_PATH})", |
| 273 | +) |
| 274 | +@click.option( |
| 275 | + "--port", |
| 276 | + type=int, |
| 277 | + default=None, |
| 278 | + help=f"TCP port, used on Windows or when --tcp flag is set (default: {DEFAULT_PORT})", |
| 279 | +) |
| 280 | +@click.option( |
| 281 | + "--tcp", |
| 282 | + is_flag=True, |
| 283 | + help="Force TCP mode even on Unix systems", |
| 284 | +) |
| 285 | +def server(socket: str | None, port: int | None, tcp: bool) -> None: |
| 286 | + """Start an HTTP server that forwards commands to run/debug/eval. |
| 287 | +
|
| 288 | + Creates its own socket to listen on and sends an ack to --socket with: |
| 289 | + {"status": "ready", "socket": "/path/to/server.sock"} |
| 290 | +
|
| 291 | + Endpoint: POST /jobs/{job_key}/start |
| 292 | + Body: {"command": "run", "args": "agent.json '{}'", "environmentVariables": {}, "workingDirectory": "/path"} |
| 293 | +
|
| 294 | + Endpoint: GET /health |
| 295 | + """ |
| 296 | + use_tcp = IS_WINDOWS or tcp |
| 297 | + |
| 298 | + preload_modules() |
| 299 | + |
| 300 | + try: |
| 301 | + if use_tcp: |
| 302 | + asyncio.run(start_tcp_server("127.0.0.1", port or DEFAULT_PORT)) |
| 303 | + else: |
| 304 | + ack_socket_path = ( |
| 305 | + socket or os.environ.get(SOCKET_ENV_VAR) or DEFAULT_SOCKET_PATH |
| 306 | + ) |
| 307 | + asyncio.run(start_unix_server(ack_socket_path)) |
| 308 | + except KeyboardInterrupt: |
| 309 | + console.info("Shutting down") |
0 commit comments