Skip to content

Commit 42654a4

Browse files
committed
fix: optimize XProvence batch processing with broadcasting and warnings
- Add broadcasting support: 1 query → N texts (common reranking pattern) - Replace silent fallback with explicit warning on dimension mismatch - Use torch.inference_mode() around entire batch for better performance - Reduce per-item overhead by batching dtype handling and TQDM_DISABLE - Add per-item error handling with graceful fallback to 0.0 score Performance improvements: - Single dtype context switch instead of per-item - Single inference_mode context for entire batch - Reduced logging overhead with debug level for per-item details
1 parent 778caf7 commit 42654a4

File tree

1 file changed

+76
-32
lines changed

1 file changed

+76
-32
lines changed

backends/python/server/text_embeddings_server/models/xprovence_model.py

Lines changed: 76 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -164,41 +164,85 @@ def predict(self, batch: PaddedBatch) -> List[Score]:
164164
Otherwise falls back to standard forward pass.
165165
"""
166166
batch_size = len(batch)
167+
raw_queries = batch.raw_queries or []
168+
raw_texts = batch.raw_texts or []
167169

168-
# Check if we have raw data for the full batch
169-
has_raw_data = (
170-
batch.raw_queries is not None
171-
and batch.raw_texts is not None
172-
and len(batch.raw_queries) == batch_size
173-
and len(batch.raw_texts) == batch_size
174-
)
170+
# Broadcasting: 1 query → N texts (common reranking pattern)
171+
if len(raw_queries) == 1 and len(raw_texts) == batch_size and batch_size > 1:
172+
logger.info(f"XProvence: Broadcasting single query to {batch_size} texts")
173+
raw_queries = raw_queries * batch_size
175174

176-
logger.info(
177-
f"XProvence predict: batch_size={batch_size}, "
178-
f"has_raw_queries={batch.raw_queries is not None}, "
179-
f"has_raw_texts={batch.raw_texts is not None}, "
180-
f"has_full_raw_data={has_raw_data}"
181-
)
175+
# Check for dimension mismatch with explicit warning
176+
if len(raw_queries) != batch_size or len(raw_texts) != batch_size:
177+
if raw_queries or raw_texts:
178+
logger.warning(
179+
f"XProvence: Dimension mismatch - batch_size={batch_size}, "
180+
f"raw_queries={len(raw_queries)}, raw_texts={len(raw_texts)}. "
181+
f"Falling back to standard inference (no pruned_text)."
182+
)
183+
return self._predict_standard(batch)
184+
185+
# Process batch with pruning (optimized)
186+
logger.info(f"XProvence: Processing {batch_size} pairs with pruning")
187+
return self._predict_batch_with_pruning(raw_queries, raw_texts)
188+
189+
def _predict_batch_with_pruning(
190+
self, raw_queries: List[str], raw_texts: List[str]
191+
) -> List[Score]:
192+
"""
193+
Optimized batch processing with pruning.
194+
195+
Uses inference_mode and batched dtype handling to reduce per-item overhead.
196+
Note: XProvence process() is inherently per-pair for sentence-level analysis.
197+
"""
198+
batch_size = len(raw_queries)
199+
results = []
200+
201+
# Suppress progress bars once for entire batch
202+
os.environ["TQDM_DISABLE"] = "1"
203+
204+
# Use inference_mode for better performance (no grad tracking)
205+
with torch.inference_mode():
206+
original_dtype = torch.get_default_dtype()
207+
torch.set_default_dtype(torch.float32)
208+
209+
try:
210+
for i in range(batch_size):
211+
query = raw_queries[i]
212+
text = raw_texts[i]
213+
214+
if not query or not text:
215+
logger.warning(
216+
f"XProvence: Empty query/text at index {i}, score=0.0"
217+
)
218+
results.append(Score(values=[0.0], pruned_text=None))
219+
continue
220+
221+
try:
222+
output = self.model.process(
223+
query,
224+
text,
225+
threshold=self.threshold,
226+
always_select_title=self.always_select_title,
227+
)
228+
229+
score = float(output["reranking_score"])
230+
pruned = output["pruned_context"]
231+
232+
logger.debug(
233+
f"XProvence [{i}]: score={score:.4f}, "
234+
f"len={len(text)}{len(pruned)}"
235+
)
236+
results.append(Score(values=[score], pruned_text=pruned))
237+
238+
except Exception as e:
239+
logger.error(f"XProvence process() failed at index {i}: {e}")
240+
results.append(Score(values=[0.0], pruned_text=None))
241+
242+
finally:
243+
torch.set_default_dtype(original_dtype)
182244

183-
if has_raw_data:
184-
logger.info(f"XProvence: Processing batch of {batch_size} with pruning")
185-
results = []
186-
for i in range(batch_size):
187-
query = batch.raw_queries[i]
188-
text = batch.raw_texts[i]
189-
190-
# Verify we have valid strings (not empty)
191-
if query and text:
192-
scores = self._predict_with_pruning(query, text)
193-
results.append(scores[0])
194-
else:
195-
# Empty string fallback - use standard forward pass result
196-
logger.warning(f"XProvence: Empty query/text at index {i}, using 0.0")
197-
results.append(Score(values=[0.0], pruned_text=None))
198-
return results
199-
200-
logger.info("XProvence: Using standard forward pass (no raw_queries/raw_texts)")
201-
return self._predict_standard(batch)
245+
return results
202246

203247
def _predict_with_pruning(self, raw_query: str, raw_text: str) -> List[Score]:
204248
"""

0 commit comments

Comments
 (0)