diff --git a/py-src/data_formulator/agents/client_utils.py b/py-src/data_formulator/agents/client_utils.py index fe854c7a..983170ad 100644 --- a/py-src/data_formulator/agents/client_utils.py +++ b/py-src/data_formulator/agents/client_utils.py @@ -1,35 +1,64 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -import openai import os -import sys - +from litellm import completion from azure.identity import DefaultAzureCredential, get_bearer_token_provider +def get_client(endpoint, key=None, model_name=None): + """ + Returns a LiteLLM client configured for the specified endpoint and model. + Supports OpenAI, Azure, Ollama, and other providers via LiteLLM. + """ + # Set default endpoint + endpoint = os.getenv("ENDPOINT", endpoint) if endpoint == "default" else endpoint -def get_client(endpoint, key): - - endpoint = os.getenv("ENDPOINT") if endpoint == "default" else endpoint + if model_name is None: + if endpoint == "openai": + model_name = "gpt-4" # Default + elif "azure" in endpoint.lower(): + model_name = "azure-gpt-4" + elif "ollama" in endpoint.lower(): + model_name = "llama2" + else: + model_name = "" - if key is None or key == "": - # using azure keyless access method - token_provider = get_bearer_token_provider( - DefaultAzureCredential(), "https://cognitiveservices.azure.com/.default" - ) - print(token_provider) - print(endpoint) - client = openai.AzureOpenAI( - api_version="2024-02-15-preview", - azure_endpoint=endpoint, - azure_ad_token_provider=token_provider - ) - elif endpoint == 'openai': - client = openai.OpenAI(api_key=key) - else: - client = openai.AzureOpenAI( - azure_endpoint = endpoint, - api_key=key, - api_version="2024-02-15-preview" - ) - return client \ No newline at end of file + # Configure LiteLLM + if endpoint == "openai": + return completion( + model=model_name, + api_key=key, + custom_llm_provider="openai" + ) + elif "azure" in endpoint.lower(): + if key is None or key == "": + token_provider = get_bearer_token_provider( + DefaultAzureCredential(), "https://cognitiveservices.azure.com/.default" + ) + return completion( + model=model_name, + api_base=endpoint, + api_version="2024-02-15-preview", + azure_ad_token_provider=token_provider, + custom_llm_provider="azure" + ) + else: + return completion( + model=model_name, + api_base=endpoint, + api_key=key, + api_version="2024-02-15-preview", + custom_llm_provider="azure" + ) + elif "ollama" in endpoint.lower(): + return completion( + model=f"ollama/{model_name}", + api_base=endpoint, + custom_llm_provider="ollama" + ) + else: + return completion( + model=model_name, + api_base=endpoint, + api_key=key + )