Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
b2c70db
tinkerscript-v1
binary-husky Jan 19, 2026
21f9bb8
improve tinkerscript
binary-husky Jan 19, 2026
2178481
feat(tinkerscript): Add comprehensive design blueprint and workflow d…
binary-husky Jan 19, 2026
0444632
fix mermaid
binary-husky Jan 19, 2026
adbeadc
Remove limitations and Chinese version from documentation
binary-husky Jan 19, 2026
4ebf055
Clarify relationship between TinkerScript and Tinker
binary-husky Jan 19, 2026
ba10f35
Update tinkerscript.md
binary-husky Jan 19, 2026
c4b86a8
remove trinity
binary-husky Jan 21, 2026
a04001a
Add AgentJet image to TinkerScript documentation
binary-husky Jan 21, 2026
81e2b7d
Merge branch 'main' into tinkerscript-v1
binary-husky Jan 28, 2026
ae13326
feat: implement TinkerScript server functionality and enhance configu…
binary-husky Jan 28, 2026
5cc7297
feat: enhance TinkerScript integration with improved engine status ha…
binary-husky Jan 28, 2026
968c2cf
feat: enhance TinkerScript functionality with improved engine status …
binary-husky Jan 28, 2026
a6c7e0e
stage eval code ( to be tested )
binary-husky Jan 28, 2026
7660007
union_gen_batch_via_task_id is to be tested
binary-husky Jan 29, 2026
f2f3b16
stage dataset io improvement
binary-husky Jan 30, 2026
920e4d5
stage academic translation agent
binary-husky Feb 2, 2026
8777bcb
stage swarm server
binary-husky Feb 3, 2026
3157658
fix state machine bugs
binary-husky Feb 4, 2026
175e259
rename to agentjet swarm
binary-husky Feb 4, 2026
f1edf19
update pro-academic-trans agent
binary-husky Feb 5, 2026
47812cb
revise pro-trans
binary-husky Feb 5, 2026
b15983a
make rollout more robust
binary-husky Feb 5, 2026
4cb513b
enhance error logging during tracker.tokenize() for better debugging
binary-husky Feb 6, 2026
5132c2b
improve readability
binary-husky Feb 9, 2026
98db2b7
delete exit message
binary-husky Feb 9, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 27 additions & 6 deletions ajet/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,4 @@
from ajet.copilot.job import AgentJetJob
from ajet.schema.task import WorkflowOutput, WorkflowTask
from ajet.tuner import AjetTuner
from ajet.workflow import Workflow
from ajet.utils.vsdb import vscode_conditional_breakpoint as bp
__version__ = "0.1.0"

__all__ = [
"Workflow",
Expand All @@ -13,4 +9,29 @@
"bp"
]

__version__ = "0.1.0"
_LAZY_IMPORTS = {
"AjetTuner": "ajet.tuner",
"AgentJetJob": "ajet.copilot.job",
"WorkflowOutput": "ajet.schema.task",
"WorkflowTask": "ajet.schema.task",
"Workflow": "ajet.workflow",
"bp": "ajet.utils.vsdb",
}

_ATTR_MAPPING = {
"bp": "vscode_conditional_breakpoint"
}

def __getattr__(name):
if name in _LAZY_IMPORTS:
import importlib
module_path = _LAZY_IMPORTS[name]
module = importlib.import_module(module_path)

attr_name = _ATTR_MAPPING.get(name, name)
value = getattr(module, attr_name) # type: ignore

globals()[name] = value
return value

raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
30 changes: 13 additions & 17 deletions ajet/backbone/main_verl.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,29 +22,33 @@
import hydra
import ray
from beast_logger import print_dict
from loguru import logger
from omegaconf import OmegaConf
from omegaconf import DictConfig, OmegaConf
from verl.trainer.ppo.reward import load_reward_manager
from verl.utils.device import is_cuda_available
from verl.utils.dataset.rl_dataset import collate_fn
from torch.utils.data import Dataset as TorchDataset

# Create training and validation datasets.
from ajet.task_reader import RouterTaskReader, task_to_standard_dataset
from ajet.utils.process_dataset import create_rl_sampler
from ajet.utils.core_env_vars import get_runtime_env
from ajet.utils.launch_utils import set_loguru_default_color

set_loguru_default_color()


@hydra.main(config_path="config", config_name="ppo_trainer", version_base=None)
def main(config):
def main(config: DictConfig) -> None:
"""Main entry point for PPO training with Hydra configuration management.

Args:
config_dict: Hydra configuration dictionary containing training parameters.
config: Hydra configuration dictionary containing training parameters.
"""
run_ppo(config)


# Define a function to run the PPO-like training process
def run_ppo(config) -> None:
def run_ppo(config: DictConfig) -> None:
"""Initialize Ray cluster and run distributed PPO training process.

Args:
Expand All @@ -56,7 +60,6 @@ def run_ppo(config) -> None:
if not ray.is_initialized():
# this is for local ray cluster
runtime_env = get_runtime_env(config)
print_dict(runtime_env["env_vars"], "runtime_env")
ray.init(
runtime_env=runtime_env,
num_cpus=config.ray_init.num_cpus,
Expand Down Expand Up @@ -110,6 +113,7 @@ def run(self, config):
# Print the initial configuration. `resolve=True` will evaluate symbolic values.
from pprint import pprint

from loguru import logger
from omegaconf import OmegaConf
from verl.utils.fs import copy_to_local

Expand Down Expand Up @@ -227,21 +231,13 @@ def run(self, config):
resource_pool_spec=resource_pool_spec, mapping=mapping
)

from verl.utils.dataset.rl_dataset import collate_fn

# Create training and validation datasets.
from ajet.task_reader import (
RouterTaskReader,
task_to_standard_dataset,
)
from ajet.utils.process_dataset import create_rl_sampler

task_reader = RouterTaskReader(
config.ajet.task_reader.type,
config.ajet.task_reader,
)
val_dataset = task_to_standard_dataset(task_reader.get_validation_tasks())
train_dataset = task_to_standard_dataset(task_reader.get_training_tasks())

train_dataset: TorchDataset = task_to_standard_dataset(task_reader.generate_training_tasks) # type: ignore
val_dataset: TorchDataset = task_to_standard_dataset(task_reader.generate_validation_tasks) # type: ignore
train_sampler = create_rl_sampler(config.data, train_dataset)

from ajet.backbone.trainer_verl import AjetRayPPOTrainer
Expand Down
10 changes: 8 additions & 2 deletions ajet/backbone/main_vllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,7 @@ def run(config):
max_parallel = config.ajet.debug.debug_max_parallel
n_task = config.ajet.debug.debug_first_n_tasks
vllm_port = config.ajet.debug.debug_vllm_port
enable_swarm_mode = config.ajet.enable_swarm_mode

# --------- init ---------
async_rollout_manager = ChatCompletionScheduler(
Expand All @@ -166,8 +167,10 @@ def run(config):
tasks = task_reader.get_validation_tasks()
logger.info(tasks[:n_task])
ctx_tracker = parallel_env.rollout(
tasks=tasks[:n_task], mode="sample", epoch="1"
) # "sample" or "validate"
tasks=tasks[:n_task],
mode="sample" if not enable_swarm_mode else "sample-ts", # type: ignore
epoch="1"
)
_ = parallel_env.to_dataproto(ctx_tracker)


Expand All @@ -186,6 +189,9 @@ def main(config):
if config.ajet.enable_experimental_interchange_server:
from ajet.tuner_lib.weight_tuner.experimental.as_oai_model_server import start_interchange_server
start_interchange_server(config)
if config.ajet.enable_swarm_mode:
from ajet.tuner_lib.weight_tuner.experimental.interchange_utils import http_change_engine_status
http_change_engine_status(config, "ENGINE.ROLLING")

def companion_launch():
import torch
Expand Down
6 changes: 2 additions & 4 deletions ajet/backbone/trainer_trinity.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,11 +206,9 @@ def __init__(self, config):

dataset_segments = []
if "train" in self.split:
dataset_segments.append(task_to_standard_dataset(task_reader.get_training_tasks()))
dataset_segments.append(task_to_standard_dataset(task_reader.generate_training_tasks)) # type: ignore
if "val" in self.split:
dataset_segments.append(
task_to_standard_dataset(task_reader.get_validation_tasks())
)
dataset_segments.append(task_to_standard_dataset(task_reader.generate_validation_tasks)) # type: ignore
if not dataset_segments:
raise ValueError(
f"Unsupported split '{self.split}'. Expected to contain 'train' or 'val'."
Expand Down
71 changes: 47 additions & 24 deletions ajet/backbone/trainer_verl.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,16 +99,29 @@ def parse_reward_from_dataproto(data: DataProto, return_dict=False) -> dict | to
return reward_tensor


def union_gen_batch_via_task_id(tasks, batch: DataProto, gen_batch_output: DataProto):
def union_gen_batch_via_task_id(tasks, batch: DataProto, gen_batch_output: DataProto, discard_original_batch=False):
"""
Union the gen_batch_output with the batch based on task_id.
"""
map_task_id_to_index = {t.task_id: i for i, t in enumerate(tasks)}
gen_task_task_ids = gen_batch_output.non_tensor_batch["task_ids"]
indices = [map_task_id_to_index[tid] for tid in gen_task_task_ids]
batch_extend = batch.select_idxs(indices)
batch_final = batch_extend.union(gen_batch_output)
return batch_final
if not discard_original_batch:
map_task_id_to_index = {t.task_id: i for i, t in enumerate(tasks)}
gen_task_task_ids = gen_batch_output.non_tensor_batch["task_ids"]
indices = [map_task_id_to_index[tid] for tid in gen_task_task_ids]
batch_extend = batch.select_idxs(indices)
batch_final = batch_extend.union(gen_batch_output)
return batch_final
else:
gen_batch_output.non_tensor_batch['uid'] = gen_batch_output.non_tensor_batch["task_ids"]
task_id_counter = {}
for i, tid in enumerate(gen_batch_output.non_tensor_batch["task_ids"]):
if tid in task_id_counter:
task_id_counter[tid] += 1
else:
task_id_counter[tid] = 1
current_id = task_id_counter[tid]
gen_batch_output.non_tensor_batch['rollout_ids'][i] = f"T{tid}R{current_id}"
logger.info(f'task_id_counter: {task_id_counter}')
return gen_batch_output


def compute_advantage(
Expand Down Expand Up @@ -443,6 +456,12 @@ def init_workers(self):
tokenizer=self.tokenizer,
)

def _update_interchange_server_status_flag(self, status: str):
if self.config.ajet.enable_experimental_interchange_server:
if self.config.ajet.enable_swarm_mode:
from ajet.tuner_lib.weight_tuner.experimental.interchange_utils import http_change_engine_status
http_change_engine_status(self.config, status)

# #######################################
# training loop
# #######################################
Expand Down Expand Up @@ -474,7 +493,7 @@ def fit(self): # noqa: C901

# perform validation before training
# currently, we only support validation using the reward_function.
if self.val_reward_fn is not None and self.config.trainer.get("val_before_train", True):
if (self.val_reward_fn is not None) and (self.config.trainer.get("val_before_train", True)) and (not self.config.ajet.enable_swarm_mode):
val_metrics = self._validate()
assert val_metrics, f"{val_metrics=}"
pprint(f"Initial validation metrics: {val_metrics}")
Expand Down Expand Up @@ -547,12 +566,13 @@ def fit(self): # noqa: C901

with marked_timer("step", timing_raw):
# generate a batch
logger.info("=== + rollout step begin ===")
logger.info("rollout step begin")
with marked_timer("gen", timing_raw, color="red"):
assert self.async_rollout_mode
logger.info("=== wake up begin ===")
logger.info("wake up begin")
self.async_rollout_manager.wake_up()
logger.info("=== wake up end ===")
self._update_interchange_server_status_flag("ENGINE.ROLLING")
logger.info("wake up end")
tasks: List[Task] = [
dict_to_ajet_task(dict(
task_id=gen_batch.non_tensor_batch["task_id"][i],
Expand All @@ -571,15 +591,17 @@ def fit(self): # noqa: C901
]
)
)
logger.info("=" * 10 + "start fit rollout" + "=" * 10)
logger.info("start fit rollout")
self.parallel_env.current_global_steps = self.global_steps
context_tracker_arr: List[BaseContextTracker] = self.parallel_env.rollout(
tasks, mode="sample", epoch=f"train.{epoch}"
)
logger.info("=" * 10 + "end fit rollout" + "=" * 10)
logger.info("begin to convert context_tracker_arr to dataproto")

# from ajet import bp; bp("BATCH")

logger.info("end fit rollout")
gen_batch_output = self.parallel_env.to_dataproto(context_tracker_arr)
logger.info("end convertion")
logger.info("end dataproto convertion")

success_rate = [
traj.reward_structure.success_rate for traj in context_tracker_arr
Expand Down Expand Up @@ -622,17 +644,17 @@ def fit(self): # noqa: C901
logger.info(
f"gen_batch_output.info batch.keys={gen_batch_output.batch.keys()}"
)
self._update_interchange_server_status_flag("ENGINE.WEIGHT_SYNCING")
self.async_rollout_manager.sleep()
logger.info("=== - rollout step end ===")
logger.info("rollout step end")

if self.config.algorithm.adv_estimator == AdvantageEstimator.REMAX:
raise NotImplementedError("REMAX is not supported in GRPO yet.")

batch.non_tensor_batch["uid"] = np.array(
[str(uuid.uuid4()) for _ in range(len(batch.batch))],
dtype=object,
)
batch = union_gen_batch_via_task_id(tasks, batch, gen_batch_output)
discard_original_batch = self.config.ajet.enable_swarm_mode
batch = union_gen_batch_via_task_id(tasks, batch, gen_batch_output, discard_original_batch)
batch.batch["response_mask"] = compute_response_mask(batch)

if "response_mask" not in batch.batch.keys():
Expand Down Expand Up @@ -666,7 +688,7 @@ def fit(self): # noqa: C901
)

# recompute old_log_probs
logger.info("=== + compute log_probs begin ===")
logger.info("+ compute log_probs begin")
with marked_timer("old_log_prob", timing_raw, color="blue"):
old_log_prob = self.actor_rollout_wg.compute_log_prob(batch)
entropys = old_log_prob.batch["entropys"]
Expand Down Expand Up @@ -764,6 +786,7 @@ def fit(self): # noqa: C901
self.val_reward_fn is not None
and self.config.trainer.test_freq > 0
and (is_last_step or self.global_steps % self.config.trainer.test_freq == 0)
and (not self.config.ajet.enable_swarm_mode)
):
with marked_timer("testing", timing_raw, color="green"):
val_metrics: dict = self._validate()
Expand Down Expand Up @@ -914,17 +937,16 @@ def _validate(self):
self.async_rollout_manager.wake_up()
main_val_dataset = self.get_eval_dataset()

logger.info("=" * 10 + "start validate rollout" + "=" * 10)
logger.info("Starting validate rollout")
context_tracker_arr, tasks, val_metrics = self.eval_dataset(
target_dataset=main_val_dataset,
target_dataset_name="main_val_dataset",
mode="validate",
epoch="test.1",
)
logger.info("=" * 10 + "end validate rollout" + "=" * 10)
logger.info("Completed validate rollout")
test_output_gen_batch = self.parallel_env.to_dataproto(context_tracker_arr)
self.async_rollout_manager.sleep()
logger.info("validation generation end")

# Store generated outputs
output_ids = test_output_gen_batch.batch["responses"]
Expand All @@ -938,7 +960,8 @@ def _validate(self):
dtype=object,
)
tasks = tasks[: len(main_val_dataset)]
test_batch = union_gen_batch_via_task_id(tasks, test_batch, test_output_gen_batch)
discard_original_batch = self.config.ajet.enable_swarm_mode
test_batch = union_gen_batch_via_task_id(tasks, test_batch, test_output_gen_batch, discard_original_batch)
# test_batch = test_batch.union(test_output_gen_batch)
test_batch.meta_info["validate"] = True

Expand Down
Loading
Loading