From b2c70db9ac1c376099f6fc38d97d158646590118 Mon Sep 17 00:00:00 2001 From: "qingxu.fu" Date: Mon, 19 Jan 2026 19:31:55 +0800 Subject: [PATCH 01/25] tinkerscript-v1 --- ajet/backbone/main_vllm.py | 7 +- ajet/default_config/ajet_default.yaml | 3 + ajet/task_rollout/single_worker.py | 13 +- ajet/task_runner/tinkerscript_runner.py | 156 ++++++++++++ .../weight_tuner/as_oai_baseurl_apikey.py | 14 ++ .../experimental/as_oai_model_client.py | 3 +- .../experimental/as_oai_model_server.py | 87 ++++++- .../experimental/as_tinkerscript_client.py | 124 ++++++++++ .../experimental/as_tinkerscript_server.py | 232 ++++++++++++++++++ ajet/utils/config_utils.py | 2 +- ajet/utils/core_env_vars.py | 8 +- ajet_tinkerscript.md | 1 + ajet_tinkerscript.py | 90 +++++++ docs/en/workflow.md | 2 +- tutorial/demo_tinkerjet/README.md | 64 +++++ .../ajet_tinkerscript_default.yaml | 49 ++++ .../demo_tinkerjet/demo_tinkerjet_math.py | 87 +++++++ 17 files changed, 922 insertions(+), 20 deletions(-) create mode 100644 ajet/task_runner/tinkerscript_runner.py create mode 100644 ajet/tuner_lib/weight_tuner/experimental/as_tinkerscript_client.py create mode 100644 ajet/tuner_lib/weight_tuner/experimental/as_tinkerscript_server.py create mode 100644 ajet_tinkerscript.md create mode 100644 ajet_tinkerscript.py create mode 100644 tutorial/demo_tinkerjet/README.md create mode 100644 tutorial/demo_tinkerjet/ajet_tinkerscript_default.yaml create mode 100644 tutorial/demo_tinkerjet/demo_tinkerjet_math.py diff --git a/ajet/backbone/main_vllm.py b/ajet/backbone/main_vllm.py index 686a35cd..f610b6ce 100644 --- a/ajet/backbone/main_vllm.py +++ b/ajet/backbone/main_vllm.py @@ -144,6 +144,7 @@ def run(config): max_parallel = config.ajet.debug.debug_max_parallel n_task = config.ajet.debug.debug_first_n_tasks vllm_port = config.ajet.debug.debug_vllm_port + enable_tinkerscript_mode = config.ajet.enable_tinkerscript_mode # --------- init --------- async_rollout_manager = ChatCompletionScheduler( @@ -166,8 +167,10 @@ def run(config): tasks = task_reader.get_validation_tasks() logger.info(tasks[:n_task]) ctx_tracker = parallel_env.rollout( - tasks=tasks[:n_task], mode="sample", epoch="1" - ) # "sample" or "validate" + tasks=tasks[:n_task], + mode="sample" if not enable_tinkerscript_mode else "sample-ts", # type: ignore + epoch="1" + ) _ = parallel_env.to_dataproto(ctx_tracker) diff --git a/ajet/default_config/ajet_default.yaml b/ajet/default_config/ajet_default.yaml index fb4a6143..41f6f18d 100644 --- a/ajet/default_config/ajet_default.yaml +++ b/ajet/default_config/ajet_default.yaml @@ -282,6 +282,9 @@ ajet: # the experimental ZeroMQ interchange server feature that allows `tuner.as_oai_baseurl_apikey` feature enable_experimental_interchange_server: True + # train in cloud, run episode locally + enable_tinkerscript_mode: False + # both tinkerscript / oai share the same interchange server interchange_server: interchange_method: 'ipc' # options: 'tcp' (multi-nodes) or 'ipc' (1 node) interchange_server_port: 'auto' diff --git a/ajet/task_rollout/single_worker.py b/ajet/task_rollout/single_worker.py index d65a979a..32dfabf2 100644 --- a/ajet/task_rollout/single_worker.py +++ b/ajet/task_rollout/single_worker.py @@ -12,6 +12,7 @@ from ajet.task_rollout.async_llm_bridge import AsyncLlmBridge from ajet.task_rollout.resource_keeper import ResourceKeeper from ajet.task_runner.general_runner import GeneralRunner +from ajet.task_runner.tinkerscript_runner import TinkerScriptRunner from ajet.utils.retry import retry_with_backoff from ajet.utils.sample import get_sample_params from ajet.utils.testing_utils import TestFailException, TestSuccessException @@ -59,6 +60,7 @@ def __init__( assert isinstance(self.pad_token_id, int), "pad_token_id must be an integer" self.current_token = 0 self.current_global_steps: int | str = "NA" + self.enable_tinkerscript_mode = config.ajet.enable_tinkerscript_mode self.async_llm_bridge = AsyncLlmBridge( config=config, async_rollout_manager=async_rollout_manager, @@ -110,9 +112,14 @@ def rollout_env_worker( with ResourceKeeper(workflow_task, config=self.config) as resource_keeper: try: workflow_task = resource_keeper.prepare() - agent_runner = GeneralRunner( - llm_inference_fn=llm_inference_fn, tokenizer=self.tokenizer, config=self.config - ) + if self.enable_tinkerscript_mode: + agent_runner = TinkerScriptRunner( + llm_inference_fn=llm_inference_fn, tokenizer=self.tokenizer, config=self.config + ) + else: + agent_runner = GeneralRunner( + llm_inference_fn=llm_inference_fn, tokenizer=self.tokenizer, config=self.config + ) tracker = agent_runner.execute( workflow_task=workflow_task, ) diff --git a/ajet/task_runner/tinkerscript_runner.py b/ajet/task_runner/tinkerscript_runner.py new file mode 100644 index 00000000..fbd617c8 --- /dev/null +++ b/ajet/task_runner/tinkerscript_runner.py @@ -0,0 +1,156 @@ + +import atexit +import json +import requests +import zmq +import os +import time +from ajet import AjetTuner +from ajet import WorkflowOutput +from ajet.context_tracker.multiagent_tracking import ( + MultiAgentContextTracker, +) +from ajet.context_tracker.basic_tracker import BaseContextTracker +from ajet.schema.task import WorkflowTask +from ajet.schema.trajectory import Reward +from ajet.task_runner.base_runner import BaseAgentRunner +from ajet.utils.networking import find_free_port +from loguru import logger +from ajet import Workflow + +context = zmq.Context() +atexit.register(context.term) + +class TinkerScriptRunner(BaseAgentRunner): + + def get_zmq_socket(self, episode_uuid: str): + interchange_method = self.config.ajet.interchange_server.interchange_method + if interchange_method == 'tcp': + master_node_ip = os.getenv("MASTER_NODE_IP", "localhost") + episode_contect_address = f"tcp://{master_node_ip}:{find_free_port()}" + elif interchange_method == 'ipc': + ipc_path = f"/tmp/ajet/{episode_uuid}-workflow.sock" + episode_contect_address = f"ipc://{ipc_path}" + else: + raise RuntimeError(f"Unknown interchange_method: {interchange_method}") + return episode_contect_address + + + def get_interchange_server_url(self): + port = os.getenv("AJET_DAT_INTERCHANGE_PORT") + if self.config.ajet.interchange_server.interchange_server_port != 'auto': + port = str(int(self.config.ajet.interchange_server.interchange_server_port)) + assert port is not None, "AJET_DAT_INTERCHANGE_PORT env var must be set" + master_node_ip = os.getenv("MASTER_NODE_IP", "localhost") + base_url = f"http://{master_node_ip}:{port}" + return base_url + + + def register_episode_and_wait_output(self, episode_uuid: str, openai_base_url: str, openai_api_key: str) -> WorkflowOutput: + """Register the episode as ready in the TinkerScript data interchange center.""" + from ajet.tuner_lib.weight_tuner.experimental.as_tinkerscript_server import RegisterEpisodeRequest + + # parse episode_uuid, openai_base_url, openai_api_key + zmq_listen_result_addr = self.get_zmq_socket(episode_uuid) + interchange_http_addr = self.get_interchange_server_url() + rer = RegisterEpisodeRequest( + episode_uuid=episode_uuid, + openai_base_url=openai_base_url, + openai_api_key=openai_api_key, + zmq_listen_result_addr=zmq_listen_result_addr, + ) + logger.info(f"zmq_listen_result_addr: {zmq_listen_result_addr}, interchange_http_addr: {interchange_http_addr}") + + # send http request to tinkerscript server to register episode + while True: + try: + response = requests.post( + f"{interchange_http_addr}/register_episode", + json=rer.model_dump(), # 或者 rer.model_dump() 如果使用 Pydantic v2 + timeout=30 + ) + response.raise_for_status() + result = response.json() + if not result.get('success'): + raise RuntimeError(f"Failed to register episode {episode_uuid}") + logger.info(f"Successfully registered episode {episode_uuid}") + break + except requests.RequestException as e: + logger.error(f"Error registering episode {episode_uuid}: {e}. Retrying...") + time.sleep(5) + + # begin wait for result + zmq_socket = zmq.Context().socket(zmq.REP) + zmq_socket.bind(zmq_listen_result_addr) + message = zmq_socket.recv_string() + logger.success(f"Received workflow output for episode {episode_uuid}") + zmq_socket.send_string("ack") + return WorkflowOutput(**json.loads(message)) + + + def execute(self, workflow_task: WorkflowTask) -> BaseContextTracker: + observation_window = workflow_task.observation_window + task_thread_index = workflow_task.task_thread_index + + hooks = self.runner_hooks( + observation_window=observation_window, + task_thread_index=task_thread_index, + workflow_task=workflow_task, + ) + context_tracker = MultiAgentContextTracker( + llm_inference_fn=self.llm_inference_fn, + tokenizer=self.tokenizer, + config=self.config, + workflow_task = workflow_task, + **hooks, + ) + tuner = AjetTuner( + context_tracker=context_tracker, + llm_inference_fn=self.llm_inference_fn, + workflow_cls=Workflow, + config=self.config, + ) + + baseurl_apikey = tuner.as_oai_baseurl_apikey() + base_url = baseurl_apikey.base_url + api_key = baseurl_apikey.api_key + + workflow_output: WorkflowOutput = self.register_episode_and_wait_output( + episode_uuid=context_tracker.episode_uuid, + openai_base_url=base_url, + openai_api_key=api_key, + ) + + if workflow_output.reward is not None: + raw_reward, is_success = ( + workflow_output.reward, + workflow_output.is_success, + ) + else: + raise ValueError("workflow_output.reward is None in TinkerScriptRunner, this is currently not allowed.") + + workflow_task.gym_env = None # clear gym env client reference to avoid serialization issue + + assert not isinstance( + raw_reward, list + ), "AgentJet will support step reward in future versions." + + # register reward + # TODO: support multi-step reward + reward = Reward( + raw_reward=raw_reward, + raw_step_reward=None, # "AgentJet will support step reward in future versions." + success_rate=1.0 if is_success else 0.0, + madness=0, + description="", + ) + context_tracker.process_reward(reward) + # generate token before merging + context_tracker.group_merge() + # after merging, process and align reward again + context_tracker.process_reward(reward) + # mark the thread as ended + observation_window["step"][task_thread_index] = -1 + tuner.terminate_episode() + context_tracker.log_metrics = workflow_output.log_metrics + return context_tracker diff --git a/ajet/tuner_lib/weight_tuner/as_oai_baseurl_apikey.py b/ajet/tuner_lib/weight_tuner/as_oai_baseurl_apikey.py index 90c2cc72..925e2a11 100644 --- a/ajet/tuner_lib/weight_tuner/as_oai_baseurl_apikey.py +++ b/ajet/tuner_lib/weight_tuner/as_oai_baseurl_apikey.py @@ -18,6 +18,17 @@ class MockAsyncChat(AsyncChat): def completions(self) -> MockAsyncCompletions: # type: ignore return MockAsyncCompletions(self._client) +class OpenaiBaseUrlAndApiKey(BaseModel): + """ At this layer, we will determine which model to use: + - training model + - debug model assigned by user, used when this target is not being trained + """ + + base_url: str = Field(default="http://localhost:27788/v1", description="The base URL for the Ajet's fake OpenAI API") + api_key: str = Field(default="invalid_apikey", description="The Ajet's fake key, which is not a real key, it is a encoded string contain episode_uuid and other stuff.") + model: str = Field(default="reserved_field", description="reserved field.") + + class OpenaiClientBaseUrlTuner(BaseModel): """ At this layer, we will determine which model to use: - training model @@ -40,6 +51,9 @@ def __init__( ): port = os.getenv("AJET_DAT_INTERCHANGE_PORT") + if config.ajet.interchange_server.interchange_server_port != 'auto': + port = str(int(config.ajet.interchange_server.interchange_server_port)) + assert port is not None, "AJET_DAT_INTERCHANGE_PORT env var must be set" master_node_ip = os.getenv("MASTER_NODE_IP", "localhost") diff --git a/ajet/tuner_lib/weight_tuner/experimental/as_oai_model_client.py b/ajet/tuner_lib/weight_tuner/experimental/as_oai_model_client.py index 720c6c77..5faf1add 100644 --- a/ajet/tuner_lib/weight_tuner/experimental/as_oai_model_client.py +++ b/ajet/tuner_lib/weight_tuner/experimental/as_oai_model_client.py @@ -16,7 +16,6 @@ from ajet.utils.thread_executors import SharedInferenceTrackerThreadExecutor, SharedInterchangeThreadExecutor from ajet.utils.networking import find_free_port - context = zmq.Context() atexit.register(context.term) @@ -158,7 +157,7 @@ def _begin_service_threading(self): break timepassed = time.time() - begin_time if timepassed > 60: - logger.warning(f"[client] {self.episode_uuid} | Still waiting for first message... (time passed {timepassed}) for episode_uuid:{self.episode_uuid}...") + if DEBUG: logger.warning(f"[client] {self.episode_uuid} | Still waiting for first message... (time passed {timepassed}) for episode_uuid:{self.episode_uuid}...") continue # parse the incoming request diff --git a/ajet/tuner_lib/weight_tuner/experimental/as_oai_model_server.py b/ajet/tuner_lib/weight_tuner/experimental/as_oai_model_server.py index 089d11eb..fc017bca 100644 --- a/ajet/tuner_lib/weight_tuner/experimental/as_oai_model_server.py +++ b/ajet/tuner_lib/weight_tuner/experimental/as_oai_model_server.py @@ -26,8 +26,9 @@ from pydantic import BaseModel from fastapi import FastAPI, Header, HTTPException, Request from contextlib import asynccontextmanager -from multiprocessing import Process +from multiprocessing import Manager, Process from concurrent.futures import ThreadPoolExecutor +from typing import Coroutine, Optional, Tuple from vllm.entrypoints.openai.protocol import ChatCompletionRequest from openai.types.chat.chat_completion import ChatCompletion @@ -53,12 +54,19 @@ class HealthCheckRequest(BaseModel): DEBUG = False # DEBUG = True - context = zmq.Context() atexit.register(context.term) -def get_app(max_fastapi_threads: int = 512) -> FastAPI: + + + + + + + +def get_app(max_fastapi_threads: int = 512, enable_tinkerscript_mode=False, shared_mem_dict=None, shared_mem_dict_lock=None) -> Tuple[FastAPI, Optional[Coroutine]]: + @asynccontextmanager async def lifespan(app: FastAPI): @@ -70,6 +78,7 @@ async def lifespan(app: FastAPI): SERVER_SHUTDOWN_EVENT.set() app.state.executor.shutdown(wait=False, cancel_futures=True) + app = FastAPI(title="AJet Interchange Endpoint", lifespan=lifespan) @@ -158,24 +167,58 @@ async def chat_completions(request: Request, authorization: str = Header(None)): finally: client_offline.set() + @app.post("/reset") async def reset(): return {"status": "reset_complete"} - return app + + if enable_tinkerscript_mode: + from ajet.tuner_lib.weight_tuner.experimental.as_tinkerscript_server import register_enable_tinkerscript_mode_routes + assert shared_mem_dict is not None, "shared_mem_dict must not be None when enable_tinkerscript_mode is True." + assert shared_mem_dict_lock is not None, "shared_mem_dict_lock must not be None when enable_tinkerscript_mode is True." + app, additional_coro = register_enable_tinkerscript_mode_routes(app, zmq_context=context, shared_mem_dict=shared_mem_dict, shared_mem_dict_lock=shared_mem_dict_lock) + else: + additional_coro = None + + + return app, additional_coro + + + + + + + + + + + + class InterchangeServer(Process): - def __init__(self, experiment_dir: str, port: int, num_fastapi_process: int = 2, max_fastapi_threads: int = 512): + def __init__(self, experiment_dir: str, port: int, num_fastapi_process: int = 2, max_fastapi_threads: int = 512, enable_tinkerscript_mode=False): super().__init__() self.experiment_dir = experiment_dir self.port = port self.num_fastapi_process = num_fastapi_process self.max_fastapi_threads = max_fastapi_threads + self.enable_tinkerscript_mode = enable_tinkerscript_mode def run(self): logger.info(f"Starting Interchange Server on port {self.port} with {self.num_fastapi_process} processes and {self.max_fastapi_threads} threads per process.") - app = get_app(self.max_fastapi_threads) - async def serve_with_monitor(): + + if self.enable_tinkerscript_mode: + manager = Manager() + shared_mem_dict = manager.dict() + shared_mem_dict_lock = manager.Lock() + else: + shared_mem_dict = None + shared_mem_dict_lock = None + + app, additional_coro = get_app(self.max_fastapi_threads, self.enable_tinkerscript_mode, shared_mem_dict, shared_mem_dict_lock) + + async def serve_with_monitor(additional_coro): # Start the server config = uvicorn.Config( app=app, @@ -185,19 +228,37 @@ async def serve_with_monitor(): workers=self.num_fastapi_process ) server = uvicorn.Server(config) - await server.serve() + if additional_coro: + coro_task_1 = asyncio.create_task(additional_coro) + coro_task_2 = asyncio.create_task(server.serve()) + await asyncio.gather(coro_task_1, coro_task_2) + else: + await server.serve() try: - asyncio.run(serve_with_monitor()) + asyncio.run(serve_with_monitor(additional_coro)) except KeyboardInterrupt as e: SERVER_SHUTDOWN_EVENT.set() raise e + + + + + + + + + + + + # Convenience function for quick server startup def start_interchange_server(config) -> int: experiment_dir = config.ajet.experiment_dir num_fastapi_process = config.ajet.interchange_server.num_fastapi_process max_fastapi_threads = config.ajet.interchange_server.max_fastapi_threads + enable_tinkerscript_mode = config.ajet.enable_tinkerscript_mode # Find a free port if not specified or invalid port = int(os.environ.get("AJET_DAT_INTERCHANGE_PORT", -1)) @@ -211,7 +272,13 @@ def start_interchange_server(config) -> int: port = s.getsockname()[1] os.environ["AJET_DAT_INTERCHANGE_PORT"] = str(port) - interchange_server = InterchangeServer(experiment_dir, port, num_fastapi_process, max_fastapi_threads) + interchange_server = InterchangeServer( + experiment_dir, + port, + num_fastapi_process, + max_fastapi_threads, + enable_tinkerscript_mode, + ) interchange_server.start() # Wait for server to be ready diff --git a/ajet/tuner_lib/weight_tuner/experimental/as_tinkerscript_client.py b/ajet/tuner_lib/weight_tuner/experimental/as_tinkerscript_client.py new file mode 100644 index 00000000..4a423782 --- /dev/null +++ b/ajet/tuner_lib/weight_tuner/experimental/as_tinkerscript_client.py @@ -0,0 +1,124 @@ +import uuid +import time +import httpx +import yaml +from loguru import logger +from pydantic import BaseModel +from ajet.schema.task import WorkflowOutput +from ajet.copilot.job import AgentJetJob +from ajet.tuner_lib.weight_tuner.as_oai_baseurl_apikey import OpenaiBaseUrlAndApiKey + +# --- Schema Definitions --- + +class SyncTrainConfigRequest(BaseModel): + yaml_as_string: str + +class ClaimEpisodeRequest(BaseModel): + client_uuid: str + episode_type: str + +class ClaimEpisodeResponse(BaseModel): + success: bool + client_uuid: str + episode_uuid: str + openai_base_url: str = "" + openai_api_key: str = "" + fail_cause: str = "" + +class EndEpisodeRequest(BaseModel): + client_uuid: str + episode_uuid: str + workflow_output: WorkflowOutput + +class EndEpisodeResponse(BaseModel): + success: bool + +class TinkerScriptClient(object): + + def __init__(self, server_url: str): + self.server_url = server_url + self.client_uuid = str(uuid.uuid4()) + self.episode_uuid = None + self.openai_base_url = None + self.openai_api_key = None + + def begin_episode(self) -> OpenaiBaseUrlAndApiKey: + """ + Block until an episode is claimed. + Return (episode_uuid, openai_base_url, openai_api_key) + """ + while True: + try: + req_obj = ClaimEpisodeRequest( + client_uuid=self.client_uuid, + episode_type="default" + ) + resp = httpx.post( + f"{self.server_url}/claim_episode", + json=req_obj.model_dump(), + timeout=30 + ) + resp.raise_for_status() + data = ClaimEpisodeResponse.model_validate(resp.json()) + + if data.success: + self.episode_uuid = data.episode_uuid + self.openai_base_url = data.openai_base_url + self.openai_api_key = data.openai_api_key + logger.info(f"Claimed episode {self.episode_uuid}") + return OpenaiBaseUrlAndApiKey( + base_url=self.openai_base_url, + api_key=self.openai_api_key, + ) + else: + logger.info(f"Failed to claim episode: {data.fail_cause}. Retrying in 5s...") + time.sleep(5) + except Exception as e: + logger.error(f"Error claiming episode: {e}. Retrying in 5s...") + time.sleep(5) + + def end_episode(self, workflow_output: WorkflowOutput): + if not self.episode_uuid: + logger.error("No episode to end.") + return + + try: + req_obj = EndEpisodeRequest( + client_uuid=self.client_uuid, + episode_uuid=self.episode_uuid, + workflow_output=workflow_output + ) + + resp = httpx.post( + f"{self.server_url}/end_episode", + json=req_obj.model_dump(), + timeout=30 + ) + resp.raise_for_status() + data = EndEpisodeResponse.model_validate(resp.json()) + + if data.success: + logger.info(f"Ended episode {self.episode_uuid}") + self.episode_uuid = None + else: + logger.error(f"Failed to end episode {self.episode_uuid}") + + except Exception as e: + logger.error(f"Error ending episode: {e}") + + def sync_train_config(self, agent_jet_job: AgentJetJob): + try: + config_dict = agent_jet_job.config.to_dict() + yaml_str = yaml.safe_dump(config_dict, sort_keys=False) + + req_obj = SyncTrainConfigRequest(yaml_as_string=yaml_str) + + resp = httpx.post( + f"{self.server_url}/sync_train_config", + json=req_obj.model_dump(), + timeout=30 + ) + resp.raise_for_status() + logger.info("Synced train config") + except Exception as e: + logger.error(f"Error syncing train config: {e}") diff --git a/ajet/tuner_lib/weight_tuner/experimental/as_tinkerscript_server.py b/ajet/tuner_lib/weight_tuner/experimental/as_tinkerscript_server.py new file mode 100644 index 00000000..e793919a --- /dev/null +++ b/ajet/tuner_lib/weight_tuner/experimental/as_tinkerscript_server.py @@ -0,0 +1,232 @@ +from multiprocessing.managers import DictProxy +import threading + +import zmq + +from loguru import logger +from pydantic import BaseModel +from fastapi import FastAPI, HTTPException +from typing import List + +from typing import Coroutine, Optional, Tuple +from ajet.schema.task import WorkflowOutput + + +class SyncTrainConfigRequest(BaseModel): + yaml_as_string: str + +class ClaimEpisodeRequest(BaseModel): + client_uuid: str + episode_type: str + +class ClaimEpisodeResponse(BaseModel): + success: bool + client_uuid: str + episode_uuid: str + openai_base_url: str = "" + openai_api_key: str = "" + fail_cause: str = "" + +class CanContinueEpisodeRequest(BaseModel): + client_uuid: str + episode_uuid: str + +class CanContinueEpisodeResponse(BaseModel): + can_continue: bool + +class EndEpisodeRequest(BaseModel): + client_uuid: str + episode_uuid: str + workflow_output: WorkflowOutput + +class EndEpisodeResponse(BaseModel): + success: bool + + +class EpisodeStatus(BaseModel): + episode_uuid: str + episode_status: str = "rolling" + openai_base_url: str = "" + openai_api_key: str = "" + client_uuid: str = "" + zmq_listen_result_addr: str = "" + +class EpisodeBufferResponse(BaseModel): + buffer: List[EpisodeStatus] + + +class BoolResponse(BaseModel): + success: bool + +class RegisterEpisodeRequest(BaseModel): + episode_uuid: str + openai_base_url: str = "" + openai_api_key: str = "" + zmq_listen_result_addr: str = "" + + +class UpdateEngineStatusRequest(BaseModel): + engine_status: str = "" + + +DEBUG = True + +def register_enable_tinkerscript_mode_routes(app, zmq_context, shared_mem_dict:DictProxy, shared_mem_dict_lock:threading.Lock) -> Tuple[FastAPI, Optional[Coroutine]]: + + if 'episodes' not in shared_mem_dict: + shared_mem_dict["episodes"] = {} + + if 'unclaimed_episodes' not in shared_mem_dict: + shared_mem_dict['unclaimed_episodes'] = [] + + @app.post("/sync_train_config") + async def sync_train_config(req: SyncTrainConfigRequest): + # dummy: just print the yaml string + try: + print("[sync_train_config] received yaml:", req.yaml_as_string) + except Exception: + pass + return {"success": True} + + + # --- engine status --- + shared_mem_dict['engine_status'] = "booting" + @app.post("/update_engine_status", response_model=BoolResponse) + async def update_engine_status(req: UpdateEngineStatusRequest): + shared_mem_dict['engine_status'] = req.engine_status + return BoolResponse(success=True) + + + @app.get("/get_engine_status") + async def get_engine_status(): + status = shared_mem_dict['engine_status'] + return {"engine_status": status} + + + # --- episode status --- + @app.post("/register_episode", response_model=BoolResponse) + async def register_episode(req: RegisterEpisodeRequest): + + episode_uuid = req.episode_uuid + es = EpisodeStatus( + episode_uuid=req.episode_uuid, + openai_base_url=req.openai_base_url, + openai_api_key=req.openai_api_key, + episode_status="registered", + zmq_listen_result_addr=req.zmq_listen_result_addr, + ) + + with shared_mem_dict_lock: + shared_mem_dict[f"episodes-{episode_uuid}"] = es + shared_mem_dict['unclaimed_episodes'] += [req.episode_uuid] + + return BoolResponse( + success=True, + ) + + + @app.post("/claim_episode", response_model=ClaimEpisodeResponse) + async def claim_episode(req: ClaimEpisodeRequest): + # placeholder implementation — real logic should check episode_semaphore + + with shared_mem_dict_lock: + if len(shared_mem_dict['unclaimed_episodes']) <= 0: + return ClaimEpisodeResponse( + success=False, + client_uuid=req.client_uuid, + episode_uuid="", + openai_base_url="", + openai_api_key="", + fail_cause="No available episodes to claim. Try again (maybe 1 minute) later.", + ) + + # hint: do not optimize this + episode_uuid = shared_mem_dict['unclaimed_episodes'][0] + shared_mem_dict['unclaimed_episodes'] = shared_mem_dict['unclaimed_episodes'][1:] + + # get episode + es:EpisodeStatus = shared_mem_dict[f"episodes-{episode_uuid}"] + es.episode_status = "claimed" + es.client_uuid = req.client_uuid + shared_mem_dict[f"episodes-{episode_uuid}"] = es + openai_base_url = es.openai_base_url + openai_api_key = es.openai_api_key + + + return ClaimEpisodeResponse( + success=True, + client_uuid=req.client_uuid, + episode_uuid=episode_uuid, + openai_base_url=openai_base_url, + openai_api_key=openai_api_key, + fail_cause="", + ) + + + + @app.post("/end_episode", response_model=EndEpisodeResponse) + async def end_episode(req: EndEpisodeRequest): + # receive workflow output data + client_uuid = req.client_uuid + episode_uuid = req.episode_uuid + workflow_output = req.workflow_output + + if 'episodes' not in shared_mem_dict: + raise HTTPException(status_code=400, detail=f"No episodes registered yet.") + if (f"episodes-{episode_uuid}") not in shared_mem_dict: + raise HTTPException(status_code=400, detail=f"Episode {episode_uuid} not found.") + + # send workflow_output to zmq + assert 'episodes' in shared_mem_dict + zmq_addr = shared_mem_dict[f"episodes-{episode_uuid}"].zmq_listen_result_addr + if DEBUG: logger.info(f"[server] episode_uuid: {episode_uuid} | Received new chat completion request (inside thread)") + socket = zmq_context.socket(zmq.REQ) + socket.setsockopt(zmq.RCVTIMEO, 60*1000) # 1 minute recv timeout + socket.connect(zmq_addr) + if DEBUG: logger.info(f"[server] episode_uuid: {episode_uuid} | connect done") + socket.send_string(workflow_output.model_dump_json()) + if DEBUG: logger.info(f"[server] episode_uuid: {episode_uuid} | send_string") + + # wait for ack + for _ in range(5): # max 5 minutes wait + try: + if DEBUG: logger.info(f"[server] episode_uuid: {episode_uuid} | recv_string begin.") + result_str = socket.recv_string() + break + except zmq.Again as e: + if DEBUG: logger.info(f"[server] episode_uuid: {episode_uuid} | recv_string timeout, retrying.") + continue + + # clean up episode records + with shared_mem_dict_lock: + del shared_mem_dict[f"episodes-{episode_uuid}"] + if episode_uuid in shared_mem_dict['unclaimed_episodes']: + shared_mem_dict['unclaimed_episodes'].remove(episode_uuid) + + # return success + return EndEpisodeResponse(success=True) + + + + @app.post("/can_continue_episode", response_model=CanContinueEpisodeResponse) + async def can_continue_episode(req: CanContinueEpisodeRequest): + can_continue = (f"episodes-{req.episode_uuid}" in shared_mem_dict) + can_continue = can_continue and shared_mem_dict[f"episodes-{req.episode_uuid}"]["episode_status"] == "claimed" + return CanContinueEpisodeResponse(can_continue=can_continue) + + + + @app.post("/get_episode_buffer", response_model=EpisodeBufferResponse) + async def get_episode_buffer(): + result = [ + v for k, v in shared_mem_dict.items() if k.startswith("episodes-") + ] + return EpisodeBufferResponse(buffer=result) + + + + async def register_episode_ready_listener(): + pass + + + return app, register_episode_ready_listener() diff --git a/ajet/utils/config_utils.py b/ajet/utils/config_utils.py index e9dd5d3b..2bbae3f9 100644 --- a/ajet/utils/config_utils.py +++ b/ajet/utils/config_utils.py @@ -253,7 +253,7 @@ def prepare_experiment_config(yaml_path, exp_dir, backbone): tuple: (yaml_backup_dst, exe_exp_base, exp_name, config_final) """ assert yaml_path.endswith(".yaml"), "Configuration file must be a YAML file" - exp_base = os.path.dirname(yaml_path) + exp_base = os.path.exists(os.path.dirname(yaml_path)) if not os.path.exists(exp_base): raise FileNotFoundError(f"Configuration file not found: {exp_base}") diff --git a/ajet/utils/core_env_vars.py b/ajet/utils/core_env_vars.py index 91fdf736..ee1dbd82 100644 --- a/ajet/utils/core_env_vars.py +++ b/ajet/utils/core_env_vars.py @@ -18,6 +18,11 @@ def get_runtime_env(config, is_trinity: bool = False) -> dict: if config.ajet.interchange_server.interchange_method == "ipc": raise ValueError("IPC interchange method is not supported for multi-node setup. Please set `ajet.interchange_server.interchange_method: tcp` ") + if config.ajet.interchange_server.interchange_server_port != 'auto': + data_interchange_port = str(int(config.ajet.interchange_server.interchange_server_port)) + else: + data_interchange_port = str(find_free_port()) + runtime_env = { "env_vars": { "VLLM_USE_V1": "1", @@ -29,7 +34,8 @@ def get_runtime_env(config, is_trinity: bool = False) -> dict: # "VLLM_ALLOW_RUNTIME_LORA_UPDATING": "true", "SWANLAB_API_KEY": os.getenv("SWANLAB_API_KEY", ""), "AJET_CONFIG_REDIRECT": os.getenv("AJET_CONFIG_REDIRECT", ""), - "AJET_DAT_INTERCHANGE_PORT": str(find_free_port()), + "AJET_DAT_INTERCHANGE_PORT": data_interchange_port, + "AJET_DAT_INTERCHANGE_ZMQ_PORT": str(find_free_port()), "MASTER_NODE_IP": master_node_ip, } } diff --git a/ajet_tinkerscript.md b/ajet_tinkerscript.md new file mode 100644 index 00000000..f7bc610c --- /dev/null +++ b/ajet_tinkerscript.md @@ -0,0 +1 @@ +python -m ajet.launcher --conf tutorial/demo_tinkerjet/ajet_tinkerscript_default.yaml --backbone="debug" --autokill diff --git a/ajet_tinkerscript.py b/ajet_tinkerscript.py new file mode 100644 index 00000000..7577364d --- /dev/null +++ b/ajet_tinkerscript.py @@ -0,0 +1,90 @@ +import re +import requests +from textwrap import dedent +from ajet.tuner_lib.weight_tuner.experimental.as_tinkerscript_client import TinkerScriptClient +from ajet.default_config.ajet_default import AjetTaskReader, HuggingfaceDatRepo +from ajet.tuner_lib.weight_tuner.as_oai_baseurl_apikey import OpenaiBaseUrlAndApiKey +from ajet import WorkflowOutput +from ajet.task_reader import RouterTaskReader +from ajet.utils.retry import retry_with_backoff + + +TINKERJET_URL = "http://localhost:10086" # Change to your tinkerjet remote url +NUM_EPOCH = 100 +GRPO_N = 4 # grpo group size + + +class WeightUpdatedHalfway(Exception): + """Raised when the remote side starts updating model weights halfway through an episode.""" + + +def main(): + + # Handshake with tinkerjet remote, then send training param to tinkerjet remote (such as model to be trained, algorithm, etc) + tinkerjet_remote = TinkerScriptClient(TINKERJET_URL) + dataset = RouterTaskReader( + reader_type = "huggingface_dat_repo", + reader_config = AjetTaskReader( + huggingface_dat_repo = HuggingfaceDatRepo( + dataset_path = "/mnt/data_cpfs/qingxu.fu/dataset/openai/gsm8k/main" + ) + ) + ) + + # Define rollout + def rollout(task): + # Q: Can I run episodes in parallel? + # A: Yes, wrap `rollout` in a thread or process pool. + api_baseurl_key = tinkerjet_remote.begin_episode() + workflow_output = execute_agent(task, api_baseurl_key) + tinkerjet_remote.end_episode(workflow_output) + return workflow_output.reward + + # Main Training loop + for epoch in range(NUM_EPOCH): + for task in dataset.get_training_tasks(): + try: + for i in range(GRPO_N): + reward = rollout(task) + print(f"{epoch}-{task}-run:{i}-{reward}") + except WeightUpdatedHalfway as e: + print(f"The remote side has gone into the LLM model weight update phrase halfway through an episode." + f"This is **normal**." + f"The remote no longer need this task anymore, so let's go to next task.") + + # Get tuned model from tinkerjet remote + return None + + + + +@retry_with_backoff(max_retry=2) +def execute_agent(task, api_baseurl_key: OpenaiBaseUrlAndApiKey): + # Prepare base_url, api_key + base_url, api_key = (api_baseurl_key.base_url, api_baseurl_key.api_key) + # Read dataset item + query, reference_answer = (task.main_query, task.metadata["answer"]) + # Prepare messages + messages = [ + { "role": "system", "content": dedent("""You are an agent specialized in solving math problems. Please solve the math problem given to you. + You can write and execute Python code to perform calculation or verify your answer. You should return your final answer within \\boxed{{}}.""") }, + { "role": "user", "content": query } + ] + # Use raw http requests (non-streaming) to get response + response = requests.post( f"{base_url}/chat/completions", json = { "model": "fill_whatever_model", "messages": messages, }, + headers = { "Authorization": f"Bearer {api_key}" } ) + print(response.json()) + final_answer = response.json()['choices'][0]['message']['content'] + # Compute reward + reference_answer = reference_answer.split("####")[-1].strip() + pattern = r"\\boxed\{([^}]*)\}" + match = re.search(pattern, final_answer) + if match: is_success = match.group(1) == reference_answer + else: is_success = False + raw_reward = 1.0 if is_success else 0.0 + # Return + return WorkflowOutput(reward=raw_reward, metadata={"final_answer": final_answer}) + + +if __name__ == "__main__": + main() diff --git a/docs/en/workflow.md b/docs/en/workflow.md index 1137f02c..94cdc825 100644 --- a/docs/en/workflow.md +++ b/docs/en/workflow.md @@ -241,7 +241,7 @@ Here's a complete example with multiple agent roles (Werewolves game): - You can flexibly switch training targets by modifying `trainable_targets` -## TinkerJet +## TinkerScript Wrapping and training your agent on a machine without GPU. diff --git a/tutorial/demo_tinkerjet/README.md b/tutorial/demo_tinkerjet/README.md new file mode 100644 index 00000000..10985526 --- /dev/null +++ b/tutorial/demo_tinkerjet/README.md @@ -0,0 +1,64 @@ +# TinkerScript + + +TinkerScript is an experimental component of AgentJet, +allowing users to +- run, debug and train **full-weight** LLM model behind user-defined LLM workflows in **machines without GPU**. + +Similar to Tinker & Open-Tinker, the basic idea behind TinkerScript is to: +- use remote (or cloud) GPU machine(s) as computation media. + +However, TinkerScript goes even further on this path: + +- Users only need to write and run their agents in a big `while` loop (e.g., in their laptop), and provide samples generated in this process. + +- TinkerScript will take care of everything else. + +- TinkerScript trains **full-weight** LLM model instead of lora. + +- Upon the termination of the training session, user can call `download_tuned_model` to download tuned LLM(s). + + +# Core Training Code + +The core code at user-side is as simple as: + +```python + +# step 1: ... write user-defined `execute_agent` +# step 2: ... init `tinkerjet_remote` to handshake with remote GPU server +# step 3: ... define hyper-parameters `NUM_EPOCH`, `GRPO_N` +# step 4: ... spawn `dataset` from dataset file + +# step 5: rock & roll +## rollout +def rollout(task): + try: + api_baseurl_key = tinkerjet_remote.begin_episode() + workflow_output = execute_agent(task, api_baseurl_key) + tinkerjet_remote.end_episode(workflow_output) + return workflow_output.reward + except Exception as e: + print(f"Episode abandoned") + return 0.0 +## Main Training loop +for epoch in range(NUM_EPOCH): + for task in dataset.get_training_tasks(): + for i in range(GRPO_N): + reward = rollout(task) + print(f"{epoch}-{task}-run:{i}-{reward}") + +# step 6: get trained model and shutdown +tuned_model_checkpoint = tinkerjet_remote.download_tuned_model() +tinkerjet_remote.close() + +``` + +# Limitation + +- Users are only limited to use OpenAI `baseurl` + `apikey` to build applications. Features such as `tuner.as_agentscope_model` is no longer available. + +- AgentJet are not able to explicitly distinguish different agents in multi-agent scenario. + But **do not worry**, AgentJet will still try its best to recognize shards of llm timelines and merge them behind the curtain, automatically. + +- TinkerScript does not support prompt tuning. diff --git a/tutorial/demo_tinkerjet/ajet_tinkerscript_default.yaml b/tutorial/demo_tinkerjet/ajet_tinkerscript_default.yaml new file mode 100644 index 00000000..c6913470 --- /dev/null +++ b/tutorial/demo_tinkerjet/ajet_tinkerscript_default.yaml @@ -0,0 +1,49 @@ +# ------------------ main configuration ------------------ +ajet: + project_name: "ajet_default_project" + experiment_name: "read_yaml_name" + experiment_dir: "auto" # {exp-dir}/{experiment_name} + backbone: debug # `debug` or `trinity` or `verl` + + model: + # which model should be trained + path: /mnt/data_cpfs/model_cache/modelscope/hub/Qwen/Qwen/Qwen2___5-14B-Instruct + + rollout: + # the path to the workflow class + user_workflow: null + + task_reader: + type: random_dummy # `env_service` or `jsonl_dataset_file` or `huggingface_dat_repo` or `data_generation` or `random_dummy` + + task_judge: + judge_type: customized_protocol # Options: 'customized_protocol', 'rubrics_auto_grader' + judge_protocol: null + + # the experimental ZeroMQ interchange server feature that allows `tuner.as_oai_baseurl_apikey` feature + enable_experimental_interchange_server: True + # train in cloud, run episode locally + enable_tinkerscript_mode: True + # both tinkerscript / oai share the same interchange server + interchange_server: + interchange_method: 'ipc' # options: 'tcp' (multi-nodes) or 'ipc' (1 node) + interchange_server_port: 10086 + num_fastapi_process: 2 # 1, 2 or 4 is fine + max_fastapi_threads: 128 # 64 or 128 is fine + max_inference_tracker_threads: 64 # recommend to be equal to `ajet.rollout.max_env_worker` + + + +# ------------------ 不需要修改 ------------------ +hydra: + searchpath: + - file://ajet/default_config + - file://ajet/default_config/verl # verl only + - file://ajet/default_config/trinity # trinity only + +# ------------------ 不需要修改 ------------------ +defaults: + - verl_default # verl inherit 1/1 + - trinity_default # trinity inherit 1/1 + - ajet_default + - _self_ diff --git a/tutorial/demo_tinkerjet/demo_tinkerjet_math.py b/tutorial/demo_tinkerjet/demo_tinkerjet_math.py new file mode 100644 index 00000000..a084cf51 --- /dev/null +++ b/tutorial/demo_tinkerjet/demo_tinkerjet_math.py @@ -0,0 +1,87 @@ +import re +import requests +from textwrap import dedent +from ajet import AgentJetJob +from ajet.copilot.tinkerjet.remote import TinkerScriptRemote +from ajet.default_config.ajet_default import AjetTaskReader, HuggingfaceDatRepo +from ajet.tuner_lib.weight_tuner.as_oai_baseurl_apikey import AgentJetAsOpenAI +from ajet import WorkflowOutput +from ajet.task_reader import RouterTaskReader +from ajet.utils.retry import retry_with_backoff +TINKERJET_URL = "http://localhost:10086" # Change to your tinkerjet remote url +NUM_EPOCH = 100 +GRPO_N = 4 # grpo group size + +class WeightUpdatedHalfway(Exception): + """Raised when the remote side starts updating model weights halfway through an episode.""" + +def main(): + # Handshake with tinkerjet remote, then send training param to tinkerjet remote (such as model to be trained, algorithm, etc) + tinkerjet_remote = TinkerScriptRemote(TINKERJET_URL) + tinkerjet_remote.sync_train_config( + AgentJetJob(backbone="verl", n_gpu=2, algorithm="grpo", model='qwen/Qwen2.5-1.5B-instruct') + ) + + # Dataset reader (read in your local machine only) + dataset = RouterTaskReader( + reader_type = "huggingface_dat_repo", + reader_config = AjetTaskReader( + huggingface_dat_repo = HuggingfaceDatRepo( dataset_path = "openai/gsm8k" ) + ) + ) + + # Define rollout + def rollout(task): + # Q: Can I run episodes in parallel? + # A: Yes, wrap `rollout` in a thread or process pool. + api_baseurl_key = tinkerjet_remote.begin_episode() + workflow_output = execute_agent(task, api_baseurl_key) + tinkerjet_remote.end_episode(workflow_output) + return workflow_output.reward + + # Main Training loop + for epoch in range(NUM_EPOCH): + for task in dataset.get_training_tasks(): + try: + for i in range(GRPO_N): + reward = rollout(task) + print(f"{epoch}-{task}-run:{i}-{reward}") + except WeightUpdatedHalfway as e: + print(f"The remote side has gone into the LLM model weight update phrase halfway through an episode." + f"This is **normal**." + f"The remote no longer need this task anymore, so let's go to next task.") + + # Get tuned model from tinkerjet remote + tuned_model_checkpoint = tinkerjet_remote.download_tuned_model() + return tuned_model_checkpoint + + +@retry_with_backoff(max_retry=2) +def execute_agent(task, api_baseurl_key: AgentJetAsOpenAI): + # Prepare base_url, api_key + base_url, api_key = (api_baseurl_key.base_url, api_baseurl_key.api_key) + # Read dataset item + query, reference_answer = (task.main_query, task.metadata["answer"]) + # Prepare messages + messages = [ + { "role": "system", "content": dedent("""You are an agent specialized in solving math problems. Please solve the math problem given to you. + You can write and execute Python code to perform calculation or verify your answer. You should return your final answer within \\boxed{{}}.""") }, + { "role": "user", "content": query } + ] + # Use raw http requests (non-streaming) to get response + response = requests.post( f"{base_url}/chat/completions", json = { "model": "fill_whatever_model", "messages": messages, }, + headers = { "Authorization": f"Bearer {api_key}" } ) + final_answer = response.json()['choices'][0]['message']['content'] + # Compute reward + reference_answer = reference_answer.split("####")[-1].strip() + pattern = r"\\boxed\{([^}]*)\}" + match = re.search(pattern, final_answer) + if match: is_success = match.group(1) == reference_answer + else: is_success = False + raw_reward = 1.0 if is_success else 0.0 + # Return + return WorkflowOutput(reward=raw_reward, metadata={"final_answer": final_answer}) + + +if __name__ == "__main__": + main() From 21f9bb8f788150980a5bcfbb90e98c0b25d7c803 Mon Sep 17 00:00:00 2001 From: "qingxu.fu" Date: Tue, 20 Jan 2026 00:39:05 +0800 Subject: [PATCH 02/25] improve tinkerscript --- ajet/backbone/main_vllm.py | 3 + ajet/backbone/trainer_verl.py | 9 ++ ajet/task_runner/tinkerscript_runner.py | 59 ++------ .../weight_tuner/as_oai_baseurl_apikey.py | 1 + .../experimental/as_oai_model_client.py | 12 +- .../experimental/as_oai_model_server.py | 17 +++ .../experimental/as_tinkerscript_client.py | 120 ++++++++++------ .../experimental/as_tinkerscript_server.py | 125 ++++++++-------- .../experimental/interchange_utils.py | 136 ++++++++++++++++++ ajet_tinkerscript.py | 2 +- ajet_tinkerscript_threading.py | 94 ++++++++++++ 11 files changed, 413 insertions(+), 165 deletions(-) create mode 100644 ajet/tuner_lib/weight_tuner/experimental/interchange_utils.py create mode 100644 ajet_tinkerscript_threading.py diff --git a/ajet/backbone/main_vllm.py b/ajet/backbone/main_vllm.py index f610b6ce..7e63e216 100644 --- a/ajet/backbone/main_vllm.py +++ b/ajet/backbone/main_vllm.py @@ -189,6 +189,9 @@ def main(config): if config.ajet.enable_experimental_interchange_server: from ajet.tuner_lib.weight_tuner.experimental.as_oai_model_server import start_interchange_server start_interchange_server(config) + if config.ajet.enable_tinkerscript_mode: + from ajet.tuner_lib.weight_tuner.experimental.interchange_utils import http_change_engine_status + http_change_engine_status(config, "ROLLING") def companion_launch(): import torch diff --git a/ajet/backbone/trainer_verl.py b/ajet/backbone/trainer_verl.py index 5b9d0853..a48a6a08 100644 --- a/ajet/backbone/trainer_verl.py +++ b/ajet/backbone/trainer_verl.py @@ -443,6 +443,13 @@ def init_workers(self): tokenizer=self.tokenizer, ) + def _update_interchange_server_status_flag(self, status: str): + # if interchange server is enabled, change engine status to ROLLING + if self.config.ajet.enable_experimental_interchange_server: + if self.config.ajet.enable_tinkerscript_mode: + from ajet.tuner_lib.weight_tuner.experimental.interchange_utils import http_change_engine_status + http_change_engine_status(self.config, status) + # ####################################### # training loop # ####################################### @@ -552,6 +559,7 @@ def fit(self): # noqa: C901 assert self.async_rollout_mode logger.info("=== wake up begin ===") self.async_rollout_manager.wake_up() + self._update_interchange_server_status_flag("ROLLING") logger.info("=== wake up end ===") tasks: List[Task] = [ dict_to_ajet_task(dict( @@ -577,6 +585,7 @@ def fit(self): # noqa: C901 tasks, mode="sample", epoch=f"train.{epoch}" ) logger.info("=" * 10 + "end fit rollout" + "=" * 10) + self._update_interchange_server_status_flag("UPDATE_WEIGHT") logger.info("begin to convert context_tracker_arr to dataproto") gen_batch_output = self.parallel_env.to_dataproto(context_tracker_arr) logger.info("end convertion") diff --git a/ajet/task_runner/tinkerscript_runner.py b/ajet/task_runner/tinkerscript_runner.py index fbd617c8..b9303b29 100644 --- a/ajet/task_runner/tinkerscript_runner.py +++ b/ajet/task_runner/tinkerscript_runner.py @@ -1,10 +1,8 @@ import atexit import json -import requests import zmq import os -import time from ajet import AjetTuner from ajet import WorkflowOutput from ajet.context_tracker.multiagent_tracking import ( @@ -14,70 +12,28 @@ from ajet.schema.task import WorkflowTask from ajet.schema.trajectory import Reward from ajet.task_runner.base_runner import BaseAgentRunner -from ajet.utils.networking import find_free_port +from ajet.tuner_lib.weight_tuner.experimental.interchange_utils import http_register_episode, get_zmq_socket from loguru import logger from ajet import Workflow context = zmq.Context() atexit.register(context.term) +DEBUG = True class TinkerScriptRunner(BaseAgentRunner): - def get_zmq_socket(self, episode_uuid: str): - interchange_method = self.config.ajet.interchange_server.interchange_method - if interchange_method == 'tcp': - master_node_ip = os.getenv("MASTER_NODE_IP", "localhost") - episode_contect_address = f"tcp://{master_node_ip}:{find_free_port()}" - elif interchange_method == 'ipc': - ipc_path = f"/tmp/ajet/{episode_uuid}-workflow.sock" - episode_contect_address = f"ipc://{ipc_path}" - else: - raise RuntimeError(f"Unknown interchange_method: {interchange_method}") - return episode_contect_address - - - def get_interchange_server_url(self): - port = os.getenv("AJET_DAT_INTERCHANGE_PORT") - if self.config.ajet.interchange_server.interchange_server_port != 'auto': - port = str(int(self.config.ajet.interchange_server.interchange_server_port)) - assert port is not None, "AJET_DAT_INTERCHANGE_PORT env var must be set" - master_node_ip = os.getenv("MASTER_NODE_IP", "localhost") - base_url = f"http://{master_node_ip}:{port}" - return base_url - - def register_episode_and_wait_output(self, episode_uuid: str, openai_base_url: str, openai_api_key: str) -> WorkflowOutput: """Register the episode as ready in the TinkerScript data interchange center.""" - from ajet.tuner_lib.weight_tuner.experimental.as_tinkerscript_server import RegisterEpisodeRequest - # parse episode_uuid, openai_base_url, openai_api_key - zmq_listen_result_addr = self.get_zmq_socket(episode_uuid) - interchange_http_addr = self.get_interchange_server_url() - rer = RegisterEpisodeRequest( + zmq_listen_result_addr, ipc_path = get_zmq_socket(self.config, episode_uuid, tag="workflow") + http_register_episode( + self.config, episode_uuid=episode_uuid, openai_base_url=openai_base_url, openai_api_key=openai_api_key, zmq_listen_result_addr=zmq_listen_result_addr, ) - logger.info(f"zmq_listen_result_addr: {zmq_listen_result_addr}, interchange_http_addr: {interchange_http_addr}") - - # send http request to tinkerscript server to register episode - while True: - try: - response = requests.post( - f"{interchange_http_addr}/register_episode", - json=rer.model_dump(), # 或者 rer.model_dump() 如果使用 Pydantic v2 - timeout=30 - ) - response.raise_for_status() - result = response.json() - if not result.get('success'): - raise RuntimeError(f"Failed to register episode {episode_uuid}") - logger.info(f"Successfully registered episode {episode_uuid}") - break - except requests.RequestException as e: - logger.error(f"Error registering episode {episode_uuid}: {e}. Retrying...") - time.sleep(5) + logger.info(f"zmq_listen_result_addr: {zmq_listen_result_addr}") # begin wait for result zmq_socket = zmq.Context().socket(zmq.REP) @@ -85,6 +41,9 @@ def register_episode_and_wait_output(self, episode_uuid: str, openai_base_url: s message = zmq_socket.recv_string() logger.success(f"Received workflow output for episode {episode_uuid}") zmq_socket.send_string("ack") + zmq_socket.close() + if ipc_path and os.path.exists(ipc_path): os.remove(ipc_path) + return WorkflowOutput(**json.loads(message)) diff --git a/ajet/tuner_lib/weight_tuner/as_oai_baseurl_apikey.py b/ajet/tuner_lib/weight_tuner/as_oai_baseurl_apikey.py index 925e2a11..3edf46c8 100644 --- a/ajet/tuner_lib/weight_tuner/as_oai_baseurl_apikey.py +++ b/ajet/tuner_lib/weight_tuner/as_oai_baseurl_apikey.py @@ -27,6 +27,7 @@ class OpenaiBaseUrlAndApiKey(BaseModel): base_url: str = Field(default="http://localhost:27788/v1", description="The base URL for the Ajet's fake OpenAI API") api_key: str = Field(default="invalid_apikey", description="The Ajet's fake key, which is not a real key, it is a encoded string contain episode_uuid and other stuff.") model: str = Field(default="reserved_field", description="reserved field.") + episode_uuid: str = Field(default="episode_id", description="reserved field.") class OpenaiClientBaseUrlTuner(BaseModel): diff --git a/ajet/tuner_lib/weight_tuner/experimental/as_oai_model_client.py b/ajet/tuner_lib/weight_tuner/experimental/as_oai_model_client.py index 5faf1add..67d9bfb6 100644 --- a/ajet/tuner_lib/weight_tuner/experimental/as_oai_model_client.py +++ b/ajet/tuner_lib/weight_tuner/experimental/as_oai_model_client.py @@ -14,7 +14,7 @@ from openai.types.chat.chat_completion import ChatCompletion from ajet.tuner_lib.weight_tuner.experimental.as_oai_model_server import InterchangeCompletionRequest, API_KEY_PREFIX from ajet.utils.thread_executors import SharedInferenceTrackerThreadExecutor, SharedInterchangeThreadExecutor -from ajet.utils.networking import find_free_port +from ajet.tuner_lib.weight_tuner.experimental.interchange_utils import get_zmq_socket context = zmq.Context() atexit.register(context.term) @@ -67,17 +67,11 @@ def __init__(self, episode_uuid: str, context_tracker: "MultiAgentContextTracker self.llm_inference_fn = llm_inference_fn self.config = config self._should_terminate = False - + self.episode_contect_address, ipc_path = get_zmq_socket(config, episode_uuid, tag="llm") + self.ipc_path = ipc_path self.interchange_method = config.ajet.interchange_server.interchange_method - if self.interchange_method == 'tcp': - master_node_ip = os.getenv("MASTER_NODE_IP", "localhost") - self.episode_contect_address = f"tcp://{master_node_ip}:{find_free_port()}" - elif self.interchange_method == 'ipc': - self.ipc_path = f"/tmp/ajet/{self.episode_uuid}.sock" - self.episode_contect_address = f"ipc://{self.ipc_path}" self.max_inference_tracker_threads = config.ajet.interchange_server.max_inference_tracker_threads - async def llm_infer( self, req: ChatCompletionRequest, diff --git a/ajet/tuner_lib/weight_tuner/experimental/as_oai_model_server.py b/ajet/tuner_lib/weight_tuner/experimental/as_oai_model_server.py index fc017bca..98aa4192 100644 --- a/ajet/tuner_lib/weight_tuner/experimental/as_oai_model_server.py +++ b/ajet/tuner_lib/weight_tuner/experimental/as_oai_model_server.py @@ -33,6 +33,8 @@ from vllm.entrypoints.openai.protocol import ChatCompletionRequest from openai.types.chat.chat_completion import ChatCompletion +from ajet.tuner_lib.weight_tuner.experimental.interchange_utils import EpisodeStatus + API_KEY_PREFIX = "sk-ajet-" class InterchangeCompletionRequest(BaseModel): @@ -151,6 +153,21 @@ async def chat_completions(request: Request, authorization: str = Header(None)): # Create timeline UUID timeline_uuid = uuid.uuid4().hex + # enable_tinkerscript_mode + if enable_tinkerscript_mode: + assert shared_mem_dict is not None + assert shared_mem_dict_lock is not None + if shared_mem_dict['engine_status'] != "ROLLING": + logger.error(f"The server is not in ROLLING status (current status: [{shared_mem_dict['engine_status']}]), cannot accept new requests.") + raise HTTPException(status_code=503, detail="The server is not in ROLLING status, cannot accept new requests.") + if (f"episodes-{episode_uuid}") not in shared_mem_dict: + raise HTTPException(status_code=404, detail=f"Episode {episode_uuid} not found.") + # update activate timestamp + with shared_mem_dict_lock: + es:EpisodeStatus = shared_mem_dict[f"episodes-{episode_uuid}"] + es.latest_activity_timestamp = time.time() + shared_mem_dict[f"episodes-{episode_uuid}"] = es + # Add to received queue int_req = InterchangeCompletionRequest( completion_request = new_req, diff --git a/ajet/tuner_lib/weight_tuner/experimental/as_tinkerscript_client.py b/ajet/tuner_lib/weight_tuner/experimental/as_tinkerscript_client.py index 4a423782..581042c8 100644 --- a/ajet/tuner_lib/weight_tuner/experimental/as_tinkerscript_client.py +++ b/ajet/tuner_lib/weight_tuner/experimental/as_tinkerscript_client.py @@ -2,47 +2,32 @@ import time import httpx import yaml +from typing import List, Tuple from loguru import logger -from pydantic import BaseModel from ajet.schema.task import WorkflowOutput from ajet.copilot.job import AgentJetJob from ajet.tuner_lib.weight_tuner.as_oai_baseurl_apikey import OpenaiBaseUrlAndApiKey +from ajet.tuner_lib.weight_tuner.experimental.interchange_utils import ( + SyncTrainConfigRequest, + ClaimEpisodeRequest, + ClaimEpisodeResponse, + CanContinueEpisodeRequest, + CanContinueEpisodeResponse, + EndEpisodeRequest, + EndEpisodeResponse, + EpisodeStatus, + EpisodeBufferResponse, +) -# --- Schema Definitions --- - -class SyncTrainConfigRequest(BaseModel): - yaml_as_string: str - -class ClaimEpisodeRequest(BaseModel): - client_uuid: str - episode_type: str - -class ClaimEpisodeResponse(BaseModel): - success: bool - client_uuid: str - episode_uuid: str - openai_base_url: str = "" - openai_api_key: str = "" - fail_cause: str = "" - -class EndEpisodeRequest(BaseModel): - client_uuid: str - episode_uuid: str - workflow_output: WorkflowOutput - -class EndEpisodeResponse(BaseModel): - success: bool class TinkerScriptClient(object): def __init__(self, server_url: str): self.server_url = server_url self.client_uuid = str(uuid.uuid4()) - self.episode_uuid = None - self.openai_base_url = None - self.openai_api_key = None - def begin_episode(self) -> OpenaiBaseUrlAndApiKey: + + def begin_episode(self, allow_discard_timeout=60) -> Tuple[str, OpenaiBaseUrlAndApiKey]: """ Block until an episode is claimed. Return (episode_uuid, openai_base_url, openai_api_key) @@ -51,7 +36,8 @@ def begin_episode(self) -> OpenaiBaseUrlAndApiKey: try: req_obj = ClaimEpisodeRequest( client_uuid=self.client_uuid, - episode_type="default" + episode_type="default", + allow_discard_timeout=allow_discard_timeout, ) resp = httpx.post( f"{self.server_url}/claim_episode", @@ -60,15 +46,17 @@ def begin_episode(self) -> OpenaiBaseUrlAndApiKey: ) resp.raise_for_status() data = ClaimEpisodeResponse.model_validate(resp.json()) + episode_uuid = data.episode_uuid if data.success: - self.episode_uuid = data.episode_uuid - self.openai_base_url = data.openai_base_url - self.openai_api_key = data.openai_api_key - logger.info(f"Claimed episode {self.episode_uuid}") - return OpenaiBaseUrlAndApiKey( - base_url=self.openai_base_url, - api_key=self.openai_api_key, + episode_uuid = data.episode_uuid + openai_base_url = data.openai_base_url + openai_api_key = data.openai_api_key + logger.info(f"Claimed episode {episode_uuid}") + return episode_uuid, OpenaiBaseUrlAndApiKey( + base_url=openai_base_url, + api_key=openai_api_key, + episode_uuid=episode_uuid ) else: logger.info(f"Failed to claim episode: {data.fail_cause}. Retrying in 5s...") @@ -77,15 +65,15 @@ def begin_episode(self) -> OpenaiBaseUrlAndApiKey: logger.error(f"Error claiming episode: {e}. Retrying in 5s...") time.sleep(5) - def end_episode(self, workflow_output: WorkflowOutput): - if not self.episode_uuid: + def end_episode(self, episode_uuid: str, workflow_output: WorkflowOutput): + if not episode_uuid: logger.error("No episode to end.") return try: req_obj = EndEpisodeRequest( client_uuid=self.client_uuid, - episode_uuid=self.episode_uuid, + episode_uuid=episode_uuid, workflow_output=workflow_output ) @@ -98,10 +86,9 @@ def end_episode(self, workflow_output: WorkflowOutput): data = EndEpisodeResponse.model_validate(resp.json()) if data.success: - logger.info(f"Ended episode {self.episode_uuid}") - self.episode_uuid = None + logger.info(f"Ended episode {episode_uuid}") else: - logger.error(f"Failed to end episode {self.episode_uuid}") + logger.error(f"Failed to end episode {episode_uuid}") except Exception as e: logger.error(f"Error ending episode: {e}") @@ -122,3 +109,50 @@ def sync_train_config(self, agent_jet_job: AgentJetJob): logger.info("Synced train config") except Exception as e: logger.error(f"Error syncing train config: {e}") + + def get_engine_status(self) -> str: + try: + resp = httpx.get( + f"{self.server_url}/get_engine_status", + timeout=10 + ) + resp.raise_for_status() + return resp.json().get("engine_status", "unknown") + except Exception as e: + logger.error(f"Error getting engine status: {e}") + return "unknown" + + def can_continue_episode(self, episode_uuid: str) -> bool: + if not episode_uuid: + return False + + try: + req_obj = CanContinueEpisodeRequest( + client_uuid=self.client_uuid, + episode_uuid=episode_uuid + ) + resp = httpx.post( + f"{self.server_url}/can_continue_episode", + json=req_obj.model_dump(), + timeout=10 + ) + resp.raise_for_status() + data = CanContinueEpisodeResponse.model_validate(resp.json()) + return data.can_continue + except Exception as e: + logger.error(f"Error checking can_continue_episode: {e}") + return False + + def get_episode_buffer(self) -> List[EpisodeStatus]: + try: + resp = httpx.post( + f"{self.server_url}/get_episode_buffer", + json={}, + timeout=10 + ) + resp.raise_for_status() + data = EpisodeBufferResponse.model_validate(resp.json()) + return data.buffer + except Exception as e: + logger.error(f"Error getting episode buffer: {e}") + return [] diff --git a/ajet/tuner_lib/weight_tuner/experimental/as_tinkerscript_server.py b/ajet/tuner_lib/weight_tuner/experimental/as_tinkerscript_server.py index e793919a..3d419514 100644 --- a/ajet/tuner_lib/weight_tuner/experimental/as_tinkerscript_server.py +++ b/ajet/tuner_lib/weight_tuner/experimental/as_tinkerscript_server.py @@ -1,77 +1,38 @@ +import time from multiprocessing.managers import DictProxy import threading import zmq from loguru import logger -from pydantic import BaseModel from fastapi import FastAPI, HTTPException from typing import List from typing import Coroutine, Optional, Tuple -from ajet.schema.task import WorkflowOutput - - -class SyncTrainConfigRequest(BaseModel): - yaml_as_string: str - -class ClaimEpisodeRequest(BaseModel): - client_uuid: str - episode_type: str - -class ClaimEpisodeResponse(BaseModel): - success: bool - client_uuid: str - episode_uuid: str - openai_base_url: str = "" - openai_api_key: str = "" - fail_cause: str = "" - -class CanContinueEpisodeRequest(BaseModel): - client_uuid: str - episode_uuid: str - -class CanContinueEpisodeResponse(BaseModel): - can_continue: bool - -class EndEpisodeRequest(BaseModel): - client_uuid: str - episode_uuid: str - workflow_output: WorkflowOutput - -class EndEpisodeResponse(BaseModel): - success: bool - - -class EpisodeStatus(BaseModel): - episode_uuid: str - episode_status: str = "rolling" - openai_base_url: str = "" - openai_api_key: str = "" - client_uuid: str = "" - zmq_listen_result_addr: str = "" - -class EpisodeBufferResponse(BaseModel): - buffer: List[EpisodeStatus] - - -class BoolResponse(BaseModel): - success: bool - -class RegisterEpisodeRequest(BaseModel): - episode_uuid: str - openai_base_url: str = "" - openai_api_key: str = "" - zmq_listen_result_addr: str = "" - - -class UpdateEngineStatusRequest(BaseModel): - engine_status: str = "" +from ajet.tuner_lib.weight_tuner.experimental.interchange_utils import ( + SyncTrainConfigRequest, + ClaimEpisodeRequest, + ClaimEpisodeResponse, + CanContinueEpisodeRequest, + CanContinueEpisodeResponse, + EndEpisodeRequest, + EndEpisodeResponse, + EpisodeStatus, + EpisodeBufferResponse, + BoolResponse, + RegisterEpisodeRequest, + UpdateEngineStatusRequest, +) DEBUG = True -def register_enable_tinkerscript_mode_routes(app, zmq_context, shared_mem_dict:DictProxy, shared_mem_dict_lock:threading.Lock) -> Tuple[FastAPI, Optional[Coroutine]]: +def register_enable_tinkerscript_mode_routes( + app, + zmq_context, + shared_mem_dict:DictProxy, + shared_mem_dict_lock:threading.Lock, + ) -> Tuple[FastAPI, Optional[Coroutine]]: if 'episodes' not in shared_mem_dict: shared_mem_dict["episodes"] = {} @@ -114,7 +75,9 @@ async def register_episode(req: RegisterEpisodeRequest): openai_api_key=req.openai_api_key, episode_status="registered", zmq_listen_result_addr=req.zmq_listen_result_addr, + allow_discard_timeout=-1, ) + es.latest_activity_timestamp = time.time() with shared_mem_dict_lock: shared_mem_dict[f"episodes-{episode_uuid}"] = es @@ -127,7 +90,7 @@ async def register_episode(req: RegisterEpisodeRequest): @app.post("/claim_episode", response_model=ClaimEpisodeResponse) async def claim_episode(req: ClaimEpisodeRequest): - # placeholder implementation — real logic should check episode_semaphore + find_claimed_episodes_that_need_to_be_unclaimed() with shared_mem_dict_lock: if len(shared_mem_dict['unclaimed_episodes']) <= 0: @@ -148,11 +111,13 @@ async def claim_episode(req: ClaimEpisodeRequest): es:EpisodeStatus = shared_mem_dict[f"episodes-{episode_uuid}"] es.episode_status = "claimed" es.client_uuid = req.client_uuid + es.latest_activity_timestamp = time.time() + es.allow_discard_timeout = req.allow_discard_timeout + shared_mem_dict[f"episodes-{episode_uuid}"] = es openai_base_url = es.openai_base_url openai_api_key = es.openai_api_key - return ClaimEpisodeResponse( success=True, client_uuid=req.client_uuid, @@ -163,6 +128,40 @@ async def claim_episode(req: ClaimEpisodeRequest): ) + def find_claimed_episodes_that_need_to_be_unclaimed() -> List[str]: + result = [] + current_time = time.time() + + for k, v in shared_mem_dict.items(): + if k.startswith("episodes-"): + es:EpisodeStatus = v + if es.episode_status == "claimed": + if (current_time - es.latest_activity_timestamp) > es.allow_discard_timeout: + result.append(es.episode_uuid) + + for episode_uuid in result: + _revert_episode_to_unclaimed(episode_uuid) + + return result + + + def _revert_episode_to_unclaimed(episode_uuid: str): + with shared_mem_dict_lock: + # check status again, because other thread may have changed it + if shared_mem_dict[f"episodes-{episode_uuid}"].episode_status != "claimed": + return + + # revert + logger.info(f"Reverting episode {episode_uuid} to unclaimed due to client timeout.") + if f"episodes-{episode_uuid}" in shared_mem_dict: + es:EpisodeStatus = shared_mem_dict[f"episodes-{episode_uuid}"] + es.episode_status = "registered" + es.client_uuid = "" + es.latest_activity_timestamp = time.time() + es.allow_discard_timeout = -1 + shared_mem_dict[f"episodes-{episode_uuid}"] = es + shared_mem_dict['unclaimed_episodes'] += [episode_uuid] + @app.post("/end_episode", response_model=EndEpisodeResponse) async def end_episode(req: EndEpisodeRequest): @@ -172,8 +171,10 @@ async def end_episode(req: EndEpisodeRequest): workflow_output = req.workflow_output if 'episodes' not in shared_mem_dict: + logger.error(f"[server] No episodes registered yet.") raise HTTPException(status_code=400, detail=f"No episodes registered yet.") if (f"episodes-{episode_uuid}") not in shared_mem_dict: + logger.error(f"[server] Episode {episode_uuid} not found.") raise HTTPException(status_code=400, detail=f"Episode {episode_uuid} not found.") # send workflow_output to zmq diff --git a/ajet/tuner_lib/weight_tuner/experimental/interchange_utils.py b/ajet/tuner_lib/weight_tuner/experimental/interchange_utils.py new file mode 100644 index 00000000..35728e31 --- /dev/null +++ b/ajet/tuner_lib/weight_tuner/experimental/interchange_utils.py @@ -0,0 +1,136 @@ +import os +import time +import httpx +from typing import List +from pydantic import BaseModel +from ajet.schema.task import WorkflowOutput +from loguru import logger +from ajet.utils.networking import find_free_port + + +class SyncTrainConfigRequest(BaseModel): + yaml_as_string: str + +class ClaimEpisodeRequest(BaseModel): + client_uuid: str + episode_type: str + allow_discard_timeout: float + +class ClaimEpisodeResponse(BaseModel): + success: bool + client_uuid: str + episode_uuid: str + openai_base_url: str = "" + openai_api_key: str = "" + fail_cause: str = "" + +class CanContinueEpisodeRequest(BaseModel): + client_uuid: str + episode_uuid: str + +class CanContinueEpisodeResponse(BaseModel): + can_continue: bool + +class EndEpisodeRequest(BaseModel): + client_uuid: str + episode_uuid: str + workflow_output: WorkflowOutput + +class EndEpisodeResponse(BaseModel): + success: bool + + +class EpisodeStatus(BaseModel): + episode_uuid: str + episode_status: str = "rolling" + openai_base_url: str = "" + openai_api_key: str = "" + client_uuid: str = "" + zmq_listen_result_addr: str = "" + latest_activity_timestamp: float = time.time() + allow_discard_timeout: float + +class EpisodeBufferResponse(BaseModel): + buffer: List[EpisodeStatus] + + +class BoolResponse(BaseModel): + success: bool + +class RegisterEpisodeRequest(BaseModel): + episode_uuid: str + openai_base_url: str = "" + openai_api_key: str = "" + zmq_listen_result_addr: str = "" + + +class UpdateEngineStatusRequest(BaseModel): + engine_status: str = "" + + +def get_interchange_server_url(config): + port = os.getenv("AJET_DAT_INTERCHANGE_PORT") + if config.ajet.interchange_server.interchange_server_port != 'auto': + port = str(int(config.ajet.interchange_server.interchange_server_port)) + assert port is not None, "AJET_DAT_INTERCHANGE_PORT env var must be set" + master_node_ip = os.getenv("MASTER_NODE_IP", "localhost") + base_url = f"http://{master_node_ip}:{port}" + return base_url + + +def http_change_engine_status(config: str, new_status: str): + resp = httpx.post( + f"{get_interchange_server_url(config)}/update_engine_status", + json={"engine_status": new_status}, + timeout=10 + ) + resp.raise_for_status() + logger.info(f"Changed engine status to {new_status}") + + + +def http_register_episode(config, episode_uuid: str, + openai_base_url: str, openai_api_key: str, + zmq_listen_result_addr: str): + + # parse episode_uuid, openai_base_url, openai_api_key + interchange_http_addr = get_interchange_server_url(config) + rer = RegisterEpisodeRequest( + episode_uuid=episode_uuid, + openai_base_url=openai_base_url, + openai_api_key=openai_api_key, + zmq_listen_result_addr=zmq_listen_result_addr, + ) + # send http request to tinkerscript server to register episode + while True: + try: + response = httpx.post( + f"{interchange_http_addr}/register_episode", + json=rer.model_dump(), # 或者 rer.model_dump() 如果使用 Pydantic v2 + timeout=30 + ) + response.raise_for_status() + result = response.json() + if not result.get('success'): + raise RuntimeError(f"Failed to register episode {episode_uuid}") + logger.info(f"Successfully registered episode {episode_uuid}") + break + except httpx.HTTPError as e: + logger.error(f"Error registering episode {episode_uuid}: {e}. Retrying...") + time.sleep(5) + + return rer + + +def get_zmq_socket(config, episode_uuid: str, tag: str = ""): + interchange_method = config.ajet.interchange_server.interchange_method + if interchange_method == 'tcp': + ipc_path = "" + master_node_ip = os.getenv("MASTER_NODE_IP", "localhost") + zmq_contect_address = f"tcp://{master_node_ip}:{find_free_port()}" + elif interchange_method == 'ipc': + ipc_path = f"/tmp/ajet/{episode_uuid}-{tag}.sock" + zmq_contect_address = f"ipc://{ipc_path}" + else: + raise RuntimeError(f"Unknown interchange_method: {interchange_method}") + return zmq_contect_address, ipc_path diff --git a/ajet_tinkerscript.py b/ajet_tinkerscript.py index 7577364d..878d9610 100644 --- a/ajet_tinkerscript.py +++ b/ajet_tinkerscript.py @@ -73,7 +73,7 @@ def execute_agent(task, api_baseurl_key: OpenaiBaseUrlAndApiKey): # Use raw http requests (non-streaming) to get response response = requests.post( f"{base_url}/chat/completions", json = { "model": "fill_whatever_model", "messages": messages, }, headers = { "Authorization": f"Bearer {api_key}" } ) - print(response.json()) + # print(response.json()) final_answer = response.json()['choices'][0]['message']['content'] # Compute reward reference_answer = reference_answer.split("####")[-1].strip() diff --git a/ajet_tinkerscript_threading.py b/ajet_tinkerscript_threading.py new file mode 100644 index 00000000..c0d35ea8 --- /dev/null +++ b/ajet_tinkerscript_threading.py @@ -0,0 +1,94 @@ +import re +import requests +from textwrap import dedent +from ajet.tuner_lib.weight_tuner.experimental.as_tinkerscript_client import TinkerScriptClient +from ajet.default_config.ajet_default import AjetTaskReader, HuggingfaceDatRepo +from ajet.tuner_lib.weight_tuner.as_oai_baseurl_apikey import OpenaiBaseUrlAndApiKey +from ajet import WorkflowOutput +from ajet.task_reader import RouterTaskReader +from ajet.utils.retry import retry_with_backoff +from concurrent.futures import ThreadPoolExecutor + + +TINKERJET_URL = "http://localhost:10086" # Change to your tinkerscript remote url +NUM_EPOCH = 100 +GRPO_N = 4 # grpo group size +MAX_PARALLEL = 2 + +class WeightUpdatedHalfway(Exception): + """Raised when the remote side starts updating model weights halfway through an episode.""" + + +def main(): + + # Handshake with tinkerscript remote, then send training param to tinkerscript remote (such as model to be trained, algorithm, etc) + dataset = RouterTaskReader( + reader_type = "huggingface_dat_repo", + reader_config = AjetTaskReader( + huggingface_dat_repo = HuggingfaceDatRepo( + dataset_path = "/mnt/data_cpfs/qingxu.fu/dataset/openai/gsm8k/main" + ) + ) + ) + tinkerscript_remote = TinkerScriptClient(TINKERJET_URL) + + # Define rollout + def rollout(task): + group_reward = [] + for i in range(GRPO_N): + # begin episode + episode_uuid, api_baseurl_key = tinkerscript_remote.begin_episode() + # execute agent + workflow_output = execute_agent(task, api_baseurl_key) + # report output back to tinkerscript remote + tinkerscript_remote.end_episode(episode_uuid, workflow_output) + # collect reward + group_reward.append(workflow_output.reward) + print(f"Group reward mean & std: {sum(group_reward)/len(group_reward)} +/- { (max(group_reward)-min(group_reward))/2 }") + + + # Main Training loop + with ThreadPoolExecutor(max_workers=MAX_PARALLEL) as executor: + for epoch in range(NUM_EPOCH): + for task in dataset.get_training_tasks(): + print(f"Submitting task for epoch {epoch}") + executor.submit(rollout, task) + + # Get tuned model from tinkerscript remote + return None + + + + +@retry_with_backoff(max_retry=2) +def execute_agent(task, api_baseurl_key: OpenaiBaseUrlAndApiKey): + # Prepare base_url, api_key + base_url, api_key = (api_baseurl_key.base_url, api_baseurl_key.api_key) + # Read dataset item + query, reference_answer = (task.main_query, task.metadata["answer"]) + # Prepare messages + messages = [ + { "role": "system", "content": dedent("""You are an agent specialized in solving math problems. Please solve the math problem given to you. + You can write and execute Python code to perform calculation or verify your answer. You should return your final answer within \\boxed{{}}.""") }, + { "role": "user", "content": query } + ] + # Use raw http requests (non-streaming) to get response + response = requests.post( f"{base_url}/chat/completions", json = { "model": "fill_whatever_model", "messages": messages, }, + headers = { "Authorization": f"Bearer {api_key}" } ) + final_answer = response.json()['choices'][0]['message']['content'] + print(final_answer) + # Compute reward + reference_answer = reference_answer.split("####")[-1].strip() + pattern = r"\\boxed\{([^}]*)\}" + match = re.search(pattern, final_answer) + if match: is_success = match.group(1) == reference_answer + else: is_success = False + raw_reward = 1.0 if is_success else 0.0 + # Return + return WorkflowOutput(reward=raw_reward, metadata={"final_answer": final_answer}) + + + + +if __name__ == "__main__": + main() From 217848132b1e4c08dc03f03c3c037bf3b45e8bd5 Mon Sep 17 00:00:00 2001 From: "qingxu.fu" Date: Tue, 20 Jan 2026 00:53:42 +0800 Subject: [PATCH 03/25] feat(tinkerscript): Add comprehensive design blueprint and workflow documentation --- ajet_tinkerscript.py | 90 ---------------- tinkerscript.md | 251 +++++++++++++++++++++++++++++++++++++++++++ tinkerscript_1.md | 120 +++++++++++++++++++++ 3 files changed, 371 insertions(+), 90 deletions(-) delete mode 100644 ajet_tinkerscript.py create mode 100644 tinkerscript.md create mode 100644 tinkerscript_1.md diff --git a/ajet_tinkerscript.py b/ajet_tinkerscript.py deleted file mode 100644 index 878d9610..00000000 --- a/ajet_tinkerscript.py +++ /dev/null @@ -1,90 +0,0 @@ -import re -import requests -from textwrap import dedent -from ajet.tuner_lib.weight_tuner.experimental.as_tinkerscript_client import TinkerScriptClient -from ajet.default_config.ajet_default import AjetTaskReader, HuggingfaceDatRepo -from ajet.tuner_lib.weight_tuner.as_oai_baseurl_apikey import OpenaiBaseUrlAndApiKey -from ajet import WorkflowOutput -from ajet.task_reader import RouterTaskReader -from ajet.utils.retry import retry_with_backoff - - -TINKERJET_URL = "http://localhost:10086" # Change to your tinkerjet remote url -NUM_EPOCH = 100 -GRPO_N = 4 # grpo group size - - -class WeightUpdatedHalfway(Exception): - """Raised when the remote side starts updating model weights halfway through an episode.""" - - -def main(): - - # Handshake with tinkerjet remote, then send training param to tinkerjet remote (such as model to be trained, algorithm, etc) - tinkerjet_remote = TinkerScriptClient(TINKERJET_URL) - dataset = RouterTaskReader( - reader_type = "huggingface_dat_repo", - reader_config = AjetTaskReader( - huggingface_dat_repo = HuggingfaceDatRepo( - dataset_path = "/mnt/data_cpfs/qingxu.fu/dataset/openai/gsm8k/main" - ) - ) - ) - - # Define rollout - def rollout(task): - # Q: Can I run episodes in parallel? - # A: Yes, wrap `rollout` in a thread or process pool. - api_baseurl_key = tinkerjet_remote.begin_episode() - workflow_output = execute_agent(task, api_baseurl_key) - tinkerjet_remote.end_episode(workflow_output) - return workflow_output.reward - - # Main Training loop - for epoch in range(NUM_EPOCH): - for task in dataset.get_training_tasks(): - try: - for i in range(GRPO_N): - reward = rollout(task) - print(f"{epoch}-{task}-run:{i}-{reward}") - except WeightUpdatedHalfway as e: - print(f"The remote side has gone into the LLM model weight update phrase halfway through an episode." - f"This is **normal**." - f"The remote no longer need this task anymore, so let's go to next task.") - - # Get tuned model from tinkerjet remote - return None - - - - -@retry_with_backoff(max_retry=2) -def execute_agent(task, api_baseurl_key: OpenaiBaseUrlAndApiKey): - # Prepare base_url, api_key - base_url, api_key = (api_baseurl_key.base_url, api_baseurl_key.api_key) - # Read dataset item - query, reference_answer = (task.main_query, task.metadata["answer"]) - # Prepare messages - messages = [ - { "role": "system", "content": dedent("""You are an agent specialized in solving math problems. Please solve the math problem given to you. - You can write and execute Python code to perform calculation or verify your answer. You should return your final answer within \\boxed{{}}.""") }, - { "role": "user", "content": query } - ] - # Use raw http requests (non-streaming) to get response - response = requests.post( f"{base_url}/chat/completions", json = { "model": "fill_whatever_model", "messages": messages, }, - headers = { "Authorization": f"Bearer {api_key}" } ) - # print(response.json()) - final_answer = response.json()['choices'][0]['message']['content'] - # Compute reward - reference_answer = reference_answer.split("####")[-1].strip() - pattern = r"\\boxed\{([^}]*)\}" - match = re.search(pattern, final_answer) - if match: is_success = match.group(1) == reference_answer - else: is_success = False - raw_reward = 1.0 if is_success else 0.0 - # Return - return WorkflowOutput(reward=raw_reward, metadata={"final_answer": final_answer}) - - -if __name__ == "__main__": - main() diff --git a/tinkerscript.md b/tinkerscript.md new file mode 100644 index 00000000..03cae640 --- /dev/null +++ b/tinkerscript.md @@ -0,0 +1,251 @@ +# TinkerScript Design Blueprint / TinkerScript 设计蓝图 + +[English](#english-version) | [中文](#chinese-version) + +--- + + +## 🇬🇧 English Version + +### 1. Overview +**TinkerScript** is an experimental component of AgentJet designed to decouple the **Training Logic** from the **Agent Execution Logic**. It allows users to train **full-weight LLM models** on machines without GPUs (e.g., a laptop) by offloading the actual model computation to a remote GPU server. + +Unlike traditional setups where the user code must run inside the training cluster, TinkerScript allows you to verify and run your agent logic locally while the heavy lifting (training & inference) happens remotely. + +### 2. Core Architecture + +The system involves two main parties: the **TinkerScript Server** (running on the GPU cluster) and the **TinkerScript Client** (running on your local machine). + +```mermaid +graph TD + subgraph "GPU Cluster (Server Side)" + TrainingLoop[Training Loop (AgentJet/GRPO)] + TSS[TinkerScript Server (FastAPI)] + ZMQ[ZeroMQ / IPC] + SharedMem[(Shared Memory)] + LLM[LLM Engine (vLLM/SGLang)] + end + + subgraph "User Laptop / CPU Cluster (Client Side)" + UserScript[User Script (python while loop)] + AgentLogic[Agent Logic / Tools] + end + + TrainingLoop -- "1. Generate Task" --> SharedMem + SharedMem -- "2. Register Episode" --> TSS + + UserScript -- "3. Claim Episode (HTTP)" --> TSS + TSS -- "4. Return API Key & Base URL" --> UserScript + + UserScript -- "5. Inference (OpenAI API)" --> LLM + LLM -- "Token Stream" --> UserScript + + UserScript -- "6. Submit Reward (HTTP)" --> TSS + TSS -- "7. Push Result" --> ZMQ + ZMQ -- "8. Update Weights" --> TrainingLoop +``` + +### 3. Detailed Workflow + +The workflow relies on a "Claim & Submit" model. The training loop generates tasks ("Episodes") and waits for external workers to pick them up. + +```mermaid +sequenceDiagram + participant TL as Training Loop (Internal) + participant S as Server (FastAPI) + participant C as Client (User Script) + participant M as LLM Model + + Note over TL, S: 1. Task Generation + TL->>S: Register Episode (Status: Unclaimed) + + Note over C, S: 2. Task Acquisition + loop Worker Loop + C->>S: POST /claim_episode + alt No Tasks + S-->>C: Retry Later + else Task Available + S->>S: Mark as "Claimed" + S-->>C: Return {EpisodeID, OpenAI_BaseURL, API_Key} + end + + Note over C, M: 3. Execution (Rollout) + C->>M: Chat Completion Request (Inference) + M-->>C: Response (Generation) + C->>C: Calculate Reward (e.g., Verify Math Answer) + + Note over C, S: 4. Result Submission + C->>S: POST /end_episode {Reward, Metadata} + S->>TL: Forward Result via ZeroMQ + S->>S: Delete Episode Record (Complete) + end +``` + +### 4. Episode State Machine + +To handle network failures or client crashes, the server maintains a state machine for every episode. + +```mermaid +stateDiagram-v2 + [*] --> Registered + Registered --> Unclaimed_Queue : Add to Queue + + Unclaimed_Queue --> Claimed : Client requests task + + Claimed --> Completed : Client submits result + Claimed --> Registered : Client Timeout / Crash + + Completed --> [*] : Removed from Memory +``` + +* **Registered**: Task created by the training algorithm. +* **Claimed**: A client is currently working on it. +* **Timeout**: If a client claims a task but doesn't report back within `allow_discard_timeout`, the server reverts the status to **Registered** so another client can try. + +### 5. Implementation Example + +The user experience is designed to be minimal. You simply query the remote server for a "job", do the work, and report the "score". + +```python +# User-side Code Concept +def rollout(task): + # 1. Handshake & Claim (Get credentials for this specific episode) + api_baseurl_key = tinkerjet_remote.begin_episode() + + # 2. Run your existing agent logic using standard OpenAI format + workflow_output = execute_agent(task, api_baseurl_key) + + # 3. Submit results + tinkerjet_remote.end_episode(workflow_output) + return workflow_output.reward +``` + +### 6. Limitations + +1. **Strict OpenAI Protocol**: Users must use the OpenAI `base_url` + `api_key` pattern. Internal access (like direct model object access) is not available. +2. **Implicit Multi-Agent Handling**: AgentJet cannot explicitly distinguish different agents in a multi-agent scenario via API, though it attempts to merge timeline shards automatically. +3. **No Prompt Tuning**: TinkerScript is designed for full-weight model training, not for soft-prompt tuning. + +--- + + +## 🇨🇳 中文版本 (Chinese Version) + +### 1. 概述 (Overview) +**TinkerScript** 是 AgentJet 的一个实验性组件,旨在将 **训练逻辑 (Training Logic)** 与 **Agent 执行逻辑 (Execution Logic)** 解耦。它允许用户在 **没有 GPU** 的机器上(例如普通笔记本电脑)训练 **全参数 LLM 模型**,计算压力完全由远程 GPU 服务器承担。 + +与传统的将用户代码嵌入训练集群的方式不同,TinkerScript 允许你在本地运行并验证 Agent 逻辑,通过网络与远程训练循环交互。 + +### 2. 核心架构 (Core Architecture) + +系统包含两个主要部分:运行在 GPU 集群上的 **TinkerScript Server** 和运行在本地的 **TinkerScript Client**。 + +```mermaid +graph TD + subgraph "GPU 集群 (Server 端)" + TrainingLoop[训练循环 (AgentJet/GRPO)] + TSS[TinkerScript Server (FastAPI)] + ZMQ[ZeroMQ / IPC 通信] + SharedMem[(共享内存)] + LLM[LLM 推理引擎 (vLLM/SGLang)] + end + + subgraph "用户笔记本 / CPU 集群 (Client 端)" + UserScript[用户脚本 (Python While Loop)] + AgentLogic[Agent 业务逻辑 / 工具调用] + end + + TrainingLoop -- "1. 生成任务 (Task)" --> SharedMem + SharedMem -- "2. 注册 Episode" --> TSS + + UserScript -- "3. 领取任务 (HTTP Claim)" --> TSS + TSS -- "4. 返回 API Key 与 Base URL" --> UserScript + + UserScript -- "5. 推理请求 (OpenAI 协议)" --> LLM + LLM -- "生成 Token 流" --> UserScript + + UserScript -- "6. 提交 Reward (HTTP End)" --> TSS + TSS -- "7. 推送结果" --> ZMQ + ZMQ -- "8. 更新权重" --> TrainingLoop +``` + +### 3. 详细工作流 (Detailed Workflow) + +基于“领取 (Claim) - 提交 (Submit)”模式。训练循环生成任务(Episode),等待外部 Worker 领取执行。 + +```mermaid +sequenceDiagram + participant TL as 训练循环 (内部) + participant S as Server (FastAPI) + participant C as Client (用户脚本) + participant M as LLM 模型服务 + + Note over TL, S: 1. 任务生成阶段 + TL->>S: 注册 Episode (状态: Unclaimed) + + Note over C, S: 2. 任务领取阶段 + loop Worker Loop + C->>S: POST /claim_episode (请求任务) + alt 无可用任务 + S-->>C: 请稍后重试 + else 有可用任务 + S->>S: 标记为 "Claimed" + S-->>C: 返回 {EpisodeID, OpenAI_BaseURL, API_Key} + end + + Note over C, M: 3. 执行阶段 (Rollout) + C->>M: Chat Completion 请求 (推理通过网络回传) + M-->>C: 返回生成结果 + C->>C: 计算 Reward (例如: 验证数学答案) + + Note over C, S: 4. 结果提交阶段 + C->>S: POST /end_episode {Reward, Metadata} + S->>TL: 通过 ZeroMQ 转发结果给训练器 + S->>S: 删除 Episode 记录 (完成) + end +``` + +### 4. 状态机管理 (Episode State Machine) + +为了处理网络波动或客户端崩溃(Crash),服务端为每个 Episode 维护了一个状态机。 + +```mermaid +stateDiagram-v2 + [*] --> Registered (已注册) + Registered --> Unclaimed_Queue : 加入待领取队列 + + Unclaimed_Queue --> Claimed (已被领取) : 客户端请求任务 + + Claimed --> Completed (已完成) : 客户端提交结果 + Claimed --> Registered (已注册) : 客户端超时 / 崩溃 + + Completed --> [*] : 从内存中移除 +``` + +* **Registered (已注册)**: 训练算法生成了该任务,等待被执行。 +* **Claimed (已被领取)**: 某个 Client 正在处理该任务。 +* **Timeout (超时)**: 如果 Client 领取任务后在规定时间 (`allow_discard_timeout`) 内未提交结果,服务器会将状态重置为 **Registered**,允许其他 Client 重新领取该任务(容错机制)。 + +### 5. 实现代码示例 + +用户侧的代码非常简洁。简而言之:向远程服务器要一个“活儿”,干完活,上报“得分”。 + +```python +# 用户侧代码概念演示 +def rollout(task): + # 1. 握手 & 领取任务 (获取当前 Episode 专属的鉴权信息) + api_baseurl_key = tinkerjet_remote.begin_episode() + + # 2. 运行你现有的 Agent 逻辑 (使用标准 OpenAI 接口) + workflow_output = execute_agent(task, api_baseurl_key) + + # 3. 提交结果 + tinkerjet_remote.end_episode(workflow_output) + return workflow_output.reward +``` + +### 6. 局限性 (Limitations) + +1. **严格依赖 OpenAI 协议**: 用户必须使用 OpenAI `base_url` + `api_key` 的方式与模型交互。无法获取模型内部对象(Weights/Gradients)。 +2. **隐式多智能体处理**: 在多智能体(Multi-Agent)场景下,AgentJet 无法通过 API 显式区分不同的 Agent 角色,但后台会尝试自动合并时间线片段。 +3. **不支持 Prompt Tuning**: TinkerScript 专为全量模型微调设计,不支持 Soft-Prompt Tuning 等轻量级微调。 diff --git a/tinkerscript_1.md b/tinkerscript_1.md new file mode 100644 index 00000000..076087dd --- /dev/null +++ b/tinkerscript_1.md @@ -0,0 +1,120 @@ +# TinkerScript Design Blueprint + +TinkerScript represents a client-server architecture designed to decouple the **Training Loop** (Server-side) from the **Rollout Execution** (Client-side). This allows for distributed, flexible, and potentially remote execution of agent rollouts (inference + reward calculation) while centralizing the model training and weight updates. + +## 1. System Architecture + +The system consists of three main components: + +### A. TinkerScript Server (The Trainer) +* **Role**: Manages the training lifecycle, generates tasks (episodes), serves the model (LLM) API, and updates model weights. +* **Technology**: Python, FastAPI, ZeroMQ (IPC/TCP), Shared Memory (Multiprocessing). +* **Location**: Runs on the GPU cluster/Training node. +* **Key Functionality**: + * Maintains a queue of "Episodes" (training tasks). + * Exposes an HTTP API for external clients to claim tasks and submit results. + * Acts as a bridge between the HTTP world and the internal ZeroMQ-based training pipeline. + +### B. TinkerScript Client (The User Script) +* **Role**: Fetches tasks, runs the agent logic, computes rewards, and reports back. +* **Technology**: Python (Requests/HTTPX). +* **Location**: Can run locally, on a separate CPU cluster, or even a different cloud environment. +* **Key Functionality**: + * Connects to the Server URL. + * Claims episodes via `begin_episode()`. + * Executes the agent logic (e.g., calling the LLM, running Python code). + * Calculates rewards (e.g., verifying math answers). + * Submits results via `end_episode()`. + +### C. The LLM Serving Layer (Implicit) +* The system provides an OpenAI-compatible API endpoint (`base_url`, `api_key`) to the client for LLM inference. This is likely hosted by the training system itself or a proxy, enabling the client to query the model being trained. + +--- + +## 2. Detailed Workflow + +### Step 1: Episode Generation & Registration (Server Side) +The training loop (e.g., RL algorithm like GRPO) generates a new task. +1. An internal component registers a new episode via `register_episode`. +2. The server stores this in `shared_mem_dict` with status `registered`. +3. The episode is added to the `unclaimed_episodes` queue. +4. The server sets up a ZeroMQ socket to listen for the result of this specific episode. + +### Step 2: Task Claiming (Client Side) +The user's script calls `tinkerjet_remote.begin_episode()`. +1. **Request**: `POST /claim_episode` +2. **Server Logic**: + * Checks `unclaimed_episodes`. + * If available, pops one episode. + * Updates status to `claimed`. + * Records `client_uuid` and `latest_activity_timestamp`. +3. **Response**: Returns `episode_uuid` and **OpenAI Credentials** (Base URL + API Key) specific to this session/model. + +### Step 3: Rollout & Execution (Client Side) +The user's script (`execute_agent`) runs: +1. Uses the provided OpenAI API to chat with the model (performing the actual inference step of the RL loop). +2. Parses the model's output. +3. Computes a reward (e.g., checking if `\boxed{answer}` matches ground truth). + +### Step 4: Result Submission (Client Side) +The user's script calls `tinkerjet_remote.end_episode()`. +1. **Request**: `POST /end_episode` with `workflow_output` (Reward + Metadata). +2. **Server Logic**: + * Validates the episode exists and is claimed by this client. + * Connects to the internal ZeroMQ socket associated with this episode. + * Forwards the `workflow_output` payload into the ZeroMQ socket, effectively pushing it back into the training loop. + * Waits for an acknowledgment. + * Deletes the episode record from memory upon success. + +### Step 5: Failure Recovery & Timeouts +* **Crash Recovery**: If a client crashes after claiming a task, the server tracks `latest_activity_timestamp`. +* **Requisition**: A background check (`find_claimed_episodes_that_need_to_be_unclaimed`) reverts "stale" claimed episodes back to `registered` status so other clients can pick them up. +* **Weight Updates**: If the server moves to a weight update phase, it might verify if an episode is still valid via `can_continue_episode`. + +--- + +## 3. Data Structures & API Design + +### Episode Status Object +Stored in Server Shared Memory: +```python +class EpisodeStatus: + episode_uuid: str # Unique ID for the task + client_uuid: str # ID of the worker claiming it + episode_status: str # "registered", "claimed" + openai_base_url: str # Endpoint for the model + openai_api_key: str # Auth for the model + zmq_listen_result_addr: str # Internal address to forward results to + latest_activity_timestamp: float +``` + +### API Endpoints + +| Method | Endpoint | Description | +| :--- | :--- | :--- | +| `POST` | `/claim_episode` | Worker requests a job. Returns UUID + LLM credentials. | +| `POST` | `/end_episode` | Worker submits results (Reward). Completes the cycle. | +| `POST` | `/can_continue_episode` | Checks if the episode is still valid (e.g., weights haven't changed). | +| `POST` | `/register_episode` | (Internal/Debug) Adds a task to the queue. | +| `GET` | `/get_engine_status` | Returns system health/state (e.g., "booting", "ready"). | +| `POST` | `/sync_train_config` | Syncs configuration yaml (logging/debug). | + +--- + +## 4. Key Configurations + +From `ajet_tinkerscript_default.yaml`, we see how this mode is activated: + +```yaml +experiment_dir: "auto" +enable_tinkerscript_mode: True # Activates the HTTP API Server +interchange_server: + interchange_method: 'ipc' # Internal communication (ZeroMQ) + interchange_server_port: 10086 # HTTP API Port +``` + +## 5. Benefits of this Design + +1. **Flexibility**: Users can write custom python logic for "Rollout" without modifying the core C++/Python training engine. +2. **Distributed Generation**: You can have 1 training node and 1000 cheap CPU nodes just running the python script to generate data. +3. **Complex Logic Support**: Since the rollout is just a client script, it can call external tools, Sandboxed code interpreters, or APIs (Google Search) easily before calculating the reward. From 04446323ef5a88dea1215ee0357d4feb6469945b Mon Sep 17 00:00:00 2001 From: "qingxu.fu" Date: Tue, 20 Jan 2026 00:57:20 +0800 Subject: [PATCH 04/25] fix mermaid --- tinkerscript.md | 28 ++++++++++++++-------------- 1 file changed, 14 insertions(+), 14 deletions(-) diff --git a/tinkerscript.md b/tinkerscript.md index 03cae640..0ae38d6d 100644 --- a/tinkerscript.md +++ b/tinkerscript.md @@ -19,16 +19,16 @@ The system involves two main parties: the **TinkerScript Server** (running on th ```mermaid graph TD subgraph "GPU Cluster (Server Side)" - TrainingLoop[Training Loop (AgentJet/GRPO)] - TSS[TinkerScript Server (FastAPI)] - ZMQ[ZeroMQ / IPC] - SharedMem[(Shared Memory)] - LLM[LLM Engine (vLLM/SGLang)] + TrainingLoop["Training Loop (AgentJet/GRPO)"] + TSS["TinkerScript Server (FastAPI)"] + ZMQ["ZeroMQ / IPC"] + SharedMem[("Shared Memory")] + LLM["LLM Engine (vLLM/SGLang)"] end subgraph "User Laptop / CPU Cluster (Client Side)" - UserScript[User Script (python while loop)] - AgentLogic[Agent Logic / Tools] + UserScript["User Script (python while loop)"] + AgentLogic["Agent Logic / Tools"] end TrainingLoop -- "1. Generate Task" --> SharedMem @@ -143,16 +143,16 @@ def rollout(task): ```mermaid graph TD subgraph "GPU 集群 (Server 端)" - TrainingLoop[训练循环 (AgentJet/GRPO)] - TSS[TinkerScript Server (FastAPI)] - ZMQ[ZeroMQ / IPC 通信] - SharedMem[(共享内存)] - LLM[LLM 推理引擎 (vLLM/SGLang)] + TrainingLoop["训练循环 (AgentJet/GRPO)"] + TSS["TinkerScript Server (FastAPI)"] + ZMQ["ZeroMQ / IPC 通信"] + SharedMem[("共享内存")] + LLM["LLM 推理引擎 (vLLM/SGLang)"] end subgraph "用户笔记本 / CPU 集群 (Client 端)" - UserScript[用户脚本 (Python While Loop)] - AgentLogic[Agent 业务逻辑 / 工具调用] + UserScript["用户脚本 (Python While Loop)"] + AgentLogic["Agent 业务逻辑 / 工具调用"] end TrainingLoop -- "1. 生成任务 (Task)" --> SharedMem From adbeadc74c789992dfbee1bc0eea1f7d9e8d8a77 Mon Sep 17 00:00:00 2001 From: binary-husky <96192199+binary-husky@users.noreply.github.com> Date: Tue, 20 Jan 2026 02:16:21 +0800 Subject: [PATCH 05/25] Remove limitations and Chinese version from documentation Removed limitations section and Chinese version from TinkerScript documentation. --- tinkerscript.md | 130 ------------------------------------------------ 1 file changed, 130 deletions(-) diff --git a/tinkerscript.md b/tinkerscript.md index 0ae38d6d..364d7c52 100644 --- a/tinkerscript.md +++ b/tinkerscript.md @@ -119,133 +119,3 @@ def rollout(task): tinkerjet_remote.end_episode(workflow_output) return workflow_output.reward ``` - -### 6. Limitations - -1. **Strict OpenAI Protocol**: Users must use the OpenAI `base_url` + `api_key` pattern. Internal access (like direct model object access) is not available. -2. **Implicit Multi-Agent Handling**: AgentJet cannot explicitly distinguish different agents in a multi-agent scenario via API, though it attempts to merge timeline shards automatically. -3. **No Prompt Tuning**: TinkerScript is designed for full-weight model training, not for soft-prompt tuning. - ---- - - -## 🇨🇳 中文版本 (Chinese Version) - -### 1. 概述 (Overview) -**TinkerScript** 是 AgentJet 的一个实验性组件,旨在将 **训练逻辑 (Training Logic)** 与 **Agent 执行逻辑 (Execution Logic)** 解耦。它允许用户在 **没有 GPU** 的机器上(例如普通笔记本电脑)训练 **全参数 LLM 模型**,计算压力完全由远程 GPU 服务器承担。 - -与传统的将用户代码嵌入训练集群的方式不同,TinkerScript 允许你在本地运行并验证 Agent 逻辑,通过网络与远程训练循环交互。 - -### 2. 核心架构 (Core Architecture) - -系统包含两个主要部分:运行在 GPU 集群上的 **TinkerScript Server** 和运行在本地的 **TinkerScript Client**。 - -```mermaid -graph TD - subgraph "GPU 集群 (Server 端)" - TrainingLoop["训练循环 (AgentJet/GRPO)"] - TSS["TinkerScript Server (FastAPI)"] - ZMQ["ZeroMQ / IPC 通信"] - SharedMem[("共享内存")] - LLM["LLM 推理引擎 (vLLM/SGLang)"] - end - - subgraph "用户笔记本 / CPU 集群 (Client 端)" - UserScript["用户脚本 (Python While Loop)"] - AgentLogic["Agent 业务逻辑 / 工具调用"] - end - - TrainingLoop -- "1. 生成任务 (Task)" --> SharedMem - SharedMem -- "2. 注册 Episode" --> TSS - - UserScript -- "3. 领取任务 (HTTP Claim)" --> TSS - TSS -- "4. 返回 API Key 与 Base URL" --> UserScript - - UserScript -- "5. 推理请求 (OpenAI 协议)" --> LLM - LLM -- "生成 Token 流" --> UserScript - - UserScript -- "6. 提交 Reward (HTTP End)" --> TSS - TSS -- "7. 推送结果" --> ZMQ - ZMQ -- "8. 更新权重" --> TrainingLoop -``` - -### 3. 详细工作流 (Detailed Workflow) - -基于“领取 (Claim) - 提交 (Submit)”模式。训练循环生成任务(Episode),等待外部 Worker 领取执行。 - -```mermaid -sequenceDiagram - participant TL as 训练循环 (内部) - participant S as Server (FastAPI) - participant C as Client (用户脚本) - participant M as LLM 模型服务 - - Note over TL, S: 1. 任务生成阶段 - TL->>S: 注册 Episode (状态: Unclaimed) - - Note over C, S: 2. 任务领取阶段 - loop Worker Loop - C->>S: POST /claim_episode (请求任务) - alt 无可用任务 - S-->>C: 请稍后重试 - else 有可用任务 - S->>S: 标记为 "Claimed" - S-->>C: 返回 {EpisodeID, OpenAI_BaseURL, API_Key} - end - - Note over C, M: 3. 执行阶段 (Rollout) - C->>M: Chat Completion 请求 (推理通过网络回传) - M-->>C: 返回生成结果 - C->>C: 计算 Reward (例如: 验证数学答案) - - Note over C, S: 4. 结果提交阶段 - C->>S: POST /end_episode {Reward, Metadata} - S->>TL: 通过 ZeroMQ 转发结果给训练器 - S->>S: 删除 Episode 记录 (完成) - end -``` - -### 4. 状态机管理 (Episode State Machine) - -为了处理网络波动或客户端崩溃(Crash),服务端为每个 Episode 维护了一个状态机。 - -```mermaid -stateDiagram-v2 - [*] --> Registered (已注册) - Registered --> Unclaimed_Queue : 加入待领取队列 - - Unclaimed_Queue --> Claimed (已被领取) : 客户端请求任务 - - Claimed --> Completed (已完成) : 客户端提交结果 - Claimed --> Registered (已注册) : 客户端超时 / 崩溃 - - Completed --> [*] : 从内存中移除 -``` - -* **Registered (已注册)**: 训练算法生成了该任务,等待被执行。 -* **Claimed (已被领取)**: 某个 Client 正在处理该任务。 -* **Timeout (超时)**: 如果 Client 领取任务后在规定时间 (`allow_discard_timeout`) 内未提交结果,服务器会将状态重置为 **Registered**,允许其他 Client 重新领取该任务(容错机制)。 - -### 5. 实现代码示例 - -用户侧的代码非常简洁。简而言之:向远程服务器要一个“活儿”,干完活,上报“得分”。 - -```python -# 用户侧代码概念演示 -def rollout(task): - # 1. 握手 & 领取任务 (获取当前 Episode 专属的鉴权信息) - api_baseurl_key = tinkerjet_remote.begin_episode() - - # 2. 运行你现有的 Agent 逻辑 (使用标准 OpenAI 接口) - workflow_output = execute_agent(task, api_baseurl_key) - - # 3. 提交结果 - tinkerjet_remote.end_episode(workflow_output) - return workflow_output.reward -``` - -### 6. 局限性 (Limitations) - -1. **严格依赖 OpenAI 协议**: 用户必须使用 OpenAI `base_url` + `api_key` 的方式与模型交互。无法获取模型内部对象(Weights/Gradients)。 -2. **隐式多智能体处理**: 在多智能体(Multi-Agent)场景下,AgentJet 无法通过 API 显式区分不同的 Agent 角色,但后台会尝试自动合并时间线片段。 -3. **不支持 Prompt Tuning**: TinkerScript 专为全量模型微调设计,不支持 Soft-Prompt Tuning 等轻量级微调。 From 4ebf055b49fcbed0db2fc3c6e5253283b343ce24 Mon Sep 17 00:00:00 2001 From: binary-husky <96192199+binary-husky@users.noreply.github.com> Date: Tue, 20 Jan 2026 02:25:53 +0800 Subject: [PATCH 06/25] Clarify relationship between TinkerScript and Tinker --- tinkerscript.md | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/tinkerscript.md b/tinkerscript.md index 364d7c52..3eaaf9d0 100644 --- a/tinkerscript.md +++ b/tinkerscript.md @@ -12,6 +12,13 @@ Unlike traditional setups where the user code must run inside the training cluster, TinkerScript allows you to verify and run your agent logic locally while the heavy lifting (training & inference) happens remotely. + +> +> Relationship between **TinkerScript** and **Tinker**: +> +> **No relationship at all** (just like **JavaScript** and **Java**). **TinkerScript** is open-source and free. **Tinker** is close-source and not free. +> + ### 2. Core Architecture The system involves two main parties: the **TinkerScript Server** (running on the GPU cluster) and the **TinkerScript Client** (running on your local machine). From ba10f35b5ccd1976fb2925debb13c37cd39ffb9f Mon Sep 17 00:00:00 2001 From: binary-husky <96192199+binary-husky@users.noreply.github.com> Date: Tue, 20 Jan 2026 02:43:50 +0800 Subject: [PATCH 07/25] Update tinkerscript.md --- tinkerscript.md | 22 +++++++++++++++++++++- 1 file changed, 21 insertions(+), 1 deletion(-) diff --git a/tinkerscript.md b/tinkerscript.md index 3eaaf9d0..cfd67114 100644 --- a/tinkerscript.md +++ b/tinkerscript.md @@ -17,7 +17,27 @@ Unlike traditional setups where the user code must run inside the training clust > Relationship between **TinkerScript** and **Tinker**: > > **No relationship at all** (just like **JavaScript** and **Java**). **TinkerScript** is open-source and free. **Tinker** is close-source and not free. -> + + +## Tinker 与 AgentJet-TinkerScript 对比表 + +| 特征 | Tinker | AgentJet-TinkerScript | +|------|--------|--------------| +| **开源性质** | ❌ 闭源 | **✅ 开源免费** | +| **收费模式** | 付费服务 | **✅ 完全免费** | +| **目标用户** | 研究人员和开发者 | 研究人员和开发者 | +| **任务** | 各种 LLM 训练 | 专精 LLM Agent RL训练 | +| **核心功能** | LLM 微调训练 API | **✅ LLM 微调训练整套解决方案** | +| **架构模式** | 托管服务 + 单点客户端 API | **✅ 服务器和客户端都可按需拓展** | +| **多客户端共同参与训练** | ❌ 不支持 | **✅ 支持** | +| **远程算力部署方式** | Thinking Machines Lab 公司提供定价 | **✅ 自建 GPU 服务器端 或 使用阿里云灵骏** | +| **训练方式** | ❌ LoRA 微调 | **✅ 全量 LLM 模型训练** | +| **支持的模型** | ❌ 少部分 LLM 模型 | **✅ 大多数新旧 LLM 模型** | +| **最大模型规模** | Llama 70B、Qwen 235B | **✅ 取决于用户 GPU 集群配置** | +| **通信协议** | 专有 API | **✅ 专有API + OpenAI兼容API** | +| **推理引擎后端** | 内置未知推理服务 | **✅ vLLM/SGLang任选** | + + ### 2. Core Architecture From c4b86a842a6ccfd9d1fa9b952ef9b02aa751bb1b Mon Sep 17 00:00:00 2001 From: "qingxu.fu" Date: Wed, 21 Jan 2026 10:58:16 +0800 Subject: [PATCH 08/25] remove trinity --- ajet/copilot/job.py | 2 +- .../experimental/as_tinkerscript_client.py | 2 +- ajet_tinkerscript_threading.py | 13 +++++++++++++ 3 files changed, 15 insertions(+), 2 deletions(-) diff --git a/ajet/copilot/job.py b/ajet/copilot/job.py index 373af631..082f6185 100644 --- a/ajet/copilot/job.py +++ b/ajet/copilot/job.py @@ -37,7 +37,7 @@ class AgentJetJob: def __init__( self, - backbone: str = "trinity", + backbone: str = "verl", model: str = "Qwen/Qwen2___5-7B-Instruct", n_gpu: int = 8, algorithm: str = "grpo", diff --git a/ajet/tuner_lib/weight_tuner/experimental/as_tinkerscript_client.py b/ajet/tuner_lib/weight_tuner/experimental/as_tinkerscript_client.py index 581042c8..c4a67e2e 100644 --- a/ajet/tuner_lib/weight_tuner/experimental/as_tinkerscript_client.py +++ b/ajet/tuner_lib/weight_tuner/experimental/as_tinkerscript_client.py @@ -93,7 +93,7 @@ def end_episode(self, episode_uuid: str, workflow_output: WorkflowOutput): except Exception as e: logger.error(f"Error ending episode: {e}") - def sync_train_config(self, agent_jet_job: AgentJetJob): + def sync_config(self, agent_jet_job: AgentJetJob): try: config_dict = agent_jet_job.config.to_dict() yaml_str = yaml.safe_dump(config_dict, sort_keys=False) diff --git a/ajet_tinkerscript_threading.py b/ajet_tinkerscript_threading.py index c0d35ea8..8d981624 100644 --- a/ajet_tinkerscript_threading.py +++ b/ajet_tinkerscript_threading.py @@ -1,6 +1,7 @@ import re import requests from textwrap import dedent +from ajet.copilot.job import AgentJetJob from ajet.tuner_lib.weight_tuner.experimental.as_tinkerscript_client import TinkerScriptClient from ajet.default_config.ajet_default import AjetTaskReader, HuggingfaceDatRepo from ajet.tuner_lib.weight_tuner.as_oai_baseurl_apikey import OpenaiBaseUrlAndApiKey @@ -32,6 +33,15 @@ def main(): ) tinkerscript_remote = TinkerScriptClient(TINKERJET_URL) + tinkerscript_remote.sync_config( + AgentJetJob( + n_gpu=2, + algorithm="grpo", + model='qwen/Qwen2.5-1.5B-instruct' + ) + ) + tinkerscript_remote.begin_engine() + # Define rollout def rollout(task): group_reward = [] @@ -54,6 +64,9 @@ def rollout(task): print(f"Submitting task for epoch {epoch}") executor.submit(rollout, task) + + model_path = tinkerscript_remote.download_latest_model(path='./tinkerscript_saved_model') + # Get tuned model from tinkerscript remote return None From a04001ac942983344f1820053d23b6a8ae90cf12 Mon Sep 17 00:00:00 2001 From: "qingxu.fu" Date: Wed, 21 Jan 2026 16:36:04 +0800 Subject: [PATCH 09/25] Add AgentJet image to TinkerScript documentation --- tinkerscript.md | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/tinkerscript.md b/tinkerscript.md index cfd67114..325ca26a 100644 --- a/tinkerscript.md +++ b/tinkerscript.md @@ -146,3 +146,10 @@ def rollout(task): tinkerjet_remote.end_episode(workflow_output) return workflow_output.reward ``` + + + From ae13326aa6f10cb8492b3c5c2a015dbf9fec92fd Mon Sep 17 00:00:00 2001 From: "qingxu.fu" Date: Wed, 28 Jan 2026 16:51:04 +0800 Subject: [PATCH 10/25] feat: implement TinkerScript server functionality and enhance configuration syncing --- ajet/copilot/job.py | 11 +- ajet/default_config/ajet_default.yaml | 2 + ajet/default_config/ajet_ts_default.yaml | 323 ++++++++++++++++++ ajet/launcher.py | 156 ++++----- .../experimental/as_oai_model_server.py | 22 +- .../experimental/as_tinkerscript_client.py | 31 +- .../experimental/as_tinkerscript_server.py | 121 ++++++- .../experimental/interchange_utils.py | 1 + ajet/utils/launch_utils.py | 96 ++++++ ajet_tinkerscript_threading.py | 9 +- 10 files changed, 658 insertions(+), 114 deletions(-) create mode 100644 ajet/default_config/ajet_ts_default.yaml diff --git a/ajet/copilot/job.py b/ajet/copilot/job.py index 082f6185..4e3be609 100644 --- a/ajet/copilot/job.py +++ b/ajet/copilot/job.py @@ -54,13 +54,14 @@ def __init__( self.config.ajet.trainer_common.algorithm.adv_estimator = algorithm if n_gpu_for_infer is None and backbone == "trinity": raise ValueError("Please specify `n_gpu_for_infer` (n_gpu_for_infer < n_gpu) for trinity backbone.") - if n_gpu_for_infer is not None and backbone == "verl": + if (n_gpu_for_infer is not None) and backbone == "verl": raise ValueError("n_gpu_for_infer is only for trinity backbone, please set it to `None`.") else: - assert isinstance(n_gpu_for_infer, int) - assert n_gpu_for_infer < n_gpu, "`n_gpu_for_infer` should be less than `n_gpu`." - self.config.ajet.rollout.n_vllm_engine = n_gpu_for_infer - self.config.ajet.rollout.tensor_model_parallel_size = 1 + if backbone == "trinity": + assert isinstance(n_gpu_for_infer, int), f"`n_gpu_for_infer` should be int, got {type(n_gpu_for_infer)}." + assert n_gpu_for_infer < n_gpu, "`n_gpu_for_infer` should be less than `n_gpu`." + self.config.ajet.rollout.n_vllm_engine = n_gpu_for_infer + self.config.ajet.rollout.tensor_model_parallel_size = 1 def build_job_from_yaml(self, yaml_path: str | None) -> dict: self.exp_name = datetime.now().strftime("ajet_job_%Y%m%d_%H%M%S") diff --git a/ajet/default_config/ajet_default.yaml b/ajet/default_config/ajet_default.yaml index f6ef2446..b5f3ac45 100644 --- a/ajet/default_config/ajet_default.yaml +++ b/ajet/default_config/ajet_default.yaml @@ -86,6 +86,7 @@ ajet: task_reader: + # how to read dataset / environment type: huggingface_dat_repo # `env_service` or `jsonl_dataset_file` or `huggingface_dat_repo` or `data_generation` or `random_dummy` # when `type == jsonl_dataset_file` @@ -284,6 +285,7 @@ ajet: enable_tinkerscript_mode: False # both tinkerscript / oai share the same interchange server enable_experimental_interchange_server: False + # interchange server configuration interchange_server: interchange_method: 'ipc' # options: 'tcp' (multi-nodes) or 'ipc' (1 node) interchange_server_port: 'auto' diff --git a/ajet/default_config/ajet_ts_default.yaml b/ajet/default_config/ajet_ts_default.yaml new file mode 100644 index 00000000..5e8837a1 --- /dev/null +++ b/ajet/default_config/ajet_ts_default.yaml @@ -0,0 +1,323 @@ +# ------------------ main configuration ------------------ +ajet: + project_name: "ajet_default_project" + experiment_name: "read_yaml_name" + experiment_dir: "auto" # {exp-dir}/{experiment_name} + backbone: debug # `debug` or `trinity` or `verl` + + + model: + # which model should be trained + path: /path/to/model/such/as/Qwen/Qwen2___5-14B-Instruct + + data: + # max number of tokens for prompt + max_prompt_length: 3000 + # max number of tokens for response + max_response_length: 15000 + # how many tasks per training batch + train_batch_size: 32 + # [Hint]: The final number of samples per update will be: N_{sample} = (data.train_batch_size * rollout.num_repeat * rollout.multi_turn.expected_steps) + + + rollout: + + # the path to the workflow class + user_workflow: tutorial.example_appworld.appworld->ExampleAgentScopeWorkflow + + # whether or not to disable all tool calls + force_disable_toolcalls: False + + # maximum number of parallel environments / simulate workers + max_env_worker: 64 + + # step reward gamma (experimental, do not change) + gamma: 1.0 + + # monitor LLM's abormal behaviors during rollout + compute_madness_checklist: + - "nonsense" + # send signal to terminate context tracing when LLM is losing control + agent_madness_termination: True # terminate_after_gone_mad + # punish the LLM when it is detected as lost control + agent_madness_reward: -1.0 + + # max response length in one turn + max_response_length_in_one_turn: 4096 + + # max token length allowed for the model during rollout + max_model_len: 18000 + + multi_turn: + # how many samples should be collected for each task run + max_sample_per_task: 30 + # limit the maximum steps for each task + max_steps: 30 + # the expected steps for each task, used to calculate the training batch size for trinity + expected_steps: 1 + + # TP size for rollout engine + tensor_model_parallel_size: 1 + + # the number of vllm engines, number of gpus for infer is `n_vllm_engine*tensor_model_parallel_size`, this argument is NOT effective when NOT using trinity + n_vllm_engine: 1 + + # how many sequences are allowed to be processed in parallel by each vllm engine + max_num_seqs: 10 + + # the usage of infer engine, options: (vllm, sglang) + name: vllm + + # how many times a task should be repeated + num_repeat: 4 + + # rollout kwargs + temperature: 0.9 + top_p: 1.0 + + # validation kwargs + val_kwargs: + # when doing validation, the sample setting when generating response + temperature: 0.0 + top_k: -1 + top_p: 1.0 + do_sample: False + num_repeat: 1 + + + task_reader: + # how to read dataset / environment + type: huggingface_dat_repo # `env_service` or `jsonl_dataset_file` or `huggingface_dat_repo` or `data_generation` or `random_dummy` + + # when `type == jsonl_dataset_file` + jsonl_dataset_file: + training: + file_path: "/path/to/training/data.jsonl" + validation: + file_path: "/path/to/validation/data.jsonl" + + # when `type == env_service` + env_service: + env_type: "appworld" + env_url: "http://127.0.0.1:8080" + env_action_preference: code # code, text, box + training_split: train + validation_split: dev + + # when `type == huggingface_dat_repo` + huggingface_dat_repo: + dataset_path: "gsm8k" + training_split: "train" + validation_split: "validation" + + # when `type == data_generation` + data_generation: + document_reader: + document_path: + - 'dataset/document/your-document1.pdf' + - 'dataset/document/your-document2.pdf' + languages: + - eng + chunk_size: 5120 + split_by: "sentence" + cache_enabled: true + query_reader: + type: jsonl_dataset_file + jsonl_dataset_file: + training: + file_path: 'dataset/jsonl/your-queries.jsonl' + task_num: 10 + llm_model: qwen-long + llm_response_length: 8192 + num_workers: 32 + sampling_params: + temperature: 0 + deduplication_filter: + enabled: true + params: + similarity_threshold: 0.8 + db_path: ./.similarity_db + model: text-embedding-v4 + api_key: null # load from the env + base_url: https://dashscope.aliyuncs.com/compatible-mode/v1 + + + task_judge: + judge_type: customized_protocol # Options: 'customized_protocol', 'rubrics_auto_grader' + + # when `judge_type == customized_protocol` + judge_protocol: ajet.task_judge.env_service_as_judge->EnvServiceJudge + + # the helper LLM model used for LLM-AS-Judge + alien_llm_model: qwen3-235b-a22b-instruct-2507 + alien_llm_response_length: 512 + + # when `judge_type == rubrics_auto_grader` + rubrics_auto_grader: + model_name: qwen-max + grader_mode: pointwise + language: en + query_specific_generate_number: 1 + enable_categorization: false + categories_number: 5 + grader_name: "auto_grader" + query_field: main_query + answer_field: final_answer + reference_field: answer + custom_evaluation_prompt: null # dict or PromptTemplate or None + input_data_type: jsonl_dataset_file # `env_service` or `jsonl_dataset_file` or `huggingface_dat_repo` + jsonl_dataset_file: + training: + file_path: "tutorial/example_rm_auto_grader/rubrics_train.jsonl" + # Pointwise mode settings + min_score: 0 + max_score: 1 + + + + # context tracker protocol is valid ONLY when `use_agentscope_protocol=False` + context_tracker: + + # timeline merging policy used in Context Tracker + timeline_merging_policy: + + # compare_level = "text": relaxed compare with text, more easier to match, at very little cost + # compare_level = "token": strict compare with token, cause less aggressive merging + timeline_compare_level: "text" # options: "text", "token" + + # whether or not to ignore tool calls when comparing steps, default to `True` to make merging more aggressive + ignore_tools: True + + # Fix Retokenization Drift: inconsistencies between training and inference token array + # Related reading: https://github.com/vllm-project/vllm/pull/22587 (note that the implementation is very different) + fix_retokenization_drift: True + + # log tool format check results + log_tool_format_check: False + + # log tool format check results + log_tool_format_error_detail: False + + # detect at which point timeline stop growing linearly and cause a snap during a episode: this will cause additional computation. + detect_timeline_snap: False + + # deprecated + alien_llm_model: qwen3-235b-a22b-instruct-2507 + + # deprecated + alien_llm_response_length: 512 + + + # when backbone is `debug`, debug related configurations + debug: + + # max parallel runners in debug mode + debug_max_parallel: 4 + + # how many task to sample from training set + debug_first_n_tasks: 2 + + # what is the vllm engine port in the background + debug_vllm_port: 18000 + + # what is the seed of the vllm engine in the background + debug_vllm_seed: 12345 + + # what is the TP size in debug mode + debug_tensor_parallel_size: 4 + + + # trainer common configurations + trainer_common: + + # validation before training + val_before_train: False + val_pass_n: 4 + + # save and test frequency (in step) + save_freq: 20 + test_freq: 20 + + # total training epochs + total_epochs: 50 + + nnodes: 1 + n_gpus_per_node: 8 + + # logger selection + logger: swanlab + + # algorithm setting + algorithm: + adv_estimator: grpo + use_kl_in_reward: False + + # number of optimizer.step per big batch + mini_batch_num: 1 + + # verl offload configs + fsdp_config: + param_offload: True + optimizer_offload: True + + # learning rate + optim: + lr: 1e-6 + + # enable KL loss regularization + use_kl_loss: True + + # kl divergence loss coefficient + kl_loss_coef: 0.002 + kl_loss_type: low_var_kl + + # Ulysses specific configs + ulysses_sequence_parallel_size: 1 + + # base directory to save checkpoints + checkpoint_base_dir: ./saved_checkpoints + + # whether to save train/eval trajectories to JSON files + save_trajectory_as_json_file: False + + + # the experimental ZeroMQ interchange server feature that allows `tuner.as_oai_baseurl_apikey` feature + enable_tinkerscript_mode: True + # both tinkerscript / oai share the same interchange server + enable_experimental_interchange_server: True + # interchange server configuration + interchange_server: + interchange_method: 'ipc' # options: 'tcp' (multi-nodes) or 'ipc' (1 node) + interchange_server_port: 10086 + num_fastapi_process: 4 # 1, 2 or 4 is fine + max_fastapi_threads: 128 # 64 or 128 is fine + max_inference_tracker_threads: 64 # recommend to be equal to `ajet.rollout.max_env_worker` + + + task_runner: + # submit llm infer submit method + llm_infer_submit_method: "async" # options: "sync", "async" + + # how to wrap the user-defined workflow + wrapper_type: "asyncio-with-gc" + # - wrapper_type: "asyncio-with-gc": safe, with periodic garbage collection to prevent event loop leaks (recommended) + # - wrapper_type: "asyncio": fast, but may cause event loop leak in long run + # - wrapper_type: "multi-processing": safe, but resource consuming + + # when `wrapper_type` is `multi-processing`, the timeout for each task + wrapper_multiprocessing_timeout: 3600 # in seconds + + # DO NOT EDIT, FOR ROBOT TESTING PURPOSE ONLY. NOT FOR HUMAN. + execute_test: False # DO NOT EDIT, FOR ROBOT TESTING PURPOSE ONLY. NOT FOR HUMAN. + execute_testing_lambda: "" # DO NOT EDIT, FOR ROBOT TESTING PURPOSE ONLY. NOT FOR HUMAN. + + +# ------------------ do not edit ------------------ +hydra: + searchpath: + - file://ajet/default_config/verl + +# ------------------ do not edit ------------------ +defaults: + - verl_default # verl inherit 1/1 + - _self_ diff --git a/ajet/launcher.py b/ajet/launcher.py index 40557137..caf739ff 100644 --- a/ajet/launcher.py +++ b/ajet/launcher.py @@ -1,6 +1,7 @@ import argparse import os import subprocess +from types import SimpleNamespace from dotenv import load_dotenv from loguru import logger @@ -12,6 +13,9 @@ launch_logview, set_loguru_default_color, start_ray_service, + check_debugpy_version, + check_avail_gpu, + dict_to_namespace, ) from ajet.utils.pty import pty_launch @@ -28,6 +32,12 @@ def parse_args(): required=False, help="verl or trinity or debug", ) + parser.add_argument( + "--tinkerscript-server", + action="store_true", + default=False, + help="Enable TinkerScript server mode", + ) parser.add_argument( "--conf", type=str, @@ -50,9 +60,18 @@ def parse_args(): required=False, help="Path to configuration file", ) - - parser.add_argument("--with-ray", action="store_true", default=False, help="Launch ray") - parser.add_argument("--with-ray-cluster", action="store_true", default=False, help="Launch ray") + parser.add_argument( + "--with-ray", + action="store_true", + default=False, + help="Launch ray" + ) + parser.add_argument( + "--with-ray-cluster", + action="store_true", + default=False, + help="Launch ray" + ) parser.add_argument( "--with-appworld", action="store_true", @@ -71,7 +90,12 @@ def parse_args(): default=False, help="Launch webshop", ) - parser.add_argument("--with-bfcl", action="store_true", default=False, help="Launch bfcl") + parser.add_argument( + "--with-bfcl", + action="store_true", + default=False, + help="Launch bfcl" + ) parser.add_argument( "--with-logview", action="store_true", @@ -84,8 +108,12 @@ def parse_args(): default=False, help="Launch Crafters Env Simulation", ) - parser.add_argument("--reboot", action="store_true", default=False, help="reboot flag") - parser.add_argument("--skip-check-avail-gpu", action="store_true", default=False, help="Skip GPU availability check") + parser.add_argument( + "--skip-check-avail-gpu", + action="store_true", + default=False, + help="Skip GPU availability check" + ) parser.add_argument( "--kill", type=str, @@ -99,92 +127,14 @@ def parse_args(): default=False, help="Kill system processes (ray + vllm + python) that may block the current experiment", ) - parser.add_argument("--prefix", type=str, default="", required=False, help="Prefix for deepfinance service names") - return parser.parse_args() - - -def check_debugpy_version(): - try: - import debugpy - except ImportError: - raise RuntimeError( - "Module 'debugpy>=1.8.0' cannot be loaded. " - "Ray Debugpy Debugger will not work without 'debugpy>=1.8.0' installed. " - "Install this module using 'pip install debugpy>=1.8.0'" - ) - version = getattr(debugpy, "__version__", "0.0.0") - from packaging import version as packaging_version - - if packaging_version.parse(version) < packaging_version.parse("1.8.0"): - raise RuntimeError( - f"debugpy version {version} is too old. " - "Ray Debugpy Debugger requires 'debugpy>=1.8.0'. " - "Upgrade using 'pip install debugpy>=1.8.0'" - ) - logger.info(f"✓ debugpy version {version} meets requirement (>=1.8.0)") - - -def check_avail_gpu(min_free_ratio: float = 0.95): - """ - Ensure there is at least one GPU and all GPUs have >= min_free_ratio free memory. - - Uses `nvidia-smi` to query total and used memory for each GPU. - Raises RuntimeError if no GPU is found or any GPU violates the free ratio threshold. - """ - try: - # Query GPU memory via nvidia-smi; output in MiB - result = subprocess.run( - [ - "nvidia-smi", - "--query-gpu=name,memory.total,memory.used", - "--format=csv,noheader,nounits", - ], - stdout=subprocess.PIPE, - stderr=subprocess.PIPE, - text=True, - check=False, - ) - except FileNotFoundError: - raise RuntimeError("nvidia-smi not found. NVIDIA drivers/GPU may be unavailable.") - - if result.returncode != 0: - raise RuntimeError(f"Failed to query GPUs via nvidia-smi: {result.stderr.strip()}") - - lines = [line.strip() for line in result.stdout.splitlines() if line.strip()] - if not lines: - raise RuntimeError("No GPUs detected by nvidia-smi.") - - violations = [] - for idx, line in enumerate(lines): - # Expected format: ", , " - parts = [p.strip() for p in line.split(",")] - if len(parts) < 3: - violations.append((idx, "parse-error", line)) - continue - name, total_str, used_str = parts[0], parts[1], parts[2] - try: - total = float(total_str) - used = float(used_str) - except ValueError: - violations.append((idx, "parse-error", line)) - continue - free = max(total - used, 0.0) - free_ratio = free / total if total > 0 else 0.0 - logger.info( - f"GPU {idx} ({name}): total={total:.0f} MiB, used={used:.0f} MiB, free_ratio={free_ratio:.3f}" - ) - if free_ratio < min_free_ratio: - violations.append((idx, name, f"free_ratio={free_ratio:.3f} < {min_free_ratio:.3f}")) - - if violations: - details = "; ".join([f"GPU {i} ({n}): {msg}" for i, n, msg in violations]) - raise RuntimeError( - "GPU memory check failed: all GPUs must have >= " - f"{int(min_free_ratio*100)}% free. Violations: {details}" - ) - logger.info( - f"✓ GPU check passed: {len(lines)} GPUs, all >= {int(min_free_ratio*100)}% free memory" + parser.add_argument( + "--prefix", + type=str, + default="", + required=False, + help="Prefix for deepfinance service names" ) + return parser.parse_args() def get_backbone_target(backbone): @@ -246,9 +196,17 @@ def check_model_file_exists(exp_config): model_path = exp_config["ajet"]["model"]["path"] # if model_path has more than 2 '/', we consider it as a dir path if model_path.count("/") > 2: - assert os.path.exists( - model_path - ), f"Model path {model_path} does not exist. Please check your configuration." + assert os.path.exists(model_path), f"Model path {model_path} does not exist. Please check your configuration." + + +def start_tinkerscript_server(env, config): + config = dict_to_namespace(config) + assert config.ajet.enable_tinkerscript_mode, \ + "Please enable_tinkerscript_mode in config to start tinkerscript server." + assert config.ajet.enable_experimental_interchange_server, \ + "Please enable_experimental_interchange_server in config to start tinkerscript server." + from ajet.tuner_lib.weight_tuner.experimental.as_oai_model_server import start_interchange_server + start_interchange_server(config, blocking=True) def main(): @@ -283,8 +241,12 @@ def main(): # switch backbone target backbone_target = get_backbone_target(args.backbone) + # read configuration from yaml exp_config = None exp_dir = args.exp_dir or "saved_experiments" + if args.tinkerscript_server and (not args.conf): + args.conf = os.path.abspath(os.path.join(os.path.dirname(__file__), "default_config/ajet_ts_default.yaml")) + assert os.path.exists(args.conf), "Please provide a valid config file for tinkerscript server mode." if args.conf: yaml_path = args.conf ( @@ -294,7 +256,13 @@ def main(): exp_config, ) = prepare_experiment_config(yaml_path, exp_dir, args.backbone) + # setup environment variables env, exp_config = setup_environment_vars(args, exp_config, main_yaml_fp) + + if args.tinkerscript_server: + start_tinkerscript_server(env, exp_config) + return + if args.with_ray: assert ( not args.with_ray_cluster diff --git a/ajet/tuner_lib/weight_tuner/experimental/as_oai_model_server.py b/ajet/tuner_lib/weight_tuner/experimental/as_oai_model_server.py index 98aa4192..72080769 100644 --- a/ajet/tuner_lib/weight_tuner/experimental/as_oai_model_server.py +++ b/ajet/tuner_lib/weight_tuner/experimental/as_oai_model_server.py @@ -34,7 +34,7 @@ from openai.types.chat.chat_completion import ChatCompletion from ajet.tuner_lib.weight_tuner.experimental.interchange_utils import EpisodeStatus - +from ajet.utils.networking import find_free_port, get_host_ip API_KEY_PREFIX = "sk-ajet-" class InterchangeCompletionRequest(BaseModel): @@ -271,7 +271,7 @@ async def serve_with_monitor(additional_coro): # Convenience function for quick server startup -def start_interchange_server(config) -> int: +def start_interchange_server(config, blocking=False) -> int: experiment_dir = config.ajet.experiment_dir num_fastapi_process = config.ajet.interchange_server.num_fastapi_process max_fastapi_threads = config.ajet.interchange_server.max_fastapi_threads @@ -299,8 +299,12 @@ def start_interchange_server(config) -> int: interchange_server.start() # Wait for server to be ready - health_url = f"http://localhost:{port}/health" + health_url = f"http://127.0.0.1:{port}/health" start_time = time.time() + localhost_url = f"http://127.0.0.1:{port}" + host_ip = get_host_ip(os.environ.get("NETWORK_INTERFACE", None)) + host_url = f"http://{host_ip}:{port}" + while True: if interchange_server.exitcode is not None: logger.error(f"Interchange server subprocess failed to start. Return code: {interchange_server.exitcode}") @@ -321,5 +325,13 @@ def start_interchange_server(config) -> int: if DEBUG: logger.info(f"Interchange server subprocess started on port {port} (pid: {interchange_server.pid})") atexit.register(lambda: interchange_server.terminate()) - # return port - return port + if not blocking: + # return port + return port + else: + logger.success(f"Interchange server is running in blocking mode on:\n------\n" + f"URL 1: {localhost_url}\n------\n" + f"URL 2: {host_url}\n------\n" + f"Press Ctrl+C to stop.") + interchange_server.join() + return -1 \ No newline at end of file diff --git a/ajet/tuner_lib/weight_tuner/experimental/as_tinkerscript_client.py b/ajet/tuner_lib/weight_tuner/experimental/as_tinkerscript_client.py index c4a67e2e..cb888671 100644 --- a/ajet/tuner_lib/weight_tuner/experimental/as_tinkerscript_client.py +++ b/ajet/tuner_lib/weight_tuner/experimental/as_tinkerscript_client.py @@ -93,7 +93,11 @@ def end_episode(self, episode_uuid: str, workflow_output: WorkflowOutput): except Exception as e: logger.error(f"Error ending episode: {e}") - def sync_config(self, agent_jet_job: AgentJetJob): + def sync_train_config(self, agent_jet_job: AgentJetJob): + """ + Sync training configuration to the TinkerScript server. + This sends the AgentJetJob config as YAML to the remote server. + """ try: config_dict = agent_jet_job.config.to_dict() yaml_str = yaml.safe_dump(config_dict, sort_keys=False) @@ -106,9 +110,32 @@ def sync_config(self, agent_jet_job: AgentJetJob): timeout=30 ) resp.raise_for_status() - logger.info("Synced train config") + logger.info("Synced train config to TinkerScript server") except Exception as e: logger.error(f"Error syncing train config: {e}") + raise + + def start_engine(self): + """ + Start the training engine on the TinkerScript server. + This triggers the server to begin the training process. + """ + try: + resp = httpx.post( + f"{self.server_url}/start_engine", + json={}, + timeout=30 + ) + resp.raise_for_status() + result = resp.json() + if result.get("success"): + logger.info("Successfully started training engine on TinkerScript server") + else: + logger.error("Failed to start training engine") + raise RuntimeError("Failed to start training engine") + except Exception as e: + logger.error(f"Error starting engine: {e}") + raise def get_engine_status(self) -> str: try: diff --git a/ajet/tuner_lib/weight_tuner/experimental/as_tinkerscript_server.py b/ajet/tuner_lib/weight_tuner/experimental/as_tinkerscript_server.py index 3d419514..113178db 100644 --- a/ajet/tuner_lib/weight_tuner/experimental/as_tinkerscript_server.py +++ b/ajet/tuner_lib/weight_tuner/experimental/as_tinkerscript_server.py @@ -1,8 +1,11 @@ +import multiprocessing import time from multiprocessing.managers import DictProxy import threading +from types import SimpleNamespace import zmq +import asyncio from loguru import logger from fastapi import FastAPI, HTTPException @@ -24,7 +27,6 @@ UpdateEngineStatusRequest, ) - DEBUG = True def register_enable_tinkerscript_mode_routes( @@ -42,12 +44,119 @@ def register_enable_tinkerscript_mode_routes( @app.post("/sync_train_config") async def sync_train_config(req: SyncTrainConfigRequest): - # dummy: just print the yaml string + """ + Receive training configuration from client as YAML string. + Store it in shared memory for later use by start_engine. + """ + try: + yaml_str = req.yaml_as_string + logger.info("[sync_train_config] Received training configuration") + if DEBUG: + logger.debug(f"[sync_train_config] YAML content:\n{yaml_str}...") + + # Store the YAML config in shared memory for start_engine to use + with shared_mem_dict_lock: + shared_mem_dict['train_config_yaml'] = yaml_str + + logger.info("[sync_train_config] Successfully stored training configuration") + return {"success": True} + except Exception as e: + logger.error(f"[sync_train_config] Error: {e}") + return {"success": False, "error": str(e)} + + + @app.post("/start_engine") + async def start_engine(): + """ + Start the training engine using the previously synced configuration. + This creates a temporary YAML file and spawns a training process. + """ try: - print("[sync_train_config] received yaml:", req.yaml_as_string) - except Exception: - pass - return {"success": True} + from ajet.utils.launch_utils import execute_training_process + from ajet.launcher import ( + get_backbone_target, + setup_environment_vars, + ) + from ajet.utils.config_utils import ( + prepare_experiment_config, + ) + import ray + import tempfile + import yaml as yaml_module + + # Check if config has been synced + if 'train_config_yaml' not in shared_mem_dict: + logger.error("[start_engine] No training config found. Please call sync_train_config first.") + return {"success": False, "error": "No training config found"} + + yaml_str = shared_mem_dict['train_config_yaml'] + + # Parse YAML to get backbone + config_dict = yaml_module.safe_load(yaml_str) + backbone = config_dict.get('ajet', {}).get('backbone', 'verl') + exp_dir_final = config_dict.get('ajet', {}).get('experiment_dir', 'saved_experiments') + + # Save YAML to temporary file + with tempfile.NamedTemporaryFile(mode='w', delete=False, suffix='.yaml') as temp_file: + temp_file.write(yaml_str) + main_yaml_fp = temp_file.name + + logger.info(f"[start_engine] Saved config to temporary file: {main_yaml_fp}") + + # Create args namespace + args = SimpleNamespace( + conf=main_yaml_fp, + backbone=backbone, + exp_dir=exp_dir_final, + with_logview=False, + debug=False, + ) + + # Finalize experiment config + main_yaml_fp, exe_exp_base, exp_name, exp_config = prepare_experiment_config( + main_yaml_fp, exp_dir_final, backbone + ) + + # Setup environment variables + env = setup_environment_vars(args, exp_config, main_yaml_fp) + + # Start ray if not already started + if not ray.is_initialized(): + from ajet.utils.launch_utils import start_ray_service + logger.info("[start_engine] Starting Ray service...") + start_ray_service(args, env) + else: + logger.info("[start_engine] Ray already initialized") + + # Start training process in a separate process + p = multiprocessing.Process( + target=execute_training_process, + args=( + args, + get_backbone_target(args.backbone), + main_yaml_fp, + exe_exp_base, + main_yaml_fp, + env, + exp_config, + ) + ) + p.daemon = True + p.start() + + # Store process info in shared memory + with shared_mem_dict_lock: + shared_mem_dict['training_process_pid'] = p.pid + shared_mem_dict['engine_status'] = "running" + + logger.info(f"[start_engine] Successfully started training process (PID: {p.pid})") + return {"success": True, "pid": p.pid} + + except Exception as e: + logger.error(f"[start_engine] Error starting engine: {e}") + import traceback + traceback.print_exc() + return {"success": False, "error": str(e)} # --- engine status --- diff --git a/ajet/tuner_lib/weight_tuner/experimental/interchange_utils.py b/ajet/tuner_lib/weight_tuner/experimental/interchange_utils.py index 35728e31..ab6d816e 100644 --- a/ajet/tuner_lib/weight_tuner/experimental/interchange_utils.py +++ b/ajet/tuner_lib/weight_tuner/experimental/interchange_utils.py @@ -79,6 +79,7 @@ def get_interchange_server_url(config): def http_change_engine_status(config: str, new_status: str): + # STATES = ["ENGINE.BOOTING", "ENGINE.ROLLING", "ENGINE.WEIGHT_SYNCING", "ENGINE.WEIGHT_EXPORTING"] resp = httpx.post( f"{get_interchange_server_url(config)}/update_engine_status", json={"engine_status": new_status}, diff --git a/ajet/utils/launch_utils.py b/ajet/utils/launch_utils.py index 978539ab..0fa64e6f 100644 --- a/ajet/utils/launch_utils.py +++ b/ajet/utils/launch_utils.py @@ -7,6 +7,7 @@ from beast_logger import print_dict from loguru import logger +from types import SimpleNamespace from ajet.utils.config_utils import align_parameters from ajet.utils.smart_daemon import LaunchCommandWhenAbsent @@ -25,6 +26,101 @@ def set_loguru_default_color(): return + +def check_debugpy_version(): + try: + import debugpy + except ImportError: + raise RuntimeError( + "Module 'debugpy>=1.8.0' cannot be loaded. " + "Ray Debugpy Debugger will not work without 'debugpy>=1.8.0' installed. " + "Install this module using 'pip install debugpy>=1.8.0'" + ) + version = getattr(debugpy, "__version__", "0.0.0") + from packaging import version as packaging_version + + if packaging_version.parse(version) < packaging_version.parse("1.8.0"): + raise RuntimeError( + f"debugpy version {version} is too old. " + "Ray Debugpy Debugger requires 'debugpy>=1.8.0'. " + "Upgrade using 'pip install debugpy>=1.8.0'" + ) + logger.info(f"✓ debugpy version {version} meets requirement (>=1.8.0)") + + +def check_avail_gpu(min_free_ratio: float = 0.95): + """ + Ensure there is at least one GPU and all GPUs have >= min_free_ratio free memory. + + Uses `nvidia-smi` to query total and used memory for each GPU. + Raises RuntimeError if no GPU is found or any GPU violates the free ratio threshold. + """ + try: + # Query GPU memory via nvidia-smi; output in MiB + result = subprocess.run( + [ + "nvidia-smi", + "--query-gpu=name,memory.total,memory.used", + "--format=csv,noheader,nounits", + ], + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + text=True, + check=False, + ) + except FileNotFoundError: + raise RuntimeError("nvidia-smi not found. NVIDIA drivers/GPU may be unavailable.") + + if result.returncode != 0: + raise RuntimeError(f"Failed to query GPUs via nvidia-smi: {result.stderr.strip()}") + + lines = [line.strip() for line in result.stdout.splitlines() if line.strip()] + if not lines: + raise RuntimeError("No GPUs detected by nvidia-smi.") + + violations = [] + for idx, line in enumerate(lines): + # Expected format: ", , " + parts = [p.strip() for p in line.split(",")] + if len(parts) < 3: + violations.append((idx, "parse-error", line)) + continue + name, total_str, used_str = parts[0], parts[1], parts[2] + try: + total = float(total_str) + used = float(used_str) + except ValueError: + violations.append((idx, "parse-error", line)) + continue + free = max(total - used, 0.0) + free_ratio = free / total if total > 0 else 0.0 + logger.info( + f"GPU {idx} ({name}): total={total:.0f} MiB, used={used:.0f} MiB, free_ratio={free_ratio:.3f}" + ) + if free_ratio < min_free_ratio: + violations.append((idx, name, f"free_ratio={free_ratio:.3f} < {min_free_ratio:.3f}")) + + if violations: + details = "; ".join([f"GPU {i} ({n}): {msg}" for i, n, msg in violations]) + raise RuntimeError( + "GPU memory check failed: all GPUs must have >= " + f"{int(min_free_ratio*100)}% free. Violations: {details}" + ) + logger.info( + f"✓ GPU check passed: {len(lines)} GPUs, all >= {int(min_free_ratio*100)}% free memory" + ) + + +def dict_to_namespace(d): + """Recursively convert a nested dictionary to a SimpleNamespace.""" + if isinstance(d, dict): + return SimpleNamespace(**{k: dict_to_namespace(v) for k, v in d.items()}) + elif isinstance(d, list): # 如果字典中嵌套了列表,递归处理列表中的每个元素 + return [dict_to_namespace(item) for item in d] + else: + return d + + def launch_logview(exp_name=None): """ Launch the log viewer service and open the web browser to view logs. diff --git a/ajet_tinkerscript_threading.py b/ajet_tinkerscript_threading.py index 8d981624..b02ddb9c 100644 --- a/ajet_tinkerscript_threading.py +++ b/ajet_tinkerscript_threading.py @@ -1,5 +1,6 @@ import re import requests +import time from textwrap import dedent from ajet.copilot.job import AgentJetJob from ajet.tuner_lib.weight_tuner.experimental.as_tinkerscript_client import TinkerScriptClient @@ -33,14 +34,18 @@ def main(): ) tinkerscript_remote = TinkerScriptClient(TINKERJET_URL) - tinkerscript_remote.sync_config( + print("TinkerScript remote handshake.") + tinkerscript_remote.sync_train_config( AgentJetJob( n_gpu=2, algorithm="grpo", model='qwen/Qwen2.5-1.5B-instruct' ) ) - tinkerscript_remote.begin_engine() + print("TinkerScript remote handshake and train config sync done.") + tinkerscript_remote.start_engine() + print("TinkerScript remote engine started.") + time.sleep(1000) # Define rollout def rollout(task): From 5cc72976fc94069ff0cd5e0fbf2761bd409dbf55 Mon Sep 17 00:00:00 2001 From: "qingxu.fu" Date: Wed, 28 Jan 2026 18:36:46 +0800 Subject: [PATCH 11/25] feat: enhance TinkerScript integration with improved engine status handling and configuration updates --- ajet/backbone/main_verl.py | 2 +- ajet/backbone/main_vllm.py | 2 +- ajet/backbone/trainer_verl.py | 4 +- ajet/copilot/job.py | 9 +- ajet/default_config/ajet_default.py | 1 + ajet/default_config/ajet_default.yaml | 1 + ajet/default_config/ajet_ts_default.yaml | 303 +----------------- .../experimental/as_oai_model_server.py | 44 ++- .../experimental/as_tinkerscript_client.py | 29 +- .../experimental/as_tinkerscript_server.py | 27 +- .../experimental/interchange_utils.py | 11 +- ajet/utils/core_env_vars.py | 5 +- ajet_tinkerscript_threading.py | 74 +++-- 13 files changed, 165 insertions(+), 347 deletions(-) diff --git a/ajet/backbone/main_verl.py b/ajet/backbone/main_verl.py index dcd575f4..47a48cfc 100644 --- a/ajet/backbone/main_verl.py +++ b/ajet/backbone/main_verl.py @@ -22,7 +22,6 @@ import hydra import ray from beast_logger import print_dict -from loguru import logger from omegaconf import OmegaConf from verl.trainer.ppo.reward import load_reward_manager from verl.utils.device import is_cuda_available @@ -110,6 +109,7 @@ def run(self, config): # Print the initial configuration. `resolve=True` will evaluate symbolic values. from pprint import pprint + from loguru import logger from omegaconf import OmegaConf from verl.utils.fs import copy_to_local diff --git a/ajet/backbone/main_vllm.py b/ajet/backbone/main_vllm.py index 7e63e216..c697bff4 100644 --- a/ajet/backbone/main_vllm.py +++ b/ajet/backbone/main_vllm.py @@ -191,7 +191,7 @@ def main(config): start_interchange_server(config) if config.ajet.enable_tinkerscript_mode: from ajet.tuner_lib.weight_tuner.experimental.interchange_utils import http_change_engine_status - http_change_engine_status(config, "ROLLING") + http_change_engine_status(config, "ENGINE.ROLLING") def companion_launch(): import torch diff --git a/ajet/backbone/trainer_verl.py b/ajet/backbone/trainer_verl.py index 9261fc91..e34581f6 100644 --- a/ajet/backbone/trainer_verl.py +++ b/ajet/backbone/trainer_verl.py @@ -559,7 +559,7 @@ def fit(self): # noqa: C901 assert self.async_rollout_mode logger.info("=== wake up begin ===") self.async_rollout_manager.wake_up() - self._update_interchange_server_status_flag("ROLLING") + self._update_interchange_server_status_flag("ENGINE.ROLLING") logger.info("=== wake up end ===") tasks: List[Task] = [ dict_to_ajet_task(dict( @@ -585,7 +585,7 @@ def fit(self): # noqa: C901 tasks, mode="sample", epoch=f"train.{epoch}" ) logger.info("=" * 10 + "end fit rollout" + "=" * 10) - self._update_interchange_server_status_flag("UPDATE_WEIGHT") + self._update_interchange_server_status_flag("ENGINE.WEIGHT_SYNCING") logger.info("begin to convert context_tracker_arr to dataproto") gen_batch_output = self.parallel_env.to_dataproto(context_tracker_arr) logger.info("end convertion") diff --git a/ajet/copilot/job.py b/ajet/copilot/job.py index 4e3be609..190d1345 100644 --- a/ajet/copilot/job.py +++ b/ajet/copilot/job.py @@ -42,16 +42,23 @@ def __init__( n_gpu: int = 8, algorithm: str = "grpo", n_gpu_for_infer: int | None = None, # only for trinity backbone + grpo_n: int = 8, + tinkerscript_mode: bool = True, *kwargs, ) -> None: self.backbone = backbone - self.config_as_dict: dict = self.build_job_from_yaml(None) + if tinkerscript_mode: + default_yaml = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', "default_config/ajet_ts_default.yaml")) + else: + default_yaml = None + self.config_as_dict: dict = self.build_job_from_yaml(default_yaml) self.config = Config.update_from_dict_recursive(Config(), self.config_as_dict) self.config.ajet.backbone = backbone self.config.ajet.model.path = model self.config.ajet.trainer_common.n_gpus_per_node = n_gpu self.config.ajet.trainer_common.algorithm.adv_estimator = algorithm + self.config.ajet.rollout.num_repeat = grpo_n if n_gpu_for_infer is None and backbone == "trinity": raise ValueError("Please specify `n_gpu_for_infer` (n_gpu_for_infer < n_gpu) for trinity backbone.") if (n_gpu_for_infer is not None) and backbone == "verl": diff --git a/ajet/default_config/ajet_default.py b/ajet/default_config/ajet_default.py index 9d0732e9..18ff2def 100644 --- a/ajet/default_config/ajet_default.py +++ b/ajet/default_config/ajet_default.py @@ -30,6 +30,7 @@ class AjetRollout: user_workflow: str = "tutorial.example_appworld.appworld->ExampleAgentScopeWorkflow" n_vllm_engine: int = 1 tensor_model_parallel_size: int = 1 + num_repeat: int = 8 @dataclass diff --git a/ajet/default_config/ajet_default.yaml b/ajet/default_config/ajet_default.yaml index b5f3ac45..9def61da 100644 --- a/ajet/default_config/ajet_default.yaml +++ b/ajet/default_config/ajet_default.yaml @@ -292,6 +292,7 @@ ajet: num_fastapi_process: 2 # 1, 2 or 4 is fine max_fastapi_threads: 128 # 64 or 128 is fine max_inference_tracker_threads: 64 # recommend to be equal to `ajet.rollout.max_env_worker` + already_started: False # do not edit, used by `tinkerscript` task_runner: diff --git a/ajet/default_config/ajet_ts_default.yaml b/ajet/default_config/ajet_ts_default.yaml index 5e8837a1..1b8cc3a1 100644 --- a/ajet/default_config/ajet_ts_default.yaml +++ b/ajet/default_config/ajet_ts_default.yaml @@ -5,319 +5,44 @@ ajet: experiment_dir: "auto" # {exp-dir}/{experiment_name} backbone: debug # `debug` or `trinity` or `verl` - model: # which model should be trained - path: /path/to/model/such/as/Qwen/Qwen2___5-14B-Instruct - - data: - # max number of tokens for prompt - max_prompt_length: 3000 - # max number of tokens for response - max_response_length: 15000 - # how many tasks per training batch - train_batch_size: 32 - # [Hint]: The final number of samples per update will be: N_{sample} = (data.train_batch_size * rollout.num_repeat * rollout.multi_turn.expected_steps) - + path: /mnt/data_cpfs/model_cache/modelscope/hub/Qwen/Qwen/Qwen2___5-14B-Instruct rollout: - # the path to the workflow class - user_workflow: tutorial.example_appworld.appworld->ExampleAgentScopeWorkflow - - # whether or not to disable all tool calls - force_disable_toolcalls: False - - # maximum number of parallel environments / simulate workers - max_env_worker: 64 - - # step reward gamma (experimental, do not change) - gamma: 1.0 - - # monitor LLM's abormal behaviors during rollout - compute_madness_checklist: - - "nonsense" - # send signal to terminate context tracing when LLM is losing control - agent_madness_termination: True # terminate_after_gone_mad - # punish the LLM when it is detected as lost control - agent_madness_reward: -1.0 - - # max response length in one turn - max_response_length_in_one_turn: 4096 - - # max token length allowed for the model during rollout - max_model_len: 18000 - - multi_turn: - # how many samples should be collected for each task run - max_sample_per_task: 30 - # limit the maximum steps for each task - max_steps: 30 - # the expected steps for each task, used to calculate the training batch size for trinity - expected_steps: 1 - - # TP size for rollout engine - tensor_model_parallel_size: 1 - - # the number of vllm engines, number of gpus for infer is `n_vllm_engine*tensor_model_parallel_size`, this argument is NOT effective when NOT using trinity - n_vllm_engine: 1 - - # how many sequences are allowed to be processed in parallel by each vllm engine - max_num_seqs: 10 - - # the usage of infer engine, options: (vllm, sglang) - name: vllm - - # how many times a task should be repeated - num_repeat: 4 - - # rollout kwargs - temperature: 0.9 - top_p: 1.0 - - # validation kwargs - val_kwargs: - # when doing validation, the sample setting when generating response - temperature: 0.0 - top_k: -1 - top_p: 1.0 - do_sample: False - num_repeat: 1 - + user_workflow: null task_reader: - # how to read dataset / environment - type: huggingface_dat_repo # `env_service` or `jsonl_dataset_file` or `huggingface_dat_repo` or `data_generation` or `random_dummy` - - # when `type == jsonl_dataset_file` - jsonl_dataset_file: - training: - file_path: "/path/to/training/data.jsonl" - validation: - file_path: "/path/to/validation/data.jsonl" - - # when `type == env_service` - env_service: - env_type: "appworld" - env_url: "http://127.0.0.1:8080" - env_action_preference: code # code, text, box - training_split: train - validation_split: dev - - # when `type == huggingface_dat_repo` - huggingface_dat_repo: - dataset_path: "gsm8k" - training_split: "train" - validation_split: "validation" - - # when `type == data_generation` - data_generation: - document_reader: - document_path: - - 'dataset/document/your-document1.pdf' - - 'dataset/document/your-document2.pdf' - languages: - - eng - chunk_size: 5120 - split_by: "sentence" - cache_enabled: true - query_reader: - type: jsonl_dataset_file - jsonl_dataset_file: - training: - file_path: 'dataset/jsonl/your-queries.jsonl' - task_num: 10 - llm_model: qwen-long - llm_response_length: 8192 - num_workers: 32 - sampling_params: - temperature: 0 - deduplication_filter: - enabled: true - params: - similarity_threshold: 0.8 - db_path: ./.similarity_db - model: text-embedding-v4 - api_key: null # load from the env - base_url: https://dashscope.aliyuncs.com/compatible-mode/v1 - + type: random_dummy # `env_service` or `jsonl_dataset_file` or `huggingface_dat_repo` or `data_generation` or `random_dummy` task_judge: judge_type: customized_protocol # Options: 'customized_protocol', 'rubrics_auto_grader' - - # when `judge_type == customized_protocol` - judge_protocol: ajet.task_judge.env_service_as_judge->EnvServiceJudge - - # the helper LLM model used for LLM-AS-Judge - alien_llm_model: qwen3-235b-a22b-instruct-2507 - alien_llm_response_length: 512 - - # when `judge_type == rubrics_auto_grader` - rubrics_auto_grader: - model_name: qwen-max - grader_mode: pointwise - language: en - query_specific_generate_number: 1 - enable_categorization: false - categories_number: 5 - grader_name: "auto_grader" - query_field: main_query - answer_field: final_answer - reference_field: answer - custom_evaluation_prompt: null # dict or PromptTemplate or None - input_data_type: jsonl_dataset_file # `env_service` or `jsonl_dataset_file` or `huggingface_dat_repo` - jsonl_dataset_file: - training: - file_path: "tutorial/example_rm_auto_grader/rubrics_train.jsonl" - # Pointwise mode settings - min_score: 0 - max_score: 1 - - - - # context tracker protocol is valid ONLY when `use_agentscope_protocol=False` - context_tracker: - - # timeline merging policy used in Context Tracker - timeline_merging_policy: - - # compare_level = "text": relaxed compare with text, more easier to match, at very little cost - # compare_level = "token": strict compare with token, cause less aggressive merging - timeline_compare_level: "text" # options: "text", "token" - - # whether or not to ignore tool calls when comparing steps, default to `True` to make merging more aggressive - ignore_tools: True - - # Fix Retokenization Drift: inconsistencies between training and inference token array - # Related reading: https://github.com/vllm-project/vllm/pull/22587 (note that the implementation is very different) - fix_retokenization_drift: True - - # log tool format check results - log_tool_format_check: False - - # log tool format check results - log_tool_format_error_detail: False - - # detect at which point timeline stop growing linearly and cause a snap during a episode: this will cause additional computation. - detect_timeline_snap: False - - # deprecated - alien_llm_model: qwen3-235b-a22b-instruct-2507 - - # deprecated - alien_llm_response_length: 512 - - - # when backbone is `debug`, debug related configurations - debug: - - # max parallel runners in debug mode - debug_max_parallel: 4 - - # how many task to sample from training set - debug_first_n_tasks: 2 - - # what is the vllm engine port in the background - debug_vllm_port: 18000 - - # what is the seed of the vllm engine in the background - debug_vllm_seed: 12345 - - # what is the TP size in debug mode - debug_tensor_parallel_size: 4 - - - # trainer common configurations - trainer_common: - - # validation before training - val_before_train: False - val_pass_n: 4 - - # save and test frequency (in step) - save_freq: 20 - test_freq: 20 - - # total training epochs - total_epochs: 50 - - nnodes: 1 - n_gpus_per_node: 8 - - # logger selection - logger: swanlab - - # algorithm setting - algorithm: - adv_estimator: grpo - use_kl_in_reward: False - - # number of optimizer.step per big batch - mini_batch_num: 1 - - # verl offload configs - fsdp_config: - param_offload: True - optimizer_offload: True - - # learning rate - optim: - lr: 1e-6 - - # enable KL loss regularization - use_kl_loss: True - - # kl divergence loss coefficient - kl_loss_coef: 0.002 - kl_loss_type: low_var_kl - - # Ulysses specific configs - ulysses_sequence_parallel_size: 1 - - # base directory to save checkpoints - checkpoint_base_dir: ./saved_checkpoints - - # whether to save train/eval trajectories to JSON files - save_trajectory_as_json_file: False - + judge_protocol: null # reward must come from remote user agent workflow, so set to null # the experimental ZeroMQ interchange server feature that allows `tuner.as_oai_baseurl_apikey` feature + enable_experimental_interchange_server: True + # train in cloud, run episode locally enable_tinkerscript_mode: True # both tinkerscript / oai share the same interchange server - enable_experimental_interchange_server: True - # interchange server configuration interchange_server: interchange_method: 'ipc' # options: 'tcp' (multi-nodes) or 'ipc' (1 node) interchange_server_port: 10086 - num_fastapi_process: 4 # 1, 2 or 4 is fine + num_fastapi_process: 2 # 1, 2 or 4 is fine max_fastapi_threads: 128 # 64 or 128 is fine max_inference_tracker_threads: 64 # recommend to be equal to `ajet.rollout.max_env_worker` + already_started: False # do not edit, used by `tinkerscript` - task_runner: - # submit llm infer submit method - llm_infer_submit_method: "async" # options: "sync", "async" - - # how to wrap the user-defined workflow - wrapper_type: "asyncio-with-gc" - # - wrapper_type: "asyncio-with-gc": safe, with periodic garbage collection to prevent event loop leaks (recommended) - # - wrapper_type: "asyncio": fast, but may cause event loop leak in long run - # - wrapper_type: "multi-processing": safe, but resource consuming - - # when `wrapper_type` is `multi-processing`, the timeout for each task - wrapper_multiprocessing_timeout: 3600 # in seconds - - # DO NOT EDIT, FOR ROBOT TESTING PURPOSE ONLY. NOT FOR HUMAN. - execute_test: False # DO NOT EDIT, FOR ROBOT TESTING PURPOSE ONLY. NOT FOR HUMAN. - execute_testing_lambda: "" # DO NOT EDIT, FOR ROBOT TESTING PURPOSE ONLY. NOT FOR HUMAN. - -# ------------------ do not edit ------------------ +# ------------------ 不需要修改 ------------------ hydra: searchpath: - - file://ajet/default_config/verl + - file://ajet/default_config + - file://ajet/default_config/verl # verl only -# ------------------ do not edit ------------------ +# ------------------ 不需要修改 ------------------ defaults: - - verl_default # verl inherit 1/1 + - verl_default # verl inherit 1/1 + - ajet_default - _self_ diff --git a/ajet/tuner_lib/weight_tuner/experimental/as_oai_model_server.py b/ajet/tuner_lib/weight_tuner/experimental/as_oai_model_server.py index 72080769..455380bd 100644 --- a/ajet/tuner_lib/weight_tuner/experimental/as_oai_model_server.py +++ b/ajet/tuner_lib/weight_tuner/experimental/as_oai_model_server.py @@ -272,16 +272,18 @@ async def serve_with_monitor(additional_coro): # Convenience function for quick server startup def start_interchange_server(config, blocking=False) -> int: + # Read config + already_started = config.ajet.interchange_server.already_started experiment_dir = config.ajet.experiment_dir num_fastapi_process = config.ajet.interchange_server.num_fastapi_process max_fastapi_threads = config.ajet.interchange_server.max_fastapi_threads enable_tinkerscript_mode = config.ajet.enable_tinkerscript_mode + # Find a free port if not specified or invalid port = int(os.environ.get("AJET_DAT_INTERCHANGE_PORT", -1)) - if config.ajet.interchange_server.interchange_server_port != 'auto': port = int(config.ajet.interchange_server.interchange_server_port) - + os.environ["AJET_DAT_INTERCHANGE_PORT"] = str(port) if port <= 0: import socket with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: @@ -289,24 +291,30 @@ def start_interchange_server(config, blocking=False) -> int: port = s.getsockname()[1] os.environ["AJET_DAT_INTERCHANGE_PORT"] = str(port) - interchange_server = InterchangeServer( - experiment_dir, - port, - num_fastapi_process, - max_fastapi_threads, - enable_tinkerscript_mode, - ) - interchange_server.start() + # init interchage server sub-process + if not already_started: + interchange_server = InterchangeServer( + experiment_dir, + port, + num_fastapi_process, + max_fastapi_threads, + enable_tinkerscript_mode, + ) + interchange_server.start() + else: + interchange_server = None # Wait for server to be ready health_url = f"http://127.0.0.1:{port}/health" - start_time = time.time() localhost_url = f"http://127.0.0.1:{port}" - host_ip = get_host_ip(os.environ.get("NETWORK_INTERFACE", None)) - host_url = f"http://{host_ip}:{port}" + master_node_ip = get_host_ip(os.environ.get("NETWORK_INTERFACE", None)) + host_url = f"http://{master_node_ip}:{port}" + os.environ["MASTER_NODE_IP"] = str(master_node_ip) + # polling for server ready + start_time = time.time() while True: - if interchange_server.exitcode is not None: + if interchange_server and interchange_server.exitcode is not None: logger.error(f"Interchange server subprocess failed to start. Return code: {interchange_server.exitcode}") raise RuntimeError("Interchange server subprocess failed to start.") if time.time() - start_time > 30: @@ -322,8 +330,9 @@ def start_interchange_server(config, blocking=False) -> int: time.sleep(1) # register a termination handler - if DEBUG: logger.info(f"Interchange server subprocess started on port {port} (pid: {interchange_server.pid})") - atexit.register(lambda: interchange_server.terminate()) + if interchange_server: + if DEBUG: logger.info(f"Interchange server subprocess started on port {port} (pid: {interchange_server.pid})") + atexit.register(lambda: interchange_server.terminate()) if not blocking: # return port @@ -333,5 +342,6 @@ def start_interchange_server(config, blocking=False) -> int: f"URL 1: {localhost_url}\n------\n" f"URL 2: {host_url}\n------\n" f"Press Ctrl+C to stop.") - interchange_server.join() + if interchange_server: + interchange_server.join() return -1 \ No newline at end of file diff --git a/ajet/tuner_lib/weight_tuner/experimental/as_tinkerscript_client.py b/ajet/tuner_lib/weight_tuner/experimental/as_tinkerscript_client.py index cb888671..43b977ad 100644 --- a/ajet/tuner_lib/weight_tuner/experimental/as_tinkerscript_client.py +++ b/ajet/tuner_lib/weight_tuner/experimental/as_tinkerscript_client.py @@ -119,12 +119,13 @@ def start_engine(self): """ Start the training engine on the TinkerScript server. This triggers the server to begin the training process. + Polls until engine status is "ENGINE.ROLLING". """ try: resp = httpx.post( f"{self.server_url}/start_engine", json={}, - timeout=30 + timeout=600 ) resp.raise_for_status() result = resp.json() @@ -137,6 +138,32 @@ def start_engine(self): logger.error(f"Error starting engine: {e}") raise + # Poll until engine status is "ENGINE.ROLLING" + logger.info("Polling engine status until ENGINE.ROLLING...") + last_report_time = time.time() + + while True: + try: + current_status = self.get_engine_status() + current_time = time.time() + + # Report status every 5 seconds + if current_time - last_report_time >= 5: + logger.info(f"Current engine status: {current_status}") + last_report_time = current_time + + # Check if engine has reached the desired status + if current_status == "ENGINE.ROLLING": + logger.info("Engine status is ENGINE.ROLLING - engine is ready") + break + + # Wait a bit before next poll + time.sleep(1) + + except Exception as e: + logger.error(f"Error polling engine status: {e}") + time.sleep(5) + def get_engine_status(self) -> str: try: resp = httpx.get( diff --git a/ajet/tuner_lib/weight_tuner/experimental/as_tinkerscript_server.py b/ajet/tuner_lib/weight_tuner/experimental/as_tinkerscript_server.py index 113178db..0214cd68 100644 --- a/ajet/tuner_lib/weight_tuner/experimental/as_tinkerscript_server.py +++ b/ajet/tuner_lib/weight_tuner/experimental/as_tinkerscript_server.py @@ -1,11 +1,12 @@ import multiprocessing import time -from multiprocessing.managers import DictProxy +import zmq +import os +import asyncio import threading +from multiprocessing.managers import DictProxy from types import SimpleNamespace -import zmq -import asyncio from loguru import logger from fastapi import FastAPI, HTTPException @@ -89,9 +90,9 @@ async def start_engine(): logger.error("[start_engine] No training config found. Please call sync_train_config first.") return {"success": False, "error": "No training config found"} - yaml_str = shared_mem_dict['train_config_yaml'] # Parse YAML to get backbone + yaml_str = shared_mem_dict['train_config_yaml'] config_dict = yaml_module.safe_load(yaml_str) backbone = config_dict.get('ajet', {}).get('backbone', 'verl') exp_dir_final = config_dict.get('ajet', {}).get('experiment_dir', 'saved_experiments') @@ -100,7 +101,6 @@ async def start_engine(): with tempfile.NamedTemporaryFile(mode='w', delete=False, suffix='.yaml') as temp_file: temp_file.write(yaml_str) main_yaml_fp = temp_file.name - logger.info(f"[start_engine] Saved config to temporary file: {main_yaml_fp}") # Create args namespace @@ -117,8 +117,11 @@ async def start_engine(): main_yaml_fp, exp_dir_final, backbone ) + # Setup environment variables - env = setup_environment_vars(args, exp_config, main_yaml_fp) + exp_config['ajet']['interchange_server']['already_started'] = True + exp_config['ajet']['interchange_server']['interchange_server_port'] = int(os.getenv("AJET_DAT_INTERCHANGE_PORT")) + env, exp_config = setup_environment_vars(args, exp_config, main_yaml_fp) # Start ray if not already started if not ray.is_initialized(): @@ -147,7 +150,7 @@ async def start_engine(): # Store process info in shared memory with shared_mem_dict_lock: shared_mem_dict['training_process_pid'] = p.pid - shared_mem_dict['engine_status'] = "running" + shared_mem_dict['engine_status'] = "ENGINE.BOOTING" logger.info(f"[start_engine] Successfully started training process (PID: {p.pid})") return {"success": True, "pid": p.pid} @@ -160,9 +163,17 @@ async def start_engine(): # --- engine status --- - shared_mem_dict['engine_status'] = "booting" + shared_mem_dict['engine_status'] = "ENGINE.OFF" @app.post("/update_engine_status", response_model=BoolResponse) async def update_engine_status(req: UpdateEngineStatusRequest): + if req.engine_status not in [ + "ENGINE.OFF", + "ENGINE.BOOTING", + "ENGINE.ROLLING", + "ENGINE.WEIGHT_SYNCING", + "ENGINE.WEIGHT_EXPORTING" + ]: + return BoolResponse(success=False, failure_reason="Invalid engine status") shared_mem_dict['engine_status'] = req.engine_status return BoolResponse(success=True) diff --git a/ajet/tuner_lib/weight_tuner/experimental/interchange_utils.py b/ajet/tuner_lib/weight_tuner/experimental/interchange_utils.py index ab6d816e..05a80736 100644 --- a/ajet/tuner_lib/weight_tuner/experimental/interchange_utils.py +++ b/ajet/tuner_lib/weight_tuner/experimental/interchange_utils.py @@ -56,6 +56,7 @@ class EpisodeBufferResponse(BaseModel): class BoolResponse(BaseModel): success: bool + failure_reason: str = "" class RegisterEpisodeRequest(BaseModel): episode_uuid: str @@ -79,7 +80,15 @@ def get_interchange_server_url(config): def http_change_engine_status(config: str, new_status: str): - # STATES = ["ENGINE.BOOTING", "ENGINE.ROLLING", "ENGINE.WEIGHT_SYNCING", "ENGINE.WEIGHT_EXPORTING"] + if new_status not in [ + "ENGINE.OFF", + "ENGINE.BOOTING", + "ENGINE.ROLLING", + "ENGINE.WEIGHT_SYNCING", + "ENGINE.WEIGHT_EXPORTING" + ]: + raise ValueError(f"Invalid engine status: {new_status}") + resp = httpx.post( f"{get_interchange_server_url(config)}/update_engine_status", json={"engine_status": new_status}, diff --git a/ajet/utils/core_env_vars.py b/ajet/utils/core_env_vars.py index ee1dbd82..7078bf2f 100644 --- a/ajet/utils/core_env_vars.py +++ b/ajet/utils/core_env_vars.py @@ -34,9 +34,8 @@ def get_runtime_env(config, is_trinity: bool = False) -> dict: # "VLLM_ALLOW_RUNTIME_LORA_UPDATING": "true", "SWANLAB_API_KEY": os.getenv("SWANLAB_API_KEY", ""), "AJET_CONFIG_REDIRECT": os.getenv("AJET_CONFIG_REDIRECT", ""), - "AJET_DAT_INTERCHANGE_PORT": data_interchange_port, - "AJET_DAT_INTERCHANGE_ZMQ_PORT": str(find_free_port()), - "MASTER_NODE_IP": master_node_ip, + "AJET_DAT_INTERCHANGE_PORT": os.getenv("AJET_DAT_INTERCHANGE_PORT", data_interchange_port), + "MASTER_NODE_IP": os.getenv("MASTER_NODE_IP", master_node_ip), } } diff --git a/ajet_tinkerscript_threading.py b/ajet_tinkerscript_threading.py index b02ddb9c..66ea81f8 100644 --- a/ajet_tinkerscript_threading.py +++ b/ajet_tinkerscript_threading.py @@ -11,16 +11,55 @@ from ajet.utils.retry import retry_with_backoff from concurrent.futures import ThreadPoolExecutor +# --------- configurations that take effect locally ------------- +LOCAL_GRPO_N = 4 # grpo group size +LOCAL_NUM_EPOCH = 10000 +LOCAL_MAX_PARALLEL = 2 +LOCAL_DATASET_PATH = "/mnt/data_cpfs/qingxu.fu/dataset/openai/gsm8k/main" +REMOTE_TINKERJET_URL = "http://localhost:10086" # Change to your tinkerscript remote url + +# --------- configurations that take effect remotely ------------- +REMOTE_ALLOCATE_GPU_PER_NODE = 4 +REMOTE_TRAIN_MODEL_01 = '/mnt/data_cpfs/model_cache/modelscope/hub/Qwen/Qwen/Qwen2.5-7B-Instruct' + -TINKERJET_URL = "http://localhost:10086" # Change to your tinkerscript remote url -NUM_EPOCH = 100 -GRPO_N = 4 # grpo group size -MAX_PARALLEL = 2 class WeightUpdatedHalfway(Exception): """Raised when the remote side starts updating model weights halfway through an episode.""" +def connect_to_tinkerscript_server( + create_server_via_ssh: bool = False, + create_server_locally: bool = False, + sync_train_config: bool = True, + start_engine: bool = True, +): + if create_server_via_ssh: + raise NotImplementedError("Creating tinkerscript server via SSH is not implemented yet.") + + if create_server_locally: + raise NotImplementedError("Creating tinkerscript server is not implemented yet, please run `ajet launch --tinkerscript-server` to start manually.") + + tinkerscript_remote = TinkerScriptClient(REMOTE_TINKERJET_URL) + + if sync_train_config: + tinkerscript_remote.sync_train_config( + AgentJetJob( + algorithm="grpo", + n_gpu=REMOTE_ALLOCATE_GPU_PER_NODE, + model=REMOTE_TRAIN_MODEL_01, + grpo_n=LOCAL_GRPO_N, + ) + ) + print("TinkerScript remote handshake and train config sync done.") + + if start_engine: + tinkerscript_remote.start_engine() + print("TinkerScript remote engine started.") + + return tinkerscript_remote + + def main(): # Handshake with tinkerscript remote, then send training param to tinkerscript remote (such as model to be trained, algorithm, etc) @@ -28,29 +67,18 @@ def main(): reader_type = "huggingface_dat_repo", reader_config = AjetTaskReader( huggingface_dat_repo = HuggingfaceDatRepo( - dataset_path = "/mnt/data_cpfs/qingxu.fu/dataset/openai/gsm8k/main" + dataset_path = LOCAL_DATASET_PATH ) ) ) - tinkerscript_remote = TinkerScriptClient(TINKERJET_URL) - - print("TinkerScript remote handshake.") - tinkerscript_remote.sync_train_config( - AgentJetJob( - n_gpu=2, - algorithm="grpo", - model='qwen/Qwen2.5-1.5B-instruct' - ) - ) - print("TinkerScript remote handshake and train config sync done.") - tinkerscript_remote.start_engine() - print("TinkerScript remote engine started.") - time.sleep(1000) + + # Hand shake with remote tinkerscript server + tinkerscript_remote = connect_to_tinkerscript_server(create_server_locally=True, sync_train_config=True, start_engine=True) # Define rollout def rollout(task): group_reward = [] - for i in range(GRPO_N): + for i in range(LOCAL_GRPO_N): # begin episode episode_uuid, api_baseurl_key = tinkerscript_remote.begin_episode() # execute agent @@ -63,14 +91,14 @@ def rollout(task): # Main Training loop - with ThreadPoolExecutor(max_workers=MAX_PARALLEL) as executor: - for epoch in range(NUM_EPOCH): + with ThreadPoolExecutor(max_workers=LOCAL_MAX_PARALLEL) as executor: + for epoch in range(LOCAL_NUM_EPOCH): for task in dataset.get_training_tasks(): print(f"Submitting task for epoch {epoch}") executor.submit(rollout, task) - model_path = tinkerscript_remote.download_latest_model(path='./tinkerscript_saved_model') + # model_path = tinkerscript_remote.download_latest_model(path='./tinkerscript_saved_model') # Get tuned model from tinkerscript remote return None From 968c2cf6487cd82daeaccd5c107816ebdd6e209a Mon Sep 17 00:00:00 2001 From: "qingxu.fu" Date: Thu, 29 Jan 2026 01:04:36 +0800 Subject: [PATCH 12/25] feat: enhance TinkerScript functionality with improved engine status handling and episode management --- ajet/backbone/trainer_verl.py | 1 - ajet/task_runner/tinkerscript_runner.py | 11 +- .../experimental/as_oai_model_server.py | 6 +- .../experimental/as_tinkerscript_client.py | 56 ++++++- .../experimental/as_tinkerscript_server.py | 156 +++++++++++++----- .../experimental/interchange_utils.py | 6 +- ajet_tinkerscript_threading.py | 79 ++++----- 7 files changed, 212 insertions(+), 103 deletions(-) diff --git a/ajet/backbone/trainer_verl.py b/ajet/backbone/trainer_verl.py index e34581f6..0fc224d5 100644 --- a/ajet/backbone/trainer_verl.py +++ b/ajet/backbone/trainer_verl.py @@ -444,7 +444,6 @@ def init_workers(self): ) def _update_interchange_server_status_flag(self, status: str): - # if interchange server is enabled, change engine status to ROLLING if self.config.ajet.enable_experimental_interchange_server: if self.config.ajet.enable_tinkerscript_mode: from ajet.tuner_lib.weight_tuner.experimental.interchange_utils import http_change_engine_status diff --git a/ajet/task_runner/tinkerscript_runner.py b/ajet/task_runner/tinkerscript_runner.py index b9303b29..e06f7262 100644 --- a/ajet/task_runner/tinkerscript_runner.py +++ b/ajet/task_runner/tinkerscript_runner.py @@ -16,9 +16,10 @@ from loguru import logger from ajet import Workflow +DEBUG = False + context = zmq.Context() atexit.register(context.term) -DEBUG = True class TinkerScriptRunner(BaseAgentRunner): @@ -33,12 +34,18 @@ def register_episode_and_wait_output(self, episode_uuid: str, openai_base_url: s openai_api_key=openai_api_key, zmq_listen_result_addr=zmq_listen_result_addr, ) - logger.info(f"zmq_listen_result_addr: {zmq_listen_result_addr}") + if DEBUG: logger.info(f"zmq_listen_result_addr: {zmq_listen_result_addr}") # begin wait for result zmq_socket = zmq.Context().socket(zmq.REP) zmq_socket.bind(zmq_listen_result_addr) + + # : + # : ajet/tuner_lib/weight_tuner/experimental/as_tinkerscript_server.py + # : socket.send_string(workflow_output.model_dump_json()) + # : workflow_output: WorkflowOutput message = zmq_socket.recv_string() + logger.success(f"Received workflow output for episode {episode_uuid}") zmq_socket.send_string("ack") zmq_socket.close() diff --git a/ajet/tuner_lib/weight_tuner/experimental/as_oai_model_server.py b/ajet/tuner_lib/weight_tuner/experimental/as_oai_model_server.py index 455380bd..072632e2 100644 --- a/ajet/tuner_lib/weight_tuner/experimental/as_oai_model_server.py +++ b/ajet/tuner_lib/weight_tuner/experimental/as_oai_model_server.py @@ -157,9 +157,9 @@ async def chat_completions(request: Request, authorization: str = Header(None)): if enable_tinkerscript_mode: assert shared_mem_dict is not None assert shared_mem_dict_lock is not None - if shared_mem_dict['engine_status'] != "ROLLING": - logger.error(f"The server is not in ROLLING status (current status: [{shared_mem_dict['engine_status']}]), cannot accept new requests.") - raise HTTPException(status_code=503, detail="The server is not in ROLLING status, cannot accept new requests.") + if shared_mem_dict['engine_status'] != "ENGINE.ROLLING": + logger.error(f"The server is not in ENGINE.ROLLING status (current status: [{shared_mem_dict['engine_status']}]), cannot accept new requests.") + raise HTTPException(status_code=503, detail="The server is not in ENGINE.ROLLING status, cannot accept new requests.") if (f"episodes-{episode_uuid}") not in shared_mem_dict: raise HTTPException(status_code=404, detail=f"Episode {episode_uuid} not found.") # update activate timestamp diff --git a/ajet/tuner_lib/weight_tuner/experimental/as_tinkerscript_client.py b/ajet/tuner_lib/weight_tuner/experimental/as_tinkerscript_client.py index 43b977ad..e6182151 100644 --- a/ajet/tuner_lib/weight_tuner/experimental/as_tinkerscript_client.py +++ b/ajet/tuner_lib/weight_tuner/experimental/as_tinkerscript_client.py @@ -25,6 +25,7 @@ class TinkerScriptClient(object): def __init__(self, server_url: str): self.server_url = server_url self.client_uuid = str(uuid.uuid4()) + self.previous_warning_time = 0 def begin_episode(self, allow_discard_timeout=60) -> Tuple[str, OpenaiBaseUrlAndApiKey]: @@ -59,8 +60,18 @@ def begin_episode(self, allow_discard_timeout=60) -> Tuple[str, OpenaiBaseUrlAnd episode_uuid=episode_uuid ) else: - logger.info(f"Failed to claim episode: {data.fail_cause}. Retrying in 5s...") - time.sleep(5) + need_wait_scenarios =[ + "Engine is syncing weights", + "No available episodes to claim.", + ] + if any(scenario in data.fail_cause for scenario in need_wait_scenarios): + if time.time() - self.previous_warning_time > 60: + logger.info(f"{data.fail_cause}. Retrying in 30s...") + self.previous_warning_time = time.time() + time.sleep(30) + else: + logger.warning(f"Failed to claim episode: {data.fail_cause}. Retrying in 5s...") + time.sleep(5) except Exception as e: logger.error(f"Error claiming episode: {e}. Retrying in 5s...") time.sleep(5) @@ -98,6 +109,11 @@ def sync_train_config(self, agent_jet_job: AgentJetJob): Sync training configuration to the TinkerScript server. This sends the AgentJetJob config as YAML to the remote server. """ + # try get init status + current_status = self.get_engine_status() + if current_status != "ENGINE.OFFLINE": + raise RuntimeError(f"Cannot sync train config when engine is NOT ENGINE.OFFLINE. (current status: {current_status})") + try: config_dict = agent_jet_job.config.to_dict() yaml_str = yaml.safe_dump(config_dict, sort_keys=False) @@ -121,6 +137,12 @@ def start_engine(self): This triggers the server to begin the training process. Polls until engine status is "ENGINE.ROLLING". """ + # try get init status + current_status = self.get_engine_status() + if current_status != "ENGINE.OFFLINE": + raise RuntimeError(f"Cannot start engine when engine is NOT ENGINE.OFFLINE. (current status: {current_status})") + + # Send start engine request try: resp = httpx.post( f"{self.server_url}/start_engine", @@ -139,8 +161,17 @@ def start_engine(self): raise # Poll until engine status is "ENGINE.ROLLING" + self._wait_until_avail() + logger.success("Training engine is now ROLLING and ready.") + + def _wait_until_avail(self): + """ + Poll engine status until it reaches ENGINE.ROLLING state. + Reports status every 5 seconds while waiting. + """ logger.info("Polling engine status until ENGINE.ROLLING...") last_report_time = time.time() + init_poll_time = last_report_time while True: try: @@ -149,7 +180,7 @@ def start_engine(self): # Report status every 5 seconds if current_time - last_report_time >= 5: - logger.info(f"Current engine status: {current_status}") + logger.info(f"Current engine status (already waited {current_time - init_poll_time:.1f}s): {current_status}") last_report_time = current_time # Check if engine has reached the desired status @@ -210,3 +241,22 @@ def get_episode_buffer(self) -> List[EpisodeStatus]: except Exception as e: logger.error(f"Error getting episode buffer: {e}") return [] + + def auto_sync_train_config_and_start_engine(self, agent_jet_job: AgentJetJob): + """ + Automatically sync training configuration and start the engine if needed. + This checks the current engine status and performs actions accordingly. + """ + current_status = self.get_engine_status() + if current_status == "ENGINE.OFFLINE": + logger.info("Engine is OFFLINE. Syncing train config and starting engine...") + self.sync_train_config(agent_jet_job) + self.start_engine() + elif current_status == "ENGINE.ROLLING": + logger.info("Engine is already ROLLING. No action needed.") + elif current_status == "ENGINE.BOOTING": + logger.info("Engine is BOOTING. Waiting until it becomes ROLLING...") + self._wait_until_avail() + logger.success("Training engine is now ROLLING and ready.") + else: + raise RuntimeError(f"Cannot sync train config or start engine when engine is in status: {current_status}") diff --git a/ajet/tuner_lib/weight_tuner/experimental/as_tinkerscript_server.py b/ajet/tuner_lib/weight_tuner/experimental/as_tinkerscript_server.py index 0214cd68..f631e742 100644 --- a/ajet/tuner_lib/weight_tuner/experimental/as_tinkerscript_server.py +++ b/ajet/tuner_lib/weight_tuner/experimental/as_tinkerscript_server.py @@ -28,7 +28,7 @@ UpdateEngineStatusRequest, ) -DEBUG = True +DEBUG = False def register_enable_tinkerscript_mode_routes( app, @@ -43,6 +43,84 @@ def register_enable_tinkerscript_mode_routes( if 'unclaimed_episodes' not in shared_mem_dict: shared_mem_dict['unclaimed_episodes'] = [] + def find_claimed_episodes_that_need_to_be_unclaimed() -> List[str]: + result = [] + current_time = time.time() + + for k, v in shared_mem_dict.items(): + if k.startswith("episodes-"): + es:EpisodeStatus = v + if es.episode_status == "claimed": + if (current_time - es.latest_activity_timestamp) > es.allow_discard_timeout: + result.append(es.episode_uuid) + + for episode_uuid in result: + _revert_episode_to_unclaimed(episode_uuid) + + return result + + def _revert_episode_to_unclaimed(episode_uuid: str): + with shared_mem_dict_lock: + # check status again, because other thread may have changed it + if shared_mem_dict[f"episodes-{episode_uuid}"].episode_status != "claimed": + return + + # revert + logger.warning(f"Reverting episode {episode_uuid} to unclaimed due to client timeout.") + if f"episodes-{episode_uuid}" in shared_mem_dict: + es:EpisodeStatus = shared_mem_dict[f"episodes-{episode_uuid}"] + es.episode_status = "registered" + es.client_uuid = "" + es.latest_activity_timestamp = time.time() + es.allow_discard_timeout = -1 + shared_mem_dict[f"episodes-{episode_uuid}"] = es + shared_mem_dict['unclaimed_episodes'] += [episode_uuid] + + + async def register_episode_ready_listener(): + while True: + read_all_episode_status() + await asyncio.sleep(10) # check every 10 seconds + find_claimed_episodes_that_need_to_be_unclaimed() + + + def read_all_episode_status() -> Optional[EpisodeStatus]: + print_buffer = [] + group_by_status = {} + + for k, v in shared_mem_dict.items(): + if k.startswith("episodes-"): + es:EpisodeStatus = v + if es.episode_status not in group_by_status: + group_by_status[es.episode_status] = [] + group_by_status[es.episode_status].append(es) + + for status, es_list in group_by_status.items(): + print_buffer.append(f"--- {status} (time since last activity) ---") + in_line_buffer = "" + for es in es_list: + time_since_last_activity = time.time() - es.latest_activity_timestamp + in_line_buffer += f"{es.episode_uuid[:6]}({time_since_last_activity:.1f}s)\t" + print_buffer.append(in_line_buffer) + + print_buffer_str = "\n".join(print_buffer) + logger.info(f"Current engine status: [{shared_mem_dict['engine_status']}]") + if print_buffer: + logger.info(f"Current episode statuses:\n{print_buffer_str}") + else: + logger.info(f"Current episode statuses: [NA]") + + return None + + + # hiefwu1(15.1s ago) hiefwu2(20.3s ago) hiefwu3(5.0s ago) + + + + # -------------------------------------------------------------------- + # -------------------------- fastapi routes -------------------------- + # -------------------------------------------------------------------- + @app.post("/sync_train_config") async def sync_train_config(req: SyncTrainConfigRequest): """ @@ -120,7 +198,7 @@ async def start_engine(): # Setup environment variables exp_config['ajet']['interchange_server']['already_started'] = True - exp_config['ajet']['interchange_server']['interchange_server_port'] = int(os.getenv("AJET_DAT_INTERCHANGE_PORT")) + exp_config['ajet']['interchange_server']['interchange_server_port'] = int(os.getenv("AJET_DAT_INTERCHANGE_PORT")) # type: ignore env, exp_config = setup_environment_vars(args, exp_config, main_yaml_fp) # Start ray if not already started @@ -163,11 +241,12 @@ async def start_engine(): # --- engine status --- - shared_mem_dict['engine_status'] = "ENGINE.OFF" + shared_mem_dict['engine_status'] = "ENGINE.OFFLINE" @app.post("/update_engine_status", response_model=BoolResponse) async def update_engine_status(req: UpdateEngineStatusRequest): + """Update the current engine status.""" if req.engine_status not in [ - "ENGINE.OFF", + "ENGINE.OFFLINE", "ENGINE.BOOTING", "ENGINE.ROLLING", "ENGINE.WEIGHT_SYNCING", @@ -180,6 +259,7 @@ async def update_engine_status(req: UpdateEngineStatusRequest): @app.get("/get_engine_status") async def get_engine_status(): + """Get the current engine status.""" status = shared_mem_dict['engine_status'] return {"engine_status": status} @@ -187,7 +267,7 @@ async def get_engine_status(): # --- episode status --- @app.post("/register_episode", response_model=BoolResponse) async def register_episode(req: RegisterEpisodeRequest): - + """(From task_runner) Register a new episode as ready to roll.""" episode_uuid = req.episode_uuid es = EpisodeStatus( episode_uuid=req.episode_uuid, @@ -210,8 +290,30 @@ async def register_episode(req: RegisterEpisodeRequest): @app.post("/claim_episode", response_model=ClaimEpisodeResponse) async def claim_episode(req: ClaimEpisodeRequest): + """(From client) Claim an available episode to rollout.""" find_claimed_episodes_that_need_to_be_unclaimed() + engine_status = shared_mem_dict['engine_status'] + if engine_status != "ENGINE.ROLLING": + fail_cause = f"Engine not ready. Current status: [{engine_status}]." + advise = "" + if engine_status == "ENGINE.OFFLINE": + advise = "Please start the engine first. Please use one of the client to run `client.sync_train_config() + client.start_engine()` to start the engine." + elif engine_status == "ENGINE.BOOTING": + advise = "Please wait until the engine is fully booted. Try again (maybe 1 minute) later." + elif engine_status == "ENGINE.WEIGHT_SYNCING": + advise = "Engine is syncing weights. Try again (maybe 1 minute) later." + elif engine_status == "ENGINE.WEIGHT_EXPORTING": + advise = "Engine is exporting weights (fsdp -> hf safetensor). Try again (maybe 1 minute) later." + return ClaimEpisodeResponse( + success=False, + client_uuid=req.client_uuid, + episode_uuid="", + openai_base_url="", + openai_api_key="", + fail_cause=fail_cause + " " + advise, + ) + with shared_mem_dict_lock: if len(shared_mem_dict['unclaimed_episodes']) <= 0: return ClaimEpisodeResponse( @@ -248,41 +350,6 @@ async def claim_episode(req: ClaimEpisodeRequest): ) - def find_claimed_episodes_that_need_to_be_unclaimed() -> List[str]: - result = [] - current_time = time.time() - - for k, v in shared_mem_dict.items(): - if k.startswith("episodes-"): - es:EpisodeStatus = v - if es.episode_status == "claimed": - if (current_time - es.latest_activity_timestamp) > es.allow_discard_timeout: - result.append(es.episode_uuid) - - for episode_uuid in result: - _revert_episode_to_unclaimed(episode_uuid) - - return result - - - def _revert_episode_to_unclaimed(episode_uuid: str): - with shared_mem_dict_lock: - # check status again, because other thread may have changed it - if shared_mem_dict[f"episodes-{episode_uuid}"].episode_status != "claimed": - return - - # revert - logger.info(f"Reverting episode {episode_uuid} to unclaimed due to client timeout.") - if f"episodes-{episode_uuid}" in shared_mem_dict: - es:EpisodeStatus = shared_mem_dict[f"episodes-{episode_uuid}"] - es.episode_status = "registered" - es.client_uuid = "" - es.latest_activity_timestamp = time.time() - es.allow_discard_timeout = -1 - shared_mem_dict[f"episodes-{episode_uuid}"] = es - shared_mem_dict['unclaimed_episodes'] += [episode_uuid] - - @app.post("/end_episode", response_model=EndEpisodeResponse) async def end_episode(req: EndEpisodeRequest): # receive workflow output data @@ -312,6 +379,10 @@ async def end_episode(req: EndEpisodeRequest): for _ in range(5): # max 5 minutes wait try: if DEBUG: logger.info(f"[server] episode_uuid: {episode_uuid} | recv_string begin.") + # : + # : ajet/task_runner/tinkerscript_runner.py + # : zmq_socket.send_string("ack") + # : "ack" result_str = socket.recv_string() break except zmq.Again as e: @@ -345,9 +416,4 @@ async def get_episode_buffer(): return EpisodeBufferResponse(buffer=result) - - async def register_episode_ready_listener(): - pass - - return app, register_episode_ready_listener() diff --git a/ajet/tuner_lib/weight_tuner/experimental/interchange_utils.py b/ajet/tuner_lib/weight_tuner/experimental/interchange_utils.py index 05a80736..fecff76f 100644 --- a/ajet/tuner_lib/weight_tuner/experimental/interchange_utils.py +++ b/ajet/tuner_lib/weight_tuner/experimental/interchange_utils.py @@ -69,6 +69,8 @@ class UpdateEngineStatusRequest(BaseModel): engine_status: str = "" +DEBUG = False + def get_interchange_server_url(config): port = os.getenv("AJET_DAT_INTERCHANGE_PORT") if config.ajet.interchange_server.interchange_server_port != 'auto': @@ -95,7 +97,7 @@ def http_change_engine_status(config: str, new_status: str): timeout=10 ) resp.raise_for_status() - logger.info(f"Changed engine status to {new_status}") + logger.success(f"Changed engine status to {new_status}") @@ -123,7 +125,7 @@ def http_register_episode(config, episode_uuid: str, result = response.json() if not result.get('success'): raise RuntimeError(f"Failed to register episode {episode_uuid}") - logger.info(f"Successfully registered episode {episode_uuid}") + if DEBUG: logger.info(f"Successfully registered episode {episode_uuid}") break except httpx.HTTPError as e: logger.error(f"Error registering episode {episode_uuid}: {e}. Retrying...") diff --git a/ajet_tinkerscript_threading.py b/ajet_tinkerscript_threading.py index 66ea81f8..6fea8282 100644 --- a/ajet_tinkerscript_threading.py +++ b/ajet_tinkerscript_threading.py @@ -1,6 +1,7 @@ import re +import threading import requests -import time +from loguru import logger from textwrap import dedent from ajet.copilot.job import AgentJetJob from ajet.tuner_lib.weight_tuner.experimental.as_tinkerscript_client import TinkerScriptClient @@ -14,7 +15,7 @@ # --------- configurations that take effect locally ------------- LOCAL_GRPO_N = 4 # grpo group size LOCAL_NUM_EPOCH = 10000 -LOCAL_MAX_PARALLEL = 2 +LOCAL_MAX_PARALLEL = 32 LOCAL_DATASET_PATH = "/mnt/data_cpfs/qingxu.fu/dataset/openai/gsm8k/main" REMOTE_TINKERJET_URL = "http://localhost:10086" # Change to your tinkerscript remote url @@ -28,38 +29,6 @@ class WeightUpdatedHalfway(Exception): """Raised when the remote side starts updating model weights halfway through an episode.""" -def connect_to_tinkerscript_server( - create_server_via_ssh: bool = False, - create_server_locally: bool = False, - sync_train_config: bool = True, - start_engine: bool = True, -): - if create_server_via_ssh: - raise NotImplementedError("Creating tinkerscript server via SSH is not implemented yet.") - - if create_server_locally: - raise NotImplementedError("Creating tinkerscript server is not implemented yet, please run `ajet launch --tinkerscript-server` to start manually.") - - tinkerscript_remote = TinkerScriptClient(REMOTE_TINKERJET_URL) - - if sync_train_config: - tinkerscript_remote.sync_train_config( - AgentJetJob( - algorithm="grpo", - n_gpu=REMOTE_ALLOCATE_GPU_PER_NODE, - model=REMOTE_TRAIN_MODEL_01, - grpo_n=LOCAL_GRPO_N, - ) - ) - print("TinkerScript remote handshake and train config sync done.") - - if start_engine: - tinkerscript_remote.start_engine() - print("TinkerScript remote engine started.") - - return tinkerscript_remote - - def main(): # Handshake with tinkerscript remote, then send training param to tinkerscript remote (such as model to be trained, algorithm, etc) @@ -73,28 +42,44 @@ def main(): ) # Hand shake with remote tinkerscript server - tinkerscript_remote = connect_to_tinkerscript_server(create_server_locally=True, sync_train_config=True, start_engine=True) + tinkerscript_remote = TinkerScriptClient(REMOTE_TINKERJET_URL) + tinkerscript_remote.auto_sync_train_config_and_start_engine( + AgentJetJob( + algorithm="grpo", + n_gpu=REMOTE_ALLOCATE_GPU_PER_NODE, + model=REMOTE_TRAIN_MODEL_01, + grpo_n=LOCAL_GRPO_N, + ) + ) + + # tinkerscript_remote = connect_to_tinkerscript_server(sync_train_config=False, start_engine=False) + submit_sem = threading.BoundedSemaphore(LOCAL_MAX_PARALLEL) # Define rollout def rollout(task): - group_reward = [] - for i in range(LOCAL_GRPO_N): - # begin episode - episode_uuid, api_baseurl_key = tinkerscript_remote.begin_episode() - # execute agent - workflow_output = execute_agent(task, api_baseurl_key) - # report output back to tinkerscript remote - tinkerscript_remote.end_episode(episode_uuid, workflow_output) - # collect reward - group_reward.append(workflow_output.reward) - print(f"Group reward mean & std: {sum(group_reward)/len(group_reward)} +/- { (max(group_reward)-min(group_reward))/2 }") - + try: + group_reward = [] + for i in range(LOCAL_GRPO_N): + # begin episode + episode_uuid, api_baseurl_key = tinkerscript_remote.begin_episode() + # execute agent + workflow_output = execute_agent(task, api_baseurl_key) + # report output back to tinkerscript remote + tinkerscript_remote.end_episode(episode_uuid, workflow_output) + # collect reward + group_reward.append(workflow_output.reward) + print(f"Group reward mean & std: {sum(group_reward)/len(group_reward)} +/- { (max(group_reward)-min(group_reward))/2 }") + except Exception as e: + logger.exception("Exception during rollout:", e) + finally: + submit_sem.release() # Main Training loop with ThreadPoolExecutor(max_workers=LOCAL_MAX_PARALLEL) as executor: for epoch in range(LOCAL_NUM_EPOCH): for task in dataset.get_training_tasks(): print(f"Submitting task for epoch {epoch}") + submit_sem.acquire() executor.submit(rollout, task) From a6c7e0e431bbef9381ac9d52c4f72d77d9e2b52f Mon Sep 17 00:00:00 2001 From: "qingxu.fu" Date: Thu, 29 Jan 2026 02:49:32 +0800 Subject: [PATCH 13/25] stage eval code ( to be tested ) --- ajet/context_tracker/base_tracker.py | 15 + ajet/task_runner/tinkerscript_runner.py | 32 +- .../experimental/as_tinkerscript_client.py | 45 ++- .../experimental/as_tinkerscript_server.py | 371 +++++++++++++----- .../experimental/interchange_utils.py | 1 + ajet_tinkerscript_threading.py | 6 +- 6 files changed, 352 insertions(+), 118 deletions(-) diff --git a/ajet/context_tracker/base_tracker.py b/ajet/context_tracker/base_tracker.py index 856cd89c..9e04c4bb 100644 --- a/ajet/context_tracker/base_tracker.py +++ b/ajet/context_tracker/base_tracker.py @@ -148,6 +148,21 @@ def __init__(self, config, tokenizer, workflow_task: WorkflowTask, **kwargs): <= max_model_len ) + def reset(self): + self.saved_timelines: List[List[ExtendedMessage]] = [] + self.current_context_status = "" + self.terminal_rewards_dict = {} + self.discarded = False + self.is_terminated = False + self.reward_structure: Union[Reward, None] = None + self.context_time_cost = 0 + self.tag = "" + self.current_batch_success_rate: float = float("-inf") + self.current_batch_reward: float = float("-inf") + self.already_mad_flag: bool = False + self.round_cnt = 0 + self.log_metrics: Optional[Dict[str, Union[float, List[float], Dict[str, Any]]]] = None + def group_tokenize(self): raise NotImplementedError diff --git a/ajet/task_runner/tinkerscript_runner.py b/ajet/task_runner/tinkerscript_runner.py index e06f7262..d0351767 100644 --- a/ajet/task_runner/tinkerscript_runner.py +++ b/ajet/task_runner/tinkerscript_runner.py @@ -23,7 +23,7 @@ class TinkerScriptRunner(BaseAgentRunner): - def register_episode_and_wait_output(self, episode_uuid: str, openai_base_url: str, openai_api_key: str) -> WorkflowOutput: + def register_episode_and_wait_output(self, episode_uuid: str, openai_base_url: str, openai_api_key: str, context_tracker: BaseContextTracker) -> WorkflowOutput: """Register the episode as ready in the TinkerScript data interchange center.""" # parse episode_uuid, openai_base_url, openai_api_key zmq_listen_result_addr, ipc_path = get_zmq_socket(self.config, episode_uuid, tag="workflow") @@ -39,15 +39,30 @@ def register_episode_and_wait_output(self, episode_uuid: str, openai_base_url: s # begin wait for result zmq_socket = zmq.Context().socket(zmq.REP) zmq_socket.bind(zmq_listen_result_addr) - - # : - # : ajet/tuner_lib/weight_tuner/experimental/as_tinkerscript_server.py - # : socket.send_string(workflow_output.model_dump_json()) - # : workflow_output: WorkflowOutput - message = zmq_socket.recv_string() + speicial_messages = [ + "RUNNER.RESET_CONTEXT_TRACKER" + ] + while True: + # : + # : ajet/tuner_lib/weight_tuner/experimental/as_tinkerscript_server.py + # : socket.send_string(workflow_output.model_dump_json()) + # : workflow_output: WorkflowOutput + # : + # : ajet/tuner_lib/weight_tuner/experimental/as_tinkerscript_server.py + # : socket.send_string("RUNNER.SPECIAL.RESET_CONTEXT_TRACKER") + # : "RUNNER.SPECIAL.RESET_CONTEXT_TRACKER" + message = zmq_socket.recv_string() + if message not in speicial_messages: + zmq_socket.send_string("ack") + break + elif message == "RUNNER.SPECIAL.RESET_CONTEXT_TRACKER": + logger.warning(f"Received reset command for episode {episode_uuid}.") + context_tracker.reset() + zmq_socket.send_string("ack") + else: + raise RuntimeError(f"Unknown special message received: {message}") logger.success(f"Received workflow output for episode {episode_uuid}") - zmq_socket.send_string("ack") zmq_socket.close() if ipc_path and os.path.exists(ipc_path): os.remove(ipc_path) @@ -85,6 +100,7 @@ def execute(self, workflow_task: WorkflowTask) -> BaseContextTracker: episode_uuid=context_tracker.episode_uuid, openai_base_url=base_url, openai_api_key=api_key, + context_tracker=context_tracker, ) if workflow_output.reward is not None: diff --git a/ajet/tuner_lib/weight_tuner/experimental/as_tinkerscript_client.py b/ajet/tuner_lib/weight_tuner/experimental/as_tinkerscript_client.py index e6182151..b9a7b515 100644 --- a/ajet/tuner_lib/weight_tuner/experimental/as_tinkerscript_client.py +++ b/ajet/tuner_lib/weight_tuner/experimental/as_tinkerscript_client.py @@ -28,7 +28,7 @@ def __init__(self, server_url: str): self.previous_warning_time = 0 - def begin_episode(self, allow_discard_timeout=60) -> Tuple[str, OpenaiBaseUrlAndApiKey]: + def begin_episode(self, allow_discard_timeout=60, episode_type="train") -> Tuple[str, OpenaiBaseUrlAndApiKey]: """ Block until an episode is claimed. Return (episode_uuid, openai_base_url, openai_api_key) @@ -37,7 +37,7 @@ def begin_episode(self, allow_discard_timeout=60) -> Tuple[str, OpenaiBaseUrlAnd try: req_obj = ClaimEpisodeRequest( client_uuid=self.client_uuid, - episode_type="default", + episode_type=episode_type, allow_discard_timeout=allow_discard_timeout, ) resp = httpx.post( @@ -161,15 +161,15 @@ def start_engine(self): raise # Poll until engine status is "ENGINE.ROLLING" - self._wait_until_avail() + self._wait_until_status_change_to(desired_status="ENGINE.ROLLING") logger.success("Training engine is now ROLLING and ready.") - def _wait_until_avail(self): + def _wait_until_status_change_to(self, desired_status="ENGINE.ROLLING"): """ - Poll engine status until it reaches ENGINE.ROLLING state. + Poll engine status until it reaches desired_status. Reports status every 5 seconds while waiting. """ - logger.info("Polling engine status until ENGINE.ROLLING...") + logger.info(f"Polling engine status until {desired_status}...") last_report_time = time.time() init_poll_time = last_report_time @@ -184,8 +184,8 @@ def _wait_until_avail(self): last_report_time = current_time # Check if engine has reached the desired status - if current_status == "ENGINE.ROLLING": - logger.info("Engine status is ENGINE.ROLLING - engine is ready") + if current_status == desired_status: + logger.info(f"Engine status is {desired_status}.") break # Wait a bit before next poll @@ -256,7 +256,34 @@ def auto_sync_train_config_and_start_engine(self, agent_jet_job: AgentJetJob): logger.info("Engine is already ROLLING. No action needed.") elif current_status == "ENGINE.BOOTING": logger.info("Engine is BOOTING. Waiting until it becomes ROLLING...") - self._wait_until_avail() + self._wait_until_status_change_to(desired_status="ENGINE.ROLLING") logger.success("Training engine is now ROLLING and ready.") else: raise RuntimeError(f"Cannot sync train config or start engine when engine is in status: {current_status}") + + def stop_engine(self): + """ + Stop the training engine on the TinkerScript server. + This triggers the server to stop the training process. + """ + current_status = self.get_engine_status() + if current_status == "ENGINE.OFFLINE": + logger.info("Engine is already OFFLINE. No action needed.") + return + + try: + resp = httpx.post( + f"{self.server_url}/stop_engine", + json={}, + timeout=600 + ) + resp.raise_for_status() + result = resp.json() + if result.get("success"): + logger.info("Successfully stopped training engine on TinkerScript server") + else: + logger.error("Failed to stop training engine") + self._wait_until_status_change_to(desired_status="ENGINE.OFFLINE") + except Exception as e: + logger.error(f"Error stopping engine: {e}") + diff --git a/ajet/tuner_lib/weight_tuner/experimental/as_tinkerscript_server.py b/ajet/tuner_lib/weight_tuner/experimental/as_tinkerscript_server.py index f631e742..29a725c5 100644 --- a/ajet/tuner_lib/weight_tuner/experimental/as_tinkerscript_server.py +++ b/ajet/tuner_lib/weight_tuner/experimental/as_tinkerscript_server.py @@ -6,13 +6,9 @@ import threading from multiprocessing.managers import DictProxy from types import SimpleNamespace - - from loguru import logger from fastapi import FastAPI, HTTPException -from typing import List - -from typing import Coroutine, Optional, Tuple +from typing import Coroutine, Optional, Tuple, List from ajet.tuner_lib.weight_tuner.experimental.interchange_utils import ( SyncTrainConfigRequest, ClaimEpisodeRequest, @@ -43,6 +39,11 @@ def register_enable_tinkerscript_mode_routes( if 'unclaimed_episodes' not in shared_mem_dict: shared_mem_dict['unclaimed_episodes'] = [] + # ------------------------------------------------------------------------------------------------ + # ------ Recycle claimed episodes that client failed to complete in (promised) time -------------- + # --------------------------------- claimed -> unclaimed ---------------------------------------- + # ------------------------------------------------------------------------------------------------ + def find_claimed_episodes_that_need_to_be_unclaimed() -> List[str]: result = [] current_time = time.time() @@ -59,12 +60,40 @@ def find_claimed_episodes_that_need_to_be_unclaimed() -> List[str]: return result + def _context_tracker_reset(episode_uuid, shared_mem_dict): + # send message to context tracker + assert 'episodes' in shared_mem_dict + zmq_addr = shared_mem_dict[f"episodes-{episode_uuid}"].zmq_listen_result_addr + socket = zmq_context.socket(zmq.REQ) + socket.setsockopt(zmq.RCVTIMEO, 60*1000) # 1 minute recv timeout + socket.connect(zmq_addr) + # + # : ajet/task_runner/tinkerscript_runner.py + # : message = zmq_socket.recv_string() + socket.send_string("RUNNER.SPECIAL.RESET_CONTEXT_TRACKER") + # + for _ in range(5): # max 5 minutes wait + try: + if DEBUG: logger.info(f"[server] episode_uuid: {episode_uuid} | recv_string begin.") + # : + # : ajet/task_runner/tinkerscript_runner.py + # : zmq_socket.send_string("ack") + # : "ack" + result_str = socket.recv_string() + break + except zmq.Again as e: + if DEBUG: logger.info(f"[server] episode_uuid: {episode_uuid} | recv_string timeout, retrying.") + continue + def _revert_episode_to_unclaimed(episode_uuid: str): with shared_mem_dict_lock: # check status again, because other thread may have changed it if shared_mem_dict[f"episodes-{episode_uuid}"].episode_status != "claimed": return + # reset context tracker + _context_tracker_reset(episode_uuid, shared_mem_dict) + # revert logger.warning(f"Reverting episode {episode_uuid} to unclaimed due to client timeout.") if f"episodes-{episode_uuid}" in shared_mem_dict: @@ -77,13 +106,53 @@ def _revert_episode_to_unclaimed(episode_uuid: str): shared_mem_dict['unclaimed_episodes'] += [episode_uuid] + + + # -------------------------------------------------------------------------------------- + # -------------------------- return workflow output ------------------------------------ + # -------------------------------------------------------------------------------------- + + def _register_final_episode_output(episode_uuid, workflow_output, shared_mem_dict, shared_mem_dict_lock): + # begin send workflow_output + zmq_addr = shared_mem_dict[f"episodes-{episode_uuid}"].zmq_listen_result_addr + if DEBUG: logger.info(f"[server] episode_uuid: {episode_uuid} | Received new chat completion request (inside thread)") + socket = zmq_context.socket(zmq.REQ) + socket.setsockopt(zmq.RCVTIMEO, 60*1000) # 1 minute recv timeout + socket.connect(zmq_addr) + if DEBUG: logger.info(f"[server] episode_uuid: {episode_uuid} | connect done") + socket.send_string(workflow_output.model_dump_json()) + if DEBUG: logger.info(f"[server] episode_uuid: {episode_uuid} | send_string") + # wait for ack + for _ in range(5): # max 5 minutes wait + try: + if DEBUG: logger.info(f"[server] episode_uuid: {episode_uuid} | recv_string begin.") + # : + # : ajet/task_runner/tinkerscript_runner.py + # : zmq_socket.send_string("ack") + # : "ack" + result_str = socket.recv_string() + break + except zmq.Again as e: + if DEBUG: logger.info(f"[server] episode_uuid: {episode_uuid} | recv_string timeout, retrying.") + continue + # clean up episode records + with shared_mem_dict_lock: + del shared_mem_dict[f"episodes-{episode_uuid}"] + if episode_uuid in shared_mem_dict['unclaimed_episodes']: + shared_mem_dict['unclaimed_episodes'].remove(episode_uuid) + + + + # -------------------------------------------------------------------------------------- + # -------------------------- status monitor -------------------------------------------- + # -------------------------------------------------------------------------------------- + async def register_episode_ready_listener(): while True: read_all_episode_status() await asyncio.sleep(10) # check every 10 seconds find_claimed_episodes_that_need_to_be_unclaimed() - def read_all_episode_status() -> Optional[EpisodeStatus]: print_buffer = [] group_by_status = {} @@ -95,13 +164,13 @@ def read_all_episode_status() -> Optional[EpisodeStatus]: group_by_status[es.episode_status] = [] group_by_status[es.episode_status].append(es) - for status, es_list in group_by_status.items(): - print_buffer.append(f"--- {status} (time since last activity) ---") - in_line_buffer = "" - for es in es_list: - time_since_last_activity = time.time() - es.latest_activity_timestamp - in_line_buffer += f"{es.episode_uuid[:6]}({time_since_last_activity:.1f}s)\t" - print_buffer.append(in_line_buffer) + # for status, es_list in group_by_status.items(): + # print_buffer.append(f"{status} (time since last activity)") + # in_line_buffer = "" + # for es in es_list: + # time_since_last_activity = time.time() - es.latest_activity_timestamp + # in_line_buffer += f"{es.episode_uuid[:6]}({time_since_last_activity:.1f}s)\t" + # print_buffer.append(in_line_buffer) print_buffer_str = "\n".join(print_buffer) logger.info(f"Current engine status: [{shared_mem_dict['engine_status']}]") @@ -113,13 +182,9 @@ def read_all_episode_status() -> Optional[EpisodeStatus]: return None - # hiefwu1(15.1s ago) hiefwu2(20.3s ago) hiefwu3(5.0s ago) - - - - # -------------------------------------------------------------------- - # -------------------------- fastapi routes -------------------------- - # -------------------------------------------------------------------- + # -------------------------------------------------------------------------------------- + # -------------------------- fastapi routes -------------------------------------------- + # -------------------------------------------------------------------------------------- @app.post("/sync_train_config") async def sync_train_config(req: SyncTrainConfigRequest): @@ -151,24 +216,18 @@ async def start_engine(): This creates a temporary YAML file and spawns a training process. """ try: - from ajet.utils.launch_utils import execute_training_process - from ajet.launcher import ( - get_backbone_target, - setup_environment_vars, - ) - from ajet.utils.config_utils import ( - prepare_experiment_config, - ) import ray import tempfile import yaml as yaml_module + from ajet.utils.launch_utils import execute_training_process + from ajet.utils.config_utils import prepare_experiment_config + from ajet.launcher import get_backbone_target, setup_environment_vars # Check if config has been synced if 'train_config_yaml' not in shared_mem_dict: logger.error("[start_engine] No training config found. Please call sync_train_config first.") return {"success": False, "error": "No training config found"} - # Parse YAML to get backbone yaml_str = shared_mem_dict['train_config_yaml'] config_dict = yaml_module.safe_load(yaml_str) @@ -183,11 +242,7 @@ async def start_engine(): # Create args namespace args = SimpleNamespace( - conf=main_yaml_fp, - backbone=backbone, - exp_dir=exp_dir_final, - with_logview=False, - debug=False, + conf=main_yaml_fp, backbone=backbone, exp_dir=exp_dir_final, with_logview=False, debug=False, ) # Finalize experiment config @@ -195,7 +250,6 @@ async def start_engine(): main_yaml_fp, exp_dir_final, backbone ) - # Setup environment variables exp_config['ajet']['interchange_server']['already_started'] = True exp_config['ajet']['interchange_server']['interchange_server_port'] = int(os.getenv("AJET_DAT_INTERCHANGE_PORT")) # type: ignore @@ -213,18 +267,16 @@ async def start_engine(): p = multiprocessing.Process( target=execute_training_process, args=( - args, - get_backbone_target(args.backbone), - main_yaml_fp, - exe_exp_base, - main_yaml_fp, - env, - exp_config, + args, get_backbone_target(args.backbone), main_yaml_fp, + exe_exp_base, main_yaml_fp, env, exp_config, ) ) p.daemon = True p.start() - + # wait until p.pid is available + while not isinstance(p.pid, int): time.sleep(1) + # set new process group + os.setpgid(p.pid, p.pid) # Store process info in shared memory with shared_mem_dict_lock: shared_mem_dict['training_process_pid'] = p.pid @@ -280,6 +332,13 @@ async def register_episode(req: RegisterEpisodeRequest): es.latest_activity_timestamp = time.time() with shared_mem_dict_lock: + engine_status = shared_mem_dict['engine_status'] + if engine_status not in ["ENGINE.ROLLING"]: + return BoolResponse( + success=False, + failure_reason=f"Engine has already shutdown. Cannot register episode.", + ) + shared_mem_dict[f"episodes-{episode_uuid}"] = es shared_mem_dict['unclaimed_episodes'] += [req.episode_uuid] @@ -314,40 +373,46 @@ async def claim_episode(req: ClaimEpisodeRequest): fail_cause=fail_cause + " " + advise, ) - with shared_mem_dict_lock: - if len(shared_mem_dict['unclaimed_episodes']) <= 0: - return ClaimEpisodeResponse( - success=False, - client_uuid=req.client_uuid, - episode_uuid="", - openai_base_url="", - openai_api_key="", - fail_cause="No available episodes to claim. Try again (maybe 1 minute) later.", - ) + if req.episode_type == "train" or req.episode_type == "eval": - # hint: do not optimize this - episode_uuid = shared_mem_dict['unclaimed_episodes'][0] - shared_mem_dict['unclaimed_episodes'] = shared_mem_dict['unclaimed_episodes'][1:] + with shared_mem_dict_lock: + if len(shared_mem_dict['unclaimed_episodes']) <= 0: + return ClaimEpisodeResponse( + success=False, + client_uuid=req.client_uuid, + episode_uuid="", + openai_base_url="", + openai_api_key="", + fail_cause="No available episodes to claim. Try again (maybe 1 minute) later.", + ) + + # Hint: do NOT optimize these two lines + episode_uuid = shared_mem_dict['unclaimed_episodes'][0] + shared_mem_dict['unclaimed_episodes'] = shared_mem_dict['unclaimed_episodes'][1:] + + # get episode + es:EpisodeStatus = shared_mem_dict[f"episodes-{episode_uuid}"] + es.episode_status = "claimed" + es.episode_type = req.episode_type + es.client_uuid = req.client_uuid + es.latest_activity_timestamp = time.time() + es.allow_discard_timeout = req.allow_discard_timeout - # get episode - es:EpisodeStatus = shared_mem_dict[f"episodes-{episode_uuid}"] - es.episode_status = "claimed" - es.client_uuid = req.client_uuid - es.latest_activity_timestamp = time.time() - es.allow_discard_timeout = req.allow_discard_timeout + shared_mem_dict[f"episodes-{episode_uuid}"] = es + openai_base_url = es.openai_base_url + openai_api_key = es.openai_api_key - shared_mem_dict[f"episodes-{episode_uuid}"] = es - openai_base_url = es.openai_base_url - openai_api_key = es.openai_api_key + return ClaimEpisodeResponse( + success=True, + client_uuid=req.client_uuid, + episode_uuid=episode_uuid, + openai_base_url=openai_base_url, + openai_api_key=openai_api_key, + fail_cause="", + ) - return ClaimEpisodeResponse( - success=True, - client_uuid=req.client_uuid, - episode_uuid=episode_uuid, - openai_base_url=openai_base_url, - openai_api_key=openai_api_key, - fail_cause="", - ) + else: + raise HTTPException(status_code=400, detail=f"Unknown episode_type: {req.episode_type}") @app.post("/end_episode", response_model=EndEpisodeResponse) @@ -366,34 +431,16 @@ async def end_episode(req: EndEpisodeRequest): # send workflow_output to zmq assert 'episodes' in shared_mem_dict - zmq_addr = shared_mem_dict[f"episodes-{episode_uuid}"].zmq_listen_result_addr - if DEBUG: logger.info(f"[server] episode_uuid: {episode_uuid} | Received new chat completion request (inside thread)") - socket = zmq_context.socket(zmq.REQ) - socket.setsockopt(zmq.RCVTIMEO, 60*1000) # 1 minute recv timeout - socket.connect(zmq_addr) - if DEBUG: logger.info(f"[server] episode_uuid: {episode_uuid} | connect done") - socket.send_string(workflow_output.model_dump_json()) - if DEBUG: logger.info(f"[server] episode_uuid: {episode_uuid} | send_string") + episode_type = shared_mem_dict[f"episodes-{episode_uuid}"].episode_type - # wait for ack - for _ in range(5): # max 5 minutes wait - try: - if DEBUG: logger.info(f"[server] episode_uuid: {episode_uuid} | recv_string begin.") - # : - # : ajet/task_runner/tinkerscript_runner.py - # : zmq_socket.send_string("ack") - # : "ack" - result_str = socket.recv_string() - break - except zmq.Again as e: - if DEBUG: logger.info(f"[server] episode_uuid: {episode_uuid} | recv_string timeout, retrying.") - continue + if episode_type == "train": + _register_final_episode_output(episode_uuid, workflow_output, shared_mem_dict, shared_mem_dict_lock) - # clean up episode records - with shared_mem_dict_lock: - del shared_mem_dict[f"episodes-{episode_uuid}"] - if episode_uuid in shared_mem_dict['unclaimed_episodes']: - shared_mem_dict['unclaimed_episodes'].remove(episode_uuid) + elif episode_type == "eval": + _revert_episode_to_unclaimed(episode_uuid) + + else: + raise HTTPException(status_code=400, detail=f"Unknown episode_type: {episode_type}") # return success return EndEpisodeResponse(success=True) @@ -416,4 +463,130 @@ async def get_episode_buffer(): return EpisodeBufferResponse(buffer=result) + + + # -------------------------------------------------------------------- + # ------------ bring engine back to ENGINE.OFFLINE ------------------- + # -------------------------------------------------------------------- + @app.post("/stop_engine") + async def stop_engine(): + """ + Terminate the training engine and reset all state. + This will: + - Kill the training process and all its subprocesses (forcefully if necessary) + - Set engine status to OFFLINE + - Remove all episodes (registered, claimed, and unclaimed) + - Clean up shared memory state + """ + try: + import psutil + + killed_pids = [] + errors = [] + + # Get the training process PID if it exists + training_pid = shared_mem_dict.get('training_process_pid', None) + + if training_pid is not None: + try: + # Try to get the process and all its children + try: + parent = psutil.Process(training_pid) + children = parent.children(recursive=True) + + # Kill all child processes first + for child in children: + try: + logger.info(f"[stop_engine] Terminating child process PID: {child.pid}") + child.terminate() + killed_pids.append(child.pid) + except psutil.NoSuchProcess: + logger.warning(f"[stop_engine] Child process {child.pid} already terminated") + except Exception as e: + logger.error(f"[stop_engine] Error terminating child process {child.pid}: {e}") + errors.append(f"Child {child.pid}: {str(e)}") + + # Wait for children to terminate gracefully + gone, alive = psutil.wait_procs(children, timeout=16) + + # Force kill any remaining children + for p in alive: + try: + logger.warning(f"[stop_engine] Force killing child process PID: {p.pid}") + p.kill() + except Exception as e: + logger.error(f"[stop_engine] Error force killing child {p.pid}: {e}") + errors.append(f"Force kill child {p.pid}: {str(e)}") + + # Now terminate the parent process + logger.info(f"[stop_engine] Terminating parent process PID: {training_pid}") + parent.terminate() + killed_pids.append(training_pid) + + # Wait for parent to terminate gracefully + try: + parent.wait(timeout=3) + except psutil.TimeoutExpired: + logger.warning(f"[stop_engine] Force killing parent process PID: {training_pid}") + parent.kill() + + except psutil.NoSuchProcess: + logger.warning(f"[stop_engine] Process {training_pid} not found (may have already terminated)") + + except Exception as e: + logger.error(f"[stop_engine] Error killing training process: {e}") + errors.append(f"Training process: {str(e)}") + else: + logger.info("[stop_engine] No training process PID found in shared memory") + + # Clean up all episodes from shared memory + with shared_mem_dict_lock: + episode_keys = [k for k in shared_mem_dict.keys() if k.startswith("episodes-")] + for key in episode_keys: + del shared_mem_dict[key] + logger.info(f"[stop_engine] Removed episode: {key}") + + # Clear unclaimed episodes list + if 'unclaimed_episodes' in shared_mem_dict: + num_unclaimed = len(shared_mem_dict['unclaimed_episodes']) + shared_mem_dict['unclaimed_episodes'] = [] + logger.info(f"[stop_engine] Cleared {num_unclaimed} unclaimed episodes") + + # Reset engine status to OFFLINE + shared_mem_dict['engine_status'] = "ENGINE.OFFLINE" + + # Remove training process PID + if 'training_process_pid' in shared_mem_dict: + del shared_mem_dict['training_process_pid'] + + logger.info("[stop_engine] Engine status set to OFFLINE") + + result = { + "success": True, + "killed_pids": killed_pids, + "episodes_removed": len(episode_keys) if 'episode_keys' in locals() else 0, + } + + if errors: + result["warnings"] = errors + logger.warning(f"[stop_engine] Completed with warnings: {errors}") + else: + logger.info(f"[stop_engine] Successfully terminated engine and reset state") + + return result + + except Exception as e: + logger.error(f"[stop_engine] Unexpected error: {e}") + import traceback + traceback.print_exc() + + # Even if there's an error, try to reset the status + try: + with shared_mem_dict_lock: + shared_mem_dict['engine_status'] = "ENGINE.OFFLINE" + except: + pass + + return {"success": False, "error": str(e)} + return app, register_episode_ready_listener() diff --git a/ajet/tuner_lib/weight_tuner/experimental/interchange_utils.py b/ajet/tuner_lib/weight_tuner/experimental/interchange_utils.py index fecff76f..04caa015 100644 --- a/ajet/tuner_lib/weight_tuner/experimental/interchange_utils.py +++ b/ajet/tuner_lib/weight_tuner/experimental/interchange_utils.py @@ -43,6 +43,7 @@ class EndEpisodeResponse(BaseModel): class EpisodeStatus(BaseModel): episode_uuid: str episode_status: str = "rolling" + episode_type: str = "train" openai_base_url: str = "" openai_api_key: str = "" client_uuid: str = "" diff --git a/ajet_tinkerscript_threading.py b/ajet_tinkerscript_threading.py index 6fea8282..00e53b44 100644 --- a/ajet_tinkerscript_threading.py +++ b/ajet_tinkerscript_threading.py @@ -15,6 +15,7 @@ # --------- configurations that take effect locally ------------- LOCAL_GRPO_N = 4 # grpo group size LOCAL_NUM_EPOCH = 10000 +LOCAL_NUM_EPOCH = 1 LOCAL_MAX_PARALLEL = 32 LOCAL_DATASET_PATH = "/mnt/data_cpfs/qingxu.fu/dataset/openai/gsm8k/main" REMOTE_TINKERJET_URL = "http://localhost:10086" # Change to your tinkerscript remote url @@ -51,6 +52,7 @@ def main(): grpo_n=LOCAL_GRPO_N, ) ) + # tinkerscript_remote.stop_engine() # tinkerscript_remote = connect_to_tinkerscript_server(sync_train_config=False, start_engine=False) submit_sem = threading.BoundedSemaphore(LOCAL_MAX_PARALLEL) @@ -77,12 +79,12 @@ def rollout(task): # Main Training loop with ThreadPoolExecutor(max_workers=LOCAL_MAX_PARALLEL) as executor: for epoch in range(LOCAL_NUM_EPOCH): - for task in dataset.get_training_tasks(): + for task in dataset.get_training_tasks()[:100]: print(f"Submitting task for epoch {epoch}") submit_sem.acquire() executor.submit(rollout, task) - + tinkerscript_remote.stop_engine() # model_path = tinkerscript_remote.download_latest_model(path='./tinkerscript_saved_model') # Get tuned model from tinkerscript remote From 766000731695fb11b243629e27f913dfdafbb47e Mon Sep 17 00:00:00 2001 From: "qingxu.fu" Date: Thu, 29 Jan 2026 18:31:36 +0800 Subject: [PATCH 14/25] union_gen_batch_via_task_id is to be tested --- ajet/backbone/trainer_verl.py | 48 ++++++++++--------- ajet/launcher.py | 6 +-- ajet/schema/task.py | 10 ++-- ajet/task_runner/tinkerscript_runner.py | 8 +++- .../experimental/as_oai_model_server.py | 17 +++++-- .../experimental/as_tinkerscript_client.py | 9 ++-- .../experimental/as_tinkerscript_server.py | 30 ++++++++---- .../experimental/interchange_utils.py | 1 + ajet/utils/config_utils.py | 9 ++-- ajet_tinkerscript_threading.py | 8 ++-- 10 files changed, 94 insertions(+), 52 deletions(-) diff --git a/ajet/backbone/trainer_verl.py b/ajet/backbone/trainer_verl.py index 0fc224d5..74498bbf 100644 --- a/ajet/backbone/trainer_verl.py +++ b/ajet/backbone/trainer_verl.py @@ -99,16 +99,20 @@ def parse_reward_from_dataproto(data: DataProto, return_dict=False) -> dict | to return reward_tensor -def union_gen_batch_via_task_id(tasks, batch: DataProto, gen_batch_output: DataProto): +def union_gen_batch_via_task_id(tasks, batch: DataProto, gen_batch_output: DataProto, discard_original_batch=False): """ Union the gen_batch_output with the batch based on task_id. """ - map_task_id_to_index = {t.task_id: i for i, t in enumerate(tasks)} - gen_task_task_ids = gen_batch_output.non_tensor_batch["task_ids"] - indices = [map_task_id_to_index[tid] for tid in gen_task_task_ids] - batch_extend = batch.select_idxs(indices) - batch_final = batch_extend.union(gen_batch_output) - return batch_final + if not discard_original_batch: + map_task_id_to_index = {t.task_id: i for i, t in enumerate(tasks)} + gen_task_task_ids = gen_batch_output.non_tensor_batch["task_ids"] + indices = [map_task_id_to_index[tid] for tid in gen_task_task_ids] + batch_extend = batch.select_idxs(indices) + batch_final = batch_extend.union(gen_batch_output) + return batch_final + else: + gen_batch_output.non_tensor_batch['uid'] = gen_batch_output.non_tensor_batch["task_ids"] + return gen_batch_output def compute_advantage( @@ -550,16 +554,17 @@ def fit(self): # noqa: C901 # pass global_steps to trace gen_batch.meta_info["global_steps"] = self.global_steps is_last_step = self.global_steps >= self.total_training_steps - + from ajet import bp + bp("BATCH") with marked_timer("step", timing_raw): # generate a batch - logger.info("=== + rollout step begin ===") + logger.info("rollout step begin") with marked_timer("gen", timing_raw, color="red"): assert self.async_rollout_mode - logger.info("=== wake up begin ===") + logger.info("wake up begin") self.async_rollout_manager.wake_up() self._update_interchange_server_status_flag("ENGINE.ROLLING") - logger.info("=== wake up end ===") + logger.info("wake up end") tasks: List[Task] = [ dict_to_ajet_task(dict( task_id=gen_batch.non_tensor_batch["task_id"][i], @@ -578,16 +583,14 @@ def fit(self): # noqa: C901 ] ) ) - logger.info("=" * 10 + "start fit rollout" + "=" * 10) + logger.info("start fit rollout") self.parallel_env.current_global_steps = self.global_steps context_tracker_arr: List[BaseContextTracker] = self.parallel_env.rollout( tasks, mode="sample", epoch=f"train.{epoch}" ) - logger.info("=" * 10 + "end fit rollout" + "=" * 10) - self._update_interchange_server_status_flag("ENGINE.WEIGHT_SYNCING") - logger.info("begin to convert context_tracker_arr to dataproto") + logger.info("end fit rollout") gen_batch_output = self.parallel_env.to_dataproto(context_tracker_arr) - logger.info("end convertion") + logger.info("end dataproto convertion") success_rate = [ traj.reward_structure.success_rate for traj in context_tracker_arr @@ -630,17 +633,17 @@ def fit(self): # noqa: C901 logger.info( f"gen_batch_output.info batch.keys={gen_batch_output.batch.keys()}" ) + self._update_interchange_server_status_flag("ENGINE.WEIGHT_SYNCING") self.async_rollout_manager.sleep() - logger.info("=== - rollout step end ===") + logger.info("rollout step end") - if self.config.algorithm.adv_estimator == AdvantageEstimator.REMAX: - raise NotImplementedError("REMAX is not supported in GRPO yet.") batch.non_tensor_batch["uid"] = np.array( [str(uuid.uuid4()) for _ in range(len(batch.batch))], dtype=object, ) - batch = union_gen_batch_via_task_id(tasks, batch, gen_batch_output) + discard_original_batch = self.config.ajet.enable_tinkerscript_mode + batch = union_gen_batch_via_task_id(tasks, batch, gen_batch_output, discard_original_batch) batch.batch["response_mask"] = compute_response_mask(batch) if "response_mask" not in batch.batch.keys(): @@ -674,7 +677,7 @@ def fit(self): # noqa: C901 ) # recompute old_log_probs - logger.info("=== + compute log_probs begin ===") + logger.info("+ compute log_probs begin") with marked_timer("old_log_prob", timing_raw, color="blue"): old_log_prob = self.actor_rollout_wg.compute_log_prob(batch) entropys = old_log_prob.batch["entropys"] @@ -946,7 +949,8 @@ def _validate(self): dtype=object, ) tasks = tasks[: len(main_val_dataset)] - test_batch = union_gen_batch_via_task_id(tasks, test_batch, test_output_gen_batch) + discard_original_batch = self.config.ajet.enable_tinkerscript_mode + test_batch = union_gen_batch_via_task_id(tasks, test_batch, test_output_gen_batch, discard_original_batch) # test_batch = test_batch.union(test_output_gen_batch) test_batch.meta_info["validate"] = True diff --git a/ajet/launcher.py b/ajet/launcher.py index caf739ff..93c68ab4 100644 --- a/ajet/launcher.py +++ b/ajet/launcher.py @@ -176,9 +176,9 @@ def setup_environment_vars(args, exp_config, main_yaml_fp): env["RAY_record_task_actor_creation_sites"] = "true" # assert exp_config["ajet"]["rollout"]["max_env_worker"] <= 4, "parallel worker too many for debugging mode" # type: ignore if exp_config["ajet"]["rollout"]["max_env_worker"] > 1: # type: ignore - exp_config["ajet"]["rollout"]["max_env_worker"] = 1 + # exp_config["ajet"]["rollout"]["max_env_worker"] = 1 logger.warning( - "For debugging mode, max_env_worker is set to 1 to facilitate debugging." + "For debugging mode, please set max_env_worker to 1 to facilitate debugging." ) logger.warning("Debug mode is ON") else: @@ -206,7 +206,7 @@ def start_tinkerscript_server(env, config): assert config.ajet.enable_experimental_interchange_server, \ "Please enable_experimental_interchange_server in config to start tinkerscript server." from ajet.tuner_lib.weight_tuner.experimental.as_oai_model_server import start_interchange_server - start_interchange_server(config, blocking=True) + start_interchange_server(config, blocking=True, env=env) def main(): diff --git a/ajet/schema/task.py b/ajet/schema/task.py index a20a4b59..1553e20d 100644 --- a/ajet/schema/task.py +++ b/ajet/schema/task.py @@ -8,11 +8,11 @@ class Task(BaseModel): - main_query: str = Field(default="") - init_messages: List[dict] = Field(default=[]) - task_id: str = Field(default="") - env_type: str = Field(default="") - metadata: dict = Field(default_factory=dict) + main_query: str = Field(default="", description="main query or instruction for the task, maybe absent if the task has valid init_messages.") + init_messages: List[dict] = Field(default=[], description="initial messages for the task, maybe absent if the task has valid main_query.") + task_id: str = Field(default="", description="same task_id mean same task, and of course, same GRPO group.") + env_type: str = Field(default="", description="valid when the task need to interact with a gym env.") + metadata: dict = Field(default_factory=dict, description="additional metadata for the task, e.g., reference answer for eval tasks.") """ diff --git a/ajet/task_runner/tinkerscript_runner.py b/ajet/task_runner/tinkerscript_runner.py index d0351767..975c2b24 100644 --- a/ajet/task_runner/tinkerscript_runner.py +++ b/ajet/task_runner/tinkerscript_runner.py @@ -40,7 +40,7 @@ def register_episode_and_wait_output(self, episode_uuid: str, openai_base_url: s zmq_socket = zmq.Context().socket(zmq.REP) zmq_socket.bind(zmq_listen_result_addr) speicial_messages = [ - "RUNNER.RESET_CONTEXT_TRACKER" + "RUNNER.SPECIAL.RESET_CONTEXT_TRACKER" ] while True: # : @@ -103,6 +103,12 @@ def execute(self, workflow_task: WorkflowTask) -> BaseContextTracker: context_tracker=context_tracker, ) + # the most important thing is to fix task_id to client task_id, set task_id to workflow_task and context_tracker task_id + assert "task_id" in workflow_output.metadata, "workflow_output.metadata must contain task_id" + task_id = workflow_output.metadata.get("task_id", "") + workflow_task.task_id = task_id + context_tracker.task_id = task_id + if workflow_output.reward is not None: raw_reward, is_success = ( workflow_output.reward, diff --git a/ajet/tuner_lib/weight_tuner/experimental/as_oai_model_server.py b/ajet/tuner_lib/weight_tuner/experimental/as_oai_model_server.py index 072632e2..22f5916a 100644 --- a/ajet/tuner_lib/weight_tuner/experimental/as_oai_model_server.py +++ b/ajet/tuner_lib/weight_tuner/experimental/as_oai_model_server.py @@ -271,7 +271,7 @@ async def serve_with_monitor(additional_coro): # Convenience function for quick server startup -def start_interchange_server(config, blocking=False) -> int: +def start_interchange_server(config, blocking=False, env={}) -> int: # Read config already_started = config.ajet.interchange_server.already_started experiment_dir = config.ajet.experiment_dir @@ -293,6 +293,9 @@ def start_interchange_server(config, blocking=False) -> int: # init interchage server sub-process if not already_started: + # apply env vars + os.environ.update(env) + # start interchange server interchange_server = InterchangeServer( experiment_dir, port, @@ -342,6 +345,14 @@ def start_interchange_server(config, blocking=False) -> int: f"URL 1: {localhost_url}\n------\n" f"URL 2: {host_url}\n------\n" f"Press Ctrl+C to stop.") - if interchange_server: - interchange_server.join() + try: + if interchange_server: + interchange_server.join() + except KeyboardInterrupt: + logger.info("Shutting down interchange server...") + try: httpx.get(f"http://127.0.0.1:{port}/stop_engine", timeout=8).status_code + except Exception: pass + + if interchange_server: + interchange_server.terminate() return -1 \ No newline at end of file diff --git a/ajet/tuner_lib/weight_tuner/experimental/as_tinkerscript_client.py b/ajet/tuner_lib/weight_tuner/experimental/as_tinkerscript_client.py index b9a7b515..7b8401b1 100644 --- a/ajet/tuner_lib/weight_tuner/experimental/as_tinkerscript_client.py +++ b/ajet/tuner_lib/weight_tuner/experimental/as_tinkerscript_client.py @@ -4,7 +4,7 @@ import yaml from typing import List, Tuple from loguru import logger -from ajet.schema.task import WorkflowOutput +from ajet.schema.task import WorkflowOutput, Task from ajet.copilot.job import AgentJetJob from ajet.tuner_lib.weight_tuner.as_oai_baseurl_apikey import OpenaiBaseUrlAndApiKey from ajet.tuner_lib.weight_tuner.experimental.interchange_utils import ( @@ -76,16 +76,19 @@ def begin_episode(self, allow_discard_timeout=60, episode_type="train") -> Tuple logger.error(f"Error claiming episode: {e}. Retrying in 5s...") time.sleep(5) - def end_episode(self, episode_uuid: str, workflow_output: WorkflowOutput): + def end_episode(self, task:Task, episode_uuid: str, workflow_output: WorkflowOutput): if not episode_uuid: logger.error("No episode to end.") return try: + task_id = task.task_id + workflow_output.metadata["task_id"] = task_id req_obj = EndEpisodeRequest( client_uuid=self.client_uuid, episode_uuid=episode_uuid, - workflow_output=workflow_output + workflow_output=workflow_output, + task_id=task_id ) resp = httpx.post( diff --git a/ajet/tuner_lib/weight_tuner/experimental/as_tinkerscript_server.py b/ajet/tuner_lib/weight_tuner/experimental/as_tinkerscript_server.py index 29a725c5..2c70ed3c 100644 --- a/ajet/tuner_lib/weight_tuner/experimental/as_tinkerscript_server.py +++ b/ajet/tuner_lib/weight_tuner/experimental/as_tinkerscript_server.py @@ -148,10 +148,11 @@ def _register_final_episode_output(episode_uuid, workflow_output, shared_mem_dic # -------------------------------------------------------------------------------------- async def register_episode_ready_listener(): - while True: - read_all_episode_status() - await asyncio.sleep(10) # check every 10 seconds - find_claimed_episodes_that_need_to_be_unclaimed() + pass + # while True: + # read_all_episode_status() + # await asyncio.sleep(10) # check every 10 seconds + # find_claimed_episodes_that_need_to_be_unclaimed() def read_all_episode_status() -> Optional[EpisodeStatus]: print_buffer = [] @@ -242,17 +243,26 @@ async def start_engine(): # Create args namespace args = SimpleNamespace( - conf=main_yaml_fp, backbone=backbone, exp_dir=exp_dir_final, with_logview=False, debug=False, + conf=main_yaml_fp, backbone=backbone, exp_dir=exp_dir_final, with_logview=False, + debug=False, ) + # get debug param + should_debug = os.environ.get("RAY_DEBUG_POST_MORTEM", "0") == "1" + debug_tags = os.environ.get("DEBUG_TAGS", "") + if should_debug: + args.debug = debug_tags + + def override_param_callback(config): + config['ajet']['interchange_server']['already_started'] = True + config['ajet']['interchange_server']['interchange_server_port'] = int(os.getenv("AJET_DAT_INTERCHANGE_PORT")) # type: ignore + return config # Finalize experiment config main_yaml_fp, exe_exp_base, exp_name, exp_config = prepare_experiment_config( - main_yaml_fp, exp_dir_final, backbone + main_yaml_fp, exp_dir_final, backbone, override_param_callback ) # Setup environment variables - exp_config['ajet']['interchange_server']['already_started'] = True - exp_config['ajet']['interchange_server']['interchange_server_port'] = int(os.getenv("AJET_DAT_INTERCHANGE_PORT")) # type: ignore env, exp_config = setup_environment_vars(args, exp_config, main_yaml_fp) # Start ray if not already started @@ -421,6 +431,10 @@ async def end_episode(req: EndEpisodeRequest): client_uuid = req.client_uuid episode_uuid = req.episode_uuid workflow_output = req.workflow_output + task_id = req.task_id + + assert "task_id" in workflow_output.metadata, "workflow_output.metadata must contain task_id" + assert workflow_output.metadata["task_id"] == task_id, "workflow_output.metadata.task_id must match req.task_id" if 'episodes' not in shared_mem_dict: logger.error(f"[server] No episodes registered yet.") diff --git a/ajet/tuner_lib/weight_tuner/experimental/interchange_utils.py b/ajet/tuner_lib/weight_tuner/experimental/interchange_utils.py index 04caa015..12d549aa 100644 --- a/ajet/tuner_lib/weight_tuner/experimental/interchange_utils.py +++ b/ajet/tuner_lib/weight_tuner/experimental/interchange_utils.py @@ -35,6 +35,7 @@ class EndEpisodeRequest(BaseModel): client_uuid: str episode_uuid: str workflow_output: WorkflowOutput + task_id: str class EndEpisodeResponse(BaseModel): success: bool diff --git a/ajet/utils/config_utils.py b/ajet/utils/config_utils.py index 0d98421f..273c2cda 100644 --- a/ajet/utils/config_utils.py +++ b/ajet/utils/config_utils.py @@ -168,7 +168,7 @@ def config_safe_guard(config: dict, backbone: str) -> dict: def read_ajet_hierarchical_config( - yaml_fp, exp_name, backbone, write_to=None, exp_dir="saved_experiments" + yaml_fp, exp_name, backbone, write_to=None, exp_dir="saved_experiments", override_param_callback=None ): if yaml_fp is None: config = { @@ -210,6 +210,9 @@ def read_ajet_hierarchical_config( config["defaults"].remove("trinity_default") config["hydra"]["searchpath"].remove("file://ajet/default_config/trinity") + if override_param_callback is not None: + config = override_param_callback(config) + if write_to: with open(write_to, "w") as file: yaml.dump(config, file) @@ -239,7 +242,7 @@ def expand_ajet_hierarchical_config(config, write_to=None): return config_final -def prepare_experiment_config(yaml_path, exp_dir, backbone): +def prepare_experiment_config(yaml_path, exp_dir, backbone, override_param_callback=None): """ Prepare experiment configuration by reading YAML, setting up backup directories, and copying necessary files for the experiment. @@ -317,7 +320,7 @@ def prepare_experiment_config(yaml_path, exp_dir, backbone): ## 4. edit new yaml config = read_ajet_hierarchical_config( - yaml_backup_dst, exp_name, backbone, write_to=yaml_backup_dst, exp_dir=exp_dir + yaml_backup_dst, exp_name, backbone, write_to=yaml_backup_dst, exp_dir=exp_dir, override_param_callback=override_param_callback ) config_final = expand_ajet_hierarchical_config(config, write_to=yaml_backup_dst) diff --git a/ajet_tinkerscript_threading.py b/ajet_tinkerscript_threading.py index 00e53b44..b9907449 100644 --- a/ajet_tinkerscript_threading.py +++ b/ajet_tinkerscript_threading.py @@ -10,6 +10,7 @@ from ajet import WorkflowOutput from ajet.task_reader import RouterTaskReader from ajet.utils.retry import retry_with_backoff +from ajet.schema.task import Task from concurrent.futures import ThreadPoolExecutor # --------- configurations that take effect locally ------------- @@ -44,6 +45,7 @@ def main(): # Hand shake with remote tinkerscript server tinkerscript_remote = TinkerScriptClient(REMOTE_TINKERJET_URL) + tinkerscript_remote.stop_engine() tinkerscript_remote.auto_sync_train_config_and_start_engine( AgentJetJob( algorithm="grpo", @@ -52,8 +54,6 @@ def main(): grpo_n=LOCAL_GRPO_N, ) ) - # tinkerscript_remote.stop_engine() - # tinkerscript_remote = connect_to_tinkerscript_server(sync_train_config=False, start_engine=False) submit_sem = threading.BoundedSemaphore(LOCAL_MAX_PARALLEL) @@ -67,7 +67,7 @@ def rollout(task): # execute agent workflow_output = execute_agent(task, api_baseurl_key) # report output back to tinkerscript remote - tinkerscript_remote.end_episode(episode_uuid, workflow_output) + tinkerscript_remote.end_episode(task, episode_uuid, workflow_output) # collect reward group_reward.append(workflow_output.reward) print(f"Group reward mean & std: {sum(group_reward)/len(group_reward)} +/- { (max(group_reward)-min(group_reward))/2 }") @@ -94,7 +94,7 @@ def rollout(task): @retry_with_backoff(max_retry=2) -def execute_agent(task, api_baseurl_key: OpenaiBaseUrlAndApiKey): +def execute_agent(task: Task, api_baseurl_key: OpenaiBaseUrlAndApiKey): # Prepare base_url, api_key base_url, api_key = (api_baseurl_key.base_url, api_baseurl_key.api_key) # Read dataset item From f2f3b16efab73c375bdc58e982bda2647a6e8c32 Mon Sep 17 00:00:00 2001 From: "qingxu.fu" Date: Fri, 30 Jan 2026 15:47:52 +0800 Subject: [PATCH 15/25] stage dataset io improvement --- ajet/backbone/main_verl.py | 28 +++--- ajet/backbone/trainer_trinity.py | 6 +- ajet/copilot/job.py | 13 +-- ajet/launcher.py | 57 +----------- ajet/task_reader/__init__.py | 51 ++++++----- ajet/task_reader/hf_dataset_reader.py | 49 +++++++---- ajet/utils/core_env_vars.py | 10 ++- ajet/utils/launch_utils.py | 57 ++++++++++++ ajet_tinkerscript_threading.py | 2 +- .../benchmark_learn2ask.py | 2 +- .../demo_tinkerjet/demo_tinkerjet_math.py | 87 ------------------- 11 files changed, 152 insertions(+), 210 deletions(-) diff --git a/ajet/backbone/main_verl.py b/ajet/backbone/main_verl.py index 47a48cfc..4ec0cf86 100644 --- a/ajet/backbone/main_verl.py +++ b/ajet/backbone/main_verl.py @@ -22,10 +22,15 @@ import hydra import ray from beast_logger import print_dict -from omegaconf import OmegaConf +from omegaconf import DictConfig, OmegaConf from verl.trainer.ppo.reward import load_reward_manager from verl.utils.device import is_cuda_available +from verl.utils.dataset.rl_dataset import collate_fn +from torch.utils.data import Dataset as TorchDataset +# Create training and validation datasets. +from ajet.task_reader import RouterTaskReader, task_to_standard_dataset +from ajet.utils.process_dataset import create_rl_sampler from ajet.utils.core_env_vars import get_runtime_env from ajet.utils.launch_utils import set_loguru_default_color @@ -33,17 +38,17 @@ @hydra.main(config_path="config", config_name="ppo_trainer", version_base=None) -def main(config): +def main(config: DictConfig) -> None: """Main entry point for PPO training with Hydra configuration management. Args: - config_dict: Hydra configuration dictionary containing training parameters. + config: Hydra configuration dictionary containing training parameters. """ run_ppo(config) # Define a function to run the PPO-like training process -def run_ppo(config) -> None: +def run_ppo(config: DictConfig) -> None: """Initialize Ray cluster and run distributed PPO training process. Args: @@ -55,7 +60,6 @@ def run_ppo(config) -> None: if not ray.is_initialized(): # this is for local ray cluster runtime_env = get_runtime_env(config) - print_dict(runtime_env["env_vars"], "runtime_env") ray.init( runtime_env=runtime_env, num_cpus=config.ray_init.num_cpus, @@ -227,21 +231,13 @@ def run(self, config): resource_pool_spec=resource_pool_spec, mapping=mapping ) - from verl.utils.dataset.rl_dataset import collate_fn - - # Create training and validation datasets. - from ajet.task_reader import ( - RouterTaskReader, - task_to_standard_dataset, - ) - from ajet.utils.process_dataset import create_rl_sampler - task_reader = RouterTaskReader( config.ajet.task_reader.type, config.ajet.task_reader, ) - val_dataset = task_to_standard_dataset(task_reader.get_validation_tasks()) - train_dataset = task_to_standard_dataset(task_reader.get_training_tasks()) + + train_dataset: TorchDataset = task_to_standard_dataset(task_reader.generate_training_tasks) # type: ignore + val_dataset: TorchDataset = task_to_standard_dataset(task_reader.generate_validation_tasks) # type: ignore train_sampler = create_rl_sampler(config.data, train_dataset) from ajet.backbone.trainer_verl import AjetRayPPOTrainer diff --git a/ajet/backbone/trainer_trinity.py b/ajet/backbone/trainer_trinity.py index 8000a636..7ab60715 100644 --- a/ajet/backbone/trainer_trinity.py +++ b/ajet/backbone/trainer_trinity.py @@ -206,11 +206,9 @@ def __init__(self, config): dataset_segments = [] if "train" in self.split: - dataset_segments.append(task_to_standard_dataset(task_reader.get_training_tasks())) + dataset_segments.append(task_to_standard_dataset(task_reader.generate_training_tasks)) # type: ignore if "val" in self.split: - dataset_segments.append( - task_to_standard_dataset(task_reader.get_validation_tasks()) - ) + dataset_segments.append(task_to_standard_dataset(task_reader.generate_validation_tasks)) # type: ignore if not dataset_segments: raise ValueError( f"Unsupported split '{self.split}'. Expected to contain 'train' or 'val'." diff --git a/ajet/copilot/job.py b/ajet/copilot/job.py index 190d1345..21f96d4a 100644 --- a/ajet/copilot/job.py +++ b/ajet/copilot/job.py @@ -17,11 +17,7 @@ import yaml from loguru import logger -from ajet.launcher import ( - check_avail_gpu, - get_backbone_target, - setup_environment_vars, -) + from ajet.default_config.ajet_default import Config from ajet.utils.config_utils import ( expand_ajet_hierarchical_config, @@ -29,7 +25,12 @@ read_ajet_hierarchical_config, ) from ajet.utils.dynamic_import import cls_to_path -from ajet.utils.launch_utils import execute_training_process +from ajet.utils.launch_utils import ( + execute_training_process, + check_avail_gpu, + get_backbone_target, + setup_environment_vars, +) class AgentJetJob: diff --git a/ajet/launcher.py b/ajet/launcher.py index 93c68ab4..3bd484a5 100644 --- a/ajet/launcher.py +++ b/ajet/launcher.py @@ -16,6 +16,8 @@ check_debugpy_version, check_avail_gpu, dict_to_namespace, + get_backbone_target, + setup_environment_vars, ) from ajet.utils.pty import pty_launch @@ -137,61 +139,6 @@ def parse_args(): return parser.parse_args() -def get_backbone_target(backbone): - """ - Determine the appropriate backbone target module based on the backbone name. - - Args: - backbone (str): The backbone name (e.g., "verl", "debug", "trinity") - - Returns: - str: The full module path for the specified backbone - """ - backbone_target = "ajet.backbone.main_verl" # Default to trinity - if backbone == "verl": - backbone_target = "ajet.backbone.main_verl" - if backbone == "debug": - backbone_target = "ajet.backbone.main_vllm" - if backbone == "trinity": - backbone_target = "ajet.backbone.main_trinity" - return backbone_target - - -def setup_environment_vars(args, exp_config, main_yaml_fp): - """ - Configure environment variables based on command line arguments. - - Args: - args: Command line arguments - exp_config: Experiment configuration dictionary - main_yaml_fp: Path to main YAML configuration file - - Returns: - dict: Configured environment variables dictionary - """ - env = os.environ.copy() - if args.debug: - env["RAY_DEBUG_POST_MORTEM"] = "1" - env["DEBUG_TAGS"] = args.debug - env["RAY_record_task_actor_creation_sites"] = "true" - # assert exp_config["ajet"]["rollout"]["max_env_worker"] <= 4, "parallel worker too many for debugging mode" # type: ignore - if exp_config["ajet"]["rollout"]["max_env_worker"] > 1: # type: ignore - # exp_config["ajet"]["rollout"]["max_env_worker"] = 1 - logger.warning( - "For debugging mode, please set max_env_worker to 1 to facilitate debugging." - ) - logger.warning("Debug mode is ON") - else: - logger.warning("Debug mode is OFF") - # if args.conf: - # assert exp_config["ajet"]["rollout"]["max_env_worker"] > 4, "parallel worker too few" # type: ignore - if args.backbone == "trinity": - env["AJET_CONFIG_REDIRECT"] = main_yaml_fp # type: ignore - if args.backbone == "debug": - env["AJET_DEBUG"] = "1" # type: ignore - return env, exp_config - - def check_model_file_exists(exp_config): model_path = exp_config["ajet"]["model"]["path"] # if model_path has more than 2 '/', we consider it as a dir path diff --git a/ajet/task_reader/__init__.py b/ajet/task_reader/__init__.py index b431456f..e1a35b5d 100644 --- a/ajet/task_reader/__init__.py +++ b/ajet/task_reader/__init__.py @@ -1,8 +1,8 @@ -from typing import List - import datasets import numpy as np +from typing import List, List, Union +from datasets import Dataset from ajet.schema.task import Task from ajet.task_reader.data_generator_reader import DataGeneratorTaskReader from ajet.task_reader.env_service_reader import EnvServiceTaskReader @@ -10,6 +10,7 @@ from ajet.task_reader.jsonl_reader import JsonlTaskReader from ajet.task_reader.task_reader_base import BaseTaskReader from ajet.task_reader.tracing_reader import TracingReader +from typing import Generator class RandomDummyTaskReader(BaseTaskReader): @@ -44,6 +45,10 @@ def get_validation_tasks(self) -> List[Task]: return self._load_dataset_split("dataset_name", "split") +def list_to_generator(tasks: List[Task]) -> Generator: + for task in tasks: + yield task + class RouterTaskReader(BaseTaskReader): def __init__(self, reader_type, reader_config): super().__init__(None) @@ -78,33 +83,39 @@ def get_validation_tasks(self) -> List[Task]: np.random.shuffle(result) # type: ignore return result + def generate_training_tasks(self) -> Generator: + if hasattr(self.task_reader, "generate_training_tasks"): + result = self.task_reader.generate_training_tasks() # type: ignore + else: + result = list_to_generator(self.task_reader.get_training_tasks()) + return result -def task_to_standard_dataset(tasks: List[Task]) -> datasets.Dataset: + def generate_validation_tasks(self) -> Generator: + if hasattr(self.task_reader, "generate_validation_tasks"): + result = self.task_reader.generate_validation_tasks() # type: ignore + else: + result = list_to_generator(self.task_reader.get_validation_tasks()) + return result + + + +def task_to_standard_dataset(gen_tasks) -> Dataset: """ - Convert a list of Task objects to a standard Hugging Face Dataset. + Convert a potentially large/infinite generator of Task objects + to a streaming Hugging Face Dataset. Args: - tasks (List[Task]): List of Task objects. + tasks: A generator or iterable producing Task objects. Returns: - datasets.Dataset: Hugging Face Dataset containing the tasks. + datasets.Dataset: A Hugging Face Dataset with streaming enabled. """ - data = { - "task_id": [], - "main_query": [], - "init_messages": [], - "env_type": [], - "metadata": [], - } + def gen(): + for task in gen_tasks(): + yield task.model_dump() - for task in tasks: - data["task_id"].append(task.task_id) - data["main_query"].append(task.main_query) - data["init_messages"].append(task.init_messages) - data["env_type"].append(task.env_type) - data["metadata"].append(task.metadata) + return Dataset.from_generator(gen) # type: ignore - return datasets.Dataset.from_dict(data) def dict_to_ajet_task(task_dict: dict) -> Task: diff --git a/ajet/task_reader/hf_dataset_reader.py b/ajet/task_reader/hf_dataset_reader.py index 381e48e2..0d0e5164 100644 --- a/ajet/task_reader/hf_dataset_reader.py +++ b/ajet/task_reader/hf_dataset_reader.py @@ -1,4 +1,4 @@ -from typing import List +from typing import List, Generator import datasets @@ -17,29 +17,36 @@ class HuggingFaceTaskReader(BaseTaskReader): def __init__(self, reader_config): super().__init__(reader_config) self.reader_config = reader_config + self.as_generator = False + self.dataset_name = self.reader_config.huggingface_dat_repo.dataset_path - def _load_dataset_split(self, dataset_name: str, split: str) -> List[Task]: + def _load_dataset_split(self, split: str): """ Load a dataset split from Hugging Face datasets. Args: - dataset_name: Name of the dataset in Hugging Face format (e.g., 'gsm8k') split: Name of the split to load (e.g., 'train', 'validation') Returns: - List[Task]: List of Task objects created from the dataset. + Generator: List of Task objects created from the dataset. """ try: - dataset = datasets.load_dataset(dataset_name, split=split) + if self.dataset_name.endswith(".parquet"): + # Load from local parquet file + dataset = datasets.load_dataset("parquet", data_files=self.dataset_name, split=split) + else: + # Load from Hugging Face hub + dataset = datasets.load_dataset(self.dataset_name, split=split) except Exception as e: raise ValueError( - f"Failed to load dataset '{dataset_name}' with split '{split}': {str(e)}" + f"Failed to load dataset '{self.dataset_name}' with split '{split}': {str(e)}" ) - # if len(dataset) == 0: - # raise ValueError(f"No examples found in dataset '{dataset_name}' with split '{split}'") + if len(dataset) == 0: + raise ValueError(f"No examples found in dataset '{self.dataset_name}' with split '{split}'") + + self.as_generator = True - tasks = [] for idx, example in enumerate(dataset): # Create Task object task = Task( @@ -49,28 +56,32 @@ def _load_dataset_split(self, dataset_name: str, split: str) -> List[Task]: env_type="no_env", metadata=example, ) - tasks.append(task) + yield task - return tasks + return - def get_training_tasks(self) -> List[Task]: + def generate_training_tasks(self): """ Get training tasks from the Hugging Face dataset specified in the config. Returns: - List[Task]: List of training Task objects. + A generator of training Task objects. """ - dataset_name = self.reader_config.huggingface_dat_repo.dataset_path split = self.reader_config.huggingface_dat_repo.training_split - return self._load_dataset_split(dataset_name, split) + return self._load_dataset_split(split) - def get_validation_tasks(self) -> List[Task]: + def generate_validation_tasks(self): """ Get validation tasks from the Hugging Face dataset specified in the config. Returns: - List[Task]: List of validation Task objects. + A generator of validation Task objects. """ - dataset_name = self.reader_config.huggingface_dat_repo.dataset_path split = self.reader_config.huggingface_dat_repo.validation_split - return self._load_dataset_split(dataset_name, split) + return self._load_dataset_split(split) + + def get_training_tasks(self): + return list(self.generate_training_tasks()) + + def get_validation_tasks(self): + return list(self.generate_validation_tasks()) diff --git a/ajet/utils/core_env_vars.py b/ajet/utils/core_env_vars.py index 7078bf2f..e48e1dda 100644 --- a/ajet/utils/core_env_vars.py +++ b/ajet/utils/core_env_vars.py @@ -1,4 +1,5 @@ import os +import copy from pathlib import Path from beast_logger import print_dict @@ -61,5 +62,12 @@ def get_runtime_env(config, is_trinity: bool = False) -> dict: if is_trinity: assert "AJET_CONFIG_REDIRECT" in runtime_env["env_vars"] - print_dict(runtime_env["env_vars"], "runtime_env") + print_env_dict = copy.deepcopy(runtime_env["env_vars"]) + # limit value length for printing + for k, v in print_env_dict.items(): + _len_limit = 500 + _len_limit_half = _len_limit // 2 + if len(v) > _len_limit: + print_env_dict[k] = v[:_len_limit_half] + "..." + v[-_len_limit_half:] + print_dict(print_env_dict, "runtime_env") return runtime_env diff --git a/ajet/utils/launch_utils.py b/ajet/utils/launch_utils.py index 0fa64e6f..441fbc93 100644 --- a/ajet/utils/launch_utils.py +++ b/ajet/utils/launch_utils.py @@ -13,6 +13,63 @@ from ajet.utils.smart_daemon import LaunchCommandWhenAbsent + +def get_backbone_target(backbone): + """ + Determine the appropriate backbone target module based on the backbone name. + + Args: + backbone (str): The backbone name (e.g., "verl", "debug", "trinity") + + Returns: + str: The full module path for the specified backbone + """ + backbone_target = "ajet.backbone.main_verl" # Default to trinity + if backbone == "verl": + backbone_target = "ajet.backbone.main_verl" + if backbone == "debug": + backbone_target = "ajet.backbone.main_vllm" + if backbone == "trinity": + backbone_target = "ajet.backbone.main_trinity" + return backbone_target + + +def setup_environment_vars(args, exp_config, main_yaml_fp): + """ + Configure environment variables based on command line arguments. + + Args: + args: Command line arguments + exp_config: Experiment configuration dictionary + main_yaml_fp: Path to main YAML configuration file + + Returns: + dict: Configured environment variables dictionary + """ + env = os.environ.copy() + if args.debug: + env["RAY_DEBUG_POST_MORTEM"] = "1" + env["DEBUG_TAGS"] = args.debug + env["RAY_record_task_actor_creation_sites"] = "true" + # assert exp_config["ajet"]["rollout"]["max_env_worker"] <= 4, "parallel worker too many for debugging mode" # type: ignore + if exp_config["ajet"]["rollout"]["max_env_worker"] > 1: # type: ignore + # exp_config["ajet"]["rollout"]["max_env_worker"] = 1 + logger.warning( + "For debugging mode, please set max_env_worker to 1 to facilitate debugging." + ) + logger.warning("Debug mode is ON") + else: + logger.warning("Debug mode is OFF") + # if args.conf: + # assert exp_config["ajet"]["rollout"]["max_env_worker"] > 4, "parallel worker too few" # type: ignore + if args.backbone == "trinity": + env["AJET_CONFIG_REDIRECT"] = main_yaml_fp # type: ignore + if args.backbone == "debug": + env["AJET_DEBUG"] = "1" # type: ignore + return env, exp_config + + + def set_loguru_default_color(): logger.remove() colorize = os.environ.get("LOGURU_COLORIZE", "YES").upper() not in ["NO", "0", "FALSE"] diff --git a/ajet_tinkerscript_threading.py b/ajet_tinkerscript_threading.py index b9907449..47b1dce9 100644 --- a/ajet_tinkerscript_threading.py +++ b/ajet_tinkerscript_threading.py @@ -79,7 +79,7 @@ def rollout(task): # Main Training loop with ThreadPoolExecutor(max_workers=LOCAL_MAX_PARALLEL) as executor: for epoch in range(LOCAL_NUM_EPOCH): - for task in dataset.get_training_tasks()[:100]: + for task in dataset.get_training_tasks(): print(f"Submitting task for epoch {epoch}") submit_sem.acquire() executor.submit(rollout, task) diff --git a/tests/bench/benchmark_learn2ask/benchmark_learn2ask.py b/tests/bench/benchmark_learn2ask/benchmark_learn2ask.py index 762609ce..697f0bad 100644 --- a/tests/bench/benchmark_learn2ask/benchmark_learn2ask.py +++ b/tests/bench/benchmark_learn2ask/benchmark_learn2ask.py @@ -27,7 +27,7 @@ def __init__(self): # step : [low, high ] 50 : [2.3, 99999.0], 100 : [2.5, 99999.0], - 200 : [2.9, 99999.0], + 200 : [2.6, 99999.0], } # fmt: on self.probe_list = ["reward_probe"] diff --git a/tutorial/demo_tinkerjet/demo_tinkerjet_math.py b/tutorial/demo_tinkerjet/demo_tinkerjet_math.py index a084cf51..e69de29b 100644 --- a/tutorial/demo_tinkerjet/demo_tinkerjet_math.py +++ b/tutorial/demo_tinkerjet/demo_tinkerjet_math.py @@ -1,87 +0,0 @@ -import re -import requests -from textwrap import dedent -from ajet import AgentJetJob -from ajet.copilot.tinkerjet.remote import TinkerScriptRemote -from ajet.default_config.ajet_default import AjetTaskReader, HuggingfaceDatRepo -from ajet.tuner_lib.weight_tuner.as_oai_baseurl_apikey import AgentJetAsOpenAI -from ajet import WorkflowOutput -from ajet.task_reader import RouterTaskReader -from ajet.utils.retry import retry_with_backoff -TINKERJET_URL = "http://localhost:10086" # Change to your tinkerjet remote url -NUM_EPOCH = 100 -GRPO_N = 4 # grpo group size - -class WeightUpdatedHalfway(Exception): - """Raised when the remote side starts updating model weights halfway through an episode.""" - -def main(): - # Handshake with tinkerjet remote, then send training param to tinkerjet remote (such as model to be trained, algorithm, etc) - tinkerjet_remote = TinkerScriptRemote(TINKERJET_URL) - tinkerjet_remote.sync_train_config( - AgentJetJob(backbone="verl", n_gpu=2, algorithm="grpo", model='qwen/Qwen2.5-1.5B-instruct') - ) - - # Dataset reader (read in your local machine only) - dataset = RouterTaskReader( - reader_type = "huggingface_dat_repo", - reader_config = AjetTaskReader( - huggingface_dat_repo = HuggingfaceDatRepo( dataset_path = "openai/gsm8k" ) - ) - ) - - # Define rollout - def rollout(task): - # Q: Can I run episodes in parallel? - # A: Yes, wrap `rollout` in a thread or process pool. - api_baseurl_key = tinkerjet_remote.begin_episode() - workflow_output = execute_agent(task, api_baseurl_key) - tinkerjet_remote.end_episode(workflow_output) - return workflow_output.reward - - # Main Training loop - for epoch in range(NUM_EPOCH): - for task in dataset.get_training_tasks(): - try: - for i in range(GRPO_N): - reward = rollout(task) - print(f"{epoch}-{task}-run:{i}-{reward}") - except WeightUpdatedHalfway as e: - print(f"The remote side has gone into the LLM model weight update phrase halfway through an episode." - f"This is **normal**." - f"The remote no longer need this task anymore, so let's go to next task.") - - # Get tuned model from tinkerjet remote - tuned_model_checkpoint = tinkerjet_remote.download_tuned_model() - return tuned_model_checkpoint - - -@retry_with_backoff(max_retry=2) -def execute_agent(task, api_baseurl_key: AgentJetAsOpenAI): - # Prepare base_url, api_key - base_url, api_key = (api_baseurl_key.base_url, api_baseurl_key.api_key) - # Read dataset item - query, reference_answer = (task.main_query, task.metadata["answer"]) - # Prepare messages - messages = [ - { "role": "system", "content": dedent("""You are an agent specialized in solving math problems. Please solve the math problem given to you. - You can write and execute Python code to perform calculation or verify your answer. You should return your final answer within \\boxed{{}}.""") }, - { "role": "user", "content": query } - ] - # Use raw http requests (non-streaming) to get response - response = requests.post( f"{base_url}/chat/completions", json = { "model": "fill_whatever_model", "messages": messages, }, - headers = { "Authorization": f"Bearer {api_key}" } ) - final_answer = response.json()['choices'][0]['message']['content'] - # Compute reward - reference_answer = reference_answer.split("####")[-1].strip() - pattern = r"\\boxed\{([^}]*)\}" - match = re.search(pattern, final_answer) - if match: is_success = match.group(1) == reference_answer - else: is_success = False - raw_reward = 1.0 if is_success else 0.0 - # Return - return WorkflowOutput(reward=raw_reward, metadata={"final_answer": final_answer}) - - -if __name__ == "__main__": - main() From 920e4d55b88d2210fd0ace00d71a81bafba986e9 Mon Sep 17 00:00:00 2001 From: "qingxu.fu" Date: Tue, 3 Feb 2026 01:15:03 +0800 Subject: [PATCH 16/25] stage academic translation agent --- ajet/backbone/trainer_verl.py | 9 + ajet/task_reader/hf_dataset_reader.py | 2 + .../experimental/as_tinkerscript_client.py | 2 +- t-agent.py | 224 ++++++++++++++++++ t_agent_reward.py | 174 ++++++++++++++ 5 files changed, 410 insertions(+), 1 deletion(-) create mode 100644 t-agent.py create mode 100644 t_agent_reward.py diff --git a/ajet/backbone/trainer_verl.py b/ajet/backbone/trainer_verl.py index 74498bbf..557f1948 100644 --- a/ajet/backbone/trainer_verl.py +++ b/ajet/backbone/trainer_verl.py @@ -112,6 +112,15 @@ def union_gen_batch_via_task_id(tasks, batch: DataProto, gen_batch_output: DataP return batch_final else: gen_batch_output.non_tensor_batch['uid'] = gen_batch_output.non_tensor_batch["task_ids"] + task_id_counter = {} + for i, tid in enumerate(gen_batch_output.non_tensor_batch["task_ids"]): + if tid in task_id_counter: + task_id_counter[tid] += 1 + else: + task_id_counter[tid] = 1 + current_id = task_id_counter[tid] + gen_batch_output.non_tensor_batch['rollout_ids'][i] = f"T{tid}R{current_id}" + logger.info(f'task_id_counter: {task_id_counter}') return gen_batch_output diff --git a/ajet/task_reader/hf_dataset_reader.py b/ajet/task_reader/hf_dataset_reader.py index 0d0e5164..21c13538 100644 --- a/ajet/task_reader/hf_dataset_reader.py +++ b/ajet/task_reader/hf_dataset_reader.py @@ -37,6 +37,8 @@ def _load_dataset_split(self, split: str): else: # Load from Hugging Face hub dataset = datasets.load_dataset(self.dataset_name, split=split) + # shuffle dataset + dataset = dataset.shuffle(seed=42) except Exception as e: raise ValueError( f"Failed to load dataset '{self.dataset_name}' with split '{split}': {str(e)}" diff --git a/ajet/tuner_lib/weight_tuner/experimental/as_tinkerscript_client.py b/ajet/tuner_lib/weight_tuner/experimental/as_tinkerscript_client.py index 7b8401b1..a48088fa 100644 --- a/ajet/tuner_lib/weight_tuner/experimental/as_tinkerscript_client.py +++ b/ajet/tuner_lib/weight_tuner/experimental/as_tinkerscript_client.py @@ -182,7 +182,7 @@ def _wait_until_status_change_to(self, desired_status="ENGINE.ROLLING"): current_time = time.time() # Report status every 5 seconds - if current_time - last_report_time >= 5: + if current_time - last_report_time >= 10: logger.info(f"Current engine status (already waited {current_time - init_poll_time:.1f}s): {current_status}") last_report_time = current_time diff --git a/t-agent.py b/t-agent.py new file mode 100644 index 00000000..617df2c8 --- /dev/null +++ b/t-agent.py @@ -0,0 +1,224 @@ + +from ajet.default_config.ajet_default import AjetTaskReader, HuggingfaceDatRepo +from ajet.task_reader import RouterTaskReader +import re +import os +import threading +import time +import requests +from loguru import logger +from textwrap import dedent +from ajet.copilot.job import AgentJetJob +from ajet.default_config.ajet_default import AjetTaskReader, HuggingfaceDatRepo +from ajet.tuner_lib.weight_tuner.as_oai_baseurl_apikey import OpenaiBaseUrlAndApiKey +from ajet import WorkflowOutput +from ajet.task_reader import RouterTaskReader +from ajet.utils.retry import retry_with_backoff +from ajet.schema.task import Task +from concurrent.futures import ThreadPoolExecutor +from beast_logger import print_listofdict +import asyncio + +# Import reward computation from t-agent-reward.py +from t_agent_reward import TranslationQualityGrader, build_translation_quality_messages + + + +LOCAL_DATASET_PATH = "/mnt/data_cpfs/qingxu.fu/agentjet/agentjet/tmp/arxiv_papers/train.parquet" + + +# Handshake with tinkerscript remote, then send training param to tinkerscript remote (such as model to be trained, algorithm, etc) +dataset = RouterTaskReader( + reader_type = "huggingface_dat_repo", + reader_config = AjetTaskReader( + huggingface_dat_repo = HuggingfaceDatRepo( + dataset_path = LOCAL_DATASET_PATH + ) + ) +) + + +def execute_agent(task: Task, api_baseurl_key: OpenaiBaseUrlAndApiKey): + # Prepare base_url, api_key + base_url, api_key = (api_baseurl_key.base_url, api_baseurl_key.api_key) + # Read dataset item + title = task.metadata['title'] + authors = task.metadata['authors'] + abstract = task.metadata['abstract'] + # Prepare messages + messages, rough_translate = rough_translate_agent(base_url, api_key, abstract) + + messages, fix_nouns = detect_hard_proper_nouns(base_url, api_key, abstract, rough_translate) + print_listofdict(messages, header="detect_hard_proper_nouns", mod="c") + + messages, final_translation = produce_final_translation(base_url, api_key, abstract, rough_translate, fix_nouns) + print_listofdict(messages, header="final_translation", mod="c") + + # Compute reward + time.sleep(1) + # Use the translation quality grader from t-agent-reward.py + from openjudge.models import OpenAIChatModel + # from openjudge.models.openai_model import OpenAIModel + grader = TranslationQualityGrader( + model=OpenAIChatModel(base_url=base_url, api_key=api_key, model="qwen-max") + ) + grader_score = asyncio.run(grader.aevaluate(original_text=abstract, translation=final_translation)) + raw_reward = grader_score.score / 3.0 # Normalize to 0-1 range (score is 0-3) + # Return + return WorkflowOutput(reward=raw_reward, metadata={ + "rough_translate": rough_translate, + "fix_nouns": fix_nouns, + "final_translation": final_translation + }) + + +def detect_hard_proper_nouns(base_url, api_key, abstract, rough_translate): + messages = [ + { + "role": "system", + "content": "You are responsible for detecting translation errors of discipline-specific proper nouns. " + "Use json to list all errors found in the translation result and provide correction. " + "Json format: [{\"original_word\": \"xxx\", \"wrong_translation\": \"xxx\", \"wrong_reason\": \"xxx\", \"correct_translation\": \"xxx\"}, ...]. " + "If no errors are found, return an empty list []." + }, + { + "role": "user", + "content": abstract + }, + { + "role": "assistant", + "content": rough_translate + }, + { + "role": "user", + "content": "Please list all translation errors of discipline-specific proper nouns found in the translation result according to the requirements." + }, + ] + + # Use raw http requests (non-streaming) to get response + response = requests.post( f"{base_url}/chat/completions", json = { "model": "qwen-max", "messages": messages, }, + headers = { "Authorization": f"Bearer {api_key}" } ) + fix_nouns = response.json()['choices'][0]['message']['content'] + messages += [ + { + "role": "assistant", + "content": fix_nouns + } + ] + return messages, fix_nouns + + +def produce_final_translation(base_url, api_key, abstract, rough_translate, fix_nouns): + """ + Third agent: Apply the corrections from fix_nouns to produce the final polished translation. + """ + messages = [ + { + "role": "system", + "content": "You are a professional academic translator responsible for producing the final, polished Chinese translation. " + "You will receive: 1) the original English abstract, 2) an initial translation, and 3) a list of corrections for proper nouns. " + "Your task is to apply all the corrections to produce a final translation that is accurate, fluent, and meets Chinese academic writing standards. " + "Ensure that all discipline-specific proper nouns are translated correctly according to the provided corrections. " + "Maintain the academic tone and ensure the translation is concise, rigorous, and natural in Chinese." + }, + { + "role": "user", + "content": f"Original English Abstract:\n{abstract}" + }, + { + "role": "user", + "content": f"Initial Translation:\n{rough_translate}" + }, + { + "role": "user", + "content": f"Corrections for Proper Nouns:\n{fix_nouns}" + }, + { + "role": "user", + "content": "Please produce the final, corrected Chinese translation by applying all the corrections listed above. " + "Output only the final translation without any explanations or additional text." + }, + ] + + # Use raw http requests (non-streaming) to get response + response = requests.post( f"{base_url}/chat/completions", json = { "model": "qwen-max", "messages": messages, }, + headers = { "Authorization": f"Bearer {api_key}" } ) + final_translation = response.json()['choices'][0]['message']['content'] + + messages += [ + { + "role": "assistant", + "content": final_translation + } + ] + + return messages, final_translation + + +def rough_translate_agent(base_url, api_key, abstract): + messages = [ + { + "role": "system", + "content": + "You are a professional language translator. " + "Translate the given Academic English text into Chinese accurately. " + "During the translation process, it is necessary to meet the linguistic norms of Chinese academic papers " + "such as conforming to the logic of the Chinese language, being simple, rigorous, and concise, " + "and avoiding the use of first-person pronouns when passive voice is appropriate. " + "Ensure that specialized terms are translated correctly according to academic standards. " + "Replace 我们 with 本研究 or 本文. " + "If an abbreviation is short in Chinese, use Chinese. " + "If an abbreviation is long in Chinese, use abbreviation. " + }, + { + "role": "user", + "content": abstract + } + ] + + examples = [ + { + "original": "We find that the EMBB is dominated by GW bursts from stellar mass black holes", + "hint": "1. 我们->本研究/本文(删除第一人称) 2. GWs->引力波(有简洁的中文表达),但EMBB保留(没有简洁的中文表达) 3. 调换语序,这句话中的重点是“恒星级黑洞发出的引力波”,所以调换语序突出重点。", + "bad": "我们发现,EMBB主要由恒星级黑洞发出的GWs爆发主导", + "good": "本研究发现恒星级黑洞发出的引力波爆发在EMBB中占主导地位", + }, + { + "original": "In a previous paper (Gayon & Bois 2008a), we have shown the general efficiency of retrograde resonances for stabilizing compact planetary systems.", + "bad": "在先前的一篇论文(Gayon & Bois 2008a)中,本文展示了逆向共振在稳定紧凑行星系统中的普遍效率。", + "hint": "修复主语,删除冗余的逗号,替换“效率”为“有效性”更符合学术表达。", + "good": "先前的一篇论文(Gayon & Bois 2008a)阐释了逆向共振在稳定紧凑行星系统中的普遍有效性。", + }, + ] + + # add examples to system prompt + for ex in examples: + messages[0]['content'] += f"\n\nExample:\n\tOriginal: {ex['original']}\n\tHint: {ex['hint']}\n\tBad Translation: {ex['bad']}\n\tGood Translation: {ex['good']}" + + # Use raw http requests (non-streaming) to get response + response = requests.post( f"{base_url}/chat/completions", json = { "model": "qwen-max", "messages": messages, }, + headers = { "Authorization": f"Bearer {api_key}" } ) + rough_translate = response.json()['choices'][0]['message']['content'] + # print(rough_translate) + + messages += [ + { + "role": "assistant", + "content": rough_translate + } + ] + + return messages, rough_translate + + + + +for i, task in enumerate(dataset.generate_training_tasks()): + if i >= 2: + execute_agent( + task, + OpenaiBaseUrlAndApiKey( + base_url="https://dashscope.aliyuncs.com/compatible-mode/v1", + api_key=os.environ.get("DASHSCOPE_API_KEY", "") + ) + ) diff --git a/t_agent_reward.py b/t_agent_reward.py new file mode 100644 index 00000000..fb1e83c3 --- /dev/null +++ b/t_agent_reward.py @@ -0,0 +1,174 @@ +from openjudge.graders.base_grader import GraderError, GraderMode, GraderScore +from openjudge.graders.llm_grader import LLMGrader +from openjudge.models.base_chat_model import BaseChatModel +import re +from typing import List + + +def get_translation_quality_system_prompt() -> str: + """Get the translation quality system prompt.""" + return """ +You are an objective translation quality evaluator for academic paper translations from English to Chinese. Your task is to identify ONLY the specific types of errors demonstrated in the provided examples - not general translation quality issues. + +Focus (but do not limit to) on issues below (as shown in the examples): + +1. **First-person pronoun issues** - Using "我们" instead of "本研究" or "本文" in academic contexts +2. **Abbreviation translation errors** - Using abbreviations when concise Chinese exists (e.g., "GWs" instead of "引力波"), or translating abbreviations that should remain in English (like "EMBB") +3. **Word order problems** - Not adjusting sentence structure to emphasize key points in Chinese academic style +4. **Subject-verb inconsistencies** - Mismatched subjects due to improper sentence structure (e.g., "在...中,本文展示..." where the subject is confused) +5. **Inappropriate word choices** - Using colloquial or incorrect terms instead of proper academic expressions (e.g., "效率" vs "有效性" in certain contexts) +6. **Redundant punctuation** - Unnecessary commas or other punctuation that disrupts Chinese reading flow + +**Examples of these errors:** + +Example 1: +- Original: "We find that the EMBB is dominated by GW bursts from stellar mass black holes" +- Bad Translation: "我们发现,EMBB主要由恒星级黑洞发出的GWs爆发主导" +- Issues: 1) "我们" should be "本研究/本文", 2) "GWs" should be "引力波" (has concise Chinese), 3) Word order doesn't emphasize the key point +- Good Translation: "本研究发现恒星级黑洞发出的引力波爆发在EMBB中占主导地位" + +Example 2: +- Original: "In a previous paper (Gayon & Bois 2008a), we have shown the general efficiency of retrograde resonances for stabilizing compact planetary systems." +- Bad Translation: "在先前的一篇论文(Gayon & Bois 2008a)中,本文展示了逆向共振在稳定紧凑行星系统中的普遍效率。" +- Issues: 1) Subject confusion (在...中,本文...), 2) Redundant comma, 3) "效率" should be "有效性" for better academic expression +- Good Translation: "先前的一篇论文(Gayon & Bois 2008a)阐释了逆向共振在稳定紧凑行星系统中的普遍有效性。" + +Example 3: +- Original: "To improve the transferability of ViT, we introduce a novel and effective module, named Domain Transferable-guided Attention Block (DTAB)." +- Bad Translation: "为了提高ViT的迁移能力,本文引入了一个新颖且有效的模块,称为域可迁移引导注意力块(DTAB)" +- Issues: 1) 语言顺序和表达不符合中文习惯。2) 没有在首次出现自定义缩写时,给出英文全称 +- Good Translation: "为提高ViT的迁移能力,本文引入了名为“域可迁移引导注意力块”(Transferable-guided Attention Block,DTAB)的新颖模块。" + +Example 4: +- Original: Extensive experiments were conducted on UCF-HMDB, Kinetics-Gameplay, and Kinetics-NEC Drone datasets, with different backbones, like ResNet101, I3D, and STAM, to verify the effectiveness of TransferAttn compared with state-of-the-art approaches. +- Bad Translation: 在UCF-HMDB、Kinetics-Gameplay和Kinetics-NEC Drone数据集上进行了广泛的实验,使用了不同的骨干网络,如ResNet101、I3D和STAM,以验证TransferAttn与现有最先进方法相比的有效性。 +- Issues: 1) 改变语言顺序后,主语缺失。应当填充主语“本研究”或者“本文”。2) 举例时,表述不够简洁。 +- Good Translation: 本研究在UCF-HMDB、Kinetics-Gameplay和Kinetics-NEC Drone数据集上进行了广泛的实验, 使用了ResNet101、I3D和STAM等骨干网络来验证TransferAttn与现有最先进方法相比的有效性。 + +Rate the translation on a scale of 0-3: + +0 = Severely impairs readability (multiple critical errors from the categories above that make the text difficult to understand) +1 = Does not impair readability, but numerous errors or significantly reduces Chinese reading efficiency (many instances of the error types above) +2 = Does not impair readability, few errors and not severe (minor instances of the error types above) +3 = No errors from the example categories detected (translation is free of the specific error types demonstrated) + +Note: +* For each key issue found, provide the specific error, its type, and where it appears in the translation. +* Be precise about which error category each issue belongs to. +* Focus on objective errors matching the example patterns, not subjective preferences. + +Think carefully before flagging any error. Ask yourself: Does this match one of the specific error types from the examples? Is this truly an objective error or just a stylistic preference? + +Return your response in this format: +X +Your detailed step-by-step reasoning analyzing the translation against the error categories + +- Error Type: [category]. Error: [specific issue]. Location: [where it appears in the translation] + + +The score must be 0, 1, 2, or 3. Each key issue should be on its own line starting with a dash. If no errors are found, the key_issues section should be empty or state "None detected". +""" + + +TRANSLATION_QUALITY_USER_PROMPT = """ +Evaluate the quality of this Chinese translation based on the specific error types demonstrated in the examples. + +Original English text: +{original} + +Chinese translation to evaluate: +{translation} +""" + + +def parse_translation_quality_response(text: str) -> dict: + """Parse XML-formatted translation quality response.""" + score_match = re.search(r"\s*(\d+)\s*", text) + reasoning_match = re.search(r"(.*?)", text, re.DOTALL) + issues_match = re.search(r"(.*?)", text, re.DOTALL) + + score = int(score_match.group(1)) if score_match else 3 + reasoning = reasoning_match.group(1).strip() if reasoning_match else text + + key_issues = [] + if issues_match: + issues_text = issues_match.group(1) + # Filter out empty lines and "None detected" type messages + key_issues = [ + line.strip().lstrip("- ") + for line in issues_text.strip().split("\n") + if line.strip() and not line.strip().lstrip("- ").lower().startswith("none") + ] + + return {"score": score, "reason": reasoning, "key_issues": key_issues} + + +def build_translation_quality_messages(original_text: str, translation: str) -> List[dict]: + """Build messages for translation quality evaluation.""" + return [ + {"role": "system", "content": get_translation_quality_system_prompt()}, + { + "role": "user", + "content": TRANSLATION_QUALITY_USER_PROMPT.format( + original=original_text, + translation=translation + ), + }, + ] + + +class TranslationQualityGrader(LLMGrader): + """Grader for evaluating translation quality based on specific error patterns. + + Score range: 0-3 + 0 = Severely impairs readability (multiple critical errors) + 1 = Does not impair readability, but numerous errors or reduces efficiency + 2 = Does not impair readability, few errors and not severe + 3 = No errors from the example categories detected + """ + + def __init__(self, model: BaseChatModel | dict): + super().__init__( + name="translation_quality", + mode=GraderMode.POINTWISE, + description="Evaluate translation quality based on specific error patterns", + model=model, + template="", # Placeholder, not used + ) + + async def aevaluate(self, original_text: str, translation: str) -> GraderScore: + """Evaluate translation quality. + + Args: + original_text: Original English text + translation: Chinese translation to evaluate + + Returns: + GraderScore with score 0-3 and identified issues + """ + try: + messages = build_translation_quality_messages(original_text, translation) + response = await self.model.achat(messages=messages) + content = await extract_response_content(response) + parsed = parse_translation_quality_response(content) + + return GraderScore( + name=self.name, + score=parsed["score"], + reason=parsed["reason"], + metadata={"key_issues": parsed["key_issues"]}, + ) + except Exception as e: + return GraderError(name=self.name, error=str(e)) + + +async def extract_response_content(response) -> str: + """Extract content from model response.""" + if hasattr(response, 'content'): + return response.content + elif isinstance(response, dict) and 'content' in response: + return response['content'] + elif isinstance(response, str): + return response + else: + raise ValueError(f"Unable to extract content from response: {type(response)}") From 8777bcb95a8dd6112056c95a65da4a623cd1defa Mon Sep 17 00:00:00 2001 From: "qingxu.fu" Date: Tue, 3 Feb 2026 17:50:27 +0800 Subject: [PATCH 17/25] stage swarm server --- ajet/backbone/trainer_verl.py | 8 +- ajet/context_tracker/base_tracker.py | 4 +- ajet/task_reader/hf_dataset_reader.py | 4 +- ajet/task_rollout/native_parallel_worker.py | 147 +++++++++++- ajet/task_rollout/single_worker.py | 3 + ajet/task_runner/tinkerscript_runner.py | 70 ++++-- ajet/tuner.py | 1 + .../experimental/as_oai_model_client.py | 6 +- .../experimental/as_oai_model_server.py | 3 + .../experimental/as_tinkerscript_client.py | 46 +++- .../experimental/as_tinkerscript_server.py | 183 ++++++++------ .../experimental/interchange_utils.py | 27 +-- ajet/utils/retry.py | 48 ++-- ajet_tinkerscript_threading.py | 2 +- t-agent.py | 224 ------------------ t_agent_reward.py | 174 -------------- tutorial/example_academic_trans/trans.py | 162 +++++++++++++ .../example_academic_trans/trans_reward.py | 180 ++++++++++++++ tutorial/example_academic_trans/trans_roll.py | 107 +++++++++ 19 files changed, 859 insertions(+), 540 deletions(-) delete mode 100644 t-agent.py delete mode 100644 t_agent_reward.py create mode 100644 tutorial/example_academic_trans/trans.py create mode 100644 tutorial/example_academic_trans/trans_reward.py create mode 100644 tutorial/example_academic_trans/trans_roll.py diff --git a/ajet/backbone/trainer_verl.py b/ajet/backbone/trainer_verl.py index 557f1948..70e5570b 100644 --- a/ajet/backbone/trainer_verl.py +++ b/ajet/backbone/trainer_verl.py @@ -493,7 +493,7 @@ def fit(self): # noqa: C901 # perform validation before training # currently, we only support validation using the reward_function. - if self.val_reward_fn is not None and self.config.trainer.get("val_before_train", True): + if (self.val_reward_fn is not None) and (self.config.trainer.get("val_before_train", True)) and (not self.config.ajet.enable_tinkerscript_mode): val_metrics = self._validate() assert val_metrics, f"{val_metrics=}" pprint(f"Initial validation metrics: {val_metrics}") @@ -784,6 +784,7 @@ def fit(self): # noqa: C901 self.val_reward_fn is not None and self.config.trainer.test_freq > 0 and (is_last_step or self.global_steps % self.config.trainer.test_freq == 0) + and (not self.config.ajet.enable_tinkerscript_mode) ): with marked_timer("testing", timing_raw, color="green"): val_metrics: dict = self._validate() @@ -934,17 +935,16 @@ def _validate(self): self.async_rollout_manager.wake_up() main_val_dataset = self.get_eval_dataset() - logger.info("=" * 10 + "start validate rollout" + "=" * 10) + logger.info("Starting validate rollout") context_tracker_arr, tasks, val_metrics = self.eval_dataset( target_dataset=main_val_dataset, target_dataset_name="main_val_dataset", mode="validate", epoch="test.1", ) - logger.info("=" * 10 + "end validate rollout" + "=" * 10) + logger.info("Completed validate rollout") test_output_gen_batch = self.parallel_env.to_dataproto(context_tracker_arr) self.async_rollout_manager.sleep() - logger.info("validation generation end") # Store generated outputs output_ids = test_output_gen_batch.batch["responses"] diff --git a/ajet/context_tracker/base_tracker.py b/ajet/context_tracker/base_tracker.py index 9e04c4bb..3baa7095 100644 --- a/ajet/context_tracker/base_tracker.py +++ b/ajet/context_tracker/base_tracker.py @@ -115,8 +115,8 @@ def __init__(self, config, tokenizer, workflow_task: WorkflowTask, **kwargs): self.workflow_task = workflow_task self.task_batch_index = self.workflow_task.task_batch_index - self.task_tag = self.workflow_task.task_tag - self.task_id = self.workflow_task.task_id + self.task_tag: str = self.workflow_task.task_tag + self.task_id: str = self.workflow_task.task_id self.episode_uuid = self.workflow_task.episode_uuid self.config = config diff --git a/ajet/task_reader/hf_dataset_reader.py b/ajet/task_reader/hf_dataset_reader.py index 21c13538..33a269e2 100644 --- a/ajet/task_reader/hf_dataset_reader.py +++ b/ajet/task_reader/hf_dataset_reader.py @@ -1,8 +1,8 @@ -from typing import List, Generator import datasets from ajet.schema.task import Task +from typing import List, Generator from ajet.task_reader.task_reader_base import BaseTaskReader @@ -38,7 +38,7 @@ def _load_dataset_split(self, split: str): # Load from Hugging Face hub dataset = datasets.load_dataset(self.dataset_name, split=split) # shuffle dataset - dataset = dataset.shuffle(seed=42) + dataset = dataset.shuffle() except Exception as e: raise ValueError( f"Failed to load dataset '{self.dataset_name}' with split '{split}': {str(e)}" diff --git a/ajet/task_rollout/native_parallel_worker.py b/ajet/task_rollout/native_parallel_worker.py index 7f35aa10..1e12ec81 100644 --- a/ajet/task_rollout/native_parallel_worker.py +++ b/ajet/task_rollout/native_parallel_worker.py @@ -2,7 +2,7 @@ import os import time -from concurrent.futures import Future, ThreadPoolExecutor +from concurrent.futures import Future, ThreadPoolExecutor, wait, ALL_COMPLETED, FIRST_COMPLETED from typing import Dict, List, Literal from urllib.parse import quote @@ -59,6 +59,9 @@ def step_status_printer(self, observation_window): if start == -1: print_buf += [f"[finished]:{count} threads"] print(f"Rollout progress ({token_gen_per_sec_str}): " + " // ".join(print_buf)) + if "info" in observation_window: + print_buf2 = "\t".join(observation_window["info"]) + print(print_buf2) def rollout_static( self, @@ -139,7 +142,9 @@ def rollout( epoch: str, ) -> List[BaseContextTracker]: """Delegate to dynamic rollout when oversampling is enabled.""" - if ( + if self.config.ajet.enable_tinkerscript_mode: + return self.rollout_swarm(tasks, mode, epoch) + elif ( mode == "sample" and (self.rollout_n != 1) and self.config.ajet.rollout.enable_oversample @@ -459,6 +464,144 @@ def rollout_dynamic( # noqa: C901 return tracker_array + + def rollout_swarm( # noqa: C901 + self, + tasks: List[Task], + mode: Literal["sample", "validate"], + epoch: str, + allow_sample_num_change=True, + allow_force_stop=True, + ) -> List[BaseContextTracker]: + """ + Build a pool of threads to run context trackers in parallel, + each thread re-spawn after complete, until reaching conditions to stop. + """ + + tracker_array: List[BaseContextTracker] = [] + assert mode != "validate" + rollout_n = self.rollout_n + n_task = len(tasks) + self.current_token_count_time = time.time() + + # initialize observation window + observation_window: Dict[str, List[int | bool | str]] = { + "info": ["" for _ in range(n_task * rollout_n)], + "step": [0 for _ in range(n_task * rollout_n)], + "stop": [False for _ in range(n_task * rollout_n)], + "token": [0 for _ in range(n_task * rollout_n)], + } + executor = ThreadPoolExecutor(max_workers=self.max_parallel) + futures: List[Future] = [] + completed_task_id_map_ct: Dict[str, List[BaseContextTracker]] = {} + + # submit initial tasks + dummy_task = Task(main_query="dummy task") + for task_batch_index in range(n_task): + for task_rollout_index in range(rollout_n): + task_thread_index = task_batch_index * rollout_n + task_rollout_index + future = executor.submit( + self.rollout_env_worker, + task=dummy_task, + task_tag="", + mode=mode, + task_batch_index=task_batch_index, + task_thread_index=task_thread_index, + observation_window=observation_window, + ) + observation_window["info"][task_thread_index] = "1" + futures.append(future) + + def enough_sample_stop_condition(completed_task_id_map_ct) -> bool: + n = 0 + for ct_list in completed_task_id_map_ct.values(): + n += len(ct_list) + return (n >= n_task * rollout_n) + + def enough_finished_task_stop_condition(completed_task_id_map_ct) -> bool: + n_finish_roll_task = 0 + for ct_list in completed_task_id_map_ct.values(): + if len(ct_list) >= rollout_n: + n_finish_roll_task += 1 + return (n_finish_roll_task >= n_task) + + def enough_non_dummy_task_stop_condition(completed_task_id_map_ct) -> bool: + n_finish_roll_task = 0 + for ct_list in completed_task_id_map_ct.values(): + task_cmd_reward_array = [ + tracker.reward_structure.performance_reward for tracker in ct_list + ] + if (len(ct_list) >= rollout_n): + all_equal = all(x == task_cmd_reward_array[0] for x in task_cmd_reward_array) + if all_equal: continue + n_finish_roll_task += 1 + return (n_finish_roll_task >= n_task) + + stop_condition = enough_sample_stop_condition + + def force_stop_all_threads(): + for k in range(len(observation_window["stop"])): + observation_window["stop"][k] = True + return + + tic = time.time() + while True: + # wait for a completed task + done_arr, pending_arr = wait(futures, timeout=10, return_when=FIRST_COMPLETED) + print(f"Done tasks: {len(done_arr)}, Pending tasks: {len(pending_arr)}") + toc = time.time() + if (toc - tic) > 8: + tic = toc + self.step_status_printer(observation_window) + # get result + for future in done_arr: + ct: BaseContextTracker = future.result() + if ct.task_id not in completed_task_id_map_ct: + completed_task_id_map_ct[ct.task_id] = [ct] + else: + completed_task_id_map_ct[ct.task_id] += [ct] + # if meet stop condition + meet_stop_condition_after_new_results = stop_condition(completed_task_id_map_ct) + if meet_stop_condition_after_new_results: + force_stop_all_threads() + break + else: + # re-spawn new tasks for done futures + for task_batch_index in range(n_task): + for task_rollout_index in range(rollout_n): + task_thread_index = task_batch_index * rollout_n + task_rollout_index + has_done = (futures[task_thread_index] in done_arr) + + observation_window["info"][task_thread_index] = str(int(observation_window["info"][task_thread_index]) + 1) + observation_window["stop"][task_thread_index] = False + observation_window["step"][task_thread_index] = 0 + + if has_done: + print(f"Re-spawning thread {task_thread_index}...") + future = executor.submit( + self.rollout_env_worker, + task=dummy_task, + task_tag="", + mode=mode, + task_batch_index=task_batch_index, + task_thread_index=task_thread_index, + observation_window=observation_window, + ) + futures[task_thread_index] = future + + # wait for all threads to complete + print('Finalizing all threads...') + wait(futures, return_when=ALL_COMPLETED) + + # build tracker_array + print('Collecting results...') + for ct_list in completed_task_id_map_ct.values(): + tracker_array.extend(ct_list) + + # return all trackers + return tracker_array + + class VerlRolloutManager(DynamicRolloutManager): """High-level manager orchestrating rollouts and batch conversion.""" diff --git a/ajet/task_rollout/single_worker.py b/ajet/task_rollout/single_worker.py index 32dfabf2..3e71492e 100644 --- a/ajet/task_rollout/single_worker.py +++ b/ajet/task_rollout/single_worker.py @@ -16,6 +16,7 @@ from ajet.utils.retry import retry_with_backoff from ajet.utils.sample import get_sample_params from ajet.utils.testing_utils import TestFailException, TestSuccessException +from ajet.task_runner.tinkerscript_runner import SwarmReceiveAbortException class BaseRolloutManager: @@ -123,6 +124,8 @@ def rollout_env_worker( tracker = agent_runner.execute( workflow_task=workflow_task, ) + except SwarmReceiveAbortException as exc: # noqa: BLE001 + return None # type: ignore except TestSuccessException as e: logger.success( f"env_worker.agent_flow completed with TestSuccessException: {e.args}" diff --git a/ajet/task_runner/tinkerscript_runner.py b/ajet/task_runner/tinkerscript_runner.py index 975c2b24..86e40899 100644 --- a/ajet/task_runner/tinkerscript_runner.py +++ b/ajet/task_runner/tinkerscript_runner.py @@ -15,32 +15,50 @@ from ajet.tuner_lib.weight_tuner.experimental.interchange_utils import http_register_episode, get_zmq_socket from loguru import logger from ajet import Workflow +from typing import Callable DEBUG = False context = zmq.Context() atexit.register(context.term) +class SwarmReceiveAbortException(Exception): + pass + class TinkerScriptRunner(BaseAgentRunner): - def register_episode_and_wait_output(self, episode_uuid: str, openai_base_url: str, openai_api_key: str, context_tracker: BaseContextTracker) -> WorkflowOutput: + def register_episode_and_wait_output( + self, + episode_uuid: str, + openai_base_url: str, + openai_api_key: str, + context_tracker: BaseContextTracker, + tuner:AjetTuner, + should_exit:Callable + ) -> WorkflowOutput: """Register the episode as ready in the TinkerScript data interchange center.""" # parse episode_uuid, openai_base_url, openai_api_key zmq_listen_result_addr, ipc_path = get_zmq_socket(self.config, episode_uuid, tag="workflow") - http_register_episode( - self.config, - episode_uuid=episode_uuid, - openai_base_url=openai_base_url, - openai_api_key=openai_api_key, - zmq_listen_result_addr=zmq_listen_result_addr, - ) + try: + http_register_episode( + self.config, + episode_uuid=episode_uuid, + openai_base_url=openai_base_url, + openai_api_key=openai_api_key, + zmq_listen_result_addr=zmq_listen_result_addr, + ) + except Exception as e: + raise SwarmReceiveAbortException(f"Episode {episode_uuid} cannot be registered.") + if DEBUG: logger.info(f"zmq_listen_result_addr: {zmq_listen_result_addr}") # begin wait for result zmq_socket = zmq.Context().socket(zmq.REP) zmq_socket.bind(zmq_listen_result_addr) + zmq_socket.setsockopt(zmq.RCVTIMEO, 3*1000) # 3 second timeout for REP speicial_messages = [ "RUNNER.SPECIAL.RESET_CONTEXT_TRACKER" + "RUNNER.SPECIAL.ABORT" ] while True: # : @@ -51,7 +69,16 @@ def register_episode_and_wait_output(self, episode_uuid: str, openai_base_url: s # : ajet/tuner_lib/weight_tuner/experimental/as_tinkerscript_server.py # : socket.send_string("RUNNER.SPECIAL.RESET_CONTEXT_TRACKER") # : "RUNNER.SPECIAL.RESET_CONTEXT_TRACKER" - message = zmq_socket.recv_string() + try: + message = zmq_socket.recv_string() + except zmq.Again as e: + if should_exit(): + context_tracker.reset() + tuner.terminate_episode() + raise SwarmReceiveAbortException(f"Episode {episode_uuid} aborted due to system exit.") + else: + continue + # process messages if message not in speicial_messages: zmq_socket.send_string("ack") break @@ -59,14 +86,23 @@ def register_episode_and_wait_output(self, episode_uuid: str, openai_base_url: s logger.warning(f"Received reset command for episode {episode_uuid}.") context_tracker.reset() zmq_socket.send_string("ack") + elif message == "RUNNER.SPECIAL.ABORT": + logger.warning(f"Received abort command for episode {episode_uuid}.") + context_tracker.reset() + zmq_socket.send_string("ack") + tuner.terminate_episode() + raise SwarmReceiveAbortException(f"Episode {episode_uuid} aborted.") else: + tuner.terminate_episode() raise RuntimeError(f"Unknown special message received: {message}") - logger.success(f"Received workflow output for episode {episode_uuid}") + final_output = WorkflowOutput(**json.loads(message)) + reward = final_output.reward + logger.success(f"Received workflow output for episode {episode_uuid} (Reward: {reward})") zmq_socket.close() if ipc_path and os.path.exists(ipc_path): os.remove(ipc_path) - return WorkflowOutput(**json.loads(message)) + return final_output def execute(self, workflow_task: WorkflowTask) -> BaseContextTracker: @@ -92,15 +128,19 @@ def execute(self, workflow_task: WorkflowTask) -> BaseContextTracker: config=self.config, ) + # from tuner, we get base_url and api_key baseurl_apikey = tuner.as_oai_baseurl_apikey() base_url = baseurl_apikey.base_url api_key = baseurl_apikey.api_key + # wait for remote client to return workflow output workflow_output: WorkflowOutput = self.register_episode_and_wait_output( episode_uuid=context_tracker.episode_uuid, openai_base_url=base_url, openai_api_key=api_key, context_tracker=context_tracker, + tuner=tuner, + should_exit=(lambda: observation_window["stop"][task_thread_index]) ) # the most important thing is to fix task_id to client task_id, set task_id to workflow_task and context_tracker task_id @@ -109,6 +149,7 @@ def execute(self, workflow_task: WorkflowTask) -> BaseContextTracker: workflow_task.task_id = task_id context_tracker.task_id = task_id + # process reward if workflow_output.reward is not None: raw_reward, is_success = ( workflow_output.reward, @@ -117,11 +158,11 @@ def execute(self, workflow_task: WorkflowTask) -> BaseContextTracker: else: raise ValueError("workflow_output.reward is None in TinkerScriptRunner, this is currently not allowed.") + # release gym_env workflow_task.gym_env = None # clear gym env client reference to avoid serialization issue - assert not isinstance( - raw_reward, list - ), "AgentJet will support step reward in future versions." + # check reward + assert not isinstance(raw_reward, list), "AgentJet will support step reward in future versions." # register reward # TODO: support multi-step reward @@ -132,6 +173,7 @@ def execute(self, workflow_task: WorkflowTask) -> BaseContextTracker: madness=0, description="", ) + # process reward context_tracker.process_reward(reward) # generate token before merging context_tracker.group_merge() diff --git a/ajet/tuner.py b/ajet/tuner.py index aacc3ab9..24258473 100644 --- a/ajet/tuner.py +++ b/ajet/tuner.py @@ -146,6 +146,7 @@ def _register(self, target_name: str, agent_name: str, explicit_tuner: TunerType self.target2proxy_registry[target_name][agent_name] = explicit_tuner return explicit_tuner + def _is_target_trainable(self, target_name) -> bool: """Determine whether user have used `trainable_targets` to explicitly control training targets. """ diff --git a/ajet/tuner_lib/weight_tuner/experimental/as_oai_model_client.py b/ajet/tuner_lib/weight_tuner/experimental/as_oai_model_client.py index 67d9bfb6..4ff26b4d 100644 --- a/ajet/tuner_lib/weight_tuner/experimental/as_oai_model_client.py +++ b/ajet/tuner_lib/weight_tuner/experimental/as_oai_model_client.py @@ -122,10 +122,10 @@ def begin_service(self): # wait till service begin running time.sleep(0.5) - w_time = 1 + wait_time = 1 while future._state == 'PENDING': - time.sleep(min(w_time * 2, 10)) - w_time += 1 + time.sleep(min(wait_time * 2, 10)) + wait_time += 1 if DEBUG: logger.info(f"[client] {self.episode_uuid} | Future ready...") return self.episode_contect_address diff --git a/ajet/tuner_lib/weight_tuner/experimental/as_oai_model_server.py b/ajet/tuner_lib/weight_tuner/experimental/as_oai_model_server.py index 22f5916a..1e4bdc43 100644 --- a/ajet/tuner_lib/weight_tuner/experimental/as_oai_model_server.py +++ b/ajet/tuner_lib/weight_tuner/experimental/as_oai_model_server.py @@ -355,4 +355,7 @@ def start_interchange_server(config, blocking=False, env={}) -> int: if interchange_server: interchange_server.terminate() + if enable_tinkerscript_mode: + from ajet.tuner_lib.weight_tuner.experimental.as_tinkerscript_server import kill_process_tree + kill_process_tree(None, None) return -1 \ No newline at end of file diff --git a/ajet/tuner_lib/weight_tuner/experimental/as_tinkerscript_client.py b/ajet/tuner_lib/weight_tuner/experimental/as_tinkerscript_client.py index a48088fa..6b4d3518 100644 --- a/ajet/tuner_lib/weight_tuner/experimental/as_tinkerscript_client.py +++ b/ajet/tuner_lib/weight_tuner/experimental/as_tinkerscript_client.py @@ -107,6 +107,37 @@ def end_episode(self, task:Task, episode_uuid: str, workflow_output: WorkflowOut except Exception as e: logger.error(f"Error ending episode: {e}") + def abort_episode(self, task:Task, episode_uuid: str, workflow_output: WorkflowOutput): + if not episode_uuid: + logger.error("No episode to end.") + return + + try: + task_id = task.task_id + workflow_output.metadata["task_id"] = task_id + req_obj = EndEpisodeRequest( + client_uuid=self.client_uuid, + episode_uuid=episode_uuid, + workflow_output=workflow_output, + task_id=task_id + ) + + resp = httpx.post( + f"{self.server_url}/abort_episode", + json=req_obj.model_dump(), + timeout=30 + ) + resp.raise_for_status() + data = EndEpisodeResponse.model_validate(resp.json()) + + if data.success: + logger.info(f"Ended episode {episode_uuid}") + else: + logger.error(f"Failed to end episode {episode_uuid}") + + except Exception as e: + logger.error(f"Error ending episode: {e}") + def sync_train_config(self, agent_jet_job: AgentJetJob): """ Sync training configuration to the TinkerScript server. @@ -205,7 +236,10 @@ def get_engine_status(self) -> str: timeout=10 ) resp.raise_for_status() - return resp.json().get("engine_status", "unknown") + result = resp.json().get("engine_status", "unknown") + if result == "unknown": + logger.warning("get_engine_status: " + resp.json()) + return result except Exception as e: logger.error(f"Error getting engine status: {e}") return "unknown" @@ -245,11 +279,19 @@ def get_episode_buffer(self) -> List[EpisodeStatus]: logger.error(f"Error getting episode buffer: {e}") return [] - def auto_sync_train_config_and_start_engine(self, agent_jet_job: AgentJetJob): + def auto_sync_train_config_and_start_engine(self, agent_jet_job: AgentJetJob, force_restart=False): """ Automatically sync training configuration and start the engine if needed. This checks the current engine status and performs actions accordingly. + + Args: + - agent_jet_job: The AgentJetJob configuration to sync. + - force_restart: If True, forces a restart of the engine. """ + if force_restart: + logger.warning("Force restarting the engine...") + self.stop_engine() + time.sleep(8) current_status = self.get_engine_status() if current_status == "ENGINE.OFFLINE": logger.info("Engine is OFFLINE. Syncing train config and starting engine...") diff --git a/ajet/tuner_lib/weight_tuner/experimental/as_tinkerscript_server.py b/ajet/tuner_lib/weight_tuner/experimental/as_tinkerscript_server.py index 2c70ed3c..20d503f3 100644 --- a/ajet/tuner_lib/weight_tuner/experimental/as_tinkerscript_server.py +++ b/ajet/tuner_lib/weight_tuner/experimental/as_tinkerscript_server.py @@ -460,6 +460,28 @@ async def end_episode(req: EndEpisodeRequest): return EndEpisodeResponse(success=True) + @app.post("/abort_episode", response_model=EndEpisodeResponse) + async def abort_episode(req: EndEpisodeRequest): + # receive workflow output data + episode_uuid = req.episode_uuid + workflow_output = req.workflow_output + task_id = req.task_id + + assert "task_id" in workflow_output.metadata, "workflow_output.metadata must contain task_id" + assert workflow_output.metadata["task_id"] == task_id, "workflow_output.metadata.task_id must match req.task_id" + + if 'episodes' not in shared_mem_dict: + logger.error(f"[server] No episodes registered yet.") + return EndEpisodeResponse(success=True) + if (f"episodes-{episode_uuid}") not in shared_mem_dict: + logger.error(f"[server] Episode {episode_uuid} not found.") + return EndEpisodeResponse(success=True) + + _revert_episode_to_unclaimed(episode_uuid) + + # return success + return EndEpisodeResponse(success=True) + @app.post("/can_continue_episode", response_model=CanContinueEpisodeResponse) async def can_continue_episode(req: CanContinueEpisodeRequest): @@ -492,68 +514,80 @@ async def stop_engine(): - Remove all episodes (registered, claimed, and unclaimed) - Clean up shared memory state """ - try: - import psutil + kill_process_tree(shared_mem_dict_lock, shared_mem_dict) + + return app, register_episode_ready_listener() + - killed_pids = [] - errors = [] - # Get the training process PID if it exists +def kill_process_tree(shared_mem_dict_lock=None, shared_mem_dict=None): + try: + import psutil + + killed_pids = [] + errors = [] + + # Get the training process PID if it exists + if shared_mem_dict and shared_mem_dict_lock: training_pid = shared_mem_dict.get('training_process_pid', None) + else: + training_pid = os.getpid() - if training_pid is not None: + if training_pid is not None: + try: + # Try to get the process and all its children try: - # Try to get the process and all its children - try: - parent = psutil.Process(training_pid) - children = parent.children(recursive=True) - - # Kill all child processes first - for child in children: - try: - logger.info(f"[stop_engine] Terminating child process PID: {child.pid}") - child.terminate() - killed_pids.append(child.pid) - except psutil.NoSuchProcess: - logger.warning(f"[stop_engine] Child process {child.pid} already terminated") - except Exception as e: - logger.error(f"[stop_engine] Error terminating child process {child.pid}: {e}") - errors.append(f"Child {child.pid}: {str(e)}") - - # Wait for children to terminate gracefully - gone, alive = psutil.wait_procs(children, timeout=16) - - # Force kill any remaining children - for p in alive: - try: - logger.warning(f"[stop_engine] Force killing child process PID: {p.pid}") - p.kill() - except Exception as e: - logger.error(f"[stop_engine] Error force killing child {p.pid}: {e}") - errors.append(f"Force kill child {p.pid}: {str(e)}") - - # Now terminate the parent process - logger.info(f"[stop_engine] Terminating parent process PID: {training_pid}") - parent.terminate() - killed_pids.append(training_pid) - - # Wait for parent to terminate gracefully + parent = psutil.Process(training_pid) + children = parent.children(recursive=True) + + # Kill all child processes first + for child in children: try: - parent.wait(timeout=3) - except psutil.TimeoutExpired: - logger.warning(f"[stop_engine] Force killing parent process PID: {training_pid}") - parent.kill() + logger.info(f"[stop_engine] Terminating child process PID: {child.pid}") + child.terminate() + killed_pids.append(child.pid) + except psutil.NoSuchProcess: + logger.warning(f"[stop_engine] Child process {child.pid} already terminated") + except Exception as e: + logger.error(f"[stop_engine] Error terminating child process {child.pid}: {e}") + errors.append(f"Child {child.pid}: {str(e)}") + + # Wait for children to terminate gracefully + gone, alive = psutil.wait_procs(children, timeout=16) + + # Force kill any remaining children + for p in alive: + try: + logger.warning(f"[stop_engine] Force killing child process PID: {p.pid}") + p.kill() + except Exception as e: + logger.error(f"[stop_engine] Error force killing child {p.pid}: {e}") + errors.append(f"Force kill child {p.pid}: {str(e)}") + + # Now terminate the parent process + logger.info(f"[stop_engine] Terminating parent process PID: {training_pid}") + parent.terminate() + killed_pids.append(training_pid) + + # Wait for parent to terminate gracefully + try: + parent.wait(timeout=3) + except psutil.TimeoutExpired: + logger.warning(f"[stop_engine] Force killing parent process PID: {training_pid}") + parent.kill() - except psutil.NoSuchProcess: - logger.warning(f"[stop_engine] Process {training_pid} not found (may have already terminated)") + except psutil.NoSuchProcess: + logger.warning(f"[stop_engine] Process {training_pid} not found (may have already terminated)") - except Exception as e: - logger.error(f"[stop_engine] Error killing training process: {e}") - errors.append(f"Training process: {str(e)}") - else: - logger.info("[stop_engine] No training process PID found in shared memory") + except Exception as e: + logger.error(f"[stop_engine] Error killing training process: {e}") + errors.append(f"Training process: {str(e)}") + else: + logger.info("[stop_engine] No training process PID found in shared memory") - # Clean up all episodes from shared memory + # Clean up all episodes from shared memory + episode_keys = [] + if shared_mem_dict and shared_mem_dict_lock: with shared_mem_dict_lock: episode_keys = [k for k in shared_mem_dict.keys() if k.startswith("episodes-")] for key in episode_keys: @@ -575,32 +609,31 @@ async def stop_engine(): logger.info("[stop_engine] Engine status set to OFFLINE") - result = { - "success": True, - "killed_pids": killed_pids, - "episodes_removed": len(episode_keys) if 'episode_keys' in locals() else 0, - } + result = { + "success": True, + "killed_pids": killed_pids, + "episodes_removed": len(episode_keys) if 'episode_keys' in locals() else 0, + } - if errors: - result["warnings"] = errors - logger.warning(f"[stop_engine] Completed with warnings: {errors}") - else: - logger.info(f"[stop_engine] Successfully terminated engine and reset state") + if errors: + result["warnings"] = errors + logger.warning(f"[stop_engine] Completed with warnings: {errors}") + else: + logger.info(f"[stop_engine] Successfully terminated engine and reset state") - return result + return result - except Exception as e: - logger.error(f"[stop_engine] Unexpected error: {e}") - import traceback - traceback.print_exc() + except Exception as e: + logger.error(f"[stop_engine] Unexpected error: {e}") + import traceback + traceback.print_exc() - # Even if there's an error, try to reset the status - try: + # Even if there's an error, try to reset the status + try: + if shared_mem_dict and shared_mem_dict_lock: with shared_mem_dict_lock: shared_mem_dict['engine_status'] = "ENGINE.OFFLINE" - except: - pass + except: + pass - return {"success": False, "error": str(e)} - - return app, register_episode_ready_listener() + return {"success": False, "error": str(e)} \ No newline at end of file diff --git a/ajet/tuner_lib/weight_tuner/experimental/interchange_utils.py b/ajet/tuner_lib/weight_tuner/experimental/interchange_utils.py index 12d549aa..949ee679 100644 --- a/ajet/tuner_lib/weight_tuner/experimental/interchange_utils.py +++ b/ajet/tuner_lib/weight_tuner/experimental/interchange_utils.py @@ -116,22 +116,17 @@ def http_register_episode(config, episode_uuid: str, zmq_listen_result_addr=zmq_listen_result_addr, ) # send http request to tinkerscript server to register episode - while True: - try: - response = httpx.post( - f"{interchange_http_addr}/register_episode", - json=rer.model_dump(), # 或者 rer.model_dump() 如果使用 Pydantic v2 - timeout=30 - ) - response.raise_for_status() - result = response.json() - if not result.get('success'): - raise RuntimeError(f"Failed to register episode {episode_uuid}") - if DEBUG: logger.info(f"Successfully registered episode {episode_uuid}") - break - except httpx.HTTPError as e: - logger.error(f"Error registering episode {episode_uuid}: {e}. Retrying...") - time.sleep(5) + + response = httpx.post( + f"{interchange_http_addr}/register_episode", + json=rer.model_dump(), # 或者 rer.model_dump() 如果使用 Pydantic v2 + timeout=30 + ) + response.raise_for_status() + result = response.json() + if not result.get('success'): + raise RuntimeError(f"Failed to register episode {episode_uuid}") + if DEBUG: logger.info(f"Successfully registered episode {episode_uuid}") return rer diff --git a/ajet/utils/retry.py b/ajet/utils/retry.py index 339eb7bb..9773c255 100644 --- a/ajet/utils/retry.py +++ b/ajet/utils/retry.py @@ -5,6 +5,7 @@ from loguru import logger from ajet.utils.testing_utils import TestFailException, TestSuccessException +from ajet.task_runner.tinkerscript_runner import SwarmReceiveAbortException T = TypeVar("T") @@ -27,27 +28,32 @@ def wrapper(*args: Any, **kwargs: Any) -> T: if target_max_retry < 1: target_max_retry = 1 - for attempt in range(target_max_retry): - try: - return func(*args, **kwargs) - except TestSuccessException as exc: # noqa: BLE001 - raise exc - except TestFailException as exc: # noqa: BLE001 - raise exc - except Exception as exc: # noqa: BLE001 - if attempt < target_max_retry - 1: - logger.bind(exception=True).exception( - f"{func.__name__} error: {exc.args}, retrying {attempt + 1}/{target_max_retry}" - ) - sleep_seconds = backoff_fn(attempt) if backoff_fn else 2**attempt - time.sleep(sleep_seconds) - else: - logger.bind(exception=True).exception( - f"{func.__name__} failed after {target_max_retry} retries: {exc.args}" - ) - raise - - raise RuntimeError("retry_with_backoff exhausted attempts") + try: + for attempt in range(target_max_retry): + try: + return func(*args, **kwargs) + except TestSuccessException as exc: # noqa: BLE001 + raise exc + except TestFailException as exc: # noqa: BLE001 + raise exc + + except Exception as exc: # noqa: BLE001 + if attempt < target_max_retry - 1: + logger.bind(exception=True).exception( + f"{func.__name__} error: {exc.args}, retrying {attempt + 1}/{target_max_retry}" + ) + sleep_seconds = backoff_fn(attempt) if backoff_fn else 2**attempt + time.sleep(sleep_seconds) + else: + logger.bind(exception=True).exception( + f"{func.__name__} failed after {target_max_retry} retries: {exc.args}" + ) + raise + + raise RuntimeError("retry_with_backoff exhausted attempts") + except SwarmReceiveAbortException as exc: # noqa: BLE001 + # ignore exception, return None silently + return None # type: ignore return wrapper diff --git a/ajet_tinkerscript_threading.py b/ajet_tinkerscript_threading.py index 47b1dce9..52546320 100644 --- a/ajet_tinkerscript_threading.py +++ b/ajet_tinkerscript_threading.py @@ -45,7 +45,7 @@ def main(): # Hand shake with remote tinkerscript server tinkerscript_remote = TinkerScriptClient(REMOTE_TINKERJET_URL) - tinkerscript_remote.stop_engine() + # tinkerscript_remote.stop_engine() tinkerscript_remote.auto_sync_train_config_and_start_engine( AgentJetJob( algorithm="grpo", diff --git a/t-agent.py b/t-agent.py deleted file mode 100644 index 617df2c8..00000000 --- a/t-agent.py +++ /dev/null @@ -1,224 +0,0 @@ - -from ajet.default_config.ajet_default import AjetTaskReader, HuggingfaceDatRepo -from ajet.task_reader import RouterTaskReader -import re -import os -import threading -import time -import requests -from loguru import logger -from textwrap import dedent -from ajet.copilot.job import AgentJetJob -from ajet.default_config.ajet_default import AjetTaskReader, HuggingfaceDatRepo -from ajet.tuner_lib.weight_tuner.as_oai_baseurl_apikey import OpenaiBaseUrlAndApiKey -from ajet import WorkflowOutput -from ajet.task_reader import RouterTaskReader -from ajet.utils.retry import retry_with_backoff -from ajet.schema.task import Task -from concurrent.futures import ThreadPoolExecutor -from beast_logger import print_listofdict -import asyncio - -# Import reward computation from t-agent-reward.py -from t_agent_reward import TranslationQualityGrader, build_translation_quality_messages - - - -LOCAL_DATASET_PATH = "/mnt/data_cpfs/qingxu.fu/agentjet/agentjet/tmp/arxiv_papers/train.parquet" - - -# Handshake with tinkerscript remote, then send training param to tinkerscript remote (such as model to be trained, algorithm, etc) -dataset = RouterTaskReader( - reader_type = "huggingface_dat_repo", - reader_config = AjetTaskReader( - huggingface_dat_repo = HuggingfaceDatRepo( - dataset_path = LOCAL_DATASET_PATH - ) - ) -) - - -def execute_agent(task: Task, api_baseurl_key: OpenaiBaseUrlAndApiKey): - # Prepare base_url, api_key - base_url, api_key = (api_baseurl_key.base_url, api_baseurl_key.api_key) - # Read dataset item - title = task.metadata['title'] - authors = task.metadata['authors'] - abstract = task.metadata['abstract'] - # Prepare messages - messages, rough_translate = rough_translate_agent(base_url, api_key, abstract) - - messages, fix_nouns = detect_hard_proper_nouns(base_url, api_key, abstract, rough_translate) - print_listofdict(messages, header="detect_hard_proper_nouns", mod="c") - - messages, final_translation = produce_final_translation(base_url, api_key, abstract, rough_translate, fix_nouns) - print_listofdict(messages, header="final_translation", mod="c") - - # Compute reward - time.sleep(1) - # Use the translation quality grader from t-agent-reward.py - from openjudge.models import OpenAIChatModel - # from openjudge.models.openai_model import OpenAIModel - grader = TranslationQualityGrader( - model=OpenAIChatModel(base_url=base_url, api_key=api_key, model="qwen-max") - ) - grader_score = asyncio.run(grader.aevaluate(original_text=abstract, translation=final_translation)) - raw_reward = grader_score.score / 3.0 # Normalize to 0-1 range (score is 0-3) - # Return - return WorkflowOutput(reward=raw_reward, metadata={ - "rough_translate": rough_translate, - "fix_nouns": fix_nouns, - "final_translation": final_translation - }) - - -def detect_hard_proper_nouns(base_url, api_key, abstract, rough_translate): - messages = [ - { - "role": "system", - "content": "You are responsible for detecting translation errors of discipline-specific proper nouns. " - "Use json to list all errors found in the translation result and provide correction. " - "Json format: [{\"original_word\": \"xxx\", \"wrong_translation\": \"xxx\", \"wrong_reason\": \"xxx\", \"correct_translation\": \"xxx\"}, ...]. " - "If no errors are found, return an empty list []." - }, - { - "role": "user", - "content": abstract - }, - { - "role": "assistant", - "content": rough_translate - }, - { - "role": "user", - "content": "Please list all translation errors of discipline-specific proper nouns found in the translation result according to the requirements." - }, - ] - - # Use raw http requests (non-streaming) to get response - response = requests.post( f"{base_url}/chat/completions", json = { "model": "qwen-max", "messages": messages, }, - headers = { "Authorization": f"Bearer {api_key}" } ) - fix_nouns = response.json()['choices'][0]['message']['content'] - messages += [ - { - "role": "assistant", - "content": fix_nouns - } - ] - return messages, fix_nouns - - -def produce_final_translation(base_url, api_key, abstract, rough_translate, fix_nouns): - """ - Third agent: Apply the corrections from fix_nouns to produce the final polished translation. - """ - messages = [ - { - "role": "system", - "content": "You are a professional academic translator responsible for producing the final, polished Chinese translation. " - "You will receive: 1) the original English abstract, 2) an initial translation, and 3) a list of corrections for proper nouns. " - "Your task is to apply all the corrections to produce a final translation that is accurate, fluent, and meets Chinese academic writing standards. " - "Ensure that all discipline-specific proper nouns are translated correctly according to the provided corrections. " - "Maintain the academic tone and ensure the translation is concise, rigorous, and natural in Chinese." - }, - { - "role": "user", - "content": f"Original English Abstract:\n{abstract}" - }, - { - "role": "user", - "content": f"Initial Translation:\n{rough_translate}" - }, - { - "role": "user", - "content": f"Corrections for Proper Nouns:\n{fix_nouns}" - }, - { - "role": "user", - "content": "Please produce the final, corrected Chinese translation by applying all the corrections listed above. " - "Output only the final translation without any explanations or additional text." - }, - ] - - # Use raw http requests (non-streaming) to get response - response = requests.post( f"{base_url}/chat/completions", json = { "model": "qwen-max", "messages": messages, }, - headers = { "Authorization": f"Bearer {api_key}" } ) - final_translation = response.json()['choices'][0]['message']['content'] - - messages += [ - { - "role": "assistant", - "content": final_translation - } - ] - - return messages, final_translation - - -def rough_translate_agent(base_url, api_key, abstract): - messages = [ - { - "role": "system", - "content": - "You are a professional language translator. " - "Translate the given Academic English text into Chinese accurately. " - "During the translation process, it is necessary to meet the linguistic norms of Chinese academic papers " - "such as conforming to the logic of the Chinese language, being simple, rigorous, and concise, " - "and avoiding the use of first-person pronouns when passive voice is appropriate. " - "Ensure that specialized terms are translated correctly according to academic standards. " - "Replace 我们 with 本研究 or 本文. " - "If an abbreviation is short in Chinese, use Chinese. " - "If an abbreviation is long in Chinese, use abbreviation. " - }, - { - "role": "user", - "content": abstract - } - ] - - examples = [ - { - "original": "We find that the EMBB is dominated by GW bursts from stellar mass black holes", - "hint": "1. 我们->本研究/本文(删除第一人称) 2. GWs->引力波(有简洁的中文表达),但EMBB保留(没有简洁的中文表达) 3. 调换语序,这句话中的重点是“恒星级黑洞发出的引力波”,所以调换语序突出重点。", - "bad": "我们发现,EMBB主要由恒星级黑洞发出的GWs爆发主导", - "good": "本研究发现恒星级黑洞发出的引力波爆发在EMBB中占主导地位", - }, - { - "original": "In a previous paper (Gayon & Bois 2008a), we have shown the general efficiency of retrograde resonances for stabilizing compact planetary systems.", - "bad": "在先前的一篇论文(Gayon & Bois 2008a)中,本文展示了逆向共振在稳定紧凑行星系统中的普遍效率。", - "hint": "修复主语,删除冗余的逗号,替换“效率”为“有效性”更符合学术表达。", - "good": "先前的一篇论文(Gayon & Bois 2008a)阐释了逆向共振在稳定紧凑行星系统中的普遍有效性。", - }, - ] - - # add examples to system prompt - for ex in examples: - messages[0]['content'] += f"\n\nExample:\n\tOriginal: {ex['original']}\n\tHint: {ex['hint']}\n\tBad Translation: {ex['bad']}\n\tGood Translation: {ex['good']}" - - # Use raw http requests (non-streaming) to get response - response = requests.post( f"{base_url}/chat/completions", json = { "model": "qwen-max", "messages": messages, }, - headers = { "Authorization": f"Bearer {api_key}" } ) - rough_translate = response.json()['choices'][0]['message']['content'] - # print(rough_translate) - - messages += [ - { - "role": "assistant", - "content": rough_translate - } - ] - - return messages, rough_translate - - - - -for i, task in enumerate(dataset.generate_training_tasks()): - if i >= 2: - execute_agent( - task, - OpenaiBaseUrlAndApiKey( - base_url="https://dashscope.aliyuncs.com/compatible-mode/v1", - api_key=os.environ.get("DASHSCOPE_API_KEY", "") - ) - ) diff --git a/t_agent_reward.py b/t_agent_reward.py deleted file mode 100644 index fb1e83c3..00000000 --- a/t_agent_reward.py +++ /dev/null @@ -1,174 +0,0 @@ -from openjudge.graders.base_grader import GraderError, GraderMode, GraderScore -from openjudge.graders.llm_grader import LLMGrader -from openjudge.models.base_chat_model import BaseChatModel -import re -from typing import List - - -def get_translation_quality_system_prompt() -> str: - """Get the translation quality system prompt.""" - return """ -You are an objective translation quality evaluator for academic paper translations from English to Chinese. Your task is to identify ONLY the specific types of errors demonstrated in the provided examples - not general translation quality issues. - -Focus (but do not limit to) on issues below (as shown in the examples): - -1. **First-person pronoun issues** - Using "我们" instead of "本研究" or "本文" in academic contexts -2. **Abbreviation translation errors** - Using abbreviations when concise Chinese exists (e.g., "GWs" instead of "引力波"), or translating abbreviations that should remain in English (like "EMBB") -3. **Word order problems** - Not adjusting sentence structure to emphasize key points in Chinese academic style -4. **Subject-verb inconsistencies** - Mismatched subjects due to improper sentence structure (e.g., "在...中,本文展示..." where the subject is confused) -5. **Inappropriate word choices** - Using colloquial or incorrect terms instead of proper academic expressions (e.g., "效率" vs "有效性" in certain contexts) -6. **Redundant punctuation** - Unnecessary commas or other punctuation that disrupts Chinese reading flow - -**Examples of these errors:** - -Example 1: -- Original: "We find that the EMBB is dominated by GW bursts from stellar mass black holes" -- Bad Translation: "我们发现,EMBB主要由恒星级黑洞发出的GWs爆发主导" -- Issues: 1) "我们" should be "本研究/本文", 2) "GWs" should be "引力波" (has concise Chinese), 3) Word order doesn't emphasize the key point -- Good Translation: "本研究发现恒星级黑洞发出的引力波爆发在EMBB中占主导地位" - -Example 2: -- Original: "In a previous paper (Gayon & Bois 2008a), we have shown the general efficiency of retrograde resonances for stabilizing compact planetary systems." -- Bad Translation: "在先前的一篇论文(Gayon & Bois 2008a)中,本文展示了逆向共振在稳定紧凑行星系统中的普遍效率。" -- Issues: 1) Subject confusion (在...中,本文...), 2) Redundant comma, 3) "效率" should be "有效性" for better academic expression -- Good Translation: "先前的一篇论文(Gayon & Bois 2008a)阐释了逆向共振在稳定紧凑行星系统中的普遍有效性。" - -Example 3: -- Original: "To improve the transferability of ViT, we introduce a novel and effective module, named Domain Transferable-guided Attention Block (DTAB)." -- Bad Translation: "为了提高ViT的迁移能力,本文引入了一个新颖且有效的模块,称为域可迁移引导注意力块(DTAB)" -- Issues: 1) 语言顺序和表达不符合中文习惯。2) 没有在首次出现自定义缩写时,给出英文全称 -- Good Translation: "为提高ViT的迁移能力,本文引入了名为“域可迁移引导注意力块”(Transferable-guided Attention Block,DTAB)的新颖模块。" - -Example 4: -- Original: Extensive experiments were conducted on UCF-HMDB, Kinetics-Gameplay, and Kinetics-NEC Drone datasets, with different backbones, like ResNet101, I3D, and STAM, to verify the effectiveness of TransferAttn compared with state-of-the-art approaches. -- Bad Translation: 在UCF-HMDB、Kinetics-Gameplay和Kinetics-NEC Drone数据集上进行了广泛的实验,使用了不同的骨干网络,如ResNet101、I3D和STAM,以验证TransferAttn与现有最先进方法相比的有效性。 -- Issues: 1) 改变语言顺序后,主语缺失。应当填充主语“本研究”或者“本文”。2) 举例时,表述不够简洁。 -- Good Translation: 本研究在UCF-HMDB、Kinetics-Gameplay和Kinetics-NEC Drone数据集上进行了广泛的实验, 使用了ResNet101、I3D和STAM等骨干网络来验证TransferAttn与现有最先进方法相比的有效性。 - -Rate the translation on a scale of 0-3: - -0 = Severely impairs readability (multiple critical errors from the categories above that make the text difficult to understand) -1 = Does not impair readability, but numerous errors or significantly reduces Chinese reading efficiency (many instances of the error types above) -2 = Does not impair readability, few errors and not severe (minor instances of the error types above) -3 = No errors from the example categories detected (translation is free of the specific error types demonstrated) - -Note: -* For each key issue found, provide the specific error, its type, and where it appears in the translation. -* Be precise about which error category each issue belongs to. -* Focus on objective errors matching the example patterns, not subjective preferences. - -Think carefully before flagging any error. Ask yourself: Does this match one of the specific error types from the examples? Is this truly an objective error or just a stylistic preference? - -Return your response in this format: -X -Your detailed step-by-step reasoning analyzing the translation against the error categories - -- Error Type: [category]. Error: [specific issue]. Location: [where it appears in the translation] - - -The score must be 0, 1, 2, or 3. Each key issue should be on its own line starting with a dash. If no errors are found, the key_issues section should be empty or state "None detected". -""" - - -TRANSLATION_QUALITY_USER_PROMPT = """ -Evaluate the quality of this Chinese translation based on the specific error types demonstrated in the examples. - -Original English text: -{original} - -Chinese translation to evaluate: -{translation} -""" - - -def parse_translation_quality_response(text: str) -> dict: - """Parse XML-formatted translation quality response.""" - score_match = re.search(r"\s*(\d+)\s*", text) - reasoning_match = re.search(r"(.*?)", text, re.DOTALL) - issues_match = re.search(r"(.*?)", text, re.DOTALL) - - score = int(score_match.group(1)) if score_match else 3 - reasoning = reasoning_match.group(1).strip() if reasoning_match else text - - key_issues = [] - if issues_match: - issues_text = issues_match.group(1) - # Filter out empty lines and "None detected" type messages - key_issues = [ - line.strip().lstrip("- ") - for line in issues_text.strip().split("\n") - if line.strip() and not line.strip().lstrip("- ").lower().startswith("none") - ] - - return {"score": score, "reason": reasoning, "key_issues": key_issues} - - -def build_translation_quality_messages(original_text: str, translation: str) -> List[dict]: - """Build messages for translation quality evaluation.""" - return [ - {"role": "system", "content": get_translation_quality_system_prompt()}, - { - "role": "user", - "content": TRANSLATION_QUALITY_USER_PROMPT.format( - original=original_text, - translation=translation - ), - }, - ] - - -class TranslationQualityGrader(LLMGrader): - """Grader for evaluating translation quality based on specific error patterns. - - Score range: 0-3 - 0 = Severely impairs readability (multiple critical errors) - 1 = Does not impair readability, but numerous errors or reduces efficiency - 2 = Does not impair readability, few errors and not severe - 3 = No errors from the example categories detected - """ - - def __init__(self, model: BaseChatModel | dict): - super().__init__( - name="translation_quality", - mode=GraderMode.POINTWISE, - description="Evaluate translation quality based on specific error patterns", - model=model, - template="", # Placeholder, not used - ) - - async def aevaluate(self, original_text: str, translation: str) -> GraderScore: - """Evaluate translation quality. - - Args: - original_text: Original English text - translation: Chinese translation to evaluate - - Returns: - GraderScore with score 0-3 and identified issues - """ - try: - messages = build_translation_quality_messages(original_text, translation) - response = await self.model.achat(messages=messages) - content = await extract_response_content(response) - parsed = parse_translation_quality_response(content) - - return GraderScore( - name=self.name, - score=parsed["score"], - reason=parsed["reason"], - metadata={"key_issues": parsed["key_issues"]}, - ) - except Exception as e: - return GraderError(name=self.name, error=str(e)) - - -async def extract_response_content(response) -> str: - """Extract content from model response.""" - if hasattr(response, 'content'): - return response.content - elif isinstance(response, dict) and 'content' in response: - return response['content'] - elif isinstance(response, str): - return response - else: - raise ValueError(f"Unable to extract content from response: {type(response)}") diff --git a/tutorial/example_academic_trans/trans.py b/tutorial/example_academic_trans/trans.py new file mode 100644 index 00000000..eed5c5df --- /dev/null +++ b/tutorial/example_academic_trans/trans.py @@ -0,0 +1,162 @@ + +import re +import os +import time +import asyncio +import requests +import threading +from loguru import logger +from textwrap import dedent + +from ajet import WorkflowOutput +from ajet.schema.task import Task +from ajet.copilot.job import AgentJetJob +from ajet.task_reader import RouterTaskReader +from ajet.utils.retry import retry_with_backoff +from ajet.default_config.ajet_default import AjetTaskReader, HuggingfaceDatRepo +from ajet.tuner_lib.weight_tuner.as_oai_baseurl_apikey import OpenaiBaseUrlAndApiKey +from beast_logger import print_listofdict + +# Import reward computation from trans_reward.py +from openjudge.models import OpenAIChatModel +from .trans_reward import TranslationQualityGrader, build_translation_quality_messages, examples + + +LOCAL_DATASET_PATH = "/mnt/data_cpfs/qingxu.fu/agentjet/agentjet/tmp/arxiv_papers/train.parquet" + + +# Handshake with tinkerscript remote, then send training param to tinkerscript remote (such as model to be trained, algorithm, etc) +dataset = RouterTaskReader( + reader_type = "huggingface_dat_repo", + reader_config = AjetTaskReader( + huggingface_dat_repo = HuggingfaceDatRepo( + dataset_path = LOCAL_DATASET_PATH + ) + ) +) + +@retry_with_backoff(max_retry=3) +def execute_agent(task: Task, api_baseurl_key: OpenaiBaseUrlAndApiKey): + # Prepare base_url, api_key + base_url, api_key = (api_baseurl_key.base_url, api_baseurl_key.api_key) + grader_base_url, grader_api_key = ("https://dashscope.aliyuncs.com/compatible-mode/v1", os.environ.get("DASHSCOPE_API_KEY", "")) + # Read dataset item + title = task.metadata['title'] + authors = task.metadata['authors'] + abstract = task.metadata['abstract'] + + messages, rough_translate = rough_translate_agent(base_url, api_key, abstract) + # print_listofdict(messages, header="rough_translate_agent", mod="c") + + messages, fix_nouns = detect_hard_proper_nouns(messages, base_url, api_key, abstract, rough_translate) + # print_listofdict(messages, header="detect_hard_proper_nouns", mod="c") + + messages, final_translation = produce_final_translation(messages, base_url, api_key, abstract, rough_translate, fix_nouns) + print_listofdict(messages, header="final_translation", mod="c") + + grader = TranslationQualityGrader( + model=OpenAIChatModel(base_url=grader_base_url, api_key=grader_api_key, model="qwen-max") + ) + grader_score = asyncio.run(grader.aevaluate(original_text=abstract, translation=final_translation)) + raw_reward = grader_score.score # Normalize to 0-1 range (score is 0-3) + return WorkflowOutput(reward=raw_reward, metadata={ + "rough_translate": rough_translate, + "fix_nouns": fix_nouns, + "final_translation": final_translation + }) + + +def detect_hard_proper_nouns(messages, base_url, api_key, abstract, rough_translate): + messages = messages + [ + + { + "role": "user", + "content": "You new job is to detect translation errors of discipline-specific proper nouns. " + "Use json to list all errors found in the translation result and provide correction. " + "Json format: [{\"original_word\": \"xxx\", \"wrong_translation\": \"xxx\", \"wrong_reason\": \"xxx\", \"correct_translation\": \"xxx\"}, ...]. " + "If no errors are found, return an empty list []." + "Please list all translation errors of discipline-specific proper nouns found in the translation result according to the requirements." + }, + ] + + response = requests.post( f"{base_url}/chat/completions", json = { "model": "qwen-turbo", "messages": messages, }, headers = { "Authorization": f"Bearer {api_key}" } ) + fix_nouns = response.json()['choices'][0]['message']['content'] + messages += [ + { + "role": "assistant", + "content": fix_nouns + } + ] + return messages, fix_nouns + + +def produce_final_translation(messages, base_url, api_key, abstract, rough_translate, fix_nouns): + messages = messages + [ + { + "role": "user", + "content": "Please produce the final, corrected Chinese translation by applying all the corrections listed above. " + "Output only the final translation without any explanations or additional text." + }, + ] + + response = requests.post( f"{base_url}/chat/completions", json = { "model": "qwen-turbo", "messages": messages, }, headers = { "Authorization": f"Bearer {api_key}" } ) + final_translation = response.json()['choices'][0]['message']['content'] + + messages += [ + { + "role": "assistant", + "content": final_translation + } + ] + + return messages, final_translation + + +def rough_translate_agent(base_url, api_key, abstract): + messages = [ + { + "role": "system", + "content": + "You are a professional language translator. " + "Translate the given Academic English text into Chinese accurately. " + "During the translation process, it is necessary to meet the linguistic norms of Chinese academic papers " + "such as conforming to the logic of the Chinese language, being simple, rigorous, and concise, " + "and avoiding the use of first-person pronouns when passive voice is appropriate. " + "Ensure that specialized terms are translated correctly according to academic standards. " + "Replace 我们 with 本研究 or 本文. " + "If an abbreviation is short in Chinese, use Chinese. " + "If an abbreviation is long in Chinese, use abbreviation. " + }, + { + "role": "user", + "content": abstract + } + ] + + for ex in examples: + messages[0]['content'] += f"\n\nExample:\n\tOriginal: {ex['original']}\n\tBad Translation: {ex['bad']}\n\tHint: {ex['hint']}\n\tGood Translation: {ex['good']}" + response = requests.post( f"{base_url}/chat/completions", json = { "model": "qwen-turbo", "messages": messages, }, headers = { "Authorization": f"Bearer {api_key}" } ) + rough_translate = response.json()['choices'][0]['message']['content'] + messages += [ + { + "role": "assistant", + "content": rough_translate + } + ] + + return messages, rough_translate + + + +if __name__ == "__main__": + + for i, task in enumerate(dataset.generate_training_tasks()): + execute_agent( + task, + OpenaiBaseUrlAndApiKey( + base_url="https://dashscope.aliyuncs.com/compatible-mode/v1", + api_key=os.environ.get("DASHSCOPE_API_KEY", "") + ) + ) + + diff --git a/tutorial/example_academic_trans/trans_reward.py b/tutorial/example_academic_trans/trans_reward.py new file mode 100644 index 00000000..b8acae84 --- /dev/null +++ b/tutorial/example_academic_trans/trans_reward.py @@ -0,0 +1,180 @@ +import re +from openjudge.graders.base_grader import GraderError, GraderMode, GraderScore +from openjudge.graders.llm_grader import LLMGrader +from openjudge.models.base_chat_model import BaseChatModel +from typing import List +from textwrap import dedent + + +examples = [ + { + "original": "We find that the EMBB is dominated by GW bursts from stellar mass black holes", + "bad": "我们发现,EMBB主要由恒星级黑洞发出的GWs爆发主导", + "hint": "1) 我们->本研究/本文(删除第一人称) 2) GWs->引力波(有简洁的中文表达),但EMBB保留(没有简洁的中文表达) 3. 调换语序,这句话中的重点是“恒星级黑洞发出的引力波”,所以调换语序突出重点。", + "good": "本研究发现恒星级黑洞发出的引力波爆发在EMBB中占主导地位" + }, + { + "original": "In a previous paper (Gayon & Bois 2008a), we have shown the general efficiency of retrograde resonances for stabilizing compact planetary systems.", + "bad": "在先前的一篇论文(Gayon & Bois 2008a)中,本文展示了逆向共振在稳定紧凑行星系统中的普遍效率。", + "hint": "修复主语,删除冗余的逗号,替换“效率”为“有效性”更符合学术表达。", + "good": "先前的一篇论文(Gayon & Bois 2008a)阐释了逆向共振在稳定紧凑行星系统中的普遍有效性。" + }, + { + "original": "To improve the transferability of ViT, we introduce a novel and effective module, named Domain Transferable-guided Attention Block (DTAB).", + "bad": "为了提高ViT的迁移能力,本文引入了一个新颖且有效的模块,称为域可迁移引导注意力块(DTAB)", + "hint": "1)语言顺序和表达不符合中文习惯 2)没有在首次出现自定义缩写时,给出英文全称", + "good": "为提高ViT的迁移能力,本文引入了名为“域可迁移引导注意力块”(Domain Transferable-guided Attention Block,DTAB)的新颖且有效的模块。" + }, + { + "original": "Extensive experiments were conducted on UCF-HMDB, Kinetics-Gameplay, and Kinetics-NEC Drone datasets, with different backbones, like ResNet101, I3D, and STAM, to verify the effectiveness of TransferAttn compared with state-of-the-art approaches.", + "bad": "在UCF-HMDB、Kinetics-Gameplay和Kinetics-NEC Drone数据集上进行了广泛的实验,使用了不同的骨干网络,如ResNet101、I3D和STAM,以验证TransferAttn与现有最先进方法相比的有效性。", + "hint": "1)改变语言顺序后,主语缺失 2)举例时,表述不够简洁", + "good": "本研究在UCF-HMDB、Kinetics-Gameplay和Kinetics-NEC Drone数据集上进行了广泛的实验,使用了ResNet101、I3D和STAM等骨干网络来验证TransferAttn与现有最先进方法相比的有效性。" + } +] + + +examples_eval = examples + [ + +] + + + +TRANSLATION_QUALITY_USER_PROMPT = """ +Evaluate the quality of this Chinese translation based on the specific error types demonstrated in the examples. + +Original English text: +{original} + +Chinese translation to evaluate: +{translation} +""" + + + +def get_translation_quality_system_prompt() -> str: + """Get the translation quality system prompt.""" + examples_text = "" + for i, ex in enumerate(examples_eval, 1): + examples_text += dedent(f""" + Example {i}: + - Original: "{ex['original']}" + - Bad Translation: "{ex['bad']}" + - Issues: {ex['hint']} + - Good Translation: "{ex['good']}" + """) + + + return dedent(""" + You are an objective translation quality evaluator for academic paper translations from English to Chinese. Your task is to identify ONLY the specific types of errors demonstrated in the provided examples - not general translation quality issues. + + Focus (but do not limit to) on issues below (as shown in the examples): + + 1. **First-person pronoun issues** - Using "我们" instead of "本研究" or "本文" in academic contexts + 2. **Abbreviation translation errors** - Using abbreviations when concise Chinese exists (e.g., "GWs" instead of "引力波"), or translating abbreviations that should remain in English (like "EMBB") + 3. **Word order problems** - Not adjusting sentence structure to emphasize key points in Chinese academic style + 4. **Subject-verb inconsistencies** - Mismatched subjects due to improper sentence structure (e.g., "在...中,本文展示..." where the subject is confused) + 5. **Inappropriate word choices** - Using colloquial or incorrect terms instead of proper academic expressions (e.g., "效率" vs "有效性" in certain contexts) + 6. **Redundant punctuation** - Unnecessary commas or other punctuation that disrupts Chinese reading flow + + **Examples of these errors:** + [[examples_text]] + Rate the translation on a scale of 0-2: + + 0 = Severely impairs readability (multiple critical errors from the categories above that make the text difficult to understand) + 1 = Contain errors, reduces Chinese reading efficiency (many instances of the error types above) + 2 = No errors from the example categories detected (translation is free of the specific error types demonstrated) + + Note: + * For each key issue found, provide the specific error, its type, and where it appears in the translation. + * Be precise about which error category each issue belongs to. + * Focus on objective errors matching the example patterns, not subjective preferences. + + Think carefully before flagging any error. Ask yourself: Does this match one of the specific error types from the examples? Is this truly an objective error or just a stylistic preference? + + Return your response in this format: + X + Your detailed step-by-step reasoning analyzing the translation against the error categories + + - Error Type: [category]. Error: [specific issue]. Location: [where it appears in the translation] + + + The score must be 0, 1, 2. Each key issue should be on its own line starting with a dash. If no errors are found, the key_issues section should be empty or state "None detected". + """.replace("[[examples_text]]", examples_text)) + + + +def parse_translation_quality_response(text: str) -> dict: + """Parse XML-formatted translation quality response.""" + score_match = re.search(r"\s*(\d+)\s*", text) + reasoning_match = re.search(r"(.*?)", text, re.DOTALL) + issues_match = re.search(r"(.*?)", text, re.DOTALL) + + score = int(score_match.group(1)) if score_match else 0 + reasoning = reasoning_match.group(1).strip() if reasoning_match else text + + key_issues = [] + if issues_match: + issues_text = issues_match.group(1) + # Filter out empty lines and "None detected" type messages + key_issues = [ + line.strip().lstrip("- ") + for line in issues_text.strip().split("\n") + if line.strip() and not line.strip().lstrip("- ").lower().startswith("none") + ] + + return {"score": score, "reason": reasoning, "key_issues": key_issues} + + +def build_translation_quality_messages(original_text: str, translation: str) -> List[dict]: + return [ + {"role": "system", "content": get_translation_quality_system_prompt()}, + { + "role": "user", + "content": TRANSLATION_QUALITY_USER_PROMPT.format( + original=original_text, + translation=translation + ), + }, + ] + + +class TranslationQualityGrader(LLMGrader): + def __init__(self, model: BaseChatModel | dict): + super().__init__( + name="translation_quality", + mode=GraderMode.POINTWISE, + description="Evaluate translation quality based on specific error patterns", + model=model, + template="", # Placeholder, not used + ) + + async def aevaluate(self, original_text: str, translation: str, normalize=True) -> GraderScore: + try: + messages = build_translation_quality_messages(original_text, translation) + response = await self.model.achat(messages=messages) + content = await extract_response_content(response) + parsed = parse_translation_quality_response(content) + + if normalize: + parsed["score"] = parsed["score"] / 2.0 + + return GraderScore( + name=self.name, + score=parsed["score"], + reason=parsed["reason"], + metadata={"key_issues": parsed["key_issues"]}, + ) + except Exception as e: + return GraderError(name=self.name, error=str(e)) + + +async def extract_response_content(response) -> str: + if hasattr(response, 'content'): + return response.content + elif isinstance(response, dict) and 'content' in response: + return response['content'] + elif isinstance(response, str): + return response + else: + raise ValueError(f"Unable to extract content from response: {type(response)}") diff --git a/tutorial/example_academic_trans/trans_roll.py b/tutorial/example_academic_trans/trans_roll.py new file mode 100644 index 00000000..e060d895 --- /dev/null +++ b/tutorial/example_academic_trans/trans_roll.py @@ -0,0 +1,107 @@ +import re +import threading +import requests +import time +from loguru import logger +from textwrap import dedent +from ajet.copilot.job import AgentJetJob +from ajet.tuner_lib.weight_tuner.experimental.as_tinkerscript_client import TinkerScriptClient +from ajet.default_config.ajet_default import AjetTaskReader, HuggingfaceDatRepo +from ajet.tuner_lib.weight_tuner.as_oai_baseurl_apikey import OpenaiBaseUrlAndApiKey +from ajet import WorkflowOutput +from ajet.schema.task import Task +from ajet.task_reader import RouterTaskReader +from ajet.utils.retry import retry_with_backoff +from concurrent.futures import ThreadPoolExecutor +from tutorial.example_academic_trans.trans import execute_agent + +# python -m tutorial.example_academic_trans.trans_roll + + +# --------- configurations that take effect locally ------------- +LOCAL_GRPO_N = 4 # grpo group size +LOCAL_NUM_EPOCH = 10000 +LOCAL_NUM_EPOCH = 1 +LOCAL_MAX_PARALLEL = 32 +LOCAL_DATASET_PATH = "/mnt/data_cpfs/qingxu.fu/agentjet/agentjet/tmp/arxiv_papers/train.parquet" +REMOTE_TINKERJET_URL = "http://localhost:10086" # Change to your tinkerscript remote url + +# --------- configurations that take effect remotely ------------- +REMOTE_ALLOCATE_GPU_PER_NODE = 8 +REMOTE_TRAIN_MODEL_01 = '/mnt/data_cpfs/model_cache/modelscope/hub/Qwen/Qwen/Qwen2.5-7B-Instruct' +REMOTE_BATCH_SIZE = 32 + +class WeightUpdatedHalfway(Exception): + """Raised when the remote side starts updating model weights halfway through an episode.""" + + +def main(): + + # Handshake with tinkerscript remote, then send training param to tinkerscript remote (such as model to be trained, algorithm, etc) + dataset = RouterTaskReader( + reader_type = "huggingface_dat_repo", + reader_config = AjetTaskReader( + huggingface_dat_repo = HuggingfaceDatRepo( + dataset_path = LOCAL_DATASET_PATH + ) + ) + ) + + # Hand shake with remote tinkerscript server + tinkerscript_remote = TinkerScriptClient(REMOTE_TINKERJET_URL) + # tinkerscript_remote.stop_engine() + tinkerscript_remote.auto_sync_train_config_and_start_engine( + AgentJetJob( + algorithm="grpo", + n_gpu=REMOTE_ALLOCATE_GPU_PER_NODE, + model=REMOTE_TRAIN_MODEL_01, + grpo_n=LOCAL_GRPO_N, + ), + force_restart=True, + ) + + # Define rollout + def rollout(task): + group_reward = [] + for i in range(LOCAL_GRPO_N): + episode_uuid = None + try: + # begin episode + episode_uuid, api_baseurl_key = tinkerscript_remote.begin_episode() + # execute agent + workflow_output = execute_agent(task, api_baseurl_key) + # report output back to tinkerscript remote + tinkerscript_remote.end_episode(task, episode_uuid, workflow_output) + # collect reward + group_reward.append(workflow_output.reward) + except Exception as e: + logger.exception("Exception during rollout:", e) + if episode_uuid: + tinkerscript_remote.abort_episode(episode_uuid) + print(f"Group reward mean & std: {sum(group_reward)/len(group_reward)} +/- { (max(group_reward)-min(group_reward))/2 }") + + # Main Training loop + futures = [] + with ThreadPoolExecutor(max_workers=LOCAL_MAX_PARALLEL) as executor: + for epoch in range(LOCAL_NUM_EPOCH): + for i, task in enumerate(dataset.generate_training_tasks()): + print(f"Submitting task for epoch {epoch}") + future = executor.submit(rollout, task) + + futures += [future] + while (i % REMOTE_BATCH_SIZE) == (REMOTE_BATCH_SIZE - 1) and futures: + futures = [f for f in futures if not f.done()] + time.sleep(1) + + + tinkerscript_remote.stop_engine() + # model_path = tinkerscript_remote.download_latest_model(path='./tinkerscript_saved_model') + + # Get tuned model from tinkerscript remote + return None + + + + +if __name__ == "__main__": + main() From 3157658766109950e1772a2c1d918c66f4835e29 Mon Sep 17 00:00:00 2001 From: "qingxu.fu" Date: Thu, 5 Feb 2026 01:15:08 +0800 Subject: [PATCH 18/25] fix state machine bugs --- ajet/__init__.py | 33 ++- ajet/context_tracker/multiagent_tracking.py | 2 + ajet/copilot/job.py | 2 + ajet/default_config/ajet_default.yaml | 2 +- ajet/default_config/ajet_ts_default.yaml | 7 +- ajet/task_rollout/native_parallel_worker.py | 83 +++---- ajet/task_rollout/single_worker.py | 57 ++++- ajet/task_runner/base_runner.py | 10 +- ajet/task_runner/general_runner.py | 4 +- ajet/task_runner/tinkerscript_runner.py | 154 +++++++----- ajet/tuner.py | 6 +- .../experimental/as_oai_model_client.py | 29 ++- .../experimental/as_oai_model_server.py | 53 ++-- .../experimental/as_tinkerscript_client.py | 9 +- .../experimental/as_tinkerscript_server.py | 231 ++++++++++++------ .../experimental/interchange_utils.py | 53 ++-- ajet/utils/retry.py | 9 +- ajet/workflow.py | 2 +- ajet_tinkerscript_threading.py | 96 ++++---- tutorial/example_academic_trans/trans_roll.py | 4 +- 20 files changed, 533 insertions(+), 313 deletions(-) diff --git a/ajet/__init__.py b/ajet/__init__.py index b0731e74..c7081e95 100644 --- a/ajet/__init__.py +++ b/ajet/__init__.py @@ -1,8 +1,4 @@ -from ajet.copilot.job import AgentJetJob -from ajet.schema.task import WorkflowOutput, WorkflowTask -from ajet.tuner import AjetTuner -from ajet.workflow import Workflow -from ajet.utils.vsdb import vscode_conditional_breakpoint as bp +__version__ = "0.1.0" __all__ = [ "Workflow", @@ -13,4 +9,29 @@ "bp" ] -__version__ = "0.1.0" +_LAZY_IMPORTS = { + "AjetTuner": "ajet.tuner", + "AgentJetJob": "ajet.copilot.job", + "WorkflowOutput": "ajet.schema.task", + "WorkflowTask": "ajet.schema.task", + "Workflow": "ajet.workflow", + "bp": "ajet.utils.vsdb", +} + +_ATTR_MAPPING = { + "bp": "vscode_conditional_breakpoint" +} + +def __getattr__(name): + if name in _LAZY_IMPORTS: + import importlib + module_path = _LAZY_IMPORTS[name] + module = importlib.import_module(module_path) + + attr_name = _ATTR_MAPPING.get(name, name) + value = getattr(module, attr_name) # type: ignore + + globals()[name] = value + return value + + raise AttributeError(f"module '{__name__}' has no attribute '{name}'") \ No newline at end of file diff --git a/ajet/context_tracker/multiagent_tracking.py b/ajet/context_tracker/multiagent_tracking.py index dc192aa6..35607224 100644 --- a/ajet/context_tracker/multiagent_tracking.py +++ b/ajet/context_tracker/multiagent_tracking.py @@ -49,12 +49,14 @@ def __init__( tokenizer: PreTrainedTokenizer, config, should_interrupt_fn, + should_interrupt_hard_fn, generated_token_callback_fn, **kwargs, ): super().__init__(config, tokenizer, **kwargs) self.tokenizer = tokenizer self.should_interrupt_fn = should_interrupt_fn + self.should_interrupt_hard_fn = should_interrupt_hard_fn self.generated_token_callback_fn = generated_token_callback_fn self.context_overflow = False self.output_kwargs = {} diff --git a/ajet/copilot/job.py b/ajet/copilot/job.py index 21f96d4a..2b6acde5 100644 --- a/ajet/copilot/job.py +++ b/ajet/copilot/job.py @@ -44,6 +44,7 @@ def __init__( algorithm: str = "grpo", n_gpu_for_infer: int | None = None, # only for trinity backbone grpo_n: int = 8, + batch_size: int = 32, tinkerscript_mode: bool = True, *kwargs, ) -> None: @@ -60,6 +61,7 @@ def __init__( self.config.ajet.trainer_common.n_gpus_per_node = n_gpu self.config.ajet.trainer_common.algorithm.adv_estimator = algorithm self.config.ajet.rollout.num_repeat = grpo_n + self.config.ajet.data.train_batch_size = batch_size if n_gpu_for_infer is None and backbone == "trinity": raise ValueError("Please specify `n_gpu_for_infer` (n_gpu_for_infer < n_gpu) for trinity backbone.") if (n_gpu_for_infer is not None) and backbone == "verl": diff --git a/ajet/default_config/ajet_default.yaml b/ajet/default_config/ajet_default.yaml index 9def61da..49a6c9a8 100644 --- a/ajet/default_config/ajet_default.yaml +++ b/ajet/default_config/ajet_default.yaml @@ -290,7 +290,7 @@ ajet: interchange_method: 'ipc' # options: 'tcp' (multi-nodes) or 'ipc' (1 node) interchange_server_port: 'auto' num_fastapi_process: 2 # 1, 2 or 4 is fine - max_fastapi_threads: 128 # 64 or 128 is fine + max_fastapi_threads: 512 # 64 or 128 is fine max_inference_tracker_threads: 64 # recommend to be equal to `ajet.rollout.max_env_worker` already_started: False # do not edit, used by `tinkerscript` diff --git a/ajet/default_config/ajet_ts_default.yaml b/ajet/default_config/ajet_ts_default.yaml index 1b8cc3a1..be7ca5a9 100644 --- a/ajet/default_config/ajet_ts_default.yaml +++ b/ajet/default_config/ajet_ts_default.yaml @@ -7,7 +7,7 @@ ajet: model: # which model should be trained - path: /mnt/data_cpfs/model_cache/modelscope/hub/Qwen/Qwen/Qwen2___5-14B-Instruct + path: /mnt/data_cpfs/model_cache/modelscope/hub/Qwen/Qwen/Qwen2___5-3B-Instruct rollout: # the path to the workflow class @@ -29,10 +29,13 @@ ajet: interchange_method: 'ipc' # options: 'tcp' (multi-nodes) or 'ipc' (1 node) interchange_server_port: 10086 num_fastapi_process: 2 # 1, 2 or 4 is fine - max_fastapi_threads: 128 # 64 or 128 is fine + max_fastapi_threads: 512 # 64 or 128 is fine max_inference_tracker_threads: 64 # recommend to be equal to `ajet.rollout.max_env_worker` already_started: False # do not edit, used by `tinkerscript` + rollout: + # maximum number of parallel environments / simulate workers + max_env_worker: 128 # ------------------ 不需要修改 ------------------ diff --git a/ajet/task_rollout/native_parallel_worker.py b/ajet/task_rollout/native_parallel_worker.py index 1e12ec81..35b56172 100644 --- a/ajet/task_rollout/native_parallel_worker.py +++ b/ajet/task_rollout/native_parallel_worker.py @@ -8,6 +8,7 @@ import numpy as np import torch +import threading from loguru import logger from tensordict import TensorDict from torch.nn.utils.rnn import pad_sequence @@ -15,10 +16,11 @@ from verl import DataProto from verl.utils.torch_functional import pad_sequence_to_length -from ajet.context_tracker.basic_tracker import BaseContextTracker from ajet.schema.task import Task from ajet.schema.trajectory import Sample from ajet.task_rollout.single_worker import BaseRolloutManager +from ajet.context_tracker.basic_tracker import BaseContextTracker +from ajet.tuner_lib.weight_tuner.experimental.interchange_utils import http_change_engine_status class DynamicRolloutManager(BaseRolloutManager): @@ -481,7 +483,9 @@ def rollout_swarm( # noqa: C901 tracker_array: List[BaseContextTracker] = [] assert mode != "validate" rollout_n = self.rollout_n - n_task = len(tasks) + n_batch_task = len(tasks) + n_task = min(len(tasks), self.max_parallel // rollout_n) + assert n_task > 0, f"n_task is not valid, n_task = min(len(tasks), self.max_parallel // rollout_n) = {n_task}" self.current_token_count_time = time.time() # initialize observation window @@ -489,11 +493,13 @@ def rollout_swarm( # noqa: C901 "info": ["" for _ in range(n_task * rollout_n)], "step": [0 for _ in range(n_task * rollout_n)], "stop": [False for _ in range(n_task * rollout_n)], + "hard_stop": [False for _ in range(n_task * rollout_n)], "token": [0 for _ in range(n_task * rollout_n)], } executor = ThreadPoolExecutor(max_workers=self.max_parallel) futures: List[Future] = [] completed_task_id_map_ct: Dict[str, List[BaseContextTracker]] = {} + executor_lock = threading.Lock() # submit initial tasks dummy_task = Task(main_query="dummy task") @@ -501,13 +507,15 @@ def rollout_swarm( # noqa: C901 for task_rollout_index in range(rollout_n): task_thread_index = task_batch_index * rollout_n + task_rollout_index future = executor.submit( - self.rollout_env_worker, + self.rollout_env_worker_loop, task=dummy_task, task_tag="", mode=mode, task_batch_index=task_batch_index, task_thread_index=task_thread_index, observation_window=observation_window, + completed_task_id_map_ct=completed_task_id_map_ct, + executor_lock=executor_lock, ) observation_window["info"][task_thread_index] = "1" futures.append(future) @@ -516,14 +524,15 @@ def enough_sample_stop_condition(completed_task_id_map_ct) -> bool: n = 0 for ct_list in completed_task_id_map_ct.values(): n += len(ct_list) - return (n >= n_task * rollout_n) + print(f"Current collected samples: {n}, target: {n_batch_task * rollout_n}") + return (n >= n_batch_task * rollout_n) def enough_finished_task_stop_condition(completed_task_id_map_ct) -> bool: n_finish_roll_task = 0 for ct_list in completed_task_id_map_ct.values(): if len(ct_list) >= rollout_n: n_finish_roll_task += 1 - return (n_finish_roll_task >= n_task) + return (n_finish_roll_task >= n_batch_task) def enough_non_dummy_task_stop_condition(completed_task_id_map_ct) -> bool: n_finish_roll_task = 0 @@ -535,63 +544,39 @@ def enough_non_dummy_task_stop_condition(completed_task_id_map_ct) -> bool: all_equal = all(x == task_cmd_reward_array[0] for x in task_cmd_reward_array) if all_equal: continue n_finish_roll_task += 1 - return (n_finish_roll_task >= n_task) + return (n_finish_roll_task >= n_batch_task) stop_condition = enough_sample_stop_condition - def force_stop_all_threads(): - for k in range(len(observation_window["stop"])): - observation_window["stop"][k] = True + def stop_all_threads_soft(): + for k in range(len(observation_window["stop"])): observation_window["stop"][k] = True + http_change_engine_status(self.config, "ENGINE.ROLLING_POST") + return + + def stop_all_threads_hard(): + for k in range(len(observation_window["hard_stop"])): observation_window["hard_stop"][k] = True + http_change_engine_status(self.config, "ENGINE.WEIGHT_SYNCING") return - tic = time.time() + cnt = 0 while True: - # wait for a completed task - done_arr, pending_arr = wait(futures, timeout=10, return_when=FIRST_COMPLETED) - print(f"Done tasks: {len(done_arr)}, Pending tasks: {len(pending_arr)}") - toc = time.time() - if (toc - tic) > 8: - tic = toc + cnt += 1 + time.sleep(2) + if (cnt % 5 == 0): self.step_status_printer(observation_window) - # get result - for future in done_arr: - ct: BaseContextTracker = future.result() - if ct.task_id not in completed_task_id_map_ct: - completed_task_id_map_ct[ct.task_id] = [ct] - else: - completed_task_id_map_ct[ct.task_id] += [ct] - # if meet stop condition meet_stop_condition_after_new_results = stop_condition(completed_task_id_map_ct) if meet_stop_condition_after_new_results: - force_stop_all_threads() + print("Sending soft stop signal to all threads...") + stop_all_threads_soft() break - else: - # re-spawn new tasks for done futures - for task_batch_index in range(n_task): - for task_rollout_index in range(rollout_n): - task_thread_index = task_batch_index * rollout_n + task_rollout_index - has_done = (futures[task_thread_index] in done_arr) - - observation_window["info"][task_thread_index] = str(int(observation_window["info"][task_thread_index]) + 1) - observation_window["stop"][task_thread_index] = False - observation_window["step"][task_thread_index] = 0 - - if has_done: - print(f"Re-spawning thread {task_thread_index}...") - future = executor.submit( - self.rollout_env_worker, - task=dummy_task, - task_tag="", - mode=mode, - task_batch_index=task_batch_index, - task_thread_index=task_thread_index, - observation_window=observation_window, - ) - futures[task_thread_index] = future # wait for all threads to complete print('Finalizing all threads...') - wait(futures, return_when=ALL_COMPLETED) + executor.shutdown(wait=True) + + # stop all threads hard + print("Sending hard stop signal to all threads...") + stop_all_threads_hard() # build tracker_array print('Collecting results...') diff --git a/ajet/task_rollout/single_worker.py b/ajet/task_rollout/single_worker.py index 3e71492e..908c9d47 100644 --- a/ajet/task_rollout/single_worker.py +++ b/ajet/task_rollout/single_worker.py @@ -1,10 +1,13 @@ """Single worker primitives for environment rollouts.""" import uuid +import time +import threading from typing import Literal from loguru import logger from omegaconf import DictConfig +from typing import Dict, List, Literal from transformers.tokenization_utils import PreTrainedTokenizer from ajet.context_tracker.basic_tracker import BaseContextTracker @@ -14,9 +17,9 @@ from ajet.task_runner.general_runner import GeneralRunner from ajet.task_runner.tinkerscript_runner import TinkerScriptRunner from ajet.utils.retry import retry_with_backoff +from ajet.utils.retry import SwarmReceiveAbortException from ajet.utils.sample import get_sample_params from ajet.utils.testing_utils import TestFailException, TestSuccessException -from ajet.task_runner.tinkerscript_runner import SwarmReceiveAbortException class BaseRolloutManager: @@ -125,6 +128,7 @@ def rollout_env_worker( workflow_task=workflow_task, ) except SwarmReceiveAbortException as exc: # noqa: BLE001 + print('SwarmReceiveAbortException caught in rollout_env_worker') return None # type: ignore except TestSuccessException as e: logger.success( @@ -141,3 +145,54 @@ def rollout_env_worker( raise e return tracker + + + def rollout_env_worker_loop( + self, + task: Task, + task_batch_index: int, + task_tag: str, + mode: Literal["sample", "validate"], + task_thread_index: int, + observation_window: dict, + completed_task_id_map_ct: Dict[str, List[BaseContextTracker]], + executor_lock: threading.Lock, + **kwargs, + ): + try: + cnt = 1 + while True: + + if observation_window["stop"][task_thread_index]: + print('rollout_env_worker_loop received stop signal, exiting...') + return + + observation_window["info"][task_thread_index] = str(cnt) + tracker = self.rollout_env_worker( + task=task, + task_batch_index=task_batch_index, + task_tag=task_tag, + mode=mode, + task_thread_index=task_thread_index, + observation_window=observation_window, + **kwargs, + ) + + # avoid write conflict + if tracker and tracker.reward_structure: + with executor_lock: + if tracker.task_id not in completed_task_id_map_ct: + completed_task_id_map_ct[tracker.task_id] = [tracker] + else: + completed_task_id_map_ct[tracker.task_id] += [tracker] + cnt += 1 + if observation_window["stop"][task_thread_index]: + return + else: + del tracker + + except Exception as e: + logger.exception( + f"encounter exception in env_worker_loop error={e.args}" + ) + raise e \ No newline at end of file diff --git a/ajet/task_runner/base_runner.py b/ajet/task_runner/base_runner.py index d8c15492..ad457a5c 100644 --- a/ajet/task_runner/base_runner.py +++ b/ajet/task_runner/base_runner.py @@ -49,9 +49,12 @@ def get_judge(self) -> BaseJudge: # type: ignore def runner_hooks(self, observation_window, task_thread_index, workflow_task): def should_interrupt_fn() -> bool: - if (observation_window["stop"] is not None) and observation_window["stop"][ - task_thread_index - ]: # Check if the thread should stop (because other threads have completed, making this thread useless) + if (observation_window["stop"] is not None) and observation_window["stop"][task_thread_index]: # Check if the thread should stop (because other threads have completed, making this thread useless) + return True + return False + + def should_interrupt_hard_fn() -> bool: + if (observation_window["hard_stop"] is not None) and observation_window["hard_stop"][task_thread_index]: # Check if the thread should stop (because other threads have completed, making this thread useless) return True return False @@ -60,6 +63,7 @@ def generated_token_callback_fn(token_array): return { "should_interrupt_fn": should_interrupt_fn, + "should_interrupt_hard_fn": should_interrupt_hard_fn, "generated_token_callback_fn": generated_token_callback_fn, } diff --git a/ajet/task_runner/general_runner.py b/ajet/task_runner/general_runner.py index 88f9ab11..a3e3db92 100644 --- a/ajet/task_runner/general_runner.py +++ b/ajet/task_runner/general_runner.py @@ -1,6 +1,6 @@ -from ajet import AjetTuner -from ajet import WorkflowOutput +from ajet.tuner import AjetTuner +from ajet.schema.task import WorkflowOutput, WorkflowTask from ajet.context_tracker.multiagent_tracking import ( MultiAgentContextTracker, ) diff --git a/ajet/task_runner/tinkerscript_runner.py b/ajet/task_runner/tinkerscript_runner.py index 86e40899..b84f1f3d 100644 --- a/ajet/task_runner/tinkerscript_runner.py +++ b/ajet/task_runner/tinkerscript_runner.py @@ -3,28 +3,24 @@ import json import zmq import os -from ajet import AjetTuner -from ajet import WorkflowOutput -from ajet.context_tracker.multiagent_tracking import ( - MultiAgentContextTracker, -) +from ajet.tuner import AjetTuner +from ajet.schema.task import WorkflowOutput +from ajet.context_tracker.multiagent_tracking import MultiAgentContextTracker from ajet.context_tracker.basic_tracker import BaseContextTracker from ajet.schema.task import WorkflowTask from ajet.schema.trajectory import Reward from ajet.task_runner.base_runner import BaseAgentRunner -from ajet.tuner_lib.weight_tuner.experimental.interchange_utils import http_register_episode, get_zmq_socket +from ajet.utils.retry import SwarmReceiveAbortException +from ajet.tuner_lib.weight_tuner.experimental.interchange_utils import http_register_episode, get_zmq_socket, is_episode_claimed from loguru import logger from ajet import Workflow from typing import Callable -DEBUG = False +DEBUG = True context = zmq.Context() atexit.register(context.term) -class SwarmReceiveAbortException(Exception): - pass - class TinkerScriptRunner(BaseAgentRunner): def register_episode_and_wait_output( @@ -34,78 +30,94 @@ def register_episode_and_wait_output( openai_api_key: str, context_tracker: BaseContextTracker, tuner:AjetTuner, - should_exit:Callable - ) -> WorkflowOutput: + should_exit_soft:Callable, + should_exit_hard:Callable + ) -> WorkflowOutput | None: """Register the episode as ready in the TinkerScript data interchange center.""" # parse episode_uuid, openai_base_url, openai_api_key zmq_listen_result_addr, ipc_path = get_zmq_socket(self.config, episode_uuid, tag="workflow") - try: - http_register_episode( - self.config, - episode_uuid=episode_uuid, - openai_base_url=openai_base_url, - openai_api_key=openai_api_key, - zmq_listen_result_addr=zmq_listen_result_addr, - ) - except Exception as e: - raise SwarmReceiveAbortException(f"Episode {episode_uuid} cannot be registered.") + success = http_register_episode( + self.config, + episode_uuid=episode_uuid, + openai_base_url=openai_base_url, + openai_api_key=openai_api_key, + zmq_listen_result_addr=zmq_listen_result_addr, + should_exit_soft=should_exit_soft, + ) + if not success: + return None # type: ignore if DEBUG: logger.info(f"zmq_listen_result_addr: {zmq_listen_result_addr}") # begin wait for result zmq_socket = zmq.Context().socket(zmq.REP) zmq_socket.bind(zmq_listen_result_addr) - zmq_socket.setsockopt(zmq.RCVTIMEO, 3*1000) # 3 second timeout for REP + zmq_socket.setsockopt(zmq.RCVTIMEO, 1*1000) # 1 second timeout for REP + speicial_messages = [ - "RUNNER.SPECIAL.RESET_CONTEXT_TRACKER" + "RUNNER.SPECIAL.RESET_CONTEXT_TRACKER", "RUNNER.SPECIAL.ABORT" ] - while True: - # : - # : ajet/tuner_lib/weight_tuner/experimental/as_tinkerscript_server.py - # : socket.send_string(workflow_output.model_dump_json()) - # : workflow_output: WorkflowOutput - # : - # : ajet/tuner_lib/weight_tuner/experimental/as_tinkerscript_server.py - # : socket.send_string("RUNNER.SPECIAL.RESET_CONTEXT_TRACKER") - # : "RUNNER.SPECIAL.RESET_CONTEXT_TRACKER" - try: - message = zmq_socket.recv_string() - except zmq.Again as e: - if should_exit(): + + try: + + while True: + # : + # : ajet/tuner_lib/weight_tuner/experimental/as_tinkerscript_server.py + # : socket.send_string(workflow_output.model_dump_json()) + # : workflow_output: WorkflowOutput + # : + # : ajet/tuner_lib/weight_tuner/experimental/as_tinkerscript_server.py + # : socket.send_string("RUNNER.SPECIAL.RESET_CONTEXT_TRACKER") + # : "RUNNER.SPECIAL.RESET_CONTEXT_TRACKER" + try: + message = zmq_socket.recv_string() + except zmq.Again as e: + if should_exit_hard(): + logger.warning(f'{episode_uuid} Exiting workflow due to should_exit_hard signal.') + context_tracker.reset() + raise SwarmReceiveAbortException(f"Episode {episode_uuid} aborted due to system exit.") + elif should_exit_soft(): + has_claimed = is_episode_claimed(self.config, episode_uuid) + if not has_claimed: + raise SwarmReceiveAbortException(f"Episode {episode_uuid} aborted due to system exit.") + else: + continue + else: + continue + # process messages + if message not in speicial_messages: + zmq_socket.send_string("ack") + break + elif message == "RUNNER.SPECIAL.RESET_CONTEXT_TRACKER": + logger.warning(f"Received reset command for episode {episode_uuid}.") + context_tracker.reset() + zmq_socket.send_string("ack") + elif message == "RUNNER.SPECIAL.ABORT": + logger.warning(f"Received abort command for episode {episode_uuid}.") context_tracker.reset() - tuner.terminate_episode() - raise SwarmReceiveAbortException(f"Episode {episode_uuid} aborted due to system exit.") + zmq_socket.send_string("ack") + return None else: - continue - # process messages - if message not in speicial_messages: - zmq_socket.send_string("ack") - break - elif message == "RUNNER.SPECIAL.RESET_CONTEXT_TRACKER": - logger.warning(f"Received reset command for episode {episode_uuid}.") - context_tracker.reset() - zmq_socket.send_string("ack") - elif message == "RUNNER.SPECIAL.ABORT": - logger.warning(f"Received abort command for episode {episode_uuid}.") - context_tracker.reset() - zmq_socket.send_string("ack") - tuner.terminate_episode() - raise SwarmReceiveAbortException(f"Episode {episode_uuid} aborted.") - else: - tuner.terminate_episode() - raise RuntimeError(f"Unknown special message received: {message}") - - final_output = WorkflowOutput(**json.loads(message)) - reward = final_output.reward - logger.success(f"Received workflow output for episode {episode_uuid} (Reward: {reward})") - zmq_socket.close() - if ipc_path and os.path.exists(ipc_path): os.remove(ipc_path) + raise RuntimeError(f"Unknown special message received: {message}") + + final_output = WorkflowOutput(**json.loads(message)) + reward = final_output.reward + logger.success(f"Received workflow output for episode {episode_uuid} (Reward: {reward})") + + except Exception as exc: + raise exc + + finally: + zmq_socket.close() + tuner.terminate_episode() + if ipc_path and os.path.exists(ipc_path): os.remove(ipc_path) return final_output def execute(self, workflow_task: WorkflowTask) -> BaseContextTracker: + observation_window = workflow_task.observation_window task_thread_index = workflow_task.task_thread_index @@ -114,6 +126,14 @@ def execute(self, workflow_task: WorkflowTask) -> BaseContextTracker: task_thread_index=task_thread_index, workflow_task=workflow_task, ) + + should_exit_soft = hooks['should_interrupt_fn'] + should_exit_hard = hooks['should_interrupt_hard_fn'] + + if should_exit_soft() or should_exit_hard(): + print(f'Exiting workflow worker due to interrupt signal for episode {workflow_task.episode_uuid}.') + raise SwarmReceiveAbortException(f"Episode {workflow_task.episode_uuid} aborted due to interrupt signal.") + context_tracker = MultiAgentContextTracker( llm_inference_fn=self.llm_inference_fn, tokenizer=self.tokenizer, @@ -130,18 +150,22 @@ def execute(self, workflow_task: WorkflowTask) -> BaseContextTracker: # from tuner, we get base_url and api_key baseurl_apikey = tuner.as_oai_baseurl_apikey() + base_url = baseurl_apikey.base_url api_key = baseurl_apikey.api_key # wait for remote client to return workflow output - workflow_output: WorkflowOutput = self.register_episode_and_wait_output( + workflow_output: WorkflowOutput | None = self.register_episode_and_wait_output( episode_uuid=context_tracker.episode_uuid, openai_base_url=base_url, openai_api_key=api_key, context_tracker=context_tracker, tuner=tuner, - should_exit=(lambda: observation_window["stop"][task_thread_index]) + should_exit_soft=should_exit_soft, + should_exit_hard=should_exit_hard, ) + if not workflow_output: + return None # type: ignore # the most important thing is to fix task_id to client task_id, set task_id to workflow_task and context_tracker task_id assert "task_id" in workflow_output.metadata, "workflow_output.metadata must contain task_id" diff --git a/ajet/tuner.py b/ajet/tuner.py index 24258473..90fcbbfd 100644 --- a/ajet/tuner.py +++ b/ajet/tuner.py @@ -1,9 +1,6 @@ from typing import TYPE_CHECKING, Callable, Union, Type -from ajet.context_tracker.multiagent_tracking import ( - MultiAgentContextTracker, -) - +from ajet.context_tracker.multiagent_tracking import MultiAgentContextTracker from ajet.tuner_lib.weight_tuner import AgentScopeModelTuner from ajet.tuner_lib.weight_tuner import OpenaiClientModelTuner from ajet.tuner_lib.weight_tuner.as_oai_baseurl_apikey import OpenaiClientBaseUrlTuner @@ -189,3 +186,4 @@ def terminate_episode(self): if self.enable_interchange_server: if (self.proxy_client_started is True) and hasattr(self, "interchange_client"): self.interchange_client._should_terminate = True + print(f'-->self.interchange_client._should_terminate = {self.interchange_client.should_terminate}') diff --git a/ajet/tuner_lib/weight_tuner/experimental/as_oai_model_client.py b/ajet/tuner_lib/weight_tuner/experimental/as_oai_model_client.py index 4ff26b4d..396457b8 100644 --- a/ajet/tuner_lib/weight_tuner/experimental/as_oai_model_client.py +++ b/ajet/tuner_lib/weight_tuner/experimental/as_oai_model_client.py @@ -22,8 +22,8 @@ if TYPE_CHECKING: from ajet.context_tracker.multiagent_tracking import MultiAgentContextTracker -DEBUG = False -# DEBUG = True +# DEBUG = False +DEBUG = True def generate_auth_token(agent_name, target_tag, episode_uuid, episode_address): """ @@ -104,17 +104,23 @@ async def llm_infer( @property def should_terminate(self) -> bool: - return self._should_terminate + try: + should_interrupt = self.context_tracker.should_interrupt_hard_fn() + return self._should_terminate or should_interrupt + except: + return self._should_terminate def begin_service(self): """ Starts the zmq communication loop. """ + if self.should_terminate: + return self.episode_contect_address if DEBUG: logger.info(f"[client] {self.episode_uuid} | Starting InterchangeClient service loop...") self.socket = context.socket(zmq.REP) self.socket.bind(f"{self.episode_contect_address}") - self.socket.setsockopt(zmq.RCVTIMEO, 3*1000) # 3 second timeout for REP + self.socket.setsockopt(zmq.RCVTIMEO, 1*1000) # 3 second timeout for REP self.executor = SharedInterchangeThreadExecutor(self.max_inference_tracker_threads).get_shared_executor() if DEBUG: logger.info(f"[client] {self.episode_uuid} | Submitting _begin_service_threading to executor...") @@ -124,9 +130,14 @@ def begin_service(self): time.sleep(0.5) wait_time = 1 while future._state == 'PENDING': + if self.should_terminate: + future.cancel() + return self.episode_contect_address time.sleep(min(wait_time * 2, 10)) wait_time += 1 - + if self.should_terminate: + future.cancel() + return self.episode_contect_address if DEBUG: logger.info(f"[client] {self.episode_uuid} | Future ready...") return self.episode_contect_address @@ -142,15 +153,16 @@ def _begin_service_threading(self): while not self.should_terminate: # listen for next request from remote try: - if DEBUG: logger.info(f"[client] {self.episode_uuid} | socket.recv_string() has begun") + # if DEBUG: logger.info(f"[client] {self.episode_uuid} | socket.recv_string() has begun (should_terminate {self.should_terminate})") message = self.socket.recv_string() - if DEBUG: logger.info(f"[client] {self.episode_uuid} | socket.recv_string() is done") + # if DEBUG: logger.info(f"[client] {self.episode_uuid} | socket.recv_string() is done") except zmq.Again as e: if self.should_terminate: + # abort_episode() if DEBUG: logger.info(f"[client] {self.episode_uuid} | episode over") break timepassed = time.time() - begin_time - if timepassed > 60: + if timepassed > 100: if DEBUG: logger.warning(f"[client] {self.episode_uuid} | Still waiting for first message... (time passed {timepassed}) for episode_uuid:{self.episode_uuid}...") continue @@ -192,3 +204,4 @@ def _begin_service_threading(self): if os.path.exists(self.ipc_path): os.remove(self.ipc_path) if DEBUG: logger.info(f"[client] {self.episode_uuid} | IPC socket file {self.ipc_path} removed.") + diff --git a/ajet/tuner_lib/weight_tuner/experimental/as_oai_model_server.py b/ajet/tuner_lib/weight_tuner/experimental/as_oai_model_server.py index 1e4bdc43..366ff7d9 100644 --- a/ajet/tuner_lib/weight_tuner/experimental/as_oai_model_server.py +++ b/ajet/tuner_lib/weight_tuner/experimental/as_oai_model_server.py @@ -33,8 +33,9 @@ from vllm.entrypoints.openai.protocol import ChatCompletionRequest from openai.types.chat.chat_completion import ChatCompletion -from ajet.tuner_lib.weight_tuner.experimental.interchange_utils import EpisodeStatus from ajet.utils.networking import find_free_port, get_host_ip +from ajet.tuner_lib.weight_tuner.experimental.interchange_utils import EpisodeStatus + API_KEY_PREFIX = "sk-ajet-" class InterchangeCompletionRequest(BaseModel): @@ -53,8 +54,8 @@ class HealthCheckRequest(BaseModel): # Create FastAPI app SERVER_SHUTDOWN_EVENT = threading.Event() -DEBUG = False -# DEBUG = True +# DEBUG = False +DEBUG = True context = zmq.Context() atexit.register(context.term) @@ -84,25 +85,31 @@ async def lifespan(app: FastAPI): app = FastAPI(title="AJet Interchange Endpoint", lifespan=lifespan) - def _begin_handle_chat_completion(episode_address, int_req: InterchangeCompletionRequest, episode_uuid, timeline_uuid, client_offline: threading.Event): + def _begin_handle_chat_completion(episode_address, int_req: InterchangeCompletionRequest, episode_uuid): """ run this in thread to avoid blocking main event loop """ if DEBUG: logger.info(f"[server] episode_uuid: {episode_uuid} | Received new chat completion request (inside thread)") socket = context.socket(zmq.REQ) - socket.setsockopt(zmq.RCVTIMEO, 60*1000) # 1 minute recv timeout + socket.setsockopt(zmq.RCVTIMEO, 6*1000) # 6 second recv timeout socket.connect(f"{episode_address}") if DEBUG: logger.info(f"[server] episode_uuid: {episode_uuid} | connect done") socket.send_string(int_req.model_dump_json()) if DEBUG: logger.info(f"[server] episode_uuid: {episode_uuid} | send_string") result_str = "" - for _ in range(5): # max 5 minutes wait + for _ in range(50): # max 5 minutes wait try: if DEBUG: logger.info(f"[server] episode_uuid: {episode_uuid} | recv_string begin.") result_str = socket.recv_string() break except zmq.Again as e: + # check whether server is still in rolling status + if enable_tinkerscript_mode: + assert shared_mem_dict is not None + if shared_mem_dict['engine_status'] not in ["ENGINE.ROLLING", "ENGINE.ROLLING_POST"]: + raise HTTPException(status_code=404, detail="The server is not in ENGINE.ROLLING status, cannot accept new requests.") + if DEBUG: logger.info(f"[server] episode_uuid: {episode_uuid} | recv_string timeout, retrying.") continue @@ -125,6 +132,8 @@ async def chat_completions(request: Request, authorization: str = Header(None)): OpenAI-compatible chat completions endpoint. Receives ChatCompletionRequest and returns ChatCompletion. """ + if DEBUG: logger.info("Received /v1/chat/completions request") + # Parse authorization header (base64 encoded JSON) if not authorization: return HTTPException(status_code=401, detail="Missing authorization header") @@ -150,23 +159,28 @@ async def chat_completions(request: Request, authorization: str = Header(None)): new_req = ChatCompletionRequest.model_validate(body) if new_req.stream: return HTTPException(status_code=400, detail="Streaming responses not supported in current AgentJet version, please set `stream=false` for now.") + # Create timeline UUID timeline_uuid = uuid.uuid4().hex # enable_tinkerscript_mode if enable_tinkerscript_mode: + from ajet.tuner_lib.weight_tuner.experimental.as_tinkerscript_server import ep_key assert shared_mem_dict is not None assert shared_mem_dict_lock is not None - if shared_mem_dict['engine_status'] != "ENGINE.ROLLING": + + if shared_mem_dict['engine_status'] not in ["ENGINE.ROLLING", "ENGINE.ROLLING_POST"]: logger.error(f"The server is not in ENGINE.ROLLING status (current status: [{shared_mem_dict['engine_status']}]), cannot accept new requests.") - raise HTTPException(status_code=503, detail="The server is not in ENGINE.ROLLING status, cannot accept new requests.") - if (f"episodes-{episode_uuid}") not in shared_mem_dict: + raise HTTPException(status_code=404, detail="The server is not in ENGINE.ROLLING status, cannot accept new requests.") + + if ep_key(episode_uuid) not in shared_mem_dict: raise HTTPException(status_code=404, detail=f"Episode {episode_uuid} not found.") + # update activate timestamp with shared_mem_dict_lock: - es:EpisodeStatus = shared_mem_dict[f"episodes-{episode_uuid}"] + es:EpisodeStatus = shared_mem_dict[ep_key(episode_uuid)] es.latest_activity_timestamp = time.time() - shared_mem_dict[f"episodes-{episode_uuid}"] = es + shared_mem_dict[ep_key(episode_uuid)] = es # Add to received queue int_req = InterchangeCompletionRequest( @@ -177,17 +191,8 @@ async def chat_completions(request: Request, authorization: str = Header(None)): timeline_uuid = timeline_uuid, ) if DEBUG: logger.info(f"episode_uuid: {episode_uuid} | Received new chat completion request (outside thread)") - client_offline = threading.Event() - try: - loop = asyncio.get_running_loop() - return await loop.run_in_executor(request.app.state.executor, _begin_handle_chat_completion, episode_address, int_req, episode_uuid, timeline_uuid, client_offline) - finally: - client_offline.set() - - - @app.post("/reset") - async def reset(): - return {"status": "reset_complete"} + loop = asyncio.get_running_loop() + return await loop.run_in_executor(request.app.state.executor, _begin_handle_chat_completion, episode_address, int_req, episode_uuid) if enable_tinkerscript_mode: @@ -241,7 +246,7 @@ async def serve_with_monitor(additional_coro): app=app, host="0.0.0.0", port=self.port, - log_level="error", + log_level="info", workers=self.num_fastapi_process ) server = uvicorn.Server(config) @@ -350,7 +355,7 @@ def start_interchange_server(config, blocking=False, env={}) -> int: interchange_server.join() except KeyboardInterrupt: logger.info("Shutting down interchange server...") - try: httpx.get(f"http://127.0.0.1:{port}/stop_engine", timeout=8).status_code + try: httpx.post(f"http://127.0.0.1:{port}/stop_engine", timeout=8).status_code except Exception: pass if interchange_server: diff --git a/ajet/tuner_lib/weight_tuner/experimental/as_tinkerscript_client.py b/ajet/tuner_lib/weight_tuner/experimental/as_tinkerscript_client.py index 6b4d3518..71927c75 100644 --- a/ajet/tuner_lib/weight_tuner/experimental/as_tinkerscript_client.py +++ b/ajet/tuner_lib/weight_tuner/experimental/as_tinkerscript_client.py @@ -223,7 +223,7 @@ def _wait_until_status_change_to(self, desired_status="ENGINE.ROLLING"): break # Wait a bit before next poll - time.sleep(1) + time.sleep(5) except Exception as e: logger.error(f"Error polling engine status: {e}") @@ -242,7 +242,7 @@ def get_engine_status(self) -> str: return result except Exception as e: logger.error(f"Error getting engine status: {e}") - return "unknown" + return "ENGINE.CANNOT_CONNECT" def can_continue_episode(self, episode_uuid: str) -> bool: if not episode_uuid: @@ -303,6 +303,10 @@ def auto_sync_train_config_and_start_engine(self, agent_jet_job: AgentJetJob, fo logger.info("Engine is BOOTING. Waiting until it becomes ROLLING...") self._wait_until_status_change_to(desired_status="ENGINE.ROLLING") logger.success("Training engine is now ROLLING and ready.") + elif current_status == "ENGINE.CANNOT_CONNECT": + logger.error("Cannot connect to the engine. Please check the network.") + self._wait_until_status_change_to(desired_status="ENGINE.ROLLING") + logger.success("Training engine is now ROLLING and ready.") else: raise RuntimeError(f"Cannot sync train config or start engine when engine is in status: {current_status}") @@ -331,4 +335,3 @@ def stop_engine(self): self._wait_until_status_change_to(desired_status="ENGINE.OFFLINE") except Exception as e: logger.error(f"Error stopping engine: {e}") - diff --git a/ajet/tuner_lib/weight_tuner/experimental/as_tinkerscript_server.py b/ajet/tuner_lib/weight_tuner/experimental/as_tinkerscript_server.py index 20d503f3..dad07b6b 100644 --- a/ajet/tuner_lib/weight_tuner/experimental/as_tinkerscript_server.py +++ b/ajet/tuner_lib/weight_tuner/experimental/as_tinkerscript_server.py @@ -4,10 +4,11 @@ import os import asyncio import threading -from multiprocessing.managers import DictProxy -from types import SimpleNamespace from loguru import logger +from functools import lru_cache +from types import SimpleNamespace from fastapi import FastAPI, HTTPException +from multiprocessing.managers import DictProxy from typing import Coroutine, Optional, Tuple, List from ajet.tuner_lib.weight_tuner.experimental.interchange_utils import ( SyncTrainConfigRequest, @@ -22,9 +23,23 @@ BoolResponse, RegisterEpisodeRequest, UpdateEngineStatusRequest, + VALID_STATUSES, ) -DEBUG = False +DEBUG = True +RCVTIMEO = 2 * 1000 +RCVTIMEO_OUT = 300 * 1000 +RCVTIMEO_WAIT_N = RCVTIMEO_OUT // RCVTIMEO + + +def is_key_epsisode_status(key: str) -> bool: + return key.startswith("episodes-") + + +@lru_cache(maxsize=128) +def ep_key(episode_uuid: str) -> str: + return f"episodes-{episode_uuid}" + def register_enable_tinkerscript_mode_routes( app, @@ -49,63 +64,81 @@ def find_claimed_episodes_that_need_to_be_unclaimed() -> List[str]: current_time = time.time() for k, v in shared_mem_dict.items(): - if k.startswith("episodes-"): + if is_key_epsisode_status(k): es:EpisodeStatus = v if es.episode_status == "claimed": if (current_time - es.latest_activity_timestamp) > es.allow_discard_timeout: result.append(es.episode_uuid) for episode_uuid in result: - _revert_episode_to_unclaimed(episode_uuid) + _revert_episode_to_unclaimed(episode_uuid, shared_mem_dict, shared_mem_dict_lock) return result def _context_tracker_reset(episode_uuid, shared_mem_dict): # send message to context tracker assert 'episodes' in shared_mem_dict - zmq_addr = shared_mem_dict[f"episodes-{episode_uuid}"].zmq_listen_result_addr + zmq_addr = shared_mem_dict[ep_key(episode_uuid)].zmq_listen_result_addr socket = zmq_context.socket(zmq.REQ) - socket.setsockopt(zmq.RCVTIMEO, 60*1000) # 1 minute recv timeout + socket.setsockopt(zmq.RCVTIMEO, RCVTIMEO) # 2 seconds recv timeout socket.connect(zmq_addr) + # # : ajet/task_runner/tinkerscript_runner.py # : message = zmq_socket.recv_string() socket.send_string("RUNNER.SPECIAL.RESET_CONTEXT_TRACKER") + # - for _ in range(5): # max 5 minutes wait + for _ in range(RCVTIMEO_WAIT_N): # max 5 minutes wait try: if DEBUG: logger.info(f"[server] episode_uuid: {episode_uuid} | recv_string begin.") # : # : ajet/task_runner/tinkerscript_runner.py # : zmq_socket.send_string("ack") # : "ack" - result_str = socket.recv_string() + socket.recv_string() break except zmq.Again as e: if DEBUG: logger.info(f"[server] episode_uuid: {episode_uuid} | recv_string timeout, retrying.") + + if shared_mem_dict["engine_status"] not in ["ENGINE.ROLLING", "ENGINE.ROLLING_POST"]: + logger.info(f"[server] episode_uuid: {episode_uuid} | Engine is no longer rolling, aborting wait for ack.") + raise RuntimeError("Engine is no longer rolling, aborting wait for ack.") continue - def _revert_episode_to_unclaimed(episode_uuid: str): - with shared_mem_dict_lock: - # check status again, because other thread may have changed it - if shared_mem_dict[f"episodes-{episode_uuid}"].episode_status != "claimed": - return + def _revert_episode_to_unclaimed(episode_uuid: str, shared_mem_dict, shared_mem_dict_lock): + # check status again, because other thread may have changed it + if shared_mem_dict[ep_key(episode_uuid)].episode_status != "claimed": + return + with shared_mem_dict_lock: # reset context tracker _context_tracker_reset(episode_uuid, shared_mem_dict) # revert logger.warning(f"Reverting episode {episode_uuid} to unclaimed due to client timeout.") - if f"episodes-{episode_uuid}" in shared_mem_dict: - es:EpisodeStatus = shared_mem_dict[f"episodes-{episode_uuid}"] + if ep_key(episode_uuid) in shared_mem_dict: + es:EpisodeStatus = shared_mem_dict[ep_key(episode_uuid)] es.episode_status = "registered" es.client_uuid = "" es.latest_activity_timestamp = time.time() es.allow_discard_timeout = -1 - shared_mem_dict[f"episodes-{episode_uuid}"] = es - shared_mem_dict['unclaimed_episodes'] += [episode_uuid] + shared_mem_dict[ep_key(episode_uuid)] = es + if episode_uuid in shared_mem_dict['unclaimed_episodes']: + pass + else: + shared_mem_dict['unclaimed_episodes'] += [episode_uuid] + def _delete_episode_record(episode_uuid: str, shared_mem_dict, shared_mem_dict_lock): + with shared_mem_dict_lock: + # remove episode record + if ep_key(episode_uuid) in shared_mem_dict: + del shared_mem_dict[ep_key(episode_uuid)] + logger.info(f"Deleted episode record for {episode_uuid}.") + # remove from unclaimed list if present + if episode_uuid in shared_mem_dict['unclaimed_episodes']: + shared_mem_dict['unclaimed_episodes'].remove(episode_uuid) # -------------------------------------------------------------------------------------- @@ -113,31 +146,35 @@ def _revert_episode_to_unclaimed(episode_uuid: str): # -------------------------------------------------------------------------------------- def _register_final_episode_output(episode_uuid, workflow_output, shared_mem_dict, shared_mem_dict_lock): + # begin send workflow_output - zmq_addr = shared_mem_dict[f"episodes-{episode_uuid}"].zmq_listen_result_addr + zmq_addr = shared_mem_dict[ep_key(episode_uuid)].zmq_listen_result_addr if DEBUG: logger.info(f"[server] episode_uuid: {episode_uuid} | Received new chat completion request (inside thread)") socket = zmq_context.socket(zmq.REQ) - socket.setsockopt(zmq.RCVTIMEO, 60*1000) # 1 minute recv timeout + socket.setsockopt(zmq.RCVTIMEO, RCVTIMEO) # 2 seconds recv timeout socket.connect(zmq_addr) if DEBUG: logger.info(f"[server] episode_uuid: {episode_uuid} | connect done") socket.send_string(workflow_output.model_dump_json()) if DEBUG: logger.info(f"[server] episode_uuid: {episode_uuid} | send_string") # wait for ack - for _ in range(5): # max 5 minutes wait + for _ in range(RCVTIMEO_WAIT_N): # max 5 minutes wait try: if DEBUG: logger.info(f"[server] episode_uuid: {episode_uuid} | recv_string begin.") # : # : ajet/task_runner/tinkerscript_runner.py # : zmq_socket.send_string("ack") # : "ack" - result_str = socket.recv_string() + socket.recv_string() break except zmq.Again as e: if DEBUG: logger.info(f"[server] episode_uuid: {episode_uuid} | recv_string timeout, retrying.") + if shared_mem_dict["engine_status"] not in ["ENGINE.ROLLING", "ENGINE.ROLLING_POST"]: + logger.info(f"[server] episode_uuid: {episode_uuid} | Engine is no longer rolling, aborting wait for ack.") + raise RuntimeError("Engine is no longer rolling, aborting wait for ack.") continue # clean up episode records with shared_mem_dict_lock: - del shared_mem_dict[f"episodes-{episode_uuid}"] + del shared_mem_dict[ep_key(episode_uuid)] if episode_uuid in shared_mem_dict['unclaimed_episodes']: shared_mem_dict['unclaimed_episodes'].remove(episode_uuid) @@ -148,37 +185,23 @@ def _register_final_episode_output(episode_uuid, workflow_output, shared_mem_dic # -------------------------------------------------------------------------------------- async def register_episode_ready_listener(): - pass - # while True: - # read_all_episode_status() - # await asyncio.sleep(10) # check every 10 seconds - # find_claimed_episodes_that_need_to_be_unclaimed() + while True: + await asyncio.sleep(10) # check every 10 seconds + find_claimed_episodes_that_need_to_be_unclaimed() + read_all_episode_status() def read_all_episode_status() -> Optional[EpisodeStatus]: - print_buffer = [] group_by_status = {} for k, v in shared_mem_dict.items(): - if k.startswith("episodes-"): + if is_key_epsisode_status(k): es:EpisodeStatus = v if es.episode_status not in group_by_status: group_by_status[es.episode_status] = [] group_by_status[es.episode_status].append(es) - # for status, es_list in group_by_status.items(): - # print_buffer.append(f"{status} (time since last activity)") - # in_line_buffer = "" - # for es in es_list: - # time_since_last_activity = time.time() - es.latest_activity_timestamp - # in_line_buffer += f"{es.episode_uuid[:6]}({time_since_last_activity:.1f}s)\t" - # print_buffer.append(in_line_buffer) - - print_buffer_str = "\n".join(print_buffer) - logger.info(f"Current engine status: [{shared_mem_dict['engine_status']}]") - if print_buffer: - logger.info(f"Current episode statuses:\n{print_buffer_str}") - else: - logger.info(f"Current episode statuses: [NA]") + print_buffer_str = f"Registered: {len(group_by_status.get('registered', []))}, Claimed: {len(group_by_status.get('claimed', []))}" + logger.info(f"Current engine status: [{shared_mem_dict['engine_status']}], " + print_buffer_str) return None @@ -193,6 +216,10 @@ async def sync_train_config(req: SyncTrainConfigRequest): Receive training configuration from client as YAML string. Store it in shared memory for later use by start_engine. """ + + if shared_mem_dict['engine_status'] != "ENGINE.OFFLINE": + raise HTTPException(status_code=400, detail="Engine is already started. Call `stop_engine` first before syncing new training configuration.") + try: yaml_str = req.yaml_as_string logger.info("[sync_train_config] Received training configuration") @@ -283,11 +310,15 @@ def override_param_callback(config): ) p.daemon = True p.start() + # wait until p.pid is available while not isinstance(p.pid, int): time.sleep(1) + # set new process group os.setpgid(p.pid, p.pid) + # Store process info in shared memory + clean_up_engine_status(shared_mem_dict_lock, shared_mem_dict) with shared_mem_dict_lock: shared_mem_dict['training_process_pid'] = p.pid shared_mem_dict['engine_status'] = "ENGINE.BOOTING" @@ -303,19 +334,32 @@ def override_param_callback(config): # --- engine status --- - shared_mem_dict['engine_status'] = "ENGINE.OFFLINE" + shared_mem_dict['engine_status'] = "ENGINE.OFFLINE" # initial status + def clean_up_engine_status(shared_mem_dict_lock, shared_mem_dict): + with shared_mem_dict_lock: + episode_keys = [k for k in shared_mem_dict.keys() if is_key_epsisode_status(k)] + # remove all episodes + for key in episode_keys: + del shared_mem_dict[key] + logger.info(f"[clean_up_engine_status] Removed episode: {key}") + + # clear unclaimed episodes list + if 'unclaimed_episodes' in shared_mem_dict: + num_unclaimed = len(shared_mem_dict['unclaimed_episodes']) + shared_mem_dict['unclaimed_episodes'] = [] + logger.info(f"[clean_up_engine_status] Cleared {num_unclaimed} unclaimed episodes") + @app.post("/update_engine_status", response_model=BoolResponse) async def update_engine_status(req: UpdateEngineStatusRequest): """Update the current engine status.""" - if req.engine_status not in [ - "ENGINE.OFFLINE", - "ENGINE.BOOTING", - "ENGINE.ROLLING", - "ENGINE.WEIGHT_SYNCING", - "ENGINE.WEIGHT_EXPORTING" - ]: + if req.engine_status not in VALID_STATUSES: return BoolResponse(success=False, failure_reason="Invalid engine status") + previous_status = shared_mem_dict['engine_status'] shared_mem_dict['engine_status'] = req.engine_status + if previous_status in ["ENGINE.ROLLING", "ENGINE.ROLLING_POST"] and req.engine_status not in ["ENGINE.ROLLING", "ENGINE.ROLLING_POST"]: + clean_up_engine_status(shared_mem_dict_lock, shared_mem_dict) + + logger.info(f"[update_engine_status] Engine status set to {req.engine_status}") return BoolResponse(success=True) @@ -330,6 +374,13 @@ async def get_engine_status(): @app.post("/register_episode", response_model=BoolResponse) async def register_episode(req: RegisterEpisodeRequest): """(From task_runner) Register a new episode as ready to roll.""" + engine_status = shared_mem_dict['engine_status'] + if engine_status not in ["ENGINE.ROLLING"]: + return BoolResponse( + success=False, + failure_reason=f"Engine is not in rolling state. Cannot register episode.", + ) + episode_uuid = req.episode_uuid es = EpisodeStatus( episode_uuid=req.episode_uuid, @@ -342,27 +393,19 @@ async def register_episode(req: RegisterEpisodeRequest): es.latest_activity_timestamp = time.time() with shared_mem_dict_lock: - engine_status = shared_mem_dict['engine_status'] - if engine_status not in ["ENGINE.ROLLING"]: - return BoolResponse( - success=False, - failure_reason=f"Engine has already shutdown. Cannot register episode.", - ) - - shared_mem_dict[f"episodes-{episode_uuid}"] = es + shared_mem_dict[ep_key(episode_uuid)] = es shared_mem_dict['unclaimed_episodes'] += [req.episode_uuid] - return BoolResponse( - success=True, - ) + return BoolResponse(success=True) @app.post("/claim_episode", response_model=ClaimEpisodeResponse) async def claim_episode(req: ClaimEpisodeRequest): """(From client) Claim an available episode to rollout.""" - find_claimed_episodes_that_need_to_be_unclaimed() + # find_claimed_episodes_that_need_to_be_unclaimed() engine_status = shared_mem_dict['engine_status'] + if engine_status != "ENGINE.ROLLING": fail_cause = f"Engine not ready. Current status: [{engine_status}]." advise = "" @@ -374,6 +417,8 @@ async def claim_episode(req: ClaimEpisodeRequest): advise = "Engine is syncing weights. Try again (maybe 1 minute) later." elif engine_status == "ENGINE.WEIGHT_EXPORTING": advise = "Engine is exporting weights (fsdp -> hf safetensor). Try again (maybe 1 minute) later." + elif engine_status == "ENGINE.ROLLING_POST": + advise = "Engine is in post-rolling phase. Try again (maybe 1 minute) later." return ClaimEpisodeResponse( success=False, client_uuid=req.client_uuid, @@ -401,14 +446,14 @@ async def claim_episode(req: ClaimEpisodeRequest): shared_mem_dict['unclaimed_episodes'] = shared_mem_dict['unclaimed_episodes'][1:] # get episode - es:EpisodeStatus = shared_mem_dict[f"episodes-{episode_uuid}"] + es:EpisodeStatus = shared_mem_dict[ep_key(episode_uuid)] es.episode_status = "claimed" es.episode_type = req.episode_type es.client_uuid = req.client_uuid es.latest_activity_timestamp = time.time() es.allow_discard_timeout = req.allow_discard_timeout - shared_mem_dict[f"episodes-{episode_uuid}"] = es + shared_mem_dict[ep_key(episode_uuid)] = es openai_base_url = es.openai_base_url openai_api_key = es.openai_api_key @@ -427,31 +472,41 @@ async def claim_episode(req: ClaimEpisodeRequest): @app.post("/end_episode", response_model=EndEpisodeResponse) async def end_episode(req: EndEpisodeRequest): + + engine_status = shared_mem_dict['engine_status'] + if engine_status not in ["ENGINE.ROLLING", "ENGINE.ROLLING_POST"]: + raise HTTPException(status_code=400, detail=f"Engine is not in rolling state. Current status: [{engine_status}]. Cannot end episode.") + # receive workflow output data client_uuid = req.client_uuid episode_uuid = req.episode_uuid workflow_output = req.workflow_output task_id = req.task_id + assert "task_id" in workflow_output.metadata, "workflow_output.metadata must contain task_id" assert workflow_output.metadata["task_id"] == task_id, "workflow_output.metadata.task_id must match req.task_id" if 'episodes' not in shared_mem_dict: logger.error(f"[server] No episodes registered yet.") raise HTTPException(status_code=400, detail=f"No episodes registered yet.") - if (f"episodes-{episode_uuid}") not in shared_mem_dict: + + if (ep_key(episode_uuid)) not in shared_mem_dict: logger.error(f"[server] Episode {episode_uuid} not found.") raise HTTPException(status_code=400, detail=f"Episode {episode_uuid} not found.") # send workflow_output to zmq assert 'episodes' in shared_mem_dict - episode_type = shared_mem_dict[f"episodes-{episode_uuid}"].episode_type + episode_type = shared_mem_dict[ep_key(episode_uuid)].episode_type if episode_type == "train": _register_final_episode_output(episode_uuid, workflow_output, shared_mem_dict, shared_mem_dict_lock) elif episode_type == "eval": - _revert_episode_to_unclaimed(episode_uuid) + if engine_status in ["ENGINE.ROLLING"]: + _revert_episode_to_unclaimed(episode_uuid, shared_mem_dict, shared_mem_dict_lock) + else: + _delete_episode_record(episode_uuid, shared_mem_dict, shared_mem_dict_lock) else: raise HTTPException(status_code=400, detail=f"Unknown episode_type: {episode_type}") @@ -462,6 +517,11 @@ async def end_episode(req: EndEpisodeRequest): @app.post("/abort_episode", response_model=EndEpisodeResponse) async def abort_episode(req: EndEpisodeRequest): + + engine_status = shared_mem_dict['engine_status'] + if engine_status not in ["ENGINE.ROLLING", "ENGINE.ROLLING_POST"]: + return EndEpisodeResponse(success=True) + # receive workflow output data episode_uuid = req.episode_uuid workflow_output = req.workflow_output @@ -473,28 +533,46 @@ async def abort_episode(req: EndEpisodeRequest): if 'episodes' not in shared_mem_dict: logger.error(f"[server] No episodes registered yet.") return EndEpisodeResponse(success=True) - if (f"episodes-{episode_uuid}") not in shared_mem_dict: + + if (ep_key(episode_uuid)) not in shared_mem_dict: logger.error(f"[server] Episode {episode_uuid} not found.") return EndEpisodeResponse(success=True) - _revert_episode_to_unclaimed(episode_uuid) + if engine_status in ["ENGINE.ROLLING"]: + _revert_episode_to_unclaimed(episode_uuid, shared_mem_dict, shared_mem_dict_lock) + else: + _delete_episode_record(episode_uuid, shared_mem_dict, shared_mem_dict_lock) - # return success return EndEpisodeResponse(success=True) @app.post("/can_continue_episode", response_model=CanContinueEpisodeResponse) async def can_continue_episode(req: CanContinueEpisodeRequest): - can_continue = (f"episodes-{req.episode_uuid}" in shared_mem_dict) - can_continue = can_continue and shared_mem_dict[f"episodes-{req.episode_uuid}"]["episode_status"] == "claimed" + engine_status = shared_mem_dict['engine_status'] + if engine_status not in ["ENGINE.ROLLING", "ENGINE.ROLLING_POST"]: + return CanContinueEpisodeResponse(can_continue=False) + + can_continue = (ep_key(req.episode_uuid) in shared_mem_dict) + can_continue = can_continue and shared_mem_dict[ep_key(req.episode_uuid)].episode_status == "claimed" + return CanContinueEpisodeResponse(can_continue=can_continue) + @app.post("/is_episode_claimed", response_model=BoolResponse) + async def is_episode_claimed(req: CanContinueEpisodeRequest): + engine_status = shared_mem_dict['engine_status'] + if engine_status not in ["ENGINE.ROLLING", "ENGINE.ROLLING_POST"]: + return BoolResponse(success=False) + if shared_mem_dict[ep_key(req.episode_uuid)].episode_status == "claimed": + return BoolResponse(success=True) + else: + return BoolResponse(success=False) + @app.post("/get_episode_buffer", response_model=EpisodeBufferResponse) async def get_episode_buffer(): result = [ - v for k, v in shared_mem_dict.items() if k.startswith("episodes-") + v for k, v in shared_mem_dict.items() if is_key_epsisode_status(k) ] return EpisodeBufferResponse(buffer=result) @@ -521,6 +599,7 @@ async def stop_engine(): def kill_process_tree(shared_mem_dict_lock=None, shared_mem_dict=None): + logger.exception("[stop_engine] Initiating engine shutdown and cleanup...") try: import psutil @@ -589,7 +668,7 @@ def kill_process_tree(shared_mem_dict_lock=None, shared_mem_dict=None): episode_keys = [] if shared_mem_dict and shared_mem_dict_lock: with shared_mem_dict_lock: - episode_keys = [k for k in shared_mem_dict.keys() if k.startswith("episodes-")] + episode_keys = [k for k in shared_mem_dict.keys() if is_key_epsisode_status(k)] for key in episode_keys: del shared_mem_dict[key] logger.info(f"[stop_engine] Removed episode: {key}") diff --git a/ajet/tuner_lib/weight_tuner/experimental/interchange_utils.py b/ajet/tuner_lib/weight_tuner/experimental/interchange_utils.py index 949ee679..fe7c3387 100644 --- a/ajet/tuner_lib/weight_tuner/experimental/interchange_utils.py +++ b/ajet/tuner_lib/weight_tuner/experimental/interchange_utils.py @@ -3,10 +3,19 @@ import httpx from typing import List from pydantic import BaseModel -from ajet.schema.task import WorkflowOutput from loguru import logger +from ajet.schema.task import WorkflowOutput from ajet.utils.networking import find_free_port +from ajet.utils.retry import retry_with_backoff +VALID_STATUSES = [ + "ENGINE.OFFLINE", + "ENGINE.BOOTING", + "ENGINE.ROLLING", + "ENGINE.ROLLING_POST", + "ENGINE.WEIGHT_SYNCING", + "ENGINE.WEIGHT_EXPORTING" +] class SyncTrainConfigRequest(BaseModel): yaml_as_string: str @@ -83,14 +92,8 @@ def get_interchange_server_url(config): return base_url -def http_change_engine_status(config: str, new_status: str): - if new_status not in [ - "ENGINE.OFF", - "ENGINE.BOOTING", - "ENGINE.ROLLING", - "ENGINE.WEIGHT_SYNCING", - "ENGINE.WEIGHT_EXPORTING" - ]: +def http_change_engine_status(config, new_status: str): + if new_status not in VALID_STATUSES: raise ValueError(f"Invalid engine status: {new_status}") resp = httpx.post( @@ -102,10 +105,28 @@ def http_change_engine_status(config: str, new_status: str): logger.success(f"Changed engine status to {new_status}") +def is_episode_claimed(config, episode_uuid: str) -> bool: + resp = httpx.post( + f"{get_interchange_server_url(config)}/is_episode_claimed", + json={"client_uuid": "", "episode_uuid": episode_uuid}, + timeout=5 + ) + resp.raise_for_status() + result = BoolResponse.model_validate(resp.json()) + return result.success + + +@retry_with_backoff(max_retry=15, backoff_fn=lambda attempt: 2) +def http_register_episode(config, + episode_uuid: str, + openai_base_url: str, + openai_api_key: str, + zmq_listen_result_addr: str, + should_exit_soft): -def http_register_episode(config, episode_uuid: str, - openai_base_url: str, openai_api_key: str, - zmq_listen_result_addr: str): + if should_exit_soft(): + logger.warning(f"Exiting before registering episode {episode_uuid}") + return None # parse episode_uuid, openai_base_url, openai_api_key interchange_http_addr = get_interchange_server_url(config) @@ -116,19 +137,19 @@ def http_register_episode(config, episode_uuid: str, zmq_listen_result_addr=zmq_listen_result_addr, ) # send http request to tinkerscript server to register episode - response = httpx.post( f"{interchange_http_addr}/register_episode", json=rer.model_dump(), # 或者 rer.model_dump() 如果使用 Pydantic v2 - timeout=30 + timeout=2 ) response.raise_for_status() result = response.json() if not result.get('success'): - raise RuntimeError(f"Failed to register episode {episode_uuid}") + logger.warning(f"Failed to register episode {episode_uuid}") + return None if DEBUG: logger.info(f"Successfully registered episode {episode_uuid}") - return rer + return True def get_zmq_socket(config, episode_uuid: str, tag: str = ""): diff --git a/ajet/utils/retry.py b/ajet/utils/retry.py index 9773c255..7f33466b 100644 --- a/ajet/utils/retry.py +++ b/ajet/utils/retry.py @@ -1,11 +1,11 @@ import time from functools import wraps from typing import Any, Callable, Optional, TypeVar - from loguru import logger -from ajet.utils.testing_utils import TestFailException, TestSuccessException -from ajet.task_runner.tinkerscript_runner import SwarmReceiveAbortException +class SwarmReceiveAbortException(Exception): + pass + T = TypeVar("T") @@ -18,6 +18,8 @@ def retry_with_backoff( """Retry decorator with exponential backoff and structured logging.""" def decorator(func: Callable[..., T]) -> Callable[..., T]: + from ajet.utils.testing_utils import TestFailException, TestSuccessException + @wraps(func) def wrapper(*args: Any, **kwargs: Any) -> T: target_max_retry = max_retry @@ -36,7 +38,6 @@ def wrapper(*args: Any, **kwargs: Any) -> T: raise exc except TestFailException as exc: # noqa: BLE001 raise exc - except Exception as exc: # noqa: BLE001 if attempt < target_max_retry - 1: logger.bind(exception=True).exception( diff --git a/ajet/workflow.py b/ajet/workflow.py index 58c8757d..b2eaaf16 100644 --- a/ajet/workflow.py +++ b/ajet/workflow.py @@ -2,7 +2,7 @@ from pydantic import BaseModel, Field -from ajet import AjetTuner +from ajet.tuner import AjetTuner from ajet.schema.task import WorkflowOutput, WorkflowTask diff --git a/ajet_tinkerscript_threading.py b/ajet_tinkerscript_threading.py index 52546320..7b31d4e7 100644 --- a/ajet_tinkerscript_threading.py +++ b/ajet_tinkerscript_threading.py @@ -1,27 +1,29 @@ import re +import time import threading import requests from loguru import logger from textwrap import dedent -from ajet.copilot.job import AgentJetJob -from ajet.tuner_lib.weight_tuner.experimental.as_tinkerscript_client import TinkerScriptClient -from ajet.default_config.ajet_default import AjetTaskReader, HuggingfaceDatRepo -from ajet.tuner_lib.weight_tuner.as_oai_baseurl_apikey import OpenaiBaseUrlAndApiKey from ajet import WorkflowOutput +from ajet.schema.task import Task +from ajet.copilot.job import AgentJetJob from ajet.task_reader import RouterTaskReader from ajet.utils.retry import retry_with_backoff -from ajet.schema.task import Task +from ajet.default_config.ajet_default import AjetTaskReader, HuggingfaceDatRepo +from ajet.tuner_lib.weight_tuner.as_oai_baseurl_apikey import OpenaiBaseUrlAndApiKey +from ajet.tuner_lib.weight_tuner.experimental.as_tinkerscript_client import TinkerScriptClient from concurrent.futures import ThreadPoolExecutor # --------- configurations that take effect locally ------------- LOCAL_GRPO_N = 4 # grpo group size LOCAL_NUM_EPOCH = 10000 LOCAL_NUM_EPOCH = 1 -LOCAL_MAX_PARALLEL = 32 +LOCAL_MAX_PARALLEL = 8 LOCAL_DATASET_PATH = "/mnt/data_cpfs/qingxu.fu/dataset/openai/gsm8k/main" REMOTE_TINKERJET_URL = "http://localhost:10086" # Change to your tinkerscript remote url # --------- configurations that take effect remotely ------------- +REMOTE_BATCH_SIZE = 4 REMOTE_ALLOCATE_GPU_PER_NODE = 4 REMOTE_TRAIN_MODEL_01 = '/mnt/data_cpfs/model_cache/modelscope/hub/Qwen/Qwen/Qwen2.5-7B-Instruct' @@ -43,51 +45,53 @@ def main(): ) ) - # Hand shake with remote tinkerscript server + # # Hand shake with remote tinkerscript server tinkerscript_remote = TinkerScriptClient(REMOTE_TINKERJET_URL) - # tinkerscript_remote.stop_engine() - tinkerscript_remote.auto_sync_train_config_and_start_engine( - AgentJetJob( - algorithm="grpo", - n_gpu=REMOTE_ALLOCATE_GPU_PER_NODE, - model=REMOTE_TRAIN_MODEL_01, - grpo_n=LOCAL_GRPO_N, - ) - ) - # tinkerscript_remote = connect_to_tinkerscript_server(sync_train_config=False, start_engine=False) - submit_sem = threading.BoundedSemaphore(LOCAL_MAX_PARALLEL) + # tinkerscript_remote.auto_sync_train_config_and_start_engine( + # AgentJetJob( + # algorithm="grpo", + # n_gpu=REMOTE_ALLOCATE_GPU_PER_NODE, + # model=REMOTE_TRAIN_MODEL_01, + # batch_size=REMOTE_BATCH_SIZE, + # grpo_n=LOCAL_GRPO_N, + # ) + # ) - # Define rollout def rollout(task): + group_reward = [] try: - group_reward = [] - for i in range(LOCAL_GRPO_N): - # begin episode - episode_uuid, api_baseurl_key = tinkerscript_remote.begin_episode() - # execute agent - workflow_output = execute_agent(task, api_baseurl_key) - # report output back to tinkerscript remote - tinkerscript_remote.end_episode(task, episode_uuid, workflow_output) - # collect reward - group_reward.append(workflow_output.reward) + for _ in range(LOCAL_GRPO_N): + try: + # begin episode + episode_uuid, api_baseurl_key = tinkerscript_remote.begin_episode() + # execute agent + workflow_output = execute_agent(task, api_baseurl_key) + # report output back to tinkerscript remote + tinkerscript_remote.end_episode(task, episode_uuid, workflow_output) + # collect reward + group_reward.append(workflow_output.reward) + except Exception as e: + logger.exception("Exception during rollout:", e) + print(f"Group reward mean & std: {sum(group_reward)/len(group_reward)} +/- { (max(group_reward)-min(group_reward))/2 }") except Exception as e: - logger.exception("Exception during rollout:", e) - finally: - submit_sem.release() - - # Main Training loop - with ThreadPoolExecutor(max_workers=LOCAL_MAX_PARALLEL) as executor: - for epoch in range(LOCAL_NUM_EPOCH): - for task in dataset.get_training_tasks(): - print(f"Submitting task for epoch {epoch}") - submit_sem.acquire() - executor.submit(rollout, task) - - tinkerscript_remote.stop_engine() - # model_path = tinkerscript_remote.download_latest_model(path='./tinkerscript_saved_model') - - # Get tuned model from tinkerscript remote + logger.exception("Exception during rollout group", e) + + task_batch = [] + for i, task in enumerate(dataset.get_training_tasks()): + task_batch += [task] + + if len(task_batch) == 3*REMOTE_BATCH_SIZE: + print('*********** beginning a new batch of tasks... ***********') + with ThreadPoolExecutor(max_workers=LOCAL_MAX_PARALLEL) as executor: + for task in task_batch: + executor.submit(rollout, task) + executor.shutdown(wait=True) + task_batch = [] + print('*********** tasks completed, wait a minute... ***********') + time.sleep(60) + + return None @@ -109,7 +113,7 @@ def execute_agent(task: Task, api_baseurl_key: OpenaiBaseUrlAndApiKey): response = requests.post( f"{base_url}/chat/completions", json = { "model": "fill_whatever_model", "messages": messages, }, headers = { "Authorization": f"Bearer {api_key}" } ) final_answer = response.json()['choices'][0]['message']['content'] - print(final_answer) + # print(final_answer) # Compute reward reference_answer = reference_answer.split("####")[-1].strip() pattern = r"\\boxed\{([^}]*)\}" diff --git a/tutorial/example_academic_trans/trans_roll.py b/tutorial/example_academic_trans/trans_roll.py index e060d895..f7c50730 100644 --- a/tutorial/example_academic_trans/trans_roll.py +++ b/tutorial/example_academic_trans/trans_roll.py @@ -94,9 +94,9 @@ def rollout(task): time.sleep(1) - tinkerscript_remote.stop_engine() + # tinkerscript_remote.stop_engine() # model_path = tinkerscript_remote.download_latest_model(path='./tinkerscript_saved_model') - + time.sleep(10000) # Get tuned model from tinkerscript remote return None From 175e259e6742d25efc132f8822db6ff2b202b4e1 Mon Sep 17 00:00:00 2001 From: "qingxu.fu" Date: Thu, 5 Feb 2026 03:32:24 +0800 Subject: [PATCH 19/25] rename to agentjet swarm --- ajet/backbone/main_vllm.py | 6 +- ajet/backbone/trainer_verl.py | 10 +-- ajet/context_tracker/multiagent_tracking.py | 6 +- ajet/copilot/job.py | 4 +- ajet/default_config/ajet_default.yaml | 6 +- ajet/default_config/ajet_ts_default.yaml | 6 +- ajet/launcher.py | 20 +++--- .../task_reader/document_reader/doc_reader.py | 2 +- ajet/task_rollout/native_parallel_worker.py | 2 +- ajet/task_rollout/single_worker.py | 8 +-- ajet/task_runner/base_runner.py | 11 +++- ...tinkerscript_runner.py => swarm_runner.py} | 18 ++---- ajet/tuner.py | 1 - .../experimental/as_oai_model_client.py | 37 ++++++----- .../experimental/as_oai_model_server.py | 36 +++++------ ...kerscript_client.py => as_swarm_client.py} | 17 +++--- ...kerscript_server.py => as_swarm_server.py} | 61 +++++++++++-------- .../experimental/interchange_utils.py | 3 +- ...pt_threading.py => ajet_swarm_threading.py | 45 +++++++------- ajet_tinkerscript.md | 2 +- docs/en/platform_comparison.md | 2 +- docs/en/workflow.md | 2 +- docs/index.md | 2 +- tinkerscript.md | 18 +++--- tinkerscript_1.md | 12 ++-- tutorial/demo_tinkerjet/README.md | 14 ++--- .../ajet_tinkerscript_default.yaml | 4 +- tutorial/example_academic_trans/trans.py | 2 +- tutorial/example_academic_trans/trans_roll.py | 28 ++++----- 29 files changed, 202 insertions(+), 183 deletions(-) rename ajet/task_runner/{tinkerscript_runner.py => swarm_runner.py} (92%) rename ajet/tuner_lib/weight_tuner/experimental/{as_tinkerscript_client.py => as_swarm_client.py} (96%) rename ajet/tuner_lib/weight_tuner/experimental/{as_tinkerscript_server.py => as_swarm_server.py} (92%) rename ajet_tinkerscript_threading.py => ajet_swarm_threading.py (77%) diff --git a/ajet/backbone/main_vllm.py b/ajet/backbone/main_vllm.py index c697bff4..2cdde610 100644 --- a/ajet/backbone/main_vllm.py +++ b/ajet/backbone/main_vllm.py @@ -144,7 +144,7 @@ def run(config): max_parallel = config.ajet.debug.debug_max_parallel n_task = config.ajet.debug.debug_first_n_tasks vllm_port = config.ajet.debug.debug_vllm_port - enable_tinkerscript_mode = config.ajet.enable_tinkerscript_mode + enable_swarm_mode = config.ajet.enable_swarm_mode # --------- init --------- async_rollout_manager = ChatCompletionScheduler( @@ -168,7 +168,7 @@ def run(config): logger.info(tasks[:n_task]) ctx_tracker = parallel_env.rollout( tasks=tasks[:n_task], - mode="sample" if not enable_tinkerscript_mode else "sample-ts", # type: ignore + mode="sample" if not enable_swarm_mode else "sample-ts", # type: ignore epoch="1" ) _ = parallel_env.to_dataproto(ctx_tracker) @@ -189,7 +189,7 @@ def main(config): if config.ajet.enable_experimental_interchange_server: from ajet.tuner_lib.weight_tuner.experimental.as_oai_model_server import start_interchange_server start_interchange_server(config) - if config.ajet.enable_tinkerscript_mode: + if config.ajet.enable_swarm_mode: from ajet.tuner_lib.weight_tuner.experimental.interchange_utils import http_change_engine_status http_change_engine_status(config, "ENGINE.ROLLING") diff --git a/ajet/backbone/trainer_verl.py b/ajet/backbone/trainer_verl.py index 70e5570b..1ddb0fef 100644 --- a/ajet/backbone/trainer_verl.py +++ b/ajet/backbone/trainer_verl.py @@ -458,7 +458,7 @@ def init_workers(self): def _update_interchange_server_status_flag(self, status: str): if self.config.ajet.enable_experimental_interchange_server: - if self.config.ajet.enable_tinkerscript_mode: + if self.config.ajet.enable_swarm_mode: from ajet.tuner_lib.weight_tuner.experimental.interchange_utils import http_change_engine_status http_change_engine_status(self.config, status) @@ -493,7 +493,7 @@ def fit(self): # noqa: C901 # perform validation before training # currently, we only support validation using the reward_function. - if (self.val_reward_fn is not None) and (self.config.trainer.get("val_before_train", True)) and (not self.config.ajet.enable_tinkerscript_mode): + if (self.val_reward_fn is not None) and (self.config.trainer.get("val_before_train", True)) and (not self.config.ajet.enable_swarm_mode): val_metrics = self._validate() assert val_metrics, f"{val_metrics=}" pprint(f"Initial validation metrics: {val_metrics}") @@ -651,7 +651,7 @@ def fit(self): # noqa: C901 [str(uuid.uuid4()) for _ in range(len(batch.batch))], dtype=object, ) - discard_original_batch = self.config.ajet.enable_tinkerscript_mode + discard_original_batch = self.config.ajet.enable_swarm_mode batch = union_gen_batch_via_task_id(tasks, batch, gen_batch_output, discard_original_batch) batch.batch["response_mask"] = compute_response_mask(batch) @@ -784,7 +784,7 @@ def fit(self): # noqa: C901 self.val_reward_fn is not None and self.config.trainer.test_freq > 0 and (is_last_step or self.global_steps % self.config.trainer.test_freq == 0) - and (not self.config.ajet.enable_tinkerscript_mode) + and (not self.config.ajet.enable_swarm_mode) ): with marked_timer("testing", timing_raw, color="green"): val_metrics: dict = self._validate() @@ -958,7 +958,7 @@ def _validate(self): dtype=object, ) tasks = tasks[: len(main_val_dataset)] - discard_original_batch = self.config.ajet.enable_tinkerscript_mode + discard_original_batch = self.config.ajet.enable_swarm_mode test_batch = union_gen_batch_via_task_id(tasks, test_batch, test_output_gen_batch, discard_original_batch) # test_batch = test_batch.union(test_output_gen_batch) test_batch.meta_info["validate"] = True diff --git a/ajet/context_tracker/multiagent_tracking.py b/ajet/context_tracker/multiagent_tracking.py index 35607224..f0c3fef1 100644 --- a/ajet/context_tracker/multiagent_tracking.py +++ b/ajet/context_tracker/multiagent_tracking.py @@ -48,14 +48,14 @@ def __init__( self, tokenizer: PreTrainedTokenizer, config, - should_interrupt_fn, + should_interrupt_soft_fn, should_interrupt_hard_fn, generated_token_callback_fn, **kwargs, ): super().__init__(config, tokenizer, **kwargs) self.tokenizer = tokenizer - self.should_interrupt_fn = should_interrupt_fn + self.should_interrupt_soft_fn = should_interrupt_soft_fn self.should_interrupt_hard_fn = should_interrupt_hard_fn self.generated_token_callback_fn = generated_token_callback_fn self.context_overflow = False @@ -601,7 +601,7 @@ def check_context_token_num_safe( token_overflow = False else: token_overflow = True - if self.should_interrupt_fn(): + if self.should_interrupt_soft_fn(): ret = (False, token_overflow, "externally_interrupted") elif self.already_mad_flag and self.config.ajet.rollout.agent_madness_termination: ret = (False, token_overflow, "already_mad") diff --git a/ajet/copilot/job.py b/ajet/copilot/job.py index 2b6acde5..4f0f5c7b 100644 --- a/ajet/copilot/job.py +++ b/ajet/copilot/job.py @@ -45,11 +45,11 @@ def __init__( n_gpu_for_infer: int | None = None, # only for trinity backbone grpo_n: int = 8, batch_size: int = 32, - tinkerscript_mode: bool = True, + swarm_mode: bool = True, *kwargs, ) -> None: self.backbone = backbone - if tinkerscript_mode: + if swarm_mode: default_yaml = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', "default_config/ajet_ts_default.yaml")) else: default_yaml = None diff --git a/ajet/default_config/ajet_default.yaml b/ajet/default_config/ajet_default.yaml index 49a6c9a8..f1f65b0a 100644 --- a/ajet/default_config/ajet_default.yaml +++ b/ajet/default_config/ajet_default.yaml @@ -282,8 +282,8 @@ ajet: # the experimental ZeroMQ interchange server feature that allows `tuner.as_oai_baseurl_apikey` feature - enable_tinkerscript_mode: False - # both tinkerscript / oai share the same interchange server + enable_swarm_mode: False + # both swarm / oai share the same interchange server enable_experimental_interchange_server: False # interchange server configuration interchange_server: @@ -292,7 +292,7 @@ ajet: num_fastapi_process: 2 # 1, 2 or 4 is fine max_fastapi_threads: 512 # 64 or 128 is fine max_inference_tracker_threads: 64 # recommend to be equal to `ajet.rollout.max_env_worker` - already_started: False # do not edit, used by `tinkerscript` + already_started: False # do not edit, used by `swarm` task_runner: diff --git a/ajet/default_config/ajet_ts_default.yaml b/ajet/default_config/ajet_ts_default.yaml index be7ca5a9..8f631421 100644 --- a/ajet/default_config/ajet_ts_default.yaml +++ b/ajet/default_config/ajet_ts_default.yaml @@ -23,15 +23,15 @@ ajet: # the experimental ZeroMQ interchange server feature that allows `tuner.as_oai_baseurl_apikey` feature enable_experimental_interchange_server: True # train in cloud, run episode locally - enable_tinkerscript_mode: True - # both tinkerscript / oai share the same interchange server + enable_swarm_mode: True + # both swarm / oai share the same interchange server interchange_server: interchange_method: 'ipc' # options: 'tcp' (multi-nodes) or 'ipc' (1 node) interchange_server_port: 10086 num_fastapi_process: 2 # 1, 2 or 4 is fine max_fastapi_threads: 512 # 64 or 128 is fine max_inference_tracker_threads: 64 # recommend to be equal to `ajet.rollout.max_env_worker` - already_started: False # do not edit, used by `tinkerscript` + already_started: False # do not edit, used by `swarm` rollout: # maximum number of parallel environments / simulate workers diff --git a/ajet/launcher.py b/ajet/launcher.py index 3bd484a5..3e4d2cb5 100644 --- a/ajet/launcher.py +++ b/ajet/launcher.py @@ -35,10 +35,10 @@ def parse_args(): help="verl or trinity or debug", ) parser.add_argument( - "--tinkerscript-server", + "--swarm-server", action="store_true", default=False, - help="Enable TinkerScript server mode", + help="Enable Swarm server mode", ) parser.add_argument( "--conf", @@ -146,12 +146,12 @@ def check_model_file_exists(exp_config): assert os.path.exists(model_path), f"Model path {model_path} does not exist. Please check your configuration." -def start_tinkerscript_server(env, config): +def start_swarm_server(env, config): config = dict_to_namespace(config) - assert config.ajet.enable_tinkerscript_mode, \ - "Please enable_tinkerscript_mode in config to start tinkerscript server." + assert config.ajet.enable_swarm_mode, \ + "Please enable_swarm_mode in config to start swarm server." assert config.ajet.enable_experimental_interchange_server, \ - "Please enable_experimental_interchange_server in config to start tinkerscript server." + "Please enable_experimental_interchange_server in config to start swarm server." from ajet.tuner_lib.weight_tuner.experimental.as_oai_model_server import start_interchange_server start_interchange_server(config, blocking=True, env=env) @@ -191,9 +191,9 @@ def main(): # read configuration from yaml exp_config = None exp_dir = args.exp_dir or "saved_experiments" - if args.tinkerscript_server and (not args.conf): + if args.swarm_server and (not args.conf): args.conf = os.path.abspath(os.path.join(os.path.dirname(__file__), "default_config/ajet_ts_default.yaml")) - assert os.path.exists(args.conf), "Please provide a valid config file for tinkerscript server mode." + assert os.path.exists(args.conf), "Please provide a valid config file for swarm server mode." if args.conf: yaml_path = args.conf ( @@ -206,8 +206,8 @@ def main(): # setup environment variables env, exp_config = setup_environment_vars(args, exp_config, main_yaml_fp) - if args.tinkerscript_server: - start_tinkerscript_server(env, exp_config) + if args.swarm_server: + start_swarm_server(env, exp_config) return if args.with_ray: diff --git a/ajet/task_reader/document_reader/doc_reader.py b/ajet/task_reader/document_reader/doc_reader.py index e73c3bdf..5083d33a 100644 --- a/ajet/task_reader/document_reader/doc_reader.py +++ b/ajet/task_reader/document_reader/doc_reader.py @@ -11,7 +11,7 @@ try: from unstructured.partition.auto import partition except Exception: - logger.warning("Cannot import dependency `unstructured`") + logger.info("`unstructured` is not installed.") from ajet.schema.document import Document from ajet.task_reader.document_reader.document_reader_base import ( diff --git a/ajet/task_rollout/native_parallel_worker.py b/ajet/task_rollout/native_parallel_worker.py index 35b56172..a47db114 100644 --- a/ajet/task_rollout/native_parallel_worker.py +++ b/ajet/task_rollout/native_parallel_worker.py @@ -144,7 +144,7 @@ def rollout( epoch: str, ) -> List[BaseContextTracker]: """Delegate to dynamic rollout when oversampling is enabled.""" - if self.config.ajet.enable_tinkerscript_mode: + if self.config.ajet.enable_swarm_mode: return self.rollout_swarm(tasks, mode, epoch) elif ( mode == "sample" diff --git a/ajet/task_rollout/single_worker.py b/ajet/task_rollout/single_worker.py index 908c9d47..56fcb54b 100644 --- a/ajet/task_rollout/single_worker.py +++ b/ajet/task_rollout/single_worker.py @@ -15,7 +15,7 @@ from ajet.task_rollout.async_llm_bridge import AsyncLlmBridge from ajet.task_rollout.resource_keeper import ResourceKeeper from ajet.task_runner.general_runner import GeneralRunner -from ajet.task_runner.tinkerscript_runner import TinkerScriptRunner +from ajet.task_runner.swarm_runner import SwarmRunner from ajet.utils.retry import retry_with_backoff from ajet.utils.retry import SwarmReceiveAbortException from ajet.utils.sample import get_sample_params @@ -64,7 +64,7 @@ def __init__( assert isinstance(self.pad_token_id, int), "pad_token_id must be an integer" self.current_token = 0 self.current_global_steps: int | str = "NA" - self.enable_tinkerscript_mode = config.ajet.enable_tinkerscript_mode + self.enable_swarm_mode = config.ajet.enable_swarm_mode self.async_llm_bridge = AsyncLlmBridge( config=config, async_rollout_manager=async_rollout_manager, @@ -116,8 +116,8 @@ def rollout_env_worker( with ResourceKeeper(workflow_task, config=self.config) as resource_keeper: try: workflow_task = resource_keeper.prepare() - if self.enable_tinkerscript_mode: - agent_runner = TinkerScriptRunner( + if self.enable_swarm_mode: + agent_runner = SwarmRunner( llm_inference_fn=llm_inference_fn, tokenizer=self.tokenizer, config=self.config ) else: diff --git a/ajet/task_runner/base_runner.py b/ajet/task_runner/base_runner.py index ad457a5c..32f47fa1 100644 --- a/ajet/task_runner/base_runner.py +++ b/ajet/task_runner/base_runner.py @@ -11,6 +11,7 @@ from ajet.utils.async_utils import run_async_coroutine_with_timeout from ajet.utils.dynamic_import import dynamic_import from ajet.workflow import Workflow +from ajet.tuner_lib.weight_tuner.experimental.interchange_utils import is_episode_claimed gc_lock = Lock() @@ -48,7 +49,7 @@ def get_judge(self) -> BaseJudge: # type: ignore def runner_hooks(self, observation_window, task_thread_index, workflow_task): - def should_interrupt_fn() -> bool: + def should_interrupt_soft_fn() -> bool: if (observation_window["stop"] is not None) and observation_window["stop"][task_thread_index]: # Check if the thread should stop (because other threads have completed, making this thread useless) return True return False @@ -56,13 +57,19 @@ def should_interrupt_fn() -> bool: def should_interrupt_hard_fn() -> bool: if (observation_window["hard_stop"] is not None) and observation_window["hard_stop"][task_thread_index]: # Check if the thread should stop (because other threads have completed, making this thread useless) return True + if (observation_window["stop"] is not None) and observation_window["stop"][task_thread_index]: # check soft condition + # if soft condition met, check if episode is claimed + has_claimed = is_episode_claimed(self.config, workflow_task.episode_uuid) + if not has_claimed: + # if not claimed by now (ENGINE.ROLLING_POST), this episode will never be claimed again, so we can hard stop + return True return False def generated_token_callback_fn(token_array): observation_window["token"][task_thread_index] += len(token_array) return { - "should_interrupt_fn": should_interrupt_fn, + "should_interrupt_soft_fn": should_interrupt_soft_fn, "should_interrupt_hard_fn": should_interrupt_hard_fn, "generated_token_callback_fn": generated_token_callback_fn, } diff --git a/ajet/task_runner/tinkerscript_runner.py b/ajet/task_runner/swarm_runner.py similarity index 92% rename from ajet/task_runner/tinkerscript_runner.py rename to ajet/task_runner/swarm_runner.py index b84f1f3d..07f39901 100644 --- a/ajet/task_runner/tinkerscript_runner.py +++ b/ajet/task_runner/swarm_runner.py @@ -21,7 +21,7 @@ context = zmq.Context() atexit.register(context.term) -class TinkerScriptRunner(BaseAgentRunner): +class SwarmRunner(BaseAgentRunner): def register_episode_and_wait_output( self, @@ -33,7 +33,7 @@ def register_episode_and_wait_output( should_exit_soft:Callable, should_exit_hard:Callable ) -> WorkflowOutput | None: - """Register the episode as ready in the TinkerScript data interchange center.""" + """Register the episode as ready in the Swarm data interchange center.""" # parse episode_uuid, openai_base_url, openai_api_key zmq_listen_result_addr, ipc_path = get_zmq_socket(self.config, episode_uuid, tag="workflow") success = http_register_episode( @@ -63,11 +63,11 @@ def register_episode_and_wait_output( while True: # : - # : ajet/tuner_lib/weight_tuner/experimental/as_tinkerscript_server.py + # : ajet/tuner_lib/weight_tuner/experimental/as_swarm_server.py # : socket.send_string(workflow_output.model_dump_json()) # : workflow_output: WorkflowOutput # : - # : ajet/tuner_lib/weight_tuner/experimental/as_tinkerscript_server.py + # : ajet/tuner_lib/weight_tuner/experimental/as_swarm_server.py # : socket.send_string("RUNNER.SPECIAL.RESET_CONTEXT_TRACKER") # : "RUNNER.SPECIAL.RESET_CONTEXT_TRACKER" try: @@ -77,12 +77,6 @@ def register_episode_and_wait_output( logger.warning(f'{episode_uuid} Exiting workflow due to should_exit_hard signal.') context_tracker.reset() raise SwarmReceiveAbortException(f"Episode {episode_uuid} aborted due to system exit.") - elif should_exit_soft(): - has_claimed = is_episode_claimed(self.config, episode_uuid) - if not has_claimed: - raise SwarmReceiveAbortException(f"Episode {episode_uuid} aborted due to system exit.") - else: - continue else: continue # process messages @@ -127,7 +121,7 @@ def execute(self, workflow_task: WorkflowTask) -> BaseContextTracker: workflow_task=workflow_task, ) - should_exit_soft = hooks['should_interrupt_fn'] + should_exit_soft = hooks['should_interrupt_soft_fn'] should_exit_hard = hooks['should_interrupt_hard_fn'] if should_exit_soft() or should_exit_hard(): @@ -180,7 +174,7 @@ def execute(self, workflow_task: WorkflowTask) -> BaseContextTracker: workflow_output.is_success, ) else: - raise ValueError("workflow_output.reward is None in TinkerScriptRunner, this is currently not allowed.") + raise ValueError("workflow_output.reward is None in SwarmRunner, this is currently not allowed.") # release gym_env workflow_task.gym_env = None # clear gym env client reference to avoid serialization issue diff --git a/ajet/tuner.py b/ajet/tuner.py index 90fcbbfd..f8be6ab4 100644 --- a/ajet/tuner.py +++ b/ajet/tuner.py @@ -186,4 +186,3 @@ def terminate_episode(self): if self.enable_interchange_server: if (self.proxy_client_started is True) and hasattr(self, "interchange_client"): self.interchange_client._should_terminate = True - print(f'-->self.interchange_client._should_terminate = {self.interchange_client.should_terminate}') diff --git a/ajet/tuner_lib/weight_tuner/experimental/as_oai_model_client.py b/ajet/tuner_lib/weight_tuner/experimental/as_oai_model_client.py index 396457b8..08a7cfb0 100644 --- a/ajet/tuner_lib/weight_tuner/experimental/as_oai_model_client.py +++ b/ajet/tuner_lib/weight_tuner/experimental/as_oai_model_client.py @@ -14,7 +14,7 @@ from openai.types.chat.chat_completion import ChatCompletion from ajet.tuner_lib.weight_tuner.experimental.as_oai_model_server import InterchangeCompletionRequest, API_KEY_PREFIX from ajet.utils.thread_executors import SharedInferenceTrackerThreadExecutor, SharedInterchangeThreadExecutor -from ajet.tuner_lib.weight_tuner.experimental.interchange_utils import get_zmq_socket +from ajet.tuner_lib.weight_tuner.experimental.interchange_utils import get_zmq_socket, is_episode_claimed context = zmq.Context() atexit.register(context.term) @@ -103,20 +103,29 @@ async def llm_infer( @property - def should_terminate(self) -> bool: - try: - should_interrupt = self.context_tracker.should_interrupt_hard_fn() - return self._should_terminate or should_interrupt - except: - return self._should_terminate + def should_soft_terminate(self) -> bool: + if self._should_terminate: + return True + return self.context_tracker.should_interrupt_soft_fn() + + @property + def should_hard_terminate(self) -> bool: + if self._should_terminate: + return True + if not self.config.ajet.enable_swarm_mode: + return self.should_soft_terminate + else: + return self.context_tracker.should_interrupt_hard_fn() + def begin_service(self): """ Starts the zmq communication loop. """ - if self.should_terminate: + if self.should_soft_terminate or self.should_hard_terminate: return self.episode_contect_address + if DEBUG: logger.info(f"[client] {self.episode_uuid} | Starting InterchangeClient service loop...") self.socket = context.socket(zmq.REP) self.socket.bind(f"{self.episode_contect_address}") @@ -130,14 +139,14 @@ def begin_service(self): time.sleep(0.5) wait_time = 1 while future._state == 'PENDING': - if self.should_terminate: + if self.should_soft_terminate or self.should_hard_terminate: future.cancel() + self.socket.close() + if os.path.exists(self.ipc_path): os.remove(self.ipc_path) return self.episode_contect_address time.sleep(min(wait_time * 2, 10)) wait_time += 1 - if self.should_terminate: - future.cancel() - return self.episode_contect_address + if DEBUG: logger.info(f"[client] {self.episode_uuid} | Future ready...") return self.episode_contect_address @@ -150,14 +159,14 @@ def _begin_service_threading(self): if DEBUG: logger.info(f"[client] {self.episode_uuid} | Starting ZMQ socket bind complete") try: - while not self.should_terminate: + while not self.should_hard_terminate: # listen for next request from remote try: # if DEBUG: logger.info(f"[client] {self.episode_uuid} | socket.recv_string() has begun (should_terminate {self.should_terminate})") message = self.socket.recv_string() # if DEBUG: logger.info(f"[client] {self.episode_uuid} | socket.recv_string() is done") except zmq.Again as e: - if self.should_terminate: + if self.should_hard_terminate: # abort_episode() if DEBUG: logger.info(f"[client] {self.episode_uuid} | episode over") break diff --git a/ajet/tuner_lib/weight_tuner/experimental/as_oai_model_server.py b/ajet/tuner_lib/weight_tuner/experimental/as_oai_model_server.py index 366ff7d9..fcdff782 100644 --- a/ajet/tuner_lib/weight_tuner/experimental/as_oai_model_server.py +++ b/ajet/tuner_lib/weight_tuner/experimental/as_oai_model_server.py @@ -68,7 +68,7 @@ class HealthCheckRequest(BaseModel): -def get_app(max_fastapi_threads: int = 512, enable_tinkerscript_mode=False, shared_mem_dict=None, shared_mem_dict_lock=None) -> Tuple[FastAPI, Optional[Coroutine]]: +def get_app(max_fastapi_threads: int = 512, enable_swarm_mode=False, shared_mem_dict=None, shared_mem_dict_lock=None) -> Tuple[FastAPI, Optional[Coroutine]]: @asynccontextmanager @@ -105,7 +105,7 @@ def _begin_handle_chat_completion(episode_address, int_req: InterchangeCompletio break except zmq.Again as e: # check whether server is still in rolling status - if enable_tinkerscript_mode: + if enable_swarm_mode: assert shared_mem_dict is not None if shared_mem_dict['engine_status'] not in ["ENGINE.ROLLING", "ENGINE.ROLLING_POST"]: raise HTTPException(status_code=404, detail="The server is not in ENGINE.ROLLING status, cannot accept new requests.") @@ -163,9 +163,9 @@ async def chat_completions(request: Request, authorization: str = Header(None)): # Create timeline UUID timeline_uuid = uuid.uuid4().hex - # enable_tinkerscript_mode - if enable_tinkerscript_mode: - from ajet.tuner_lib.weight_tuner.experimental.as_tinkerscript_server import ep_key + # enable_swarm_mode + if enable_swarm_mode: + from ajet.tuner_lib.weight_tuner.experimental.as_swarm_server import ep_key assert shared_mem_dict is not None assert shared_mem_dict_lock is not None @@ -195,11 +195,11 @@ async def chat_completions(request: Request, authorization: str = Header(None)): return await loop.run_in_executor(request.app.state.executor, _begin_handle_chat_completion, episode_address, int_req, episode_uuid) - if enable_tinkerscript_mode: - from ajet.tuner_lib.weight_tuner.experimental.as_tinkerscript_server import register_enable_tinkerscript_mode_routes - assert shared_mem_dict is not None, "shared_mem_dict must not be None when enable_tinkerscript_mode is True." - assert shared_mem_dict_lock is not None, "shared_mem_dict_lock must not be None when enable_tinkerscript_mode is True." - app, additional_coro = register_enable_tinkerscript_mode_routes(app, zmq_context=context, shared_mem_dict=shared_mem_dict, shared_mem_dict_lock=shared_mem_dict_lock) + if enable_swarm_mode: + from ajet.tuner_lib.weight_tuner.experimental.as_swarm_server import register_enable_swarm_mode_routes + assert shared_mem_dict is not None, "shared_mem_dict must not be None when enable_swarm_mode is True." + assert shared_mem_dict_lock is not None, "shared_mem_dict_lock must not be None when enable_swarm_mode is True." + app, additional_coro = register_enable_swarm_mode_routes(app, zmq_context=context, shared_mem_dict=shared_mem_dict, shared_mem_dict_lock=shared_mem_dict_lock) else: additional_coro = None @@ -219,18 +219,18 @@ async def chat_completions(request: Request, authorization: str = Header(None)): class InterchangeServer(Process): - def __init__(self, experiment_dir: str, port: int, num_fastapi_process: int = 2, max_fastapi_threads: int = 512, enable_tinkerscript_mode=False): + def __init__(self, experiment_dir: str, port: int, num_fastapi_process: int = 2, max_fastapi_threads: int = 512, enable_swarm_mode=False): super().__init__() self.experiment_dir = experiment_dir self.port = port self.num_fastapi_process = num_fastapi_process self.max_fastapi_threads = max_fastapi_threads - self.enable_tinkerscript_mode = enable_tinkerscript_mode + self.enable_swarm_mode = enable_swarm_mode def run(self): logger.info(f"Starting Interchange Server on port {self.port} with {self.num_fastapi_process} processes and {self.max_fastapi_threads} threads per process.") - if self.enable_tinkerscript_mode: + if self.enable_swarm_mode: manager = Manager() shared_mem_dict = manager.dict() shared_mem_dict_lock = manager.Lock() @@ -238,7 +238,7 @@ def run(self): shared_mem_dict = None shared_mem_dict_lock = None - app, additional_coro = get_app(self.max_fastapi_threads, self.enable_tinkerscript_mode, shared_mem_dict, shared_mem_dict_lock) + app, additional_coro = get_app(self.max_fastapi_threads, self.enable_swarm_mode, shared_mem_dict, shared_mem_dict_lock) async def serve_with_monitor(additional_coro): # Start the server @@ -282,7 +282,7 @@ def start_interchange_server(config, blocking=False, env={}) -> int: experiment_dir = config.ajet.experiment_dir num_fastapi_process = config.ajet.interchange_server.num_fastapi_process max_fastapi_threads = config.ajet.interchange_server.max_fastapi_threads - enable_tinkerscript_mode = config.ajet.enable_tinkerscript_mode + enable_swarm_mode = config.ajet.enable_swarm_mode # Find a free port if not specified or invalid port = int(os.environ.get("AJET_DAT_INTERCHANGE_PORT", -1)) @@ -306,7 +306,7 @@ def start_interchange_server(config, blocking=False, env={}) -> int: port, num_fastapi_process, max_fastapi_threads, - enable_tinkerscript_mode, + enable_swarm_mode, ) interchange_server.start() else: @@ -360,7 +360,7 @@ def start_interchange_server(config, blocking=False, env={}) -> int: if interchange_server: interchange_server.terminate() - if enable_tinkerscript_mode: - from ajet.tuner_lib.weight_tuner.experimental.as_tinkerscript_server import kill_process_tree + if enable_swarm_mode: + from ajet.tuner_lib.weight_tuner.experimental.as_swarm_server import kill_process_tree kill_process_tree(None, None) return -1 \ No newline at end of file diff --git a/ajet/tuner_lib/weight_tuner/experimental/as_tinkerscript_client.py b/ajet/tuner_lib/weight_tuner/experimental/as_swarm_client.py similarity index 96% rename from ajet/tuner_lib/weight_tuner/experimental/as_tinkerscript_client.py rename to ajet/tuner_lib/weight_tuner/experimental/as_swarm_client.py index 71927c75..7f4d6493 100644 --- a/ajet/tuner_lib/weight_tuner/experimental/as_tinkerscript_client.py +++ b/ajet/tuner_lib/weight_tuner/experimental/as_swarm_client.py @@ -20,7 +20,7 @@ ) -class TinkerScriptClient(object): +class SwarmClient(object): def __init__(self, server_url: str): self.server_url = server_url @@ -62,6 +62,7 @@ def begin_episode(self, allow_discard_timeout=60, episode_type="train") -> Tuple else: need_wait_scenarios =[ "Engine is syncing weights", + "Engine is in post-rolling phase", "No available episodes to claim.", ] if any(scenario in data.fail_cause for scenario in need_wait_scenarios): @@ -140,7 +141,7 @@ def abort_episode(self, task:Task, episode_uuid: str, workflow_output: WorkflowO def sync_train_config(self, agent_jet_job: AgentJetJob): """ - Sync training configuration to the TinkerScript server. + Sync training configuration to the Swarm server. This sends the AgentJetJob config as YAML to the remote server. """ # try get init status @@ -160,14 +161,14 @@ def sync_train_config(self, agent_jet_job: AgentJetJob): timeout=30 ) resp.raise_for_status() - logger.info("Synced train config to TinkerScript server") + logger.info("Synced train config to Swarm server") except Exception as e: logger.error(f"Error syncing train config: {e}") raise def start_engine(self): """ - Start the training engine on the TinkerScript server. + Start the training engine on the Swarm server. This triggers the server to begin the training process. Polls until engine status is "ENGINE.ROLLING". """ @@ -186,7 +187,7 @@ def start_engine(self): resp.raise_for_status() result = resp.json() if result.get("success"): - logger.info("Successfully started training engine on TinkerScript server") + logger.info("Successfully started training engine on Swarm server") else: logger.error("Failed to start training engine") raise RuntimeError("Failed to start training engine") @@ -299,6 +300,8 @@ def auto_sync_train_config_and_start_engine(self, agent_jet_job: AgentJetJob, fo self.start_engine() elif current_status == "ENGINE.ROLLING": logger.info("Engine is already ROLLING. No action needed.") + elif current_status == "ENGINE.ROLLING_POST": + logger.info("Engine is already ROLLING. No action needed.") elif current_status == "ENGINE.BOOTING": logger.info("Engine is BOOTING. Waiting until it becomes ROLLING...") self._wait_until_status_change_to(desired_status="ENGINE.ROLLING") @@ -312,7 +315,7 @@ def auto_sync_train_config_and_start_engine(self, agent_jet_job: AgentJetJob, fo def stop_engine(self): """ - Stop the training engine on the TinkerScript server. + Stop the training engine on the Swarm server. This triggers the server to stop the training process. """ current_status = self.get_engine_status() @@ -329,7 +332,7 @@ def stop_engine(self): resp.raise_for_status() result = resp.json() if result.get("success"): - logger.info("Successfully stopped training engine on TinkerScript server") + logger.info("Successfully stopped training engine on Swarm server") else: logger.error("Failed to stop training engine") self._wait_until_status_change_to(desired_status="ENGINE.OFFLINE") diff --git a/ajet/tuner_lib/weight_tuner/experimental/as_tinkerscript_server.py b/ajet/tuner_lib/weight_tuner/experimental/as_swarm_server.py similarity index 92% rename from ajet/tuner_lib/weight_tuner/experimental/as_tinkerscript_server.py rename to ajet/tuner_lib/weight_tuner/experimental/as_swarm_server.py index dad07b6b..ea043100 100644 --- a/ajet/tuner_lib/weight_tuner/experimental/as_tinkerscript_server.py +++ b/ajet/tuner_lib/weight_tuner/experimental/as_swarm_server.py @@ -41,7 +41,7 @@ def ep_key(episode_uuid: str) -> str: return f"episodes-{episode_uuid}" -def register_enable_tinkerscript_mode_routes( +def register_enable_swarm_mode_routes( app, zmq_context, shared_mem_dict:DictProxy, @@ -59,7 +59,7 @@ def register_enable_tinkerscript_mode_routes( # --------------------------------- claimed -> unclaimed ---------------------------------------- # ------------------------------------------------------------------------------------------------ - def find_claimed_episodes_that_need_to_be_unclaimed() -> List[str]: + async def find_claimed_episodes_that_need_to_be_unclaimed() -> List[str]: result = [] current_time = time.time() @@ -71,11 +71,11 @@ def find_claimed_episodes_that_need_to_be_unclaimed() -> List[str]: result.append(es.episode_uuid) for episode_uuid in result: - _revert_episode_to_unclaimed(episode_uuid, shared_mem_dict, shared_mem_dict_lock) + await _revert_episode_to_unclaimed(episode_uuid, shared_mem_dict, shared_mem_dict_lock) return result - def _context_tracker_reset(episode_uuid, shared_mem_dict): + def _context_tracker_reset_blocking(episode_uuid, shared_mem_dict): # must async # send message to context tracker assert 'episodes' in shared_mem_dict zmq_addr = shared_mem_dict[ep_key(episode_uuid)].zmq_listen_result_addr @@ -84,7 +84,7 @@ def _context_tracker_reset(episode_uuid, shared_mem_dict): socket.connect(zmq_addr) # - # : ajet/task_runner/tinkerscript_runner.py + # : ajet/task_runner/swarm_runner.py # : message = zmq_socket.recv_string() socket.send_string("RUNNER.SPECIAL.RESET_CONTEXT_TRACKER") @@ -93,7 +93,7 @@ def _context_tracker_reset(episode_uuid, shared_mem_dict): try: if DEBUG: logger.info(f"[server] episode_uuid: {episode_uuid} | recv_string begin.") # : - # : ajet/task_runner/tinkerscript_runner.py + # : ajet/task_runner/swarm_runner.py # : zmq_socket.send_string("ack") # : "ack" socket.recv_string() @@ -106,23 +106,24 @@ def _context_tracker_reset(episode_uuid, shared_mem_dict): raise RuntimeError("Engine is no longer rolling, aborting wait for ack.") continue - def _revert_episode_to_unclaimed(episode_uuid: str, shared_mem_dict, shared_mem_dict_lock): + async def _revert_episode_to_unclaimed(episode_uuid: str, shared_mem_dict, shared_mem_dict_lock): # check status again, because other thread may have changed it if shared_mem_dict[ep_key(episode_uuid)].episode_status != "claimed": return - with shared_mem_dict_lock: - # reset context tracker - _context_tracker_reset(episode_uuid, shared_mem_dict) - - # revert - logger.warning(f"Reverting episode {episode_uuid} to unclaimed due to client timeout.") - if ep_key(episode_uuid) in shared_mem_dict: - es:EpisodeStatus = shared_mem_dict[ep_key(episode_uuid)] - es.episode_status = "registered" - es.client_uuid = "" - es.latest_activity_timestamp = time.time() - es.allow_discard_timeout = -1 + # reset context tracker + # _context_tracker_reset_blocking(episode_uuid, shared_mem_dict) # must async + await asyncio.to_thread(_context_tracker_reset_blocking, episode_uuid, shared_mem_dict) + + # revert + logger.warning(f"Reverting episode {episode_uuid} to unclaimed due to client timeout.") + if ep_key(episode_uuid) in shared_mem_dict: + es:EpisodeStatus = shared_mem_dict[ep_key(episode_uuid)] + es.episode_status = "registered" + es.client_uuid = "" + es.latest_activity_timestamp = time.time() + es.allow_discard_timeout = -1 + with shared_mem_dict_lock: shared_mem_dict[ep_key(episode_uuid)] = es if episode_uuid in shared_mem_dict['unclaimed_episodes']: pass @@ -145,11 +146,11 @@ def _delete_episode_record(episode_uuid: str, shared_mem_dict, shared_mem_dict_l # -------------------------- return workflow output ------------------------------------ # -------------------------------------------------------------------------------------- - def _register_final_episode_output(episode_uuid, workflow_output, shared_mem_dict, shared_mem_dict_lock): + def _register_final_episode_output_blocking(episode_uuid, workflow_output, shared_mem_dict, shared_mem_dict_lock): # must async # begin send workflow_output zmq_addr = shared_mem_dict[ep_key(episode_uuid)].zmq_listen_result_addr - if DEBUG: logger.info(f"[server] episode_uuid: {episode_uuid} | Received new chat completion request (inside thread)") + if DEBUG: logger.info(f"[server] episode_uuid: {episode_uuid} | Received new chat completion request") socket = zmq_context.socket(zmq.REQ) socket.setsockopt(zmq.RCVTIMEO, RCVTIMEO) # 2 seconds recv timeout socket.connect(zmq_addr) @@ -161,7 +162,7 @@ def _register_final_episode_output(episode_uuid, workflow_output, shared_mem_dic try: if DEBUG: logger.info(f"[server] episode_uuid: {episode_uuid} | recv_string begin.") # : - # : ajet/task_runner/tinkerscript_runner.py + # : ajet/task_runner/swarm_runner.py # : zmq_socket.send_string("ack") # : "ack" socket.recv_string() @@ -187,7 +188,7 @@ def _register_final_episode_output(episode_uuid, workflow_output, shared_mem_dic async def register_episode_ready_listener(): while True: await asyncio.sleep(10) # check every 10 seconds - find_claimed_episodes_that_need_to_be_unclaimed() + await find_claimed_episodes_that_need_to_be_unclaimed() read_all_episode_status() def read_all_episode_status() -> Optional[EpisodeStatus]: @@ -500,11 +501,12 @@ async def end_episode(req: EndEpisodeRequest): episode_type = shared_mem_dict[ep_key(episode_uuid)].episode_type if episode_type == "train": - _register_final_episode_output(episode_uuid, workflow_output, shared_mem_dict, shared_mem_dict_lock) + # _register_final_episode_output_blocking(episode_uuid, workflow_output, shared_mem_dict, shared_mem_dict_lock) # must async + await asyncio.to_thread(_register_final_episode_output_blocking, episode_uuid, workflow_output, shared_mem_dict, shared_mem_dict_lock) elif episode_type == "eval": if engine_status in ["ENGINE.ROLLING"]: - _revert_episode_to_unclaimed(episode_uuid, shared_mem_dict, shared_mem_dict_lock) + await _revert_episode_to_unclaimed(episode_uuid, shared_mem_dict, shared_mem_dict_lock) else: _delete_episode_record(episode_uuid, shared_mem_dict, shared_mem_dict_lock) @@ -539,7 +541,7 @@ async def abort_episode(req: EndEpisodeRequest): return EndEpisodeResponse(success=True) if engine_status in ["ENGINE.ROLLING"]: - _revert_episode_to_unclaimed(episode_uuid, shared_mem_dict, shared_mem_dict_lock) + await _revert_episode_to_unclaimed(episode_uuid, shared_mem_dict, shared_mem_dict_lock) else: _delete_episode_record(episode_uuid, shared_mem_dict, shared_mem_dict_lock) @@ -563,7 +565,12 @@ async def is_episode_claimed(req: CanContinueEpisodeRequest): engine_status = shared_mem_dict['engine_status'] if engine_status not in ["ENGINE.ROLLING", "ENGINE.ROLLING_POST"]: return BoolResponse(success=False) - if shared_mem_dict[ep_key(req.episode_uuid)].episode_status == "claimed": + if ep_key(req.episode_uuid) not in shared_mem_dict: + return BoolResponse(success=False) + es = shared_mem_dict[ep_key(req.episode_uuid)] + if not es: + return BoolResponse(success=False) + if es.episode_status == "claimed": return BoolResponse(success=True) else: return BoolResponse(success=False) diff --git a/ajet/tuner_lib/weight_tuner/experimental/interchange_utils.py b/ajet/tuner_lib/weight_tuner/experimental/interchange_utils.py index fe7c3387..f82def38 100644 --- a/ajet/tuner_lib/weight_tuner/experimental/interchange_utils.py +++ b/ajet/tuner_lib/weight_tuner/experimental/interchange_utils.py @@ -106,6 +106,7 @@ def http_change_engine_status(config, new_status: str): def is_episode_claimed(config, episode_uuid: str) -> bool: + # TODO: add cache to reduce communication overhead resp = httpx.post( f"{get_interchange_server_url(config)}/is_episode_claimed", json={"client_uuid": "", "episode_uuid": episode_uuid}, @@ -136,7 +137,7 @@ def http_register_episode(config, openai_api_key=openai_api_key, zmq_listen_result_addr=zmq_listen_result_addr, ) - # send http request to tinkerscript server to register episode + # send http request to swarm server to register episode response = httpx.post( f"{interchange_http_addr}/register_episode", json=rer.model_dump(), # 或者 rer.model_dump() 如果使用 Pydantic v2 diff --git a/ajet_tinkerscript_threading.py b/ajet_swarm_threading.py similarity index 77% rename from ajet_tinkerscript_threading.py rename to ajet_swarm_threading.py index 7b31d4e7..24534383 100644 --- a/ajet_tinkerscript_threading.py +++ b/ajet_swarm_threading.py @@ -4,26 +4,25 @@ import requests from loguru import logger from textwrap import dedent -from ajet import WorkflowOutput -from ajet.schema.task import Task +from ajet.schema.task import Task, WorkflowOutput from ajet.copilot.job import AgentJetJob from ajet.task_reader import RouterTaskReader from ajet.utils.retry import retry_with_backoff from ajet.default_config.ajet_default import AjetTaskReader, HuggingfaceDatRepo from ajet.tuner_lib.weight_tuner.as_oai_baseurl_apikey import OpenaiBaseUrlAndApiKey -from ajet.tuner_lib.weight_tuner.experimental.as_tinkerscript_client import TinkerScriptClient +from ajet.tuner_lib.weight_tuner.experimental.as_swarm_client import SwarmClient from concurrent.futures import ThreadPoolExecutor # --------- configurations that take effect locally ------------- LOCAL_GRPO_N = 4 # grpo group size LOCAL_NUM_EPOCH = 10000 LOCAL_NUM_EPOCH = 1 -LOCAL_MAX_PARALLEL = 8 +LOCAL_MAX_PARALLEL = 64 LOCAL_DATASET_PATH = "/mnt/data_cpfs/qingxu.fu/dataset/openai/gsm8k/main" -REMOTE_TINKERJET_URL = "http://localhost:10086" # Change to your tinkerscript remote url +REMOTE_SWARM_URL = "http://localhost:10086" # Change to your swarm remote url # --------- configurations that take effect remotely ------------- -REMOTE_BATCH_SIZE = 4 +REMOTE_BATCH_SIZE = 32 REMOTE_ALLOCATE_GPU_PER_NODE = 4 REMOTE_TRAIN_MODEL_01 = '/mnt/data_cpfs/model_cache/modelscope/hub/Qwen/Qwen/Qwen2.5-7B-Instruct' @@ -35,7 +34,7 @@ class WeightUpdatedHalfway(Exception): def main(): - # Handshake with tinkerscript remote, then send training param to tinkerscript remote (such as model to be trained, algorithm, etc) + # Handshake with swarm remote, then send training param to swarm remote (such as model to be trained, algorithm, etc) dataset = RouterTaskReader( reader_type = "huggingface_dat_repo", reader_config = AjetTaskReader( @@ -45,17 +44,17 @@ def main(): ) ) - # # Hand shake with remote tinkerscript server - tinkerscript_remote = TinkerScriptClient(REMOTE_TINKERJET_URL) - # tinkerscript_remote.auto_sync_train_config_and_start_engine( - # AgentJetJob( - # algorithm="grpo", - # n_gpu=REMOTE_ALLOCATE_GPU_PER_NODE, - # model=REMOTE_TRAIN_MODEL_01, - # batch_size=REMOTE_BATCH_SIZE, - # grpo_n=LOCAL_GRPO_N, - # ) - # ) + # # Hand shake with remote swarm server + swarm_remote = SwarmClient(REMOTE_SWARM_URL) + swarm_remote.auto_sync_train_config_and_start_engine( + AgentJetJob( + algorithm="grpo", + n_gpu=REMOTE_ALLOCATE_GPU_PER_NODE, + model=REMOTE_TRAIN_MODEL_01, + batch_size=REMOTE_BATCH_SIZE, + grpo_n=LOCAL_GRPO_N, + ) + ) def rollout(task): group_reward = [] @@ -63,11 +62,11 @@ def rollout(task): for _ in range(LOCAL_GRPO_N): try: # begin episode - episode_uuid, api_baseurl_key = tinkerscript_remote.begin_episode() + episode_uuid, api_baseurl_key = swarm_remote.begin_episode() # execute agent workflow_output = execute_agent(task, api_baseurl_key) - # report output back to tinkerscript remote - tinkerscript_remote.end_episode(task, episode_uuid, workflow_output) + # report output back to swarm remote + swarm_remote.end_episode(task, episode_uuid, workflow_output) # collect reward group_reward.append(workflow_output.reward) except Exception as e: @@ -81,7 +80,7 @@ def rollout(task): for i, task in enumerate(dataset.get_training_tasks()): task_batch += [task] - if len(task_batch) == 3*REMOTE_BATCH_SIZE: + if len(task_batch) == REMOTE_BATCH_SIZE: print('*********** beginning a new batch of tasks... ***********') with ThreadPoolExecutor(max_workers=LOCAL_MAX_PARALLEL) as executor: for task in task_batch: @@ -89,7 +88,7 @@ def rollout(task): executor.shutdown(wait=True) task_batch = [] print('*********** tasks completed, wait a minute... ***********') - time.sleep(60) + time.sleep(3) return None diff --git a/ajet_tinkerscript.md b/ajet_tinkerscript.md index f7bc610c..c256853d 100644 --- a/ajet_tinkerscript.md +++ b/ajet_tinkerscript.md @@ -1 +1 @@ -python -m ajet.launcher --conf tutorial/demo_tinkerjet/ajet_tinkerscript_default.yaml --backbone="debug" --autokill +python -m ajet.launcher --conf tutorial/demo_tinkerjet/ajet_swarm_default.yaml --backbone="debug" --autokill diff --git a/docs/en/platform_comparison.md b/docs/en/platform_comparison.md index fa263918..04ffe968 100644 --- a/docs/en/platform_comparison.md +++ b/docs/en/platform_comparison.md @@ -5,7 +5,7 @@ - Multi OSS Training Backbone: Support switching between multiple open-source training backbones quickly. - Multi OSS Infer Backbone: Support both vLLM and SGLang. - Low Code Change: Do not require too many edits to convert a user‑defined (multi) agent workflow into trainable workflows. -- Without-GPU (Cloud-Computing): Rollout and power RL training in a laptop without GPU, using Tinker (AgentLightning) or without Tinker (AgentJet-TinkerScript, comming soon) +- Without-GPU (Cloud-Computing): Rollout and power RL training in a laptop without GPU, using Tinker (AgentLightning) or without Tinker (AgentJet-Swarm, comming soon) - Timeline Optimization: Automatically merge shared-history context generated by the same agents to promote training speed. - Open Bench Platform: Trace baseline environment's performance across git history in different training backbones. - Multi-Agent Optimization: Deal with sophisticated multi-agent interaction efficiently, automatically clustering and merging samples generated by the same agents. diff --git a/docs/en/workflow.md b/docs/en/workflow.md index 94cdc825..1647a219 100644 --- a/docs/en/workflow.md +++ b/docs/en/workflow.md @@ -241,7 +241,7 @@ Here's a complete example with multiple agent roles (Werewolves game): - You can flexibly switch training targets by modifying `trainable_targets` -## TinkerScript +## Swarm Wrapping and training your agent on a machine without GPU. diff --git a/docs/index.md b/docs/index.md index c9dc7f3f..ce15a67c 100644 --- a/docs/index.md +++ b/docs/index.md @@ -67,7 +67,7 @@

Any Training Engine

- Support multiple training engines as backbone (VeRL and Trinity-RFT). Tinker backbone support will be released soon. + Support multiple training engines as backbone (VeRL and Trinity-RFT). Swarm backbone support will be released soon. Choose from vLLM and SGLang as you wish. Say goodbye to training engine gaps.

diff --git a/tinkerscript.md b/tinkerscript.md index 325ca26a..05dfcce3 100644 --- a/tinkerscript.md +++ b/tinkerscript.md @@ -1,4 +1,4 @@ -# TinkerScript Design Blueprint / TinkerScript 设计蓝图 +# Swarm Design Blueprint / Swarm 设计蓝图 [English](#english-version) | [中文](#chinese-version) @@ -8,20 +8,20 @@ ## 🇬🇧 English Version ### 1. Overview -**TinkerScript** is an experimental component of AgentJet designed to decouple the **Training Logic** from the **Agent Execution Logic**. It allows users to train **full-weight LLM models** on machines without GPUs (e.g., a laptop) by offloading the actual model computation to a remote GPU server. +**Swarm** is an experimental component of AgentJet designed to decouple the **Training Logic** from the **Agent Execution Logic**. It allows users to train **full-weight LLM models** on machines without GPUs (e.g., a laptop) by offloading the actual model computation to a remote GPU server. -Unlike traditional setups where the user code must run inside the training cluster, TinkerScript allows you to verify and run your agent logic locally while the heavy lifting (training & inference) happens remotely. +Unlike traditional setups where the user code must run inside the training cluster, Swarm allows you to verify and run your agent logic locally while the heavy lifting (training & inference) happens remotely. > -> Relationship between **TinkerScript** and **Tinker**: +> Relationship between **Swarm** and **Tinker**: > -> **No relationship at all** (just like **JavaScript** and **Java**). **TinkerScript** is open-source and free. **Tinker** is close-source and not free. +> **No relationship at all** (just like **JavaScript** and **Java**). **Swarm** is open-source and free. **Tinker** is close-source and not free. -## Tinker 与 AgentJet-TinkerScript 对比表 +## Tinker 与 AgentJet-Swarm 对比表 -| 特征 | Tinker | AgentJet-TinkerScript | +| 特征 | Tinker | AgentJet-Swarm | |------|--------|--------------| | **开源性质** | ❌ 闭源 | **✅ 开源免费** | | **收费模式** | 付费服务 | **✅ 完全免费** | @@ -41,13 +41,13 @@ Unlike traditional setups where the user code must run inside the training clust ### 2. Core Architecture -The system involves two main parties: the **TinkerScript Server** (running on the GPU cluster) and the **TinkerScript Client** (running on your local machine). +The system involves two main parties: the **Swarm Server** (running on the GPU cluster) and the **Swarm Client** (running on your local machine). ```mermaid graph TD subgraph "GPU Cluster (Server Side)" TrainingLoop["Training Loop (AgentJet/GRPO)"] - TSS["TinkerScript Server (FastAPI)"] + TSS["Swarm Server (FastAPI)"] ZMQ["ZeroMQ / IPC"] SharedMem[("Shared Memory")] LLM["LLM Engine (vLLM/SGLang)"] diff --git a/tinkerscript_1.md b/tinkerscript_1.md index 076087dd..c15df6db 100644 --- a/tinkerscript_1.md +++ b/tinkerscript_1.md @@ -1,12 +1,12 @@ -# TinkerScript Design Blueprint +# Swarm Design Blueprint -TinkerScript represents a client-server architecture designed to decouple the **Training Loop** (Server-side) from the **Rollout Execution** (Client-side). This allows for distributed, flexible, and potentially remote execution of agent rollouts (inference + reward calculation) while centralizing the model training and weight updates. +Swarm represents a client-server architecture designed to decouple the **Training Loop** (Server-side) from the **Rollout Execution** (Client-side). This allows for distributed, flexible, and potentially remote execution of agent rollouts (inference + reward calculation) while centralizing the model training and weight updates. ## 1. System Architecture The system consists of three main components: -### A. TinkerScript Server (The Trainer) +### A. Swarm Server (The Trainer) * **Role**: Manages the training lifecycle, generates tasks (episodes), serves the model (LLM) API, and updates model weights. * **Technology**: Python, FastAPI, ZeroMQ (IPC/TCP), Shared Memory (Multiprocessing). * **Location**: Runs on the GPU cluster/Training node. @@ -15,7 +15,7 @@ The system consists of three main components: * Exposes an HTTP API for external clients to claim tasks and submit results. * Acts as a bridge between the HTTP world and the internal ZeroMQ-based training pipeline. -### B. TinkerScript Client (The User Script) +### B. Swarm Client (The User Script) * **Role**: Fetches tasks, runs the agent logic, computes rewards, and reports back. * **Technology**: Python (Requests/HTTPX). * **Location**: Can run locally, on a separate CPU cluster, or even a different cloud environment. @@ -103,11 +103,11 @@ class EpisodeStatus: ## 4. Key Configurations -From `ajet_tinkerscript_default.yaml`, we see how this mode is activated: +From `ajet_swarm_default.yaml`, we see how this mode is activated: ```yaml experiment_dir: "auto" -enable_tinkerscript_mode: True # Activates the HTTP API Server +enable_swarm_mode: True # Activates the HTTP API Server interchange_server: interchange_method: 'ipc' # Internal communication (ZeroMQ) interchange_server_port: 10086 # HTTP API Port diff --git a/tutorial/demo_tinkerjet/README.md b/tutorial/demo_tinkerjet/README.md index 10985526..d7e1ad20 100644 --- a/tutorial/demo_tinkerjet/README.md +++ b/tutorial/demo_tinkerjet/README.md @@ -1,20 +1,20 @@ -# TinkerScript +# Swarm -TinkerScript is an experimental component of AgentJet, +Swarm is an experimental component of AgentJet, allowing users to - run, debug and train **full-weight** LLM model behind user-defined LLM workflows in **machines without GPU**. -Similar to Tinker & Open-Tinker, the basic idea behind TinkerScript is to: +Similar to Tinker & Open-Tinker, the basic idea behind Swarm is to: - use remote (or cloud) GPU machine(s) as computation media. -However, TinkerScript goes even further on this path: +However, Swarm goes even further on this path: - Users only need to write and run their agents in a big `while` loop (e.g., in their laptop), and provide samples generated in this process. -- TinkerScript will take care of everything else. +- Swarm will take care of everything else. -- TinkerScript trains **full-weight** LLM model instead of lora. +- Swarm trains **full-weight** LLM model instead of lora. - Upon the termination of the training session, user can call `download_tuned_model` to download tuned LLM(s). @@ -61,4 +61,4 @@ tinkerjet_remote.close() - AgentJet are not able to explicitly distinguish different agents in multi-agent scenario. But **do not worry**, AgentJet will still try its best to recognize shards of llm timelines and merge them behind the curtain, automatically. -- TinkerScript does not support prompt tuning. +- Swarm does not support prompt tuning. diff --git a/tutorial/demo_tinkerjet/ajet_tinkerscript_default.yaml b/tutorial/demo_tinkerjet/ajet_tinkerscript_default.yaml index c6913470..c0baa7f4 100644 --- a/tutorial/demo_tinkerjet/ajet_tinkerscript_default.yaml +++ b/tutorial/demo_tinkerjet/ajet_tinkerscript_default.yaml @@ -23,8 +23,8 @@ ajet: # the experimental ZeroMQ interchange server feature that allows `tuner.as_oai_baseurl_apikey` feature enable_experimental_interchange_server: True # train in cloud, run episode locally - enable_tinkerscript_mode: True - # both tinkerscript / oai share the same interchange server + enable_swarm_mode: True + # both swarm / oai share the same interchange server interchange_server: interchange_method: 'ipc' # options: 'tcp' (multi-nodes) or 'ipc' (1 node) interchange_server_port: 10086 diff --git a/tutorial/example_academic_trans/trans.py b/tutorial/example_academic_trans/trans.py index eed5c5df..6e506d4a 100644 --- a/tutorial/example_academic_trans/trans.py +++ b/tutorial/example_academic_trans/trans.py @@ -25,7 +25,7 @@ LOCAL_DATASET_PATH = "/mnt/data_cpfs/qingxu.fu/agentjet/agentjet/tmp/arxiv_papers/train.parquet" -# Handshake with tinkerscript remote, then send training param to tinkerscript remote (such as model to be trained, algorithm, etc) +# Handshake with swarm remote, then send training param to swarm remote (such as model to be trained, algorithm, etc) dataset = RouterTaskReader( reader_type = "huggingface_dat_repo", reader_config = AjetTaskReader( diff --git a/tutorial/example_academic_trans/trans_roll.py b/tutorial/example_academic_trans/trans_roll.py index f7c50730..67dbcf8d 100644 --- a/tutorial/example_academic_trans/trans_roll.py +++ b/tutorial/example_academic_trans/trans_roll.py @@ -5,7 +5,7 @@ from loguru import logger from textwrap import dedent from ajet.copilot.job import AgentJetJob -from ajet.tuner_lib.weight_tuner.experimental.as_tinkerscript_client import TinkerScriptClient +from ajet.tuner_lib.weight_tuner.experimental.as_swarm_client import SwarmClient from ajet.default_config.ajet_default import AjetTaskReader, HuggingfaceDatRepo from ajet.tuner_lib.weight_tuner.as_oai_baseurl_apikey import OpenaiBaseUrlAndApiKey from ajet import WorkflowOutput @@ -24,7 +24,7 @@ LOCAL_NUM_EPOCH = 1 LOCAL_MAX_PARALLEL = 32 LOCAL_DATASET_PATH = "/mnt/data_cpfs/qingxu.fu/agentjet/agentjet/tmp/arxiv_papers/train.parquet" -REMOTE_TINKERJET_URL = "http://localhost:10086" # Change to your tinkerscript remote url +REMOTE_SWARM_URL = "http://localhost:10086" # Change to your swarm remote url # --------- configurations that take effect remotely ------------- REMOTE_ALLOCATE_GPU_PER_NODE = 8 @@ -37,7 +37,7 @@ class WeightUpdatedHalfway(Exception): def main(): - # Handshake with tinkerscript remote, then send training param to tinkerscript remote (such as model to be trained, algorithm, etc) + # Handshake with swarm remote, then send training param to swarm remote (such as model to be trained, algorithm, etc) dataset = RouterTaskReader( reader_type = "huggingface_dat_repo", reader_config = AjetTaskReader( @@ -47,10 +47,10 @@ def main(): ) ) - # Hand shake with remote tinkerscript server - tinkerscript_remote = TinkerScriptClient(REMOTE_TINKERJET_URL) - # tinkerscript_remote.stop_engine() - tinkerscript_remote.auto_sync_train_config_and_start_engine( + # Hand shake with remote swarm server + swarm_remote = SwarmClient(REMOTE_SWARM_URL) + # swarm_remote.stop_engine() + swarm_remote.auto_sync_train_config_and_start_engine( AgentJetJob( algorithm="grpo", n_gpu=REMOTE_ALLOCATE_GPU_PER_NODE, @@ -67,17 +67,17 @@ def rollout(task): episode_uuid = None try: # begin episode - episode_uuid, api_baseurl_key = tinkerscript_remote.begin_episode() + episode_uuid, api_baseurl_key = swarm_remote.begin_episode() # execute agent workflow_output = execute_agent(task, api_baseurl_key) - # report output back to tinkerscript remote - tinkerscript_remote.end_episode(task, episode_uuid, workflow_output) + # report output back to swarm remote + swarm_remote.end_episode(task, episode_uuid, workflow_output) # collect reward group_reward.append(workflow_output.reward) except Exception as e: logger.exception("Exception during rollout:", e) if episode_uuid: - tinkerscript_remote.abort_episode(episode_uuid) + swarm_remote.abort_episode(episode_uuid) print(f"Group reward mean & std: {sum(group_reward)/len(group_reward)} +/- { (max(group_reward)-min(group_reward))/2 }") # Main Training loop @@ -94,10 +94,10 @@ def rollout(task): time.sleep(1) - # tinkerscript_remote.stop_engine() - # model_path = tinkerscript_remote.download_latest_model(path='./tinkerscript_saved_model') + # swarm_remote.stop_engine() + # model_path = swarm_remote.download_latest_model(path='./swarm_saved_model') time.sleep(10000) - # Get tuned model from tinkerscript remote + # Get tuned model from swarm remote return None From f1edf194f632221adbc9aa38f77b9cba51e2ae4d Mon Sep 17 00:00:00 2001 From: "qingxu.fu" Date: Thu, 5 Feb 2026 10:47:34 +0800 Subject: [PATCH 20/25] update pro-academic-trans agent --- .../experimental/as_swarm_client.py | 7 +- ajet_swarm_threading.py | 2 +- .../example_academic_trans/trans_reward.py | 1 + tutorial/example_academic_trans/trans_roll.py | 66 +++++++++---------- 4 files changed, 36 insertions(+), 40 deletions(-) diff --git a/ajet/tuner_lib/weight_tuner/experimental/as_swarm_client.py b/ajet/tuner_lib/weight_tuner/experimental/as_swarm_client.py index 7f4d6493..e7117db2 100644 --- a/ajet/tuner_lib/weight_tuner/experimental/as_swarm_client.py +++ b/ajet/tuner_lib/weight_tuner/experimental/as_swarm_client.py @@ -108,19 +108,18 @@ def end_episode(self, task:Task, episode_uuid: str, workflow_output: WorkflowOut except Exception as e: logger.error(f"Error ending episode: {e}") - def abort_episode(self, task:Task, episode_uuid: str, workflow_output: WorkflowOutput): + def abort_episode(self, episode_uuid: str): if not episode_uuid: logger.error("No episode to end.") return try: - task_id = task.task_id - workflow_output.metadata["task_id"] = task_id + workflow_output = WorkflowOutput(reward=0.0, metadata={}) req_obj = EndEpisodeRequest( client_uuid=self.client_uuid, episode_uuid=episode_uuid, workflow_output=workflow_output, - task_id=task_id + task_id="" ) resp = httpx.post( diff --git a/ajet_swarm_threading.py b/ajet_swarm_threading.py index 24534383..ed3f84b2 100644 --- a/ajet_swarm_threading.py +++ b/ajet_swarm_threading.py @@ -77,7 +77,7 @@ def rollout(task): logger.exception("Exception during rollout group", e) task_batch = [] - for i, task in enumerate(dataset.get_training_tasks()): + for i, task in enumerate(dataset.generate_training_tasks()): task_batch += [task] if len(task_batch) == REMOTE_BATCH_SIZE: diff --git a/tutorial/example_academic_trans/trans_reward.py b/tutorial/example_academic_trans/trans_reward.py index b8acae84..01aed962 100644 --- a/tutorial/example_academic_trans/trans_reward.py +++ b/tutorial/example_academic_trans/trans_reward.py @@ -76,6 +76,7 @@ def get_translation_quality_system_prompt() -> str: 4. **Subject-verb inconsistencies** - Mismatched subjects due to improper sentence structure (e.g., "在...中,本文展示..." where the subject is confused) 5. **Inappropriate word choices** - Using colloquial or incorrect terms instead of proper academic expressions (e.g., "效率" vs "有效性" in certain contexts) 6. **Redundant punctuation** - Unnecessary commas or other punctuation that disrupts Chinese reading flow + 7. **主语不清晰** - 中文句子主语缺失或不明确。例如:“通过该实验,证明了该药物对癌细胞有抑制作用”(缺少主语) **Examples of these errors:** [[examples_text]] diff --git a/tutorial/example_academic_trans/trans_roll.py b/tutorial/example_academic_trans/trans_roll.py index 67dbcf8d..f68733b5 100644 --- a/tutorial/example_academic_trans/trans_roll.py +++ b/tutorial/example_academic_trans/trans_roll.py @@ -27,9 +27,9 @@ REMOTE_SWARM_URL = "http://localhost:10086" # Change to your swarm remote url # --------- configurations that take effect remotely ------------- +REMOTE_BATCH_SIZE = 32 REMOTE_ALLOCATE_GPU_PER_NODE = 8 REMOTE_TRAIN_MODEL_01 = '/mnt/data_cpfs/model_cache/modelscope/hub/Qwen/Qwen/Qwen2.5-7B-Instruct' -REMOTE_BATCH_SIZE = 32 class WeightUpdatedHalfway(Exception): """Raised when the remote side starts updating model weights halfway through an episode.""" @@ -49,55 +49,51 @@ def main(): # Hand shake with remote swarm server swarm_remote = SwarmClient(REMOTE_SWARM_URL) - # swarm_remote.stop_engine() swarm_remote.auto_sync_train_config_and_start_engine( AgentJetJob( algorithm="grpo", n_gpu=REMOTE_ALLOCATE_GPU_PER_NODE, model=REMOTE_TRAIN_MODEL_01, + batch_size=REMOTE_BATCH_SIZE, grpo_n=LOCAL_GRPO_N, - ), - force_restart=True, + ) ) - # Define rollout def rollout(task): group_reward = [] - for i in range(LOCAL_GRPO_N): - episode_uuid = None - try: - # begin episode - episode_uuid, api_baseurl_key = swarm_remote.begin_episode() - # execute agent - workflow_output = execute_agent(task, api_baseurl_key) - # report output back to swarm remote - swarm_remote.end_episode(task, episode_uuid, workflow_output) - # collect reward - group_reward.append(workflow_output.reward) - except Exception as e: - logger.exception("Exception during rollout:", e) - if episode_uuid: - swarm_remote.abort_episode(episode_uuid) + try: + for _ in range(LOCAL_GRPO_N): + try: + # begin episode + episode_uuid, api_baseurl_key = swarm_remote.begin_episode() + # execute agent + workflow_output = execute_agent(task, api_baseurl_key) + # report output back to swarm remote + swarm_remote.end_episode(task, episode_uuid, workflow_output) + # collect reward + group_reward.append(workflow_output.reward) + except Exception as e: + logger.exception("Exception during rollout:", e) + print(f"Group reward mean & std: {sum(group_reward)/len(group_reward)} +/- { (max(group_reward)-min(group_reward))/2 }") + except Exception as e: + logger.exception("Exception during rollout group", e) - # Main Training loop - futures = [] - with ThreadPoolExecutor(max_workers=LOCAL_MAX_PARALLEL) as executor: - for epoch in range(LOCAL_NUM_EPOCH): - for i, task in enumerate(dataset.generate_training_tasks()): - print(f"Submitting task for epoch {epoch}") - future = executor.submit(rollout, task) + task_batch = [] + for i, task in enumerate(dataset.generate_training_tasks()): + task_batch += [task] - futures += [future] - while (i % REMOTE_BATCH_SIZE) == (REMOTE_BATCH_SIZE - 1) and futures: - futures = [f for f in futures if not f.done()] - time.sleep(1) + if len(task_batch) == REMOTE_BATCH_SIZE: + print('*********** beginning a new batch of tasks... ***********') + with ThreadPoolExecutor(max_workers=LOCAL_MAX_PARALLEL) as executor: + for task in task_batch: + executor.submit(rollout, task) + executor.shutdown(wait=True) + task_batch = [] + print('*********** tasks completed, wait a minute... ***********') + time.sleep(60) - # swarm_remote.stop_engine() - # model_path = swarm_remote.download_latest_model(path='./swarm_saved_model') - time.sleep(10000) - # Get tuned model from swarm remote return None From 47812cba22ffcce9262e8b66b2db07f142845c36 Mon Sep 17 00:00:00 2001 From: "qingxu.fu" Date: Thu, 5 Feb 2026 16:51:04 +0800 Subject: [PATCH 21/25] revise pro-trans --- ajet/context_tracker/multiagent_tracking.py | 2 +- ajet/task_runner/swarm_runner.py | 3 +- .../experimental/as_oai_model_client.py | 4 +- .../experimental/as_oai_model_server.py | 4 +- .../experimental/as_swarm_server.py | 3 +- tutorial/example_academic_trans/trans.py | 122 +++++++++--------- .../example_academic_trans/trans_reward.py | 32 +++-- 7 files changed, 94 insertions(+), 76 deletions(-) diff --git a/ajet/context_tracker/multiagent_tracking.py b/ajet/context_tracker/multiagent_tracking.py index f0c3fef1..e8514222 100644 --- a/ajet/context_tracker/multiagent_tracking.py +++ b/ajet/context_tracker/multiagent_tracking.py @@ -294,7 +294,7 @@ def save_llm_interaction_timeline(self, tools, llm_ext_msg, timeline): # save to self.saved_timelines self.saved_timelines += [copy.deepcopy(timeline)] - # DEBUG = True # warn when merge fails + # warn when merge fails timeline_merging_policy: TimelineMergingPolicyConfig = self.config.ajet.context_tracker.timeline_merging_policy if ( self.config.ajet.context_tracker.detect_timeline_snap diff --git a/ajet/task_runner/swarm_runner.py b/ajet/task_runner/swarm_runner.py index 07f39901..cb839aa3 100644 --- a/ajet/task_runner/swarm_runner.py +++ b/ajet/task_runner/swarm_runner.py @@ -16,7 +16,8 @@ from ajet import Workflow from typing import Callable -DEBUG = True +# DEBUG = True +DEBUG = False context = zmq.Context() atexit.register(context.term) diff --git a/ajet/tuner_lib/weight_tuner/experimental/as_oai_model_client.py b/ajet/tuner_lib/weight_tuner/experimental/as_oai_model_client.py index 08a7cfb0..dd8b191f 100644 --- a/ajet/tuner_lib/weight_tuner/experimental/as_oai_model_client.py +++ b/ajet/tuner_lib/weight_tuner/experimental/as_oai_model_client.py @@ -22,8 +22,8 @@ if TYPE_CHECKING: from ajet.context_tracker.multiagent_tracking import MultiAgentContextTracker -# DEBUG = False -DEBUG = True +DEBUG = False +# DEBUG = True def generate_auth_token(agent_name, target_tag, episode_uuid, episode_address): """ diff --git a/ajet/tuner_lib/weight_tuner/experimental/as_oai_model_server.py b/ajet/tuner_lib/weight_tuner/experimental/as_oai_model_server.py index fcdff782..43ee39da 100644 --- a/ajet/tuner_lib/weight_tuner/experimental/as_oai_model_server.py +++ b/ajet/tuner_lib/weight_tuner/experimental/as_oai_model_server.py @@ -54,8 +54,8 @@ class HealthCheckRequest(BaseModel): # Create FastAPI app SERVER_SHUTDOWN_EVENT = threading.Event() -# DEBUG = False -DEBUG = True +DEBUG = False +# DEBUG = True context = zmq.Context() atexit.register(context.term) diff --git a/ajet/tuner_lib/weight_tuner/experimental/as_swarm_server.py b/ajet/tuner_lib/weight_tuner/experimental/as_swarm_server.py index ea043100..9158ec6b 100644 --- a/ajet/tuner_lib/weight_tuner/experimental/as_swarm_server.py +++ b/ajet/tuner_lib/weight_tuner/experimental/as_swarm_server.py @@ -26,7 +26,8 @@ VALID_STATUSES, ) -DEBUG = True +# DEBUG = True +DEBUG = False RCVTIMEO = 2 * 1000 RCVTIMEO_OUT = 300 * 1000 RCVTIMEO_WAIT_N = RCVTIMEO_OUT // RCVTIMEO diff --git a/tutorial/example_academic_trans/trans.py b/tutorial/example_academic_trans/trans.py index 6e506d4a..e73844d4 100644 --- a/tutorial/example_academic_trans/trans.py +++ b/tutorial/example_academic_trans/trans.py @@ -3,10 +3,10 @@ import os import time import asyncio -import requests import threading from loguru import logger from textwrap import dedent +from openai import OpenAI from ajet import WorkflowOutput from ajet.schema.task import Task @@ -22,18 +22,6 @@ from .trans_reward import TranslationQualityGrader, build_translation_quality_messages, examples -LOCAL_DATASET_PATH = "/mnt/data_cpfs/qingxu.fu/agentjet/agentjet/tmp/arxiv_papers/train.parquet" - - -# Handshake with swarm remote, then send training param to swarm remote (such as model to be trained, algorithm, etc) -dataset = RouterTaskReader( - reader_type = "huggingface_dat_repo", - reader_config = AjetTaskReader( - huggingface_dat_repo = HuggingfaceDatRepo( - dataset_path = LOCAL_DATASET_PATH - ) - ) -) @retry_with_backoff(max_retry=3) def execute_agent(task: Task, api_baseurl_key: OpenaiBaseUrlAndApiKey): @@ -48,17 +36,22 @@ def execute_agent(task: Task, api_baseurl_key: OpenaiBaseUrlAndApiKey): messages, rough_translate = rough_translate_agent(base_url, api_key, abstract) # print_listofdict(messages, header="rough_translate_agent", mod="c") - messages, fix_nouns = detect_hard_proper_nouns(messages, base_url, api_key, abstract, rough_translate) + # messages, fix_nouns = detect_hard_proper_nouns(messages, base_url, api_key, abstract, rough_translate) + messages, fix_nouns = detect_hard_proper_nouns(messages, grader_base_url, grader_api_key, abstract, rough_translate) # print_listofdict(messages, header="detect_hard_proper_nouns", mod="c") messages, final_translation = produce_final_translation(messages, base_url, api_key, abstract, rough_translate, fix_nouns) print_listofdict(messages, header="final_translation", mod="c") - grader = TranslationQualityGrader( - model=OpenAIChatModel(base_url=grader_base_url, api_key=grader_api_key, model="qwen-max") - ) - grader_score = asyncio.run(grader.aevaluate(original_text=abstract, translation=final_translation)) - raw_reward = grader_score.score # Normalize to 0-1 range (score is 0-3) + if final_translation is None: + raw_reward = 0.0 + else: + grader = TranslationQualityGrader( + model=OpenAIChatModel(base_url=grader_base_url, api_key=grader_api_key, model="qwen3-max-2026-01-23") + ) + grader_score = asyncio.run(grader.aevaluate(original_text=abstract, translation=final_translation)) + raw_reward = grader_score.score + print(f"Grader Score: {grader_score.score}, Reason: {grader_score.reason}, Metadata: {grader_score.metadata}") return WorkflowOutput(reward=raw_reward, metadata={ "rough_translate": rough_translate, "fix_nouns": fix_nouns, @@ -66,50 +59,68 @@ def execute_agent(task: Task, api_baseurl_key: OpenaiBaseUrlAndApiKey): }) -def detect_hard_proper_nouns(messages, base_url, api_key, abstract, rough_translate): +def produce_final_translation(messages, base_url, api_key, abstract, rough_translate, fix_nouns): messages = messages + [ - { "role": "user", - "content": "You new job is to detect translation errors of discipline-specific proper nouns. " - "Use json to list all errors found in the translation result and provide correction. " - "Json format: [{\"original_word\": \"xxx\", \"wrong_translation\": \"xxx\", \"wrong_reason\": \"xxx\", \"correct_translation\": \"xxx\"}, ...]. " - "If no errors are found, return an empty list []." - "Please list all translation errors of discipline-specific proper nouns found in the translation result according to the requirements." + "content": "Please produce the final, corrected Chinese translation by applying all the corrections listed above. " + "Output only the final translation between ... , so I will extract result with regex." }, ] - response = requests.post( f"{base_url}/chat/completions", json = { "model": "qwen-turbo", "messages": messages, }, headers = { "Authorization": f"Bearer {api_key}" } ) - fix_nouns = response.json()['choices'][0]['message']['content'] + client = OpenAI(base_url=base_url, api_key=api_key) + response = client.chat.completions.create( + model="agentjet-model", + messages=messages + ) + final_translation = response.choices[0].message.content + messages += [ { "role": "assistant", - "content": fix_nouns + "content": final_translation } ] - return messages, fix_nouns + # Extract final translation + match = re.search(r"(.*?)", final_translation, re.DOTALL) + if match: + final_translation = match.group(1).strip() + else: + final_translation = None -def produce_final_translation(messages, base_url, api_key, abstract, rough_translate, fix_nouns): + return messages, final_translation + + + +def detect_hard_proper_nouns(messages, base_url, api_key, abstract, rough_translate): messages = messages + [ + { "role": "user", - "content": "Please produce the final, corrected Chinese translation by applying all the corrections listed above. " - "Output only the final translation without any explanations or additional text." + "content": "You new job is to detect translation errors of discipline-specific proper nouns. " + "Use json to list all errors found in the translation result and provide correction. " + "Json format: [{\"original_word\": \"xxx\", \"wrong_translation\": \"xxx\", \"wrong_reason\": \"xxx\", \"correct_translation\": \"xxx\"}, ...]. " + "If no errors are found, return an empty list []." + "Please list all translation errors of discipline-specific proper nouns found in the translation result according to the requirements." }, - ] - response = requests.post( f"{base_url}/chat/completions", json = { "model": "qwen-turbo", "messages": messages, }, headers = { "Authorization": f"Bearer {api_key}" } ) - final_translation = response.json()['choices'][0]['message']['content'] + ] + client = OpenAI(base_url=base_url, api_key=api_key) + response = client.chat.completions.create( + model="qwen3-max-2026-01-23", + messages=messages, + extra_body={"enable_thinking":True} + ) + fix_nouns = response.choices[0].message.content messages += [ { "role": "assistant", - "content": final_translation + "content": fix_nouns } ] - - return messages, final_translation + return messages, fix_nouns def rough_translate_agent(base_url, api_key, abstract): @@ -123,9 +134,12 @@ def rough_translate_agent(base_url, api_key, abstract): "such as conforming to the logic of the Chinese language, being simple, rigorous, and concise, " "and avoiding the use of first-person pronouns when passive voice is appropriate. " "Ensure that specialized terms are translated correctly according to academic standards. " - "Replace 我们 with 本研究 or 本文. " - "If an abbreviation is short in Chinese, use Chinese. " - "If an abbreviation is long in Chinese, use abbreviation. " + "Replace 我/我们 with 本研究 or 本文 or 研究者 or simply remove it and rephrase the sentence. " + "If an English abbreviation is short in Chinese, use Chinese. " + "If an English abbreviation is long in Chinese, use English abbreviation. " + "To use an English abbreviation, if the author has mentioned the full form first, mention the full form at its first appearance. " + "e.g. `We have used the LAsMA heterodyne array installed on the Atacama Pathfinder EXperiment (APEX)` should be translated as " + "`本研究使用了安装在阿塔卡马探路者实验望远镜(APEX, Atacama Pathfinder EXperiment)上的LAsMA外差阵列`. " }, { "role": "user", @@ -135,8 +149,13 @@ def rough_translate_agent(base_url, api_key, abstract): for ex in examples: messages[0]['content'] += f"\n\nExample:\n\tOriginal: {ex['original']}\n\tBad Translation: {ex['bad']}\n\tHint: {ex['hint']}\n\tGood Translation: {ex['good']}" - response = requests.post( f"{base_url}/chat/completions", json = { "model": "qwen-turbo", "messages": messages, }, headers = { "Authorization": f"Bearer {api_key}" } ) - rough_translate = response.json()['choices'][0]['message']['content'] + + client = OpenAI(base_url=base_url, api_key=api_key) + response = client.chat.completions.create( + model="agentjet-model", + messages=messages + ) + rough_translate = response.choices[0].message.content messages += [ { "role": "assistant", @@ -145,18 +164,3 @@ def rough_translate_agent(base_url, api_key, abstract): ] return messages, rough_translate - - - -if __name__ == "__main__": - - for i, task in enumerate(dataset.generate_training_tasks()): - execute_agent( - task, - OpenaiBaseUrlAndApiKey( - base_url="https://dashscope.aliyuncs.com/compatible-mode/v1", - api_key=os.environ.get("DASHSCOPE_API_KEY", "") - ) - ) - - diff --git a/tutorial/example_academic_trans/trans_reward.py b/tutorial/example_academic_trans/trans_reward.py index 01aed962..663d7107 100644 --- a/tutorial/example_academic_trans/trans_reward.py +++ b/tutorial/example_academic_trans/trans_reward.py @@ -4,6 +4,7 @@ from openjudge.models.base_chat_model import BaseChatModel from typing import List from textwrap import dedent +from beast_logger import print_listofdict examples = [ @@ -68,15 +69,19 @@ def get_translation_quality_system_prompt() -> str: return dedent(""" You are an objective translation quality evaluator for academic paper translations from English to Chinese. Your task is to identify ONLY the specific types of errors demonstrated in the provided examples - not general translation quality issues. - Focus (but do not limit to) on issues below (as shown in the examples): + 重点关注(但不限于)以下问题类型(如示例所示): - 1. **First-person pronoun issues** - Using "我们" instead of "本研究" or "本文" in academic contexts - 2. **Abbreviation translation errors** - Using abbreviations when concise Chinese exists (e.g., "GWs" instead of "引力波"), or translating abbreviations that should remain in English (like "EMBB") - 3. **Word order problems** - Not adjusting sentence structure to emphasize key points in Chinese academic style - 4. **Subject-verb inconsistencies** - Mismatched subjects due to improper sentence structure (e.g., "在...中,本文展示..." where the subject is confused) - 5. **Inappropriate word choices** - Using colloquial or incorrect terms instead of proper academic expressions (e.g., "效率" vs "有效性" in certain contexts) - 6. **Redundant punctuation** - Unnecessary commas or other punctuation that disrupts Chinese reading flow + 1. **错误使用第一人称代词** - 禁止使用"我们"。正确的方法是使用"本研究"、"本文"、“研究者”,或者直接删除we并改写句子替换主语。不要漏掉出现的任何第一人称代词。 + 2. **缩写翻译错误** - 当存在简洁的中文表达时使用缩写(例如,使用"GWs"而非"引力波"),或翻译本应保留英文的缩写(如"EMBB") + 3. **语序问题** - 未调整句子结构以符合中文学术风格强调重点的习惯 + 4. **主谓不一致、主语缺失** - 由于句子结构不当导致主语混乱(例如,"在...中,本文展示..."中主语混淆) + 5. **用词不当** - 使用口语化或不正确的术语而非恰当的学术表达 + 6. **多余标点和停顿** - 不必要的逗号或其他标点符号影响中文阅读流畅性 7. **主语不清晰** - 中文句子主语缺失或不明确。例如:“通过该实验,证明了该药物对癌细胞有抑制作用”(缺少主语) + 8. **缩写问题** - 首次出现自定义缩写、且原文中已经提供自定义缩写的英文全称时,没有在首次出现的地方提供英文全称。 + (正确的例子:`We have used the LAsMA heterodyne array installed on the Atacama Pathfinder EXperiment (APEX)`->`本研究使用了安装在阿塔卡马探路者实验望远镜(APEX, Atacama Pathfinder EXperiment)上的LAsMA外差阵列`) + 9. **专有名词翻译错误** - 领域特定的专有名词翻译错误,例如技术术语、学科术语等。如错把Agent翻译成“代理”(实际上应为“智能体”)等。 + 10. **表意偏差** - 翻译结果与原文在意义上存在偏差,导致信息传达不准确。 **Examples of these errors:** [[examples_text]] @@ -90,15 +95,19 @@ def get_translation_quality_system_prompt() -> str: * For each key issue found, provide the specific error, its type, and where it appears in the translation. * Be precise about which error category each issue belongs to. * Focus on objective errors matching the example patterns, not subjective preferences. + * 当出现 **语序问题**、**主谓不一致、主语缺失**、**主语不清晰**、**专有名词翻译错误**、**表意偏差** 等严重问题时,直接给 0 分。 + * 逐句分析,切勿遗漏。 Think carefully before flagging any error. Ask yourself: Does this match one of the specific error types from the examples? Is this truly an objective error or just a stylistic preference? Return your response in this format: - X - Your detailed step-by-step reasoning analyzing the translation against the error categories + + Your analysis + - Error Type: [category]. Error: [specific issue]. Location: [where it appears in the translation] + X The score must be 0, 1, 2. Each key issue should be on its own line starting with a dash. If no errors are found, the key_issues section should be empty or state "None detected". """.replace("[[examples_text]]", examples_text)) @@ -129,7 +138,10 @@ def parse_translation_quality_response(text: str) -> dict: def build_translation_quality_messages(original_text: str, translation: str) -> List[dict]: return [ - {"role": "system", "content": get_translation_quality_system_prompt()}, + { + "role": "system", + "content": get_translation_quality_system_prompt() + }, { "role": "user", "content": TRANSLATION_QUALITY_USER_PROMPT.format( From b15983acc212591cc3657dd0a3a0be8cff2c4bfa Mon Sep 17 00:00:00 2001 From: "qingxu.fu" Date: Thu, 5 Feb 2026 18:32:49 +0800 Subject: [PATCH 22/25] make rollout more robust --- ajet/task_rollout/native_parallel_worker.py | 6 ++-- ajet/task_runner/swarm_runner.py | 3 +- .../experimental/as_swarm_client.py | 9 +++++- .../experimental/as_swarm_server.py | 30 +++++++++++------ ajet/utils/thread_executors.py | 32 +++++++++++++++---- tutorial/example_academic_trans/trans.py | 5 +-- tutorial/example_academic_trans/trans_roll.py | 15 ++++----- 7 files changed, 68 insertions(+), 32 deletions(-) diff --git a/ajet/task_rollout/native_parallel_worker.py b/ajet/task_rollout/native_parallel_worker.py index a47db114..43e4df5c 100644 --- a/ajet/task_rollout/native_parallel_worker.py +++ b/ajet/task_rollout/native_parallel_worker.py @@ -61,9 +61,9 @@ def step_status_printer(self, observation_window): if start == -1: print_buf += [f"[finished]:{count} threads"] print(f"Rollout progress ({token_gen_per_sec_str}): " + " // ".join(print_buf)) - if "info" in observation_window: - print_buf2 = "\t".join(observation_window["info"]) - print(print_buf2) + # if "info" in observation_window: + # print_buf2 = "\t".join(observation_window["info"]) + # print(print_buf2) def rollout_static( self, diff --git a/ajet/task_runner/swarm_runner.py b/ajet/task_runner/swarm_runner.py index cb839aa3..d8e34b77 100644 --- a/ajet/task_runner/swarm_runner.py +++ b/ajet/task_runner/swarm_runner.py @@ -88,6 +88,7 @@ def register_episode_and_wait_output( logger.warning(f"Received reset command for episode {episode_uuid}.") context_tracker.reset() zmq_socket.send_string("ack") + continue elif message == "RUNNER.SPECIAL.ABORT": logger.warning(f"Received abort command for episode {episode_uuid}.") context_tracker.reset() @@ -104,8 +105,8 @@ def register_episode_and_wait_output( raise exc finally: + tuner.terminate_episode() # this is very important to avoid resource leak zmq_socket.close() - tuner.terminate_episode() if ipc_path and os.path.exists(ipc_path): os.remove(ipc_path) return final_output diff --git a/ajet/tuner_lib/weight_tuner/experimental/as_swarm_client.py b/ajet/tuner_lib/weight_tuner/experimental/as_swarm_client.py index e7117db2..09b05f88 100644 --- a/ajet/tuner_lib/weight_tuner/experimental/as_swarm_client.py +++ b/ajet/tuner_lib/weight_tuner/experimental/as_swarm_client.py @@ -26,6 +26,7 @@ def __init__(self, server_url: str): self.server_url = server_url self.client_uuid = str(uuid.uuid4()) self.previous_warning_time = 0 + self.record_episode_expire_time = {} def begin_episode(self, allow_discard_timeout=60, episode_type="train") -> Tuple[str, OpenaiBaseUrlAndApiKey]: @@ -48,6 +49,7 @@ def begin_episode(self, allow_discard_timeout=60, episode_type="train") -> Tuple resp.raise_for_status() data = ClaimEpisodeResponse.model_validate(resp.json()) episode_uuid = data.episode_uuid + self.record_episode_expire_time[episode_uuid] = time.time() + allow_discard_timeout if data.success: episode_uuid = data.episode_uuid @@ -82,6 +84,11 @@ def end_episode(self, task:Task, episode_uuid: str, workflow_output: WorkflowOut logger.error("No episode to end.") return + remain_time = self.record_episode_expire_time.get(episode_uuid, 0) - time.time() + if remain_time < 0: + logger.warning(f"Episode {episode_uuid} has expired (expired {remain_time} seconds ago). Please use a larger `allow_discard_timeout` when `begin_episode`. Skipping end_episode.") + return + try: task_id = task.task_id workflow_output.metadata["task_id"] = task_id @@ -131,7 +138,7 @@ def abort_episode(self, episode_uuid: str): data = EndEpisodeResponse.model_validate(resp.json()) if data.success: - logger.info(f"Ended episode {episode_uuid}") + logger.info(f"Aborted episode {episode_uuid}") else: logger.error(f"Failed to end episode {episode_uuid}") diff --git a/ajet/tuner_lib/weight_tuner/experimental/as_swarm_server.py b/ajet/tuner_lib/weight_tuner/experimental/as_swarm_server.py index 9158ec6b..392b0fe5 100644 --- a/ajet/tuner_lib/weight_tuner/experimental/as_swarm_server.py +++ b/ajet/tuner_lib/weight_tuner/experimental/as_swarm_server.py @@ -61,7 +61,7 @@ def register_enable_swarm_mode_routes( # ------------------------------------------------------------------------------------------------ async def find_claimed_episodes_that_need_to_be_unclaimed() -> List[str]: - result = [] + to_unclaim_episodes = [] current_time = time.time() for k, v in shared_mem_dict.items(): @@ -69,12 +69,12 @@ async def find_claimed_episodes_that_need_to_be_unclaimed() -> List[str]: es:EpisodeStatus = v if es.episode_status == "claimed": if (current_time - es.latest_activity_timestamp) > es.allow_discard_timeout: - result.append(es.episode_uuid) + to_unclaim_episodes.append(es.episode_uuid) - for episode_uuid in result: + for episode_uuid in to_unclaim_episodes: await _revert_episode_to_unclaimed(episode_uuid, shared_mem_dict, shared_mem_dict_lock) - return result + return to_unclaim_episodes def _context_tracker_reset_blocking(episode_uuid, shared_mem_dict): # must async # send message to context tracker @@ -110,6 +110,8 @@ def _context_tracker_reset_blocking(episode_uuid, shared_mem_dict): # must asyn async def _revert_episode_to_unclaimed(episode_uuid: str, shared_mem_dict, shared_mem_dict_lock): # check status again, because other thread may have changed it if shared_mem_dict[ep_key(episode_uuid)].episode_status != "claimed": + if episode_uuid in shared_mem_dict['unclaimed_episodes']: pass + else: shared_mem_dict['unclaimed_episodes'] += [episode_uuid] return # reset context tracker @@ -126,17 +128,15 @@ async def _revert_episode_to_unclaimed(episode_uuid: str, shared_mem_dict, share es.allow_discard_timeout = -1 with shared_mem_dict_lock: shared_mem_dict[ep_key(episode_uuid)] = es - if episode_uuid in shared_mem_dict['unclaimed_episodes']: - pass - else: - shared_mem_dict['unclaimed_episodes'] += [episode_uuid] + if episode_uuid in shared_mem_dict['unclaimed_episodes']: pass + else: shared_mem_dict['unclaimed_episodes'] += [episode_uuid] def _delete_episode_record(episode_uuid: str, shared_mem_dict, shared_mem_dict_lock): with shared_mem_dict_lock: # remove episode record if ep_key(episode_uuid) in shared_mem_dict: - del shared_mem_dict[ep_key(episode_uuid)] + del shared_mem_dict[ep_key(episode_uuid)] # RM-- logger.info(f"Deleted episode record for {episode_uuid}.") # remove from unclaimed list if present if episode_uuid in shared_mem_dict['unclaimed_episodes']: @@ -499,7 +499,17 @@ async def end_episode(req: EndEpisodeRequest): # send workflow_output to zmq assert 'episodes' in shared_mem_dict - episode_type = shared_mem_dict[ep_key(episode_uuid)].episode_type + ep_stat = shared_mem_dict[ep_key(episode_uuid)] + episode_type = ep_stat.episode_type + episode_status = ep_stat.episode_status + client_uuid_recorded = ep_stat.client_uuid + if client_uuid_recorded != client_uuid: + logger.error(f"[server] Episode {episode_uuid} is claimed by different client: {client_uuid_recorded}, but got {client_uuid}.") + raise HTTPException(status_code=404, detail=f"Episode {episode_uuid} is claimed by different client: {client_uuid_recorded}, but got {client_uuid}.") + + if episode_status != "claimed": + logger.error(f"[server] Episode {episode_uuid} is not in claimed status.") + raise HTTPException(status_code=400, detail=f"Episode {episode_uuid} is not in claimed status, maybe you take too long to submit.") if episode_type == "train": # _register_final_episode_output_blocking(episode_uuid, workflow_output, shared_mem_dict, shared_mem_dict_lock) # must async diff --git a/ajet/utils/thread_executors.py b/ajet/utils/thread_executors.py index 1ab02baf..797ac54c 100644 --- a/ajet/utils/thread_executors.py +++ b/ajet/utils/thread_executors.py @@ -1,14 +1,14 @@ +from concurrent.futures import ThreadPoolExecutor from ajet.utils.sington import singleton -import concurrent.futures - +import threading @singleton class SharedInterchangeThreadExecutor: def __init__(self, max_workers=64): - self.executor = concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) + self.executor = ThreadPoolExecutor(max_workers=max_workers) - def get_shared_executor(self) -> concurrent.futures.ThreadPoolExecutor: + def get_shared_executor(self) -> ThreadPoolExecutor: return self.executor @@ -16,7 +16,27 @@ def get_shared_executor(self) -> concurrent.futures.ThreadPoolExecutor: @singleton class SharedInferenceTrackerThreadExecutor: def __init__(self, max_workers=64): - self.executor = concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) + self.executor = ThreadPoolExecutor(max_workers=max_workers) - def get_shared_executor(self) -> concurrent.futures.ThreadPoolExecutor: + def get_shared_executor(self) -> ThreadPoolExecutor: return self.executor + + +class BoundedThreadPoolExecutor: + def __init__(self, max_workers, max_queue_size=100): + self.executor = ThreadPoolExecutor(max_workers=max_workers) + self.semaphore = threading.Semaphore(max_queue_size) + + def submit(self, fn, *args, **kwargs): + self.semaphore.acquire() + + def wrapped_fn(*args, **kwargs): + try: + return fn(*args, **kwargs) + finally: + self.semaphore.release() + + return self.executor.submit(wrapped_fn, *args, **kwargs) + + def shutdown(self, wait=True): + self.executor.shutdown(wait=wait) \ No newline at end of file diff --git a/tutorial/example_academic_trans/trans.py b/tutorial/example_academic_trans/trans.py index e73844d4..49120d41 100644 --- a/tutorial/example_academic_trans/trans.py +++ b/tutorial/example_academic_trans/trans.py @@ -49,7 +49,7 @@ def execute_agent(task: Task, api_baseurl_key: OpenaiBaseUrlAndApiKey): grader = TranslationQualityGrader( model=OpenAIChatModel(base_url=grader_base_url, api_key=grader_api_key, model="qwen3-max-2026-01-23") ) - grader_score = asyncio.run(grader.aevaluate(original_text=abstract, translation=final_translation)) + grader_score = asyncio.run(asyncio.wait_for(grader.aevaluate(original_text=abstract, translation=final_translation), timeout=120)) raw_reward = grader_score.score print(f"Grader Score: {grader_score.score}, Reason: {grader_score.reason}, Metadata: {grader_score.metadata}") return WorkflowOutput(reward=raw_reward, metadata={ @@ -111,7 +111,8 @@ def detect_hard_proper_nouns(messages, base_url, api_key, abstract, rough_transl response = client.chat.completions.create( model="qwen3-max-2026-01-23", messages=messages, - extra_body={"enable_thinking":True} + timeout=60, + # extra_body={"enable_thinking":True} ) fix_nouns = response.choices[0].message.content messages += [ diff --git a/tutorial/example_academic_trans/trans_roll.py b/tutorial/example_academic_trans/trans_roll.py index f68733b5..4630218d 100644 --- a/tutorial/example_academic_trans/trans_roll.py +++ b/tutorial/example_academic_trans/trans_roll.py @@ -8,7 +8,7 @@ from ajet.tuner_lib.weight_tuner.experimental.as_swarm_client import SwarmClient from ajet.default_config.ajet_default import AjetTaskReader, HuggingfaceDatRepo from ajet.tuner_lib.weight_tuner.as_oai_baseurl_apikey import OpenaiBaseUrlAndApiKey -from ajet import WorkflowOutput +from ajet.utils.thread_executors import BoundedThreadPoolExecutor from ajet.schema.task import Task from ajet.task_reader import RouterTaskReader from ajet.utils.retry import retry_with_backoff @@ -56,7 +56,7 @@ def main(): model=REMOTE_TRAIN_MODEL_01, batch_size=REMOTE_BATCH_SIZE, grpo_n=LOCAL_GRPO_N, - ) + ), ) def rollout(task): @@ -80,20 +80,17 @@ def rollout(task): logger.exception("Exception during rollout group", e) task_batch = [] + executor = BoundedThreadPoolExecutor(max_workers=LOCAL_MAX_PARALLEL, max_queue_size=LOCAL_MAX_PARALLEL*2) for i, task in enumerate(dataset.generate_training_tasks()): task_batch += [task] if len(task_batch) == REMOTE_BATCH_SIZE: print('*********** beginning a new batch of tasks... ***********') - with ThreadPoolExecutor(max_workers=LOCAL_MAX_PARALLEL) as executor: - for task in task_batch: - executor.submit(rollout, task) - executor.shutdown(wait=True) + for task in task_batch: + executor.submit(rollout, task) task_batch = [] - print('*********** tasks completed, wait a minute... ***********') - time.sleep(60) - + executor.shutdown(wait=True) return None From 4cb513bc3a54d5849f63c437da18f0cf2b0a1025 Mon Sep 17 00:00:00 2001 From: "qingxu.fu" Date: Fri, 6 Feb 2026 13:18:16 +0800 Subject: [PATCH 23/25] enhance error logging during tracker.tokenize() for better debugging --- ajet/context_tracker/basic_tracker.py | 48 ++++++++++----------- ajet/task_rollout/native_parallel_worker.py | 1 + 2 files changed, 24 insertions(+), 25 deletions(-) diff --git a/ajet/context_tracker/basic_tracker.py b/ajet/context_tracker/basic_tracker.py index 44d81cb7..9f70d8ab 100644 --- a/ajet/context_tracker/basic_tracker.py +++ b/ajet/context_tracker/basic_tracker.py @@ -262,30 +262,28 @@ def tokenize_steps( # check reward structure self.reward_structure: Reward # type: ignore - assert ( - self.reward_structure.step_reward_arr is not None - ), "must call `process_reward` before tokenize_steps" - assert len(self.reward_structure.step_reward_arr) == total_steps + assert self.reward_structure.step_reward_arr is not None, "must call `process_reward` before tokenize_steps" + assert len(self.reward_structure.step_reward_arr) == total_steps, f"reward step count {len(self.reward_structure.step_reward_arr)} != total_steps {total_steps}" # mapping input_ids = [] input_logprobs = [] attention_mask = [] loss_mask = [] - split_prompt_reponse_index = -1 + split_prompt_response_index = -1 split_point_message_left_index = -1 input_ids_len = [] # cat all messages for i, ext_msg in enumerate(ext_steps): # find split index, this have to be done before input_ids += ext_msg.token_arr - if (split_prompt_reponse_index == -1) and (ext_msg.need_training): - split_prompt_reponse_index = len(input_ids) + if (split_prompt_response_index == -1) and (ext_msg.need_training): + split_prompt_response_index = len(input_ids) split_point_message_left_index = i - 1 assert ( split_point_message_left_index >= 0 ), "There should be at least one message before the first training message" - assert split_prompt_reponse_index == input_ids_len[split_point_message_left_index] + assert split_prompt_response_index == input_ids_len[split_point_message_left_index] assert ( ext_msg.author == "llm" ), "The first message after initialization should be from LLM, not from env or user" @@ -304,37 +302,37 @@ def tokenize_steps( # move the split index forward MAX_FORWARD_STEPS = 100 for i in range(MAX_FORWARD_STEPS): - if loss_mask[split_prompt_reponse_index] == 0: - split_prompt_reponse_index += 1 + if loss_mask[split_prompt_response_index] == 0: + split_prompt_response_index += 1 else: break # no matter what, the split index should not exceed max prompt length # make sure that the prompt length does not exceed `config.ajet.data.max_prompt_length` - if split_prompt_reponse_index > self.config.ajet.data.max_prompt_length: - split_prompt_reponse_index = self.config.ajet.data.max_prompt_length + if split_prompt_response_index > self.config.ajet.data.max_prompt_length: + split_prompt_response_index = self.config.ajet.data.max_prompt_length # check assert len(ext_steps) == len( input_ids_len ), "length of ext_steps and input_ids_len should be equal" assert ( - split_prompt_reponse_index != -1 - ), "split_prompt_reponse_index should not be -1, at least one message should be in the context" + split_prompt_response_index != -1 + ), "split_prompt_response_index should not be -1, at least one message should be in the context" position_ids = compute_position_id_with_mask(torch.tensor(attention_mask)).tolist() # sperate prompt and response - prompt_ids = input_ids[:split_prompt_reponse_index] - prompt_attention_mask = attention_mask[:split_prompt_reponse_index] - prompt_position_ids = position_ids[:split_prompt_reponse_index] - prompt_loss_mask = loss_mask[:split_prompt_reponse_index] - prompt_logprobs = input_logprobs[:split_prompt_reponse_index] - - response_ids = input_ids[split_prompt_reponse_index:] - response_attention_mask = attention_mask[split_prompt_reponse_index:] - response_position_ids = position_ids[split_prompt_reponse_index:] - response_loss_mask = loss_mask[split_prompt_reponse_index:] - response_logprobs = input_logprobs[split_prompt_reponse_index:] + prompt_ids = input_ids[:split_prompt_response_index] + prompt_attention_mask = attention_mask[:split_prompt_response_index] + prompt_position_ids = position_ids[:split_prompt_response_index] + prompt_loss_mask = loss_mask[:split_prompt_response_index] + prompt_logprobs = input_logprobs[:split_prompt_response_index] + + response_ids = input_ids[split_prompt_response_index:] + response_attention_mask = attention_mask[split_prompt_response_index:] + response_position_ids = position_ids[split_prompt_response_index:] + response_loss_mask = loss_mask[split_prompt_response_index:] + response_logprobs = input_logprobs[split_prompt_response_index:] tracker_tokenized = {} tracker_tokenized["input_ids"] = input_ids diff --git a/ajet/task_rollout/native_parallel_worker.py b/ajet/task_rollout/native_parallel_worker.py index 43e4df5c..23379819 100644 --- a/ajet/task_rollout/native_parallel_worker.py +++ b/ajet/task_rollout/native_parallel_worker.py @@ -606,6 +606,7 @@ def trajectories_to_samples(self, tracker_array: List[BaseContextTracker]) -> Li except Exception as e: raise e finally: + logger.bind(exception=True).exception("Error during tracker.tokenize()") # for debugging tracker.generate_log(global_step=self.current_global_steps) if os.environ.get("BEST_LOGGER_PATH", None) and os.environ.get( "AJET_DEBUG", None From 5132c2b2e2bc4998daf15a881a0c1ebf38cdb3a5 Mon Sep 17 00:00:00 2001 From: "qingxu.fu" Date: Mon, 9 Feb 2026 11:47:22 +0800 Subject: [PATCH 24/25] improve readability --- ajet/backbone/trainer_verl.py | 5 ++- ajet/context_tracker/base_tracker.py | 43 +++++++++++++-------- ajet/context_tracker/basic_tracker.py | 1 - ajet/context_tracker/multiagent_tracking.py | 14 ++++++- ajet/task_rollout/native_parallel_worker.py | 21 ++++++++-- ajet/task_rollout/single_worker.py | 12 ++++-- ajet/task_runner/swarm_runner.py | 2 + 7 files changed, 71 insertions(+), 27 deletions(-) diff --git a/ajet/backbone/trainer_verl.py b/ajet/backbone/trainer_verl.py index 1ddb0fef..c5da1334 100644 --- a/ajet/backbone/trainer_verl.py +++ b/ajet/backbone/trainer_verl.py @@ -563,8 +563,7 @@ def fit(self): # noqa: C901 # pass global_steps to trace gen_batch.meta_info["global_steps"] = self.global_steps is_last_step = self.global_steps >= self.total_training_steps - from ajet import bp - bp("BATCH") + with marked_timer("step", timing_raw): # generate a batch logger.info("rollout step begin") @@ -597,6 +596,8 @@ def fit(self): # noqa: C901 context_tracker_arr: List[BaseContextTracker] = self.parallel_env.rollout( tasks, mode="sample", epoch=f"train.{epoch}" ) + from ajet import bp + bp("BATCH") logger.info("end fit rollout") gen_batch_output = self.parallel_env.to_dataproto(context_tracker_arr) logger.info("end dataproto convertion") diff --git a/ajet/context_tracker/base_tracker.py b/ajet/context_tracker/base_tracker.py index 3baa7095..be87f012 100644 --- a/ajet/context_tracker/base_tracker.py +++ b/ajet/context_tracker/base_tracker.py @@ -113,34 +113,45 @@ def replace_token_ids( class BaseTracker(object): def __init__(self, config, tokenizer, workflow_task: WorkflowTask, **kwargs): + # disable read only mode + self._read_only = False + self._discarded = False + + # task related info self.workflow_task = workflow_task self.task_batch_index = self.workflow_task.task_batch_index self.task_tag: str = self.workflow_task.task_tag self.task_id: str = self.workflow_task.task_id self.episode_uuid = self.workflow_task.episode_uuid - self.config = config + # tokenizer self.tokenizer = tokenizer + self.blackout_token_combo = tokenizer.encode("<|im_start|>assistant\n") + self._im_start_token_id = tokenizer.encode("<|im_start|>")[0] + + # config + self.config = config self.saved_timelines: List[List[ExtendedMessage]] = [] self.current_context_status = "" + + # length control max_response_length = self.config.ajet.rollout.max_response_length_in_one_turn max_model_len: int = self.config.ajet.rollout.max_model_len self.max_seq_length: int = max_model_len - max_response_length - self.blackout_token_combo = tokenizer.encode("<|im_start|>assistant\n") - self._im_start_token_id = tokenizer.encode("<|im_start|>")[0] - self.generated_token_cnt = 0 - self.terminal_rewards_dict = {} - self.discarded = False - self.is_terminated = False - self.reward_structure: Union[Reward, None] = None - self.context_time_cost = 0 + + self.generation_prompt_token = None + self.log_metrics: Optional[Dict[str, Union[float, List[float], Dict[str, Any]]]] = None # Initialize workflow_metadata to store tool statistics + + # meta data attributes self.tag = "" + self.round_cnt = 0 + self.generated_token_cnt = 0 self.current_batch_success_rate: float = float("-inf") self.current_batch_reward: float = float("-inf") + + # reward and madness detection + self.reward_structure: Union[Reward, None] = None self.already_mad_flag: bool = False - self.round_cnt = 0 - self.generation_prompt_token = None - self.log_metrics: Optional[Dict[str, Union[float, List[float], Dict[str, Any]]]] = None # Initialize workflow_metadata to store tool statistics assert ( self.config.ajet.data.max_prompt_length @@ -149,13 +160,13 @@ def __init__(self, config, tokenizer, workflow_task: WorkflowTask, **kwargs): ) def reset(self): + # disable read only mode + self._read_only = False + self._discarded = False + self.saved_timelines: List[List[ExtendedMessage]] = [] self.current_context_status = "" - self.terminal_rewards_dict = {} - self.discarded = False - self.is_terminated = False self.reward_structure: Union[Reward, None] = None - self.context_time_cost = 0 self.tag = "" self.current_batch_success_rate: float = float("-inf") self.current_batch_reward: float = float("-inf") diff --git a/ajet/context_tracker/basic_tracker.py b/ajet/context_tracker/basic_tracker.py index 9f70d8ab..5df5c682 100644 --- a/ajet/context_tracker/basic_tracker.py +++ b/ajet/context_tracker/basic_tracker.py @@ -24,7 +24,6 @@ class BaseContextTracker(BaseTracker): full_context (List[ExtendedMessage]): List of all messages in the conversation current_context_status (str): Current status of the context max_seq_length (int): Maximum sequence length for the context window - terminal_rewards_dict (dict): Dictionary storing terminal rewards """ def __init__(self, config, tokenizer, **kwargs): diff --git a/ajet/context_tracker/multiagent_tracking.py b/ajet/context_tracker/multiagent_tracking.py index e8514222..a2ce50d6 100644 --- a/ajet/context_tracker/multiagent_tracking.py +++ b/ajet/context_tracker/multiagent_tracking.py @@ -216,7 +216,12 @@ def step_track( timeline_uuid: str = "", ): assert timeline_uuid in self.timeline_cache, "Timeline UUID not found in cache. Please ensure `step_prepare` is called before `step_track`." - timeline = self.timeline_cache.get(timeline_uuid, []) + + # round ++ + self.round_cnt += 1 + + # get timeline from cache + timeline = self.timeline_cache.pop(timeline_uuid, []) if not self.already_mad_flag: if ( compute_string_madness( @@ -291,6 +296,11 @@ def save_llm_interaction_timeline(self, tools, llm_ext_msg, timeline): for i in range(1, len(timeline)): assert not timeline[i].first_message + # no longer write anything + if self._read_only: + logger.exception("Timeline is in read-only mode, should not save new timeline. Please report a github issue if you see this error.") + return + # save to self.saved_timelines self.saved_timelines += [copy.deepcopy(timeline)] @@ -556,6 +566,8 @@ def generate_log(self, task_id=None, global_step="NA"): def group_merge(self) -> List[List[ExtendedMessage]]: timeline_merging_policy: TimelineMergingPolicyConfig = self.config.ajet.context_tracker.timeline_merging_policy self.saved_timelines = merge_tracker_timelines(self.saved_timelines, timeline_merging_policy) + self._read_only = True + return self.saved_timelines diff --git a/ajet/task_rollout/native_parallel_worker.py b/ajet/task_rollout/native_parallel_worker.py index 23379819..898b2a3c 100644 --- a/ajet/task_rollout/native_parallel_worker.py +++ b/ajet/task_rollout/native_parallel_worker.py @@ -255,7 +255,7 @@ def rollout_dynamic( # noqa: C901 completed_task_futures = [f for f in task_future_array if f.done()] completed_results = [f.result() for f in completed_task_futures] completed_results = [ - tracker for tracker in completed_results if not tracker.discarded + tracker for tracker in completed_results if not tracker._discarded ] reward = [ tracker.reward_structure.performance_reward for tracker in completed_results @@ -306,7 +306,7 @@ def rollout_dynamic( # noqa: C901 ) time.sleep(5) - # We have enough number of samples, but we need to wait for all threads to finish, including discarded threads + # We have enough number of samples, but we need to wait for all threads to finish, including ._discarded threads tic = -1 while any(f.running() for task_future_array in futures for f in task_future_array): tic += 1 @@ -325,7 +325,7 @@ def rollout_dynamic( # noqa: C901 completed_task_futures = [f for f in task_future_array if f.done()] completed_results = [f.result() for f in completed_task_futures] completed_results = [ - tracker for tracker in completed_results if not tracker.discarded + tracker for tracker in completed_results if not tracker._discarded ] task_cmd_reward_array = [ tracker.reward_structure.performance_reward for tracker in completed_results @@ -409,7 +409,7 @@ def rollout_dynamic( # noqa: C901 completed_task_futures = [f for f in task_future_array if f.done()] completed_results = [f.result() for f in completed_task_futures] completed_results = [ - tracker for tracker in completed_results if not tracker.discarded + tracker for tracker in completed_results if not tracker._discarded ] # in-group success rate and reward task_cmd_reward_array = [ @@ -583,6 +583,19 @@ def stop_all_threads_hard(): for ct_list in completed_task_id_map_ct.values(): tracker_array.extend(ct_list) + + # TODO: support multi-step reward + task_success_rate = np.mean( + [tracker.reward_structure.success_rate for tracker in tracker_array] + ) + task_scalar_reward = np.mean( + [tracker.reward_structure.final_scalar_reward for tracker in tracker_array] + ) + + for tracker in tracker_array: + tracker.current_batch_success_rate = float(task_success_rate) + tracker.current_batch_reward = float(task_scalar_reward) + # return all trackers return tracker_array diff --git a/ajet/task_rollout/single_worker.py b/ajet/task_rollout/single_worker.py index 56fcb54b..efbe1931 100644 --- a/ajet/task_rollout/single_worker.py +++ b/ajet/task_rollout/single_worker.py @@ -160,14 +160,18 @@ def rollout_env_worker_loop( **kwargs, ): try: + cnt = 1 + while True: - if observation_window["stop"][task_thread_index]: - print('rollout_env_worker_loop received stop signal, exiting...') + if observation_window["stop"][task_thread_index]: # since we use multi-threading, the best way to communicate with main thread is through shared memory. return - observation_window["info"][task_thread_index] = str(cnt) + observation_window["info"][task_thread_index] = str(cnt) # observe how many iterations have been done in the loop + + # Let's begin working on the task, the result `tracker` will contain everything: reward, llm calls, conversation history, etc. + # Later we will gather all trackers and do post-processing, generating samples for VeRL. tracker = self.rollout_env_worker( task=task, task_batch_index=task_batch_index, @@ -185,7 +189,9 @@ def rollout_env_worker_loop( completed_task_id_map_ct[tracker.task_id] = [tracker] else: completed_task_id_map_ct[tracker.task_id] += [tracker] + cnt += 1 + if observation_window["stop"][task_thread_index]: return else: diff --git a/ajet/task_runner/swarm_runner.py b/ajet/task_runner/swarm_runner.py index d8e34b77..4bc47a0d 100644 --- a/ajet/task_runner/swarm_runner.py +++ b/ajet/task_runner/swarm_runner.py @@ -130,6 +130,7 @@ def execute(self, workflow_task: WorkflowTask) -> BaseContextTracker: print(f'Exiting workflow worker due to interrupt signal for episode {workflow_task.episode_uuid}.') raise SwarmReceiveAbortException(f"Episode {workflow_task.episode_uuid} aborted due to interrupt signal.") + # context tracker will trace and gather everything we need for training context_tracker = MultiAgentContextTracker( llm_inference_fn=self.llm_inference_fn, tokenizer=self.tokenizer, @@ -137,6 +138,7 @@ def execute(self, workflow_task: WorkflowTask) -> BaseContextTracker: workflow_task = workflow_task, **hooks, ) + # tuner will handle the communication and provide `baseurl_apikey` tuner = AjetTuner( context_tracker=context_tracker, llm_inference_fn=self.llm_inference_fn, From 98db2b7fbb089c2f71bb2672222da270cdf22202 Mon Sep 17 00:00:00 2001 From: "qingxu.fu" Date: Mon, 9 Feb 2026 12:18:41 +0800 Subject: [PATCH 25/25] delete exit message --- ajet/backbone/trainer_verl.py | 5 +++-- ajet/task_rollout/single_worker.py | 2 +- ajet/task_runner/swarm_runner.py | 4 ++-- 3 files changed, 6 insertions(+), 5 deletions(-) diff --git a/ajet/backbone/trainer_verl.py b/ajet/backbone/trainer_verl.py index c5da1334..f1e07407 100644 --- a/ajet/backbone/trainer_verl.py +++ b/ajet/backbone/trainer_verl.py @@ -596,8 +596,9 @@ def fit(self): # noqa: C901 context_tracker_arr: List[BaseContextTracker] = self.parallel_env.rollout( tasks, mode="sample", epoch=f"train.{epoch}" ) - from ajet import bp - bp("BATCH") + + # from ajet import bp; bp("BATCH") + logger.info("end fit rollout") gen_batch_output = self.parallel_env.to_dataproto(context_tracker_arr) logger.info("end dataproto convertion") diff --git a/ajet/task_rollout/single_worker.py b/ajet/task_rollout/single_worker.py index efbe1931..103f9621 100644 --- a/ajet/task_rollout/single_worker.py +++ b/ajet/task_rollout/single_worker.py @@ -128,7 +128,7 @@ def rollout_env_worker( workflow_task=workflow_task, ) except SwarmReceiveAbortException as exc: # noqa: BLE001 - print('SwarmReceiveAbortException caught in rollout_env_worker') + # print('SwarmReceiveAbortException caught in rollout_env_worker') return None # type: ignore except TestSuccessException as e: logger.success( diff --git a/ajet/task_runner/swarm_runner.py b/ajet/task_runner/swarm_runner.py index 4bc47a0d..03d27c85 100644 --- a/ajet/task_runner/swarm_runner.py +++ b/ajet/task_runner/swarm_runner.py @@ -75,7 +75,7 @@ def register_episode_and_wait_output( message = zmq_socket.recv_string() except zmq.Again as e: if should_exit_hard(): - logger.warning(f'{episode_uuid} Exiting workflow due to should_exit_hard signal.') + # logger.warning(f'{episode_uuid} Exiting workflow due to should_exit_hard signal.') context_tracker.reset() raise SwarmReceiveAbortException(f"Episode {episode_uuid} aborted due to system exit.") else: @@ -127,7 +127,7 @@ def execute(self, workflow_task: WorkflowTask) -> BaseContextTracker: should_exit_hard = hooks['should_interrupt_hard_fn'] if should_exit_soft() or should_exit_hard(): - print(f'Exiting workflow worker due to interrupt signal for episode {workflow_task.episode_uuid}.') + # print(f'Exiting workflow worker due to interrupt signal for episode {workflow_task.episode_uuid}.') raise SwarmReceiveAbortException(f"Episode {workflow_task.episode_uuid} aborted due to interrupt signal.") # context tracker will trace and gather everything we need for training