diff --git a/js/plugins/compat-oai/src/audio.ts b/js/plugins/compat-oai/src/audio.ts index 7e0e9c52e5..0a19989aaf 100644 --- a/js/plugins/compat-oai/src/audio.ts +++ b/js/plugins/compat-oai/src/audio.ts @@ -14,6 +14,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + import type { GenerateRequest, GenerateResponseData, @@ -41,7 +42,7 @@ export type TranscriptionRequestBuilder = ( params: TranscriptionCreateParams ) => void; -export const TRANSCRIPTION_MODEL_INFO = { +export const TRANSCRIPTION_MODEL_INFO: ModelInfo = { supports: { media: true, output: ['text', 'json'], diff --git a/js/plugins/compat-oai/src/index.ts b/js/plugins/compat-oai/src/index.ts index faf8e17b01..f474ebbf9d 100644 --- a/js/plugins/compat-oai/src/index.ts +++ b/js/plugins/compat-oai/src/index.ts @@ -45,6 +45,12 @@ export { openAIModelRunner, type ModelRequestBuilder, } from './model.js'; +export { + TranslationConfigSchema, + compatOaiTranslationModelRef, + defineCompatOpenAITranslationModel, + type TranslationRequestBuilder, +} from './translate.js'; export interface PluginOptions extends Partial> { apiKey?: ClientOptions['apiKey'] | false; diff --git a/js/plugins/compat-oai/src/openai/index.ts b/js/plugins/compat-oai/src/openai/index.ts index 38d6f80ced..15b4ec7671 100644 --- a/js/plugins/compat-oai/src/openai/index.ts +++ b/js/plugins/compat-oai/src/openai/index.ts @@ -56,6 +56,12 @@ import { } from './gpt.js'; import { openAITranscriptionModelRef, SUPPORTED_STT_MODELS } from './stt.js'; import { openAISpeechModelRef, SUPPORTED_TTS_MODELS } from './tts.js'; +import { + defineOpenAIWhisperModel, + openAIWhisperModelRef, + SUPPORTED_WHISPER_MODELS, + WhisperConfigSchema, +} from './whisper.js'; export type OpenAIPluginOptions = Omit; @@ -88,10 +94,15 @@ function createResolver(pluginOptions: PluginOptions) { pluginOptions, modelRef, }); - } else if ( - actionName.includes('whisper') || - actionName.includes('transcribe') - ) { + } else if (actionName.includes('whisper')) { + const modelRef = openAIWhisperModelRef({ name: actionName }); + return defineOpenAIWhisperModel({ + name: modelRef.name, + client, + pluginOptions, + modelRef, + }); + } else if (actionName.includes('transcribe')) { const modelRef = openAITranscriptionModelRef({ name: actionName, }); @@ -147,10 +158,16 @@ const listActions = async (client: OpenAI): Promise => { info: modelRef.info, configSchema: modelRef.configSchema, }); - } else if ( - model.id.includes('whisper') || - model.id.includes('transcribe') - ) { + } else if (model.id.includes('whisper')) { + const modelRef = + SUPPORTED_WHISPER_MODELS[model.id] ?? + openAIWhisperModelRef({ name: model.id }); + return modelActionMetadata({ + name: modelRef.name, + info: modelRef.info, + configSchema: modelRef.configSchema, + }); + } else if (model.id.includes('transcribe')) { const modelRef = SUPPORTED_STT_MODELS[model.id] ?? openAITranscriptionModelRef({ name: model.id }); @@ -209,6 +226,16 @@ export function openAIPlugin(options?: OpenAIPluginOptions): GenkitPluginV2 { }) ) ); + models.push( + ...Object.values(SUPPORTED_WHISPER_MODELS).map((modelRef) => + defineOpenAIWhisperModel({ + name: modelRef.name, + client, + pluginOptions, + modelRef, + }) + ) + ); models.push( ...Object.values(SUPPORTED_STT_MODELS).map((modelRef) => defineCompatOpenAITranscriptionModel({ @@ -256,10 +283,11 @@ export type OpenAIPlugin = { config?: z.infer ): ModelReference; model( - name: - | keyof typeof SUPPORTED_STT_MODELS - | (`whisper-${string}` & {}) - | (`${string}-transcribe` & {}), + name: keyof typeof SUPPORTED_WHISPER_MODELS | (`whisper-${string}` & {}), + config?: z.infer + ): ModelReference; + model( + name: keyof typeof SUPPORTED_STT_MODELS | (`${string}-transcribe` & {}), config?: z.infer ): ModelReference; model( @@ -292,7 +320,13 @@ const model = ((name: string, config?: any): ModelReference => { config, }); } - if (name.includes('whisper') || name.includes('transcribe')) { + if (name.includes('whisper')) { + return openAIWhisperModelRef({ + name, + config, + }); + } + if (name.includes('transcribe')) { return openAITranscriptionModelRef({ name, config, diff --git a/js/plugins/compat-oai/src/openai/stt.ts b/js/plugins/compat-oai/src/openai/stt.ts index 1833abaf27..081678cd34 100644 --- a/js/plugins/compat-oai/src/openai/stt.ts +++ b/js/plugins/compat-oai/src/openai/stt.ts @@ -38,7 +38,4 @@ export const SUPPORTED_STT_MODELS = { 'gpt-4o-mini-transcribe': openAITranscriptionModelRef({ name: 'gpt-4o-mini-transcribe', }), - 'whisper-1': openAITranscriptionModelRef({ - name: 'whisper-1', - }), }; diff --git a/js/plugins/compat-oai/src/openai/whisper.ts b/js/plugins/compat-oai/src/openai/whisper.ts new file mode 100644 index 0000000000..91d2ba67a6 --- /dev/null +++ b/js/plugins/compat-oai/src/openai/whisper.ts @@ -0,0 +1,147 @@ +/** + * Copyright 2024 The Fire Company + * Copyright 2024 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import type { ModelReference } from 'genkit'; +import { modelRef, z } from 'genkit'; +import type { ModelAction, ModelInfo } from 'genkit/model'; +import { model } from 'genkit/plugin'; +import OpenAI from 'openai'; +import { + TranscriptionConfigSchema, + toSttRequest, + transcriptionToGenerateResponse, +} from '../audio.js'; +import type { PluginOptions } from '../index.js'; +import { + toTranslationRequest, + translationToGenerateResponse, +} from '../translate.js'; +import { maybeCreateRequestScopedOpenAIClient, toModelName } from '../utils.js'; + +export const WHISPER_MODEL_INFO: ModelInfo = { + supports: { + media: true, + output: ['text', 'json'], + multiturn: false, + systemRole: false, + tools: false, + }, +}; + +/** + * Config schema for Whisper models. Extends the transcription config with + * a `translate` flag that switches between transcription and translation APIs. + */ +export const WhisperConfigSchema = TranscriptionConfigSchema.extend({ + /** When true, uses Translation API instead of Transcription. Default: false */ + translate: z.boolean().optional().default(false), +}); + +/** + * Method to define an OpenAI Whisper model that can perform both transcription and + * translation based on the `translate` config flag. + * + * @param params.ai The Genkit AI instance. + * @param params.name The name of the model. + * @param params.client The OpenAI client instance. + * @param params.modelRef Optional reference to the model's configuration and + * custom options. + * + * @returns the created {@link ModelAction} + */ +export function defineOpenAIWhisperModel< + CustomOptions extends z.ZodTypeAny = z.ZodTypeAny, +>(params: { + name: string; + client: OpenAI; + modelRef?: ModelReference; + pluginOptions?: PluginOptions; +}): ModelAction { + const { name, client: defaultClient, pluginOptions, modelRef } = params; + const modelName = toModelName(name, pluginOptions?.name); + const actionName = + modelRef?.name ?? `${pluginOptions?.name ?? 'openai'}/${modelName}`; + + return model( + { + name: actionName, + ...modelRef?.info, + configSchema: modelRef?.configSchema, + }, + async (request, { abortSignal }) => { + const { translate, ...cleanConfig } = (request.config ?? {}) as Record< + string, + unknown + >; + const cleanRequest = { ...request, config: cleanConfig }; + const client = maybeCreateRequestScopedOpenAIClient( + pluginOptions, + request, + defaultClient + ); + + if (translate === true) { + const params = toTranslationRequest(modelName, cleanRequest); + const result = await client.audio.translations.create(params, { + signal: abortSignal, + }); + return translationToGenerateResponse(result); + } else { + const params = toSttRequest(modelName, cleanRequest); + // Explicitly setting stream to false ensures we use the non-streaming overload + const result = await client.audio.transcriptions.create( + { + ...params, + stream: false, + }, + { signal: abortSignal } + ); + return transcriptionToGenerateResponse(result); + } + } + ); +} + +/** OpenAI whisper ModelRef helper. */ +export function openAIWhisperModelRef< + CustomOptions extends z.ZodTypeAny = z.ZodTypeAny, +>(params: { + name: string; + info?: ModelInfo; + configSchema?: CustomOptions; + config?: any; +}) { + const { + name, + info = WHISPER_MODEL_INFO, + configSchema, + config = undefined, + } = params; + return modelRef({ + name, + configSchema: configSchema || (WhisperConfigSchema as any), + info, + config, + namespace: 'openai', + }); +} + +export const SUPPORTED_WHISPER_MODELS = { + 'whisper-1': openAIWhisperModelRef({ + name: 'whisper-1', + }), +}; diff --git a/js/plugins/compat-oai/src/translate.ts b/js/plugins/compat-oai/src/translate.ts new file mode 100644 index 0000000000..2200b5e119 --- /dev/null +++ b/js/plugins/compat-oai/src/translate.ts @@ -0,0 +1,223 @@ +/** + * Copyright 2024 The Fire Company + * Copyright 2024 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import type { + GenerateRequest, + GenerateResponseData, + ModelReference, +} from 'genkit'; +import { GenerationCommonConfigSchema, Message, modelRef, z } from 'genkit'; +import type { ModelAction, ModelInfo } from 'genkit/model'; +import { model } from 'genkit/plugin'; +import OpenAI from 'openai'; +import type { + TranslationCreateParams, + TranslationCreateResponse, +} from 'openai/resources/audio/index.mjs'; +import { PluginOptions } from './index.js'; +import { maybeCreateRequestScopedOpenAIClient, toModelName } from './utils.js'; + +export type TranslationRequestBuilder = ( + req: GenerateRequest, + params: TranslationCreateParams +) => void; + +export const TRANSLATION_MODEL_INFO: ModelInfo = { + supports: { + media: true, + output: ['text', 'json'], + multiturn: false, + systemRole: false, + tools: false, + }, +}; + +export const TranslationConfigSchema = GenerationCommonConfigSchema.pick({ + temperature: true, +}).extend({ + response_format: z + .enum(['json', 'text', 'srt', 'verbose_json', 'vtt']) + .optional(), +}); + +export function toTranslationRequest( + modelName: string, + request: GenerateRequest, + requestBuilder?: TranslationRequestBuilder +): TranslationCreateParams { + const message = new Message(request.messages[0]); + const media = message.media; + if (!media?.url) { + throw new Error('No media found in the request'); + } + const mediaBuffer = Buffer.from( + media.url.slice(media.url.indexOf(',') + 1), + 'base64' + ); + const mediaFile = new File([mediaBuffer], 'input', { + type: + media.contentType ?? + media.url.slice('data:'.length, media.url.indexOf(';')), + }); + const { + temperature, + version: modelVersion, + maxOutputTokens, + stopSequences, + topK, + topP, + ...restOfConfig + } = request.config ?? {}; + + let options: TranslationCreateParams = { + model: modelVersion ?? modelName, + file: mediaFile, + prompt: message.text, + temperature, + }; + if (requestBuilder) { + requestBuilder(request, options); + } else { + options = { + ...options, + ...restOfConfig, // passthrough rest of the config + }; + } + const outputFormat = request.output?.format as 'json' | 'text' | 'media'; + const customFormat = request.config?.response_format; + if (outputFormat && customFormat) { + if ( + outputFormat === 'json' && + customFormat !== 'json' && + customFormat !== 'verbose_json' + ) { + throw new Error( + `Custom response format ${customFormat} is not compatible with output format ${outputFormat}` + ); + } + } + if (outputFormat === 'media') { + throw new Error(`Output format ${outputFormat} is not supported.`); + } + options.response_format = customFormat || outputFormat || 'text'; + for (const k in options) { + if (options[k] === undefined) { + delete options[k]; + } + } + return options; +} + +export function translationToGenerateResponse( + result: TranslationCreateResponse | string +): GenerateResponseData { + return { + message: { + role: 'model', + content: [ + { + text: typeof result === 'string' ? result : result.text, + }, + ], + }, + finishReason: 'stop', + raw: result, + }; +} + +/** + * Method to define a new Genkit Model that is compatible with Open AI + * Translation API. + * + * These models are to be used to translate audio to text. + * + * @param params An object containing parameters for defining the OpenAI + * translation model. + * @param params.ai The Genkit AI instance. + * @param params.name The name of the model. + * @param params.client The OpenAI client instance. + * @param params.modelRef Optional reference to the model's configuration and + * custom options. + * + * @returns the created {@link ModelAction} + */ +export function defineCompatOpenAITranslationModel< + CustomOptions extends z.ZodTypeAny = z.ZodTypeAny, +>(params: { + name: string; + client: OpenAI; + pluginOptions?: PluginOptions; + modelRef?: ModelReference; + requestBuilder?: TranslationRequestBuilder; +}) { + const { + name, + client: defaultClient, + pluginOptions, + modelRef, + requestBuilder, + } = params; + const modelName = toModelName(name, pluginOptions?.name); + const actionName = `${pluginOptions?.name ?? 'compat-oai'}/${modelName}`; + + return model( + { + name: actionName, + ...modelRef?.info, + configSchema: modelRef?.configSchema, + }, + async (request, { abortSignal }) => { + const params = toTranslationRequest(modelName, request, requestBuilder); + const client = maybeCreateRequestScopedOpenAIClient( + pluginOptions, + request, + defaultClient + ); + const result = await client.audio.translations.create(params, { + signal: abortSignal, + }); + return translationToGenerateResponse(result); + } + ); +} + +/** Translation ModelRef helper, with reasonable defaults for + * OpenAI-compatible providers */ +export function compatOaiTranslationModelRef< + CustomOptions extends z.ZodTypeAny = z.ZodTypeAny, +>(params: { + name: string; + info?: ModelInfo; + configSchema?: CustomOptions; + config?: any; + namespace?: string; +}) { + const { + name, + info = TRANSLATION_MODEL_INFO, + configSchema, + config = undefined, + namespace, + } = params; + return modelRef({ + name, + configSchema: configSchema || (TranslationConfigSchema as any), + info, + config, + namespace, + }); +} diff --git a/js/plugins/compat-oai/tests/compat_oai_translate_test.ts b/js/plugins/compat-oai/tests/compat_oai_translate_test.ts new file mode 100644 index 0000000000..371bbdc8c0 --- /dev/null +++ b/js/plugins/compat-oai/tests/compat_oai_translate_test.ts @@ -0,0 +1,277 @@ +/** + * Copyright 2024 The Fire Company + * Copyright 2024 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import { describe, expect, it, jest } from '@jest/globals'; +import { GenerateRequest } from 'genkit'; +import OpenAI from 'openai'; +import { Translation } from 'openai/resources/audio/translations.mjs'; +import { + defineCompatOpenAITranslationModel, + toTranslationRequest, + translationToGenerateResponse, +} from '../src/translate'; + +jest.mock('genkit/model', () => { + const originalModule = + jest.requireActual('genkit/model'); + return { + ...originalModule, + defineModel: jest.fn((_, runner) => { + return runner; + }), + }; +}); + +describe('toTranslationRequest', () => { + it('should create translation request from base64 audio', () => { + const request = { + messages: [ + { + role: 'user', + content: [ + { + media: { + contentType: 'audio/wav', + url: 'data:audio/wav;base64,aGVsbG8=', + }, + }, + ], + }, + ], + output: { format: 'text' }, + } as GenerateRequest; + + const actualOutput = toTranslationRequest('whisper-1', request); + expect(actualOutput).toStrictEqual({ + model: 'whisper-1', + file: expect.any(File), + prompt: '', + response_format: 'text', + }); + }); + + it('should allow verbose_json when output.format is json', () => { + const request = { + messages: [ + { + role: 'user', + content: [ + { + media: { + contentType: 'audio/wav', + url: 'data:audio/wav;base64,aGVsbG8=', + }, + }, + ], + }, + ], + output: { format: 'json' }, + config: { response_format: 'verbose_json' }, + } as GenerateRequest; + + const actualOutput = toTranslationRequest('whisper-1', request); + expect(actualOutput).toStrictEqual({ + model: 'whisper-1', + file: expect.any(File), + prompt: '', + response_format: 'verbose_json', + }); + }); + + it('should throw error when media.url is missing', () => { + const request = { + messages: [ + { + role: 'user', + content: [ + { + media: { + contentType: 'audio/wav', + }, + }, + ], + }, + ], + output: { format: 'text' }, + } as GenerateRequest; + + expect(() => toTranslationRequest('whisper-1', request)).toThrowError( + 'No media found in the request' + ); + }); + + it('should throw error when output.format is json but custom format is incompatible', () => { + const request = { + messages: [ + { + role: 'user', + content: [ + { + media: { + contentType: 'audio/wav', + url: 'data:audio/wav;base64,aGVsbG8=', + }, + }, + ], + }, + ], + output: { format: 'json' }, + config: { response_format: 'srt' }, + } as GenerateRequest; + + expect(() => toTranslationRequest('whisper-1', request)).toThrowError( + 'Custom response format srt is not compatible with output format json' + ); + }); + + it('should throw error when output.format is media', () => { + const request = { + messages: [ + { + role: 'user', + content: [ + { + media: { + contentType: 'audio/wav', + url: 'data:audio/wav;base64,aGVsbG8=', + }, + }, + ], + }, + ], + output: { format: 'media' }, + } as GenerateRequest; + + expect(() => toTranslationRequest('whisper-1', request)).toThrow( + 'Output format media is not supported.' + ); + }); + + it('should run with requestBuilder', () => { + const requestBuilder = jest.fn((_, params) => { + (params as any).foo = 'bar'; + }); + + const request = { + messages: [ + { + role: 'user', + content: [ + { + media: { + contentType: 'audio/wav', + url: 'data:audio/wav;base64,aGVsbG8=', + }, + }, + ], + }, + ], + output: { format: 'text' }, + } as GenerateRequest; + + const actualOutput = toTranslationRequest( + 'whisper-1', + request, + requestBuilder + ); + + expect(requestBuilder).toHaveBeenCalledTimes(1); + expect(actualOutput).toHaveProperty('foo', 'bar'); + }); +}); + +describe('translationToGenerateResponse', () => { + it('should transform translation result correctly when result is Translation object', () => { + const result: Translation = { + text: 'Hello', + }; + + const actualOutput = translationToGenerateResponse(result); + expect(actualOutput).toStrictEqual({ + message: { + role: 'model', + content: [{ text: 'Hello' }], + }, + finishReason: 'stop', + raw: result, + }); + }); + + it('should transform translation result correctly when result is string', () => { + const result = 'Hello'; + + const actualOutput = translationToGenerateResponse(result); + expect(actualOutput).toStrictEqual({ + message: { + role: 'model', + content: [{ text: 'Hello' }], + }, + finishReason: 'stop', + raw: result, + }); + }); +}); + +describe('defineCompatOpenAITranslationModel runner', () => { + it('should correctly run Translation requests', async () => { + const result: Translation = { + text: 'Hello', + }; + + const openaiClient = { + audio: { + translations: { + create: jest.fn(async () => result), + }, + }, + }; + const abortSignal = jest.fn(); + const runner = defineCompatOpenAITranslationModel({ + name: 'whisper-1', + client: openaiClient as unknown as OpenAI, + }); + await runner( + { + messages: [ + { + role: 'user', + content: [ + { + media: { + url: 'data:audio/wav;base64,aGVsbG8=', + contentType: 'audio/wav', + }, + }, + ], + }, + ], + }, + { + abortSignal: abortSignal as unknown as AbortSignal, + } + ); + expect(openaiClient.audio.translations.create).toHaveBeenCalledWith( + { + model: 'whisper-1', + file: expect.any(File), + prompt: '', + response_format: 'text', + }, + { signal: abortSignal } + ); + }); +}); diff --git a/js/plugins/compat-oai/tests/openai_whisper_test.ts b/js/plugins/compat-oai/tests/openai_whisper_test.ts new file mode 100644 index 0000000000..ff727c13f7 --- /dev/null +++ b/js/plugins/compat-oai/tests/openai_whisper_test.ts @@ -0,0 +1,205 @@ +/** + * Copyright 2024 The Fire Company + * Copyright 2024 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import { describe, expect, it, jest } from '@jest/globals'; +import OpenAI from 'openai'; +import { Transcription } from 'openai/resources/audio/transcriptions.mjs'; +import { Translation } from 'openai/resources/audio/translations.mjs'; +import { defineOpenAIWhisperModel } from '../src/openai/whisper'; + +jest.mock('genkit/model', () => { + const originalModule = + jest.requireActual('genkit/model'); + return { + ...originalModule, + defineModel: jest.fn((_, runner) => { + return runner; + }), + }; +}); + +describe('defineOpenAIWhisperModel runner — transcription (default)', () => { + it('should call transcriptions.create when translate is not set', async () => { + const result: Transcription = { + text: 'Hello world', + }; + + const openaiClient = { + audio: { + transcriptions: { + create: jest.fn(async () => result), + }, + translations: { + create: jest.fn(async () => ({ text: 'should not be called' })), + }, + }, + }; + const abortSignal = jest.fn(); + + const runner = defineOpenAIWhisperModel({ + name: 'whisper-1', + client: openaiClient as unknown as OpenAI, + }); + + await runner( + { + messages: [ + { + role: 'user', + content: [ + { + media: { + url: 'data:audio/wav;base64,aGVsbG8=', + contentType: 'audio/wav', + }, + }, + ], + }, + ], + }, + { + abortSignal: abortSignal as unknown as AbortSignal, + } + ); + + expect(openaiClient.audio.transcriptions.create).toHaveBeenCalledWith( + { + model: 'whisper-1', + file: expect.any(File), + prompt: '', + response_format: 'text', + stream: false, + }, + { signal: abortSignal } + ); + expect(openaiClient.audio.translations.create).not.toHaveBeenCalled(); + }); + + it('should call transcriptions.create when translate is explicitly false', async () => { + const result: Transcription = { + text: 'transcribed text', + }; + + const openaiClient = { + audio: { + transcriptions: { + create: jest.fn(async () => result), + }, + translations: { + create: jest.fn(async () => ({ text: 'should not be called' })), + }, + }, + }; + const abortSignal = jest.fn(); + + const runner = defineOpenAIWhisperModel({ + name: 'whisper-1', + client: openaiClient as unknown as OpenAI, + }); + + await runner( + { + messages: [ + { + role: 'user', + content: [ + { + media: { + url: 'data:audio/wav;base64,aGVsbG8=', + contentType: 'audio/wav', + }, + }, + ], + }, + ], + config: { translate: false }, + }, + { + abortSignal: abortSignal as unknown as AbortSignal, + } + ); + expect(openaiClient.audio.transcriptions.create).toHaveBeenCalledWith( + { + model: 'whisper-1', + file: expect.any(File), + prompt: '', + response_format: 'text', + stream: false, + }, + { signal: abortSignal } + ); + expect(openaiClient.audio.translations.create).not.toHaveBeenCalled(); + }); +}); + +describe('defineOpenAIWhisperModel runner — translation (translate: true)', () => { + it('should call translations.create when translate is true', async () => { + const result: Translation = { + text: 'Hello in English', + }; + + const openaiClient = { + audio: { + transcriptions: { + create: jest.fn(async () => ({ text: 'should not be called' })), + }, + translations: { + create: jest.fn(async () => result), + }, + }, + }; + const abortSignal = jest.fn(); + + const runner = defineOpenAIWhisperModel({ + name: 'whisper-1', + client: openaiClient as unknown as OpenAI, + }); + + await runner( + { + messages: [ + { + role: 'user', + content: [ + { + media: { + url: 'data:audio/wav;base64,aGVsbG8=', + contentType: 'audio/wav', + }, + }, + ], + }, + ], + config: { translate: true }, + }, + { + abortSignal: abortSignal as unknown as AbortSignal, + } + ); + + expect(openaiClient.audio.translations.create).toHaveBeenCalledWith( + { + model: 'whisper-1', + file: expect.any(File), + prompt: '', + response_format: 'text', + }, + { signal: abortSignal } + ); + expect(openaiClient.audio.transcriptions.create).not.toHaveBeenCalled(); + }); +}); diff --git a/js/testapps/compat-oai/audio-korean.mp3 b/js/testapps/compat-oai/audio-korean.mp3 new file mode 100644 index 0000000000..c605c9716f Binary files /dev/null and b/js/testapps/compat-oai/audio-korean.mp3 differ diff --git a/js/testapps/compat-oai/src/index.ts b/js/testapps/compat-oai/src/index.ts index d5ce6ca8da..6f195838f2 100644 --- a/js/testapps/compat-oai/src/index.ts +++ b/js/testapps/compat-oai/src/index.ts @@ -377,6 +377,28 @@ ai.defineFlow('transcribe', async () => { return text; }); +// translation sample +ai.defineFlow('translate', async () => { + const audioFile = fs.readFileSync('audio-korean.mp3'); + + const { text } = await ai.generate({ + model: openAI.model('whisper-1', { + translate: true, + temperature: 0.5, + }), + prompt: [ + { + media: { + contentType: 'audio/mp3', + url: `data:audio/mp3;base64,${audioFile.toString('base64')}`, + }, + }, + ], + }); + + return text; +}); + // PDF file input example ai.defineFlow( {