[model] Support gemma4#56
Conversation
There was a problem hiding this comment.
Code Review
This pull request introduces support for the Gemma4 model, refactors rotary positional embedding logic within GPTModel, and relocates the patch_deepcopy utility to a shared module. Feedback highlights a missing inference_context argument in a method call that could degrade inference performance, potential runtime crashes and dead code in the Gemma4-specific RoPE initialization, and the inclusion of commented-out code in the new model loader.
| rotary_pos_emb, decoder_rotary_pos_emb, rotary_pos_cos, rotary_pos_sin = self._get_rotary_pos_emb( | ||
| decoder_input, position_ids, packed_seq_params=packed_seq_params) |
There was a problem hiding this comment.
The inference_context is not passed to the _get_rotary_pos_emb method. This will cause the method to skip critical inference-specific logic, such as utilizing the RoPE cache or correctly calculating the rotary sequence length for flash decoding, which can lead to performance degradation or incorrect results during inference.
| rotary_pos_emb, decoder_rotary_pos_emb, rotary_pos_cos, rotary_pos_sin = self._get_rotary_pos_emb( | |
| decoder_input, position_ids, packed_seq_params=packed_seq_params) | |
| rotary_pos_emb, decoder_rotary_pos_emb, rotary_pos_cos, rotary_pos_sin = self._get_rotary_pos_emb( | |
| decoder_input, position_ids, packed_seq_params=packed_seq_params, inference_context=inference_context) |
| def _set_inv_freq(self): | ||
| rope_scaling = self.config.rope_scaling | ||
| self.config.rope_scaling = rope_scaling['sliding_attention'] | ||
| new_inv_freq, attention_scaling = get_rope_inv_freq(self.config) | ||
| assert attention_scaling == 1, 'not support' | ||
| self.rotary_pos_emb.inv_freq = new_inv_freq.to(self.rotary_pos_emb.inv_freq.device) | ||
| # full | ||
| self.full_rotary_pos_emb = copy.copy(self.rotary_pos_emb) | ||
| self.config.rope_scaling = rope_scaling['full_attention'] | ||
| kwargs = {} | ||
| if self.config.rope_scaling['rope_type'] == 'proportional': | ||
| kwargs['head_dim_key'] = 'global_head_dim' | ||
| new_inv_freq, attention_scaling = get_rope_inv_freq(self.config, **kwargs) | ||
| assert attention_scaling == 1, 'not support' | ||
| self.full_rotary_pos_emb.inv_freq = new_inv_freq | ||
| self.attention_scaling = attention_scaling | ||
|
|
||
| self.config.rope_scaling = rope_scaling |
There was a problem hiding this comment.
The implementation of _set_inv_freq for Gemma4TextGPTModel has several issues:
- Potential Runtime Crash: Restoring
self.config.rope_scalingto the original nested dictionary at line 62 will cause aKeyErrorin_get_rope_type(called viadynamic_rope_updateduring every forward pass) because that function expects a dictionary with arope_typekey at the top level, which the Gemma4 configuration lacks (it usessliding_attentionandfull_attentionas top-level keys). - Dead Code:
self.full_rotary_pos_embis initialized but never utilized by the baseGPTModelforward pass or RoPE application logic. - Poor Error Messages: The assertion messages
'not support'at lines 49 and 58 are not descriptive. They should clearly state that attention scaling other than 1.0 is not supported for this model.
| # def get_transformer_layer_spec(self, vp_stage: Optional[int] = None): | ||
| # layer_specs = get_gpt_decoder_block_spec( | ||
| # self.config, use_transformer_engine=True, normalization=self.config.normalization, vp_stage=vp_stage) | ||
| # for layer_spec in layer_specs.layer_specs: | ||
| # pass | ||
| # return layer_specs |
|
/gemini review |
There was a problem hiding this comment.
Code Review
This pull request introduces support for the Gemma4 model, implementing its specific architecture including sliding window attention, shared KV layers, and multimodal components. It refactors the bridge and base GPT model classes to accommodate these features, such as handling dictionary-based rotary embeddings and dynamic attention head configurations. Review feedback identified critical issues in the Gemma4 implementation where tensor indexing and boolean evaluations could lead to IndexError or RuntimeError during inference, especially with short sequences. Additionally, a minor style improvement was suggested for slicing syntax in the bridge logic.
| if 'full_attention' in shared_kv_states: | ||
| flag[1] = 1 |
There was a problem hiding this comment.
The assignment flag[1] = 1 can lead to an IndexError if the sequence length is less than 2. This is a common scenario during inference, for example, when seq_len=1. It's safer to guard this assignment with a check on the sequence length.
| if 'full_attention' in shared_kv_states: | |
| flag[1] = 1 | |
| if 'full_attention' in shared_kv_states and flag.shape[0] > 1: | |
| flag[1] = 1 |
| if flag[1] != 0: | ||
| input_tensor, full_states = input_tensor.split([input_tensor.shape[-1] - full_states_dim, full_states_dim], | ||
| dim=-1) | ||
| full_states = full_states.reshape(*full_states_shape) | ||
| shared_kv_states['full_attention'] = full_states.chunk(2, -1) | ||
| if flag[0] != 0: | ||
| input_tensor, sliding_states = input_tensor.split( | ||
| [input_tensor.shape[-1] - sliding_states_dim, sliding_states_dim], dim=-1) | ||
| sliding_states = sliding_states.reshape(*sliding_states_shape) | ||
| shared_kv_states['sliding_attention'] = sliding_states.chunk(2, -1) |
There was a problem hiding this comment.
The conditions if flag[1] != 0: and if flag[0] != 0: are likely to cause a RuntimeError: bool value of Tensor with more than one value is ambiguous. The flag tensor has a shape of (seq_len, batch_size, 1), which means flag[0] and flag[1] are slices, not scalars. You should check a specific element (e.g., flag[index, 0, 0]), assuming the flag is consistent across the batch. Additionally, you should guard against IndexError for small sequence lengths.
| if flag[1] != 0: | |
| input_tensor, full_states = input_tensor.split([input_tensor.shape[-1] - full_states_dim, full_states_dim], | |
| dim=-1) | |
| full_states = full_states.reshape(*full_states_shape) | |
| shared_kv_states['full_attention'] = full_states.chunk(2, -1) | |
| if flag[0] != 0: | |
| input_tensor, sliding_states = input_tensor.split( | |
| [input_tensor.shape[-1] - sliding_states_dim, sliding_states_dim], dim=-1) | |
| sliding_states = sliding_states.reshape(*sliding_states_shape) | |
| shared_kv_states['sliding_attention'] = sliding_states.chunk(2, -1) | |
| if flag.shape[0] > 1 and flag[1, 0, 0] != 0: | |
| input_tensor, full_states = input_tensor.split([input_tensor.shape[-1] - full_states_dim, full_states_dim], | |
| dim=-1) | |
| full_states = full_states.reshape(*full_states_shape) | |
| shared_kv_states['full_attention'] = full_states.chunk(2, -1) | |
| if flag.shape[0] > 0 and flag[0, 0, 0] != 0: | |
| input_tensor, sliding_states = input_tensor.split( | |
| [input_tensor.shape[-1] - sliding_states_dim, sliding_states_dim], dim=-1) | |
| sliding_states = sliding_states.reshape(*sliding_states_shape) | |
| shared_kv_states['sliding_attention'] = sliding_states.chunk(2, -1) |
|
/gemini review |
There was a problem hiding this comment.
Code Review
This pull request introduces support for the Gemma 4 model, incorporating multimodal vision and audio components, sliding window attention, and MoE layers. The implementation includes refactoring the GPT bridge for modularity and updating the configuration parser. Review feedback identifies several critical issues, including a NameError in the MoE state logic, potential AttributeError and IndexError in the Gemma 4 model implementation, and a possible TypeError in the configuration parser.
| video_features = self.get_video_features( | ||
| pixel_values_videos, video_position_ids, return_dict=True).pooler_output | ||
| video_features = video_features.to(inputs_embeds.device, inputs_embeds.dtype) | ||
| video_mask_e = video_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) | ||
| inputs_embeds = inputs_embeds.masked_scatter(video_mask_e, video_features) | ||
|
|
||
| if (input_features is not None and input_features_mask is not None and self.audio_tower is not None): | ||
| with self.patch_hf_config(): | ||
| audio_output = self.get_audio_features(input_features, input_features_mask, return_dict=True) |
There was a problem hiding this comment.
These calls to self.get_video_features and self.get_audio_features will likely raise an AttributeError as they are not defined in Gemma4Vit or its visible base classes. Following the pattern used for image features at line 98, these should probably be called via self.model_cls (e.g., self.model_cls.get_video_features(self, ...)).
|
/gemini review |
There was a problem hiding this comment.
Code Review
This pull request introduces support for the Gemma 4 model series, implementing specialized components such as Gemma4Vit, Gemma4SelfAttention, and Gemma4MoELayer. The changes include refactoring the GPT bridge to support flexible attention configurations, updating the model parser for Gemma 4 specificities, and enhancing the transformer block to handle nested structures during activation checkpointing. Feedback identifies several critical issues in the new implementation, including inconsistent method calls for multimodal feature extraction, undefined variables in the attention module, and potential index errors in pipeline parallel configurations. Additionally, improvements are suggested to avoid hardcoded device strings, fix incorrect bias handling in the MLP, and prevent dangerous in-place dictionary modifications within the checkpointing utility functions.
| video_features = self.get_video_features( | ||
| pixel_values_videos, video_position_ids, return_dict=True).pooler_output |
There was a problem hiding this comment.
Inconsistent method call. Based on line 98, these multimodal feature extraction methods should likely be called as class methods of self.model_cls passing self as the first argument, as Gemma4Vit might not have these methods directly defined or mixed in.
| video_features = self.get_video_features( | |
| pixel_values_videos, video_position_ids, return_dict=True).pooler_output | |
| video_features = self.model_cls.get_video_features( | |
| self, pixel_values_videos, video_position_ids, return_dict=True).pooler_output |
|
|
||
| if (input_features is not None and input_features_mask is not None and self.audio_tower is not None): | ||
| with self.patch_hf_config(): | ||
| audio_output = self.get_audio_features(input_features, input_features_mask, return_dict=True) |
There was a problem hiding this comment.
Inconsistent method call. Similar to the video features, this should likely be called via self.model_cls.
| audio_output = self.get_audio_features(input_features, input_features_mask, return_dict=True) | |
| audio_output = self.model_cls.get_audio_features(self, input_features, input_features_mask, return_dict=True) |
| packed_seq_params = kwargs.get('packed_seq_params') | ||
| attention_bias = kwargs.get('attention_bias') | ||
| mixed_qkv, _ = self.linear_qkv(hidden_states) | ||
| if getattr(self, 'world_size', None) is not None and self.config.num_query_groups < self.world_size: |
There was a problem hiding this comment.
self.world_size is not defined in Gemma4SelfAttention. This will cause getattr(self, 'world_size', None) to return None, skipping the logic intended to handle cases where num_query_groups < TP_size. You should initialize self.world_size in __init__ using parallel_state.get_tensor_model_parallel_world_size().
|
|
||
| def __init__(self, *args, **kwargs): | ||
| super().__init__(*args, **kwargs) | ||
| self.num_query_groups_per_partition = self.decoder.layers[0].self_attention.num_query_groups_per_partition |
There was a problem hiding this comment.
Accessing self.decoder.layers[0] can raise an IndexError in pipeline parallel configurations where a specific rank has no transformer layers (e.g., a rank only handling embeddings).
| self.num_query_groups_per_partition = self.decoder.layers[0].self_attention.num_query_groups_per_partition | |
| self.num_query_groups_per_partition = ( | |
| self.decoder.layers[0].self_attention.num_query_groups_per_partition | |
| if len(self.decoder.layers) > 0 else 0) |
| window_size = self.text_config.sliding_window - 1 | ||
| seq_len = attention_mask.shape[-1] | ||
|
|
||
| window_mask = torch.ones(seq_len, seq_len, dtype=torch.bool, device='cuda') |
There was a problem hiding this comment.
Hardcoding device='cuda' prevents the model from running on other devices (e.g., CPU or other accelerators). Use attention_mask.device instead.
| window_mask = torch.ones(seq_len, seq_len, dtype=torch.bool, device='cuda') | |
| window_mask = torch.ones(seq_len, seq_len, dtype=torch.bool, device=attention_mask.device) |
| mlp_output, bias = self.mlp(pre_mlp_layernorm_output, padding_mask=padding_mask) | ||
| if self.enable_moe_block: | ||
| mlp_output_1 = self.post_feedforward_layernorm_1(mlp_output) | ||
| mlp_output_2, bias = self.experts_mlp(residual, padding_mask=padding_mask) | ||
| mlp_output_2 = self.post_feedforward_layernorm_2(mlp_output_2) | ||
|
|
||
| # Combine mlp and moe outputs | ||
| mlp_output = mlp_output_1 + mlp_output_2 | ||
|
|
||
| mlp_output = self.post_feedforward_layernorm(mlp_output) | ||
| hidden_states = self.mlp_bda(self.training, self.config.bias_dropout_fusion)((mlp_output, bias), residual, | ||
| self.hidden_dropout) |
There was a problem hiding this comment.
The bias handling in _forward_mlp is incorrect if add_bias_linear is enabled. Biases from self.mlp and self.experts_mlp are ignored before their respective normalization layers (post_feedforward_layernorm_1/2), and the bias from self.mlp is shadowed by the one from self.experts_mlp. If biases are used, they must be added to the outputs before normalization.
| elif isinstance(obj, dict): | ||
| # inplace (gemma4 shared_kv_states) | ||
| for k, v in obj.items(): | ||
| obj[k] = _checkpoint_flatten(v, tensors) | ||
| return obj |
There was a problem hiding this comment.
Modifying the obj dictionary in-place is a dangerous side effect. If the dictionary is reused elsewhere in the model or in subsequent iterations, its contents will be corrupted with _TensorIdx sentinels. It is safer to return a new dictionary.
| elif isinstance(obj, dict): | |
| # inplace (gemma4 shared_kv_states) | |
| for k, v in obj.items(): | |
| obj[k] = _checkpoint_flatten(v, tensors) | |
| return obj | |
| elif isinstance(obj, dict): | |
| return {k: _checkpoint_flatten(v, tensors) for k, v in obj.items()} |
| elif isinstance(schema, dict): | ||
| # inplace (gemma4 shared_kv_states) | ||
| for k, v in schema.items(): | ||
| schema[k] = _checkpoint_unflatten(v, tensors) | ||
| return schema |
There was a problem hiding this comment.
Similarly to _checkpoint_flatten, modifying the schema dictionary in-place should be avoided to prevent unexpected behavior with captured closure variables.
| elif isinstance(schema, dict): | |
| # inplace (gemma4 shared_kv_states) | |
| for k, v in schema.items(): | |
| schema[k] = _checkpoint_unflatten(v, tensors) | |
| return schema | |
| elif isinstance(schema, dict): | |
| return {k: _checkpoint_unflatten(v, tensors) for k, v in schema.items()} |
|
|
|
/gemini review |
There was a problem hiding this comment.
Code Review
This pull request introduces support for the Gemma 4 model, involving significant updates to the README, configuration parser, and core model components. The changes include the implementation of Gemma 4-specific modules such as RMSNorm, Attention, and MoE layers, alongside refactoring the GPT bridge and transformer blocks to accommodate shared KV states and dynamic attention configurations. Feedback from the review highlights an opportunity to improve the maintainability of the Gemma 4 attention implementation by adding explanatory comments to complex logic.
| self.store_full_length_kv = not self.is_kv_shared_layer and layer_idx == len( | ||
| prev_layers) - 1 - prev_layers[::-1].index(text_config.layer_types[layer_idx]) |
There was a problem hiding this comment.
This logic is quite complex. Adding a comment to explain its purpose would improve readability and maintainability. For example, you could explain that this is determining if the current layer is the last of its type before the shared KV layers begin, which is used to decide whether to store the KV cache for subsequent layers.
modelscope/ms-swift#9296