diff --git a/internvl_chat/internvl/model/internvl_chat/configuration_internvl_chat.py b/internvl_chat/internvl/model/internvl_chat/configuration_internvl_chat.py index 80abf7cb..af013842 100644 --- a/internvl_chat/internvl/model/internvl_chat/configuration_internvl_chat.py +++ b/internvl_chat/internvl/model/internvl_chat/configuration_internvl_chat.py @@ -45,21 +45,23 @@ def __init__( logger.info('vision_config is None. Initializing the InternVisionConfig with default values.') if llm_config is None: - # TODO: There might still be a bug in transformers version 4.44 and above. - llm_config = {'architectures': ['']} + llm_config = {} logger.info('llm_config is None. Initializing the LlamaConfig config with default values (`LlamaConfig`).') self.vision_config = InternVisionConfig(**vision_config) - if llm_config['architectures'][0] == 'LlamaForCausalLM': + arch = llm_config.get('architectures', [''])[0] + if arch == 'LlamaForCausalLM': self.llm_config = LlamaConfig(**llm_config) - elif llm_config['architectures'][0] == 'InternLM2ForCausalLM': + elif arch == 'InternLM2ForCausalLM': self.llm_config = InternLM2Config(**llm_config) - elif llm_config['architectures'][0] == 'Phi3ForCausalLM': + elif arch == 'Phi3ForCausalLM': self.llm_config = Phi3Config(**llm_config) - elif llm_config['architectures'][0] == 'Qwen2ForCausalLM': + elif arch == 'Qwen2ForCausalLM': self.llm_config = Qwen2Config(**llm_config) + elif arch: + self.llm_config = AutoConfig.for_model(llm_config.get('model_type', 'llama'), **llm_config) else: - raise ValueError('Unsupported architecture: {}'.format(llm_config['architectures'][0])) + self.llm_config = LlamaConfig(**llm_config) self.use_backbone_lora = use_backbone_lora self.use_llm_lora = use_llm_lora self.pad2square = pad2square