Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
228 changes: 125 additions & 103 deletions src/services/realtime_ws.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use axum::{
response::IntoResponse,
};
use base64::Engine;
use bytes::{BufMut, Bytes, BytesMut};
use bytes::{BufMut, BytesMut};
use futures_util::{
sink::SinkExt,
stream::{SplitStream, StreamExt},
Expand All @@ -15,13 +15,7 @@ use tokio::sync::mpsc;
use uuid::Uuid;

use crate::{
ai::{
ChatSession,
bailian::cosyvoice,
elevenlabs,
openai::realtime::*,
vad::{VadRealtimeClient, VadRealtimeEvent},
},
ai::{ChatSession, bailian::cosyvoice, elevenlabs, openai::realtime::*, vad::VadSession},
config::*,
};

Expand All @@ -44,11 +38,13 @@ pub struct RealtimeSession {
pub input_audio_buffer: BytesMut,
pub triggered: bool,
pub is_generating: bool,
pub vad_realtime_client: Option<VadRealtimeClient>,
pub vad_session: Option<VadSession>,
/// Cumulative audio duration in milliseconds (for 24kHz PCM16 input)
pub audio_position_ms: u32,
}

impl RealtimeSession {
pub fn new(chat_session: ChatSession) -> Self {
pub fn new(chat_session: ChatSession, vad_session: Option<VadSession>) -> Self {
Self {
client: reqwest::Client::new(),
chat_session,
Expand All @@ -58,7 +54,8 @@ impl RealtimeSession {
input_audio_buffer: BytesMut::new(),
triggered: false,
is_generating: false,
vad_realtime_client: None,
vad_session,
audio_position_ms: 0,
}
}
}
Expand All @@ -72,7 +69,6 @@ pub struct StableRealtimeConfig {

enum RealtimeEvent {
ClientEvent(ClientEvent),
VadEvent(VadRealtimeEvent),
}

pub async fn ws_handler(
Expand Down Expand Up @@ -100,24 +96,36 @@ async fn handle_socket(config: Arc<StableRealtimeConfig>, socket: WebSocket) {
chat_session.system_prompts = parts.sys_prompts;
chat_session.messages = parts.dynamic_prompts;

// 创建新的 Realtime 会话
let mut session = RealtimeSession::new(chat_session);
let mut realtime_rx: Option<_> = None;

if let Some(vad_realtime_url) = &config.asr.vad_realtime_url {
match crate::ai::vad::vad_realtime_client(&session.client, vad_realtime_url.clone()).await {
Ok((client, rx)) => {
session.vad_realtime_client = Some(client);
realtime_rx = Some(rx);
log::info!("Connected to VAD realtime service at {}", vad_realtime_url);
}
Err(e) => {
log::error!("Failed to connect to VAD realtime service: {}", e);
// Initialize built-in silero VAD session
let device = burn::backend::ndarray::NdArrayDevice::default();
let vad_session = match silero_vad_burn::SileroVAD6Model::new(&device) {
Ok(vad_model) => {
match crate::ai::vad::VadSession::new(&config.asr.vad, Box::new(vad_model), device) {
Ok(session) => {
log::info!("Initialized built-in silero VAD session");
Some(session)
}
Err(e) => {
log::error!("Failed to create VAD session: {}", e);
None
}
}
}
}
Err(e) => {
log::error!(
"Failed to load silero VAD model: {}. \
This may be due to missing model files or insufficient memory.",
e
);
None
}
};

let turn_detection = if realtime_rx.is_some() {
// 创建新的 Realtime 会话
let has_vad = vad_session.is_some();
let mut session = RealtimeSession::new(chat_session, vad_session);

let turn_detection = if has_vad {
TurnDetection::server_vad()
} else {
TurnDetection::none()
Expand Down Expand Up @@ -244,33 +252,10 @@ async fn handle_socket(config: Arc<StableRealtimeConfig>, socket: WebSocket) {
None
}

async fn select_event(
socket: &mut SplitStream<WebSocket>,
realtime_rx: &mut Option<crate::ai::vad::VadRealtimeRx>,
) -> Option<RealtimeEvent> {
if let Some(rx) = realtime_rx {
tokio::select! {
client_event = recv_client_event(socket) => {
client_event.map(RealtimeEvent::ClientEvent)
}
vad_event = rx.next_event() => {
match vad_event {
Ok(event) => Some(RealtimeEvent::VadEvent(event)),
Err(e) => {
log::error!("Failed to receive VAD event: {}", e);
None
}
}
}
}
} else {
recv_client_event(socket)
.await
.map(RealtimeEvent::ClientEvent)
}
}

while let Some(event) = select_event(&mut receiver, &mut realtime_rx).await {
while let Some(event) = recv_client_event(&mut receiver)
.await
.map(RealtimeEvent::ClientEvent)
{
if let Err(e) = handle_client_message(
event,
&mut session,
Expand Down Expand Up @@ -360,14 +345,14 @@ async fn handle_client_message(
return Ok(());
}
if turn_detection.turn_type == TurnDetectionType::ServerVad
&& session.vad_realtime_client.is_none()
&& session.vad_session.is_none()
{
let error_event = ServerEvent::Error {
event_id: Uuid::new_v4().to_string(),
error: ErrorDetails {
error_type: "invalid_request_error".to_string(),
code: Some("vad_realtime_not_connected".to_string()),
message: "VAD realtime service is not connected".to_string(),
code: Some("vad_not_available".to_string()),
message: "VAD session is not available".to_string(),
param: Some("turn_detection.type".to_string()),
event_id: None,
},
Expand Down Expand Up @@ -433,21 +418,25 @@ async fn handle_client_message(
.as_ref()
.map(|t| t.turn_type == TurnDetectionType::ServerVad)
.unwrap_or_default()
&& session.vad_realtime_client.is_some();
&& session.vad_session.is_some();

log::debug!(
"Server VAD status: {} {:?}",
session.vad_realtime_client.is_some(),
session.vad_session.is_some(),
session.config.turn_detection
);

// Calculate audio duration: 24kHz PCM16 = 48 bytes per ms
let chunk_duration_ms = (audio_data.len() / 48) as u32;
Copy link

Copilot AI Jan 24, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The calculation assumes PCM16 (2 bytes per sample) at 24kHz, which would be 48,000 bytes per second or 48 bytes per millisecond. However, this only works for mono audio. If the audio is stereo, the calculation would be incorrect. The code should explicitly document the assumption about mono audio or handle both cases.

Copilot uses AI. Check for mistakes.

if !server_vad || session.triggered {
log::debug!(
"Appending audio chunk to input buffer, length: {}, server VAD: {}",
audio_data.len(),
server_vad
);
session.input_audio_buffer.extend_from_slice(&audio_data);
session.audio_position_ms += chunk_duration_ms;
Copy link

Copilot AI Jan 24, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The audio_position_ms is incremented after extending the buffer when triggered, but in the else branch (lines 464), it's also incremented. However, when the buffer is cleared (line 460), the audio_position_ms is reset to 0 (line 461) and then incremented by the current chunk duration (line 464). This means the position doesn't account for discarded audio before speech detection. The timing may be inaccurate if audio was buffered and then cleared.

Copilot uses AI. Check for mistakes.
} else {
log::debug!(
"Audio chunk received but not triggered, length: {}, server VAD: {}",
Expand All @@ -469,30 +458,63 @@ async fn handle_client_message(
session.input_audio_buffer.extend_from_slice(&audio_data);
} else {
session.input_audio_buffer.clear();
session.audio_position_ms = 0;
session.input_audio_buffer.extend_from_slice(&audio_data);
}
session.audio_position_ms += chunk_duration_ms;
}

// Process audio through built-in silero VAD
if server_vad {
let samples_24k = audio_data
.chunks_exact(2)
.map(|chunk| {
i16::from_le_bytes([chunk[0], chunk[1]]) as f32 / i16::MAX as f32
})
.collect::<Vec<f32>>();
let sample_16k = wav_io::resample::linear(samples_24k, 1, 24000, 16000);

let sample_16k = crate::util::convert_samples_f32_to_i16_bytes(&sample_16k);
log::debug!(
"Sending audio chunk to VAD realtime service, length: {}",
sample_16k.len()
);
session
.vad_realtime_client
.as_mut()
.unwrap()
.push_audio_16k_chunk(Bytes::from(sample_16k))
.await?;
if let Some(vad_session) = session.vad_session.as_mut() {
// Convert 24kHz PCM16 to 16kHz f32 for VAD
let samples_24k: Vec<f32> = audio_data
.chunks_exact(2)
.map(|chunk| {
i16::from_le_bytes([chunk[0], chunk[1]]) as f32 / i16::MAX as f32
})
.collect();
Comment on lines +471 to +476
Copy link

Copilot AI Jan 24, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This audio conversion logic (24kHz PCM16 to f32) is duplicated in three locations: here (lines 462-467), in the handle_audio_buffer_commit function (lines 677-680), and appears twice within the same InputAudioBufferAppend handler. Consider extracting this into a helper function to improve maintainability and reduce duplication.

Copilot uses AI. Check for mistakes.
Comment on lines +471 to +476
Copy link

Copilot AI Jan 24, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This PCM16-to-f32 conversion is duplicated in both the real-time processing (lines 471-476) and the commit validation (lines 689-692). Consider extracting this conversion logic into a helper function to reduce code duplication and improve maintainability.

Copilot uses AI. Check for mistakes.
let samples_16k =
wav_io::resample::linear(samples_24k, 1, 24000, 16000);

// Process through VAD in chunks, collecting state transitions
// We collect first to release the vad_session borrow before
// calling functions that need mutable access to session
let chunk_size = VadSession::vad_chunk_size();
let vad_events: Vec<bool> = samples_16k
.chunks(chunk_size)
.filter_map(|chunk| vad_session.detect(chunk).ok())
Copy link

Copilot AI Jan 24, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using filter_map with .ok() silently ignores VAD detection errors. If errors occur during processing, they won't be logged or handled, potentially causing missed speech detection. Consider logging errors or maintaining error state instead of silently discarding them.

Suggested change
.filter_map(|chunk| vad_session.detect(chunk).ok())
.filter_map(|chunk| match vad_session.detect(chunk) {
Ok(is_speech) => Some(is_speech),
Err(e) => {
log::error!("VAD detection error: {}", e);
None
}
})

Copilot uses AI. Check for mistakes.
.collect();

// Process VAD events and handle state transitions
for is_speech in vad_events {
if is_speech && !session.triggered {
// Speech started
log::info!(
"VAD detected speech start at {}ms",
session.audio_position_ms
);
session.triggered = true;
let event = ServerEvent::InputAudioBufferSpeechStarted {
event_id: Uuid::new_v4().to_string(),
audio_start_ms: session.audio_position_ms,
Copy link

Copilot AI Jan 24, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The audio_start_ms uses the current audio_position_ms, but this represents the end of the current chunk. For accurate speech start timing, it should use the position at the beginning of the chunk where speech was first detected. Consider subtracting chunk_duration_ms or tracking the position before processing the chunk.

Copilot uses AI. Check for mistakes.
item_id: Uuid::new_v4().to_string(),
};
let _ = tx.send(event).await;
} else if !is_speech && session.triggered {
Copy link

Copilot AI Jan 24, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The state transition from speech to non-speech triggers immediately on the first non-speech chunk. This may cause premature cutoff if there are brief pauses during speech. Consider implementing a debounce mechanism or silence threshold (e.g., requiring multiple consecutive non-speech chunks) before triggering the speech end event.

Copilot uses AI. Check for mistakes.
// Speech ended - trigger commit
log::info!("VAD detected speech end, triggering commit");
if handle_audio_buffer_commit(session, tx, None, asr).await? {
generate_response(session, tx, tts).await?;
}
session.triggered = false;
if let Some(vs) = session.vad_session.as_mut() {
vs.reset_state();
}
// Continue processing - new speech may start in remaining chunks
}
}
}
}
}

Expand All @@ -505,6 +527,7 @@ async fn handle_client_message(

ClientEvent::InputAudioBufferClear { event_id: _ } => {
session.input_audio_buffer.clear();
session.audio_position_ms = 0;

let event = ServerEvent::InputAudioBufferCleared {
event_id: Uuid::new_v4().to_string(),
Expand Down Expand Up @@ -630,28 +653,6 @@ async fn handle_client_message(
}
}
}
RealtimeEvent::VadEvent(vad_realtime_event) => match vad_realtime_event {
VadRealtimeEvent::Event { event } => match event.as_str() {
"speech_start" => {
log::debug!("VAD speech start detected");
session.triggered = true;
}
"speech_end" => {
log::debug!("VAD speech end detected");
session.triggered = false;
if handle_audio_buffer_commit(session, tx, None, asr).await? {
log::debug!("Audio buffer committed, generating response");
generate_response(session, tx, tts).await?;
}
}
_ => {
log::warn!("Unhandled VAD event: {}", event);
}
},
VadRealtimeEvent::Error { message, .. } => {
return Err(anyhow::anyhow!("VAD error: {}", message));
}
},
}

Ok(())
Expand Down Expand Up @@ -682,9 +683,30 @@ async fn handle_audio_buffer_commit(
};
let _ = tx.send(committed_event).await;

if let Some(vad_url) = &config.vad_url {
let vad = crate::ai::vad::vad_detect(&session.client, vad_url, wav_audio.clone()).await?;
if vad.timestamps.is_empty() {
// Check for speech using built-in silero VAD
if let Some(vad_session) = session.vad_session.as_mut() {
// Convert 24kHz PCM16 to 16kHz f32 for VAD
let samples_24k: Vec<f32> = audio_data
.chunks_exact(2)
.map(|chunk| i16::from_le_bytes([chunk[0], chunk[1]]) as f32 / i16::MAX as f32)
.collect();
let samples_16k = wav_io::resample::linear(samples_24k, 1, 24000, 16000);

// Process through VAD to check if there's any speech
let chunk_size = VadSession::vad_chunk_size();
let mut has_speech = false;
vad_session.reset_state();
Copy link

Copilot AI Jan 24, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The VAD session state is reset before validating speech in the committed buffer. This reset occurs regardless of whether the real-time VAD is currently active (session.triggered). If speech was ongoing and commit is called, resetting the state could cause inconsistency between the real-time VAD state and the validation check. Consider only resetting when appropriate or maintaining separate VAD instances for real-time and validation.

Suggested change
vad_session.reset_state();
// Only reset VAD state when real-time VAD is not currently triggered
if !session.triggered {
vad_session.reset_state();
}

Copilot uses AI. Check for mistakes.
for chunk in samples_16k.chunks(chunk_size) {
if let Ok(is_speech) = vad_session.detect(chunk) {
if is_speech {
has_speech = true;
break;
Comment on lines +696 to +703
Copy link

Copilot AI Jan 24, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The VAD processing logic (conversion of 24kHz PCM16 to 16kHz f32, chunking, and detection) is duplicated between the inline audio processing (lines 454-500) and the commit handler (lines 668-688). Consider extracting this into a helper function to improve maintainability and ensure consistent behavior.

Copilot uses AI. Check for mistakes.
}
}
}

if !has_speech {
log::debug!("No speech detected in audio buffer, skipping ASR");
let transcription_completed =
ServerEvent::ConversationItemInputAudioTranscriptionCompleted {
event_id: Uuid::new_v4().to_string(),
Expand Down