Skip to content

Commit 09e8491

Browse files
committed
fix: use get_class_from_dynamic_module to avoid config mismatch
AutoModel.from_pretrained internally calls AutoConfig which returns XLMRobertaConfig, causing a conflict with the model's XProvenceConfig. Solution: Use transformers.dynamic_module_utils.get_class_from_dynamic_module to directly import the custom XProvenceForSequenceClassification class, then call from_pretrained on the custom class which uses its own config_class.
1 parent 7b63634 commit 09e8491

File tree

1 file changed

+22
-10
lines changed

1 file changed

+22
-10
lines changed

backends/python/server/text_embeddings_server/models/xprovence_model.py

Lines changed: 22 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,8 @@
33

44
from pathlib import Path
55
from typing import Type, List, Optional
6-
from transformers import AutoModel
6+
from transformers.dynamic_module_utils import get_class_from_dynamic_module
7+
from huggingface_hub import hf_hub_download
78
from opentelemetry import trace
89
from loguru import logger
910

@@ -68,21 +69,32 @@ def __init__(
6869
model_id = _extract_model_id(model_path_str)
6970

7071
if model_id:
71-
# Use model_id directly with AutoModel.from_pretrained
72-
# This ensures:
73-
# 1. All custom Python files (modeling_*.py) are downloaded
74-
# 2. The correct XProvenceConfig is loaded via model class's config_class attribute
75-
# 3. No config class mismatch (unlike passing config from AutoConfig.from_pretrained)
76-
logger.info(f"XProvence: Loading {model_id} with trust_remote_code=True")
77-
model = AutoModel.from_pretrained(
72+
# Directly import the custom model class to avoid AutoModel's config class mismatch
73+
# AutoModel.from_pretrained internally loads config which causes XLMRobertaConfig
74+
# to be registered, conflicting with the model's expected XProvenceConfig
75+
logger.info(f"XProvence: Loading custom model class for {model_id}")
76+
77+
# Get the custom model class directly from the dynamic module
78+
model_class = get_class_from_dynamic_module(
79+
"modeling_xprovence_hf.XProvenceForSequenceClassification",
80+
model_id,
81+
cache_dir=cache_dir,
82+
)
83+
84+
# Load using the custom class directly - this uses the correct config_class
85+
model = model_class.from_pretrained(
7886
model_id,
7987
trust_remote_code=True,
8088
cache_dir=cache_dir,
8189
)
8290
else:
83-
# Fallback for local paths not in HF cache format
91+
# Fallback for local paths - try to import from local path
8492
logger.info(f"XProvence: Loading from local path {model_path}")
85-
model = AutoModel.from_pretrained(
93+
model_class = get_class_from_dynamic_module(
94+
"modeling_xprovence_hf.XProvenceForSequenceClassification",
95+
model_path,
96+
)
97+
model = model_class.from_pretrained(
8698
model_path,
8799
trust_remote_code=True,
88100
)

0 commit comments

Comments
 (0)