Skip to content

Commit e49e0d1

Browse files
committed
fix(sqlite-provider): use caller-provided key column name
The SqliteLookupProvider previously hardcoded "row_idx" as the key column name in CREATE TABLE and WHERE clauses. This caused errors when callers used a different key column name (e.g. "_key"). Now derives the key column name from the first field in the provided schema, making the provider work with any key column name.
1 parent 3f303d3 commit e49e0d1

1 file changed

Lines changed: 19 additions & 6 deletions

File tree

src/sqlite_provider.rs

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,12 @@
22
//
33
// Stores all non-embedding columns in a local SQLite database (bundled libsqlite3).
44
// Scalar columns map to INTEGER/TEXT/REAL; list columns are serialised as JSON TEXT.
5-
// Lookups use `WHERE row_idx IN (?, ...)` against the INTEGER PRIMARY KEY B-tree.
5+
// Lookups use `WHERE <key_col> IN (?, ...)` against the INTEGER PRIMARY KEY B-tree.
66
//
7-
// Schema: row_idx INTEGER PRIMARY KEY, <col> TEXT/INTEGER/REAL, ...
7+
// Schema: <key_col> INTEGER PRIMARY KEY, <col> TEXT/INTEGER/REAL, ...
8+
//
9+
// The key column name is caller-provided (e.g. "_key") and must match the first
10+
// field in the schema passed to `open_or_build`.
811
//
912
// Persistence: the database is written once to the given path and reused on
1013
// subsequent runs. The first build reads all parquet files and inserts rows
@@ -42,6 +45,7 @@ use crate::lookup::PointLookupProvider;
4245
pub struct SqliteLookupProvider {
4346
schema: SchemaRef,
4447
table_name: String,
48+
key_col: String,
4549
pool: Arc<Mutex<Vec<Connection>>>,
4650
sem: Arc<Semaphore>,
4751
}
@@ -117,6 +121,8 @@ impl SqliteLookupProvider {
117121
schema: SchemaRef,
118122
parquet_col_indices: &[usize],
119123
) -> DFResult<Self> {
124+
// The first field in the schema is the key column (INTEGER PRIMARY KEY).
125+
let key_col = schema.field(0).name().clone();
120126
if pool_size == 0 {
121127
return Err(DataFusionError::Execution(
122128
"pool_size must be at least 1".into(),
@@ -167,6 +173,7 @@ impl SqliteLookupProvider {
167173
Ok(Self {
168174
schema,
169175
table_name: table_name.to_string(),
176+
key_col,
170177
pool: Arc::new(Mutex::new(conns)),
171178
sem: Arc::new(Semaphore::new(pool_size)),
172179
})
@@ -202,6 +209,7 @@ impl PointLookupProvider for SqliteLookupProvider {
202209
let keys_vec = keys.to_vec();
203210
let pool = self.pool.clone();
204211
let table_name = self.table_name.clone();
212+
let key_col = self.key_col.clone();
205213

206214
// Acquire a semaphore permit to bound concurrency to the pool size,
207215
// then run the synchronous SQLite query on a blocking thread.
@@ -227,6 +235,7 @@ impl PointLookupProvider for SqliteLookupProvider {
227235
&keys_vec,
228236
&out_schema,
229237
&table_name,
238+
&key_col,
230239
);
231240
drop(guard); // explicit but not required — Drop handles it
232241
res
@@ -243,6 +252,7 @@ fn execute_query_sync(
243252
keys: &[u64],
244253
out_schema: &SchemaRef,
245254
table_name: &str,
255+
key_col: &str,
246256
) -> DFResult<Vec<RecordBatch>> {
247257
let placeholders = keys.iter().map(|_| "?").collect::<Vec<_>>().join(", ");
248258
// Select only the columns in out_schema (already projection-applied by the
@@ -253,8 +263,9 @@ fn execute_query_sync(
253263
.map(|f| quote_ident(f.name()))
254264
.collect::<Vec<_>>()
255265
.join(", ");
266+
let qk = quote_ident(key_col);
256267
let sql = format!(
257-
"SELECT {col_list} FROM {tn} WHERE row_idx IN ({placeholders}) ORDER BY row_idx",
268+
"SELECT {col_list} FROM {tn} WHERE {qk} IN ({placeholders}) ORDER BY {qk}",
258269
tn = quote_ident(table_name)
259270
);
260271

@@ -586,14 +597,16 @@ fn build_table(
586597
schema: &SchemaRef,
587598
parquet_col_indices: &[usize],
588599
) -> DFResult<()> {
600+
// The first field is the key column (INTEGER PRIMARY KEY).
601+
let key_col_name = schema.field(0).name();
589602
let col_defs = schema
590603
.fields()
591604
.iter()
592605
.map(|f| {
593-
let sql_type = arrow_type_to_sql(f.data_type());
594-
if f.name() == "row_idx" {
595-
"row_idx INTEGER PRIMARY KEY".to_string()
606+
if f.name() == key_col_name {
607+
format!("{} INTEGER PRIMARY KEY", quote_ident(f.name()))
596608
} else {
609+
let sql_type = arrow_type_to_sql(f.data_type());
597610
format!("{} {}", quote_ident(f.name()), sql_type)
598611
}
599612
})

0 commit comments

Comments
 (0)