From eae429b9ea2aa04a5c54f5114e9763b2413517c6 Mon Sep 17 00:00:00 2001 From: qq_30035749 Date: Thu, 21 May 2026 17:05:49 +0800 Subject: [PATCH 1/3] feat: add support for forward methods with incompatible kwargs Add `_call_with_supported_kwargs` utility to filter out unsupported keyword arguments when calling forward methods, preventing errors from incompatible function signatures. This fixes issues where `origin_forward` methods may not accept all passed kwargs. --- src/twinkle/patch/gdn_padding_free.py | 16 +++++++++++++--- 1 file changed, 13 insertions(+), 3 deletions(-) diff --git a/src/twinkle/patch/gdn_padding_free.py b/src/twinkle/patch/gdn_padding_free.py index d34b7ec7..ccd4c9b2 100644 --- a/src/twinkle/patch/gdn_padding_free.py +++ b/src/twinkle/patch/gdn_padding_free.py @@ -1,6 +1,8 @@ +import inspect +from typing import Optional + import torch from transformers.utils.import_utils import is_flash_linear_attention_available -from typing import Optional from twinkle.patch import Patch @@ -33,6 +35,13 @@ def _get_flash_linear_attention_kernels(): return causal_conv1d, chunk_gated_delta_rule +def _call_with_supported_kwargs(fn, *args, **kwargs): + signature = inspect.signature(fn) + if not any(param.kind == inspect.Parameter.VAR_KEYWORD for param in signature.parameters.values()): + kwargs = {key: value for key, value in kwargs.items() if key in signature.parameters} + return fn(*args, **kwargs) + + def _patch_gdn_kernels_for_cu_seqlens( mod: torch.nn.Module, *, @@ -64,7 +73,7 @@ def chunk_gated_delta_rule_wrapper(query, key, value, **kwargs): mod.causal_conv1d_fn = causal_conv1d_wrapper mod.chunk_gated_delta_rule = chunk_gated_delta_rule_wrapper try: - return origin_forward(mod, *forward_args, **forward_kwargs) + return _call_with_supported_kwargs(origin_forward, mod, *forward_args, **forward_kwargs) finally: mod.causal_conv1d_fn = old_conv_fn mod.chunk_gated_delta_rule = old_chunk_rule @@ -147,7 +156,8 @@ def forward( **extra_kwargs, ): if cu_seq_lens_q is None: - return origin_forward( + return _call_with_supported_kwargs( + origin_forward, mod, hidden_states, cache_params=cache_params, From aa9bf73b54e0f1f1af16c2b06e34fa0e5ce1bf6e Mon Sep 17 00:00:00 2001 From: qq_30035749 Date: Fri, 22 May 2026 00:24:29 +0800 Subject: [PATCH 2/3] fix: support native padding-free in GatedDeltaNet and improve kwargs handling - Add `_call_with_supported_kwargs` and `_call_create_causal_mask` helpers to filter unsupported kwargs - Rename `cache_position` parameter to `q_length` in flash_attention_mask and sdpa_mask for clarity - Fix device detection in sdpa_mask when `q_length` is not a tensor - Ensure compatibility with models that don't accept `cache_position` in causal mask functions --- .../strategy/sequence_parallel/__init__.py | 128 ++++++++++++------ src/twinkle/patch/gdn_padding_free.py | 15 +- 2 files changed, 99 insertions(+), 44 deletions(-) diff --git a/src/twinkle/model/transformers/strategy/sequence_parallel/__init__.py b/src/twinkle/model/transformers/strategy/sequence_parallel/__init__.py index 51a28015..bbe06873 100644 --- a/src/twinkle/model/transformers/strategy/sequence_parallel/__init__.py +++ b/src/twinkle/model/transformers/strategy/sequence_parallel/__init__.py @@ -1,4 +1,5 @@ # Copyright (c) ModelScope Contributors. All rights reserved. +import inspect import math import torch import torch.distributed as dist @@ -28,6 +29,38 @@ def is_qwen3_omni(model): return 'qwen3_omni' in mt +def _call_with_supported_kwargs(fn, *args, **kwargs): + signature = inspect.signature(fn) + if not any(param.kind == inspect.Parameter.VAR_KEYWORD for param in signature.parameters.values()): + kwargs = {key: value for key, value in kwargs.items() if key in signature.parameters} + return fn(*args, **kwargs) + + +def _call_create_causal_mask(fn, config, input_embeds, attention_mask, cache_position_or_past_key_values, *args, + **kwargs): + if 'cache_position' in inspect.signature(fn).parameters: + return _call_with_supported_kwargs( + fn, + config, + input_embeds, + attention_mask, + cache_position_or_past_key_values, + *args, + **kwargs, + ) + if cache_position_or_past_key_values is None and 'past_key_values' in kwargs: + return _call_with_supported_kwargs(fn, config, input_embeds, attention_mask, *args, **kwargs) + return _call_with_supported_kwargs( + fn, + config, + input_embeds, + attention_mask, + cache_position_or_past_key_values, + *args, + **kwargs, + ) + + # main content copied from ms-swift class SequenceParallel: @@ -77,59 +110,72 @@ def _prepare_flash_attn(self, base_model: torch.nn.Module): try: from transformers import masking_utils - _origin_flash_attention_mask = masking_utils.flash_attention_mask - - # Patch attention masks for SP: avoid masking when full sequence is reconstructed. - def flash_attention_mask(batch_size, - cache_position, - kv_length, - kv_offset=0, - mask_function=masking_utils.causal_mask_function, - attention_mask=None, - **kwargs): - if self.world_size == 1: - return _origin_flash_attention_mask(batch_size, cache_position, kv_length, kv_offset, mask_function, - attention_mask, **kwargs) - if attention_mask is not None: - if attention_mask.all(): - attention_mask = None - - return attention_mask + def sdpa_mask(batch_size, q_length=None, kv_length=None, *args, **kwargs): + origin_sdpa = masking_utils.ALL_MASK_ATTENTION_FUNCTIONS._global_mapping['sdpa_origin'] + origin_uses_cache_position = 'cache_position' in inspect.signature(origin_sdpa).parameters + q_length = q_length if q_length is not None else kwargs.pop('cache_position', None) + device = q_length.device if torch.is_tensor(q_length) else kwargs.get('device') + if device is None: + device = self.real_position_ids.device + + cache_position = None + if self.world_size > 1 and origin_uses_cache_position: + padded_position_ids = self.pad( + self.real_position_ids[0], + padding_value=-1, + position_ids=self.real_position_ids, + dim=0, + ) + cache_position = torch.arange(0, padded_position_ids.shape[0], device=device) + kv_length = cache_position.shape[0] - masking_utils.flash_attention_mask = flash_attention_mask - masking_utils.ALL_MASK_ATTENTION_FUNCTIONS._global_mapping['flash_attention_2'] = flash_attention_mask + if origin_uses_cache_position: + if cache_position is None: + cache_position = q_length if torch.is_tensor(q_length) else torch.arange( + q_length, device=device) + return origin_sdpa(batch_size, cache_position, kv_length, *args, **kwargs) - def sdpa_mask(batch_size, cache_position, kv_length, *args, **kwargs): - if self.world_size == 1: - return masking_utils.ALL_MASK_ATTENTION_FUNCTIONS._global_mapping['sdpa_origin'](batch_size, - cache_position, - kv_length, *args, - **kwargs) - device = cache_position.device - cache_position = self.real_position_ids[0] - cache_position = self.pad(cache_position, padding_value=-1, position_ids=self.real_position_ids, dim=0) - cache_position = torch.arange(0, cache_position.shape[0], device=device) - kv_length = cache_position.shape[0] - return masking_utils.ALL_MASK_ATTENTION_FUNCTIONS._global_mapping['sdpa_origin'](batch_size, - cache_position, - kv_length, *args, - **kwargs) + return origin_sdpa(batch_size, q_length, kv_length, *args, device=device, **kwargs) masking_utils.ALL_MASK_ATTENTION_FUNCTIONS._global_mapping[ 'sdpa_origin'] = masking_utils.ALL_MASK_ATTENTION_FUNCTIONS._global_mapping['sdpa'] masking_utils.ALL_MASK_ATTENTION_FUNCTIONS._global_mapping['sdpa'] = sdpa_mask - def create_causal_mask(config, input_embeds, attention_mask, cache_position, *args, **kwargs): + def create_causal_mask(config, + input_embeds, + attention_mask, + cache_position_or_past_key_values=None, + *args, + **kwargs): if self.world_size == 1: - return masking_utils.origin_create_causal_mask(config, input_embeds, attention_mask, cache_position, - *args, **kwargs) + return _call_create_causal_mask( + masking_utils.origin_create_causal_mask, + config, + input_embeds, + attention_mask, + cache_position_or_past_key_values, + *args, + **kwargs, + ) input_embeds = torch.ones( (input_embeds.shape[0], input_embeds.shape[1] * self.sp_world_size, input_embeds.shape[2]), dtype=input_embeds.dtype, device=input_embeds.device) - cache_position = torch.arange(0, input_embeds.shape[1], device=input_embeds.device) - return masking_utils.origin_create_causal_mask(config, input_embeds, attention_mask, cache_position, - *args, **kwargs) + if 'cache_position' in inspect.signature(masking_utils.origin_create_causal_mask).parameters: + cache_position_or_past_key_values = torch.arange( + 0, + input_embeds.shape[1], + device=input_embeds.device, + ) + return _call_create_causal_mask( + masking_utils.origin_create_causal_mask, + config, + input_embeds, + attention_mask, + cache_position_or_past_key_values, + *args, + **kwargs, + ) masking_utils.origin_create_causal_mask = masking_utils.create_causal_mask masking_utils.create_causal_mask = create_causal_mask diff --git a/src/twinkle/patch/gdn_padding_free.py b/src/twinkle/patch/gdn_padding_free.py index ccd4c9b2..dde8bdf7 100644 --- a/src/twinkle/patch/gdn_padding_free.py +++ b/src/twinkle/patch/gdn_padding_free.py @@ -1,8 +1,7 @@ import inspect -from typing import Optional - import torch from transformers.utils.import_utils import is_flash_linear_attention_available +from typing import Optional from twinkle.patch import Patch @@ -42,6 +41,13 @@ def _call_with_supported_kwargs(fn, *args, **kwargs): return fn(*args, **kwargs) +def _supports_native_padding_free(Qwen3_5GatedDeltaNet) -> bool: + try: + return 'cu_seq_lens_q' in inspect.getsource(Qwen3_5GatedDeltaNet.forward) + except (OSError, TypeError): + return False + + def _patch_gdn_kernels_for_cu_seqlens( mod: torch.nn.Module, *, @@ -93,6 +99,8 @@ def __call__(self, module, *args, **kwargs): if getattr(Qwen3_5GatedDeltaNet, '_twinkle_sp_linear_patched', False): return module._twinkle_gdn_padding_free_patched = True + if _supports_native_padding_free(Qwen3_5GatedDeltaNet): + return if not getattr(Qwen3_5DecoderLayer, '_twinkle_padding_free_cu_seqlens_patched', False): origin_decoder_forward = Qwen3_5DecoderLayer.forward @@ -108,7 +116,8 @@ def decoder_forward( **extra_kwargs, ): if getattr(layer, 'layer_type', None) != 'linear_attention': - return origin_decoder_forward( + return _call_with_supported_kwargs( + origin_decoder_forward, layer, hidden_states=hidden_states, position_embeddings=position_embeddings, From 584ad663d7912ea127e3b74277aa998c80702d8a Mon Sep 17 00:00:00 2001 From: qq_30035749 Date: Fri, 22 May 2026 18:56:38 +0800 Subject: [PATCH 3/3] fix(sequence_parallel): restore global query length for no-cache prefill path In sequence parallel training, when newer Transformers versions pass q_length/q_offset instead of cache_position, the causal mask creation may still see the local shard length. This change restores the global query length for the no-cache prefill path while keeping cache/sliding paths with their upstream offsets. Also refactor GDN padding-free detection to use transformers version check instead of source inspection, supporting transformers >= 5.9.0. --- .../strategy/sequence_parallel/__init__.py | 25 ++++++++++++++++--- src/twinkle/patch/gdn_padding_free.py | 11 ++++---- 2 files changed, 26 insertions(+), 10 deletions(-) diff --git a/src/twinkle/model/transformers/strategy/sequence_parallel/__init__.py b/src/twinkle/model/transformers/strategy/sequence_parallel/__init__.py index bbe06873..08f5b990 100644 --- a/src/twinkle/model/transformers/strategy/sequence_parallel/__init__.py +++ b/src/twinkle/model/transformers/strategy/sequence_parallel/__init__.py @@ -114,20 +114,37 @@ def sdpa_mask(batch_size, q_length=None, kv_length=None, *args, **kwargs): origin_sdpa = masking_utils.ALL_MASK_ATTENTION_FUNCTIONS._global_mapping['sdpa_origin'] origin_uses_cache_position = 'cache_position' in inspect.signature(origin_sdpa).parameters q_length = q_length if q_length is not None else kwargs.pop('cache_position', None) - device = q_length.device if torch.is_tensor(q_length) else kwargs.get('device') + device = q_length.device if torch.is_tensor(q_length) else kwargs.pop('device', None) if device is None: device = self.real_position_ids.device cache_position = None - if self.world_size > 1 and origin_uses_cache_position: + if self.world_size > 1: padded_position_ids = self.pad( self.real_position_ids[0], padding_value=-1, position_ids=self.real_position_ids, dim=0, ) - cache_position = torch.arange(0, padded_position_ids.shape[0], device=device) - kv_length = cache_position.shape[0] + global_length = padded_position_ids.shape[0] + if origin_uses_cache_position: + cache_position = torch.arange(0, global_length, device=device) + kv_length = global_length + else: + # Newer Transformers passes q_length/q_offset instead of cache_position. In SP training, + # create_causal_mask may still see the local shard length, so restore the global query length + # only for the no-cache prefill path; cache/sliding paths keep their upstream offsets. + q_offset = kwargs.get('q_offset', 0) + kv_offset = kwargs.get('kv_offset', 0) + no_cache_offsets = ((not torch.is_tensor(q_offset) and q_offset == 0) + and (not torch.is_tensor(kv_offset) and kv_offset == 0)) + if no_cache_offsets: + q_length = global_length + attention_mask = kwargs.get('attention_mask') + if attention_mask is not None and torch.is_tensor(attention_mask): + kv_length = attention_mask.shape[-1] + else: + kv_length = global_length if origin_uses_cache_position: if cache_position is None: diff --git a/src/twinkle/patch/gdn_padding_free.py b/src/twinkle/patch/gdn_padding_free.py index dde8bdf7..1026a175 100644 --- a/src/twinkle/patch/gdn_padding_free.py +++ b/src/twinkle/patch/gdn_padding_free.py @@ -1,5 +1,7 @@ import inspect import torch +from packaging.version import Version +from transformers import __version__ as transformers_version from transformers.utils.import_utils import is_flash_linear_attention_available from typing import Optional @@ -41,11 +43,8 @@ def _call_with_supported_kwargs(fn, *args, **kwargs): return fn(*args, **kwargs) -def _supports_native_padding_free(Qwen3_5GatedDeltaNet) -> bool: - try: - return 'cu_seq_lens_q' in inspect.getsource(Qwen3_5GatedDeltaNet.forward) - except (OSError, TypeError): - return False +def _supports_native_padding_free() -> bool: + return Version(Version(transformers_version).base_version) >= Version('5.9.0') def _patch_gdn_kernels_for_cu_seqlens( @@ -99,7 +98,7 @@ def __call__(self, module, *args, **kwargs): if getattr(Qwen3_5GatedDeltaNet, '_twinkle_sp_linear_patched', False): return module._twinkle_gdn_padding_free_patched = True - if _supports_native_padding_free(Qwen3_5GatedDeltaNet): + if _supports_native_padding_free(): return if not getattr(Qwen3_5DecoderLayer, '_twinkle_padding_free_cu_seqlens_patched', False):