Skip to content

Commit 7b63634

Browse files
committed
fix: check XProvence before AutoConfig to prevent registry pollution
The previous fix still failed because __init__.py called AutoConfig.from_pretrained before XProvenceModel was created. This polluted transformers' internal config registry with XLMRobertaConfig, causing conflicts when XProvenceModel tried to load the custom XProvenceConfig. Solution: - Add _is_xprovence_model() helper that reads config.json directly - Check for XProvence BEFORE calling AutoConfig.from_pretrained - This prevents transformers from caching the wrong config class
1 parent cda8d79 commit 7b63634

File tree

1 file changed

+25
-7
lines changed
  • backends/python/server/text_embeddings_server/models

1 file changed

+25
-7
lines changed

backends/python/server/text_embeddings_server/models/__init__.py

Lines changed: 25 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import os
2+
import json
23
import torch
34

45
from loguru import logger
@@ -14,6 +15,25 @@
1415
from text_embeddings_server.models.xprovence_model import XProvenceModel
1516
from text_embeddings_server.utils.device import get_device, use_ipex
1617

18+
19+
def _is_xprovence_model(model_path: Path) -> bool:
20+
"""Check if model is XProvence by reading config.json directly.
21+
22+
This avoids calling AutoConfig.from_pretrained which can pollute
23+
transformers' internal registry and cause config class conflicts.
24+
"""
25+
config_path = model_path / "config.json"
26+
if not config_path.exists():
27+
return False
28+
29+
try:
30+
with open(config_path, "r") as f:
31+
config = json.load(f)
32+
architectures = config.get("architectures", [])
33+
return any("XProvence" in arch for arch in architectures)
34+
except Exception:
35+
return False
36+
1737
FlashJinaBert = None
1838
FlashMistral = None
1939
FlashQwen3 = None
@@ -81,16 +101,14 @@ def get_model(model_path: Path, dtype: Optional[str], pool: str):
81101
device = get_device()
82102
logger.info(f"backend device: {device}")
83103

84-
config = AutoConfig.from_pretrained(model_path, trust_remote_code=TRUST_REMOTE_CODE)
85-
86-
if (
87-
hasattr(config, "architectures")
88-
and config.architectures
89-
and "XProvence" in config.architectures[0]
90-
):
104+
# Check for XProvence BEFORE calling AutoConfig.from_pretrained
105+
# to avoid polluting transformers' internal config registry
106+
if _is_xprovence_model(model_path):
91107
logger.info("Detected XProvence model for context pruning")
92108
return XProvenceModel(model_path, device, datatype, trust_remote=True)
93109

110+
config = AutoConfig.from_pretrained(model_path, trust_remote_code=TRUST_REMOTE_CODE)
111+
94112
if (
95113
FlashJinaBert is not None
96114
and hasattr(config, "auto_map")

0 commit comments

Comments
 (0)