Skip to content

Commit cda8d79

Browse files
committed
fix: use model_id directly to avoid XProvenceConfig class mismatch
The previous fix (7ff382c) incorrectly passed config from AutoConfig.from_pretrained to AutoModel.from_pretrained. Since XProvence's config.json lacks auto_map for AutoConfig, it returned XLMRobertaConfig while the model expected XProvenceConfig. New approach: - Extract model_id from cache path (e.g., naver/xprovence-reranker-bgem3-v1) - Use model_id directly with AutoModel.from_pretrained(model_id, trust_remote_code=True) - Let AutoModel handle config internally via model class's config_class attribute - Remove explicit config passing and snapshot_download (AutoModel handles downloads)
1 parent 7ff382c commit cda8d79

File tree

1 file changed

+43
-24
lines changed

1 file changed

+43
-24
lines changed

backends/python/server/text_embeddings_server/models/xprovence_model.py

Lines changed: 43 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,8 @@
22
import torch
33

44
from pathlib import Path
5-
from typing import Type, List
6-
from transformers import AutoModel, AutoConfig
7-
from huggingface_hub import snapshot_download
5+
from typing import Type, List, Optional
6+
from transformers import AutoModel
87
from opentelemetry import trace
98
from loguru import logger
109

@@ -19,6 +18,23 @@ def _parse_bool(value: str) -> bool:
1918
return str(value).lower() in ("true", "1", "t", "yes", "on")
2019

2120

21+
def _extract_model_id(model_path_str: str) -> Optional[str]:
22+
"""Extract model_id from HF cache path format.
23+
24+
Converts paths like '/data/models--naver--xprovence-reranker-bgem3-v1/snapshots/...'
25+
to 'naver/xprovence-reranker-bgem3-v1'
26+
"""
27+
if "/models--" not in model_path_str:
28+
return None
29+
30+
parts = model_path_str.split("/")
31+
for part in parts:
32+
if part.startswith("models--"):
33+
# models--naver--xprovence-reranker-bgem3-v1 -> naver/xprovence-reranker-bgem3-v1
34+
return part.replace("models--", "").replace("--", "/", 1)
35+
return None
36+
37+
2238
class XProvenceModel(Model):
2339
"""
2440
XProvence: Zero-cost context pruning model for RAG.
@@ -45,28 +61,31 @@ def __init__(
4561
pool: str = "cls",
4662
trust_remote: bool = True,
4763
):
48-
# Download all model files including custom Python files for trust_remote_code
49-
# The Rust router only downloads config/tokenizer/weights, but not custom modeling files
5064
model_path_str = str(model_path)
51-
if model_path_str.startswith("/data/models--"):
52-
# Extract model_id from HF cache path format: /data/models--org--name/...
53-
# Convert "models--naver--xprovence-reranker-bgem3-v1" to "naver/xprovence-reranker-bgem3-v1"
54-
parts = model_path_str.split("/")
55-
for part in parts:
56-
if part.startswith("models--"):
57-
model_id = part.replace("models--", "").replace("--", "/", 1)
58-
logger.info(f"XProvence: Downloading custom files for {model_id}")
59-
cache_dir = os.getenv("HUGGINGFACE_HUB_CACHE", "/data")
60-
snapshot_download(
61-
repo_id=model_id,
62-
cache_dir=cache_dir,
63-
local_files_only=False,
64-
)
65-
break
66-
67-
# Load config first with trust_remote_code to get the correct XProvenceConfig
68-
config = AutoConfig.from_pretrained(model_path, trust_remote_code=True)
69-
model = AutoModel.from_pretrained(model_path, config=config, trust_remote_code=True)
65+
cache_dir = os.getenv("HUGGINGFACE_HUB_CACHE", "/data")
66+
67+
# Extract model_id from cache path for proper trust_remote_code handling
68+
model_id = _extract_model_id(model_path_str)
69+
70+
if model_id:
71+
# Use model_id directly with AutoModel.from_pretrained
72+
# This ensures:
73+
# 1. All custom Python files (modeling_*.py) are downloaded
74+
# 2. The correct XProvenceConfig is loaded via model class's config_class attribute
75+
# 3. No config class mismatch (unlike passing config from AutoConfig.from_pretrained)
76+
logger.info(f"XProvence: Loading {model_id} with trust_remote_code=True")
77+
model = AutoModel.from_pretrained(
78+
model_id,
79+
trust_remote_code=True,
80+
cache_dir=cache_dir,
81+
)
82+
else:
83+
# Fallback for local paths not in HF cache format
84+
logger.info(f"XProvence: Loading from local path {model_path}")
85+
model = AutoModel.from_pretrained(
86+
model_path,
87+
trust_remote_code=True,
88+
)
7089

7190
if dtype == torch.bfloat16:
7291
logger.info("XProvence: using float32 instead of bfloat16 for process() compatibility")

0 commit comments

Comments
 (0)