@@ -164,41 +164,85 @@ def predict(self, batch: PaddedBatch) -> List[Score]:
164164 Otherwise falls back to standard forward pass.
165165 """
166166 batch_size = len (batch )
167+ raw_queries = batch .raw_queries or []
168+ raw_texts = batch .raw_texts or []
167169
168- # Check if we have raw data for the full batch
169- has_raw_data = (
170- batch .raw_queries is not None
171- and batch .raw_texts is not None
172- and len (batch .raw_queries ) == batch_size
173- and len (batch .raw_texts ) == batch_size
174- )
170+ # Broadcasting: 1 query → N texts (common reranking pattern)
171+ if len (raw_queries ) == 1 and len (raw_texts ) == batch_size and batch_size > 1 :
172+ logger .info (f"XProvence: Broadcasting single query to { batch_size } texts" )
173+ raw_queries = raw_queries * batch_size
175174
176- logger .info (
177- f"XProvence predict: batch_size={ batch_size } , "
178- f"has_raw_queries={ batch .raw_queries is not None } , "
179- f"has_raw_texts={ batch .raw_texts is not None } , "
180- f"has_full_raw_data={ has_raw_data } "
181- )
175+ # Check for dimension mismatch with explicit warning
176+ if len (raw_queries ) != batch_size or len (raw_texts ) != batch_size :
177+ if raw_queries or raw_texts :
178+ logger .warning (
179+ f"XProvence: Dimension mismatch - batch_size={ batch_size } , "
180+ f"raw_queries={ len (raw_queries )} , raw_texts={ len (raw_texts )} . "
181+ f"Falling back to standard inference (no pruned_text)."
182+ )
183+ return self ._predict_standard (batch )
184+
185+ # Process batch with pruning (optimized)
186+ logger .info (f"XProvence: Processing { batch_size } pairs with pruning" )
187+ return self ._predict_batch_with_pruning (raw_queries , raw_texts )
188+
189+ def _predict_batch_with_pruning (
190+ self , raw_queries : List [str ], raw_texts : List [str ]
191+ ) -> List [Score ]:
192+ """
193+ Optimized batch processing with pruning.
194+
195+ Uses inference_mode and batched dtype handling to reduce per-item overhead.
196+ Note: XProvence process() is inherently per-pair for sentence-level analysis.
197+ """
198+ batch_size = len (raw_queries )
199+ results = []
200+
201+ # Suppress progress bars once for entire batch
202+ os .environ ["TQDM_DISABLE" ] = "1"
203+
204+ # Use inference_mode for better performance (no grad tracking)
205+ with torch .inference_mode ():
206+ original_dtype = torch .get_default_dtype ()
207+ torch .set_default_dtype (torch .float32 )
208+
209+ try :
210+ for i in range (batch_size ):
211+ query = raw_queries [i ]
212+ text = raw_texts [i ]
213+
214+ if not query or not text :
215+ logger .warning (
216+ f"XProvence: Empty query/text at index { i } , score=0.0"
217+ )
218+ results .append (Score (values = [0.0 ], pruned_text = None ))
219+ continue
220+
221+ try :
222+ output = self .model .process (
223+ query ,
224+ text ,
225+ threshold = self .threshold ,
226+ always_select_title = self .always_select_title ,
227+ )
228+
229+ score = float (output ["reranking_score" ])
230+ pruned = output ["pruned_context" ]
231+
232+ logger .debug (
233+ f"XProvence [{ i } ]: score={ score :.4f} , "
234+ f"len={ len (text )} →{ len (pruned )} "
235+ )
236+ results .append (Score (values = [score ], pruned_text = pruned ))
237+
238+ except Exception as e :
239+ logger .error (f"XProvence process() failed at index { i } : { e } " )
240+ results .append (Score (values = [0.0 ], pruned_text = None ))
241+
242+ finally :
243+ torch .set_default_dtype (original_dtype )
182244
183- if has_raw_data :
184- logger .info (f"XProvence: Processing batch of { batch_size } with pruning" )
185- results = []
186- for i in range (batch_size ):
187- query = batch .raw_queries [i ]
188- text = batch .raw_texts [i ]
189-
190- # Verify we have valid strings (not empty)
191- if query and text :
192- scores = self ._predict_with_pruning (query , text )
193- results .append (scores [0 ])
194- else :
195- # Empty string fallback - use standard forward pass result
196- logger .warning (f"XProvence: Empty query/text at index { i } , using 0.0" )
197- results .append (Score (values = [0.0 ], pruned_text = None ))
198- return results
199-
200- logger .info ("XProvence: Using standard forward pass (no raw_queries/raw_texts)" )
201- return self ._predict_standard (batch )
245+ return results
202246
203247 def _predict_with_pruning (self , raw_query : str , raw_text : str ) -> List [Score ]:
204248 """
0 commit comments