Skip to content

Commit 0ea4749

Browse files
Onprem Compatibility Change
1 parent 161778c commit 0ea4749

28 files changed

Lines changed: 910 additions & 158 deletions

charts/model-engine/templates/_helpers.tpl

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -256,6 +256,10 @@ env:
256256
- name: ABS_CONTAINER_NAME
257257
value: {{ .Values.azure.abs_container_name }}
258258
{{- end }}
259+
{{- if .Values.s3EndpointUrl }}
260+
- name: S3_ENDPOINT_URL
261+
value: {{ .Values.s3EndpointUrl | quote }}
262+
{{- end }}
259263
{{- end }}
260264

261265
{{- define "modelEngine.syncForwarderTemplateEnv" -}}
@@ -342,9 +346,27 @@ env:
342346
value: "/workspace/model-engine/model_engine_server/core/configs/config.yaml"
343347
{{- end }}
344348
- name: CELERY_ELASTICACHE_ENABLED
345-
value: "true"
349+
value: {{ .Values.celeryElasticacheEnabled | default true | quote }}
346350
- name: LAUNCH_SERVICE_TEMPLATE_FOLDER
347351
value: "/workspace/model-engine/model_engine_server/infra/gateways/resources/templates"
352+
{{- if .Values.s3EndpointUrl }}
353+
- name: S3_ENDPOINT_URL
354+
value: {{ .Values.s3EndpointUrl | quote }}
355+
{{- end }}
356+
{{- if .Values.redisHost }}
357+
- name: REDIS_HOST
358+
value: {{ .Values.redisHost | quote }}
359+
- name: REDIS_PORT
360+
value: {{ .Values.redisPort | default "6379" | quote }}
361+
{{- end }}
362+
{{- if .Values.celeryBrokerUrl }}
363+
- name: CELERY_BROKER_URL
364+
value: {{ .Values.celeryBrokerUrl | quote }}
365+
{{- end }}
366+
{{- if .Values.celeryResultBackend }}
367+
- name: CELERY_RESULT_BACKEND
368+
value: {{ .Values.celeryResultBackend | quote }}
369+
{{- end }}
348370
{{- if .Values.redis.auth}}
349371
- name: REDIS_AUTH_TOKEN
350372
value: {{ .Values.redis.auth }}

model-engine/model_engine_server/api/dependencies.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -225,7 +225,8 @@ def _get_external_interfaces(
225225
)
226226

227227
queue_delegate: QueueEndpointResourceDelegate
228-
if CIRCLECI:
228+
if CIRCLECI or infra_config().cloud_provider == "onprem":
229+
# On-prem uses fake queue delegate (no SQS/ServiceBus)
229230
queue_delegate = FakeQueueEndpointResourceDelegate()
230231
elif infra_config().cloud_provider == "azure":
231232
queue_delegate = ASBQueueEndpointResourceDelegate()
@@ -238,7 +239,8 @@ def _get_external_interfaces(
238239

239240
inference_task_queue_gateway: TaskQueueGateway
240241
infra_task_queue_gateway: TaskQueueGateway
241-
if CIRCLECI:
242+
if CIRCLECI or infra_config().cloud_provider == "onprem":
243+
# On-prem uses Redis-based task queues
242244
inference_task_queue_gateway = redis_24h_task_queue_gateway
243245
infra_task_queue_gateway = redis_task_queue_gateway
244246
elif infra_config().cloud_provider == "azure":
@@ -366,7 +368,8 @@ def _get_external_interfaces(
366368
file_storage_gateway = S3FileStorageGateway()
367369

368370
docker_repository: DockerRepository
369-
if CIRCLECI:
371+
if CIRCLECI or infra_config().cloud_provider == "onprem":
372+
# On-prem uses fake docker repository (no ECR/ACR validation)
370373
docker_repository = FakeDockerRepository()
371374
elif infra_config().cloud_provider == "azure":
372375
docker_repository = ACRDockerRepository()

model-engine/model_engine_server/common/config.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,12 +70,13 @@ class HostedModelInferenceServiceConfig:
7070
user_inference_tensorflow_repository: str
7171
docker_image_layer_cache_repository: str
7272
sensitive_log_mode: bool
73-
# Exactly one of the following three must be specified
73+
# Exactly one of the following must be specified for Redis cache
7474
cache_redis_aws_url: Optional[str] = None # also using this to store sync autoscaling metrics
7575
cache_redis_azure_host: Optional[str] = None
7676
cache_redis_aws_secret_name: Optional[str] = (
7777
None # Not an env var because the redis cache info is already here
7878
)
79+
cache_redis_onprem_url: Optional[str] = None # For on-prem Redis (e.g., redis://redis:6379/0)
7980
sglang_repository: Optional[str] = None
8081

8182
@classmethod
@@ -90,8 +91,13 @@ def from_yaml(cls, yaml_path):
9091

9192
@property
9293
def cache_redis_url(self) -> str:
94+
# On-prem Redis support (explicit URL, no cloud provider dependency)
95+
if self.cache_redis_onprem_url:
96+
return self.cache_redis_onprem_url
97+
9398
cloud_provider = infra_config().cloud_provider
9499

100+
# On-prem: support REDIS_HOST env var fallback
95101
if cloud_provider == "onprem":
96102
if self.cache_redis_aws_url:
97103
logger.info("On-prem deployment using cache_redis_aws_url")

model-engine/model_engine_server/common/io.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,17 +3,19 @@
33
import os
44
from typing import Any
55

6+
import boto3
67
import smart_open
78
from model_engine_server.core.config import infra_config
89

910

1011
def open_wrapper(uri: str, mode: str = "rt", **kwargs):
1112
client: Any
13+
cloud_provider: str
14+
# This follows the 5.1.0 smart_open API
1215
try:
1316
cloud_provider = infra_config().cloud_provider
1417
except Exception:
1518
cloud_provider = "aws"
16-
1719
if cloud_provider == "azure":
1820
from azure.identity import DefaultAzureCredential
1921
from azure.storage.blob import BlobServiceClient
@@ -23,9 +25,9 @@ def open_wrapper(uri: str, mode: str = "rt", **kwargs):
2325
DefaultAzureCredential(),
2426
)
2527
else:
26-
from model_engine_server.infra.gateways.s3_utils import get_s3_client
27-
28-
client = get_s3_client(kwargs)
28+
profile_name = kwargs.get("aws_profile", os.getenv("AWS_PROFILE"))
29+
session = boto3.Session(profile_name=profile_name)
30+
client = session.client("s3")
2931

3032
transport_params = {"client": client}
3133
return smart_open.open(uri, mode, transport_params=transport_params)

model-engine/model_engine_server/core/aws/roles.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -119,12 +119,21 @@ def session(role: Optional[str], session_type: SessionT = Session) -> SessionT:
119119
120120
:param:`session_type` defines the type of session to return. Most users will use
121121
the default boto3 type. Some users required a special type (e.g aioboto3 session).
122+
123+
For on-prem deployments without AWS profiles, pass role=None or role=""
124+
to use default credentials from environment variables (AWS_ACCESS_KEY_ID, etc).
122125
"""
123126
# Do not assume roles in CIRCLECI
124127
if os.getenv("CIRCLECI"):
125128
logger.warning(f"In circleci, not assuming role (ignoring: {role})")
126129
role = None
127-
sesh: SessionT = session_type(profile_name=role)
130+
131+
# Use profile-based auth only if role is specified
132+
# For on-prem with MinIO, role will be None or empty - use env var credentials
133+
if role:
134+
sesh: SessionT = session_type(profile_name=role)
135+
else:
136+
sesh: SessionT = session_type() # Uses default credential chain (env vars)
128137
return sesh
129138

130139

model-engine/model_engine_server/core/aws/storage_client.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import os
12
import time
23
from typing import IO, Callable, Iterable, Optional, Sequence
34

@@ -20,6 +21,10 @@
2021

2122

2223
def sync_storage_client(**kwargs) -> BaseClient:
24+
# Support for MinIO/on-prem S3-compatible storage
25+
endpoint_url = os.getenv("S3_ENDPOINT_URL")
26+
if endpoint_url and "endpoint_url" not in kwargs:
27+
kwargs["endpoint_url"] = endpoint_url
2328
return session(infra_config().profile_ml_worker).client("s3", **kwargs) # type: ignore
2429

2530

model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py

Lines changed: 103 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@
6161
from model_engine_server.common.dtos.tasks import SyncEndpointPredictV1Request, TaskStatus
6262
from model_engine_server.common.resource_limits import validate_resource_requests
6363
from model_engine_server.core.auth.authentication_repository import User
64+
from model_engine_server.core.config import infra_config
6465
from model_engine_server.core.configmap import read_config_map
6566
from model_engine_server.core.loggers import (
6667
LoggerTagKey,
@@ -369,6 +370,10 @@ def __init__(
369370
def check_docker_image_exists_for_image_tag(
370371
self, framework_image_tag: str, repository_name: str
371372
):
373+
# Skip ECR validation for on-prem deployments - images are in local registry
374+
if infra_config().cloud_provider == "onprem":
375+
return
376+
372377
if not self.docker_repository.image_exists(
373378
image_tag=framework_image_tag,
374379
repository_name=repository_name,
@@ -640,8 +645,13 @@ def load_model_weights_sub_commands_s3(
640645
file_selection_str = '--include "*.model" --include "*.model.v*" --include "*.json" --include "*.safetensors" --include "*.txt" --exclude "optimizer*"'
641646
if trust_remote_code:
642647
file_selection_str += ' --include "*.py"'
648+
649+
# Support for MinIO/on-prem S3-compatible storage via S3_ENDPOINT_URL env var
650+
endpoint_flag = (
651+
'$(if [ -n "$S3_ENDPOINT_URL" ]; then echo "--endpoint-url $S3_ENDPOINT_URL"; fi)'
652+
)
643653
subcommands.append(
644-
f"{s5cmd} --numworkers 512 cp --concurrency 10 {file_selection_str} {os.path.join(checkpoint_path, '*')} {final_weights_folder}"
654+
f"{s5cmd} {endpoint_flag} --numworkers 512 cp --concurrency 10 {file_selection_str} {os.path.join(checkpoint_path, '*')} {final_weights_folder}"
645655
)
646656
return subcommands
647657

@@ -693,8 +703,12 @@ def load_model_files_sub_commands_trt_llm(
693703
and llm-engine/model-engine/model_engine_server/inference/tensorrt-llm/triton_model_repo/postprocessing/config.pbtxt
694704
"""
695705
if checkpoint_path.startswith("s3://"):
706+
# Support for MinIO/on-prem S3-compatible storage via S3_ENDPOINT_URL env var
707+
endpoint_flag = (
708+
'$(if [ -n "$S3_ENDPOINT_URL" ]; then echo "--endpoint-url $S3_ENDPOINT_URL"; fi)'
709+
)
696710
subcommands = [
697-
f"./s5cmd --numworkers 512 cp --concurrency 50 {os.path.join(checkpoint_path, '*')} ./"
711+
f"./s5cmd {endpoint_flag} --numworkers 512 cp --concurrency 50 {os.path.join(checkpoint_path, '*')} ./"
698712
]
699713
else:
700714
subcommands.extend(
@@ -1053,8 +1067,9 @@ async def create_vllm_bundle(
10531067
protocol="http",
10541068
readiness_initial_delay_seconds=10,
10551069
healthcheck_route="/health",
1056-
predict_route="/predict",
1057-
streaming_predict_route="/stream",
1070+
# vLLM 0.5+ uses OpenAI-compatible endpoints
1071+
predict_route=OPENAI_COMPLETION_PATH, # "/v1/completions"
1072+
streaming_predict_route=OPENAI_COMPLETION_PATH, # "/v1/completions" (streaming via same endpoint)
10581073
routes=[
10591074
OPENAI_CHAT_COMPLETION_PATH,
10601075
OPENAI_COMPLETION_PATH,
@@ -1135,8 +1150,9 @@ async def create_vllm_multinode_bundle(
11351150
protocol="http",
11361151
readiness_initial_delay_seconds=10,
11371152
healthcheck_route="/health",
1138-
predict_route="/predict",
1139-
streaming_predict_route="/stream",
1153+
# vLLM 0.5+ uses OpenAI-compatible endpoints
1154+
predict_route=OPENAI_COMPLETION_PATH, # "/v1/completions"
1155+
streaming_predict_route=OPENAI_COMPLETION_PATH, # "/v1/completions" (streaming via same endpoint)
11401156
routes=[OPENAI_CHAT_COMPLETION_PATH, OPENAI_COMPLETION_PATH],
11411157
env=common_vllm_envs,
11421158
worker_command=worker_command,
@@ -1937,18 +1953,42 @@ def model_output_to_completion_output(
19371953

19381954
elif model_content.inference_framework == LLMInferenceFramework.VLLM:
19391955
tokens = None
1940-
if with_token_probs:
1941-
tokens = [
1942-
TokenOutput(
1943-
token=model_output["tokens"][index],
1944-
log_prob=list(t.values())[0],
1945-
)
1946-
for index, t in enumerate(model_output["log_probs"])
1947-
]
1956+
# Handle OpenAI-compatible format (vLLM 0.5+) vs legacy format
1957+
if "choices" in model_output and model_output["choices"]:
1958+
# OpenAI-compatible format: {"choices": [{"text": "...", ...}], "usage": {...}}
1959+
choice = model_output["choices"][0]
1960+
text = choice.get("text", "")
1961+
usage = model_output.get("usage", {})
1962+
num_prompt_tokens = usage.get("prompt_tokens", 0)
1963+
num_completion_tokens = usage.get("completion_tokens", 0)
1964+
# OpenAI format logprobs are in choice.logprobs
1965+
if with_token_probs and choice.get("logprobs"):
1966+
logprobs = choice["logprobs"]
1967+
if logprobs.get("tokens") and logprobs.get("token_logprobs"):
1968+
tokens = [
1969+
TokenOutput(
1970+
token=logprobs["tokens"][i],
1971+
log_prob=logprobs["token_logprobs"][i] or 0.0,
1972+
)
1973+
for i in range(len(logprobs["tokens"]))
1974+
]
1975+
else:
1976+
# Legacy format: {"text": "...", "count_prompt_tokens": ..., ...}
1977+
text = model_output["text"]
1978+
num_prompt_tokens = model_output["count_prompt_tokens"]
1979+
num_completion_tokens = model_output["count_output_tokens"]
1980+
if with_token_probs and model_output.get("log_probs"):
1981+
tokens = [
1982+
TokenOutput(
1983+
token=model_output["tokens"][index],
1984+
log_prob=list(t.values())[0],
1985+
)
1986+
for index, t in enumerate(model_output["log_probs"])
1987+
]
19481988
return CompletionOutput(
1949-
text=model_output["text"],
1950-
num_prompt_tokens=model_output["count_prompt_tokens"],
1951-
num_completion_tokens=model_output["count_output_tokens"],
1989+
text=text,
1990+
num_prompt_tokens=num_prompt_tokens,
1991+
num_completion_tokens=num_completion_tokens,
19521992
tokens=tokens,
19531993
)
19541994
elif model_content.inference_framework == LLMInferenceFramework.LIGHTLLM:
@@ -2688,20 +2728,43 @@ async def _response_chunk_generator(
26882728
# VLLM
26892729
elif model_content.inference_framework == LLMInferenceFramework.VLLM:
26902730
token = None
2691-
if request.return_token_log_probs:
2692-
token = TokenOutput(
2693-
token=result["result"]["text"],
2694-
log_prob=list(result["result"]["log_probs"].values())[0],
2695-
)
2696-
finished = result["result"]["finished"]
2697-
num_prompt_tokens = result["result"]["count_prompt_tokens"]
2731+
vllm_output: dict = result["result"]
2732+
# Handle OpenAI-compatible streaming format (vLLM 0.5+) vs legacy format
2733+
if "choices" in vllm_output and vllm_output["choices"]:
2734+
# OpenAI streaming format: {"choices": [{"text": "...", "finish_reason": ...}], ...}
2735+
choice = vllm_output["choices"][0]
2736+
text = choice.get("text", "")
2737+
finished = choice.get("finish_reason") is not None
2738+
usage = vllm_output.get("usage", {})
2739+
num_prompt_tokens = usage.get("prompt_tokens", 0)
2740+
num_completion_tokens = usage.get("completion_tokens", 0)
2741+
if request.return_token_log_probs and choice.get("logprobs"):
2742+
logprobs = choice["logprobs"]
2743+
if logprobs.get("tokens") and logprobs.get("token_logprobs"):
2744+
# Get the last token from the logprobs
2745+
idx = len(logprobs["tokens"]) - 1
2746+
token = TokenOutput(
2747+
token=logprobs["tokens"][idx],
2748+
log_prob=logprobs["token_logprobs"][idx] or 0.0,
2749+
)
2750+
else:
2751+
# Legacy format: {"text": "...", "finished": ..., ...}
2752+
text = vllm_output["text"]
2753+
finished = vllm_output["finished"]
2754+
num_prompt_tokens = vllm_output["count_prompt_tokens"]
2755+
num_completion_tokens = vllm_output["count_output_tokens"]
2756+
if request.return_token_log_probs and vllm_output.get("log_probs"):
2757+
token = TokenOutput(
2758+
token=vllm_output["text"],
2759+
log_prob=list(vllm_output["log_probs"].values())[0],
2760+
)
26982761
yield CompletionStreamV1Response(
26992762
request_id=request_id,
27002763
output=CompletionStreamOutput(
2701-
text=result["result"]["text"],
2764+
text=text,
27022765
finished=finished,
27032766
num_prompt_tokens=num_prompt_tokens if finished else None,
2704-
num_completion_tokens=result["result"]["count_output_tokens"],
2767+
num_completion_tokens=num_completion_tokens,
27052768
token=token,
27062769
),
27072770
)
@@ -2750,12 +2813,14 @@ def validate_endpoint_supports_openai_completion(
27502813
f"The endpoint's inference framework ({endpoint_content.inference_framework}) does not support openai compatible completion."
27512814
)
27522815

2753-
if not isinstance(
2754-
endpoint.record.current_model_bundle.flavor, RunnableImageLike
2755-
) or OPENAI_COMPLETION_PATH not in (
2756-
endpoint.record.current_model_bundle.flavor.extra_routes
2757-
+ endpoint.record.current_model_bundle.flavor.routes
2758-
):
2816+
if not isinstance(endpoint.record.current_model_bundle.flavor, RunnableImageLike):
2817+
raise EndpointUnsupportedRequestException(
2818+
"Endpoint does not support v2 openai compatible completion"
2819+
)
2820+
2821+
flavor = endpoint.record.current_model_bundle.flavor
2822+
all_routes = flavor.extra_routes + flavor.routes
2823+
if OPENAI_COMPLETION_PATH not in all_routes:
27592824
raise EndpointUnsupportedRequestException(
27602825
"Endpoint does not support v2 openai compatible completion"
27612826
)
@@ -3042,12 +3107,12 @@ def validate_endpoint_supports_chat_completion(
30423107
f"The endpoint's inference framework ({endpoint_content.inference_framework}) does not support chat completion."
30433108
)
30443109

3045-
if not isinstance(
3046-
endpoint.record.current_model_bundle.flavor, RunnableImageLike
3047-
) or OPENAI_CHAT_COMPLETION_PATH not in (
3048-
endpoint.record.current_model_bundle.flavor.extra_routes
3049-
+ endpoint.record.current_model_bundle.flavor.routes
3050-
):
3110+
if not isinstance(endpoint.record.current_model_bundle.flavor, RunnableImageLike):
3111+
raise EndpointUnsupportedRequestException("Endpoint does not support chat completion")
3112+
3113+
flavor = endpoint.record.current_model_bundle.flavor
3114+
all_routes = flavor.extra_routes + flavor.routes
3115+
if OPENAI_CHAT_COMPLETION_PATH not in all_routes:
30513116
raise EndpointUnsupportedRequestException("Endpoint does not support chat completion")
30523117

30533118

0 commit comments

Comments
 (0)