diff --git a/cookbook/mm/fsdp2_gemma4_mm.py b/cookbook/mm/fsdp2_gemma4_mm.py new file mode 100644 index 00000000..f4fa597c --- /dev/null +++ b/cookbook/mm/fsdp2_gemma4_mm.py @@ -0,0 +1,176 @@ +import os +from peft import LoraConfig +from tqdm import tqdm +from transformers import AutoConfig +from transformers import ( + Gemma4Config, +) + +import twinkle +from twinkle import DeviceMesh, Platform, get_device_placement, get_logger +from twinkle.dataloader import DataLoader +from twinkle.dataset import Dataset, DatasetMeta +from twinkle.model import TransformersModel +# from twinkle.preprocessor import SelfCognitionProcessor, LatexOCRProcessor + +logger = get_logger() + +########## Construct a device_mesh ########## +device_mesh = DeviceMesh.from_sizes( + # fsdp_size=2, + # dp_size=1, + # ep_size=2, + device_type=Platform.get_platform().device_prefix(), +) +# use torchrun mode +twinkle.initialize(mode='local', global_device_mesh=device_mesh) + +########## hyperparameters ########## +IGNORE_MISMATCHED_SIZES = True +MODEL_PATH = 'ms://google/gemma-4-26b-a4b' +DATASET_PATH = 'ms://AI-ModelScope/LaTeX_OCR' +TRAIN_LEN = 2000 +BATCH_SIZE = 4 +METRIC_STEP = 10 +SAVE_STEP = 10 + +### reduce model layers for debug +TEXT_NUM_LAYERS = 3 +VISION_NUM_LAYERS = 3 + + +from twinkle.preprocessor import Preprocessor +from twinkle.data_format import Message, Trajectory +class LatexOCRProcessor(Preprocessor): + + def __call__(self, rows): + rows = self.map_col_to_row(rows) + rows = [self.preprocess(row) for row in rows] + col = self.map_row_to_col(rows) + return col + + def preprocess(self, row) -> Trajectory: + return Trajectory( + messages=[ + Message(role='user', content='Using LaTeX to perform OCR on the image.', images=[row['image']]), + Message(role='assistant', content=row['text']), + ] + ) + +def eval(model, eval_dataloader): + for step, batch in tqdm(enumerate(eval_dataloader)): + model.forward_only(inputs=batch) + model.calculate_loss() + metrics = model.calculate_metric(is_training=False) + return metrics + +def train(): + + ### prepare dataset and dataloader + dataset = Dataset(dataset_meta=DatasetMeta(DATASET_PATH, data_slice=range(TRAIN_LEN))) + # Set template to prepare encoding + dataset.set_template('Gemma4Template', model_id=MODEL_PATH) + # Preprocess the dataset to standard format + # dataset.map(preprocess_func=SelfCognitionProcessor('twinkle大模型', 'ModelScope社区')) + dataset.map(preprocess_func=LatexOCRProcessor) + # Encode dataset + dataset.encode() + dataloader = DataLoader(dataset=dataset, batch_size=BATCH_SIZE) + + config, kwargs = AutoConfig.from_pretrained( + MODEL_PATH, + trust_remote_code=True, + return_unused_kwargs=True, + # code_revision=code_revision, + # _commit_hash=commit_hash, + # **hub_kwargs, + # **kwargs, + ) + + if isinstance(config, Gemma4Config): # 减层 + text_config = config.text_config + vision_config = config.vision_config + if TEXT_NUM_LAYERS is not None and hasattr(text_config, 'num_hidden_layers'): + text_config.num_hidden_layers = TEXT_NUM_LAYERS + logger.info(f" modify > text_config.num_hidden_layers = {text_config.num_hidden_layers}") + if VISION_NUM_LAYERS is not None and hasattr(vision_config, 'num_hidden_layers'): + vision_config.num_hidden_layers = VISION_NUM_LAYERS + logger.info(f" modify > vision_config.num_hidden_layers = {vision_config.num_hidden_layers}") + if hasattr(config, 'use_cache'): + config.use_cache = False + + # Use a TransformersModel + model = TransformersModel( + model_id=MODEL_PATH, + config=config, + device_mesh=device_mesh, + strategy="accelerate", # native_fsdp、 accelerate + ignore_mismatched_sizes=IGNORE_MISMATCHED_SIZES, + fsdp_config={ + 'reshard_after_forward': True, + 'expert_parallel': { + 'enabled': True, + 'router_dtype': 'fp32', + 'keep_router_logits': False, + } + }, + ) + # 若未传入config, 则自动通过 AutoConfig.from_pretrained 读取config + + # model.model._no_split_modules = {'Gemma4VisionEncoderLayer', 'Gemma4TextDecoderLayer', 'Gemma4AudioLayer'} # for 3 modalities(2B、4B) + # model.model._no_split_modules = {'Gemma4VisionEncoderLayer', 'Gemma4TextDecoderLayer'} # for 2 modalities(26B-A4B、31B) + + lora_config = LoraConfig(r=8, lora_alpha=32, target_modules='all-linear') + + # Add a lora to model, with name `default` + # Comment this to use full-parameter training + model.add_adapter_to_model('default', lora_config, gradient_accumulation_steps=2) + # Add Optimizer for lora `default` + model.set_optimizer(optimizer_cls='AdamW', lr=1e-4) + + # Add LRScheduler for lora `default` + model.set_lr_scheduler( + scheduler_cls='CosineWarmupScheduler', num_warmup_steps=5, num_training_steps=len(dataloader)) + + logger.info(get_device_placement()) + # Print the training config + logger.info(model.get_train_configs()) + logger.info(f'Total steps: {len(dataloader)}') + best_eval_loss = float('inf') + # lora: 8G * 8 + # full: 18G * 8 + + ### eval dataset and dataloader + EVAL_LENGTH = 100 + eval_dataset = Dataset(dataset_meta=DatasetMeta(DATASET_PATH, data_slice=range(EVAL_LENGTH))) + eval_dataset.set_template('Gemma4Template', model_id=MODEL_PATH) + # eval_dataset.map(preprocess_func=SelfCognitionProcessor('twinkle大模型', 'ModelScope社区')) + eval_dataset.map(preprocess_func=LatexOCRProcessor) + eval_dataset.encode() + eval_dataloader = DataLoader(dataset=eval_dataset, batch_size=8) + for step, batch in tqdm(enumerate(dataloader), total=len(dataloader)): + # Do forward and backward + model.forward_backward(inputs=batch) + # Step + model.clip_grad_and_step() + + if step % METRIC_STEP == 0: + # Print metric + metric = model.calculate_metric(is_training=True) + logger.info(f'Current is step {step} of {len(dataloader)}, Train metric: {metric}') + + if step % SAVE_STEP == 0: + metrics = eval(model, eval_dataloader) + metrics['step'] = step + if float(metrics['loss']) < best_eval_loss: + # model.save(f'checkpoint-{step}') + best_eval_loss = float(metrics['loss']) + metrics['best_eval_loss'] = best_eval_loss + logger.info(f'Current is step {step} of {len(dataloader)}, Eval metric: {metrics}') + + # model.save(f'last-checkpoint') + + +if __name__ == '__main__': + train() + diff --git a/cookbook/mm/fsdp2_gemma4_mm.sh b/cookbook/mm/fsdp2_gemma4_mm.sh new file mode 100644 index 00000000..5a790def --- /dev/null +++ b/cookbook/mm/fsdp2_gemma4_mm.sh @@ -0,0 +1,3 @@ +export CUDA_VISIBLE_DEVICES=0,1 + +torchrun --nnodes=1 --nproc_per_node=2 fsdp2_gemma4_mm.py diff --git a/src/twinkle/template/__init__.py b/src/twinkle/template/__init__.py index 324ce7ac..3760f52d 100644 --- a/src/twinkle/template/__init__.py +++ b/src/twinkle/template/__init__.py @@ -1,3 +1,4 @@ # Copyright (c) ModelScope Contributors. All rights reserved. from .base import Template from .qwen3_5_vl import Qwen3_5Template +from .gemma4 import Gemma4Template diff --git a/src/twinkle/template/base.py b/src/twinkle/template/base.py index 5784ddae..bf110096 100644 --- a/src/twinkle/template/base.py +++ b/src/twinkle/template/base.py @@ -37,7 +37,8 @@ def __init__(self, enable_thinking: bool = True, **kwargs): model_id = HubOperation.download_model(model_id, ignore_model=True) - if os.path.exists(os.path.join(model_id, 'preprocessor_config.json')): + if os.path.exists(os.path.join(model_id, 'preprocessor_config.json')) or os.path.exists( + os.path.join(model_id, 'processor_config.json')): from transformers import AutoProcessor self.processor = AutoProcessor.from_pretrained(model_id, **kwargs) else: @@ -52,15 +53,17 @@ def __init__(self, self.truncation_strategy = truncation_strategy self.default_system = default_system self._test_support_assistant_tokens_mask() - self.pre_pipeline: List[Callable[[Trajectory], List[Trajectory]]] = [ - self._add_default_system, # Add a default system field - self._to_standard_reasoning_content, # Convert thinking to standard field - self._build_standard_messages, # turn to standard mm messages + + self.pre_pipeline_names: List[str] = [ + "_add_default_system", + "_to_standard_reasoning_content", + "_build_standard_messages", ] - self.post_pipeline: List[Callable[[InputFeature], List[InputFeature]]] = [ - self._check_max_length, # Check and split input_features - self._add_attention_fields, # Add useful fields - self._roll_labels, # roll labels + + self.post_pipeline_names: List[str] = [ + "_check_max_length", + "_add_attention_fields", + "_roll_labels", ] @property @@ -140,7 +143,8 @@ def preprocess_audios(self, audios: List[AudioInput]) -> List[np.ndarray]: def _invoke_pre_pipeline(self, trajectories: List[Trajectory]) -> List[Trajectory]: current = trajectories - for pipeline in self.pre_pipeline: + for pipeline_name in self.pre_pipeline_names: + pipeline: Callable[[Trajectory], List[Trajectory]] = getattr(self, pipeline_name) next_batch = [] for trajectory in current: next_batch.extend(pipeline(trajectory)) @@ -149,7 +153,8 @@ def _invoke_pre_pipeline(self, trajectories: List[Trajectory]) -> List[Trajector def _invoke_post_pipeline(self, input_features: List[InputFeature]) -> List[InputFeature]: current = input_features - for pipeline in self.post_pipeline: + for pipeline_name in self.post_pipeline_names: + pipeline: Callable[[InputFeature], List[InputFeature]] = getattr(self, pipeline_name) next_batch = [] for input_feature in current: next_batch.extend(pipeline(input_feature)) @@ -436,9 +441,21 @@ def _process_mm_string_format(self, messages: List, images: List, videos: List, def _build_standard_messages(self, trajectory: Trajectory) -> List[Trajectory]: # Extract trajectory-level media - images = self.preprocess_images(trajectory.pop('images', None) or []) - videos = self.preprocess_videos(trajectory.pop('videos', None) or []) - audios = self.preprocess_audios(trajectory.pop('audios', None) or []) + extracted_images = trajectory.pop('images', None) or [ + img for msg in trajectory['messages'] + for img in msg.get('images', []) or [] + ] + extracted_videos = trajectory.pop('videos', None) or [ + video for msg in trajectory['messages'] + for video in msg.get('videos', []) or [] + ] + extracted_audios = trajectory.pop('audios', None) or [ + audio for msg in trajectory['messages'] + for audio in msg.get('audios', []) or [] + ] + images = self.preprocess_images(extracted_images) + videos = self.preprocess_videos(extracted_videos) + audios = self.preprocess_audios(extracted_audios) trajectory['messages'] = self._process_mm_messages(trajectory['messages'], images, videos, audios) if not self.is_mm: diff --git a/src/twinkle/template/gemma4.py b/src/twinkle/template/gemma4.py new file mode 100644 index 00000000..872515df --- /dev/null +++ b/src/twinkle/template/gemma4.py @@ -0,0 +1,15 @@ + +from typing import List +from twinkle import remote_class +from twinkle.template import Template +from twinkle.data_format import Trajectory + +@remote_class() +class Gemma4Template(Template): + """Processor for Google Gemma4 series.""" + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + # use original Template + +