diff --git a/lightllm/models/qwen3_omni_moe_thinker/audio_process.py b/lightllm/models/qwen3_omni_moe_thinker/audio_process.py index 194914d45..58b223d57 100644 --- a/lightllm/models/qwen3_omni_moe_thinker/audio_process.py +++ b/lightllm/models/qwen3_omni_moe_thinker/audio_process.py @@ -102,7 +102,7 @@ def zero_mean_unit_var_norm( def _preprocess( self, raw_speech: Union[np.ndarray, list[float], list[np.ndarray], list[list[float]]], - truncation: bool = True, + truncation: bool = False, pad_to_multiple_of: Optional[int] = None, return_tensors: Optional[Union[str, TensorType]] = None, return_attention_mask: Optional[bool] = None, diff --git a/lightllm/models/qwen3_omni_moe_thinker/model.py b/lightllm/models/qwen3_omni_moe_thinker/model.py index 1b8fa0110..bee15e3d2 100644 --- a/lightllm/models/qwen3_omni_moe_thinker/model.py +++ b/lightllm/models/qwen3_omni_moe_thinker/model.py @@ -59,12 +59,7 @@ def init_audioitem_extral_params( return def get_audio_token_length(self, audio: AudioItem): - # 这里得处理对应奖语音长度按照 30 进行限制,后续处理中,超过30的会被截断。 - if audio.audio_length > self.n_samples: - logger.warning(f"audio length {audio.audio_length} exceed max length {self.n_samples}, will be truncated.") - - length = min(audio.audio_length, int(self.n_samples)) - token_num = self._caclu_audio_token_num(length) + token_num = self._caclu_audio_token_num(audio.audio_length) # print(f"token_num is {token_num} n_samples is {self.n_samples} hop_length is {self.hop_length}") return token_num diff --git a/lightllm/models/qwen3_omni_moe_thinker/qwen3_omni_audio.py b/lightllm/models/qwen3_omni_moe_thinker/qwen3_omni_audio.py index 03c57126f..ff49ab160 100644 --- a/lightllm/models/qwen3_omni_moe_thinker/qwen3_omni_audio.py +++ b/lightllm/models/qwen3_omni_moe_thinker/qwen3_omni_audio.py @@ -10,10 +10,15 @@ from transformers.activations import ACT2FN from lightllm.server.multimodal_params import AudioItem +from lightllm.utils.log_utils import init_logger from lightllm.common.basemodel.layer_infer.cache_tensor_manager import g_cache_manager from lightllm.models.vit.triton_kernel.flashattention_nopad import flash_attention_fwd from lightllm.models.qwen3_omni_moe_thinker.audio_process import WhisperFeatureExtractor +QWEN3_OMNI_CONV_CHUNKSIZE = int(os.getenv("LIGHTLLM_QWEN3_OMNI_CONV_CHUNKSIZE", 500)) + +logger = init_logger(__name__) + def _get_feat_extract_output_lengths(input_lengths): """ @@ -156,7 +161,7 @@ def __init__( activation_function="gelu", output_dim=2048, n_window_infer=800, - conv_chunksize=500, + conv_chunksize=QWEN3_OMNI_CONV_CHUNKSIZE, encoder_attention_heads=20, attention_dropout=0, activation_dropout=0, @@ -259,6 +264,7 @@ def load_model(self, weight_dir, config): self.load_state_dict(weight_dict) + @torch.inference_mode() def forward( self, input_features, @@ -327,6 +333,7 @@ def forward( hidden_states = self.proj2(hidden_states) return hidden_states + @torch.inference_mode() def encode(self, audio_items: List[AudioItem]): uuids = [] items: List[AudioItem] = [] @@ -363,3 +370,23 @@ def encode(self, audio_items: List[AudioItem]): all_embeds.append(cur_embed) return all_embeds, audio_items + + @torch.inference_mode() + def check_long_audio_infer(self): + """Exercise forward with mel length chosen so the conv loop runs once with batch dim == conv_chunksize.""" + params = next(self.parameters()) + device = params.device + dtype = params.dtype + frame_len = self.conv_chunksize * (self.n_window * 2) + logger.info( + "check_long_audio_infer: start frame_len=%s conv_chunksize=%s n_window=%s device=%s dtype=%s", + frame_len, + self.conv_chunksize, + self.n_window, + device, + dtype, + ) + input_features = torch.zeros(self.num_mel_bins, frame_len, device=device, dtype=dtype) + feature_lens = torch.tensor([frame_len], device=device, dtype=torch.long) + out = self.forward(input_features, feature_lens=feature_lens) + logger.info("check_long_audio_infer: done output_shape=%s", tuple(out.shape)) diff --git a/lightllm/models/whisper/whisper_audio.py b/lightllm/models/whisper/whisper_audio.py index aaa29e1c7..8a984d29a 100644 --- a/lightllm/models/whisper/whisper_audio.py +++ b/lightllm/models/whisper/whisper_audio.py @@ -223,3 +223,6 @@ def encode(self, audio_items: List[AudioItem]): ans_embeds.append(cur_embed) return ans_embeds, audio_items + + def check_long_audio_infer(self): + pass diff --git a/lightllm/server/audioserver/model_infer/model_rpc.py b/lightllm/server/audioserver/model_infer/model_rpc.py index 39a7e06ac..82919856d 100644 --- a/lightllm/server/audioserver/model_infer/model_rpc.py +++ b/lightllm/server/audioserver/model_infer/model_rpc.py @@ -51,6 +51,7 @@ def exposed_init_model(self, kvargs): self.model.load_model(weight_dir, model_cfg) self.model = self.model.cuda() + self.model.check_long_audio_infer() self.cache_client = rpyc.connect("localhost", self.cache_port, config={"allow_pickle": True}) self.cache_client._channel.stream.sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)