Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion js/plugins/compat-oai/src/audio.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/

import type {
GenerateRequest,
GenerateResponseData,
Expand Down Expand Up @@ -41,7 +42,7 @@ export type TranscriptionRequestBuilder = (
params: TranscriptionCreateParams
) => void;

export const TRANSCRIPTION_MODEL_INFO = {
export const TRANSCRIPTION_MODEL_INFO: ModelInfo = {
supports: {
media: true,
output: ['text', 'json'],
Expand Down
6 changes: 6 additions & 0 deletions js/plugins/compat-oai/src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,12 @@ export {
openAIModelRunner,
type ModelRequestBuilder,
} from './model.js';
export {
TranslationConfigSchema,
compatOaiTranslationModelRef,
defineCompatOpenAITranslationModel,
type TranslationRequestBuilder,
} from './translate.js';

export interface PluginOptions extends Partial<Omit<ClientOptions, 'apiKey'>> {
apiKey?: ClientOptions['apiKey'] | false;
Expand Down
60 changes: 47 additions & 13 deletions js/plugins/compat-oai/src/openai/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,12 @@ import {
} from './gpt.js';
import { openAITranscriptionModelRef, SUPPORTED_STT_MODELS } from './stt.js';
import { openAISpeechModelRef, SUPPORTED_TTS_MODELS } from './tts.js';
import {
defineOpenAIWhisperModel,
openAIWhisperModelRef,
SUPPORTED_WHISPER_MODELS,
WhisperConfigSchema,
} from './whisper.js';

export type OpenAIPluginOptions = Omit<PluginOptions, 'name' | 'baseURL'>;

Expand Down Expand Up @@ -88,10 +94,15 @@ function createResolver(pluginOptions: PluginOptions) {
pluginOptions,
modelRef,
});
} else if (
actionName.includes('whisper') ||
actionName.includes('transcribe')
) {
} else if (actionName.includes('whisper')) {
const modelRef = openAIWhisperModelRef({ name: actionName });
return defineOpenAIWhisperModel({
name: modelRef.name,
client,
pluginOptions,
modelRef,
});
} else if (actionName.includes('transcribe')) {
const modelRef = openAITranscriptionModelRef({
name: actionName,
});
Expand Down Expand Up @@ -147,10 +158,16 @@ const listActions = async (client: OpenAI): Promise<ActionMetadata[]> => {
info: modelRef.info,
configSchema: modelRef.configSchema,
});
} else if (
model.id.includes('whisper') ||
model.id.includes('transcribe')
) {
} else if (model.id.includes('whisper')) {
const modelRef =
SUPPORTED_WHISPER_MODELS[model.id] ??
openAIWhisperModelRef({ name: model.id });
return modelActionMetadata({
name: modelRef.name,
info: modelRef.info,
configSchema: modelRef.configSchema,
});
} else if (model.id.includes('transcribe')) {
const modelRef =
SUPPORTED_STT_MODELS[model.id] ??
openAITranscriptionModelRef({ name: model.id });
Expand Down Expand Up @@ -209,6 +226,16 @@ export function openAIPlugin(options?: OpenAIPluginOptions): GenkitPluginV2 {
})
)
);
models.push(
...Object.values(SUPPORTED_WHISPER_MODELS).map((modelRef) =>
defineOpenAIWhisperModel({
name: modelRef.name,
client,
pluginOptions,
modelRef,
})
)
);
models.push(
...Object.values(SUPPORTED_STT_MODELS).map((modelRef) =>
defineCompatOpenAITranscriptionModel({
Expand Down Expand Up @@ -256,10 +283,11 @@ export type OpenAIPlugin = {
config?: z.infer<typeof SpeechConfigSchema>
): ModelReference<typeof SpeechConfigSchema>;
model(
name:
| keyof typeof SUPPORTED_STT_MODELS
| (`whisper-${string}` & {})
| (`${string}-transcribe` & {}),
name: keyof typeof SUPPORTED_WHISPER_MODELS | (`whisper-${string}` & {}),
config?: z.infer<typeof WhisperConfigSchema>
): ModelReference<typeof WhisperConfigSchema>;
model(
name: keyof typeof SUPPORTED_STT_MODELS | (`${string}-transcribe` & {}),
config?: z.infer<typeof TranscriptionConfigSchema>
): ModelReference<typeof TranscriptionConfigSchema>;
model(
Expand Down Expand Up @@ -292,7 +320,13 @@ const model = ((name: string, config?: any): ModelReference<z.ZodTypeAny> => {
config,
});
}
if (name.includes('whisper') || name.includes('transcribe')) {
if (name.includes('whisper')) {
return openAIWhisperModelRef({
name,
config,
});
}
if (name.includes('transcribe')) {
return openAITranscriptionModelRef({
name,
config,
Expand Down
3 changes: 0 additions & 3 deletions js/plugins/compat-oai/src/openai/stt.ts
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,4 @@ export const SUPPORTED_STT_MODELS = {
'gpt-4o-mini-transcribe': openAITranscriptionModelRef({
name: 'gpt-4o-mini-transcribe',
}),
'whisper-1': openAITranscriptionModelRef({
name: 'whisper-1',
}),
};
147 changes: 147 additions & 0 deletions js/plugins/compat-oai/src/openai/whisper.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
/**
* Copyright 2024 The Fire Company
* Copyright 2024 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

import type { ModelReference } from 'genkit';
import { modelRef, z } from 'genkit';
import type { ModelAction, ModelInfo } from 'genkit/model';
import { model } from 'genkit/plugin';
import OpenAI from 'openai';
import {
TranscriptionConfigSchema,
toSttRequest,
transcriptionToGenerateResponse,
} from '../audio.js';
import type { PluginOptions } from '../index.js';
import {
toTranslationRequest,
translationToGenerateResponse,
} from '../translate.js';
import { maybeCreateRequestScopedOpenAIClient, toModelName } from '../utils.js';

export const WHISPER_MODEL_INFO: ModelInfo = {
supports: {
media: true,
output: ['text', 'json'],
multiturn: false,
systemRole: false,
tools: false,
},
};

/**
* Config schema for Whisper models. Extends the transcription config with
* a `translate` flag that switches between transcription and translation APIs.
*/
export const WhisperConfigSchema = TranscriptionConfigSchema.extend({
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this be a combination of TranscriptionConfigSchema AND TranslationConfigSchema?

We can use discriminated unions here... (off the translate field)
https://v3.zod.dev/?id=discriminated-unions

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you try out discriminated unions, can you please verify it it works well with the Dev UI (Model Runner -> Select whisper-1 -> Check the Model config panel looks/works)

image

If it is unusable, let us just stick to simple union + extension (translate: boolean) for now. I will revisit this to see if it can be improved later. Thanks!

/** When true, uses Translation API instead of Transcription. Default: false */
translate: z.boolean().optional().default(false),
});

/**
* Method to define an OpenAI Whisper model that can perform both transcription and
* translation based on the `translate` config flag.
*
* @param params.ai The Genkit AI instance.
* @param params.name The name of the model.
* @param params.client The OpenAI client instance.
* @param params.modelRef Optional reference to the model's configuration and
* custom options.
*
* @returns the created {@link ModelAction}
*/
export function defineOpenAIWhisperModel<
CustomOptions extends z.ZodTypeAny = z.ZodTypeAny,
>(params: {
name: string;
client: OpenAI;
modelRef?: ModelReference<CustomOptions>;
pluginOptions?: PluginOptions;
}): ModelAction {
const { name, client: defaultClient, pluginOptions, modelRef } = params;
const modelName = toModelName(name, pluginOptions?.name);
const actionName =
modelRef?.name ?? `${pluginOptions?.name ?? 'openai'}/${modelName}`;

return model(
{
name: actionName,
...modelRef?.info,
configSchema: modelRef?.configSchema,
},
async (request, { abortSignal }) => {
const { translate, ...cleanConfig } = (request.config ?? {}) as Record<
string,
unknown
>;
const cleanRequest = { ...request, config: cleanConfig };
const client = maybeCreateRequestScopedOpenAIClient(
pluginOptions,
request,
defaultClient
);

if (translate === true) {
const params = toTranslationRequest(modelName, cleanRequest);
const result = await client.audio.translations.create(params, {
signal: abortSignal,
});
return translationToGenerateResponse(result);
} else {
const params = toSttRequest(modelName, cleanRequest);
// Explicitly setting stream to false ensures we use the non-streaming overload
const result = await client.audio.transcriptions.create(
{
...params,
stream: false,
},
{ signal: abortSignal }
);
return transcriptionToGenerateResponse(result);
}
}
);
}

/** OpenAI whisper ModelRef helper. */
export function openAIWhisperModelRef<
CustomOptions extends z.ZodTypeAny = z.ZodTypeAny,
>(params: {
name: string;
info?: ModelInfo;
configSchema?: CustomOptions;
config?: any;
}) {
const {
name,
info = WHISPER_MODEL_INFO,
configSchema,
config = undefined,
} = params;
return modelRef({
name,
configSchema: configSchema || (WhisperConfigSchema as any),
info,
config,
namespace: 'openai',
});
}

export const SUPPORTED_WHISPER_MODELS = {
'whisper-1': openAIWhisperModelRef({
name: 'whisper-1',
}),
};
Loading