From bbedda412974e2274c5a03b4c75ad53342aa2e84 Mon Sep 17 00:00:00 2001 From: Chenglong Wang Date: Tue, 11 Feb 2025 18:04:56 -0800 Subject: [PATCH 01/16] supporting multiple LLM backend (require a little more testing) --- .../agents/agent_code_explanation.py | 7 +- .../agents/agent_concept_derive.py | 7 +- .../agents/agent_data_clean.py | 7 +- .../agents/agent_data_filter.py | 11 +- .../data_formulator/agents/agent_data_load.py | 7 +- .../data_formulator/agents/agent_data_rec.py | 8 +- .../agents/agent_data_transform_v2.py | 13 +- .../agents/agent_data_transformation.py | 11 +- .../agents/agent_generic_py_concept.py | 12 +- .../agents/agent_py_concept_derive.py | 7 +- .../data_formulator/agents/agent_sort_data.py | 7 +- py-src/data_formulator/agents/agent_utils.py | 1 + py-src/data_formulator/agents/client_utils.py | 89 ++++-- py-src/data_formulator/app.py | 99 +++--- src/app/dfSlice.tsx | 78 +++-- src/views/ModelSelectionDialog.tsx | 295 ++++++++++++------ 16 files changed, 385 insertions(+), 274 deletions(-) diff --git a/py-src/data_formulator/agents/agent_code_explanation.py b/py-src/data_formulator/agents/agent_code_explanation.py index 0053c69e..39b1b4be 100644 --- a/py-src/data_formulator/agents/agent_code_explanation.py +++ b/py-src/data_formulator/agents/agent_code_explanation.py @@ -66,9 +66,8 @@ def transform_data(df_0): class CodeExplanationAgent(object): - def __init__(self, client, model): + def __init__(self, client): self.client = client - self.model = model def run(self, input_tables, code): @@ -82,9 +81,7 @@ def run(self, input_tables, code): {"role":"user","content": user_query}] ###### the part that calls open_ai - response = self.client.chat.completions.create( - model=self.model, messages = messages, temperature=0.7, max_tokens=1200, - top_p=0.95, n=1, frequency_penalty=0, presence_penalty=0, stop=None) + response = self.client.get_completion(messages = messages) logger.info('\n=== explanation output ===>\n') logger.info(response.choices[0].message.content) diff --git a/py-src/data_formulator/agents/agent_concept_derive.py b/py-src/data_formulator/agents/agent_concept_derive.py index 69873b47..d7de8aab 100644 --- a/py-src/data_formulator/agents/agent_concept_derive.py +++ b/py-src/data_formulator/agents/agent_concept_derive.py @@ -167,9 +167,8 @@ class ConceptDeriveAgent(object): - def __init__(self, client, model): + def __init__(self, client): self.client = client - self.model = model def run(self, input_table, input_fields, output_field, description, n=1): """derive a new concept based on input table, input fields, and output field name, (and description) @@ -190,9 +189,7 @@ def run(self, input_table, input_fields, output_field, description, n=1): {"role":"user","content": user_query}] ###### the part that calls open_ai - response = self.client.chat.completions.create( - model=self.model, messages = messages, temperature=0.7, max_tokens=1200, - top_p=0.95, n=n, frequency_penalty=0, presence_penalty=0, stop=None) + response = self.client.get_completion(messages = messages) #log = {'messages': messages, 'response': response.model_dump(mode='json')} diff --git a/py-src/data_formulator/agents/agent_data_clean.py b/py-src/data_formulator/agents/agent_data_clean.py index 8cecbeca..818131e0 100644 --- a/py-src/data_formulator/agents/agent_data_clean.py +++ b/py-src/data_formulator/agents/agent_data_clean.py @@ -78,8 +78,7 @@ class DataCleanAgent(object): - def __init__(self, client, model): - self.model = model + def __init__(self, client): self.client = client def run(self, content_type, raw_data, image_cleaning_instruction): @@ -129,9 +128,7 @@ def run(self, content_type, raw_data, image_cleaning_instruction): messages = [system_message, user_prompt] ###### the part that calls open_ai - response = self.client.chat.completions.create( - model=self.model, messages = messages, temperature=0.7, max_tokens=1200, - top_p=0.95, n=1, frequency_penalty=0, presence_penalty=0, stop=None) + response = self.client.get_completion(messages = messages) candidates = [] for choice in response.choices: diff --git a/py-src/data_formulator/agents/agent_data_filter.py b/py-src/data_formulator/agents/agent_data_filter.py index 895f7663..ffb7b637 100644 --- a/py-src/data_formulator/agents/agent_data_filter.py +++ b/py-src/data_formulator/agents/agent_data_filter.py @@ -125,9 +125,8 @@ def filter_row(row, df): class DataFilterAgent(object): - def __init__(self, client, model): + def __init__(self, client): self.client = client - self.model = model def process_gpt_result(self, input_table, response, messages): #log = {'messages': messages, 'response': response.model_dump(mode='json')} @@ -177,9 +176,7 @@ def run(self, input_table, description): {"role":"user","content": user_query}] ###### the part that calls open_ai - response = self.client.chat.completions.create( - model=self.model, messages = messages, temperature=0.7, max_tokens=1200, - top_p=0.95, n=1, frequency_penalty=0, presence_penalty=0, stop=None) + response = self.client.get_completion(messages = messages) return self.process_gpt_result(input_table, response, messages) @@ -190,8 +187,6 @@ def followup(self, input_table, dialog, new_instruction: str, n=1): "content": new_instruction + '\nupdate the filter function accordingly'}] ##### the part that calls open_ai - response = self.client.chat.completions.create( - model=self.model, messages=messages, temperature=0.7, max_tokens=1200, - top_p=0.95, n=n, frequency_penalty=0, presence_penalty=0, stop=None) + response = self.client.get_completion(messages = messages) return self.process_gpt_result(input_table, response, messages) \ No newline at end of file diff --git a/py-src/data_formulator/agents/agent_data_load.py b/py-src/data_formulator/agents/agent_data_load.py index 86fb367d..8d534cf6 100644 --- a/py-src/data_formulator/agents/agent_data_load.py +++ b/py-src/data_formulator/agents/agent_data_load.py @@ -124,9 +124,8 @@ class DataLoadAgent(object): - def __init__(self, client, model): + def __init__(self, client): self.client = client - self.model = model def run(self, input_data, n=1): @@ -140,9 +139,7 @@ def run(self, input_data, n=1): {"role":"user","content": user_query}] ###### the part that calls open_ai - response = self.client.chat.completions.create( - model=self.model, messages=messages, temperature=0.2, max_tokens=4096, - top_p=0.95, n=n, frequency_penalty=0, presence_penalty=0, stop=None) + response = self.client.get_completion(messages = messages) #log = {'messages': messages, 'response': response.model_dump(mode='json')} diff --git a/py-src/data_formulator/agents/agent_data_rec.py b/py-src/data_formulator/agents/agent_data_rec.py index 9c718273..b9f3cb9e 100644 --- a/py-src/data_formulator/agents/agent_data_rec.py +++ b/py-src/data_formulator/agents/agent_data_rec.py @@ -126,9 +126,8 @@ def transform_data(df): class DataRecAgent(object): - def __init__(self, client, model, system_prompt=None): + def __init__(self, client, system_prompt=None): self.client = client - self.model = model self.system_prompt = system_prompt if system_prompt is not None else SYSTEM_PROMPT def process_gpt_response(self, input_tables, messages, response): @@ -192,7 +191,7 @@ def run(self, input_tables, description, n=1): messages = [{"role":"system", "content": self.system_prompt}, {"role":"user","content": user_query}] - response = completion_response_wrapper(self.client, self.model, messages, n) + response = completion_response_wrapper(self.client, messages, n) return self.process_gpt_response(input_tables, messages, response) @@ -204,7 +203,6 @@ def followup(self, input_tables, dialog, new_instruction: str, n=1): messages = [*dialog, {"role":"user", "content": f"Update: \n\n{new_instruction}"}] - ##### the part that calls open_ai - response = completion_response_wrapper(self.client, self.model, messages, n) + response = completion_response_wrapper(self.client, messages, n) return self.process_gpt_response(input_tables, messages, response) \ No newline at end of file diff --git a/py-src/data_formulator/agents/agent_data_transform_v2.py b/py-src/data_formulator/agents/agent_data_transform_v2.py index d9e530ea..20ec489d 100644 --- a/py-src/data_formulator/agents/agent_data_transform_v2.py +++ b/py-src/data_formulator/agents/agent_data_transform_v2.py @@ -178,12 +178,10 @@ def transform_data(df): ``` ''' -def completion_response_wrapper(client, model, messages, n): +def completion_response_wrapper(client, messages, n): ### wrapper for completion response, especially handling errors try: - response = client.chat.completions.create( - model=model, messages=messages, temperature=0.7, max_tokens=1200, - top_p=0.95, n=n, frequency_penalty=0, presence_penalty=0, stop=None) + response = client.get_completion(messages = messages) except Exception as e: response = e @@ -192,9 +190,8 @@ def completion_response_wrapper(client, model, messages, n): class DataTransformationAgentV2(object): - def __init__(self, client, model, system_prompt=None): + def __init__(self, client, system_prompt=None): self.client = client - self.model = model self.system_prompt = system_prompt if system_prompt is not None else SYSTEM_PROMPT def process_gpt_response(self, input_tables, messages, response): @@ -265,7 +262,7 @@ def run(self, input_tables, description, expected_fields: list[str], n=1): messages = [{"role":"system", "content": self.system_prompt}, {"role":"user","content": user_query}] - response = completion_response_wrapper(self.client, self.model, messages, n) + response = completion_response_wrapper(self.client, messages, n) return self.process_gpt_response(input_tables, messages, response) @@ -287,6 +284,6 @@ def followup(self, input_tables, dialog, output_fields: list[str], new_instructi messages = [*updated_dialog, {"role":"user", "content": f"Update the code above based on the following instruction:\n\n{json.dumps(goal, indent=4)}"}] - response = completion_response_wrapper(self.client, self.model, messages, n) + response = completion_response_wrapper(self.client, messages, n) return self.process_gpt_response(input_tables, messages, response) diff --git a/py-src/data_formulator/agents/agent_data_transformation.py b/py-src/data_formulator/agents/agent_data_transformation.py index 37b337fb..16c6414b 100644 --- a/py-src/data_formulator/agents/agent_data_transformation.py +++ b/py-src/data_formulator/agents/agent_data_transformation.py @@ -122,9 +122,8 @@ def transform_data(df_0): class DataTransformationAgent(object): - def __init__(self, client, model): + def __init__(self, client): self.client = client - self.model = model def process_gpt_response(self, input_tables, messages, response): """process gpt response to handle execution""" @@ -185,9 +184,7 @@ def run(self, input_tables, description, expected_fields: list[str], n=1, enrich {"role":"user","content": user_query}] ###### the part that calls open_ai - response = self.client.chat.completions.create( - model=self.model, messages = messages, temperature=0.7, max_tokens=1200, - top_p=0.95, n=n, frequency_penalty=0, presence_penalty=0, stop=None) + response = self.client.get_completion(messages = messages) return self.process_gpt_response(input_tables, messages, response) @@ -207,9 +204,7 @@ def followup(self, input_tables, dialog, output_fields: list[str], new_instructi "content": "Update the code above based on the following instruction:\n\n" + new_instruction + output_fields_instr}] ##### the part that calls open_ai - response = self.client.chat.completions.create( - model=self.model, messages=messages, temperature=0.7, max_tokens=1200, - top_p=0.95, n=n, frequency_penalty=0, presence_penalty=0, stop=None) + response = self.client.get_completion(messages = messages) logger.info(response) diff --git a/py-src/data_formulator/agents/agent_generic_py_concept.py b/py-src/data_formulator/agents/agent_generic_py_concept.py index 6ec2bab5..c1f7b868 100644 --- a/py-src/data_formulator/agents/agent_generic_py_concept.py +++ b/py-src/data_formulator/agents/agent_generic_py_concept.py @@ -157,9 +157,8 @@ def derive(row, df): class GenericPyConceptDeriveAgent(object): - def __init__(self, client, model_version): + def __init__(self, client): self.client = client - self.model_version = model_version def process_gpt_response(self, input_table, output_field, response, messages): #log = {'messages': messages, 'response': response.model_dump(mode='json')} @@ -220,10 +219,7 @@ def run(self, input_table, output_field, description): {"role":"user","content": user_query}] ###### the part that calls open_ai - response = self.client.chat.completions.create( - model=self.model_version, - messages = messages, temperature=0.7, max_tokens=1200, - top_p=0.95, n=1, frequency_penalty=0, presence_penalty=0, stop=None) + response = self.client.get_completion(messages = messages) return self.process_gpt_response(input_table, output_field, response, messages) @@ -234,9 +230,7 @@ def followup(self, input_table, dialog, output_field: str, new_instruction: str, "content": new_instruction + '\n update the function accordingly'}] ##### the part that calls open_ai - response = self.client.chat.completions.create( - model=self.model, messages=messages, temperature=0.7, max_tokens=1200, - top_p=0.95, n=n, frequency_penalty=0, presence_penalty=0, stop=None) + response = self.client.get_completion(messages = messages) candidates = self.process_gpt_response(input_table, output_field, response, messages) diff --git a/py-src/data_formulator/agents/agent_py_concept_derive.py b/py-src/data_formulator/agents/agent_py_concept_derive.py index 57299347..d6371d7a 100644 --- a/py-src/data_formulator/agents/agent_py_concept_derive.py +++ b/py-src/data_formulator/agents/agent_py_concept_derive.py @@ -131,8 +131,7 @@ def derive(writing, reading, math): class PyConceptDeriveAgent(object): - def __init__(self, client, model): - self.model = model + def __init__(self, client): self.client = client def run(self, input_table, input_fields, output_field, description): @@ -163,9 +162,7 @@ def derive({arg_string}): {"role":"user","content": user_query}] ###### the part that calls open_ai - response = self.client.chat.completions.create( - model=self.model, messages = messages, temperature=0.7, max_tokens=1200, - top_p=0.95, n=1, frequency_penalty=0, presence_penalty=0, stop=None) + response = self.client.get_completion(messages = messages) #log = {'messages': messages, 'response': response.model_dump(mode='json')} diff --git a/py-src/data_formulator/agents/agent_sort_data.py b/py-src/data_formulator/agents/agent_sort_data.py index aa13ed35..20a628a3 100644 --- a/py-src/data_formulator/agents/agent_sort_data.py +++ b/py-src/data_formulator/agents/agent_sort_data.py @@ -65,9 +65,8 @@ class SortDataAgent(object): - def __init__(self, client, model): + def __init__(self, client): self.client = client - self.model = model def run(self, name, values, n=1): @@ -84,9 +83,7 @@ def run(self, name, values, n=1): {"role":"user","content": user_query}] ###### the part that calls open_ai - response = self.client.chat.completions.create( - model=self.model, messages = messages, temperature=0.2, max_tokens=2400, - top_p=0.95, n=n, frequency_penalty=0, presence_penalty=0, stop=None) + response = self.client.get_completion(messages = messages) #log = {'messages': messages, 'response': response.model_dump(mode='json')} diff --git a/py-src/data_formulator/agents/agent_utils.py b/py-src/data_formulator/agents/agent_utils.py index 802112f3..792b3d8a 100644 --- a/py-src/data_formulator/agents/agent_utils.py +++ b/py-src/data_formulator/agents/agent_utils.py @@ -66,6 +66,7 @@ def table_hash(table): frozen_table = tuple(sorted([tuple([hash(value_handling_func(r[key])) for key in schema]) for r in table])) return hash(frozen_table) + def extract_code_from_gpt_response(code_raw, language): """search for matches and then look for pairs of ```...``` to extract code""" diff --git a/py-src/data_formulator/agents/client_utils.py b/py-src/data_formulator/agents/client_utils.py index fe854c7a..6209deea 100644 --- a/py-src/data_formulator/agents/client_utils.py +++ b/py-src/data_formulator/agents/client_utils.py @@ -1,35 +1,66 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. - -import openai import os -import sys - +from litellm import completion from azure.identity import DefaultAzureCredential, get_bearer_token_provider -def get_client(endpoint, key): +class Client(object): + """ + Returns a LiteLLM client configured for the specified endpoint and model. + Supports OpenAI, Azure, Ollama, and other providers via LiteLLM. + """ + def __init__(self, endpoint, model, api_key=None, api_base=None, api_version=None): + + if endpoint == "default": + self.endpoint = os.getenv("ENDPOINT", "azure_openai") + self.model = model + api_base = os.getenv("API_BASE") + else: + self.endpoint = endpoint + self.model = model + + # other params, including temperature, max_completion_tokens, api_base, api_version + self.params = { + "api_key": api_key, + "temperature": 0.7, + "max_completion_tokens": 1200, + } - endpoint = os.getenv("ENDPOINT") if endpoint == "default" else endpoint + if self.endpoint == "gemini": + if model.startswith("gemini/"): + self.model = model + else: + self.model = f"gemini/{model}" + elif self.endpoint == "anthropic": + if model.startswith("anthropic/"): + self.model = model + else: + self.model = f"anthropic/{model}" + elif self.endpoint == "azure_openai": + self.params["api_base"] = api_base + self.params["api_version"] = api_version if api_version else "2024-02-15-preview" + if api_key is None or api_key == "": + token_provider = get_bearer_token_provider( + DefaultAzureCredential(), "https://cognitiveservices.azure.com/.default" + ) + self.params["azure_ad_token_provider"] = token_provider + self.params["custom_llm_provider"] = "azure" + elif self.endpoint == "ollama": + self.params["api_base"] = api_base if api_base else "http://localhost:11434" + self.params["max_tokens"] = self.params["max_completion_tokens"] + if model.startswith("ollama/"): + self.model = model + else: + self.model = f"ollama/{model}" - if key is None or key == "": - # using azure keyless access method - token_provider = get_bearer_token_provider( - DefaultAzureCredential(), "https://cognitiveservices.azure.com/.default" - ) - print(token_provider) - print(endpoint) - client = openai.AzureOpenAI( - api_version="2024-02-15-preview", - azure_endpoint=endpoint, - azure_ad_token_provider=token_provider - ) - elif endpoint == 'openai': - client = openai.OpenAI(api_key=key) - else: - client = openai.AzureOpenAI( - azure_endpoint = endpoint, - api_key=key, - api_version="2024-02-15-preview" - ) - return client \ No newline at end of file + def get_completion(self, messages): + """ + Returns a LiteLLM client configured for the specified endpoint and model. + Supports OpenAI, Azure, Ollama, and other providers via LiteLLM. + """ + # Configure LiteLLM + return completion( + model=self.model, + messages=messages, + drop_params=True, + **self.params + ) \ No newline at end of file diff --git a/py-src/data_formulator/app.py b/py-src/data_formulator/app.py index 1c89baf6..89398b37 100644 --- a/py-src/data_formulator/app.py +++ b/py-src/data_formulator/app.py @@ -35,7 +35,7 @@ from data_formulator.agents.agent_data_clean import DataCleanAgent from data_formulator.agents.agent_code_explanation import CodeExplanationAgent -from data_formulator.agents.client_utils import get_client +from data_formulator.agents.client_utils import Client from dotenv import load_dotenv @@ -53,6 +53,19 @@ app = Flask(__name__, static_url_path='', static_folder=os.path.join(APP_ROOT, "dist")) CORS(app) +def get_client(model_config): + for key in model_config: + model_config[key] = html.escape(model_config[key].strip()) + + client = Client( + model_config["endpoint"], + model_config["model"], + model_config["api_key"] if "api_key" in model_config else None, + model_config["api_base"] if "api_base" in model_config else None, + model_config["api_version"] if "api_version" in model_config else None) + + return client + @app.route('/vega-datasets') def get_example_dataset_list(): dataset_names = vega_data.list_datasets() @@ -125,13 +138,18 @@ def check_available_models(): if os.getenv("ENDPOINT") is None: return json.dumps(results) - client = get_client(os.getenv("ENDPOINT"), "") + endpoint = os.getenv("ENDPOINT") models = [model.strip() for model in os.getenv("MODELS").split(',')] + api_base = os.getenv("API_BASE") + + print("endpoint", endpoint) + print("models", models) + print("api_base", api_base) for model in models: try: - response = client.chat.completions.create( - model=model, + client = Client(endpoint, model, api_key=None, api_base=api_base, api_version=None) + response = client.get_completion( messages=[ {"role": "system", "content": "You are a helpful assistant."}, {"role": "user", "content": "Respond 'I can hear you.' if you can hear me. Do not say anything other than 'I can hear you.'"}, @@ -143,12 +161,15 @@ def check_available_models(): if "I can hear you." in response.choices[0].message.content: results.append({ + "id": f"default-{model}", "endpoint": "default", "key": "", "model": model }) - except: - pass + except Exception as e: + print(f"Error: {e}") + error_message = str(e) + return json.dumps(results) @@ -158,31 +179,27 @@ def test_model(): if request.is_json: app.logger.info("# code query: ") content = request.get_json() - endpoint = html.escape(content['endpoint'].strip()) - key = html.escape(f"{content['key']}".strip()) + # contains endpoint, key, model, api_base, api_version + print("content------------------------------") print(content) - client = get_client(endpoint, key) - model = html.escape(content['model'].strip()) + client = get_client(content['model']) try: - response = client.chat.completions.create( - model=model, + response = client.get_completion( messages=[ {"role": "system", "content": "You are a helpful assistant."}, {"role": "user", "content": "Respond 'I can hear you.' if you can hear me. Do not say anything other than 'I can hear you.'"}, ] ) - print(f"model: {model}") + print(f"model: {content['model']}") print(f"welcome message: {response.choices[0].message.content}") if "I can hear you." in response.choices[0].message.content: result = { - "endpoint": endpoint, - "key": key, - "model": model, + "model": content['model'], "status": 'ok', "message": "" } @@ -190,14 +207,12 @@ def test_model(): print(f"Error: {e}") error_message = str(e) result = { - "endpoint": endpoint, - "key": key, - "model": model, + "model": content['model'], "status": 'error', "message": error_message, } else: - {'status': 'error'} + result = {'status': 'error'} return json.dumps(result) @@ -269,11 +284,11 @@ def process_data_on_load_request(): content = request.get_json() token = content["token"] - client = get_client(content['model']['endpoint'], content['model']['key']) - model = content['model']['model'] + client = get_client(content['model']) + app.logger.info(f" model: {content['model']}") - agent = DataLoadAgent(client=client, model=model) + agent = DataLoadAgent(client=client) candidates = agent.run(content["input_data"]) candidates = [c['content'] for c in candidates if c['status'] == 'ok'] @@ -294,11 +309,10 @@ def derive_concept_request(): content = request.get_json() token = content["token"] - client = get_client(content['model']['endpoint'], content['model']['key']) - model = content['model']['model'] + client = get_client(content['model']) + app.logger.info(f" model: {content['model']}") - - agent = ConceptDeriveAgent(client=client, model=model) + agent = ConceptDeriveAgent(client=client) #print(content["input_data"]) @@ -323,12 +337,11 @@ def clean_data_request(): content = request.get_json() token = content["token"] - client = get_client(content['model']['endpoint'], content['model']['key']) - model = content['model']['model'] + client = get_client(content['model']) app.logger.info(f" model: {content['model']}") - agent = DataCleanAgent(client=client, model=model) + agent = DataCleanAgent(client=client) candidates = agent.run(content['content_type'], content["raw_data"], content["image_cleaning_instruction"]) @@ -350,11 +363,9 @@ def sort_data_request(): content = request.get_json() token = content["token"] - client = get_client(content['model']['endpoint'], content['model']['key']) - model = content['model']['model'] - app.logger.info(f" model: {content['model']}") + client = get_client(content['model']) - agent = SortDataAgent(client=client, model=model) + agent = SortDataAgent(client=client) candidates = agent.run(content['field'], content['items']) #candidates, dialog = limbo_concept.call_codex_sort(content["items"], content["field"]) @@ -375,9 +386,7 @@ def derive_data(): content = request.get_json() token = content["token"] - client = get_client(content['model']['endpoint'], content['model']['key']) - model = content['model']['model'] - app.logger.info(f" model: {content['model']}") + client = get_client(content['model']) # each table is a dict with {"name": xxx, "rows": [...]} input_tables = content["input_tables"] @@ -394,10 +403,10 @@ def derive_data(): if mode == "recommendation": # now it's in recommendation mode - agent = DataRecAgent(client, model) + agent = DataRecAgent(client=client) results = agent.run(input_tables, instruction) else: - agent = DataTransformationAgentV2(client=client, model=model) + agent = DataTransformationAgentV2(client=client) results = agent.run(input_tables, instruction, [field['name'] for field in new_fields]) repair_attempts = 0 @@ -429,9 +438,7 @@ def refine_data(): content = request.get_json() token = content["token"] - client = get_client(content['model']['endpoint'], content['model']['key']) - model = content['model']['model'] - app.logger.info(f" model: {content['model']}") + client = get_client(content['model']) # each table is a dict with {"name": xxx, "rows": [...]} input_tables = content["input_tables"] @@ -443,7 +450,7 @@ def refine_data(): print(dialog) # always resort to the data transform agent - agent = DataTransformationAgentV2(client, model=model) + agent = DataTransformationAgentV2(client=client) results = agent.followup(input_tables, dialog, [field['name'] for field in output_fields], new_instruction) repair_attempts = 0 @@ -469,15 +476,13 @@ def request_code_expl(): content = request.get_json() token = content["token"] - client = get_client(content['model']['endpoint'], content['model']['key']) - model = content['model']['model'] - app.logger.info(f" model: {content['model']}") + client = get_client(content['model']) # each table is a dict with {"name": xxx, "rows": [...]} input_tables = content["input_tables"] code = content["code"] - code_expl_agent = CodeExplanationAgent(client=client, model=model) + code_expl_agent = CodeExplanationAgent(client=client) expl = code_expl_agent.run(input_tables, code) else: expl = "" diff --git a/src/app/dfSlice.tsx b/src/app/dfSlice.tsx index b306745c..409a4514 100644 --- a/src/app/dfSlice.tsx +++ b/src/app/dfSlice.tsx @@ -26,12 +26,21 @@ export const generateFreshChart = (tableRef: string, chartType?: string) : Chart } } +export interface ModelConfig { + id: string; // unique identifier for the model / client combination + endpoint: string; + model: string; + api_key?: string; + api_base?: string; + api_version?: string; +} + // Define a type for the slice state export interface DataFormulatorState { - oaiModels: {endpoint: string, key: string, model: string }[]; - selectedModel: {endpoint: string, model: string} | undefined; - testedModels: {endpoint: string, model: string, status: 'ok' | 'error' | 'testing' | 'unknown', message: string}[]; + models: ModelConfig[]; + selectedModelId: string | undefined; + testedModels: {id: string, status: 'ok' | 'error' | 'testing' | 'unknown', message: string}[]; tables : DictTable[]; charts: Chart[]; @@ -63,8 +72,8 @@ export interface DataFormulatorState { // Define the initial state using that type const initialState: DataFormulatorState = { - oaiModels: [], - selectedModel: undefined, + models: [], + selectedModelId: undefined, testedModels: [], tables: [], @@ -222,7 +231,7 @@ export const dataFormulatorSlice = createSlice({ // avoid resetting inputted models // state.oaiModels = state.oaiModels.filter((m: any) => m.endpoint != 'default'); - state.selectedModel = state.oaiModels.length > 0 ? state.oaiModels[0] : undefined; + state.selectedModelId = state.models.length > 0 ? state.models[0].id : undefined; state.testedModels = []; state.tables = []; @@ -247,8 +256,8 @@ export const dataFormulatorSlice = createSlice({ let savedState = action.payload; - state.oaiModels = savedState.oaiModels.filter((m: any) => m.endpoint != 'default'); - state.selectedModel = state.oaiModels.length > 0 ? state.oaiModels[0] : undefined; + state.models = savedState.models.filter((m: any) => m.endpoint != 'default'); + state.selectedModelId = state.models.length > 0 ? state.models[0].id : undefined; state.testedModels = []; // models should be tested again //state.table = undefined; @@ -274,25 +283,27 @@ export const dataFormulatorSlice = createSlice({ toggleBetaMode: (state, action: PayloadAction) => { state.betaMode = action.payload; }, - selectModel: (state, action: PayloadAction<{model: string, endpoint: string}>) => { - state.selectedModel = action.payload; + selectModel: (state, action: PayloadAction) => { + state.selectedModelId = action.payload; }, - addModel: (state, action: PayloadAction<{model: string, key: string, endpoint: string}>) => { - state.oaiModels = [...state.oaiModels, action.payload]; + addModel: (state, action: PayloadAction) => { + state.models = [...state.models, action.payload]; }, - removeModel: (state, action: PayloadAction<{model: string, endpoint: string}>) => { - let model = action.payload.model; - let endpoint = action.payload.endpoint; - state.oaiModels = state.oaiModels.filter(oaiModel => oaiModel.model != model || oaiModel.endpoint != endpoint ); - state.testedModels = state.testedModels.filter(m => !(m.model == model && m.endpoint == endpoint)); + removeModel: (state, action: PayloadAction) => { + state.models = state.models.filter(model => model.id != action.payload); + if (state.selectedModelId == action.payload) { + state.selectedModelId = undefined; + } }, - updateModelStatus: (state, action: PayloadAction<{model: string, endpoint: string, status: 'ok' | 'error' | 'testing' | 'unknown', message: string}>) => { - let model = action.payload.model; - let endpoint = action.payload.endpoint; + updateModelStatus: (state, action: PayloadAction<{id: string, status: 'ok' | 'error' | 'testing' | 'unknown', message: string}>) => { + let id = action.payload.id; let status = action.payload.status; let message = action.payload.message; - state.testedModels = [...state.testedModels.filter(t => !(t.model == model && t.endpoint == endpoint)), {model, endpoint, status, message} ] + state.testedModels = [ + ...state.testedModels.filter(t => t.id != id), + {id: id, status, message} + ]; }, addTable: (state, action: PayloadAction) => { let table = action.payload; @@ -640,17 +651,19 @@ export const dataFormulatorSlice = createSlice({ }) .addCase(fetchAvailableModels.fulfilled, (state, action) => { let defaultModels = action.payload; - state.oaiModels = [...defaultModels, ...state.oaiModels.filter(e => !defaultModels.map((m: any) => m.endpoint).includes(e.endpoint))]; - console.log(state.oaiModels) + state.models = [...defaultModels, ...state.models.filter(e => !defaultModels.map((m: any) => m.endpoint).includes(e.endpoint))]; - state.testedModels = [...state.testedModels.filter(t => !(t.endpoint == 'default')), - ...defaultModels.map((m: any) => {return {endpoint: m.endpoint, model: m.model, status: 'ok'}}) ] + console.log("defaultModels", defaultModels); + console.log("state.models", state.models); + console.log("state.testedModels", state.testedModels); - if (state.selectedModel == undefined && defaultModels.length > 0) { - state.selectedModel = { - model: defaultModels[0].model, - endpoint: defaultModels[0].endpoint - } + state.testedModels = [ + ...defaultModels.map((m: any) => {return {id: `default-${m.model}`, status: 'ok'}}) , + ...state.testedModels.filter(t => !defaultModels.map((m: any) => m.endpoint).includes(t.id)) + ] + + if (state.selectedModelId == undefined && defaultModels.length > 0) { + state.selectedModelId = defaultModels[0].id; } console.log("fetched models"); @@ -670,12 +683,11 @@ export const dataFormulatorSlice = createSlice({ }) export const dfSelectors = { - getActiveModel: (state: DataFormulatorState) => { - return state.oaiModels.find(m => m.endpoint == state.selectedModel?.endpoint && m.model == state.selectedModel.model) || {'endpoint': 'default', model: 'gpt-4o', key: ""} + getActiveModel: (state: DataFormulatorState) : ModelConfig => { + return state.models.find(m => m.id == state.selectedModelId) || {'endpoint': 'default', model: 'gpt-4o', id: 'default-gpt-4o'} } } - // derived field: extra all field items from the table export const getDataFieldItems = (baseTable: DictTable): FieldItem[] => { return baseTable.names.map((name, index) => { diff --git a/src/views/ModelSelectionDialog.tsx b/src/views/ModelSelectionDialog.tsx index c878cbd8..d03707c7 100644 --- a/src/views/ModelSelectionDialog.tsx +++ b/src/views/ModelSelectionDialog.tsx @@ -1,4 +1,3 @@ - // Copyright (c) Microsoft Corporation. // Licensed under the MIT License. @@ -9,6 +8,7 @@ import { useDispatch, useSelector } from "react-redux"; import { DataFormulatorState, dfActions, + ModelConfig, } from '../app/dfSlice' import _ from 'lodash'; @@ -37,6 +37,7 @@ import { SelectChangeEvent, MenuItem, OutlinedInput, + Paper, } from '@mui/material'; @@ -55,7 +56,7 @@ import { getUrls } from '../app/utils'; export const GroupHeader = styled('div')(({ theme }) => ({ position: 'sticky', - padding: '8px 16px', + padding: '8px 8px', marginLeft: '-8px', color: "rgba(0, 0, 0, 0.6)", fontSize: "12px", @@ -68,132 +69,218 @@ export const GroupItems = styled('ul')({ export const ModelSelectionButton: React.FC<{}> = ({ }) => { const dispatch = useDispatch(); - const oaiModels = useSelector((state: DataFormulatorState) => state.oaiModels); - const selectedModel = useSelector((state: DataFormulatorState) => state.selectedModel); + const models = useSelector((state: DataFormulatorState) => state.models); + const selectedModelId = useSelector((state: DataFormulatorState) => state.selectedModelId); const testedModels = useSelector((state: DataFormulatorState) => state.testedModels); const [modelDialogOpen, setModelDialogOpen] = useState(false); const [showKeys, setShowKeys] = useState(false); - const [tempSelectedModel, setTempSelectedMode] = useState<{model: string, endpoint: string} | undefined >(selectedModel); + const [tempSelectedModelId, setTempSelectedModeId] = useState(selectedModelId); + + console.log("--------------------------------"); + console.log("models", models); + console.log("selectedModelId", selectedModelId); + console.log("tempSelectedModelId", tempSelectedModelId); + console.log("testedModels", testedModels); - let updateModelStatus = (model: string, endpoint: string, status: 'ok' | 'error' | 'testing' | 'unknown', message: string) => { - dispatch(dfActions.updateModelStatus({endpoint, model, status, message})); + let updateModelStatus = (model: ModelConfig, status: 'ok' | 'error' | 'testing' | 'unknown', message: string) => { + dispatch(dfActions.updateModelStatus({id: model.id, status, message})); } - let getStatus = (model: string, endpoint: string) => { - return testedModels.find(t => t.model == model && t.endpoint == endpoint)?.status || 'unknown'; + let getStatus = (id: string) => { + return testedModels.find(t => (t.id == id))?.status || 'unknown'; } - const [newKeyType, setNewKeyType] = useState("openai"); - const [newEndpoint, setNewEndpoint] = useState(""); - const [newKey, setNewKey] = useState(""); + const [newEndpoint, setNewEndpoint] = useState(""); // openai, azure_openai, ollama etc const [newModel, setNewModel] = useState(""); + const [newApiKey, setNewApiKey] = useState(undefined); + const [newApiBase, setNewApiBase] = useState(undefined); + const [newApiVersion, setNewApiVersion] = useState(undefined); + let disableApiKey = newEndpoint == "default" || newEndpoint == "" || newEndpoint == "ollama"; + let disableModel = newEndpoint == "default" || newEndpoint == ""; + let disableApiBase = newEndpoint != "azure_openai"; + let disableApiVersion = newEndpoint != "azure_openai"; - let modelExists = oaiModels.some(m => m.endpoint == newEndpoint && m.model == newModel); + let modelExists = models.some(m => m.endpoint == newEndpoint && m.model == newModel && m.api_base == newApiBase && m.api_key == newApiKey && m.api_version == newApiVersion); - let testModel = (endpoint: string, key: string, model: string) => { - updateModelStatus(model, endpoint, 'testing', ""); + let testModel = (model: ModelConfig) => { + updateModelStatus(model, 'testing', ""); let message = { method: 'POST', headers: { 'Content-Type': 'application/json', }, body: JSON.stringify({ model: model, - key: key, - endpoint: endpoint }), }; fetch(getUrls().TEST_MODEL, {...message }) .then((response) => response.json()) .then((data) => { let status = data["status"] || 'error'; - updateModelStatus(model, endpoint, status, data["message"] || ""); + updateModelStatus(model, status, data["message"] || ""); }).catch((error) => { - updateModelStatus(model, endpoint, 'error', error.message) + updateModelStatus(model, 'error', error.message) }); } + let readyToTest = false; + if (newEndpoint != "default") { + readyToTest = true; + } + if (newEndpoint == "openai") { + readyToTest = newModel != ""; + } + if (newEndpoint == "azure_openai") { + readyToTest = newModel != "" && newApiBase != ""; + } + if (newEndpoint == "ollama") { + readyToTest = newModel != ""; + } + let newModelEntry = {setTempSelectedMode(undefined)}} + onClick={() => {setTempSelectedModeId(undefined)}} > - + - - - - - - {newKeyType == "openai" ? N/A : { setNewEndpoint(event.target.value); }} - autoComplete='off'/>} + { + setNewEndpoint(newValue || ""); + if (newModel == "" && newValue == "openai") { + setNewModel("gpt-4o"); + } + if (!newApiVersion && newValue == "azure_openai") { + setNewApiVersion("2024-02-15"); + } + }} + options={['openai', 'azure_openai', 'ollama', 'gemini', 'anthropic']} + renderOption={(props, option) => ( + setNewEndpoint(option)} sx={{fontSize: "0.875rem"}}> + {option} + + )} + renderInput={(params) => ( + setNewEndpoint(event.target.value)} + /> + )} + ListboxProps={{ + style: { padding: 0 } + }} + PaperComponent={({ children }) => ( + + + suggestions + + {children} + + )} + /> { setNewKey(event.target.value); }} - autoComplete='off'/> + value={newApiKey} onChange={(event: any) => { setNewApiKey(event.target.value); }} + autoComplete='off' + disabled={disableApiKey} + /> { setNewModel(newValue || ""); }} value={newModel} - options={['gpt-35-turbo', 'gpt-4', 'gpt-4o']} + options={['gpt-35-turbo', 'gpt-4', 'gpt-4o', 'llama3.2']} renderOption={(props, option) => { return { setNewModel(option); }} sx={{fontSize: "small"}}>{option} }} renderInput={(params) => ( { setNewModel(event.target.value); }} /> - )}/> + )} + ListboxProps={{ + style: { padding: 0 } + }} + PaperComponent={({ children }) => ( + + + suggestions + + {children} + + )} + /> + + + { setNewApiBase(event.target.value); }} + autoComplete='off' + disabled={disableApiBase} + required={newEndpoint == "azure_openai"} + /> + + + { setNewApiVersion(event.target.value); }} + autoComplete='off' + disabled={disableApiVersion} + placeholder="api_version" + /> { if (modelExists) { return } - let endpoint = newKeyType == 'openai' ? 'openai' : newEndpoint; + let endpoint = newEndpoint; event.stopPropagation() - dispatch(dfActions.addModel({model: newModel, key: newKey, endpoint})); - dispatch(dfActions.selectModel({model: newModel, endpoint})); - setTempSelectedMode({endpoint, model: newModel}); + let id = `${endpoint}-${newModel}-${newApiKey}-${newApiBase}-${newApiVersion}`; + + let model = {endpoint, model: newModel, api_key: newApiKey, api_base: newApiBase, api_version: newApiVersion, id: id}; + + dispatch(dfActions.addModel(model)); + dispatch(dfActions.selectModel(id)); + setTempSelectedModeId(id); - testModel(endpoint, newKey, newModel); + testModel(model); - setNewKeyType('openai'); setNewEndpoint(""); - setNewKey(""); setNewModel(""); + + setNewApiKey(undefined); + setNewApiBase(undefined); + setNewApiVersion(undefined); }}> @@ -204,41 +291,48 @@ export const ModelSelectionButton: React.FC<{}> = ({ }) => { { event.stopPropagation() - setNewEndpoint(""); - setNewKey(""); setNewModel(""); + setNewApiKey(undefined); + setNewApiBase(undefined); + setNewApiVersion(undefined); }}> + let modelTable = - +
- Key Type - Endpoint - Key - Model + endpoint + api_key + model + api_base + api_version Status Action - {oaiModels.map((oaiModel) => { - let isItemSelected = tempSelectedModel && - tempSelectedModel.endpoint == oaiModel.endpoint && - tempSelectedModel.model == oaiModel.model; - let status = getStatus(oaiModel.model, oaiModel.endpoint); + {models.map((model) => { + let isItemSelected = tempSelectedModelId != undefined && tempSelectedModelId == model.id; + let status = getStatus(model.id); + let statusIcon = status == "unknown" ? : ( status == 'testing' ? : (status == "ok" ? : )) - let message = status == "unknown" ? "Status unknown, click the status icon to test again." : - (testedModels.find(m => m.model === oaiModel.model && m.endpoint === oaiModel.endpoint)?.message || "Unknown error"); + let message = "the model is ready to use"; + if (status == "unknown") { + message = "Status unknown, click the status icon to test again."; + } else if (status == "error") { + message = testedModels.find(t => t.id == model.id)?.message || "Unknown error"; + } + const borderStyle = ['error', 'unknown'].includes(status) ? '1px dashed text.secondary' : undefined; const noBorderStyle = ['error', 'unknown'].includes(status) ? 'none' : undefined; @@ -246,30 +340,35 @@ export const ModelSelectionButton: React.FC<{}> = ({ }) => { <> { setTempSelectedMode({model: oaiModel.model, endpoint: oaiModel.endpoint}) }} + key={`${model.id}`} + onClick={() => { setTempSelectedModeId(model.id) }} sx={{ cursor: 'pointer'}} > - {oaiModel.endpoint == 'openai' ? 'openai' : 'azure openai'} + {model.endpoint} - {oaiModel.endpoint} - - - {oaiModel.key != "" ? - (showKeys ? (oaiModel.key || N/A) : "************") : + {model.api_key != "" ? + (showKeys ? (model.api_key || N/A) : "************") : N/A } - {oaiModel.model} + + {model.model} + + + {model.api_base} + + + {model.api_version} + { testModel(oaiModel.endpoint, oaiModel.key, oaiModel.model) }} + onClick ={() => { testModel(model) }} > {statusIcon} @@ -277,19 +376,16 @@ export const ModelSelectionButton: React.FC<{}> = ({ }) => { - { - dispatch(dfActions.removeModel({model: oaiModel.model, endpoint: oaiModel.endpoint})); - if ((tempSelectedModel) - && tempSelectedModel.endpoint == oaiModel.endpoint - && tempSelectedModel.model == oaiModel.model) { - if (oaiModels.length == 0) { - setTempSelectedMode(undefined); + dispatch(dfActions.removeModel(model.id)); + if ((tempSelectedModelId) + && tempSelectedModelId == model.id) { + if (models.length == 0) { + setTempSelectedModeId(undefined); } else { - let chosenModel = oaiModels[oaiModels.length - 1]; - setTempSelectedMode({ - model: chosenModel.model, endpoint: chosenModel.endpoint - }) + let chosenModel = models[models.length - 1]; + setTempSelectedModeId(chosenModel.id) } } }}> @@ -301,7 +397,7 @@ export const ModelSelectionButton: React.FC<{}> = ({ }) => { {['error', 'unknown'].includes(status) && ( { setTempSelectedMode({model: oaiModel.model, endpoint: oaiModel.endpoint}) }} + onClick={() => { setTempSelectedModeId(model.id) }} sx={{ cursor: 'pointer', '&:hover': { @@ -321,14 +417,19 @@ export const ModelSelectionButton: React.FC<{}> = ({ }) => { ) })} {newModelEntry} + + + model configuration based on LiteLLM, check out supported endpoint / model configurations here. + +
return <> - {setModelDialogOpen(false)}} open={modelDialogOpen}> @@ -341,13 +442,13 @@ export const ModelSelectionButton: React.FC<{}> = ({ }) => { setShowKeys(!showKeys);}}> {showKeys ? 'hide' : 'show'} keys - From 5e856b0a5309bd21c329e285c401c038ef1198b8 Mon Sep 17 00:00:00 2001 From: Dan Marshall Date: Tue, 11 Feb 2025 19:01:45 -0800 Subject: [PATCH 02/16] fix: update key props for table list items and cards for better uniqueness --- src/views/EncodingShelfThread.tsx | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/views/EncodingShelfThread.tsx b/src/views/EncodingShelfThread.tsx index e6e55623..2763db2f 100644 --- a/src/views/EncodingShelfThread.tsx +++ b/src/views/EncodingShelfThread.tsx @@ -405,7 +405,7 @@ export const EncodingShelfThread: FC = function ({ cha //let triggers = currentTable.derive.triggers; let tableList = activeTableThread.map((tableId) =>
; - console.log("selected model?") - console.log(selectedModel) + console.log("selected model?") + console.log(selectedModelId) + return ( - {selectedModel == undefined ? modelSelectionDialogBox : (tables.length > 0 ? fixedSplitPane : dataUploadRequestBox)} + {selectedModelId == undefined ? modelSelectionDialogBox : (tables.length > 0 ? fixedSplitPane : dataUploadRequestBox)} - - ); } \ No newline at end of file diff --git a/src/views/ModelSelectionDialog.tsx b/src/views/ModelSelectionDialog.tsx index d2a96a0b..156555f7 100644 --- a/src/views/ModelSelectionDialog.tsx +++ b/src/views/ModelSelectionDialog.tsx @@ -77,17 +77,11 @@ export const ModelSelectionButton: React.FC<{}> = ({ }) => { const [showKeys, setShowKeys] = useState(false); const [tempSelectedModelId, setTempSelectedModeId] = useState(selectedModelId); - console.log("--------------------------------"); - console.log("models", models); - console.log("selectedModelId", selectedModelId); - console.log("tempSelectedModelId", tempSelectedModelId); - console.log("testedModels", testedModels); - let updateModelStatus = (model: ModelConfig, status: 'ok' | 'error' | 'testing' | 'unknown', message: string) => { dispatch(dfActions.updateModelStatus({id: model.id, status, message})); } - let getStatus = (id: string) => { - return testedModels.find(t => (t.id == id))?.status || 'unknown'; + let getStatus = (id: string | undefined) => { + return id != undefined ? (testedModels.find(t => (t.id == id))?.status || 'unknown') : 'unknown'; } const [newEndpoint, setNewEndpoint] = useState(""); // openai, azure, ollama etc @@ -97,7 +91,7 @@ export const ModelSelectionButton: React.FC<{}> = ({ }) => { const [newApiVersion, setNewApiVersion] = useState(undefined); useEffect(() => { - if (newEndpoint == 'ollama' ) { + if (newEndpoint == 'ollama') { if (!newApiBase) { setNewApiBase('http://localhost:11434'); } @@ -142,7 +136,10 @@ export const ModelSelectionButton: React.FC<{}> = ({ }) => { let newModelEntry = {setTempSelectedModeId(undefined)}} + onClick={(event) => { + event.stopPropagation(); + setTempSelectedModeId(undefined); + }} > @@ -205,7 +202,7 @@ export const ModelSelectionButton: React.FC<{}> = ({ }) => { freeSolo onChange={(event: any, newValue: string | null) => { setNewModel(newValue || ""); }} value={newModel} - options={['gpt-4o-mini', 'gpt-4o', 'claude-3-5-sonnet-20241022', 'codellama']} + options={['gpt-4o-mini', 'gpt-4o', 'claude-3-5-sonnet-20241022']} renderOption={(props, option) => { return { setNewModel(option); }} sx={{fontSize: "small"}}>{option} }} @@ -260,28 +257,33 @@ export const ModelSelectionButton: React.FC<{}> = ({ }) => { disabled={!readyToTest} sx={{cursor: modelExists ? 'help' : 'pointer'}} onClick={(event) => { - if (modelExists) { - return - } - let endpoint = newEndpoint; event.stopPropagation() + console.log("checkpont 1") + + let endpoint = newEndpoint; + let id = `${endpoint}-${newModel}-${newApiKey}-${newApiBase}-${newApiVersion}`; let model = {endpoint, model: newModel, api_key: newApiKey, api_base: newApiBase, api_version: newApiVersion, id: id}; + console.log("checkpont 2") + dispatch(dfActions.addModel(model)); dispatch(dfActions.selectModel(id)); setTempSelectedModeId(id); + console.log("checkpont 3") + testModel(model); setNewEndpoint(""); setNewModel(""); - setNewApiKey(undefined); setNewApiBase(undefined); setNewApiVersion(undefined); + + console.log("checkpont 4") }}> @@ -434,7 +436,15 @@ export const ModelSelectionButton: React.FC<{}> = ({ }) => { {selectedModelId ? `Model: ${(models.find(m => m.id == selectedModelId) as any)?.model}` : 'Select A Model'} - {setModelDialogOpen(false)}} open={modelDialogOpen}> + { + if (reason !== 'backdropClick') { + setModelDialogOpen(false); + } + }} + > Select Model {modelTable} @@ -444,7 +454,7 @@ export const ModelSelectionButton: React.FC<{}> = ({ }) => { setShowKeys(!showKeys);}}> {showKeys ? 'hide' : 'show'} keys -