Skip to content

Commit 6973570

Browse files
authored
Merge branch 'main' into main
2 parents 857952d + b4475b0 commit 6973570

File tree

21 files changed

+486
-940
lines changed

21 files changed

+486
-940
lines changed

.github/workflows/general.yml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -150,10 +150,10 @@ jobs:
150150
continue-on-error: ${{ github.event.pull_request.head.repo.full_name != github.repository || github.actor == 'dependabot[bot]' }}
151151

152152
# We deliberately install our MSRV here (rather than 'stable') to ensure that everything compiles with that version
153-
- name: Install Rust 1.85.0
153+
- name: Install Rust 1.86.0
154154
run: |
155-
rustup install 1.85.0 --component clippy,rustfmt
156-
rustup default 1.85.0
155+
rustup install 1.86.0 --component clippy,rustfmt
156+
rustup default 1.86.0
157157
158158
- name: Print Rust version
159159
run: rustc --version

Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ resolver = "2"
1616

1717
[workspace.package]
1818
version = "2025.7.4"
19-
rust-version = "1.85.0"
19+
rust-version = "1.86.0"
2020
license = "Apache-2.0"
2121

2222
[workspace.dependencies]

evaluations/src/lib.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -465,7 +465,7 @@ impl ThrottledTensorZeroClient {
465465
}
466466

467467
async fn inference(&self, params: ClientInferenceParams) -> Result<InferenceOutput> {
468-
let _permit = self.semaphore.acquire().await;
468+
let _permit = self.semaphore.acquire().await?;
469469
let inference_output = self.client.inference(params).await?;
470470
Ok(inference_output)
471471
}
Lines changed: 186 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,186 @@
1+
use super::check_table_exists;
2+
use crate::clickhouse::migration_manager::migration_trait::Migration;
3+
use crate::clickhouse::ClickHouseConnectionInfo;
4+
use crate::error::{Error, ErrorDetails};
5+
use crate::serde_util::deserialize_u64;
6+
use async_trait::async_trait;
7+
use serde::Deserialize;
8+
use std::time::Duration;
9+
10+
/// This migration adds a `CumulativeUsage` table and `CumulativeUsageView` materialized view
11+
/// This will allow the sum of tokens in the ModelInference table to be amortized and
12+
/// looked up as needed.
13+
pub struct Migration0034<'a> {
14+
pub clickhouse: &'a ClickHouseConnectionInfo,
15+
}
16+
17+
const MIGRATION_ID: &str = "0034";
18+
19+
#[async_trait]
20+
impl Migration for Migration0034<'_> {
21+
async fn can_apply(&self) -> Result<(), Error> {
22+
if !check_table_exists(self.clickhouse, "ModelInference", MIGRATION_ID).await? {
23+
return Err(Error::new(ErrorDetails::ClickHouseMigration {
24+
id: MIGRATION_ID.to_string(),
25+
message: "ModelInference table does not exist".to_string(),
26+
}));
27+
}
28+
Ok(())
29+
}
30+
31+
async fn should_apply(&self) -> Result<bool, Error> {
32+
// If either the CumulativeUsage table or CumulativeUsageView view doesn't exist, we need to create it
33+
if !check_table_exists(self.clickhouse, "CumulativeUsage", MIGRATION_ID).await? {
34+
return Ok(true);
35+
}
36+
if !check_table_exists(self.clickhouse, "CumulativeUsageView", MIGRATION_ID).await? {
37+
return Ok(true);
38+
}
39+
Ok(false)
40+
}
41+
42+
async fn apply(&self, clean_start: bool) -> Result<(), Error> {
43+
let view_offset = Duration::from_secs(15);
44+
let view_timestamp_nanos = (std::time::SystemTime::now()
45+
.duration_since(std::time::UNIX_EPOCH)
46+
.map_err(|e| {
47+
Error::new(ErrorDetails::ClickHouseMigration {
48+
id: MIGRATION_ID.to_string(),
49+
message: e.to_string(),
50+
})
51+
})?
52+
+ view_offset)
53+
.as_nanos();
54+
self.clickhouse
55+
.run_query_synchronous_no_params(
56+
r"CREATE TABLE IF NOT EXISTS CumulativeUsage (
57+
type LowCardinality(String),
58+
count UInt64,
59+
)
60+
ENGINE = SummingMergeTree
61+
ORDER BY type;"
62+
.to_string(),
63+
)
64+
.await?;
65+
66+
// Create the materialized view for the CumulativeUsage table from ModelInference
67+
// If we are not doing a clean start, we need to add a where clause ot the view to only include rows that have been created
68+
// after the view_timestamp
69+
let view_where_clause = if clean_start {
70+
String::new()
71+
} else {
72+
format!("AND UUIDv7ToDateTime(id) >= fromUnixTimestamp64Nano({view_timestamp_nanos})")
73+
};
74+
let query = format!(
75+
r"
76+
CREATE MATERIALIZED VIEW IF NOT EXISTS CumulativeUsageView
77+
TO CumulativeUsage
78+
AS
79+
SELECT
80+
tupleElement(t, 1) AS type,
81+
tupleElement(t, 2) AS count
82+
FROM (
83+
SELECT
84+
arrayJoin([
85+
tuple('input_tokens', input_tokens),
86+
tuple('output_tokens', output_tokens),
87+
tuple('model_inferences', 1)
88+
]) AS t
89+
FROM ModelInference
90+
WHERE input_tokens IS NOT NULL
91+
{view_where_clause}
92+
)
93+
"
94+
);
95+
let _ = self
96+
.clickhouse
97+
.run_query_synchronous_no_params(query)
98+
.await?;
99+
100+
// If we are not clean starting, we must backfill this table
101+
if !clean_start {
102+
tokio::time::sleep(view_offset).await;
103+
// Check if the materialized view we wrote is still in the table.
104+
// If this is the case, we should compute the backfilled sums and add them to the table.
105+
// Otherwise, we should warn that our view was not written (probably because a concurrent client did this first)
106+
// and conclude the migration.
107+
let create_table = self
108+
.clickhouse
109+
.run_query_synchronous_no_params(
110+
"SHOW CREATE TABLE CumulativeUsageView".to_string(),
111+
)
112+
.await?
113+
.response;
114+
let view_timestamp_nanos_string = view_timestamp_nanos.to_string();
115+
if !create_table.contains(&view_timestamp_nanos_string) {
116+
tracing::warn!("Materialized view `CumulativeUsageView` was not written because it was recently created. This is likely due to a concurrent migration. Unless the other migration failed, no action is required.");
117+
return Ok(());
118+
}
119+
120+
tracing::info!("Running backfill of CumulativeUsage");
121+
let query = format!(
122+
r"
123+
SELECT
124+
sum(ifNull(input_tokens, 0)) as total_input_tokens,
125+
sum(ifNull(output_tokens, 0)) as total_output_tokens,
126+
COUNT(input_tokens) as total_count
127+
FROM ModelInference
128+
WHERE UUIDv7ToDateTime(id) < fromUnixTimestamp64Nano({view_timestamp_nanos})
129+
FORMAT JsonEachRow;
130+
"
131+
);
132+
let response = self
133+
.clickhouse
134+
.run_query_synchronous_no_params(query)
135+
.await?;
136+
let trimmed_response = response.response.trim();
137+
let parsed_response =
138+
serde_json::from_str::<CountResponse>(trimmed_response).map_err(|e| {
139+
Error::new(ErrorDetails::ClickHouseDeserialization {
140+
message: format!("Failed to deserialize count query: {e}"),
141+
})
142+
})?;
143+
let CountResponse {
144+
total_input_tokens,
145+
total_output_tokens,
146+
total_count,
147+
} = parsed_response;
148+
149+
let write_query = format!(
150+
r"
151+
INSERT INTO CumulativeUsage (type, count) VALUES
152+
('input_tokens', {total_input_tokens}),
153+
('output_tokens', {total_output_tokens}),
154+
('model_inferences', {total_count})
155+
"
156+
);
157+
self.clickhouse
158+
.run_query_synchronous_no_params(write_query)
159+
.await?;
160+
}
161+
162+
Ok(())
163+
}
164+
165+
fn rollback_instructions(&self) -> String {
166+
r"
167+
DROP TABLE CumulativeUsageView;
168+
DROP TABLE CumulativeUsage;"
169+
.to_string()
170+
}
171+
172+
async fn has_succeeded(&self) -> Result<bool, Error> {
173+
let should_apply = self.should_apply().await?;
174+
Ok(!should_apply)
175+
}
176+
}
177+
178+
#[derive(Debug, Deserialize)]
179+
struct CountResponse {
180+
#[serde(deserialize_with = "deserialize_u64")]
181+
total_input_tokens: u64,
182+
#[serde(deserialize_with = "deserialize_u64")]
183+
total_output_tokens: u64,
184+
#[serde(deserialize_with = "deserialize_u64")]
185+
total_count: u64,
186+
}

tensorzero-core/src/clickhouse/migration_manager/migrations/mod.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ pub mod migration_0030;
3131
pub mod migration_0031;
3232
pub mod migration_0032;
3333
pub mod migration_0033;
34+
pub mod migration_0034;
3435

3536
/// Returns true if the table exists, false if it does not
3637
/// Errors if the query fails

tensorzero-core/src/clickhouse/migration_manager/mod.rs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,11 +36,12 @@ use migrations::migration_0030::Migration0030;
3636
use migrations::migration_0031::Migration0031;
3737
use migrations::migration_0032::Migration0032;
3838
use migrations::migration_0033::Migration0033;
39+
use migrations::migration_0034::Migration0034;
3940
use serde::{Deserialize, Serialize};
4041

4142
/// This must match the number of migrations returned by `make_all_migrations` - the tests
4243
/// will panic if they don't match.
43-
pub const NUM_MIGRATIONS: usize = 27;
44+
pub const NUM_MIGRATIONS: usize = 28;
4445

4546
/// Constructs (but does not run) a vector of all our database migrations.
4647
/// This is the single source of truth for all migration - it's used during startup to migrate
@@ -89,6 +90,7 @@ pub fn make_all_migrations<'a>(
8990
Box::new(Migration0031 { clickhouse }),
9091
Box::new(Migration0032 { clickhouse }),
9192
Box::new(Migration0033 { clickhouse }),
93+
Box::new(Migration0034 { clickhouse }),
9294
];
9395
assert_eq!(
9496
migrations.len(),

tensorzero-core/src/clickhouse/mod.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -702,6 +702,7 @@ pub struct ExternalDataInfo {
702702
pub data: String, // Must be valid ClickHouse data in the given format
703703
}
704704

705+
#[derive(Debug)]
705706
pub struct ClickHouseResponse {
706707
pub response: String,
707708
pub metadata: ClickHouseResponseMetadata,

tensorzero-core/src/config_parser/mod.rs

Lines changed: 16 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -472,25 +472,29 @@ impl Config {
472472
.map(|(name, config)| config.load(name.clone()).map(|c| (name, Arc::new(c))))
473473
.collect::<Result<HashMap<String, Arc<StaticToolConfig>>, Error>>()?;
474474

475-
let models = uninitialized_config
476-
.models
477-
.into_iter()
478-
.map(|(name, config)| {
475+
let models = try_join_all(uninitialized_config.models.into_iter().map(
476+
|(name, config)| async {
479477
config
480478
.load(&name, &uninitialized_config.provider_types)
479+
.await
481480
.map(|c| (name, c))
482-
})
483-
.collect::<Result<HashMap<_, _>, _>>()?;
481+
},
482+
))
483+
.await?
484+
.into_iter()
485+
.collect::<HashMap<_, _>>();
484486

485-
let embedding_models = uninitialized_config
486-
.embedding_models
487-
.into_iter()
488-
.map(|(name, config)| {
487+
let embedding_models = try_join_all(uninitialized_config.embedding_models.into_iter().map(
488+
|(name, config)| async {
489489
config
490490
.load(&uninitialized_config.provider_types)
491+
.await
491492
.map(|c| (name, c))
492-
})
493-
.collect::<Result<HashMap<_, _>, _>>()?;
493+
},
494+
))
495+
.await?
496+
.into_iter()
497+
.collect::<HashMap<_, _>>();
494498

495499
let object_store_info = ObjectStoreInfo::new(uninitialized_config.object_storage)?;
496500
let optimizers = try_join_all(

tensorzero-core/src/embeddings.rs

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ use crate::{
2020
model::ProviderConfig,
2121
providers::openai::OpenAIProvider,
2222
};
23+
use futures::future::try_join_all;
2324
use reqwest::Client;
2425
use serde::{Deserialize, Serialize};
2526
use tracing::instrument;
@@ -68,15 +69,17 @@ pub struct UninitializedEmbeddingModelConfig {
6869
}
6970

7071
impl UninitializedEmbeddingModelConfig {
71-
pub fn load(self, provider_types: &ProviderTypesConfig) -> Result<EmbeddingModelConfig, Error> {
72-
let providers = self
73-
.providers
74-
.into_iter()
75-
.map(|(name, config)| {
76-
let provider_config = config.load(provider_types)?;
77-
Ok((name, provider_config))
78-
})
79-
.collect::<Result<HashMap<_, _>, Error>>()?;
72+
pub async fn load(
73+
self,
74+
provider_types: &ProviderTypesConfig,
75+
) -> Result<EmbeddingModelConfig, Error> {
76+
let providers = try_join_all(self.providers.into_iter().map(|(name, config)| async {
77+
let provider_config = config.load(provider_types).await?;
78+
Ok::<_, Error>((name, provider_config))
79+
}))
80+
.await?
81+
.into_iter()
82+
.collect::<HashMap<_, _>>();
8083
Ok(EmbeddingModelConfig {
8184
routing: self.routing,
8285
providers,
@@ -318,11 +321,11 @@ pub struct UninitializedEmbeddingProviderConfig {
318321
}
319322

320323
impl UninitializedEmbeddingProviderConfig {
321-
pub fn load(
324+
pub async fn load(
322325
self,
323326
provider_types: &ProviderTypesConfig,
324327
) -> Result<EmbeddingProviderConfig, Error> {
325-
let provider_config = self.config.load(provider_types)?;
328+
let provider_config = self.config.load(provider_types).await?;
326329
Ok(match provider_config {
327330
ProviderConfig::OpenAI(provider) => EmbeddingProviderConfig::OpenAI(provider),
328331
_ => {

0 commit comments

Comments
 (0)