diff --git a/backends/infinity/model.py b/backends/infinity/model.py
index c131e3cb..72fa8eb9 100644
--- a/backends/infinity/model.py
+++ b/backends/infinity/model.py
@@ -25,20 +25,68 @@ def __init__(self, model_directory: pathlib.Path):
async def load(self, **kwargs):
# Use cpu by default
device = unwrap(kwargs.get("embeddings_device"), "cpu")
+
+ # Extract device ID if specified
+ device_id = kwargs.get("embeddings_device_id", [])
+
+ # Handle mixed device types (CPU/CUDA conflict)
+ if device == "cpu" and device_id:
+ logger.warning("embeddings_device is set to 'cpu' but embeddings_device_id is specified. Ignoring device_id and using CPU.")
+ device_id = []
+
+ # Validate device ID if using CUDA
+ if device == "cuda" and device_id:
+ if not isinstance(device_id, list):
+ device_id = [device_id]
+
+ # Validate GPU exists
+ available_gpus = torch.cuda.device_count()
+ valid_device_ids = []
+
+ for gpu_id in device_id:
+ if gpu_id >= available_gpus:
+ logger.error(f"GPU {gpu_id} not found. Available GPUs: 0-{available_gpus-1}")
+ continue # Skip invalid GPU but continue checking others
+ else:
+ valid_device_ids.append(gpu_id)
+
+ # Use only valid device IDs
+ device_id = valid_device_ids
+
+ # Handle multiple device IDs with infinity_emb compatibility
+ if len(device_id) > 1:
+ logger.warning("infinity_emb may not support multiple GPU IDs. Using first valid GPU: {device_id[0]}")
+ device_id = [device_id[0]] # Use only first GPU
- engine_args = EngineArgs(
- model_name_or_path=str(self.model_dir),
- engine="torch",
- device=device,
- bettertransformer=False,
- model_warmup=False,
- )
+ try:
+ engine_args = EngineArgs(
+ model_name_or_path=str(self.model_dir),
+ engine="torch",
+ device=device,
+ device_id=device_id, # Pass device ID to infinity_emb
+ bettertransformer=False,
+ model_warmup=False,
+ )
- self.engine = AsyncEmbeddingEngine.from_args(engine_args)
- await self.engine.astart()
+ self.engine = AsyncEmbeddingEngine.from_args(engine_args)
+ await self.engine.astart()
- self.loaded = True
- logger.info("Embedding model successfully loaded.")
+ self.loaded = True
+ gpu_info = f" on GPU {device_id}" if device_id else ""
+ logger.info(f"Embedding model successfully loaded{gpu_info}.")
+
+ except RuntimeError as e:
+ if "out of memory" in str(e).lower():
+ logger.error(f"GPU {device_id} has insufficient memory for embedding model. Error: {str(e)}")
+ logger.error("Try using a different GPU or loading the model on CPU.")
+ raise
+ elif "cuda" in str(e).lower() or "device" in str(e).lower():
+ logger.error(f"Failed to load embedding model on GPU {device_id}. Error: {str(e)}")
+ logger.error("The GPU may be busy or unavailable. Try using a different GPU or CPU.")
+ raise
+ else:
+ logger.error(f"Unexpected error loading embedding model: {str(e)}")
+ raise
async def unload(self):
await self.engine.astop()
diff --git a/common/config_models.py b/common/config_models.py
index 0e71734c..95afcfc5 100644
--- a/common/config_models.py
+++ b/common/config_models.py
@@ -420,11 +420,39 @@ class EmbeddingsConfig(BaseConfigModel):
"If using an AMD GPU, set this value to 'cuda'."
),
)
+ embeddings_device_id: Optional[List[int]] = Field(
+ [],
+ description=(
+ "Specific GPU device IDs for embedding models (default: []).\n"
+ "Empty list for auto-select.\n"
+ "Only applies when embeddings_device is 'cuda'."
+ ),
+ )
embedding_model_name: Optional[str] = Field(
None,
description=("An initial embedding model to load on the infinity backend."),
)
+ @field_validator("embeddings_device_id", mode="before")
+ @classmethod
+ def validate_embeddings_device_id(cls, v, info):
+ # Only validate if CUDA is selected
+ if info.data.get("embeddings_device") == "cuda" and v:
+ # Check if torch is available
+ try:
+ import torch
+ available_gpus = torch.cuda.device_count()
+ for gpu_id in v:
+ if gpu_id >= available_gpus:
+ raise ValueError(
+ f"GPU {gpu_id} not found. Available GPUs: 0-{available_gpus-1}"
+ )
+ except ImportError:
+ # If torch is not available, we can't validate now
+ # This will be caught later during model loading
+ pass
+ return v
+
class DeveloperConfig(BaseConfigModel):
"""Options for development and experimentation"""
diff --git a/config_sample.yml b/config_sample.yml
index 0b65f9e8..a3139b82 100644
--- a/config_sample.yml
+++ b/config_sample.yml
@@ -218,6 +218,12 @@ embeddings:
# If using an AMD GPU, set this value to 'cuda'.
embeddings_device: cpu
+ # Specific GPU device IDs for embedding models (default: []).
+ # Empty list for auto-select.
+ # Only applies when embeddings_device is 'cuda'.
+ # Example: [0] for first GPU, [1] for second GPU
+ embeddings_device_id: []
+
# An initial embedding model to load on the infinity backend.
embedding_model_name:
diff --git a/docs/02.-Server-options.md b/docs/02.-Server-options.md
index 98cee556..860ef9a8 100644
--- a/docs/02.-Server-options.md
+++ b/docs/02.-Server-options.md
@@ -107,4 +107,12 @@ Note: Most of the options here will only apply on initial embedding model load/s
| -------------------- | ----------------- | ---------------------------------------------------------------------------------------------------------------------------- |
| embedding_model_dir | String ("models") | Directory to look for embedding models.
Note: Persisted across subsequent load requests |
| embeddings_device | String ("cpu") | Device to load an embedding model on.
Options: cpu, cuda, auto
Note: Persisted across subsequent load requests |
+| embeddings_device_id | List[int] ([]) | Specific GPU device IDs for embedding models.
Empty list for auto-select.
Only applies when embeddings_device is "cuda".
Note: If multiple GPUs are specified, only the first valid GPU will be used. |
| embedding_model_name | String (None) | Folder name of an embedding model to load using infinity-emb. |
+
+#### Troubleshooting
+
+- **GPU not found**: If you see "GPU X not found" error, check your GPU IDs against `nvidia-smi` output
+- **Out of memory**: If GPU runs out of memory, try using a different GPU or set `embeddings_device` to "cpu"
+- **GPU busy**: If model loading fails with CUDA errors, the GPU may be busy with other processes
+- **Mixed device types**: If `embeddings_device` is "cpu" but `embeddings_device_id` is set, the device ID will be ignored
diff --git a/endpoints/core/types/model.py b/endpoints/core/types/model.py
index 84229294..c121df59 100644
--- a/endpoints/core/types/model.py
+++ b/endpoints/core/types/model.py
@@ -123,6 +123,7 @@ class EmbeddingModelLoadRequest(BaseModel):
# Set default from the config
embeddings_device: Optional[str] = Field(config.embeddings.embeddings_device)
+ embeddings_device_id: Optional[List[int]] = Field(config.embeddings.embeddings_device_id)
class ModelLoadResponse(BaseModel):