Skip to content

Commit a5c5f9c

Browse files
committed
feat: add option to augment prompt
Adds an option to the environment config that allows users to augment the resolved prompts before they're sent out.
1 parent e0bd8ed commit a5c5f9c

File tree

3 files changed

+97
-23
lines changed

3 files changed

+97
-23
lines changed

runner/configuration/environment-config.ts

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,9 @@ import {
88
LocalExecutorConfig,
99
localExecutorConfigSchema,
1010
} from '../orchestration/executors/local-executor-config.js';
11-
import {RatingContextFilter, ReportContextFilter} from '../shared-interfaces.js';
11+
import {PromptDefinition, RatingContextFilter, ReportContextFilter} from '../shared-interfaces.js';
12+
import type {Environment} from './environment.js';
13+
import type {GenkitRunner} from '../codegen/genkit/genkit-runner.js';
1214

1315
export const environmentConfigSchema = z.object({
1416
/** Display name for the environment. */
@@ -118,6 +120,13 @@ export const environmentConfigSchema = z.object({
118120
}),
119121
)
120122
.optional(),
123+
124+
/**
125+
* Function that can be used to augment prompts before they're evaluated.
126+
*/
127+
augmentExecutablePrompt: z
128+
.function(z.tuple([z.custom<PromptAugmentationContext>()]), z.promise(z.string()))
129+
.optional(),
121130
});
122131

123132
/**
@@ -127,6 +136,16 @@ export const environmentConfigSchema = z.object({
127136
export type EnvironmentConfig = z.infer<typeof environmentConfigSchema> &
128137
Partial<LocalExecutorConfig>;
129138

139+
/** Context passed to the `augmentExecutablePrompt` function. */
140+
export interface PromptAugmentationContext {
141+
/** Definition being augmented. */
142+
promptDef: PromptDefinition;
143+
/** Environment running the evaluation. */
144+
environment: Environment;
145+
/** Runner that the user can use for augmentation. */
146+
runner: GenkitRunner;
147+
}
148+
130149
/** Asserts that the specified data is a valid environment config. */
131150
export function assertIsEnvironmentConfig(value: unknown): asserts value is EnvironmentConfig {
132151
const validationResult = environmentConfigSchema

runner/configuration/environment.ts

Lines changed: 64 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,13 @@ import {
1414
import {UserFacingError} from '../utils/errors.js';
1515
import {generateId} from '../utils/id-generation.js';
1616
import {lazy} from '../utils/lazy-creation.js';
17-
import {EnvironmentConfig} from './environment-config.js';
17+
import {EnvironmentConfig, PromptAugmentationContext} from './environment-config.js';
1818
import {EvalPromptWithMetadata, MultiStepPrompt} from './prompts.js';
1919
import {renderPromptTemplate} from './prompt-templating.js';
2020
import {getSha256Hash} from '../utils/hashing.js';
2121
import {DEFAULT_SUMMARY_MODEL} from './constants.js';
22+
import type {GenkitRunner} from '../codegen/genkit/genkit-runner.js';
23+
import {getRunnerByName} from '../codegen/runner-creation.js';
2224

2325
interface CategoryConfig {
2426
name: string;
@@ -73,6 +75,14 @@ export class Environment {
7375
/** Ratings configured at the environment level. */
7476
private readonly ratings: Rating[];
7577

78+
/** User-configured function used to augment prompts. */
79+
private readonly augmentExecutablePrompt:
80+
| ((context: PromptAugmentationContext) => Promise<string>)
81+
| null;
82+
83+
/** Runner that user can use to access an LLM to augment prompts. */
84+
private augmentationRunner: GenkitRunner | null = null;
85+
7686
constructor(
7787
rootPath: string,
7888
private readonly config: EnvironmentConfig & Required<Pick<EnvironmentConfig, 'executor'>>,
@@ -103,26 +113,27 @@ export class Environment {
103113
this.ratings = this.resolveRatings(config);
104114
this.ratingHash = this.getRatingHash(this.ratings, this.ratingCategories);
105115
this.analysisPrompts = this.resolveAnalysisPrompts(config);
116+
this.augmentExecutablePrompt = config.augmentExecutablePrompt || null;
106117
this.validateRatingHash(this.ratingHash, config);
107118
}
108119

109120
/** Prompts that should be executed as a part of the evaluation. */
110-
executablePrompts = lazy(async () => {
121+
readonly executablePrompts = lazy(async () => {
111122
return this.resolveExecutablePrompts(this.config.executablePrompts);
112123
});
113124

114-
systemPromptGeneration = lazy(async () => {
125+
readonly systemPromptGeneration = lazy(async () => {
115126
return (await this.renderSystemPrompt(this.config.generationSystemPrompt)).result;
116127
});
117128

118-
systemPromptRepair = lazy(async () => {
129+
readonly systemPromptRepair = lazy(async () => {
119130
if (!this.config.repairSystemPrompt) {
120131
return 'Please fix the given errors and return the corrected code.';
121132
}
122133
return (await this.renderSystemPrompt(this.config.repairSystemPrompt)).result;
123134
});
124135

125-
systemPromptEditing = lazy(async () => {
136+
readonly systemPromptEditing = lazy(async () => {
126137
if (!this.config.editingSystemPrompt) {
127138
return this.systemPromptGeneration();
128139
}
@@ -180,6 +191,14 @@ export class Environment {
180191
});
181192
}
182193

194+
async destroy(): Promise<void> {
195+
await this.executor.destroy();
196+
197+
if (this.augmentationRunner) {
198+
await this.augmentationRunner.dispose();
199+
}
200+
}
201+
183202
/**
184203
* Gets the readable display name of a framework, based on its ID.
185204
* @param id ID to be resolved.
@@ -209,16 +228,16 @@ export class Environment {
209228
* @param config Configuration for the environment.
210229
*/
211230
private async resolveExecutablePrompts(
212-
prompts: EnvironmentConfig['executablePrompts'],
231+
definitions: EnvironmentConfig['executablePrompts'],
213232
): Promise<RootPromptDefinition[]> {
214-
const result: Promise<RootPromptDefinition>[] = [];
233+
const promptPromises: Promise<RootPromptDefinition>[] = [];
215234
const envRatings = this.ratings;
216235

217-
for (const def of prompts) {
236+
for (const def of definitions) {
218237
if (def instanceof MultiStepPrompt) {
219-
result.push(this.getMultiStepPrompt(def, envRatings));
238+
promptPromises.push(this.getMultiStepPrompt(def, envRatings));
220239
} else if (def instanceof EvalPromptWithMetadata) {
221-
result.push(
240+
promptPromises.push(
222241
Promise.resolve({
223242
name: def.name,
224243
kind: 'single',
@@ -243,10 +262,10 @@ export class Environment {
243262
name = def.name;
244263
}
245264

246-
result.push(
265+
promptPromises.push(
247266
...globSync(path, {cwd: this.rootPath}).map(
248267
async relativePath =>
249-
await this.getStepPromptDefinition(
268+
await this.getSinglePromptDefinition(
250269
name ?? basename(relativePath, extname(relativePath)),
251270
relativePath,
252271
ratings,
@@ -258,19 +277,47 @@ export class Environment {
258277
}
259278
}
260279

261-
return Promise.all(result);
280+
const prompts = await Promise.all(promptPromises);
281+
282+
if (this.augmentExecutablePrompt) {
283+
const augmentationPromises: Promise<unknown>[] = [];
284+
const updatePrompt = (promptDef: PromptDefinition) => {
285+
augmentationPromises.push(
286+
this.augmentExecutablePrompt!({
287+
promptDef,
288+
environment: this,
289+
runner: this.augmentationRunner!,
290+
}).then(text => (promptDef.prompt = text)),
291+
);
292+
};
293+
this.augmentationRunner ??= await getRunnerByName('genkit');
294+
295+
for (const rootPrompt of prompts) {
296+
if (rootPrompt.kind === 'multi-step') {
297+
for (const promptDef of rootPrompt.steps) {
298+
updatePrompt(promptDef);
299+
}
300+
} else {
301+
updatePrompt(rootPrompt);
302+
}
303+
}
304+
305+
await Promise.all(augmentationPromises);
306+
}
307+
308+
return prompts;
262309
}
263310

264311
/**
265-
* Creates a prompt definition for a given step.
312+
* Creates a prompt definition for a single prompt.
266313
*
267314
* @param name Name of the prompt.
268315
* @param rootPath Root path of the project.
269316
* @param relativePath Relative path to the prompt.
270317
* @param ratings Ratings to run against the definition.
271318
* @param isEditing Whether this is an editing or generation step.
272319
*/
273-
private async getStepPromptDefinition<Metadata>(
320+
private async getSinglePromptDefinition<Metadata>(
274321
name: string,
275322
relativePath: string,
276323
ratings: Rating[],
@@ -345,11 +392,11 @@ export class Environment {
345392
if (stepNum === 0) {
346393
throw new UserFacingError('Multi-step prompts start with `step-1`.');
347394
}
348-
const step = await this.getStepPromptDefinition(
395+
const step = await this.getSinglePromptDefinition(
349396
`${name}-step-${stepNum}`,
350397
join(def.directoryPath, current.name),
351398
ratings,
352-
/*isEditing */ stepNum !== 1,
399+
/* isEditing */ stepNum !== 1,
353400
stepMetadata,
354401
);
355402

runner/orchestration/generate.ts

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -48,16 +48,24 @@ export async function generateCodeAndAssess(options: AssessmentConfig): Promise<
4848
const cleanup = async () => {
4949
// Clean-up should never interrupt a potentially passing completion.
5050
try {
51-
await env.executor.destroy();
52-
for (const cleanupFn of extraCleanupFns) {
53-
await cleanupFn();
54-
}
51+
await env.destroy();
5552
} catch (e) {
56-
console.error(`Failed to destroy executor: ${e}`);
53+
console.error(`Failed to destroy environment: ${e}`);
5754
if (e instanceof Error) {
5855
console.error(e.stack);
5956
}
6057
}
58+
59+
for (const cleanupFn of extraCleanupFns) {
60+
try {
61+
await cleanupFn();
62+
} catch (e) {
63+
console.error(`Failed cleanup: ${e}`);
64+
if (e instanceof Error) {
65+
console.error(e.stack);
66+
}
67+
}
68+
}
6169
};
6270

6371
// Ensure cleanup logic runs when the evaluation is aborted.

0 commit comments

Comments
 (0)