Skip to content
114 changes: 55 additions & 59 deletions backends/exllamav2/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@

from ruamel.yaml import YAML

from backends.exllamav2.types import DraftModelInstanceConfig, ModelInstanceConfig
from common.health import HealthManager

from backends.exllamav2.grammar import (
Expand All @@ -43,6 +44,7 @@
hardware_supports_flash_attn,
supports_paged_attn,
)
from common.tabby_config import config
from common.concurrency import iterate_in_threadpool
from common.gen_logging import (
log_generation_params,
Expand Down Expand Up @@ -103,7 +105,12 @@ class ExllamaV2Container:
load_condition: asyncio.Condition = asyncio.Condition()

@classmethod
async def create(cls, model_directory: pathlib.Path, quiet=False, **kwargs):
async def create(
cls,
model: ModelInstanceConfig,
draft: DraftModelInstanceConfig,
quiet=False,
):
"""
Primary asynchronous initializer for model container.

Expand All @@ -117,8 +124,15 @@ async def create(cls, model_directory: pathlib.Path, quiet=False, **kwargs):

# Initialize config
self.config = ExLlamaV2Config()
self.model_dir = model_directory
self.config.model_dir = str(model_directory.resolve())

model_path = pathlib.Path(config.model.model_dir)
model_path = model_path / model.model_name
model_path = model_path.resolve()
if not model_path.exists():
raise FileNotFoundError(f"Model path {model_path} does not exist.")

self.model_dir = model_path
self.config.model_dir = str(model_path)

# Make the max seq len 4096 before preparing the config
# This is a better default than 2048
Expand All @@ -130,35 +144,23 @@ async def create(cls, model_directory: pathlib.Path, quiet=False, **kwargs):
self.config.arch_compat_overrides()

# Prepare the draft model config if necessary
draft_args = unwrap(kwargs.get("draft"), {})
draft_model_name = draft_args.get("draft_model_name")
enable_draft = draft_args and draft_model_name

# Always disable draft if params are incorrectly configured
if draft_args and draft_model_name is None:
logger.warning(
"Draft model is disabled because a model name "
"wasn't provided. Please check your config.yml!"
)
enable_draft = False

if enable_draft:
if draft.draft_model_name:
self.draft_config = ExLlamaV2Config()
self.draft_config.no_flash_attn = self.config.no_flash_attn
draft_model_path = pathlib.Path(
unwrap(draft_args.get("draft_model_dir"), "models")

draft_model_path = (
config.draft_model.draft_model_dir / draft.draft_model_name
)
draft_model_path = draft_model_path / draft_model_name

self.draft_model_dir = draft_model_path
self.draft_config.model_dir = str(draft_model_path.resolve())
self.draft_config.prepare()

# Create the hf_config
self.hf_config = await HuggingFaceConfig.from_file(model_directory)
self.hf_config = await HuggingFaceConfig.from_file(model_path)

# Load generation config overrides
generation_config_path = model_directory / "generation_config.json"
generation_config_path = model_path / "generation_config.json"
if generation_config_path.exists():
try:
self.generation_config = await GenerationConfig.from_file(
Expand All @@ -171,18 +173,20 @@ async def create(cls, model_directory: pathlib.Path, quiet=False, **kwargs):
)

# Apply a model's config overrides while respecting user settings
kwargs = await self.set_model_overrides(**kwargs)

# FIXME: THIS IS BROKEN!!!
# kwargs do not exist now
# should be investigated after the models have pydantic stuff
# kwargs = await self.set_model_overrides(**kwargs)

# MARK: User configuration

# Get cache mode
self.cache_mode = unwrap(kwargs.get("cache_mode"), "FP16")
self.cache_mode = model.cache_mode

# Turn off GPU split if the user is using 1 GPU
gpu_count = torch.cuda.device_count()
gpu_split_auto = unwrap(kwargs.get("gpu_split_auto"), True)
use_tp = unwrap(kwargs.get("tensor_parallel"), False)
gpu_split = kwargs.get("gpu_split")
gpu_split_auto = model.gpu_split_auto
gpu_device_list = list(range(0, gpu_count))

# Set GPU split options
Expand All @@ -191,16 +195,16 @@ async def create(cls, model_directory: pathlib.Path, quiet=False, **kwargs):
logger.info("Disabling GPU split because one GPU is in use.")
else:
# Set tensor parallel
if use_tp:
if model.tensor_parallel:
self.use_tp = True

# TP has its own autosplit loader
self.gpu_split_auto = False

# Enable manual GPU split if provided
if gpu_split:
if model.gpu_split:
self.gpu_split_auto = False
self.gpu_split = gpu_split
self.gpu_split = model.gpu_split

gpu_device_list = [
device_idx
Expand All @@ -211,9 +215,7 @@ async def create(cls, model_directory: pathlib.Path, quiet=False, **kwargs):
# Otherwise fallback to autosplit settings
self.gpu_split_auto = gpu_split_auto

autosplit_reserve_megabytes = unwrap(
kwargs.get("autosplit_reserve"), [96]
)
autosplit_reserve_megabytes = model.autosplit_reserve

# Reserve VRAM for each GPU
self.autosplit_reserve = [
Expand All @@ -225,37 +227,34 @@ async def create(cls, model_directory: pathlib.Path, quiet=False, **kwargs):
self.config.max_output_len = 16

# Then override the base_seq_len if present
override_base_seq_len = kwargs.get("override_base_seq_len")
if override_base_seq_len:
self.config.max_seq_len = override_base_seq_len
if model.override_base_seq_len:
self.config.max_seq_len = model.override_base_seq_len

# Grab the base model's sequence length before overrides for
# rope calculations
base_seq_len = self.config.max_seq_len

# Set the target seq len if present
target_max_seq_len = kwargs.get("max_seq_len")
target_max_seq_len = model.max_seq_len
if target_max_seq_len:
self.config.max_seq_len = target_max_seq_len

# Set the rope scale
self.config.scale_pos_emb = unwrap(
kwargs.get("rope_scale"), self.config.scale_pos_emb
)
self.config.scale_pos_emb = unwrap(model.rope_scale, self.config.scale_pos_emb)

# Sets rope alpha value.
# Automatically calculate if unset or defined as an "auto" literal.
rope_alpha = unwrap(kwargs.get("rope_alpha"), "auto")
rope_alpha = unwrap(model.rope_alpha, "auto")
if rope_alpha == "auto":
self.config.scale_alpha_value = self.calculate_rope_alpha(base_seq_len)
else:
self.config.scale_alpha_value = rope_alpha

# Enable fasttensors loading if present
self.config.fasttensors = unwrap(kwargs.get("fasttensors"), False)
self.config.fasttensors = config.model.fasttensors

# Set max batch size to the config override
self.max_batch_size = unwrap(kwargs.get("max_batch_size"))
self.max_batch_size = model.max_batch_size

# Check whether the user's configuration supports flash/paged attention
# Also check if exl2 has disabled flash attention
Expand All @@ -272,7 +271,7 @@ async def create(cls, model_directory: pathlib.Path, quiet=False, **kwargs):
# Set k/v cache size
# cache_size is only relevant when paged mode is enabled
if self.paged:
cache_size = unwrap(kwargs.get("cache_size"), self.config.max_seq_len)
cache_size = unwrap(model.cache_size, self.config.max_seq_len)

if cache_size < self.config.max_seq_len:
logger.warning(
Expand Down Expand Up @@ -314,7 +313,7 @@ async def create(cls, model_directory: pathlib.Path, quiet=False, **kwargs):

# Try to set prompt template
self.prompt_template = await self.find_prompt_template(
kwargs.get("prompt_template"), model_directory
model.prompt_template, model.model_name
)

# Catch all for template lookup errors
Expand All @@ -329,29 +328,26 @@ async def create(cls, model_directory: pathlib.Path, quiet=False, **kwargs):
)

# Set num of experts per token if provided
num_experts_override = kwargs.get("num_experts_per_token")
if num_experts_override:
self.config.num_experts_per_token = kwargs.get("num_experts_per_token")
if model.num_experts_per_token:
self.config.num_experts_per_token = model.num_experts_per_token

# Make sure chunk size is >= 16 and <= max seq length
user_chunk_size = unwrap(kwargs.get("chunk_size"), 2048)
user_chunk_size = unwrap(model.chunk_size, 2048)
chunk_size = sorted((16, user_chunk_size, self.config.max_seq_len))[1]
self.config.max_input_len = chunk_size
self.config.max_attention_size = chunk_size**2

# Set user-configured draft model values
if enable_draft:
# Fetch from the updated kwargs
draft_args = unwrap(kwargs.get("draft"), {})
if draft.draft_model_name:

self.draft_config.max_seq_len = self.config.max_seq_len

self.draft_config.scale_pos_emb = unwrap(
draft_args.get("draft_rope_scale"), 1.0
draft.draft_rope_scale, 1.0
)

# Set draft rope alpha. Follows same behavior as model rope alpha.
draft_rope_alpha = unwrap(draft_args.get("draft_rope_alpha"), "auto")
draft_rope_alpha = unwrap(draft.draft_rope_alpha, "auto")
if draft_rope_alpha == "auto":
self.draft_config.scale_alpha_value = self.calculate_rope_alpha(
self.draft_config.max_seq_len
Expand All @@ -360,7 +356,7 @@ async def create(cls, model_directory: pathlib.Path, quiet=False, **kwargs):
self.draft_config.scale_alpha_value = draft_rope_alpha

# Set draft cache mode
self.draft_cache_mode = unwrap(draft_args.get("draft_cache_mode"), "FP16")
self.draft_cache_mode = draft.draft_cache_mode

if chunk_size:
self.draft_config.max_input_len = chunk_size
Expand Down Expand Up @@ -524,7 +520,7 @@ def progress(loaded_modules: int, total_modules: int)
async for _ in self.load_gen(progress_callback):
pass

async def load_gen(self, progress_callback=None, **kwargs):
async def load_gen(self, progress_callback=None, skip_wait=False):
"""Loads a model and streams progress via a generator."""

# Indicate that model load has started
Expand All @@ -534,7 +530,7 @@ async def load_gen(self, progress_callback=None, **kwargs):
self.model_is_loading = True

# Wait for existing generation jobs to finish
await self.wait_for_jobs(kwargs.get("skip_wait"))
await self.wait_for_jobs(skip_wait)

# Streaming gen for model load progress
model_load_generator = self.load_model_sync(progress_callback)
Expand Down Expand Up @@ -1130,19 +1126,19 @@ async def generate_gen(
grammar_handler = ExLlamaV2Grammar()

# Add JSON schema filter if it exists
json_schema = unwrap(kwargs.get("json_schema"))
json_schema = kwargs.get("json_schema")
if json_schema:
grammar_handler.add_json_schema_filter(
json_schema, self.model, self.tokenizer
)

# Add regex filter if it exists
regex_pattern = unwrap(kwargs.get("regex_pattern"))
regex_pattern = kwargs.get("regex_pattern")
if regex_pattern:
grammar_handler.add_regex_filter(regex_pattern, self.model, self.tokenizer)

# Add EBNF filter if it exists
grammar_string = unwrap(kwargs.get("grammar_string"))
grammar_string = kwargs.get("grammar_string")
if grammar_string:
grammar_handler.add_ebnf_filter(grammar_string, self.model, self.tokenizer)

Expand Down
Loading