|
3 | 3 |
|
4 | 4 | from pathlib import Path |
5 | 5 | from typing import Type, List, Optional |
6 | | -from transformers import AutoModel |
| 6 | +from transformers.dynamic_module_utils import get_class_from_dynamic_module |
| 7 | +from huggingface_hub import hf_hub_download |
7 | 8 | from opentelemetry import trace |
8 | 9 | from loguru import logger |
9 | 10 |
|
@@ -68,21 +69,32 @@ def __init__( |
68 | 69 | model_id = _extract_model_id(model_path_str) |
69 | 70 |
|
70 | 71 | if model_id: |
71 | | - # Use model_id directly with AutoModel.from_pretrained |
72 | | - # This ensures: |
73 | | - # 1. All custom Python files (modeling_*.py) are downloaded |
74 | | - # 2. The correct XProvenceConfig is loaded via model class's config_class attribute |
75 | | - # 3. No config class mismatch (unlike passing config from AutoConfig.from_pretrained) |
76 | | - logger.info(f"XProvence: Loading {model_id} with trust_remote_code=True") |
77 | | - model = AutoModel.from_pretrained( |
| 72 | + # Directly import the custom model class to avoid AutoModel's config class mismatch |
| 73 | + # AutoModel.from_pretrained internally loads config which causes XLMRobertaConfig |
| 74 | + # to be registered, conflicting with the model's expected XProvenceConfig |
| 75 | + logger.info(f"XProvence: Loading custom model class for {model_id}") |
| 76 | + |
| 77 | + # Get the custom model class directly from the dynamic module |
| 78 | + model_class = get_class_from_dynamic_module( |
| 79 | + "modeling_xprovence_hf.XProvenceForSequenceClassification", |
| 80 | + model_id, |
| 81 | + cache_dir=cache_dir, |
| 82 | + ) |
| 83 | + |
| 84 | + # Load using the custom class directly - this uses the correct config_class |
| 85 | + model = model_class.from_pretrained( |
78 | 86 | model_id, |
79 | 87 | trust_remote_code=True, |
80 | 88 | cache_dir=cache_dir, |
81 | 89 | ) |
82 | 90 | else: |
83 | | - # Fallback for local paths not in HF cache format |
| 91 | + # Fallback for local paths - try to import from local path |
84 | 92 | logger.info(f"XProvence: Loading from local path {model_path}") |
85 | | - model = AutoModel.from_pretrained( |
| 93 | + model_class = get_class_from_dynamic_module( |
| 94 | + "modeling_xprovence_hf.XProvenceForSequenceClassification", |
| 95 | + model_path, |
| 96 | + ) |
| 97 | + model = model_class.from_pretrained( |
86 | 98 | model_path, |
87 | 99 | trust_remote_code=True, |
88 | 100 | ) |
|
0 commit comments