Skip to content

Commit 778caf7

Browse files
committed
feat: support batch processing with pruned_text for multiple texts
Previously, only the first raw_query/raw_text was sent to Python backend, so process() was only called when batch_size == 1. Now all pairs are sent. Changes: - embed.proto: change to repeated string raw_queries/raw_texts - grpc-client: accept Vec<String> instead of Option<String> - backends/python/src/lib.rs: send all raw_queries/texts from batch - types.py: extract lists from proto repeated fields - xprovence_model.py: iterate batch and call process() for each pair Now /rerank with multiple texts returns pruned_text for each result.
1 parent cc0b4e5 commit 778caf7

File tree

5 files changed

+64
-45
lines changed

5 files changed

+64
-45
lines changed

backends/grpc-client/src/client.rs

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -59,8 +59,8 @@ impl Client {
5959
position_ids,
6060
max_length,
6161
cu_seq_lengths,
62-
raw_query: None,
63-
raw_text: None,
62+
raw_queries: vec![],
63+
raw_texts: vec![],
6464
})
6565
.inject_context();
6666
let response = self.stub.embed(request).await?.into_inner();
@@ -75,17 +75,17 @@ impl Client {
7575
position_ids: Vec<u32>,
7676
cu_seq_lengths: Vec<u32>,
7777
max_length: u32,
78-
raw_query: Option<String>,
79-
raw_text: Option<String>,
78+
raw_queries: Vec<String>,
79+
raw_texts: Vec<String>,
8080
) -> Result<Vec<Score>> {
8181
let request = tonic::Request::new(EmbedRequest {
8282
input_ids,
8383
token_type_ids,
8484
position_ids,
8585
max_length,
8686
cu_seq_lengths,
87-
raw_query,
88-
raw_text,
87+
raw_queries,
88+
raw_texts,
8989
})
9090
.inject_context();
9191
let response = self.stub.predict(request).await?.into_inner();

backends/proto/embed.proto

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,10 +21,10 @@ message EmbedRequest {
2121
repeated uint32 cu_seq_lengths = 4;
2222
/// Length of the longest request
2323
uint32 max_length = 5;
24-
/// XProvence: raw query text for context pruning
25-
optional string raw_query = 6;
26-
/// XProvence: raw context text for context pruning
27-
optional string raw_text = 7;
24+
/// XProvence: raw query texts for context pruning (one per batch item)
25+
repeated string raw_queries = 6;
26+
/// XProvence: raw context texts for context pruning (one per batch item)
27+
repeated string raw_texts = 7;
2828
}
2929

3030
message Embedding {

backends/python/server/text_embeddings_server/models/types.py

Lines changed: 10 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,8 @@
33
import torch
44

55
from abc import ABC, abstractmethod
6-
from dataclasses import dataclass
6+
from dataclasses import dataclass, field
7+
from typing import List, Optional
78
from opentelemetry import trace
89

910
from text_embeddings_server.pb import embed_pb2
@@ -36,9 +37,9 @@ class PaddedBatch(Batch):
3637
token_type_ids: torch.Tensor
3738
position_ids: torch.Tensor
3839
attention_mask: torch.Tensor
39-
# XProvence: raw text for context pruning
40-
raw_query: str = None
41-
raw_text: str = None
40+
# XProvence: raw texts for context pruning (one per batch item)
41+
raw_queries: Optional[List[str]] = None
42+
raw_texts: Optional[List[str]] = None
4243

4344
@classmethod
4445
@tracer.start_as_current_span("from_pb")
@@ -80,27 +81,17 @@ def from_pb(
8081
# Move padded tensors all at once
8182
all_tensors = all_tensors.to(device)
8283

83-
# XProvence: Extract raw text if present in proto
84-
# Use HasField for proto3 optional fields to properly detect if they were set
85-
raw_query = None
86-
raw_text = None
87-
if hasattr(pb, 'HasField'):
88-
if pb.HasField('raw_query'):
89-
raw_query = pb.raw_query
90-
if pb.HasField('raw_text'):
91-
raw_text = pb.raw_text
92-
else:
93-
# Fallback for older proto versions
94-
raw_query = pb.raw_query if pb.raw_query else None
95-
raw_text = pb.raw_text if pb.raw_text else None
84+
# XProvence: Extract repeated raw_queries/raw_texts from proto
85+
raw_queries = list(pb.raw_queries) if pb.raw_queries else None
86+
raw_texts = list(pb.raw_texts) if pb.raw_texts else None
9687

9788
return PaddedBatch(
9889
input_ids=all_tensors[0],
9990
token_type_ids=all_tensors[1],
10091
position_ids=all_tensors[2],
10192
attention_mask=all_tensors[3],
102-
raw_query=raw_query,
103-
raw_text=raw_text,
93+
raw_queries=raw_queries,
94+
raw_texts=raw_texts,
10495
)
10596

10697
def __len__(self):

backends/python/server/text_embeddings_server/models/xprovence_model.py

Lines changed: 31 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -159,25 +159,45 @@ def predict(self, batch: PaddedBatch) -> List[Score]:
159159
"""
160160
XProvence prediction with context pruning support.
161161
162-
For single-item batches with raw_query/raw_text available,
163-
uses XProvence's process() method for sentence-level pruning.
162+
For batches with raw_queries/raw_texts available (one per item),
163+
uses XProvence's process() method for sentence-level pruning on each pair.
164164
Otherwise falls back to standard forward pass.
165165
"""
166166
batch_size = len(batch)
167167

168-
# Debug: log raw_query/raw_text availability
169-
has_query = batch.raw_query is not None
170-
has_text = batch.raw_text is not None
168+
# Check if we have raw data for the full batch
169+
has_raw_data = (
170+
batch.raw_queries is not None
171+
and batch.raw_texts is not None
172+
and len(batch.raw_queries) == batch_size
173+
and len(batch.raw_texts) == batch_size
174+
)
175+
171176
logger.info(
172177
f"XProvence predict: batch_size={batch_size}, "
173-
f"has_raw_query={has_query}, has_raw_text={has_text}"
178+
f"has_raw_queries={batch.raw_queries is not None}, "
179+
f"has_raw_texts={batch.raw_texts is not None}, "
180+
f"has_full_raw_data={has_raw_data}"
174181
)
175182

176-
if batch_size == 1 and batch.raw_query and batch.raw_text:
177-
logger.info("XProvence: Using process() for context pruning")
178-
return self._predict_with_pruning(batch.raw_query, batch.raw_text)
179-
180-
logger.info("XProvence: Using standard forward pass (no raw_query/raw_text)")
183+
if has_raw_data:
184+
logger.info(f"XProvence: Processing batch of {batch_size} with pruning")
185+
results = []
186+
for i in range(batch_size):
187+
query = batch.raw_queries[i]
188+
text = batch.raw_texts[i]
189+
190+
# Verify we have valid strings (not empty)
191+
if query and text:
192+
scores = self._predict_with_pruning(query, text)
193+
results.append(scores[0])
194+
else:
195+
# Empty string fallback - use standard forward pass result
196+
logger.warning(f"XProvence: Empty query/text at index {i}, using 0.0")
197+
results.append(Score(values=[0.0], pruned_text=None))
198+
return results
199+
200+
logger.info("XProvence: Using standard forward pass (no raw_queries/raw_texts)")
181201
return self._predict_standard(batch)
182202

183203
def _predict_with_pruning(self, raw_query: str, raw_text: str) -> List[Score]:

backends/python/src/lib.rs

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -109,9 +109,17 @@ impl Backend for PythonBackend {
109109
}
110110
let batch_size = batch.len();
111111

112-
// XProvence: Get first raw query/text from batch (for single request)
113-
let raw_query = batch.raw_queries.first().cloned().flatten();
114-
let raw_text = batch.raw_texts.first().cloned().flatten();
112+
// XProvence: Collect all raw queries/texts for the batch (one per item)
113+
let raw_queries: Vec<String> = batch
114+
.raw_queries
115+
.into_iter()
116+
.map(|q| q.unwrap_or_default())
117+
.collect();
118+
let raw_texts: Vec<String> = batch
119+
.raw_texts
120+
.into_iter()
121+
.map(|t| t.unwrap_or_default())
122+
.collect();
115123

116124
let results = self
117125
.tokio_runtime
@@ -121,8 +129,8 @@ impl Backend for PythonBackend {
121129
batch.position_ids,
122130
batch.cumulative_seq_lengths,
123131
batch.max_length,
124-
raw_query,
125-
raw_text,
132+
raw_queries,
133+
raw_texts,
126134
))
127135
.map_err(|err| BackendError::Inference(err.to_string()))?;
128136

0 commit comments

Comments
 (0)