Skip to content

[model] Support gemma4#56

Merged
Jintao-Huang merged 66 commits into
modelscope:mainfrom
Jintao-Huang:support_gemma4
May 19, 2026
Merged

[model] Support gemma4#56
Jintao-Huang merged 66 commits into
modelscope:mainfrom
Jintao-Huang:support_gemma4

Conversation

@Jintao-Huang
Copy link
Copy Markdown
Collaborator

@Jintao-Huang Jintao-Huang commented Apr 30, 2026

Copy link
Copy Markdown

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces support for the Gemma4 model, refactors rotary positional embedding logic within GPTModel, and relocates the patch_deepcopy utility to a shared module. Feedback highlights a missing inference_context argument in a method call that could degrade inference performance, potential runtime crashes and dead code in the Gemma4-specific RoPE initialization, and the inclusion of commented-out code in the new model loader.

Comment thread src/mcore_bridge/model/gpt_model.py Outdated
Comment on lines +223 to +224
rotary_pos_emb, decoder_rotary_pos_emb, rotary_pos_cos, rotary_pos_sin = self._get_rotary_pos_emb(
decoder_input, position_ids, packed_seq_params=packed_seq_params)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The inference_context is not passed to the _get_rotary_pos_emb method. This will cause the method to skip critical inference-specific logic, such as utilizing the RoPE cache or correctly calculating the rotary sequence length for flash decoding, which can lead to performance degradation or incorrect results during inference.

Suggested change
rotary_pos_emb, decoder_rotary_pos_emb, rotary_pos_cos, rotary_pos_sin = self._get_rotary_pos_emb(
decoder_input, position_ids, packed_seq_params=packed_seq_params)
rotary_pos_emb, decoder_rotary_pos_emb, rotary_pos_cos, rotary_pos_sin = self._get_rotary_pos_emb(
decoder_input, position_ids, packed_seq_params=packed_seq_params, inference_context=inference_context)

Comment on lines +45 to +62
def _set_inv_freq(self):
rope_scaling = self.config.rope_scaling
self.config.rope_scaling = rope_scaling['sliding_attention']
new_inv_freq, attention_scaling = get_rope_inv_freq(self.config)
assert attention_scaling == 1, 'not support'
self.rotary_pos_emb.inv_freq = new_inv_freq.to(self.rotary_pos_emb.inv_freq.device)
# full
self.full_rotary_pos_emb = copy.copy(self.rotary_pos_emb)
self.config.rope_scaling = rope_scaling['full_attention']
kwargs = {}
if self.config.rope_scaling['rope_type'] == 'proportional':
kwargs['head_dim_key'] = 'global_head_dim'
new_inv_freq, attention_scaling = get_rope_inv_freq(self.config, **kwargs)
assert attention_scaling == 1, 'not support'
self.full_rotary_pos_emb.inv_freq = new_inv_freq
self.attention_scaling = attention_scaling

self.config.rope_scaling = rope_scaling
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The implementation of _set_inv_freq for Gemma4TextGPTModel has several issues:

  1. Potential Runtime Crash: Restoring self.config.rope_scaling to the original nested dictionary at line 62 will cause a KeyError in _get_rope_type (called via dynamic_rope_update during every forward pass) because that function expects a dictionary with a rope_type key at the top level, which the Gemma4 configuration lacks (it uses sliding_attention and full_attention as top-level keys).
  2. Dead Code: self.full_rotary_pos_emb is initialized but never utilized by the base GPTModel forward pass or RoPE application logic.
  3. Poor Error Messages: The assertion messages 'not support' at lines 49 and 58 are not descriptive. They should clearly state that attention scaling other than 1.0 is not supported for this model.

Comment on lines +71 to +76
# def get_transformer_layer_spec(self, vp_stage: Optional[int] = None):
# layer_specs = get_gpt_decoder_block_spec(
# self.config, use_transformer_engine=True, normalization=self.config.normalization, vp_stage=vp_stage)
# for layer_spec in layer_specs.layer_specs:
# pass
# return layer_specs
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Please remove the commented-out code in Gemma4Loader to maintain code cleanliness and readability.

@Jintao-Huang
Copy link
Copy Markdown
Collaborator Author

/gemini review

Copy link
Copy Markdown

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces support for the Gemma4 model, implementing its specific architecture including sliding window attention, shared KV layers, and multimodal components. It refactors the bridge and base GPT model classes to accommodate these features, such as handling dictionary-based rotary embeddings and dynamic attention head configurations. Review feedback identified critical issues in the Gemma4 implementation where tensor indexing and boolean evaluations could lead to IndexError or RuntimeError during inference, especially with short sequences. Additionally, a minor style improvement was suggested for slicing syntax in the bridge logic.

Comment on lines +530 to +531
if 'full_attention' in shared_kv_states:
flag[1] = 1
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The assignment flag[1] = 1 can lead to an IndexError if the sequence length is less than 2. This is a common scenario during inference, for example, when seq_len=1. It's safer to guard this assignment with a check on the sequence length.

Suggested change
if 'full_attention' in shared_kv_states:
flag[1] = 1
if 'full_attention' in shared_kv_states and flag.shape[0] > 1:
flag[1] = 1

Comment on lines +555 to +564
if flag[1] != 0:
input_tensor, full_states = input_tensor.split([input_tensor.shape[-1] - full_states_dim, full_states_dim],
dim=-1)
full_states = full_states.reshape(*full_states_shape)
shared_kv_states['full_attention'] = full_states.chunk(2, -1)
if flag[0] != 0:
input_tensor, sliding_states = input_tensor.split(
[input_tensor.shape[-1] - sliding_states_dim, sliding_states_dim], dim=-1)
sliding_states = sliding_states.reshape(*sliding_states_shape)
shared_kv_states['sliding_attention'] = sliding_states.chunk(2, -1)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The conditions if flag[1] != 0: and if flag[0] != 0: are likely to cause a RuntimeError: bool value of Tensor with more than one value is ambiguous. The flag tensor has a shape of (seq_len, batch_size, 1), which means flag[0] and flag[1] are slices, not scalars. You should check a specific element (e.g., flag[index, 0, 0]), assuming the flag is consistent across the batch. Additionally, you should guard against IndexError for small sequence lengths.

Suggested change
if flag[1] != 0:
input_tensor, full_states = input_tensor.split([input_tensor.shape[-1] - full_states_dim, full_states_dim],
dim=-1)
full_states = full_states.reshape(*full_states_shape)
shared_kv_states['full_attention'] = full_states.chunk(2, -1)
if flag[0] != 0:
input_tensor, sliding_states = input_tensor.split(
[input_tensor.shape[-1] - sliding_states_dim, sliding_states_dim], dim=-1)
sliding_states = sliding_states.reshape(*sliding_states_shape)
shared_kv_states['sliding_attention'] = sliding_states.chunk(2, -1)
if flag.shape[0] > 1 and flag[1, 0, 0] != 0:
input_tensor, full_states = input_tensor.split([input_tensor.shape[-1] - full_states_dim, full_states_dim],
dim=-1)
full_states = full_states.reshape(*full_states_shape)
shared_kv_states['full_attention'] = full_states.chunk(2, -1)
if flag.shape[0] > 0 and flag[0, 0, 0] != 0:
input_tensor, sliding_states = input_tensor.split(
[input_tensor.shape[-1] - sliding_states_dim, sliding_states_dim], dim=-1)
sliding_states = sliding_states.reshape(*sliding_states_shape)
shared_kv_states['sliding_attention'] = sliding_states.chunk(2, -1)

Comment thread src/mcore_bridge/bridge/gpt_bridge.py Outdated
@Jintao-Huang
Copy link
Copy Markdown
Collaborator Author

/gemini review

Copy link
Copy Markdown

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces support for the Gemma 4 model, incorporating multimodal vision and audio components, sliding window attention, and MoE layers. The implementation includes refactoring the GPT bridge for modularity and updating the configuration parser. Review feedback identifies several critical issues, including a NameError in the MoE state logic, potential AttributeError and IndexError in the Gemma 4 model implementation, and a possible TypeError in the configuration parser.

Comment thread src/mcore_bridge/bridge/gpt_bridge.py Outdated
Comment on lines +106 to +114
video_features = self.get_video_features(
pixel_values_videos, video_position_ids, return_dict=True).pooler_output
video_features = video_features.to(inputs_embeds.device, inputs_embeds.dtype)
video_mask_e = video_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
inputs_embeds = inputs_embeds.masked_scatter(video_mask_e, video_features)

if (input_features is not None and input_features_mask is not None and self.audio_tower is not None):
with self.patch_hf_config():
audio_output = self.get_audio_features(input_features, input_features_mask, return_dict=True)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

These calls to self.get_video_features and self.get_audio_features will likely raise an AttributeError as they are not defined in Gemma4Vit or its visible base classes. Following the pattern used for image features at line 98, these should probably be called via self.model_cls (e.g., self.model_cls.get_video_features(self, ...)).

Comment thread src/mcore_bridge/model/mm_gpts/gemma4.py
Comment thread src/mcore_bridge/config/parser.py
Comment thread src/mcore_bridge/model/mm_gpts/gemma4.py
@Jintao-Huang
Copy link
Copy Markdown
Collaborator Author

/gemini review

Copy link
Copy Markdown

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces support for the Gemma 4 model series, implementing specialized components such as Gemma4Vit, Gemma4SelfAttention, and Gemma4MoELayer. The changes include refactoring the GPT bridge to support flexible attention configurations, updating the model parser for Gemma 4 specificities, and enhancing the transformer block to handle nested structures during activation checkpointing. Feedback identifies several critical issues in the new implementation, including inconsistent method calls for multimodal feature extraction, undefined variables in the attention module, and potential index errors in pipeline parallel configurations. Additionally, improvements are suggested to avoid hardcoded device strings, fix incorrect bias handling in the MLP, and prevent dangerous in-place dictionary modifications within the checkpointing utility functions.

Comment on lines +106 to +107
video_features = self.get_video_features(
pixel_values_videos, video_position_ids, return_dict=True).pooler_output
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Inconsistent method call. Based on line 98, these multimodal feature extraction methods should likely be called as class methods of self.model_cls passing self as the first argument, as Gemma4Vit might not have these methods directly defined or mixed in.

Suggested change
video_features = self.get_video_features(
pixel_values_videos, video_position_ids, return_dict=True).pooler_output
video_features = self.model_cls.get_video_features(
self, pixel_values_videos, video_position_ids, return_dict=True).pooler_output


if (input_features is not None and input_features_mask is not None and self.audio_tower is not None):
with self.patch_hf_config():
audio_output = self.get_audio_features(input_features, input_features_mask, return_dict=True)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Inconsistent method call. Similar to the video features, this should likely be called via self.model_cls.

Suggested change
audio_output = self.get_audio_features(input_features, input_features_mask, return_dict=True)
audio_output = self.model_cls.get_audio_features(self, input_features, input_features_mask, return_dict=True)

packed_seq_params = kwargs.get('packed_seq_params')
attention_bias = kwargs.get('attention_bias')
mixed_qkv, _ = self.linear_qkv(hidden_states)
if getattr(self, 'world_size', None) is not None and self.config.num_query_groups < self.world_size:
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

self.world_size is not defined in Gemma4SelfAttention. This will cause getattr(self, 'world_size', None) to return None, skipping the logic intended to handle cases where num_query_groups < TP_size. You should initialize self.world_size in __init__ using parallel_state.get_tensor_model_parallel_world_size().


def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.num_query_groups_per_partition = self.decoder.layers[0].self_attention.num_query_groups_per_partition
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Accessing self.decoder.layers[0] can raise an IndexError in pipeline parallel configurations where a specific rank has no transformer layers (e.g., a rank only handling embeddings).

Suggested change
self.num_query_groups_per_partition = self.decoder.layers[0].self_attention.num_query_groups_per_partition
self.num_query_groups_per_partition = (
self.decoder.layers[0].self_attention.num_query_groups_per_partition
if len(self.decoder.layers) > 0 else 0)

window_size = self.text_config.sliding_window - 1
seq_len = attention_mask.shape[-1]

window_mask = torch.ones(seq_len, seq_len, dtype=torch.bool, device='cuda')
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Hardcoding device='cuda' prevents the model from running on other devices (e.g., CPU or other accelerators). Use attention_mask.device instead.

Suggested change
window_mask = torch.ones(seq_len, seq_len, dtype=torch.bool, device='cuda')
window_mask = torch.ones(seq_len, seq_len, dtype=torch.bool, device=attention_mask.device)

Comment on lines +731 to +742
mlp_output, bias = self.mlp(pre_mlp_layernorm_output, padding_mask=padding_mask)
if self.enable_moe_block:
mlp_output_1 = self.post_feedforward_layernorm_1(mlp_output)
mlp_output_2, bias = self.experts_mlp(residual, padding_mask=padding_mask)
mlp_output_2 = self.post_feedforward_layernorm_2(mlp_output_2)

# Combine mlp and moe outputs
mlp_output = mlp_output_1 + mlp_output_2

mlp_output = self.post_feedforward_layernorm(mlp_output)
hidden_states = self.mlp_bda(self.training, self.config.bias_dropout_fusion)((mlp_output, bias), residual,
self.hidden_dropout)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The bias handling in _forward_mlp is incorrect if add_bias_linear is enabled. Biases from self.mlp and self.experts_mlp are ignored before their respective normalization layers (post_feedforward_layernorm_1/2), and the bias from self.mlp is shadowed by the one from self.experts_mlp. If biases are used, they must be added to the outputs before normalization.

Comment on lines +41 to +45
elif isinstance(obj, dict):
# inplace (gemma4 shared_kv_states)
for k, v in obj.items():
obj[k] = _checkpoint_flatten(v, tensors)
return obj
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Modifying the obj dictionary in-place is a dangerous side effect. If the dictionary is reused elsewhere in the model or in subsequent iterations, its contents will be corrupted with _TensorIdx sentinels. It is safer to return a new dictionary.

Suggested change
elif isinstance(obj, dict):
# inplace (gemma4 shared_kv_states)
for k, v in obj.items():
obj[k] = _checkpoint_flatten(v, tensors)
return obj
elif isinstance(obj, dict):
return {k: _checkpoint_flatten(v, tensors) for k, v in obj.items()}

Comment on lines +56 to +60
elif isinstance(schema, dict):
# inplace (gemma4 shared_kv_states)
for k, v in schema.items():
schema[k] = _checkpoint_unflatten(v, tensors)
return schema
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Similarly to _checkpoint_flatten, modifying the schema dictionary in-place should be avoided to prevent unexpected behavior with captured closure variables.

Suggested change
elif isinstance(schema, dict):
# inplace (gemma4 shared_kv_states)
for k, v in schema.items():
schema[k] = _checkpoint_unflatten(v, tensors)
return schema
elif isinstance(schema, dict):
return {k: _checkpoint_unflatten(v, tensors) for k, v in schema.items()}

@Jintao-Huang
Copy link
Copy Markdown
Collaborator Author

Jintao-Huang commented May 19, 2026

@Jintao-Huang
Copy link
Copy Markdown
Collaborator Author

/gemini review

Copy link
Copy Markdown

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces support for the Gemma 4 model, involving significant updates to the README, configuration parser, and core model components. The changes include the implementation of Gemma 4-specific modules such as RMSNorm, Attention, and MoE layers, alongside refactoring the GPT bridge and transformer blocks to accommodate shared KV states and dynamic attention configurations. Feedback from the review highlights an opportunity to improve the maintainability of the Gemma 4 attention implementation by adding explanatory comments to complex logic.

Comment on lines +155 to +156
self.store_full_length_kv = not self.is_kv_shared_layer and layer_idx == len(
prev_layers) - 1 - prev_layers[::-1].index(text_config.layer_types[layer_idx])
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This logic is quite complex. Adding a comment to explain its purpose would improve readability and maintainability. For example, you could explain that this is determining if the current layer is the last of its type before the shared KV layers begin, which is used to decide whether to store the KV cache for subsequent layers.

@Jintao-Huang Jintao-Huang merged commit 28017e1 into modelscope:main May 19, 2026
1 check passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants