diff --git a/src/services/realtime_ws.rs b/src/services/realtime_ws.rs index e648f0f..74878f4 100644 --- a/src/services/realtime_ws.rs +++ b/src/services/realtime_ws.rs @@ -5,7 +5,7 @@ use axum::{ response::IntoResponse, }; use base64::Engine; -use bytes::{BufMut, Bytes, BytesMut}; +use bytes::{BufMut, BytesMut}; use futures_util::{ sink::SinkExt, stream::{SplitStream, StreamExt}, @@ -15,13 +15,7 @@ use tokio::sync::mpsc; use uuid::Uuid; use crate::{ - ai::{ - ChatSession, - bailian::cosyvoice, - elevenlabs, - openai::realtime::*, - vad::{VadRealtimeClient, VadRealtimeEvent}, - }, + ai::{ChatSession, bailian::cosyvoice, elevenlabs, openai::realtime::*, vad::VadSession}, config::*, }; @@ -44,11 +38,13 @@ pub struct RealtimeSession { pub input_audio_buffer: BytesMut, pub triggered: bool, pub is_generating: bool, - pub vad_realtime_client: Option, + pub vad_session: Option, + /// Cumulative audio duration in milliseconds (for 24kHz PCM16 input) + pub audio_position_ms: u32, } impl RealtimeSession { - pub fn new(chat_session: ChatSession) -> Self { + pub fn new(chat_session: ChatSession, vad_session: Option) -> Self { Self { client: reqwest::Client::new(), chat_session, @@ -58,7 +54,8 @@ impl RealtimeSession { input_audio_buffer: BytesMut::new(), triggered: false, is_generating: false, - vad_realtime_client: None, + vad_session, + audio_position_ms: 0, } } } @@ -72,7 +69,6 @@ pub struct StableRealtimeConfig { enum RealtimeEvent { ClientEvent(ClientEvent), - VadEvent(VadRealtimeEvent), } pub async fn ws_handler( @@ -100,24 +96,36 @@ async fn handle_socket(config: Arc, socket: WebSocket) { chat_session.system_prompts = parts.sys_prompts; chat_session.messages = parts.dynamic_prompts; - // 创建新的 Realtime 会话 - let mut session = RealtimeSession::new(chat_session); - let mut realtime_rx: Option<_> = None; - - if let Some(vad_realtime_url) = &config.asr.vad_realtime_url { - match crate::ai::vad::vad_realtime_client(&session.client, vad_realtime_url.clone()).await { - Ok((client, rx)) => { - session.vad_realtime_client = Some(client); - realtime_rx = Some(rx); - log::info!("Connected to VAD realtime service at {}", vad_realtime_url); - } - Err(e) => { - log::error!("Failed to connect to VAD realtime service: {}", e); + // Initialize built-in silero VAD session + let device = burn::backend::ndarray::NdArrayDevice::default(); + let vad_session = match silero_vad_burn::SileroVAD6Model::new(&device) { + Ok(vad_model) => { + match crate::ai::vad::VadSession::new(&config.asr.vad, Box::new(vad_model), device) { + Ok(session) => { + log::info!("Initialized built-in silero VAD session"); + Some(session) + } + Err(e) => { + log::error!("Failed to create VAD session: {}", e); + None + } } } - } + Err(e) => { + log::error!( + "Failed to load silero VAD model: {}. \ + This may be due to missing model files or insufficient memory.", + e + ); + None + } + }; - let turn_detection = if realtime_rx.is_some() { + // 创建新的 Realtime 会话 + let has_vad = vad_session.is_some(); + let mut session = RealtimeSession::new(chat_session, vad_session); + + let turn_detection = if has_vad { TurnDetection::server_vad() } else { TurnDetection::none() @@ -244,33 +252,10 @@ async fn handle_socket(config: Arc, socket: WebSocket) { None } - async fn select_event( - socket: &mut SplitStream, - realtime_rx: &mut Option, - ) -> Option { - if let Some(rx) = realtime_rx { - tokio::select! { - client_event = recv_client_event(socket) => { - client_event.map(RealtimeEvent::ClientEvent) - } - vad_event = rx.next_event() => { - match vad_event { - Ok(event) => Some(RealtimeEvent::VadEvent(event)), - Err(e) => { - log::error!("Failed to receive VAD event: {}", e); - None - } - } - } - } - } else { - recv_client_event(socket) - .await - .map(RealtimeEvent::ClientEvent) - } - } - - while let Some(event) = select_event(&mut receiver, &mut realtime_rx).await { + while let Some(event) = recv_client_event(&mut receiver) + .await + .map(RealtimeEvent::ClientEvent) + { if let Err(e) = handle_client_message( event, &mut session, @@ -360,14 +345,14 @@ async fn handle_client_message( return Ok(()); } if turn_detection.turn_type == TurnDetectionType::ServerVad - && session.vad_realtime_client.is_none() + && session.vad_session.is_none() { let error_event = ServerEvent::Error { event_id: Uuid::new_v4().to_string(), error: ErrorDetails { error_type: "invalid_request_error".to_string(), - code: Some("vad_realtime_not_connected".to_string()), - message: "VAD realtime service is not connected".to_string(), + code: Some("vad_not_available".to_string()), + message: "VAD session is not available".to_string(), param: Some("turn_detection.type".to_string()), event_id: None, }, @@ -433,14 +418,17 @@ async fn handle_client_message( .as_ref() .map(|t| t.turn_type == TurnDetectionType::ServerVad) .unwrap_or_default() - && session.vad_realtime_client.is_some(); + && session.vad_session.is_some(); log::debug!( "Server VAD status: {} {:?}", - session.vad_realtime_client.is_some(), + session.vad_session.is_some(), session.config.turn_detection ); + // Calculate audio duration: 24kHz PCM16 = 48 bytes per ms + let chunk_duration_ms = (audio_data.len() / 48) as u32; + if !server_vad || session.triggered { log::debug!( "Appending audio chunk to input buffer, length: {}, server VAD: {}", @@ -448,6 +436,7 @@ async fn handle_client_message( server_vad ); session.input_audio_buffer.extend_from_slice(&audio_data); + session.audio_position_ms += chunk_duration_ms; } else { log::debug!( "Audio chunk received but not triggered, length: {}, server VAD: {}", @@ -469,30 +458,63 @@ async fn handle_client_message( session.input_audio_buffer.extend_from_slice(&audio_data); } else { session.input_audio_buffer.clear(); + session.audio_position_ms = 0; session.input_audio_buffer.extend_from_slice(&audio_data); } + session.audio_position_ms += chunk_duration_ms; } + // Process audio through built-in silero VAD if server_vad { - let samples_24k = audio_data - .chunks_exact(2) - .map(|chunk| { - i16::from_le_bytes([chunk[0], chunk[1]]) as f32 / i16::MAX as f32 - }) - .collect::>(); - let sample_16k = wav_io::resample::linear(samples_24k, 1, 24000, 16000); - - let sample_16k = crate::util::convert_samples_f32_to_i16_bytes(&sample_16k); - log::debug!( - "Sending audio chunk to VAD realtime service, length: {}", - sample_16k.len() - ); - session - .vad_realtime_client - .as_mut() - .unwrap() - .push_audio_16k_chunk(Bytes::from(sample_16k)) - .await?; + if let Some(vad_session) = session.vad_session.as_mut() { + // Convert 24kHz PCM16 to 16kHz f32 for VAD + let samples_24k: Vec = audio_data + .chunks_exact(2) + .map(|chunk| { + i16::from_le_bytes([chunk[0], chunk[1]]) as f32 / i16::MAX as f32 + }) + .collect(); + let samples_16k = + wav_io::resample::linear(samples_24k, 1, 24000, 16000); + + // Process through VAD in chunks, collecting state transitions + // We collect first to release the vad_session borrow before + // calling functions that need mutable access to session + let chunk_size = VadSession::vad_chunk_size(); + let vad_events: Vec = samples_16k + .chunks(chunk_size) + .filter_map(|chunk| vad_session.detect(chunk).ok()) + .collect(); + + // Process VAD events and handle state transitions + for is_speech in vad_events { + if is_speech && !session.triggered { + // Speech started + log::info!( + "VAD detected speech start at {}ms", + session.audio_position_ms + ); + session.triggered = true; + let event = ServerEvent::InputAudioBufferSpeechStarted { + event_id: Uuid::new_v4().to_string(), + audio_start_ms: session.audio_position_ms, + item_id: Uuid::new_v4().to_string(), + }; + let _ = tx.send(event).await; + } else if !is_speech && session.triggered { + // Speech ended - trigger commit + log::info!("VAD detected speech end, triggering commit"); + if handle_audio_buffer_commit(session, tx, None, asr).await? { + generate_response(session, tx, tts).await?; + } + session.triggered = false; + if let Some(vs) = session.vad_session.as_mut() { + vs.reset_state(); + } + // Continue processing - new speech may start in remaining chunks + } + } + } } } @@ -505,6 +527,7 @@ async fn handle_client_message( ClientEvent::InputAudioBufferClear { event_id: _ } => { session.input_audio_buffer.clear(); + session.audio_position_ms = 0; let event = ServerEvent::InputAudioBufferCleared { event_id: Uuid::new_v4().to_string(), @@ -630,28 +653,6 @@ async fn handle_client_message( } } } - RealtimeEvent::VadEvent(vad_realtime_event) => match vad_realtime_event { - VadRealtimeEvent::Event { event } => match event.as_str() { - "speech_start" => { - log::debug!("VAD speech start detected"); - session.triggered = true; - } - "speech_end" => { - log::debug!("VAD speech end detected"); - session.triggered = false; - if handle_audio_buffer_commit(session, tx, None, asr).await? { - log::debug!("Audio buffer committed, generating response"); - generate_response(session, tx, tts).await?; - } - } - _ => { - log::warn!("Unhandled VAD event: {}", event); - } - }, - VadRealtimeEvent::Error { message, .. } => { - return Err(anyhow::anyhow!("VAD error: {}", message)); - } - }, } Ok(()) @@ -682,9 +683,30 @@ async fn handle_audio_buffer_commit( }; let _ = tx.send(committed_event).await; - if let Some(vad_url) = &config.vad_url { - let vad = crate::ai::vad::vad_detect(&session.client, vad_url, wav_audio.clone()).await?; - if vad.timestamps.is_empty() { + // Check for speech using built-in silero VAD + if let Some(vad_session) = session.vad_session.as_mut() { + // Convert 24kHz PCM16 to 16kHz f32 for VAD + let samples_24k: Vec = audio_data + .chunks_exact(2) + .map(|chunk| i16::from_le_bytes([chunk[0], chunk[1]]) as f32 / i16::MAX as f32) + .collect(); + let samples_16k = wav_io::resample::linear(samples_24k, 1, 24000, 16000); + + // Process through VAD to check if there's any speech + let chunk_size = VadSession::vad_chunk_size(); + let mut has_speech = false; + vad_session.reset_state(); + for chunk in samples_16k.chunks(chunk_size) { + if let Ok(is_speech) = vad_session.detect(chunk) { + if is_speech { + has_speech = true; + break; + } + } + } + + if !has_speech { + log::debug!("No speech detected in audio buffer, skipping ASR"); let transcription_completed = ServerEvent::ConversationItemInputAudioTranscriptionCompleted { event_id: Uuid::new_v4().to_string(),