From d5c118801632402cfd87af172ab1328c84921e90 Mon Sep 17 00:00:00 2001 From: TerminalMan <84923604+SecretiveShell@users.noreply.github.com> Date: Tue, 24 Sep 2024 15:17:49 +0100 Subject: [PATCH 1/9] migrate various minor systems to use pydantic - improve validators for some types - convert various systems to access config directly - lots of file path changes --- common/auth.py | 14 ++--- common/config_models.py | 114 +++++++++++++++------------------- common/downloader.py | 4 +- common/logger.py | 2 +- common/tabby_config.py | 7 ++- common/utils.py | 4 +- endpoints/core/router.py | 13 ++-- endpoints/core/utils/model.py | 17 +++-- main.py | 2 +- 9 files changed, 81 insertions(+), 96 deletions(-) diff --git a/common/auth.py b/common/auth.py index b02cdd02..773b59b3 100644 --- a/common/auth.py +++ b/common/auth.py @@ -13,6 +13,7 @@ from typing import Optional from common.utils import coalesce +from common.tabby_config import config class AuthKeys(BaseModel): @@ -39,17 +40,14 @@ def verify_key(self, test_key: str, key_type: str): # Global auth constants AUTH_KEYS: Optional[AuthKeys] = None -DISABLE_AUTH: bool = False -async def load_auth_keys(disable_from_config: bool): +async def load_auth_keys(): """Load the authentication keys from api_tokens.yml. If the file does not exist, generate new keys and save them to api_tokens.yml.""" global AUTH_KEYS - global DISABLE_AUTH - DISABLE_AUTH = disable_from_config - if disable_from_config: + if config.network.disable_auth: logger.warning( "Disabling authentication makes your instance vulnerable. " "Set the `disable_auth` flag to False in config.yml if you " @@ -94,7 +92,7 @@ def get_key_permission(request: Request): """ # Give full admin permissions if auth is disabled - if DISABLE_AUTH: + if config.network.disable_auth: return "admin" # Hyphens are okay here @@ -124,7 +122,7 @@ async def check_api_key( """Check if the API key is valid.""" # Allow request if auth is disabled - if DISABLE_AUTH: + if config.network.disable_auth: return if x_api_key: @@ -152,7 +150,7 @@ async def check_admin_key( """Check if the admin key is valid.""" # Allow request if auth is disabled - if DISABLE_AUTH: + if config.network.disable_auth: return if x_admin_key: diff --git a/common/config_models.py b/common/config_models.py index 79d774fa..a7fea31e 100644 --- a/common/config_models.py +++ b/common/config_models.py @@ -1,8 +1,9 @@ -from pathlib import Path from pydantic import ( BaseModel, ConfigDict, + DirectoryPath, Field, + FilePath, PrivateAttr, field_validator, ) @@ -15,7 +16,9 @@ class Metadata(BaseModel): """metadata model for config options""" - include_in_config: Optional[bool] = Field(True) + include_in_config: bool = Field( + True, description="if the model is included by the config file generator" + ) class BaseConfigModel(BaseModel): @@ -27,8 +30,7 @@ class BaseConfigModel(BaseModel): class ConfigOverrideConfig(BaseConfigModel): """Model for overriding a provided config file.""" - # TODO: convert this to a pathlib.path? - config: Optional[str] = Field( + config: Optional[FilePath] = Field( None, description=("Path to an overriding config.yml file") ) @@ -39,18 +41,14 @@ class UtilityActions(BaseConfigModel): """Model used for arg actions.""" # YAML export options - export_config: Optional[str] = Field( - None, description="generate a template config file" - ) - config_export_path: Optional[Path] = Field( + export_config: bool = Field(False, description="generate a template config file") + config_export_path: FilePath = Field( "config_sample.yml", description="path to export configuration file to" ) # OpenAPI JSON export options - export_openapi: Optional[bool] = Field( - False, description="export openapi schema files" - ) - openapi_export_path: Optional[Path] = Field( + export_openapi: bool = Field(False, description="export openapi schema files") + openapi_export_path: FilePath = Field( "openapi.json", description="path to export openapi schema to" ) @@ -60,17 +58,16 @@ class UtilityActions(BaseConfigModel): class NetworkConfig(BaseConfigModel): """Options for networking""" - host: Optional[str] = Field( + # TODO: convert to IPvAnyAddress? + host: str = Field( "127.0.0.1", description=( "The IP to host on (default: 127.0.0.1).\n" "Use 0.0.0.0 to expose on all network adapters." ), ) - port: Optional[int] = Field( - 5000, description=("The port to host on (default: 5000).") - ) - disable_auth: Optional[bool] = Field( + port: int = Field(5000, description=("The port to host on (default: 5000).")) + disable_auth: bool = Field( False, description=( "Disable HTTP token authentication with requests.\n" @@ -78,14 +75,14 @@ class NetworkConfig(BaseConfigModel): "Turn on this option if you are ONLY connecting from localhost." ), ) - send_tracebacks: Optional[bool] = Field( + send_tracebacks: bool = Field( False, description=( "Send tracebacks over the API (default: False).\n" "NOTE: Only enable this for debug purposes." ), ) - api_servers: Optional[List[Literal["oai", "kobold"]]] = Field( + api_servers: List[Literal["oai", "kobold"]] = Field( ["OAI"], description=( 'Select API servers to enable (default: ["OAI"]).\n' @@ -105,15 +102,15 @@ def api_server_validator(cls, api_servers): class LoggingConfig(BaseConfigModel): """Options for logging""" - log_prompt: Optional[bool] = Field( + log_prompt: bool = Field( False, description=("Enable prompt logging (default: False)."), ) - log_generation_params: Optional[bool] = Field( + log_generation_params: bool = Field( False, description=("Enable generation parameter logging (default: False)."), ) - log_requests: Optional[bool] = Field( + log_requests: bool = Field( False, description=( "Enable request logging (default: False).\n" @@ -129,22 +126,21 @@ class ModelConfig(BaseConfigModel): between initial and API loads """ - # TODO: convert this to a pathlib.path? - model_dir: str = Field( + model_dir: DirectoryPath = Field( "models", description=( "Directory to look for models (default: models).\n" "Windows users, do NOT put this path in quotes!" ), ) - inline_model_loading: Optional[bool] = Field( + inline_model_loading: bool = Field( False, description=( "Allow direct loading of models " "from a completion or chat completion request (default: False)." ), ) - use_dummy_models: Optional[bool] = Field( + use_dummy_models: bool = Field( False, description=( "Sends dummy model names when the models endpoint is queried.\n" @@ -186,7 +182,7 @@ class ModelConfig(BaseConfigModel): ), ge=0, ) - tensor_parallel: Optional[bool] = Field( + tensor_parallel: bool = Field( False, description=( "Load model with tensor parallelism.\n" @@ -194,7 +190,7 @@ class ModelConfig(BaseConfigModel): "This ignores the gpu_split_auto value." ), ) - gpu_split_auto: Optional[bool] = Field( + gpu_split_auto: bool = Field( True, description=( "Automatically allocate resources to GPUs (default: True).\n" @@ -215,7 +211,7 @@ class ModelConfig(BaseConfigModel): "Used with tensor parallelism." ), ) - rope_scale: Optional[float] = Field( + rope_scale: float = Field( 1.0, description=( "Rope scale (default: 1.0).\n" @@ -233,7 +229,7 @@ class ModelConfig(BaseConfigModel): "or auto-calculate." ), ) - cache_mode: Optional[CACHE_SIZES] = Field( + cache_mode: CACHE_SIZES = Field( "FP16", description=( "Enable different cache modes for VRAM savings (default: FP16).\n" @@ -250,7 +246,7 @@ class ModelConfig(BaseConfigModel): multiple_of=256, gt=0, ) - chunk_size: Optional[int] = Field( + chunk_size: int = Field( 2048, description=( "Chunk size for prompt ingestion (default: 2048).\n" @@ -290,7 +286,7 @@ class ModelConfig(BaseConfigModel): ), ge=1, ) - fasttensors: Optional[bool] = Field( + fasttensors: bool = Field( False, description=( "Enables fasttensors to possibly increase model loading speeds " @@ -308,8 +304,7 @@ class DraftModelConfig(BaseConfigModel): This will use more VRAM! """ - # TODO: convert this to a pathlib.path? - draft_model_dir: Optional[str] = Field( + draft_model_dir: DirectoryPath = Field( "models", description=("Directory to look for draft models (default: models)"), ) @@ -320,7 +315,7 @@ class DraftModelConfig(BaseConfigModel): "Ensure the model is in the model directory." ), ) - draft_rope_scale: Optional[float] = Field( + draft_rope_scale: float = Field( 1.0, description=( "Rope scale for draft models (default: 1.0).\n" @@ -337,7 +332,7 @@ class DraftModelConfig(BaseConfigModel): "or auto-calculate." ), ) - draft_cache_mode: Optional[CACHE_SIZES] = Field( + draft_cache_mode: CACHE_SIZES = Field( "FP16", description=( "Cache mode for draft models to save VRAM (default: FP16).\n" @@ -357,7 +352,7 @@ class LoraConfig(BaseConfigModel): """Options for Loras""" # TODO: convert this to a pathlib.path? - lora_dir: Optional[str] = Field( + lora_dir: DirectoryPath = Field( "loras", description=("Directory to look for LoRAs (default: loras).") ) loras: Optional[List[LoraInstanceModel]] = Field( @@ -379,12 +374,11 @@ class EmbeddingsConfig(BaseConfigModel): Install it via "pip install .[extras]" """ - # TODO: convert this to a pathlib.path? - embedding_model_dir: Optional[str] = Field( + embedding_model_dir: DirectoryPath = Field( "models", description=("Directory to look for embedding models (default: models)."), ) - embeddings_device: Optional[Literal["cpu", "auto", "cuda"]] = Field( + embeddings_device: Literal["cpu", "auto", "cuda"] = Field( "cpu", description=( "Device to load embedding models on (default: cpu).\n" @@ -416,7 +410,7 @@ class SamplingConfig(BaseConfigModel): class DeveloperConfig(BaseConfigModel): """Options for development and experimentation""" - unsafe_launch: Optional[bool] = Field( + unsafe_launch: bool = Field( False, description=( "Skip Exllamav2 version check (default: False).\n" @@ -424,13 +418,13 @@ class DeveloperConfig(BaseConfigModel): "than enabling this flag." ), ) - disable_request_streaming: Optional[bool] = Field( + disable_request_streaming: bool = Field( False, description=("Disable API request streaming (default: False).") ) - cuda_malloc_backend: Optional[bool] = Field( + cuda_malloc_backend: bool = Field( False, description=("Enable the torch CUDA malloc backend (default: False).") ) - uvloop: Optional[bool] = Field( + uvloop: bool = Field( False, description=( "Run asyncio using Uvloop or Winloop which can improve performance.\n" @@ -438,7 +432,7 @@ class DeveloperConfig(BaseConfigModel): "turn this off." ), ) - realtime_process_priority: Optional[bool] = Field( + realtime_process_priority: bool = Field( False, description=( "Set process to use a higher priority.\n" @@ -451,31 +445,21 @@ class DeveloperConfig(BaseConfigModel): class TabbyConfigModel(BaseModel): """Base model for a TabbyConfig.""" - config: Optional[ConfigOverrideConfig] = Field( + config: ConfigOverrideConfig = Field( default_factory=ConfigOverrideConfig.model_construct ) - network: Optional[NetworkConfig] = Field( - default_factory=NetworkConfig.model_construct - ) - logging: Optional[LoggingConfig] = Field( - default_factory=LoggingConfig.model_construct - ) - model: Optional[ModelConfig] = Field(default_factory=ModelConfig.model_construct) - draft_model: Optional[DraftModelConfig] = Field( + network: NetworkConfig = Field(default_factory=NetworkConfig.model_construct) + logging: LoggingConfig = Field(default_factory=LoggingConfig.model_construct) + model: ModelConfig = Field(default_factory=ModelConfig.model_construct) + draft_model: DraftModelConfig = Field( default_factory=DraftModelConfig.model_construct ) - lora: Optional[LoraConfig] = Field(default_factory=LoraConfig.model_construct) - embeddings: Optional[EmbeddingsConfig] = Field( + lora: LoraConfig = Field(default_factory=LoraConfig.model_construct) + embeddings: EmbeddingsConfig = Field( default_factory=EmbeddingsConfig.model_construct ) - sampling: Optional[SamplingConfig] = Field( - default_factory=SamplingConfig.model_construct - ) - developer: Optional[DeveloperConfig] = Field( - default_factory=DeveloperConfig.model_construct - ) - actions: Optional[UtilityActions] = Field( - default_factory=UtilityActions.model_construct - ) + sampling: SamplingConfig = Field(default_factory=SamplingConfig.model_construct) + developer: DeveloperConfig = Field(default_factory=DeveloperConfig.model_construct) + actions: UtilityActions = Field(default_factory=UtilityActions.model_construct) model_config = ConfigDict(validate_assignment=True, protected_namespaces=()) diff --git a/common/downloader.py b/common/downloader.py index 6813e0d8..29158427 100644 --- a/common/downloader.py +++ b/common/downloader.py @@ -76,9 +76,9 @@ def _get_download_folder(repo_id: str, repo_type: str, folder_name: Optional[str """Gets the download folder for the repo.""" if repo_type == "lora": - download_path = pathlib.Path(config.lora.lora_dir) + download_path = config.lora.lora_dir else: - download_path = pathlib.Path(config.model.model_dir) + download_path = config.model.model_dir download_path = download_path / (folder_name or repo_id.split("/")[-1]) return download_path diff --git a/common/logger.py b/common/logger.py index f21ab098..c4d5795d 100644 --- a/common/logger.py +++ b/common/logger.py @@ -61,7 +61,7 @@ def _log_formatter(record: dict): message = unwrap(record.get("message"), "") # Replace once loguru allows for turning off str.format - message = message.replace("{", "{{").replace("}", "}}").replace("<", "\<") + message = message.replace(r"{", r"{{").replace(r"}", r"}}").replace(r"<", r"\<") # Escape markup tags from Rich message = escape(message) diff --git a/common/tabby_config.py b/common/tabby_config.py index d41cc640..14356c4e 100644 --- a/common/tabby_config.py +++ b/common/tabby_config.py @@ -175,11 +175,12 @@ def _from_environment(self): def generate_config_file( - model: BaseModel = None, - filename: str = "config_sample.yml", + model: Optional[BaseModel] = None, + filename: Optional[pathlib.Path] = None, ) -> None: """Creates a config.yml file from Pydantic models.""" + file = unwrap(filename, "config_sample.yml") schema = unwrap(model, TabbyConfigModel()) preamble = """ # Sample YAML file for configuration. @@ -193,7 +194,7 @@ def generate_config_file( yaml_content = pydantic_model_to_yaml(schema) - with open(filename, "w") as f: + with open(file, "w") as f: f.write(dedent(preamble).lstrip()) yaml.dump(yaml_content, f) diff --git a/common/utils.py b/common/utils.py index 97ecaf70..8593170f 100644 --- a/common/utils.py +++ b/common/utils.py @@ -1,12 +1,12 @@ """Common utility functions""" from types import NoneType -from typing import Dict, Optional, Type, Union, get_args, get_origin, TypeVar +from typing import Dict, Type, Union, get_args, get_origin, TypeVar T = TypeVar("T") -def unwrap(wrapped: Optional[T], default: T = None) -> T: +def unwrap(wrapped: T, default: T) -> T: """Unwrap function for Optionals.""" if wrapped is None: return default diff --git a/endpoints/core/router.py b/endpoints/core/router.py index 2c60cd77..28cab98d 100644 --- a/endpoints/core/router.py +++ b/endpoints/core/router.py @@ -71,13 +71,10 @@ async def list_models(request: Request) -> ModelList: Requires an admin key to see all models. """ - model_dir = config.model.model_dir - model_path = pathlib.Path(model_dir) - - draft_model_dir = config.draft_model.draft_model_dir - if get_key_permission(request) == "admin": - models = get_model_list(model_path.resolve(), draft_model_dir) + models = get_model_list( + config.model.model_dir, config.draft_model.draft_model_dir + ) else: models = await get_current_model_list() @@ -110,7 +107,7 @@ async def list_draft_models(request: Request) -> ModelList: draft_model_dir = config.draft_model.draft_model_dir draft_model_path = pathlib.Path(draft_model_dir) - models = get_model_list(draft_model_path.resolve()) + models = get_model_list(draft_model_path) else: models = await get_current_model_list(model_type="draft") @@ -278,7 +275,7 @@ async def list_embedding_models(request: Request) -> ModelList: embedding_model_dir = config.embeddings.embedding_model_dir embedding_model_path = pathlib.Path(embedding_model_dir) - models = get_model_list(embedding_model_path.resolve()) + models = get_model_list(embedding_model_path) else: models = await get_current_model_list(model_type="embedding") diff --git a/endpoints/core/utils/model.py b/endpoints/core/utils/model.py index d151fdd1..917b87cb 100644 --- a/endpoints/core/utils/model.py +++ b/endpoints/core/utils/model.py @@ -6,6 +6,7 @@ from common.networking import get_generator_error, handle_request_disconnect from common.tabby_config import config from common.utils import unwrap +from common.model import ModelType from endpoints.core.types.model import ( ModelCard, ModelCardParameters, @@ -15,13 +16,17 @@ ) -def get_model_list(model_path: pathlib.Path, draft_model_path: Optional[str] = None): +def get_model_list( + model_path: pathlib.Path, draft_model_path: Optional[pathlib.Path] = None +): """Get the list of models from the provided path.""" # Convert the provided draft model path to a pathlib path for # equality comparisons + if model_path: + model_path = model_path.resolve() if draft_model_path: - draft_model_path = pathlib.Path(draft_model_path).resolve() + draft_model_path = draft_model_path.resolve() model_card_list = ModelList() for path in model_path.iterdir(): @@ -33,7 +38,7 @@ def get_model_list(model_path: pathlib.Path, draft_model_path: Optional[str] = N return model_card_list -async def get_current_model_list(model_type: str = "model"): +async def get_current_model_list(model_type: ModelType): """ Gets the current model in list format and with path only. @@ -45,13 +50,13 @@ async def get_current_model_list(model_type: str = "model"): # Make sure the model container exists match model_type: - case "model": + case ModelType.MODEL: if model.container: model_path = model.container.model_dir - case "draft": + case ModelType.DRAFT: if model.container: model_path = model.container.draft_model_dir - case "embedding": + case ModelType.EMBEDDING: if model.embeddings_container: model_path = model.embeddings_container.model_dir diff --git a/main.py b/main.py index 06db5d53..b01ce6a9 100644 --- a/main.py +++ b/main.py @@ -47,7 +47,7 @@ async def entrypoint_async(): port = fallback_port # Initialize auth keys - await load_auth_keys(config.network.disable_auth) + await load_auth_keys() gen_logging.broadcast_status() From 565980c100f64e5f95731d6bd395ff752e3f987a Mon Sep 17 00:00:00 2001 From: TerminalMan <84923604+SecretiveShell@users.noreply.github.com> Date: Tue, 24 Sep 2024 15:35:12 +0100 Subject: [PATCH 2/9] fix path types to not require existing files --- common/config_models.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/common/config_models.py b/common/config_models.py index a7fea31e..0abde8c7 100644 --- a/common/config_models.py +++ b/common/config_models.py @@ -8,7 +8,7 @@ field_validator, ) from typing import List, Literal, Optional, Union - +from pathlib import Path CACHE_SIZES = Literal["FP16", "Q8", "Q6", "Q4"] @@ -42,13 +42,13 @@ class UtilityActions(BaseConfigModel): # YAML export options export_config: bool = Field(False, description="generate a template config file") - config_export_path: FilePath = Field( + config_export_path: Path = Field( "config_sample.yml", description="path to export configuration file to" ) # OpenAPI JSON export options export_openapi: bool = Field(False, description="export openapi schema files") - openapi_export_path: FilePath = Field( + openapi_export_path: Path = Field( "openapi.json", description="path to export openapi schema to" ) From 6b8151c81c26ea3da166e7936f166a67048a2652 Mon Sep 17 00:00:00 2001 From: TerminalMan <84923604+SecretiveShell@users.noreply.github.com> Date: Wed, 25 Sep 2024 08:31:46 +0100 Subject: [PATCH 3/9] progress on conversion to pydantic --- backends/exllamav2/model.py | 67 +++++++++----- backends/exllamav2/types.py | 177 ++++++++++++++++++++++++++++++++++++ common/auth.py | 11 +-- common/config_models.py | 170 +--------------------------------- main.py | 4 +- 5 files changed, 230 insertions(+), 199 deletions(-) create mode 100644 backends/exllamav2/types.py diff --git a/backends/exllamav2/model.py b/backends/exllamav2/model.py index 8582b76f..6466ba9d 100644 --- a/backends/exllamav2/model.py +++ b/backends/exllamav2/model.py @@ -103,7 +103,27 @@ class ExllamaV2Container: load_condition: asyncio.Condition = asyncio.Condition() @classmethod - async def create(cls, model_directory: pathlib.Path, quiet=False, **kwargs): + async def create( + cls, + model_directory: pathlib.Path, + quiet=False, + draft=None, + cache_mode="FP16", + gpu_split_auto=True, + tensor_parallel=False, + gpu_split=None, + autosplit_reserve=None, + override_base_seq_len=None, + max_seq_len=None, + rope_scale=None, + rope_alpha="auto", + fasttensors=False, + max_batch_size=None, + cache_size=None, + prompt_template=None, + num_experts_per_token=None, + chunk_size=2048, + ): """ Primary asynchronous initializer for model container. @@ -130,7 +150,7 @@ async def create(cls, model_directory: pathlib.Path, quiet=False, **kwargs): self.config.arch_compat_overrides() # Prepare the draft model config if necessary - draft_args = unwrap(kwargs.get("draft"), {}) + draft_args = unwrap(draft, dict()) draft_model_name = draft_args.get("draft_model_name") enable_draft = draft_args and draft_model_name @@ -171,18 +191,21 @@ async def create(cls, model_directory: pathlib.Path, quiet=False, **kwargs): ) # Apply a model's config overrides while respecting user settings - kwargs = await self.set_model_overrides(**kwargs) + + # FIXME: THIS IS BROKEN!!! + # kwargs do not exist now + # should be investigated after the models have pydantic stuff + # kwargs = await self.set_model_overrides(**kwargs) # MARK: User configuration # Get cache mode - self.cache_mode = unwrap(kwargs.get("cache_mode"), "FP16") + self.cache_mode = unwrap(cache_mode, "FP16") # Turn off GPU split if the user is using 1 GPU gpu_count = torch.cuda.device_count() - gpu_split_auto = unwrap(kwargs.get("gpu_split_auto"), True) - use_tp = unwrap(kwargs.get("tensor_parallel"), False) - gpu_split = kwargs.get("gpu_split") + gpu_split_auto = unwrap(gpu_split_auto, True) + use_tp = unwrap(tensor_parallel, False) gpu_device_list = list(range(0, gpu_count)) # Set GPU split options @@ -211,9 +234,7 @@ async def create(cls, model_directory: pathlib.Path, quiet=False, **kwargs): # Otherwise fallback to autosplit settings self.gpu_split_auto = gpu_split_auto - autosplit_reserve_megabytes = unwrap( - kwargs.get("autosplit_reserve"), [96] - ) + autosplit_reserve_megabytes = unwrap(autosplit_reserve, [96]) # Reserve VRAM for each GPU self.autosplit_reserve = [ @@ -225,7 +246,6 @@ async def create(cls, model_directory: pathlib.Path, quiet=False, **kwargs): self.config.max_output_len = 16 # Then override the base_seq_len if present - override_base_seq_len = kwargs.get("override_base_seq_len") if override_base_seq_len: self.config.max_seq_len = override_base_seq_len @@ -234,28 +254,26 @@ async def create(cls, model_directory: pathlib.Path, quiet=False, **kwargs): base_seq_len = self.config.max_seq_len # Set the target seq len if present - target_max_seq_len = kwargs.get("max_seq_len") + target_max_seq_len = max_seq_len if target_max_seq_len: self.config.max_seq_len = target_max_seq_len # Set the rope scale - self.config.scale_pos_emb = unwrap( - kwargs.get("rope_scale"), self.config.scale_pos_emb - ) + self.config.scale_pos_emb = unwrap(rope_scale, self.config.scale_pos_emb) # Sets rope alpha value. # Automatically calculate if unset or defined as an "auto" literal. - rope_alpha = unwrap(kwargs.get("rope_alpha"), "auto") + rope_alpha = unwrap(rope_alpha, "auto") if rope_alpha == "auto": self.config.scale_alpha_value = self.calculate_rope_alpha(base_seq_len) else: self.config.scale_alpha_value = rope_alpha # Enable fasttensors loading if present - self.config.fasttensors = unwrap(kwargs.get("fasttensors"), False) + self.config.fasttensors = unwrap(fasttensors, False) # Set max batch size to the config override - self.max_batch_size = unwrap(kwargs.get("max_batch_size")) + self.max_batch_size = max_batch_size # Check whether the user's configuration supports flash/paged attention # Also check if exl2 has disabled flash attention @@ -272,7 +290,7 @@ async def create(cls, model_directory: pathlib.Path, quiet=False, **kwargs): # Set k/v cache size # cache_size is only relevant when paged mode is enabled if self.paged: - cache_size = unwrap(kwargs.get("cache_size"), self.config.max_seq_len) + cache_size = unwrap(cache_size, self.config.max_seq_len) if cache_size < self.config.max_seq_len: logger.warning( @@ -314,7 +332,7 @@ async def create(cls, model_directory: pathlib.Path, quiet=False, **kwargs): # Try to set prompt template self.prompt_template = await self.find_prompt_template( - kwargs.get("prompt_template"), model_directory + prompt_template, model_directory ) # Catch all for template lookup errors @@ -329,12 +347,11 @@ async def create(cls, model_directory: pathlib.Path, quiet=False, **kwargs): ) # Set num of experts per token if provided - num_experts_override = kwargs.get("num_experts_per_token") - if num_experts_override: - self.config.num_experts_per_token = kwargs.get("num_experts_per_token") + if num_experts_per_token: + self.config.num_experts_per_token = num_experts_per_token # Make sure chunk size is >= 16 and <= max seq length - user_chunk_size = unwrap(kwargs.get("chunk_size"), 2048) + user_chunk_size = unwrap(chunk_size, 2048) chunk_size = sorted((16, user_chunk_size, self.config.max_seq_len))[1] self.config.max_input_len = chunk_size self.config.max_attention_size = chunk_size**2 @@ -342,7 +359,7 @@ async def create(cls, model_directory: pathlib.Path, quiet=False, **kwargs): # Set user-configured draft model values if enable_draft: # Fetch from the updated kwargs - draft_args = unwrap(kwargs.get("draft"), {}) + draft_args = unwrap(draft, {}) self.draft_config.max_seq_len = self.config.max_seq_len diff --git a/backends/exllamav2/types.py b/backends/exllamav2/types.py new file mode 100644 index 00000000..18ceea22 --- /dev/null +++ b/backends/exllamav2/types.py @@ -0,0 +1,177 @@ +from typing import List, Literal, Optional, Union +from pydantic import BaseModel, ConfigDict, Field + +CACHE_SIZES = Literal["FP16", "Q8", "Q6", "Q4"] + +class DraftModelInstanceConfig(BaseModel): + draft_model_name: Optional[str] = Field( + None, + description=( + "An initial draft model to load.\n" + "Ensure the model is in the model directory." + ), + ) + draft_rope_scale: float = Field( + 1.0, + description=( + "Rope scale for draft models (default: 1.0).\n" + "Same as compress_pos_emb.\n" + "Use if the draft model was trained on long context with rope." + ), + ) + draft_rope_alpha: Optional[float] = Field( + None, + description=( + "Rope alpha for draft models (default: None).\n" + 'Same as alpha_value. Set to "auto" to auto-calculate.\n' + "Leaving this value blank will either pull from the model " + "or auto-calculate." + ), + ) + draft_cache_mode: CACHE_SIZES = Field( + "FP16", + description=( + "Cache mode for draft models to save VRAM (default: FP16).\n" + f"Possible values: {str(CACHE_SIZES)[15:-1]}." + ), + ) + +class ModelInstanceConfig(BaseModel): + """ + Options for model overrides and loading + Please read the comments to understand how arguments are handled + between initial and API loads + """ + + model_name: Optional[str] = Field( + None, + description=( + "An initial model to load.\n" + "Make sure the model is located in the model directory!\n" + "REQUIRED: This must be filled out to load a model on startup." + ), + ) + max_seq_len: Optional[int] = Field( + None, + description=( + "Max sequence length (default: Empty).\n" + "Fetched from the model's base sequence length in config.json by default." + ), + ge=0, + ) + override_base_seq_len: Optional[int] = Field( + None, + description=( + "Overrides base model context length (default: Empty).\n" + "WARNING: Don't set this unless you know what you're doing!\n" + "Again, do NOT use this for configuring context length, " + "use max_seq_len above ^" + ), + ge=0, + ) + tensor_parallel: bool = Field( + False, + description=( + "Load model with tensor parallelism.\n" + "Falls back to autosplit if GPU split isn't provided.\n" + "This ignores the gpu_split_auto value." + ), + ) + gpu_split_auto: bool = Field( + True, + description=( + "Automatically allocate resources to GPUs (default: True).\n" + "Not parsed for single GPU users." + ), + ) + autosplit_reserve: List[int] = Field( + [96], + description=( + "Reserve VRAM used for autosplit loading (default: 96 MB on GPU 0).\n" + "Represented as an array of MB per GPU." + ), + ) + gpu_split: List[float] = Field( + default_factory=list, + description=( + "An integer array of GBs of VRAM to split between GPUs (default: []).\n" + "Used with tensor parallelism." + ), + ) + rope_scale: float = Field( + 1.0, + description=( + "Rope scale (default: 1.0).\n" + "Same as compress_pos_emb.\n" + "Use if the model was trained on long context with rope.\n" + "Leave blank to pull the value from the model." + ), + ) + rope_alpha: Optional[Union[float, Literal["auto"]]] = Field( + None, + description=( + "Rope alpha (default: None).\n" + 'Same as alpha_value. Set to "auto" to auto-calculate.\n' + "Leaving this value blank will either pull from the model " + "or auto-calculate." + ), + ) + cache_mode: CACHE_SIZES = Field( + "FP16", + description=( + "Enable different cache modes for VRAM savings (default: FP16).\n" + f"Possible values: {str(CACHE_SIZES)[15:-1]}." + ), + ) + cache_size: Optional[int] = Field( + None, + description=( + "Size of the prompt cache to allocate (default: max_seq_len).\n" + "Must be a multiple of 256 and can't be less than max_seq_len.\n" + "For CFG, set this to 2 * max_seq_len." + ), + multiple_of=256, + gt=0, + ) + chunk_size: int = Field( + 2048, + description=( + "Chunk size for prompt ingestion (default: 2048).\n" + "A lower value reduces VRAM usage but decreases ingestion speed.\n" + "NOTE: Effects vary depending on the model.\n" + "An ideal value is between 512 and 4096." + ), + gt=0, + ) + max_batch_size: Optional[int] = Field( + None, + description=( + "Set the maximum number of prompts to process at one time " + "(default: None/Automatic).\n" + "Automatically calculated if left blank.\n" + "NOTE: Only available for Nvidia ampere (30 series) and above GPUs." + ), + ge=1, + ) + prompt_template: Optional[str] = Field( + None, + description=( + "Set the prompt template for this model. (default: None)\n" + "If empty, attempts to look for the model's chat template.\n" + "If a model contains multiple templates in its tokenizer_config.json,\n" + "set prompt_template to the name of the template you want to use.\n" + "NOTE: Only works with chat completion message lists!" + ), + ) + num_experts_per_token: Optional[int] = Field( + None, + description=( + "Number of experts to use per token.\n" + "Fetched from the model's config.json if empty.\n" + "NOTE: For MoE models only.\n" + "WARNING: Don't set this unless you know what you're doing!" + ), + ge=1, + ) + + model_config = ConfigDict(protected_namespaces=()) \ No newline at end of file diff --git a/common/auth.py b/common/auth.py index 773b59b3..a4945483 100644 --- a/common/auth.py +++ b/common/auth.py @@ -3,12 +3,13 @@ application, it should be fine. """ +from functools import partial import aiofiles import io import secrets from ruamel.yaml import YAML from fastapi import Header, HTTPException, Request -from pydantic import BaseModel +from pydantic import BaseModel, Field, SecretStr from loguru import logger from typing import Optional @@ -25,8 +26,8 @@ class AuthKeys(BaseModel): to verify if a given key matches the stored 'api_key' or 'admin_key'. """ - api_key: str - admin_key: str + api_key: SecretStr = Field(default_factory=partial(secrets.token_hex, 16)) + admin_key: SecretStr = Field(default_factory=partial(secrets.token_hex, 16)) def verify_key(self, test_key: str, key_type: str): """Verify if a given key matches the stored key.""" @@ -65,9 +66,7 @@ async def load_auth_keys(): auth_keys_dict = yaml.load(contents) AUTH_KEYS = AuthKeys.model_validate(auth_keys_dict) except FileNotFoundError: - new_auth_keys = AuthKeys( - api_key=secrets.token_hex(16), admin_key=secrets.token_hex(16) - ) + new_auth_keys = AuthKeys() AUTH_KEYS = new_auth_keys async with aiofiles.open("api_tokens.yml", "w", encoding="utf8") as auth_file: diff --git a/common/config_models.py b/common/config_models.py index 0abde8c7..6b3f55e3 100644 --- a/common/config_models.py +++ b/common/config_models.py @@ -7,11 +7,10 @@ PrivateAttr, field_validator, ) -from typing import List, Literal, Optional, Union +from typing import List, Literal, Optional from pathlib import Path -CACHE_SIZES = Literal["FP16", "Q8", "Q6", "Q4"] - +from backends.exllamav2.types import DraftModelInstanceConfig, ModelInstanceConfig class Metadata(BaseModel): """metadata model for config options""" @@ -119,7 +118,7 @@ class LoggingConfig(BaseConfigModel): ) -class ModelConfig(BaseConfigModel): +class ModelConfig(BaseConfigModel, ModelInstanceConfig): """ Options for model overrides and loading Please read the comments to understand how arguments are handled @@ -147,14 +146,6 @@ class ModelConfig(BaseConfigModel): "Enable this if the client is looking for specific OAI models." ), ) - model_name: Optional[str] = Field( - None, - description=( - "An initial model to load.\n" - "Make sure the model is located in the model directory!\n" - "REQUIRED: This must be filled out to load a model on startup." - ), - ) use_as_default: List[str] = Field( default_factory=list, description=( @@ -164,128 +155,6 @@ class ModelConfig(BaseConfigModel): "Example: ['max_seq_len', 'cache_mode']." ), ) - max_seq_len: Optional[int] = Field( - None, - description=( - "Max sequence length (default: Empty).\n" - "Fetched from the model's base sequence length in config.json by default." - ), - ge=0, - ) - override_base_seq_len: Optional[int] = Field( - None, - description=( - "Overrides base model context length (default: Empty).\n" - "WARNING: Don't set this unless you know what you're doing!\n" - "Again, do NOT use this for configuring context length, " - "use max_seq_len above ^" - ), - ge=0, - ) - tensor_parallel: bool = Field( - False, - description=( - "Load model with tensor parallelism.\n" - "Falls back to autosplit if GPU split isn't provided.\n" - "This ignores the gpu_split_auto value." - ), - ) - gpu_split_auto: bool = Field( - True, - description=( - "Automatically allocate resources to GPUs (default: True).\n" - "Not parsed for single GPU users." - ), - ) - autosplit_reserve: List[int] = Field( - [96], - description=( - "Reserve VRAM used for autosplit loading (default: 96 MB on GPU 0).\n" - "Represented as an array of MB per GPU." - ), - ) - gpu_split: List[float] = Field( - default_factory=list, - description=( - "An integer array of GBs of VRAM to split between GPUs (default: []).\n" - "Used with tensor parallelism." - ), - ) - rope_scale: float = Field( - 1.0, - description=( - "Rope scale (default: 1.0).\n" - "Same as compress_pos_emb.\n" - "Use if the model was trained on long context with rope.\n" - "Leave blank to pull the value from the model." - ), - ) - rope_alpha: Optional[Union[float, Literal["auto"]]] = Field( - None, - description=( - "Rope alpha (default: None).\n" - 'Same as alpha_value. Set to "auto" to auto-calculate.\n' - "Leaving this value blank will either pull from the model " - "or auto-calculate." - ), - ) - cache_mode: CACHE_SIZES = Field( - "FP16", - description=( - "Enable different cache modes for VRAM savings (default: FP16).\n" - f"Possible values: {str(CACHE_SIZES)[15:-1]}." - ), - ) - cache_size: Optional[int] = Field( - None, - description=( - "Size of the prompt cache to allocate (default: max_seq_len).\n" - "Must be a multiple of 256 and can't be less than max_seq_len.\n" - "For CFG, set this to 2 * max_seq_len." - ), - multiple_of=256, - gt=0, - ) - chunk_size: int = Field( - 2048, - description=( - "Chunk size for prompt ingestion (default: 2048).\n" - "A lower value reduces VRAM usage but decreases ingestion speed.\n" - "NOTE: Effects vary depending on the model.\n" - "An ideal value is between 512 and 4096." - ), - gt=0, - ) - max_batch_size: Optional[int] = Field( - None, - description=( - "Set the maximum number of prompts to process at one time " - "(default: None/Automatic).\n" - "Automatically calculated if left blank.\n" - "NOTE: Only available for Nvidia ampere (30 series) and above GPUs." - ), - ge=1, - ) - prompt_template: Optional[str] = Field( - None, - description=( - "Set the prompt template for this model. (default: None)\n" - "If empty, attempts to look for the model's chat template.\n" - "If a model contains multiple templates in its tokenizer_config.json,\n" - "set prompt_template to the name of the template you want to use.\n" - "NOTE: Only works with chat completion message lists!" - ), - ) - num_experts_per_token: Optional[int] = Field( - None, - description=( - "Number of experts to use per token.\n" - "Fetched from the model's config.json if empty.\n" - "NOTE: For MoE models only.\n" - "WARNING: Don't set this unless you know what you're doing!" - ), - ge=1, - ) fasttensors: bool = Field( False, description=( @@ -298,7 +167,7 @@ class ModelConfig(BaseConfigModel): model_config = ConfigDict(protected_namespaces=()) -class DraftModelConfig(BaseConfigModel): +class DraftModelConfig(BaseConfigModel, DraftModelInstanceConfig): """ Options for draft models (speculative decoding) This will use more VRAM! @@ -308,37 +177,6 @@ class DraftModelConfig(BaseConfigModel): "models", description=("Directory to look for draft models (default: models)"), ) - draft_model_name: Optional[str] = Field( - None, - description=( - "An initial draft model to load.\n" - "Ensure the model is in the model directory." - ), - ) - draft_rope_scale: float = Field( - 1.0, - description=( - "Rope scale for draft models (default: 1.0).\n" - "Same as compress_pos_emb.\n" - "Use if the draft model was trained on long context with rope." - ), - ) - draft_rope_alpha: Optional[float] = Field( - None, - description=( - "Rope alpha for draft models (default: None).\n" - 'Same as alpha_value. Set to "auto" to auto-calculate.\n' - "Leaving this value blank will either pull from the model " - "or auto-calculate." - ), - ) - draft_cache_mode: CACHE_SIZES = Field( - "FP16", - description=( - "Cache mode for draft models to save VRAM (default: FP16).\n" - f"Possible values: {str(CACHE_SIZES)[15:-1]}." - ), - ) class LoraInstanceModel(BaseConfigModel): diff --git a/main.py b/main.py index b01ce6a9..eb4857d8 100644 --- a/main.py +++ b/main.py @@ -19,7 +19,7 @@ from endpoints.server import start_api from backends.exllamav2.version import check_exllama_version - +from backends.exllamav2.types import ModelInstanceConfig async def entrypoint_async(): """Async entry function for program startup""" @@ -69,7 +69,7 @@ async def entrypoint_async(): # TODO: remove model_dump() await model.load_model( model_path.resolve(), - **config.model.model_dump(), + **ModelInstanceConfig.model_validate(**config.model.model_dump()).model_dump(), draft=config.draft_model.model_dump(), ) From e3308cc91a89dee7787576a6c44b0887cf05d816 Mon Sep 17 00:00:00 2001 From: Jake <84923604+SecretiveShell@users.noreply.github.com> Date: Wed, 25 Sep 2024 12:31:11 +0100 Subject: [PATCH 4/9] fix loading files --- common/config_models.py | 20 ++++++++++---------- common/tabby_config.py | 6 ++++-- common/utils.py | 8 ++++++-- main.py | 3 ++- 4 files changed, 22 insertions(+), 15 deletions(-) diff --git a/common/config_models.py b/common/config_models.py index 6b3f55e3..311ce776 100644 --- a/common/config_models.py +++ b/common/config_models.py @@ -284,20 +284,20 @@ class TabbyConfigModel(BaseModel): """Base model for a TabbyConfig.""" config: ConfigOverrideConfig = Field( - default_factory=ConfigOverrideConfig.model_construct + default_factory=ConfigOverrideConfig ) - network: NetworkConfig = Field(default_factory=NetworkConfig.model_construct) - logging: LoggingConfig = Field(default_factory=LoggingConfig.model_construct) - model: ModelConfig = Field(default_factory=ModelConfig.model_construct) + network: NetworkConfig = Field(default_factory=NetworkConfig) + logging: LoggingConfig = Field(default_factory=LoggingConfig) + model: ModelConfig = Field(default_factory=ModelConfig) draft_model: DraftModelConfig = Field( - default_factory=DraftModelConfig.model_construct + default_factory=DraftModelConfig ) - lora: LoraConfig = Field(default_factory=LoraConfig.model_construct) + lora: LoraConfig = Field(default_factory=LoraConfig) embeddings: EmbeddingsConfig = Field( - default_factory=EmbeddingsConfig.model_construct + default_factory=EmbeddingsConfig ) - sampling: SamplingConfig = Field(default_factory=SamplingConfig.model_construct) - developer: DeveloperConfig = Field(default_factory=DeveloperConfig.model_construct) - actions: UtilityActions = Field(default_factory=UtilityActions.model_construct) + sampling: SamplingConfig = Field(default_factory=SamplingConfig) + developer: DeveloperConfig = Field(default_factory=DeveloperConfig) + actions: UtilityActions = Field(default_factory=UtilityActions) model_config = ConfigDict(validate_assignment=True, protected_namespaces=()) diff --git a/common/tabby_config.py b/common/tabby_config.py index 14356c4e..bca08025 100644 --- a/common/tabby_config.py +++ b/common/tabby_config.py @@ -37,9 +37,10 @@ def load(self, arguments: Optional[dict] = None): # This should be less expensive than pruning the entire merged dictionary configs = filter_none_values(configs) merged_config = merge_dicts(*configs) + merged_config = filter_none_values(merged_config) # validate and update config - merged_config_model = TabbyConfigModel.model_validate(merged_config) + merged_config_model = TabbyConfigModel(**merged_config) for field in TabbyConfigModel.model_fields.keys(): value = getattr(merged_config_model, field) setattr(self, field, value) @@ -106,7 +107,8 @@ def _from_file(self, config_path: pathlib.Path): ) # Create a temporary base config model - new_cfg = TabbyConfigModel.model_validate(cfg) + cfg = filter_none_values(cfg) + new_cfg = TabbyConfigModel(**cfg) try: config_path.rename(f"{config_path}.bak") diff --git a/common/utils.py b/common/utils.py index 8593170f..111b9290 100644 --- a/common/utils.py +++ b/common/utils.py @@ -2,11 +2,12 @@ from types import NoneType from typing import Dict, Type, Union, get_args, get_origin, TypeVar +from pydantic import BaseModel T = TypeVar("T") +M = TypeVar("M", bound=BaseModel) - -def unwrap(wrapped: T, default: T) -> T: +def unwrap(wrapped: Type[T], default: Type[T]) -> T: """Unwrap function for Optionals.""" if wrapped is None: return default @@ -85,3 +86,6 @@ def unwrap_optional_type(type_hint) -> Type: return arg return type_hint + +def cast_model(model: BaseModel, new: Type[M]) -> M: + return new(**model.model_dump()) \ No newline at end of file diff --git a/main.py b/main.py index eb4857d8..b10acb23 100644 --- a/main.py +++ b/main.py @@ -16,6 +16,7 @@ from common.networking import is_port_in_use from common.signals import signal_handler from common.tabby_config import config +from common.utils import cast_model from endpoints.server import start_api from backends.exllamav2.version import check_exllama_version @@ -69,7 +70,7 @@ async def entrypoint_async(): # TODO: remove model_dump() await model.load_model( model_path.resolve(), - **ModelInstanceConfig.model_validate(**config.model.model_dump()).model_dump(), + cast_model(config.model, ModelInstanceConfig).model_dump(), draft=config.draft_model.model_dump(), ) From ba4613393ea3ff1563d441b0cc2725c4a7abe4f3 Mon Sep 17 00:00:00 2001 From: Jake <84923604+SecretiveShell@users.noreply.github.com> Date: Wed, 25 Sep 2024 12:33:46 +0100 Subject: [PATCH 5/9] fix print api keys --- common/auth.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/common/auth.py b/common/auth.py index a4945483..a876797b 100644 --- a/common/auth.py +++ b/common/auth.py @@ -76,8 +76,8 @@ async def load_auth_keys(): await auth_file.write(string_stream.getvalue()) logger.info( - f"Your API key is: {AUTH_KEYS.api_key}\n" - f"Your admin key is: {AUTH_KEYS.admin_key}\n\n" + f"Your API key is: {AUTH_KEYS.api_key.get_secret_value()}\n" + f"Your admin key is: {AUTH_KEYS.admin_key.get_secret_value()}\n\n" "If these keys get compromised, make sure to delete api_tokens.yml " "and restart the server. Have fun!" ) From dda6b3a41a3e983df7c266d44f26c35abf9b1880 Mon Sep 17 00:00:00 2001 From: TerminalMan <84923604+SecretiveShell@users.noreply.github.com> Date: Thu, 26 Sep 2024 14:16:29 +0100 Subject: [PATCH 6/9] remove model dump for model load --- backends/exllamav2/model.py | 30 +++++++----- backends/exllamav2/types.py | 12 ++++- common/auth.py | 8 +++- common/config_models.py | 13 ++--- common/model.py | 26 +++++++--- common/utils.py | 4 +- endpoints/OAI/utils/completion.py | 17 ++++--- endpoints/core/router.py | 21 ++++---- endpoints/core/types/model.py | 80 ++++++++----------------------- endpoints/core/utils/model.py | 16 +++---- main.py | 15 ++---- 11 files changed, 108 insertions(+), 134 deletions(-) diff --git a/backends/exllamav2/model.py b/backends/exllamav2/model.py index dffa05db..aa2c1eb6 100644 --- a/backends/exllamav2/model.py +++ b/backends/exllamav2/model.py @@ -43,6 +43,7 @@ hardware_supports_flash_attn, supports_paged_attn, ) +from common.tabby_config import config from common.concurrency import iterate_in_threadpool from common.gen_logging import ( log_generation_params, @@ -105,7 +106,7 @@ class ExllamaV2Container: @classmethod async def create( cls, - model_directory: pathlib.Path, + model_name: str, quiet=False, draft=None, cache_mode="FP16", @@ -137,8 +138,15 @@ async def create( # Initialize config self.config = ExLlamaV2Config() - self.model_dir = model_directory - self.config.model_dir = str(model_directory.resolve()) + + model_path = pathlib.Path(config.model.model_dir) + model_path = model_path / model_name + model_path = model_path.resolve() + if not model_path.exists(): + raise FileNotFoundError(f"Model path {model_path} does not exist.") + + self.model_dir = model_path + self.config.model_dir = str(model_path) # Make the max seq len 4096 before preparing the config # This is a better default than 2048 @@ -175,10 +183,10 @@ async def create( self.draft_config.prepare() # Create the hf_config - self.hf_config = await HuggingFaceConfig.from_file(model_directory) + self.hf_config = await HuggingFaceConfig.from_file(model_path) # Load generation config overrides - generation_config_path = model_directory / "generation_config.json" + generation_config_path = model_path / "generation_config.json" if generation_config_path.exists(): try: self.generation_config = await GenerationConfig.from_file( @@ -332,7 +340,7 @@ async def create( # Try to set prompt template self.prompt_template = await self.find_prompt_template( - prompt_template, model_directory + prompt_template, model_name ) # Catch all for template lookup errors @@ -541,7 +549,7 @@ def progress(loaded_modules: int, total_modules: int) async for _ in self.load_gen(progress_callback): pass - async def load_gen(self, progress_callback=None, **kwargs): + async def load_gen(self, progress_callback=None, skip_wait = False): """Loads a model and streams progress via a generator.""" # Indicate that model load has started @@ -551,7 +559,7 @@ async def load_gen(self, progress_callback=None, **kwargs): self.model_is_loading = True # Wait for existing generation jobs to finish - await self.wait_for_jobs(kwargs.get("skip_wait")) + await self.wait_for_jobs(skip_wait) # Streaming gen for model load progress model_load_generator = self.load_model_sync(progress_callback) @@ -1147,19 +1155,19 @@ async def generate_gen( grammar_handler = ExLlamaV2Grammar() # Add JSON schema filter if it exists - json_schema = unwrap(kwargs.get("json_schema")) + json_schema = kwargs.get("json_schema") if json_schema: grammar_handler.add_json_schema_filter( json_schema, self.model, self.tokenizer ) # Add regex filter if it exists - regex_pattern = unwrap(kwargs.get("regex_pattern")) + regex_pattern = kwargs.get("regex_pattern") if regex_pattern: grammar_handler.add_regex_filter(regex_pattern, self.model, self.tokenizer) # Add EBNF filter if it exists - grammar_string = unwrap(kwargs.get("grammar_string")) + grammar_string = kwargs.get("grammar_string") if grammar_string: grammar_handler.add_ebnf_filter(grammar_string, self.model, self.tokenizer) diff --git a/backends/exllamav2/types.py b/backends/exllamav2/types.py index 18ceea22..3ba44c33 100644 --- a/backends/exllamav2/types.py +++ b/backends/exllamav2/types.py @@ -3,6 +3,7 @@ CACHE_SIZES = Literal["FP16", "Q8", "Q6", "Q4"] + class DraftModelInstanceConfig(BaseModel): draft_model_name: Optional[str] = Field( None, @@ -19,7 +20,7 @@ class DraftModelInstanceConfig(BaseModel): "Use if the draft model was trained on long context with rope." ), ) - draft_rope_alpha: Optional[float] = Field( + draft_rope_alpha: Optional[Union[float, Literal["auto"]]] = Field( None, description=( "Rope alpha for draft models (default: None).\n" @@ -27,6 +28,7 @@ class DraftModelInstanceConfig(BaseModel): "Leaving this value blank will either pull from the model " "or auto-calculate." ), + examples=[1.0], ) draft_cache_mode: CACHE_SIZES = Field( "FP16", @@ -36,6 +38,7 @@ class DraftModelInstanceConfig(BaseModel): ), ) + class ModelInstanceConfig(BaseModel): """ Options for model overrides and loading @@ -58,6 +61,7 @@ class ModelInstanceConfig(BaseModel): "Fetched from the model's base sequence length in config.json by default." ), ge=0, + examples=[16384, 4096, 2048], ) override_base_seq_len: Optional[int] = Field( None, @@ -68,6 +72,7 @@ class ModelInstanceConfig(BaseModel): "use max_seq_len above ^" ), ge=0, + examples=[4096], ) tensor_parallel: bool = Field( False, @@ -106,6 +111,7 @@ class ModelInstanceConfig(BaseModel): "Use if the model was trained on long context with rope.\n" "Leave blank to pull the value from the model." ), + examples=[1.0], ) rope_alpha: Optional[Union[float, Literal["auto"]]] = Field( None, @@ -115,6 +121,7 @@ class ModelInstanceConfig(BaseModel): "Leaving this value blank will either pull from the model " "or auto-calculate." ), + examples=["auto", 1.0], ) cache_mode: CACHE_SIZES = Field( "FP16", @@ -132,6 +139,7 @@ class ModelInstanceConfig(BaseModel): ), multiple_of=256, gt=0, + examples=[4096], ) chunk_size: int = Field( 2048, @@ -174,4 +182,4 @@ class ModelInstanceConfig(BaseModel): ge=1, ) - model_config = ConfigDict(protected_namespaces=()) \ No newline at end of file + model_config = ConfigDict(protected_namespaces=()) diff --git a/common/auth.py b/common/auth.py index a876797b..3a1f26ef 100644 --- a/common/auth.py +++ b/common/auth.py @@ -31,11 +31,15 @@ class AuthKeys(BaseModel): def verify_key(self, test_key: str, key_type: str): """Verify if a given key matches the stored key.""" + if key_type == "admin_key": - return test_key == self.admin_key + return test_key == self.admin_key.get_secret_value() if key_type == "api_key": # Admin keys are valid for all API calls - return test_key == self.api_key or test_key == self.admin_key + return ( + test_key == self.api_key.get_secret_value() + or test_key == self.admin_key.get_secret_value() + ) return False diff --git a/common/config_models.py b/common/config_models.py index 311ce776..13d6de3d 100644 --- a/common/config_models.py +++ b/common/config_models.py @@ -12,6 +12,7 @@ from backends.exllamav2.types import DraftModelInstanceConfig, ModelInstanceConfig + class Metadata(BaseModel): """metadata model for config options""" @@ -283,19 +284,13 @@ class DeveloperConfig(BaseConfigModel): class TabbyConfigModel(BaseModel): """Base model for a TabbyConfig.""" - config: ConfigOverrideConfig = Field( - default_factory=ConfigOverrideConfig - ) + config: ConfigOverrideConfig = Field(default_factory=ConfigOverrideConfig) network: NetworkConfig = Field(default_factory=NetworkConfig) logging: LoggingConfig = Field(default_factory=LoggingConfig) model: ModelConfig = Field(default_factory=ModelConfig) - draft_model: DraftModelConfig = Field( - default_factory=DraftModelConfig - ) + draft_model: DraftModelConfig = Field(default_factory=DraftModelConfig) lora: LoraConfig = Field(default_factory=LoraConfig) - embeddings: EmbeddingsConfig = Field( - default_factory=EmbeddingsConfig - ) + embeddings: EmbeddingsConfig = Field(default_factory=EmbeddingsConfig) sampling: SamplingConfig = Field(default_factory=SamplingConfig) developer: DeveloperConfig = Field(default_factory=DeveloperConfig) actions: UtilityActions = Field(default_factory=UtilityActions) diff --git a/common/model.py b/common/model.py index 87b06adf..8ef7a4ad 100644 --- a/common/model.py +++ b/common/model.py @@ -10,6 +10,7 @@ from loguru import logger from typing import Optional +from backends.exllamav2.types import DraftModelInstanceConfig, ModelInstanceConfig from common.logger import get_loading_progress_bar from common.networking import handle_request_error from common.tabby_config import config @@ -48,7 +49,11 @@ async def unload_model(skip_wait: bool = False, shutdown: bool = False): container = None -async def load_model_gen(model_path: pathlib.Path, **kwargs): +async def load_model_gen( + model: ModelInstanceConfig, + draft: Optional[DraftModelInstanceConfig] = None, + skip_wait: bool = False, +): """Generator to load a model""" global container @@ -56,7 +61,7 @@ async def load_model_gen(model_path: pathlib.Path, **kwargs): if container and container.model: loaded_model_name = container.model_dir.name - if loaded_model_name == model_path.name and container.model_loaded: + if loaded_model_name == model.model_name and container.model_loaded: raise ValueError( f'Model "{loaded_model_name}" is already loaded! Aborting.' ) @@ -65,13 +70,18 @@ async def load_model_gen(model_path: pathlib.Path, **kwargs): await unload_model() # Merge with config defaults - kwargs = {**config.model_defaults, **kwargs} + # FIXME: KWARGS DO NOT EXIST NOW + # kwargs = {**config.model_defaults, **kwargs} # Create a new container - container = await ExllamaV2Container.create(model_path.resolve(), False, **kwargs) + draft = draft or DraftModelInstanceConfig() + + container = await ExllamaV2Container.create( + **model.model_dump(), quiet=False, draft=draft.model_dump() + ) model_type = "draft" if container.draft_config else "model" - load_status = container.load_gen(load_progress, **kwargs) + load_status = container.load_gen(load_progress, skip_wait) progress = get_loading_progress_bar() progress.start() @@ -97,8 +107,10 @@ async def load_model_gen(model_path: pathlib.Path, **kwargs): progress.stop() -async def load_model(model_path: pathlib.Path, **kwargs): - async for _ in load_model_gen(model_path, **kwargs): +async def load_model( + model: ModelInstanceConfig, draft: Optional[DraftModelInstanceConfig] = None +): + async for _ in load_model_gen(model=model, draft=draft): pass diff --git a/common/utils.py b/common/utils.py index 111b9290..17c74b5c 100644 --- a/common/utils.py +++ b/common/utils.py @@ -7,6 +7,7 @@ T = TypeVar("T") M = TypeVar("M", bound=BaseModel) + def unwrap(wrapped: Type[T], default: Type[T]) -> T: """Unwrap function for Optionals.""" if wrapped is None: @@ -87,5 +88,6 @@ def unwrap_optional_type(type_hint) -> Type: return type_hint + def cast_model(model: BaseModel, new: Type[M]) -> M: - return new(**model.model_dump()) \ No newline at end of file + return new(**model.model_dump()) diff --git a/endpoints/OAI/utils/completion.py b/endpoints/OAI/utils/completion.py index c8b02c81..8ec83af0 100644 --- a/endpoints/OAI/utils/completion.py +++ b/endpoints/OAI/utils/completion.py @@ -13,6 +13,7 @@ from loguru import logger +from backends.exllamav2.types import ModelInstanceConfig from common import model from common.auth import get_key_permission from common.networking import ( @@ -138,19 +139,17 @@ async def load_inline_model(model_name: str, request: Request): return - model_path = pathlib.Path(config.model.model_dir) - model_path = model_path / model_name - # Model path doesn't exist - if not model_path.exists(): - logger.warning( - f"Could not find model path {str(model_path)}. Skipping inline model load." - ) + # if not model_path.exists(): + # logger.warning( + # f"Could not find model path {str(model_path)}." + + # "Skipping inline model load." + # ) - return + # return # Load the model - await model.load_model(model_path) + await model.load_model(ModelInstanceConfig(model_name=model_name)) async def stream_generate_completion( diff --git a/endpoints/core/router.py b/endpoints/core/router.py index 28cab98d..8ee00e27 100644 --- a/endpoints/core/router.py +++ b/endpoints/core/router.py @@ -120,7 +120,7 @@ async def load_model(data: ModelLoadRequest) -> ModelLoadResponse: """Loads a model into the model container. This returns an SSE stream.""" # Verify request parameters - if not data.name: + if not data.model_name: error_message = handle_request_error( "A model name was not provided for load.", exc_info=False, @@ -128,10 +128,6 @@ async def load_model(data: ModelLoadRequest) -> ModelLoadResponse: raise HTTPException(400, error_message) - model_path = pathlib.Path(config.model.model_dir) - model_path = model_path / data.name - - draft_model_path = None if data.draft: if not data.draft.draft_model_name: error_message = handle_request_error( @@ -141,18 +137,17 @@ async def load_model(data: ModelLoadRequest) -> ModelLoadResponse: raise HTTPException(400, error_message) - draft_model_path = config.draft_model.draft_model_dir - if not model_path.exists(): - error_message = handle_request_error( - "Could not find the model path for load. Check model name or config.yml?", - exc_info=False, - ).error.message + # if not model_path.exists(): + # error_message = handle_request_error( + # "Could not find the model path for load. Check model name or config.yml?", + # exc_info=False, + # ).error.message - raise HTTPException(400, error_message) + # raise HTTPException(400, error_message) return EventSourceResponse( - stream_model_load(data, model_path, draft_model_path), ping=maxsize + stream_model_load(data), ping=maxsize ) diff --git a/endpoints/core/types/model.py b/endpoints/core/types/model.py index b169162d..2fd9bdf3 100644 --- a/endpoints/core/types/model.py +++ b/endpoints/core/types/model.py @@ -1,9 +1,10 @@ """Contains model card types.""" -from pydantic import BaseModel, Field, ConfigDict +from pydantic import BaseModel, Field, ConfigDict, model_validator from time import time from typing import List, Literal, Optional, Union +from backends.exllamav2.types import DraftModelInstanceConfig, ModelInstanceConfig from common.config_models import LoggingConfig from common.tabby_config import config @@ -44,74 +45,31 @@ class ModelList(BaseModel): data: List[ModelCard] = Field(default_factory=list) -class DraftModelLoadRequest(BaseModel): - """Represents a draft model load request.""" - - # Required - draft_model_name: str - - # Config arguments - draft_rope_scale: Optional[float] = None - draft_rope_alpha: Optional[Union[float, Literal["auto"]]] = Field( - description='Automatically calculated if set to "auto"', - default=None, - examples=[1.0], - ) - draft_cache_mode: Optional[str] = None - - -class ModelLoadRequest(BaseModel): +class ModelLoadRequest(ModelInstanceConfig): """Represents a model load request.""" - # Required - name: str - - # Config arguments - - max_seq_len: Optional[int] = Field( - description="Leave this blank to use the model's base sequence length", - default=None, - examples=[4096], + # These Fields only exist to stop a breaking change + name: Optional[str] = Field( + None, description="model name to load", deprecated="Use model_name instead" ) - override_base_seq_len: Optional[int] = Field( - description=( - "Overrides the model's base sequence length. " "Leave blank if unsure" - ), - default=None, - examples=[4096], + fasttensors: Optional[bool] = Field( + None, + description="ignored, set globally from config.yml", + deprecated="Use model config instead", ) - cache_size: Optional[int] = Field( - description=("Number in tokens, must be greater than or equal to max_seq_len"), - default=None, - examples=[4096], - ) - tensor_parallel: Optional[bool] = None - gpu_split_auto: Optional[bool] = None - autosplit_reserve: Optional[List[float]] = None - gpu_split: Optional[List[float]] = Field( - default=None, - examples=[[24.0, 20.0]], - ) - rope_scale: Optional[float] = Field( - description="Automatically pulled from the model's config if not present", - default=None, - examples=[1.0], - ) - rope_alpha: Optional[Union[float, Literal["auto"]]] = Field( - description='Automatically calculated if set to "auto"', - default=None, - examples=[1.0], - ) - cache_mode: Optional[str] = None - chunk_size: Optional[int] = None - prompt_template: Optional[str] = None - num_experts_per_token: Optional[int] = None - fasttensors: Optional[bool] = None # Non-config arguments - draft: Optional[DraftModelLoadRequest] = None + draft: Optional[DraftModelInstanceConfig] = None skip_queue: Optional[bool] = False + # for the name value + @model_validator(mode="after") + def set_model_name(self): + """Sets the model name.""" + if self.name and self.model_name is None: + self.model_name = self.name + return self + class EmbeddingModelLoadRequest(BaseModel): name: str diff --git a/endpoints/core/utils/model.py b/endpoints/core/utils/model.py index 917b87cb..014b6fb2 100644 --- a/endpoints/core/utils/model.py +++ b/endpoints/core/utils/model.py @@ -2,10 +2,11 @@ from asyncio import CancelledError from typing import Optional +from backends.exllamav2.types import DraftModelInstanceConfig, ModelInstanceConfig from common import model from common.networking import get_generator_error, handle_request_disconnect from common.tabby_config import config -from common.utils import unwrap +from common.utils import cast_model, unwrap from common.model import ModelType from endpoints.core.types.model import ( ModelCard, @@ -99,20 +100,17 @@ def get_current_model(): async def stream_model_load( data: ModelLoadRequest, - model_path: pathlib.Path, - draft_model_path: str, ): """Request generation wrapper for the loading process.""" - # Get trimmed load data - load_data = data.model_dump(exclude_none=True) + load_config = cast_model(data, ModelInstanceConfig) - # Set the draft model path if it exists - if draft_model_path: - load_data["draft"]["draft_model_dir"] = draft_model_path + draft_load_config = ( + cast_model(data.draft, DraftModelInstanceConfig) if data.draft else None + ) load_status = model.load_model_gen( - model_path, skip_wait=data.skip_queue, **load_data + model=load_config, draft=draft_load_config, skip_wait=data.skip_queue ) try: async for module, modules, model_type in load_status: diff --git a/main.py b/main.py index b10acb23..a25a7384 100644 --- a/main.py +++ b/main.py @@ -20,7 +20,8 @@ from endpoints.server import start_api from backends.exllamav2.version import check_exllama_version -from backends.exllamav2.types import ModelInstanceConfig +from backends.exllamav2.types import DraftModelInstanceConfig, ModelInstanceConfig + async def entrypoint_async(): """Async entry function for program startup""" @@ -62,16 +63,10 @@ async def entrypoint_async(): # If an initial model name is specified, create a container # and load the model - model_name = config.model.model_name - if model_name: - model_path = pathlib.Path(config.model.model_dir) - model_path = model_path / model_name - - # TODO: remove model_dump() + if config.model.model_name: await model.load_model( - model_path.resolve(), - cast_model(config.model, ModelInstanceConfig).model_dump(), - draft=config.draft_model.model_dump(), + model=cast_model(config.model, ModelInstanceConfig), + draft=cast_model(config.draft_model, DraftModelInstanceConfig), ) # Load loras after loading the model From 73c6534075c1d6f5c2b7eea3b12e8fc3ab968794 Mon Sep 17 00:00:00 2001 From: TerminalMan <84923604+SecretiveShell@users.noreply.github.com> Date: Thu, 26 Sep 2024 14:17:51 +0100 Subject: [PATCH 7/9] Tree: Format --- backends/exllamav2/model.py | 2 +- common/model.py | 1 - endpoints/core/router.py | 5 +---- endpoints/core/types/model.py | 2 +- 4 files changed, 3 insertions(+), 7 deletions(-) diff --git a/backends/exllamav2/model.py b/backends/exllamav2/model.py index aa2c1eb6..822b129d 100644 --- a/backends/exllamav2/model.py +++ b/backends/exllamav2/model.py @@ -549,7 +549,7 @@ def progress(loaded_modules: int, total_modules: int) async for _ in self.load_gen(progress_callback): pass - async def load_gen(self, progress_callback=None, skip_wait = False): + async def load_gen(self, progress_callback=None, skip_wait=False): """Loads a model and streams progress via a generator.""" # Indicate that model load has started diff --git a/common/model.py b/common/model.py index 8ef7a4ad..286cc2b5 100644 --- a/common/model.py +++ b/common/model.py @@ -13,7 +13,6 @@ from backends.exllamav2.types import DraftModelInstanceConfig, ModelInstanceConfig from common.logger import get_loading_progress_bar from common.networking import handle_request_error -from common.tabby_config import config from common.optional_dependencies import dependencies if dependencies.exllamav2: diff --git a/endpoints/core/router.py b/endpoints/core/router.py index 8ee00e27..94de9143 100644 --- a/endpoints/core/router.py +++ b/endpoints/core/router.py @@ -137,7 +137,6 @@ async def load_model(data: ModelLoadRequest) -> ModelLoadResponse: raise HTTPException(400, error_message) - # if not model_path.exists(): # error_message = handle_request_error( # "Could not find the model path for load. Check model name or config.yml?", @@ -146,9 +145,7 @@ async def load_model(data: ModelLoadRequest) -> ModelLoadResponse: # raise HTTPException(400, error_message) - return EventSourceResponse( - stream_model_load(data), ping=maxsize - ) + return EventSourceResponse(stream_model_load(data), ping=maxsize) # Unload model endpoint diff --git a/endpoints/core/types/model.py b/endpoints/core/types/model.py index 2fd9bdf3..a887c834 100644 --- a/endpoints/core/types/model.py +++ b/endpoints/core/types/model.py @@ -2,7 +2,7 @@ from pydantic import BaseModel, Field, ConfigDict, model_validator from time import time -from typing import List, Literal, Optional, Union +from typing import List, Optional from backends.exllamav2.types import DraftModelInstanceConfig, ModelInstanceConfig from common.config_models import LoggingConfig From 55cf8b62f5484d9b19aed11e62231a6ec86f7594 Mon Sep 17 00:00:00 2001 From: TerminalMan <84923604+SecretiveShell@users.noreply.github.com> Date: Thu, 26 Sep 2024 14:42:00 +0100 Subject: [PATCH 8/9] convert exl2 container to pydantic inputs --- backends/exllamav2/model.py | 87 +++++++++++++------------------------ backends/exllamav2/types.py | 2 +- common/model.py | 2 +- 3 files changed, 32 insertions(+), 59 deletions(-) diff --git a/backends/exllamav2/model.py b/backends/exllamav2/model.py index 822b129d..d58c2125 100644 --- a/backends/exllamav2/model.py +++ b/backends/exllamav2/model.py @@ -32,6 +32,7 @@ from ruamel.yaml import YAML +from backends.exllamav2.types import DraftModelInstanceConfig, ModelInstanceConfig from common.health import HealthManager from backends.exllamav2.grammar import ( @@ -106,24 +107,9 @@ class ExllamaV2Container: @classmethod async def create( cls, - model_name: str, + model: ModelInstanceConfig, + draft: DraftModelInstanceConfig, quiet=False, - draft=None, - cache_mode="FP16", - gpu_split_auto=True, - tensor_parallel=False, - gpu_split=None, - autosplit_reserve=None, - override_base_seq_len=None, - max_seq_len=None, - rope_scale=None, - rope_alpha="auto", - fasttensors=False, - max_batch_size=None, - cache_size=None, - prompt_template=None, - num_experts_per_token=None, - chunk_size=2048, ): """ Primary asynchronous initializer for model container. @@ -140,7 +126,7 @@ async def create( self.config = ExLlamaV2Config() model_path = pathlib.Path(config.model.model_dir) - model_path = model_path / model_name + model_path = model_path / model.model_name model_path = model_path.resolve() if not model_path.exists(): raise FileNotFoundError(f"Model path {model_path} does not exist.") @@ -158,25 +144,13 @@ async def create( self.config.arch_compat_overrides() # Prepare the draft model config if necessary - draft_args = unwrap(draft, dict()) - draft_model_name = draft_args.get("draft_model_name") - enable_draft = draft_args and draft_model_name - - # Always disable draft if params are incorrectly configured - if draft_args and draft_model_name is None: - logger.warning( - "Draft model is disabled because a model name " - "wasn't provided. Please check your config.yml!" - ) - enable_draft = False - - if enable_draft: + if draft.draft_model_name: self.draft_config = ExLlamaV2Config() self.draft_config.no_flash_attn = self.config.no_flash_attn - draft_model_path = pathlib.Path( - unwrap(draft_args.get("draft_model_dir"), "models") + + draft_model_path = ( + config.draft_model.draft_model_dir / draft.draft_model_name ) - draft_model_path = draft_model_path / draft_model_name self.draft_model_dir = draft_model_path self.draft_config.model_dir = str(draft_model_path.resolve()) @@ -208,12 +182,11 @@ async def create( # MARK: User configuration # Get cache mode - self.cache_mode = unwrap(cache_mode, "FP16") + self.cache_mode = model.cache_mode # Turn off GPU split if the user is using 1 GPU gpu_count = torch.cuda.device_count() - gpu_split_auto = unwrap(gpu_split_auto, True) - use_tp = unwrap(tensor_parallel, False) + gpu_split_auto = model.gpu_split_auto gpu_device_list = list(range(0, gpu_count)) # Set GPU split options @@ -222,16 +195,16 @@ async def create( logger.info("Disabling GPU split because one GPU is in use.") else: # Set tensor parallel - if use_tp: + if model.tensor_parallel: self.use_tp = True # TP has its own autosplit loader self.gpu_split_auto = False # Enable manual GPU split if provided - if gpu_split: + if model.gpu_split: self.gpu_split_auto = False - self.gpu_split = gpu_split + self.gpu_split = model.gpu_split gpu_device_list = [ device_idx @@ -242,7 +215,7 @@ async def create( # Otherwise fallback to autosplit settings self.gpu_split_auto = gpu_split_auto - autosplit_reserve_megabytes = unwrap(autosplit_reserve, [96]) + autosplit_reserve_megabytes = model.autosplit_reserve # Reserve VRAM for each GPU self.autosplit_reserve = [ @@ -254,34 +227,34 @@ async def create( self.config.max_output_len = 16 # Then override the base_seq_len if present - if override_base_seq_len: - self.config.max_seq_len = override_base_seq_len + if model.override_base_seq_len: + self.config.max_seq_len = model.override_base_seq_len # Grab the base model's sequence length before overrides for # rope calculations base_seq_len = self.config.max_seq_len # Set the target seq len if present - target_max_seq_len = max_seq_len + target_max_seq_len = model.max_seq_len if target_max_seq_len: self.config.max_seq_len = target_max_seq_len # Set the rope scale - self.config.scale_pos_emb = unwrap(rope_scale, self.config.scale_pos_emb) + self.config.scale_pos_emb = unwrap(model.rope_scale, self.config.scale_pos_emb) # Sets rope alpha value. # Automatically calculate if unset or defined as an "auto" literal. - rope_alpha = unwrap(rope_alpha, "auto") + rope_alpha = unwrap(model.rope_alpha, "auto") if rope_alpha == "auto": self.config.scale_alpha_value = self.calculate_rope_alpha(base_seq_len) else: self.config.scale_alpha_value = rope_alpha # Enable fasttensors loading if present - self.config.fasttensors = unwrap(fasttensors, False) + self.config.fasttensors = config.model.fasttensors # Set max batch size to the config override - self.max_batch_size = max_batch_size + self.max_batch_size = model.max_batch_size # Check whether the user's configuration supports flash/paged attention # Also check if exl2 has disabled flash attention @@ -298,7 +271,7 @@ async def create( # Set k/v cache size # cache_size is only relevant when paged mode is enabled if self.paged: - cache_size = unwrap(cache_size, self.config.max_seq_len) + cache_size = unwrap(model.cache_size, self.config.max_seq_len) if cache_size < self.config.max_seq_len: logger.warning( @@ -340,7 +313,7 @@ async def create( # Try to set prompt template self.prompt_template = await self.find_prompt_template( - prompt_template, model_name + model.prompt_template, model.model_name ) # Catch all for template lookup errors @@ -355,28 +328,28 @@ async def create( ) # Set num of experts per token if provided - if num_experts_per_token: - self.config.num_experts_per_token = num_experts_per_token + if model.num_experts_per_token: + self.config.num_experts_per_token = model.num_experts_per_token # Make sure chunk size is >= 16 and <= max seq length - user_chunk_size = unwrap(chunk_size, 2048) + user_chunk_size = unwrap(model.chunk_size, 2048) chunk_size = sorted((16, user_chunk_size, self.config.max_seq_len))[1] self.config.max_input_len = chunk_size self.config.max_attention_size = chunk_size**2 # Set user-configured draft model values - if enable_draft: + if draft.draft_model_name: # Fetch from the updated kwargs draft_args = unwrap(draft, {}) self.draft_config.max_seq_len = self.config.max_seq_len self.draft_config.scale_pos_emb = unwrap( - draft_args.get("draft_rope_scale"), 1.0 + draft.draft_rope_scale, 1.0 ) # Set draft rope alpha. Follows same behavior as model rope alpha. - draft_rope_alpha = unwrap(draft_args.get("draft_rope_alpha"), "auto") + draft_rope_alpha = unwrap(draft.draft_rope_alpha, "auto") if draft_rope_alpha == "auto": self.draft_config.scale_alpha_value = self.calculate_rope_alpha( self.draft_config.max_seq_len @@ -385,7 +358,7 @@ async def create( self.draft_config.scale_alpha_value = draft_rope_alpha # Set draft cache mode - self.draft_cache_mode = unwrap(draft_args.get("draft_cache_mode"), "FP16") + self.draft_cache_mode = draft.draft_cache_mode if chunk_size: self.draft_config.max_input_len = chunk_size diff --git a/backends/exllamav2/types.py b/backends/exllamav2/types.py index 3ba44c33..446fbdc1 100644 --- a/backends/exllamav2/types.py +++ b/backends/exllamav2/types.py @@ -114,7 +114,7 @@ class ModelInstanceConfig(BaseModel): examples=[1.0], ) rope_alpha: Optional[Union[float, Literal["auto"]]] = Field( - None, + "auto", description=( "Rope alpha (default: None).\n" 'Same as alpha_value. Set to "auto" to auto-calculate.\n' diff --git a/common/model.py b/common/model.py index 286cc2b5..52a6f247 100644 --- a/common/model.py +++ b/common/model.py @@ -76,7 +76,7 @@ async def load_model_gen( draft = draft or DraftModelInstanceConfig() container = await ExllamaV2Container.create( - **model.model_dump(), quiet=False, draft=draft.model_dump() + model=model, draft=draft, quiet=False ) model_type = "draft" if container.draft_config else "model" From 7de70eca37a6627aed42549f88f9d9ed1ec1589a Mon Sep 17 00:00:00 2001 From: TerminalMan <84923604+SecretiveShell@users.noreply.github.com> Date: Thu, 26 Sep 2024 16:34:19 +0100 Subject: [PATCH 9/9] fix model overrides dict --- backends/exllamav2/model.py | 2 -- backends/exllamav2/types.py | 4 +++- common/model.py | 5 +++-- 3 files changed, 6 insertions(+), 5 deletions(-) diff --git a/backends/exllamav2/model.py b/backends/exllamav2/model.py index d58c2125..253e168b 100644 --- a/backends/exllamav2/model.py +++ b/backends/exllamav2/model.py @@ -339,8 +339,6 @@ async def create( # Set user-configured draft model values if draft.draft_model_name: - # Fetch from the updated kwargs - draft_args = unwrap(draft, {}) self.draft_config.max_seq_len = self.config.max_seq_len diff --git a/backends/exllamav2/types.py b/backends/exllamav2/types.py index 446fbdc1..10a1ee04 100644 --- a/backends/exllamav2/types.py +++ b/backends/exllamav2/types.py @@ -38,6 +38,8 @@ class DraftModelInstanceConfig(BaseModel): ), ) + model_config = ConfigDict(revalidate_instances="always") + class ModelInstanceConfig(BaseModel): """ @@ -182,4 +184,4 @@ class ModelInstanceConfig(BaseModel): ge=1, ) - model_config = ConfigDict(protected_namespaces=()) + model_config = ConfigDict(protected_namespaces=(), revalidate_instances="always") \ No newline at end of file diff --git a/common/model.py b/common/model.py index 52a6f247..9937537b 100644 --- a/common/model.py +++ b/common/model.py @@ -14,6 +14,7 @@ from common.logger import get_loading_progress_bar from common.networking import handle_request_error from common.optional_dependencies import dependencies +from common.tabby_config import config if dependencies.exllamav2: from backends.exllamav2.model import ExllamaV2Container @@ -69,8 +70,8 @@ async def load_model_gen( await unload_model() # Merge with config defaults - # FIXME: KWARGS DO NOT EXIST NOW - # kwargs = {**config.model_defaults, **kwargs} + model = model.model_copy(update=config.model_defaults) + model.model_validate(model, strict=True) # Create a new container draft = draft or DraftModelInstanceConfig()