From 2df1bb7c7689d928e149e27dea9a8250152d274d Mon Sep 17 00:00:00 2001 From: LZH Date: Mon, 18 May 2026 16:22:21 +0800 Subject: [PATCH 1/4] feat: gemma4 support initial commit (WIP) --- cookbook/mm/fsdp2_gemma4_mm.py | 198 +++++++++++++++++++++++++++++++ cookbook/mm/fsdp2_gemma4_mm.sh | 3 + src/twinkle/template/__init__.py | 1 + src/twinkle/template/gemma4.py | 26 ++++ 4 files changed, 228 insertions(+) create mode 100644 cookbook/mm/fsdp2_gemma4_mm.py create mode 100644 cookbook/mm/fsdp2_gemma4_mm.sh create mode 100644 src/twinkle/template/gemma4.py diff --git a/cookbook/mm/fsdp2_gemma4_mm.py b/cookbook/mm/fsdp2_gemma4_mm.py new file mode 100644 index 00000000..85a4762f --- /dev/null +++ b/cookbook/mm/fsdp2_gemma4_mm.py @@ -0,0 +1,198 @@ +import os +from peft import LoraConfig +from tqdm import tqdm +from transformers import AutoConfig +from transformers import ( + Gemma4AudioConfig, + Gemma4AudioFeatureExtractor, + Gemma4Config, + Gemma4ForCausalLM, + Gemma4ForConditionalGeneration, + Gemma4ImageProcessor, + Gemma4Processor, + Gemma4TextConfig, + Gemma4VideoProcessor, + Gemma4VisionConfig, + GemmaTokenizer, + GenerationConfig, + RopeParameters, +) + +import twinkle +from twinkle import DeviceMesh, Platform, get_device_placement, get_logger +from twinkle.dataloader import DataLoader +from twinkle.dataset import Dataset, DatasetMeta +from twinkle.model import TransformersModel +from twinkle.preprocessor import SelfCognitionProcessor, LatexOCRProcessor + +logger = get_logger() + +########## Construct a device_mesh ########## +device_mesh = DeviceMesh.from_sizes( + # fsdp_size=2, + # dp_size=1, + # ep_size=2, + device_type=Platform.get_platform().device_prefix(), +) +# use torchrun mode +twinkle.initialize(mode='local', global_device_mesh=device_mesh) + +########## 超参数 ########## +IGNORE_MISMATCHED_SIZES = True +# MODEL_PATH = '/nas/disk1/MiniMax-M2.5' +# MODEL_PATH = '/nas/disk1/Qwen3-30B-A3B' +# MODEL_PATH = '/nas/disk1/gemma-4-E2B-it' +# MODEL_PATH = '/nas/disk1/gemma-4-26B-A4B' +MODEL_PATH = r"C:\Users\wliuu\.cache\modelscope\hub\models\google\gemma-4-E2B-it" + +# DATASET_PATH = '/model/lzh/train/datasets/self-cognition.jsonl' +DATASET_PATH = 'ms://AI-ModelScope/LaTeX_OCR' +# DATASET_PATH = r'C:\Users\wliuu\.cache\modelscope\hub\datasets\LaTex_OCR_train.json' + +### 注意: gemma4-26b-a4b: text layers=30, vision layers=27 +TEXT_NUM_LAYERS = 3 +VISION_NUM_LAYERS = 3 + +TRAIN_LEN = 200 +BATCH_SIZE = 4 + +from twinkle.preprocessor import Preprocessor +from twinkle.data_format import Message, Trajectory +class LatexOCRProcessor(Preprocessor): + + def __call__(self, rows): # 输入的rows 是pyarrow.Table处理后的inputs + rows = self.map_col_to_row(rows) # 构建出一行一行的数据 + rows = [self.preprocess(row) for row in rows] # 每行构建为Trajectory + col = self.map_row_to_col(rows) # 变成列的形式, col['messages']下有BATCH_LEN条 Trajectory数据 + return col + + def preprocess(self, row) -> Trajectory: + return Trajectory( # + messages=[ + Message(role='user', content='Using LaTeX to perform OCR on the image.', images=[row['image']]), + Message(role='assistant', content=row['text']), # 相当于有监督学习的label + ] + ) + +def eval(model, eval_dataloader): + for step, batch in tqdm(enumerate(eval_dataloader)): + model.forward_only(inputs=batch) + model.calculate_loss() + metrics = model.calculate_metric(is_training=False) + return metrics + +def train(): + + ### prepare dataset and dataloader + dataset = Dataset(dataset_meta=DatasetMeta(DATASET_PATH, data_slice=range(TRAIN_LEN))) + # Set template to prepare encoding + # dataset.set_template('Template', model_id=MODEL_PATH) # 指定的是Template实例名称 + dataset.set_template('Gemma4Template', model_id=MODEL_PATH) # + # Preprocess the dataset to standard format + # dataset.map(SelfCognitionProcessor('twinkle大模型', 'ModelScope社区')) + dataset.map(preprocess_func=LatexOCRProcessor) + # Encode dataset + dataset.encode() # 2B可以用Template,26B会报错 + # Global batch size = 8, for GPUs, so 1 sample per GPU + dataloader = DataLoader(dataset=dataset, batch_size=BATCH_SIZE) + + # config = AutoConfig.from_pretrained(MODEL_PATH, trust_remote_code=True) + config, kwargs = AutoConfig.from_pretrained( + MODEL_PATH, + trust_remote_code=True, + return_unused_kwargs=True, + # code_revision=code_revision, + # _commit_hash=commit_hash, + # **hub_kwargs, + # **kwargs, + ) + + if isinstance(config, Gemma4Config): # 减层 + text_config = config.text_config + vision_config = config.vision_config + if TEXT_NUM_LAYERS is not None and hasattr(text_config, 'num_hidden_layers'): + text_config.num_hidden_layers = TEXT_NUM_LAYERS + print(f" modify > text_config.num_hidden_layers = {text_config.num_hidden_layers}") + if VISION_NUM_LAYERS is not None and hasattr(vision_config, 'num_hidden_layers'): + vision_config.num_hidden_layers = VISION_NUM_LAYERS + print(f" modify > vision_config.num_hidden_layers = {vision_config.num_hidden_layers}") + if hasattr(config, 'use_cache'): + config.use_cache = False + + # Use a TransformersModel + model = TransformersModel( + model_id=MODEL_PATH, + config=config, + device_mesh=device_mesh, + strategy="accelerate", # native_fsdp、 accelerate + ignore_mismatched_sizes=IGNORE_MISMATCHED_SIZES, + fsdp_config={ + 'reshard_after_forward': True, + 'expert_parallel': { + 'enabled': True, + 'router_dtype': 'fp32', + 'keep_router_logits': False, + } + }, + ) # 会直接load weights + # type(model): + # type(model.model): + # 若未传入config, 则自动通过 AutoConfig.from_pretrained 读取config + + # FSDP 不进行切分: + # model.model._no_split_modules = {'Gemma4VisionEncoderLayer', 'Gemma4TextDecoderLayer', 'Gemma4AudioLayer'} # for 3 modalities(2B、4B) + model.model._no_split_modules = {'Gemma4VisionEncoderLayer', 'Gemma4TextDecoderLayer'} # for 2 modalities(26B-A4B、31B) + + lora_config = LoraConfig(r=8, lora_alpha=32, target_modules='all-linear') + + # Add a lora to model, with name `default` + # Comment this to use full-parameter training + model.add_adapter_to_model('default', lora_config, gradient_accumulation_steps=2) # gradient_accumulation_steps? + # Add Optimizer for lora `default` + model.set_optimizer(optimizer_cls='AdamW', lr=1e-4) + + # Add LRScheduler for lora `default` + model.set_lr_scheduler( + scheduler_cls='CosineWarmupScheduler', num_warmup_steps=5, num_training_steps=len(dataloader)) + + logger.info(get_device_placement()) + # Print the training config + logger.info(model.get_train_configs()) + logger.info(f'Total steps: {len(dataloader)}') + best_eval_loss = 99.0 + # lora: 8G * 8 + # full: 18G * 8 + + ### eval dataset and dataloader + # EVAL_LENGTH = 100 + # eval_dataset = Dataset(dataset_meta=DatasetMeta(DATASET_PATH, data_slice=range(EVAL_LENGTH))) + # eval_dataset.set_template('Gemma4Template', model_id=MODEL_PATH) + # eval_dataset.map(SelfCognitionProcessor('twinkle大模型', 'ModelScope社区')) + # eval_dataset.encode() + # eval_dataloader = DataLoader(dataset=eval_dataset, batch_size=8) + for step, batch in enumerate(dataloader): + # Do forward and backward + model.forward_backward(inputs=batch) + # Step + model.clip_grad_and_step() + + if step % 10 == 0: + # Print metric + metric = model.calculate_metric(is_training=True) # 若npu显存不足, + logger.info(f'Current is step {step} of {len(dataloader)}, Train metric: {metric}') + + # if step % 10 == 0: + # metrics = eval(model, eval_dataloader) + # metrics['step'] = step + # if float(metrics['loss']) < best_eval_loss: + # # model.save(f'checkpoint-{step}') + # best_eval_loss = float(metrics['loss']) + # metrics['best_eval_loss'] = best_eval_loss + # logger.info(f'Current is step {step} of {len(dataloader)}, Eval metric: {metrics}') + + # model.save(f'last-checkpoint') + + +if __name__ == '__main__': + train() + diff --git a/cookbook/mm/fsdp2_gemma4_mm.sh b/cookbook/mm/fsdp2_gemma4_mm.sh new file mode 100644 index 00000000..5a790def --- /dev/null +++ b/cookbook/mm/fsdp2_gemma4_mm.sh @@ -0,0 +1,3 @@ +export CUDA_VISIBLE_DEVICES=0,1 + +torchrun --nnodes=1 --nproc_per_node=2 fsdp2_gemma4_mm.py diff --git a/src/twinkle/template/__init__.py b/src/twinkle/template/__init__.py index 324ce7ac..3760f52d 100644 --- a/src/twinkle/template/__init__.py +++ b/src/twinkle/template/__init__.py @@ -1,3 +1,4 @@ # Copyright (c) ModelScope Contributors. All rights reserved. from .base import Template from .qwen3_5_vl import Qwen3_5Template +from .gemma4 import Gemma4Template diff --git a/src/twinkle/template/gemma4.py b/src/twinkle/template/gemma4.py new file mode 100644 index 00000000..7fc292f8 --- /dev/null +++ b/src/twinkle/template/gemma4.py @@ -0,0 +1,26 @@ +import json +import torch +import torch.nn.functional as F +from typing import Any, Dict, List, Literal, Optional, Set, Tuple, Type, Union +from dataclasses import dataclass, field, fields +from PIL import Image +from copy import deepcopy + +from twinkle import remote_class, requires +from twinkle.template import Template + +Tool = Dict[str, Union[str, Dict]] +History = List[Union[Tuple[str, str], List[str]]] +Message = Dict[str, Union[str, List[Dict[str, Any]], List[int], None]] +Messages = List[Message] +Prompt = List[Union[str, List[int], List[str]]] +Word = Union[str, List[int]] +Context = Word + +@remote_class() +class Gemma4Template(Template): + """Processor for Google Gemma4 series.""" + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + # use original Template \ No newline at end of file From 854a8673e5f5376a279fc8ac11126d427a1f35e9 Mon Sep 17 00:00:00 2001 From: LZH Date: Wed, 20 May 2026 15:34:49 +0800 Subject: [PATCH 2/4] fix: resolve multimodal data handling bug in Template messages --- cookbook/mm/fsdp2_gemma4_mm.py | 96 +++++++++++++--------------------- src/twinkle/template/base.py | 36 +++++++++---- src/twinkle/template/gemma4.py | 35 +++++++++---- 3 files changed, 88 insertions(+), 79 deletions(-) diff --git a/cookbook/mm/fsdp2_gemma4_mm.py b/cookbook/mm/fsdp2_gemma4_mm.py index 85a4762f..128fa29f 100644 --- a/cookbook/mm/fsdp2_gemma4_mm.py +++ b/cookbook/mm/fsdp2_gemma4_mm.py @@ -3,19 +3,7 @@ from tqdm import tqdm from transformers import AutoConfig from transformers import ( - Gemma4AudioConfig, - Gemma4AudioFeatureExtractor, Gemma4Config, - Gemma4ForCausalLM, - Gemma4ForConditionalGeneration, - Gemma4ImageProcessor, - Gemma4Processor, - Gemma4TextConfig, - Gemma4VideoProcessor, - Gemma4VisionConfig, - GemmaTokenizer, - GenerationConfig, - RopeParameters, ) import twinkle @@ -23,7 +11,7 @@ from twinkle.dataloader import DataLoader from twinkle.dataset import Dataset, DatasetMeta from twinkle.model import TransformersModel -from twinkle.preprocessor import SelfCognitionProcessor, LatexOCRProcessor +# from twinkle.preprocessor import SelfCognitionProcessor, LatexOCRProcessor logger = get_logger() @@ -37,40 +25,35 @@ # use torchrun mode twinkle.initialize(mode='local', global_device_mesh=device_mesh) -########## 超参数 ########## +########## hyperparameters ########## IGNORE_MISMATCHED_SIZES = True -# MODEL_PATH = '/nas/disk1/MiniMax-M2.5' -# MODEL_PATH = '/nas/disk1/Qwen3-30B-A3B' -# MODEL_PATH = '/nas/disk1/gemma-4-E2B-it' -# MODEL_PATH = '/nas/disk1/gemma-4-26B-A4B' -MODEL_PATH = r"C:\Users\wliuu\.cache\modelscope\hub\models\google\gemma-4-E2B-it" - -# DATASET_PATH = '/model/lzh/train/datasets/self-cognition.jsonl' +MODEL_PATH = 'ms://google/gemma-4-26b-a4b' DATASET_PATH = 'ms://AI-ModelScope/LaTeX_OCR' -# DATASET_PATH = r'C:\Users\wliuu\.cache\modelscope\hub\datasets\LaTex_OCR_train.json' +TRAIN_LEN = 2000 +BATCH_SIZE = 4 +METRIC_STEP = 10 +SAVE_STEP = 10 -### 注意: gemma4-26b-a4b: text layers=30, vision layers=27 +### reduce model layers for debug TEXT_NUM_LAYERS = 3 VISION_NUM_LAYERS = 3 -TRAIN_LEN = 200 -BATCH_SIZE = 4 from twinkle.preprocessor import Preprocessor from twinkle.data_format import Message, Trajectory class LatexOCRProcessor(Preprocessor): - def __call__(self, rows): # 输入的rows 是pyarrow.Table处理后的inputs - rows = self.map_col_to_row(rows) # 构建出一行一行的数据 - rows = [self.preprocess(row) for row in rows] # 每行构建为Trajectory - col = self.map_row_to_col(rows) # 变成列的形式, col['messages']下有BATCH_LEN条 Trajectory数据 + def __call__(self, rows): + rows = self.map_col_to_row(rows) + rows = [self.preprocess(row) for row in rows] + col = self.map_row_to_col(rows) return col def preprocess(self, row) -> Trajectory: - return Trajectory( # + return Trajectory( messages=[ Message(role='user', content='Using LaTeX to perform OCR on the image.', images=[row['image']]), - Message(role='assistant', content=row['text']), # 相当于有监督学习的label + Message(role='assistant', content=row['text']), ] ) @@ -86,17 +69,14 @@ def train(): ### prepare dataset and dataloader dataset = Dataset(dataset_meta=DatasetMeta(DATASET_PATH, data_slice=range(TRAIN_LEN))) # Set template to prepare encoding - # dataset.set_template('Template', model_id=MODEL_PATH) # 指定的是Template实例名称 - dataset.set_template('Gemma4Template', model_id=MODEL_PATH) # + dataset.set_template('Gemma4Template', model_id=MODEL_PATH) # Preprocess the dataset to standard format - # dataset.map(SelfCognitionProcessor('twinkle大模型', 'ModelScope社区')) + # dataset.map(preprocess_func=SelfCognitionProcessor('twinkle大模型', 'ModelScope社区')) dataset.map(preprocess_func=LatexOCRProcessor) # Encode dataset - dataset.encode() # 2B可以用Template,26B会报错 - # Global batch size = 8, for GPUs, so 1 sample per GPU + dataset.encode() dataloader = DataLoader(dataset=dataset, batch_size=BATCH_SIZE) - # config = AutoConfig.from_pretrained(MODEL_PATH, trust_remote_code=True) config, kwargs = AutoConfig.from_pretrained( MODEL_PATH, trust_remote_code=True, @@ -134,20 +114,17 @@ def train(): 'keep_router_logits': False, } }, - ) # 会直接load weights - # type(model): - # type(model.model): + ) # 若未传入config, 则自动通过 AutoConfig.from_pretrained 读取config - # FSDP 不进行切分: # model.model._no_split_modules = {'Gemma4VisionEncoderLayer', 'Gemma4TextDecoderLayer', 'Gemma4AudioLayer'} # for 3 modalities(2B、4B) - model.model._no_split_modules = {'Gemma4VisionEncoderLayer', 'Gemma4TextDecoderLayer'} # for 2 modalities(26B-A4B、31B) + # model.model._no_split_modules = {'Gemma4VisionEncoderLayer', 'Gemma4TextDecoderLayer'} # for 2 modalities(26B-A4B、31B) lora_config = LoraConfig(r=8, lora_alpha=32, target_modules='all-linear') # Add a lora to model, with name `default` # Comment this to use full-parameter training - model.add_adapter_to_model('default', lora_config, gradient_accumulation_steps=2) # gradient_accumulation_steps? + model.add_adapter_to_model('default', lora_config, gradient_accumulation_steps=2) # Add Optimizer for lora `default` model.set_optimizer(optimizer_cls='AdamW', lr=1e-4) @@ -164,31 +141,32 @@ def train(): # full: 18G * 8 ### eval dataset and dataloader - # EVAL_LENGTH = 100 - # eval_dataset = Dataset(dataset_meta=DatasetMeta(DATASET_PATH, data_slice=range(EVAL_LENGTH))) - # eval_dataset.set_template('Gemma4Template', model_id=MODEL_PATH) - # eval_dataset.map(SelfCognitionProcessor('twinkle大模型', 'ModelScope社区')) - # eval_dataset.encode() - # eval_dataloader = DataLoader(dataset=eval_dataset, batch_size=8) + EVAL_LENGTH = 100 + eval_dataset = Dataset(dataset_meta=DatasetMeta(DATASET_PATH, data_slice=range(EVAL_LENGTH))) + eval_dataset.set_template('Gemma4Template', model_id=MODEL_PATH) + # eval_dataset.map(preprocess_func=SelfCognitionProcessor('twinkle大模型', 'ModelScope社区')) + dataset.map(preprocess_func=LatexOCRProcessor) + eval_dataset.encode() + eval_dataloader = DataLoader(dataset=eval_dataset, batch_size=8) for step, batch in enumerate(dataloader): # Do forward and backward model.forward_backward(inputs=batch) # Step model.clip_grad_and_step() - if step % 10 == 0: + if step % METRIC_STEP == 0: # Print metric - metric = model.calculate_metric(is_training=True) # 若npu显存不足, + metric = model.calculate_metric(is_training=True) logger.info(f'Current is step {step} of {len(dataloader)}, Train metric: {metric}') - # if step % 10 == 0: - # metrics = eval(model, eval_dataloader) - # metrics['step'] = step - # if float(metrics['loss']) < best_eval_loss: - # # model.save(f'checkpoint-{step}') - # best_eval_loss = float(metrics['loss']) - # metrics['best_eval_loss'] = best_eval_loss - # logger.info(f'Current is step {step} of {len(dataloader)}, Eval metric: {metrics}') + if step % SAVE_STEP == 0: + metrics = eval(model, eval_dataloader) + metrics['step'] = step + if float(metrics['loss']) < best_eval_loss: + # model.save(f'checkpoint-{step}') + best_eval_loss = float(metrics['loss']) + metrics['best_eval_loss'] = best_eval_loss + logger.info(f'Current is step {step} of {len(dataloader)}, Eval metric: {metrics}') # model.save(f'last-checkpoint') diff --git a/src/twinkle/template/base.py b/src/twinkle/template/base.py index 5784ddae..1575e0f3 100644 --- a/src/twinkle/template/base.py +++ b/src/twinkle/template/base.py @@ -37,7 +37,8 @@ def __init__(self, enable_thinking: bool = True, **kwargs): model_id = HubOperation.download_model(model_id, ignore_model=True) - if os.path.exists(os.path.join(model_id, 'preprocessor_config.json')): + if os.path.exists(os.path.join(model_id, 'preprocessor_config.json')) or os.path.exists( + os.path.join(model_id, 'processor_config.json')): from transformers import AutoProcessor self.processor = AutoProcessor.from_pretrained(model_id, **kwargs) else: @@ -52,15 +53,26 @@ def __init__(self, self.truncation_strategy = truncation_strategy self.default_system = default_system self._test_support_assistant_tokens_mask() - self.pre_pipeline: List[Callable[[Trajectory], List[Trajectory]]] = [ - self._add_default_system, # Add a default system field - self._to_standard_reasoning_content, # Convert thinking to standard field - self._build_standard_messages, # turn to standard mm messages + # self.pre_pipeline: List[Callable[[Trajectory], List[Trajectory]]] = [ + # self._add_default_system, # Add a default system field + # self._to_standard_reasoning_content, # Convert thinking to standard field + # self._build_standard_messages, # turn to standard mm messages + # ] + # self.post_pipeline: List[Callable[[InputFeature], List[InputFeature]]] = [ + # self._check_max_length, # Check and split input_features + # self._add_attention_fields, # Add useful fields + # self._roll_labels, # roll labels + # ] + self.pre_pipeline_names: List[str] = [ + "_add_default_system", + "_to_standard_reasoning_content", + "_build_standard_messages", ] - self.post_pipeline: List[Callable[[InputFeature], List[InputFeature]]] = [ - self._check_max_length, # Check and split input_features - self._add_attention_fields, # Add useful fields - self._roll_labels, # roll labels + + self.post_pipeline_names: List[str] = [ + "_check_max_length", + "_add_attention_fields", + "_roll_labels", ] @property @@ -140,7 +152,8 @@ def preprocess_audios(self, audios: List[AudioInput]) -> List[np.ndarray]: def _invoke_pre_pipeline(self, trajectories: List[Trajectory]) -> List[Trajectory]: current = trajectories - for pipeline in self.pre_pipeline: + for pipeline_name in self.pre_pipeline_names: + pipeline: Callable[[Trajectory], List[Trajectory]] = getattr(self, pipeline_name) next_batch = [] for trajectory in current: next_batch.extend(pipeline(trajectory)) @@ -149,7 +162,8 @@ def _invoke_pre_pipeline(self, trajectories: List[Trajectory]) -> List[Trajector def _invoke_post_pipeline(self, input_features: List[InputFeature]) -> List[InputFeature]: current = input_features - for pipeline in self.post_pipeline: + for pipeline_name in self.post_pipeline_names: + pipeline: Callable[[InputFeature], List[InputFeature]] = getattr(self, pipeline_name) next_batch = [] for input_feature in current: next_batch.extend(pipeline(input_feature)) diff --git a/src/twinkle/template/gemma4.py b/src/twinkle/template/gemma4.py index 7fc292f8..fd96f930 100644 --- a/src/twinkle/template/gemma4.py +++ b/src/twinkle/template/gemma4.py @@ -8,14 +8,7 @@ from twinkle import remote_class, requires from twinkle.template import Template - -Tool = Dict[str, Union[str, Dict]] -History = List[Union[Tuple[str, str], List[str]]] -Message = Dict[str, Union[str, List[Dict[str, Any]], List[int], None]] -Messages = List[Message] -Prompt = List[Union[str, List[int], List[str]]] -Word = Union[str, List[int]] -Context = Word +from twinkle.data_format import InputFeature, Message, Trajectory @remote_class() class Gemma4Template(Template): @@ -23,4 +16,28 @@ class Gemma4Template(Template): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - # use original Template \ No newline at end of file + # use original Template + + def _build_standard_messages(self, trajectory: Trajectory) -> List[Trajectory]: + # Extract trajectory-level media + extracted_images = trajectory.pop('images', None) or [ + img for msg in trajectory['messages'] + for img in msg.get('images', []) or [] + ] + extracted_videos = trajectory.pop('videos', None) or [ + video for msg in trajectory['messages'] + for video in msg.get('videos', []) or [] + ] + extracted_audios = trajectory.pop('audios', None) or [ + audio for msg in trajectory['messages'] + for audio in msg.get('audios', []) or [] + ] + images = self.preprocess_images(extracted_images) + videos = self.preprocess_videos(extracted_videos) + audios = self.preprocess_audios(extracted_audios) + + trajectory['messages'] = self._process_mm_messages(trajectory['messages'], images, videos, audios) + if not self.is_mm: + for message in trajectory['messages']: + message['content'] = message['content'][0]['text'] + return [trajectory] \ No newline at end of file From dccd0c5c10b4de949a5b6f494144308125a489d2 Mon Sep 17 00:00:00 2001 From: LZH Date: Wed, 20 May 2026 15:50:34 +0800 Subject: [PATCH 3/4] fix: correct some spelling errors --- cookbook/mm/fsdp2_gemma4_mm.py | 10 +++++----- src/twinkle/template/gemma4.py | 16 ++++++---------- 2 files changed, 11 insertions(+), 15 deletions(-) diff --git a/cookbook/mm/fsdp2_gemma4_mm.py b/cookbook/mm/fsdp2_gemma4_mm.py index 128fa29f..f4fa597c 100644 --- a/cookbook/mm/fsdp2_gemma4_mm.py +++ b/cookbook/mm/fsdp2_gemma4_mm.py @@ -92,10 +92,10 @@ def train(): vision_config = config.vision_config if TEXT_NUM_LAYERS is not None and hasattr(text_config, 'num_hidden_layers'): text_config.num_hidden_layers = TEXT_NUM_LAYERS - print(f" modify > text_config.num_hidden_layers = {text_config.num_hidden_layers}") + logger.info(f" modify > text_config.num_hidden_layers = {text_config.num_hidden_layers}") if VISION_NUM_LAYERS is not None and hasattr(vision_config, 'num_hidden_layers'): vision_config.num_hidden_layers = VISION_NUM_LAYERS - print(f" modify > vision_config.num_hidden_layers = {vision_config.num_hidden_layers}") + logger.info(f" modify > vision_config.num_hidden_layers = {vision_config.num_hidden_layers}") if hasattr(config, 'use_cache'): config.use_cache = False @@ -136,7 +136,7 @@ def train(): # Print the training config logger.info(model.get_train_configs()) logger.info(f'Total steps: {len(dataloader)}') - best_eval_loss = 99.0 + best_eval_loss = float('inf') # lora: 8G * 8 # full: 18G * 8 @@ -145,10 +145,10 @@ def train(): eval_dataset = Dataset(dataset_meta=DatasetMeta(DATASET_PATH, data_slice=range(EVAL_LENGTH))) eval_dataset.set_template('Gemma4Template', model_id=MODEL_PATH) # eval_dataset.map(preprocess_func=SelfCognitionProcessor('twinkle大模型', 'ModelScope社区')) - dataset.map(preprocess_func=LatexOCRProcessor) + eval_dataset.map(preprocess_func=LatexOCRProcessor) eval_dataset.encode() eval_dataloader = DataLoader(dataset=eval_dataset, batch_size=8) - for step, batch in enumerate(dataloader): + for step, batch in tqdm(enumerate(dataloader), total=len(dataloader)): # Do forward and backward model.forward_backward(inputs=batch) # Step diff --git a/src/twinkle/template/gemma4.py b/src/twinkle/template/gemma4.py index fd96f930..b629f6d4 100644 --- a/src/twinkle/template/gemma4.py +++ b/src/twinkle/template/gemma4.py @@ -1,14 +1,8 @@ -import json -import torch -import torch.nn.functional as F -from typing import Any, Dict, List, Literal, Optional, Set, Tuple, Type, Union -from dataclasses import dataclass, field, fields -from PIL import Image -from copy import deepcopy -from twinkle import remote_class, requires +from typing import List +from twinkle import remote_class from twinkle.template import Template -from twinkle.data_format import InputFeature, Message, Trajectory +from twinkle.data_format import Trajectory @remote_class() class Gemma4Template(Template): @@ -40,4 +34,6 @@ def _build_standard_messages(self, trajectory: Trajectory) -> List[Trajectory]: if not self.is_mm: for message in trajectory['messages']: message['content'] = message['content'][0]['text'] - return [trajectory] \ No newline at end of file + return [trajectory] + + \ No newline at end of file From e967689bd711d94ba5d5178688efcc55634f0ea7 Mon Sep 17 00:00:00 2001 From: LZH Date: Wed, 20 May 2026 17:05:07 +0800 Subject: [PATCH 4/4] refactor: move _build_standard_messages to base Template --- src/twinkle/template/base.py | 29 ++++++++++++++++------------- src/twinkle/template/gemma4.py | 24 ------------------------ 2 files changed, 16 insertions(+), 37 deletions(-) diff --git a/src/twinkle/template/base.py b/src/twinkle/template/base.py index 1575e0f3..bf110096 100644 --- a/src/twinkle/template/base.py +++ b/src/twinkle/template/base.py @@ -53,16 +53,7 @@ def __init__(self, self.truncation_strategy = truncation_strategy self.default_system = default_system self._test_support_assistant_tokens_mask() - # self.pre_pipeline: List[Callable[[Trajectory], List[Trajectory]]] = [ - # self._add_default_system, # Add a default system field - # self._to_standard_reasoning_content, # Convert thinking to standard field - # self._build_standard_messages, # turn to standard mm messages - # ] - # self.post_pipeline: List[Callable[[InputFeature], List[InputFeature]]] = [ - # self._check_max_length, # Check and split input_features - # self._add_attention_fields, # Add useful fields - # self._roll_labels, # roll labels - # ] + self.pre_pipeline_names: List[str] = [ "_add_default_system", "_to_standard_reasoning_content", @@ -450,9 +441,21 @@ def _process_mm_string_format(self, messages: List, images: List, videos: List, def _build_standard_messages(self, trajectory: Trajectory) -> List[Trajectory]: # Extract trajectory-level media - images = self.preprocess_images(trajectory.pop('images', None) or []) - videos = self.preprocess_videos(trajectory.pop('videos', None) or []) - audios = self.preprocess_audios(trajectory.pop('audios', None) or []) + extracted_images = trajectory.pop('images', None) or [ + img for msg in trajectory['messages'] + for img in msg.get('images', []) or [] + ] + extracted_videos = trajectory.pop('videos', None) or [ + video for msg in trajectory['messages'] + for video in msg.get('videos', []) or [] + ] + extracted_audios = trajectory.pop('audios', None) or [ + audio for msg in trajectory['messages'] + for audio in msg.get('audios', []) or [] + ] + images = self.preprocess_images(extracted_images) + videos = self.preprocess_videos(extracted_videos) + audios = self.preprocess_audios(extracted_audios) trajectory['messages'] = self._process_mm_messages(trajectory['messages'], images, videos, audios) if not self.is_mm: diff --git a/src/twinkle/template/gemma4.py b/src/twinkle/template/gemma4.py index b629f6d4..872515df 100644 --- a/src/twinkle/template/gemma4.py +++ b/src/twinkle/template/gemma4.py @@ -12,28 +12,4 @@ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) # use original Template - def _build_standard_messages(self, trajectory: Trajectory) -> List[Trajectory]: - # Extract trajectory-level media - extracted_images = trajectory.pop('images', None) or [ - img for msg in trajectory['messages'] - for img in msg.get('images', []) or [] - ] - extracted_videos = trajectory.pop('videos', None) or [ - video for msg in trajectory['messages'] - for video in msg.get('videos', []) or [] - ] - extracted_audios = trajectory.pop('audios', None) or [ - audio for msg in trajectory['messages'] - for audio in msg.get('audios', []) or [] - ] - images = self.preprocess_images(extracted_images) - videos = self.preprocess_videos(extracted_videos) - audios = self.preprocess_audios(extracted_audios) - - trajectory['messages'] = self._process_mm_messages(trajectory['messages'], images, videos, audios) - if not self.is_mm: - for message in trajectory['messages']: - message['content'] = message['content'][0]['text'] - return [trajectory] - \ No newline at end of file