diff --git a/src/agentlab/agents/visual_agent/agent_configs.py b/src/agentlab/agents/visual_agent/agent_configs.py new file mode 100644 index 00000000..404afaec --- /dev/null +++ b/src/agentlab/agents/visual_agent/agent_configs.py @@ -0,0 +1,44 @@ +from agentlab.llm.llm_configs import CHAT_MODEL_ARGS_DICT + +from .visual_agent import VisualAgentArgs +from .visual_agent_prompts import PromptFlags +import agentlab.agents.dynamic_prompting as dp +import bgym + +# the other flags are ignored for this agent. +DEFAULT_OBS_FLAGS = dp.ObsFlags( + use_tabs=True, # will be overridden by the benchmark when set_benchmark is called after initalizing the agent + use_error_logs=True, + use_past_error_logs=False, + use_screenshot=True, + use_som=False, + openai_vision_detail="auto", +) + +DEFAULT_ACTION_FLAGS = dp.ActionFlags( + action_set=bgym.HighLevelActionSetArgs(subsets=["coord"]), + long_description=True, + individual_examples=False, +) + + +DEFAULT_PROMPT_FLAGS = PromptFlags( + obs=DEFAULT_OBS_FLAGS, + action=DEFAULT_ACTION_FLAGS, + use_thinking=True, + use_concrete_example=False, + use_abstract_example=True, + enable_chat=False, + extra_instructions=None, +) + +VISUAL_AGENT_4o = VisualAgentArgs( + chat_model_args=CHAT_MODEL_ARGS_DICT["openai/gpt-4o-2024-05-13"], + flags=DEFAULT_PROMPT_FLAGS, +) + + +VISUAL_AGENT_CLAUDE_3_5 = VisualAgentArgs( + chat_model_args=CHAT_MODEL_ARGS_DICT["openrouter/anthropic/claude-3.5-sonnet:beta"], + flags=DEFAULT_PROMPT_FLAGS, +) diff --git a/src/agentlab/agents/visual_agent/visual_agent.py b/src/agentlab/agents/visual_agent/visual_agent.py new file mode 100644 index 00000000..8efee11d --- /dev/null +++ b/src/agentlab/agents/visual_agent/visual_agent.py @@ -0,0 +1,129 @@ +""" +GenericAgent implementation for AgentLab + +This module defines a `GenericAgent` class and its associated arguments for use in the AgentLab framework. \ +The `GenericAgent` class is designed to interact with a chat-based model to determine actions based on \ +observations. It includes methods for preprocessing observations, generating actions, and managing internal \ +state such as plans, memories, and thoughts. The `GenericAgentArgs` class provides configuration options for \ +the agent, including model arguments and flags for various behaviors. +""" + +from dataclasses import asdict, dataclass + +import bgym +from browsergym.experiments.agent import Agent, AgentInfo + +from agentlab.agents import dynamic_prompting as dp +from agentlab.agents.agent_args import AgentArgs +from agentlab.llm.chat_api import BaseModelArgs +from agentlab.llm.llm_utils import Discussion, ParseError, SystemMessage, retry +from agentlab.llm.tracking import cost_tracker_decorator + +from .visual_agent_prompts import PromptFlags, MainPrompt + + +@dataclass +class VisualAgentArgs(AgentArgs): + chat_model_args: BaseModelArgs = None + flags: PromptFlags = None + max_retry: int = 4 + + def __post_init__(self): + try: # some attributes might be missing temporarily due to args.CrossProd for hyperparameter generation + self.agent_name = f"VisualAgent-{self.chat_model_args.model_name}".replace("/", "_") + except AttributeError: + pass + + def set_benchmark(self, benchmark: bgym.Benchmark, demo_mode): + """Override Some flags based on the benchmark.""" + self.flags.obs.use_tabs = benchmark.is_multi_tab + + def set_reproducibility_mode(self): + self.chat_model_args.temperature = 0 + + def prepare(self): + return self.chat_model_args.prepare_server() + + def close(self): + return self.chat_model_args.close_server() + + def make_agent(self): + return VisualAgent( + chat_model_args=self.chat_model_args, flags=self.flags, max_retry=self.max_retry + ) + + +class VisualAgent(Agent): + + def __init__( + self, + chat_model_args: BaseModelArgs, + flags: PromptFlags, + max_retry: int = 4, + ): + + self.chat_llm = chat_model_args.make_model() + self.chat_model_args = chat_model_args + self.max_retry = max_retry + + self.flags = flags + self.action_set = self.flags.action.action_set.make_action_set() + self._obs_preprocessor = dp.make_obs_preprocessor(flags.obs) + + self.reset(seed=None) + + def obs_preprocessor(self, obs: dict) -> dict: + return self._obs_preprocessor(obs) + + @cost_tracker_decorator + def get_action(self, obs): + + main_prompt = MainPrompt( + action_set=self.action_set, + obs=obs, + actions=self.actions, + thoughts=self.thoughts, + flags=self.flags, + ) + + system_prompt = SystemMessage(dp.SystemPrompt().prompt) + try: + # TODO, we would need to further shrink the prompt if the retry + # cause it to be too long + + chat_messages = Discussion([system_prompt, main_prompt.prompt]) + ans_dict = retry( + self.chat_llm, + chat_messages, + n_retry=self.max_retry, + parser=main_prompt._parse_answer, + ) + ans_dict["busted_retry"] = 0 + # inferring the number of retries, TODO: make this less hacky + ans_dict["n_retry"] = (len(chat_messages) - 3) / 2 + except ParseError: + ans_dict = dict( + action=None, + n_retry=self.max_retry + 1, + busted_retry=1, + ) + + stats = self.chat_llm.get_stats() + stats["n_retry"] = ans_dict["n_retry"] + stats["busted_retry"] = ans_dict["busted_retry"] + + self.actions.append(ans_dict["action"]) + self.thoughts.append(ans_dict.get("think", None)) + + agent_info = AgentInfo( + think=ans_dict.get("think", None), + chat_messages=chat_messages, + stats=stats, + extra_info={"chat_model_args": asdict(self.chat_model_args)}, + ) + return ans_dict["action"], agent_info + + def reset(self, seed=None): + self.seed = seed + self.thoughts = [] + self.actions = [] diff --git a/src/agentlab/agents/visual_agent/visual_agent_prompts.py b/src/agentlab/agents/visual_agent/visual_agent_prompts.py new file mode 100644 index 00000000..383923f0 --- /dev/null +++ b/src/agentlab/agents/visual_agent/visual_agent_prompts.py @@ -0,0 +1,185 @@ +""" +Prompt builder for GenericAgent + +It is based on the dynamic_prompting module from the agentlab package. +""" + +import logging +from dataclasses import dataclass +import bgym + +from browsergym.core.action.base import AbstractActionSet + +from agentlab.agents import dynamic_prompting as dp +from agentlab.llm.llm_utils import BaseMessage, HumanMessage, image_to_jpg_base64_url + + +@dataclass +class PromptFlags(dp.Flags): + """ + A class to represent various flags used to control features in an application. + """ + + obs: dp.ObsFlags = None + action: dp.ActionFlags = None + use_thinking: bool = True + use_concrete_example: bool = False + use_abstract_example: bool = True + enable_chat: bool = False + extra_instructions: str | None = None + + +class SystemPrompt(dp.PromptElement): + _prompt = """\ +You are an agent trying to solve a web task based on the content of the page and +user instructions. You can interact with the page and explore, and send messages to the user. Each time you +submit an action it will be sent to the browser and you will receive a new page.""" + + +def make_instructions(obs: dict, from_chat: bool, extra_instructions: str | None): + """Convenient wrapper to extract instructions from either goal or chat""" + if from_chat: + instructions = dp.ChatInstructions( + obs["chat_messages"], extra_instructions=extra_instructions + ) + else: + if sum([msg["role"] == "user" for msg in obs.get("chat_messages", [])]) > 1: + logging.warning( + "Agent is in goal mode, but multiple user messages are present in the chat. Consider switching to `enable_chat=True`." + ) + instructions = dp.GoalInstructions( + obs["goal_object"], extra_instructions=extra_instructions + ) + return instructions + + +class History(dp.PromptElement): + """ + Format the actions and thoughts of previous steps.""" + + def __init__(self, actions, thoughts) -> None: + super().__init__() + prompt_elements = [] + for i, (action, thought) in enumerate(zip(actions, thoughts)): + prompt_elements.append( + f""" +## Step {i} +### Thoughts: +{thought} +### Action: +{action} +""" + ) + self._prompt = "\n".join(prompt_elements) + "\n" + + +class Observation(dp.PromptElement): + """Observation of the current step. + + Contains the html, the accessibility tree and the error logs. + """ + + def __init__(self, obs, flags: dp.ObsFlags) -> None: + super().__init__() + self.flags = flags + self.obs = obs + + # for a multi-tab browser, we need to show the current tab + self.tabs = dp.Tabs( + obs, + visible=lambda: flags.use_tabs, + prefix="## ", + ) + + # if an error is present, we need to show it + self.error = dp.Error( + obs["last_action_error"], + visible=lambda: flags.use_error_logs and obs["last_action_error"], + prefix="## ", + ) + + @property + def _prompt(self) -> str: + return f""" +# Observation of current step: +{self.tabs.prompt}{self.error.prompt} + +""" + + def add_screenshot(self, prompt: BaseMessage) -> BaseMessage: + if self.flags.use_screenshot: + if self.flags.use_som: + screenshot = self.obs["screenshot_som"] + prompt.add_text( + "\n## Screenshot:\nHere is a screenshot of the page, it is annotated with bounding boxes and corresponding bids:" + ) + else: + screenshot = self.obs["screenshot"] + prompt.add_text("\n## Screenshot:\nHere is a screenshot of the page:") + img_url = image_to_jpg_base64_url(screenshot) + prompt.add_image(img_url, detail=self.flags.openai_vision_detail) + return prompt + + +class MainPrompt(dp.PromptElement): + + def __init__( + self, + action_set: AbstractActionSet, + obs: dict, + actions: list[str], + thoughts: list[str], + flags: PromptFlags, + ) -> None: + super().__init__() + self.flags = flags + self.history = History(actions, thoughts) + self.instructions = make_instructions(obs, flags.enable_chat, flags.extra_instructions) + self.obs = Observation(obs, self.flags.obs) + + self.action_prompt = dp.ActionPrompt(action_set, action_flags=flags.action) + self.think = dp.Think(visible=lambda: flags.use_thinking) + + @property + def _prompt(self) -> HumanMessage: + prompt = HumanMessage(self.instructions.prompt) + prompt.add_text( + f"""\ +{self.obs.prompt}\ +{self.history.prompt}\ +{self.action_prompt.prompt}\ +{self.think.prompt}\ +""" + ) + + if self.flags.use_abstract_example: + prompt.add_text( + f""" +# Abstract Example + +Here is an abstract version of the answer with description of the content of +each tag. Make sure you follow this structure, but replace the content with your +answer: +{self.think.abstract_ex}\ +{self.action_prompt.abstract_ex}\ +""" + ) + + if self.flags.use_concrete_example: + prompt.add_text( + f""" +# Concrete Example + +Here is a concrete example of how to format your answer. +Make sure to follow the template with proper tags: +{self.think.concrete_ex}\ +{self.action_prompt.concrete_ex}\ +""" + ) + return self.obs.add_screenshot(prompt) + + def _parse_answer(self, text_answer): + ans_dict = {} + ans_dict.update(self.think.parse_answer(text_answer)) + ans_dict.update(self.action_prompt.parse_answer(text_answer)) + return ans_dict diff --git a/src/agentlab/experiments/list_openai_models.py b/src/agentlab/experiments/list_openai_models.py index 0c301926..9314e7ef 100644 --- a/src/agentlab/experiments/list_openai_models.py +++ b/src/agentlab/experiments/list_openai_models.py @@ -6,10 +6,12 @@ df = pd.DataFrame([dict(model) for model in models.data]) # Filter GPT models or o1 models - df = df[df["id"].str.contains("gpt") | df["id"].str.contains("o1")] + # df = df[df["id"].str.contains("gpt") | df["id"].str.contains("o1")] # Convert Unix timestamps to dates (YYYY-MM-DD) and remove time df["created"] = pd.to_datetime(df["created"], unit="s").dt.date df.sort_values(by="created", inplace=True) # Print all entries - print(df) + + # print all entries + print(df.to_string(index=False))