-
Notifications
You must be signed in to change notification settings - Fork 269
Add Adaptive Interruption (bargein) #1002
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
4e31b27
f3c7430
07c5d71
f1a2114
c861f50
1862dc3
705ed33
8f53889
7d24bf0
c78cf58
d5b271c
b020180
dbad1e4
d882012
67e8f6c
96d6b57
2ee2748
62cd448
ec6d9bd
016e3a4
9a4939c
e28b1b1
4310baa
175e57b
0682f25
a820521
e175e7d
9cb0a29
5f088b9
6fbc417
9f2932d
b4a82ad
f2ac83a
ec26bb1
b5c541f
76bd4e8
245bc66
ea27278
63eccca
5bc7108
171fb98
bb93420
717e908
b935b0d
74e42fa
eaafc14
8d14806
cfe0362
ff277d5
b0dbbf5
1caed4f
7138bc9
dad12f8
9ae1e9a
f58454b
dbff0f3
44684ac
66a6b85
7ef1ecf
0956efc
db163f3
1cf1aa4
7cd3e9d
6450210
85068ac
e67caa8
f9739dd
c98313e
4458b81
fe1a82f
b8cad33
5371a6b
65391bc
ee5b957
ac6d18b
7a963de
696e89a
da585e0
98298c5
d2dab3d
aff175c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,6 @@ | ||
| --- | ||
| '@livekit/agents': minor | ||
| --- | ||
|
|
||
| - Add adaptive interruption handling | ||
| - Add remote session event handler |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,51 @@ | ||
| // SPDX-FileCopyrightText: 2026 LiveKit, Inc. | ||
| // | ||
| // SPDX-License-Identifier: Apache-2.0 | ||
| import type { ApiConnectOptions } from './interruption_stream.js'; | ||
| import type { InterruptionOptions } from './types.js'; | ||
|
|
||
| export const MIN_INTERRUPTION_DURATION_IN_S = 0.025 * 2; // 25ms per frame, 2 consecutive frames | ||
| export const THRESHOLD = 0.5; | ||
| export const MAX_AUDIO_DURATION_IN_S = 3.0; | ||
| export const AUDIO_PREFIX_DURATION_IN_S = 0.5; | ||
| export const DETECTION_INTERVAL_IN_S = 0.1; | ||
| export const REMOTE_INFERENCE_TIMEOUT_IN_S = 0.7; | ||
| export const SAMPLE_RATE = 16000; | ||
| export const FRAMES_PER_SECOND = 40; | ||
| export const FRAME_DURATION_IN_S = 0.025; // 25ms per frame | ||
|
|
||
| export const apiConnectDefaults: ApiConnectOptions = { | ||
| maxRetries: 3, | ||
| retryInterval: 2_000, | ||
| timeout: 10_000, | ||
| } as const; | ||
|
|
||
| /** | ||
| * Calculate the retry interval using exponential backoff with jitter. | ||
| * Matches the Python implementation's _interval_for_retry behavior. | ||
| */ | ||
| export function intervalForRetry( | ||
| attempt: number, | ||
| baseInterval: number = apiConnectDefaults.retryInterval, | ||
| ): number { | ||
| // Exponential backoff: baseInterval * 2^attempt with some jitter | ||
| const exponentialDelay = baseInterval * Math.pow(2, attempt); | ||
| // Add jitter (0-25% of the delay) | ||
| const jitter = exponentialDelay * Math.random() * 0.25; | ||
| return exponentialDelay + jitter; | ||
| } | ||
|
|
||
| // baseUrl and useProxy are resolved dynamically in the constructor | ||
| // to respect LIVEKIT_REMOTE_EOT_URL environment variable | ||
| export const interruptionOptionDefaults: Omit<InterruptionOptions, 'baseUrl' | 'useProxy'> = { | ||
| sampleRate: SAMPLE_RATE, | ||
| threshold: THRESHOLD, | ||
| minFrames: Math.ceil(MIN_INTERRUPTION_DURATION_IN_S * FRAMES_PER_SECOND), | ||
| maxAudioDurationInS: MAX_AUDIO_DURATION_IN_S, | ||
| audioPrefixDurationInS: AUDIO_PREFIX_DURATION_IN_S, | ||
| detectionIntervalInS: DETECTION_INTERVAL_IN_S, | ||
| inferenceTimeout: REMOTE_INFERENCE_TIMEOUT_IN_S * 1_000, | ||
| apiKey: process.env.LIVEKIT_API_KEY || '', | ||
| apiSecret: process.env.LIVEKIT_API_SECRET || '', | ||
|
Comment on lines
+48
to
+49
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🟡 In lkApiKey = apiKey ?? process.env.LIVEKIT_INFERENCE_API_KEY ?? process.env.LIVEKIT_API_KEY ?? '';Since Was this helpful? React with 👍 or 👎 to provide feedback. |
||
| minInterruptionDurationInS: MIN_INTERRUPTION_DURATION_IN_S, | ||
|
toubatbrian marked this conversation as resolved.
|
||
| } as const; | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,25 @@ | ||
| // SPDX-FileCopyrightText: 2026 LiveKit, Inc. | ||
| // | ||
| // SPDX-License-Identifier: Apache-2.0 | ||
| /** | ||
| * Error thrown during interruption detection. | ||
| */ | ||
| export class InterruptionDetectionError extends Error { | ||
| readonly type = 'interruption_detection_error' as const; | ||
|
|
||
| readonly timestamp: number; | ||
| readonly label: string; | ||
| readonly recoverable: boolean; | ||
|
|
||
| constructor(message: string, timestamp: number, label: string, recoverable: boolean) { | ||
| super(message); | ||
| this.name = 'InterruptionDetectionError'; | ||
| this.timestamp = timestamp; | ||
| this.label = label; | ||
| this.recoverable = recoverable; | ||
| } | ||
|
|
||
| toString(): string { | ||
| return `${this.name}: ${this.message} (label=${this.label}, timestamp=${this.timestamp}, recoverable=${this.recoverable})`; | ||
| } | ||
| } |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,206 @@ | ||
| // SPDX-FileCopyrightText: 2026 LiveKit, Inc. | ||
| // | ||
| // SPDX-License-Identifier: Apache-2.0 | ||
| import { FetchError, ofetch } from 'ofetch'; | ||
| import { TransformStream } from 'stream/web'; | ||
| import { z } from 'zod'; | ||
| import { APIConnectionError, APIError, APIStatusError, isAPIError } from '../../_exceptions.js'; | ||
| import { log } from '../../log.js'; | ||
| import { createAccessToken } from '../utils.js'; | ||
| import { InterruptionCacheEntry } from './interruption_cache_entry.js'; | ||
| import type { OverlappingSpeechEvent } from './types.js'; | ||
| import type { BoundedCache } from './utils.js'; | ||
|
|
||
| export interface PostOptions { | ||
| baseUrl: string; | ||
| token: string; | ||
| signal?: AbortSignal; | ||
| timeout?: number; | ||
| maxRetries?: number; | ||
| } | ||
|
|
||
| export interface PredictOptions { | ||
| threshold: number; | ||
| minFrames: number; | ||
| } | ||
|
|
||
| export const predictEndpointResponseSchema = z.object({ | ||
| created_at: z.number(), | ||
| is_bargein: z.boolean(), | ||
| probabilities: z.array(z.number()), | ||
| }); | ||
|
|
||
| export type PredictEndpointResponse = z.infer<typeof predictEndpointResponseSchema>; | ||
|
|
||
| export interface PredictResponse { | ||
| createdAt: number; | ||
| isBargein: boolean; | ||
| probabilities: number[]; | ||
| predictionDurationInS: number; | ||
| } | ||
|
|
||
| export async function predictHTTP( | ||
| data: Int16Array, | ||
| predictOptions: PredictOptions, | ||
| options: PostOptions, | ||
| ): Promise<PredictResponse> { | ||
| const createdAt = performance.now(); | ||
| const url = new URL(`/bargein`, options.baseUrl); | ||
| url.searchParams.append('threshold', predictOptions.threshold.toString()); | ||
| url.searchParams.append('min_frames', predictOptions.minFrames.toFixed()); | ||
| url.searchParams.append('created_at', createdAt.toFixed()); | ||
|
|
||
| try { | ||
| const response = await ofetch(url.toString(), { | ||
| retry: 0, | ||
| headers: { | ||
| 'Content-Type': 'application/octet-stream', | ||
| Authorization: `Bearer ${options.token}`, | ||
| }, | ||
| signal: options.signal, | ||
| timeout: options.timeout, | ||
| method: 'POST', | ||
| body: data, | ||
| }); | ||
| const { created_at, is_bargein, probabilities } = predictEndpointResponseSchema.parse(response); | ||
|
|
||
| return { | ||
| createdAt: created_at, | ||
| isBargein: is_bargein, | ||
| probabilities, | ||
| predictionDurationInS: (performance.now() - createdAt) / 1000, | ||
| }; | ||
| } catch (err) { | ||
| if (isAPIError(err)) throw err; | ||
| if (err instanceof FetchError) { | ||
| if (err.statusCode) { | ||
| throw new APIStatusError({ | ||
| message: `error during interruption prediction: ${err.message}`, | ||
| options: { statusCode: err.statusCode, body: err.data }, | ||
| }); | ||
| } | ||
| if ( | ||
| err.cause instanceof Error && | ||
| (err.cause.name === 'TimeoutError' || err.cause.name === 'AbortError') | ||
| ) { | ||
| throw new APIStatusError({ | ||
| message: `interruption inference timeout: ${err.message}`, | ||
| options: { statusCode: 408, retryable: false }, | ||
| }); | ||
| } | ||
| throw new APIConnectionError({ | ||
| message: `interruption inference connection error: ${err.message}`, | ||
| }); | ||
| } | ||
| throw new APIError(`error during interruption prediction: ${err}`); | ||
| } | ||
| } | ||
|
|
||
| export interface HttpTransportOptions { | ||
| baseUrl: string; | ||
| apiKey: string; | ||
| apiSecret: string; | ||
| threshold: number; | ||
| minFrames: number; | ||
| timeout: number; | ||
| maxRetries?: number; | ||
| } | ||
|
|
||
| export interface HttpTransportState { | ||
| overlapSpeechStarted: boolean; | ||
| overlapSpeechStartedAt: number | undefined; | ||
| cache: BoundedCache<number, InterruptionCacheEntry>; | ||
| } | ||
|
|
||
| /** | ||
| * Creates an HTTP transport TransformStream for interruption detection. | ||
| * | ||
| * This transport receives Int16Array audio slices and outputs InterruptionEvents. | ||
| * Each audio slice triggers an HTTP POST request. | ||
| * | ||
| * @param options - Transport options object. This is read on each request, so mutations | ||
| * to threshold/minFrames will be picked up dynamically. | ||
| */ | ||
| export function createHttpTransport( | ||
| options: HttpTransportOptions, | ||
| getState: () => HttpTransportState, | ||
| setState: (partial: Partial<HttpTransportState>) => void, | ||
| updateUserSpeakingSpan?: (entry: InterruptionCacheEntry) => void, | ||
| getAndResetNumRequests?: () => number, | ||
| ): TransformStream<Int16Array | OverlappingSpeechEvent, OverlappingSpeechEvent> { | ||
| const logger = log(); | ||
|
|
||
| return new TransformStream<Int16Array | OverlappingSpeechEvent, OverlappingSpeechEvent>( | ||
| { | ||
| async transform(chunk, controller) { | ||
| if (!(chunk instanceof Int16Array)) { | ||
| controller.enqueue(chunk); | ||
| return; | ||
| } | ||
|
|
||
| const state = getState(); | ||
| const overlapSpeechStartedAt = state.overlapSpeechStartedAt; | ||
| if (overlapSpeechStartedAt === undefined || !state.overlapSpeechStarted) return; | ||
|
|
||
| try { | ||
| const resp = await predictHTTP( | ||
| chunk, | ||
| { threshold: options.threshold, minFrames: options.minFrames }, | ||
| { | ||
| baseUrl: options.baseUrl, | ||
| timeout: options.timeout, | ||
| maxRetries: options.maxRetries, | ||
| token: await createAccessToken(options.apiKey, options.apiSecret), | ||
| }, | ||
| ); | ||
|
|
||
| const { createdAt, isBargein, probabilities, predictionDurationInS } = resp; | ||
| const entry = state.cache.setOrUpdate( | ||
| createdAt, | ||
| () => new InterruptionCacheEntry({ createdAt }), | ||
| { | ||
| probabilities, | ||
| isInterruption: isBargein, | ||
| speechInput: chunk, | ||
| totalDurationInS: (performance.now() - createdAt) / 1000, | ||
| detectionDelayInS: (Date.now() - overlapSpeechStartedAt) / 1000, | ||
| predictionDurationInS, | ||
| }, | ||
| ); | ||
|
|
||
| if (state.overlapSpeechStarted && entry.isInterruption) { | ||
| if (updateUserSpeakingSpan) { | ||
| updateUserSpeakingSpan(entry); | ||
| } | ||
| const event: OverlappingSpeechEvent = { | ||
| type: 'overlapping_speech', | ||
| detectedAt: Date.now(), | ||
| overlapStartedAt: overlapSpeechStartedAt, | ||
| isInterruption: entry.isInterruption, | ||
| speechInput: entry.speechInput, | ||
| probabilities: entry.probabilities, | ||
| totalDurationInS: entry.totalDurationInS, | ||
| predictionDurationInS: entry.predictionDurationInS, | ||
| detectionDelayInS: entry.detectionDelayInS, | ||
| probability: entry.probability, | ||
| numRequests: getAndResetNumRequests?.() ?? 0, | ||
| }; | ||
| logger.debug( | ||
| { | ||
| detectionDelayInS: entry.detectionDelayInS, | ||
| totalDurationInS: entry.totalDurationInS, | ||
| }, | ||
| 'interruption detected', | ||
| ); | ||
| setState({ overlapSpeechStarted: false }); | ||
| controller.enqueue(event); | ||
| } | ||
| } catch (err) { | ||
| controller.error(err); | ||
| } | ||
| }, | ||
| }, | ||
| { highWaterMark: 2 }, | ||
| { highWaterMark: 2 }, | ||
| ); | ||
| } |
Uh oh!
There was an error while loading. Please reload this page.