Skip to content

Commit c09b8d3

Browse files
authored
Allow variant configuration to be passed dynamically @ inference (tensorzero#2931)
* refactored InferenceConfig & sample_variant * wip * factored preparation of variants into a helper & added parameter to inference handler * everything but conversion is implemented * fixed container build thing * wip * huh? * a file * wip * compiler errors pass * removed some unncessary Cows and added typescript derives * check for dryrun * added e2e tests for dynamic variant * fixed issues with merge * built bindings * updated name of dynamic variant * improved test * fixed test * fixed bad e2e test
1 parent 8f74825 commit c09b8d3

39 files changed

+822
-311
lines changed

clients/python/src/lib.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -498,6 +498,7 @@ impl BaseTensorZeroGateway {
498498
include_original_response,
499499
extra_body,
500500
extra_headers,
501+
internal_dynamic_variant_config: None,
501502
})
502503
}
503504
}

clients/rust/src/client_inference_params.rs

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,13 @@ use serde::{Deserialize, Serialize};
55
use serde_json::Value;
66
use tensorzero_core::{
77
cache::CacheParamsOptions,
8+
config_parser::UninitializedVariantInfo,
89
endpoints::inference::{InferenceParams, Params},
910
error::Error,
10-
inference::types::extra_body::UnfilteredInferenceExtraBody,
11-
inference::types::extra_headers::UnfilteredInferenceExtraHeaders,
12-
inference::types::{Input, InputMessage, InputMessageContent},
11+
inference::types::{
12+
extra_body::UnfilteredInferenceExtraBody, extra_headers::UnfilteredInferenceExtraHeaders,
13+
Input, InputMessage, InputMessageContent,
14+
},
1315
tool::DynamicToolParams,
1416
};
1517
use uuid::Uuid;
@@ -70,6 +72,7 @@ pub struct ClientInferenceParams {
7072
pub extra_body: UnfilteredInferenceExtraBody,
7173
#[serde(default)]
7274
pub extra_headers: UnfilteredInferenceExtraHeaders,
75+
pub internal_dynamic_variant_config: Option<UninitializedVariantInfo>,
7376
}
7477

7578
impl TryFrom<ClientInferenceParams> for Params {
@@ -112,6 +115,7 @@ impl TryFrom<ClientInferenceParams> for Params {
112115
include_original_response: this.include_original_response,
113116
extra_body: this.extra_body,
114117
extra_headers: this.extra_headers,
118+
internal_dynamic_variant_config: this.internal_dynamic_variant_config,
115119
})
116120
}
117121
}
@@ -139,6 +143,7 @@ fn assert_params_match(client_params: ClientInferenceParams) {
139143
include_original_response,
140144
extra_body,
141145
extra_headers,
146+
internal_dynamic_variant_config,
142147
} = client_params;
143148
let _ = Params {
144149
function_name,
@@ -158,6 +163,7 @@ fn assert_params_match(client_params: ClientInferenceParams) {
158163
include_original_response,
159164
extra_body,
160165
extra_headers,
166+
internal_dynamic_variant_config,
161167
};
162168
}
163169

evaluations/src/evaluators/llm_judge/mod.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,7 @@ pub async fn run_llm_judge_evaluator(
123123
cache_options: get_cache_options(inference_cache),
124124
extra_body: Default::default(),
125125
extra_headers: Default::default(),
126+
internal_dynamic_variant_config: None,
126127
};
127128
let result = clients.tensorzero_client.inference(params).await?;
128129
let response = match result {

evaluations/src/lib.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -416,6 +416,7 @@ async fn infer_datapoint(params: InferDatapointParams<'_>) -> Result<InferenceRe
416416
internal: true,
417417
extra_body: Default::default(),
418418
extra_headers: Default::default(),
419+
internal_dynamic_variant_config: None,
419420
};
420421
debug!("Making inference request");
421422
let inference_result = clients.tensorzero_client.inference(params).await?;

internal/tensorzero-node/lib/bindings/TomlRelativePath.ts

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,4 +6,10 @@
66
* all paths (e.g. `system_schema`) as `TomlRelativePath`s, which will
77
* track the original `.toml` file in order to perform correct relative path resolution.
88
*/
9-
export type TomlRelativePath = { __tensorzero_remapped_path: string };
9+
export type TomlRelativePath = {
10+
__tensorzero_remapped_path: string;
11+
/**
12+
* This should be set for dynamic variants to indicate what the file contents would have been at this remapped path.
13+
*/
14+
__data: string | null;
15+
};
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
// This file was generated by [ts-rs](https://github.com/Aleph-Alpha/ts-rs). Do not edit this file manually.
2+
import type { ExtraBodyConfig } from "./ExtraBodyConfig";
3+
import type { ExtraHeadersConfig } from "./ExtraHeadersConfig";
4+
import type { JsonMode } from "./JsonMode";
5+
import type { RetryConfig } from "./RetryConfig";
6+
import type { TomlRelativePath } from "./TomlRelativePath";
7+
8+
export type UninitializedBestOfNEvaluatorConfig = {
9+
weight: number | null;
10+
model: string;
11+
system_template: TomlRelativePath | null;
12+
user_template: TomlRelativePath | null;
13+
assistant_template: TomlRelativePath | null;
14+
temperature: number | null;
15+
top_p: number | null;
16+
max_tokens: number | null;
17+
presence_penalty: number | null;
18+
frequency_penalty: number | null;
19+
seed: number | null;
20+
stop_sequences: Array<string> | null;
21+
json_mode: JsonMode | null;
22+
retries: RetryConfig;
23+
extra_body: ExtraBodyConfig | null;
24+
extra_headers: ExtraHeadersConfig | null;
25+
};
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
// This file was generated by [ts-rs](https://github.com/Aleph-Alpha/ts-rs). Do not edit this file manually.
2+
import type { UninitializedBestOfNEvaluatorConfig } from "./UninitializedBestOfNEvaluatorConfig";
3+
4+
export type UninitializedBestOfNSamplingConfig = {
5+
weight: number | null;
6+
timeout_s: number;
7+
candidates: Array<string>;
8+
evaluator: UninitializedBestOfNEvaluatorConfig;
9+
};
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
// This file was generated by [ts-rs](https://github.com/Aleph-Alpha/ts-rs). Do not edit this file manually.
2+
import type { ExtraBodyConfig } from "./ExtraBodyConfig";
3+
import type { ExtraHeadersConfig } from "./ExtraHeadersConfig";
4+
import type { JsonMode } from "./JsonMode";
5+
import type { RetryConfig } from "./RetryConfig";
6+
import type { TomlRelativePath } from "./TomlRelativePath";
7+
8+
export type UninitializedChainOfThoughtConfig = {
9+
weight: number | null;
10+
model: string;
11+
system_template: TomlRelativePath | null;
12+
user_template: TomlRelativePath | null;
13+
assistant_template: TomlRelativePath | null;
14+
temperature: number | null;
15+
top_p: number | null;
16+
max_tokens: number | null;
17+
presence_penalty: number | null;
18+
frequency_penalty: number | null;
19+
seed: number | null;
20+
stop_sequences: Array<string> | null;
21+
json_mode: JsonMode | null;
22+
retries: RetryConfig;
23+
extra_body: ExtraBodyConfig | null;
24+
extra_headers: ExtraHeadersConfig | null;
25+
};
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
// This file was generated by [ts-rs](https://github.com/Aleph-Alpha/ts-rs). Do not edit this file manually.
2+
import type { ExtraBodyConfig } from "./ExtraBodyConfig";
3+
import type { ExtraHeadersConfig } from "./ExtraHeadersConfig";
4+
import type { JsonMode } from "./JsonMode";
5+
import type { RetryConfig } from "./RetryConfig";
6+
import type { TomlRelativePath } from "./TomlRelativePath";
7+
8+
export type UninitializedChatCompletionConfig = {
9+
weight: number | null;
10+
model: string;
11+
system_template: TomlRelativePath | null;
12+
user_template: TomlRelativePath | null;
13+
assistant_template: TomlRelativePath | null;
14+
temperature: number | null;
15+
top_p: number | null;
16+
max_tokens: number | null;
17+
presence_penalty: number | null;
18+
frequency_penalty: number | null;
19+
seed: number | null;
20+
stop_sequences: Array<string> | null;
21+
json_mode: JsonMode | null;
22+
retries: RetryConfig;
23+
extra_body: ExtraBodyConfig | null;
24+
extra_headers: ExtraHeadersConfig | null;
25+
};
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
// This file was generated by [ts-rs](https://github.com/Aleph-Alpha/ts-rs). Do not edit this file manually.
2+
import type { ExtraBodyConfig } from "./ExtraBodyConfig";
3+
import type { ExtraHeadersConfig } from "./ExtraHeadersConfig";
4+
import type { JsonMode } from "./JsonMode";
5+
import type { RetryConfig } from "./RetryConfig";
6+
import type { TomlRelativePath } from "./TomlRelativePath";
7+
8+
export type UninitializedDiclConfig = {
9+
weight: number | null;
10+
embedding_model: string;
11+
k: number;
12+
model: string;
13+
system_instructions: TomlRelativePath | null;
14+
temperature: number | null;
15+
top_p: number | null;
16+
stop_sequences: Array<string> | null;
17+
presence_penalty: number | null;
18+
frequency_penalty: number | null;
19+
max_tokens: number | null;
20+
seed: number | null;
21+
json_mode: JsonMode | null;
22+
extra_body: ExtraBodyConfig | null;
23+
retries: RetryConfig;
24+
extra_headers: ExtraHeadersConfig | null;
25+
};

0 commit comments

Comments
 (0)