Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
149 changes: 106 additions & 43 deletions src/twinkle/model/transformers/strategy/sequence_parallel/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# Copyright (c) ModelScope Contributors. All rights reserved.
import inspect
import math
import torch
import torch.distributed as dist
Expand Down Expand Up @@ -28,6 +29,38 @@ def is_qwen3_omni(model):
return 'qwen3_omni' in mt


def _call_with_supported_kwargs(fn, *args, **kwargs):
signature = inspect.signature(fn)
if not any(param.kind == inspect.Parameter.VAR_KEYWORD for param in signature.parameters.values()):
kwargs = {key: value for key, value in kwargs.items() if key in signature.parameters}
return fn(*args, **kwargs)
Comment on lines +32 to +36
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Calling inspect.signature is a relatively expensive operation. Since _call_with_supported_kwargs is used within the model's forward pass (e.g., during mask creation), calling it repeatedly can introduce significant overhead. Consider caching the signature of the function to improve performance.

Suggested change
def _call_with_supported_kwargs(fn, *args, **kwargs):
signature = inspect.signature(fn)
if not any(param.kind == inspect.Parameter.VAR_KEYWORD for param in signature.parameters.values()):
kwargs = {key: value for key, value in kwargs.items() if key in signature.parameters}
return fn(*args, **kwargs)
def _call_with_supported_kwargs(fn, *args, **kwargs):
if not hasattr(_call_with_supported_kwargs, '_cache'):
_call_with_supported_kwargs._cache = {}
sig = _call_with_supported_kwargs._cache.get(fn)
if sig is None:
sig = _call_with_supported_kwargs._cache[fn] = inspect.signature(fn)
if not any(param.kind == inspect.Parameter.VAR_KEYWORD for param in sig.parameters.values()):
kwargs = {key: value for key, value in kwargs.items() if key in sig.parameters}
return fn(*args, **kwargs)



def _call_create_causal_mask(fn, config, input_embeds, attention_mask, cache_position_or_past_key_values, *args,
**kwargs):
if 'cache_position' in inspect.signature(fn).parameters:
return _call_with_supported_kwargs(
fn,
config,
input_embeds,
attention_mask,
cache_position_or_past_key_values,
*args,
**kwargs,
)
if cache_position_or_past_key_values is None and 'past_key_values' in kwargs:
return _call_with_supported_kwargs(fn, config, input_embeds, attention_mask, *args, **kwargs)
return _call_with_supported_kwargs(
fn,
config,
input_embeds,
attention_mask,
cache_position_or_past_key_values,
*args,
**kwargs,
)


# main content copied from ms-swift
class SequenceParallel:

Expand Down Expand Up @@ -77,59 +110,89 @@ def _prepare_flash_attn(self, base_model: torch.nn.Module):
try:
from transformers import masking_utils

_origin_flash_attention_mask = masking_utils.flash_attention_mask

# Patch attention masks for SP: avoid masking when full sequence is reconstructed.
def flash_attention_mask(batch_size,
cache_position,
kv_length,
kv_offset=0,
mask_function=masking_utils.causal_mask_function,
attention_mask=None,
**kwargs):
if self.world_size == 1:
return _origin_flash_attention_mask(batch_size, cache_position, kv_length, kv_offset, mask_function,
attention_mask, **kwargs)
if attention_mask is not None:
if attention_mask.all():
attention_mask = None

return attention_mask

masking_utils.flash_attention_mask = flash_attention_mask
masking_utils.ALL_MASK_ATTENTION_FUNCTIONS._global_mapping['flash_attention_2'] = flash_attention_mask

def sdpa_mask(batch_size, cache_position, kv_length, *args, **kwargs):
if self.world_size == 1:
return masking_utils.ALL_MASK_ATTENTION_FUNCTIONS._global_mapping['sdpa_origin'](batch_size,
cache_position,
kv_length, *args,
**kwargs)
device = cache_position.device
cache_position = self.real_position_ids[0]
cache_position = self.pad(cache_position, padding_value=-1, position_ids=self.real_position_ids, dim=0)
cache_position = torch.arange(0, cache_position.shape[0], device=device)
kv_length = cache_position.shape[0]
return masking_utils.ALL_MASK_ATTENTION_FUNCTIONS._global_mapping['sdpa_origin'](batch_size,
cache_position,
kv_length, *args,
**kwargs)
def sdpa_mask(batch_size, q_length=None, kv_length=None, *args, **kwargs):
origin_sdpa = masking_utils.ALL_MASK_ATTENTION_FUNCTIONS._global_mapping['sdpa_origin']
origin_uses_cache_position = 'cache_position' in inspect.signature(origin_sdpa).parameters
Comment on lines +113 to +115
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The signature inspection and the lookup of the original SDPA function should be performed once outside of the sdpa_mask definition. This avoids redundant computation during every forward pass of the model.

Suggested change
def sdpa_mask(batch_size, q_length=None, kv_length=None, *args, **kwargs):
origin_sdpa = masking_utils.ALL_MASK_ATTENTION_FUNCTIONS._global_mapping['sdpa_origin']
origin_uses_cache_position = 'cache_position' in inspect.signature(origin_sdpa).parameters
origin_sdpa = masking_utils.ALL_MASK_ATTENTION_FUNCTIONS._global_mapping['sdpa']
origin_uses_cache_position = 'cache_position' in inspect.signature(origin_sdpa).parameters
def sdpa_mask(batch_size, q_length=None, kv_length=None, *args, **kwargs):

q_length = q_length if q_length is not None else kwargs.pop('cache_position', None)
device = q_length.device if torch.is_tensor(q_length) else kwargs.pop('device', None)
if device is None:
device = self.real_position_ids.device

cache_position = None
if self.world_size > 1:
padded_position_ids = self.pad(
self.real_position_ids[0],
padding_value=-1,
position_ids=self.real_position_ids,
dim=0,
)
global_length = padded_position_ids.shape[0]
if origin_uses_cache_position:
cache_position = torch.arange(0, global_length, device=device)
kv_length = global_length
else:
# Newer Transformers passes q_length/q_offset instead of cache_position. In SP training,
# create_causal_mask may still see the local shard length, so restore the global query length
# only for the no-cache prefill path; cache/sliding paths keep their upstream offsets.
q_offset = kwargs.get('q_offset', 0)
kv_offset = kwargs.get('kv_offset', 0)
no_cache_offsets = ((not torch.is_tensor(q_offset) and q_offset == 0)
and (not torch.is_tensor(kv_offset) and kv_offset == 0))
if no_cache_offsets:
q_length = global_length
attention_mask = kwargs.get('attention_mask')
if attention_mask is not None and torch.is_tensor(attention_mask):
kv_length = attention_mask.shape[-1]
else:
kv_length = global_length

if origin_uses_cache_position:
if cache_position is None:
cache_position = q_length if torch.is_tensor(q_length) else torch.arange(
q_length, device=device)
return origin_sdpa(batch_size, cache_position, kv_length, *args, **kwargs)

return origin_sdpa(batch_size, q_length, kv_length, *args, device=device, **kwargs)
Comment on lines +153 to +155
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The calls to origin_sdpa should use _call_with_supported_kwargs to ensure compatibility across different Transformers versions. This is particularly important for the device argument, which may not be accepted by older versions of the SDPA mask function.

Suggested change
return origin_sdpa(batch_size, cache_position, kv_length, *args, **kwargs)
return origin_sdpa(batch_size, q_length, kv_length, *args, device=device, **kwargs)
return _call_with_supported_kwargs(origin_sdpa, batch_size, cache_position, kv_length, *args, **kwargs)
return _call_with_supported_kwargs(origin_sdpa, batch_size, q_length, kv_length, *args, device=device, **kwargs)


masking_utils.ALL_MASK_ATTENTION_FUNCTIONS._global_mapping[
'sdpa_origin'] = masking_utils.ALL_MASK_ATTENTION_FUNCTIONS._global_mapping['sdpa']
masking_utils.ALL_MASK_ATTENTION_FUNCTIONS._global_mapping['sdpa'] = sdpa_mask

def create_causal_mask(config, input_embeds, attention_mask, cache_position, *args, **kwargs):
def create_causal_mask(config,
input_embeds,
attention_mask,
cache_position_or_past_key_values=None,
*args,
**kwargs):
if self.world_size == 1:
return masking_utils.origin_create_causal_mask(config, input_embeds, attention_mask, cache_position,
*args, **kwargs)
return _call_create_causal_mask(
masking_utils.origin_create_causal_mask,
config,
input_embeds,
attention_mask,
cache_position_or_past_key_values,
*args,
**kwargs,
)
input_embeds = torch.ones(
(input_embeds.shape[0], input_embeds.shape[1] * self.sp_world_size, input_embeds.shape[2]),
dtype=input_embeds.dtype,
device=input_embeds.device)
cache_position = torch.arange(0, input_embeds.shape[1], device=input_embeds.device)
return masking_utils.origin_create_causal_mask(config, input_embeds, attention_mask, cache_position,
*args, **kwargs)
if 'cache_position' in inspect.signature(masking_utils.origin_create_causal_mask).parameters:
cache_position_or_past_key_values = torch.arange(
0,
input_embeds.shape[1],
device=input_embeds.device,
)
Comment on lines +181 to +186
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Similar to the optimization suggested for sdpa_mask, the signature of masking_utils.origin_create_causal_mask should be inspected once outside the create_causal_mask function to avoid overhead in the forward pass.

return _call_create_causal_mask(
masking_utils.origin_create_causal_mask,
config,
input_embeds,
attention_mask,
cache_position_or_past_key_values,
*args,
**kwargs,
)

masking_utils.origin_create_causal_mask = masking_utils.create_causal_mask
masking_utils.create_causal_mask = create_causal_mask
Expand Down
24 changes: 21 additions & 3 deletions src/twinkle/patch/gdn_padding_free.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
import inspect
import torch
from packaging.version import Version
from transformers import __version__ as transformers_version
from transformers.utils.import_utils import is_flash_linear_attention_available
from typing import Optional

Expand Down Expand Up @@ -33,6 +36,17 @@ def _get_flash_linear_attention_kernels():
return causal_conv1d, chunk_gated_delta_rule


def _call_with_supported_kwargs(fn, *args, **kwargs):
signature = inspect.signature(fn)
if not any(param.kind == inspect.Parameter.VAR_KEYWORD for param in signature.parameters.values()):
kwargs = {key: value for key, value in kwargs.items() if key in signature.parameters}
return fn(*args, **kwargs)
Comment on lines +39 to +43
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Caching the signature in _call_with_supported_kwargs is recommended here as well, as this function is called within the forward pass of the GatedDeltaNet layers.

Suggested change
def _call_with_supported_kwargs(fn, *args, **kwargs):
signature = inspect.signature(fn)
if not any(param.kind == inspect.Parameter.VAR_KEYWORD for param in signature.parameters.values()):
kwargs = {key: value for key, value in kwargs.items() if key in signature.parameters}
return fn(*args, **kwargs)
def _call_with_supported_kwargs(fn, *args, **kwargs):
if not hasattr(_call_with_supported_kwargs, '_cache'):
_call_with_supported_kwargs._cache = {}
sig = _call_with_supported_kwargs._cache.get(fn)
if sig is None:
sig = _call_with_supported_kwargs._cache[fn] = inspect.signature(fn)
if not any(param.kind == inspect.Parameter.VAR_KEYWORD for param in sig.parameters.values()):
kwargs = {key: value for key, value in kwargs.items() if key in sig.parameters}
return fn(*args, **kwargs)



def _supports_native_padding_free() -> bool:
return Version(Version(transformers_version).base_version) >= Version('5.9.0')


def _patch_gdn_kernels_for_cu_seqlens(
mod: torch.nn.Module,
*,
Expand Down Expand Up @@ -64,7 +78,7 @@ def chunk_gated_delta_rule_wrapper(query, key, value, **kwargs):
mod.causal_conv1d_fn = causal_conv1d_wrapper
mod.chunk_gated_delta_rule = chunk_gated_delta_rule_wrapper
try:
return origin_forward(mod, *forward_args, **forward_kwargs)
return _call_with_supported_kwargs(origin_forward, mod, *forward_args, **forward_kwargs)
finally:
mod.causal_conv1d_fn = old_conv_fn
mod.chunk_gated_delta_rule = old_chunk_rule
Expand All @@ -84,6 +98,8 @@ def __call__(self, module, *args, **kwargs):
if getattr(Qwen3_5GatedDeltaNet, '_twinkle_sp_linear_patched', False):
return
module._twinkle_gdn_padding_free_patched = True
if _supports_native_padding_free():
return

if not getattr(Qwen3_5DecoderLayer, '_twinkle_padding_free_cu_seqlens_patched', False):
origin_decoder_forward = Qwen3_5DecoderLayer.forward
Expand All @@ -99,7 +115,8 @@ def decoder_forward(
**extra_kwargs,
):
if getattr(layer, 'layer_type', None) != 'linear_attention':
return origin_decoder_forward(
return _call_with_supported_kwargs(
origin_decoder_forward,
layer,
hidden_states=hidden_states,
position_embeddings=position_embeddings,
Expand Down Expand Up @@ -147,7 +164,8 @@ def forward(
**extra_kwargs,
):
if cu_seq_lens_q is None:
return origin_forward(
return _call_with_supported_kwargs(
origin_forward,
mod,
hidden_states,
cache_params=cache_params,
Expand Down
Loading