diff --git a/ajet/__init__.py b/ajet/__init__.py index b0731e74..c7081e95 100644 --- a/ajet/__init__.py +++ b/ajet/__init__.py @@ -1,8 +1,4 @@ -from ajet.copilot.job import AgentJetJob -from ajet.schema.task import WorkflowOutput, WorkflowTask -from ajet.tuner import AjetTuner -from ajet.workflow import Workflow -from ajet.utils.vsdb import vscode_conditional_breakpoint as bp +__version__ = "0.1.0" __all__ = [ "Workflow", @@ -13,4 +9,29 @@ "bp" ] -__version__ = "0.1.0" +_LAZY_IMPORTS = { + "AjetTuner": "ajet.tuner", + "AgentJetJob": "ajet.copilot.job", + "WorkflowOutput": "ajet.schema.task", + "WorkflowTask": "ajet.schema.task", + "Workflow": "ajet.workflow", + "bp": "ajet.utils.vsdb", +} + +_ATTR_MAPPING = { + "bp": "vscode_conditional_breakpoint" +} + +def __getattr__(name): + if name in _LAZY_IMPORTS: + import importlib + module_path = _LAZY_IMPORTS[name] + module = importlib.import_module(module_path) + + attr_name = _ATTR_MAPPING.get(name, name) + value = getattr(module, attr_name) # type: ignore + + globals()[name] = value + return value + + raise AttributeError(f"module '{__name__}' has no attribute '{name}'") \ No newline at end of file diff --git a/ajet/backbone/main_verl.py b/ajet/backbone/main_verl.py index dcd575f4..4ec0cf86 100644 --- a/ajet/backbone/main_verl.py +++ b/ajet/backbone/main_verl.py @@ -22,11 +22,15 @@ import hydra import ray from beast_logger import print_dict -from loguru import logger -from omegaconf import OmegaConf +from omegaconf import DictConfig, OmegaConf from verl.trainer.ppo.reward import load_reward_manager from verl.utils.device import is_cuda_available +from verl.utils.dataset.rl_dataset import collate_fn +from torch.utils.data import Dataset as TorchDataset +# Create training and validation datasets. +from ajet.task_reader import RouterTaskReader, task_to_standard_dataset +from ajet.utils.process_dataset import create_rl_sampler from ajet.utils.core_env_vars import get_runtime_env from ajet.utils.launch_utils import set_loguru_default_color @@ -34,17 +38,17 @@ @hydra.main(config_path="config", config_name="ppo_trainer", version_base=None) -def main(config): +def main(config: DictConfig) -> None: """Main entry point for PPO training with Hydra configuration management. Args: - config_dict: Hydra configuration dictionary containing training parameters. + config: Hydra configuration dictionary containing training parameters. """ run_ppo(config) # Define a function to run the PPO-like training process -def run_ppo(config) -> None: +def run_ppo(config: DictConfig) -> None: """Initialize Ray cluster and run distributed PPO training process. Args: @@ -56,7 +60,6 @@ def run_ppo(config) -> None: if not ray.is_initialized(): # this is for local ray cluster runtime_env = get_runtime_env(config) - print_dict(runtime_env["env_vars"], "runtime_env") ray.init( runtime_env=runtime_env, num_cpus=config.ray_init.num_cpus, @@ -110,6 +113,7 @@ def run(self, config): # Print the initial configuration. `resolve=True` will evaluate symbolic values. from pprint import pprint + from loguru import logger from omegaconf import OmegaConf from verl.utils.fs import copy_to_local @@ -227,21 +231,13 @@ def run(self, config): resource_pool_spec=resource_pool_spec, mapping=mapping ) - from verl.utils.dataset.rl_dataset import collate_fn - - # Create training and validation datasets. - from ajet.task_reader import ( - RouterTaskReader, - task_to_standard_dataset, - ) - from ajet.utils.process_dataset import create_rl_sampler - task_reader = RouterTaskReader( config.ajet.task_reader.type, config.ajet.task_reader, ) - val_dataset = task_to_standard_dataset(task_reader.get_validation_tasks()) - train_dataset = task_to_standard_dataset(task_reader.get_training_tasks()) + + train_dataset: TorchDataset = task_to_standard_dataset(task_reader.generate_training_tasks) # type: ignore + val_dataset: TorchDataset = task_to_standard_dataset(task_reader.generate_validation_tasks) # type: ignore train_sampler = create_rl_sampler(config.data, train_dataset) from ajet.backbone.trainer_verl import AjetRayPPOTrainer diff --git a/ajet/backbone/main_vllm.py b/ajet/backbone/main_vllm.py index 686a35cd..2cdde610 100644 --- a/ajet/backbone/main_vllm.py +++ b/ajet/backbone/main_vllm.py @@ -144,6 +144,7 @@ def run(config): max_parallel = config.ajet.debug.debug_max_parallel n_task = config.ajet.debug.debug_first_n_tasks vllm_port = config.ajet.debug.debug_vllm_port + enable_swarm_mode = config.ajet.enable_swarm_mode # --------- init --------- async_rollout_manager = ChatCompletionScheduler( @@ -166,8 +167,10 @@ def run(config): tasks = task_reader.get_validation_tasks() logger.info(tasks[:n_task]) ctx_tracker = parallel_env.rollout( - tasks=tasks[:n_task], mode="sample", epoch="1" - ) # "sample" or "validate" + tasks=tasks[:n_task], + mode="sample" if not enable_swarm_mode else "sample-ts", # type: ignore + epoch="1" + ) _ = parallel_env.to_dataproto(ctx_tracker) @@ -186,6 +189,9 @@ def main(config): if config.ajet.enable_experimental_interchange_server: from ajet.tuner_lib.weight_tuner.experimental.as_oai_model_server import start_interchange_server start_interchange_server(config) + if config.ajet.enable_swarm_mode: + from ajet.tuner_lib.weight_tuner.experimental.interchange_utils import http_change_engine_status + http_change_engine_status(config, "ENGINE.ROLLING") def companion_launch(): import torch diff --git a/ajet/backbone/trainer_trinity.py b/ajet/backbone/trainer_trinity.py index 8000a636..7ab60715 100644 --- a/ajet/backbone/trainer_trinity.py +++ b/ajet/backbone/trainer_trinity.py @@ -206,11 +206,9 @@ def __init__(self, config): dataset_segments = [] if "train" in self.split: - dataset_segments.append(task_to_standard_dataset(task_reader.get_training_tasks())) + dataset_segments.append(task_to_standard_dataset(task_reader.generate_training_tasks)) # type: ignore if "val" in self.split: - dataset_segments.append( - task_to_standard_dataset(task_reader.get_validation_tasks()) - ) + dataset_segments.append(task_to_standard_dataset(task_reader.generate_validation_tasks)) # type: ignore if not dataset_segments: raise ValueError( f"Unsupported split '{self.split}'. Expected to contain 'train' or 'val'." diff --git a/ajet/backbone/trainer_verl.py b/ajet/backbone/trainer_verl.py index cb573457..f1e07407 100644 --- a/ajet/backbone/trainer_verl.py +++ b/ajet/backbone/trainer_verl.py @@ -99,16 +99,29 @@ def parse_reward_from_dataproto(data: DataProto, return_dict=False) -> dict | to return reward_tensor -def union_gen_batch_via_task_id(tasks, batch: DataProto, gen_batch_output: DataProto): +def union_gen_batch_via_task_id(tasks, batch: DataProto, gen_batch_output: DataProto, discard_original_batch=False): """ Union the gen_batch_output with the batch based on task_id. """ - map_task_id_to_index = {t.task_id: i for i, t in enumerate(tasks)} - gen_task_task_ids = gen_batch_output.non_tensor_batch["task_ids"] - indices = [map_task_id_to_index[tid] for tid in gen_task_task_ids] - batch_extend = batch.select_idxs(indices) - batch_final = batch_extend.union(gen_batch_output) - return batch_final + if not discard_original_batch: + map_task_id_to_index = {t.task_id: i for i, t in enumerate(tasks)} + gen_task_task_ids = gen_batch_output.non_tensor_batch["task_ids"] + indices = [map_task_id_to_index[tid] for tid in gen_task_task_ids] + batch_extend = batch.select_idxs(indices) + batch_final = batch_extend.union(gen_batch_output) + return batch_final + else: + gen_batch_output.non_tensor_batch['uid'] = gen_batch_output.non_tensor_batch["task_ids"] + task_id_counter = {} + for i, tid in enumerate(gen_batch_output.non_tensor_batch["task_ids"]): + if tid in task_id_counter: + task_id_counter[tid] += 1 + else: + task_id_counter[tid] = 1 + current_id = task_id_counter[tid] + gen_batch_output.non_tensor_batch['rollout_ids'][i] = f"T{tid}R{current_id}" + logger.info(f'task_id_counter: {task_id_counter}') + return gen_batch_output def compute_advantage( @@ -443,6 +456,12 @@ def init_workers(self): tokenizer=self.tokenizer, ) + def _update_interchange_server_status_flag(self, status: str): + if self.config.ajet.enable_experimental_interchange_server: + if self.config.ajet.enable_swarm_mode: + from ajet.tuner_lib.weight_tuner.experimental.interchange_utils import http_change_engine_status + http_change_engine_status(self.config, status) + # ####################################### # training loop # ####################################### @@ -474,7 +493,7 @@ def fit(self): # noqa: C901 # perform validation before training # currently, we only support validation using the reward_function. - if self.val_reward_fn is not None and self.config.trainer.get("val_before_train", True): + if (self.val_reward_fn is not None) and (self.config.trainer.get("val_before_train", True)) and (not self.config.ajet.enable_swarm_mode): val_metrics = self._validate() assert val_metrics, f"{val_metrics=}" pprint(f"Initial validation metrics: {val_metrics}") @@ -547,12 +566,13 @@ def fit(self): # noqa: C901 with marked_timer("step", timing_raw): # generate a batch - logger.info("=== + rollout step begin ===") + logger.info("rollout step begin") with marked_timer("gen", timing_raw, color="red"): assert self.async_rollout_mode - logger.info("=== wake up begin ===") + logger.info("wake up begin") self.async_rollout_manager.wake_up() - logger.info("=== wake up end ===") + self._update_interchange_server_status_flag("ENGINE.ROLLING") + logger.info("wake up end") tasks: List[Task] = [ dict_to_ajet_task(dict( task_id=gen_batch.non_tensor_batch["task_id"][i], @@ -571,15 +591,17 @@ def fit(self): # noqa: C901 ] ) ) - logger.info("=" * 10 + "start fit rollout" + "=" * 10) + logger.info("start fit rollout") self.parallel_env.current_global_steps = self.global_steps context_tracker_arr: List[BaseContextTracker] = self.parallel_env.rollout( tasks, mode="sample", epoch=f"train.{epoch}" ) - logger.info("=" * 10 + "end fit rollout" + "=" * 10) - logger.info("begin to convert context_tracker_arr to dataproto") + + # from ajet import bp; bp("BATCH") + + logger.info("end fit rollout") gen_batch_output = self.parallel_env.to_dataproto(context_tracker_arr) - logger.info("end convertion") + logger.info("end dataproto convertion") success_rate = [ traj.reward_structure.success_rate for traj in context_tracker_arr @@ -622,17 +644,17 @@ def fit(self): # noqa: C901 logger.info( f"gen_batch_output.info batch.keys={gen_batch_output.batch.keys()}" ) + self._update_interchange_server_status_flag("ENGINE.WEIGHT_SYNCING") self.async_rollout_manager.sleep() - logger.info("=== - rollout step end ===") + logger.info("rollout step end") - if self.config.algorithm.adv_estimator == AdvantageEstimator.REMAX: - raise NotImplementedError("REMAX is not supported in GRPO yet.") batch.non_tensor_batch["uid"] = np.array( [str(uuid.uuid4()) for _ in range(len(batch.batch))], dtype=object, ) - batch = union_gen_batch_via_task_id(tasks, batch, gen_batch_output) + discard_original_batch = self.config.ajet.enable_swarm_mode + batch = union_gen_batch_via_task_id(tasks, batch, gen_batch_output, discard_original_batch) batch.batch["response_mask"] = compute_response_mask(batch) if "response_mask" not in batch.batch.keys(): @@ -666,7 +688,7 @@ def fit(self): # noqa: C901 ) # recompute old_log_probs - logger.info("=== + compute log_probs begin ===") + logger.info("+ compute log_probs begin") with marked_timer("old_log_prob", timing_raw, color="blue"): old_log_prob = self.actor_rollout_wg.compute_log_prob(batch) entropys = old_log_prob.batch["entropys"] @@ -764,6 +786,7 @@ def fit(self): # noqa: C901 self.val_reward_fn is not None and self.config.trainer.test_freq > 0 and (is_last_step or self.global_steps % self.config.trainer.test_freq == 0) + and (not self.config.ajet.enable_swarm_mode) ): with marked_timer("testing", timing_raw, color="green"): val_metrics: dict = self._validate() @@ -914,17 +937,16 @@ def _validate(self): self.async_rollout_manager.wake_up() main_val_dataset = self.get_eval_dataset() - logger.info("=" * 10 + "start validate rollout" + "=" * 10) + logger.info("Starting validate rollout") context_tracker_arr, tasks, val_metrics = self.eval_dataset( target_dataset=main_val_dataset, target_dataset_name="main_val_dataset", mode="validate", epoch="test.1", ) - logger.info("=" * 10 + "end validate rollout" + "=" * 10) + logger.info("Completed validate rollout") test_output_gen_batch = self.parallel_env.to_dataproto(context_tracker_arr) self.async_rollout_manager.sleep() - logger.info("validation generation end") # Store generated outputs output_ids = test_output_gen_batch.batch["responses"] @@ -938,7 +960,8 @@ def _validate(self): dtype=object, ) tasks = tasks[: len(main_val_dataset)] - test_batch = union_gen_batch_via_task_id(tasks, test_batch, test_output_gen_batch) + discard_original_batch = self.config.ajet.enable_swarm_mode + test_batch = union_gen_batch_via_task_id(tasks, test_batch, test_output_gen_batch, discard_original_batch) # test_batch = test_batch.union(test_output_gen_batch) test_batch.meta_info["validate"] = True diff --git a/ajet/context_tracker/base_tracker.py b/ajet/context_tracker/base_tracker.py index 856cd89c..be87f012 100644 --- a/ajet/context_tracker/base_tracker.py +++ b/ajet/context_tracker/base_tracker.py @@ -113,34 +113,45 @@ def replace_token_ids( class BaseTracker(object): def __init__(self, config, tokenizer, workflow_task: WorkflowTask, **kwargs): + # disable read only mode + self._read_only = False + self._discarded = False + + # task related info self.workflow_task = workflow_task self.task_batch_index = self.workflow_task.task_batch_index - self.task_tag = self.workflow_task.task_tag - self.task_id = self.workflow_task.task_id + self.task_tag: str = self.workflow_task.task_tag + self.task_id: str = self.workflow_task.task_id self.episode_uuid = self.workflow_task.episode_uuid - self.config = config + # tokenizer self.tokenizer = tokenizer + self.blackout_token_combo = tokenizer.encode("<|im_start|>assistant\n") + self._im_start_token_id = tokenizer.encode("<|im_start|>")[0] + + # config + self.config = config self.saved_timelines: List[List[ExtendedMessage]] = [] self.current_context_status = "" + + # length control max_response_length = self.config.ajet.rollout.max_response_length_in_one_turn max_model_len: int = self.config.ajet.rollout.max_model_len self.max_seq_length: int = max_model_len - max_response_length - self.blackout_token_combo = tokenizer.encode("<|im_start|>assistant\n") - self._im_start_token_id = tokenizer.encode("<|im_start|>")[0] - self.generated_token_cnt = 0 - self.terminal_rewards_dict = {} - self.discarded = False - self.is_terminated = False - self.reward_structure: Union[Reward, None] = None - self.context_time_cost = 0 + + self.generation_prompt_token = None + self.log_metrics: Optional[Dict[str, Union[float, List[float], Dict[str, Any]]]] = None # Initialize workflow_metadata to store tool statistics + + # meta data attributes self.tag = "" + self.round_cnt = 0 + self.generated_token_cnt = 0 self.current_batch_success_rate: float = float("-inf") self.current_batch_reward: float = float("-inf") + + # reward and madness detection + self.reward_structure: Union[Reward, None] = None self.already_mad_flag: bool = False - self.round_cnt = 0 - self.generation_prompt_token = None - self.log_metrics: Optional[Dict[str, Union[float, List[float], Dict[str, Any]]]] = None # Initialize workflow_metadata to store tool statistics assert ( self.config.ajet.data.max_prompt_length @@ -148,6 +159,21 @@ def __init__(self, config, tokenizer, workflow_task: WorkflowTask, **kwargs): <= max_model_len ) + def reset(self): + # disable read only mode + self._read_only = False + self._discarded = False + + self.saved_timelines: List[List[ExtendedMessage]] = [] + self.current_context_status = "" + self.reward_structure: Union[Reward, None] = None + self.tag = "" + self.current_batch_success_rate: float = float("-inf") + self.current_batch_reward: float = float("-inf") + self.already_mad_flag: bool = False + self.round_cnt = 0 + self.log_metrics: Optional[Dict[str, Union[float, List[float], Dict[str, Any]]]] = None + def group_tokenize(self): raise NotImplementedError diff --git a/ajet/context_tracker/basic_tracker.py b/ajet/context_tracker/basic_tracker.py index 44d81cb7..5df5c682 100644 --- a/ajet/context_tracker/basic_tracker.py +++ b/ajet/context_tracker/basic_tracker.py @@ -24,7 +24,6 @@ class BaseContextTracker(BaseTracker): full_context (List[ExtendedMessage]): List of all messages in the conversation current_context_status (str): Current status of the context max_seq_length (int): Maximum sequence length for the context window - terminal_rewards_dict (dict): Dictionary storing terminal rewards """ def __init__(self, config, tokenizer, **kwargs): @@ -262,30 +261,28 @@ def tokenize_steps( # check reward structure self.reward_structure: Reward # type: ignore - assert ( - self.reward_structure.step_reward_arr is not None - ), "must call `process_reward` before tokenize_steps" - assert len(self.reward_structure.step_reward_arr) == total_steps + assert self.reward_structure.step_reward_arr is not None, "must call `process_reward` before tokenize_steps" + assert len(self.reward_structure.step_reward_arr) == total_steps, f"reward step count {len(self.reward_structure.step_reward_arr)} != total_steps {total_steps}" # mapping input_ids = [] input_logprobs = [] attention_mask = [] loss_mask = [] - split_prompt_reponse_index = -1 + split_prompt_response_index = -1 split_point_message_left_index = -1 input_ids_len = [] # cat all messages for i, ext_msg in enumerate(ext_steps): # find split index, this have to be done before input_ids += ext_msg.token_arr - if (split_prompt_reponse_index == -1) and (ext_msg.need_training): - split_prompt_reponse_index = len(input_ids) + if (split_prompt_response_index == -1) and (ext_msg.need_training): + split_prompt_response_index = len(input_ids) split_point_message_left_index = i - 1 assert ( split_point_message_left_index >= 0 ), "There should be at least one message before the first training message" - assert split_prompt_reponse_index == input_ids_len[split_point_message_left_index] + assert split_prompt_response_index == input_ids_len[split_point_message_left_index] assert ( ext_msg.author == "llm" ), "The first message after initialization should be from LLM, not from env or user" @@ -304,37 +301,37 @@ def tokenize_steps( # move the split index forward MAX_FORWARD_STEPS = 100 for i in range(MAX_FORWARD_STEPS): - if loss_mask[split_prompt_reponse_index] == 0: - split_prompt_reponse_index += 1 + if loss_mask[split_prompt_response_index] == 0: + split_prompt_response_index += 1 else: break # no matter what, the split index should not exceed max prompt length # make sure that the prompt length does not exceed `config.ajet.data.max_prompt_length` - if split_prompt_reponse_index > self.config.ajet.data.max_prompt_length: - split_prompt_reponse_index = self.config.ajet.data.max_prompt_length + if split_prompt_response_index > self.config.ajet.data.max_prompt_length: + split_prompt_response_index = self.config.ajet.data.max_prompt_length # check assert len(ext_steps) == len( input_ids_len ), "length of ext_steps and input_ids_len should be equal" assert ( - split_prompt_reponse_index != -1 - ), "split_prompt_reponse_index should not be -1, at least one message should be in the context" + split_prompt_response_index != -1 + ), "split_prompt_response_index should not be -1, at least one message should be in the context" position_ids = compute_position_id_with_mask(torch.tensor(attention_mask)).tolist() # sperate prompt and response - prompt_ids = input_ids[:split_prompt_reponse_index] - prompt_attention_mask = attention_mask[:split_prompt_reponse_index] - prompt_position_ids = position_ids[:split_prompt_reponse_index] - prompt_loss_mask = loss_mask[:split_prompt_reponse_index] - prompt_logprobs = input_logprobs[:split_prompt_reponse_index] - - response_ids = input_ids[split_prompt_reponse_index:] - response_attention_mask = attention_mask[split_prompt_reponse_index:] - response_position_ids = position_ids[split_prompt_reponse_index:] - response_loss_mask = loss_mask[split_prompt_reponse_index:] - response_logprobs = input_logprobs[split_prompt_reponse_index:] + prompt_ids = input_ids[:split_prompt_response_index] + prompt_attention_mask = attention_mask[:split_prompt_response_index] + prompt_position_ids = position_ids[:split_prompt_response_index] + prompt_loss_mask = loss_mask[:split_prompt_response_index] + prompt_logprobs = input_logprobs[:split_prompt_response_index] + + response_ids = input_ids[split_prompt_response_index:] + response_attention_mask = attention_mask[split_prompt_response_index:] + response_position_ids = position_ids[split_prompt_response_index:] + response_loss_mask = loss_mask[split_prompt_response_index:] + response_logprobs = input_logprobs[split_prompt_response_index:] tracker_tokenized = {} tracker_tokenized["input_ids"] = input_ids diff --git a/ajet/context_tracker/multiagent_tracking.py b/ajet/context_tracker/multiagent_tracking.py index dc192aa6..a2ce50d6 100644 --- a/ajet/context_tracker/multiagent_tracking.py +++ b/ajet/context_tracker/multiagent_tracking.py @@ -48,13 +48,15 @@ def __init__( self, tokenizer: PreTrainedTokenizer, config, - should_interrupt_fn, + should_interrupt_soft_fn, + should_interrupt_hard_fn, generated_token_callback_fn, **kwargs, ): super().__init__(config, tokenizer, **kwargs) self.tokenizer = tokenizer - self.should_interrupt_fn = should_interrupt_fn + self.should_interrupt_soft_fn = should_interrupt_soft_fn + self.should_interrupt_hard_fn = should_interrupt_hard_fn self.generated_token_callback_fn = generated_token_callback_fn self.context_overflow = False self.output_kwargs = {} @@ -214,7 +216,12 @@ def step_track( timeline_uuid: str = "", ): assert timeline_uuid in self.timeline_cache, "Timeline UUID not found in cache. Please ensure `step_prepare` is called before `step_track`." - timeline = self.timeline_cache.get(timeline_uuid, []) + + # round ++ + self.round_cnt += 1 + + # get timeline from cache + timeline = self.timeline_cache.pop(timeline_uuid, []) if not self.already_mad_flag: if ( compute_string_madness( @@ -289,10 +296,15 @@ def save_llm_interaction_timeline(self, tools, llm_ext_msg, timeline): for i in range(1, len(timeline)): assert not timeline[i].first_message + # no longer write anything + if self._read_only: + logger.exception("Timeline is in read-only mode, should not save new timeline. Please report a github issue if you see this error.") + return + # save to self.saved_timelines self.saved_timelines += [copy.deepcopy(timeline)] - # DEBUG = True # warn when merge fails + # warn when merge fails timeline_merging_policy: TimelineMergingPolicyConfig = self.config.ajet.context_tracker.timeline_merging_policy if ( self.config.ajet.context_tracker.detect_timeline_snap @@ -554,6 +566,8 @@ def generate_log(self, task_id=None, global_step="NA"): def group_merge(self) -> List[List[ExtendedMessage]]: timeline_merging_policy: TimelineMergingPolicyConfig = self.config.ajet.context_tracker.timeline_merging_policy self.saved_timelines = merge_tracker_timelines(self.saved_timelines, timeline_merging_policy) + self._read_only = True + return self.saved_timelines @@ -599,7 +613,7 @@ def check_context_token_num_safe( token_overflow = False else: token_overflow = True - if self.should_interrupt_fn(): + if self.should_interrupt_soft_fn(): ret = (False, token_overflow, "externally_interrupted") elif self.already_mad_flag and self.config.ajet.rollout.agent_madness_termination: ret = (False, token_overflow, "already_mad") diff --git a/ajet/copilot/job.py b/ajet/copilot/job.py index 373af631..4f0f5c7b 100644 --- a/ajet/copilot/job.py +++ b/ajet/copilot/job.py @@ -17,11 +17,7 @@ import yaml from loguru import logger -from ajet.launcher import ( - check_avail_gpu, - get_backbone_target, - setup_environment_vars, -) + from ajet.default_config.ajet_default import Config from ajet.utils.config_utils import ( expand_ajet_hierarchical_config, @@ -29,7 +25,12 @@ read_ajet_hierarchical_config, ) from ajet.utils.dynamic_import import cls_to_path -from ajet.utils.launch_utils import execute_training_process +from ajet.utils.launch_utils import ( + execute_training_process, + check_avail_gpu, + get_backbone_target, + setup_environment_vars, +) class AgentJetJob: @@ -37,30 +38,40 @@ class AgentJetJob: def __init__( self, - backbone: str = "trinity", + backbone: str = "verl", model: str = "Qwen/Qwen2___5-7B-Instruct", n_gpu: int = 8, algorithm: str = "grpo", n_gpu_for_infer: int | None = None, # only for trinity backbone + grpo_n: int = 8, + batch_size: int = 32, + swarm_mode: bool = True, *kwargs, ) -> None: self.backbone = backbone - self.config_as_dict: dict = self.build_job_from_yaml(None) + if swarm_mode: + default_yaml = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', "default_config/ajet_ts_default.yaml")) + else: + default_yaml = None + self.config_as_dict: dict = self.build_job_from_yaml(default_yaml) self.config = Config.update_from_dict_recursive(Config(), self.config_as_dict) self.config.ajet.backbone = backbone self.config.ajet.model.path = model self.config.ajet.trainer_common.n_gpus_per_node = n_gpu self.config.ajet.trainer_common.algorithm.adv_estimator = algorithm + self.config.ajet.rollout.num_repeat = grpo_n + self.config.ajet.data.train_batch_size = batch_size if n_gpu_for_infer is None and backbone == "trinity": raise ValueError("Please specify `n_gpu_for_infer` (n_gpu_for_infer < n_gpu) for trinity backbone.") - if n_gpu_for_infer is not None and backbone == "verl": + if (n_gpu_for_infer is not None) and backbone == "verl": raise ValueError("n_gpu_for_infer is only for trinity backbone, please set it to `None`.") else: - assert isinstance(n_gpu_for_infer, int) - assert n_gpu_for_infer < n_gpu, "`n_gpu_for_infer` should be less than `n_gpu`." - self.config.ajet.rollout.n_vllm_engine = n_gpu_for_infer - self.config.ajet.rollout.tensor_model_parallel_size = 1 + if backbone == "trinity": + assert isinstance(n_gpu_for_infer, int), f"`n_gpu_for_infer` should be int, got {type(n_gpu_for_infer)}." + assert n_gpu_for_infer < n_gpu, "`n_gpu_for_infer` should be less than `n_gpu`." + self.config.ajet.rollout.n_vllm_engine = n_gpu_for_infer + self.config.ajet.rollout.tensor_model_parallel_size = 1 def build_job_from_yaml(self, yaml_path: str | None) -> dict: self.exp_name = datetime.now().strftime("ajet_job_%Y%m%d_%H%M%S") diff --git a/ajet/default_config/ajet_default.py b/ajet/default_config/ajet_default.py index 9d0732e9..18ff2def 100644 --- a/ajet/default_config/ajet_default.py +++ b/ajet/default_config/ajet_default.py @@ -30,6 +30,7 @@ class AjetRollout: user_workflow: str = "tutorial.example_appworld.appworld->ExampleAgentScopeWorkflow" n_vllm_engine: int = 1 tensor_model_parallel_size: int = 1 + num_repeat: int = 8 @dataclass diff --git a/ajet/default_config/ajet_default.yaml b/ajet/default_config/ajet_default.yaml index 622669f8..f1f65b0a 100644 --- a/ajet/default_config/ajet_default.yaml +++ b/ajet/default_config/ajet_default.yaml @@ -86,6 +86,7 @@ ajet: task_reader: + # how to read dataset / environment type: huggingface_dat_repo # `env_service` or `jsonl_dataset_file` or `huggingface_dat_repo` or `data_generation` or `random_dummy` # when `type == jsonl_dataset_file` @@ -281,13 +282,17 @@ ajet: # the experimental ZeroMQ interchange server feature that allows `tuner.as_oai_baseurl_apikey` feature + enable_swarm_mode: False + # both swarm / oai share the same interchange server enable_experimental_interchange_server: False + # interchange server configuration interchange_server: interchange_method: 'ipc' # options: 'tcp' (multi-nodes) or 'ipc' (1 node) interchange_server_port: 'auto' num_fastapi_process: 2 # 1, 2 or 4 is fine - max_fastapi_threads: 128 # 64 or 128 is fine + max_fastapi_threads: 512 # 64 or 128 is fine max_inference_tracker_threads: 64 # recommend to be equal to `ajet.rollout.max_env_worker` + already_started: False # do not edit, used by `swarm` task_runner: diff --git a/ajet/default_config/ajet_ts_default.yaml b/ajet/default_config/ajet_ts_default.yaml new file mode 100644 index 00000000..8f631421 --- /dev/null +++ b/ajet/default_config/ajet_ts_default.yaml @@ -0,0 +1,51 @@ +# ------------------ main configuration ------------------ +ajet: + project_name: "ajet_default_project" + experiment_name: "read_yaml_name" + experiment_dir: "auto" # {exp-dir}/{experiment_name} + backbone: debug # `debug` or `trinity` or `verl` + + model: + # which model should be trained + path: /mnt/data_cpfs/model_cache/modelscope/hub/Qwen/Qwen/Qwen2___5-3B-Instruct + + rollout: + # the path to the workflow class + user_workflow: null + + task_reader: + type: random_dummy # `env_service` or `jsonl_dataset_file` or `huggingface_dat_repo` or `data_generation` or `random_dummy` + + task_judge: + judge_type: customized_protocol # Options: 'customized_protocol', 'rubrics_auto_grader' + judge_protocol: null # reward must come from remote user agent workflow, so set to null + + # the experimental ZeroMQ interchange server feature that allows `tuner.as_oai_baseurl_apikey` feature + enable_experimental_interchange_server: True + # train in cloud, run episode locally + enable_swarm_mode: True + # both swarm / oai share the same interchange server + interchange_server: + interchange_method: 'ipc' # options: 'tcp' (multi-nodes) or 'ipc' (1 node) + interchange_server_port: 10086 + num_fastapi_process: 2 # 1, 2 or 4 is fine + max_fastapi_threads: 512 # 64 or 128 is fine + max_inference_tracker_threads: 64 # recommend to be equal to `ajet.rollout.max_env_worker` + already_started: False # do not edit, used by `swarm` + + rollout: + # maximum number of parallel environments / simulate workers + max_env_worker: 128 + + +# ------------------ 不需要修改 ------------------ +hydra: + searchpath: + - file://ajet/default_config + - file://ajet/default_config/verl # verl only + +# ------------------ 不需要修改 ------------------ +defaults: + - verl_default # verl inherit 1/1 + - ajet_default + - _self_ diff --git a/ajet/launcher.py b/ajet/launcher.py index 40557137..3e4d2cb5 100644 --- a/ajet/launcher.py +++ b/ajet/launcher.py @@ -1,6 +1,7 @@ import argparse import os import subprocess +from types import SimpleNamespace from dotenv import load_dotenv from loguru import logger @@ -12,6 +13,11 @@ launch_logview, set_loguru_default_color, start_ray_service, + check_debugpy_version, + check_avail_gpu, + dict_to_namespace, + get_backbone_target, + setup_environment_vars, ) from ajet.utils.pty import pty_launch @@ -28,6 +34,12 @@ def parse_args(): required=False, help="verl or trinity or debug", ) + parser.add_argument( + "--swarm-server", + action="store_true", + default=False, + help="Enable Swarm server mode", + ) parser.add_argument( "--conf", type=str, @@ -50,9 +62,18 @@ def parse_args(): required=False, help="Path to configuration file", ) - - parser.add_argument("--with-ray", action="store_true", default=False, help="Launch ray") - parser.add_argument("--with-ray-cluster", action="store_true", default=False, help="Launch ray") + parser.add_argument( + "--with-ray", + action="store_true", + default=False, + help="Launch ray" + ) + parser.add_argument( + "--with-ray-cluster", + action="store_true", + default=False, + help="Launch ray" + ) parser.add_argument( "--with-appworld", action="store_true", @@ -71,7 +92,12 @@ def parse_args(): default=False, help="Launch webshop", ) - parser.add_argument("--with-bfcl", action="store_true", default=False, help="Launch bfcl") + parser.add_argument( + "--with-bfcl", + action="store_true", + default=False, + help="Launch bfcl" + ) parser.add_argument( "--with-logview", action="store_true", @@ -84,8 +110,12 @@ def parse_args(): default=False, help="Launch Crafters Env Simulation", ) - parser.add_argument("--reboot", action="store_true", default=False, help="reboot flag") - parser.add_argument("--skip-check-avail-gpu", action="store_true", default=False, help="Skip GPU availability check") + parser.add_argument( + "--skip-check-avail-gpu", + action="store_true", + default=False, + help="Skip GPU availability check" + ) parser.add_argument( "--kill", type=str, @@ -99,156 +129,31 @@ def parse_args(): default=False, help="Kill system processes (ray + vllm + python) that may block the current experiment", ) - parser.add_argument("--prefix", type=str, default="", required=False, help="Prefix for deepfinance service names") - return parser.parse_args() - - -def check_debugpy_version(): - try: - import debugpy - except ImportError: - raise RuntimeError( - "Module 'debugpy>=1.8.0' cannot be loaded. " - "Ray Debugpy Debugger will not work without 'debugpy>=1.8.0' installed. " - "Install this module using 'pip install debugpy>=1.8.0'" - ) - version = getattr(debugpy, "__version__", "0.0.0") - from packaging import version as packaging_version - - if packaging_version.parse(version) < packaging_version.parse("1.8.0"): - raise RuntimeError( - f"debugpy version {version} is too old. " - "Ray Debugpy Debugger requires 'debugpy>=1.8.0'. " - "Upgrade using 'pip install debugpy>=1.8.0'" - ) - logger.info(f"✓ debugpy version {version} meets requirement (>=1.8.0)") - - -def check_avail_gpu(min_free_ratio: float = 0.95): - """ - Ensure there is at least one GPU and all GPUs have >= min_free_ratio free memory. - - Uses `nvidia-smi` to query total and used memory for each GPU. - Raises RuntimeError if no GPU is found or any GPU violates the free ratio threshold. - """ - try: - # Query GPU memory via nvidia-smi; output in MiB - result = subprocess.run( - [ - "nvidia-smi", - "--query-gpu=name,memory.total,memory.used", - "--format=csv,noheader,nounits", - ], - stdout=subprocess.PIPE, - stderr=subprocess.PIPE, - text=True, - check=False, - ) - except FileNotFoundError: - raise RuntimeError("nvidia-smi not found. NVIDIA drivers/GPU may be unavailable.") - - if result.returncode != 0: - raise RuntimeError(f"Failed to query GPUs via nvidia-smi: {result.stderr.strip()}") - - lines = [line.strip() for line in result.stdout.splitlines() if line.strip()] - if not lines: - raise RuntimeError("No GPUs detected by nvidia-smi.") - - violations = [] - for idx, line in enumerate(lines): - # Expected format: ", , " - parts = [p.strip() for p in line.split(",")] - if len(parts) < 3: - violations.append((idx, "parse-error", line)) - continue - name, total_str, used_str = parts[0], parts[1], parts[2] - try: - total = float(total_str) - used = float(used_str) - except ValueError: - violations.append((idx, "parse-error", line)) - continue - free = max(total - used, 0.0) - free_ratio = free / total if total > 0 else 0.0 - logger.info( - f"GPU {idx} ({name}): total={total:.0f} MiB, used={used:.0f} MiB, free_ratio={free_ratio:.3f}" - ) - if free_ratio < min_free_ratio: - violations.append((idx, name, f"free_ratio={free_ratio:.3f} < {min_free_ratio:.3f}")) - - if violations: - details = "; ".join([f"GPU {i} ({n}): {msg}" for i, n, msg in violations]) - raise RuntimeError( - "GPU memory check failed: all GPUs must have >= " - f"{int(min_free_ratio*100)}% free. Violations: {details}" - ) - logger.info( - f"✓ GPU check passed: {len(lines)} GPUs, all >= {int(min_free_ratio*100)}% free memory" + parser.add_argument( + "--prefix", + type=str, + default="", + required=False, + help="Prefix for deepfinance service names" ) - - -def get_backbone_target(backbone): - """ - Determine the appropriate backbone target module based on the backbone name. - - Args: - backbone (str): The backbone name (e.g., "verl", "debug", "trinity") - - Returns: - str: The full module path for the specified backbone - """ - backbone_target = "ajet.backbone.main_verl" # Default to trinity - if backbone == "verl": - backbone_target = "ajet.backbone.main_verl" - if backbone == "debug": - backbone_target = "ajet.backbone.main_vllm" - if backbone == "trinity": - backbone_target = "ajet.backbone.main_trinity" - return backbone_target - - -def setup_environment_vars(args, exp_config, main_yaml_fp): - """ - Configure environment variables based on command line arguments. - - Args: - args: Command line arguments - exp_config: Experiment configuration dictionary - main_yaml_fp: Path to main YAML configuration file - - Returns: - dict: Configured environment variables dictionary - """ - env = os.environ.copy() - if args.debug: - env["RAY_DEBUG_POST_MORTEM"] = "1" - env["DEBUG_TAGS"] = args.debug - env["RAY_record_task_actor_creation_sites"] = "true" - # assert exp_config["ajet"]["rollout"]["max_env_worker"] <= 4, "parallel worker too many for debugging mode" # type: ignore - if exp_config["ajet"]["rollout"]["max_env_worker"] > 1: # type: ignore - exp_config["ajet"]["rollout"]["max_env_worker"] = 1 - logger.warning( - "For debugging mode, max_env_worker is set to 1 to facilitate debugging." - ) - logger.warning("Debug mode is ON") - else: - logger.warning("Debug mode is OFF") - # if args.conf: - # assert exp_config["ajet"]["rollout"]["max_env_worker"] > 4, "parallel worker too few" # type: ignore - if args.backbone == "trinity": - env["AJET_CONFIG_REDIRECT"] = main_yaml_fp # type: ignore - if args.backbone == "debug": - env["AJET_DEBUG"] = "1" # type: ignore - return env, exp_config + return parser.parse_args() def check_model_file_exists(exp_config): model_path = exp_config["ajet"]["model"]["path"] # if model_path has more than 2 '/', we consider it as a dir path if model_path.count("/") > 2: - assert os.path.exists( - model_path - ), f"Model path {model_path} does not exist. Please check your configuration." + assert os.path.exists(model_path), f"Model path {model_path} does not exist. Please check your configuration." + + +def start_swarm_server(env, config): + config = dict_to_namespace(config) + assert config.ajet.enable_swarm_mode, \ + "Please enable_swarm_mode in config to start swarm server." + assert config.ajet.enable_experimental_interchange_server, \ + "Please enable_experimental_interchange_server in config to start swarm server." + from ajet.tuner_lib.weight_tuner.experimental.as_oai_model_server import start_interchange_server + start_interchange_server(config, blocking=True, env=env) def main(): @@ -283,8 +188,12 @@ def main(): # switch backbone target backbone_target = get_backbone_target(args.backbone) + # read configuration from yaml exp_config = None exp_dir = args.exp_dir or "saved_experiments" + if args.swarm_server and (not args.conf): + args.conf = os.path.abspath(os.path.join(os.path.dirname(__file__), "default_config/ajet_ts_default.yaml")) + assert os.path.exists(args.conf), "Please provide a valid config file for swarm server mode." if args.conf: yaml_path = args.conf ( @@ -294,7 +203,13 @@ def main(): exp_config, ) = prepare_experiment_config(yaml_path, exp_dir, args.backbone) + # setup environment variables env, exp_config = setup_environment_vars(args, exp_config, main_yaml_fp) + + if args.swarm_server: + start_swarm_server(env, exp_config) + return + if args.with_ray: assert ( not args.with_ray_cluster diff --git a/ajet/schema/task.py b/ajet/schema/task.py index a20a4b59..1553e20d 100644 --- a/ajet/schema/task.py +++ b/ajet/schema/task.py @@ -8,11 +8,11 @@ class Task(BaseModel): - main_query: str = Field(default="") - init_messages: List[dict] = Field(default=[]) - task_id: str = Field(default="") - env_type: str = Field(default="") - metadata: dict = Field(default_factory=dict) + main_query: str = Field(default="", description="main query or instruction for the task, maybe absent if the task has valid init_messages.") + init_messages: List[dict] = Field(default=[], description="initial messages for the task, maybe absent if the task has valid main_query.") + task_id: str = Field(default="", description="same task_id mean same task, and of course, same GRPO group.") + env_type: str = Field(default="", description="valid when the task need to interact with a gym env.") + metadata: dict = Field(default_factory=dict, description="additional metadata for the task, e.g., reference answer for eval tasks.") """ diff --git a/ajet/task_reader/__init__.py b/ajet/task_reader/__init__.py index b431456f..e1a35b5d 100644 --- a/ajet/task_reader/__init__.py +++ b/ajet/task_reader/__init__.py @@ -1,8 +1,8 @@ -from typing import List - import datasets import numpy as np +from typing import List, List, Union +from datasets import Dataset from ajet.schema.task import Task from ajet.task_reader.data_generator_reader import DataGeneratorTaskReader from ajet.task_reader.env_service_reader import EnvServiceTaskReader @@ -10,6 +10,7 @@ from ajet.task_reader.jsonl_reader import JsonlTaskReader from ajet.task_reader.task_reader_base import BaseTaskReader from ajet.task_reader.tracing_reader import TracingReader +from typing import Generator class RandomDummyTaskReader(BaseTaskReader): @@ -44,6 +45,10 @@ def get_validation_tasks(self) -> List[Task]: return self._load_dataset_split("dataset_name", "split") +def list_to_generator(tasks: List[Task]) -> Generator: + for task in tasks: + yield task + class RouterTaskReader(BaseTaskReader): def __init__(self, reader_type, reader_config): super().__init__(None) @@ -78,33 +83,39 @@ def get_validation_tasks(self) -> List[Task]: np.random.shuffle(result) # type: ignore return result + def generate_training_tasks(self) -> Generator: + if hasattr(self.task_reader, "generate_training_tasks"): + result = self.task_reader.generate_training_tasks() # type: ignore + else: + result = list_to_generator(self.task_reader.get_training_tasks()) + return result -def task_to_standard_dataset(tasks: List[Task]) -> datasets.Dataset: + def generate_validation_tasks(self) -> Generator: + if hasattr(self.task_reader, "generate_validation_tasks"): + result = self.task_reader.generate_validation_tasks() # type: ignore + else: + result = list_to_generator(self.task_reader.get_validation_tasks()) + return result + + + +def task_to_standard_dataset(gen_tasks) -> Dataset: """ - Convert a list of Task objects to a standard Hugging Face Dataset. + Convert a potentially large/infinite generator of Task objects + to a streaming Hugging Face Dataset. Args: - tasks (List[Task]): List of Task objects. + tasks: A generator or iterable producing Task objects. Returns: - datasets.Dataset: Hugging Face Dataset containing the tasks. + datasets.Dataset: A Hugging Face Dataset with streaming enabled. """ - data = { - "task_id": [], - "main_query": [], - "init_messages": [], - "env_type": [], - "metadata": [], - } + def gen(): + for task in gen_tasks(): + yield task.model_dump() - for task in tasks: - data["task_id"].append(task.task_id) - data["main_query"].append(task.main_query) - data["init_messages"].append(task.init_messages) - data["env_type"].append(task.env_type) - data["metadata"].append(task.metadata) + return Dataset.from_generator(gen) # type: ignore - return datasets.Dataset.from_dict(data) def dict_to_ajet_task(task_dict: dict) -> Task: diff --git a/ajet/task_reader/document_reader/doc_reader.py b/ajet/task_reader/document_reader/doc_reader.py index e73c3bdf..5083d33a 100644 --- a/ajet/task_reader/document_reader/doc_reader.py +++ b/ajet/task_reader/document_reader/doc_reader.py @@ -11,7 +11,7 @@ try: from unstructured.partition.auto import partition except Exception: - logger.warning("Cannot import dependency `unstructured`") + logger.info("`unstructured` is not installed.") from ajet.schema.document import Document from ajet.task_reader.document_reader.document_reader_base import ( diff --git a/ajet/task_reader/hf_dataset_reader.py b/ajet/task_reader/hf_dataset_reader.py index 381e48e2..33a269e2 100644 --- a/ajet/task_reader/hf_dataset_reader.py +++ b/ajet/task_reader/hf_dataset_reader.py @@ -1,8 +1,8 @@ -from typing import List import datasets from ajet.schema.task import Task +from typing import List, Generator from ajet.task_reader.task_reader_base import BaseTaskReader @@ -17,29 +17,38 @@ class HuggingFaceTaskReader(BaseTaskReader): def __init__(self, reader_config): super().__init__(reader_config) self.reader_config = reader_config + self.as_generator = False + self.dataset_name = self.reader_config.huggingface_dat_repo.dataset_path - def _load_dataset_split(self, dataset_name: str, split: str) -> List[Task]: + def _load_dataset_split(self, split: str): """ Load a dataset split from Hugging Face datasets. Args: - dataset_name: Name of the dataset in Hugging Face format (e.g., 'gsm8k') split: Name of the split to load (e.g., 'train', 'validation') Returns: - List[Task]: List of Task objects created from the dataset. + Generator: List of Task objects created from the dataset. """ try: - dataset = datasets.load_dataset(dataset_name, split=split) + if self.dataset_name.endswith(".parquet"): + # Load from local parquet file + dataset = datasets.load_dataset("parquet", data_files=self.dataset_name, split=split) + else: + # Load from Hugging Face hub + dataset = datasets.load_dataset(self.dataset_name, split=split) + # shuffle dataset + dataset = dataset.shuffle() except Exception as e: raise ValueError( - f"Failed to load dataset '{dataset_name}' with split '{split}': {str(e)}" + f"Failed to load dataset '{self.dataset_name}' with split '{split}': {str(e)}" ) - # if len(dataset) == 0: - # raise ValueError(f"No examples found in dataset '{dataset_name}' with split '{split}'") + if len(dataset) == 0: + raise ValueError(f"No examples found in dataset '{self.dataset_name}' with split '{split}'") + + self.as_generator = True - tasks = [] for idx, example in enumerate(dataset): # Create Task object task = Task( @@ -49,28 +58,32 @@ def _load_dataset_split(self, dataset_name: str, split: str) -> List[Task]: env_type="no_env", metadata=example, ) - tasks.append(task) + yield task - return tasks + return - def get_training_tasks(self) -> List[Task]: + def generate_training_tasks(self): """ Get training tasks from the Hugging Face dataset specified in the config. Returns: - List[Task]: List of training Task objects. + A generator of training Task objects. """ - dataset_name = self.reader_config.huggingface_dat_repo.dataset_path split = self.reader_config.huggingface_dat_repo.training_split - return self._load_dataset_split(dataset_name, split) + return self._load_dataset_split(split) - def get_validation_tasks(self) -> List[Task]: + def generate_validation_tasks(self): """ Get validation tasks from the Hugging Face dataset specified in the config. Returns: - List[Task]: List of validation Task objects. + A generator of validation Task objects. """ - dataset_name = self.reader_config.huggingface_dat_repo.dataset_path split = self.reader_config.huggingface_dat_repo.validation_split - return self._load_dataset_split(dataset_name, split) + return self._load_dataset_split(split) + + def get_training_tasks(self): + return list(self.generate_training_tasks()) + + def get_validation_tasks(self): + return list(self.generate_validation_tasks()) diff --git a/ajet/task_rollout/native_parallel_worker.py b/ajet/task_rollout/native_parallel_worker.py index 7f35aa10..898b2a3c 100644 --- a/ajet/task_rollout/native_parallel_worker.py +++ b/ajet/task_rollout/native_parallel_worker.py @@ -2,12 +2,13 @@ import os import time -from concurrent.futures import Future, ThreadPoolExecutor +from concurrent.futures import Future, ThreadPoolExecutor, wait, ALL_COMPLETED, FIRST_COMPLETED from typing import Dict, List, Literal from urllib.parse import quote import numpy as np import torch +import threading from loguru import logger from tensordict import TensorDict from torch.nn.utils.rnn import pad_sequence @@ -15,10 +16,11 @@ from verl import DataProto from verl.utils.torch_functional import pad_sequence_to_length -from ajet.context_tracker.basic_tracker import BaseContextTracker from ajet.schema.task import Task from ajet.schema.trajectory import Sample from ajet.task_rollout.single_worker import BaseRolloutManager +from ajet.context_tracker.basic_tracker import BaseContextTracker +from ajet.tuner_lib.weight_tuner.experimental.interchange_utils import http_change_engine_status class DynamicRolloutManager(BaseRolloutManager): @@ -59,6 +61,9 @@ def step_status_printer(self, observation_window): if start == -1: print_buf += [f"[finished]:{count} threads"] print(f"Rollout progress ({token_gen_per_sec_str}): " + " // ".join(print_buf)) + # if "info" in observation_window: + # print_buf2 = "\t".join(observation_window["info"]) + # print(print_buf2) def rollout_static( self, @@ -139,7 +144,9 @@ def rollout( epoch: str, ) -> List[BaseContextTracker]: """Delegate to dynamic rollout when oversampling is enabled.""" - if ( + if self.config.ajet.enable_swarm_mode: + return self.rollout_swarm(tasks, mode, epoch) + elif ( mode == "sample" and (self.rollout_n != 1) and self.config.ajet.rollout.enable_oversample @@ -248,7 +255,7 @@ def rollout_dynamic( # noqa: C901 completed_task_futures = [f for f in task_future_array if f.done()] completed_results = [f.result() for f in completed_task_futures] completed_results = [ - tracker for tracker in completed_results if not tracker.discarded + tracker for tracker in completed_results if not tracker._discarded ] reward = [ tracker.reward_structure.performance_reward for tracker in completed_results @@ -299,7 +306,7 @@ def rollout_dynamic( # noqa: C901 ) time.sleep(5) - # We have enough number of samples, but we need to wait for all threads to finish, including discarded threads + # We have enough number of samples, but we need to wait for all threads to finish, including ._discarded threads tic = -1 while any(f.running() for task_future_array in futures for f in task_future_array): tic += 1 @@ -318,7 +325,7 @@ def rollout_dynamic( # noqa: C901 completed_task_futures = [f for f in task_future_array if f.done()] completed_results = [f.result() for f in completed_task_futures] completed_results = [ - tracker for tracker in completed_results if not tracker.discarded + tracker for tracker in completed_results if not tracker._discarded ] task_cmd_reward_array = [ tracker.reward_structure.performance_reward for tracker in completed_results @@ -402,7 +409,7 @@ def rollout_dynamic( # noqa: C901 completed_task_futures = [f for f in task_future_array if f.done()] completed_results = [f.result() for f in completed_task_futures] completed_results = [ - tracker for tracker in completed_results if not tracker.discarded + tracker for tracker in completed_results if not tracker._discarded ] # in-group success rate and reward task_cmd_reward_array = [ @@ -459,6 +466,140 @@ def rollout_dynamic( # noqa: C901 return tracker_array + + def rollout_swarm( # noqa: C901 + self, + tasks: List[Task], + mode: Literal["sample", "validate"], + epoch: str, + allow_sample_num_change=True, + allow_force_stop=True, + ) -> List[BaseContextTracker]: + """ + Build a pool of threads to run context trackers in parallel, + each thread re-spawn after complete, until reaching conditions to stop. + """ + + tracker_array: List[BaseContextTracker] = [] + assert mode != "validate" + rollout_n = self.rollout_n + n_batch_task = len(tasks) + n_task = min(len(tasks), self.max_parallel // rollout_n) + assert n_task > 0, f"n_task is not valid, n_task = min(len(tasks), self.max_parallel // rollout_n) = {n_task}" + self.current_token_count_time = time.time() + + # initialize observation window + observation_window: Dict[str, List[int | bool | str]] = { + "info": ["" for _ in range(n_task * rollout_n)], + "step": [0 for _ in range(n_task * rollout_n)], + "stop": [False for _ in range(n_task * rollout_n)], + "hard_stop": [False for _ in range(n_task * rollout_n)], + "token": [0 for _ in range(n_task * rollout_n)], + } + executor = ThreadPoolExecutor(max_workers=self.max_parallel) + futures: List[Future] = [] + completed_task_id_map_ct: Dict[str, List[BaseContextTracker]] = {} + executor_lock = threading.Lock() + + # submit initial tasks + dummy_task = Task(main_query="dummy task") + for task_batch_index in range(n_task): + for task_rollout_index in range(rollout_n): + task_thread_index = task_batch_index * rollout_n + task_rollout_index + future = executor.submit( + self.rollout_env_worker_loop, + task=dummy_task, + task_tag="", + mode=mode, + task_batch_index=task_batch_index, + task_thread_index=task_thread_index, + observation_window=observation_window, + completed_task_id_map_ct=completed_task_id_map_ct, + executor_lock=executor_lock, + ) + observation_window["info"][task_thread_index] = "1" + futures.append(future) + + def enough_sample_stop_condition(completed_task_id_map_ct) -> bool: + n = 0 + for ct_list in completed_task_id_map_ct.values(): + n += len(ct_list) + print(f"Current collected samples: {n}, target: {n_batch_task * rollout_n}") + return (n >= n_batch_task * rollout_n) + + def enough_finished_task_stop_condition(completed_task_id_map_ct) -> bool: + n_finish_roll_task = 0 + for ct_list in completed_task_id_map_ct.values(): + if len(ct_list) >= rollout_n: + n_finish_roll_task += 1 + return (n_finish_roll_task >= n_batch_task) + + def enough_non_dummy_task_stop_condition(completed_task_id_map_ct) -> bool: + n_finish_roll_task = 0 + for ct_list in completed_task_id_map_ct.values(): + task_cmd_reward_array = [ + tracker.reward_structure.performance_reward for tracker in ct_list + ] + if (len(ct_list) >= rollout_n): + all_equal = all(x == task_cmd_reward_array[0] for x in task_cmd_reward_array) + if all_equal: continue + n_finish_roll_task += 1 + return (n_finish_roll_task >= n_batch_task) + + stop_condition = enough_sample_stop_condition + + def stop_all_threads_soft(): + for k in range(len(observation_window["stop"])): observation_window["stop"][k] = True + http_change_engine_status(self.config, "ENGINE.ROLLING_POST") + return + + def stop_all_threads_hard(): + for k in range(len(observation_window["hard_stop"])): observation_window["hard_stop"][k] = True + http_change_engine_status(self.config, "ENGINE.WEIGHT_SYNCING") + return + + cnt = 0 + while True: + cnt += 1 + time.sleep(2) + if (cnt % 5 == 0): + self.step_status_printer(observation_window) + meet_stop_condition_after_new_results = stop_condition(completed_task_id_map_ct) + if meet_stop_condition_after_new_results: + print("Sending soft stop signal to all threads...") + stop_all_threads_soft() + break + + # wait for all threads to complete + print('Finalizing all threads...') + executor.shutdown(wait=True) + + # stop all threads hard + print("Sending hard stop signal to all threads...") + stop_all_threads_hard() + + # build tracker_array + print('Collecting results...') + for ct_list in completed_task_id_map_ct.values(): + tracker_array.extend(ct_list) + + + # TODO: support multi-step reward + task_success_rate = np.mean( + [tracker.reward_structure.success_rate for tracker in tracker_array] + ) + task_scalar_reward = np.mean( + [tracker.reward_structure.final_scalar_reward for tracker in tracker_array] + ) + + for tracker in tracker_array: + tracker.current_batch_success_rate = float(task_success_rate) + tracker.current_batch_reward = float(task_scalar_reward) + + # return all trackers + return tracker_array + + class VerlRolloutManager(DynamicRolloutManager): """High-level manager orchestrating rollouts and batch conversion.""" @@ -478,6 +619,7 @@ def trajectories_to_samples(self, tracker_array: List[BaseContextTracker]) -> Li except Exception as e: raise e finally: + logger.bind(exception=True).exception("Error during tracker.tokenize()") # for debugging tracker.generate_log(global_step=self.current_global_steps) if os.environ.get("BEST_LOGGER_PATH", None) and os.environ.get( "AJET_DEBUG", None diff --git a/ajet/task_rollout/single_worker.py b/ajet/task_rollout/single_worker.py index d65a979a..103f9621 100644 --- a/ajet/task_rollout/single_worker.py +++ b/ajet/task_rollout/single_worker.py @@ -1,10 +1,13 @@ """Single worker primitives for environment rollouts.""" import uuid +import time +import threading from typing import Literal from loguru import logger from omegaconf import DictConfig +from typing import Dict, List, Literal from transformers.tokenization_utils import PreTrainedTokenizer from ajet.context_tracker.basic_tracker import BaseContextTracker @@ -12,7 +15,9 @@ from ajet.task_rollout.async_llm_bridge import AsyncLlmBridge from ajet.task_rollout.resource_keeper import ResourceKeeper from ajet.task_runner.general_runner import GeneralRunner +from ajet.task_runner.swarm_runner import SwarmRunner from ajet.utils.retry import retry_with_backoff +from ajet.utils.retry import SwarmReceiveAbortException from ajet.utils.sample import get_sample_params from ajet.utils.testing_utils import TestFailException, TestSuccessException @@ -59,6 +64,7 @@ def __init__( assert isinstance(self.pad_token_id, int), "pad_token_id must be an integer" self.current_token = 0 self.current_global_steps: int | str = "NA" + self.enable_swarm_mode = config.ajet.enable_swarm_mode self.async_llm_bridge = AsyncLlmBridge( config=config, async_rollout_manager=async_rollout_manager, @@ -110,12 +116,20 @@ def rollout_env_worker( with ResourceKeeper(workflow_task, config=self.config) as resource_keeper: try: workflow_task = resource_keeper.prepare() - agent_runner = GeneralRunner( - llm_inference_fn=llm_inference_fn, tokenizer=self.tokenizer, config=self.config - ) + if self.enable_swarm_mode: + agent_runner = SwarmRunner( + llm_inference_fn=llm_inference_fn, tokenizer=self.tokenizer, config=self.config + ) + else: + agent_runner = GeneralRunner( + llm_inference_fn=llm_inference_fn, tokenizer=self.tokenizer, config=self.config + ) tracker = agent_runner.execute( workflow_task=workflow_task, ) + except SwarmReceiveAbortException as exc: # noqa: BLE001 + # print('SwarmReceiveAbortException caught in rollout_env_worker') + return None # type: ignore except TestSuccessException as e: logger.success( f"env_worker.agent_flow completed with TestSuccessException: {e.args}" @@ -131,3 +145,60 @@ def rollout_env_worker( raise e return tracker + + + def rollout_env_worker_loop( + self, + task: Task, + task_batch_index: int, + task_tag: str, + mode: Literal["sample", "validate"], + task_thread_index: int, + observation_window: dict, + completed_task_id_map_ct: Dict[str, List[BaseContextTracker]], + executor_lock: threading.Lock, + **kwargs, + ): + try: + + cnt = 1 + + while True: + + if observation_window["stop"][task_thread_index]: # since we use multi-threading, the best way to communicate with main thread is through shared memory. + return + + observation_window["info"][task_thread_index] = str(cnt) # observe how many iterations have been done in the loop + + # Let's begin working on the task, the result `tracker` will contain everything: reward, llm calls, conversation history, etc. + # Later we will gather all trackers and do post-processing, generating samples for VeRL. + tracker = self.rollout_env_worker( + task=task, + task_batch_index=task_batch_index, + task_tag=task_tag, + mode=mode, + task_thread_index=task_thread_index, + observation_window=observation_window, + **kwargs, + ) + + # avoid write conflict + if tracker and tracker.reward_structure: + with executor_lock: + if tracker.task_id not in completed_task_id_map_ct: + completed_task_id_map_ct[tracker.task_id] = [tracker] + else: + completed_task_id_map_ct[tracker.task_id] += [tracker] + + cnt += 1 + + if observation_window["stop"][task_thread_index]: + return + else: + del tracker + + except Exception as e: + logger.exception( + f"encounter exception in env_worker_loop error={e.args}" + ) + raise e \ No newline at end of file diff --git a/ajet/task_runner/base_runner.py b/ajet/task_runner/base_runner.py index d8c15492..32f47fa1 100644 --- a/ajet/task_runner/base_runner.py +++ b/ajet/task_runner/base_runner.py @@ -11,6 +11,7 @@ from ajet.utils.async_utils import run_async_coroutine_with_timeout from ajet.utils.dynamic_import import dynamic_import from ajet.workflow import Workflow +from ajet.tuner_lib.weight_tuner.experimental.interchange_utils import is_episode_claimed gc_lock = Lock() @@ -48,18 +49,28 @@ def get_judge(self) -> BaseJudge: # type: ignore def runner_hooks(self, observation_window, task_thread_index, workflow_task): - def should_interrupt_fn() -> bool: - if (observation_window["stop"] is not None) and observation_window["stop"][ - task_thread_index - ]: # Check if the thread should stop (because other threads have completed, making this thread useless) + def should_interrupt_soft_fn() -> bool: + if (observation_window["stop"] is not None) and observation_window["stop"][task_thread_index]: # Check if the thread should stop (because other threads have completed, making this thread useless) return True return False + def should_interrupt_hard_fn() -> bool: + if (observation_window["hard_stop"] is not None) and observation_window["hard_stop"][task_thread_index]: # Check if the thread should stop (because other threads have completed, making this thread useless) + return True + if (observation_window["stop"] is not None) and observation_window["stop"][task_thread_index]: # check soft condition + # if soft condition met, check if episode is claimed + has_claimed = is_episode_claimed(self.config, workflow_task.episode_uuid) + if not has_claimed: + # if not claimed by now (ENGINE.ROLLING_POST), this episode will never be claimed again, so we can hard stop + return True + return False + def generated_token_callback_fn(token_array): observation_window["token"][task_thread_index] += len(token_array) return { - "should_interrupt_fn": should_interrupt_fn, + "should_interrupt_soft_fn": should_interrupt_soft_fn, + "should_interrupt_hard_fn": should_interrupt_hard_fn, "generated_token_callback_fn": generated_token_callback_fn, } diff --git a/ajet/task_runner/general_runner.py b/ajet/task_runner/general_runner.py index 88f9ab11..a3e3db92 100644 --- a/ajet/task_runner/general_runner.py +++ b/ajet/task_runner/general_runner.py @@ -1,6 +1,6 @@ -from ajet import AjetTuner -from ajet import WorkflowOutput +from ajet.tuner import AjetTuner +from ajet.schema.task import WorkflowOutput, WorkflowTask from ajet.context_tracker.multiagent_tracking import ( MultiAgentContextTracker, ) diff --git a/ajet/task_runner/swarm_runner.py b/ajet/task_runner/swarm_runner.py new file mode 100644 index 00000000..03d27c85 --- /dev/null +++ b/ajet/task_runner/swarm_runner.py @@ -0,0 +1,208 @@ + +import atexit +import json +import zmq +import os +from ajet.tuner import AjetTuner +from ajet.schema.task import WorkflowOutput +from ajet.context_tracker.multiagent_tracking import MultiAgentContextTracker +from ajet.context_tracker.basic_tracker import BaseContextTracker +from ajet.schema.task import WorkflowTask +from ajet.schema.trajectory import Reward +from ajet.task_runner.base_runner import BaseAgentRunner +from ajet.utils.retry import SwarmReceiveAbortException +from ajet.tuner_lib.weight_tuner.experimental.interchange_utils import http_register_episode, get_zmq_socket, is_episode_claimed +from loguru import logger +from ajet import Workflow +from typing import Callable + +# DEBUG = True +DEBUG = False + +context = zmq.Context() +atexit.register(context.term) + +class SwarmRunner(BaseAgentRunner): + + def register_episode_and_wait_output( + self, + episode_uuid: str, + openai_base_url: str, + openai_api_key: str, + context_tracker: BaseContextTracker, + tuner:AjetTuner, + should_exit_soft:Callable, + should_exit_hard:Callable + ) -> WorkflowOutput | None: + """Register the episode as ready in the Swarm data interchange center.""" + # parse episode_uuid, openai_base_url, openai_api_key + zmq_listen_result_addr, ipc_path = get_zmq_socket(self.config, episode_uuid, tag="workflow") + success = http_register_episode( + self.config, + episode_uuid=episode_uuid, + openai_base_url=openai_base_url, + openai_api_key=openai_api_key, + zmq_listen_result_addr=zmq_listen_result_addr, + should_exit_soft=should_exit_soft, + ) + if not success: + return None # type: ignore + + if DEBUG: logger.info(f"zmq_listen_result_addr: {zmq_listen_result_addr}") + + # begin wait for result + zmq_socket = zmq.Context().socket(zmq.REP) + zmq_socket.bind(zmq_listen_result_addr) + zmq_socket.setsockopt(zmq.RCVTIMEO, 1*1000) # 1 second timeout for REP + + speicial_messages = [ + "RUNNER.SPECIAL.RESET_CONTEXT_TRACKER", + "RUNNER.SPECIAL.ABORT" + ] + + try: + + while True: + # : + # : ajet/tuner_lib/weight_tuner/experimental/as_swarm_server.py + # : socket.send_string(workflow_output.model_dump_json()) + # : workflow_output: WorkflowOutput + # : + # : ajet/tuner_lib/weight_tuner/experimental/as_swarm_server.py + # : socket.send_string("RUNNER.SPECIAL.RESET_CONTEXT_TRACKER") + # : "RUNNER.SPECIAL.RESET_CONTEXT_TRACKER" + try: + message = zmq_socket.recv_string() + except zmq.Again as e: + if should_exit_hard(): + # logger.warning(f'{episode_uuid} Exiting workflow due to should_exit_hard signal.') + context_tracker.reset() + raise SwarmReceiveAbortException(f"Episode {episode_uuid} aborted due to system exit.") + else: + continue + # process messages + if message not in speicial_messages: + zmq_socket.send_string("ack") + break + elif message == "RUNNER.SPECIAL.RESET_CONTEXT_TRACKER": + logger.warning(f"Received reset command for episode {episode_uuid}.") + context_tracker.reset() + zmq_socket.send_string("ack") + continue + elif message == "RUNNER.SPECIAL.ABORT": + logger.warning(f"Received abort command for episode {episode_uuid}.") + context_tracker.reset() + zmq_socket.send_string("ack") + return None + else: + raise RuntimeError(f"Unknown special message received: {message}") + + final_output = WorkflowOutput(**json.loads(message)) + reward = final_output.reward + logger.success(f"Received workflow output for episode {episode_uuid} (Reward: {reward})") + + except Exception as exc: + raise exc + + finally: + tuner.terminate_episode() # this is very important to avoid resource leak + zmq_socket.close() + if ipc_path and os.path.exists(ipc_path): os.remove(ipc_path) + + return final_output + + + def execute(self, workflow_task: WorkflowTask) -> BaseContextTracker: + + observation_window = workflow_task.observation_window + task_thread_index = workflow_task.task_thread_index + + hooks = self.runner_hooks( + observation_window=observation_window, + task_thread_index=task_thread_index, + workflow_task=workflow_task, + ) + + should_exit_soft = hooks['should_interrupt_soft_fn'] + should_exit_hard = hooks['should_interrupt_hard_fn'] + + if should_exit_soft() or should_exit_hard(): + # print(f'Exiting workflow worker due to interrupt signal for episode {workflow_task.episode_uuid}.') + raise SwarmReceiveAbortException(f"Episode {workflow_task.episode_uuid} aborted due to interrupt signal.") + + # context tracker will trace and gather everything we need for training + context_tracker = MultiAgentContextTracker( + llm_inference_fn=self.llm_inference_fn, + tokenizer=self.tokenizer, + config=self.config, + workflow_task = workflow_task, + **hooks, + ) + # tuner will handle the communication and provide `baseurl_apikey` + tuner = AjetTuner( + context_tracker=context_tracker, + llm_inference_fn=self.llm_inference_fn, + workflow_cls=Workflow, + config=self.config, + ) + + # from tuner, we get base_url and api_key + baseurl_apikey = tuner.as_oai_baseurl_apikey() + + base_url = baseurl_apikey.base_url + api_key = baseurl_apikey.api_key + + # wait for remote client to return workflow output + workflow_output: WorkflowOutput | None = self.register_episode_and_wait_output( + episode_uuid=context_tracker.episode_uuid, + openai_base_url=base_url, + openai_api_key=api_key, + context_tracker=context_tracker, + tuner=tuner, + should_exit_soft=should_exit_soft, + should_exit_hard=should_exit_hard, + ) + if not workflow_output: + return None # type: ignore + + # the most important thing is to fix task_id to client task_id, set task_id to workflow_task and context_tracker task_id + assert "task_id" in workflow_output.metadata, "workflow_output.metadata must contain task_id" + task_id = workflow_output.metadata.get("task_id", "") + workflow_task.task_id = task_id + context_tracker.task_id = task_id + + # process reward + if workflow_output.reward is not None: + raw_reward, is_success = ( + workflow_output.reward, + workflow_output.is_success, + ) + else: + raise ValueError("workflow_output.reward is None in SwarmRunner, this is currently not allowed.") + + # release gym_env + workflow_task.gym_env = None # clear gym env client reference to avoid serialization issue + + # check reward + assert not isinstance(raw_reward, list), "AgentJet will support step reward in future versions." + + # register reward + # TODO: support multi-step reward + reward = Reward( + raw_reward=raw_reward, + raw_step_reward=None, # "AgentJet will support step reward in future versions." + success_rate=1.0 if is_success else 0.0, + madness=0, + description="", + ) + # process reward + context_tracker.process_reward(reward) + # generate token before merging + context_tracker.group_merge() + # after merging, process and align reward again + context_tracker.process_reward(reward) + # mark the thread as ended + observation_window["step"][task_thread_index] = -1 + tuner.terminate_episode() + context_tracker.log_metrics = workflow_output.log_metrics + return context_tracker diff --git a/ajet/tuner.py b/ajet/tuner.py index aacc3ab9..f8be6ab4 100644 --- a/ajet/tuner.py +++ b/ajet/tuner.py @@ -1,9 +1,6 @@ from typing import TYPE_CHECKING, Callable, Union, Type -from ajet.context_tracker.multiagent_tracking import ( - MultiAgentContextTracker, -) - +from ajet.context_tracker.multiagent_tracking import MultiAgentContextTracker from ajet.tuner_lib.weight_tuner import AgentScopeModelTuner from ajet.tuner_lib.weight_tuner import OpenaiClientModelTuner from ajet.tuner_lib.weight_tuner.as_oai_baseurl_apikey import OpenaiClientBaseUrlTuner @@ -146,6 +143,7 @@ def _register(self, target_name: str, agent_name: str, explicit_tuner: TunerType self.target2proxy_registry[target_name][agent_name] = explicit_tuner return explicit_tuner + def _is_target_trainable(self, target_name) -> bool: """Determine whether user have used `trainable_targets` to explicitly control training targets. """ diff --git a/ajet/tuner_lib/weight_tuner/as_oai_baseurl_apikey.py b/ajet/tuner_lib/weight_tuner/as_oai_baseurl_apikey.py index 90c2cc72..3edf46c8 100644 --- a/ajet/tuner_lib/weight_tuner/as_oai_baseurl_apikey.py +++ b/ajet/tuner_lib/weight_tuner/as_oai_baseurl_apikey.py @@ -18,6 +18,18 @@ class MockAsyncChat(AsyncChat): def completions(self) -> MockAsyncCompletions: # type: ignore return MockAsyncCompletions(self._client) +class OpenaiBaseUrlAndApiKey(BaseModel): + """ At this layer, we will determine which model to use: + - training model + - debug model assigned by user, used when this target is not being trained + """ + + base_url: str = Field(default="http://localhost:27788/v1", description="The base URL for the Ajet's fake OpenAI API") + api_key: str = Field(default="invalid_apikey", description="The Ajet's fake key, which is not a real key, it is a encoded string contain episode_uuid and other stuff.") + model: str = Field(default="reserved_field", description="reserved field.") + episode_uuid: str = Field(default="episode_id", description="reserved field.") + + class OpenaiClientBaseUrlTuner(BaseModel): """ At this layer, we will determine which model to use: - training model @@ -40,6 +52,9 @@ def __init__( ): port = os.getenv("AJET_DAT_INTERCHANGE_PORT") + if config.ajet.interchange_server.interchange_server_port != 'auto': + port = str(int(config.ajet.interchange_server.interchange_server_port)) + assert port is not None, "AJET_DAT_INTERCHANGE_PORT env var must be set" master_node_ip = os.getenv("MASTER_NODE_IP", "localhost") diff --git a/ajet/tuner_lib/weight_tuner/experimental/as_oai_model_client.py b/ajet/tuner_lib/weight_tuner/experimental/as_oai_model_client.py index 720c6c77..dd8b191f 100644 --- a/ajet/tuner_lib/weight_tuner/experimental/as_oai_model_client.py +++ b/ajet/tuner_lib/weight_tuner/experimental/as_oai_model_client.py @@ -14,8 +14,7 @@ from openai.types.chat.chat_completion import ChatCompletion from ajet.tuner_lib.weight_tuner.experimental.as_oai_model_server import InterchangeCompletionRequest, API_KEY_PREFIX from ajet.utils.thread_executors import SharedInferenceTrackerThreadExecutor, SharedInterchangeThreadExecutor -from ajet.utils.networking import find_free_port - +from ajet.tuner_lib.weight_tuner.experimental.interchange_utils import get_zmq_socket, is_episode_claimed context = zmq.Context() atexit.register(context.term) @@ -68,17 +67,11 @@ def __init__(self, episode_uuid: str, context_tracker: "MultiAgentContextTracker self.llm_inference_fn = llm_inference_fn self.config = config self._should_terminate = False - + self.episode_contect_address, ipc_path = get_zmq_socket(config, episode_uuid, tag="llm") + self.ipc_path = ipc_path self.interchange_method = config.ajet.interchange_server.interchange_method - if self.interchange_method == 'tcp': - master_node_ip = os.getenv("MASTER_NODE_IP", "localhost") - self.episode_contect_address = f"tcp://{master_node_ip}:{find_free_port()}" - elif self.interchange_method == 'ipc': - self.ipc_path = f"/tmp/ajet/{self.episode_uuid}.sock" - self.episode_contect_address = f"ipc://{self.ipc_path}" self.max_inference_tracker_threads = config.ajet.interchange_server.max_inference_tracker_threads - async def llm_infer( self, req: ChatCompletionRequest, @@ -110,18 +103,33 @@ async def llm_infer( @property - def should_terminate(self) -> bool: - return self._should_terminate + def should_soft_terminate(self) -> bool: + if self._should_terminate: + return True + return self.context_tracker.should_interrupt_soft_fn() + + @property + def should_hard_terminate(self) -> bool: + if self._should_terminate: + return True + if not self.config.ajet.enable_swarm_mode: + return self.should_soft_terminate + else: + return self.context_tracker.should_interrupt_hard_fn() + def begin_service(self): """ Starts the zmq communication loop. """ + if self.should_soft_terminate or self.should_hard_terminate: + return self.episode_contect_address + if DEBUG: logger.info(f"[client] {self.episode_uuid} | Starting InterchangeClient service loop...") self.socket = context.socket(zmq.REP) self.socket.bind(f"{self.episode_contect_address}") - self.socket.setsockopt(zmq.RCVTIMEO, 3*1000) # 3 second timeout for REP + self.socket.setsockopt(zmq.RCVTIMEO, 1*1000) # 3 second timeout for REP self.executor = SharedInterchangeThreadExecutor(self.max_inference_tracker_threads).get_shared_executor() if DEBUG: logger.info(f"[client] {self.episode_uuid} | Submitting _begin_service_threading to executor...") @@ -129,10 +137,15 @@ def begin_service(self): # wait till service begin running time.sleep(0.5) - w_time = 1 + wait_time = 1 while future._state == 'PENDING': - time.sleep(min(w_time * 2, 10)) - w_time += 1 + if self.should_soft_terminate or self.should_hard_terminate: + future.cancel() + self.socket.close() + if os.path.exists(self.ipc_path): os.remove(self.ipc_path) + return self.episode_contect_address + time.sleep(min(wait_time * 2, 10)) + wait_time += 1 if DEBUG: logger.info(f"[client] {self.episode_uuid} | Future ready...") return self.episode_contect_address @@ -146,19 +159,20 @@ def _begin_service_threading(self): if DEBUG: logger.info(f"[client] {self.episode_uuid} | Starting ZMQ socket bind complete") try: - while not self.should_terminate: + while not self.should_hard_terminate: # listen for next request from remote try: - if DEBUG: logger.info(f"[client] {self.episode_uuid} | socket.recv_string() has begun") + # if DEBUG: logger.info(f"[client] {self.episode_uuid} | socket.recv_string() has begun (should_terminate {self.should_terminate})") message = self.socket.recv_string() - if DEBUG: logger.info(f"[client] {self.episode_uuid} | socket.recv_string() is done") + # if DEBUG: logger.info(f"[client] {self.episode_uuid} | socket.recv_string() is done") except zmq.Again as e: - if self.should_terminate: + if self.should_hard_terminate: + # abort_episode() if DEBUG: logger.info(f"[client] {self.episode_uuid} | episode over") break timepassed = time.time() - begin_time - if timepassed > 60: - logger.warning(f"[client] {self.episode_uuid} | Still waiting for first message... (time passed {timepassed}) for episode_uuid:{self.episode_uuid}...") + if timepassed > 100: + if DEBUG: logger.warning(f"[client] {self.episode_uuid} | Still waiting for first message... (time passed {timepassed}) for episode_uuid:{self.episode_uuid}...") continue # parse the incoming request @@ -199,3 +213,4 @@ def _begin_service_threading(self): if os.path.exists(self.ipc_path): os.remove(self.ipc_path) if DEBUG: logger.info(f"[client] {self.episode_uuid} | IPC socket file {self.ipc_path} removed.") + diff --git a/ajet/tuner_lib/weight_tuner/experimental/as_oai_model_server.py b/ajet/tuner_lib/weight_tuner/experimental/as_oai_model_server.py index 089d11eb..43ee39da 100644 --- a/ajet/tuner_lib/weight_tuner/experimental/as_oai_model_server.py +++ b/ajet/tuner_lib/weight_tuner/experimental/as_oai_model_server.py @@ -26,12 +26,16 @@ from pydantic import BaseModel from fastapi import FastAPI, Header, HTTPException, Request from contextlib import asynccontextmanager -from multiprocessing import Process +from multiprocessing import Manager, Process from concurrent.futures import ThreadPoolExecutor +from typing import Coroutine, Optional, Tuple from vllm.entrypoints.openai.protocol import ChatCompletionRequest from openai.types.chat.chat_completion import ChatCompletion +from ajet.utils.networking import find_free_port, get_host_ip +from ajet.tuner_lib.weight_tuner.experimental.interchange_utils import EpisodeStatus + API_KEY_PREFIX = "sk-ajet-" class InterchangeCompletionRequest(BaseModel): @@ -53,12 +57,19 @@ class HealthCheckRequest(BaseModel): DEBUG = False # DEBUG = True - context = zmq.Context() atexit.register(context.term) -def get_app(max_fastapi_threads: int = 512) -> FastAPI: + + + + + + + +def get_app(max_fastapi_threads: int = 512, enable_swarm_mode=False, shared_mem_dict=None, shared_mem_dict_lock=None) -> Tuple[FastAPI, Optional[Coroutine]]: + @asynccontextmanager async def lifespan(app: FastAPI): @@ -70,28 +81,35 @@ async def lifespan(app: FastAPI): SERVER_SHUTDOWN_EVENT.set() app.state.executor.shutdown(wait=False, cancel_futures=True) + app = FastAPI(title="AJet Interchange Endpoint", lifespan=lifespan) - def _begin_handle_chat_completion(episode_address, int_req: InterchangeCompletionRequest, episode_uuid, timeline_uuid, client_offline: threading.Event): + def _begin_handle_chat_completion(episode_address, int_req: InterchangeCompletionRequest, episode_uuid): """ run this in thread to avoid blocking main event loop """ if DEBUG: logger.info(f"[server] episode_uuid: {episode_uuid} | Received new chat completion request (inside thread)") socket = context.socket(zmq.REQ) - socket.setsockopt(zmq.RCVTIMEO, 60*1000) # 1 minute recv timeout + socket.setsockopt(zmq.RCVTIMEO, 6*1000) # 6 second recv timeout socket.connect(f"{episode_address}") if DEBUG: logger.info(f"[server] episode_uuid: {episode_uuid} | connect done") socket.send_string(int_req.model_dump_json()) if DEBUG: logger.info(f"[server] episode_uuid: {episode_uuid} | send_string") result_str = "" - for _ in range(5): # max 5 minutes wait + for _ in range(50): # max 5 minutes wait try: if DEBUG: logger.info(f"[server] episode_uuid: {episode_uuid} | recv_string begin.") result_str = socket.recv_string() break except zmq.Again as e: + # check whether server is still in rolling status + if enable_swarm_mode: + assert shared_mem_dict is not None + if shared_mem_dict['engine_status'] not in ["ENGINE.ROLLING", "ENGINE.ROLLING_POST"]: + raise HTTPException(status_code=404, detail="The server is not in ENGINE.ROLLING status, cannot accept new requests.") + if DEBUG: logger.info(f"[server] episode_uuid: {episode_uuid} | recv_string timeout, retrying.") continue @@ -114,6 +132,8 @@ async def chat_completions(request: Request, authorization: str = Header(None)): OpenAI-compatible chat completions endpoint. Receives ChatCompletionRequest and returns ChatCompletion. """ + if DEBUG: logger.info("Received /v1/chat/completions request") + # Parse authorization header (base64 encoded JSON) if not authorization: return HTTPException(status_code=401, detail="Missing authorization header") @@ -139,9 +159,29 @@ async def chat_completions(request: Request, authorization: str = Header(None)): new_req = ChatCompletionRequest.model_validate(body) if new_req.stream: return HTTPException(status_code=400, detail="Streaming responses not supported in current AgentJet version, please set `stream=false` for now.") + # Create timeline UUID timeline_uuid = uuid.uuid4().hex + # enable_swarm_mode + if enable_swarm_mode: + from ajet.tuner_lib.weight_tuner.experimental.as_swarm_server import ep_key + assert shared_mem_dict is not None + assert shared_mem_dict_lock is not None + + if shared_mem_dict['engine_status'] not in ["ENGINE.ROLLING", "ENGINE.ROLLING_POST"]: + logger.error(f"The server is not in ENGINE.ROLLING status (current status: [{shared_mem_dict['engine_status']}]), cannot accept new requests.") + raise HTTPException(status_code=404, detail="The server is not in ENGINE.ROLLING status, cannot accept new requests.") + + if ep_key(episode_uuid) not in shared_mem_dict: + raise HTTPException(status_code=404, detail=f"Episode {episode_uuid} not found.") + + # update activate timestamp + with shared_mem_dict_lock: + es:EpisodeStatus = shared_mem_dict[ep_key(episode_uuid)] + es.latest_activity_timestamp = time.time() + shared_mem_dict[ep_key(episode_uuid)] = es + # Add to received queue int_req = InterchangeCompletionRequest( completion_request = new_req, @@ -151,59 +191,104 @@ async def chat_completions(request: Request, authorization: str = Header(None)): timeline_uuid = timeline_uuid, ) if DEBUG: logger.info(f"episode_uuid: {episode_uuid} | Received new chat completion request (outside thread)") - client_offline = threading.Event() - try: - loop = asyncio.get_running_loop() - return await loop.run_in_executor(request.app.state.executor, _begin_handle_chat_completion, episode_address, int_req, episode_uuid, timeline_uuid, client_offline) - finally: - client_offline.set() + loop = asyncio.get_running_loop() + return await loop.run_in_executor(request.app.state.executor, _begin_handle_chat_completion, episode_address, int_req, episode_uuid) + + + if enable_swarm_mode: + from ajet.tuner_lib.weight_tuner.experimental.as_swarm_server import register_enable_swarm_mode_routes + assert shared_mem_dict is not None, "shared_mem_dict must not be None when enable_swarm_mode is True." + assert shared_mem_dict_lock is not None, "shared_mem_dict_lock must not be None when enable_swarm_mode is True." + app, additional_coro = register_enable_swarm_mode_routes(app, zmq_context=context, shared_mem_dict=shared_mem_dict, shared_mem_dict_lock=shared_mem_dict_lock) + else: + additional_coro = None + + + return app, additional_coro + + + + + + + + + + - @app.post("/reset") - async def reset(): - return {"status": "reset_complete"} - return app class InterchangeServer(Process): - def __init__(self, experiment_dir: str, port: int, num_fastapi_process: int = 2, max_fastapi_threads: int = 512): + def __init__(self, experiment_dir: str, port: int, num_fastapi_process: int = 2, max_fastapi_threads: int = 512, enable_swarm_mode=False): super().__init__() self.experiment_dir = experiment_dir self.port = port self.num_fastapi_process = num_fastapi_process self.max_fastapi_threads = max_fastapi_threads + self.enable_swarm_mode = enable_swarm_mode def run(self): logger.info(f"Starting Interchange Server on port {self.port} with {self.num_fastapi_process} processes and {self.max_fastapi_threads} threads per process.") - app = get_app(self.max_fastapi_threads) - async def serve_with_monitor(): + + if self.enable_swarm_mode: + manager = Manager() + shared_mem_dict = manager.dict() + shared_mem_dict_lock = manager.Lock() + else: + shared_mem_dict = None + shared_mem_dict_lock = None + + app, additional_coro = get_app(self.max_fastapi_threads, self.enable_swarm_mode, shared_mem_dict, shared_mem_dict_lock) + + async def serve_with_monitor(additional_coro): # Start the server config = uvicorn.Config( app=app, host="0.0.0.0", port=self.port, - log_level="error", + log_level="info", workers=self.num_fastapi_process ) server = uvicorn.Server(config) - await server.serve() + if additional_coro: + coro_task_1 = asyncio.create_task(additional_coro) + coro_task_2 = asyncio.create_task(server.serve()) + await asyncio.gather(coro_task_1, coro_task_2) + else: + await server.serve() try: - asyncio.run(serve_with_monitor()) + asyncio.run(serve_with_monitor(additional_coro)) except KeyboardInterrupt as e: SERVER_SHUTDOWN_EVENT.set() raise e + + + + + + + + + + + + # Convenience function for quick server startup -def start_interchange_server(config) -> int: +def start_interchange_server(config, blocking=False, env={}) -> int: + # Read config + already_started = config.ajet.interchange_server.already_started experiment_dir = config.ajet.experiment_dir num_fastapi_process = config.ajet.interchange_server.num_fastapi_process max_fastapi_threads = config.ajet.interchange_server.max_fastapi_threads + enable_swarm_mode = config.ajet.enable_swarm_mode + # Find a free port if not specified or invalid port = int(os.environ.get("AJET_DAT_INTERCHANGE_PORT", -1)) - if config.ajet.interchange_server.interchange_server_port != 'auto': port = int(config.ajet.interchange_server.interchange_server_port) - + os.environ["AJET_DAT_INTERCHANGE_PORT"] = str(port) if port <= 0: import socket with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: @@ -211,14 +296,33 @@ def start_interchange_server(config) -> int: port = s.getsockname()[1] os.environ["AJET_DAT_INTERCHANGE_PORT"] = str(port) - interchange_server = InterchangeServer(experiment_dir, port, num_fastapi_process, max_fastapi_threads) - interchange_server.start() + # init interchage server sub-process + if not already_started: + # apply env vars + os.environ.update(env) + # start interchange server + interchange_server = InterchangeServer( + experiment_dir, + port, + num_fastapi_process, + max_fastapi_threads, + enable_swarm_mode, + ) + interchange_server.start() + else: + interchange_server = None # Wait for server to be ready - health_url = f"http://localhost:{port}/health" + health_url = f"http://127.0.0.1:{port}/health" + localhost_url = f"http://127.0.0.1:{port}" + master_node_ip = get_host_ip(os.environ.get("NETWORK_INTERFACE", None)) + host_url = f"http://{master_node_ip}:{port}" + os.environ["MASTER_NODE_IP"] = str(master_node_ip) + + # polling for server ready start_time = time.time() while True: - if interchange_server.exitcode is not None: + if interchange_server and interchange_server.exitcode is not None: logger.error(f"Interchange server subprocess failed to start. Return code: {interchange_server.exitcode}") raise RuntimeError("Interchange server subprocess failed to start.") if time.time() - start_time > 30: @@ -234,8 +338,29 @@ def start_interchange_server(config) -> int: time.sleep(1) # register a termination handler - if DEBUG: logger.info(f"Interchange server subprocess started on port {port} (pid: {interchange_server.pid})") - atexit.register(lambda: interchange_server.terminate()) - - # return port - return port + if interchange_server: + if DEBUG: logger.info(f"Interchange server subprocess started on port {port} (pid: {interchange_server.pid})") + atexit.register(lambda: interchange_server.terminate()) + + if not blocking: + # return port + return port + else: + logger.success(f"Interchange server is running in blocking mode on:\n------\n" + f"URL 1: {localhost_url}\n------\n" + f"URL 2: {host_url}\n------\n" + f"Press Ctrl+C to stop.") + try: + if interchange_server: + interchange_server.join() + except KeyboardInterrupt: + logger.info("Shutting down interchange server...") + try: httpx.post(f"http://127.0.0.1:{port}/stop_engine", timeout=8).status_code + except Exception: pass + + if interchange_server: + interchange_server.terminate() + if enable_swarm_mode: + from ajet.tuner_lib.weight_tuner.experimental.as_swarm_server import kill_process_tree + kill_process_tree(None, None) + return -1 \ No newline at end of file diff --git a/ajet/tuner_lib/weight_tuner/experimental/as_swarm_client.py b/ajet/tuner_lib/weight_tuner/experimental/as_swarm_client.py new file mode 100644 index 00000000..09b05f88 --- /dev/null +++ b/ajet/tuner_lib/weight_tuner/experimental/as_swarm_client.py @@ -0,0 +1,346 @@ +import uuid +import time +import httpx +import yaml +from typing import List, Tuple +from loguru import logger +from ajet.schema.task import WorkflowOutput, Task +from ajet.copilot.job import AgentJetJob +from ajet.tuner_lib.weight_tuner.as_oai_baseurl_apikey import OpenaiBaseUrlAndApiKey +from ajet.tuner_lib.weight_tuner.experimental.interchange_utils import ( + SyncTrainConfigRequest, + ClaimEpisodeRequest, + ClaimEpisodeResponse, + CanContinueEpisodeRequest, + CanContinueEpisodeResponse, + EndEpisodeRequest, + EndEpisodeResponse, + EpisodeStatus, + EpisodeBufferResponse, +) + + +class SwarmClient(object): + + def __init__(self, server_url: str): + self.server_url = server_url + self.client_uuid = str(uuid.uuid4()) + self.previous_warning_time = 0 + self.record_episode_expire_time = {} + + + def begin_episode(self, allow_discard_timeout=60, episode_type="train") -> Tuple[str, OpenaiBaseUrlAndApiKey]: + """ + Block until an episode is claimed. + Return (episode_uuid, openai_base_url, openai_api_key) + """ + while True: + try: + req_obj = ClaimEpisodeRequest( + client_uuid=self.client_uuid, + episode_type=episode_type, + allow_discard_timeout=allow_discard_timeout, + ) + resp = httpx.post( + f"{self.server_url}/claim_episode", + json=req_obj.model_dump(), + timeout=30 + ) + resp.raise_for_status() + data = ClaimEpisodeResponse.model_validate(resp.json()) + episode_uuid = data.episode_uuid + self.record_episode_expire_time[episode_uuid] = time.time() + allow_discard_timeout + + if data.success: + episode_uuid = data.episode_uuid + openai_base_url = data.openai_base_url + openai_api_key = data.openai_api_key + logger.info(f"Claimed episode {episode_uuid}") + return episode_uuid, OpenaiBaseUrlAndApiKey( + base_url=openai_base_url, + api_key=openai_api_key, + episode_uuid=episode_uuid + ) + else: + need_wait_scenarios =[ + "Engine is syncing weights", + "Engine is in post-rolling phase", + "No available episodes to claim.", + ] + if any(scenario in data.fail_cause for scenario in need_wait_scenarios): + if time.time() - self.previous_warning_time > 60: + logger.info(f"{data.fail_cause}. Retrying in 30s...") + self.previous_warning_time = time.time() + time.sleep(30) + else: + logger.warning(f"Failed to claim episode: {data.fail_cause}. Retrying in 5s...") + time.sleep(5) + except Exception as e: + logger.error(f"Error claiming episode: {e}. Retrying in 5s...") + time.sleep(5) + + def end_episode(self, task:Task, episode_uuid: str, workflow_output: WorkflowOutput): + if not episode_uuid: + logger.error("No episode to end.") + return + + remain_time = self.record_episode_expire_time.get(episode_uuid, 0) - time.time() + if remain_time < 0: + logger.warning(f"Episode {episode_uuid} has expired (expired {remain_time} seconds ago). Please use a larger `allow_discard_timeout` when `begin_episode`. Skipping end_episode.") + return + + try: + task_id = task.task_id + workflow_output.metadata["task_id"] = task_id + req_obj = EndEpisodeRequest( + client_uuid=self.client_uuid, + episode_uuid=episode_uuid, + workflow_output=workflow_output, + task_id=task_id + ) + + resp = httpx.post( + f"{self.server_url}/end_episode", + json=req_obj.model_dump(), + timeout=30 + ) + resp.raise_for_status() + data = EndEpisodeResponse.model_validate(resp.json()) + + if data.success: + logger.info(f"Ended episode {episode_uuid}") + else: + logger.error(f"Failed to end episode {episode_uuid}") + + except Exception as e: + logger.error(f"Error ending episode: {e}") + + def abort_episode(self, episode_uuid: str): + if not episode_uuid: + logger.error("No episode to end.") + return + + try: + workflow_output = WorkflowOutput(reward=0.0, metadata={}) + req_obj = EndEpisodeRequest( + client_uuid=self.client_uuid, + episode_uuid=episode_uuid, + workflow_output=workflow_output, + task_id="" + ) + + resp = httpx.post( + f"{self.server_url}/abort_episode", + json=req_obj.model_dump(), + timeout=30 + ) + resp.raise_for_status() + data = EndEpisodeResponse.model_validate(resp.json()) + + if data.success: + logger.info(f"Aborted episode {episode_uuid}") + else: + logger.error(f"Failed to end episode {episode_uuid}") + + except Exception as e: + logger.error(f"Error ending episode: {e}") + + def sync_train_config(self, agent_jet_job: AgentJetJob): + """ + Sync training configuration to the Swarm server. + This sends the AgentJetJob config as YAML to the remote server. + """ + # try get init status + current_status = self.get_engine_status() + if current_status != "ENGINE.OFFLINE": + raise RuntimeError(f"Cannot sync train config when engine is NOT ENGINE.OFFLINE. (current status: {current_status})") + + try: + config_dict = agent_jet_job.config.to_dict() + yaml_str = yaml.safe_dump(config_dict, sort_keys=False) + + req_obj = SyncTrainConfigRequest(yaml_as_string=yaml_str) + + resp = httpx.post( + f"{self.server_url}/sync_train_config", + json=req_obj.model_dump(), + timeout=30 + ) + resp.raise_for_status() + logger.info("Synced train config to Swarm server") + except Exception as e: + logger.error(f"Error syncing train config: {e}") + raise + + def start_engine(self): + """ + Start the training engine on the Swarm server. + This triggers the server to begin the training process. + Polls until engine status is "ENGINE.ROLLING". + """ + # try get init status + current_status = self.get_engine_status() + if current_status != "ENGINE.OFFLINE": + raise RuntimeError(f"Cannot start engine when engine is NOT ENGINE.OFFLINE. (current status: {current_status})") + + # Send start engine request + try: + resp = httpx.post( + f"{self.server_url}/start_engine", + json={}, + timeout=600 + ) + resp.raise_for_status() + result = resp.json() + if result.get("success"): + logger.info("Successfully started training engine on Swarm server") + else: + logger.error("Failed to start training engine") + raise RuntimeError("Failed to start training engine") + except Exception as e: + logger.error(f"Error starting engine: {e}") + raise + + # Poll until engine status is "ENGINE.ROLLING" + self._wait_until_status_change_to(desired_status="ENGINE.ROLLING") + logger.success("Training engine is now ROLLING and ready.") + + def _wait_until_status_change_to(self, desired_status="ENGINE.ROLLING"): + """ + Poll engine status until it reaches desired_status. + Reports status every 5 seconds while waiting. + """ + logger.info(f"Polling engine status until {desired_status}...") + last_report_time = time.time() + init_poll_time = last_report_time + + while True: + try: + current_status = self.get_engine_status() + current_time = time.time() + + # Report status every 5 seconds + if current_time - last_report_time >= 10: + logger.info(f"Current engine status (already waited {current_time - init_poll_time:.1f}s): {current_status}") + last_report_time = current_time + + # Check if engine has reached the desired status + if current_status == desired_status: + logger.info(f"Engine status is {desired_status}.") + break + + # Wait a bit before next poll + time.sleep(5) + + except Exception as e: + logger.error(f"Error polling engine status: {e}") + time.sleep(5) + + def get_engine_status(self) -> str: + try: + resp = httpx.get( + f"{self.server_url}/get_engine_status", + timeout=10 + ) + resp.raise_for_status() + result = resp.json().get("engine_status", "unknown") + if result == "unknown": + logger.warning("get_engine_status: " + resp.json()) + return result + except Exception as e: + logger.error(f"Error getting engine status: {e}") + return "ENGINE.CANNOT_CONNECT" + + def can_continue_episode(self, episode_uuid: str) -> bool: + if not episode_uuid: + return False + + try: + req_obj = CanContinueEpisodeRequest( + client_uuid=self.client_uuid, + episode_uuid=episode_uuid + ) + resp = httpx.post( + f"{self.server_url}/can_continue_episode", + json=req_obj.model_dump(), + timeout=10 + ) + resp.raise_for_status() + data = CanContinueEpisodeResponse.model_validate(resp.json()) + return data.can_continue + except Exception as e: + logger.error(f"Error checking can_continue_episode: {e}") + return False + + def get_episode_buffer(self) -> List[EpisodeStatus]: + try: + resp = httpx.post( + f"{self.server_url}/get_episode_buffer", + json={}, + timeout=10 + ) + resp.raise_for_status() + data = EpisodeBufferResponse.model_validate(resp.json()) + return data.buffer + except Exception as e: + logger.error(f"Error getting episode buffer: {e}") + return [] + + def auto_sync_train_config_and_start_engine(self, agent_jet_job: AgentJetJob, force_restart=False): + """ + Automatically sync training configuration and start the engine if needed. + This checks the current engine status and performs actions accordingly. + + Args: + - agent_jet_job: The AgentJetJob configuration to sync. + - force_restart: If True, forces a restart of the engine. + """ + if force_restart: + logger.warning("Force restarting the engine...") + self.stop_engine() + time.sleep(8) + current_status = self.get_engine_status() + if current_status == "ENGINE.OFFLINE": + logger.info("Engine is OFFLINE. Syncing train config and starting engine...") + self.sync_train_config(agent_jet_job) + self.start_engine() + elif current_status == "ENGINE.ROLLING": + logger.info("Engine is already ROLLING. No action needed.") + elif current_status == "ENGINE.ROLLING_POST": + logger.info("Engine is already ROLLING. No action needed.") + elif current_status == "ENGINE.BOOTING": + logger.info("Engine is BOOTING. Waiting until it becomes ROLLING...") + self._wait_until_status_change_to(desired_status="ENGINE.ROLLING") + logger.success("Training engine is now ROLLING and ready.") + elif current_status == "ENGINE.CANNOT_CONNECT": + logger.error("Cannot connect to the engine. Please check the network.") + self._wait_until_status_change_to(desired_status="ENGINE.ROLLING") + logger.success("Training engine is now ROLLING and ready.") + else: + raise RuntimeError(f"Cannot sync train config or start engine when engine is in status: {current_status}") + + def stop_engine(self): + """ + Stop the training engine on the Swarm server. + This triggers the server to stop the training process. + """ + current_status = self.get_engine_status() + if current_status == "ENGINE.OFFLINE": + logger.info("Engine is already OFFLINE. No action needed.") + return + + try: + resp = httpx.post( + f"{self.server_url}/stop_engine", + json={}, + timeout=600 + ) + resp.raise_for_status() + result = resp.json() + if result.get("success"): + logger.info("Successfully stopped training engine on Swarm server") + else: + logger.error("Failed to stop training engine") + self._wait_until_status_change_to(desired_status="ENGINE.OFFLINE") + except Exception as e: + logger.error(f"Error stopping engine: {e}") diff --git a/ajet/tuner_lib/weight_tuner/experimental/as_swarm_server.py b/ajet/tuner_lib/weight_tuner/experimental/as_swarm_server.py new file mode 100644 index 00000000..392b0fe5 --- /dev/null +++ b/ajet/tuner_lib/weight_tuner/experimental/as_swarm_server.py @@ -0,0 +1,736 @@ +import multiprocessing +import time +import zmq +import os +import asyncio +import threading +from loguru import logger +from functools import lru_cache +from types import SimpleNamespace +from fastapi import FastAPI, HTTPException +from multiprocessing.managers import DictProxy +from typing import Coroutine, Optional, Tuple, List +from ajet.tuner_lib.weight_tuner.experimental.interchange_utils import ( + SyncTrainConfigRequest, + ClaimEpisodeRequest, + ClaimEpisodeResponse, + CanContinueEpisodeRequest, + CanContinueEpisodeResponse, + EndEpisodeRequest, + EndEpisodeResponse, + EpisodeStatus, + EpisodeBufferResponse, + BoolResponse, + RegisterEpisodeRequest, + UpdateEngineStatusRequest, + VALID_STATUSES, +) + +# DEBUG = True +DEBUG = False +RCVTIMEO = 2 * 1000 +RCVTIMEO_OUT = 300 * 1000 +RCVTIMEO_WAIT_N = RCVTIMEO_OUT // RCVTIMEO + + +def is_key_epsisode_status(key: str) -> bool: + return key.startswith("episodes-") + + +@lru_cache(maxsize=128) +def ep_key(episode_uuid: str) -> str: + return f"episodes-{episode_uuid}" + + +def register_enable_swarm_mode_routes( + app, + zmq_context, + shared_mem_dict:DictProxy, + shared_mem_dict_lock:threading.Lock, + ) -> Tuple[FastAPI, Optional[Coroutine]]: + + if 'episodes' not in shared_mem_dict: + shared_mem_dict["episodes"] = {} + + if 'unclaimed_episodes' not in shared_mem_dict: + shared_mem_dict['unclaimed_episodes'] = [] + + # ------------------------------------------------------------------------------------------------ + # ------ Recycle claimed episodes that client failed to complete in (promised) time -------------- + # --------------------------------- claimed -> unclaimed ---------------------------------------- + # ------------------------------------------------------------------------------------------------ + + async def find_claimed_episodes_that_need_to_be_unclaimed() -> List[str]: + to_unclaim_episodes = [] + current_time = time.time() + + for k, v in shared_mem_dict.items(): + if is_key_epsisode_status(k): + es:EpisodeStatus = v + if es.episode_status == "claimed": + if (current_time - es.latest_activity_timestamp) > es.allow_discard_timeout: + to_unclaim_episodes.append(es.episode_uuid) + + for episode_uuid in to_unclaim_episodes: + await _revert_episode_to_unclaimed(episode_uuid, shared_mem_dict, shared_mem_dict_lock) + + return to_unclaim_episodes + + def _context_tracker_reset_blocking(episode_uuid, shared_mem_dict): # must async + # send message to context tracker + assert 'episodes' in shared_mem_dict + zmq_addr = shared_mem_dict[ep_key(episode_uuid)].zmq_listen_result_addr + socket = zmq_context.socket(zmq.REQ) + socket.setsockopt(zmq.RCVTIMEO, RCVTIMEO) # 2 seconds recv timeout + socket.connect(zmq_addr) + + # + # : ajet/task_runner/swarm_runner.py + # : message = zmq_socket.recv_string() + socket.send_string("RUNNER.SPECIAL.RESET_CONTEXT_TRACKER") + + # + for _ in range(RCVTIMEO_WAIT_N): # max 5 minutes wait + try: + if DEBUG: logger.info(f"[server] episode_uuid: {episode_uuid} | recv_string begin.") + # : + # : ajet/task_runner/swarm_runner.py + # : zmq_socket.send_string("ack") + # : "ack" + socket.recv_string() + break + except zmq.Again as e: + if DEBUG: logger.info(f"[server] episode_uuid: {episode_uuid} | recv_string timeout, retrying.") + + if shared_mem_dict["engine_status"] not in ["ENGINE.ROLLING", "ENGINE.ROLLING_POST"]: + logger.info(f"[server] episode_uuid: {episode_uuid} | Engine is no longer rolling, aborting wait for ack.") + raise RuntimeError("Engine is no longer rolling, aborting wait for ack.") + continue + + async def _revert_episode_to_unclaimed(episode_uuid: str, shared_mem_dict, shared_mem_dict_lock): + # check status again, because other thread may have changed it + if shared_mem_dict[ep_key(episode_uuid)].episode_status != "claimed": + if episode_uuid in shared_mem_dict['unclaimed_episodes']: pass + else: shared_mem_dict['unclaimed_episodes'] += [episode_uuid] + return + + # reset context tracker + # _context_tracker_reset_blocking(episode_uuid, shared_mem_dict) # must async + await asyncio.to_thread(_context_tracker_reset_blocking, episode_uuid, shared_mem_dict) + + # revert + logger.warning(f"Reverting episode {episode_uuid} to unclaimed due to client timeout.") + if ep_key(episode_uuid) in shared_mem_dict: + es:EpisodeStatus = shared_mem_dict[ep_key(episode_uuid)] + es.episode_status = "registered" + es.client_uuid = "" + es.latest_activity_timestamp = time.time() + es.allow_discard_timeout = -1 + with shared_mem_dict_lock: + shared_mem_dict[ep_key(episode_uuid)] = es + if episode_uuid in shared_mem_dict['unclaimed_episodes']: pass + else: shared_mem_dict['unclaimed_episodes'] += [episode_uuid] + + def _delete_episode_record(episode_uuid: str, shared_mem_dict, shared_mem_dict_lock): + + with shared_mem_dict_lock: + # remove episode record + if ep_key(episode_uuid) in shared_mem_dict: + del shared_mem_dict[ep_key(episode_uuid)] # RM-- + logger.info(f"Deleted episode record for {episode_uuid}.") + # remove from unclaimed list if present + if episode_uuid in shared_mem_dict['unclaimed_episodes']: + shared_mem_dict['unclaimed_episodes'].remove(episode_uuid) + + + # -------------------------------------------------------------------------------------- + # -------------------------- return workflow output ------------------------------------ + # -------------------------------------------------------------------------------------- + + def _register_final_episode_output_blocking(episode_uuid, workflow_output, shared_mem_dict, shared_mem_dict_lock): # must async + + # begin send workflow_output + zmq_addr = shared_mem_dict[ep_key(episode_uuid)].zmq_listen_result_addr + if DEBUG: logger.info(f"[server] episode_uuid: {episode_uuid} | Received new chat completion request") + socket = zmq_context.socket(zmq.REQ) + socket.setsockopt(zmq.RCVTIMEO, RCVTIMEO) # 2 seconds recv timeout + socket.connect(zmq_addr) + if DEBUG: logger.info(f"[server] episode_uuid: {episode_uuid} | connect done") + socket.send_string(workflow_output.model_dump_json()) + if DEBUG: logger.info(f"[server] episode_uuid: {episode_uuid} | send_string") + # wait for ack + for _ in range(RCVTIMEO_WAIT_N): # max 5 minutes wait + try: + if DEBUG: logger.info(f"[server] episode_uuid: {episode_uuid} | recv_string begin.") + # : + # : ajet/task_runner/swarm_runner.py + # : zmq_socket.send_string("ack") + # : "ack" + socket.recv_string() + break + except zmq.Again as e: + if DEBUG: logger.info(f"[server] episode_uuid: {episode_uuid} | recv_string timeout, retrying.") + if shared_mem_dict["engine_status"] not in ["ENGINE.ROLLING", "ENGINE.ROLLING_POST"]: + logger.info(f"[server] episode_uuid: {episode_uuid} | Engine is no longer rolling, aborting wait for ack.") + raise RuntimeError("Engine is no longer rolling, aborting wait for ack.") + continue + # clean up episode records + with shared_mem_dict_lock: + del shared_mem_dict[ep_key(episode_uuid)] + if episode_uuid in shared_mem_dict['unclaimed_episodes']: + shared_mem_dict['unclaimed_episodes'].remove(episode_uuid) + + + + # -------------------------------------------------------------------------------------- + # -------------------------- status monitor -------------------------------------------- + # -------------------------------------------------------------------------------------- + + async def register_episode_ready_listener(): + while True: + await asyncio.sleep(10) # check every 10 seconds + await find_claimed_episodes_that_need_to_be_unclaimed() + read_all_episode_status() + + def read_all_episode_status() -> Optional[EpisodeStatus]: + group_by_status = {} + + for k, v in shared_mem_dict.items(): + if is_key_epsisode_status(k): + es:EpisodeStatus = v + if es.episode_status not in group_by_status: + group_by_status[es.episode_status] = [] + group_by_status[es.episode_status].append(es) + + print_buffer_str = f"Registered: {len(group_by_status.get('registered', []))}, Claimed: {len(group_by_status.get('claimed', []))}" + logger.info(f"Current engine status: [{shared_mem_dict['engine_status']}], " + print_buffer_str) + + return None + + + # -------------------------------------------------------------------------------------- + # -------------------------- fastapi routes -------------------------------------------- + # -------------------------------------------------------------------------------------- + + @app.post("/sync_train_config") + async def sync_train_config(req: SyncTrainConfigRequest): + """ + Receive training configuration from client as YAML string. + Store it in shared memory for later use by start_engine. + """ + + if shared_mem_dict['engine_status'] != "ENGINE.OFFLINE": + raise HTTPException(status_code=400, detail="Engine is already started. Call `stop_engine` first before syncing new training configuration.") + + try: + yaml_str = req.yaml_as_string + logger.info("[sync_train_config] Received training configuration") + if DEBUG: + logger.debug(f"[sync_train_config] YAML content:\n{yaml_str}...") + + # Store the YAML config in shared memory for start_engine to use + with shared_mem_dict_lock: + shared_mem_dict['train_config_yaml'] = yaml_str + + logger.info("[sync_train_config] Successfully stored training configuration") + return {"success": True} + except Exception as e: + logger.error(f"[sync_train_config] Error: {e}") + return {"success": False, "error": str(e)} + + + @app.post("/start_engine") + async def start_engine(): + """ + Start the training engine using the previously synced configuration. + This creates a temporary YAML file and spawns a training process. + """ + try: + import ray + import tempfile + import yaml as yaml_module + from ajet.utils.launch_utils import execute_training_process + from ajet.utils.config_utils import prepare_experiment_config + from ajet.launcher import get_backbone_target, setup_environment_vars + + # Check if config has been synced + if 'train_config_yaml' not in shared_mem_dict: + logger.error("[start_engine] No training config found. Please call sync_train_config first.") + return {"success": False, "error": "No training config found"} + + # Parse YAML to get backbone + yaml_str = shared_mem_dict['train_config_yaml'] + config_dict = yaml_module.safe_load(yaml_str) + backbone = config_dict.get('ajet', {}).get('backbone', 'verl') + exp_dir_final = config_dict.get('ajet', {}).get('experiment_dir', 'saved_experiments') + + # Save YAML to temporary file + with tempfile.NamedTemporaryFile(mode='w', delete=False, suffix='.yaml') as temp_file: + temp_file.write(yaml_str) + main_yaml_fp = temp_file.name + logger.info(f"[start_engine] Saved config to temporary file: {main_yaml_fp}") + + # Create args namespace + args = SimpleNamespace( + conf=main_yaml_fp, backbone=backbone, exp_dir=exp_dir_final, with_logview=False, + debug=False, + ) + # get debug param + should_debug = os.environ.get("RAY_DEBUG_POST_MORTEM", "0") == "1" + debug_tags = os.environ.get("DEBUG_TAGS", "") + if should_debug: + args.debug = debug_tags + + def override_param_callback(config): + config['ajet']['interchange_server']['already_started'] = True + config['ajet']['interchange_server']['interchange_server_port'] = int(os.getenv("AJET_DAT_INTERCHANGE_PORT")) # type: ignore + return config + + # Finalize experiment config + main_yaml_fp, exe_exp_base, exp_name, exp_config = prepare_experiment_config( + main_yaml_fp, exp_dir_final, backbone, override_param_callback + ) + + # Setup environment variables + env, exp_config = setup_environment_vars(args, exp_config, main_yaml_fp) + + # Start ray if not already started + if not ray.is_initialized(): + from ajet.utils.launch_utils import start_ray_service + logger.info("[start_engine] Starting Ray service...") + start_ray_service(args, env) + else: + logger.info("[start_engine] Ray already initialized") + + # Start training process in a separate process + p = multiprocessing.Process( + target=execute_training_process, + args=( + args, get_backbone_target(args.backbone), main_yaml_fp, + exe_exp_base, main_yaml_fp, env, exp_config, + ) + ) + p.daemon = True + p.start() + + # wait until p.pid is available + while not isinstance(p.pid, int): time.sleep(1) + + # set new process group + os.setpgid(p.pid, p.pid) + + # Store process info in shared memory + clean_up_engine_status(shared_mem_dict_lock, shared_mem_dict) + with shared_mem_dict_lock: + shared_mem_dict['training_process_pid'] = p.pid + shared_mem_dict['engine_status'] = "ENGINE.BOOTING" + + logger.info(f"[start_engine] Successfully started training process (PID: {p.pid})") + return {"success": True, "pid": p.pid} + + except Exception as e: + logger.error(f"[start_engine] Error starting engine: {e}") + import traceback + traceback.print_exc() + return {"success": False, "error": str(e)} + + + # --- engine status --- + shared_mem_dict['engine_status'] = "ENGINE.OFFLINE" # initial status + def clean_up_engine_status(shared_mem_dict_lock, shared_mem_dict): + with shared_mem_dict_lock: + episode_keys = [k for k in shared_mem_dict.keys() if is_key_epsisode_status(k)] + # remove all episodes + for key in episode_keys: + del shared_mem_dict[key] + logger.info(f"[clean_up_engine_status] Removed episode: {key}") + + # clear unclaimed episodes list + if 'unclaimed_episodes' in shared_mem_dict: + num_unclaimed = len(shared_mem_dict['unclaimed_episodes']) + shared_mem_dict['unclaimed_episodes'] = [] + logger.info(f"[clean_up_engine_status] Cleared {num_unclaimed} unclaimed episodes") + + @app.post("/update_engine_status", response_model=BoolResponse) + async def update_engine_status(req: UpdateEngineStatusRequest): + """Update the current engine status.""" + if req.engine_status not in VALID_STATUSES: + return BoolResponse(success=False, failure_reason="Invalid engine status") + previous_status = shared_mem_dict['engine_status'] + shared_mem_dict['engine_status'] = req.engine_status + if previous_status in ["ENGINE.ROLLING", "ENGINE.ROLLING_POST"] and req.engine_status not in ["ENGINE.ROLLING", "ENGINE.ROLLING_POST"]: + clean_up_engine_status(shared_mem_dict_lock, shared_mem_dict) + + logger.info(f"[update_engine_status] Engine status set to {req.engine_status}") + return BoolResponse(success=True) + + + @app.get("/get_engine_status") + async def get_engine_status(): + """Get the current engine status.""" + status = shared_mem_dict['engine_status'] + return {"engine_status": status} + + + # --- episode status --- + @app.post("/register_episode", response_model=BoolResponse) + async def register_episode(req: RegisterEpisodeRequest): + """(From task_runner) Register a new episode as ready to roll.""" + engine_status = shared_mem_dict['engine_status'] + if engine_status not in ["ENGINE.ROLLING"]: + return BoolResponse( + success=False, + failure_reason=f"Engine is not in rolling state. Cannot register episode.", + ) + + episode_uuid = req.episode_uuid + es = EpisodeStatus( + episode_uuid=req.episode_uuid, + openai_base_url=req.openai_base_url, + openai_api_key=req.openai_api_key, + episode_status="registered", + zmq_listen_result_addr=req.zmq_listen_result_addr, + allow_discard_timeout=-1, + ) + es.latest_activity_timestamp = time.time() + + with shared_mem_dict_lock: + shared_mem_dict[ep_key(episode_uuid)] = es + shared_mem_dict['unclaimed_episodes'] += [req.episode_uuid] + + return BoolResponse(success=True) + + + @app.post("/claim_episode", response_model=ClaimEpisodeResponse) + async def claim_episode(req: ClaimEpisodeRequest): + """(From client) Claim an available episode to rollout.""" + # find_claimed_episodes_that_need_to_be_unclaimed() + + engine_status = shared_mem_dict['engine_status'] + + if engine_status != "ENGINE.ROLLING": + fail_cause = f"Engine not ready. Current status: [{engine_status}]." + advise = "" + if engine_status == "ENGINE.OFFLINE": + advise = "Please start the engine first. Please use one of the client to run `client.sync_train_config() + client.start_engine()` to start the engine." + elif engine_status == "ENGINE.BOOTING": + advise = "Please wait until the engine is fully booted. Try again (maybe 1 minute) later." + elif engine_status == "ENGINE.WEIGHT_SYNCING": + advise = "Engine is syncing weights. Try again (maybe 1 minute) later." + elif engine_status == "ENGINE.WEIGHT_EXPORTING": + advise = "Engine is exporting weights (fsdp -> hf safetensor). Try again (maybe 1 minute) later." + elif engine_status == "ENGINE.ROLLING_POST": + advise = "Engine is in post-rolling phase. Try again (maybe 1 minute) later." + return ClaimEpisodeResponse( + success=False, + client_uuid=req.client_uuid, + episode_uuid="", + openai_base_url="", + openai_api_key="", + fail_cause=fail_cause + " " + advise, + ) + + if req.episode_type == "train" or req.episode_type == "eval": + + with shared_mem_dict_lock: + if len(shared_mem_dict['unclaimed_episodes']) <= 0: + return ClaimEpisodeResponse( + success=False, + client_uuid=req.client_uuid, + episode_uuid="", + openai_base_url="", + openai_api_key="", + fail_cause="No available episodes to claim. Try again (maybe 1 minute) later.", + ) + + # Hint: do NOT optimize these two lines + episode_uuid = shared_mem_dict['unclaimed_episodes'][0] + shared_mem_dict['unclaimed_episodes'] = shared_mem_dict['unclaimed_episodes'][1:] + + # get episode + es:EpisodeStatus = shared_mem_dict[ep_key(episode_uuid)] + es.episode_status = "claimed" + es.episode_type = req.episode_type + es.client_uuid = req.client_uuid + es.latest_activity_timestamp = time.time() + es.allow_discard_timeout = req.allow_discard_timeout + + shared_mem_dict[ep_key(episode_uuid)] = es + openai_base_url = es.openai_base_url + openai_api_key = es.openai_api_key + + return ClaimEpisodeResponse( + success=True, + client_uuid=req.client_uuid, + episode_uuid=episode_uuid, + openai_base_url=openai_base_url, + openai_api_key=openai_api_key, + fail_cause="", + ) + + else: + raise HTTPException(status_code=400, detail=f"Unknown episode_type: {req.episode_type}") + + + @app.post("/end_episode", response_model=EndEpisodeResponse) + async def end_episode(req: EndEpisodeRequest): + + engine_status = shared_mem_dict['engine_status'] + if engine_status not in ["ENGINE.ROLLING", "ENGINE.ROLLING_POST"]: + raise HTTPException(status_code=400, detail=f"Engine is not in rolling state. Current status: [{engine_status}]. Cannot end episode.") + + # receive workflow output data + client_uuid = req.client_uuid + episode_uuid = req.episode_uuid + workflow_output = req.workflow_output + task_id = req.task_id + + + assert "task_id" in workflow_output.metadata, "workflow_output.metadata must contain task_id" + assert workflow_output.metadata["task_id"] == task_id, "workflow_output.metadata.task_id must match req.task_id" + + if 'episodes' not in shared_mem_dict: + logger.error(f"[server] No episodes registered yet.") + raise HTTPException(status_code=400, detail=f"No episodes registered yet.") + + if (ep_key(episode_uuid)) not in shared_mem_dict: + logger.error(f"[server] Episode {episode_uuid} not found.") + raise HTTPException(status_code=400, detail=f"Episode {episode_uuid} not found.") + + # send workflow_output to zmq + assert 'episodes' in shared_mem_dict + ep_stat = shared_mem_dict[ep_key(episode_uuid)] + episode_type = ep_stat.episode_type + episode_status = ep_stat.episode_status + client_uuid_recorded = ep_stat.client_uuid + if client_uuid_recorded != client_uuid: + logger.error(f"[server] Episode {episode_uuid} is claimed by different client: {client_uuid_recorded}, but got {client_uuid}.") + raise HTTPException(status_code=404, detail=f"Episode {episode_uuid} is claimed by different client: {client_uuid_recorded}, but got {client_uuid}.") + + if episode_status != "claimed": + logger.error(f"[server] Episode {episode_uuid} is not in claimed status.") + raise HTTPException(status_code=400, detail=f"Episode {episode_uuid} is not in claimed status, maybe you take too long to submit.") + + if episode_type == "train": + # _register_final_episode_output_blocking(episode_uuid, workflow_output, shared_mem_dict, shared_mem_dict_lock) # must async + await asyncio.to_thread(_register_final_episode_output_blocking, episode_uuid, workflow_output, shared_mem_dict, shared_mem_dict_lock) + + elif episode_type == "eval": + if engine_status in ["ENGINE.ROLLING"]: + await _revert_episode_to_unclaimed(episode_uuid, shared_mem_dict, shared_mem_dict_lock) + else: + _delete_episode_record(episode_uuid, shared_mem_dict, shared_mem_dict_lock) + + else: + raise HTTPException(status_code=400, detail=f"Unknown episode_type: {episode_type}") + + # return success + return EndEpisodeResponse(success=True) + + + @app.post("/abort_episode", response_model=EndEpisodeResponse) + async def abort_episode(req: EndEpisodeRequest): + + engine_status = shared_mem_dict['engine_status'] + if engine_status not in ["ENGINE.ROLLING", "ENGINE.ROLLING_POST"]: + return EndEpisodeResponse(success=True) + + # receive workflow output data + episode_uuid = req.episode_uuid + workflow_output = req.workflow_output + task_id = req.task_id + + assert "task_id" in workflow_output.metadata, "workflow_output.metadata must contain task_id" + assert workflow_output.metadata["task_id"] == task_id, "workflow_output.metadata.task_id must match req.task_id" + + if 'episodes' not in shared_mem_dict: + logger.error(f"[server] No episodes registered yet.") + return EndEpisodeResponse(success=True) + + if (ep_key(episode_uuid)) not in shared_mem_dict: + logger.error(f"[server] Episode {episode_uuid} not found.") + return EndEpisodeResponse(success=True) + + if engine_status in ["ENGINE.ROLLING"]: + await _revert_episode_to_unclaimed(episode_uuid, shared_mem_dict, shared_mem_dict_lock) + else: + _delete_episode_record(episode_uuid, shared_mem_dict, shared_mem_dict_lock) + + return EndEpisodeResponse(success=True) + + + @app.post("/can_continue_episode", response_model=CanContinueEpisodeResponse) + async def can_continue_episode(req: CanContinueEpisodeRequest): + engine_status = shared_mem_dict['engine_status'] + if engine_status not in ["ENGINE.ROLLING", "ENGINE.ROLLING_POST"]: + return CanContinueEpisodeResponse(can_continue=False) + + can_continue = (ep_key(req.episode_uuid) in shared_mem_dict) + can_continue = can_continue and shared_mem_dict[ep_key(req.episode_uuid)].episode_status == "claimed" + + return CanContinueEpisodeResponse(can_continue=can_continue) + + + @app.post("/is_episode_claimed", response_model=BoolResponse) + async def is_episode_claimed(req: CanContinueEpisodeRequest): + engine_status = shared_mem_dict['engine_status'] + if engine_status not in ["ENGINE.ROLLING", "ENGINE.ROLLING_POST"]: + return BoolResponse(success=False) + if ep_key(req.episode_uuid) not in shared_mem_dict: + return BoolResponse(success=False) + es = shared_mem_dict[ep_key(req.episode_uuid)] + if not es: + return BoolResponse(success=False) + if es.episode_status == "claimed": + return BoolResponse(success=True) + else: + return BoolResponse(success=False) + + + @app.post("/get_episode_buffer", response_model=EpisodeBufferResponse) + async def get_episode_buffer(): + result = [ + v for k, v in shared_mem_dict.items() if is_key_epsisode_status(k) + ] + return EpisodeBufferResponse(buffer=result) + + + + + # -------------------------------------------------------------------- + # ------------ bring engine back to ENGINE.OFFLINE ------------------- + # -------------------------------------------------------------------- + @app.post("/stop_engine") + async def stop_engine(): + """ + Terminate the training engine and reset all state. + This will: + - Kill the training process and all its subprocesses (forcefully if necessary) + - Set engine status to OFFLINE + - Remove all episodes (registered, claimed, and unclaimed) + - Clean up shared memory state + """ + kill_process_tree(shared_mem_dict_lock, shared_mem_dict) + + return app, register_episode_ready_listener() + + + +def kill_process_tree(shared_mem_dict_lock=None, shared_mem_dict=None): + logger.exception("[stop_engine] Initiating engine shutdown and cleanup...") + try: + import psutil + + killed_pids = [] + errors = [] + + # Get the training process PID if it exists + if shared_mem_dict and shared_mem_dict_lock: + training_pid = shared_mem_dict.get('training_process_pid', None) + else: + training_pid = os.getpid() + + if training_pid is not None: + try: + # Try to get the process and all its children + try: + parent = psutil.Process(training_pid) + children = parent.children(recursive=True) + + # Kill all child processes first + for child in children: + try: + logger.info(f"[stop_engine] Terminating child process PID: {child.pid}") + child.terminate() + killed_pids.append(child.pid) + except psutil.NoSuchProcess: + logger.warning(f"[stop_engine] Child process {child.pid} already terminated") + except Exception as e: + logger.error(f"[stop_engine] Error terminating child process {child.pid}: {e}") + errors.append(f"Child {child.pid}: {str(e)}") + + # Wait for children to terminate gracefully + gone, alive = psutil.wait_procs(children, timeout=16) + + # Force kill any remaining children + for p in alive: + try: + logger.warning(f"[stop_engine] Force killing child process PID: {p.pid}") + p.kill() + except Exception as e: + logger.error(f"[stop_engine] Error force killing child {p.pid}: {e}") + errors.append(f"Force kill child {p.pid}: {str(e)}") + + # Now terminate the parent process + logger.info(f"[stop_engine] Terminating parent process PID: {training_pid}") + parent.terminate() + killed_pids.append(training_pid) + + # Wait for parent to terminate gracefully + try: + parent.wait(timeout=3) + except psutil.TimeoutExpired: + logger.warning(f"[stop_engine] Force killing parent process PID: {training_pid}") + parent.kill() + + except psutil.NoSuchProcess: + logger.warning(f"[stop_engine] Process {training_pid} not found (may have already terminated)") + + except Exception as e: + logger.error(f"[stop_engine] Error killing training process: {e}") + errors.append(f"Training process: {str(e)}") + else: + logger.info("[stop_engine] No training process PID found in shared memory") + + # Clean up all episodes from shared memory + episode_keys = [] + if shared_mem_dict and shared_mem_dict_lock: + with shared_mem_dict_lock: + episode_keys = [k for k in shared_mem_dict.keys() if is_key_epsisode_status(k)] + for key in episode_keys: + del shared_mem_dict[key] + logger.info(f"[stop_engine] Removed episode: {key}") + + # Clear unclaimed episodes list + if 'unclaimed_episodes' in shared_mem_dict: + num_unclaimed = len(shared_mem_dict['unclaimed_episodes']) + shared_mem_dict['unclaimed_episodes'] = [] + logger.info(f"[stop_engine] Cleared {num_unclaimed} unclaimed episodes") + + # Reset engine status to OFFLINE + shared_mem_dict['engine_status'] = "ENGINE.OFFLINE" + + # Remove training process PID + if 'training_process_pid' in shared_mem_dict: + del shared_mem_dict['training_process_pid'] + + logger.info("[stop_engine] Engine status set to OFFLINE") + + result = { + "success": True, + "killed_pids": killed_pids, + "episodes_removed": len(episode_keys) if 'episode_keys' in locals() else 0, + } + + if errors: + result["warnings"] = errors + logger.warning(f"[stop_engine] Completed with warnings: {errors}") + else: + logger.info(f"[stop_engine] Successfully terminated engine and reset state") + + return result + + except Exception as e: + logger.error(f"[stop_engine] Unexpected error: {e}") + import traceback + traceback.print_exc() + + # Even if there's an error, try to reset the status + try: + if shared_mem_dict and shared_mem_dict_lock: + with shared_mem_dict_lock: + shared_mem_dict['engine_status'] = "ENGINE.OFFLINE" + except: + pass + + return {"success": False, "error": str(e)} \ No newline at end of file diff --git a/ajet/tuner_lib/weight_tuner/experimental/interchange_utils.py b/ajet/tuner_lib/weight_tuner/experimental/interchange_utils.py new file mode 100644 index 00000000..f82def38 --- /dev/null +++ b/ajet/tuner_lib/weight_tuner/experimental/interchange_utils.py @@ -0,0 +1,167 @@ +import os +import time +import httpx +from typing import List +from pydantic import BaseModel +from loguru import logger +from ajet.schema.task import WorkflowOutput +from ajet.utils.networking import find_free_port +from ajet.utils.retry import retry_with_backoff + +VALID_STATUSES = [ + "ENGINE.OFFLINE", + "ENGINE.BOOTING", + "ENGINE.ROLLING", + "ENGINE.ROLLING_POST", + "ENGINE.WEIGHT_SYNCING", + "ENGINE.WEIGHT_EXPORTING" +] + +class SyncTrainConfigRequest(BaseModel): + yaml_as_string: str + +class ClaimEpisodeRequest(BaseModel): + client_uuid: str + episode_type: str + allow_discard_timeout: float + +class ClaimEpisodeResponse(BaseModel): + success: bool + client_uuid: str + episode_uuid: str + openai_base_url: str = "" + openai_api_key: str = "" + fail_cause: str = "" + +class CanContinueEpisodeRequest(BaseModel): + client_uuid: str + episode_uuid: str + +class CanContinueEpisodeResponse(BaseModel): + can_continue: bool + +class EndEpisodeRequest(BaseModel): + client_uuid: str + episode_uuid: str + workflow_output: WorkflowOutput + task_id: str + +class EndEpisodeResponse(BaseModel): + success: bool + + +class EpisodeStatus(BaseModel): + episode_uuid: str + episode_status: str = "rolling" + episode_type: str = "train" + openai_base_url: str = "" + openai_api_key: str = "" + client_uuid: str = "" + zmq_listen_result_addr: str = "" + latest_activity_timestamp: float = time.time() + allow_discard_timeout: float + +class EpisodeBufferResponse(BaseModel): + buffer: List[EpisodeStatus] + + +class BoolResponse(BaseModel): + success: bool + failure_reason: str = "" + +class RegisterEpisodeRequest(BaseModel): + episode_uuid: str + openai_base_url: str = "" + openai_api_key: str = "" + zmq_listen_result_addr: str = "" + + +class UpdateEngineStatusRequest(BaseModel): + engine_status: str = "" + + +DEBUG = False + +def get_interchange_server_url(config): + port = os.getenv("AJET_DAT_INTERCHANGE_PORT") + if config.ajet.interchange_server.interchange_server_port != 'auto': + port = str(int(config.ajet.interchange_server.interchange_server_port)) + assert port is not None, "AJET_DAT_INTERCHANGE_PORT env var must be set" + master_node_ip = os.getenv("MASTER_NODE_IP", "localhost") + base_url = f"http://{master_node_ip}:{port}" + return base_url + + +def http_change_engine_status(config, new_status: str): + if new_status not in VALID_STATUSES: + raise ValueError(f"Invalid engine status: {new_status}") + + resp = httpx.post( + f"{get_interchange_server_url(config)}/update_engine_status", + json={"engine_status": new_status}, + timeout=10 + ) + resp.raise_for_status() + logger.success(f"Changed engine status to {new_status}") + + +def is_episode_claimed(config, episode_uuid: str) -> bool: + # TODO: add cache to reduce communication overhead + resp = httpx.post( + f"{get_interchange_server_url(config)}/is_episode_claimed", + json={"client_uuid": "", "episode_uuid": episode_uuid}, + timeout=5 + ) + resp.raise_for_status() + result = BoolResponse.model_validate(resp.json()) + return result.success + + +@retry_with_backoff(max_retry=15, backoff_fn=lambda attempt: 2) +def http_register_episode(config, + episode_uuid: str, + openai_base_url: str, + openai_api_key: str, + zmq_listen_result_addr: str, + should_exit_soft): + + if should_exit_soft(): + logger.warning(f"Exiting before registering episode {episode_uuid}") + return None + + # parse episode_uuid, openai_base_url, openai_api_key + interchange_http_addr = get_interchange_server_url(config) + rer = RegisterEpisodeRequest( + episode_uuid=episode_uuid, + openai_base_url=openai_base_url, + openai_api_key=openai_api_key, + zmq_listen_result_addr=zmq_listen_result_addr, + ) + # send http request to swarm server to register episode + response = httpx.post( + f"{interchange_http_addr}/register_episode", + json=rer.model_dump(), # 或者 rer.model_dump() 如果使用 Pydantic v2 + timeout=2 + ) + response.raise_for_status() + result = response.json() + if not result.get('success'): + logger.warning(f"Failed to register episode {episode_uuid}") + return None + if DEBUG: logger.info(f"Successfully registered episode {episode_uuid}") + + return True + + +def get_zmq_socket(config, episode_uuid: str, tag: str = ""): + interchange_method = config.ajet.interchange_server.interchange_method + if interchange_method == 'tcp': + ipc_path = "" + master_node_ip = os.getenv("MASTER_NODE_IP", "localhost") + zmq_contect_address = f"tcp://{master_node_ip}:{find_free_port()}" + elif interchange_method == 'ipc': + ipc_path = f"/tmp/ajet/{episode_uuid}-{tag}.sock" + zmq_contect_address = f"ipc://{ipc_path}" + else: + raise RuntimeError(f"Unknown interchange_method: {interchange_method}") + return zmq_contect_address, ipc_path diff --git a/ajet/utils/config_utils.py b/ajet/utils/config_utils.py index 57663ff1..273c2cda 100644 --- a/ajet/utils/config_utils.py +++ b/ajet/utils/config_utils.py @@ -168,7 +168,7 @@ def config_safe_guard(config: dict, backbone: str) -> dict: def read_ajet_hierarchical_config( - yaml_fp, exp_name, backbone, write_to=None, exp_dir="saved_experiments" + yaml_fp, exp_name, backbone, write_to=None, exp_dir="saved_experiments", override_param_callback=None ): if yaml_fp is None: config = { @@ -210,6 +210,9 @@ def read_ajet_hierarchical_config( config["defaults"].remove("trinity_default") config["hydra"]["searchpath"].remove("file://ajet/default_config/trinity") + if override_param_callback is not None: + config = override_param_callback(config) + if write_to: with open(write_to, "w") as file: yaml.dump(config, file) @@ -239,7 +242,7 @@ def expand_ajet_hierarchical_config(config, write_to=None): return config_final -def prepare_experiment_config(yaml_path, exp_dir, backbone): +def prepare_experiment_config(yaml_path, exp_dir, backbone, override_param_callback=None): """ Prepare experiment configuration by reading YAML, setting up backup directories, and copying necessary files for the experiment. @@ -253,7 +256,7 @@ def prepare_experiment_config(yaml_path, exp_dir, backbone): tuple: (yaml_backup_dst, exe_exp_base, exp_name, config_final) """ assert yaml_path.endswith(".yaml"), "Configuration file must be a YAML file" - exp_base = os.path.dirname(yaml_path) + exp_base = os.path.exists(os.path.dirname(yaml_path)) if not os.path.exists(exp_base): raise FileNotFoundError(f"Configuration file not found: {exp_base}") @@ -317,7 +320,7 @@ def prepare_experiment_config(yaml_path, exp_dir, backbone): ## 4. edit new yaml config = read_ajet_hierarchical_config( - yaml_backup_dst, exp_name, backbone, write_to=yaml_backup_dst, exp_dir=exp_dir + yaml_backup_dst, exp_name, backbone, write_to=yaml_backup_dst, exp_dir=exp_dir, override_param_callback=override_param_callback ) config_final = expand_ajet_hierarchical_config(config, write_to=yaml_backup_dst) diff --git a/ajet/utils/core_env_vars.py b/ajet/utils/core_env_vars.py index 91fdf736..e48e1dda 100644 --- a/ajet/utils/core_env_vars.py +++ b/ajet/utils/core_env_vars.py @@ -1,4 +1,5 @@ import os +import copy from pathlib import Path from beast_logger import print_dict @@ -18,6 +19,11 @@ def get_runtime_env(config, is_trinity: bool = False) -> dict: if config.ajet.interchange_server.interchange_method == "ipc": raise ValueError("IPC interchange method is not supported for multi-node setup. Please set `ajet.interchange_server.interchange_method: tcp` ") + if config.ajet.interchange_server.interchange_server_port != 'auto': + data_interchange_port = str(int(config.ajet.interchange_server.interchange_server_port)) + else: + data_interchange_port = str(find_free_port()) + runtime_env = { "env_vars": { "VLLM_USE_V1": "1", @@ -29,8 +35,8 @@ def get_runtime_env(config, is_trinity: bool = False) -> dict: # "VLLM_ALLOW_RUNTIME_LORA_UPDATING": "true", "SWANLAB_API_KEY": os.getenv("SWANLAB_API_KEY", ""), "AJET_CONFIG_REDIRECT": os.getenv("AJET_CONFIG_REDIRECT", ""), - "AJET_DAT_INTERCHANGE_PORT": str(find_free_port()), - "MASTER_NODE_IP": master_node_ip, + "AJET_DAT_INTERCHANGE_PORT": os.getenv("AJET_DAT_INTERCHANGE_PORT", data_interchange_port), + "MASTER_NODE_IP": os.getenv("MASTER_NODE_IP", master_node_ip), } } @@ -56,5 +62,12 @@ def get_runtime_env(config, is_trinity: bool = False) -> dict: if is_trinity: assert "AJET_CONFIG_REDIRECT" in runtime_env["env_vars"] - print_dict(runtime_env["env_vars"], "runtime_env") + print_env_dict = copy.deepcopy(runtime_env["env_vars"]) + # limit value length for printing + for k, v in print_env_dict.items(): + _len_limit = 500 + _len_limit_half = _len_limit // 2 + if len(v) > _len_limit: + print_env_dict[k] = v[:_len_limit_half] + "..." + v[-_len_limit_half:] + print_dict(print_env_dict, "runtime_env") return runtime_env diff --git a/ajet/utils/launch_utils.py b/ajet/utils/launch_utils.py index 978539ab..441fbc93 100644 --- a/ajet/utils/launch_utils.py +++ b/ajet/utils/launch_utils.py @@ -7,11 +7,69 @@ from beast_logger import print_dict from loguru import logger +from types import SimpleNamespace from ajet.utils.config_utils import align_parameters from ajet.utils.smart_daemon import LaunchCommandWhenAbsent + +def get_backbone_target(backbone): + """ + Determine the appropriate backbone target module based on the backbone name. + + Args: + backbone (str): The backbone name (e.g., "verl", "debug", "trinity") + + Returns: + str: The full module path for the specified backbone + """ + backbone_target = "ajet.backbone.main_verl" # Default to trinity + if backbone == "verl": + backbone_target = "ajet.backbone.main_verl" + if backbone == "debug": + backbone_target = "ajet.backbone.main_vllm" + if backbone == "trinity": + backbone_target = "ajet.backbone.main_trinity" + return backbone_target + + +def setup_environment_vars(args, exp_config, main_yaml_fp): + """ + Configure environment variables based on command line arguments. + + Args: + args: Command line arguments + exp_config: Experiment configuration dictionary + main_yaml_fp: Path to main YAML configuration file + + Returns: + dict: Configured environment variables dictionary + """ + env = os.environ.copy() + if args.debug: + env["RAY_DEBUG_POST_MORTEM"] = "1" + env["DEBUG_TAGS"] = args.debug + env["RAY_record_task_actor_creation_sites"] = "true" + # assert exp_config["ajet"]["rollout"]["max_env_worker"] <= 4, "parallel worker too many for debugging mode" # type: ignore + if exp_config["ajet"]["rollout"]["max_env_worker"] > 1: # type: ignore + # exp_config["ajet"]["rollout"]["max_env_worker"] = 1 + logger.warning( + "For debugging mode, please set max_env_worker to 1 to facilitate debugging." + ) + logger.warning("Debug mode is ON") + else: + logger.warning("Debug mode is OFF") + # if args.conf: + # assert exp_config["ajet"]["rollout"]["max_env_worker"] > 4, "parallel worker too few" # type: ignore + if args.backbone == "trinity": + env["AJET_CONFIG_REDIRECT"] = main_yaml_fp # type: ignore + if args.backbone == "debug": + env["AJET_DEBUG"] = "1" # type: ignore + return env, exp_config + + + def set_loguru_default_color(): logger.remove() colorize = os.environ.get("LOGURU_COLORIZE", "YES").upper() not in ["NO", "0", "FALSE"] @@ -25,6 +83,101 @@ def set_loguru_default_color(): return + +def check_debugpy_version(): + try: + import debugpy + except ImportError: + raise RuntimeError( + "Module 'debugpy>=1.8.0' cannot be loaded. " + "Ray Debugpy Debugger will not work without 'debugpy>=1.8.0' installed. " + "Install this module using 'pip install debugpy>=1.8.0'" + ) + version = getattr(debugpy, "__version__", "0.0.0") + from packaging import version as packaging_version + + if packaging_version.parse(version) < packaging_version.parse("1.8.0"): + raise RuntimeError( + f"debugpy version {version} is too old. " + "Ray Debugpy Debugger requires 'debugpy>=1.8.0'. " + "Upgrade using 'pip install debugpy>=1.8.0'" + ) + logger.info(f"✓ debugpy version {version} meets requirement (>=1.8.0)") + + +def check_avail_gpu(min_free_ratio: float = 0.95): + """ + Ensure there is at least one GPU and all GPUs have >= min_free_ratio free memory. + + Uses `nvidia-smi` to query total and used memory for each GPU. + Raises RuntimeError if no GPU is found or any GPU violates the free ratio threshold. + """ + try: + # Query GPU memory via nvidia-smi; output in MiB + result = subprocess.run( + [ + "nvidia-smi", + "--query-gpu=name,memory.total,memory.used", + "--format=csv,noheader,nounits", + ], + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + text=True, + check=False, + ) + except FileNotFoundError: + raise RuntimeError("nvidia-smi not found. NVIDIA drivers/GPU may be unavailable.") + + if result.returncode != 0: + raise RuntimeError(f"Failed to query GPUs via nvidia-smi: {result.stderr.strip()}") + + lines = [line.strip() for line in result.stdout.splitlines() if line.strip()] + if not lines: + raise RuntimeError("No GPUs detected by nvidia-smi.") + + violations = [] + for idx, line in enumerate(lines): + # Expected format: ", , " + parts = [p.strip() for p in line.split(",")] + if len(parts) < 3: + violations.append((idx, "parse-error", line)) + continue + name, total_str, used_str = parts[0], parts[1], parts[2] + try: + total = float(total_str) + used = float(used_str) + except ValueError: + violations.append((idx, "parse-error", line)) + continue + free = max(total - used, 0.0) + free_ratio = free / total if total > 0 else 0.0 + logger.info( + f"GPU {idx} ({name}): total={total:.0f} MiB, used={used:.0f} MiB, free_ratio={free_ratio:.3f}" + ) + if free_ratio < min_free_ratio: + violations.append((idx, name, f"free_ratio={free_ratio:.3f} < {min_free_ratio:.3f}")) + + if violations: + details = "; ".join([f"GPU {i} ({n}): {msg}" for i, n, msg in violations]) + raise RuntimeError( + "GPU memory check failed: all GPUs must have >= " + f"{int(min_free_ratio*100)}% free. Violations: {details}" + ) + logger.info( + f"✓ GPU check passed: {len(lines)} GPUs, all >= {int(min_free_ratio*100)}% free memory" + ) + + +def dict_to_namespace(d): + """Recursively convert a nested dictionary to a SimpleNamespace.""" + if isinstance(d, dict): + return SimpleNamespace(**{k: dict_to_namespace(v) for k, v in d.items()}) + elif isinstance(d, list): # 如果字典中嵌套了列表,递归处理列表中的每个元素 + return [dict_to_namespace(item) for item in d] + else: + return d + + def launch_logview(exp_name=None): """ Launch the log viewer service and open the web browser to view logs. diff --git a/ajet/utils/retry.py b/ajet/utils/retry.py index 339eb7bb..7f33466b 100644 --- a/ajet/utils/retry.py +++ b/ajet/utils/retry.py @@ -1,10 +1,11 @@ import time from functools import wraps from typing import Any, Callable, Optional, TypeVar - from loguru import logger -from ajet.utils.testing_utils import TestFailException, TestSuccessException +class SwarmReceiveAbortException(Exception): + pass + T = TypeVar("T") @@ -17,6 +18,8 @@ def retry_with_backoff( """Retry decorator with exponential backoff and structured logging.""" def decorator(func: Callable[..., T]) -> Callable[..., T]: + from ajet.utils.testing_utils import TestFailException, TestSuccessException + @wraps(func) def wrapper(*args: Any, **kwargs: Any) -> T: target_max_retry = max_retry @@ -27,27 +30,31 @@ def wrapper(*args: Any, **kwargs: Any) -> T: if target_max_retry < 1: target_max_retry = 1 - for attempt in range(target_max_retry): - try: - return func(*args, **kwargs) - except TestSuccessException as exc: # noqa: BLE001 - raise exc - except TestFailException as exc: # noqa: BLE001 - raise exc - except Exception as exc: # noqa: BLE001 - if attempt < target_max_retry - 1: - logger.bind(exception=True).exception( - f"{func.__name__} error: {exc.args}, retrying {attempt + 1}/{target_max_retry}" - ) - sleep_seconds = backoff_fn(attempt) if backoff_fn else 2**attempt - time.sleep(sleep_seconds) - else: - logger.bind(exception=True).exception( - f"{func.__name__} failed after {target_max_retry} retries: {exc.args}" - ) - raise - - raise RuntimeError("retry_with_backoff exhausted attempts") + try: + for attempt in range(target_max_retry): + try: + return func(*args, **kwargs) + except TestSuccessException as exc: # noqa: BLE001 + raise exc + except TestFailException as exc: # noqa: BLE001 + raise exc + except Exception as exc: # noqa: BLE001 + if attempt < target_max_retry - 1: + logger.bind(exception=True).exception( + f"{func.__name__} error: {exc.args}, retrying {attempt + 1}/{target_max_retry}" + ) + sleep_seconds = backoff_fn(attempt) if backoff_fn else 2**attempt + time.sleep(sleep_seconds) + else: + logger.bind(exception=True).exception( + f"{func.__name__} failed after {target_max_retry} retries: {exc.args}" + ) + raise + + raise RuntimeError("retry_with_backoff exhausted attempts") + except SwarmReceiveAbortException as exc: # noqa: BLE001 + # ignore exception, return None silently + return None # type: ignore return wrapper diff --git a/ajet/utils/thread_executors.py b/ajet/utils/thread_executors.py index 1ab02baf..797ac54c 100644 --- a/ajet/utils/thread_executors.py +++ b/ajet/utils/thread_executors.py @@ -1,14 +1,14 @@ +from concurrent.futures import ThreadPoolExecutor from ajet.utils.sington import singleton -import concurrent.futures - +import threading @singleton class SharedInterchangeThreadExecutor: def __init__(self, max_workers=64): - self.executor = concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) + self.executor = ThreadPoolExecutor(max_workers=max_workers) - def get_shared_executor(self) -> concurrent.futures.ThreadPoolExecutor: + def get_shared_executor(self) -> ThreadPoolExecutor: return self.executor @@ -16,7 +16,27 @@ def get_shared_executor(self) -> concurrent.futures.ThreadPoolExecutor: @singleton class SharedInferenceTrackerThreadExecutor: def __init__(self, max_workers=64): - self.executor = concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) + self.executor = ThreadPoolExecutor(max_workers=max_workers) - def get_shared_executor(self) -> concurrent.futures.ThreadPoolExecutor: + def get_shared_executor(self) -> ThreadPoolExecutor: return self.executor + + +class BoundedThreadPoolExecutor: + def __init__(self, max_workers, max_queue_size=100): + self.executor = ThreadPoolExecutor(max_workers=max_workers) + self.semaphore = threading.Semaphore(max_queue_size) + + def submit(self, fn, *args, **kwargs): + self.semaphore.acquire() + + def wrapped_fn(*args, **kwargs): + try: + return fn(*args, **kwargs) + finally: + self.semaphore.release() + + return self.executor.submit(wrapped_fn, *args, **kwargs) + + def shutdown(self, wait=True): + self.executor.shutdown(wait=wait) \ No newline at end of file diff --git a/ajet/workflow.py b/ajet/workflow.py index 58c8757d..b2eaaf16 100644 --- a/ajet/workflow.py +++ b/ajet/workflow.py @@ -2,7 +2,7 @@ from pydantic import BaseModel, Field -from ajet import AjetTuner +from ajet.tuner import AjetTuner from ajet.schema.task import WorkflowOutput, WorkflowTask diff --git a/ajet_swarm_threading.py b/ajet_swarm_threading.py new file mode 100644 index 00000000..ed3f84b2 --- /dev/null +++ b/ajet_swarm_threading.py @@ -0,0 +1,130 @@ +import re +import time +import threading +import requests +from loguru import logger +from textwrap import dedent +from ajet.schema.task import Task, WorkflowOutput +from ajet.copilot.job import AgentJetJob +from ajet.task_reader import RouterTaskReader +from ajet.utils.retry import retry_with_backoff +from ajet.default_config.ajet_default import AjetTaskReader, HuggingfaceDatRepo +from ajet.tuner_lib.weight_tuner.as_oai_baseurl_apikey import OpenaiBaseUrlAndApiKey +from ajet.tuner_lib.weight_tuner.experimental.as_swarm_client import SwarmClient +from concurrent.futures import ThreadPoolExecutor + +# --------- configurations that take effect locally ------------- +LOCAL_GRPO_N = 4 # grpo group size +LOCAL_NUM_EPOCH = 10000 +LOCAL_NUM_EPOCH = 1 +LOCAL_MAX_PARALLEL = 64 +LOCAL_DATASET_PATH = "/mnt/data_cpfs/qingxu.fu/dataset/openai/gsm8k/main" +REMOTE_SWARM_URL = "http://localhost:10086" # Change to your swarm remote url + +# --------- configurations that take effect remotely ------------- +REMOTE_BATCH_SIZE = 32 +REMOTE_ALLOCATE_GPU_PER_NODE = 4 +REMOTE_TRAIN_MODEL_01 = '/mnt/data_cpfs/model_cache/modelscope/hub/Qwen/Qwen/Qwen2.5-7B-Instruct' + + + +class WeightUpdatedHalfway(Exception): + """Raised when the remote side starts updating model weights halfway through an episode.""" + + +def main(): + + # Handshake with swarm remote, then send training param to swarm remote (such as model to be trained, algorithm, etc) + dataset = RouterTaskReader( + reader_type = "huggingface_dat_repo", + reader_config = AjetTaskReader( + huggingface_dat_repo = HuggingfaceDatRepo( + dataset_path = LOCAL_DATASET_PATH + ) + ) + ) + + # # Hand shake with remote swarm server + swarm_remote = SwarmClient(REMOTE_SWARM_URL) + swarm_remote.auto_sync_train_config_and_start_engine( + AgentJetJob( + algorithm="grpo", + n_gpu=REMOTE_ALLOCATE_GPU_PER_NODE, + model=REMOTE_TRAIN_MODEL_01, + batch_size=REMOTE_BATCH_SIZE, + grpo_n=LOCAL_GRPO_N, + ) + ) + + def rollout(task): + group_reward = [] + try: + for _ in range(LOCAL_GRPO_N): + try: + # begin episode + episode_uuid, api_baseurl_key = swarm_remote.begin_episode() + # execute agent + workflow_output = execute_agent(task, api_baseurl_key) + # report output back to swarm remote + swarm_remote.end_episode(task, episode_uuid, workflow_output) + # collect reward + group_reward.append(workflow_output.reward) + except Exception as e: + logger.exception("Exception during rollout:", e) + + print(f"Group reward mean & std: {sum(group_reward)/len(group_reward)} +/- { (max(group_reward)-min(group_reward))/2 }") + except Exception as e: + logger.exception("Exception during rollout group", e) + + task_batch = [] + for i, task in enumerate(dataset.generate_training_tasks()): + task_batch += [task] + + if len(task_batch) == REMOTE_BATCH_SIZE: + print('*********** beginning a new batch of tasks... ***********') + with ThreadPoolExecutor(max_workers=LOCAL_MAX_PARALLEL) as executor: + for task in task_batch: + executor.submit(rollout, task) + executor.shutdown(wait=True) + task_batch = [] + print('*********** tasks completed, wait a minute... ***********') + time.sleep(3) + + + return None + + + + +@retry_with_backoff(max_retry=2) +def execute_agent(task: Task, api_baseurl_key: OpenaiBaseUrlAndApiKey): + # Prepare base_url, api_key + base_url, api_key = (api_baseurl_key.base_url, api_baseurl_key.api_key) + # Read dataset item + query, reference_answer = (task.main_query, task.metadata["answer"]) + # Prepare messages + messages = [ + { "role": "system", "content": dedent("""You are an agent specialized in solving math problems. Please solve the math problem given to you. + You can write and execute Python code to perform calculation or verify your answer. You should return your final answer within \\boxed{{}}.""") }, + { "role": "user", "content": query } + ] + # Use raw http requests (non-streaming) to get response + response = requests.post( f"{base_url}/chat/completions", json = { "model": "fill_whatever_model", "messages": messages, }, + headers = { "Authorization": f"Bearer {api_key}" } ) + final_answer = response.json()['choices'][0]['message']['content'] + # print(final_answer) + # Compute reward + reference_answer = reference_answer.split("####")[-1].strip() + pattern = r"\\boxed\{([^}]*)\}" + match = re.search(pattern, final_answer) + if match: is_success = match.group(1) == reference_answer + else: is_success = False + raw_reward = 1.0 if is_success else 0.0 + # Return + return WorkflowOutput(reward=raw_reward, metadata={"final_answer": final_answer}) + + + + +if __name__ == "__main__": + main() diff --git a/ajet_tinkerscript.md b/ajet_tinkerscript.md new file mode 100644 index 00000000..c256853d --- /dev/null +++ b/ajet_tinkerscript.md @@ -0,0 +1 @@ +python -m ajet.launcher --conf tutorial/demo_tinkerjet/ajet_swarm_default.yaml --backbone="debug" --autokill diff --git a/docs/en/platform_comparison.md b/docs/en/platform_comparison.md index fa263918..04ffe968 100644 --- a/docs/en/platform_comparison.md +++ b/docs/en/platform_comparison.md @@ -5,7 +5,7 @@ - Multi OSS Training Backbone: Support switching between multiple open-source training backbones quickly. - Multi OSS Infer Backbone: Support both vLLM and SGLang. - Low Code Change: Do not require too many edits to convert a user‑defined (multi) agent workflow into trainable workflows. -- Without-GPU (Cloud-Computing): Rollout and power RL training in a laptop without GPU, using Tinker (AgentLightning) or without Tinker (AgentJet-TinkerScript, comming soon) +- Without-GPU (Cloud-Computing): Rollout and power RL training in a laptop without GPU, using Tinker (AgentLightning) or without Tinker (AgentJet-Swarm, comming soon) - Timeline Optimization: Automatically merge shared-history context generated by the same agents to promote training speed. - Open Bench Platform: Trace baseline environment's performance across git history in different training backbones. - Multi-Agent Optimization: Deal with sophisticated multi-agent interaction efficiently, automatically clustering and merging samples generated by the same agents. diff --git a/docs/en/workflow.md b/docs/en/workflow.md index 1137f02c..1647a219 100644 --- a/docs/en/workflow.md +++ b/docs/en/workflow.md @@ -241,7 +241,7 @@ Here's a complete example with multiple agent roles (Werewolves game): - You can flexibly switch training targets by modifying `trainable_targets` -## TinkerJet +## Swarm Wrapping and training your agent on a machine without GPU. diff --git a/docs/index.md b/docs/index.md index c9dc7f3f..ce15a67c 100644 --- a/docs/index.md +++ b/docs/index.md @@ -67,7 +67,7 @@

Any Training Engine

- Support multiple training engines as backbone (VeRL and Trinity-RFT). Tinker backbone support will be released soon. + Support multiple training engines as backbone (VeRL and Trinity-RFT). Swarm backbone support will be released soon. Choose from vLLM and SGLang as you wish. Say goodbye to training engine gaps.

diff --git a/tests/bench/benchmark_learn2ask/benchmark_learn2ask.py b/tests/bench/benchmark_learn2ask/benchmark_learn2ask.py index 762609ce..697f0bad 100644 --- a/tests/bench/benchmark_learn2ask/benchmark_learn2ask.py +++ b/tests/bench/benchmark_learn2ask/benchmark_learn2ask.py @@ -27,7 +27,7 @@ def __init__(self): # step : [low, high ] 50 : [2.3, 99999.0], 100 : [2.5, 99999.0], - 200 : [2.9, 99999.0], + 200 : [2.6, 99999.0], } # fmt: on self.probe_list = ["reward_probe"] diff --git a/tinkerscript.md b/tinkerscript.md new file mode 100644 index 00000000..05dfcce3 --- /dev/null +++ b/tinkerscript.md @@ -0,0 +1,155 @@ +# Swarm Design Blueprint / Swarm 设计蓝图 + +[English](#english-version) | [中文](#chinese-version) + +--- + + +## 🇬🇧 English Version + +### 1. Overview +**Swarm** is an experimental component of AgentJet designed to decouple the **Training Logic** from the **Agent Execution Logic**. It allows users to train **full-weight LLM models** on machines without GPUs (e.g., a laptop) by offloading the actual model computation to a remote GPU server. + +Unlike traditional setups where the user code must run inside the training cluster, Swarm allows you to verify and run your agent logic locally while the heavy lifting (training & inference) happens remotely. + + +> +> Relationship between **Swarm** and **Tinker**: +> +> **No relationship at all** (just like **JavaScript** and **Java**). **Swarm** is open-source and free. **Tinker** is close-source and not free. + + +## Tinker 与 AgentJet-Swarm 对比表 + +| 特征 | Tinker | AgentJet-Swarm | +|------|--------|--------------| +| **开源性质** | ❌ 闭源 | **✅ 开源免费** | +| **收费模式** | 付费服务 | **✅ 完全免费** | +| **目标用户** | 研究人员和开发者 | 研究人员和开发者 | +| **任务** | 各种 LLM 训练 | 专精 LLM Agent RL训练 | +| **核心功能** | LLM 微调训练 API | **✅ LLM 微调训练整套解决方案** | +| **架构模式** | 托管服务 + 单点客户端 API | **✅ 服务器和客户端都可按需拓展** | +| **多客户端共同参与训练** | ❌ 不支持 | **✅ 支持** | +| **远程算力部署方式** | Thinking Machines Lab 公司提供定价 | **✅ 自建 GPU 服务器端 或 使用阿里云灵骏** | +| **训练方式** | ❌ LoRA 微调 | **✅ 全量 LLM 模型训练** | +| **支持的模型** | ❌ 少部分 LLM 模型 | **✅ 大多数新旧 LLM 模型** | +| **最大模型规模** | Llama 70B、Qwen 235B | **✅ 取决于用户 GPU 集群配置** | +| **通信协议** | 专有 API | **✅ 专有API + OpenAI兼容API** | +| **推理引擎后端** | 内置未知推理服务 | **✅ vLLM/SGLang任选** | + + + +### 2. Core Architecture + +The system involves two main parties: the **Swarm Server** (running on the GPU cluster) and the **Swarm Client** (running on your local machine). + +```mermaid +graph TD + subgraph "GPU Cluster (Server Side)" + TrainingLoop["Training Loop (AgentJet/GRPO)"] + TSS["Swarm Server (FastAPI)"] + ZMQ["ZeroMQ / IPC"] + SharedMem[("Shared Memory")] + LLM["LLM Engine (vLLM/SGLang)"] + end + + subgraph "User Laptop / CPU Cluster (Client Side)" + UserScript["User Script (python while loop)"] + AgentLogic["Agent Logic / Tools"] + end + + TrainingLoop -- "1. Generate Task" --> SharedMem + SharedMem -- "2. Register Episode" --> TSS + + UserScript -- "3. Claim Episode (HTTP)" --> TSS + TSS -- "4. Return API Key & Base URL" --> UserScript + + UserScript -- "5. Inference (OpenAI API)" --> LLM + LLM -- "Token Stream" --> UserScript + + UserScript -- "6. Submit Reward (HTTP)" --> TSS + TSS -- "7. Push Result" --> ZMQ + ZMQ -- "8. Update Weights" --> TrainingLoop +``` + +### 3. Detailed Workflow + +The workflow relies on a "Claim & Submit" model. The training loop generates tasks ("Episodes") and waits for external workers to pick them up. + +```mermaid +sequenceDiagram + participant TL as Training Loop (Internal) + participant S as Server (FastAPI) + participant C as Client (User Script) + participant M as LLM Model + + Note over TL, S: 1. Task Generation + TL->>S: Register Episode (Status: Unclaimed) + + Note over C, S: 2. Task Acquisition + loop Worker Loop + C->>S: POST /claim_episode + alt No Tasks + S-->>C: Retry Later + else Task Available + S->>S: Mark as "Claimed" + S-->>C: Return {EpisodeID, OpenAI_BaseURL, API_Key} + end + + Note over C, M: 3. Execution (Rollout) + C->>M: Chat Completion Request (Inference) + M-->>C: Response (Generation) + C->>C: Calculate Reward (e.g., Verify Math Answer) + + Note over C, S: 4. Result Submission + C->>S: POST /end_episode {Reward, Metadata} + S->>TL: Forward Result via ZeroMQ + S->>S: Delete Episode Record (Complete) + end +``` + +### 4. Episode State Machine + +To handle network failures or client crashes, the server maintains a state machine for every episode. + +```mermaid +stateDiagram-v2 + [*] --> Registered + Registered --> Unclaimed_Queue : Add to Queue + + Unclaimed_Queue --> Claimed : Client requests task + + Claimed --> Completed : Client submits result + Claimed --> Registered : Client Timeout / Crash + + Completed --> [*] : Removed from Memory +``` + +* **Registered**: Task created by the training algorithm. +* **Claimed**: A client is currently working on it. +* **Timeout**: If a client claims a task but doesn't report back within `allow_discard_timeout`, the server reverts the status to **Registered** so another client can try. + +### 5. Implementation Example + +The user experience is designed to be minimal. You simply query the remote server for a "job", do the work, and report the "score". + +```python +# User-side Code Concept +def rollout(task): + # 1. Handshake & Claim (Get credentials for this specific episode) + api_baseurl_key = tinkerjet_remote.begin_episode() + + # 2. Run your existing agent logic using standard OpenAI format + workflow_output = execute_agent(task, api_baseurl_key) + + # 3. Submit results + tinkerjet_remote.end_episode(workflow_output) + return workflow_output.reward +``` + + + diff --git a/tinkerscript_1.md b/tinkerscript_1.md new file mode 100644 index 00000000..c15df6db --- /dev/null +++ b/tinkerscript_1.md @@ -0,0 +1,120 @@ +# Swarm Design Blueprint + +Swarm represents a client-server architecture designed to decouple the **Training Loop** (Server-side) from the **Rollout Execution** (Client-side). This allows for distributed, flexible, and potentially remote execution of agent rollouts (inference + reward calculation) while centralizing the model training and weight updates. + +## 1. System Architecture + +The system consists of three main components: + +### A. Swarm Server (The Trainer) +* **Role**: Manages the training lifecycle, generates tasks (episodes), serves the model (LLM) API, and updates model weights. +* **Technology**: Python, FastAPI, ZeroMQ (IPC/TCP), Shared Memory (Multiprocessing). +* **Location**: Runs on the GPU cluster/Training node. +* **Key Functionality**: + * Maintains a queue of "Episodes" (training tasks). + * Exposes an HTTP API for external clients to claim tasks and submit results. + * Acts as a bridge between the HTTP world and the internal ZeroMQ-based training pipeline. + +### B. Swarm Client (The User Script) +* **Role**: Fetches tasks, runs the agent logic, computes rewards, and reports back. +* **Technology**: Python (Requests/HTTPX). +* **Location**: Can run locally, on a separate CPU cluster, or even a different cloud environment. +* **Key Functionality**: + * Connects to the Server URL. + * Claims episodes via `begin_episode()`. + * Executes the agent logic (e.g., calling the LLM, running Python code). + * Calculates rewards (e.g., verifying math answers). + * Submits results via `end_episode()`. + +### C. The LLM Serving Layer (Implicit) +* The system provides an OpenAI-compatible API endpoint (`base_url`, `api_key`) to the client for LLM inference. This is likely hosted by the training system itself or a proxy, enabling the client to query the model being trained. + +--- + +## 2. Detailed Workflow + +### Step 1: Episode Generation & Registration (Server Side) +The training loop (e.g., RL algorithm like GRPO) generates a new task. +1. An internal component registers a new episode via `register_episode`. +2. The server stores this in `shared_mem_dict` with status `registered`. +3. The episode is added to the `unclaimed_episodes` queue. +4. The server sets up a ZeroMQ socket to listen for the result of this specific episode. + +### Step 2: Task Claiming (Client Side) +The user's script calls `tinkerjet_remote.begin_episode()`. +1. **Request**: `POST /claim_episode` +2. **Server Logic**: + * Checks `unclaimed_episodes`. + * If available, pops one episode. + * Updates status to `claimed`. + * Records `client_uuid` and `latest_activity_timestamp`. +3. **Response**: Returns `episode_uuid` and **OpenAI Credentials** (Base URL + API Key) specific to this session/model. + +### Step 3: Rollout & Execution (Client Side) +The user's script (`execute_agent`) runs: +1. Uses the provided OpenAI API to chat with the model (performing the actual inference step of the RL loop). +2. Parses the model's output. +3. Computes a reward (e.g., checking if `\boxed{answer}` matches ground truth). + +### Step 4: Result Submission (Client Side) +The user's script calls `tinkerjet_remote.end_episode()`. +1. **Request**: `POST /end_episode` with `workflow_output` (Reward + Metadata). +2. **Server Logic**: + * Validates the episode exists and is claimed by this client. + * Connects to the internal ZeroMQ socket associated with this episode. + * Forwards the `workflow_output` payload into the ZeroMQ socket, effectively pushing it back into the training loop. + * Waits for an acknowledgment. + * Deletes the episode record from memory upon success. + +### Step 5: Failure Recovery & Timeouts +* **Crash Recovery**: If a client crashes after claiming a task, the server tracks `latest_activity_timestamp`. +* **Requisition**: A background check (`find_claimed_episodes_that_need_to_be_unclaimed`) reverts "stale" claimed episodes back to `registered` status so other clients can pick them up. +* **Weight Updates**: If the server moves to a weight update phase, it might verify if an episode is still valid via `can_continue_episode`. + +--- + +## 3. Data Structures & API Design + +### Episode Status Object +Stored in Server Shared Memory: +```python +class EpisodeStatus: + episode_uuid: str # Unique ID for the task + client_uuid: str # ID of the worker claiming it + episode_status: str # "registered", "claimed" + openai_base_url: str # Endpoint for the model + openai_api_key: str # Auth for the model + zmq_listen_result_addr: str # Internal address to forward results to + latest_activity_timestamp: float +``` + +### API Endpoints + +| Method | Endpoint | Description | +| :--- | :--- | :--- | +| `POST` | `/claim_episode` | Worker requests a job. Returns UUID + LLM credentials. | +| `POST` | `/end_episode` | Worker submits results (Reward). Completes the cycle. | +| `POST` | `/can_continue_episode` | Checks if the episode is still valid (e.g., weights haven't changed). | +| `POST` | `/register_episode` | (Internal/Debug) Adds a task to the queue. | +| `GET` | `/get_engine_status` | Returns system health/state (e.g., "booting", "ready"). | +| `POST` | `/sync_train_config` | Syncs configuration yaml (logging/debug). | + +--- + +## 4. Key Configurations + +From `ajet_swarm_default.yaml`, we see how this mode is activated: + +```yaml +experiment_dir: "auto" +enable_swarm_mode: True # Activates the HTTP API Server +interchange_server: + interchange_method: 'ipc' # Internal communication (ZeroMQ) + interchange_server_port: 10086 # HTTP API Port +``` + +## 5. Benefits of this Design + +1. **Flexibility**: Users can write custom python logic for "Rollout" without modifying the core C++/Python training engine. +2. **Distributed Generation**: You can have 1 training node and 1000 cheap CPU nodes just running the python script to generate data. +3. **Complex Logic Support**: Since the rollout is just a client script, it can call external tools, Sandboxed code interpreters, or APIs (Google Search) easily before calculating the reward. diff --git a/tutorial/demo_tinkerjet/README.md b/tutorial/demo_tinkerjet/README.md new file mode 100644 index 00000000..d7e1ad20 --- /dev/null +++ b/tutorial/demo_tinkerjet/README.md @@ -0,0 +1,64 @@ +# Swarm + + +Swarm is an experimental component of AgentJet, +allowing users to +- run, debug and train **full-weight** LLM model behind user-defined LLM workflows in **machines without GPU**. + +Similar to Tinker & Open-Tinker, the basic idea behind Swarm is to: +- use remote (or cloud) GPU machine(s) as computation media. + +However, Swarm goes even further on this path: + +- Users only need to write and run their agents in a big `while` loop (e.g., in their laptop), and provide samples generated in this process. + +- Swarm will take care of everything else. + +- Swarm trains **full-weight** LLM model instead of lora. + +- Upon the termination of the training session, user can call `download_tuned_model` to download tuned LLM(s). + + +# Core Training Code + +The core code at user-side is as simple as: + +```python + +# step 1: ... write user-defined `execute_agent` +# step 2: ... init `tinkerjet_remote` to handshake with remote GPU server +# step 3: ... define hyper-parameters `NUM_EPOCH`, `GRPO_N` +# step 4: ... spawn `dataset` from dataset file + +# step 5: rock & roll +## rollout +def rollout(task): + try: + api_baseurl_key = tinkerjet_remote.begin_episode() + workflow_output = execute_agent(task, api_baseurl_key) + tinkerjet_remote.end_episode(workflow_output) + return workflow_output.reward + except Exception as e: + print(f"Episode abandoned") + return 0.0 +## Main Training loop +for epoch in range(NUM_EPOCH): + for task in dataset.get_training_tasks(): + for i in range(GRPO_N): + reward = rollout(task) + print(f"{epoch}-{task}-run:{i}-{reward}") + +# step 6: get trained model and shutdown +tuned_model_checkpoint = tinkerjet_remote.download_tuned_model() +tinkerjet_remote.close() + +``` + +# Limitation + +- Users are only limited to use OpenAI `baseurl` + `apikey` to build applications. Features such as `tuner.as_agentscope_model` is no longer available. + +- AgentJet are not able to explicitly distinguish different agents in multi-agent scenario. + But **do not worry**, AgentJet will still try its best to recognize shards of llm timelines and merge them behind the curtain, automatically. + +- Swarm does not support prompt tuning. diff --git a/tutorial/demo_tinkerjet/ajet_tinkerscript_default.yaml b/tutorial/demo_tinkerjet/ajet_tinkerscript_default.yaml new file mode 100644 index 00000000..c0baa7f4 --- /dev/null +++ b/tutorial/demo_tinkerjet/ajet_tinkerscript_default.yaml @@ -0,0 +1,49 @@ +# ------------------ main configuration ------------------ +ajet: + project_name: "ajet_default_project" + experiment_name: "read_yaml_name" + experiment_dir: "auto" # {exp-dir}/{experiment_name} + backbone: debug # `debug` or `trinity` or `verl` + + model: + # which model should be trained + path: /mnt/data_cpfs/model_cache/modelscope/hub/Qwen/Qwen/Qwen2___5-14B-Instruct + + rollout: + # the path to the workflow class + user_workflow: null + + task_reader: + type: random_dummy # `env_service` or `jsonl_dataset_file` or `huggingface_dat_repo` or `data_generation` or `random_dummy` + + task_judge: + judge_type: customized_protocol # Options: 'customized_protocol', 'rubrics_auto_grader' + judge_protocol: null + + # the experimental ZeroMQ interchange server feature that allows `tuner.as_oai_baseurl_apikey` feature + enable_experimental_interchange_server: True + # train in cloud, run episode locally + enable_swarm_mode: True + # both swarm / oai share the same interchange server + interchange_server: + interchange_method: 'ipc' # options: 'tcp' (multi-nodes) or 'ipc' (1 node) + interchange_server_port: 10086 + num_fastapi_process: 2 # 1, 2 or 4 is fine + max_fastapi_threads: 128 # 64 or 128 is fine + max_inference_tracker_threads: 64 # recommend to be equal to `ajet.rollout.max_env_worker` + + + +# ------------------ 不需要修改 ------------------ +hydra: + searchpath: + - file://ajet/default_config + - file://ajet/default_config/verl # verl only + - file://ajet/default_config/trinity # trinity only + +# ------------------ 不需要修改 ------------------ +defaults: + - verl_default # verl inherit 1/1 + - trinity_default # trinity inherit 1/1 + - ajet_default + - _self_ diff --git a/tutorial/demo_tinkerjet/demo_tinkerjet_math.py b/tutorial/demo_tinkerjet/demo_tinkerjet_math.py new file mode 100644 index 00000000..e69de29b diff --git a/tutorial/example_academic_trans/trans.py b/tutorial/example_academic_trans/trans.py new file mode 100644 index 00000000..49120d41 --- /dev/null +++ b/tutorial/example_academic_trans/trans.py @@ -0,0 +1,167 @@ + +import re +import os +import time +import asyncio +import threading +from loguru import logger +from textwrap import dedent +from openai import OpenAI + +from ajet import WorkflowOutput +from ajet.schema.task import Task +from ajet.copilot.job import AgentJetJob +from ajet.task_reader import RouterTaskReader +from ajet.utils.retry import retry_with_backoff +from ajet.default_config.ajet_default import AjetTaskReader, HuggingfaceDatRepo +from ajet.tuner_lib.weight_tuner.as_oai_baseurl_apikey import OpenaiBaseUrlAndApiKey +from beast_logger import print_listofdict + +# Import reward computation from trans_reward.py +from openjudge.models import OpenAIChatModel +from .trans_reward import TranslationQualityGrader, build_translation_quality_messages, examples + + + +@retry_with_backoff(max_retry=3) +def execute_agent(task: Task, api_baseurl_key: OpenaiBaseUrlAndApiKey): + # Prepare base_url, api_key + base_url, api_key = (api_baseurl_key.base_url, api_baseurl_key.api_key) + grader_base_url, grader_api_key = ("https://dashscope.aliyuncs.com/compatible-mode/v1", os.environ.get("DASHSCOPE_API_KEY", "")) + # Read dataset item + title = task.metadata['title'] + authors = task.metadata['authors'] + abstract = task.metadata['abstract'] + + messages, rough_translate = rough_translate_agent(base_url, api_key, abstract) + # print_listofdict(messages, header="rough_translate_agent", mod="c") + + # messages, fix_nouns = detect_hard_proper_nouns(messages, base_url, api_key, abstract, rough_translate) + messages, fix_nouns = detect_hard_proper_nouns(messages, grader_base_url, grader_api_key, abstract, rough_translate) + # print_listofdict(messages, header="detect_hard_proper_nouns", mod="c") + + messages, final_translation = produce_final_translation(messages, base_url, api_key, abstract, rough_translate, fix_nouns) + print_listofdict(messages, header="final_translation", mod="c") + + if final_translation is None: + raw_reward = 0.0 + else: + grader = TranslationQualityGrader( + model=OpenAIChatModel(base_url=grader_base_url, api_key=grader_api_key, model="qwen3-max-2026-01-23") + ) + grader_score = asyncio.run(asyncio.wait_for(grader.aevaluate(original_text=abstract, translation=final_translation), timeout=120)) + raw_reward = grader_score.score + print(f"Grader Score: {grader_score.score}, Reason: {grader_score.reason}, Metadata: {grader_score.metadata}") + return WorkflowOutput(reward=raw_reward, metadata={ + "rough_translate": rough_translate, + "fix_nouns": fix_nouns, + "final_translation": final_translation + }) + + +def produce_final_translation(messages, base_url, api_key, abstract, rough_translate, fix_nouns): + messages = messages + [ + { + "role": "user", + "content": "Please produce the final, corrected Chinese translation by applying all the corrections listed above. " + "Output only the final translation between ... , so I will extract result with regex." + }, + ] + + client = OpenAI(base_url=base_url, api_key=api_key) + response = client.chat.completions.create( + model="agentjet-model", + messages=messages + ) + final_translation = response.choices[0].message.content + + messages += [ + { + "role": "assistant", + "content": final_translation + } + ] + + # Extract final translation + match = re.search(r"(.*?)", final_translation, re.DOTALL) + if match: + final_translation = match.group(1).strip() + else: + final_translation = None + + return messages, final_translation + + + +def detect_hard_proper_nouns(messages, base_url, api_key, abstract, rough_translate): + messages = messages + [ + + { + "role": "user", + "content": "You new job is to detect translation errors of discipline-specific proper nouns. " + "Use json to list all errors found in the translation result and provide correction. " + "Json format: [{\"original_word\": \"xxx\", \"wrong_translation\": \"xxx\", \"wrong_reason\": \"xxx\", \"correct_translation\": \"xxx\"}, ...]. " + "If no errors are found, return an empty list []." + "Please list all translation errors of discipline-specific proper nouns found in the translation result according to the requirements." + }, + + ] + + client = OpenAI(base_url=base_url, api_key=api_key) + response = client.chat.completions.create( + model="qwen3-max-2026-01-23", + messages=messages, + timeout=60, + # extra_body={"enable_thinking":True} + ) + fix_nouns = response.choices[0].message.content + messages += [ + { + "role": "assistant", + "content": fix_nouns + } + ] + return messages, fix_nouns + + +def rough_translate_agent(base_url, api_key, abstract): + messages = [ + { + "role": "system", + "content": + "You are a professional language translator. " + "Translate the given Academic English text into Chinese accurately. " + "During the translation process, it is necessary to meet the linguistic norms of Chinese academic papers " + "such as conforming to the logic of the Chinese language, being simple, rigorous, and concise, " + "and avoiding the use of first-person pronouns when passive voice is appropriate. " + "Ensure that specialized terms are translated correctly according to academic standards. " + "Replace 我/我们 with 本研究 or 本文 or 研究者 or simply remove it and rephrase the sentence. " + "If an English abbreviation is short in Chinese, use Chinese. " + "If an English abbreviation is long in Chinese, use English abbreviation. " + "To use an English abbreviation, if the author has mentioned the full form first, mention the full form at its first appearance. " + "e.g. `We have used the LAsMA heterodyne array installed on the Atacama Pathfinder EXperiment (APEX)` should be translated as " + "`本研究使用了安装在阿塔卡马探路者实验望远镜(APEX, Atacama Pathfinder EXperiment)上的LAsMA外差阵列`. " + }, + { + "role": "user", + "content": abstract + } + ] + + for ex in examples: + messages[0]['content'] += f"\n\nExample:\n\tOriginal: {ex['original']}\n\tBad Translation: {ex['bad']}\n\tHint: {ex['hint']}\n\tGood Translation: {ex['good']}" + + client = OpenAI(base_url=base_url, api_key=api_key) + response = client.chat.completions.create( + model="agentjet-model", + messages=messages + ) + rough_translate = response.choices[0].message.content + messages += [ + { + "role": "assistant", + "content": rough_translate + } + ] + + return messages, rough_translate diff --git a/tutorial/example_academic_trans/trans_reward.py b/tutorial/example_academic_trans/trans_reward.py new file mode 100644 index 00000000..663d7107 --- /dev/null +++ b/tutorial/example_academic_trans/trans_reward.py @@ -0,0 +1,193 @@ +import re +from openjudge.graders.base_grader import GraderError, GraderMode, GraderScore +from openjudge.graders.llm_grader import LLMGrader +from openjudge.models.base_chat_model import BaseChatModel +from typing import List +from textwrap import dedent +from beast_logger import print_listofdict + + +examples = [ + { + "original": "We find that the EMBB is dominated by GW bursts from stellar mass black holes", + "bad": "我们发现,EMBB主要由恒星级黑洞发出的GWs爆发主导", + "hint": "1) 我们->本研究/本文(删除第一人称) 2) GWs->引力波(有简洁的中文表达),但EMBB保留(没有简洁的中文表达) 3. 调换语序,这句话中的重点是“恒星级黑洞发出的引力波”,所以调换语序突出重点。", + "good": "本研究发现恒星级黑洞发出的引力波爆发在EMBB中占主导地位" + }, + { + "original": "In a previous paper (Gayon & Bois 2008a), we have shown the general efficiency of retrograde resonances for stabilizing compact planetary systems.", + "bad": "在先前的一篇论文(Gayon & Bois 2008a)中,本文展示了逆向共振在稳定紧凑行星系统中的普遍效率。", + "hint": "修复主语,删除冗余的逗号,替换“效率”为“有效性”更符合学术表达。", + "good": "先前的一篇论文(Gayon & Bois 2008a)阐释了逆向共振在稳定紧凑行星系统中的普遍有效性。" + }, + { + "original": "To improve the transferability of ViT, we introduce a novel and effective module, named Domain Transferable-guided Attention Block (DTAB).", + "bad": "为了提高ViT的迁移能力,本文引入了一个新颖且有效的模块,称为域可迁移引导注意力块(DTAB)", + "hint": "1)语言顺序和表达不符合中文习惯 2)没有在首次出现自定义缩写时,给出英文全称", + "good": "为提高ViT的迁移能力,本文引入了名为“域可迁移引导注意力块”(Domain Transferable-guided Attention Block,DTAB)的新颖且有效的模块。" + }, + { + "original": "Extensive experiments were conducted on UCF-HMDB, Kinetics-Gameplay, and Kinetics-NEC Drone datasets, with different backbones, like ResNet101, I3D, and STAM, to verify the effectiveness of TransferAttn compared with state-of-the-art approaches.", + "bad": "在UCF-HMDB、Kinetics-Gameplay和Kinetics-NEC Drone数据集上进行了广泛的实验,使用了不同的骨干网络,如ResNet101、I3D和STAM,以验证TransferAttn与现有最先进方法相比的有效性。", + "hint": "1)改变语言顺序后,主语缺失 2)举例时,表述不够简洁", + "good": "本研究在UCF-HMDB、Kinetics-Gameplay和Kinetics-NEC Drone数据集上进行了广泛的实验,使用了ResNet101、I3D和STAM等骨干网络来验证TransferAttn与现有最先进方法相比的有效性。" + } +] + + +examples_eval = examples + [ + +] + + + +TRANSLATION_QUALITY_USER_PROMPT = """ +Evaluate the quality of this Chinese translation based on the specific error types demonstrated in the examples. + +Original English text: +{original} + +Chinese translation to evaluate: +{translation} +""" + + + +def get_translation_quality_system_prompt() -> str: + """Get the translation quality system prompt.""" + examples_text = "" + for i, ex in enumerate(examples_eval, 1): + examples_text += dedent(f""" + Example {i}: + - Original: "{ex['original']}" + - Bad Translation: "{ex['bad']}" + - Issues: {ex['hint']} + - Good Translation: "{ex['good']}" + """) + + + return dedent(""" + You are an objective translation quality evaluator for academic paper translations from English to Chinese. Your task is to identify ONLY the specific types of errors demonstrated in the provided examples - not general translation quality issues. + + 重点关注(但不限于)以下问题类型(如示例所示): + + 1. **错误使用第一人称代词** - 禁止使用"我们"。正确的方法是使用"本研究"、"本文"、“研究者”,或者直接删除we并改写句子替换主语。不要漏掉出现的任何第一人称代词。 + 2. **缩写翻译错误** - 当存在简洁的中文表达时使用缩写(例如,使用"GWs"而非"引力波"),或翻译本应保留英文的缩写(如"EMBB") + 3. **语序问题** - 未调整句子结构以符合中文学术风格强调重点的习惯 + 4. **主谓不一致、主语缺失** - 由于句子结构不当导致主语混乱(例如,"在...中,本文展示..."中主语混淆) + 5. **用词不当** - 使用口语化或不正确的术语而非恰当的学术表达 + 6. **多余标点和停顿** - 不必要的逗号或其他标点符号影响中文阅读流畅性 + 7. **主语不清晰** - 中文句子主语缺失或不明确。例如:“通过该实验,证明了该药物对癌细胞有抑制作用”(缺少主语) + 8. **缩写问题** - 首次出现自定义缩写、且原文中已经提供自定义缩写的英文全称时,没有在首次出现的地方提供英文全称。 + (正确的例子:`We have used the LAsMA heterodyne array installed on the Atacama Pathfinder EXperiment (APEX)`->`本研究使用了安装在阿塔卡马探路者实验望远镜(APEX, Atacama Pathfinder EXperiment)上的LAsMA外差阵列`) + 9. **专有名词翻译错误** - 领域特定的专有名词翻译错误,例如技术术语、学科术语等。如错把Agent翻译成“代理”(实际上应为“智能体”)等。 + 10. **表意偏差** - 翻译结果与原文在意义上存在偏差,导致信息传达不准确。 + + **Examples of these errors:** + [[examples_text]] + Rate the translation on a scale of 0-2: + + 0 = Severely impairs readability (multiple critical errors from the categories above that make the text difficult to understand) + 1 = Contain errors, reduces Chinese reading efficiency (many instances of the error types above) + 2 = No errors from the example categories detected (translation is free of the specific error types demonstrated) + + Note: + * For each key issue found, provide the specific error, its type, and where it appears in the translation. + * Be precise about which error category each issue belongs to. + * Focus on objective errors matching the example patterns, not subjective preferences. + * 当出现 **语序问题**、**主谓不一致、主语缺失**、**主语不清晰**、**专有名词翻译错误**、**表意偏差** 等严重问题时,直接给 0 分。 + * 逐句分析,切勿遗漏。 + + Think carefully before flagging any error. Ask yourself: Does this match one of the specific error types from the examples? Is this truly an objective error or just a stylistic preference? + + Return your response in this format: + + Your analysis + + + - Error Type: [category]. Error: [specific issue]. Location: [where it appears in the translation] + + X + + The score must be 0, 1, 2. Each key issue should be on its own line starting with a dash. If no errors are found, the key_issues section should be empty or state "None detected". + """.replace("[[examples_text]]", examples_text)) + + + +def parse_translation_quality_response(text: str) -> dict: + """Parse XML-formatted translation quality response.""" + score_match = re.search(r"\s*(\d+)\s*", text) + reasoning_match = re.search(r"(.*?)", text, re.DOTALL) + issues_match = re.search(r"(.*?)", text, re.DOTALL) + + score = int(score_match.group(1)) if score_match else 0 + reasoning = reasoning_match.group(1).strip() if reasoning_match else text + + key_issues = [] + if issues_match: + issues_text = issues_match.group(1) + # Filter out empty lines and "None detected" type messages + key_issues = [ + line.strip().lstrip("- ") + for line in issues_text.strip().split("\n") + if line.strip() and not line.strip().lstrip("- ").lower().startswith("none") + ] + + return {"score": score, "reason": reasoning, "key_issues": key_issues} + + +def build_translation_quality_messages(original_text: str, translation: str) -> List[dict]: + return [ + { + "role": "system", + "content": get_translation_quality_system_prompt() + }, + { + "role": "user", + "content": TRANSLATION_QUALITY_USER_PROMPT.format( + original=original_text, + translation=translation + ), + }, + ] + + +class TranslationQualityGrader(LLMGrader): + def __init__(self, model: BaseChatModel | dict): + super().__init__( + name="translation_quality", + mode=GraderMode.POINTWISE, + description="Evaluate translation quality based on specific error patterns", + model=model, + template="", # Placeholder, not used + ) + + async def aevaluate(self, original_text: str, translation: str, normalize=True) -> GraderScore: + try: + messages = build_translation_quality_messages(original_text, translation) + response = await self.model.achat(messages=messages) + content = await extract_response_content(response) + parsed = parse_translation_quality_response(content) + + if normalize: + parsed["score"] = parsed["score"] / 2.0 + + return GraderScore( + name=self.name, + score=parsed["score"], + reason=parsed["reason"], + metadata={"key_issues": parsed["key_issues"]}, + ) + except Exception as e: + return GraderError(name=self.name, error=str(e)) + + +async def extract_response_content(response) -> str: + if hasattr(response, 'content'): + return response.content + elif isinstance(response, dict) and 'content' in response: + return response['content'] + elif isinstance(response, str): + return response + else: + raise ValueError(f"Unable to extract content from response: {type(response)}") diff --git a/tutorial/example_academic_trans/trans_roll.py b/tutorial/example_academic_trans/trans_roll.py new file mode 100644 index 00000000..4630218d --- /dev/null +++ b/tutorial/example_academic_trans/trans_roll.py @@ -0,0 +1,100 @@ +import re +import threading +import requests +import time +from loguru import logger +from textwrap import dedent +from ajet.copilot.job import AgentJetJob +from ajet.tuner_lib.weight_tuner.experimental.as_swarm_client import SwarmClient +from ajet.default_config.ajet_default import AjetTaskReader, HuggingfaceDatRepo +from ajet.tuner_lib.weight_tuner.as_oai_baseurl_apikey import OpenaiBaseUrlAndApiKey +from ajet.utils.thread_executors import BoundedThreadPoolExecutor +from ajet.schema.task import Task +from ajet.task_reader import RouterTaskReader +from ajet.utils.retry import retry_with_backoff +from concurrent.futures import ThreadPoolExecutor +from tutorial.example_academic_trans.trans import execute_agent + +# python -m tutorial.example_academic_trans.trans_roll + + +# --------- configurations that take effect locally ------------- +LOCAL_GRPO_N = 4 # grpo group size +LOCAL_NUM_EPOCH = 10000 +LOCAL_NUM_EPOCH = 1 +LOCAL_MAX_PARALLEL = 32 +LOCAL_DATASET_PATH = "/mnt/data_cpfs/qingxu.fu/agentjet/agentjet/tmp/arxiv_papers/train.parquet" +REMOTE_SWARM_URL = "http://localhost:10086" # Change to your swarm remote url + +# --------- configurations that take effect remotely ------------- +REMOTE_BATCH_SIZE = 32 +REMOTE_ALLOCATE_GPU_PER_NODE = 8 +REMOTE_TRAIN_MODEL_01 = '/mnt/data_cpfs/model_cache/modelscope/hub/Qwen/Qwen/Qwen2.5-7B-Instruct' + +class WeightUpdatedHalfway(Exception): + """Raised when the remote side starts updating model weights halfway through an episode.""" + + +def main(): + + # Handshake with swarm remote, then send training param to swarm remote (such as model to be trained, algorithm, etc) + dataset = RouterTaskReader( + reader_type = "huggingface_dat_repo", + reader_config = AjetTaskReader( + huggingface_dat_repo = HuggingfaceDatRepo( + dataset_path = LOCAL_DATASET_PATH + ) + ) + ) + + # Hand shake with remote swarm server + swarm_remote = SwarmClient(REMOTE_SWARM_URL) + swarm_remote.auto_sync_train_config_and_start_engine( + AgentJetJob( + algorithm="grpo", + n_gpu=REMOTE_ALLOCATE_GPU_PER_NODE, + model=REMOTE_TRAIN_MODEL_01, + batch_size=REMOTE_BATCH_SIZE, + grpo_n=LOCAL_GRPO_N, + ), + ) + + def rollout(task): + group_reward = [] + try: + for _ in range(LOCAL_GRPO_N): + try: + # begin episode + episode_uuid, api_baseurl_key = swarm_remote.begin_episode() + # execute agent + workflow_output = execute_agent(task, api_baseurl_key) + # report output back to swarm remote + swarm_remote.end_episode(task, episode_uuid, workflow_output) + # collect reward + group_reward.append(workflow_output.reward) + except Exception as e: + logger.exception("Exception during rollout:", e) + + print(f"Group reward mean & std: {sum(group_reward)/len(group_reward)} +/- { (max(group_reward)-min(group_reward))/2 }") + except Exception as e: + logger.exception("Exception during rollout group", e) + + task_batch = [] + executor = BoundedThreadPoolExecutor(max_workers=LOCAL_MAX_PARALLEL, max_queue_size=LOCAL_MAX_PARALLEL*2) + for i, task in enumerate(dataset.generate_training_tasks()): + task_batch += [task] + + if len(task_batch) == REMOTE_BATCH_SIZE: + print('*********** beginning a new batch of tasks... ***********') + for task in task_batch: + executor.submit(rollout, task) + task_batch = [] + + executor.shutdown(wait=True) + return None + + + + +if __name__ == "__main__": + main()