Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 56 additions & 27 deletions py-src/data_formulator/agents/client_utils.py
Original file line number Diff line number Diff line change
@@ -1,35 +1,64 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.

import openai
import os
import sys

from litellm import completion
from azure.identity import DefaultAzureCredential, get_bearer_token_provider

def get_client(endpoint, key=None, model_name=None):
"""
Returns a LiteLLM client configured for the specified endpoint and model.
Supports OpenAI, Azure, Ollama, and other providers via LiteLLM.
"""
# Set default endpoint
endpoint = os.getenv("ENDPOINT", endpoint) if endpoint == "default" else endpoint

def get_client(endpoint, key):

endpoint = os.getenv("ENDPOINT") if endpoint == "default" else endpoint
if model_name is None:
if endpoint == "openai":
model_name = "gpt-4" # Default
elif "azure" in endpoint.lower():
model_name = "azure-gpt-4"
elif "ollama" in endpoint.lower():
model_name = "llama2"
else:
model_name = ""

if key is None or key == "":
# using azure keyless access method
token_provider = get_bearer_token_provider(
DefaultAzureCredential(), "https://cognitiveservices.azure.com/.default"
)
print(token_provider)
print(endpoint)
client = openai.AzureOpenAI(
api_version="2024-02-15-preview",
azure_endpoint=endpoint,
azure_ad_token_provider=token_provider
)
elif endpoint == 'openai':
client = openai.OpenAI(api_key=key)
else:
client = openai.AzureOpenAI(
azure_endpoint = endpoint,
api_key=key,
api_version="2024-02-15-preview"
)
return client
# Configure LiteLLM
if endpoint == "openai":
return completion(
model=model_name,
api_key=key,
custom_llm_provider="openai"
)
elif "azure" in endpoint.lower():
if key is None or key == "":
token_provider = get_bearer_token_provider(
DefaultAzureCredential(), "https://cognitiveservices.azure.com/.default"
)
return completion(
model=model_name,
api_base=endpoint,
api_version="2024-02-15-preview",
azure_ad_token_provider=token_provider,
custom_llm_provider="azure"
)
else:
return completion(
model=model_name,
api_base=endpoint,
api_key=key,
api_version="2024-02-15-preview",
custom_llm_provider="azure"
)
elif "ollama" in endpoint.lower():
return completion(
model=f"ollama/{model_name}",
api_base=endpoint,
custom_llm_provider="ollama"
)
else:
return completion(
model=model_name,
api_base=endpoint,
api_key=key
)