From a18048f803a5b60a9d35ba949347b10fef5d5d10 Mon Sep 17 00:00:00 2001 From: "qingxu.fu" Date: Tue, 17 Mar 2026 01:16:18 +0800 Subject: [PATCH 01/17] set up --- .gitignore | 2 + ajet/backbone/main_verl.py | 66 +- ajet/backbone/trainer_verl.py | 202 +--- ajet/backbone/verl/__init__.py | 3 + ajet/backbone/verl/fsdp_workers.py | 430 +++++++ ajet/default_config/ajet_default.yaml | 21 +- .../verl/config_auto_convertion_verl.jsonc | 4 + .../verl/config_schema_rollout.py | 27 + ajet/default_config/verl/verl_default.yaml | 1022 +++++++++++------ ajet/task_rollout/async_llm_bridge.py | 14 +- ajet/utils/launch_utils.py | 2 +- ajet/utils/process_dataset.py | 5 +- docs/en/installation.md | 2 +- 13 files changed, 1200 insertions(+), 600 deletions(-) create mode 100644 ajet/backbone/verl/__init__.py create mode 100644 ajet/backbone/verl/fsdp_workers.py create mode 100644 ajet/default_config/verl/config_schema_rollout.py diff --git a/.gitignore b/.gitignore index 00da5134..1f863ee0 100644 --- a/.gitignore +++ b/.gitignore @@ -175,3 +175,5 @@ werewolves_swarm tensorboard_log tutorial/**/*.json node_modules +.agents +skills-lock.json diff --git a/ajet/backbone/main_verl.py b/ajet/backbone/main_verl.py index 0fe845ca..04a681ce 100644 --- a/ajet/backbone/main_verl.py +++ b/ajet/backbone/main_verl.py @@ -61,7 +61,6 @@ def run_ppo(config: DictConfig) -> None: runtime_env = get_runtime_env(config) ray.init( runtime_env=runtime_env, - num_cpus=config.ray_init.num_cpus, ) def on_shutdown(): @@ -93,12 +92,6 @@ def on_shutdown(): runner = TaskRunner.remote() ray.get(runner.run.remote(config)) - # [Optional] get the path of the timeline trace file from the configuration, default to None - # This file is used for performance analysis - timeline_json_file = config.ray_init.get("timeline_json_file", None) - if timeline_json_file: - ray.timeline(filename=timeline_json_file) - @ray.remote(num_cpus=1) # please make sure main_task is not scheduled on head class TaskRunner: @@ -148,23 +141,13 @@ def run(self, config): if config.actor_rollout_ref.actor.strategy in {"fsdp", "fsdp2"}: assert config.critic.strategy in {"fsdp", "fsdp2"} from verl.single_controller.ray import RayWorkerGroup - from verl.workers.fsdp_workers import ( - ActorRolloutRefWorker, - AsyncActorRolloutRefWorker, - ) + from ajet.backbone.verl import AjetActorRolloutRefWorker + from ajet.backbone.verl import AjetAsyncActorRolloutRefWorker + - use_legacy_worker_impl = config.trainer.get("use_legacy_worker_impl", "auto") - if use_legacy_worker_impl in ["auto", "enable"]: - # import warnings - # warnings.warn(f"Legacy worker impl is going to be deprecated, will be removed in the future. \ - # Please set trainer.use_legacy_worker_impl = false to switch to the new worker implementation.") - from verl.workers.fsdp_workers import CriticWorker - elif use_legacy_worker_impl == "disable": - from verl.workers.roles import CriticWorker - else: - raise ValueError(f"Invalid use_legacy_worker_impl: {use_legacy_worker_impl}") - - actor_rollout_cls = AsyncActorRolloutRefWorker + + ActorRolloutRefWorker = AjetActorRolloutRefWorker + actor_rollout_cls = AjetAsyncActorRolloutRefWorker ray_worker_group_cls = RayWorkerGroup elif config.actor_rollout_ref.actor.strategy == "megatron": @@ -172,11 +155,11 @@ def run(self, config): from verl.single_controller.ray.megatron import NVMegatronRayWorkerGroup from verl.workers.megatron_workers import ( ActorRolloutRefWorker, - AsyncActorRolloutRefWorker, + AjetAsyncActorRolloutRefWorker, CriticWorker, ) - actor_rollout_cls = AsyncActorRolloutRefWorker + actor_rollout_cls = AjetAsyncActorRolloutRefWorker ray_worker_group_cls = NVMegatronRayWorkerGroup else: @@ -187,7 +170,6 @@ def run(self, config): # Map roles to their corresponding remote worker classes. role_worker_mapping = { Role.ActorRollout: ray.remote(actor_rollout_cls), - Role.Critic: ray.remote(CriticWorker), } # Define the resource pool specification. @@ -198,43 +180,15 @@ def run(self, config): } mapping = { Role.ActorRollout: global_pool_id, - Role.Critic: global_pool_id, } - # We should adopt a multi-source reward function here: - # - for rule-based rm, we directly call a reward score - # - for model-based rm, we call a model - # - for code related prompt, we send to a sandbox if there are test cases - # finally, we combine all the rewards together - # The reward type depends on the tag of the data - if config.reward_model.enable: - if config.reward_model.strategy in {"fsdp", "fsdp2"}: - from verl.workers.fsdp_workers import RewardModelWorker - elif config.reward_model.strategy == "megatron": - from verl.workers.megatron_workers import RewardModelWorker - else: - raise NotImplementedError - role_worker_mapping[Role.RewardModel] = ray.remote(RewardModelWorker) - mapping[Role.RewardModel] = global_pool_id # Add a reference policy worker if KL loss or KL reward is used. if config.algorithm.use_kl_in_reward or config.actor_rollout_ref.actor.use_kl_loss: role_worker_mapping[Role.RefPolicy] = ray.remote(ActorRolloutRefWorker) mapping[Role.RefPolicy] = global_pool_id - # Load the reward manager for training and validation. - reward_fn = load_reward_manager( - config, - tokenizer, - num_examine=0, - **config.reward_model.get("reward_kwargs", {}), - ) - val_reward_fn = load_reward_manager( - config, - tokenizer, - num_examine=1, - **config.reward_model.get("reward_kwargs", {}), - ) + resource_pool_manager = ResourcePoolManager( resource_pool_spec=resource_pool_spec, mapping=mapping ) @@ -262,8 +216,6 @@ def run(self, config): role_worker_mapping=role_worker_mapping, resource_pool_manager=resource_pool_manager, ray_worker_group_cls=ray_worker_group_cls, - reward_fn=reward_fn, - val_reward_fn=val_reward_fn, train_dataset=train_dataset, val_dataset=val_dataset, collate_fn=collate_fn, diff --git a/ajet/backbone/trainer_verl.py b/ajet/backbone/trainer_verl.py index 00caaa6f..afe06434 100644 --- a/ajet/backbone/trainer_verl.py +++ b/ajet/backbone/trainer_verl.py @@ -16,7 +16,7 @@ import uuid from collections import defaultdict from pprint import pprint -from typing import List, Optional +from typing import Any, List, Optional import hydra import numpy as np @@ -27,6 +27,7 @@ from tqdm import tqdm from verl import DataProto from verl.experimental.dataset.sampler import AbstractCurriculumSampler +from verl.experimental.fully_async_policy.agent_loop.agent_loop import FullyAsyncAgentLoopManager from verl.single_controller.ray import RayClassWithInitArgs from verl.single_controller.ray.base import create_colocated_worker_cls from verl.trainer.config import AlgoConfig @@ -43,7 +44,6 @@ apply_kl_penalty, compute_response_mask, ) -from verl.trainer.ppo.reward import compute_reward from verl.utils.checkpoint.checkpoint_manager import should_save_ckpt_esi from verl.utils.config import omega_conf_to_dataclass from verl.utils.debug import marked_timer @@ -56,6 +56,27 @@ from ajet.task_rollout.native_parallel_worker import VerlRolloutManager from ajet.utils.metric_helper import save_trajectory_as_json_file, update_metrics + +def compute_reward(data: DataProto, reward_fn) -> tuple[torch.Tensor, dict[str, Any]]: + """ + Compute reward for a batch of data. + Args: + data: DataProto object containing the input data. + reward_fn: Reward function to compute the reward. + Returns: + Tuple of reward tensor and extra info dictionary. + """ + try: + reward_result = reward_fn(data, return_dict=True) + reward_tensor = reward_result["reward_tensor"] + reward_extra_infos_dict = reward_result.get("reward_extra_info", {}) + except Exception as e: + print(f"Error in reward_fn: {e}") + reward_tensor = reward_fn(data) + reward_extra_infos_dict = {} + + return reward_tensor, reward_extra_infos_dict + def parse_reward_from_dataproto(data: DataProto, return_dict=False) -> dict | torch.Tensor: """ Compute reward for a batch of data. @@ -297,13 +318,6 @@ def check_mutually_exclusive(mbs, mbs_per_gpu, name: str): "actor_rollout_ref.rollout", ) - # Check for reward model micro-batch size conflicts - if config.reward_model.enable and not config.reward_model.use_dynamic_bsz: - check_mutually_exclusive( - config.reward_model.micro_batch_size, - config.reward_model.micro_batch_size_per_gpu, - "reward_model", - ) if self.config.algorithm.use_kl_in_reward and config.actor_rollout_ref.actor.use_kl_loss: logger.warning("NOTICE: You have both enabled in-reward kl and kl loss.") @@ -329,122 +343,7 @@ def check_mutually_exclusive(mbs, mbs_per_gpu, name: str): logger.success("[validate_config] All configuration checks passed successfully!") def init_workers(self): - """Initialize distributed training workers using Ray backend. - - Creates: - 1. Ray resource pools from configuration - 2. Worker groups for each role (actor, critic, etc.) - """ - self.resource_pool_manager.create_resource_pool() - - self.resource_pool_to_cls = { - pool: {} for pool in self.resource_pool_manager.resource_pool_dict.values() - } - - # create actor and rollout - if self.hybrid_engine: - resource_pool = self.resource_pool_manager.get_resource_pool(Role.ActorRollout) - actor_rollout_cls = RayClassWithInitArgs( - cls=self.role_worker_mapping[Role.ActorRollout], - config=self.config.actor_rollout_ref, - role="actor_rollout", - profile_option=self.config.trainer.npu_profile.options, - ) - self.resource_pool_to_cls[resource_pool]["actor_rollout"] = actor_rollout_cls - else: - raise NotImplementedError - - # create critic - if self.use_critic: - resource_pool = self.resource_pool_manager.get_resource_pool(Role.Critic) - critic_cfg = omega_conf_to_dataclass(self.config.critic) - critic_cls = RayClassWithInitArgs( - cls=self.role_worker_mapping[Role.Critic], config=critic_cfg - ) - self.resource_pool_to_cls[resource_pool]["critic"] = critic_cls - - # create reference policy if needed - if self.use_reference_policy: - resource_pool = self.resource_pool_manager.get_resource_pool(Role.RefPolicy) - ref_policy_cls = RayClassWithInitArgs( - self.role_worker_mapping[Role.RefPolicy], - config=self.config.actor_rollout_ref, - role="ref", - profile_option=self.config.trainer.npu_profile.options, - ) - self.resource_pool_to_cls[resource_pool]["ref"] = ref_policy_cls - - # create a reward model if reward_fn is None - if self.use_rm: - # we create a RM here - resource_pool = self.resource_pool_manager.get_resource_pool(Role.RewardModel) - rm_cls = RayClassWithInitArgs( - self.role_worker_mapping[Role.RewardModel], - config=self.config.reward_model, - ) - self.resource_pool_to_cls[resource_pool]["rm"] = rm_cls - - # initialize WorkerGroup - # NOTE: if you want to use a different resource pool for each role, which can support different parallel size, - # you should not use `create_colocated_worker_cls`. - # Instead, directly pass different resource pool to different worker groups. - # See https://github.com/volcengine/verl/blob/master/examples/ray/tutorial.ipynb for more information. - all_wg = {} - wg_kwargs = {} # Setting up kwargs for RayWorkerGroup - if OmegaConf.select(self.config.trainer, "ray_wait_register_center_timeout") is not None: - wg_kwargs[ - "ray_wait_register_center_timeout" - ] = self.config.trainer.ray_wait_register_center_timeout - if OmegaConf.select(self.config.trainer, "profile_steps") is not None: - wg_kwargs["profile_steps"] = OmegaConf.select(self.config.trainer, "profile_steps") - assert ( - OmegaConf.select(self.config.trainer, "worker_nsight_options") is not None - ), "worker_nsight_options must be set when profile_steps is set" - wg_kwargs["worker_nsight_options"] = OmegaConf.to_container( - OmegaConf.select(self.config.trainer, "worker_nsight_options") - ) - wg_kwargs["device_name"] = self.device_name - - for resource_pool, class_dict in self.resource_pool_to_cls.items(): - worker_dict_cls = create_colocated_worker_cls(class_dict=class_dict) - wg_dict = self.ray_worker_group_cls( - resource_pool=resource_pool, - ray_cls_with_init=worker_dict_cls, - **wg_kwargs, - ) - spawn_wg = wg_dict.spawn(prefix_set=class_dict.keys()) - all_wg.update(spawn_wg) - - if self.use_critic: - self.critic_wg = all_wg["critic"] - self.critic_wg.init_model() - - if self.use_reference_policy and not self.ref_in_actor: - self.ref_policy_wg = all_wg["ref"] - self.ref_policy_wg.init_model() - - if self.use_rm: - self.rm_wg = all_wg["rm"] - self.rm_wg.init_model() - - # we should create rollout at the end so that vllm can have a better estimation of kv cache memory - self.actor_rollout_wg = all_wg["actor_rollout"] - self.actor_rollout_wg.init_model() - - # create async rollout manager and request scheduler - self.async_rollout_mode = False - from verl.experimental.agent_loop.agent_loop import ( - AgentLoopManager, - AsyncLLMServerManager, - ) - - self.async_rollout_mode = True - agent_loop_manager = AgentLoopManager( - config=self.config, - worker_group=self.actor_rollout_wg, - ) - self.async_server_list = agent_loop_manager.async_llm_servers - self.async_rollout_manager = AsyncLLMServerManager(self.config, self.async_server_list) + super().init_workers() self.reward_fn = parse_reward_from_dataproto self.val_reward_fn = parse_reward_from_dataproto @@ -466,16 +365,12 @@ def _update_interchange_server_status_flag(self, status: str): # training loop # ####################################### def fit(self): # noqa: C901 - """ - The training loop of PPO. - The driver process only need to call the compute functions of the worker group through RPC - to construct the PPO dataflow. - The light-weight advantage computation is done on the driver process. - """ + from omegaconf import OmegaConf from verl.utils.tracking import Tracking warm_up_process(self.config) + self.verl_logger = Tracking( project_name=self.config.ajet.project_name, experiment_name=self.config.ajet.experiment_name, @@ -486,10 +381,7 @@ def fit(self): # noqa: C901 # load checkpoint before doing anything self._load_checkpoint() - - # wake and sleep to enforce param sync - self.async_rollout_manager.wake_up() - self.async_rollout_manager.sleep() + self.checkpoint_manager.update_weights(self.global_steps) # perform validation before training # currently, we only support validation using the reward_function. @@ -513,25 +405,11 @@ def fit(self): # noqa: C901 last_val_metrics = None self.max_steps_duration = 0 - prev_step_profile = False - curr_step_profile = ( - self.global_steps in self.config.trainer.profile_steps - if self.config.trainer.profile_steps is not None - else False - ) - next_step_profile = False - for epoch in range(self.config.trainer.total_epochs): for batch_dict in self.train_dataloader: metrics = {} timing_raw = {} - with marked_timer("start_profile", timing_raw): - self._start_profiling( - not prev_step_profile and curr_step_profile - if self.config.trainer.profile_continuous_steps - else curr_step_profile - ) batch_dict["index"] = torch.tensor( [i for i in range(len(batch_dict["task_id"]))], @@ -570,7 +448,7 @@ def fit(self): # noqa: C901 with marked_timer("gen", timing_raw, color="red"): assert self.async_rollout_mode logger.info("wake up begin") - self.async_rollout_manager.wake_up() + self.checkpoint_manager.update_weights(self.global_steps) self._update_interchange_server_status_flag("ENGINE.ROLLING") logger.info("wake up end") tasks: List[Task] = [ @@ -645,7 +523,7 @@ def fit(self): # noqa: C901 f"gen_batch_output.info batch.keys={gen_batch_output.batch.keys()}" ) self._update_interchange_server_status_flag("ENGINE.WEIGHT_SYNCING") - self.async_rollout_manager.sleep() + self.checkpoint_manager.sleep_replicas() logger.info("rollout step end") @@ -673,11 +551,6 @@ def fit(self): # noqa: C901 ).tolist() with marked_timer("reward", timing_raw, color="yellow"): - # compute reward model score - if self.use_rm: - reward_tensor = self.rm_wg.compute_rm_score(batch) - batch = batch.union(reward_tensor) - if self.config.reward_model.launch_reward_fn_async: raise NotImplementedError( "launch_reward_fn_async is not supported in GRPO yet." @@ -816,19 +689,6 @@ def fit(self): # noqa: C901 with marked_timer("save_checkpoint", timing_raw, color="green"): self._save_checkpoint() - with marked_timer("stop_profile", timing_raw): - next_step_profile = ( - self.global_steps + 1 in self.config.trainer.profile_steps - if self.config.trainer.profile_steps is not None - else False - ) - self._stop_profiling( - curr_step_profile and not next_step_profile - if self.config.trainer.profile_continuous_steps - else curr_step_profile - ) - prev_step_profile = curr_step_profile - curr_step_profile = next_step_profile steps_duration = timing_raw["step"] self.max_steps_duration = max(self.max_steps_duration, steps_duration) @@ -934,7 +794,7 @@ def _validate(self): } logger.info(f"test_gen_batch meta info: {test_gen_batch.meta_info}") - self.async_rollout_manager.wake_up() + self.checkpoint_manager.update_weights(self.global_steps) main_val_dataset = self.get_eval_dataset() logger.info("Starting validate rollout") @@ -946,7 +806,7 @@ def _validate(self): ) logger.info("Completed validate rollout") test_output_gen_batch = self.parallel_env.to_dataproto(context_tracker_arr) - self.async_rollout_manager.sleep() + self.checkpoint_manager.sleep_replicas() # Store generated outputs output_ids = test_output_gen_batch.batch["responses"] diff --git a/ajet/backbone/verl/__init__.py b/ajet/backbone/verl/__init__.py new file mode 100644 index 00000000..77d833bb --- /dev/null +++ b/ajet/backbone/verl/__init__.py @@ -0,0 +1,3 @@ +from .fsdp_workers import AjetActorRolloutRefWorker, AjetAsyncActorRolloutRefWorker + +__all__ = ["AjetActorRolloutRefWorker", "AjetAsyncActorRolloutRefWorker"] diff --git a/ajet/backbone/verl/fsdp_workers.py b/ajet/backbone/verl/fsdp_workers.py new file mode 100644 index 00000000..e5c0fc88 --- /dev/null +++ b/ajet/backbone/verl/fsdp_workers.py @@ -0,0 +1,430 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Custom FSDP workers for AgentJet project +""" + +import datetime +import json +import logging +import os +import warnings +from dataclasses import asdict + +import psutil +import torch +import torch.distributed +import torch.distributed as dist +from codetiming import Timer +from omegaconf import DictConfig, OmegaConf, open_dict +from omegaconf.errors import ConfigAttributeError +from peft import LoraConfig, TaskType, get_peft_model +from safetensors.torch import save_file +from torch.distributed.device_mesh import init_device_mesh +from torch.distributed.fsdp import FullyShardedDataParallel as FSDP +from torch.distributed.fsdp.api import FullStateDictConfig, ShardedStateDictConfig, StateDictType + +try: + # for torch 2.5+ + from torch.distributed.tensor import DTensor +except ImportError: + from torch.distributed._tensor import DTensor + +from verl import DataProto +from verl.models.transformers.monkey_patch import apply_monkey_patch +from verl.single_controller.base import Worker +from verl.single_controller.base.decorator import Dispatch, make_nd_compute_dataproto_dispatch_fn, register +from verl.utils import hf_processor, hf_tokenizer +from verl.utils.activation_offload import enable_activation_offloading +from verl.utils.checkpoint.fsdp_checkpoint_manager import FSDPCheckpointManager +from verl.utils.config import omega_conf_to_dataclass +from verl.utils.device import ( + get_device_id, + get_device_name, + get_nccl_backend, + get_torch_device, + set_expandable_segments, +) +from verl.utils.flops_counter import FlopsCounter +from verl.utils.fs import copy_to_local +from verl.utils.fsdp_utils import ( + CPUOffloadPolicy, + MixedPrecisionPolicy, + apply_fsdp2, + collect_lora_params, + fsdp2_load_full_state_dict, + fsdp_version, + get_fsdp_wrap_policy, + get_init_weight_context_manager, + get_shard_placement_fn, + init_fn, + layered_summon_lora_params, + load_fsdp_model_to_gpu, + load_fsdp_optimizer, + offload_fsdp_model_to_cpu, + offload_fsdp_optimizer, + replace_lora_wrapper, +) +from verl.utils.import_utils import import_external_libs +from verl.utils.memory_utils import aggressive_empty_cache +from verl.utils.model import convert_weight_keys +from verl.utils.profiler import DistProfiler, DistProfilerExtension, ProfilerConfig, log_gpu_memory_usage, simple_timer +from verl.utils.profiler.performance import reduce_timing, topk_reduce_ratio_min_max +from verl.utils.py_functional import convert_to_regular_types + +# QAT support +from verl.utils.qat import apply_qat, enable_qat_fuse +from verl.utils.ray_utils import get_event_loop +from verl.utils.transformers_compat import get_auto_model_for_vision2seq +from verl.workers.config import FSDPCriticConfig, FSDPEngineConfig, HFModelConfig, RolloutConfig +from verl.workers.config.optimizer import build_optimizer +from verl.workers.rollout import get_rollout_class +from verl.workers.sharding_manager.fsdp_ulysses import FSDPUlyssesShardingManager +from verl.workers.fsdp_workers import ActorRolloutRefWorker + + +logger = logging.getLogger(__file__) +logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN")) + +device_name = get_device_name() + + +def create_device_mesh(world_size, fsdp_size): + if fsdp_size < 0 or fsdp_size >= world_size: + device_mesh = init_device_mesh(device_name, mesh_shape=(world_size,), mesh_dim_names=["fsdp"]) + else: + device_mesh = init_device_mesh( + device_name, mesh_shape=(world_size // fsdp_size, fsdp_size), mesh_dim_names=["ddp", "fsdp"] + ) + return device_mesh + + +def get_sharding_strategy(device_mesh, zero3_enable=True): + from torch.distributed.fsdp import ShardingStrategy + + if zero3_enable: + fsdp_strategy = ShardingStrategy.FULL_SHARD + hsdp_strategy = ShardingStrategy.HYBRID_SHARD + else: + fsdp_strategy = ShardingStrategy.SHARD_GRAD_OP + hsdp_strategy = ShardingStrategy._HYBRID_SHARD_ZERO2 + + if device_mesh.ndim == 1: + sharding_strategy = fsdp_strategy + elif device_mesh.ndim == 2: + sharding_strategy = hsdp_strategy + else: + raise NotImplementedError(f"Get device mesh ndim={device_mesh.ndim}, but only support 1 or 2") + return sharding_strategy + + +def get_vl_model_vision_tower(vl_model_instance): + """ + Util to extract Vision Tower from a VL model instance + """ + if hasattr(vl_model_instance, "model") and hasattr(vl_model_instance.model, "visual"): + # transformers >= 4.52.0 + return vl_model_instance.model.visual + elif hasattr(vl_model_instance, "visual"): + # transformers < 4.52.0 + return vl_model_instance.visual + return None + + +class AjetActorRolloutRefWorker(ActorRolloutRefWorker): + """Custom ActorRolloutRefWorker for AgentJet.""" + + def __init__(self, config: DictConfig, role: str, **kwargs): + Worker.__init__(self) + + self.config = config + import torch.distributed + + if not torch.distributed.is_initialized(): + rank = int(os.environ.get("RANK", 0)) + world_size = int(os.environ.get("WORLD_SIZE", 1)) + torch.distributed.init_process_group( + backend=f"cpu:gloo,{get_device_name()}:{get_nccl_backend()}", + rank=rank, + world_size=world_size, + timeout=datetime.timedelta(seconds=self.config.get("nccl_timeout", 600)), + init_method=os.environ.get("DIST_INIT_METHOD", None), + ) + + # Apply NPU patches for FSDP backend + from verl.workers.engine.fsdp.utils import apply_npu_fsdp_patches + + apply_npu_fsdp_patches() + + # build device mesh for FSDP + world_size = torch.distributed.get_world_size() + # TODO(sgm): support FSDP hybrid shard for larger model + self.device_mesh = create_device_mesh(world_size=world_size, fsdp_size=self.config.actor.fsdp_config.fsdp_size) + + # build device mesh for Ulysses Sequence Parallel + self.ulysses_device_mesh = None + self.ulysses_sequence_parallel_size = self.config.actor.get("ulysses_sequence_parallel_size", 1) + dp = world_size // self.ulysses_sequence_parallel_size + if self.ulysses_sequence_parallel_size > 1: + self.ulysses_device_mesh = init_device_mesh( + device_name, mesh_shape=(dp, self.ulysses_sequence_parallel_size), mesh_dim_names=["dp", "sp"] + ) + + # create training dispatch + if self.ulysses_device_mesh is not None: + is_collect = self.ulysses_device_mesh["sp"].get_local_rank() == 0 + self._register_dispatch_collect_info( + "actor", dp_rank=self.ulysses_device_mesh["dp"].get_local_rank(), is_collect=is_collect + ) + else: + self._register_dispatch_collect_info("actor", dp_rank=self.rank, is_collect=True) + + self.ulysses_sharding_manager = FSDPUlyssesShardingManager(self.ulysses_device_mesh) + self._lora_rank = self.config.model.get("lora_rank", 0) + self._is_lora = self.config.model.get("lora_adapter_path") is not None or self._lora_rank > 0 + + self.role = role + assert self.role in ["actor", "rollout", "ref", "actor_rollout", "actor_rollout_ref"] + + self._is_actor = self.role in ["actor", "actor_rollout", "actor_rollout_ref"] + self._is_rollout = self.role in ["rollout", "actor_rollout", "actor_rollout_ref"] + self._is_ref = self.role in ["ref", "actor_rollout_ref"] + self.use_orig_params = self.config.actor.fsdp_config.get("use_orig_params", False) + + # TODO(haibin.lin): + # As of now the type of config is DictConfig, if we assign config.profiler with ProfilerConfig, + # it will actually convert the ProfilerConfig dataclass back to a DictConfig. + # We can still use ProfilerConfig for testing purpose (tests/utils/test_nvtx_profile.py) + # as they provides DictConfig-like interface + # The benefit of creating the dataclass config is to perform validation during __post_init__ + if self._is_actor: + omega_profiler_config = config.actor.get("profiler", {}) + elif self._is_rollout: + # NOTE: In colocation mode, rollout config may not take effect (follow the actor config) + # This is for extendability in AsyncRL cases + omega_profiler_config = config.rollout.get("profiler", {}) + elif self._is_ref: + omega_profiler_config = config.ref.get("profiler", {}) + else: + raise ValueError( + f"Invalid role {self.role}, should be one of " + "['actor', 'rollout', 'ref', 'actor_rollout', 'actor_rollout_ref']" + ) + # omega_profiler_config is DictConfig + # profiler_config is a ProfilerConfig dataclass + profiler_config = omega_conf_to_dataclass(omega_profiler_config, dataclass_type=ProfilerConfig) + if omega_profiler_config.get("tool", None) in ["npu", "nsys", "torch", "torch_memory"]: + tool_config = omega_conf_to_dataclass( + omega_profiler_config.get("tool_config", {}).get(omega_profiler_config.get("tool")) + ) + else: + tool_config = None + DistProfilerExtension.__init__( + self, DistProfiler(rank=self.rank, config=profiler_config, tool_config=tool_config) + ) + + self._is_offload_param = False + self._is_offload_optimizer = False + if self._is_actor: + self._is_offload_param = self.config.actor.fsdp_config.get("param_offload", False) + self._is_offload_optimizer = self.config.actor.fsdp_config.get("optimizer_offload", False) + elif self._is_ref: + # TODO: it seems that manual offload is slowly than FSDP offload + self._is_offload_param = self.config.ref.fsdp_config.get("param_offload", False) + + # normalize config + if self._is_actor: + # self.config.actor.ppo_mini_batch_size *= self.config.rollout.n + self.config.actor.ppo_mini_batch_size //= self.device_mesh.size() // self.ulysses_sequence_parallel_size + if self.config.actor.ppo_mini_batch_size <= 0: + # `ppo_mini_batch_size` is deprecated + # replaced by the `actor_rollout_ref.actor.override_ppo_mini_batch_num` config + # which define the number of number of optimizer steps per train-batch-step + self.config.actor.ppo_mini_batch_size = 1 + # micro bsz + if self.config.actor.ppo_micro_batch_size is not None: + self.config.actor.ppo_micro_batch_size //= ( + self.device_mesh.size() // self.ulysses_sequence_parallel_size + ) + self.config.actor.ppo_micro_batch_size_per_gpu = self.config.actor.ppo_micro_batch_size + + if self.config.actor.ppo_micro_batch_size_per_gpu is not None: + assert self.config.actor.ppo_mini_batch_size % self.config.actor.ppo_micro_batch_size_per_gpu == 0, ( + f"normalized ppo_mini_batch_size {self.config.actor.ppo_mini_batch_size} should be divisible by " + f"ppo_micro_batch_size_per_gpu {self.config.actor.ppo_micro_batch_size_per_gpu}" + ) + assert self.config.actor.ppo_mini_batch_size // self.config.actor.ppo_micro_batch_size_per_gpu > 0, ( + f"normalized ppo_mini_batch_size {self.config.actor.ppo_mini_batch_size} should be larger than " + f"ppo_micro_batch_size_per_gpu {self.config.actor.ppo_micro_batch_size_per_gpu}" + ) + + # normalize rollout config + if self._is_rollout and self.config.rollout.log_prob_micro_batch_size is not None: + self.config.rollout.log_prob_micro_batch_size //= ( + self.device_mesh.size() // self.ulysses_sequence_parallel_size + ) + self.config.rollout.log_prob_micro_batch_size_per_gpu = self.config.rollout.log_prob_micro_batch_size + # normalize ref config + if self._is_ref and self.config.ref.log_prob_micro_batch_size is not None: + self.config.ref.log_prob_micro_batch_size //= self.device_mesh.size() // self.ulysses_sequence_parallel_size + self.config.ref.log_prob_micro_batch_size_per_gpu = self.config.ref.log_prob_micro_batch_size + + @register(dispatch_mode=Dispatch.ONE_TO_ALL) + def init_model(self): + from verl.workers.actor import DataParallelPPOActor + + # This is used to import external_lib into the huggingface systems + import_external_libs(self.config.model.get("external_lib", None)) + + # Initialize QAT config before _build_model_optimizer + self._init_qat_config() + + override_model_config = OmegaConf.to_container(OmegaConf.create(self.config.model.get("override_config", {}))) + use_remove_padding = self.config.model.get("use_remove_padding", False) + use_shm = self.config.model.get("use_shm", False) + use_fused_kernels = self.config.model.get("use_fused_kernels", False) + + if self._is_actor or self._is_rollout: + # we need the model for actor and rollout + if self._is_actor: + optim_config = self.config.actor.optim + fsdp_config = omega_conf_to_dataclass(self.config.actor.fsdp_config) + else: + optim_config = None + fsdp_config = FSDPEngineConfig() + + local_path = copy_to_local(self.config.model.path, use_shm=use_shm) + # TiledMLP configuration for memory-efficient MLP computation + tiled_mlp_config = self.config.model.get("tiled_mlp", {}) + use_tiled_mlp = tiled_mlp_config.get("enabled", False) + tiled_mlp_shards = tiled_mlp_config.get("num_shards", 4) + + ( + self.actor_module_fsdp, + self.actor_optimizer, + self.actor_lr_scheduler, + self.actor_model_config, + ) = self._build_model_optimizer( + model_path=local_path, + fsdp_config=fsdp_config, + optim_config=optim_config, + override_model_config=override_model_config, + use_remove_padding=use_remove_padding, + use_fused_kernels=use_fused_kernels, + enable_gradient_checkpointing=self.config.model.get("enable_gradient_checkpointing", False), + trust_remote_code=self.config.model.get("trust_remote_code", False), + use_liger=self.config.model.get("use_liger", False), + role="actor", + enable_activation_offload=self.config.model.get("enable_activation_offload", False), + use_prefix_grouper=self.config.actor.get("use_prefix_grouper", False), + use_tiled_mlp=use_tiled_mlp, + tiled_mlp_shards=tiled_mlp_shards, + ) + + # get the original unwrapped module + if fsdp_version(self.actor_module_fsdp) == 1: + self.actor_module = self.actor_module_fsdp._fsdp_wrapped_module + + if self._is_offload_param: + offload_fsdp_model_to_cpu(self.actor_module_fsdp) + log_gpu_memory_usage("After offload actor model during init", logger=logger) + + if self._is_offload_optimizer: + offload_fsdp_optimizer(optimizer=self.actor_optimizer) + log_gpu_memory_usage("After offload actor optimizer during init", logger=logger) + + if self._is_actor: + actor_cfg = self.config.actor + self.actor = DataParallelPPOActor( + config=actor_cfg, actor_module=self.actor_module_fsdp, actor_optimizer=self.actor_optimizer + ) + + if self._is_rollout: + self._build_rollout(trust_remote_code=self.config.model.get("trust_remote_code", False)) + + if self._is_ref: + ref_model_path = self.config.model.path + ref_model = self.config.ref.get("model", None) + if ref_model is not None: + ref_model_path = ref_model.get("path", self.config.model.path) + + if self.rank == 0: + print("reference model:", ref_model_path) + local_path = copy_to_local(ref_model_path, use_shm=use_shm) + use_prefix_grouper = hasattr(self.config, "actor") and self.config.actor.get("use_prefix_grouper", False) + + # TiledMLP for ref model: use ref config if specified, otherwise use actor config + ref_tiled_mlp_config = self.config.ref.get("tiled_mlp", None) + if ref_tiled_mlp_config is None: + ref_tiled_mlp_config = self.config.model.get("tiled_mlp", {}) + ref_use_tiled_mlp = ref_tiled_mlp_config.get("enabled", False) + ref_tiled_mlp_shards = ref_tiled_mlp_config.get("num_shards", 4) + + self.ref_module_fsdp = self._build_model_optimizer( + model_path=local_path, + fsdp_config=omega_conf_to_dataclass(self.config.ref.fsdp_config), + optim_config=None, + override_model_config=override_model_config, + use_remove_padding=use_remove_padding, + use_fused_kernels=use_fused_kernels, + trust_remote_code=self.config.model.get("trust_remote_code", False), + use_liger=self.config.model.get("use_liger", False), + role="ref", + use_prefix_grouper=use_prefix_grouper, + use_tiled_mlp=ref_use_tiled_mlp, + tiled_mlp_shards=ref_tiled_mlp_shards, + )[0] + OmegaConf.set_struct(self.config.ref, True) + with open_dict(self.config.ref): + self.config.ref.use_remove_padding = use_remove_padding + self.config.ref.use_fused_kernels = use_fused_kernels + if use_prefix_grouper: + self.config.ref.use_prefix_grouper = use_prefix_grouper + self.ref_policy = DataParallelPPOActor(config=self.config.ref, actor_module=self.ref_module_fsdp) + + if self._is_actor: + self.flops_counter = FlopsCounter(self.actor_model_config) + self.checkpoint_manager = FSDPCheckpointManager( + model=self.actor_module_fsdp, + optimizer=self.actor.actor_optimizer, + lr_scheduler=self.actor_lr_scheduler, + processing_class=self.processor if self.processor is not None else self.tokenizer, + checkpoint_config=self.config.actor.checkpoint, + trust_remote_code=self.config.model.get("trust_remote_code", False), + ) + + if not self._is_actor and self._is_rollout: + # If ActorRolloutRefWorker is initialized as a standalone rollout, + # create a checkpoint manager for FSDP model to allow loading FSDP checkpoints for rollout. + + checkpoint_contents = OmegaConf.create({"load_contents": ["model"], "save_contents": []}) + self.checkpoint_manager = FSDPCheckpointManager( + model=self.actor_module_fsdp, + optimizer=None, + lr_scheduler=None, + processing_class=self.processor if self.processor is not None else self.tokenizer, + checkpoint_config=checkpoint_contents, + ) + + # Free cached GPU memory so colocated vLLM processes can see it via cudaMemGetInfo + aggressive_empty_cache(force_sync=True) + + +# ================================= Async related workers ================================= +class AjetAsyncActorRolloutRefWorker(AjetActorRolloutRefWorker): + @register(dispatch_mode=Dispatch.ONE_TO_ALL, blocking=False) + async def update_weights(self, global_steps: int = None): + await self.rollout_mode() + return True diff --git a/ajet/default_config/ajet_default.yaml b/ajet/default_config/ajet_default.yaml index 1539d028..a880e049 100644 --- a/ajet/default_config/ajet_default.yaml +++ b/ajet/default_config/ajet_default.yaml @@ -82,7 +82,6 @@ ajet: top_k: -1 top_p: 1.0 do_sample: False - num_repeat: 1 @@ -261,9 +260,25 @@ ajet: param_offload: True optimizer_offload: True - # learning rate + # optimizer settings optim: - lr: 1e-6 + optimizer: AdamW + optimizer_impl: torch.optim + lr: 1.0e-06 + lr_warmup_steps_ratio: 0.0 + total_training_steps: -1 + weight_decay: 0.01 + lr_warmup_steps: -1 + betas: + - 0.9 + - 0.999 + clip_grad: 1.0 + min_lr_ratio: 0.0 + num_cycles: 0.5 + lr_scheduler_type: constant + zero_indexed_step: true + warmup_style: null + override_optimizer_config: null # enable KL loss regularization use_kl_loss: True diff --git a/ajet/default_config/verl/config_auto_convertion_verl.jsonc b/ajet/default_config/verl/config_auto_convertion_verl.jsonc index 378fd112..90abdaef 100644 --- a/ajet/default_config/verl/config_auto_convertion_verl.jsonc +++ b/ajet/default_config/verl/config_auto_convertion_verl.jsonc @@ -32,6 +32,10 @@ "ajet.rollout.multi_turn": "actor_rollout_ref.rollout.multi_turn", "ajet.rollout.val_kwargs": "actor_rollout_ref.rollout.val_kwargs", + "ajet.rollout.num_repeat": [ + "actor_rollout_ref.actor.rollout_n", + "actor_rollout_ref.rollout.n" + ], "ajet.model.path": "actor_rollout_ref.model.path", "ajet.project_name": "trainer.project_name", diff --git a/ajet/default_config/verl/config_schema_rollout.py b/ajet/default_config/verl/config_schema_rollout.py new file mode 100644 index 00000000..01627982 --- /dev/null +++ b/ajet/default_config/verl/config_schema_rollout.py @@ -0,0 +1,27 @@ +from verl.workers.config.rollout import MultiTurnConfig +from dataclasses import dataclass, field +from typing import Optional +from verl.base_config import BaseConfig + + +@dataclass +class AjetMultiTurnConfig(BaseConfig): + _mutable_fields = {"max_assistant_turns", "max_user_turns"} + + enable: bool = False + max_assistant_turns: Optional[int] = None + tool_config_path: Optional[str] = None + max_user_turns: Optional[int] = None + max_parallel_calls: int = 1 + max_sample_per_task: int = 30 + max_steps: int = 30 + expected_steps: Optional[int] = None + max_tool_response_length: int = 256 + tool_response_truncate_side: str = "middle" + interaction_config_path: Optional[str] = None + use_inference_chat_template: bool = False + tokenization_sanity_check_mode: str = "strict" + format: str = "hermes" + num_repeat_rollouts: Optional[int] = None + + diff --git a/ajet/default_config/verl/verl_default.yaml b/ajet/default_config/verl/verl_default.yaml index 76847473..bb26fb28 100644 --- a/ajet/default_config/verl/verl_default.yaml +++ b/ajet/default_config/verl/verl_default.yaml @@ -1,431 +1,735 @@ -# DO NOT EDIT: THIS FILE IS READ ONLY and ALWAYS FIXED, EDIT `ajet/default_config/ajet_default.yaml` INSTEAD -# DO NOT EDIT: THIS FILE IS READ ONLY and ALWAYS FIXED, EDIT `ajet/default_config/ajet_default.yaml` INSTEAD -# DO NOT EDIT: THIS FILE IS READ ONLY and ALWAYS FIXED, EDIT `ajet/default_config/ajet_default.yaml` INSTEAD +# coyp from verl's: +# verl/trainer/config/_generated_ppo_trainer.yaml -ajet: - rollout: - step_skip_action: 0 - submit_oversample_multiplier: 1.5 +# DO NOT EDIT MANUALLY +# DO NOT EDIT MANUALLY +# DO NOT EDIT MANUALLY +# DO NOT EDIT MANUALLY +# DO NOT EDIT MANUALLY +# DO NOT EDIT MANUALLY + +# This reference configration yaml is automatically generated via 'scripts/generate_trainer_config.sh' +# in which it invokes 'python3 scripts/print_cfg.py --cfg job ' to flatten the 'verl/trainer/config/ppo_trainer.yaml' config fields into a single file. +# Do not modify this file directly. +# The file is usually only for reference and never used. actor_rollout_ref: actor: - _target_: verl.workers.config.FSDPActorConfig - checkpoint: - _target_: verl.trainer.config.CheckpointConfig - async_save: false - load_contents: - - model - - optimizer - - extra - save_contents: - - model - - optimizer - - extra - clip_ratio: 0.2 - clip_ratio_c: 3.0 - clip_ratio_high: 0.2 - clip_ratio_low: 0.2 - entropy_checkpointing: false - entropy_coeff: 0 - entropy_from_logits_with_chunking: false - fsdp_config: - _target_: verl.workers.config.FSDPEngineConfig - forward_prefetch: false - fsdp_size: -1 - offload_policy: false - optimizer_offload: true - param_offload: true - reshard_after_forward: true - wrap_policy: - min_num_params: 0 - grad_clip: 1.0 - kl_loss_coef: 0.002 - kl_loss_type: low_var_kl - loss_agg_mode: seq-mean-token-mean optim: _target_: verl.workers.config.FSDPOptimizerConfig + optimizer: AdamW + optimizer_impl: torch.optim lr: 1.0e-06 - lr_warmup_steps: -1 lr_warmup_steps_ratio: 0.0 - min_lr_ratio: 0.0 - num_cycles: 0.5 total_training_steps: -1 - warmup_style: constant weight_decay: 0.01 - override_ppo_mini_batch_num: 1 + lr_warmup_steps: -1 + betas: + - 0.9 + - 0.999 + clip_grad: 1.0 + min_lr_ratio: 0.0 + num_cycles: 0.5 + lr_scheduler_type: constant + zero_indexed_step: true + warmup_style: null + override_optimizer_config: null + fsdp_config: + _target_: verl.workers.config.FSDPEngineConfig + wrap_policy: + min_num_params: 0 + param_offload: true + optimizer_offload: true + offload_policy: false + reshard_after_forward: true + fsdp_size: -1 + forward_prefetch: false + model_dtype: bfloat16 + use_orig_params: false + seed: 42 + full_determinism: false + ulysses_sequence_parallel_size: 1 + entropy_from_logits_with_chunking: false + use_torch_compile: true + entropy_checkpointing: false + forward_only: false + strategy: fsdp + dtype: bfloat16 + qat: + _target_: verl.workers.config.QATEngineConfig + enable: false + mode: w4a16 + group_size: 16 + ignore_patterns: + - lm_head + - embed_tokens + - re:.*mlp.gate$ + activation_observer: static_minmax + quantization_config_path: null + _target_: verl.workers.config.FSDPActorConfig + rollout_n: ${oc.select:actor_rollout_ref.rollout.n,1} + strategy: fsdp + ppo_mini_batch_size: 256 + override_ppo_mini_batch_num: 1 # special in agentjet + ppo_micro_batch_size: null + ppo_micro_batch_size_per_gpu: null + use_dynamic_bsz: false + ppo_max_token_len_per_gpu: 16384 + clip_ratio: 0.2 + clip_ratio_low: 0.2 + clip_ratio_high: 0.2 + tau_pos: 1.0 + tau_neg: 1.05 + freeze_vision_tower: false policy_loss: _target_: verl.workers.config.PolicyLossConfig - clip_cov_lb: 1.0 + loss_mode: vanilla clip_cov_ratio: 0.0002 + clip_cov_lb: 1.0 clip_cov_ub: 5.0 kl_cov_ratio: 0.0002 - loss_mode: vanilla ppo_kl_coef: 0.1 + clip_ratio_c: 3.0 + loss_agg_mode: token-mean + loss_scale_factor: null + entropy_coeff: 0 + calculate_entropy: false + use_kl_loss: false + use_prefix_grouper: false + use_torch_compile: true + kl_loss_coef: 0.001 + kl_loss_type: low_var_kl ppo_epochs: 1 - ppo_max_token_len_per_gpu: 13000 - ppo_micro_batch_size: null - ppo_micro_batch_size_per_gpu: 1 - ppo_mini_batch_size: 16 shuffle: false - strategy: fsdp + data_loader_seed: 42 + checkpoint: + _target_: verl.trainer.config.CheckpointConfig + save_contents: + - model + - optimizer + - extra + load_contents: ${.save_contents} + async_save: false + mbridge_config: {} + use_fused_kernels: ${oc.select:actor_rollout_ref.model.use_fused_kernels,false} + profiler: + _target_: verl.utils.profiler.ProfilerConfig + tool: ${oc.select:global_profiler.tool,null} + enable: false + all_ranks: false + ranks: [] + save_path: ${oc.select:global_profiler.save_path,null} + tool_config: + nsys: + _target_: verl.utils.profiler.config.NsightToolConfig + discrete: ${oc.select:global_profiler.global_tool_config.nsys.discrete} + npu: + _target_: verl.utils.profiler.config.NPUToolConfig + contents: [] + level: level0 + analysis: true + discrete: false + torch: + _target_: verl.utils.profiler.config.TorchProfilerToolConfig + contents: [] + discrete: false + torch_memory: + _target_: verl.utils.profiler.config.TorchMemoryToolConfig + trace_alloc_max_entries: ${oc.select:global_profiler.global_tool_config.torch_memory.trace_alloc_max_entries,100000} + stack_depth: ${oc.select:global_profiler.global_tool_config.torch_memory.stack_depth,32} + router_replay: + _target_: verl.workers.config.RouterReplayConfig + mode: disabled + record_file: null + replay_file: null + grad_clip: 1.0 ulysses_sequence_parallel_size: 1 - use_dynamic_bsz: true - use_fused_kernels: false - use_kl_loss: true - use_remove_padding: true - use_torch_compile: true - hybrid_engine: true - model: - custom_chat_template: null - enable_activation_offload: false - enable_gradient_checkpointing: true - exclude_modules: null - external_lib: null - fused_kernel_options: - impl_backend: torch - lora_alpha: 16 - lora_rank: 0 - override_config: {} - path: /mnt/data_cpfs/model_cache/modelscope/hub/Qwen/Qwen/Qwen2___5-14B-Instruct - target_modules: all-linear - trust_remote_code: false - use_fused_kernels: false - use_liger: false - use_remove_padding: true - use_shm: false - nccl_timeout: 600 - profiler: - _target_: verl.utils.profiler.ProfilerConfig - all_ranks: false - discrete: false - ranks: [] - ref: - entropy_checkpointing: false entropy_from_logits_with_chunking: false + entropy_checkpointing: false + use_remove_padding: ${oc.select:actor_rollout_ref.model.use_remove_padding,false} + calculate_sum_pi_squared: false + sum_pi_squared_checkpointing: false + qat: + enable: false + mode: w4a16 + group_size: 16 + ignore_patterns: + - lm_head + - embed_tokens + - re:.*mlp.gate$ + activation_observer: static_minmax + quantization_config_path: null + ref: + rollout_n: ${oc.select:actor_rollout_ref.rollout.n,1} + strategy: ${actor_rollout_ref.actor.strategy} + use_torch_compile: ${oc.select:actor_rollout_ref.actor.use_torch_compile,true} + log_prob_micro_batch_size: null + log_prob_micro_batch_size_per_gpu: null + log_prob_use_dynamic_bsz: ${oc.select:actor_rollout_ref.actor.use_dynamic_bsz,false} + log_prob_max_token_len_per_gpu: ${oc.select:actor_rollout_ref.actor.ppo_max_token_len_per_gpu,16384} + profiler: + _target_: verl.utils.profiler.ProfilerConfig + tool: ${oc.select:global_profiler.tool,null} + enable: false + all_ranks: false + ranks: [] + save_path: ${oc.select:global_profiler.save_path,null} + tool_config: + nsys: + _target_: verl.utils.profiler.config.NsightToolConfig + discrete: ${oc.select:global_profiler.global_tool_config.nsys.discrete} + npu: + _target_: verl.utils.profiler.config.NPUToolConfig + contents: [] + level: level0 + analysis: true + discrete: false + torch: + _target_: verl.utils.profiler.config.TorchProfilerToolConfig + contents: [] + discrete: false + torch_memory: + _target_: verl.utils.profiler.config.TorchMemoryToolConfig + trace_alloc_max_entries: ${oc.select:global_profiler.global_tool_config.torch_memory.trace_alloc_max_entries,100000} + stack_depth: ${oc.select:global_profiler.global_tool_config.torch_memory.stack_depth,32} + router_replay: + _target_: verl.workers.config.RouterReplayConfig + mode: disabled + record_file: null + replay_file: null fsdp_config: _target_: verl.workers.config.FSDPEngineConfig - forward_prefetch: false - param_offload: true - reshard_after_forward: true wrap_policy: min_num_params: 0 - log_prob_max_token_len_per_gpu: 13000 - log_prob_micro_batch_size: null - log_prob_micro_batch_size_per_gpu: 4 - log_prob_use_dynamic_bsz: true - model: null - strategy: fsdp - ulysses_sequence_parallel_size: 1 - use_dynamic_bsz: true - use_torch_compile: true + param_offload: true + optimizer_offload: true + offload_policy: false + reshard_after_forward: true + fsdp_size: -1 + forward_prefetch: false + model_dtype: bfloat16 + use_orig_params: false + seed: 42 + full_determinism: false + ulysses_sequence_parallel_size: 1 + entropy_from_logits_with_chunking: false + use_torch_compile: true + entropy_checkpointing: false + forward_only: true + strategy: fsdp + dtype: bfloat16 + qat: + _target_: verl.workers.config.QATEngineConfig + enable: false + mode: w4a16 + group_size: 16 + ignore_patterns: + - lm_head + - embed_tokens + - re:.*mlp.gate$ + activation_observer: static_minmax + quantization_config_path: null + _target_: verl.workers.config.FSDPActorConfig + ulysses_sequence_parallel_size: ${oc.select:actor_rollout_ref.actor.ulysses_sequence_parallel_size,1} + entropy_from_logits_with_chunking: false + entropy_checkpointing: false rollout: - agent: - agent_loop_config_path: null - custom_async_server: - name: null - path: null - num_workers: 8 - calculate_log_probs: false - cudagraph_capture_sizes: null - custom_dataflow_cls: - name: '' - path: '' - disable_log_stats: true - do_sample: true + _target_: verl.workers.config.RolloutConfig + name: vllm + mode: async + nnodes: 0 + n_gpus_per_node: ${oc.select:trainer.n_gpus_per_node,8} + temperature: 1.0 + top_k: -1 + top_p: 1 + prompt_length: ${oc.select:data.max_prompt_length,512} + response_length: ${oc.select:data.max_response_length,512} dtype: bfloat16 - enable_chunked_prefill: true - enforce_eager: true - engine_kwargs: - sglang: - attention_backend: null - vllm: - disable_mm_preprocessor_cache: false - swap_space: null - free_cache_engine: true - gamma: 1.0 - gpu_memory_utilization: 0.9 + gpu_memory_utilization: 0.95 ignore_eos: false - layered_summon: false - load_format: dummy_dtensor - log_prob_max_token_len_per_gpu: 13000 - log_prob_micro_batch_size: null - log_prob_micro_batch_size_per_gpu: 4 - log_prob_use_dynamic_bsz: true - max_env_worker: 64 - max_model_len: 13000 + enforce_eager: false + cudagraph_capture_sizes: null + free_cache_engine: true + tensor_model_parallel_size: 1 + data_parallel_size: 1 + expert_parallel_size: 1 + pipeline_model_parallel_size: 1 max_num_batched_tokens: 8192 - max_num_seqs: 10 - mode: async + max_model_len: null + max_num_seqs: 1024 + enable_chunked_prefill: false + enable_prefix_caching: false + logprobs_mode: processed_logprobs + scheduling_policy: fcfs + load_format: dummy + log_prob_micro_batch_size: null + log_prob_micro_batch_size_per_gpu: 1 + log_prob_use_dynamic_bsz: ${oc.select:actor_rollout_ref.actor.use_dynamic_bsz,false} + log_prob_max_token_len_per_gpu: ${oc.select:actor_rollout_ref.actor.ppo_max_token_len_per_gpu,16384} + disable_log_stats: true + do_sample: true + 'n': 1 + over_sample_rate: 0 multi_stage_wake_up: false + engine_kwargs: + vllm: {} + sglang: {} + trtllm: {} + val_kwargs: + _target_: verl.workers.config.SamplingConfig + top_k: -1 + top_p: 1.0 + temperature: 0 + 'n': 1 + do_sample: false multi_turn: - enable: true - format: hermes - interaction_config_path: null + _target_: ajet.default_config.verl.config_schema_rollout.AjetMultiTurnConfig + enable: false max_assistant_turns: null + tool_config_path: null + max_user_turns: null max_parallel_calls: 1 - max_sample_per_task: 4 expected_steps: 1 - max_steps: 30 max_tool_response_length: 256 - max_user_turns: null - tokenization_sanity_check_mode: strict - tool_config_path: null tool_response_truncate_side: middle + interaction_config_path: null use_inference_chat_template: false - n: 1 - name: vllm - ppo_micro_batch_size_per_gpu: 1 - prompt_length: 3000 - response_length: 10000 - skip_dump_dir: /tmp/rollout_dump - skip_rollout: false - temperature: 0.9 - tensor_model_parallel_size: 1 - top_k: -1 - top_p: 1.0 + tokenization_sanity_check_mode: strict + format: hermes + num_repeat_rollouts: null + calculate_log_probs: false + agent: + _target_: verl.workers.config.AgentLoopConfig + num_workers: 8 + default_agent_loop: single_turn_agent + agent_loop_config_path: null + custom_async_server: + _target_: verl.workers.config.CustomAsyncServerConfig + path: null + name: null + checkpoint_engine: + _target_: verl.workers.config.CheckpointEngineConfig + backend: naive + update_weights_bucket_megabytes: 4096 + engine_kwargs: {} trace: + _target_: verl.workers.config.TraceConfig + project_name: ${oc.select:trainer.project_name,null} + experiment_name: ${oc.select:trainer.experiment_name,null} backend: null token2text: false - update_weights_bucket_megabytes: 512 - val_kwargs: - do_sample: false - n: 1 - num_repeat: 1 - temperature: 0.0 - top_k: -1 - top_p: 1.0 - - -algorithm: - _target_: verl.trainer.config.AlgoConfig - adv_estimator: grpo - gamma: 1.0 - kl_ctrl: - _target_: verl.trainer.config.KLControlConfig - horizon: 10000 - kl_coef: 0.001 - target_kl: 0.1 - type: fixed - kl_penalty: kl - lam: 1.0 - norm_adv_by_std_in_grpo: true - pf_ppo: - reweight_method: pow - weight_pow: 2.0 - use_kl_in_reward: false - use_pf_ppo: false - - -critic: - _target_: verl.workers.config.FSDPCriticConfig - checkpoint: - _target_: verl.trainer.config.CheckpointConfig - async_save: false - load_contents: - - model - - optimizer - - extra - save_contents: - - model - - optimizer - - extra - cliprange_value: 0.5 - enable: false - forward_max_token_len_per_gpu: 32768 - forward_micro_batch_size: null - forward_micro_batch_size_per_gpu: null - grad_clip: 1.0 - loss_agg_mode: seq-mean-token-mean + max_samples_per_step_per_worker: null + skip_rollout: false + skip_dump_dir: /tmp/rollout_dump + skip_tokenizer_init: true + enable_rollout_routing_replay: false + profiler: + _target_: verl.utils.profiler.ProfilerConfig + tool: ${oc.select:global_profiler.tool,null} + enable: ${oc.select:actor_rollout_ref.actor.profiler.enable,false} + all_ranks: ${oc.select:actor_rollout_ref.actor.profiler.all_ranks,false} + ranks: ${oc.select:actor_rollout_ref.actor.profiler.ranks,[]} + save_path: ${oc.select:global_profiler.save_path,null} + tool_config: + npu: + _target_: verl.utils.profiler.config.NPUToolConfig + contents: ${oc.select:actor_rollout_ref.actor.profiler.tool_config.npu.contents,[]} + level: ${oc.select:actor_rollout_ref.actor.profiler.tool_config.npu.level,level0} + analysis: ${oc.select:actor_rollout_ref.actor.profiler.tool_config.npu.analysis,false} + discrete: ${oc.select:actor_rollout_ref.actor.profiler.tool_config.npu.discrete,false} + torch: + _target_: verl.utils.profiler.config.TorchProfilerToolConfig + contents: ${oc.select:actor_rollout_ref.actor.profiler.tool_config.torch.contents,[]} + discrete: ${oc.select:actor_rollout_ref.actor.profiler.tool_config.torch.discrete,false} + prometheus: + _target_: verl.workers.config.PrometheusConfig + enable: false + port: 9090 + file: /tmp/ray/session_latest/metrics/prometheus/prometheus.yml + served_model_name: ${oc.select:actor_rollout_ref.model.path,null} + quantization: null + quantization_config_file: null + mtp: ${oc.select:actor_rollout_ref.model.mtp, null} + qat: ${oc.select:actor_rollout_ref.actor.fsdp_config.qat,null} + layered_summon: false model: - _target_: verl.workers.config.FSDPCriticModelCfg - enable_activation_offload: false - enable_gradient_checkpointing: true + _target_: verl.workers.config.HFModelConfig + path: ~/models/deepseek-llm-7b-chat + hf_config_path: null + tokenizer_path: null + use_shm: false + trust_remote_code: false + custom_chat_template: null external_lib: null - fsdp_config: - _target_: verl.workers.config.FSDPEngineConfig - forward_prefetch: false - fsdp_size: -1 - offload_policy: false - optimizer_offload: false - param_offload: false - reshard_after_forward: true - wrap_policy: - min_num_params: 0 - lora_alpha: 16 - lora_rank: 0 override_config: {} - path: ~/models/deepseek-llm-7b-chat + enable_gradient_checkpointing: true + enable_activation_offload: false + use_remove_padding: true + lora_rank: 0 + lora_alpha: 16 target_modules: all-linear - trust_remote_code: false - use_remove_padding: false - use_shm: false - optim: - _target_: verl.workers.config.FSDPOptimizerConfig - lr: 1.0e-05 - lr_warmup_steps: -1 - lr_warmup_steps_ratio: 0.0 - min_lr_ratio: null - total_training_steps: -1 - warmup_style: constant - weight_decay: 0.01 - ppo_epochs: 1 - ppo_max_token_len_per_gpu: 32768 - ppo_micro_batch_size: null - ppo_micro_batch_size_per_gpu: null - ppo_mini_batch_size: 16 - profiler: - _target_: verl.utils.profiler.ProfilerConfig - all_ranks: false - discrete: false - ranks: [] - rollout_n: 1 - shuffle: false - strategy: fsdp - ulysses_sequence_parallel_size: 1 - use_dynamic_bsz: true - - -custom_reward_function: - name: compute_score - path: null - - + exclude_modules: null + lora_adapter_path: null + use_liger: false + use_fused_kernels: false + fused_kernel_options: + impl_backend: torch + tiled_mlp: + enabled: false + num_shards: 4 + mtp: + _target_: verl.workers.config.MtpConfig + enable: false + enable_train: false + enable_rollout: false + detach_encoder: false + mtp_loss_scaling_factor: 0.1 + speculative_algorithm: EAGLE + speculative_num_steps: 3 + speculative_eagle_topk: 1 + speculative_num_draft_tokens: 4 + method: mtp + num_speculative_tokens: 1 + hybrid_engine: true + nccl_timeout: 600 data: - custom_cls: - name: null - path: null - datagen: - name: null - path: null + tokenizer: null + use_shm: false + train_files: ~/data/rlhf/gsm8k/train.parquet + val_files: ~/data/rlhf/gsm8k/test.parquet + train_max_samples: -1 + val_max_samples: -1 + prompt_key: prompt + reward_fn_key: data_source + max_prompt_length: 512 + max_response_length: 512 + train_batch_size: 1024 + val_batch_size: null + tool_config_path: ${oc.select:actor_rollout_ref.rollout.multi_turn.tool_config_path, + null} + return_raw_input_ids: false + return_raw_chat: true + return_full_prompt: false + shuffle: true + seed: null dataloader_num_workers: 8 - fast_eval: true - filter_overlong_prompts: true + image_patch_size: 14 + validation_shuffle: false + filter_overlong_prompts: false filter_overlong_prompts_workers: 1 + truncation: error image_key: images - max_prompt_length: 3000 - max_response_length: 10000 - prompt_key: prompt - return_full_prompt: false + video_key: videos + trust_remote_code: false + custom_cls: + path: null + name: null return_multi_modal_inputs: true - return_raw_chat: true - return_raw_input_ids: false - reward_fn_key: data_source sampler: - class_name: null class_path: null - shuffle: true - tokenizer: null - train_batch_size: 264 - train_files: ~/data/rlhf/gsm8k/train.parquet - truncation: error - trust_remote_code: false - use_shm: false - val_batch_size: 100000000000 - val_files: ~/data/rlhf/gsm8k/test.parquet - validation_shuffle: false - video_key: videos - seed: 42 - - -ray_init: - num_cpus: null - timeline_json_file: null - - -reward_model: - enable: false - forward_max_token_len_per_gpu: 32768 - launch_reward_fn_async: false - max_length: null - micro_batch_size: null - micro_batch_size_per_gpu: null + class_name: null + datagen: + path: null + name: null + apply_chat_template_kwargs: {} +critic: + optim: + _target_: verl.workers.config.FSDPOptimizerConfig + optimizer: AdamW + optimizer_impl: torch.optim + lr: 1.0e-05 + lr_warmup_steps_ratio: 0.0 + total_training_steps: -1 + weight_decay: 0.01 + lr_warmup_steps: -1 + betas: + - 0.9 + - 0.999 + clip_grad: 1.0 + min_lr_ratio: 0.0 + num_cycles: 0.5 + lr_scheduler_type: constant + zero_indexed_step: true + warmup_style: null + override_optimizer_config: null model: - external_lib: null fsdp_config: _target_: verl.workers.config.FSDPEngineConfig - forward_prefetch: false - fsdp_size: -1 - param_offload: false - reshard_after_forward: true wrap_policy: min_num_params: 0 - input_tokenizer: /mnt/data_cpfs/model_cache/modelscope/hub/Qwen/Qwen/Qwen2___5-14B-Instruct - path: ~/models/FsfairX-LLaMA3-RM-v0.1 - trust_remote_code: false - use_fused_kernels: false - use_remove_padding: false + param_offload: false + optimizer_offload: false + offload_policy: false + reshard_after_forward: true + fsdp_size: -1 + forward_prefetch: false + model_dtype: bfloat16 + use_orig_params: false + seed: 42 + full_determinism: false + ulysses_sequence_parallel_size: 1 + entropy_from_logits_with_chunking: false + use_torch_compile: true + entropy_checkpointing: false + forward_only: false + strategy: fsdp + dtype: bfloat16 + qat: + _target_: verl.workers.config.QATEngineConfig + enable: false + mode: w4a16 + group_size: 16 + ignore_patterns: + - lm_head + - embed_tokens + - re:.*mlp.gate$ + activation_observer: static_minmax + quantization_config_path: null + path: ~/models/deepseek-llm-7b-chat + tokenizer_path: ${oc.select:actor_rollout_ref.model.path,"~/models/deepseek-llm-7b-chat"} + override_config: {} + external_lib: ${oc.select:actor_rollout_ref.model.external_lib,null} + trust_remote_code: ${oc.select:actor_rollout_ref.model.trust_remote_code,false} + _target_: verl.workers.config.FSDPCriticModelCfg use_shm: false + enable_gradient_checkpointing: true + enable_activation_offload: false + use_remove_padding: false + lora_rank: 0 + lora_alpha: 16 + target_modules: all-linear + tiled_mlp: + enabled: false + num_shards: 4 + _target_: verl.workers.config.FSDPCriticConfig + rollout_n: ${oc.select:actor_rollout_ref.rollout.n,1} + strategy: fsdp + enable: null + ppo_mini_batch_size: ${oc.select:actor_rollout_ref.actor.ppo_mini_batch_size,256} + ppo_micro_batch_size: null + ppo_micro_batch_size_per_gpu: ${oc.select:.ppo_micro_batch_size,null} + use_dynamic_bsz: ${oc.select:actor_rollout_ref.actor.use_dynamic_bsz,false} + ppo_max_token_len_per_gpu: 32768 + forward_max_token_len_per_gpu: ${.ppo_max_token_len_per_gpu} + ppo_epochs: ${oc.select:actor_rollout_ref.actor.ppo_epochs,1} + shuffle: ${oc.select:actor_rollout_ref.actor.shuffle,false} + data_loader_seed: 42 + cliprange_value: 0.5 + loss_agg_mode: ${oc.select:actor_rollout_ref.actor.loss_agg_mode,token-mean} + checkpoint: + _target_: verl.trainer.config.CheckpointConfig + save_contents: + - model + - optimizer + - extra + load_contents: ${.save_contents} + async_save: false + mbridge_config: {} profiler: _target_: verl.utils.profiler.ProfilerConfig + tool: ${oc.select:global_profiler.tool,null} + enable: false all_ranks: false - discrete: false ranks: [] - reward_manager: naive + save_path: ${oc.select:global_profiler.save_path,null} + tool_config: + nsys: + _target_: verl.utils.profiler.config.NsightToolConfig + discrete: ${oc.select:global_profiler.global_tool_config.nsys.discrete} + npu: + _target_: verl.utils.profiler.config.NPUToolConfig + contents: [] + level: level0 + analysis: true + discrete: false + torch: + _target_: verl.utils.profiler.config.TorchProfilerToolConfig + contents: [] + discrete: false + torch_memory: + _target_: verl.utils.profiler.config.TorchMemoryToolConfig + trace_alloc_max_entries: ${oc.select:global_profiler.global_tool_config.torch_memory.trace_alloc_max_entries,100000} + stack_depth: ${oc.select:global_profiler.global_tool_config.torch_memory.stack_depth,32} + forward_micro_batch_size: ${oc.select:.ppo_micro_batch_size,null} + forward_micro_batch_size_per_gpu: ${oc.select:.ppo_micro_batch_size_per_gpu,null} + ulysses_sequence_parallel_size: 1 + grad_clip: 1.0 +custom_reward_function: + path: null + name: null +reward_model: + num_workers: null + reward_manager: null + enable: null + enable_resource_pool: null + n_gpus_per_node: null + nnodes: null + reward_loop_source: null + reward_loop_module_path: null + reward_loop_class_name: null + model: + path: null + external_lib: null + trust_remote_code: null + rollout: + name: null + dtype: null + gpu_memory_utilization: null + enforce_eager: null + cudagraph_capture_sizes: null + free_cache_engine: null + data_parallel_size: null + expert_parallel_size: null + tensor_model_parallel_size: null + max_num_batched_tokens: null + max_model_len: null + max_num_seqs: null + load_format: null + engine_kwargs: null + limit_images: null + enable_chunked_prefill: null + enable_prefix_caching: null + disable_log_stats: null + skip_tokenizer_init: null + prompt_length: null + response_length: null +sandbox_fusion: + url: null + max_concurrent: null + memory_limit_mb: null +reward: + num_workers: 8 + custom_reward_function: + path: null + name: compute_score + reward_manager: + _target_: verl.workers.config.reward_model.RewardManagerConfig + source: register + name: naive + module: + _target_: verl.trainer.config.config.ModuleConfig + path: null + name: custom_reward_manager + reward_model: + enable: false + enable_resource_pool: false + n_gpus_per_node: 8 + nnodes: 0 + model_path: null + rollout: + _target_: verl.workers.config.RolloutConfig + name: ??? + dtype: bfloat16 + gpu_memory_utilization: 0.5 + enforce_eager: true + cudagraph_capture_sizes: null + free_cache_engine: true + data_parallel_size: 1 + expert_parallel_size: 1 + tensor_model_parallel_size: 1 + max_num_batched_tokens: 8192 + max_model_len: null + max_num_seqs: 1024 + load_format: auto + engine_kwargs: {} + limit_images: null + enable_chunked_prefill: true + enable_prefix_caching: true + disable_log_stats: true + skip_tokenizer_init: false + prompt_length: 2048 + response_length: 2048 sandbox_fusion: + url: null max_concurrent: 64 memory_limit_mb: 1024 - url: null - strategy: fsdp - ulysses_sequence_parallel_size: 1 - use_dynamic_bsz: true - - +algorithm: + rollout_correction: + rollout_is: null + rollout_is_threshold: 2.0 + rollout_rs: null + rollout_rs_threshold: null + bypass_mode: false + loss_type: ppo_clip + rollout_is_batch_normalize: false + _target_: verl.trainer.config.AlgoConfig + gamma: 1.0 + lam: 1.0 + adv_estimator: grpo + norm_adv_by_std_in_grpo: true + use_kl_in_reward: false + kl_penalty: kl + kl_ctrl: + _target_: verl.trainer.config.KLControlConfig + type: fixed + kl_coef: 0.001 + horizon: 10000 + target_kl: 0.1 + use_pf_ppo: false + pf_ppo: + reweight_method: pow + weight_pow: 2.0 trainer: balance_batch: true - controller_nsight_options: - cuda-graph-trace: graph - cuda-memory-usage: 'true' - trace: cuda,nvtx,cublas,ucx + total_epochs: 30 + total_training_steps: null + project_name: verl_examples + experiment_name: gsm8k + logger: + - console + - wandb + log_val_generations: 0 + rollout_data_dir: null + validation_data_dir: null + nnodes: 1 + n_gpus_per_node: 8 + save_freq: -1 + esi_redundant_time: 0 + resume_mode: auto + resume_from_path: null + val_before_train: true + val_only: false + test_freq: -1 critic_warmup: 0 default_hdfs_dir: null - checkpoint_base_dir: ./saved_checkpoints - default_local_dir: ${trainer.checkpoint_base_dir}/${trainer.project_name}/${trainer.experiment_name} del_local_ckpt_after_load: false - device: cuda - esi_redundant_time: 0 - experiment_name: read_yaml_name - hfmodelpath: '' - log_val_generations: 0 - logger: - - console - - swanlab + checkpoint_base_dir: checkpoints + default_local_dir: ${trainer.checkpoint_base_dir}/${trainer.project_name}/${trainer.experiment_name} max_actor_ckpt_to_keep: null max_critic_ckpt_to_keep: null - n_gpus_per_node: 8 - nnodes: 1 - npu_profile: - options: - analysis: true - level: level1 - record_shapes: false - roles: - - all - save_path: ./profiler_data - with_cpu: true - with_memory: false - with_module: false - with_npu: true - with_stack: false - profile_continuous_steps: false - profile_steps: null - project_name: project_name_placeholder ray_wait_register_center_timeout: 300 - resume_from_path: null - resume_mode: auto - rollout_data_dir: null - save_freq: 99999 - test_freq: 99999 - total_epochs: 99999 - total_training_steps: null + device: cuda use_legacy_worker_impl: auto - val_before_train: false - val_only: false - val_pass_n: 4 - validation_data_dir: null - worker_nsight_options: - capture-range: cudaProfilerApi - capture-range-end: null - cuda-graph-trace: graph - cuda-memory-usage: 'true' - kill: none - trace: cuda,nvtx,cublas,ucx +global_profiler: + _target_: verl.utils.profiler.ProfilerConfig + tool: null + steps: null + profile_continuous_steps: false + save_path: outputs/profile + global_tool_config: + nsys: + _target_: verl.utils.profiler.config.NsightToolConfig + discrete: false + controller_nsight_options: + trace: cuda,nvtx,cublas,ucx + cuda-memory-usage: 'true' + cuda-graph-trace: graph + worker_nsight_options: + trace: cuda,nvtx,cublas,ucx + cuda-memory-usage: 'true' + cuda-graph-trace: graph + capture-range: cudaProfilerApi + capture-range-end: null + kill: none + torch_memory: + trace_alloc_max_entries: 100000 + stack_depth: 32 + context: all + stacks: all + kw_args: {} +transfer_queue: + enable: false +ray_kwargs: + ray_init: + num_cpus: null + timeline_json_file: null diff --git a/ajet/task_rollout/async_llm_bridge.py b/ajet/task_rollout/async_llm_bridge.py index 9aae1296..b855fd9d 100644 --- a/ajet/task_rollout/async_llm_bridge.py +++ b/ajet/task_rollout/async_llm_bridge.py @@ -10,6 +10,7 @@ from pydantic import BaseModel from vllm.entrypoints.openai.tool_parsers.hermes_tool_parser import Hermes2ProToolParser from vllm.outputs import RequestOutput as VerlVllmRequestOutput +from verl.workers.rollout.replica import TokenOutput from agentscope.model import ChatResponse as AgentScopeChatResponse from openai.types.chat.chat_completion import ChatCompletion as OpenAIChatCompletion @@ -86,18 +87,17 @@ async def llm_chat_verl( ) prompt_token_ids = self.tokenizer(prompt_text)["input_ids"] - final_res = await self.async_rollout_manager.generate( + final_res: TokenOutput = await self.async_rollout_manager.generate( request_id=request_id, prompt_ids=prompt_token_ids, sampling_params=updated_sampling_params, ) - if self.config.ajet.rollout.name == "vllm": - final_res: VerlVllmRequestOutput - token_array = final_res.outputs[0].token_ids - logprob_array = final_res.outputs[0].logprobs - elif self.config.ajet.rollout.name == "sglang": - token_array = final_res + """response token ids""" + token_array = final_res.token_ids + logprob_array = final_res.log_probs + # routed_experts = final_res.routed_experts + # vllm_stop_reason = final_res.stop_reason decoded_text = self.tokenizer.decode(token_array) # type: ignore diff --git a/ajet/utils/launch_utils.py b/ajet/utils/launch_utils.py index 46200ad9..a941dbbc 100644 --- a/ajet/utils/launch_utils.py +++ b/ajet/utils/launch_utils.py @@ -297,7 +297,7 @@ def verify_python_env(args, exp_config): time.sleep(5) raise ImportError(cause + " " + solution) elif args.backbone == "verl": - if not any([v in verl.__version__ for v in ["0.5.0.post", "0.5.0.dev", "0.7.0.post"]]): # you must install via `pip install -e .[verl]` to get every dependency right + if not any([v in verl.__version__ for v in ["0.5.0.post", "0.5.0.dev", "0.7.0.post", "0.8.0.dev"]]): # you must install via `pip install -e .[verl]` to get every dependency right cause = "Python environment does not match current backbone 'verl'." solution = "Please `cd /path/to/project/AgentJet` and run `(uv) pip install -e .[verl]` to install the correct environment." print_dict( diff --git a/ajet/utils/process_dataset.py b/ajet/utils/process_dataset.py index 44f23698..c0493b1a 100644 --- a/ajet/utils/process_dataset.py +++ b/ajet/utils/process_dataset.py @@ -55,7 +55,10 @@ def create_rl_sampler( # If shuffling is enabled in the data configuration, create a random sampler. elif data_config.shuffle: train_dataloader_generator = torch.Generator() - train_dataloader_generator.manual_seed(data_config.get("seed", int(time.time()))) + seed = data_config.get("seed", None) + if seed is None: + seed = int(time.time()) + train_dataloader_generator.manual_seed(seed) sampler = RandomSampler(data_source=dataset, generator=train_dataloader_generator) else: # If shuffling is disabled, use a sequential sampler to iterate through the dataset in order. diff --git a/docs/en/installation.md b/docs/en/installation.md index 2909ca10..ba3806c6 100644 --- a/docs/en/installation.md +++ b/docs/en/installation.md @@ -77,7 +77,7 @@ AgentJet supports multiple backbones, you can choose any of them depending on yo ```bash # Install with `verl` training backbone: - uv venv --python=3.10 + uv venv --python=3.12 source .venv/bin/activate uv pip install -i https://mirrors.aliyun.com/pypi/simple/ -e .[verl] From edabb22b6a43dad219fd5c724426ebe608b5cbd2 Mon Sep 17 00:00:00 2001 From: "qingxu.fu" Date: Tue, 17 Mar 2026 14:18:25 +0800 Subject: [PATCH 02/17] stage success merge --- ajet/backbone/trainer_verl.py | 128 +++++++----- ajet/backbone/verl/__init__.py | 10 +- ajet/backbone/verl/actor_config.py | 47 +++++ ajet/backbone/verl/dp_actor.py | 228 +++++++++++++++++++++ ajet/backbone/verl/fsdp_workers.py | 3 +- ajet/default_config/verl/verl_default.yaml | 10 +- ajet/task_rollout/async_llm_bridge.py | 11 +- ajet/utils/core_env_vars.py | 10 +- pyproject.toml | 4 +- 9 files changed, 391 insertions(+), 60 deletions(-) create mode 100644 ajet/backbone/verl/actor_config.py create mode 100644 ajet/backbone/verl/dp_actor.py diff --git a/ajet/backbone/trainer_verl.py b/ajet/backbone/trainer_verl.py index afe06434..93696434 100644 --- a/ajet/backbone/trainer_verl.py +++ b/ajet/backbone/trainer_verl.py @@ -27,7 +27,7 @@ from tqdm import tqdm from verl import DataProto from verl.experimental.dataset.sampler import AbstractCurriculumSampler -from verl.experimental.fully_async_policy.agent_loop.agent_loop import FullyAsyncAgentLoopManager +from verl.experimental.agent_loop.agent_loop import AsyncLLMServerManager, AgentLoopWorker from verl.single_controller.ray import RayClassWithInitArgs from verl.single_controller.ray.base import create_colocated_worker_cls from verl.trainer.config import AlgoConfig @@ -348,9 +348,19 @@ def init_workers(self): self.reward_fn = parse_reward_from_dataproto self.val_reward_fn = parse_reward_from_dataproto + assert hasattr(self.async_rollout_manager, "agent_loop_workers") + assert len(self.async_rollout_manager.agent_loop_workers) == 1, "Please set `num_workers = 1` in `ajet/default_config/verl/verl_default.yaml`" + + servers = list(zip(self.async_rollout_manager.server_addresses, self.async_rollout_manager.server_handles, strict=True)) + real_async_rollout_manager: AsyncLLMServerManager = AsyncLLMServerManager( + config = self.async_rollout_manager.config, + servers = servers, + load_balancer_handle = self.async_rollout_manager.global_load_balancer + ) + self.parallel_env = VerlRolloutManager( config=self.config, - async_rollout_manager=self.async_rollout_manager, + async_rollout_manager=real_async_rollout_manager, max_parallel=self.config.ajet.rollout.max_env_worker, tokenizer=self.tokenizer, ) @@ -382,6 +392,7 @@ def fit(self): # noqa: C901 # load checkpoint before doing anything self._load_checkpoint() self.checkpoint_manager.update_weights(self.global_steps) + self.checkpoint_manager.sleep_replicas() # perform validation before training # currently, we only support validation using the reward_function. @@ -446,7 +457,7 @@ def fit(self): # noqa: C901 # generate a batch logger.info("rollout step begin") with marked_timer("gen", timing_raw, color="red"): - assert self.async_rollout_mode + # assert self.async_rollout_mode logger.info("wake up begin") self.checkpoint_manager.update_weights(self.global_steps) self._update_interchange_server_status_flag("ENGINE.ROLLING") @@ -526,7 +537,6 @@ def fit(self): # noqa: C901 self.checkpoint_manager.sleep_replicas() logger.info("rollout step end") - batch.non_tensor_batch["uid"] = np.array( [str(uuid.uuid4()) for _ in range(len(batch.batch))], dtype=object, @@ -546,45 +556,60 @@ def fit(self): # noqa: C901 self._balance_batch(batch, metrics=metrics) # compute global_valid tokens - batch.meta_info["global_token_num"] = torch.sum( - batch.batch["attention_mask"], dim=-1 - ).tolist() + batch.meta_info["global_token_num"] = torch.sum(batch.batch["attention_mask"], dim=-1).tolist() with marked_timer("reward", timing_raw, color="yellow"): - if self.config.reward_model.launch_reward_fn_async: - raise NotImplementedError( - "launch_reward_fn_async is not supported in GRPO yet." - ) - else: - reward_tensor, reward_extra_infos_dict = compute_reward( - batch, self.reward_fn - ) + reward_tensor, reward_extra_infos_dict = compute_reward(batch, self.reward_fn) # recompute old_log_probs - logger.info("+ compute log_probs begin") - with marked_timer("old_log_prob", timing_raw, color="blue"): - old_log_prob = self.actor_rollout_wg.compute_log_prob(batch) - entropys = old_log_prob.batch["entropys"] - response_masks = batch.batch["response_mask"] - loss_agg_mode = self.config.actor_rollout_ref.actor.loss_agg_mode - entropy_loss = agg_loss( - loss_mat=entropys, - loss_mask=response_masks, - loss_agg_mode=loss_agg_mode, + # Operating Mode Selection: + # - Bypass mode: Sets old_log_probs = rollout_log_probs (2 policies: π_rollout, π_θ) + # - Decoupled mode: Recomputes old_log_probs as proximal anchor (3 policies: π_rollout, π_old, π_θ) + # Note: π_old computed once per data batch, serves as stable reference during mini-batch updates + rollout_corr_config = self.config.algorithm.get("rollout_correction", None) + bypass_recomputing_logprobs = rollout_corr_config and rollout_corr_config.get("bypass_mode", False) + if bypass_recomputing_logprobs: # Use `rollout_log_probs` + from verl.trainer.ppo.rollout_corr_helper import apply_bypass_mode + + apply_bypass_mode( + batch=batch, + rollout_corr_config=rollout_corr_config, + policy_loss_config=self.config.actor_rollout_ref.actor.policy_loss, ) - assert not torch.isnan( - entropy_loss - ).item(), "Entropy loss should not be NaN, something must have gone terribly wrong." - old_log_prob_metrics = {"actor/entropy": entropy_loss.detach().item()} - metrics.update(old_log_prob_metrics) - old_log_prob.batch.pop("entropys") - batch = batch.union(old_log_prob) - - if "rollout_log_probs" in batch.batch.keys(): - # TODO: we may want to add diff of probs too. - from verl.utils.debug.metrics import calculate_debug_metrics - - metrics.update(calculate_debug_metrics(batch)) + else: # Recompute old_log_probs + with marked_timer("old_log_prob", timing_raw, color="blue"): + old_log_prob, old_log_prob_mfu = self._compute_old_log_prob(batch) + entropys = old_log_prob.batch["entropys"] + response_masks = batch.batch["response_mask"] + actor_config = self.config.actor_rollout_ref.actor + entropy_agg = agg_loss( + loss_mat=entropys, + loss_mask=response_masks, + loss_agg_mode=actor_config.loss_agg_mode, + loss_scale_factor=actor_config.loss_scale_factor, + ) + old_log_prob_metrics = { + "actor/entropy": entropy_agg.detach().item(), + "perf/mfu/actor_infer": old_log_prob_mfu, + } + metrics.update(old_log_prob_metrics) + old_log_prob.batch.pop("entropys") + if "routed_experts" in batch.batch and "routed_experts" in old_log_prob.batch: + raise ValueError( + "Detected conflicting router replay configuration: " + "router_replay.mode='R2' and enable_rollout_routing_replay=True " + "cannot be enabled simultaneously. " + "The enable_rollout_routing_replay option is only used in R3 mode; " + "it should not be set when using R2 mode." + ) + batch = batch.union(old_log_prob) + if "rollout_log_probs" in batch.batch.keys(): + # TODO: we may want to add diff of probs too. + from verl.utils.debug.metrics import calculate_debug_metrics + + metrics.update(calculate_debug_metrics(batch)) + + assert "old_log_probs" in batch.batch, f'"old_log_prob" not in {batch.batch.keys()=}' if self.use_reference_policy: # compute reference log_prob @@ -607,23 +632,33 @@ def fit(self): # noqa: C901 batch.batch["token_level_scores"] = reward_tensor if reward_extra_infos_dict: - batch.non_tensor_batch.update( - {k: np.array(v) for k, v in reward_extra_infos_dict.items()} - ) + batch.non_tensor_batch.update({k: np.array(v) for k, v in reward_extra_infos_dict.items()}) # compute rewards. apply_kl_penalty if available if self.config.algorithm.use_kl_in_reward: batch, kl_metrics = apply_kl_penalty( - batch, - kl_ctrl=self.kl_ctrl_in_reward, - kl_penalty=self.config.algorithm.kl_penalty, + batch, kl_ctrl=self.kl_ctrl_in_reward, kl_penalty=self.config.algorithm.kl_penalty, ) metrics.update(kl_metrics) else: batch.batch["token_level_rewards"] = batch.batch["token_level_scores"] - # compute advantages, executed on the driver process + # Compute rollout correction: IS weights, rejection sampling, and metrics + # Only runs in decoupled mode (computes once per batch using stable π_old) + # In bypass mode, this is skipped - actor computes metrics from evolving π_θ vs π_rollout + if ( + rollout_corr_config is not None + and "rollout_log_probs" in batch.batch + and not bypass_recomputing_logprobs # Only in decoupled mode + ): + from verl.trainer.ppo.rollout_corr_helper import compute_rollout_correction_and_add_to_batch + + # Compute IS weights, apply rejection sampling, compute metrics + batch, is_metrics = compute_rollout_correction_and_add_to_batch(batch, rollout_corr_config) + # IS and off-policy metrics already have rollout_corr/ prefix + metrics.update(is_metrics) + # compute advantages, executed on the driver process norm_adv_by_std_in_grpo = self.config.algorithm.get( "norm_adv_by_std_in_grpo", True ) # GRPO adv normalization factor @@ -649,8 +684,7 @@ def fit(self): # noqa: C901 if self.config.trainer.critic_warmup <= self.global_steps: # update actor with marked_timer("update_actor", timing_raw, color="red"): - batch.meta_info["multi_turn"] = True - actor_output = self.actor_rollout_wg.update_actor(batch) + actor_output = self._update_actor(batch) actor_output_metrics = reduce_metrics(actor_output.meta_info["metrics"]) metrics.update(actor_output_metrics) diff --git a/ajet/backbone/verl/__init__.py b/ajet/backbone/verl/__init__.py index 77d833bb..a92acb07 100644 --- a/ajet/backbone/verl/__init__.py +++ b/ajet/backbone/verl/__init__.py @@ -1,3 +1,11 @@ from .fsdp_workers import AjetActorRolloutRefWorker, AjetAsyncActorRolloutRefWorker +from .actor_config import AjetActorConfig, AjetFSDPActorConfig +from .dp_actor import AjetDataParallelPPOActor -__all__ = ["AjetActorRolloutRefWorker", "AjetAsyncActorRolloutRefWorker"] +__all__ = [ + "AjetActorRolloutRefWorker", + "AjetAsyncActorRolloutRefWorker", + "AjetActorConfig", + "AjetFSDPActorConfig", + "AjetDataParallelPPOActor", +] diff --git a/ajet/backbone/verl/actor_config.py b/ajet/backbone/verl/actor_config.py new file mode 100644 index 00000000..85315aba --- /dev/null +++ b/ajet/backbone/verl/actor_config.py @@ -0,0 +1,47 @@ +# Copyright 2025 Alibaba Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Ajet extensions for verl ActorConfig. +Adds `override_ppo_mini_batch_num` field to control the number of optimizer steps per train-batch-step. +""" + +from dataclasses import dataclass, field +from typing import Optional + +from verl.workers.config.actor import ActorConfig, FSDPActorConfig + + +@dataclass +class AjetActorConfig(ActorConfig): + """ActorConfig extended with ajet-specific fields. + + Additional fields: + override_ppo_mini_batch_num (Optional[int]): If > 0, overrides ppo_mini_batch_size + by computing mini_batch_split_size = ceil(batch_size / override_ppo_mini_batch_num). + """ + + override_ppo_mini_batch_num: Optional[int] = None + + +@dataclass +class AjetFSDPActorConfig(FSDPActorConfig): + """FSDPActorConfig extended with ajet-specific fields. + + Additional fields: + override_ppo_mini_batch_num (Optional[int]): If > 0, overrides ppo_mini_batch_size + by computing mini_batch_split_size = ceil(batch_size / override_ppo_mini_batch_num). + """ + + override_ppo_mini_batch_num: Optional[int] = None diff --git a/ajet/backbone/verl/dp_actor.py b/ajet/backbone/verl/dp_actor.py new file mode 100644 index 00000000..1f7ca3a9 --- /dev/null +++ b/ajet/backbone/verl/dp_actor.py @@ -0,0 +1,228 @@ +# Copyright 2025 Alibaba Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Ajet extension for verl DataParallelPPOActor. +Overrides `update_policy` to support `override_ppo_mini_batch_num` and add debug logging. +""" + +import logging +import math +import os + +import torch +from torch.distributed.fsdp import FullyShardedDataParallel as FSDP +from torch.distributed.tensor import DTensor + +import verl.utils.torch_functional as verl_F +from verl import DataProto +from verl.trainer.ppo.core_algos import agg_loss, get_policy_loss_fn, kl_penalty +from verl.utils.device import get_device_id +from verl.utils.fsdp_utils import FSDPModule, fsdp2_clip_grad_norm_ +from verl.utils.profiler import GPUMemoryLogger +from verl.utils.py_functional import append_to_dict +from verl.utils.seqlen_balancing import prepare_dynamic_batch +from verl.workers.actor.dp_actor import DataParallelPPOActor + +__all__ = ["AjetDataParallelPPOActor"] + +logger = logging.getLogger(__file__) +logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN")) + + +class AjetDataParallelPPOActor(DataParallelPPOActor): + """DataParallelPPOActor with ajet-specific modifications: + + 1. Supports `override_ppo_mini_batch_num` to control the number of optimizer steps per train-batch-step. + 2. Adds debug print for tensor shapes during training. + """ + + @GPUMemoryLogger(role="dp actor", logger=logger) + def update_policy(self, data: DataProto): + # make sure we are in training mode + self.actor_module.train() + + temperature = data.meta_info["temperature"] # temperature must be in the data.meta_info to avoid silent error + pad_token_id = data.meta_info.get("pad_token_id", 0) + + select_keys = [ + "responses", + "response_mask", + "input_ids", + "attention_mask", + "position_ids", + "old_log_probs", + "advantages", + ] + if self.use_prefix_grouper and "prompts" in data.batch.keys(): + select_keys.append("prompts") + if self.config.use_kl_loss: + select_keys.append("ref_log_prob") + # Include pre-computed IS weights if present in batch + # Weights are computed centrally in trainer and added to batch when algorithm.rollout_is=True + if "rollout_is_weights" in data.batch.keys(): + select_keys.append("rollout_is_weights") + # Include rollout_log_probs for computing rollout_corr metrics in bypass mode + if "rollout_log_probs" in data.batch.keys(): + select_keys.append("rollout_log_probs") + + has_multi_modal_inputs = "multi_modal_inputs" in data.non_tensor_batch.keys() + non_tensor_select_keys = [] + if has_multi_modal_inputs: + non_tensor_select_keys.append("multi_modal_inputs") + if self.use_prefix_grouper and "uid" in data.non_tensor_batch.keys(): + non_tensor_select_keys.append("uid") + + data = data.select(batch_keys=select_keys, non_tensor_batch_keys=non_tensor_select_keys) + + # Split to make minibatch iterator for updating the actor + # See PPO paper for details. https://arxiv.org/abs/1707.06347 + # [AJET] Support override_ppo_mini_batch_num to control the number of optimizer steps + if self.config.override_ppo_mini_batch_num > 0: + mini_batch_split_size = math.ceil(data.batch.batch_size[0] / self.config.override_ppo_mini_batch_num) + else: + mini_batch_split_size = self.config.ppo_mini_batch_size + + mini_batches = data.split(mini_batch_split_size) + + on_policy = len(mini_batches) == 1 and self.config.ppo_epochs == 1 + + metrics = { + "actor/pg_loss": 0.0, + "actor/kl_loss": 0.0, + } + for _ in range(self.config.ppo_epochs): + for batch_idx, mini_batch in enumerate(mini_batches): + if self.config.use_dynamic_bsz: + max_token_len = self.config.ppo_max_token_len_per_gpu * self.ulysses_sequence_parallel_size + micro_batches, _ = prepare_dynamic_batch(mini_batch, max_token_len=max_token_len) + else: + self.gradient_accumulation = ( + self.config.ppo_mini_batch_size // self.config.ppo_micro_batch_size_per_gpu + ) + micro_batches = mini_batch.split(self.config.ppo_micro_batch_size_per_gpu) + + self.actor_optimizer.zero_grad() + + for micro_batch in micro_batches: + micro_batch = micro_batch.to(get_device_id()) + micro_batch_metrics = {} + model_inputs = {**micro_batch.batch, **micro_batch.non_tensor_batch, "pad_token_id": pad_token_id} + response_mask = model_inputs["response_mask"] + old_log_prob = model_inputs["old_log_probs"] + advantages = model_inputs["advantages"] + # [AJET] Debug logging for tensor shapes + input_ids = model_inputs["input_ids"] + print(f'*** Current tensor shape, input_ids {input_ids.shape}, response {response_mask.shape}') + + entropy_coeff = self.config.entropy_coeff + loss_agg_mode = self.config.loss_agg_mode + + calculate_entropy = self.config.calculate_entropy or (entropy_coeff != 0) + + if self.config.use_dynamic_bsz: + loss_scale_factor = response_mask.shape[0] / self.config.ppo_mini_batch_size + else: + loss_scale_factor = 1 / self.gradient_accumulation + + # all return: (bsz, response_length) + outputs = self._forward_micro_batch( + model_inputs, temperature=temperature, calculate_entropy=calculate_entropy + ) + log_prob = outputs["log_probs"] + entropy = outputs["entropys"] if calculate_entropy else None + + # for fully_async_policy + if hasattr(self.config, "use_rollout_log_probs") and self.config.use_rollout_log_probs: + old_log_prob = model_inputs["old_log_probs"] + else: + if on_policy: + old_log_prob = log_prob.detach() + else: + old_log_prob = model_inputs["old_log_probs"] + + loss_mode = self.config.policy_loss.get("loss_mode", "vanilla") + # vanilla -> verl.trainer.ppo.core_algos.compute_policy_loss_vanilla + + # Extract pre-computed rollout correction weights if present + # Weights are computed centrally in trainer and added when algorithm.rollout_is=True + rollout_is_weights = model_inputs.get("rollout_is_weights", None) + + # gpg -> verl.trainer.ppo.core_algos.compute_policy_loss_gpg + # clip_cov -> verl.trainer.ppo.core_algos.compute_policy_loss_clip_cov + policy_loss_fn = get_policy_loss_fn(loss_mode) + + # Compute policy loss (any function is expected to return 2 values) + pg_loss, pg_metrics = policy_loss_fn( + old_log_prob=old_log_prob, + log_prob=log_prob, + advantages=advantages, + response_mask=response_mask, + loss_agg_mode=loss_agg_mode, + config=self.config, + rollout_is_weights=rollout_is_weights, + ) + micro_batch_metrics.update(pg_metrics) + + # Skip if using bypass_mode loss (metrics already computed in pg_metrics) + rollout_log_prob = model_inputs.get("rollout_log_probs", None) + if loss_mode != "bypass_mode" and rollout_log_prob is not None: + # Compute metrics using CURRENT policy π_θ vs π_rollout + # Tracks evolving off-policy gap as π_θ updates during mini-batch training + from verl.trainer.ppo.rollout_corr_helper import compute_rollout_corr_metrics_from_logprobs + + rollout_corr_metrics = compute_rollout_corr_metrics_from_logprobs( + log_prob=log_prob, + rollout_log_prob=rollout_log_prob, + response_mask=response_mask, + ) + micro_batch_metrics.update(rollout_corr_metrics) + + policy_loss = pg_loss + if calculate_entropy and entropy is not None: + entropy_agg = agg_loss(loss_mat=entropy, loss_mask=response_mask, loss_agg_mode=loss_agg_mode) + micro_batch_metrics["actor/entropy"] = entropy_agg.detach().item() + if entropy_coeff != 0: + policy_loss -= entropy_agg * entropy_coeff + + if self.config.use_kl_loss: + ref_log_prob = model_inputs["ref_log_prob"] + # compute kl loss + kld = kl_penalty( + logprob=log_prob, ref_logprob=ref_log_prob, kl_penalty=self.config.kl_loss_type + ) + kl_loss = agg_loss(loss_mat=kld, loss_mask=response_mask, loss_agg_mode=loss_agg_mode) + + policy_loss = policy_loss + kl_loss * self.config.kl_loss_coef + metrics["actor/kl_loss"] += kl_loss.detach().item() * loss_scale_factor + micro_batch_metrics["actor/kl_coef"] = self.config.kl_loss_coef + + if self.config.use_dynamic_bsz: + # relative to the dynamic bsz + loss = policy_loss * loss_scale_factor + else: + loss = policy_loss * loss_scale_factor + if self.scaler is not None: + self.scaler.scale(loss).backward() + else: + loss.backward() + + metrics["actor/pg_loss"] += pg_loss.detach().item() * loss_scale_factor + append_to_dict(metrics, micro_batch_metrics) + + grad_norm = self._optimizer_step() + mini_batch_metrics = {"actor/grad_norm": grad_norm.detach().item()} + append_to_dict(metrics, mini_batch_metrics) + self.actor_optimizer.zero_grad() + return metrics diff --git a/ajet/backbone/verl/fsdp_workers.py b/ajet/backbone/verl/fsdp_workers.py index e5c0fc88..826ad78c 100644 --- a/ajet/backbone/verl/fsdp_workers.py +++ b/ajet/backbone/verl/fsdp_workers.py @@ -283,7 +283,8 @@ def __init__(self, config: DictConfig, role: str, **kwargs): @register(dispatch_mode=Dispatch.ONE_TO_ALL) def init_model(self): - from verl.workers.actor import DataParallelPPOActor + # from verl.workers.actor import DataParallelPPOActor + from ajet.backbone.verl.dp_actor import AjetDataParallelPPOActor as DataParallelPPOActor # This is used to import external_lib into the huggingface systems import_external_libs(self.config.model.get("external_lib", None)) diff --git a/ajet/default_config/verl/verl_default.yaml b/ajet/default_config/verl/verl_default.yaml index bb26fb28..c433677d 100644 --- a/ajet/default_config/verl/verl_default.yaml +++ b/ajet/default_config/verl/verl_default.yaml @@ -73,7 +73,7 @@ actor_rollout_ref: override_ppo_mini_batch_num: 1 # special in agentjet ppo_micro_batch_size: null ppo_micro_batch_size_per_gpu: null - use_dynamic_bsz: false + use_dynamic_bsz: true ppo_max_token_len_per_gpu: 16384 clip_ratio: 0.2 clip_ratio_low: 0.2 @@ -92,6 +92,7 @@ actor_rollout_ref: clip_ratio_c: 3.0 loss_agg_mode: token-mean loss_scale_factor: null + global_batch_info: {} entropy_coeff: 0 calculate_entropy: false use_kl_loss: false @@ -245,7 +246,7 @@ actor_rollout_ref: prompt_length: ${oc.select:data.max_prompt_length,512} response_length: ${oc.select:data.max_response_length,512} dtype: bfloat16 - gpu_memory_utilization: 0.95 + gpu_memory_utilization: 0.80 ignore_eos: false enforce_eager: false cudagraph_capture_sizes: null @@ -300,7 +301,7 @@ actor_rollout_ref: calculate_log_probs: false agent: _target_: verl.workers.config.AgentLoopConfig - num_workers: 8 + num_workers: 1 default_agent_loop: single_turn_agent agent_loop_config_path: null custom_async_server: @@ -405,8 +406,7 @@ data: max_response_length: 512 train_batch_size: 1024 val_batch_size: null - tool_config_path: ${oc.select:actor_rollout_ref.rollout.multi_turn.tool_config_path, - null} + tool_config_path: ${oc.select:actor_rollout_ref.rollout.multi_turn.tool_config_path,null} return_raw_input_ids: false return_raw_chat: true return_full_prompt: false diff --git a/ajet/task_rollout/async_llm_bridge.py b/ajet/task_rollout/async_llm_bridge.py index b855fd9d..f954bf04 100644 --- a/ajet/task_rollout/async_llm_bridge.py +++ b/ajet/task_rollout/async_llm_bridge.py @@ -98,6 +98,10 @@ async def llm_chat_verl( logprob_array = final_res.log_probs # routed_experts = final_res.routed_experts # vllm_stop_reason = final_res.stop_reason + if "decoded_string" in final_res.extra_fields: + decoded_string_array = final_res.extra_fields["decoded_string"] + else: + decoded_string_array = [self.tokenizer.decode(token_x) for token_x in token_array] decoded_text = self.tokenizer.decode(token_array) # type: ignore @@ -146,6 +150,7 @@ async def llm_chat_verl( "completion_tokens": len(token_array), # type: ignore "total_tokens": len(prompt_token_ids) + len(token_array), # type: ignore } + # from ajet import bp; bp("DECODE") return { "role": "assistant", "request_id": request_id, @@ -156,10 +161,10 @@ async def llm_chat_verl( "tokens": [ TokenAndProb( token_id=token_id, - logprob=logprob[token_id].logprob, # Warning: vllm logprob does not participant training (not reliable enough), for log only. - decoded_string=logprob[token_id].decoded_token, + logprob=logprob, # Warning: vllm logprob does not participant training (not reliable enough), for log only. + decoded_string=decoded_string, ) - for token_id, logprob in zip(token_array, logprob_array) # type: ignore + for token_id, logprob, decoded_string in zip(token_array, logprob_array, decoded_string_array) # type: ignore ], } diff --git a/ajet/utils/core_env_vars.py b/ajet/utils/core_env_vars.py index 9df18216..daf5386a 100644 --- a/ajet/utils/core_env_vars.py +++ b/ajet/utils/core_env_vars.py @@ -26,9 +26,17 @@ def get_runtime_env(config, is_trinity: bool = False) -> dict: runtime_env = { "env_vars": { - "VLLM_USE_V1": "1", "NCCL_DEBUG": "WARN", + + "VLLM_USE_V1": "1", "VLLM_LOGGING_LEVEL": "WARN", + "VLLM_ALLOW_RUNTIME_LORA_UPDATING": "true", + "VLLM_DISABLE_COMPILE_CACHE": "1", + + "HCCL_HOST_SOCKET_PORT_RANGE": "auto", + "HCCL_NPU_SOCKET_PORT_RANGE": "auto", + + "CUDA_DEVICE_MAX_CONNECTIONS": "1", "TOKENIZERS_PARALLELISM": "true", # use ajet.backbone as plugin directory "TRINITY_PLUGIN_DIRS": str((Path(__file__).parent.parent / "backbone").resolve()), diff --git a/pyproject.toml b/pyproject.toml index 7aebcbc1..af95660b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -30,6 +30,7 @@ dependencies = [ "beast-logger>=0.1.3", "pytest>=8.0.0", "hydra-core", + "cachetools", "datasets>=4", "pip", ] @@ -38,8 +39,7 @@ dependencies = [ [project.optional-dependencies] verl = [ - "transformers<5", - "verl-bundle[vllm]==0.5.0.post2", + "verl[vllm] @ git+https://github.com/volcengine/verl.git@016c1d5a7a3f2973d68fda2f7abe5e7df9e05e00" ] trinity = [ From 67e809291bb408fdaf146cf276fdef2c7113cf69 Mon Sep 17 00:00:00 2001 From: "qingxu.fu" Date: Tue, 17 Mar 2026 14:42:42 +0800 Subject: [PATCH 03/17] support qwen-3.5 --- ajet/default_config/ajet_default.py | 2 +- ajet/default_config/verl/verl_default.yaml | 8 ++++---- ajet/task_runner/base_runner.py | 12 ++++++------ scripts/download_model.py | 4 ++-- tests/bench/benchmark_math/benchmark_math.yaml | 6 +++--- 5 files changed, 16 insertions(+), 16 deletions(-) diff --git a/ajet/default_config/ajet_default.py b/ajet/default_config/ajet_default.py index 22902d21..9832e824 100644 --- a/ajet/default_config/ajet_default.py +++ b/ajet/default_config/ajet_default.py @@ -28,7 +28,7 @@ class AjetData: @dataclass class AjetRollout: user_workflow: str = "tutorial.example_appworld.appworld->ExampleAgentScopeWorkflow" - n_vllm_engine: int = 1 + n_vllm_engine: int = 1 # this argument is NOT effective when NOT using trinity tensor_model_parallel_size: int = 1 num_repeat: int = 8 diff --git a/ajet/default_config/verl/verl_default.yaml b/ajet/default_config/verl/verl_default.yaml index c433677d..391ee3f3 100644 --- a/ajet/default_config/verl/verl_default.yaml +++ b/ajet/default_config/verl/verl_default.yaml @@ -166,7 +166,7 @@ actor_rollout_ref: use_torch_compile: ${oc.select:actor_rollout_ref.actor.use_torch_compile,true} log_prob_micro_batch_size: null log_prob_micro_batch_size_per_gpu: null - log_prob_use_dynamic_bsz: ${oc.select:actor_rollout_ref.actor.use_dynamic_bsz,false} + log_prob_use_dynamic_bsz: true log_prob_max_token_len_per_gpu: ${oc.select:actor_rollout_ref.actor.ppo_max_token_len_per_gpu,16384} profiler: _target_: verl.utils.profiler.ProfilerConfig @@ -248,7 +248,7 @@ actor_rollout_ref: dtype: bfloat16 gpu_memory_utilization: 0.80 ignore_eos: false - enforce_eager: false + enforce_eager: true cudagraph_capture_sizes: null free_cache_engine: true tensor_model_parallel_size: 1 @@ -265,7 +265,7 @@ actor_rollout_ref: load_format: dummy log_prob_micro_batch_size: null log_prob_micro_batch_size_per_gpu: 1 - log_prob_use_dynamic_bsz: ${oc.select:actor_rollout_ref.actor.use_dynamic_bsz,false} + log_prob_use_dynamic_bsz: true log_prob_max_token_len_per_gpu: ${oc.select:actor_rollout_ref.actor.ppo_max_token_len_per_gpu,16384} disable_log_stats: true do_sample: true @@ -508,7 +508,7 @@ critic: ppo_mini_batch_size: ${oc.select:actor_rollout_ref.actor.ppo_mini_batch_size,256} ppo_micro_batch_size: null ppo_micro_batch_size_per_gpu: ${oc.select:.ppo_micro_batch_size,null} - use_dynamic_bsz: ${oc.select:actor_rollout_ref.actor.use_dynamic_bsz,false} + use_dynamic_bsz: true ppo_max_token_len_per_gpu: 32768 forward_max_token_len_per_gpu: ${.ppo_max_token_len_per_gpu} ppo_epochs: ${oc.select:actor_rollout_ref.actor.ppo_epochs,1} diff --git a/ajet/task_runner/base_runner.py b/ajet/task_runner/base_runner.py index b3e68f6f..a52f15a4 100644 --- a/ajet/task_runner/base_runner.py +++ b/ajet/task_runner/base_runner.py @@ -82,12 +82,12 @@ async def wrapper_type_asyncio(self, workflow_cls: Type[Workflow], workflow_task # malloc garbage collection del user_workflow - # run gc in a thread-safe way - if gc_lock.acquire(blocking=False): - try: - gc.collect() - finally: - gc_lock.release() + # # run gc in a thread-safe way + # if gc_lock.acquire(blocking=False): + # try: + # gc.collect() + # finally: + # gc_lock.release() return result diff --git a/scripts/download_model.py b/scripts/download_model.py index 45ce6008..b3492f60 100644 --- a/scripts/download_model.py +++ b/scripts/download_model.py @@ -4,9 +4,9 @@ from loguru import logger from modelscope import snapshot_download - cache_dir = input("model path (./modelscope_cache): ").strip() + cache_dir = input("model path (/mnt/data_cpfs/model_cache/modelscope/hub/Qwen): ").strip() if not cache_dir: - cache_dir = "./modelscope_cache" + cache_dir = "/mnt/data_cpfs/model_cache/modelscope/hub/Qwen" res = snapshot_download(input("model name: ").strip(), cache_dir=cache_dir) logger.success(res) diff --git a/tests/bench/benchmark_math/benchmark_math.yaml b/tests/bench/benchmark_math/benchmark_math.yaml index f0f8d896..e4f5922f 100644 --- a/tests/bench/benchmark_math/benchmark_math.yaml +++ b/tests/bench/benchmark_math/benchmark_math.yaml @@ -14,7 +14,8 @@ ajet: model: # ✨✨✨✨ 设置待训练的模型 - path: /mnt/data_cpfs/model_cache/modelscope/hub/Qwen/Qwen/Qwen2___5-7B-Instruct + # path: /mnt/data_cpfs/model_cache/modelscope/hub/Qwen/Qwen/Qwen2___5-7B-Instruct + path: /mnt/data_cpfs/model_cache/modelscope/hub/Qwen/Qwen/Qwen3.5-9B rollout: user_workflow: "tutorial.example_math_agent.math_agent->ExampleMathLearn" # ✨✨✨✨ 编写并选择Agent @@ -31,7 +32,6 @@ ajet: - "wrong_toolcall" max_response_length_in_one_turn: 1024 max_model_len: 10000 - n_vllm_engine: 2 data: train_batch_size: 100 @@ -48,7 +48,7 @@ ajet: total_epochs: 100 logger: swanlab nnodes: 1 - n_gpus_per_node: 4 + n_gpus_per_node: 8 From ad08055f5a7acc941b8bc82be17ca1b789b8cc92 Mon Sep 17 00:00:00 2001 From: "qingxu.fu" Date: Tue, 17 Mar 2026 15:24:19 +0800 Subject: [PATCH 04/17] patch for higher version vllm --- ajet/task_rollout/async_llm_bridge.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/ajet/task_rollout/async_llm_bridge.py b/ajet/task_rollout/async_llm_bridge.py index f954bf04..b2b48a91 100644 --- a/ajet/task_rollout/async_llm_bridge.py +++ b/ajet/task_rollout/async_llm_bridge.py @@ -8,7 +8,10 @@ from loguru import logger from omegaconf import DictConfig from pydantic import BaseModel -from vllm.entrypoints.openai.tool_parsers.hermes_tool_parser import Hermes2ProToolParser +try: + from vllm.entrypoints.openai.tool_parsers.hermes_tool_parser import Hermes2ProToolParser +except: + from vllm.tool_parsers.hermes_tool_parser import Hermes2ProToolParser # vllm 0.17.x moved this class elsewhere from vllm.outputs import RequestOutput as VerlVllmRequestOutput from verl.workers.rollout.replica import TokenOutput from agentscope.model import ChatResponse as AgentScopeChatResponse From 7909969ded683d547650497583f697d156a65f28 Mon Sep 17 00:00:00 2001 From: "qingxu.fu" Date: Tue, 17 Mar 2026 16:19:17 +0800 Subject: [PATCH 05/17] stable vllm and verl 0.8.0 --- ajet/default_config/verl/verl_default.yaml | 6 +- docs/en/installation.md | 12 - requirements_stable_vllm.txt | 269 ++++++++++++++++++ .../bench/benchmark_math/benchmark_math.yaml | 4 +- 4 files changed, 274 insertions(+), 17 deletions(-) create mode 100644 requirements_stable_vllm.txt diff --git a/ajet/default_config/verl/verl_default.yaml b/ajet/default_config/verl/verl_default.yaml index 391ee3f3..52d967b7 100644 --- a/ajet/default_config/verl/verl_default.yaml +++ b/ajet/default_config/verl/verl_default.yaml @@ -246,7 +246,7 @@ actor_rollout_ref: prompt_length: ${oc.select:data.max_prompt_length,512} response_length: ${oc.select:data.max_response_length,512} dtype: bfloat16 - gpu_memory_utilization: 0.80 + gpu_memory_utilization: 0.8 ignore_eos: false enforce_eager: true cudagraph_capture_sizes: null @@ -618,9 +618,9 @@ reward: model_path: null rollout: _target_: verl.workers.config.RolloutConfig - name: ??? + name: vllm dtype: bfloat16 - gpu_memory_utilization: 0.5 + gpu_memory_utilization: 0.8 enforce_eager: true cudagraph_capture_sizes: null free_cache_engine: true diff --git a/docs/en/installation.md b/docs/en/installation.md index ba3806c6..5e3e9db6 100644 --- a/docs/en/installation.md +++ b/docs/en/installation.md @@ -104,18 +104,6 @@ AgentJet supports multiple backbones, you can choose any of them depending on yo ``` -=== "Trinity (aliyun)" - - ```bash - # Install with `trinity` training backbone for fully asynchronous RFT: - - uv venv --python=3.10 - source .venv/bin/activate - uv pip install -i https://mirrors.aliyun.com/pypi/simple/ -e .[trinity] - uv pip install -i https://mirrors.aliyun.com/pypi/simple/ --verbose flash-attn --no-deps --no-build-isolation --no-cache - ``` - - | Backbone | VERL | Trinity-RFT | | -------- |-------- | ------------- | | Core design | Share-GPU actor-rollout engine (colocate) | Async actor-rollout engine | diff --git a/requirements_stable_vllm.txt b/requirements_stable_vllm.txt new file mode 100644 index 00000000..ea753e95 --- /dev/null +++ b/requirements_stable_vllm.txt @@ -0,0 +1,269 @@ +absl-py==2.4.0 +accelerate==1.13.0 +agentscope==1.0.8 +aiohappyeyeballs==2.6.1 +aiohttp==3.13.3 +aiohttp-cors==0.8.1 +aioitertools==0.13.0 +aiosignal==1.4.0 +annotated-doc==0.0.4 +annotated-types==0.7.0 +anthropic==0.71.0 +antlr4-python3-runtime==4.9.3 +anyio==4.12.1 +apache-tvm-ffi==0.1.9 +astor==0.8.1 +attrs==25.4.0 +bcrypt==5.0.0 +beast-logger==0.1.4 +bidict==0.23.1 +blake3==1.0.8 +build==1.4.0 +cached-property==2.0.1 +cachetools==7.0.5 +cbor2==5.8.0 +certifi==2026.2.25 +cffi==2.0.0 +charset-normalizer==3.4.6 +chromadb==1.5.5 +click==8.3.1 +cloudpickle==3.1.2 +codetiming==1.4.0 +colorful==0.5.8 +compressed-tensors==0.12.2 +cryptography==46.0.5 +cuda-bindings==13.2.0 +cuda-pathfinder==1.4.2 +cuda-python==13.2.0 +cupy-cuda12x==13.6.0 +dashscope==1.25.14 +datasets==4.7.0 +debugpy==1.8.20 +depyf==0.20.0 +dill==0.4.0 +diskcache==5.6.3 +distlib==0.4.0 +distro==1.9.0 +dnspython==2.8.0 +docstring-parser==0.17.0 +durationpy==0.10 +einops==0.8.2 +email-validator==2.3.0 +farama-notifications==0.0.4 +fastapi==0.135.1 +fastapi-cli==0.0.24 +fastapi-cloud-cli==0.15.0 +fastar==0.8.0 +fastrlock==0.8.3 +filelock==3.25.2 +flash-attn==2.8.3 +flashinfer-python==0.5.3 +flatbuffers==25.12.19 +frozenlist==1.8.0 +fsspec==2026.2.0 +gguf==0.18.0 +gitdb==4.0.12 +gitpython==3.1.46 +google-api-core==2.30.0 +google-auth==2.49.1 +googleapis-common-protos==1.73.0 +grpcio==1.78.0 +gymnasium==1.2.3 +h11==0.16.0 +h2==4.3.0 +hf-xet==1.4.2 +hpack==4.1.0 +httpcore==1.0.9 +httptools==0.7.1 +httpx==0.28.1 +httpx-sse==0.4.3 +huggingface-hub==0.36.2 +hydra-core==1.3.2 +hyperframe==6.1.0 +idna==3.11 +importlib-metadata==8.7.1 +importlib-resources==6.5.2 +iniconfig==2.3.0 +interegular==0.3.3 +jieba==0.42.1 +jinja2==3.1.6 +jiter==0.13.0 +jmespath==1.1.0 +json-repair==0.58.6 +json5==0.13.0 +jsonschema==4.26.0 +jsonschema-specifications==2025.9.1 +kubernetes==35.0.0 +lark==1.2.2 +llguidance==1.3.0 +llvmlite==0.44.0 +lm-format-enforcer==0.11.3 +loguru==0.7.3 +markdown==3.10.2 +markdown-it-py==4.0.0 +markupsafe==3.0.3 +mcp==1.26.0 +mdurl==0.1.2 +mistral-common==1.10.0 +mmh3==5.2.1 +model-hosting-container-standards==0.1.13 +modelscope==1.35.0 +mpmath==1.3.0 +msgpack==1.1.2 +msgspec==0.20.0 +multidict==6.7.1 +multiprocess==0.70.18 +networkx==3.6.1 +ninja==1.13.0 +numba==0.61.2 +numpy==1.26.4 +nvidia-cublas-cu12==12.8.4.1 +nvidia-cuda-cupti-cu12==12.8.90 +nvidia-cuda-nvrtc-cu12==12.8.93 +nvidia-cuda-runtime-cu12==12.8.90 +nvidia-cudnn-cu12==9.10.2.21 +nvidia-cudnn-frontend==1.19.1 +nvidia-cufft-cu12==11.3.3.83 +nvidia-cufile-cu12==1.13.1.3 +nvidia-curand-cu12==10.3.9.90 +nvidia-cusolver-cu12==11.7.3.90 +nvidia-cusparse-cu12==12.5.8.93 +nvidia-cusparselt-cu12==0.7.1 +nvidia-cutlass-dsl==4.4.2 +nvidia-cutlass-dsl-libs-base==4.4.2 +nvidia-ml-py==13.590.48 +nvidia-nccl-cu12==2.27.5 +nvidia-nvjitlink-cu12==12.8.93 +nvidia-nvshmem-cu12==3.3.20 +nvidia-nvtx-cu12==12.8.90 +oauthlib==3.3.1 +omegaconf==2.3.0 +onnxruntime==1.24.3 +openai==2.28.0 +openai-harmony==0.0.8 +opencensus==0.11.4 +opencensus-context==0.1.3 +opencv-python-headless==4.11.0.86 +opentelemetry-api==1.40.0 +opentelemetry-exporter-otlp==1.40.0 +opentelemetry-exporter-otlp-proto-common==1.40.0 +opentelemetry-exporter-otlp-proto-grpc==1.40.0 +opentelemetry-exporter-otlp-proto-http==1.40.0 +opentelemetry-exporter-prometheus==0.61b0 +opentelemetry-proto==1.40.0 +opentelemetry-sdk==1.40.0 +opentelemetry-semantic-conventions==0.61b0 +orjson==3.11.7 +outlines-core==0.2.11 +overrides==7.7.0 +packaging==25.0 +pandas==3.0.1 +partial-json-parser==0.2.1.1.post7 +peft==0.18.1 +pillow==12.1.1 +pip==26.0.1 +platformdirs==4.9.4 +pluggy==1.6.0 +prettytable==3.17.0 +prometheus-client==0.24.1 +prometheus-fastapi-instrumentator==7.1.0 +propcache==0.4.1 +proto-plus==1.27.1 +protobuf==6.33.5 +psutil==7.2.2 +py-cpuinfo==9.0.0 +py-spy==0.4.1 +pyarrow==23.0.1 +pyasn1==0.6.2 +pyasn1-modules==0.4.2 +pybase64==1.4.3 +pybind11==3.0.2 +pycountry==26.2.16 +pycparser==3.0 +pydantic==2.12.5 +pydantic-core==2.41.5 +pydantic-extra-types==2.11.1 +pydantic-settings==2.13.1 +pyecharts==2.1.0 +pygame==2.6.1 +pygments==2.19.2 +pyjwt==2.12.1 +pylatexenc==2.10 +pypika==0.51.1 +pyproject-hooks==1.2.0 +pytest==9.0.2 +python-datauri==3.0.2 +python-dateutil==2.9.0.post0 +python-discovery==1.1.3 +python-dotenv==1.2.2 +python-engineio==4.13.1 +python-json-logger==4.0.0 +python-multipart==0.0.22 +python-socketio==5.16.1 +pyvers==0.1.0 +pyyaml==6.0.3 +pyzmq==27.1.0 +ray==2.54.0 +referencing==0.37.0 +regex==2026.2.28 +requests==2.32.5 +requests-oauthlib==2.0.0 +rich==13.9.4 +rich-toolkit==0.19.7 +rignore==0.7.6 +rpds-py==0.30.0 +safetensors==0.7.0 +scipy==1.17.1 +sentencepiece==0.2.1 +sentry-sdk==2.54.0 +setproctitle==1.3.7 +setuptools==80.10.2 +shellingham==1.5.4 +shortuuid==1.0.13 +simple-websocket==1.1.0 +simplejson==3.20.2 +six==1.17.0 +smart-open==7.5.1 +smmap==5.0.3 +sniffio==1.3.1 +sounddevice==0.5.5 +sse-starlette==3.3.2 +starlette==0.52.1 +supervisor==4.3.0 +swanlab==0.7.11 +sympy==1.14.0 +tabulate==0.10.0 +tenacity==9.1.4 +tensorboard==2.20.0 +tensorboard-data-server==0.7.2 +tensordict==0.10.0 +tiktoken==0.12.0 +tokenizers==0.22.2 +torch==2.9.0 +torchaudio==2.9.0 +torchdata==0.11.0 +torchvision==0.24.0 +tqdm==4.67.3 +transformers==4.57.6 +triton==3.5.0 +typer==0.24.1 +typing-extensions==4.15.0 +typing-inspection==0.4.2 +urllib3==2.6.3 +uvicorn==0.42.0 +uvloop==0.22.1 +verl @ git+https://github.com/volcengine/verl.git@016c1d5a7a3f2973d68fda2f7abe5e7df9e05e00 +virtualenv==21.2.0 +vllm==0.12.0 +wandb==0.25.1 +watchfiles==1.1.1 +wcwidth==0.6.0 +websocket-client==1.9.0 +websockets==16.0 +werkzeug==3.1.6 +wrapt==2.1.2 +wsproto==1.3.2 +xgrammar==0.1.27 +xxhash==3.6.0 +yarl==1.23.0 +zipp==3.23.0 diff --git a/tests/bench/benchmark_math/benchmark_math.yaml b/tests/bench/benchmark_math/benchmark_math.yaml index e4f5922f..b7689c91 100644 --- a/tests/bench/benchmark_math/benchmark_math.yaml +++ b/tests/bench/benchmark_math/benchmark_math.yaml @@ -14,8 +14,8 @@ ajet: model: # ✨✨✨✨ 设置待训练的模型 - # path: /mnt/data_cpfs/model_cache/modelscope/hub/Qwen/Qwen/Qwen2___5-7B-Instruct - path: /mnt/data_cpfs/model_cache/modelscope/hub/Qwen/Qwen/Qwen3.5-9B + path: /mnt/data_cpfs/model_cache/modelscope/hub/Qwen/Qwen/Qwen2___5-7B-Instruct + # path: /mnt/data_cpfs/model_cache/modelscope/hub/Qwen/Qwen/Qwen3.5-9B rollout: user_workflow: "tutorial.example_math_agent.math_agent->ExampleMathLearn" # ✨✨✨✨ 编写并选择Agent From 8e2a95491c3901058b8f76d6d27cdf0cb008372f Mon Sep 17 00:00:00 2001 From: Qingxu Fu Date: Tue, 17 Mar 2026 17:11:14 +0800 Subject: [PATCH 06/17] patch docker builder --- pyproject.toml | 2 +- scripts/docker/dockerfile | 4 +- scripts/docker/dockerfile_zh | 7 +- .../docker/pyproject_for_docker_build.toml | 276 +++++++++++++++++- 4 files changed, 279 insertions(+), 10 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index af95660b..f95f087c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "AgentJet" -version = "0.0.2" +version = "0.0.3" readme = "README.md" classifiers = [ "Development Status :: 3 - Alpha", diff --git a/scripts/docker/dockerfile b/scripts/docker/dockerfile index c485dee2..4b26435c 100644 --- a/scripts/docker/dockerfile +++ b/scripts/docker/dockerfile @@ -31,7 +31,7 @@ COPY scripts/docker/pyproject_for_docker_build.toml pyproject.toml RUN pip install uv # use uv to create a virtual environment and install dependencies -RUN uv venv /opt/venv --python=3.10 +RUN uv venv /opt/venv --python=3.12 ENV UV_HTTP_TIMEOUT=9999 @@ -41,7 +41,7 @@ RUN . /opt/venv/bin/activate && uv pip install flash_attn==2.8.3 --no-deps --no- # cache friendly layer for code changes COPY . . -RUN . /opt/venv/bin/activate && uv pip install -e .[verl] -i https://mirrors.aliyun.com/pypi/simple/ +RUN . /opt/venv/bin/activate && uv pip install -e .[verl] RUN wget https://dail-wlcb.oss-cn-wulanchabu.aliyuncs.com/astuner_archive/dataset.tar.gz RUN mkdir -p /mnt/data_cpfs/model_cache/modelscope RUN tar -xzf dataset.tar.gz -C /mnt/data_cpfs/model_cache/modelscope/ && rm dataset.tar.gz diff --git a/scripts/docker/dockerfile_zh b/scripts/docker/dockerfile_zh index d3cd70fb..381880be 100644 --- a/scripts/docker/dockerfile_zh +++ b/scripts/docker/dockerfile_zh @@ -32,7 +32,7 @@ COPY scripts/docker/pyproject_for_docker_build.toml pyproject.toml RUN pip install uv -i https://mirrors.aliyun.com/pypi/simple/ # use uv to create a virtual environment and install dependencies -RUN uv venv /opt/venv --python=3.10 +RUN uv venv /opt/venv --python=3.12 ENV UV_HTTP_TIMEOUT=9999 @@ -42,8 +42,9 @@ ENV UV_HTTP_TIMEOUT=9999 # for ZH users, install dependencies from aliyun mirror and use prebuilt flash_attn wheel RUN . /opt/venv/bin/activate && uv pip install -e .[verl] -i https://mirrors.aliyun.com/pypi/simple/ -RUN wget https://dail-wlcb.oss-cn-wulanchabu.aliyuncs.com/astuner_archive/flash_attn-2.8.3%2Bcu12torch2.7cxx11abiTRUE-cp310-cp310-linux_x86_64.whl -RUN . /opt/venv/bin/activate && uv pip install flash_attn-2.8.3+cu12torch2.7cxx11abiTRUE-cp310-cp310-linux_x86_64.whl +RUN . /opt/venv/bin/activate && uv pip install flash_attn==2.8.3 --no-deps --no-cache-dir --no-build-isolation +# RUN wget https://dail-wlcb.oss-cn-wulanchabu.aliyuncs.com/astuner_archive/flash_attn-2.8.3%2Bcu12torch2.7cxx11abiTRUE-cp310-cp310-linux_x86_64.whl +# RUN . /opt/venv/bin/activate && uv pip install flash_attn-2.8.3+cu12torch2.7cxx11abiTRUE-cp310-cp310-linux_x86_64.whl # cache friendly layer for code changes diff --git a/scripts/docker/pyproject_for_docker_build.toml b/scripts/docker/pyproject_for_docker_build.toml index 4b29c7ff..faf4498f 100644 --- a/scripts/docker/pyproject_for_docker_build.toml +++ b/scripts/docker/pyproject_for_docker_build.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "AgentJet" -version = "0.0.0" +version = "0.0.3" readme = "README.md" classifiers = [ "Development Status :: 3 - Alpha", @@ -19,7 +19,7 @@ requires-python = ">=3.10,<3.13" dependencies = [ "agentscope==1.0.8", "chromadb", - "httpx", + "httpx[http2]", "tenacity", "loguru", "debugpy", @@ -30,6 +30,7 @@ dependencies = [ "beast-logger>=0.1.3", "pytest>=8.0.0", "hydra-core", + "cachetools", "datasets>=4", "pip", ] @@ -38,8 +39,275 @@ dependencies = [ [project.optional-dependencies] verl = [ - "transformers<5", - "verl-bundle[vllm]==0.5.0.post2", + "absl-py==2.4.0", + "accelerate==1.13.0", + "agentscope==1.0.8", + "aiohappyeyeballs==2.6.1", + "aiohttp==3.13.3", + "aiohttp-cors==0.8.1", + "aioitertools==0.13.0", + "aiosignal==1.4.0", + "annotated-doc==0.0.4", + "annotated-types==0.7.0", + "anthropic==0.71.0", + "antlr4-python3-runtime==4.9.3", + "anyio==4.12.1", + "apache-tvm-ffi==0.1.9", + "astor==0.8.1", + "attrs==25.4.0", + "bcrypt==5.0.0", + "beast-logger==0.1.4", + "bidict==0.23.1", + "blake3==1.0.8", + "build==1.4.0", + "cached-property==2.0.1", + "cachetools==7.0.5", + "cbor2==5.8.0", + "certifi==2026.2.25", + "cffi==2.0.0", + "charset-normalizer==3.4.6", + "chromadb==1.5.5", + "click==8.3.1", + "cloudpickle==3.1.2", + "codetiming==1.4.0", + "colorful==0.5.8", + "compressed-tensors==0.12.2", + "cryptography==46.0.5", + "cuda-bindings==13.2.0", + "cuda-pathfinder==1.4.2", + "cuda-python==13.2.0", + "cupy-cuda12x==13.6.0", + "dashscope==1.25.14", + "datasets==4.7.0", + "debugpy==1.8.20", + "depyf==0.20.0", + "dill==0.4.0", + "diskcache==5.6.3", + "distlib==0.4.0", + "distro==1.9.0", + "dnspython==2.8.0", + "docstring-parser==0.17.0", + "durationpy==0.10", + "einops==0.8.2", + "email-validator==2.3.0", + "farama-notifications==0.0.4", + "fastapi==0.135.1", + "fastapi-cli==0.0.24", + "fastapi-cloud-cli==0.15.0", + "fastar==0.8.0", + "fastrlock==0.8.3", + "filelock==3.25.2", + "flashinfer-python==0.5.3", + "flatbuffers==25.12.19", + "frozenlist==1.8.0", + "fsspec==2026.2.0", + "gguf==0.18.0", + "gitdb==4.0.12", + "gitpython==3.1.46", + "google-api-core==2.30.0", + "google-auth==2.49.1", + "googleapis-common-protos==1.73.0", + "grpcio==1.78.0", + "gymnasium==1.2.3", + "h11==0.16.0", + "h2==4.3.0", + "hf-xet==1.4.2", + "hpack==4.1.0", + "httpcore==1.0.9", + "httptools==0.7.1", + "httpx==0.28.1", + "httpx-sse==0.4.3", + "huggingface-hub==0.36.2", + "hydra-core==1.3.2", + "hyperframe==6.1.0", + "idna==3.11", + "importlib-metadata==8.7.1", + "importlib-resources==6.5.2", + "iniconfig==2.3.0", + "interegular==0.3.3", + "jieba==0.42.1", + "jinja2==3.1.6", + "jiter==0.13.0", + "jmespath==1.1.0", + "json-repair==0.58.6", + "json5==0.13.0", + "jsonschema==4.26.0", + "jsonschema-specifications==2025.9.1", + "kubernetes==35.0.0", + "lark==1.2.2", + "llguidance==1.3.0", + "llvmlite==0.44.0", + "lm-format-enforcer==0.11.3", + "loguru==0.7.3", + "markdown==3.10.2", + "markdown-it-py==4.0.0", + "markupsafe==3.0.3", + "mcp==1.26.0", + "mdurl==0.1.2", + "mistral-common==1.10.0", + "mmh3==5.2.1", + "model-hosting-container-standards==0.1.13", + "modelscope==1.35.0", + "mpmath==1.3.0", + "msgpack==1.1.2", + "msgspec==0.20.0", + "multidict==6.7.1", + "multiprocess==0.70.18", + "networkx==3.6.1", + "ninja==1.13.0", + "numba==0.61.2", + "numpy==1.26.4", + "nvidia-cublas-cu12==12.8.4.1", + "nvidia-cuda-cupti-cu12==12.8.90", + "nvidia-cuda-nvrtc-cu12==12.8.93", + "nvidia-cuda-runtime-cu12==12.8.90", + "nvidia-cudnn-cu12==9.10.2.21", + "nvidia-cudnn-frontend==1.19.1", + "nvidia-cufft-cu12==11.3.3.83", + "nvidia-cufile-cu12==1.13.1.3", + "nvidia-curand-cu12==10.3.9.90", + "nvidia-cusolver-cu12==11.7.3.90", + "nvidia-cusparse-cu12==12.5.8.93", + "nvidia-cusparselt-cu12==0.7.1", + "nvidia-cutlass-dsl==4.4.2", + "nvidia-cutlass-dsl-libs-base==4.4.2", + "nvidia-ml-py==13.590.48", + "nvidia-nccl-cu12==2.27.5", + "nvidia-nvjitlink-cu12==12.8.93", + "nvidia-nvshmem-cu12==3.3.20", + "nvidia-nvtx-cu12==12.8.90", + "oauthlib==3.3.1", + "omegaconf==2.3.0", + "onnxruntime==1.24.3", + "openai==2.28.0", + "openai-harmony==0.0.8", + "opencensus==0.11.4", + "opencensus-context==0.1.3", + "opencv-python-headless==4.11.0.86", + "opentelemetry-api==1.40.0", + "opentelemetry-exporter-otlp==1.40.0", + "opentelemetry-exporter-otlp-proto-common==1.40.0", + "opentelemetry-exporter-otlp-proto-grpc==1.40.0", + "opentelemetry-exporter-otlp-proto-http==1.40.0", + "opentelemetry-exporter-prometheus==0.61b0", + "opentelemetry-proto==1.40.0", + "opentelemetry-sdk==1.40.0", + "opentelemetry-semantic-conventions==0.61b0", + "orjson==3.11.7", + "outlines-core==0.2.11", + "overrides==7.7.0", + "packaging==25.0", + "pandas==3.0.1", + "partial-json-parser==0.2.1.1.post7", + "peft==0.18.1", + "pillow==12.1.1", + "pip==26.0.1", + "platformdirs==4.9.4", + "pluggy==1.6.0", + "prettytable==3.17.0", + "prometheus-client==0.24.1", + "prometheus-fastapi-instrumentator==7.1.0", + "propcache==0.4.1", + "proto-plus==1.27.1", + "protobuf==6.33.5", + "psutil==7.2.2", + "py-cpuinfo==9.0.0", + "py-spy==0.4.1", + "pyarrow==23.0.1", + "pyasn1==0.6.2", + "pyasn1-modules==0.4.2", + "pybase64==1.4.3", + "pybind11==3.0.2", + "pycountry==26.2.16", + "pycparser==3.0", + "pydantic==2.12.5", + "pydantic-core==2.41.5", + "pydantic-extra-types==2.11.1", + "pydantic-settings==2.13.1", + "pyecharts==2.1.0", + "pygame==2.6.1", + "pygments==2.19.2", + "pyjwt==2.12.1", + "pylatexenc==2.10", + "pypika==0.51.1", + "pyproject-hooks==1.2.0", + "pytest==9.0.2", + "python-datauri==3.0.2", + "python-dateutil==2.9.0.post0", + "python-discovery==1.1.3", + "python-dotenv==1.2.2", + "python-engineio==4.13.1", + "python-json-logger==4.0.0", + "python-multipart==0.0.22", + "python-socketio==5.16.1", + "pyvers==0.1.0", + "pyyaml==6.0.3", + "pyzmq==27.1.0", + "ray==2.54.0", + "referencing==0.37.0", + "regex==2026.2.28", + "requests==2.32.5", + "requests-oauthlib==2.0.0", + "rich==13.9.4", + "rich-toolkit==0.19.7", + "rignore==0.7.6", + "rpds-py==0.30.0", + "safetensors==0.7.0", + "scipy==1.17.1", + "sentencepiece==0.2.1", + "sentry-sdk==2.54.0", + "setproctitle==1.3.7", + "setuptools==80.10.2", + "shellingham==1.5.4", + "shortuuid==1.0.13", + "simple-websocket==1.1.0", + "simplejson==3.20.2", + "six==1.17.0", + "smart-open==7.5.1", + "smmap==5.0.3", + "sniffio==1.3.1", + "sounddevice==0.5.5", + "sse-starlette==3.3.2", + "starlette==0.52.1", + "supervisor==4.3.0", + "swanlab==0.7.11", + "sympy==1.14.0", + "tabulate==0.10.0", + "tenacity==9.1.4", + "tensorboard==2.20.0", + "tensorboard-data-server==0.7.2", + "tensordict==0.10.0", + "tiktoken==0.12.0", + "tokenizers==0.22.2", + "torch==2.9.0", + "torchaudio==2.9.0", + "torchdata==0.11.0", + "torchvision==0.24.0", + "tqdm==4.67.3", + "transformers==4.57.6", + "triton==3.5.0", + "typer==0.24.1", + "typing-extensions==4.15.0", + "typing-inspection==0.4.2", + "urllib3==2.6.3", + "uvicorn==0.42.0", + "uvloop==0.22.1", + "verl @ git+https://github.com/volcengine/verl.git@016c1d5a7a3f2973d68fda2f7abe5e7df9e05e00", + "virtualenv==21.2.0", + "vllm==0.12.0", + "wandb==0.25.1", + "watchfiles==1.1.1", + "wcwidth==0.6.0", + "websocket-client==1.9.0", + "websockets==16.0", + "werkzeug==3.1.6", + "wrapt==2.1.2", + "wsproto==1.3.2", + "xgrammar==0.1.27", + "xxhash==3.6.0", + "yarl==1.23.0", + "zipp==3.23.0", + ] trinity = [ From b36bf738606daf13e2fad5b880e8572d41391b37 Mon Sep 17 00:00:00 2001 From: "qingxu.fu" Date: Tue, 17 Mar 2026 18:32:03 +0800 Subject: [PATCH 07/17] upgrade python to 3.12 --- ajet/default_config/verl/verl_default.yaml | 4 ++-- docs/en/installation.md | 7 +++---- pyproject.toml | 4 +++- 3 files changed, 8 insertions(+), 7 deletions(-) diff --git a/ajet/default_config/verl/verl_default.yaml b/ajet/default_config/verl/verl_default.yaml index 52d967b7..16952ab4 100644 --- a/ajet/default_config/verl/verl_default.yaml +++ b/ajet/default_config/verl/verl_default.yaml @@ -246,7 +246,7 @@ actor_rollout_ref: prompt_length: ${oc.select:data.max_prompt_length,512} response_length: ${oc.select:data.max_response_length,512} dtype: bfloat16 - gpu_memory_utilization: 0.8 + gpu_memory_utilization: 0.85 ignore_eos: false enforce_eager: true cudagraph_capture_sizes: null @@ -620,7 +620,7 @@ reward: _target_: verl.workers.config.RolloutConfig name: vllm dtype: bfloat16 - gpu_memory_utilization: 0.8 + gpu_memory_utilization: 0.85 enforce_eager: true cudagraph_capture_sizes: null free_cache_engine: true diff --git a/docs/en/installation.md b/docs/en/installation.md index 5e3e9db6..74090759 100644 --- a/docs/en/installation.md +++ b/docs/en/installation.md @@ -11,7 +11,7 @@ This document provides a step-by-step guide to installing AgentJet. | Requirement | Detail | |-------------|---------| -| **Python** | 3.10 | +| **Python** | 3.12 | | Package Management | `uv` or `conda` | @@ -40,7 +40,7 @@ AgentJet supports multiple backbones, you can choose any of them depending on yo ```bash # Install with `verl` training backbone: - uv venv --python=3.10 + uv venv --python=3.12 source .venv/bin/activate uv pip install -e .[verl] @@ -57,7 +57,7 @@ AgentJet supports multiple backbones, you can choose any of them depending on yo ```bash # Install with `verl` training backbone: - conda create -n ajet-verl python=3.10 + conda create -n ajet-verl python=3.12 conda activate ajet-verl pip install -e .[verl] @@ -65,7 +65,6 @@ AgentJet supports multiple backbones, you can choose any of them depending on yo pip install --verbose flash-attn --no-deps --no-build-isolation --no-cache ``` - !!! warning "flash-attn Installation" - `flash-attn` must be installed **after** other dependencies. - If you find your machine spend a long time installing flash-attn, ensure a healthy connection to GitHub. diff --git a/pyproject.toml b/pyproject.toml index f95f087c..73d89cd9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -39,7 +39,9 @@ dependencies = [ [project.optional-dependencies] verl = [ - "verl[vllm] @ git+https://github.com/volcengine/verl.git@016c1d5a7a3f2973d68fda2f7abe5e7df9e05e00" + "verl[vllm] @ git+https://github.com/volcengine/verl.git@016c1d5a7a3f2973d68fda2f7abe5e7df9e05e00", + "vllm==0.11.0", + "transformers==4.57.6" ] trinity = [ From 1828d5822c4e2bf7f1ec22c85674c3f38b6f38e2 Mon Sep 17 00:00:00 2001 From: "qingxu.fu" Date: Tue, 17 Mar 2026 20:22:26 +0800 Subject: [PATCH 08/17] patch green verl versions --- ajet/utils/launch_utils.py | 2 +- pyproject.toml | 2 +- tests/bench/README.md | 12 ++++++------ 3 files changed, 8 insertions(+), 8 deletions(-) diff --git a/ajet/utils/launch_utils.py b/ajet/utils/launch_utils.py index a941dbbc..6a8c2a28 100644 --- a/ajet/utils/launch_utils.py +++ b/ajet/utils/launch_utils.py @@ -297,7 +297,7 @@ def verify_python_env(args, exp_config): time.sleep(5) raise ImportError(cause + " " + solution) elif args.backbone == "verl": - if not any([v in verl.__version__ for v in ["0.5.0.post", "0.5.0.dev", "0.7.0.post", "0.8.0.dev"]]): # you must install via `pip install -e .[verl]` to get every dependency right + if not any([v in verl.__version__ for v in ["0.5.0.post", "0.5.0.dev", "0.7.0.post", "0.7.1", "0.8.0.dev"]]): # you must install via `pip install -e .[verl]` to get every dependency right cause = "Python environment does not match current backbone 'verl'." solution = "Please `cd /path/to/project/AgentJet` and run `(uv) pip install -e .[verl]` to install the correct environment." print_dict( diff --git a/pyproject.toml b/pyproject.toml index 73d89cd9..75a69215 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -39,7 +39,7 @@ dependencies = [ [project.optional-dependencies] verl = [ - "verl[vllm] @ git+https://github.com/volcengine/verl.git@016c1d5a7a3f2973d68fda2f7abe5e7df9e05e00", + "verl[vllm]==0.7.1", "vllm==0.11.0", "transformers==4.57.6" ] diff --git a/tests/bench/README.md b/tests/bench/README.md index c35005d3..c9ada294 100644 --- a/tests/bench/README.md +++ b/tests/bench/README.md @@ -11,7 +11,7 @@ Note: `tests/bench` source code is for test robot only, therefore `yaml` configu # prepare dataset path # prepare swanlab api -source .verl/bin/activate +source .venv/bin/activate python -m pytest -s tests/bench/benchmark_math/execute_benchmark_math.py python -m pytest -s tests/bench/benchmark_appworld/execute_benchmark_appworld.py @@ -19,11 +19,11 @@ python -m pytest -s tests/bench/benchmark_countdown/execute_benchmark_countdown. python -m pytest -s tests/bench/benchmark_learn2ask/execute_benchmark_learn2ask.py python -m pytest -s tests/bench/benchmark_frozenlake/execute_benchmark_frozenlake.py -VERL_PYTHON="./.verl/bin/python" python -m pytest -s tests/bench/benchmark_math/execute_benchmark_math.py::TestBenchmarkMath::test_01_begin_verl -VERL_PYTHON="./.verl/bin/python" python -m pytest -s tests/bench/benchmark_appworld/execute_benchmark_appworld.py::TestBenchmarkAppworld::test_01_begin_verl -VERL_PYTHON="./.verl/bin/python" python -m pytest -s tests/bench/benchmark_countdown/execute_benchmark_countdown.py::TestBenchmarkCountdown::test_01_begin_verl -VERL_PYTHON="./.verl/bin/python" python -m pytest -s tests/bench/benchmark_learn2ask/execute_benchmark_learn2ask.py::TestBenchmarkLearnToAsk::test_01_begin_verl -VERL_PYTHON="./.verl/bin/python" python -m pytest -s tests/bench/benchmark_frozenlake/execute_benchmark_frozenlake.py::TestBenchmarkFrozenLake::test_01_begin_verl +VERL_PYTHON="./.venv/bin/python" python -m pytest -s tests/bench/benchmark_math/execute_benchmark_math.py::TestBenchmarkMath::test_01_begin_verl +VERL_PYTHON="./.venv/bin/python" python -m pytest -s tests/bench/benchmark_appworld/execute_benchmark_appworld.py::TestBenchmarkAppworld::test_01_begin_verl +VERL_PYTHON="./.venv/bin/python" python -m pytest -s tests/bench/benchmark_countdown/execute_benchmark_countdown.py::TestBenchmarkCountdown::test_01_begin_verl +VERL_PYTHON="./.venv/bin/python" python -m pytest -s tests/bench/benchmark_learn2ask/execute_benchmark_learn2ask.py::TestBenchmarkLearnToAsk::test_01_begin_verl +VERL_PYTHON="./.venv/bin/python" python -m pytest -s tests/bench/benchmark_frozenlake/execute_benchmark_frozenlake.py::TestBenchmarkFrozenLake::test_01_begin_verl export APPWORLD_PATH="/dev/shm/pack_all_in_one" From 0c96884a936511b4bf3a4f94592273714aa18b78 Mon Sep 17 00:00:00 2001 From: Qingxu Fu Date: Wed, 18 Mar 2026 15:54:14 +0800 Subject: [PATCH 09/17] supress tool parser warnings --- ajet/backbone/warm_up.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/ajet/backbone/warm_up.py b/ajet/backbone/warm_up.py index 6e261b6a..21e0e78f 100644 --- a/ajet/backbone/warm_up.py +++ b/ajet/backbone/warm_up.py @@ -38,6 +38,8 @@ def init_parallel_rollout_logger(experiment_name, experiment_dir): target_logger = logging.getLogger("vllm.entrypoints.openai.tool_parsers.hermes_tool_parser") target_logger.setLevel(logging.CRITICAL) + target_logger = logging.getLogger("vllm.tool_parsers.hermes_tool_parser") + target_logger.setLevel(logging.CRITICAL) logging.getLogger("httpx").setLevel(logging.WARNING) From 01d3fa3b5f7fbc2f8782d146fa888823c17cf6dc Mon Sep 17 00:00:00 2001 From: Qingxu Fu Date: Wed, 18 Mar 2026 17:42:24 +0800 Subject: [PATCH 10/17] align benchmarking hyper-parameters --- ajet/backbone/verl/dp_actor.py | 4 ++-- ajet/default_config/ajet_default.yaml | 3 ++- ajet/default_config/verl/verl_default.yaml | 4 ++-- pyproject.toml | 2 +- tests/bench/benchmark_base.py | 3 +++ tests/bench/benchmark_math/benchmark_math.yaml | 2 +- 6 files changed, 11 insertions(+), 7 deletions(-) diff --git a/ajet/backbone/verl/dp_actor.py b/ajet/backbone/verl/dp_actor.py index 1f7ca3a9..2e9dc194 100644 --- a/ajet/backbone/verl/dp_actor.py +++ b/ajet/backbone/verl/dp_actor.py @@ -124,7 +124,7 @@ def update_policy(self, data: DataProto): advantages = model_inputs["advantages"] # [AJET] Debug logging for tensor shapes input_ids = model_inputs["input_ids"] - print(f'*** Current tensor shape, input_ids {input_ids.shape}, response {response_mask.shape}') + print(f'-> Current tensor shape, input_ids {input_ids.shape}, response {response_mask.shape}') entropy_coeff = self.config.entropy_coeff loss_agg_mode = self.config.loss_agg_mode @@ -220,7 +220,7 @@ def update_policy(self, data: DataProto): metrics["actor/pg_loss"] += pg_loss.detach().item() * loss_scale_factor append_to_dict(metrics, micro_batch_metrics) - + print(f'-> optimizer_step !') grad_norm = self._optimizer_step() mini_batch_metrics = {"actor/grad_norm": grad_norm.detach().item()} append_to_dict(metrics, mini_batch_metrics) diff --git a/ajet/default_config/ajet_default.yaml b/ajet/default_config/ajet_default.yaml index a880e049..4017d1b2 100644 --- a/ajet/default_config/ajet_default.yaml +++ b/ajet/default_config/ajet_default.yaml @@ -72,8 +72,9 @@ ajet: num_repeat: 4 # rollout kwargs - temperature: 0.9 + temperature: 1.0 top_p: 1.0 + top_k: -1 # validation kwargs val_kwargs: diff --git a/ajet/default_config/verl/verl_default.yaml b/ajet/default_config/verl/verl_default.yaml index 16952ab4..8d1c3bea 100644 --- a/ajet/default_config/verl/verl_default.yaml +++ b/ajet/default_config/verl/verl_default.yaml @@ -257,12 +257,12 @@ actor_rollout_ref: pipeline_model_parallel_size: 1 max_num_batched_tokens: 8192 max_model_len: null - max_num_seqs: 1024 + max_num_seqs: 10 enable_chunked_prefill: false enable_prefix_caching: false logprobs_mode: processed_logprobs scheduling_policy: fcfs - load_format: dummy + load_format: auto log_prob_micro_batch_size: null log_prob_micro_batch_size_per_gpu: 1 log_prob_use_dynamic_bsz: true diff --git a/pyproject.toml b/pyproject.toml index 75a69215..a128c4a5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -17,7 +17,7 @@ classifiers = [ ] requires-python = ">=3.10,<3.13" dependencies = [ - "agentscope==1.0.8", + "agentscope==1.0.7", "chromadb", "httpx[http2]", "tenacity", diff --git a/tests/bench/benchmark_base.py b/tests/bench/benchmark_base.py index 7b6815f3..9eecf998 100644 --- a/tests/bench/benchmark_base.py +++ b/tests/bench/benchmark_base.py @@ -28,6 +28,9 @@ def execute_benchmark( enable_ray_for_trinity: bool = True, ) -> None: """Run a benchmark with shared boilerplate for setup and process management.""" + import agentscope + assert agentscope.__version__ == "1.0.7", "AgentScope has too many bugs across versions, please use version 1.0.7 for werewolves example." + workspace_dir = Path(__file__).resolve().parents[2] git_hash, req_txt = populate_test_env_metadata(str(workspace_dir)) diff --git a/tests/bench/benchmark_math/benchmark_math.yaml b/tests/bench/benchmark_math/benchmark_math.yaml index b7689c91..c974bf4b 100644 --- a/tests/bench/benchmark_math/benchmark_math.yaml +++ b/tests/bench/benchmark_math/benchmark_math.yaml @@ -48,7 +48,7 @@ ajet: total_epochs: 100 logger: swanlab nnodes: 1 - n_gpus_per_node: 8 + n_gpus_per_node: 4 From 8f59278b366c209c8c288ad55a851d03ac5e8020 Mon Sep 17 00:00:00 2001 From: "qingxu.fu" Date: Wed, 18 Mar 2026 17:51:08 +0800 Subject: [PATCH 11/17] remove outdated argument --- tests/bench/benchmark_countdown/benchmark_countdown.yaml | 1 - tutorial/example_countdown/countdown.yaml | 1 - 2 files changed, 2 deletions(-) diff --git a/tests/bench/benchmark_countdown/benchmark_countdown.yaml b/tests/bench/benchmark_countdown/benchmark_countdown.yaml index 53cdd902..5986495a 100644 --- a/tests/bench/benchmark_countdown/benchmark_countdown.yaml +++ b/tests/bench/benchmark_countdown/benchmark_countdown.yaml @@ -77,7 +77,6 @@ ajet: top_k: -1 top_p: 1.0 do_sample: False - num_repeat: 1 task_reader: diff --git a/tutorial/example_countdown/countdown.yaml b/tutorial/example_countdown/countdown.yaml index 6dcadf81..90a39ff2 100644 --- a/tutorial/example_countdown/countdown.yaml +++ b/tutorial/example_countdown/countdown.yaml @@ -77,7 +77,6 @@ ajet: top_k: -1 top_p: 1.0 do_sample: False - num_repeat: 1 task_reader: From 60acbdfe9cbf074feec1c3688536b47a4e30b38d Mon Sep 17 00:00:00 2001 From: Qingxu Fu Date: Wed, 18 Mar 2026 22:26:01 +0800 Subject: [PATCH 12/17] add a loss scaler --- ajet/backbone/verl/dp_actor.py | 5 ++++- ajet/default_config/ajet_default.yaml | 3 +++ ajet/default_config/verl/config_auto_convertion_verl.jsonc | 3 +++ ajet/default_config/verl/verl_default.yaml | 1 + tests/bench/README.md | 4 ++++ tests/bench/benchmark_appworld/benchmark_appworld.yaml | 3 +++ tests/bench/benchmark_countdown/benchmark_countdown.yaml | 3 ++- tests/bench/benchmark_frozenlake/benchmark_frozenlake.yaml | 3 ++- tests/bench/benchmark_learn2ask/benchmark_learn2ask.yaml | 3 ++- tests/bench/benchmark_math/benchmark_math.yaml | 5 +++-- 10 files changed, 27 insertions(+), 6 deletions(-) diff --git a/ajet/backbone/verl/dp_actor.py b/ajet/backbone/verl/dp_actor.py index 2e9dc194..5ab28d9b 100644 --- a/ajet/backbone/verl/dp_actor.py +++ b/ajet/backbone/verl/dp_actor.py @@ -131,10 +131,13 @@ def update_policy(self, data: DataProto): calculate_entropy = self.config.calculate_entropy or (entropy_coeff != 0) - if self.config.use_dynamic_bsz: + if self.config.override_ppo_mini_batch_num > 0: + loss_scale_factor = response_mask.shape[0] / mini_batch_split_size + elif self.config.use_dynamic_bsz: loss_scale_factor = response_mask.shape[0] / self.config.ppo_mini_batch_size else: loss_scale_factor = 1 / self.gradient_accumulation + loss_scale_factor *= self.config.loss_extra_scale_ratio # [AJET] Extra scaling for loss if needed # all return: (bsz, response_length) outputs = self._forward_micro_batch( diff --git a/ajet/default_config/ajet_default.yaml b/ajet/default_config/ajet_default.yaml index 4017d1b2..e4ac8e69 100644 --- a/ajet/default_config/ajet_default.yaml +++ b/ajet/default_config/ajet_default.yaml @@ -288,6 +288,9 @@ ajet: kl_loss_coef: 0.002 kl_loss_type: low_var_kl + # loss = loss * loss_extra_scale_ratio + loss_extra_scale_ratio: 1.0 + # Ulysses specific configs ulysses_sequence_parallel_size: 1 diff --git a/ajet/default_config/verl/config_auto_convertion_verl.jsonc b/ajet/default_config/verl/config_auto_convertion_verl.jsonc index 90abdaef..7380e750 100644 --- a/ajet/default_config/verl/config_auto_convertion_verl.jsonc +++ b/ajet/default_config/verl/config_auto_convertion_verl.jsonc @@ -15,6 +15,7 @@ "ajet.trainer_common.kl_loss_coef": "actor_rollout_ref.actor.kl_loss_coef", "ajet.trainer_common.kl_loss_type": "actor_rollout_ref.actor.kl_loss_type", "ajet.trainer_common.ulysses_sequence_parallel_size": "actor_rollout_ref.actor.ulysses_sequence_parallel_size", + "ajet.trainer_common.loss_extra_scale_ratio": "actor_rollout_ref.actor.loss_extra_scale_ratio", "ajet.trainer_common.save_freq": "trainer.save_freq", "ajet.trainer_common.test_freq": "trainer.test_freq", @@ -30,6 +31,8 @@ "actor_rollout_ref.ref.log_prob_max_token_len_per_gpu" ], + "ajet.rollout.max_num_seqs": "actor_rollout_ref.rollout.max_num_seqs", + "ajet.rollout.temperature": "actor_rollout_ref.rollout.temperature", "ajet.rollout.multi_turn": "actor_rollout_ref.rollout.multi_turn", "ajet.rollout.val_kwargs": "actor_rollout_ref.rollout.val_kwargs", "ajet.rollout.num_repeat": [ diff --git a/ajet/default_config/verl/verl_default.yaml b/ajet/default_config/verl/verl_default.yaml index 8d1c3bea..7cb2ce91 100644 --- a/ajet/default_config/verl/verl_default.yaml +++ b/ajet/default_config/verl/verl_default.yaml @@ -70,6 +70,7 @@ actor_rollout_ref: rollout_n: ${oc.select:actor_rollout_ref.rollout.n,1} strategy: fsdp ppo_mini_batch_size: 256 + loss_extra_scale_ratio: 1.0 override_ppo_mini_batch_num: 1 # special in agentjet ppo_micro_batch_size: null ppo_micro_batch_size_per_gpu: null diff --git a/tests/bench/README.md b/tests/bench/README.md index c9ada294..2097a0b0 100644 --- a/tests/bench/README.md +++ b/tests/bench/README.md @@ -25,9 +25,13 @@ VERL_PYTHON="./.venv/bin/python" python -m pytest -s tests/bench/benchmark_count VERL_PYTHON="./.venv/bin/python" python -m pytest -s tests/bench/benchmark_learn2ask/execute_benchmark_learn2ask.py::TestBenchmarkLearnToAsk::test_01_begin_verl VERL_PYTHON="./.venv/bin/python" python -m pytest -s tests/bench/benchmark_frozenlake/execute_benchmark_frozenlake.py::TestBenchmarkFrozenLake::test_01_begin_verl +python -m ajet.launcher --conf tests/bench/benchmark_math/benchmark_math.yaml --autokill --db="UPP" export APPWORLD_PATH="/dev/shm/pack_all_in_one" export APPWORLD_SCRIPT="bash EnvService/env_sandbox/appworld.sh" python -m ajet.launcher --conf tests/bench/benchmark_appworld/benchmark_appworld.yaml --with-appworld --backbone=debug --autokill python -m ajet.launcher --conf tests/bench/benchmark_appworld/benchmark_appworld_oai_sdk.yaml --with-appworld --autokill --db="EXT" ``` + + +VERL_PYTHON="./.venv/bin/python" python -m pytest -s tests/bench/benchmark_math/execute_benchmark_math.py::TestBenchmarkMath::test_01_begin_verl diff --git a/tests/bench/benchmark_appworld/benchmark_appworld.yaml b/tests/bench/benchmark_appworld/benchmark_appworld.yaml index 3622ed1b..730a128f 100644 --- a/tests/bench/benchmark_appworld/benchmark_appworld.yaml +++ b/tests/bench/benchmark_appworld/benchmark_appworld.yaml @@ -47,12 +47,15 @@ ajet: max_prompt_length: 3000 max_response_length: 15000 + # trainer common configurations trainer_common: save_freq: 99999 test_freq: 99999 total_epochs: 99999 nnodes: 1 n_gpus_per_node: 8 + # loss = loss * loss_extra_scale_ratio + loss_extra_scale_ratio: 10.0 execute_test: True # DO NOT EDIT, THIS IS FOR TEST ROBOT execute_testing_lambda: "tests/bench/benchmark_appworld/benchmark_appworld.py->TestProbe" # diff --git a/tests/bench/benchmark_countdown/benchmark_countdown.yaml b/tests/bench/benchmark_countdown/benchmark_countdown.yaml index 5986495a..defc7eaa 100644 --- a/tests/bench/benchmark_countdown/benchmark_countdown.yaml +++ b/tests/bench/benchmark_countdown/benchmark_countdown.yaml @@ -116,7 +116,8 @@ ajet: kl_loss_coef: 0.002 kl_loss_type: low_var_kl ulysses_sequence_parallel_size: 1 - + # loss = loss * loss_extra_scale_ratio + loss_extra_scale_ratio: 10.0 # DO NOT EDIT, FOR ROBOT TESTING PURPOSE ONLY. NOT FOR HUMAN. execute_test: True # FOR ROBOT TESTING PURPOSE ONLY. NOT FOR HUMAN. diff --git a/tests/bench/benchmark_frozenlake/benchmark_frozenlake.yaml b/tests/bench/benchmark_frozenlake/benchmark_frozenlake.yaml index 1e08d03c..e82c657b 100644 --- a/tests/bench/benchmark_frozenlake/benchmark_frozenlake.yaml +++ b/tests/bench/benchmark_frozenlake/benchmark_frozenlake.yaml @@ -69,7 +69,8 @@ ajet: nnodes: 1 n_gpus_per_node: 8 logger: swanlab - + # loss = loss * loss_extra_scale_ratio + loss_extra_scale_ratio: 10.0 execute_test: True execute_testing_lambda: "tests/bench/benchmark_frozenlake/benchmark_frozenlake.py->TestProbe" diff --git a/tests/bench/benchmark_learn2ask/benchmark_learn2ask.yaml b/tests/bench/benchmark_learn2ask/benchmark_learn2ask.yaml index b435a0ac..efe6ce17 100644 --- a/tests/bench/benchmark_learn2ask/benchmark_learn2ask.yaml +++ b/tests/bench/benchmark_learn2ask/benchmark_learn2ask.yaml @@ -45,7 +45,8 @@ ajet: test_freq: 100 total_epochs: 100 logger: swanlab - + # loss = loss * loss_extra_scale_ratio + loss_extra_scale_ratio: 10.0 execute_test: True # DO NOT EDIT, THIS IS FOR TEST ROBOT execute_testing_lambda: "tests/bench/benchmark_learn2ask/benchmark_learn2ask.py->TestProbe" # DO NOT EDIT, THIS IS FOR TEST ROBOT diff --git a/tests/bench/benchmark_math/benchmark_math.yaml b/tests/bench/benchmark_math/benchmark_math.yaml index c974bf4b..831f2775 100644 --- a/tests/bench/benchmark_math/benchmark_math.yaml +++ b/tests/bench/benchmark_math/benchmark_math.yaml @@ -21,7 +21,7 @@ ajet: user_workflow: "tutorial.example_math_agent.math_agent->ExampleMathLearn" # ✨✨✨✨ 编写并选择Agent temperature: 1.0 max_env_worker: 64 - max_num_seqs: 256 + max_num_seqs: 10 num_repeat: 6 agent_madness_reward: 0.0 tensor_model_parallel_size: 1 @@ -49,7 +49,8 @@ ajet: logger: swanlab nnodes: 1 n_gpus_per_node: 4 - + # loss = loss * loss_extra_scale_ratio + loss_extra_scale_ratio: 40.0 execute_test: True # DO NOT EDIT, THIS IS FOR TEST ROBOT From 563d5703f2cf73f1c12c352ac0f021206eb5f256 Mon Sep 17 00:00:00 2001 From: "qingxu.fu" Date: Sat, 21 Mar 2026 23:55:32 +0800 Subject: [PATCH 13/17] fix broken arguments --- ajet/default_config/ajet_default.yaml | 2 +- ajet/default_config/verl/verl_default.yaml | 24 +++++++++---------- .../benchmark_learn2ask.yaml | 2 +- .../bench/benchmark_math/benchmark_math.yaml | 2 +- .../benchmark_math_oai_sdk.yaml | 2 +- .../benchmark_math_raw_http.yaml | 2 +- 6 files changed, 17 insertions(+), 17 deletions(-) diff --git a/ajet/default_config/ajet_default.yaml b/ajet/default_config/ajet_default.yaml index e4ac8e69..d52f4ec6 100644 --- a/ajet/default_config/ajet_default.yaml +++ b/ajet/default_config/ajet_default.yaml @@ -72,7 +72,7 @@ ajet: num_repeat: 4 # rollout kwargs - temperature: 1.0 + temperature: 0.9 top_p: 1.0 top_k: -1 diff --git a/ajet/default_config/verl/verl_default.yaml b/ajet/default_config/verl/verl_default.yaml index 7cb2ce91..0753478b 100644 --- a/ajet/default_config/verl/verl_default.yaml +++ b/ajet/default_config/verl/verl_default.yaml @@ -44,7 +44,8 @@ actor_rollout_ref: reshard_after_forward: true fsdp_size: -1 forward_prefetch: false - model_dtype: bfloat16 + model_dtype: fp32 # Model data type used to initialize the transformers model. default "fp32" + dtype: bfloat16 # dtype (str): Mixed precision training param dtype, default "bfloat16" use_orig_params: false seed: 42 full_determinism: false @@ -54,7 +55,6 @@ actor_rollout_ref: entropy_checkpointing: false forward_only: false strategy: fsdp - dtype: bfloat16 qat: _target_: verl.workers.config.QATEngineConfig enable: false @@ -69,7 +69,7 @@ actor_rollout_ref: _target_: verl.workers.config.FSDPActorConfig rollout_n: ${oc.select:actor_rollout_ref.rollout.n,1} strategy: fsdp - ppo_mini_batch_size: 256 + ppo_mini_batch_size: 16 loss_extra_scale_ratio: 1.0 override_ppo_mini_batch_num: 1 # special in agentjet ppo_micro_batch_size: null @@ -91,12 +91,12 @@ actor_rollout_ref: kl_cov_ratio: 0.0002 ppo_kl_coef: 0.1 clip_ratio_c: 3.0 - loss_agg_mode: token-mean + loss_agg_mode: seq-mean-token-mean loss_scale_factor: null global_batch_info: {} entropy_coeff: 0 calculate_entropy: false - use_kl_loss: false + use_kl_loss: true use_prefix_grouper: false use_torch_compile: true kl_loss_coef: 0.001 @@ -148,7 +148,7 @@ actor_rollout_ref: ulysses_sequence_parallel_size: 1 entropy_from_logits_with_chunking: false entropy_checkpointing: false - use_remove_padding: ${oc.select:actor_rollout_ref.model.use_remove_padding,false} + use_remove_padding: true calculate_sum_pi_squared: false sum_pi_squared_checkpointing: false qat: @@ -209,7 +209,8 @@ actor_rollout_ref: reshard_after_forward: true fsdp_size: -1 forward_prefetch: false - model_dtype: bfloat16 + model_dtype: fp32 # model_dtype (str): Model data type used to initialize the transformers model. default "fp32" + dtype: bfloat16 # dtype (str): Mixed precision training param dtype, default "bfloat16" use_orig_params: false seed: 42 full_determinism: false @@ -219,7 +220,6 @@ actor_rollout_ref: entropy_checkpointing: false forward_only: true strategy: fsdp - dtype: bfloat16 qat: _target_: verl.workers.config.QATEngineConfig enable: false @@ -241,7 +241,7 @@ actor_rollout_ref: mode: async nnodes: 0 n_gpus_per_node: ${oc.select:trainer.n_gpus_per_node,8} - temperature: 1.0 + temperature: 0.9 top_k: -1 top_p: 1 prompt_length: ${oc.select:data.max_prompt_length,512} @@ -464,17 +464,17 @@ critic: reshard_after_forward: true fsdp_size: -1 forward_prefetch: false - model_dtype: bfloat16 + model_dtype: fp32 # model_dtype (str): Model data type used to initialize the transformers model. default "fp32" + dtype: bfloat16 # dtype (str): Mixed precision training param dtype, default "bfloat16" use_orig_params: false seed: 42 full_determinism: false ulysses_sequence_parallel_size: 1 entropy_from_logits_with_chunking: false use_torch_compile: true + strategy: fsdp entropy_checkpointing: false forward_only: false - strategy: fsdp - dtype: bfloat16 qat: _target_: verl.workers.config.QATEngineConfig enable: false diff --git a/tests/bench/benchmark_learn2ask/benchmark_learn2ask.yaml b/tests/bench/benchmark_learn2ask/benchmark_learn2ask.yaml index efe6ce17..5db93f9a 100644 --- a/tests/bench/benchmark_learn2ask/benchmark_learn2ask.yaml +++ b/tests/bench/benchmark_learn2ask/benchmark_learn2ask.yaml @@ -15,7 +15,7 @@ ajet: rollout: user_workflow: "tutorial.example_learn2ask.learn2ask->ExampleLearn2Ask" force_disable_toolcalls: True - temperature: 1.0 + temperature: 0.9 max_env_worker: 64 num_repeat: 6 tensor_model_parallel_size: 1 diff --git a/tests/bench/benchmark_math/benchmark_math.yaml b/tests/bench/benchmark_math/benchmark_math.yaml index 831f2775..9024364c 100644 --- a/tests/bench/benchmark_math/benchmark_math.yaml +++ b/tests/bench/benchmark_math/benchmark_math.yaml @@ -19,7 +19,7 @@ ajet: rollout: user_workflow: "tutorial.example_math_agent.math_agent->ExampleMathLearn" # ✨✨✨✨ 编写并选择Agent - temperature: 1.0 + temperature: 0.9 max_env_worker: 64 max_num_seqs: 10 num_repeat: 6 diff --git a/tests/bench/benchmark_math/benchmark_math_oai_sdk.yaml b/tests/bench/benchmark_math/benchmark_math_oai_sdk.yaml index e7bf0aba..c994ce83 100644 --- a/tests/bench/benchmark_math/benchmark_math_oai_sdk.yaml +++ b/tests/bench/benchmark_math/benchmark_math_oai_sdk.yaml @@ -18,7 +18,7 @@ ajet: rollout: user_workflow: "tutorial.example_math_agent.math_agent_oai_sdk->ExampleMathLearn" # ✨✨✨✨ 编写并选择Agent - temperature: 1.0 + temperature: 0.9 max_env_worker: 64 max_num_seqs: 256 num_repeat: 6 diff --git a/tests/bench/benchmark_math/benchmark_math_raw_http.yaml b/tests/bench/benchmark_math/benchmark_math_raw_http.yaml index 88c9aa15..1fe812cb 100644 --- a/tests/bench/benchmark_math/benchmark_math_raw_http.yaml +++ b/tests/bench/benchmark_math/benchmark_math_raw_http.yaml @@ -18,7 +18,7 @@ ajet: rollout: user_workflow: "tutorial.example_math_agent.math_agent_raw_http->ExampleMathLearn" # ✨✨✨✨ 编写并选择Agent - temperature: 1.0 + temperature: 0.9 max_env_worker: 64 max_num_seqs: 256 num_repeat: 6 From 3a92c5384df02f333dc3501073cb7b3448f5a2de Mon Sep 17 00:00:00 2001 From: "qingxu.fu" Date: Sun, 22 Mar 2026 00:49:37 +0800 Subject: [PATCH 14/17] rename source code --- ajet/backbone/trainer_verl.py | 29 +++++++++++++++++++ ajet/copilot/job.py | 8 ++--- ajet/copilot/train-complex-blackbox/SKILL.md | 2 +- ajet/copilot/write-swarm-client/SKILL.md | 2 +- ...{ajet_default.py => ajet_config_schema.py} | 0 ...s_default.yaml => ajet_swarm_default.yaml} | 0 ajet/launcher.py | 2 +- ajet/swarm_cli.py | 2 +- docs/en/swarm_best_practice.md | 2 +- docs/en/tune_your_first_agent.md | 4 +-- .../bench/benchmark_math/benchmark_math.yaml | 2 +- .../trans_roll.py | 2 +- .../frozen_lake_roll.py | 2 +- .../frozen_lake_roll_2_models.py | 2 +- tutorial/example_math_swarm/math.py | 2 +- .../example_train_multi_model/trans_roll.py | 2 +- .../example_werewolves_swarm/agent_roll.py | 2 +- .../example_werewolves_swarm/convert_skill.md | 4 +-- .../agent_roll.py | 2 +- 19 files changed, 50 insertions(+), 21 deletions(-) rename ajet/default_config/{ajet_default.py => ajet_config_schema.py} (100%) rename ajet/default_config/{ajet_ts_default.yaml => ajet_swarm_default.yaml} (100%) diff --git a/ajet/backbone/trainer_verl.py b/ajet/backbone/trainer_verl.py index 93696434..74f28a85 100644 --- a/ajet/backbone/trainer_verl.py +++ b/ajet/backbone/trainer_verl.py @@ -144,6 +144,35 @@ def union_gen_batch_via_task_id(tasks, batch: DataProto, gen_batch_output: DataP logger.info(f'task_id_counter: {task_id_counter}') return gen_batch_output +def import_or_export_data_proto(batch: DataProto, direction: str = "export", file: str = "./tmp.pkl") -> DataProto: + """Import or export a DataProto batch to/from a pickle file. + + Args: + batch: The DataProto batch object. Used when direction is "export"; + ignored (can be None) when direction is "import". + direction: "import" to load a batch from file, "export" to save the batch to file. + file: Path to the pickle file. Defaults to "./tmp.pkl". + + Returns: + The DataProto batch — either the one just loaded (import) or the one just saved (export). + + Raises: + ValueError: If direction is not "import" or "export". + FileNotFoundError: If direction is "import" and the file does not exist. + """ + import pickle + if direction == "export": + with open(file, "wb") as f: + pickle.dump(batch, f) + logger.info(f"[import_or_export_data_proto] Exported batch to {file}") + return batch + elif direction == "import": + with open(file, "rb") as f: + batch = pickle.load(f) + logger.info(f"[import_or_export_data_proto] Imported batch from {file}") + return batch + else: + raise ValueError(f"direction must be 'import' or 'export', got '{direction}'") def compute_advantage( data: DataProto, diff --git a/ajet/copilot/job.py b/ajet/copilot/job.py index 86c07aeb..50116c04 100644 --- a/ajet/copilot/job.py +++ b/ajet/copilot/job.py @@ -13,7 +13,7 @@ from typing import Any, Callable, Union, cast from loguru import logger -from ajet.default_config.ajet_default import Config +from ajet.default_config.ajet_config_schema import Config from ajet.utils.config_utils import ( expand_ajet_hierarchical_config, read_ajet_hierarchical_config, @@ -42,7 +42,7 @@ class AgentJetJob: """Programmatic interface for configuring and launching AgentJet training jobs. Args: - base_yaml_config: Path to base YAML configuration file. If None, uses default config (at ./ajet/default_config/ajet_ts_default.yaml). + base_yaml_config: Path to base YAML configuration file. If None, uses default config (at ./ajet/default_config/ajet_swarm_default.yaml). experiment_dir: Directory where experiment outputs will be saved. project_name: Name of the project for organizing experiments. experiment_name: Unique name for this specific experiment run. @@ -86,7 +86,7 @@ def __init__( ) -> None: if base_yaml_config is None: - base_yaml_config = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', "default_config/ajet_ts_default.yaml")) + base_yaml_config = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', "default_config/ajet_swarm_default.yaml")) else: logger.warning(f"Reading config from {base_yaml_config}.") time.sleep(1) @@ -121,7 +121,7 @@ def __init__( self.max_model_len: int = cast(int, max_model_len) self.mini_batch_num: int = cast(int, mini_batch_num) - # see `ajet/default_config/ajet_ts_default.yaml` + # see `ajet/default_config/ajet_swarm_default.yaml` overrides = { # left: [yaml key navigation] right: [AgentJetJob self attr] "ajet.experiment_dir": "experiment_dir", diff --git a/ajet/copilot/train-complex-blackbox/SKILL.md b/ajet/copilot/train-complex-blackbox/SKILL.md index c33b527c..a56d8bec 100644 --- a/ajet/copilot/train-complex-blackbox/SKILL.md +++ b/ajet/copilot/train-complex-blackbox/SKILL.md @@ -54,7 +54,7 @@ from ajet.copilot.job import AgentJetJob from ajet.task_reader import RouterTaskReader from ajet.utils.thread_executors import PeriodicDrainThreadPoolExecutor from ajet.tuner_lib.as_oai_baseurl_apikey import OpenaiBaseUrlAndApiKey -from ajet.default_config.ajet_default import AjetTaskReader, HuggingfaceDatRepo +from ajet.default_config.ajet_config_schema import AjetTaskReader, HuggingfaceDatRepo from ajet.tuner_lib.experimental.swarm_client import SwarmClient # python -m tutorial.example_math_swarm.math diff --git a/ajet/copilot/write-swarm-client/SKILL.md b/ajet/copilot/write-swarm-client/SKILL.md index 3fd6fc4c..516b4fc1 100644 --- a/ajet/copilot/write-swarm-client/SKILL.md +++ b/ajet/copilot/write-swarm-client/SKILL.md @@ -365,7 +365,7 @@ Below are some reference materials. ```python from ajet.copilot.job import AgentJetJob from ajet.tuner_lib.experimental.swarm_client import SwarmClient, run_episodes_until_all_complete - from ajet.default_config.ajet_default import AjetTaskReader, HuggingfaceDatRepo + from ajet.default_config.ajet_config_schema import AjetTaskReader, HuggingfaceDatRepo from ajet.task_reader import RouterTaskReader from tutorial.example_academic_trans_swarm.trans import execute_agent diff --git a/ajet/default_config/ajet_default.py b/ajet/default_config/ajet_config_schema.py similarity index 100% rename from ajet/default_config/ajet_default.py rename to ajet/default_config/ajet_config_schema.py diff --git a/ajet/default_config/ajet_ts_default.yaml b/ajet/default_config/ajet_swarm_default.yaml similarity index 100% rename from ajet/default_config/ajet_ts_default.yaml rename to ajet/default_config/ajet_swarm_default.yaml diff --git a/ajet/launcher.py b/ajet/launcher.py index 71d8683e..c30ec2e5 100644 --- a/ajet/launcher.py +++ b/ajet/launcher.py @@ -207,7 +207,7 @@ def main(): if args.swarm_server and (not args.conf): args.conf = os.path.abspath( os.path.join( - os.path.dirname(__file__), "default_config/ajet_ts_default.yaml" + os.path.dirname(__file__), "default_config/ajet_swarm_default.yaml" ) ) assert os.path.exists(args.conf), ( diff --git a/ajet/swarm_cli.py b/ajet/swarm_cli.py index 7f65e9f4..2716fb11 100644 --- a/ajet/swarm_cli.py +++ b/ajet/swarm_cli.py @@ -46,7 +46,7 @@ def cmd_start(args): if not args.conf: args.conf = os.path.abspath( os.path.join( - os.path.dirname(__file__), "default_config/ajet_ts_default.yaml" + os.path.dirname(__file__), "default_config/ajet_swarm_default.yaml" ) ) assert os.path.exists(args.conf), ( diff --git a/docs/en/swarm_best_practice.md b/docs/en/swarm_best_practice.md index 3557a110..4220baa2 100644 --- a/docs/en/swarm_best_practice.md +++ b/docs/en/swarm_best_practice.md @@ -133,7 +133,7 @@ Hint: you do not have to use `run_episodes_until_all_complete`, you are free to ```python from ajet.copilot.job import AgentJetJob from ajet.tuner_lib.experimental.swarm_client import SwarmClient, run_episodes_until_all_complete -from ajet.default_config.ajet_default import AjetTaskReader, HuggingfaceDatRepo +from ajet.default_config.ajet_config_schema import AjetTaskReader, HuggingfaceDatRepo from ajet.task_reader import RouterTaskReader from tutorial.example_academic_trans_swarm.trans import execute_agent diff --git a/docs/en/tune_your_first_agent.md b/docs/en/tune_your_first_agent.md index 1103a11c..0ee70212 100644 --- a/docs/en/tune_your_first_agent.md +++ b/docs/en/tune_your_first_agent.md @@ -496,7 +496,7 @@ Create your client script. The client reads the dataset, runs the agent workflow from ajet.task_reader import RouterTaskReader from ajet.utils.thread_executors import PeriodicDrainThreadPoolExecutor from ajet.tuner_lib.as_oai_baseurl_apikey import OpenaiBaseUrlAndApiKey - from ajet.default_config.ajet_default import AjetTaskReader, HuggingfaceDatRepo + from ajet.default_config.ajet_config_schema import AjetTaskReader, HuggingfaceDatRepo from ajet.tuner_lib.experimental.swarm_client import SwarmClient # Configuration @@ -649,7 +649,7 @@ The server handles gradient computation and model updates automatically. from ajet.task_reader import RouterTaskReader from ajet.utils.thread_executors import PeriodicDrainThreadPoolExecutor from ajet.tuner_lib.as_oai_baseurl_apikey import OpenaiBaseUrlAndApiKey - from ajet.default_config.ajet_default import AjetTaskReader, HuggingfaceDatRepo + from ajet.default_config.ajet_config_schema import AjetTaskReader, HuggingfaceDatRepo from ajet.tuner_lib.experimental.swarm_client import SwarmClient GRPO_N = 4 # grpo group size diff --git a/tests/bench/benchmark_math/benchmark_math.yaml b/tests/bench/benchmark_math/benchmark_math.yaml index 9024364c..d4d8ba8a 100644 --- a/tests/bench/benchmark_math/benchmark_math.yaml +++ b/tests/bench/benchmark_math/benchmark_math.yaml @@ -50,7 +50,7 @@ ajet: nnodes: 1 n_gpus_per_node: 4 # loss = loss * loss_extra_scale_ratio - loss_extra_scale_ratio: 40.0 + loss_extra_scale_ratio: 10.0 execute_test: True # DO NOT EDIT, THIS IS FOR TEST ROBOT diff --git a/tutorial/example_academic_trans_swarm/trans_roll.py b/tutorial/example_academic_trans_swarm/trans_roll.py index c6a55cb1..d7165c4b 100644 --- a/tutorial/example_academic_trans_swarm/trans_roll.py +++ b/tutorial/example_academic_trans_swarm/trans_roll.py @@ -1,6 +1,6 @@ from ajet.copilot.job import AgentJetJob from ajet.tuner_lib.experimental.swarm_client import SwarmClient, run_episodes_until_all_complete -from ajet.default_config.ajet_default import AjetTaskReader, HuggingfaceDatRepo +from ajet.default_config.ajet_config_schema import AjetTaskReader, HuggingfaceDatRepo from ajet.task_reader import RouterTaskReader from tutorial.example_academic_trans_swarm.trans import execute_agent diff --git a/tutorial/example_frozenlake_swarm/frozen_lake_roll.py b/tutorial/example_frozenlake_swarm/frozen_lake_roll.py index e3365f46..80f4bc62 100644 --- a/tutorial/example_frozenlake_swarm/frozen_lake_roll.py +++ b/tutorial/example_frozenlake_swarm/frozen_lake_roll.py @@ -1,6 +1,6 @@ from ajet.copilot.job import AgentJetJob from ajet.tuner_lib.experimental.swarm_client import SwarmClient, run_episodes_until_all_complete -from ajet.default_config.ajet_default import AjetTaskReader +from ajet.default_config.ajet_config_schema import AjetTaskReader from ajet.task_reader import RouterTaskReader from .frozenlake import FrozenLake diff --git a/tutorial/example_frozenlake_swarm/frozen_lake_roll_2_models.py b/tutorial/example_frozenlake_swarm/frozen_lake_roll_2_models.py index 151274e7..8270650f 100644 --- a/tutorial/example_frozenlake_swarm/frozen_lake_roll_2_models.py +++ b/tutorial/example_frozenlake_swarm/frozen_lake_roll_2_models.py @@ -1,6 +1,6 @@ from ajet.copilot.job import AgentJetJob from ajet.tuner_lib.experimental.swarm_client import SwarmClient, run_episodes_until_all_complete -from ajet.default_config.ajet_default import AjetTaskReader +from ajet.default_config.ajet_config_schema import AjetTaskReader from ajet.task_reader import RouterTaskReader from .frozenlake import FrozenLake diff --git a/tutorial/example_math_swarm/math.py b/tutorial/example_math_swarm/math.py index 2174076a..541de25d 100644 --- a/tutorial/example_math_swarm/math.py +++ b/tutorial/example_math_swarm/math.py @@ -9,7 +9,7 @@ from ajet.task_reader import RouterTaskReader from ajet.utils.thread_executors import PeriodicDrainThreadPoolExecutor from ajet.tuner_lib.as_oai_baseurl_apikey import OpenaiBaseUrlAndApiKey -from ajet.default_config.ajet_default import AjetTaskReader, HuggingfaceDatRepo +from ajet.default_config.ajet_config_schema import AjetTaskReader, HuggingfaceDatRepo from ajet.tuner_lib.experimental.swarm_client import SwarmClient # python -m tutorial.example_math_swarm.math diff --git a/tutorial/example_train_multi_model/trans_roll.py b/tutorial/example_train_multi_model/trans_roll.py index 7e8e28a6..14fd96a3 100644 --- a/tutorial/example_train_multi_model/trans_roll.py +++ b/tutorial/example_train_multi_model/trans_roll.py @@ -1,6 +1,6 @@ from ajet.copilot.job import AgentJetJob from ajet.tuner_lib.experimental.swarm_client import SwarmClient, run_episodes_until_all_complete -from ajet.default_config.ajet_default import AjetTaskReader, HuggingfaceDatRepo +from ajet.default_config.ajet_config_schema import AjetTaskReader, HuggingfaceDatRepo from ajet.task_reader import RouterTaskReader from tutorial.example_academic_trans_swarm.trans import execute_agent diff --git a/tutorial/example_werewolves_swarm/agent_roll.py b/tutorial/example_werewolves_swarm/agent_roll.py index 5107ab8d..fc7f1486 100644 --- a/tutorial/example_werewolves_swarm/agent_roll.py +++ b/tutorial/example_werewolves_swarm/agent_roll.py @@ -6,7 +6,7 @@ from ajet.task_reader import RouterTaskReader from ajet.utils.thread_executors import PeriodicDrainThreadPoolExecutor from ajet.tuner_lib.as_oai_baseurl_apikey import OpenaiBaseUrlAndApiKey -from ajet.default_config.ajet_default import AjetTaskReader +from ajet.default_config.ajet_config_schema import AjetTaskReader from ajet.tuner_lib.experimental.swarm_client import SwarmClient NUM_EPOCH = 10000 diff --git a/tutorial/example_werewolves_swarm/convert_skill.md b/tutorial/example_werewolves_swarm/convert_skill.md index 9ad3ff16..6c641f02 100644 --- a/tutorial/example_werewolves_swarm/convert_skill.md +++ b/tutorial/example_werewolves_swarm/convert_skill.md @@ -1,8 +1,8 @@ 训练复杂智能体的时候,推荐先从yaml配置出发 -首先,复制一份基础配置 ajet/default_config/ajet_ts_default.yaml +首先,复制一份基础配置 ajet/default_config/ajet_swarm_default.yaml -cp ajet/default_config/ajet_ts_default.yaml tutorial/example_werewolves_swarm/werewolves.yaml +cp ajet/default_config/ajet_swarm_default.yaml tutorial/example_werewolves_swarm/werewolves.yaml 然后对配置中的参数进行修改: diff --git a/tutorial/opencode_build_countdown_agent/agent_roll.py b/tutorial/opencode_build_countdown_agent/agent_roll.py index d6b7e092..d6e732ca 100644 --- a/tutorial/opencode_build_countdown_agent/agent_roll.py +++ b/tutorial/opencode_build_countdown_agent/agent_roll.py @@ -20,7 +20,7 @@ SwarmClient, run_episodes_until_all_complete, ) -from ajet.default_config.ajet_default import ( +from ajet.default_config.ajet_config_schema import ( AjetTaskReader, JsonlDatasetFile, JsonlTrainingFp, From aa85dfc4c172009ce8cc7a4fea7ba80d8e4e746f Mon Sep 17 00:00:00 2001 From: "qingxu.fu" Date: Wed, 25 Mar 2026 14:33:11 +0800 Subject: [PATCH 15/17] fix dp seq balance bugs --- ajet/backbone/trainer_verl.py | 10 +- ajet/backbone/verl/__init__.py | 3 - ajet/backbone/verl/actor_config.py | 47 +- ajet/backbone/verl/dp_actor.py | 89 +- ajet/backbone/verl/fsdp_workers.py | 6 +- ajet/backbone/verl/seqlen_balancing.py | 626 ++++++++++ ajet/default_config/ajet_default.yaml | 2 +- ajet/default_config/verl/verl_default.yaml | 101 +- .../verl/verl_default_expand.yaml | 1066 +++++++++++++++++ scripts/expand_config_targets.py | 377 ++++++ .../bench/benchmark_math/benchmark_math.yaml | 2 +- 11 files changed, 2237 insertions(+), 92 deletions(-) create mode 100644 ajet/backbone/verl/seqlen_balancing.py create mode 100644 ajet/default_config/verl/verl_default_expand.yaml create mode 100644 scripts/expand_config_targets.py diff --git a/ajet/backbone/trainer_verl.py b/ajet/backbone/trainer_verl.py index 74f28a85..4d3b6b00 100644 --- a/ajet/backbone/trainer_verl.py +++ b/ajet/backbone/trainer_verl.py @@ -810,7 +810,7 @@ def _validate(self): # repeat test batch test_batch = test_batch.repeat( - repeat_times=self.config.ajet.rollout.val_kwargs.num_repeat, + repeat_times=self.config.ajet.trainer_common.val_pass_n, interleave=True, ) @@ -858,10 +858,10 @@ def _validate(self): logger.info(f"test_gen_batch meta info: {test_gen_batch.meta_info}") self.checkpoint_manager.update_weights(self.global_steps) - main_val_dataset = self.get_eval_dataset() + main_val_dataset = self.get_val_dataset() logger.info("Starting validate rollout") - context_tracker_arr, tasks, val_metrics = self.eval_dataset( + context_tracker_arr, tasks, val_metrics = self._rollout_val_dataset( target_dataset=main_val_dataset, target_dataset_name="main_val_dataset", mode="validate", @@ -920,7 +920,7 @@ def _validate(self): return metric_dict - def eval_dataset(self, target_dataset, target_dataset_name, mode, epoch): + def _rollout_val_dataset(self, target_dataset, target_dataset_name, mode, epoch): """ Evaluate a dataset by running rollouts and computing task completion metrics. @@ -1005,7 +1005,7 @@ def eval_dataset(self, target_dataset, target_dataset_name, mode, epoch): return ctx_trackers, tasks, val_metrics - def get_eval_dataset(self): + def get_val_dataset(self): from ajet.task_reader import RouterTaskReader task_reader = RouterTaskReader( diff --git a/ajet/backbone/verl/__init__.py b/ajet/backbone/verl/__init__.py index a92acb07..64aa84c6 100644 --- a/ajet/backbone/verl/__init__.py +++ b/ajet/backbone/verl/__init__.py @@ -1,11 +1,8 @@ from .fsdp_workers import AjetActorRolloutRefWorker, AjetAsyncActorRolloutRefWorker -from .actor_config import AjetActorConfig, AjetFSDPActorConfig from .dp_actor import AjetDataParallelPPOActor __all__ = [ "AjetActorRolloutRefWorker", "AjetAsyncActorRolloutRefWorker", - "AjetActorConfig", - "AjetFSDPActorConfig", "AjetDataParallelPPOActor", ] diff --git a/ajet/backbone/verl/actor_config.py b/ajet/backbone/verl/actor_config.py index 85315aba..297c722a 100644 --- a/ajet/backbone/verl/actor_config.py +++ b/ajet/backbone/verl/actor_config.py @@ -1,47 +1,8 @@ -# Copyright 2025 Alibaba Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" -Ajet extensions for verl ActorConfig. -Adds `override_ppo_mini_batch_num` field to control the number of optimizer steps per train-batch-step. -""" - +from verl.workers.config import FSDPActorConfig from dataclasses import dataclass, field -from typing import Optional - -from verl.workers.config.actor import ActorConfig, FSDPActorConfig - - -@dataclass -class AjetActorConfig(ActorConfig): - """ActorConfig extended with ajet-specific fields. - - Additional fields: - override_ppo_mini_batch_num (Optional[int]): If > 0, overrides ppo_mini_batch_size - by computing mini_batch_split_size = ceil(batch_size / override_ppo_mini_batch_num). - """ - - override_ppo_mini_batch_num: Optional[int] = None @dataclass -class AjetFSDPActorConfig(FSDPActorConfig): - """FSDPActorConfig extended with ajet-specific fields. - - Additional fields: - override_ppo_mini_batch_num (Optional[int]): If > 0, overrides ppo_mini_batch_size - by computing mini_batch_split_size = ceil(batch_size / override_ppo_mini_batch_num). - """ - - override_ppo_mini_batch_num: Optional[int] = None +class AgentJetFSDPActorConfig(FSDPActorConfig): + loss_extra_scale_ratio: float = 1.0 + override_ppo_mini_batch_num: int = 1 diff --git a/ajet/backbone/verl/dp_actor.py b/ajet/backbone/verl/dp_actor.py index 5ab28d9b..b4c85a9c 100644 --- a/ajet/backbone/verl/dp_actor.py +++ b/ajet/backbone/verl/dp_actor.py @@ -32,7 +32,8 @@ from verl.utils.fsdp_utils import FSDPModule, fsdp2_clip_grad_norm_ from verl.utils.profiler import GPUMemoryLogger from verl.utils.py_functional import append_to_dict -from verl.utils.seqlen_balancing import prepare_dynamic_batch +# ajet/backbone/verl/seqlen_balancing.py +from ajet.backbone.verl.seqlen_balancing import prepare_dynamic_batch, restore_dynamic_batch from verl.workers.actor.dp_actor import DataParallelPPOActor __all__ = ["AjetDataParallelPPOActor"] @@ -46,8 +47,94 @@ class AjetDataParallelPPOActor(DataParallelPPOActor): 1. Supports `override_ppo_mini_batch_num` to control the number of optimizer steps per train-batch-step. 2. Adds debug print for tensor shapes during training. + 3. Override `prepare_dynamic_batch` """ + @GPUMemoryLogger(role="dp actor", logger=logger) + def compute_log_prob(self, data: DataProto, calculate_entropy: bool = False) -> dict[str, torch.Tensor]: + """Compute the log probability of the responses given input_ids, attention_mask and position_ids + + Args: + data (DataProto): a DataProto containing keys + + ``input_ids``: tensor of shape [batch_size, sequence_length]. torch.int64. Note that input_ids is the + concatenation of prompt and response. Note that ``sequence_length = prompt_length + response_length``. + + ``attention_mask``: tensor of shape [batch_size, sequence_length]. torch.int64. + + ``position_ids``: tensor of shape [batch_size, sequence_length]. torch.int64. + + ``responses``: tensor of shape [batch_size, response_length]. torch.int64. + + Returns: + dict[str, torch.Tensor]: a dict containing keys + - ``log_probs``: tensor of shape [batch_size, response_length]. torch.float32. + - ``entropys``: tensor of shape [batch_size, response_length]. torch.float32. + - ``sum_pi_squared``: tensor of shape [batch_size, response_length]. torch.float32. + """ + calculate_sum_pi_squared = self.config.get("calculate_sum_pi_squared", False) + self.actor_module.eval() + + micro_batch_size = data.meta_info["micro_batch_size"] + temperature = data.meta_info["temperature"] # temperature must be in the data.meta_info to avoid silent error + use_dynamic_bsz = data.meta_info["use_dynamic_bsz"] + pad_token_id = data.meta_info.get("pad_token_id", 0) + has_multi_modal_inputs = "multi_modal_inputs" in data.non_tensor_batch.keys() + + select_keys = ["responses", "input_ids", "attention_mask", "position_ids"] + non_tensor_select_keys = ["multi_modal_inputs"] if has_multi_modal_inputs else [] + if self.use_prefix_grouper: + select_keys += [k for k in ["prompts", "response_mask"] if k in data.batch] + if "uid" in data.non_tensor_batch: + non_tensor_select_keys.append("uid") + + data = data.select(batch_keys=select_keys, non_tensor_batch_keys=non_tensor_select_keys) + + if use_dynamic_bsz: + max_token_len = data.meta_info["max_token_len"] * self.ulysses_sequence_parallel_size + micro_batches, batch_idx_list = prepare_dynamic_batch(data, max_token_len=max_token_len) + else: + micro_batches = data.split(micro_batch_size) + + log_probs_lst = [] + entropy_lst = [] + sum_pi_squared_lst = [] + print(f"len(micro_batches) = {len(micro_batches)}") + for micro_batch in micro_batches: + micro_batch = micro_batch.to(get_device_id()) + model_inputs = {**micro_batch.batch, **micro_batch.non_tensor_batch, "pad_token_id": pad_token_id} + with torch.no_grad(): + outputs = self._forward_micro_batch( + model_inputs, temperature=temperature, calculate_entropy=calculate_entropy + ) + log_probs_lst.append(outputs["log_probs"]) + if calculate_entropy: + entropy_lst.append(outputs["entropys"]) + if calculate_sum_pi_squared: + sum_pi_squared_lst.append(outputs["sum_pi_squared"]) + + log_probs = torch.concat(log_probs_lst, dim=0) + if calculate_entropy: + entropys = torch.concat(entropy_lst, dim=0) + if calculate_sum_pi_squared: + sum_pi_squared = torch.concat(sum_pi_squared_lst, dim=0) + + if use_dynamic_bsz: + log_probs = restore_dynamic_batch(log_probs, batch_idx_list) + if calculate_entropy: + entropys = restore_dynamic_batch(entropys, batch_idx_list) + if calculate_sum_pi_squared: + sum_pi_squared = restore_dynamic_batch(sum_pi_squared, batch_idx_list) + + outputs = {"log_probs": log_probs} + if calculate_entropy: + outputs["entropys"] = entropys + if calculate_sum_pi_squared: + outputs["sum_pi_squared"] = sum_pi_squared + return outputs + + + @GPUMemoryLogger(role="dp actor", logger=logger) def update_policy(self, data: DataProto): # make sure we are in training mode diff --git a/ajet/backbone/verl/fsdp_workers.py b/ajet/backbone/verl/fsdp_workers.py index 826ad78c..0db436a9 100644 --- a/ajet/backbone/verl/fsdp_workers.py +++ b/ajet/backbone/verl/fsdp_workers.py @@ -283,7 +283,7 @@ def __init__(self, config: DictConfig, role: str, **kwargs): @register(dispatch_mode=Dispatch.ONE_TO_ALL) def init_model(self): - # from verl.workers.actor import DataParallelPPOActor + # [AgentJet Change]: use the custom DataParallelPPOActor which supports FSDP and other features needed for ActorRolloutRefWorker from ajet.backbone.verl.dp_actor import AjetDataParallelPPOActor as DataParallelPPOActor # This is used to import external_lib into the huggingface systems @@ -347,7 +347,8 @@ def init_model(self): log_gpu_memory_usage("After offload actor optimizer during init", logger=logger) if self._is_actor: - actor_cfg = self.config.actor + # [AgentJet Change]: use the custom DataParallelPPOActor which supports FSDP and other features needed for ActorRolloutRefWorker + actor_cfg = omega_conf_to_dataclass(self.config.actor) self.actor = DataParallelPPOActor( config=actor_cfg, actor_module=self.actor_module_fsdp, actor_optimizer=self.actor_optimizer ) @@ -422,7 +423,6 @@ def init_model(self): # Free cached GPU memory so colocated vLLM processes can see it via cudaMemGetInfo aggressive_empty_cache(force_sync=True) - # ================================= Async related workers ================================= class AjetAsyncActorRolloutRefWorker(AjetActorRolloutRefWorker): @register(dispatch_mode=Dispatch.ONE_TO_ALL, blocking=False) diff --git a/ajet/backbone/verl/seqlen_balancing.py b/ajet/backbone/verl/seqlen_balancing.py new file mode 100644 index 00000000..46a61271 --- /dev/null +++ b/ajet/backbone/verl/seqlen_balancing.py @@ -0,0 +1,626 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy +import heapq +from itertools import chain + +import torch +from torch import distributed as dist + +from verl.protocol import DataProto +from verl.utils import tensordict_utils as tu +from verl.utils.device import get_device_name + + +def calculate_workload(seqlen_list: torch.Tensor) -> torch.Tensor: + """Calculate approximate computational workload for transformer attention. + + Estimates FLOPs for dense transformer blocks based on sequence length using + the formula: FLOPs ≈ 12 * hidden_size² * seqlen + 2 * hidden_size * seqlen² + + The constants are calibrated for a 7B model (hidden_size=4096), yielding: + workload ∝ 24576 * seqlen + seqlen² + + Args: + seqlen_list: Sequence lengths as a tensor. + + Returns: + torch.Tensor: Estimated workload values proportional to actual FLOPs. + + Note: + The returned values are relative workloads, not actual FLOP counts. + Useful for balancing computation across data parallel ranks. + """ + return 24576 * seqlen_list + seqlen_list**2 + + +def karmarkar_karp(seqlen_list: list[int], k_partitions: int, equal_size: bool) -> list[list[int]]: + """Partition items into k groups using the Karmarkar-Karp differencing method. + + Implements the Largest Differencing Method (LDM) algorithm for balanced + multi-way number partitioning. This heuristic produces near-optimal partitions + by iteratively combining the sets with the largest difference. + + Args: + seqlen_list: Values to partition (typically sequence lengths or workloads). + k_partitions: Number of partitions to create. + equal_size: If True, each partition will have exactly len(seqlen_list) / k_partitions + items. If False, partitions may have different sizes. + + Returns: + list[list[int]]: List of k partitions, each containing indices into seqlen_list. + + See Also: + https://en.wikipedia.org/wiki/Largest_differencing_method + + Note: + When equal_size=True, len(seqlen_list) must be divisible by k_partitions. + """ + + # see: https://en.wikipedia.org/wiki/Largest_differencing_method + class Set: + def __init__(self) -> None: + self.sum = 0 + self.items = [] + + def add(self, idx: int, val: int): + self.items.append((idx, val)) + self.sum += val + + def merge(self, other): + for idx, val in other.items: + self.items.append((idx, val)) + self.sum += val + + def __lt__(self, other): + if self.sum != other.sum: + return self.sum < other.sum + if len(self.items) != len(other.items): + return len(self.items) < len(other.items) + return self.items < other.items + + class State: + def __init__(self, items: list[tuple[int, int]], k: int) -> None: + self.k = k + # sets should always be decreasing order + self.sets = [Set() for _ in range(k)] + assert len(items) in [1, k], f"{len(items)} not in [1, {k}]" + for i, (idx, seqlen) in enumerate(items): + self.sets[i].add(idx=idx, val=seqlen) + self.sets = sorted(self.sets, reverse=True) + + def get_partitions(self): + partitions = [] + for i in range(len(self.sets)): + cur_partition = [] + for idx, _ in self.sets[i].items: + cur_partition.append(idx) + partitions.append(cur_partition) + return partitions + + def merge(self, other): + for i in range(self.k): + self.sets[i].merge(other.sets[self.k - 1 - i]) + self.sets = sorted(self.sets, reverse=True) + + @property + def spread(self) -> int: + return self.sets[0].sum - self.sets[-1].sum + + def __lt__(self, other): + # least heap, let the state with largest spread to be popped first, + # if the spread is the same, let the state who has the largest set + # to be popped first. + if self.spread != other.spread: + return self.spread > other.spread + return self.sets[0] > other.sets[0] + + def __repr__(self) -> str: + repr_str = "[" + for i in range(self.k): + if i > 0: + repr_str += "," + repr_str += "{" + for j, (_, seqlen) in enumerate(self.sets[i].items): + if j > 0: + repr_str += "," + repr_str += str(seqlen) + repr_str += "}" + repr_str += "]" + return repr_str + + sorted_seqlen_list = sorted([(seqlen, i) for i, seqlen in enumerate(seqlen_list)]) + states_pq = [] + if equal_size: + assert len(seqlen_list) % k_partitions == 0, f"{len(seqlen_list)} % {k_partitions} != 0" + for offset in range(0, len(sorted_seqlen_list), k_partitions): + items = [] + for i in range(k_partitions): + seqlen, idx = sorted_seqlen_list[offset + i] + items.append((idx, seqlen)) + heapq.heappush(states_pq, State(items=items, k=k_partitions)) + else: + for seqlen, idx in sorted_seqlen_list: + heapq.heappush(states_pq, State(items=[(idx, seqlen)], k=k_partitions)) + + while len(states_pq) > 1: + state0 = heapq.heappop(states_pq) + state1 = heapq.heappop(states_pq) + # merge states + state0.merge(state1) + heapq.heappush(states_pq, state0) + + final_state = states_pq[0] + partitions = final_state.get_partitions() + if equal_size: + for i, partition in enumerate(partitions): + assert len(partition) * k_partitions == len(seqlen_list), ( + f"{len(partition)} * {k_partitions} != {len(seqlen_list)}" + ) + return partitions + + +def greedy_partition(seqlen_list: list[int], k_partitions: int, equal_size: bool) -> list[list[int]]: + """Partition items into k groups using a greedy assignment strategy. + + Assigns each item to the partition with the smallest current sum, iterating + through items in order. Simpler but typically less optimal than Karmarkar-Karp. + + Args: + seqlen_list: Values to partition (typically sequence lengths or workloads). + k_partitions: Number of partitions to create. + equal_size: If True, adds a bias to ensure equal partition sizes. + Requires len(seqlen_list) to be divisible by k_partitions. + + Returns: + list[list[int]]: List of k partitions, each containing indices into seqlen_list. + + Note: + When equal_size=True, a large bias is added to encourage equal distribution + of items before considering the actual values. + """ + bias = sum(seqlen_list) + 1 if equal_size else 0 + sorted_seqlen = [(seqlen + bias, i) for i, seqlen in enumerate(seqlen_list)] + partitions = [[] for _ in range(k_partitions)] + partition_sums = [0 for _ in range(k_partitions)] + for seqlen, i in sorted_seqlen: + min_idx = None + for j in range(k_partitions): + if min_idx is None or partition_sums[j] < partition_sums[min_idx]: + min_idx = j + partitions[min_idx].append(i) + partition_sums[min_idx] += seqlen + if equal_size: + for i, partition in enumerate(partitions): + assert len(partition) * k_partitions == len(seqlen_list), ( + f"{len(partition)} * {k_partitions} != {len(seqlen_list)}" + ) + return partitions + + +def get_seqlen_balanced_partitions(seqlen_list: list[int], k_partitions: int, equal_size: bool): + """ + Calculates partitions of indices from seqlen_list such that the sum of sequence lengths + in each partition is balanced. Uses the Karmarkar-Karp differencing method. + + This is useful for balancing workload across devices or batches, especially when + dealing with variable sequence lengths. + + Args: + seqlen_list (List[int]): A list of sequence lengths for each item. + k_partitions (int): The desired number of partitions. + equal_size (bool): If True, ensures that each partition has the same number of items. + Requires len(seqlen_list) to be divisible by k_partitions. + If False, partitions can have varying numbers of items, focusing + only on balancing the sum of sequence lengths. + + Returns: + List[List[int]]: A list containing k_partitions lists. Each inner list contains the + original indices of the items assigned to that partition. The indices + within each partition list are sorted. + + Raises: + AssertionError: If len(seqlen_list) < k_partitions. + AssertionError: If equal_size is True and len(seqlen_list) is not divisible by k_partitions. + AssertionError: If any resulting partition is empty. + """ + assert len(seqlen_list) >= k_partitions, f"number of items:[{len(seqlen_list)}] < k_partitions:[{k_partitions}]" + + def _check_and_sort_partitions(partitions): + assert len(partitions) == k_partitions, f"{len(partitions)} != {k_partitions}" + seen_idx = set() + sorted_partitions = [None] * k_partitions + for i, partition in enumerate(partitions): + assert len(partition) > 0, f"the {i}-th partition is empty" + for idx in partition: + seen_idx.add(idx) + sorted_partitions[i] = sorted(partition) + assert seen_idx == set(range(len(seqlen_list))) + return sorted_partitions + + partitions = karmarkar_karp(seqlen_list=seqlen_list, k_partitions=k_partitions, equal_size=equal_size) + return _check_and_sort_partitions(partitions) + + +def log_seqlen_unbalance(seqlen_list: list[int], partitions: list[list[int]], prefix): + """ + Calculate and log metrics related to sequence length imbalance before and after partitioning. + + Args: + seqlen_list (List[int]): A list of sequence lengths for each item. + partitions (List[List[int]]): A list of partitions, where each inner list contains indices + from seqlen_list assigned to that partition. + prefix (str): A prefix to be added to each metric key in the returned dictionary. + + Returns: + dict: A dictionary containing metrics related to sequence length imbalance. + """ + # Get the number of partitions + k_partition = len(partitions) + # assert len(seqlen_list) % k_partition == 0 + batch_size = len(seqlen_list) // k_partition + min_sum_seqlen = None + max_sum_seqlen = None + total_sum_seqlen = 0 + + # Iterate over each batch of sequence lengths + for offset in range(0, len(seqlen_list), batch_size): + cur_sum_seqlen = sum(seqlen_list[offset : offset + batch_size]) + if min_sum_seqlen is None or cur_sum_seqlen < min_sum_seqlen: + min_sum_seqlen = cur_sum_seqlen + if max_sum_seqlen is None or cur_sum_seqlen > max_sum_seqlen: + max_sum_seqlen = cur_sum_seqlen + total_sum_seqlen += cur_sum_seqlen + + balanced_sum_seqlen_list = [] + for partition in partitions: + cur_sum_seqlen_balanced = sum([seqlen_list[i] for i in partition]) + balanced_sum_seqlen_list.append(cur_sum_seqlen_balanced) + # print("balanced_sum_seqlen_list: ", balanced_sum_seqlen_list) + min_sum_seqlen_balanced = min(balanced_sum_seqlen_list) + max_sum_seqlen_balanced = max(balanced_sum_seqlen_list) + + return { + f"{prefix}/min": min_sum_seqlen, + f"{prefix}/max": max_sum_seqlen, + f"{prefix}/minmax_diff": max_sum_seqlen - min_sum_seqlen, + f"{prefix}/balanced_min": min_sum_seqlen_balanced, + f"{prefix}/balanced_max": max_sum_seqlen_balanced, + f"{prefix}/mean": total_sum_seqlen / len(partitions), + } + + +def ceildiv(a: int, b: int) -> int: + """Compute ceiling division of a by b. + + Returns the smallest integer greater than or equal to a/b. + Uses the identity: ceil(a/b) = floor((a + b - 1) / b) = -(-a // b) + + Args: + a: Dividend (numerator). + b: Divisor (denominator), must be non-zero. + + Returns: + int: Ceiling of a divided by b. + + Example: + >>> ceildiv(7, 3) # ceil(7/3) = ceil(2.33) = 3 + 3 + >>> ceildiv(6, 3) # ceil(6/3) = ceil(2.0) = 2 + 2 + """ + return -(a // -b) + + +def roundup_divisible(a: int, b: int) -> int: + """Round up a to the nearest multiple of b. + + Returns the smallest multiple of b that is >= a. + + Args: + a: Value to round up. + b: Divisor to round to (must be positive). + + Returns: + int: Smallest multiple of b that is >= a. + + Example: + >>> roundup_divisible(7, 4) # nearest multiple of 4 >= 7 is 8 + 8 + >>> roundup_divisible(8, 4) # 8 is already a multiple of 4 + 8 + """ + return ((a + b - 1) // b) * b + + +def rearrange_micro_batches( + batch, + max_token_len, + dp_group=None, + num_batches_divided_by=None, + same_micro_num_in_dp=True, + min_num_micro_batch=None, + use_dynamic_bsz_balance=True, + force_group_size=1, +): + """ + Split a batch into micro-batches by total token count, with optional DP sync and padding. + + Args: + batch (TensorDict): must include "attention_mask" (B*S); other fields are sliced similarly. + max_token_len (int): max sum of attention_mask per micro-batch. + dp_group (optional): torch.distributed group for data-parallel sync. + num_batches_divided_by (optional): virtual pipeline parallel size, for megatron. + same_micro_num_in_dp (bool): if True and dp_group set, pad all ranks to the same count. + min_num_micro_batch (int, optional): force at least this many splits (pads empty ones). + use_dynamic_bsz_balance (bool, optional): balance the computational workload between micro-batches + force_group_size (int, optional): force consecutive samples to be in the same micro-batch (for RM training). + + Returns: + List[TensorDict]: the micro-batches. + List[List[int]]: index lists mapping each micro-batch back to original positions. + """ + # this is per local micro_bsz + input_ids = batch["input_ids"] + if input_ids.is_nested: + seq_len_effective: torch.Tensor = input_ids.offsets().diff() + max_seq_len = max(seq_len_effective) + else: + max_seq_len = batch["attention_mask"].shape[-1] + seq_len_effective: torch.Tensor = batch["attention_mask"].sum(dim=1) + + assert max_token_len >= max_seq_len, ( + f"max_token_len must be greater than the sequence length. Got {max_token_len=} and {max_seq_len=}" + ) + + # Validate force_group_size + batch_size = len(seq_len_effective) + assert batch_size % force_group_size == 0, ( + f"Batch size {batch_size} must be divisible by force_group_size {force_group_size}" + ) + + total_seqlen = seq_len_effective.sum().item() + # NOTE: num_microbatches <= batch_size, so take the min of this two. + # When force_group_size > 1, we work with groups instead of individual samples + num_groups = batch_size // force_group_size + num_micro_batches = min(num_groups, ceildiv(total_seqlen, max_token_len)) + if min_num_micro_batch is not None: + # used to support pp + num_micro_batches = max(min_num_micro_batch, num_micro_batches) + if dist.is_initialized() and same_micro_num_in_dp: + num_micro_batches = torch.tensor([num_micro_batches], device=get_device_name()) + dist.all_reduce(num_micro_batches, op=dist.ReduceOp.MAX, group=dp_group) + num_micro_batches = num_micro_batches.cpu().item() + if num_batches_divided_by is not None: + num_micro_batches = roundup_divisible(num_micro_batches, num_batches_divided_by) + + assert num_micro_batches <= num_groups + + # upcast to int64 to avoid potential overflow im `calculate_workload` computation. + seq_len_effective = seq_len_effective.long() + + # When force_group_size > 1, aggregate workloads by groups + if force_group_size > 1: + # Calculate workload for each group (sum of workloads of samples in the group) + workloads_per_sample = calculate_workload(seq_len_effective) + workloads_per_sample_grouped = workloads_per_sample.view(num_groups, force_group_size) + group_workloads = workloads_per_sample_grouped.sum(dim=1).cpu().tolist() + + # Partition groups instead of individual samples + micro_bsz_group_idx = get_seqlen_balanced_partitions(group_workloads, num_micro_batches, equal_size=False) + + # Convert group indices back to sample indices + micro_bsz_idx = [] + for group_partition in micro_bsz_group_idx: + sample_partition = [] + for group_idx in group_partition: + start_idx = group_idx * force_group_size + sample_partition.extend(range(start_idx, start_idx + force_group_size)) + micro_bsz_idx.append(sample_partition) + + workloads = group_workloads + else: + # Original logic for force_group_size == 1 + # note that seq_len_effective is a GPU tensor. We need to make it a list to avoid D2H! + workloads = calculate_workload(seq_len_effective).cpu().tolist() + micro_bsz_idx = get_seqlen_balanced_partitions(workloads, num_micro_batches, equal_size=False) + + if use_dynamic_bsz_balance: + # Use the sum of squared sequence lengths to approximate attention computation workload + if force_group_size > 1: + # For grouped samples, use group workloads for sorting + micro_bsz_idx.sort( + key=lambda partition: ( + sum(workloads[idx // force_group_size] for idx in partition[::force_group_size]), + partition[0] if partition else 0, + ), + reverse=True, + ) + else: + micro_bsz_idx.sort( + key=lambda partition: ( + sum(workloads[idx] for idx in partition), + partition[0] if partition else 0, + ), + reverse=True, + ) + # Place smaller micro-batches at both ends to reduce the bubbles exposed during the warm-up and cool-down. + micro_bsz_idx = micro_bsz_idx[::2][::-1] + micro_bsz_idx[1::2] + + micro_batches = [] + + for partition in micro_bsz_idx: + curr_micro_batch = tu.index_select_tensor_dict(batch, partition) + micro_batches.append(curr_micro_batch) + + return micro_batches, micro_bsz_idx + + +def get_reverse_idx(idx_map): + """ + Build the inverse of an index mapping. + + Args: + idx_map (Sequence[int]): Sequence where idx_map[i] = j. + + Returns: + List[int]: Inverse mapping list such that output[j] = i for each i. + """ + reverse_idx_map = copy.deepcopy(idx_map) + + for i, idx in enumerate(idx_map): + reverse_idx_map[idx] = i + + return reverse_idx_map + + +def prepare_dynamic_batch( + data: DataProto, + max_token_len: int, + dp_group=None, + num_batches_divided_by=None, + same_micro_num_in_dp=True, + min_num_micro_batch=None, + use_dynamic_bsz_balance=True, +) -> tuple[list[DataProto], list[list[int]]]: + """ + Prepare a batch for dynamic batching. + + Args: + data (DataProto): The input data. + max_token_len (int): The maximum token length for dynamic batching. + + Returns: + Tuple[List[DataProto], List[List[int]]]: A tuple containing a list of DataProto objects + and a list of index lists. + """ + batch, batch_idx_list = rearrange_micro_batches( + data.batch, + max_token_len=max_token_len, + dp_group=dp_group, + num_batches_divided_by=num_batches_divided_by, + same_micro_num_in_dp=same_micro_num_in_dp, + min_num_micro_batch=min_num_micro_batch, + use_dynamic_bsz_balance=use_dynamic_bsz_balance, + ) + micro_batches = [] + for i, batch_idx in enumerate(batch_idx_list): + tensors = dict(batch[i]) + non_tensors = {key: value[batch_idx] for key, value in data.non_tensor_batch.items()} + meta_info = copy.deepcopy(data.meta_info) + micro_batches.append(DataProto.from_dict(tensors, non_tensors, meta_info=meta_info)) + + return micro_batches, batch_idx_list + + +def restore_dynamic_batch(data: torch.Tensor, batch_idx_list: list[list[int]]) -> torch.Tensor: + """ + Restore a batch from dynamic batching. + + Args: + data (torch.Tensor): The input data. + batch_idx_list (List[List[int]]): The list of index lists. + + Returns: + torch.Tensor: The restored data. + """ + indices = list(chain.from_iterable(batch_idx_list)) + batch_size = data.shape[0] + assert len(indices) == batch_size, f"{len(indices)} vs. {batch_size}" + revert_indices = torch.tensor(get_reverse_idx(indices), dtype=torch.long) + + if data.is_nested: + data_lst = data.unbind() + tensors = [data_lst[i] for i in revert_indices] + reverted_data = torch.nested.as_nested_tensor(tensors, layout=torch.jagged) + else: + reverted_data = data[revert_indices] + + return reverted_data + + +def get_group_balanced_partitions( + seqlen_list: list[int], + uid_list: list, + k_partitions: int, +) -> list[list[int]]: + """ + Partition samples into k groups while keeping samples with the same uid together. + + Args: + seqlen_list: List of sequence lengths for each sample. + uid_list: List of uids identifying which samples share the same prefix. + Samples with the same uid will be kept together. + k_partitions: Number of partitions (typically world_size). + + Returns: + List of k lists, each containing sample indices assigned to that partition. + Samples with the same uid are guaranteed to be in the same partition. + """ + assert len(seqlen_list) == len(uid_list), "seqlen_list and uid_list must have same length" + + # Build groups: each group contains indices of samples with the same uid + # Assumes samples with same uid are contiguous + groups = [] # List of (group_indices, group_total_seqlen) + current_uid = None + current_indices = [] + current_seqlen = 0 + + for i, (seqlen, uid) in enumerate(zip(seqlen_list, uid_list, strict=False)): + if uid != current_uid: + if current_indices: + groups.append((current_indices, current_seqlen)) + current_uid = uid + current_indices = [i] + current_seqlen = seqlen + else: + current_indices.append(i) + current_seqlen += seqlen + + # Don't forget the last group + if current_indices: + groups.append((current_indices, current_seqlen)) + + num_groups = len(groups) + assert num_groups >= k_partitions, ( + f"Number of uid groups ({num_groups}) must be >= k_partitions ({k_partitions}). " + f"Consider reducing world_size or increasing batch_size." + ) + + # Calculate workload for each group (as integers for partitioning) + group_workloads = [] + for indices, total_seqlen in groups: + # Use sum of individual workloads for more accurate estimation + workload = sum(int(calculate_workload(torch.tensor([seqlen_list[i]])).item()) for i in indices) + group_workloads.append(workload) + + # Use Karmarkar-Karp to partition groups + # equal_size=True ensures each partition gets the same number of groups, + # which is required when each group has the same number of samples (rollout.n) + group_partitions = get_seqlen_balanced_partitions( + seqlen_list=group_workloads, + k_partitions=k_partitions, + equal_size=True, + ) + + # Convert group partitions to sample partitions + sample_partitions = [] + for group_partition in group_partitions: + sample_indices = [] + for group_idx in group_partition: + sample_indices.extend(groups[group_idx][0]) + sample_partitions.append(sorted(sample_indices)) + + return sample_partitions diff --git a/ajet/default_config/ajet_default.yaml b/ajet/default_config/ajet_default.yaml index d52f4ec6..2c90c594 100644 --- a/ajet/default_config/ajet_default.yaml +++ b/ajet/default_config/ajet_default.yaml @@ -228,7 +228,7 @@ ajet: debug_tensor_parallel_size: 4 - # trainer common configurations + # (ajet.trainer_common) trainer common configurations trainer_common: # validation before training diff --git a/ajet/default_config/verl/verl_default.yaml b/ajet/default_config/verl/verl_default.yaml index 0753478b..94106f66 100644 --- a/ajet/default_config/verl/verl_default.yaml +++ b/ajet/default_config/verl/verl_default.yaml @@ -1,11 +1,6 @@ # coyp from verl's: # verl/trainer/config/_generated_ppo_trainer.yaml -# DO NOT EDIT MANUALLY -# DO NOT EDIT MANUALLY -# DO NOT EDIT MANUALLY -# DO NOT EDIT MANUALLY -# DO NOT EDIT MANUALLY # DO NOT EDIT MANUALLY # This reference configration yaml is automatically generated via 'scripts/generate_trainer_config.sh' @@ -27,13 +22,13 @@ actor_rollout_ref: betas: - 0.9 - 0.999 - clip_grad: 1.0 min_lr_ratio: 0.0 num_cycles: 0.5 lr_scheduler_type: constant zero_indexed_step: true warmup_style: null override_optimizer_config: null + grad_clip: null fsdp_config: _target_: verl.workers.config.FSDPEngineConfig wrap_policy: @@ -66,7 +61,7 @@ actor_rollout_ref: - re:.*mlp.gate$ activation_observer: static_minmax quantization_config_path: null - _target_: verl.workers.config.FSDPActorConfig + _target_: ajet.backbone.verl.actor_config.AgentJetFSDPActorConfig rollout_n: ${oc.select:actor_rollout_ref.rollout.n,1} strategy: fsdp ppo_mini_batch_size: 16 @@ -125,20 +120,25 @@ actor_rollout_ref: nsys: _target_: verl.utils.profiler.config.NsightToolConfig discrete: ${oc.select:global_profiler.global_tool_config.nsys.discrete} + name: nsight npu: _target_: verl.utils.profiler.config.NPUToolConfig contents: [] level: level0 analysis: true discrete: false + name: npu torch: _target_: verl.utils.profiler.config.TorchProfilerToolConfig contents: [] discrete: false + name: torch torch_memory: _target_: verl.utils.profiler.config.TorchMemoryToolConfig trace_alloc_max_entries: ${oc.select:global_profiler.global_tool_config.torch_memory.trace_alloc_max_entries,100000} stack_depth: ${oc.select:global_profiler.global_tool_config.torch_memory.stack_depth,32} + name: torch_memory + global_tool_config: null router_replay: _target_: verl.workers.config.RouterReplayConfig mode: disabled @@ -165,10 +165,17 @@ actor_rollout_ref: rollout_n: ${oc.select:actor_rollout_ref.rollout.n,1} strategy: ${actor_rollout_ref.actor.strategy} use_torch_compile: ${oc.select:actor_rollout_ref.actor.use_torch_compile,true} + + ppo_micro_batch_size: null + ppo_micro_batch_size_per_gpu: null + use_dynamic_bsz: true + log_prob_micro_batch_size: null - log_prob_micro_batch_size_per_gpu: null + log_prob_micro_batch_size_per_gpu: 1 log_prob_use_dynamic_bsz: true log_prob_max_token_len_per_gpu: ${oc.select:actor_rollout_ref.actor.ppo_max_token_len_per_gpu,16384} + + ppo_max_token_len_per_gpu: ${oc.select:actor_rollout_ref.actor.ppo_max_token_len_per_gpu,16384} profiler: _target_: verl.utils.profiler.ProfilerConfig tool: ${oc.select:global_profiler.tool,null} @@ -180,20 +187,25 @@ actor_rollout_ref: nsys: _target_: verl.utils.profiler.config.NsightToolConfig discrete: ${oc.select:global_profiler.global_tool_config.nsys.discrete} + name: nsight npu: _target_: verl.utils.profiler.config.NPUToolConfig contents: [] level: level0 analysis: true discrete: false + name: npu torch: _target_: verl.utils.profiler.config.TorchProfilerToolConfig contents: [] discrete: false + name: torch torch_memory: _target_: verl.utils.profiler.config.TorchMemoryToolConfig trace_alloc_max_entries: ${oc.select:global_profiler.global_tool_config.torch_memory.trace_alloc_max_entries,100000} stack_depth: ${oc.select:global_profiler.global_tool_config.torch_memory.stack_depth,32} + name: torch_memory + global_tool_config: null router_replay: _target_: verl.workers.config.RouterReplayConfig mode: disabled @@ -299,6 +311,8 @@ actor_rollout_ref: tokenization_sanity_check_mode: strict format: hermes num_repeat_rollouts: null + max_sample_per_task: 30 + max_steps: 30 calculate_log_probs: false agent: _target_: verl.workers.config.AgentLoopConfig @@ -309,6 +323,7 @@ actor_rollout_ref: _target_: verl.workers.config.CustomAsyncServerConfig path: null name: null + agent_loop_manager_class: null checkpoint_engine: _target_: verl.workers.config.CheckpointEngineConfig backend: naive @@ -339,10 +354,13 @@ actor_rollout_ref: level: ${oc.select:actor_rollout_ref.actor.profiler.tool_config.npu.level,level0} analysis: ${oc.select:actor_rollout_ref.actor.profiler.tool_config.npu.analysis,false} discrete: ${oc.select:actor_rollout_ref.actor.profiler.tool_config.npu.discrete,false} + name: npu torch: _target_: verl.utils.profiler.config.TorchProfilerToolConfig contents: ${oc.select:actor_rollout_ref.actor.profiler.tool_config.torch.contents,[]} discrete: ${oc.select:actor_rollout_ref.actor.profiler.tool_config.torch.discrete,false} + name: torch + global_tool_config: null prometheus: _target_: verl.workers.config.PrometheusConfig enable: false @@ -354,6 +372,20 @@ actor_rollout_ref: mtp: ${oc.select:actor_rollout_ref.model.mtp, null} qat: ${oc.select:actor_rollout_ref.actor.fsdp_config.qat,null} layered_summon: false + enable_sleep_mode: true + custom: null + sglang_engine_mode: local + moe_tensor_parallel_size: 1 + limit_images: null + repetition_penalty: 1.0 + layer_name_map: {} + server: + _target_: verl.workers.config.rollout.ServerConfig + timeout: 60.0 + max_attempts: 3 + retry_delay: 2.0 + max_start_wait_time: 300.0 + max_connections: 1000 model: _target_: verl.workers.config.HFModelConfig path: ~/models/deepseek-llm-7b-chat @@ -392,6 +424,16 @@ actor_rollout_ref: speculative_num_draft_tokens: 4 method: mtp num_speculative_tokens: 1 + tokenizer: null + load_tokenizer: true + hf_config: null + local_tokenizer_path: null + processor: null + architectures: null + local_path: null + generation_config: null + local_hf_config_path: null + target_parameters: null hybrid_engine: true nccl_timeout: 600 data: @@ -453,6 +495,7 @@ critic: zero_indexed_step: true warmup_style: null override_optimizer_config: null + grad_clip: null model: fsdp_config: _target_: verl.workers.config.FSDPEngineConfig @@ -537,24 +580,34 @@ critic: nsys: _target_: verl.utils.profiler.config.NsightToolConfig discrete: ${oc.select:global_profiler.global_tool_config.nsys.discrete} + name: nsight npu: _target_: verl.utils.profiler.config.NPUToolConfig contents: [] level: level0 analysis: true discrete: false + name: npu torch: _target_: verl.utils.profiler.config.TorchProfilerToolConfig contents: [] discrete: false + name: torch torch_memory: _target_: verl.utils.profiler.config.TorchMemoryToolConfig trace_alloc_max_entries: ${oc.select:global_profiler.global_tool_config.torch_memory.trace_alloc_max_entries,100000} stack_depth: ${oc.select:global_profiler.global_tool_config.torch_memory.stack_depth,32} + name: torch_memory + global_tool_config: null forward_micro_batch_size: ${oc.select:.ppo_micro_batch_size,null} forward_micro_batch_size_per_gpu: ${oc.select:.ppo_micro_batch_size_per_gpu,null} ulysses_sequence_parallel_size: 1 grad_clip: 1.0 + ppo_infer_max_token_len_per_gpu: 32768 + engine: + _target_: verl.base_config.BaseConfig + model_config: null + ppo_infer_micro_batch_size_per_gpu: null custom_reward_function: path: null name: null @@ -599,12 +652,12 @@ sandbox_fusion: max_concurrent: null memory_limit_mb: null reward: - num_workers: 8 + num_workers: 0 custom_reward_function: path: null name: compute_score reward_manager: - _target_: verl.workers.config.reward_model.RewardManagerConfig + _target_: verl.workers.config.reward.RewardManagerConfig source: register name: naive module: @@ -670,6 +723,9 @@ algorithm: pf_ppo: reweight_method: pow weight_pow: 2.0 + filter_groups: null + gdpo_reward_keys: null + gdpo_reward_weights: null trainer: balance_batch: true total_epochs: 30 @@ -703,31 +759,6 @@ trainer: use_legacy_worker_impl: auto global_profiler: _target_: verl.utils.profiler.ProfilerConfig - tool: null - steps: null - profile_continuous_steps: false - save_path: outputs/profile - global_tool_config: - nsys: - _target_: verl.utils.profiler.config.NsightToolConfig - discrete: false - controller_nsight_options: - trace: cuda,nvtx,cublas,ucx - cuda-memory-usage: 'true' - cuda-graph-trace: graph - worker_nsight_options: - trace: cuda,nvtx,cublas,ucx - cuda-memory-usage: 'true' - cuda-graph-trace: graph - capture-range: cudaProfilerApi - capture-range-end: null - kill: none - torch_memory: - trace_alloc_max_entries: 100000 - stack_depth: 32 - context: all - stacks: all - kw_args: {} transfer_queue: enable: false ray_kwargs: diff --git a/ajet/default_config/verl/verl_default_expand.yaml b/ajet/default_config/verl/verl_default_expand.yaml new file mode 100644 index 00000000..4f8cef43 --- /dev/null +++ b/ajet/default_config/verl/verl_default_expand.yaml @@ -0,0 +1,1066 @@ +actor_rollout_ref: + actor: + optim: # [auto-convert] + _target_: verl.workers.config.FSDPOptimizerConfig + optimizer: AdamW + optimizer_impl: torch.optim + lr: 1.0e-06 + lr_warmup_steps_ratio: 0.0 + total_training_steps: -1 + weight_decay: 0.01 + lr_warmup_steps: -1 + betas: + - 0.9 + - 0.999 + min_lr_ratio: 0.0 + num_cycles: 0.5 + lr_scheduler_type: constant + zero_indexed_step: true + warmup_style: null + override_optimizer_config: null + grad_clip: null + clip_grad: 1.0 + fsdp_config: # [auto-convert] + _target_: verl.workers.config.FSDPEngineConfig + wrap_policy: + min_num_params: 0 + param_offload: true + optimizer_offload: true + offload_policy: false + reshard_after_forward: true + fsdp_size: -1 + forward_prefetch: false + model_dtype: fp32 + dtype: bfloat16 + use_orig_params: false + seed: 42 + full_determinism: false + ulysses_sequence_parallel_size: 1 + entropy_from_logits_with_chunking: false + use_torch_compile: true + entropy_checkpointing: false + forward_only: false + strategy: fsdp + qat: + _target_: verl.workers.config.QATEngineConfig + enable: false + mode: w4a16 + group_size: 16 + ignore_patterns: + - lm_head + - embed_tokens + - re:.*mlp.gate$ + activation_observer: static_minmax + quantization_config_path: null + router_replay: + _target_: verl.workers.config.engine.EngineRouterReplayConfig + replay_file: null + mode: disabled + record_file: null + micro_batch_size_per_gpu: null + infer_max_token_len_per_gpu: null + infer_micro_batch_size_per_gpu: null + grad_offload: false + use_fused_kernels: false + mixed_precision: null + use_dynamic_bsz: true + use_remove_padding: true + max_token_len_per_gpu: null + _target_: ajet.backbone.verl.actor_config.AgentJetFSDPActorConfig + rollout_n: ${oc.select:actor_rollout_ref.rollout.n,1} # [auto-convert] + strategy: fsdp + ppo_mini_batch_size: 16 + loss_extra_scale_ratio: 1.0 # [auto-convert] + override_ppo_mini_batch_num: 1 # [auto-convert] + ppo_micro_batch_size: null + ppo_micro_batch_size_per_gpu: null + use_dynamic_bsz: true + ppo_max_token_len_per_gpu: 16384 # [auto-convert] + clip_ratio: 0.2 + clip_ratio_low: 0.2 + clip_ratio_high: 0.2 + tau_pos: 1.0 + tau_neg: 1.05 + freeze_vision_tower: false + policy_loss: + _target_: verl.workers.config.PolicyLossConfig + loss_mode: vanilla + clip_cov_ratio: 0.0002 + clip_cov_lb: 1.0 + clip_cov_ub: 5.0 + kl_cov_ratio: 0.0002 + ppo_kl_coef: 0.1 + rollout_correction: + _target_: verl.trainer.config.algorithm.RolloutCorrectionConfig + rollout_is_threshold: 2.0 + rollout_is_batch_normalize: false + rollout_is: sequence + rollout_rs_threshold: null + loss_type: ppo_clip + bypass_mode: false + rollout_rs: null + clip_ratio_c: 3.0 + loss_agg_mode: seq-mean-token-mean + loss_scale_factor: null + global_batch_info: {} + entropy_coeff: 0 + calculate_entropy: false + use_kl_loss: true # [auto-convert] + use_prefix_grouper: false + use_torch_compile: true + kl_loss_coef: 0.001 # [auto-convert] + kl_loss_type: low_var_kl # [auto-convert] + ppo_epochs: 1 + shuffle: false + data_loader_seed: 42 + checkpoint: + _target_: verl.trainer.config.CheckpointConfig + save_contents: + - model + - optimizer + - extra + load_contents: ${.save_contents} + async_save: false + mbridge_config: {} + use_fused_kernels: ${oc.select:actor_rollout_ref.model.use_fused_kernels,false} + profiler: + _target_: verl.utils.profiler.ProfilerConfig + tool: ${oc.select:global_profiler.tool,null} + enable: false + all_ranks: false + ranks: [] + save_path: ${oc.select:global_profiler.save_path,null} + tool_config: + nsys: + _target_: verl.utils.profiler.config.NsightToolConfig + discrete: ${oc.select:global_profiler.global_tool_config.nsys.discrete} + name: nsight + npu: + _target_: verl.utils.profiler.config.NPUToolConfig + contents: [] + level: level0 + analysis: true + discrete: false + name: npu + torch: + _target_: verl.utils.profiler.config.TorchProfilerToolConfig + contents: [] + discrete: false + name: torch + torch_memory: + _target_: verl.utils.profiler.config.TorchMemoryToolConfig + trace_alloc_max_entries: ${oc.select:global_profiler.global_tool_config.torch_memory.trace_alloc_max_entries,100000} + stack_depth: ${oc.select:global_profiler.global_tool_config.torch_memory.stack_depth,32} + name: torch_memory + global_tool_config: null + router_replay: + _target_: verl.workers.config.RouterReplayConfig + mode: disabled + record_file: null + replay_file: null + grad_clip: 1.0 + ulysses_sequence_parallel_size: 1 # [auto-convert] + entropy_from_logits_with_chunking: false + entropy_checkpointing: false + use_remove_padding: true + calculate_sum_pi_squared: false + sum_pi_squared_checkpointing: false + qat: + enable: false + mode: w4a16 + group_size: 16 + ignore_patterns: + - lm_head + - embed_tokens + - re:.*mlp.gate$ + activation_observer: static_minmax + quantization_config_path: null + ppo_infer_max_token_len_per_gpu: 16384 + model_config: + _target_: verl.workers.config.model.HFModelConfig + local_tokenizer_path: null + use_fused_kernels: false + trust_remote_code: false + override_config: {} + local_hf_config_path: null + external_lib: null + use_remove_padding: true + tokenizer: null + hf_config: null + processor: null + hf_config_path: null + target_modules: all-linear + use_liger: false + use_shm: false + tokenizer_path: null + exclude_modules: null + mtp: + _target_: verl.workers.config.model.MtpConfig + speculative_num_steps: 3 + speculative_num_draft_tokens: 4 + num_speculative_tokens: 1 + enable_rollout: false + mtp_loss_scaling_factor: 0.1 + speculative_algorithm: EAGLE + enable: false + enable_train: false + detach_encoder: false + method: mtp + speculative_eagle_topk: 1 + architectures: null + local_path: null + path: ??? + generation_config: null + enable_activation_offload: false + fused_kernel_options: {} + custom_chat_template: null + load_tokenizer: true + lora_rank: 0 + tiled_mlp: {} + enable_gradient_checkpointing: true + lora_alpha: 16 + lora_adapter_path: null + target_parameters: null + use_rollout_log_probs: false + ppo_infer_micro_batch_size_per_gpu: null + engine: + _target_: verl.base_config.BaseConfig + ref: + rollout_n: ${oc.select:actor_rollout_ref.rollout.n,1} + strategy: ${actor_rollout_ref.actor.strategy} + use_torch_compile: ${oc.select:actor_rollout_ref.actor.use_torch_compile,true} + ppo_micro_batch_size: null + ppo_micro_batch_size_per_gpu: null + use_dynamic_bsz: true + ppo_max_token_len_per_gpu: ${oc.select:actor_rollout_ref.actor.ppo_max_token_len_per_gpu,16384} + profiler: + _target_: verl.utils.profiler.ProfilerConfig + tool: ${oc.select:global_profiler.tool,null} + enable: false + all_ranks: false + ranks: [] + save_path: ${oc.select:global_profiler.save_path,null} + tool_config: + nsys: + _target_: verl.utils.profiler.config.NsightToolConfig + discrete: ${oc.select:global_profiler.global_tool_config.nsys.discrete} + name: nsight + npu: + _target_: verl.utils.profiler.config.NPUToolConfig + contents: [] + level: level0 + analysis: true + discrete: false + name: npu + torch: + _target_: verl.utils.profiler.config.TorchProfilerToolConfig + contents: [] + discrete: false + name: torch + torch_memory: + _target_: verl.utils.profiler.config.TorchMemoryToolConfig + trace_alloc_max_entries: ${oc.select:global_profiler.global_tool_config.torch_memory.trace_alloc_max_entries,100000} + stack_depth: ${oc.select:global_profiler.global_tool_config.torch_memory.stack_depth,32} + name: torch_memory + global_tool_config: null + router_replay: + _target_: verl.workers.config.RouterReplayConfig + mode: disabled + record_file: null + replay_file: null + fsdp_config: + _target_: verl.workers.config.FSDPEngineConfig + wrap_policy: + min_num_params: 0 + param_offload: true + optimizer_offload: true + offload_policy: false + reshard_after_forward: true + fsdp_size: -1 + forward_prefetch: false + model_dtype: fp32 + dtype: bfloat16 + use_orig_params: false + seed: 42 + full_determinism: false + ulysses_sequence_parallel_size: 1 + entropy_from_logits_with_chunking: false + use_torch_compile: true + entropy_checkpointing: false + forward_only: true + strategy: fsdp + qat: + _target_: verl.workers.config.QATEngineConfig + enable: false + mode: w4a16 + group_size: 16 + ignore_patterns: + - lm_head + - embed_tokens + - re:.*mlp.gate$ + activation_observer: static_minmax + quantization_config_path: null + router_replay: + _target_: verl.workers.config.engine.EngineRouterReplayConfig + replay_file: null + mode: disabled + record_file: null + micro_batch_size_per_gpu: null + infer_max_token_len_per_gpu: null + infer_micro_batch_size_per_gpu: null + grad_offload: false + use_fused_kernels: false + mixed_precision: null + use_dynamic_bsz: true + use_remove_padding: true + max_token_len_per_gpu: null + _target_: verl.workers.config.FSDPActorConfig + ulysses_sequence_parallel_size: ${oc.select:actor_rollout_ref.actor.ulysses_sequence_parallel_size,1} + entropy_from_logits_with_chunking: false + entropy_checkpointing: false + clip_ratio: 0.2 + ppo_epochs: 1 + use_fused_kernels: false + freeze_vision_tower: false + optim: + _target_: verl.workers.config.optimizer.OptimizerConfig + lr_warmup_steps: -1 + clip_grad: 1.0 + total_training_steps: -1 + grad_clip: null + betas: !!python/tuple + - 0.9 + - 0.999 + lr: 0.001 + lr_warmup_steps_ratio: 0.0 + weight_decay: 0.01 + entropy_coeff: 0 + use_remove_padding: false + shuffle: false + kl_loss_coef: 0.001 + ppo_infer_max_token_len_per_gpu: 16384 + global_batch_info: {} + clip_ratio_high: 0.2 + data_loader_seed: 1 + ppo_infer_micro_batch_size_per_gpu: null + ppo_mini_batch_size: 256 + use_kl_loss: false + kl_loss_type: low_var_kl + tau_neg: 1.05 + checkpoint: + _target_: verl.trainer.config.config.CheckpointConfig + async_save: false + engine: + _target_: verl.base_config.BaseConfig + use_prefix_grouper: false + model_config: + _target_: verl.workers.config.model.HFModelConfig + local_tokenizer_path: null + use_fused_kernels: false + trust_remote_code: false + override_config: {} + local_hf_config_path: null + external_lib: null + use_remove_padding: true + tokenizer: null + hf_config: null + processor: null + hf_config_path: null + target_modules: all-linear + use_liger: false + use_shm: false + tokenizer_path: null + exclude_modules: null + mtp: + _target_: verl.workers.config.model.MtpConfig + speculative_num_steps: 3 + speculative_num_draft_tokens: 4 + num_speculative_tokens: 1 + enable_rollout: false + mtp_loss_scaling_factor: 0.1 + speculative_algorithm: EAGLE + enable: false + enable_train: false + detach_encoder: false + method: mtp + speculative_eagle_topk: 1 + architectures: null + local_path: null + path: ??? + generation_config: null + enable_activation_offload: false + fused_kernel_options: {} + custom_chat_template: null + load_tokenizer: true + lora_rank: 0 + tiled_mlp: {} + enable_gradient_checkpointing: true + lora_alpha: 16 + lora_adapter_path: null + target_parameters: null + calculate_entropy: false + loss_scale_factor: null + calculate_sum_pi_squared: false + clip_ratio_c: 3.0 + loss_agg_mode: token-mean + tau_pos: 1.0 + clip_ratio_low: 0.2 + grad_clip: 1.0 + policy_loss: + _target_: verl.workers.config.actor.PolicyLossConfig + ppo_kl_coef: 0.1 + clip_cov_lb: 1.0 + clip_cov_ratio: 0.0002 + kl_cov_ratio: 0.0002 + rollout_correction: + _target_: verl.trainer.config.algorithm.RolloutCorrectionConfig + rollout_is_threshold: 2.0 + rollout_is_batch_normalize: false + rollout_is: sequence + rollout_rs_threshold: null + loss_type: ppo_clip + bypass_mode: false + rollout_rs: null + loss_mode: vanilla + clip_cov_ub: 5.0 + qat: + _target_: verl.utils.qat.core.QATConfig + mode: w4a16 + group_size: 16 + enable: false + activation_observer: static_minmax + quantization_config_path: null + use_rollout_log_probs: false + sum_pi_squared_checkpointing: false + rollout: + _target_: verl.workers.config.RolloutConfig + name: vllm + mode: async + nnodes: 0 + n_gpus_per_node: ${oc.select:trainer.n_gpus_per_node,8} + temperature: 0.9 # [auto-convert] + top_k: -1 + top_p: 1 + prompt_length: ${oc.select:data.max_prompt_length,512} + response_length: ${oc.select:data.max_response_length,512} + dtype: bfloat16 + gpu_memory_utilization: 0.85 + ignore_eos: false + enforce_eager: true + cudagraph_capture_sizes: null + free_cache_engine: true + tensor_model_parallel_size: 1 + data_parallel_size: 1 + expert_parallel_size: 1 + pipeline_model_parallel_size: 1 + max_num_batched_tokens: 8192 + max_model_len: null # [auto-convert] + max_num_seqs: 10 # [auto-convert] + enable_chunked_prefill: false + enable_prefix_caching: false + logprobs_mode: processed_logprobs + scheduling_policy: fcfs + load_format: auto + log_prob_micro_batch_size: null + log_prob_micro_batch_size_per_gpu: 1 + log_prob_use_dynamic_bsz: true + log_prob_max_token_len_per_gpu: ${oc.select:actor_rollout_ref.actor.ppo_max_token_len_per_gpu,16384} # [auto-convert] + disable_log_stats: true + do_sample: true + n: 1 # [auto-convert] + over_sample_rate: 0 + multi_stage_wake_up: false + engine_kwargs: + vllm: {} + sglang: {} + trtllm: {} + val_kwargs: # [auto-convert] + _target_: verl.workers.config.SamplingConfig + top_k: -1 + top_p: 1.0 + temperature: 0 + n: 1 + do_sample: false + multi_turn: # [auto-convert] + _target_: ajet.default_config.verl.config_schema_rollout.AjetMultiTurnConfig + enable: false + max_assistant_turns: null + tool_config_path: null + max_user_turns: null + max_parallel_calls: 1 + expected_steps: 1 + max_tool_response_length: 256 + tool_response_truncate_side: middle + interaction_config_path: null + use_inference_chat_template: false + tokenization_sanity_check_mode: strict + format: hermes + num_repeat_rollouts: null + max_sample_per_task: 30 + max_steps: 30 + calculate_log_probs: false + agent: + _target_: verl.workers.config.AgentLoopConfig + num_workers: 1 + default_agent_loop: single_turn_agent + agent_loop_config_path: null + custom_async_server: + _target_: verl.workers.config.CustomAsyncServerConfig + path: null + name: null + agent_loop_manager_class: null + checkpoint_engine: + _target_: verl.workers.config.CheckpointEngineConfig + backend: naive + update_weights_bucket_megabytes: 4096 + engine_kwargs: {} + trace: + _target_: verl.workers.config.TraceConfig + project_name: ${oc.select:trainer.project_name,null} + experiment_name: ${oc.select:trainer.experiment_name,null} + backend: null + token2text: false + max_samples_per_step_per_worker: null + skip_rollout: false + skip_dump_dir: /tmp/rollout_dump + skip_tokenizer_init: true + enable_rollout_routing_replay: false + profiler: + _target_: verl.utils.profiler.ProfilerConfig + tool: ${oc.select:global_profiler.tool,null} + enable: ${oc.select:actor_rollout_ref.actor.profiler.enable,false} + all_ranks: ${oc.select:actor_rollout_ref.actor.profiler.all_ranks,false} + ranks: ${oc.select:actor_rollout_ref.actor.profiler.ranks,[]} + save_path: ${oc.select:global_profiler.save_path,null} + tool_config: + npu: + _target_: verl.utils.profiler.config.NPUToolConfig + contents: ${oc.select:actor_rollout_ref.actor.profiler.tool_config.npu.contents,[]} + level: ${oc.select:actor_rollout_ref.actor.profiler.tool_config.npu.level,level0} + analysis: ${oc.select:actor_rollout_ref.actor.profiler.tool_config.npu.analysis,false} + discrete: ${oc.select:actor_rollout_ref.actor.profiler.tool_config.npu.discrete,false} + name: npu + torch: + _target_: verl.utils.profiler.config.TorchProfilerToolConfig + contents: ${oc.select:actor_rollout_ref.actor.profiler.tool_config.torch.contents,[]} + discrete: ${oc.select:actor_rollout_ref.actor.profiler.tool_config.torch.discrete,false} + name: torch + global_tool_config: null + prometheus: + _target_: verl.workers.config.PrometheusConfig + enable: false + port: 9090 + file: /tmp/ray/session_latest/metrics/prometheus/prometheus.yml + served_model_name: ${oc.select:actor_rollout_ref.model.path,null} + quantization: null + quantization_config_file: null + mtp: ${oc.select:actor_rollout_ref.model.mtp, null} + qat: ${oc.select:actor_rollout_ref.actor.fsdp_config.qat,null} + layered_summon: false + enable_sleep_mode: true + custom: null + sglang_engine_mode: local + moe_tensor_parallel_size: 1 + limit_images: null + repetition_penalty: 1.0 + layer_name_map: {} + server: + _target_: verl.workers.config.rollout.ServerConfig + timeout: 60.0 + max_attempts: 3 + retry_delay: 2.0 + max_start_wait_time: 300.0 + max_connections: 1000 + model: + _target_: verl.workers.config.HFModelConfig + path: ~/models/deepseek-llm-7b-chat # [auto-convert] + hf_config_path: null + tokenizer_path: null + use_shm: false + trust_remote_code: false + custom_chat_template: null + external_lib: null + override_config: {} + enable_gradient_checkpointing: true + enable_activation_offload: false + use_remove_padding: true + lora_rank: 0 + lora_alpha: 16 + target_modules: all-linear + exclude_modules: null + lora_adapter_path: null + use_liger: false + use_fused_kernels: false + fused_kernel_options: + impl_backend: torch + tiled_mlp: + enabled: false + num_shards: 4 + mtp: + _target_: verl.workers.config.MtpConfig + enable: false + enable_train: false + enable_rollout: false + detach_encoder: false + mtp_loss_scaling_factor: 0.1 + speculative_algorithm: EAGLE + speculative_num_steps: 3 + speculative_eagle_topk: 1 + speculative_num_draft_tokens: 4 + method: mtp + num_speculative_tokens: 1 + tokenizer: null + load_tokenizer: true + hf_config: null + local_tokenizer_path: null + processor: null + architectures: null + local_path: null + generation_config: null + local_hf_config_path: null + target_parameters: null + hybrid_engine: true + nccl_timeout: 600 +data: + tokenizer: null + use_shm: false + train_files: ~/data/rlhf/gsm8k/train.parquet + val_files: ~/data/rlhf/gsm8k/test.parquet + train_max_samples: -1 + val_max_samples: -1 + prompt_key: prompt + reward_fn_key: data_source + max_prompt_length: 512 # [auto-convert] + max_response_length: 512 # [auto-convert] + train_batch_size: 1024 # [auto-convert] + val_batch_size: null + tool_config_path: ${oc.select:actor_rollout_ref.rollout.multi_turn.tool_config_path,null} + return_raw_input_ids: false + return_raw_chat: true + return_full_prompt: false + shuffle: true + seed: null + dataloader_num_workers: 8 + image_patch_size: 14 + validation_shuffle: false + filter_overlong_prompts: false + filter_overlong_prompts_workers: 1 + truncation: error + image_key: images + video_key: videos + trust_remote_code: false + custom_cls: + path: null + name: null + return_multi_modal_inputs: true + sampler: + class_path: null + class_name: null + datagen: + path: null + name: null + apply_chat_template_kwargs: {} +critic: + optim: + _target_: verl.workers.config.FSDPOptimizerConfig + optimizer: AdamW + optimizer_impl: torch.optim + lr: 1.0e-05 + lr_warmup_steps_ratio: 0.0 + total_training_steps: -1 + weight_decay: 0.01 + lr_warmup_steps: -1 + betas: + - 0.9 + - 0.999 + clip_grad: 1.0 + min_lr_ratio: 0.0 + num_cycles: 0.5 + lr_scheduler_type: constant + zero_indexed_step: true + warmup_style: null + override_optimizer_config: null + grad_clip: null + model: + fsdp_config: + _target_: verl.workers.config.FSDPEngineConfig + wrap_policy: + min_num_params: 0 + param_offload: false + optimizer_offload: false + offload_policy: false + reshard_after_forward: true + fsdp_size: -1 + forward_prefetch: false + model_dtype: fp32 + dtype: bfloat16 + use_orig_params: false + seed: 42 + full_determinism: false + ulysses_sequence_parallel_size: 1 + entropy_from_logits_with_chunking: false + use_torch_compile: true + strategy: fsdp + entropy_checkpointing: false + forward_only: false + qat: + _target_: verl.workers.config.QATEngineConfig + enable: false + mode: w4a16 + group_size: 16 + ignore_patterns: + - lm_head + - embed_tokens + - re:.*mlp.gate$ + activation_observer: static_minmax + quantization_config_path: null + router_replay: + _target_: verl.workers.config.engine.EngineRouterReplayConfig + replay_file: null + mode: disabled + record_file: null + micro_batch_size_per_gpu: null + infer_max_token_len_per_gpu: null + infer_micro_batch_size_per_gpu: null + grad_offload: false + use_fused_kernels: false + mixed_precision: null + use_dynamic_bsz: true + use_remove_padding: true + max_token_len_per_gpu: null + path: ~/models/deepseek-llm-7b-chat + tokenizer_path: ${oc.select:actor_rollout_ref.model.path,"~/models/deepseek-llm-7b-chat"} + override_config: {} + external_lib: ${oc.select:actor_rollout_ref.model.external_lib,null} + trust_remote_code: ${oc.select:actor_rollout_ref.model.trust_remote_code,false} + _target_: verl.workers.config.FSDPCriticModelCfg + use_shm: false + enable_gradient_checkpointing: true + enable_activation_offload: false + use_remove_padding: false + lora_rank: 0 + lora_alpha: 16 + target_modules: all-linear + tiled_mlp: + enabled: false + num_shards: 4 + _target_: verl.workers.config.FSDPCriticConfig + rollout_n: ${oc.select:actor_rollout_ref.rollout.n,1} + strategy: fsdp + enable: null + ppo_mini_batch_size: ${oc.select:actor_rollout_ref.actor.ppo_mini_batch_size,256} + ppo_micro_batch_size: null + ppo_micro_batch_size_per_gpu: ${oc.select:.ppo_micro_batch_size,null} + use_dynamic_bsz: true + ppo_max_token_len_per_gpu: 32768 + forward_max_token_len_per_gpu: ${.ppo_max_token_len_per_gpu} + ppo_epochs: ${oc.select:actor_rollout_ref.actor.ppo_epochs,1} + shuffle: ${oc.select:actor_rollout_ref.actor.shuffle,false} + data_loader_seed: 42 + cliprange_value: 0.5 + loss_agg_mode: ${oc.select:actor_rollout_ref.actor.loss_agg_mode,token-mean} + checkpoint: + _target_: verl.trainer.config.CheckpointConfig + save_contents: + - model + - optimizer + - extra + load_contents: ${.save_contents} + async_save: false + mbridge_config: {} + profiler: + _target_: verl.utils.profiler.ProfilerConfig + tool: ${oc.select:global_profiler.tool,null} + enable: false + all_ranks: false + ranks: [] + save_path: ${oc.select:global_profiler.save_path,null} + tool_config: + nsys: + _target_: verl.utils.profiler.config.NsightToolConfig + discrete: ${oc.select:global_profiler.global_tool_config.nsys.discrete} + name: nsight + npu: + _target_: verl.utils.profiler.config.NPUToolConfig + contents: [] + level: level0 + analysis: true + discrete: false + name: npu + torch: + _target_: verl.utils.profiler.config.TorchProfilerToolConfig + contents: [] + discrete: false + name: torch + torch_memory: + _target_: verl.utils.profiler.config.TorchMemoryToolConfig + trace_alloc_max_entries: ${oc.select:global_profiler.global_tool_config.torch_memory.trace_alloc_max_entries,100000} + stack_depth: ${oc.select:global_profiler.global_tool_config.torch_memory.stack_depth,32} + name: torch_memory + global_tool_config: null + forward_micro_batch_size: ${oc.select:.ppo_micro_batch_size,null} + forward_micro_batch_size_per_gpu: ${oc.select:.ppo_micro_batch_size_per_gpu,null} + ulysses_sequence_parallel_size: 1 + grad_clip: 1.0 + ppo_infer_max_token_len_per_gpu: 32768 + engine: + _target_: verl.base_config.BaseConfig + model_config: null + ppo_infer_micro_batch_size_per_gpu: null +custom_reward_function: + path: null + name: null +reward_model: + num_workers: null + reward_manager: null + enable: null + enable_resource_pool: null + n_gpus_per_node: null + nnodes: null + reward_loop_source: null + reward_loop_module_path: null + reward_loop_class_name: null + model: + path: null + external_lib: null + trust_remote_code: null + rollout: + name: null + dtype: null + gpu_memory_utilization: null + enforce_eager: null + cudagraph_capture_sizes: null + free_cache_engine: null + data_parallel_size: null + expert_parallel_size: null + tensor_model_parallel_size: null + max_num_batched_tokens: null + max_model_len: null + max_num_seqs: null + load_format: null + engine_kwargs: null + limit_images: null + enable_chunked_prefill: null + enable_prefix_caching: null + disable_log_stats: null + skip_tokenizer_init: null + prompt_length: null + response_length: null +sandbox_fusion: + url: null + max_concurrent: null + memory_limit_mb: null +reward: + num_workers: 8 + custom_reward_function: + path: null + name: compute_score + reward_manager: + _target_: verl.workers.config.reward.RewardManagerConfig + source: register + name: naive + module: + _target_: verl.trainer.config.config.ModuleConfig + path: null + name: custom_reward_manager + reward_model: + enable: false + enable_resource_pool: false + n_gpus_per_node: 8 + nnodes: 0 + model_path: null + rollout: + _target_: verl.workers.config.RolloutConfig + name: vllm + dtype: bfloat16 + gpu_memory_utilization: 0.85 + enforce_eager: true + cudagraph_capture_sizes: null + free_cache_engine: true + data_parallel_size: 1 + expert_parallel_size: 1 + tensor_model_parallel_size: 1 + max_num_batched_tokens: 8192 + max_model_len: null + max_num_seqs: 1024 + load_format: auto + engine_kwargs: {} + limit_images: null + enable_chunked_prefill: true + enable_prefix_caching: true + disable_log_stats: true + skip_tokenizer_init: false + prompt_length: 2048 + response_length: 2048 + do_sample: true + quantization_config_file: null + skip_dump_dir: /tmp/rollout_dump + prometheus: + _target_: verl.workers.config.rollout.PrometheusConfig + served_model_name: null + port: 9090 + enable: false + file: /tmp/ray/session_latest/metrics/prometheus/prometheus.yml + profiler: null + ignore_eos: false + checkpoint_engine: + _target_: verl.workers.config.rollout.CheckpointEngineConfig + backend: naive + engine_kwargs: {} + update_weights_bucket_megabytes: 2048 + temperature: 1.0 + layer_name_map: {} + calculate_log_probs: false + quantization: null + over_sample_rate: 0.0 + custom: null + sglang_engine_mode: local + scheduling_policy: fcfs + logprobs_mode: processed_logprobs + skip_rollout: false + nnodes: 0 + server: + _target_: verl.workers.config.rollout.ServerConfig + timeout: 60.0 + max_attempts: 3 + retry_delay: 2.0 + max_start_wait_time: 300.0 + max_connections: 1000 + trace: + _target_: verl.workers.config.rollout.TraceConfig + project_name: null + token2text: false + experiment_name: null + max_samples_per_step_per_worker: null + backend: null + enable_sleep_mode: true + mode: async + enable_rollout_routing_replay: false + pipeline_model_parallel_size: 1 + mtp: + _target_: verl.workers.config.model.MtpConfig + speculative_num_steps: 3 + speculative_num_draft_tokens: 4 + num_speculative_tokens: 1 + enable_rollout: false + mtp_loss_scaling_factor: 0.1 + speculative_algorithm: EAGLE + enable: false + enable_train: false + detach_encoder: false + method: mtp + speculative_eagle_topk: 1 + val_kwargs: + _target_: verl.workers.config.rollout.SamplingConfig + do_sample: true + n: 1 + top_k: -1 + top_p: 1.0 + temperature: 1.0 + log_prob_micro_batch_size: null + top_p: 1.0 + multi_turn: + _target_: verl.workers.config.rollout.MultiTurnConfig + interaction_config_path: null + max_user_turns: null + format: hermes + tool_config_path: null + max_assistant_turns: null + max_parallel_calls: 1 + use_inference_chat_template: false + enable: false + max_tool_response_length: 256 + tokenization_sanity_check_mode: strict + tool_response_truncate_side: middle + num_repeat_rollouts: null + repetition_penalty: 1.0 + moe_tensor_parallel_size: 1 + log_prob_micro_batch_size_per_gpu: null + n: 1 + n_gpus_per_node: 8 + top_k: -1 + log_prob_use_dynamic_bsz: false + agent: + _target_: verl.workers.config.rollout.AgentLoopConfig + agent_loop_manager_class: null + num_workers: 8 + default_agent_loop: single_turn_agent + agent_loop_config_path: null + custom_async_server: + _target_: verl.workers.config.rollout.CustomAsyncServerConfig + name: null + path: null + layered_summon: false + qat: null + multi_stage_wake_up: false + log_prob_max_token_len_per_gpu: 16384 + sandbox_fusion: + url: null + max_concurrent: 64 + memory_limit_mb: 1024 +algorithm: + rollout_correction: + rollout_is: null + rollout_is_threshold: 2.0 + rollout_rs: null + rollout_rs_threshold: null + bypass_mode: false + loss_type: ppo_clip + rollout_is_batch_normalize: false + _target_: verl.trainer.config.AlgoConfig + gamma: 1.0 + lam: 1.0 + adv_estimator: grpo # [auto-convert] + norm_adv_by_std_in_grpo: true + use_kl_in_reward: false # [auto-convert] + kl_penalty: kl + kl_ctrl: + _target_: verl.trainer.config.KLControlConfig + type: fixed + kl_coef: 0.001 + horizon: 10000 + target_kl: 0.1 + use_pf_ppo: false + pf_ppo: + reweight_method: pow + weight_pow: 2.0 + filter_groups: null + gdpo_reward_keys: null + gdpo_reward_weights: null +trainer: + balance_batch: true + total_epochs: 30 # [auto-convert] + total_training_steps: null + project_name: verl_examples # [auto-convert] + experiment_name: gsm8k # [auto-convert] + logger: # [auto-convert] + - console + - wandb + log_val_generations: 0 + rollout_data_dir: null + validation_data_dir: null + nnodes: 1 # [auto-convert] + n_gpus_per_node: 8 # [auto-convert] + save_freq: -1 # [auto-convert] + esi_redundant_time: 0 + resume_mode: auto + resume_from_path: null + val_before_train: true # [auto-convert] + val_only: false + test_freq: -1 # [auto-convert] + critic_warmup: 0 + default_hdfs_dir: null + del_local_ckpt_after_load: false + checkpoint_base_dir: checkpoints # [auto-convert] + default_local_dir: ${trainer.checkpoint_base_dir}/${trainer.project_name}/${trainer.experiment_name} + max_actor_ckpt_to_keep: null + max_critic_ckpt_to_keep: null + ray_wait_register_center_timeout: 300 + device: cuda + use_legacy_worker_impl: auto +transfer_queue: + enable: false +ray_kwargs: + ray_init: + num_cpus: null + timeline_json_file: null diff --git a/scripts/expand_config_targets.py b/scripts/expand_config_targets.py new file mode 100644 index 00000000..a9ac9f06 --- /dev/null +++ b/scripts/expand_config_targets.py @@ -0,0 +1,377 @@ +#!/usr/bin/env python3 +""" +Script to remove _target_ fields from YAML configuration files. +Before removing, validates that all parameters in the target class exist in the YAML config. +Also adds comments to entries that appear in the auto-conversion config. +""" + +import yaml +import importlib +import inspect +import json +import re +from loguru import logger +from typing import Any, Dict, Set, List +from pathlib import Path +import sys + + +class TargetRemovalError(Exception): + """Custom exception for target removal validation errors.""" + pass + + +def get_class_from_target(target_path: str): + """ + Import and return a class from a dotted path like 'verl.workers.config.FSDPOptimizerConfig'. + + Args: + target_path: Dotted path to the class + + Returns: + The imported class object + """ + module_path, class_name = target_path.rsplit('.', 1) + try: + module = importlib.import_module(module_path) + return getattr(module, class_name) + except (ImportError, ModuleNotFoundError, AttributeError) as e: + raise TargetRemovalError(f"Failed to import {target_path}: {e}") + + +def get_class_parameters(cls) -> Set[str]: + """ + Get all parameter names from a class's __init__ method. + + Args: + cls: The class to inspect + + Returns: + Set of parameter names (excluding 'self') + """ + try: + sig = inspect.signature(cls.__init__) + params = set(sig.parameters.keys()) + params.discard('self') + # Also check for dataclass fields + if hasattr(cls, '__dataclass_fields__'): + params.update(cls.__dataclass_fields__.keys()) + return params + except Exception as e: + raise TargetRemovalError(f"Failed to inspect class {cls.__name__}: {e}") + + +def get_config_keys(config: Any) -> Set[str]: + """ + Get all keys from a configuration dict, excluding special keys like _target_. + + Args: + config: Configuration dict or value + + Returns: + Set of configuration keys + """ + if not isinstance(config, dict): + return set() + + keys = set(config.keys()) + keys.discard('_target_') + return keys + + +def parse_jsonc(file_path: str) -> Dict[str, Any]: + """Parse a JSONC file (JSON with comments).""" + with open(file_path, 'r', encoding='utf-8') as f: + content = f.read() + + # Remove comments (both // and /* */ style) + content = re.sub(r'//.*?$', '', content, flags=re.MULTILINE) + content = re.sub(r'/\*.*?\*/', '', content, flags=re.DOTALL) + + return json.loads(content) + + +def get_all_yaml_paths(data: Any, prefix: str = "") -> Set[str]: + """ + Recursively extract all key paths from a YAML structure. + Returns paths like 'actor_rollout_ref.actor.optim.lr' + """ + paths = set() + + if isinstance(data, dict): + for key, value in data.items(): + if key == '_target_': # Skip _target_ keys + continue + + current_path = f"{prefix}.{key}" if prefix else key + paths.add(current_path) + + # Recursively process nested structures + if isinstance(value, (dict, list)): + paths.update(get_all_yaml_paths(value, current_path)) + + elif isinstance(data, list): + for i, item in enumerate(data): + if isinstance(item, (dict, list)): + paths.update(get_all_yaml_paths(item, prefix)) + + return paths + + +def get_conversion_targets(conversion_config: Dict[str, Any]) -> Set[str]: + """ + Extract all target paths from the conversion config. + Returns a set of paths like 'actor_rollout_ref.actor.optim' + """ + targets = set() + + for key, value in conversion_config.items(): + if isinstance(value, str): + targets.add(value) + elif isinstance(value, list): + targets.update(value) + + return targets + + +def add_comments_to_yaml_lines(yaml_lines: List[str], yaml_data: Dict, conversion_targets: Set[str]) -> List[str]: + """ + Add comments to YAML lines that match conversion targets. + """ + result_lines = [] + path_stack = [] + indent_stack = [] + + for line in yaml_lines: + # Calculate indentation + stripped = line.lstrip() + if not stripped or stripped.startswith('#'): + result_lines.append(line) + continue + + indent = len(line) - len(stripped) + + # Update path stack based on indentation + while indent_stack and indent <= indent_stack[-1]: + indent_stack.pop() + if path_stack: + path_stack.pop() + + # Extract key from line + if ':' in stripped: + key = stripped.split(':')[0].strip() + + # Skip special keys + if key in ['_target_', '-']: + result_lines.append(line) + continue + + # Build current path + path_stack.append(key) + indent_stack.append(indent) + + current_path = '.'.join(path_stack) + + # Check if this path is in conversion targets + if current_path in conversion_targets: + # Add comment if not already present + if '# [auto-convert]' not in line: + line = line.rstrip() + ' # [auto-convert]\n' + + result_lines.append(line) + else: + result_lines.append(line) + + return result_lines + +def validate_and_remove_targets(data: Any, path: str = "root") -> Any: + + if isinstance(data, dict): + data = {key: validate_and_remove_targets(value, f"{path}.{key}") for key, value in data.items()} + + if '_target_' not in data: + return data + target_path = data['_target_'] + + # Import the target class + target_class = get_class_from_target(target_path) + + # Get parameters from the class + class_params = get_class_parameters(target_class) + + # Get keys from current config (excluding _target_) + config_keys = get_config_keys(data) + + # Check if there are any class parameters missing in config + # Note: It's OK for config to have extra keys (like nested configs) + # We're checking if class has required params that aren't in config + missing_in_config = class_params - config_keys - {"_target_"} + extra_in_config = config_keys - class_params - {"_target_"} + + if extra_in_config.__len__ != 0: + for k in extra_in_config: + logger.error(f"Error: discovered unidentified config: {path}.{k}") + + sig = inspect.signature(target_class.__init__) + params_with_defaults = { + name: param for name, param in sig.parameters.items() + if param.default != inspect.Parameter.empty + } + for key in missing_in_config: + if str(params_with_defaults[key].default) != '': + # add to data + print(f"[{path}] add {key} with default value {params_with_defaults[key]} from class {target_class}") + data[key] = params_with_defaults[key].default + else: + # str(params_with_defaults['engine']._annotation) + # target_instance = target_class(**data) + # if isinstance(getattr(target_instance, key), dict): + # data[key] = getattr(target_instance, key) + # else: + # data[key] = { + # '_target_': str(getattr(target_instance, key).__class__).split("'")[1] + # } + if " Date: Wed, 25 Mar 2026 15:16:47 +0800 Subject: [PATCH 16/17] remove debug message --- ajet/backbone/verl/dp_actor.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ajet/backbone/verl/dp_actor.py b/ajet/backbone/verl/dp_actor.py index b4c85a9c..8617f682 100644 --- a/ajet/backbone/verl/dp_actor.py +++ b/ajet/backbone/verl/dp_actor.py @@ -99,7 +99,7 @@ def compute_log_prob(self, data: DataProto, calculate_entropy: bool = False) -> log_probs_lst = [] entropy_lst = [] sum_pi_squared_lst = [] - print(f"len(micro_batches) = {len(micro_batches)}") + # print(f"len(micro_batches) = {len(micro_batches)}") for micro_batch in micro_batches: micro_batch = micro_batch.to(get_device_id()) model_inputs = {**micro_batch.batch, **micro_batch.non_tensor_batch, "pad_token_id": pad_token_id} @@ -211,7 +211,7 @@ def update_policy(self, data: DataProto): advantages = model_inputs["advantages"] # [AJET] Debug logging for tensor shapes input_ids = model_inputs["input_ids"] - print(f'-> Current tensor shape, input_ids {input_ids.shape}, response {response_mask.shape}') + print(f'[Update Policy] -> Micro batch shape, input_ids {input_ids.shape}, response {response_mask.shape}') entropy_coeff = self.config.entropy_coeff loss_agg_mode = self.config.loss_agg_mode From 694752ed59ae490cb571638286bb677403ef44c2 Mon Sep 17 00:00:00 2001 From: "qingxu.fu" Date: Wed, 25 Mar 2026 18:35:56 +0800 Subject: [PATCH 17/17] ignore benchmark report errs --- ajet/utils/testing_utils.py | 18 +++++++++++------- 1 file changed, 11 insertions(+), 7 deletions(-) diff --git a/ajet/utils/testing_utils.py b/ajet/utils/testing_utils.py index 7b71b736..5afe3b5d 100644 --- a/ajet/utils/testing_utils.py +++ b/ajet/utils/testing_utils.py @@ -102,13 +102,17 @@ def send_test_result( "append_log": append_log or "", "data_dashboard_url": data_dashboard_url or "", } - resp = requests.post( - r"https://benchmark-report.agent-matrix.com/report_test_result", - json=payload, - timeout=timeout, - ) - resp.raise_for_status() - return resp.json() + try: + resp = requests.post( + r"https://benchmark-report.agent-matrix.com/report_test_result", + json=payload, + timeout=timeout, + ) + resp.raise_for_status() + return resp.json() + except: + logger.error("Unable to report to benchmark server.") + return {} def populate_test_env_metadata(workspace_dir: str) -> tuple[str, str]: