-
Notifications
You must be signed in to change notification settings - Fork 32
feat: gemma4 support initial commit (WIP) #191
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Draft
EvineR666
wants to merge
1
commit into
modelscope:main
Choose a base branch
from
EvineR666:feat/add-gemma4-support
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
+228
−0
Draft
Changes from all commits
Commits
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,198 @@ | ||
| import os | ||
| from peft import LoraConfig | ||
| from tqdm import tqdm | ||
| from transformers import AutoConfig | ||
| from transformers import ( | ||
| Gemma4AudioConfig, | ||
| Gemma4AudioFeatureExtractor, | ||
| Gemma4Config, | ||
| Gemma4ForCausalLM, | ||
| Gemma4ForConditionalGeneration, | ||
| Gemma4ImageProcessor, | ||
| Gemma4Processor, | ||
| Gemma4TextConfig, | ||
| Gemma4VideoProcessor, | ||
| Gemma4VisionConfig, | ||
| GemmaTokenizer, | ||
| GenerationConfig, | ||
| RopeParameters, | ||
| ) | ||
|
|
||
| import twinkle | ||
| from twinkle import DeviceMesh, Platform, get_device_placement, get_logger | ||
| from twinkle.dataloader import DataLoader | ||
| from twinkle.dataset import Dataset, DatasetMeta | ||
| from twinkle.model import TransformersModel | ||
| from twinkle.preprocessor import SelfCognitionProcessor, LatexOCRProcessor | ||
|
|
||
| logger = get_logger() | ||
|
|
||
| ########## Construct a device_mesh ########## | ||
| device_mesh = DeviceMesh.from_sizes( | ||
| # fsdp_size=2, | ||
| # dp_size=1, | ||
| # ep_size=2, | ||
| device_type=Platform.get_platform().device_prefix(), | ||
| ) | ||
| # use torchrun mode | ||
| twinkle.initialize(mode='local', global_device_mesh=device_mesh) | ||
|
|
||
| ########## 超参数 ########## | ||
| IGNORE_MISMATCHED_SIZES = True | ||
| # MODEL_PATH = '/nas/disk1/MiniMax-M2.5' | ||
| # MODEL_PATH = '/nas/disk1/Qwen3-30B-A3B' | ||
| # MODEL_PATH = '/nas/disk1/gemma-4-E2B-it' | ||
| # MODEL_PATH = '/nas/disk1/gemma-4-26B-A4B' | ||
| MODEL_PATH = r"C:\Users\wliuu\.cache\modelscope\hub\models\google\gemma-4-E2B-it" | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
|
|
||
| # DATASET_PATH = '/model/lzh/train/datasets/self-cognition.jsonl' | ||
| DATASET_PATH = 'ms://AI-ModelScope/LaTeX_OCR' | ||
| # DATASET_PATH = r'C:\Users\wliuu\.cache\modelscope\hub\datasets\LaTex_OCR_train.json' | ||
|
|
||
| ### 注意: gemma4-26b-a4b: text layers=30, vision layers=27 | ||
| TEXT_NUM_LAYERS = 3 | ||
| VISION_NUM_LAYERS = 3 | ||
|
|
||
| TRAIN_LEN = 200 | ||
| BATCH_SIZE = 4 | ||
|
|
||
| from twinkle.preprocessor import Preprocessor | ||
| from twinkle.data_format import Message, Trajectory | ||
| class LatexOCRProcessor(Preprocessor): | ||
|
|
||
| def __call__(self, rows): # 输入的rows 是pyarrow.Table处理后的inputs | ||
| rows = self.map_col_to_row(rows) # 构建出一行一行的数据 | ||
| rows = [self.preprocess(row) for row in rows] # 每行构建为Trajectory | ||
| col = self.map_row_to_col(rows) # 变成列的形式, col['messages']下有BATCH_LEN条 Trajectory数据 | ||
| return col | ||
|
|
||
| def preprocess(self, row) -> Trajectory: | ||
| return Trajectory( # <class 'twinkle.data_format.trajectory.Trajectory'> | ||
| messages=[ | ||
| Message(role='user', content='<image>Using LaTeX to perform OCR on the image.', images=[row['image']]), | ||
| Message(role='assistant', content=row['text']), # 相当于有监督学习的label | ||
| ] | ||
| ) | ||
|
|
||
| def eval(model, eval_dataloader): | ||
| for step, batch in tqdm(enumerate(eval_dataloader)): | ||
| model.forward_only(inputs=batch) | ||
| model.calculate_loss() | ||
| metrics = model.calculate_metric(is_training=False) | ||
| return metrics | ||
|
|
||
| def train(): | ||
|
|
||
| ### prepare dataset and dataloader | ||
| dataset = Dataset(dataset_meta=DatasetMeta(DATASET_PATH, data_slice=range(TRAIN_LEN))) | ||
| # Set template to prepare encoding | ||
| # dataset.set_template('Template', model_id=MODEL_PATH) # 指定的是Template实例名称 | ||
| dataset.set_template('Gemma4Template', model_id=MODEL_PATH) # | ||
| # Preprocess the dataset to standard format | ||
| # dataset.map(SelfCognitionProcessor('twinkle大模型', 'ModelScope社区')) | ||
| dataset.map(preprocess_func=LatexOCRProcessor) | ||
| # Encode dataset | ||
| dataset.encode() # 2B可以用Template,26B会报错 | ||
| # Global batch size = 8, for GPUs, so 1 sample per GPU | ||
| dataloader = DataLoader(dataset=dataset, batch_size=BATCH_SIZE) | ||
|
|
||
| # config = AutoConfig.from_pretrained(MODEL_PATH, trust_remote_code=True) | ||
| config, kwargs = AutoConfig.from_pretrained( | ||
| MODEL_PATH, | ||
| trust_remote_code=True, | ||
| return_unused_kwargs=True, | ||
| # code_revision=code_revision, | ||
| # _commit_hash=commit_hash, | ||
| # **hub_kwargs, | ||
| # **kwargs, | ||
| ) | ||
|
|
||
| if isinstance(config, Gemma4Config): # 减层 | ||
| text_config = config.text_config | ||
| vision_config = config.vision_config | ||
| if TEXT_NUM_LAYERS is not None and hasattr(text_config, 'num_hidden_layers'): | ||
| text_config.num_hidden_layers = TEXT_NUM_LAYERS | ||
| print(f" modify > text_config.num_hidden_layers = {text_config.num_hidden_layers}") | ||
| if VISION_NUM_LAYERS is not None and hasattr(vision_config, 'num_hidden_layers'): | ||
| vision_config.num_hidden_layers = VISION_NUM_LAYERS | ||
| print(f" modify > vision_config.num_hidden_layers = {vision_config.num_hidden_layers}") | ||
| if hasattr(config, 'use_cache'): | ||
| config.use_cache = False | ||
|
|
||
| # Use a TransformersModel | ||
| model = TransformersModel( | ||
| model_id=MODEL_PATH, | ||
| config=config, | ||
| device_mesh=device_mesh, | ||
| strategy="accelerate", # native_fsdp、 accelerate | ||
| ignore_mismatched_sizes=IGNORE_MISMATCHED_SIZES, | ||
| fsdp_config={ | ||
| 'reshard_after_forward': True, | ||
| 'expert_parallel': { | ||
| 'enabled': True, | ||
| 'router_dtype': 'fp32', | ||
| 'keep_router_logits': False, | ||
| } | ||
| }, | ||
| ) # 会直接load weights | ||
| # type(model): <class 'twinkle.model.transformers.transformers.TransformersModel'> | ||
| # type(model.model): <class 'transformers.models.gemma4.modeling_gemma4.Gemma4ForConditionalGeneration'> | ||
| # 若未传入config, 则自动通过 AutoConfig.from_pretrained 读取config | ||
|
|
||
| # FSDP 不进行切分: | ||
| # model.model._no_split_modules = {'Gemma4VisionEncoderLayer', 'Gemma4TextDecoderLayer', 'Gemma4AudioLayer'} # for 3 modalities(2B、4B) | ||
| model.model._no_split_modules = {'Gemma4VisionEncoderLayer', 'Gemma4TextDecoderLayer'} # for 2 modalities(26B-A4B、31B) | ||
|
|
||
| lora_config = LoraConfig(r=8, lora_alpha=32, target_modules='all-linear') | ||
|
|
||
| # Add a lora to model, with name `default` | ||
| # Comment this to use full-parameter training | ||
| model.add_adapter_to_model('default', lora_config, gradient_accumulation_steps=2) # gradient_accumulation_steps? | ||
| # Add Optimizer for lora `default` | ||
| model.set_optimizer(optimizer_cls='AdamW', lr=1e-4) | ||
|
|
||
| # Add LRScheduler for lora `default` | ||
| model.set_lr_scheduler( | ||
| scheduler_cls='CosineWarmupScheduler', num_warmup_steps=5, num_training_steps=len(dataloader)) | ||
|
|
||
| logger.info(get_device_placement()) | ||
| # Print the training config | ||
| logger.info(model.get_train_configs()) | ||
| logger.info(f'Total steps: {len(dataloader)}') | ||
| best_eval_loss = 99.0 | ||
| # lora: 8G * 8 | ||
| # full: 18G * 8 | ||
|
|
||
| ### eval dataset and dataloader | ||
| # EVAL_LENGTH = 100 | ||
| # eval_dataset = Dataset(dataset_meta=DatasetMeta(DATASET_PATH, data_slice=range(EVAL_LENGTH))) | ||
| # eval_dataset.set_template('Gemma4Template', model_id=MODEL_PATH) | ||
| # eval_dataset.map(SelfCognitionProcessor('twinkle大模型', 'ModelScope社区')) | ||
| # eval_dataset.encode() | ||
| # eval_dataloader = DataLoader(dataset=eval_dataset, batch_size=8) | ||
| for step, batch in enumerate(dataloader): | ||
| # Do forward and backward | ||
| model.forward_backward(inputs=batch) | ||
| # Step | ||
| model.clip_grad_and_step() | ||
|
|
||
| if step % 10 == 0: | ||
| # Print metric | ||
| metric = model.calculate_metric(is_training=True) # 若npu显存不足, | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
| logger.info(f'Current is step {step} of {len(dataloader)}, Train metric: {metric}') | ||
|
|
||
| # if step % 10 == 0: | ||
| # metrics = eval(model, eval_dataloader) | ||
| # metrics['step'] = step | ||
| # if float(metrics['loss']) < best_eval_loss: | ||
| # # model.save(f'checkpoint-{step}') | ||
| # best_eval_loss = float(metrics['loss']) | ||
| # metrics['best_eval_loss'] = best_eval_loss | ||
| # logger.info(f'Current is step {step} of {len(dataloader)}, Eval metric: {metrics}') | ||
|
|
||
| # model.save(f'last-checkpoint') | ||
|
|
||
|
|
||
| if __name__ == '__main__': | ||
| train() | ||
|
|
||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,3 @@ | ||
| export CUDA_VISIBLE_DEVICES=0,1 | ||
|
|
||
| torchrun --nnodes=1 --nproc_per_node=2 fsdp2_gemma4_mm.py |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,3 +1,4 @@ | ||
| # Copyright (c) ModelScope Contributors. All rights reserved. | ||
| from .base import Template | ||
| from .qwen3_5_vl import Qwen3_5Template | ||
| from .gemma4 import Gemma4Template |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,26 @@ | ||
| import json | ||
| import torch | ||
| import torch.nn.functional as F | ||
| from typing import Any, Dict, List, Literal, Optional, Set, Tuple, Type, Union | ||
| from dataclasses import dataclass, field, fields | ||
| from PIL import Image | ||
| from copy import deepcopy | ||
|
|
||
| from twinkle import remote_class, requires | ||
| from twinkle.template import Template | ||
|
|
||
| Tool = Dict[str, Union[str, Dict]] | ||
| History = List[Union[Tuple[str, str], List[str]]] | ||
| Message = Dict[str, Union[str, List[Dict[str, Any]], List[int], None]] | ||
| Messages = List[Message] | ||
| Prompt = List[Union[str, List[int], List[str]]] | ||
| Word = Union[str, List[int]] | ||
| Context = Word | ||
|
Comment on lines
+1
to
+18
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
|
|
||
| @remote_class() | ||
| class Gemma4Template(Template): | ||
| """Processor for Google Gemma4 series.""" | ||
|
|
||
| def __init__(self, *args, **kwargs): | ||
| super().__init__(*args, **kwargs) | ||
| # use original Template | ||
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LatexOCRProcessoris imported here but then redefined on line 61. This is redundant and can lead to confusion. If the local definition is intended for this cookbook, please remove the import ofLatexOCRProcessorfrom line 26.