diff --git a/tuning/config/configs.py b/tuning/config/configs.py index 9fba0163e..18ce0d22f 100644 --- a/tuning/config/configs.py +++ b/tuning/config/configs.py @@ -40,3 +40,7 @@ class TrainingArguments(transformers.TrainingArguments): default=False, metadata={"help": "Packing to be enabled in SFT Trainer, default is False"}, ) + profiling: bool = field( + default=False, + metadata={"help": "Whether to enable Profiling, default is False"}, + ) diff --git a/tuning/sft_trainer.py b/tuning/sft_trainer.py index a19f1c593..ac2476172 100644 --- a/tuning/sft_trainer.py +++ b/tuning/sft_trainer.py @@ -28,6 +28,18 @@ def on_save(self, args, state, control, **kwargs): if "pytorch_model.bin" in os.listdir(checkpoint_path): os.remove(os.path.join(checkpoint_path, "pytorch_model.bin")) +class ProfilerCallback(TrainerCallback): + def __init__(self, prof): + self.prof = prof + + def on_step_end(self, args, state, control, **kwargs): + self.prof.step() + +def trace_handler(prof): + prof.export_chrome_trace(f"trace.json.gz") + prof.export_memory_timeline(f"mem.html") + prof.export_memory_timeline(f"mem.json") + prof.export_memory_timeline(f"mem.raw.json.gz") def train( model_args: configs.ModelArguments, @@ -153,6 +165,17 @@ def train( data_collator = DataCollatorForCompletionOnlyLM(response_template_ids, tokenizer=tokenizer, ignore_index=configs.IGNORE_INDEX) packing = False + if train_args.profiling: + logger.info("Enabling Profiling") + profiler = torch.profiler.profile( + activities=[torch.profiler.ProfilerActivity.CPU, torch.profiler.ProfilerActivity.CUDA], + schedule=torch.profiler.schedule(wait=0, warmup=0, active=5, repeat=1), + record_shapes=True, + profile_memory=True, + with_stack=True, + on_trace_ready=trace_handler) + callbacks.append(ProfilerCallback(profiler)) + trainer = SFTTrainer( model=model, tokenizer=tokenizer, @@ -168,7 +191,12 @@ def train( if run_distributed and peft_config is not None: trainer.accelerator.state.fsdp_plugin.auto_wrap_policy = fsdp_auto_wrap_policy(model) - trainer.train() + + if train_args.profiling: + with profiler as prof: + trainer.train() + else: + trainer.train() def main(**kwargs): parser = transformers.HfArgumentParser(dataclass_types=(configs.ModelArguments,