diff --git a/src/twinkle/model/transformers/strategy/sequence_parallel/__init__.py b/src/twinkle/model/transformers/strategy/sequence_parallel/__init__.py index 51a28015..08f5b990 100644 --- a/src/twinkle/model/transformers/strategy/sequence_parallel/__init__.py +++ b/src/twinkle/model/transformers/strategy/sequence_parallel/__init__.py @@ -1,4 +1,5 @@ # Copyright (c) ModelScope Contributors. All rights reserved. +import inspect import math import torch import torch.distributed as dist @@ -28,6 +29,38 @@ def is_qwen3_omni(model): return 'qwen3_omni' in mt +def _call_with_supported_kwargs(fn, *args, **kwargs): + signature = inspect.signature(fn) + if not any(param.kind == inspect.Parameter.VAR_KEYWORD for param in signature.parameters.values()): + kwargs = {key: value for key, value in kwargs.items() if key in signature.parameters} + return fn(*args, **kwargs) + + +def _call_create_causal_mask(fn, config, input_embeds, attention_mask, cache_position_or_past_key_values, *args, + **kwargs): + if 'cache_position' in inspect.signature(fn).parameters: + return _call_with_supported_kwargs( + fn, + config, + input_embeds, + attention_mask, + cache_position_or_past_key_values, + *args, + **kwargs, + ) + if cache_position_or_past_key_values is None and 'past_key_values' in kwargs: + return _call_with_supported_kwargs(fn, config, input_embeds, attention_mask, *args, **kwargs) + return _call_with_supported_kwargs( + fn, + config, + input_embeds, + attention_mask, + cache_position_or_past_key_values, + *args, + **kwargs, + ) + + # main content copied from ms-swift class SequenceParallel: @@ -77,59 +110,89 @@ def _prepare_flash_attn(self, base_model: torch.nn.Module): try: from transformers import masking_utils - _origin_flash_attention_mask = masking_utils.flash_attention_mask - - # Patch attention masks for SP: avoid masking when full sequence is reconstructed. - def flash_attention_mask(batch_size, - cache_position, - kv_length, - kv_offset=0, - mask_function=masking_utils.causal_mask_function, - attention_mask=None, - **kwargs): - if self.world_size == 1: - return _origin_flash_attention_mask(batch_size, cache_position, kv_length, kv_offset, mask_function, - attention_mask, **kwargs) - if attention_mask is not None: - if attention_mask.all(): - attention_mask = None - - return attention_mask - - masking_utils.flash_attention_mask = flash_attention_mask - masking_utils.ALL_MASK_ATTENTION_FUNCTIONS._global_mapping['flash_attention_2'] = flash_attention_mask - - def sdpa_mask(batch_size, cache_position, kv_length, *args, **kwargs): - if self.world_size == 1: - return masking_utils.ALL_MASK_ATTENTION_FUNCTIONS._global_mapping['sdpa_origin'](batch_size, - cache_position, - kv_length, *args, - **kwargs) - device = cache_position.device - cache_position = self.real_position_ids[0] - cache_position = self.pad(cache_position, padding_value=-1, position_ids=self.real_position_ids, dim=0) - cache_position = torch.arange(0, cache_position.shape[0], device=device) - kv_length = cache_position.shape[0] - return masking_utils.ALL_MASK_ATTENTION_FUNCTIONS._global_mapping['sdpa_origin'](batch_size, - cache_position, - kv_length, *args, - **kwargs) + def sdpa_mask(batch_size, q_length=None, kv_length=None, *args, **kwargs): + origin_sdpa = masking_utils.ALL_MASK_ATTENTION_FUNCTIONS._global_mapping['sdpa_origin'] + origin_uses_cache_position = 'cache_position' in inspect.signature(origin_sdpa).parameters + q_length = q_length if q_length is not None else kwargs.pop('cache_position', None) + device = q_length.device if torch.is_tensor(q_length) else kwargs.pop('device', None) + if device is None: + device = self.real_position_ids.device + + cache_position = None + if self.world_size > 1: + padded_position_ids = self.pad( + self.real_position_ids[0], + padding_value=-1, + position_ids=self.real_position_ids, + dim=0, + ) + global_length = padded_position_ids.shape[0] + if origin_uses_cache_position: + cache_position = torch.arange(0, global_length, device=device) + kv_length = global_length + else: + # Newer Transformers passes q_length/q_offset instead of cache_position. In SP training, + # create_causal_mask may still see the local shard length, so restore the global query length + # only for the no-cache prefill path; cache/sliding paths keep their upstream offsets. + q_offset = kwargs.get('q_offset', 0) + kv_offset = kwargs.get('kv_offset', 0) + no_cache_offsets = ((not torch.is_tensor(q_offset) and q_offset == 0) + and (not torch.is_tensor(kv_offset) and kv_offset == 0)) + if no_cache_offsets: + q_length = global_length + attention_mask = kwargs.get('attention_mask') + if attention_mask is not None and torch.is_tensor(attention_mask): + kv_length = attention_mask.shape[-1] + else: + kv_length = global_length + + if origin_uses_cache_position: + if cache_position is None: + cache_position = q_length if torch.is_tensor(q_length) else torch.arange( + q_length, device=device) + return origin_sdpa(batch_size, cache_position, kv_length, *args, **kwargs) + + return origin_sdpa(batch_size, q_length, kv_length, *args, device=device, **kwargs) masking_utils.ALL_MASK_ATTENTION_FUNCTIONS._global_mapping[ 'sdpa_origin'] = masking_utils.ALL_MASK_ATTENTION_FUNCTIONS._global_mapping['sdpa'] masking_utils.ALL_MASK_ATTENTION_FUNCTIONS._global_mapping['sdpa'] = sdpa_mask - def create_causal_mask(config, input_embeds, attention_mask, cache_position, *args, **kwargs): + def create_causal_mask(config, + input_embeds, + attention_mask, + cache_position_or_past_key_values=None, + *args, + **kwargs): if self.world_size == 1: - return masking_utils.origin_create_causal_mask(config, input_embeds, attention_mask, cache_position, - *args, **kwargs) + return _call_create_causal_mask( + masking_utils.origin_create_causal_mask, + config, + input_embeds, + attention_mask, + cache_position_or_past_key_values, + *args, + **kwargs, + ) input_embeds = torch.ones( (input_embeds.shape[0], input_embeds.shape[1] * self.sp_world_size, input_embeds.shape[2]), dtype=input_embeds.dtype, device=input_embeds.device) - cache_position = torch.arange(0, input_embeds.shape[1], device=input_embeds.device) - return masking_utils.origin_create_causal_mask(config, input_embeds, attention_mask, cache_position, - *args, **kwargs) + if 'cache_position' in inspect.signature(masking_utils.origin_create_causal_mask).parameters: + cache_position_or_past_key_values = torch.arange( + 0, + input_embeds.shape[1], + device=input_embeds.device, + ) + return _call_create_causal_mask( + masking_utils.origin_create_causal_mask, + config, + input_embeds, + attention_mask, + cache_position_or_past_key_values, + *args, + **kwargs, + ) masking_utils.origin_create_causal_mask = masking_utils.create_causal_mask masking_utils.create_causal_mask = create_causal_mask diff --git a/src/twinkle/patch/gdn_padding_free.py b/src/twinkle/patch/gdn_padding_free.py index d34b7ec7..1026a175 100644 --- a/src/twinkle/patch/gdn_padding_free.py +++ b/src/twinkle/patch/gdn_padding_free.py @@ -1,4 +1,7 @@ +import inspect import torch +from packaging.version import Version +from transformers import __version__ as transformers_version from transformers.utils.import_utils import is_flash_linear_attention_available from typing import Optional @@ -33,6 +36,17 @@ def _get_flash_linear_attention_kernels(): return causal_conv1d, chunk_gated_delta_rule +def _call_with_supported_kwargs(fn, *args, **kwargs): + signature = inspect.signature(fn) + if not any(param.kind == inspect.Parameter.VAR_KEYWORD for param in signature.parameters.values()): + kwargs = {key: value for key, value in kwargs.items() if key in signature.parameters} + return fn(*args, **kwargs) + + +def _supports_native_padding_free() -> bool: + return Version(Version(transformers_version).base_version) >= Version('5.9.0') + + def _patch_gdn_kernels_for_cu_seqlens( mod: torch.nn.Module, *, @@ -64,7 +78,7 @@ def chunk_gated_delta_rule_wrapper(query, key, value, **kwargs): mod.causal_conv1d_fn = causal_conv1d_wrapper mod.chunk_gated_delta_rule = chunk_gated_delta_rule_wrapper try: - return origin_forward(mod, *forward_args, **forward_kwargs) + return _call_with_supported_kwargs(origin_forward, mod, *forward_args, **forward_kwargs) finally: mod.causal_conv1d_fn = old_conv_fn mod.chunk_gated_delta_rule = old_chunk_rule @@ -84,6 +98,8 @@ def __call__(self, module, *args, **kwargs): if getattr(Qwen3_5GatedDeltaNet, '_twinkle_sp_linear_patched', False): return module._twinkle_gdn_padding_free_patched = True + if _supports_native_padding_free(): + return if not getattr(Qwen3_5DecoderLayer, '_twinkle_padding_free_cu_seqlens_patched', False): origin_decoder_forward = Qwen3_5DecoderLayer.forward @@ -99,7 +115,8 @@ def decoder_forward( **extra_kwargs, ): if getattr(layer, 'layer_type', None) != 'linear_attention': - return origin_decoder_forward( + return _call_with_supported_kwargs( + origin_decoder_forward, layer, hidden_states=hidden_states, position_embeddings=position_embeddings, @@ -147,7 +164,8 @@ def forward( **extra_kwargs, ): if cu_seq_lens_q is None: - return origin_forward( + return _call_with_supported_kwargs( + origin_forward, mod, hidden_states, cache_params=cache_params,