-
Notifications
You must be signed in to change notification settings - Fork 74
Refactor realtime_ws to use built-in silero VAD #39
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
6f01e93
d4e74eb
16ac56d
53c226e
d425b33
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -5,7 +5,7 @@ use axum::{ | |||||||||||||||||
| response::IntoResponse, | ||||||||||||||||||
| }; | ||||||||||||||||||
| use base64::Engine; | ||||||||||||||||||
| use bytes::{BufMut, Bytes, BytesMut}; | ||||||||||||||||||
| use bytes::{BufMut, BytesMut}; | ||||||||||||||||||
| use futures_util::{ | ||||||||||||||||||
| sink::SinkExt, | ||||||||||||||||||
| stream::{SplitStream, StreamExt}, | ||||||||||||||||||
|
|
@@ -15,13 +15,7 @@ use tokio::sync::mpsc; | |||||||||||||||||
| use uuid::Uuid; | ||||||||||||||||||
|
|
||||||||||||||||||
| use crate::{ | ||||||||||||||||||
| ai::{ | ||||||||||||||||||
| ChatSession, | ||||||||||||||||||
| bailian::cosyvoice, | ||||||||||||||||||
| elevenlabs, | ||||||||||||||||||
| openai::realtime::*, | ||||||||||||||||||
| vad::{VadRealtimeClient, VadRealtimeEvent}, | ||||||||||||||||||
| }, | ||||||||||||||||||
| ai::{ChatSession, bailian::cosyvoice, elevenlabs, openai::realtime::*, vad::VadSession}, | ||||||||||||||||||
| config::*, | ||||||||||||||||||
| }; | ||||||||||||||||||
|
|
||||||||||||||||||
|
|
@@ -44,11 +38,13 @@ pub struct RealtimeSession { | |||||||||||||||||
| pub input_audio_buffer: BytesMut, | ||||||||||||||||||
| pub triggered: bool, | ||||||||||||||||||
| pub is_generating: bool, | ||||||||||||||||||
| pub vad_realtime_client: Option<VadRealtimeClient>, | ||||||||||||||||||
| pub vad_session: Option<VadSession>, | ||||||||||||||||||
| /// Cumulative audio duration in milliseconds (for 24kHz PCM16 input) | ||||||||||||||||||
| pub audio_position_ms: u32, | ||||||||||||||||||
| } | ||||||||||||||||||
|
|
||||||||||||||||||
| impl RealtimeSession { | ||||||||||||||||||
| pub fn new(chat_session: ChatSession) -> Self { | ||||||||||||||||||
| pub fn new(chat_session: ChatSession, vad_session: Option<VadSession>) -> Self { | ||||||||||||||||||
| Self { | ||||||||||||||||||
| client: reqwest::Client::new(), | ||||||||||||||||||
| chat_session, | ||||||||||||||||||
|
|
@@ -58,7 +54,8 @@ impl RealtimeSession { | |||||||||||||||||
| input_audio_buffer: BytesMut::new(), | ||||||||||||||||||
| triggered: false, | ||||||||||||||||||
| is_generating: false, | ||||||||||||||||||
| vad_realtime_client: None, | ||||||||||||||||||
| vad_session, | ||||||||||||||||||
| audio_position_ms: 0, | ||||||||||||||||||
| } | ||||||||||||||||||
| } | ||||||||||||||||||
| } | ||||||||||||||||||
|
|
@@ -72,7 +69,6 @@ pub struct StableRealtimeConfig { | |||||||||||||||||
|
|
||||||||||||||||||
| enum RealtimeEvent { | ||||||||||||||||||
| ClientEvent(ClientEvent), | ||||||||||||||||||
| VadEvent(VadRealtimeEvent), | ||||||||||||||||||
| } | ||||||||||||||||||
|
|
||||||||||||||||||
| pub async fn ws_handler( | ||||||||||||||||||
|
|
@@ -100,24 +96,36 @@ async fn handle_socket(config: Arc<StableRealtimeConfig>, socket: WebSocket) { | |||||||||||||||||
| chat_session.system_prompts = parts.sys_prompts; | ||||||||||||||||||
| chat_session.messages = parts.dynamic_prompts; | ||||||||||||||||||
|
|
||||||||||||||||||
| // 创建新的 Realtime 会话 | ||||||||||||||||||
| let mut session = RealtimeSession::new(chat_session); | ||||||||||||||||||
| let mut realtime_rx: Option<_> = None; | ||||||||||||||||||
|
|
||||||||||||||||||
| if let Some(vad_realtime_url) = &config.asr.vad_realtime_url { | ||||||||||||||||||
| match crate::ai::vad::vad_realtime_client(&session.client, vad_realtime_url.clone()).await { | ||||||||||||||||||
| Ok((client, rx)) => { | ||||||||||||||||||
| session.vad_realtime_client = Some(client); | ||||||||||||||||||
| realtime_rx = Some(rx); | ||||||||||||||||||
| log::info!("Connected to VAD realtime service at {}", vad_realtime_url); | ||||||||||||||||||
| } | ||||||||||||||||||
| Err(e) => { | ||||||||||||||||||
| log::error!("Failed to connect to VAD realtime service: {}", e); | ||||||||||||||||||
| // Initialize built-in silero VAD session | ||||||||||||||||||
| let device = burn::backend::ndarray::NdArrayDevice::default(); | ||||||||||||||||||
| let vad_session = match silero_vad_burn::SileroVAD6Model::new(&device) { | ||||||||||||||||||
| Ok(vad_model) => { | ||||||||||||||||||
| match crate::ai::vad::VadSession::new(&config.asr.vad, Box::new(vad_model), device) { | ||||||||||||||||||
| Ok(session) => { | ||||||||||||||||||
| log::info!("Initialized built-in silero VAD session"); | ||||||||||||||||||
| Some(session) | ||||||||||||||||||
| } | ||||||||||||||||||
| Err(e) => { | ||||||||||||||||||
| log::error!("Failed to create VAD session: {}", e); | ||||||||||||||||||
| None | ||||||||||||||||||
| } | ||||||||||||||||||
| } | ||||||||||||||||||
| } | ||||||||||||||||||
| } | ||||||||||||||||||
| Err(e) => { | ||||||||||||||||||
| log::error!( | ||||||||||||||||||
| "Failed to load silero VAD model: {}. \ | ||||||||||||||||||
| This may be due to missing model files or insufficient memory.", | ||||||||||||||||||
| e | ||||||||||||||||||
| ); | ||||||||||||||||||
| None | ||||||||||||||||||
| } | ||||||||||||||||||
| }; | ||||||||||||||||||
|
|
||||||||||||||||||
| let turn_detection = if realtime_rx.is_some() { | ||||||||||||||||||
| // 创建新的 Realtime 会话 | ||||||||||||||||||
| let has_vad = vad_session.is_some(); | ||||||||||||||||||
| let mut session = RealtimeSession::new(chat_session, vad_session); | ||||||||||||||||||
|
|
||||||||||||||||||
| let turn_detection = if has_vad { | ||||||||||||||||||
| TurnDetection::server_vad() | ||||||||||||||||||
| } else { | ||||||||||||||||||
| TurnDetection::none() | ||||||||||||||||||
|
|
@@ -244,33 +252,10 @@ async fn handle_socket(config: Arc<StableRealtimeConfig>, socket: WebSocket) { | |||||||||||||||||
| None | ||||||||||||||||||
| } | ||||||||||||||||||
|
|
||||||||||||||||||
| async fn select_event( | ||||||||||||||||||
| socket: &mut SplitStream<WebSocket>, | ||||||||||||||||||
| realtime_rx: &mut Option<crate::ai::vad::VadRealtimeRx>, | ||||||||||||||||||
| ) -> Option<RealtimeEvent> { | ||||||||||||||||||
| if let Some(rx) = realtime_rx { | ||||||||||||||||||
| tokio::select! { | ||||||||||||||||||
| client_event = recv_client_event(socket) => { | ||||||||||||||||||
| client_event.map(RealtimeEvent::ClientEvent) | ||||||||||||||||||
| } | ||||||||||||||||||
| vad_event = rx.next_event() => { | ||||||||||||||||||
| match vad_event { | ||||||||||||||||||
| Ok(event) => Some(RealtimeEvent::VadEvent(event)), | ||||||||||||||||||
| Err(e) => { | ||||||||||||||||||
| log::error!("Failed to receive VAD event: {}", e); | ||||||||||||||||||
| None | ||||||||||||||||||
| } | ||||||||||||||||||
| } | ||||||||||||||||||
| } | ||||||||||||||||||
| } | ||||||||||||||||||
| } else { | ||||||||||||||||||
| recv_client_event(socket) | ||||||||||||||||||
| .await | ||||||||||||||||||
| .map(RealtimeEvent::ClientEvent) | ||||||||||||||||||
| } | ||||||||||||||||||
| } | ||||||||||||||||||
|
|
||||||||||||||||||
| while let Some(event) = select_event(&mut receiver, &mut realtime_rx).await { | ||||||||||||||||||
| while let Some(event) = recv_client_event(&mut receiver) | ||||||||||||||||||
| .await | ||||||||||||||||||
| .map(RealtimeEvent::ClientEvent) | ||||||||||||||||||
| { | ||||||||||||||||||
| if let Err(e) = handle_client_message( | ||||||||||||||||||
| event, | ||||||||||||||||||
| &mut session, | ||||||||||||||||||
|
|
@@ -360,14 +345,14 @@ async fn handle_client_message( | |||||||||||||||||
| return Ok(()); | ||||||||||||||||||
| } | ||||||||||||||||||
| if turn_detection.turn_type == TurnDetectionType::ServerVad | ||||||||||||||||||
| && session.vad_realtime_client.is_none() | ||||||||||||||||||
| && session.vad_session.is_none() | ||||||||||||||||||
| { | ||||||||||||||||||
| let error_event = ServerEvent::Error { | ||||||||||||||||||
| event_id: Uuid::new_v4().to_string(), | ||||||||||||||||||
| error: ErrorDetails { | ||||||||||||||||||
| error_type: "invalid_request_error".to_string(), | ||||||||||||||||||
| code: Some("vad_realtime_not_connected".to_string()), | ||||||||||||||||||
| message: "VAD realtime service is not connected".to_string(), | ||||||||||||||||||
| code: Some("vad_not_available".to_string()), | ||||||||||||||||||
| message: "VAD session is not available".to_string(), | ||||||||||||||||||
| param: Some("turn_detection.type".to_string()), | ||||||||||||||||||
| event_id: None, | ||||||||||||||||||
| }, | ||||||||||||||||||
|
|
@@ -433,21 +418,25 @@ async fn handle_client_message( | |||||||||||||||||
| .as_ref() | ||||||||||||||||||
| .map(|t| t.turn_type == TurnDetectionType::ServerVad) | ||||||||||||||||||
| .unwrap_or_default() | ||||||||||||||||||
| && session.vad_realtime_client.is_some(); | ||||||||||||||||||
| && session.vad_session.is_some(); | ||||||||||||||||||
|
|
||||||||||||||||||
| log::debug!( | ||||||||||||||||||
| "Server VAD status: {} {:?}", | ||||||||||||||||||
| session.vad_realtime_client.is_some(), | ||||||||||||||||||
| session.vad_session.is_some(), | ||||||||||||||||||
| session.config.turn_detection | ||||||||||||||||||
| ); | ||||||||||||||||||
|
|
||||||||||||||||||
| // Calculate audio duration: 24kHz PCM16 = 48 bytes per ms | ||||||||||||||||||
| let chunk_duration_ms = (audio_data.len() / 48) as u32; | ||||||||||||||||||
|
|
||||||||||||||||||
| if !server_vad || session.triggered { | ||||||||||||||||||
| log::debug!( | ||||||||||||||||||
| "Appending audio chunk to input buffer, length: {}, server VAD: {}", | ||||||||||||||||||
| audio_data.len(), | ||||||||||||||||||
| server_vad | ||||||||||||||||||
| ); | ||||||||||||||||||
| session.input_audio_buffer.extend_from_slice(&audio_data); | ||||||||||||||||||
| session.audio_position_ms += chunk_duration_ms; | ||||||||||||||||||
|
||||||||||||||||||
| } else { | ||||||||||||||||||
| log::debug!( | ||||||||||||||||||
| "Audio chunk received but not triggered, length: {}, server VAD: {}", | ||||||||||||||||||
|
|
@@ -469,30 +458,63 @@ async fn handle_client_message( | |||||||||||||||||
| session.input_audio_buffer.extend_from_slice(&audio_data); | ||||||||||||||||||
| } else { | ||||||||||||||||||
| session.input_audio_buffer.clear(); | ||||||||||||||||||
| session.audio_position_ms = 0; | ||||||||||||||||||
| session.input_audio_buffer.extend_from_slice(&audio_data); | ||||||||||||||||||
| } | ||||||||||||||||||
| session.audio_position_ms += chunk_duration_ms; | ||||||||||||||||||
| } | ||||||||||||||||||
|
|
||||||||||||||||||
| // Process audio through built-in silero VAD | ||||||||||||||||||
| if server_vad { | ||||||||||||||||||
| let samples_24k = audio_data | ||||||||||||||||||
| .chunks_exact(2) | ||||||||||||||||||
| .map(|chunk| { | ||||||||||||||||||
| i16::from_le_bytes([chunk[0], chunk[1]]) as f32 / i16::MAX as f32 | ||||||||||||||||||
| }) | ||||||||||||||||||
| .collect::<Vec<f32>>(); | ||||||||||||||||||
| let sample_16k = wav_io::resample::linear(samples_24k, 1, 24000, 16000); | ||||||||||||||||||
|
|
||||||||||||||||||
| let sample_16k = crate::util::convert_samples_f32_to_i16_bytes(&sample_16k); | ||||||||||||||||||
| log::debug!( | ||||||||||||||||||
| "Sending audio chunk to VAD realtime service, length: {}", | ||||||||||||||||||
| sample_16k.len() | ||||||||||||||||||
| ); | ||||||||||||||||||
| session | ||||||||||||||||||
| .vad_realtime_client | ||||||||||||||||||
| .as_mut() | ||||||||||||||||||
| .unwrap() | ||||||||||||||||||
| .push_audio_16k_chunk(Bytes::from(sample_16k)) | ||||||||||||||||||
| .await?; | ||||||||||||||||||
| if let Some(vad_session) = session.vad_session.as_mut() { | ||||||||||||||||||
| // Convert 24kHz PCM16 to 16kHz f32 for VAD | ||||||||||||||||||
| let samples_24k: Vec<f32> = audio_data | ||||||||||||||||||
| .chunks_exact(2) | ||||||||||||||||||
| .map(|chunk| { | ||||||||||||||||||
| i16::from_le_bytes([chunk[0], chunk[1]]) as f32 / i16::MAX as f32 | ||||||||||||||||||
| }) | ||||||||||||||||||
| .collect(); | ||||||||||||||||||
|
Comment on lines
+471
to
+476
|
||||||||||||||||||
| let samples_16k = | ||||||||||||||||||
| wav_io::resample::linear(samples_24k, 1, 24000, 16000); | ||||||||||||||||||
|
|
||||||||||||||||||
| // Process through VAD in chunks, collecting state transitions | ||||||||||||||||||
| // We collect first to release the vad_session borrow before | ||||||||||||||||||
| // calling functions that need mutable access to session | ||||||||||||||||||
| let chunk_size = VadSession::vad_chunk_size(); | ||||||||||||||||||
| let vad_events: Vec<bool> = samples_16k | ||||||||||||||||||
| .chunks(chunk_size) | ||||||||||||||||||
| .filter_map(|chunk| vad_session.detect(chunk).ok()) | ||||||||||||||||||
|
||||||||||||||||||
| .filter_map(|chunk| vad_session.detect(chunk).ok()) | |
| .filter_map(|chunk| match vad_session.detect(chunk) { | |
| Ok(is_speech) => Some(is_speech), | |
| Err(e) => { | |
| log::error!("VAD detection error: {}", e); | |
| None | |
| } | |
| }) |
Copilot
AI
Jan 24, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The audio_start_ms uses the current audio_position_ms, but this represents the end of the current chunk. For accurate speech start timing, it should use the position at the beginning of the chunk where speech was first detected. Consider subtracting chunk_duration_ms or tracking the position before processing the chunk.
Copilot
AI
Jan 24, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The state transition from speech to non-speech triggers immediately on the first non-speech chunk. This may cause premature cutoff if there are brief pauses during speech. Consider implementing a debounce mechanism or silence threshold (e.g., requiring multiple consecutive non-speech chunks) before triggering the speech end event.
Copilot
AI
Jan 24, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The VAD session state is reset before validating speech in the committed buffer. This reset occurs regardless of whether the real-time VAD is currently active (session.triggered). If speech was ongoing and commit is called, resetting the state could cause inconsistency between the real-time VAD state and the validation check. Consider only resetting when appropriate or maintaining separate VAD instances for real-time and validation.
| vad_session.reset_state(); | |
| // Only reset VAD state when real-time VAD is not currently triggered | |
| if !session.triggered { | |
| vad_session.reset_state(); | |
| } |
Copilot
AI
Jan 24, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The VAD processing logic (conversion of 24kHz PCM16 to 16kHz f32, chunking, and detection) is duplicated between the inline audio processing (lines 454-500) and the commit handler (lines 668-688). Consider extracting this into a helper function to improve maintainability and ensure consistent behavior.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The calculation assumes PCM16 (2 bytes per sample) at 24kHz, which would be 48,000 bytes per second or 48 bytes per millisecond. However, this only works for mono audio. If the audio is stereo, the calculation would be incorrect. The code should explicitly document the assumption about mono audio or handle both cases.