-
Notifications
You must be signed in to change notification settings - Fork 32
Padding free bufix ljl #202
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| @@ -1,4 +1,5 @@ | ||||||||||||||||
| # Copyright (c) ModelScope Contributors. All rights reserved. | ||||||||||||||||
| import inspect | ||||||||||||||||
| import math | ||||||||||||||||
| import torch | ||||||||||||||||
| import torch.distributed as dist | ||||||||||||||||
|
|
@@ -28,6 +29,38 @@ def is_qwen3_omni(model): | |||||||||||||||
| return 'qwen3_omni' in mt | ||||||||||||||||
|
|
||||||||||||||||
|
|
||||||||||||||||
| def _call_with_supported_kwargs(fn, *args, **kwargs): | ||||||||||||||||
| signature = inspect.signature(fn) | ||||||||||||||||
| if not any(param.kind == inspect.Parameter.VAR_KEYWORD for param in signature.parameters.values()): | ||||||||||||||||
| kwargs = {key: value for key, value in kwargs.items() if key in signature.parameters} | ||||||||||||||||
| return fn(*args, **kwargs) | ||||||||||||||||
|
|
||||||||||||||||
|
|
||||||||||||||||
| def _call_create_causal_mask(fn, config, input_embeds, attention_mask, cache_position_or_past_key_values, *args, | ||||||||||||||||
| **kwargs): | ||||||||||||||||
| if 'cache_position' in inspect.signature(fn).parameters: | ||||||||||||||||
| return _call_with_supported_kwargs( | ||||||||||||||||
| fn, | ||||||||||||||||
| config, | ||||||||||||||||
| input_embeds, | ||||||||||||||||
| attention_mask, | ||||||||||||||||
| cache_position_or_past_key_values, | ||||||||||||||||
| *args, | ||||||||||||||||
| **kwargs, | ||||||||||||||||
| ) | ||||||||||||||||
| if cache_position_or_past_key_values is None and 'past_key_values' in kwargs: | ||||||||||||||||
| return _call_with_supported_kwargs(fn, config, input_embeds, attention_mask, *args, **kwargs) | ||||||||||||||||
| return _call_with_supported_kwargs( | ||||||||||||||||
| fn, | ||||||||||||||||
| config, | ||||||||||||||||
| input_embeds, | ||||||||||||||||
| attention_mask, | ||||||||||||||||
| cache_position_or_past_key_values, | ||||||||||||||||
| *args, | ||||||||||||||||
| **kwargs, | ||||||||||||||||
| ) | ||||||||||||||||
|
|
||||||||||||||||
|
|
||||||||||||||||
| # main content copied from ms-swift | ||||||||||||||||
| class SequenceParallel: | ||||||||||||||||
|
|
||||||||||||||||
|
|
@@ -77,59 +110,89 @@ def _prepare_flash_attn(self, base_model: torch.nn.Module): | |||||||||||||||
| try: | ||||||||||||||||
| from transformers import masking_utils | ||||||||||||||||
|
|
||||||||||||||||
| _origin_flash_attention_mask = masking_utils.flash_attention_mask | ||||||||||||||||
|
|
||||||||||||||||
| # Patch attention masks for SP: avoid masking when full sequence is reconstructed. | ||||||||||||||||
| def flash_attention_mask(batch_size, | ||||||||||||||||
| cache_position, | ||||||||||||||||
| kv_length, | ||||||||||||||||
| kv_offset=0, | ||||||||||||||||
| mask_function=masking_utils.causal_mask_function, | ||||||||||||||||
| attention_mask=None, | ||||||||||||||||
| **kwargs): | ||||||||||||||||
| if self.world_size == 1: | ||||||||||||||||
| return _origin_flash_attention_mask(batch_size, cache_position, kv_length, kv_offset, mask_function, | ||||||||||||||||
| attention_mask, **kwargs) | ||||||||||||||||
| if attention_mask is not None: | ||||||||||||||||
| if attention_mask.all(): | ||||||||||||||||
| attention_mask = None | ||||||||||||||||
|
|
||||||||||||||||
| return attention_mask | ||||||||||||||||
|
|
||||||||||||||||
| masking_utils.flash_attention_mask = flash_attention_mask | ||||||||||||||||
| masking_utils.ALL_MASK_ATTENTION_FUNCTIONS._global_mapping['flash_attention_2'] = flash_attention_mask | ||||||||||||||||
|
|
||||||||||||||||
| def sdpa_mask(batch_size, cache_position, kv_length, *args, **kwargs): | ||||||||||||||||
| if self.world_size == 1: | ||||||||||||||||
| return masking_utils.ALL_MASK_ATTENTION_FUNCTIONS._global_mapping['sdpa_origin'](batch_size, | ||||||||||||||||
| cache_position, | ||||||||||||||||
| kv_length, *args, | ||||||||||||||||
| **kwargs) | ||||||||||||||||
| device = cache_position.device | ||||||||||||||||
| cache_position = self.real_position_ids[0] | ||||||||||||||||
| cache_position = self.pad(cache_position, padding_value=-1, position_ids=self.real_position_ids, dim=0) | ||||||||||||||||
| cache_position = torch.arange(0, cache_position.shape[0], device=device) | ||||||||||||||||
| kv_length = cache_position.shape[0] | ||||||||||||||||
| return masking_utils.ALL_MASK_ATTENTION_FUNCTIONS._global_mapping['sdpa_origin'](batch_size, | ||||||||||||||||
| cache_position, | ||||||||||||||||
| kv_length, *args, | ||||||||||||||||
| **kwargs) | ||||||||||||||||
| def sdpa_mask(batch_size, q_length=None, kv_length=None, *args, **kwargs): | ||||||||||||||||
| origin_sdpa = masking_utils.ALL_MASK_ATTENTION_FUNCTIONS._global_mapping['sdpa_origin'] | ||||||||||||||||
| origin_uses_cache_position = 'cache_position' in inspect.signature(origin_sdpa).parameters | ||||||||||||||||
|
Comment on lines
+113
to
+115
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The signature inspection and the lookup of the original SDPA function should be performed once outside of the
Suggested change
|
||||||||||||||||
| q_length = q_length if q_length is not None else kwargs.pop('cache_position', None) | ||||||||||||||||
| device = q_length.device if torch.is_tensor(q_length) else kwargs.pop('device', None) | ||||||||||||||||
| if device is None: | ||||||||||||||||
| device = self.real_position_ids.device | ||||||||||||||||
|
|
||||||||||||||||
| cache_position = None | ||||||||||||||||
| if self.world_size > 1: | ||||||||||||||||
| padded_position_ids = self.pad( | ||||||||||||||||
| self.real_position_ids[0], | ||||||||||||||||
| padding_value=-1, | ||||||||||||||||
| position_ids=self.real_position_ids, | ||||||||||||||||
| dim=0, | ||||||||||||||||
| ) | ||||||||||||||||
| global_length = padded_position_ids.shape[0] | ||||||||||||||||
| if origin_uses_cache_position: | ||||||||||||||||
| cache_position = torch.arange(0, global_length, device=device) | ||||||||||||||||
| kv_length = global_length | ||||||||||||||||
| else: | ||||||||||||||||
| # Newer Transformers passes q_length/q_offset instead of cache_position. In SP training, | ||||||||||||||||
| # create_causal_mask may still see the local shard length, so restore the global query length | ||||||||||||||||
| # only for the no-cache prefill path; cache/sliding paths keep their upstream offsets. | ||||||||||||||||
| q_offset = kwargs.get('q_offset', 0) | ||||||||||||||||
| kv_offset = kwargs.get('kv_offset', 0) | ||||||||||||||||
| no_cache_offsets = ((not torch.is_tensor(q_offset) and q_offset == 0) | ||||||||||||||||
| and (not torch.is_tensor(kv_offset) and kv_offset == 0)) | ||||||||||||||||
| if no_cache_offsets: | ||||||||||||||||
| q_length = global_length | ||||||||||||||||
| attention_mask = kwargs.get('attention_mask') | ||||||||||||||||
| if attention_mask is not None and torch.is_tensor(attention_mask): | ||||||||||||||||
| kv_length = attention_mask.shape[-1] | ||||||||||||||||
| else: | ||||||||||||||||
| kv_length = global_length | ||||||||||||||||
|
|
||||||||||||||||
| if origin_uses_cache_position: | ||||||||||||||||
| if cache_position is None: | ||||||||||||||||
| cache_position = q_length if torch.is_tensor(q_length) else torch.arange( | ||||||||||||||||
| q_length, device=device) | ||||||||||||||||
| return origin_sdpa(batch_size, cache_position, kv_length, *args, **kwargs) | ||||||||||||||||
|
|
||||||||||||||||
| return origin_sdpa(batch_size, q_length, kv_length, *args, device=device, **kwargs) | ||||||||||||||||
|
Comment on lines
+153
to
+155
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The calls to
Suggested change
|
||||||||||||||||
|
|
||||||||||||||||
| masking_utils.ALL_MASK_ATTENTION_FUNCTIONS._global_mapping[ | ||||||||||||||||
| 'sdpa_origin'] = masking_utils.ALL_MASK_ATTENTION_FUNCTIONS._global_mapping['sdpa'] | ||||||||||||||||
| masking_utils.ALL_MASK_ATTENTION_FUNCTIONS._global_mapping['sdpa'] = sdpa_mask | ||||||||||||||||
|
|
||||||||||||||||
| def create_causal_mask(config, input_embeds, attention_mask, cache_position, *args, **kwargs): | ||||||||||||||||
| def create_causal_mask(config, | ||||||||||||||||
| input_embeds, | ||||||||||||||||
| attention_mask, | ||||||||||||||||
| cache_position_or_past_key_values=None, | ||||||||||||||||
| *args, | ||||||||||||||||
| **kwargs): | ||||||||||||||||
| if self.world_size == 1: | ||||||||||||||||
| return masking_utils.origin_create_causal_mask(config, input_embeds, attention_mask, cache_position, | ||||||||||||||||
| *args, **kwargs) | ||||||||||||||||
| return _call_create_causal_mask( | ||||||||||||||||
| masking_utils.origin_create_causal_mask, | ||||||||||||||||
| config, | ||||||||||||||||
| input_embeds, | ||||||||||||||||
| attention_mask, | ||||||||||||||||
| cache_position_or_past_key_values, | ||||||||||||||||
| *args, | ||||||||||||||||
| **kwargs, | ||||||||||||||||
| ) | ||||||||||||||||
| input_embeds = torch.ones( | ||||||||||||||||
| (input_embeds.shape[0], input_embeds.shape[1] * self.sp_world_size, input_embeds.shape[2]), | ||||||||||||||||
| dtype=input_embeds.dtype, | ||||||||||||||||
| device=input_embeds.device) | ||||||||||||||||
| cache_position = torch.arange(0, input_embeds.shape[1], device=input_embeds.device) | ||||||||||||||||
| return masking_utils.origin_create_causal_mask(config, input_embeds, attention_mask, cache_position, | ||||||||||||||||
| *args, **kwargs) | ||||||||||||||||
| if 'cache_position' in inspect.signature(masking_utils.origin_create_causal_mask).parameters: | ||||||||||||||||
| cache_position_or_past_key_values = torch.arange( | ||||||||||||||||
| 0, | ||||||||||||||||
| input_embeds.shape[1], | ||||||||||||||||
| device=input_embeds.device, | ||||||||||||||||
| ) | ||||||||||||||||
|
Comment on lines
+181
to
+186
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||||||||||||||||
| return _call_create_causal_mask( | ||||||||||||||||
| masking_utils.origin_create_causal_mask, | ||||||||||||||||
| config, | ||||||||||||||||
| input_embeds, | ||||||||||||||||
| attention_mask, | ||||||||||||||||
| cache_position_or_past_key_values, | ||||||||||||||||
| *args, | ||||||||||||||||
| **kwargs, | ||||||||||||||||
| ) | ||||||||||||||||
|
|
||||||||||||||||
| masking_utils.origin_create_causal_mask = masking_utils.create_causal_mask | ||||||||||||||||
| masking_utils.create_causal_mask = create_causal_mask | ||||||||||||||||
|
|
||||||||||||||||
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| @@ -1,4 +1,7 @@ | ||||||||||||||||||||||||||||||
| import inspect | ||||||||||||||||||||||||||||||
| import torch | ||||||||||||||||||||||||||||||
| from packaging.version import Version | ||||||||||||||||||||||||||||||
| from transformers import __version__ as transformers_version | ||||||||||||||||||||||||||||||
| from transformers.utils.import_utils import is_flash_linear_attention_available | ||||||||||||||||||||||||||||||
| from typing import Optional | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
|
|
@@ -33,6 +36,17 @@ def _get_flash_linear_attention_kernels(): | |||||||||||||||||||||||||||||
| return causal_conv1d, chunk_gated_delta_rule | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
| def _call_with_supported_kwargs(fn, *args, **kwargs): | ||||||||||||||||||||||||||||||
| signature = inspect.signature(fn) | ||||||||||||||||||||||||||||||
| if not any(param.kind == inspect.Parameter.VAR_KEYWORD for param in signature.parameters.values()): | ||||||||||||||||||||||||||||||
| kwargs = {key: value for key, value in kwargs.items() if key in signature.parameters} | ||||||||||||||||||||||||||||||
| return fn(*args, **kwargs) | ||||||||||||||||||||||||||||||
|
Comment on lines
+39
to
+43
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Caching the signature in
Suggested change
|
||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
| def _supports_native_padding_free() -> bool: | ||||||||||||||||||||||||||||||
| return Version(Version(transformers_version).base_version) >= Version('5.9.0') | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
| def _patch_gdn_kernels_for_cu_seqlens( | ||||||||||||||||||||||||||||||
| mod: torch.nn.Module, | ||||||||||||||||||||||||||||||
| *, | ||||||||||||||||||||||||||||||
|
|
@@ -64,7 +78,7 @@ def chunk_gated_delta_rule_wrapper(query, key, value, **kwargs): | |||||||||||||||||||||||||||||
| mod.causal_conv1d_fn = causal_conv1d_wrapper | ||||||||||||||||||||||||||||||
| mod.chunk_gated_delta_rule = chunk_gated_delta_rule_wrapper | ||||||||||||||||||||||||||||||
| try: | ||||||||||||||||||||||||||||||
| return origin_forward(mod, *forward_args, **forward_kwargs) | ||||||||||||||||||||||||||||||
| return _call_with_supported_kwargs(origin_forward, mod, *forward_args, **forward_kwargs) | ||||||||||||||||||||||||||||||
| finally: | ||||||||||||||||||||||||||||||
| mod.causal_conv1d_fn = old_conv_fn | ||||||||||||||||||||||||||||||
| mod.chunk_gated_delta_rule = old_chunk_rule | ||||||||||||||||||||||||||||||
|
|
@@ -84,6 +98,8 @@ def __call__(self, module, *args, **kwargs): | |||||||||||||||||||||||||||||
| if getattr(Qwen3_5GatedDeltaNet, '_twinkle_sp_linear_patched', False): | ||||||||||||||||||||||||||||||
| return | ||||||||||||||||||||||||||||||
| module._twinkle_gdn_padding_free_patched = True | ||||||||||||||||||||||||||||||
| if _supports_native_padding_free(): | ||||||||||||||||||||||||||||||
| return | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
| if not getattr(Qwen3_5DecoderLayer, '_twinkle_padding_free_cu_seqlens_patched', False): | ||||||||||||||||||||||||||||||
| origin_decoder_forward = Qwen3_5DecoderLayer.forward | ||||||||||||||||||||||||||||||
|
|
@@ -99,7 +115,8 @@ def decoder_forward( | |||||||||||||||||||||||||||||
| **extra_kwargs, | ||||||||||||||||||||||||||||||
| ): | ||||||||||||||||||||||||||||||
| if getattr(layer, 'layer_type', None) != 'linear_attention': | ||||||||||||||||||||||||||||||
| return origin_decoder_forward( | ||||||||||||||||||||||||||||||
| return _call_with_supported_kwargs( | ||||||||||||||||||||||||||||||
| origin_decoder_forward, | ||||||||||||||||||||||||||||||
| layer, | ||||||||||||||||||||||||||||||
| hidden_states=hidden_states, | ||||||||||||||||||||||||||||||
| position_embeddings=position_embeddings, | ||||||||||||||||||||||||||||||
|
|
@@ -147,7 +164,8 @@ def forward( | |||||||||||||||||||||||||||||
| **extra_kwargs, | ||||||||||||||||||||||||||||||
| ): | ||||||||||||||||||||||||||||||
| if cu_seq_lens_q is None: | ||||||||||||||||||||||||||||||
| return origin_forward( | ||||||||||||||||||||||||||||||
| return _call_with_supported_kwargs( | ||||||||||||||||||||||||||||||
| origin_forward, | ||||||||||||||||||||||||||||||
| mod, | ||||||||||||||||||||||||||||||
| hidden_states, | ||||||||||||||||||||||||||||||
| cache_params=cache_params, | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Calling
inspect.signatureis a relatively expensive operation. Since_call_with_supported_kwargsis used within the model's forward pass (e.g., during mask creation), calling it repeatedly can introduce significant overhead. Consider caching the signature of the function to improve performance.