diff --git a/.gitignore b/.gitignore index b6982ac0..74e91524 100644 --- a/.gitignore +++ b/.gitignore @@ -1,6 +1,5 @@ - -*openai-keys.env +*api-keys.env **/*.ipynb_checkpoints/ .DS_Store diff --git a/DEVELOPMENT.md b/DEVELOPMENT.md index 2b24bde7..9168e8ec 100644 --- a/DEVELOPMENT.md +++ b/DEVELOPMENT.md @@ -18,8 +18,13 @@ How to set up your local machine. ```bash pip install -r requirements.txt ``` +- **Configure environment variable (optional)s** + - copy `api-keys.env.example` to `api-keys.env` and add your API keys. + - required fields for different providers are different, please refer to the [LiteLLM setup](https://docs.litellm.ai/docs#litellm-python-sdk) guide for more details. + - currently only endpoint, model, api_key, api_base, api_version are supported. + - this helps data formulator to automatically load the API keys when you run the app, so you don't need to set the API keys in the app UI. -- **Run** +- **Run the app** - **Windows** ```bash .\local_server.bat @@ -27,7 +32,7 @@ How to set up your local machine. - **Unix-based** ```bash - .\local_server.sh + ./local_server.sh ``` ## Frontend (TypeScript) diff --git a/README.md b/README.md index 0369df05..ca91b704 100644 --- a/README.md +++ b/README.md @@ -22,6 +22,12 @@ Transform data and create rich visualizations iteratively with AI 🪄. Try Data ## News 🔥🔥🔥 +- [02-12-2025] More models supported now! Powered by [LiteLLM](https://github.com/BerriAI/litellm)! + - Now supports OpenAI, Azure, Ollama, and Anthropic models (and more based on LiteLLM); + - Models with strong code generation capabilities are recommended (gpt-4o, claude-3-5-sonnet etc.); + - You can store API keys in `api-keys.env` to avoid typing them every time (see template `api-keys.env.template`). + - Let us know which models you have good/bad experiences with, and what models you would like to see supported! [[comment here]](https://github.com/microsoft/data-formulator/issues/49) + - [11-07-2024] Minor fun update: data visualization challenges! - We added a few visualization challenges with the sample datasets. Can you complete them all? [[try them out!]](https://github.com/microsoft/data-formulator/issues/53#issue-2641841252) - Comment in the issue when you did, or share your results/questions with others! [[comment here]](https://github.com/microsoft/data-formulator/issues/53) @@ -77,7 +83,7 @@ Play with Data Formulator with one of the following options: ## Using Data Formulator -Once you’ve completed the setup using either option, follow these steps to start using Data Formulator: +Once you've completed the setup using either option, follow these steps to start using Data Formulator: ### The basics of data visualization * Provide OpenAI keys and select a model (GPT-4o suggested) and choose a dataset. diff --git a/api-keys.env.template b/api-keys.env.template new file mode 100644 index 00000000..db8cba5d --- /dev/null +++ b/api-keys.env.template @@ -0,0 +1,24 @@ +# OpenAI Configuration +OPENAI_ENABLED=true +OPENAI_API_KEY=#your-openai-api-key +OPENAI_MODELS=gpt-4o,gpt-4o-mini # comma separated list of models + +# Azure OpenAI Configuration +AZURE_ENABLED=true +AZURE_API_KEY=#your-azure-openai-api-key +AZURE_API_BASE=https://your-azure-openai-endpoint.openai.azure.com/ +AZURE_API_VERSION=2024-02-15-preview +AZURE_MODELS=gpt-4o + +# Anthropic Configuration +ANTHROPIC_ENABLED=true +ANTHROPIC_API_KEY=#your-anthropic-api-key +ANTHROPIC_MODELS=claude-3-5-sonnet-20241022,claude-3-5-haiku-20241022 + +# Ollama Configuration +OLLAMA_ENABLED=true +OLLAMA_API_BASE=http://localhost:11434 +OLLAMA_MODELS=codellama:7b # models with good code generation capabilities recommended + +# if you want to add other models, you can add them with PROVIDER_API_KEY=your-api-key, PROVIDER_MODELS=model1,model2 etc +# (replacing PROVIDER with the provider name like GEMINI, ANTHROPIC, AZURE, OPENAI, OLLAMA etc. as long as they are supported by LiteLLM) \ No newline at end of file diff --git a/py-src/data_formulator/agents/agent_code_explanation.py b/py-src/data_formulator/agents/agent_code_explanation.py index 0053c69e..8d16a968 100644 --- a/py-src/data_formulator/agents/agent_code_explanation.py +++ b/py-src/data_formulator/agents/agent_code_explanation.py @@ -66,9 +66,8 @@ def transform_data(df_0): class CodeExplanationAgent(object): - def __init__(self, client, model): + def __init__(self, client): self.client = client - self.model = model def run(self, input_tables, code): @@ -82,11 +81,8 @@ def run(self, input_tables, code): {"role":"user","content": user_query}] ###### the part that calls open_ai - response = self.client.chat.completions.create( - model=self.model, messages = messages, temperature=0.7, max_tokens=1200, - top_p=0.95, n=1, frequency_penalty=0, presence_penalty=0, stop=None) + response = self.client.get_completion(messages = messages) - logger.info('\n=== explanation output ===>\n') - logger.info(response.choices[0].message.content) + logger.info(f"=== explanation output ===>\n{response.choices[0].message.content}\n") return response.choices[0].message.content diff --git a/py-src/data_formulator/agents/agent_concept_derive.py b/py-src/data_formulator/agents/agent_concept_derive.py index 69873b47..d7de8aab 100644 --- a/py-src/data_formulator/agents/agent_concept_derive.py +++ b/py-src/data_formulator/agents/agent_concept_derive.py @@ -167,9 +167,8 @@ class ConceptDeriveAgent(object): - def __init__(self, client, model): + def __init__(self, client): self.client = client - self.model = model def run(self, input_table, input_fields, output_field, description, n=1): """derive a new concept based on input table, input fields, and output field name, (and description) @@ -190,9 +189,7 @@ def run(self, input_table, input_fields, output_field, description, n=1): {"role":"user","content": user_query}] ###### the part that calls open_ai - response = self.client.chat.completions.create( - model=self.model, messages = messages, temperature=0.7, max_tokens=1200, - top_p=0.95, n=n, frequency_penalty=0, presence_penalty=0, stop=None) + response = self.client.get_completion(messages = messages) #log = {'messages': messages, 'response': response.model_dump(mode='json')} diff --git a/py-src/data_formulator/agents/agent_data_clean.py b/py-src/data_formulator/agents/agent_data_clean.py index 8cecbeca..818131e0 100644 --- a/py-src/data_formulator/agents/agent_data_clean.py +++ b/py-src/data_formulator/agents/agent_data_clean.py @@ -78,8 +78,7 @@ class DataCleanAgent(object): - def __init__(self, client, model): - self.model = model + def __init__(self, client): self.client = client def run(self, content_type, raw_data, image_cleaning_instruction): @@ -129,9 +128,7 @@ def run(self, content_type, raw_data, image_cleaning_instruction): messages = [system_message, user_prompt] ###### the part that calls open_ai - response = self.client.chat.completions.create( - model=self.model, messages = messages, temperature=0.7, max_tokens=1200, - top_p=0.95, n=1, frequency_penalty=0, presence_penalty=0, stop=None) + response = self.client.get_completion(messages = messages) candidates = [] for choice in response.choices: diff --git a/py-src/data_formulator/agents/agent_data_filter.py b/py-src/data_formulator/agents/agent_data_filter.py index 895f7663..ffb7b637 100644 --- a/py-src/data_formulator/agents/agent_data_filter.py +++ b/py-src/data_formulator/agents/agent_data_filter.py @@ -125,9 +125,8 @@ def filter_row(row, df): class DataFilterAgent(object): - def __init__(self, client, model): + def __init__(self, client): self.client = client - self.model = model def process_gpt_result(self, input_table, response, messages): #log = {'messages': messages, 'response': response.model_dump(mode='json')} @@ -177,9 +176,7 @@ def run(self, input_table, description): {"role":"user","content": user_query}] ###### the part that calls open_ai - response = self.client.chat.completions.create( - model=self.model, messages = messages, temperature=0.7, max_tokens=1200, - top_p=0.95, n=1, frequency_penalty=0, presence_penalty=0, stop=None) + response = self.client.get_completion(messages = messages) return self.process_gpt_result(input_table, response, messages) @@ -190,8 +187,6 @@ def followup(self, input_table, dialog, new_instruction: str, n=1): "content": new_instruction + '\nupdate the filter function accordingly'}] ##### the part that calls open_ai - response = self.client.chat.completions.create( - model=self.model, messages=messages, temperature=0.7, max_tokens=1200, - top_p=0.95, n=n, frequency_penalty=0, presence_penalty=0, stop=None) + response = self.client.get_completion(messages = messages) return self.process_gpt_result(input_table, response, messages) \ No newline at end of file diff --git a/py-src/data_formulator/agents/agent_data_load.py b/py-src/data_formulator/agents/agent_data_load.py index 86fb367d..8d534cf6 100644 --- a/py-src/data_formulator/agents/agent_data_load.py +++ b/py-src/data_formulator/agents/agent_data_load.py @@ -124,9 +124,8 @@ class DataLoadAgent(object): - def __init__(self, client, model): + def __init__(self, client): self.client = client - self.model = model def run(self, input_data, n=1): @@ -140,9 +139,7 @@ def run(self, input_data, n=1): {"role":"user","content": user_query}] ###### the part that calls open_ai - response = self.client.chat.completions.create( - model=self.model, messages=messages, temperature=0.2, max_tokens=4096, - top_p=0.95, n=n, frequency_penalty=0, presence_penalty=0, stop=None) + response = self.client.get_completion(messages = messages) #log = {'messages': messages, 'response': response.model_dump(mode='json')} diff --git a/py-src/data_formulator/agents/agent_data_rec.py b/py-src/data_formulator/agents/agent_data_rec.py index 9c718273..0047e376 100644 --- a/py-src/data_formulator/agents/agent_data_rec.py +++ b/py-src/data_formulator/agents/agent_data_rec.py @@ -126,9 +126,8 @@ def transform_data(df): class DataRecAgent(object): - def __init__(self, client, model, system_prompt=None): + def __init__(self, client, system_prompt=None): self.client = client - self.model = model self.system_prompt = system_prompt if system_prompt is not None else SYSTEM_PROMPT def process_gpt_response(self, input_tables, messages, response): @@ -171,7 +170,7 @@ def process_gpt_response(self, input_tables, messages, response): logger.warning(error_message) result = {'status': 'other error', 'code': code_str, 'content': f"Unexpected error: {error_message}"} else: - result = {'status': 'no transformation', 'code': "", 'content': input_tables[0]['rows']} + result = {'status': 'error', 'code': "", 'content': "No code block found in the response. The model is unable to generate code to complete the task."} result['dialog'] = [*messages, {"role": choice.message.role, "content": choice.message.content}] result['agent'] = 'DataRecAgent' @@ -192,7 +191,7 @@ def run(self, input_tables, description, n=1): messages = [{"role":"system", "content": self.system_prompt}, {"role":"user","content": user_query}] - response = completion_response_wrapper(self.client, self.model, messages, n) + response = completion_response_wrapper(self.client, messages, n) return self.process_gpt_response(input_tables, messages, response) @@ -204,7 +203,6 @@ def followup(self, input_tables, dialog, new_instruction: str, n=1): messages = [*dialog, {"role":"user", "content": f"Update: \n\n{new_instruction}"}] - ##### the part that calls open_ai - response = completion_response_wrapper(self.client, self.model, messages, n) + response = completion_response_wrapper(self.client, messages, n) return self.process_gpt_response(input_tables, messages, response) \ No newline at end of file diff --git a/py-src/data_formulator/agents/agent_data_transform_v2.py b/py-src/data_formulator/agents/agent_data_transform_v2.py index d9e530ea..26d04c53 100644 --- a/py-src/data_formulator/agents/agent_data_transform_v2.py +++ b/py-src/data_formulator/agents/agent_data_transform_v2.py @@ -2,6 +2,7 @@ # Licensed under the MIT License. import json +import sys from data_formulator.agents.agent_utils import extract_json_objects, generate_data_summary, extract_code_from_gpt_response import data_formulator.py_sandbox as py_sandbox @@ -10,6 +11,7 @@ import logging +# Replace/update the logger configuration logger = logging.getLogger(__name__) SYSTEM_PROMPT = '''You are a data scientist to help user to transform data that will be used for visualization. @@ -178,12 +180,10 @@ def transform_data(df): ``` ''' -def completion_response_wrapper(client, model, messages, n): +def completion_response_wrapper(client, messages, n): ### wrapper for completion response, especially handling errors try: - response = client.chat.completions.create( - model=model, messages=messages, temperature=0.7, max_tokens=1200, - top_p=0.95, n=n, frequency_penalty=0, presence_penalty=0, stop=None) + response = client.get_completion(messages = messages) except Exception as e: response = e @@ -192,9 +192,8 @@ def completion_response_wrapper(client, model, messages, n): class DataTransformationAgentV2(object): - def __init__(self, client, model, system_prompt=None): + def __init__(self, client, system_prompt=None): self.client = client - self.model = model self.system_prompt = system_prompt if system_prompt is not None else SYSTEM_PROMPT def process_gpt_response(self, input_tables, messages, response): @@ -210,8 +209,8 @@ def process_gpt_response(self, input_tables, messages, response): candidates = [] for choice in response.choices: - # logger.info("\n=== Data transformation result ===>\n") - # logger.info(choice.message.content + "\n") + logger.info("=== Data transformation result ===>") + logger.info(choice.message.content + "\n") json_blocks = extract_json_objects(choice.message.content + "\n") if len(json_blocks) > 0: @@ -221,6 +220,9 @@ def process_gpt_response(self, input_tables, messages, response): code_blocks = extract_code_from_gpt_response(choice.message.content + "\n", "python") + logger.info("=== Code blocks ===>") + logger.info(code_blocks) + if len(code_blocks) > 0: code_str = code_blocks[-1] @@ -237,15 +239,18 @@ def process_gpt_response(self, input_tables, messages, response): logger.warning('Error occurred during code execution:') error_message = f"An error occurred during code execution. Error type: {type(e).__name__}" logger.warning(error_message) - result = {'status': 'other error', 'code': code_str, 'content': error_message} + result = {'status': 'error', 'code': code_str, 'content': error_message} else: - result = {'status': 'no transformation', 'code': "", 'content': input_tables[0]['rows']} + result = {'status': 'error', 'code': "", 'content': "No code block found in the response. The model is unable to generate code to complete the task."} result['dialog'] = [*messages, {"role": choice.message.role, "content": choice.message.content}] result['agent'] = 'DataTransformationAgent' result['refined_goal'] = refined_goal candidates.append(result) + logger.info("=== Candidates ===>") + logger.info(candidates) + return candidates @@ -265,7 +270,7 @@ def run(self, input_tables, description, expected_fields: list[str], n=1): messages = [{"role":"system", "content": self.system_prompt}, {"role":"user","content": user_query}] - response = completion_response_wrapper(self.client, self.model, messages, n) + response = completion_response_wrapper(self.client, messages, n) return self.process_gpt_response(input_tables, messages, response) @@ -287,6 +292,6 @@ def followup(self, input_tables, dialog, output_fields: list[str], new_instructi messages = [*updated_dialog, {"role":"user", "content": f"Update the code above based on the following instruction:\n\n{json.dumps(goal, indent=4)}"}] - response = completion_response_wrapper(self.client, self.model, messages, n) + response = completion_response_wrapper(self.client, messages, n) return self.process_gpt_response(input_tables, messages, response) diff --git a/py-src/data_formulator/agents/agent_data_transformation.py b/py-src/data_formulator/agents/agent_data_transformation.py index 37b337fb..16c6414b 100644 --- a/py-src/data_formulator/agents/agent_data_transformation.py +++ b/py-src/data_formulator/agents/agent_data_transformation.py @@ -122,9 +122,8 @@ def transform_data(df_0): class DataTransformationAgent(object): - def __init__(self, client, model): + def __init__(self, client): self.client = client - self.model = model def process_gpt_response(self, input_tables, messages, response): """process gpt response to handle execution""" @@ -185,9 +184,7 @@ def run(self, input_tables, description, expected_fields: list[str], n=1, enrich {"role":"user","content": user_query}] ###### the part that calls open_ai - response = self.client.chat.completions.create( - model=self.model, messages = messages, temperature=0.7, max_tokens=1200, - top_p=0.95, n=n, frequency_penalty=0, presence_penalty=0, stop=None) + response = self.client.get_completion(messages = messages) return self.process_gpt_response(input_tables, messages, response) @@ -207,9 +204,7 @@ def followup(self, input_tables, dialog, output_fields: list[str], new_instructi "content": "Update the code above based on the following instruction:\n\n" + new_instruction + output_fields_instr}] ##### the part that calls open_ai - response = self.client.chat.completions.create( - model=self.model, messages=messages, temperature=0.7, max_tokens=1200, - top_p=0.95, n=n, frequency_penalty=0, presence_penalty=0, stop=None) + response = self.client.get_completion(messages = messages) logger.info(response) diff --git a/py-src/data_formulator/agents/agent_generic_py_concept.py b/py-src/data_formulator/agents/agent_generic_py_concept.py index 6ec2bab5..c1f7b868 100644 --- a/py-src/data_formulator/agents/agent_generic_py_concept.py +++ b/py-src/data_formulator/agents/agent_generic_py_concept.py @@ -157,9 +157,8 @@ def derive(row, df): class GenericPyConceptDeriveAgent(object): - def __init__(self, client, model_version): + def __init__(self, client): self.client = client - self.model_version = model_version def process_gpt_response(self, input_table, output_field, response, messages): #log = {'messages': messages, 'response': response.model_dump(mode='json')} @@ -220,10 +219,7 @@ def run(self, input_table, output_field, description): {"role":"user","content": user_query}] ###### the part that calls open_ai - response = self.client.chat.completions.create( - model=self.model_version, - messages = messages, temperature=0.7, max_tokens=1200, - top_p=0.95, n=1, frequency_penalty=0, presence_penalty=0, stop=None) + response = self.client.get_completion(messages = messages) return self.process_gpt_response(input_table, output_field, response, messages) @@ -234,9 +230,7 @@ def followup(self, input_table, dialog, output_field: str, new_instruction: str, "content": new_instruction + '\n update the function accordingly'}] ##### the part that calls open_ai - response = self.client.chat.completions.create( - model=self.model, messages=messages, temperature=0.7, max_tokens=1200, - top_p=0.95, n=n, frequency_penalty=0, presence_penalty=0, stop=None) + response = self.client.get_completion(messages = messages) candidates = self.process_gpt_response(input_table, output_field, response, messages) diff --git a/py-src/data_formulator/agents/agent_py_concept_derive.py b/py-src/data_formulator/agents/agent_py_concept_derive.py index 57299347..d6371d7a 100644 --- a/py-src/data_formulator/agents/agent_py_concept_derive.py +++ b/py-src/data_formulator/agents/agent_py_concept_derive.py @@ -131,8 +131,7 @@ def derive(writing, reading, math): class PyConceptDeriveAgent(object): - def __init__(self, client, model): - self.model = model + def __init__(self, client): self.client = client def run(self, input_table, input_fields, output_field, description): @@ -163,9 +162,7 @@ def derive({arg_string}): {"role":"user","content": user_query}] ###### the part that calls open_ai - response = self.client.chat.completions.create( - model=self.model, messages = messages, temperature=0.7, max_tokens=1200, - top_p=0.95, n=1, frequency_penalty=0, presence_penalty=0, stop=None) + response = self.client.get_completion(messages = messages) #log = {'messages': messages, 'response': response.model_dump(mode='json')} diff --git a/py-src/data_formulator/agents/agent_sort_data.py b/py-src/data_formulator/agents/agent_sort_data.py index aa13ed35..20a628a3 100644 --- a/py-src/data_formulator/agents/agent_sort_data.py +++ b/py-src/data_formulator/agents/agent_sort_data.py @@ -65,9 +65,8 @@ class SortDataAgent(object): - def __init__(self, client, model): + def __init__(self, client): self.client = client - self.model = model def run(self, name, values, n=1): @@ -84,9 +83,7 @@ def run(self, name, values, n=1): {"role":"user","content": user_query}] ###### the part that calls open_ai - response = self.client.chat.completions.create( - model=self.model, messages = messages, temperature=0.2, max_tokens=2400, - top_p=0.95, n=n, frequency_penalty=0, presence_penalty=0, stop=None) + response = self.client.get_completion(messages = messages) #log = {'messages': messages, 'response': response.model_dump(mode='json')} diff --git a/py-src/data_formulator/agents/agent_utils.py b/py-src/data_formulator/agents/agent_utils.py index 802112f3..792b3d8a 100644 --- a/py-src/data_formulator/agents/agent_utils.py +++ b/py-src/data_formulator/agents/agent_utils.py @@ -66,6 +66,7 @@ def table_hash(table): frozen_table = tuple(sorted([tuple([hash(value_handling_func(r[key])) for key in schema]) for r in table])) return hash(frozen_table) + def extract_code_from_gpt_response(code_raw, language): """search for matches and then look for pairs of ```...``` to extract code""" diff --git a/py-src/data_formulator/agents/client_utils.py b/py-src/data_formulator/agents/client_utils.py index fe854c7a..de83a43a 100644 --- a/py-src/data_formulator/agents/client_utils.py +++ b/py-src/data_formulator/agents/client_utils.py @@ -1,35 +1,61 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. - -import openai import os -import sys - +from litellm import completion from azure.identity import DefaultAzureCredential, get_bearer_token_provider -def get_client(endpoint, key): +class Client(object): + """ + Returns a LiteLLM client configured for the specified endpoint and model. + Supports OpenAI, Azure, Ollama, and other providers via LiteLLM. + """ + def __init__(self, endpoint, model, api_key=None, api_base=None, api_version=None): + + self.endpoint = endpoint + self.model = model + + # other params, including temperature, max_completion_tokens, api_base, api_version + self.params = { + "api_key": api_key, + "temperature": 0.7, + "max_completion_tokens": 1200, + } - endpoint = os.getenv("ENDPOINT") if endpoint == "default" else endpoint + if self.endpoint == "gemini": + if model.startswith("gemini/"): + self.model = model + else: + self.model = f"gemini/{model}" + elif self.endpoint == "anthropic": + if model.startswith("anthropic/"): + self.model = model + else: + self.model = f"anthropic/{model}" + elif self.endpoint == "azure": + self.params["api_base"] = api_base + self.params["api_version"] = api_version if api_version else "2024-02-15-preview" + if api_key is None or api_key == "": + token_provider = get_bearer_token_provider( + DefaultAzureCredential(), "https://cognitiveservices.azure.com/.default" + ) + self.params["azure_ad_token_provider"] = token_provider + self.params["custom_llm_provider"] = "azure" + elif self.endpoint == "ollama": + self.params["api_base"] = api_base if api_base else "http://localhost:11434" + self.params["max_tokens"] = self.params["max_completion_tokens"] + if model.startswith("ollama/"): + self.model = model + else: + self.model = f"ollama/{model}" - if key is None or key == "": - # using azure keyless access method - token_provider = get_bearer_token_provider( - DefaultAzureCredential(), "https://cognitiveservices.azure.com/.default" - ) - print(token_provider) - print(endpoint) - client = openai.AzureOpenAI( - api_version="2024-02-15-preview", - azure_endpoint=endpoint, - azure_ad_token_provider=token_provider - ) - elif endpoint == 'openai': - client = openai.OpenAI(api_key=key) - else: - client = openai.AzureOpenAI( - azure_endpoint = endpoint, - api_key=key, - api_version="2024-02-15-preview" - ) - return client \ No newline at end of file + def get_completion(self, messages): + """ + Returns a LiteLLM client configured for the specified endpoint and model. + Supports OpenAI, Azure, Ollama, and other providers via LiteLLM. + """ + # Configure LiteLLM + return completion( + model=self.model, + messages=messages, + drop_params=True, + **self.params + ) \ No newline at end of file diff --git a/py-src/data_formulator/app.py b/py-src/data_formulator/app.py index 1c89baf6..4120e6df 100644 --- a/py-src/data_formulator/app.py +++ b/py-src/data_formulator/app.py @@ -20,6 +20,8 @@ from flask_cors import CORS +import logging + import json import time from pathlib import Path @@ -35,24 +37,57 @@ from data_formulator.agents.agent_data_clean import DataCleanAgent from data_formulator.agents.agent_code_explanation import CodeExplanationAgent -from data_formulator.agents.client_utils import get_client +from data_formulator.agents.client_utils import Client from dotenv import load_dotenv APP_ROOT = Path(os.path.join(Path(__file__).parent)).absolute() -print(APP_ROOT) - -# try to look for stored openAI keys information from the ROOT dir, -# this file might be in one of the two locations -load_dotenv(os.path.join(APP_ROOT, "..", "..", 'openai-keys.env')) -load_dotenv(os.path.join(APP_ROOT, 'openai-keys.env')) - import os app = Flask(__name__, static_url_path='', static_folder=os.path.join(APP_ROOT, "dist")) CORS(app) +print(APP_ROOT) + +# Load the single environment file +load_dotenv(os.path.join(APP_ROOT, "..", "..", 'api-keys.env')) +load_dotenv(os.path.join(APP_ROOT, 'api-keys.env')) + +# Configure root logger for general application logging +logging.basicConfig( + level=logging.INFO, + format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', + handlers=[logging.StreamHandler(sys.stdout)] +) + +# Get logger for this module +logger = logging.getLogger(__name__) + +# Configure Flask app logger to use the same settings +app.logger.handlers = [] +for handler in logging.getLogger().handlers: + app.logger.addHandler(handler) + +# Example usage: +logger.info("Application level log") # General application logging +app.logger.info("Flask specific log") # Web request related logging + + + +def get_client(model_config): + for key in model_config: + model_config[key] = model_config[key].strip() + + client = Client( + model_config["endpoint"], + model_config["model"], + model_config["api_key"] if "api_key" in model_config else None, + html.escape(model_config["api_base"]) if "api_base" in model_config else None, + model_config["api_version"] if "api_version" in model_config else None) + + return client + @app.route('/vega-datasets') def get_example_dataset_list(): dataset_names = vega_data.list_datasets() @@ -118,38 +153,56 @@ def get_datasets(path): @app.route('/check-available-models', methods=['GET', 'POST']) def check_available_models(): - results = [] + + # Define configurations for different providers + providers = ['openai', 'azure', 'anthropic', 'gemini', 'ollama'] - # dont need to check if it's empty - if os.getenv("ENDPOINT") is None: - return json.dumps(results) - - client = get_client(os.getenv("ENDPOINT"), "") - models = [model.strip() for model in os.getenv("MODELS").split(',')] - - for model in models: - try: - response = client.chat.completions.create( - model=model, - messages=[ - {"role": "system", "content": "You are a helpful assistant."}, - {"role": "user", "content": "Respond 'I can hear you.' if you can hear me. Do not say anything other than 'I can hear you.'"}, - ] - ) - - print(f"model: {model}") - print(f"welcome message: {response.choices[0].message.content}") - - if "I can hear you." in response.choices[0].message.content: - results.append({ - "endpoint": "default", - "key": "", - "model": model - }) - except: - pass - + for provider in providers: + # Skip if provider is not enabled + if not os.getenv(f"{provider.upper()}_ENABLED", "").lower() == "true": + continue + + api_key = os.getenv(f"{provider.upper()}_API_KEY", "") + api_base = os.getenv(f"{provider.upper()}_API_BASE", "") + api_version = os.getenv(f"{provider.upper()}_API_VERSION", "") + models = os.getenv(f"{provider.upper()}_MODELS", "") + + if not (api_key or api_base): + continue + + if not models: + continue + + # Build config for each model + for model in models.split(","): + model = model.strip() + if not model: + continue + + model_config = { + "id": f"{provider}-{model}-{api_key}-{api_base}-{api_version}", + "endpoint": provider, + "model": model, + "api_key": api_key, + "api_base": api_base, + "api_version": api_version + } + + try: + client = get_client(model_config) + response = client.get_completion( + messages=[ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "Respond 'I can hear you.' if you can hear me."}, + ] + ) + + if "I can hear you." in response.choices[0].message.content: + results.append(model_config) + except Exception as e: + print(f"Error testing {provider} model {model}: {e}") + return json.dumps(results) @app.route('/test-model', methods=['GET', 'POST']) @@ -158,31 +211,27 @@ def test_model(): if request.is_json: app.logger.info("# code query: ") content = request.get_json() - endpoint = html.escape(content['endpoint'].strip()) - key = html.escape(f"{content['key']}".strip()) + # contains endpoint, key, model, api_base, api_version + print("content------------------------------") print(content) - client = get_client(endpoint, key) - model = html.escape(content['model'].strip()) + client = get_client(content['model']) try: - response = client.chat.completions.create( - model=model, + response = client.get_completion( messages=[ {"role": "system", "content": "You are a helpful assistant."}, {"role": "user", "content": "Respond 'I can hear you.' if you can hear me. Do not say anything other than 'I can hear you.'"}, ] ) - print(f"model: {model}") + print(f"model: {content['model']}") print(f"welcome message: {response.choices[0].message.content}") if "I can hear you." in response.choices[0].message.content: result = { - "endpoint": endpoint, - "key": key, - "model": model, + "model": content['model'], "status": 'ok', "message": "" } @@ -190,14 +239,12 @@ def test_model(): print(f"Error: {e}") error_message = str(e) result = { - "endpoint": endpoint, - "key": key, - "model": model, + "model": content['model'], "status": 'error', "message": error_message, } else: - {'status': 'error'} + result = {'status': 'error'} return json.dumps(result) @@ -269,11 +316,11 @@ def process_data_on_load_request(): content = request.get_json() token = content["token"] - client = get_client(content['model']['endpoint'], content['model']['key']) - model = content['model']['model'] + client = get_client(content['model']) + app.logger.info(f" model: {content['model']}") - agent = DataLoadAgent(client=client, model=model) + agent = DataLoadAgent(client=client) candidates = agent.run(content["input_data"]) candidates = [c['content'] for c in candidates if c['status'] == 'ok'] @@ -294,11 +341,10 @@ def derive_concept_request(): content = request.get_json() token = content["token"] - client = get_client(content['model']['endpoint'], content['model']['key']) - model = content['model']['model'] + client = get_client(content['model']) + app.logger.info(f" model: {content['model']}") - - agent = ConceptDeriveAgent(client=client, model=model) + agent = ConceptDeriveAgent(client=client) #print(content["input_data"]) @@ -323,12 +369,11 @@ def clean_data_request(): content = request.get_json() token = content["token"] - client = get_client(content['model']['endpoint'], content['model']['key']) - model = content['model']['model'] + client = get_client(content['model']) app.logger.info(f" model: {content['model']}") - agent = DataCleanAgent(client=client, model=model) + agent = DataCleanAgent(client=client) candidates = agent.run(content['content_type'], content["raw_data"], content["image_cleaning_instruction"]) @@ -350,11 +395,9 @@ def sort_data_request(): content = request.get_json() token = content["token"] - client = get_client(content['model']['endpoint'], content['model']['key']) - model = content['model']['model'] - app.logger.info(f" model: {content['model']}") + client = get_client(content['model']) - agent = SortDataAgent(client=client, model=model) + agent = SortDataAgent(client=client) candidates = agent.run(content['field'], content['items']) #candidates, dialog = limbo_concept.call_codex_sort(content["items"], content["field"]) @@ -375,9 +418,7 @@ def derive_data(): content = request.get_json() token = content["token"] - client = get_client(content['model']['endpoint'], content['model']['key']) - model = content['model']['model'] - app.logger.info(f" model: {content['model']}") + client = get_client(content['model']) # each table is a dict with {"name": xxx, "rows": [...]} input_tables = content["input_tables"] @@ -394,10 +435,10 @@ def derive_data(): if mode == "recommendation": # now it's in recommendation mode - agent = DataRecAgent(client, model) + agent = DataRecAgent(client=client) results = agent.run(input_tables, instruction) else: - agent = DataTransformationAgentV2(client=client, model=model) + agent = DataTransformationAgentV2(client=client) results = agent.run(input_tables, instruction, [field['name'] for field in new_fields]) repair_attempts = 0 @@ -429,9 +470,7 @@ def refine_data(): content = request.get_json() token = content["token"] - client = get_client(content['model']['endpoint'], content['model']['key']) - model = content['model']['model'] - app.logger.info(f" model: {content['model']}") + client = get_client(content['model']) # each table is a dict with {"name": xxx, "rows": [...]} input_tables = content["input_tables"] @@ -443,7 +482,7 @@ def refine_data(): print(dialog) # always resort to the data transform agent - agent = DataTransformationAgentV2(client, model=model) + agent = DataTransformationAgentV2(client=client) results = agent.followup(input_tables, dialog, [field['name'] for field in output_fields], new_instruction) repair_attempts = 0 @@ -469,15 +508,13 @@ def request_code_expl(): content = request.get_json() token = content["token"] - client = get_client(content['model']['endpoint'], content['model']['key']) - model = content['model']['model'] - app.logger.info(f" model: {content['model']}") + client = get_client(content['model']) # each table is a dict with {"name": xxx, "rows": [...]} input_tables = content["input_tables"] code = content["code"] - code_expl_agent = CodeExplanationAgent(client=client, model=model) + code_expl_agent = CodeExplanationAgent(client=client) expl = code_expl_agent.run(input_tables, code) else: expl = "" diff --git a/pyproject.toml b/pyproject.toml index b677c997..f1b234e2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -33,7 +33,8 @@ dependencies = [ "azure-identity", "azure-keyvault-secrets", "python-dotenv", - "vega_datasets" + "vega_datasets", + "litellm" ] [project.urls] diff --git a/requirements.txt b/requirements.txt index 8ba373ef..4f26a8e4 100644 --- a/requirements.txt +++ b/requirements.txt @@ -11,4 +11,5 @@ azure-identity azure-keyvault-secrets python-dotenv vega_datasets +litellm -e . #also need to install data formulator itself \ No newline at end of file diff --git a/src/app/dfSlice.tsx b/src/app/dfSlice.tsx index b306745c..fc48946b 100644 --- a/src/app/dfSlice.tsx +++ b/src/app/dfSlice.tsx @@ -26,12 +26,21 @@ export const generateFreshChart = (tableRef: string, chartType?: string) : Chart } } +export interface ModelConfig { + id: string; // unique identifier for the model / client combination + endpoint: string; + model: string; + api_key?: string; + api_base?: string; + api_version?: string; +} + // Define a type for the slice state export interface DataFormulatorState { - oaiModels: {endpoint: string, key: string, model: string }[]; - selectedModel: {endpoint: string, model: string} | undefined; - testedModels: {endpoint: string, model: string, status: 'ok' | 'error' | 'testing' | 'unknown', message: string}[]; + models: ModelConfig[]; + selectedModelId: string | undefined; + testedModels: {id: string, status: 'ok' | 'error' | 'testing' | 'unknown', message: string}[]; tables : DictTable[]; charts: Chart[]; @@ -63,8 +72,8 @@ export interface DataFormulatorState { // Define the initial state using that type const initialState: DataFormulatorState = { - oaiModels: [], - selectedModel: undefined, + models: [], + selectedModelId: undefined, testedModels: [], tables: [], @@ -194,7 +203,7 @@ export const fetchCodeExpl = createAsyncThunk( export const fetchAvailableModels = createAsyncThunk( "dataFormulatorSlice/fetchAvailableModels", async () => { - console.log(">>> call agent to infer semantic types <<<") + console.log(">>> call agent to fetch available models <<<") let message = { method: 'POST', headers: { 'Content-Type': 'application/json', }, @@ -222,7 +231,7 @@ export const dataFormulatorSlice = createSlice({ // avoid resetting inputted models // state.oaiModels = state.oaiModels.filter((m: any) => m.endpoint != 'default'); - state.selectedModel = state.oaiModels.length > 0 ? state.oaiModels[0] : undefined; + state.selectedModelId = state.models.length > 0 ? state.models[0].id : undefined; state.testedModels = []; state.tables = []; @@ -247,8 +256,8 @@ export const dataFormulatorSlice = createSlice({ let savedState = action.payload; - state.oaiModels = savedState.oaiModels.filter((m: any) => m.endpoint != 'default'); - state.selectedModel = state.oaiModels.length > 0 ? state.oaiModels[0] : undefined; + state.models = savedState.models; + state.selectedModelId = savedState.selectedModelId; state.testedModels = []; // models should be tested again //state.table = undefined; @@ -274,25 +283,27 @@ export const dataFormulatorSlice = createSlice({ toggleBetaMode: (state, action: PayloadAction) => { state.betaMode = action.payload; }, - selectModel: (state, action: PayloadAction<{model: string, endpoint: string}>) => { - state.selectedModel = action.payload; + selectModel: (state, action: PayloadAction) => { + state.selectedModelId = action.payload; }, - addModel: (state, action: PayloadAction<{model: string, key: string, endpoint: string}>) => { - state.oaiModels = [...state.oaiModels, action.payload]; + addModel: (state, action: PayloadAction) => { + state.models = [...state.models, action.payload]; }, - removeModel: (state, action: PayloadAction<{model: string, endpoint: string}>) => { - let model = action.payload.model; - let endpoint = action.payload.endpoint; - state.oaiModels = state.oaiModels.filter(oaiModel => oaiModel.model != model || oaiModel.endpoint != endpoint ); - state.testedModels = state.testedModels.filter(m => !(m.model == model && m.endpoint == endpoint)); + removeModel: (state, action: PayloadAction) => { + state.models = state.models.filter(model => model.id != action.payload); + if (state.selectedModelId == action.payload) { + state.selectedModelId = undefined; + } }, - updateModelStatus: (state, action: PayloadAction<{model: string, endpoint: string, status: 'ok' | 'error' | 'testing' | 'unknown', message: string}>) => { - let model = action.payload.model; - let endpoint = action.payload.endpoint; + updateModelStatus: (state, action: PayloadAction<{id: string, status: 'ok' | 'error' | 'testing' | 'unknown', message: string}>) => { + let id = action.payload.id; let status = action.payload.status; let message = action.payload.message; - state.testedModels = [...state.testedModels.filter(t => !(t.model == model && t.endpoint == endpoint)), {model, endpoint, status, message} ] + state.testedModels = [ + ...state.testedModels.filter(t => t.id != id), + {id: id, status, message} + ]; }, addTable: (state, action: PayloadAction) => { let table = action.payload; @@ -640,21 +651,25 @@ export const dataFormulatorSlice = createSlice({ }) .addCase(fetchAvailableModels.fulfilled, (state, action) => { let defaultModels = action.payload; - state.oaiModels = [...defaultModels, ...state.oaiModels.filter(e => !defaultModels.map((m: any) => m.endpoint).includes(e.endpoint))]; - console.log(state.oaiModels) + + state.models = [ + ...defaultModels, + ...state.models.filter(e => !defaultModels.map((m: ModelConfig) => m.endpoint).includes(e.endpoint)) + ]; - state.testedModels = [...state.testedModels.filter(t => !(t.endpoint == 'default')), - ...defaultModels.map((m: any) => {return {endpoint: m.endpoint, model: m.model, status: 'ok'}}) ] + state.testedModels = [ + ...defaultModels.map((m: ModelConfig) => {return {id: m.id, status: 'ok'}}) , + ...state.testedModels.filter(t => !defaultModels.map((m: ModelConfig) => m.id).includes(t.id)) + ] - if (state.selectedModel == undefined && defaultModels.length > 0) { - state.selectedModel = { - model: defaultModels[0].model, - endpoint: defaultModels[0].endpoint - } + if (state.selectedModelId == undefined && defaultModels.length > 0) { + state.selectedModelId = defaultModels[0].id; } - - console.log("fetched models"); - console.log(action.payload); + + console.log("load model complete"); + console.log("state.models", state.models); + console.log("state.selectedModelId", state.selectedModelId); + console.log("state.testedModels", state.testedModels); }) .addCase(fetchCodeExpl.fulfilled, (state, action) => { let codeExpl = action.payload; @@ -670,12 +685,11 @@ export const dataFormulatorSlice = createSlice({ }) export const dfSelectors = { - getActiveModel: (state: DataFormulatorState) => { - return state.oaiModels.find(m => m.endpoint == state.selectedModel?.endpoint && m.model == state.selectedModel.model) || {'endpoint': 'default', model: 'gpt-4o', key: ""} + getActiveModel: (state: DataFormulatorState) : ModelConfig => { + return state.models.find(m => m.id == state.selectedModelId) || state.models[0]; } } - // derived field: extra all field items from the table export const getDataFieldItems = (baseTable: DictTable): FieldItem[] => { return baseTable.names.map((name, index) => { diff --git a/src/views/DataFormulator.tsx b/src/views/DataFormulator.tsx index 91b07393..d9598212 100644 --- a/src/views/DataFormulator.tsx +++ b/src/views/DataFormulator.tsx @@ -54,7 +54,7 @@ export const DataFormulatorFC = ({ }) => { const displayPanelSize = useSelector((state: DataFormulatorState) => state.displayPanelSize); const visPaneSize = useSelector((state: DataFormulatorState) => state.visPaneSize); const tables = useSelector((state: DataFormulatorState) => state.tables); - const selectedModel = useSelector((state: DataFormulatorState) => state.selectedModel); + const selectedModelId = useSelector((state: DataFormulatorState) => state.selectedModelId); const dispatch = useDispatch(); @@ -173,15 +173,14 @@ Totals (7 entries) 5 5 5 15 href="https://privacy.microsoft.com/en-US/data-privacy-notice">view data privacy notice ; - console.log("selected model?") - console.log(selectedModel) + console.log("selected model?") + console.log(selectedModelId) + return ( - {selectedModel == undefined ? modelSelectionDialogBox : (tables.length > 0 ? fixedSplitPane : dataUploadRequestBox)} + {selectedModelId == undefined ? modelSelectionDialogBox : (tables.length > 0 ? fixedSplitPane : dataUploadRequestBox)} - - ); } \ No newline at end of file diff --git a/src/views/EncodingBox.tsx b/src/views/EncodingBox.tsx index ecd40d00..41bbfc1b 100644 --- a/src/views/EncodingBox.tsx +++ b/src/views/EncodingBox.tsx @@ -594,7 +594,7 @@ export const EncodingBox: FC = function EncodingBox({ channel, }} freeSolo renderInput={(params) => ( - )} /> diff --git a/src/views/EncodingShelfCard.tsx b/src/views/EncodingShelfCard.tsx index 289cced6..6eea915c 100644 --- a/src/views/EncodingShelfCard.tsx +++ b/src/views/EncodingShelfCard.tsx @@ -328,7 +328,7 @@ export const EncodingShelfCard: FC = function ({ chartId dispatch(dfActions.addMessages({ "timestamp": Date.now(), "type": "error", - "value": `Data formulation failed, please retry.`, + "value": `Data formulation failed, please try again.`, "code": code, "detail": errorMessage })); diff --git a/src/views/EncodingShelfThread.tsx b/src/views/EncodingShelfThread.tsx index e6e55623..2763db2f 100644 --- a/src/views/EncodingShelfThread.tsx +++ b/src/views/EncodingShelfThread.tsx @@ -405,7 +405,7 @@ export const EncodingShelfThread: FC = function ({ cha //let triggers = currentTable.derive.triggers; let tableList = activeTableThread.map((tableId) =>
- {setModelDialogOpen(false)}} open={modelDialogOpen}> + { + if (reason !== 'backdropClick') { + setModelDialogOpen(false); + } + }} + > Select Model {modelTable} @@ -341,13 +454,13 @@ export const ModelSelectionButton: React.FC<{}> = ({ }) => { setShowKeys(!showKeys);}}> {showKeys ? 'hide' : 'show'} keys -