Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
176 changes: 176 additions & 0 deletions cookbook/mm/fsdp2_gemma4_mm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,176 @@
import os
from peft import LoraConfig
from tqdm import tqdm
from transformers import AutoConfig
from transformers import (
Gemma4Config,
)

import twinkle
from twinkle import DeviceMesh, Platform, get_device_placement, get_logger
from twinkle.dataloader import DataLoader
from twinkle.dataset import Dataset, DatasetMeta
from twinkle.model import TransformersModel
# from twinkle.preprocessor import SelfCognitionProcessor, LatexOCRProcessor

logger = get_logger()

########## Construct a device_mesh ##########
device_mesh = DeviceMesh.from_sizes(
# fsdp_size=2,
# dp_size=1,
# ep_size=2,
device_type=Platform.get_platform().device_prefix(),
)
# use torchrun mode
twinkle.initialize(mode='local', global_device_mesh=device_mesh)

########## hyperparameters ##########
IGNORE_MISMATCHED_SIZES = True
MODEL_PATH = 'ms://google/gemma-4-26b-a4b'
DATASET_PATH = 'ms://AI-ModelScope/LaTeX_OCR'
TRAIN_LEN = 2000
BATCH_SIZE = 4
METRIC_STEP = 10
SAVE_STEP = 10

### reduce model layers for debug
TEXT_NUM_LAYERS = 3
VISION_NUM_LAYERS = 3


from twinkle.preprocessor import Preprocessor
from twinkle.data_format import Message, Trajectory
class LatexOCRProcessor(Preprocessor):

def __call__(self, rows):
rows = self.map_col_to_row(rows)
rows = [self.preprocess(row) for row in rows]
col = self.map_row_to_col(rows)
return col

def preprocess(self, row) -> Trajectory:
return Trajectory(
messages=[
Message(role='user', content='<image>Using LaTeX to perform OCR on the image.', images=[row['image']]),
Message(role='assistant', content=row['text']),
]
)

def eval(model, eval_dataloader):
for step, batch in tqdm(enumerate(eval_dataloader)):
model.forward_only(inputs=batch)
model.calculate_loss()
metrics = model.calculate_metric(is_training=False)
return metrics

def train():

### prepare dataset and dataloader
dataset = Dataset(dataset_meta=DatasetMeta(DATASET_PATH, data_slice=range(TRAIN_LEN)))
# Set template to prepare encoding
dataset.set_template('Gemma4Template', model_id=MODEL_PATH)
# Preprocess the dataset to standard format
# dataset.map(preprocess_func=SelfCognitionProcessor('twinkle大模型', 'ModelScope社区'))
dataset.map(preprocess_func=LatexOCRProcessor)
# Encode dataset
dataset.encode()
dataloader = DataLoader(dataset=dataset, batch_size=BATCH_SIZE)

config, kwargs = AutoConfig.from_pretrained(
MODEL_PATH,
trust_remote_code=True,
return_unused_kwargs=True,
# code_revision=code_revision,
# _commit_hash=commit_hash,
# **hub_kwargs,
# **kwargs,
)

if isinstance(config, Gemma4Config): # 减层
text_config = config.text_config
vision_config = config.vision_config
if TEXT_NUM_LAYERS is not None and hasattr(text_config, 'num_hidden_layers'):
text_config.num_hidden_layers = TEXT_NUM_LAYERS
logger.info(f" modify > text_config.num_hidden_layers = {text_config.num_hidden_layers}")
if VISION_NUM_LAYERS is not None and hasattr(vision_config, 'num_hidden_layers'):
vision_config.num_hidden_layers = VISION_NUM_LAYERS
logger.info(f" modify > vision_config.num_hidden_layers = {vision_config.num_hidden_layers}")
if hasattr(config, 'use_cache'):
config.use_cache = False

# Use a TransformersModel
model = TransformersModel(
model_id=MODEL_PATH,
config=config,
device_mesh=device_mesh,
strategy="accelerate", # native_fsdp、 accelerate
ignore_mismatched_sizes=IGNORE_MISMATCHED_SIZES,
fsdp_config={
'reshard_after_forward': True,
'expert_parallel': {
'enabled': True,
'router_dtype': 'fp32',
'keep_router_logits': False,
}
},
)
# 若未传入config, 则自动通过 AutoConfig.from_pretrained 读取config

# model.model._no_split_modules = {'Gemma4VisionEncoderLayer', 'Gemma4TextDecoderLayer', 'Gemma4AudioLayer'} # for 3 modalities(2B、4B)
# model.model._no_split_modules = {'Gemma4VisionEncoderLayer', 'Gemma4TextDecoderLayer'} # for 2 modalities(26B-A4B、31B)

lora_config = LoraConfig(r=8, lora_alpha=32, target_modules='all-linear')

# Add a lora to model, with name `default`
# Comment this to use full-parameter training
model.add_adapter_to_model('default', lora_config, gradient_accumulation_steps=2)
# Add Optimizer for lora `default`
model.set_optimizer(optimizer_cls='AdamW', lr=1e-4)

# Add LRScheduler for lora `default`
model.set_lr_scheduler(
scheduler_cls='CosineWarmupScheduler', num_warmup_steps=5, num_training_steps=len(dataloader))

logger.info(get_device_placement())
# Print the training config
logger.info(model.get_train_configs())
logger.info(f'Total steps: {len(dataloader)}')
best_eval_loss = float('inf')
# lora: 8G * 8
# full: 18G * 8

### eval dataset and dataloader
EVAL_LENGTH = 100
eval_dataset = Dataset(dataset_meta=DatasetMeta(DATASET_PATH, data_slice=range(EVAL_LENGTH)))
eval_dataset.set_template('Gemma4Template', model_id=MODEL_PATH)
# eval_dataset.map(preprocess_func=SelfCognitionProcessor('twinkle大模型', 'ModelScope社区'))
eval_dataset.map(preprocess_func=LatexOCRProcessor)
eval_dataset.encode()
eval_dataloader = DataLoader(dataset=eval_dataset, batch_size=8)
for step, batch in tqdm(enumerate(dataloader), total=len(dataloader)):
# Do forward and backward
model.forward_backward(inputs=batch)
# Step
model.clip_grad_and_step()

if step % METRIC_STEP == 0:
# Print metric
metric = model.calculate_metric(is_training=True)
logger.info(f'Current is step {step} of {len(dataloader)}, Train metric: {metric}')

if step % SAVE_STEP == 0:
metrics = eval(model, eval_dataloader)
metrics['step'] = step
if float(metrics['loss']) < best_eval_loss:
# model.save(f'checkpoint-{step}')
best_eval_loss = float(metrics['loss'])
metrics['best_eval_loss'] = best_eval_loss
logger.info(f'Current is step {step} of {len(dataloader)}, Eval metric: {metrics}')

# model.save(f'last-checkpoint')


if __name__ == '__main__':
train()

3 changes: 3 additions & 0 deletions cookbook/mm/fsdp2_gemma4_mm.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
export CUDA_VISIBLE_DEVICES=0,1

torchrun --nnodes=1 --nproc_per_node=2 fsdp2_gemma4_mm.py
1 change: 1 addition & 0 deletions src/twinkle/template/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
# Copyright (c) ModelScope Contributors. All rights reserved.
from .base import Template
from .qwen3_5_vl import Qwen3_5Template
from .gemma4 import Gemma4Template
45 changes: 31 additions & 14 deletions src/twinkle/template/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,8 @@ def __init__(self,
enable_thinking: bool = True,
**kwargs):
model_id = HubOperation.download_model(model_id, ignore_model=True)
if os.path.exists(os.path.join(model_id, 'preprocessor_config.json')):
if os.path.exists(os.path.join(model_id, 'preprocessor_config.json')) or os.path.exists(
os.path.join(model_id, 'processor_config.json')):
from transformers import AutoProcessor
self.processor = AutoProcessor.from_pretrained(model_id, **kwargs)
else:
Expand All @@ -52,15 +53,17 @@ def __init__(self,
self.truncation_strategy = truncation_strategy
self.default_system = default_system
self._test_support_assistant_tokens_mask()
self.pre_pipeline: List[Callable[[Trajectory], List[Trajectory]]] = [
self._add_default_system, # Add a default system field
self._to_standard_reasoning_content, # Convert thinking to standard field
self._build_standard_messages, # turn to standard mm messages

self.pre_pipeline_names: List[str] = [
"_add_default_system",
"_to_standard_reasoning_content",
"_build_standard_messages",
]
self.post_pipeline: List[Callable[[InputFeature], List[InputFeature]]] = [
self._check_max_length, # Check and split input_features
self._add_attention_fields, # Add useful fields
self._roll_labels, # roll labels

self.post_pipeline_names: List[str] = [
"_check_max_length",
"_add_attention_fields",
"_roll_labels",
]

@property
Expand Down Expand Up @@ -140,7 +143,8 @@ def preprocess_audios(self, audios: List[AudioInput]) -> List[np.ndarray]:

def _invoke_pre_pipeline(self, trajectories: List[Trajectory]) -> List[Trajectory]:
current = trajectories
for pipeline in self.pre_pipeline:
for pipeline_name in self.pre_pipeline_names:
pipeline: Callable[[Trajectory], List[Trajectory]] = getattr(self, pipeline_name)
next_batch = []
for trajectory in current:
next_batch.extend(pipeline(trajectory))
Expand All @@ -149,7 +153,8 @@ def _invoke_pre_pipeline(self, trajectories: List[Trajectory]) -> List[Trajector

def _invoke_post_pipeline(self, input_features: List[InputFeature]) -> List[InputFeature]:
current = input_features
for pipeline in self.post_pipeline:
for pipeline_name in self.post_pipeline_names:
pipeline: Callable[[InputFeature], List[InputFeature]] = getattr(self, pipeline_name)
next_batch = []
for input_feature in current:
next_batch.extend(pipeline(input_feature))
Expand Down Expand Up @@ -436,9 +441,21 @@ def _process_mm_string_format(self, messages: List, images: List, videos: List,

def _build_standard_messages(self, trajectory: Trajectory) -> List[Trajectory]:
# Extract trajectory-level media
images = self.preprocess_images(trajectory.pop('images', None) or [])
videos = self.preprocess_videos(trajectory.pop('videos', None) or [])
audios = self.preprocess_audios(trajectory.pop('audios', None) or [])
extracted_images = trajectory.pop('images', None) or [
img for msg in trajectory['messages']
for img in msg.get('images', []) or []
]
extracted_videos = trajectory.pop('videos', None) or [
video for msg in trajectory['messages']
for video in msg.get('videos', []) or []
]
extracted_audios = trajectory.pop('audios', None) or [
audio for msg in trajectory['messages']
for audio in msg.get('audios', []) or []
]
images = self.preprocess_images(extracted_images)
videos = self.preprocess_videos(extracted_videos)
audios = self.preprocess_audios(extracted_audios)

trajectory['messages'] = self._process_mm_messages(trajectory['messages'], images, videos, audios)
if not self.is_mm:
Expand Down
15 changes: 15 additions & 0 deletions src/twinkle/template/gemma4.py
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个类是否有存在的必要

Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@

from typing import List
from twinkle import remote_class
from twinkle.template import Template
from twinkle.data_format import Trajectory

@remote_class()
class Gemma4Template(Template):
"""Processor for Google Gemma4 series."""

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# use original Template