diff --git a/README.md b/README.md index d07ce40..22cd0be 100644 --- a/README.md +++ b/README.md @@ -34,7 +34,7 @@ -**This is the code for the paper [HYDRA: A Hyper Agent for Dynamic Compositional Visual Reasoning](https://link.springer.com/chapter/10.1007/978-3-031-72661-3_8), accepted by ECCV 2024 \[[Project Page](https://hydra-vl4ai.github.io)\].** +**This is the code for the paper [HYDRA: A Hyper Agent for Dynamic Compositional Visual Reasoning](https://link.springer.com/chapter/10.1007/978-3-031-72661-3_8), accepted by ECCV 2024 \[[Project Page](https://hydra-vl4ai.github.io)\]. We released the code that uses Reinforcement Learning (DQN) to fine-tune the LLM🔥🔥🔥** ## Release @@ -50,7 +50,8 @@ We also notice the embedding model is updated by OpenAI as shown in this [link]( - [x] LLaMA3.1 (ollama) replacement. - [x] Gradio Demo - [x] GPT-4o Version. -- [ ] HYDRA with RL +- [x] HYDRA with RL(DQN). +- [ ] HYDRA with Deepseek R1. ## Installation @@ -71,7 +72,7 @@ git clone https://github.com/ControlNet/HYDRA Option 1: Using [pixi](https://prefix.dev/) (recommended): ```Bash -pixi run install +pixi install pixi shell ``` @@ -86,8 +87,8 @@ If you meet errors, please consider going through the `build_env.sh` file and in Edit the file `.env` or setup in CLI to configure the environment variables. ``` -OPENAI_API_KEY=your-api-key -OLLAMA_HOST=http://ollama.server:11434 +OPENAI_API_KEY=your-api-key # if you want to use OpenAI LLMs +OLLAMA_HOST=http://ollama.server:11434 # if you want to use your OLLaMA server for llama or deepseek # do not change this TORCH_HOME variable TORCH_HOME=./pretrained_models ``` @@ -126,7 +127,9 @@ python demo_gradio.py \ --base_config \ --model_config ``` +https://github.com/user-attachments/assets/39a897ab-d457-49d2-8527-0d6fe3a3b922 +--- ### Inference dataset ```Bash @@ -150,6 +153,23 @@ For example, python evaluate.py result/result_okvqa.jsonl okvqa ``` +## Training Controller with RL(DQN) + +```Bash +python train.py \ + --data_root \ + --base_config \ + --model_config \ + --dqn_config +``` +For example, +```Bash +python train.py \ + --data_root ../coco2014 \ + --base_config ./config/okvqa.yaml\ + --model_config ./config/model_config_1gpu.yaml \ + --dqn_config ./config/dqn_debug.yaml +``` ## Citation ```bibtex diff --git a/config/dqn_debug.yaml b/config/dqn_debug.yaml new file mode 100644 index 0000000..71b3998 --- /dev/null +++ b/config/dqn_debug.yaml @@ -0,0 +1,15 @@ +model_name: dataset_nameaokvqa_learn_starts300 +llm_embedding_dim: 1536 # small, 3072 for large +mlp_hidden_dim: 512 +critic_layer_num: 4 +critic_lr: 0.0001 +train_log_interval: 1 +batch_size: 1 +update_times: 2 +save_interval: 1 +learn_starts: 2 +dqn_explore_epsilon: 0.2 +dqn_explore_epsilon_decay_rate: 0.02 +dqn_explore_epsilon_decay_interval: 100 +buffer_size: 100000 +training_epoch: 1 \ No newline at end of file diff --git a/config/dqn_train_config.yaml b/config/dqn_train_config.yaml new file mode 100644 index 0000000..af1b147 --- /dev/null +++ b/config/dqn_train_config.yaml @@ -0,0 +1,15 @@ +model_name: dqn_model_example +llm_embedding_dim: 1536 # small, 3072 for large +mlp_hidden_dim: 512 +critic_layer_num: 4 +critic_lr: 0.0001 +train_log_interval: 100 +batch_size: 128 +update_times: 4 +save_interval: 100 +learn_starts: 1000 +dqn_explore_epsilon: 0.2 +dqn_explore_epsilon_decay_rate: 0.02 +dqn_explore_epsilon_decay_interval: 100 +buffer_size: 100000 +training_epoch: 4 \ No newline at end of file diff --git a/evaluate.py b/evaluate.py index db41125..e02b68a 100644 --- a/evaluate.py +++ b/evaluate.py @@ -1,10 +1,10 @@ -import sys -from evaluation.vqa_eval import GQAeval -from evaluation.grounding_eval import batch_iou_2d -from hydra_vl4ai.util.console import console, logger -import tensorneko_util as N import argparse +import tensorneko_util as N + +from hydra_vl4ai.evaluation.grounding_eval import batch_iou_2d +from hydra_vl4ai.evaluation.vqa_eval import GQAeval +from hydra_vl4ai.util.console import console, logger if __name__ == '__main__': @@ -20,7 +20,8 @@ match args.dataset: case "okvqa": - score = evaluator.accuracy_one_set([each["result"] for each in result], [each["ground_truth"] for each in result]) + score = evaluator.accuracy_one_set([each["result"] for each in result], + [each["ground_truth"] for each in result]) case "refcoco" | "refcoco+": score = batch_iou_2d(result) case _: diff --git a/hydra_vl4ai/agent/controller.py b/hydra_vl4ai/agent/controller.py index 24595c1..651645d 100644 --- a/hydra_vl4ai/agent/controller.py +++ b/hydra_vl4ai/agent/controller.py @@ -1,16 +1,161 @@ import abc +import pickle +import random +from collections import deque import numpy as np +import torch + +from hydra_vl4ai.util.console import logger +from .llm import llm_embedding +from .rl_dqn import DQN_EmbeddingViaLLM, ReplayBuffer +from .smb.state_memory_bank import StateMemoryBank +from ..util.config import Config +from ..util.misc import get_hydra_root_folder class Controller(abc.ABC): @abc.abstractmethod - def __call__(self, instructions: list[str], probs: np.ndarray) -> str: + def __call__(self, *args, **kwargs) -> str: pass class ControllerLLM(Controller): + """ + This is the function for not using RL controller + but directly use the LLM score to return optimal instruction + """ def __call__(self, instructions: list[str], probs: np.ndarray) -> str: return instructions[np.argmax(probs)] + + +class ControllerDQN(Controller): + + def __init__(self, + embedding_prompt_base: str, + task_description_for_instruction: str, + instruction_example: str, + training: bool = False + ): + super().__init__() + self.instruction_example = instruction_example + self.embedding_prompt_base = embedding_prompt_base + self.model_name = Config.dqn_config["model_name"] + self.model_save_path = get_hydra_root_folder().parent / "ckpt" / self.model_name + self.task_description_for_instruction = task_description_for_instruction + self.training = training + + self.rl_agent_model = DQN_EmbeddingViaLLM( + device=torch.device('cuda:0'), + llm_embedding_dim_concat=Config.dqn_config["llm_embedding_dim"], + mlp_hidden_dim=Config.dqn_config["mlp_hidden_dim"], + action_dim=Config.base_config["num_actions"] + 1, + critic_layer_num=Config.dqn_config["critic_layer_num"], + critic_lr=float(Config.dqn_config["critic_lr"]) + ) + # load model + self.model_full_path = self.model_save_path / "critic.pt" + self.buffer_path = self.model_save_path / "buffer.pickle" + self.model_save_path.mkdir(parents=True, exist_ok=True) + + if self.model_full_path.exists(): + self.rl_agent_model.load_model(str(self.model_full_path)) + logger.info(f"Load Model Done from file: {str(self.model_full_path)}") + elif not self.training: # for inference, if no model, raise error + raise RuntimeError(f"Model is not found: {self.model_full_path}") + + if self.training: # for training + self.rl_agent_model.train_mode() + self.train_log_interval = Config.dqn_config["train_log_interval"] + self.reward_window = deque(maxlen=self.train_log_interval) + self.obs_no = 0 + self.batch_size = Config.dqn_config["batch_size"] + self.update_times = Config.dqn_config["update_times"] + self.save_interval = Config.dqn_config["save_interval"] + self.save_model_obs_num = 0 # accumulate + self.best_cum_reward = -100 # TODO:MODIFY + self.best_score = 0 + self.learn_starts = Config.dqn_config["learn_starts"] + self.dqn_explore_epsilon = Config.dqn_config["dqn_explore_epsilon"] + self.dqn_explore_epsilon_decay_rate = Config.dqn_config["dqn_explore_epsilon_decay_rate"] + self.dqn_explore_epsilon_decay_interval = Config.dqn_config["dqn_explore_epsilon_decay_interval"] + self.dqn_explore_threshold = self.dqn_explore_epsilon - self.dqn_explore_epsilon_decay_rate \ + * (self.obs_no / self.dqn_explore_epsilon_decay_interval) + + # load buffer + if self.buffer_path.exists(): + with open(self.buffer_path, "rb") as reward_buffer_container: + self.replay_buffer = pickle.load(reward_buffer_container) + reward_buffer_container.close() + else: + self.replay_buffer = ReplayBuffer(capacity=Config.dqn_config["buffer_size"]) + + else: + self.rl_agent_model.eval_mode() + + def save(self): + self.rl_agent_model.save_model(self.model_full_path) + with open(self.buffer_path, "wb") as f: + pickle.dump(self.replay_buffer, f) + + def load(self): + self.rl_agent_model.load_model(self.model_full_path) + if self.training and self.buffer_path.exists(): + with open(self.buffer_path, "rb") as f: + self.replay_buffer = pickle.load(f) + + async def __call__(self, query: str, current_step_index: int, instructions: list[str], probs: np.ndarray, + state_memory_bank: StateMemoryBank + ) -> tuple[str, np.ndarray, int]: + prompt = self.build_prompt(query, current_step_index, instructions, probs, state_memory_bank) + + # get embedding from llm + response_emb = await llm_embedding(Config.base_config["embedding_model"], prompt) + + affordance_value_array = self.rl_agent_model.get_action(obs=response_emb) + + selected_idx = np.argmax(affordance_value_array) + + # random exploration in the beginning. + if self.training: + # if it is in the beginning phase, do random exploration! + if self.obs_no <= self.learn_starts or np.random.random() <= self.dqn_explore_threshold: + selected_idx = random.choice(range(len(affordance_value_array))) + + if selected_idx != len(instructions): + selected_instruction = instructions[selected_idx] + else: + selected_instruction = "REJECT" + return selected_instruction, response_emb, selected_idx + + def build_prompt(self, query: str, current_step_index: int, instructions: list[str], probs: np.ndarray, + state_memory_bank: StateMemoryBank + ): + """Getting prompt based on template""" + # prompt-for-each-query + prompt = self.embedding_prompt_base.replace('[INSERT_QUERY_HERE]', query) # query insert + prompt = prompt.replace('[INSERT_CURRENT_STEP_NO]', str(current_step_index)) # step number insert + + # prompt-for-query-type-about-the-dataset + prompt = prompt.replace('[INSERT_QUERY_TYPE_HERE]', self.task_description_for_instruction) # query type + prompt = prompt.replace('[EXAMPLE_HERE]', self.instruction_example) # query type demo/ exps + + # previous instruction + prompt = prompt.replace('[NEED_TO_PROVIDE_PREVIOUS_INSTRUCTION]', + state_memory_bank.instructions_prompt) # previous code insert + + # previous executed code + prompt = prompt.replace('[MORE_CODE_WAITING]', state_memory_bank.codes_prompt) # previous code insert + prompt = prompt.replace('[CURRENTLY_RESULT_WAITING]', + state_memory_bank.feedbacks_prompt) # result description insert + + # variable details + prompt = prompt.replace('[VARIABLE_AND_DETAILS]', state_memory_bank.variables_prompt) + + # current instructions/probs + prompt = prompt.replace('[CURRENT_OPTION]', str(instructions)) # instruction options + prompt = prompt.replace('[CURRENT_OPTION_PROBABILITY]', str(probs)) # probs of instruction options + + return prompt diff --git a/hydra_vl4ai/agent/hydra.py b/hydra_vl4ai/agent/hydra.py index 2f92f0e..83377ce 100644 --- a/hydra_vl4ai/agent/hydra.py +++ b/hydra_vl4ai/agent/hydra.py @@ -1,15 +1,17 @@ import tensorneko_util as N +import torchvision.ops.boxes as bops import websockets from websockets import WebSocketClientProtocol +from .controller import ControllerLLM, ControllerDQN +from .planner import Planner from .reasoner import Reasoner from .smb import StateMemoryBank -from .controller import ControllerLLM -from .planner import Planner from .summarizer import Summarizer +from ..evaluation.vqa_eval import GQAeval +from ..util.config import Config from ..util.console import logger from ..util.misc import get_hydra_root_folder -from ..util.config import Config class Hydra: @@ -125,7 +127,7 @@ async def _call_vqa(self, image: bytes, query: str, state_memory_bank: StateMemo with N.util.Timer(verbose=False) as timer: # initial perception result, code = await self.reasoner.initial_run(image, query, websocket) - t = timer.time(timer_msg := f"Reasoner in Step 0") + t = timer.time(timer_msg := "Reasoner in Step 0") logger.debug(f"[Timer] {timer_msg}: {t:.4f} sec") if result.type == "error": return None @@ -173,6 +175,358 @@ async def _call_vqa(self, image: bytes, query: str, state_memory_bank: StateMemo case "final": # ------------------ Summarizer ------------------ final_result = await self.summarizer.final_guess(query, guesses) - t = timer.time(timer_msg := f"Summarizer in Step Final") + t = timer.time(timer_msg := "Summarizer in Step Final") + logger.debug(f"[Timer] {timer_msg}: {t:.4f} sec") + return final_result + + +class HydraWithRL(Hydra): + + def __init__(self, training: bool = False): + super().__init__() + # load config + self.dataset = Config.base_config["dataset"] + prompt_type = Config.base_config["prompt"] + self.max_iterations = Config.base_config["max_iterations"] + self._debug = Config.base_config["debug"] + self.task = Config.base_config["task"] + num_actions = Config.base_config["num_actions"] + reasoner_max_retry = Config.base_config["reasoner_max_retry"] + + # load prompts + prompt_path = get_hydra_root_folder() / "agent" / "prompt" / prompt_type + if not prompt_path.exists(): + raise NotImplementedError( + f"Prompt for {prompt_type} on {self.dataset} is not implemented in {prompt_path}.") + + instruction_prompt_base = N.io.read.text(str(prompt_path / "instruction.prompt")) + task_description_for_instruction = N.io.read.text(str(prompt_path / "task_description_for_instruction.prompt")) + instruction_examples = N.io.read.text(str(prompt_path / "instruction_example.prompt")) + + code_prompt_base = N.io.read.text(str(prompt_path / "code.prompt")) + task_description_for_code = N.io.read.text(str(prompt_path / "task_description_for_code.prompt")) + code_example = N.io.read.text(str(prompt_path / "code_example.prompt")) + + self.planner = Planner( + instruction_prompt_base, + task_description_for_instruction, + instruction_examples, + num_actions, + Config.base_config["planner_max_retry"]) + + embedding_prompt_base = N.io.read.text(str(prompt_path / "embedding.prompt")) + self.controller = ControllerDQN( + embedding_prompt_base, + task_description_for_instruction, + instruction_examples, + training + ) + + self.reasoner = Reasoner( + code_prompt_base, + task_description_for_code, + code_example, + num_actions, + reasoner_max_retry + ) + + match self.task: + case "grounding": + self.summarizer = None + case "vqa": + summarize_prompt_base = N.io.read.text(str(prompt_path / "summarize.prompt")) + guess_answer_prompt_base = N.io.read.text(str(prompt_path / "guess_answer.prompt")) + self.summarizer = Summarizer( + summarize_prompt_base, + guess_answer_prompt_base, + task_description_for_instruction + ) + + self.evaluator = GQAeval() + + async def __call__(self, image: bytes, query: str) -> str: + state_memory_bank = StateMemoryBank() + async with websockets.connect(f"ws://localhost:{Config.base_config['executor_port']}/ws/") as websocket: + match self.task: + case "grounding": + return await self._call_grounding(image, query, state_memory_bank, websocket) + + case "vqa": + return await self._call_vqa(image, query, state_memory_bank, websocket) + + async def _call_grounding(self, image: bytes, query: str, state_memory_bank: StateMemoryBank, + websocket: WebSocketClientProtocol + ) -> str: + with N.util.Timer(verbose=False) as timer: + for current_step_index in range(1, self.max_iterations + 1): + # ----------------- Planner ----------------- + instructions, probs = await self.planner(query, current_step_index, state_memory_bank) + t = timer.time(timer_msg := f"Planner in Step {current_step_index}") + logger.debug(f"[Timer] {timer_msg}: {t:.4f} sec") + + if not instructions: + break + + # ----------------- Controller ----------------- + instruction, response_emb, selected_idx = await self.controller(query, current_step_index, instructions, + probs, state_memory_bank) # TODO:MODIFY + t = timer.time(timer_msg := f"Controller in Step {current_step_index}") + logger.debug(f"[Timer] {timer_msg}: {t:.4f} sec") + + if instruction == "REJECT": # TODO:MODIFY + continue + # ------------------ Reasoner ------------------ + result, code = await self.reasoner(image, query, instruction, current_step_index, state_memory_bank, + websocket) + t = timer.time(timer_msg := f"Reasoner in Step {current_step_index}") + logger.debug(f"[Timer] {timer_msg}: {t:.4f} sec") + + if result.type != "error": + assert type(code) is str + state_memory_bank.extend_memory( + result.feedbacks, [code], [instruction], result.variables, result.variable_names + ) + + match result.type: + case "error": + return None + case "final": + return result.final_result + case "continue": + continue + + async def _call_vqa(self, image: bytes, query: str, state_memory_bank: StateMemoryBank, + websocket: WebSocketClientProtocol + ) -> str: + with N.util.Timer(verbose=False) as timer: + # initial perception + result, code = await self.reasoner.initial_run(image, query, websocket) + t = timer.time(timer_msg := "Reasoner in Step 0") + logger.debug(f"[Timer] {timer_msg}: {t:.4f} sec") + if result.type == "error": + return None + state_memory_bank.extend_memory( + result.feedbacks, [], [], result.variables, result.variable_names + ) + + guesses = [] + + for current_step_index in range(1, self.max_iterations + 1): + # ----------------- Planner ----------------- + instructions, probs = await self.planner(query, current_step_index, state_memory_bank) + t = timer.time(timer_msg := f"Planner in Step {current_step_index}") + logger.debug(f"[Timer] {timer_msg}: {t:.4f} sec") + + if not instructions: + break + + # ----------------- Controller ----------------- + instruction, response_emb, selected_idx = await self.controller(query, current_step_index, instructions, + probs, state_memory_bank) # TODO:MODIFY + t = timer.time(timer_msg := f"Controller in Step {current_step_index}") + logger.debug(f"[Timer] {timer_msg}: {t:.4f} sec") + + if instruction == "REJECT": # TODO:MODIFY + continue + # ------------------ Reasoner ------------------ + result, code = await self.reasoner(image, query, instruction, current_step_index, state_memory_bank, + websocket) + t = timer.time(timer_msg := f"Reasoner in Step {current_step_index}") + logger.debug(f"[Timer] {timer_msg}: {t:.4f} sec") + + if result.type != "error": + assert type(code) is str + state_memory_bank.extend_memory( + result.feedbacks, [code], [instruction], result.variables, result.variable_names + ) + # ------------------ Summarizer ------------------ + guesses.append(await self.summarizer(query, state_memory_bank)) + t = timer.time(timer_msg := f"Summarizer in Step {current_step_index}") + logger.debug(f"[Timer] {timer_msg}: {t:.4f} sec") + + match result.type: + case "continue": + continue + case "error": + return None + case "final": + # ------------------ Summarizer ------------------ + final_result = await self.summarizer.final_guess(query, guesses) + t = timer.time(timer_msg := "Summarizer in Step Final") + logger.debug(f"[Timer] {timer_msg}: {t:.4f} sec") + return final_result + + async def train_step(self, image: bytes, query: str, ground_true=None) -> str: + state_memory_bank = StateMemoryBank() + async with websockets.connect(f"ws://localhost:{Config.base_config['executor_port']}/ws/") as websocket: + match self.task: + case "grounding": + return await self._train_step_grounding(image, query, state_memory_bank, websocket, ground_true) + + case "vqa": + return await self._train_step_vqa(image, query, state_memory_bank, websocket, ground_true) + + async def _train_step_grounding(self, image: bytes, query: str, state_memory_bank: StateMemoryBank, + websocket: WebSocketClientProtocol, ground_true=None + ) -> str: + with N.util.Timer(verbose=False) as timer: + sub_reward = 10 + pre_obs_emb = None + for current_step_index in range(1, self.max_iterations + 1): + # ----------------- Planner ----------------- + instructions, probs = await self.planner(query, current_step_index, state_memory_bank) + t = timer.time(timer_msg := f"Planner in Step {current_step_index}") + logger.debug(f"[Timer] {timer_msg}: {t:.4f} sec") + + if not instructions: + break + + # ----------------- Controller ----------------- + instruction, response_emb, selected_idx = await self.controller(query, current_step_index, instructions, + probs, state_memory_bank) + t = timer.time(timer_msg := f"Controller in Step {current_step_index}") + logger.debug(f"[Timer] {timer_msg}: {t:.4f} sec") + + if instruction == "REJECT": + continue + # ------------------ Reasoner ------------------ + result, code = await self.reasoner(image, query, instruction, current_step_index, state_memory_bank, + websocket) + t = timer.time(timer_msg := f"Reasoner in Step {current_step_index}") + logger.debug(f"[Timer] {timer_msg}: {t:.4f} sec") + + if result.type != "error": + assert type(code) is str + state_memory_bank.extend_memory( + result.feedbacks, [code], [instruction], result.variables, result.variable_names + ) + + else: + sub_reward -= 100 + + if result.type == "final": + sub_reward += float(bops.box_iou(result.final_result, ground_true)[0][0]) * 100 + + # Calculate reward + sub_reward -= current_step_index + self.controller.save_model_obs_num += 1 + + # update buffer and model + if current_step_index > 1: + # store transition + self.controller.replay_buffer.push(pre_obs_emb, selected_idx, sub_reward, response_emb, done=False) + self.controller.reward_window.append(sub_reward) + self.controller.obs_no += 1 + # update model each step when buffer size bigger than batch_size. + if len(self.controller.replay_buffer) > self.controller.batch_size: + for i in range(self.controller.update_times): + self.controller.rl_agent_model.update(replay_buffer=self.controller.replay_buffer, + batch_size=self.controller.batch_size) + + pre_obs_emb = response_emb # reserve current emb as previous emb + self.controller.dqn_explore_threshold = \ + self.controller.dqn_explore_epsilon - self.controller.dqn_explore_epsilon_decay_rate \ + * (self.controller.obs_no / self.controller.dqn_explore_epsilon_decay_interval) + + match result.type: + case "error": + return None + case "final": + return result.final_result + case "continue": + continue + + async def _train_step_vqa(self, image: bytes, query: str, state_memory_bank: StateMemoryBank, + websocket: WebSocketClientProtocol, ground_true=None + ) -> str: + with (N.util.Timer(verbose=False) as timer): + # initial perception + result, code = await self.reasoner.initial_run(image, query, websocket) + t = timer.time(timer_msg := "Reasoner in Step 0") + logger.debug(f"[Timer] {timer_msg}: {t:.4f} sec") + if result.type == "error": + return None + state_memory_bank.extend_memory( + result.feedbacks, [], [], result.variables, result.variable_names + ) + + guesses = [] + + # for training + sub_reward = 0 + pre_obs_emb = None + + for current_step_index in range(1, self.max_iterations + 1): + # ----------------- Planner ----------------- + instructions, probs = await self.planner(query, current_step_index, state_memory_bank) + t = timer.time(timer_msg := f"Planner in Step {current_step_index}") + logger.debug(f"[Timer] {timer_msg}: {t:.4f} sec") + + if not instructions: + break + + # ----------------- Controller ----------------- + instruction, response_emb, selected_idx = await self.controller(query, current_step_index, instructions, + probs, state_memory_bank) + t = timer.time(timer_msg := f"Controller in Step {current_step_index}") + logger.debug(f"[Timer] {timer_msg}: {t:.4f} sec") + + if instruction == "REJECT": + continue + # ------------------ Reasoner ------------------ + result, code = await self.reasoner(image, query, instruction, current_step_index, state_memory_bank, + websocket) + t = timer.time(timer_msg := f"Reasoner in Step {current_step_index}") + logger.debug(f"[Timer] {timer_msg}: {t:.4f} sec") + + if result.type != "error": + assert type(code) is str + state_memory_bank.extend_memory( + result.feedbacks, [code], [instruction], result.variables, result.variable_names + ) + # ------------------ Summarizer ------------------ + guesses.append(await self.summarizer(query, state_memory_bank)) + t = timer.time(timer_msg := f"Summarizer in Step {current_step_index}") + logger.debug(f"[Timer] {timer_msg}: {t:.4f} sec") + + else: + sub_reward -= 100 + + if result.type == "final": + if 'okvqa' in self.dataset: + sub_reward += self.evaluator.accuracy_one_set([result.final_result], [ground_true]) + else: + sub_reward += self.evaluator.accuracy_one_one([result.final_result], [ground_true]) + + # Calculate reward + sub_reward -= current_step_index + self.controller.save_model_obs_num += 1 + + # update buffer and model + if current_step_index > 1: + # store transition + self.controller.replay_buffer.push(pre_obs_emb, selected_idx, sub_reward, response_emb, done=False) + self.controller.reward_window.append(sub_reward) + self.controller.obs_no += 1 + # update model each step when buffer size bigger than batch_size. + if len(self.controller.replay_buffer) > self.controller.batch_size: + for i in range(self.controller.update_times): + self.controller.rl_agent_model.update(replay_buffer=self.controller.replay_buffer, + batch_size=self.controller.batch_size) + + pre_obs_emb = response_emb + self.controller.dqn_explore_threshold = \ + self.controller.dqn_explore_epsilon - self.controller.dqn_explore_epsilon_decay_rate \ + * (self.controller.obs_no / self.controller.dqn_explore_epsilon_decay_interval) + + match result.type: + case "continue": + continue + case "error": + return None + case "final": + # ------------------ Summarizer ------------------ + final_result = await self.summarizer.final_guess(query, guesses) + t = timer.time(timer_msg := "Summarizer in Step Final") logger.debug(f"[Timer] {timer_msg}: {t:.4f} sec") return final_result diff --git a/hydra_vl4ai/agent/llm.py b/hydra_vl4ai/agent/llm.py index 49dbfbf..2cfdae8 100644 --- a/hydra_vl4ai/agent/llm.py +++ b/hydra_vl4ai/agent/llm.py @@ -1,12 +1,13 @@ import asyncio -from functools import wraps import os +import time +from functools import wraps + import httpx import numpy as np -from openai import AsyncOpenAI -from ollama import AsyncClient, Client import openai -import time +from ollama import AsyncClient, Client +from openai import AsyncOpenAI from ..util.config import Config from ..util.console import logger @@ -40,21 +41,22 @@ async def wrapper(*args, **kwargs): for _ in range(max_trial): try: return await func(*args, **kwargs) - except openai.APITimeoutError as e: + except openai.APITimeoutError: pass - except openai.APIConnectionError as e: + except openai.APIConnectionError: pass - except openai.RateLimitError as e: + except openai.RateLimitError: time.sleep(1) pass except openai.BadRequestError as e: # maybe exceed the length, should raise directly - raise + raise e except openai.APIStatusError as e: # server side problem, should raise directly - raise + raise e except Exception as e: - raise + raise e + return wrapper @@ -66,14 +68,15 @@ async def wrapper(*args, **kwargs): for _ in range(max_trial): try: return await func(*args, **kwargs) - except httpx.ConnectError as e: + except httpx.ConnectError: pass - except httpx.ConnectTimeout as e: + except httpx.ConnectTimeout: pass - except httpx.TimeoutException as e: + except httpx.TimeoutException: pass except Exception as e: - raise + raise e + return wrapper @@ -86,10 +89,10 @@ async def chatgpt(model_name: str, prompt: str): @handle_openai_exceptions -async def gpt3_embedding(prompt: str): +async def gpt3_embedding(model_name: str, prompt: str): async with _semaphore: - response = (await openai_client.embeddings.create(input = [prompt], - model=Config.base_config["embedding_model"])).data[0].embedding + response = (await openai_client.embeddings.create(input=[prompt], + model=model_name)).data[0].embedding response = np.array(response) return response @@ -97,7 +100,7 @@ async def gpt3_embedding(prompt: str): @handle_ollama_exceptions async def ollama(model_name: str, prompt: str): async with _semaphore: - response = await ollama_client.chat(model=model_name, + response = await ollama_client.chat(model=model_name, messages=[{"role": "user", "content": prompt}], stream=False, ) return response["message"]["content"] @@ -109,3 +112,10 @@ async def llm(model_name: str, prompt: str): return await ollama(model_name, prompt) else: raise ValueError(f"Model {model_name} is not supported.") + + +async def llm_embedding(model_name: str, prompt: str): + if model_name in ("text-embedding-3-small", "text-embedding-3-large"): + return await gpt3_embedding(model_name, prompt) + else: + raise ValueError(f"Model {model_name} is not supported.") diff --git a/hydra_vl4ai/agent/planner.py b/hydra_vl4ai/agent/planner.py index 3351051..d3feb94 100644 --- a/hydra_vl4ai/agent/planner.py +++ b/hydra_vl4ai/agent/planner.py @@ -2,9 +2,9 @@ import numpy as np -from ..util.config import Config -from .smb.state_memory_bank import StateMemoryBank from .llm import llm +from .smb.state_memory_bank import StateMemoryBank +from ..util.config import Config from ..util.console import logger diff --git a/hydra_vl4ai/agent/reasoner.py b/hydra_vl4ai/agent/reasoner.py index 48e1c5f..6333844 100644 --- a/hydra_vl4ai/agent/reasoner.py +++ b/hydra_vl4ai/agent/reasoner.py @@ -1,13 +1,14 @@ import io import json + import websockets from PIL import Image +from .llm import llm +from .smb.state_memory_bank import StateMemoryBank from ..util.config import Config from ..util.console import logger from ..util.message import ExecutionRequest, ExecutionResult -from .llm import llm -from .smb.state_memory_bank import StateMemoryBank class Reasoner: diff --git a/hydra_vl4ai/agent/rl_dqn.py b/hydra_vl4ai/agent/rl_dqn.py new file mode 100644 index 0000000..58b2607 --- /dev/null +++ b/hydra_vl4ai/agent/rl_dqn.py @@ -0,0 +1,156 @@ +import random + +import numpy as np +import torch +import torch.nn as nn +import torch.optim as optim +from torch.nn.utils import clip_grad_norm_ + + +# MLP head +class MLPHead(nn.Module): + + def __init__( + self, + input_dim, + mlp_hidden_dim, + output_dim, + layer_num + ) -> None: + super().__init__() + + mlp_head = [] + for idx in range(layer_num): + if idx == 0: + i_dim = input_dim + else: + i_dim = mlp_hidden_dim + + if idx == layer_num - 1: + o_dim = output_dim + else: + o_dim = mlp_hidden_dim + + mlp_head.append(nn.Linear(i_dim, o_dim)) + + # if idx != layer_num -1: + mlp_head.append(nn.Sigmoid()) + + self.mlp_head = nn.Sequential(*mlp_head) + + def forward(self, x): + return self.mlp_head(x) + + +class ReplayBuffer: + + def __init__(self, capacity) -> None: + self.capacity = capacity + self.buffer = [] + self.pos = 0 + + def push(self, state, action, reward, next_state, done): + if len(self.buffer) < self.capacity: + self.buffer.append(None) + self.buffer[self.pos] = (state, action, reward, next_state, done) + self.pos = int((self.pos + 1) % self.capacity) # as a ring buffer + + def sample(self, batch_size): + batch = random.sample(self.buffer, batch_size) + state, action, reward, next_state, done = map(np.stack, zip(*batch)) # stack for each element + + return state, action, reward, next_state, done + + def __len__(self): + return len(self.buffer) + + +class DQN_EmbeddingViaLLM: + + def __init__( + self, + device, + llm_embedding_dim_concat, + mlp_hidden_dim, + action_dim, + critic_layer_num, + critic_lr, + ) -> None: + self.device = device + + self.critic_head = MLPHead(input_dim=llm_embedding_dim_concat, mlp_hidden_dim=mlp_hidden_dim, + output_dim=action_dim, layer_num=critic_layer_num).to(device) + self.tar_critic_head = MLPHead(input_dim=llm_embedding_dim_concat, mlp_hidden_dim=mlp_hidden_dim, + output_dim=action_dim, layer_num=critic_layer_num).to(device) + for tar_param, param in zip(self.tar_critic_head.parameters(), self.critic_head.parameters()): + tar_param.data.copy_(param.data) + + self.critic_optim = optim.Adam(self.critic_head.parameters(), lr=critic_lr) + + def get_action(self, obs, batch_input=False): + if not batch_input: + obs = torch.FloatTensor(obs).to(self.device).unsqueeze(0) + else: + obs = torch.FloatTensor(obs).to(self.device) + + qval = self.critic_head.forward(x=obs) + + return qval.detach().cpu().numpy().flatten() + # if not batch_input: return qval.argmax().detach().cpu().numpy().flatten() + # else: return qval.argmax(dim=-1).detach().cpu().numpy().flatten() + + def update(self, replay_buffer, batch_size, reward_scale=10., gamma=0.99, soft_tau=1e-2, is_clip_gradient=True, + clip_gradient_val=40 + ): + obs, action, reward, next_obs, done = replay_buffer.sample(batch_size) + + obs = torch.FloatTensor(obs).to(self.device) # obs.size = (batch_size, 1+seq_dim) + next_obs = torch.FloatTensor(next_obs).to(self.device) + action = torch.FloatTensor(action).to(self.device) + reward = torch.FloatTensor(reward).unsqueeze(1).to( + self.device) # reward is single value, unsqueeze() to add one dim to be [reward] at the sample dim; + reward = reward_scale * (reward - reward.mean(dim=0)) / (reward.std( + dim=0) + 1e-6) # normalize with batch mean and std; plus a small number to prevent numerical problem + done = torch.FloatTensor(np.float32(done)).unsqueeze(1).to(self.device) + # print(f'self.critic_head.forward(x=obs) {self.critic_head.forward(x=obs).shape}') + # print(f'action.long() {action.long().shape}') + # qval = self.critic_head.forward(x=obs).gather(1, action.long()) + val_value_ = self.critic_head.forward(x=obs) + qval = val_value_.gather(1, action.long().reshape(val_value_.size(0), 1)) + + with torch.no_grad(): + max_next_qval = self.tar_critic_head.forward(x=next_obs).max(dim=-1)[0].unsqueeze(-1) + tar_qval = reward + gamma * (1 - done) * max_next_qval + + loss_func = nn.MSELoss() + qloss = loss_func(qval, tar_qval.detach()) + + self.critic_optim.zero_grad() + qloss.backward() + if is_clip_gradient: clip_grad_norm_(self.critic_head.parameters(), clip_gradient_val) + self.critic_optim.step() + + for tar_param, param in zip(self.tar_critic_head.parameters(), self.critic_head.parameters()): + tar_param.data.copy_(param.data * soft_tau + tar_param.data * (1 - soft_tau)) + + return qloss.detach().cpu().item() + + def save_model(self, path: str): + torch.save(self.critic_head.state_dict(), path) + + def load_model(self, path: str): + self.critic_head.load_state_dict(torch.load(path)) + + for tar_param, param in zip(self.tar_critic_head.parameters(), self.critic_head.parameters()): + tar_param.data.copy_(param.data) + + self.critic_head.eval() + self.tar_critic_head.eval() + + def eval_mode(self): + self.critic_head.eval() + self.tar_critic_head.eval() + + def train_mode(self): + self.critic_head.train() + self.tar_critic_head.train() diff --git a/hydra_vl4ai/agent/smb/state_memory_bank.py b/hydra_vl4ai/agent/smb/state_memory_bank.py index 103c4f6..d0e3661 100644 --- a/hydra_vl4ai/agent/smb/state_memory_bank.py +++ b/hydra_vl4ai/agent/smb/state_memory_bank.py @@ -25,24 +25,24 @@ def reset(self): @property def instructions_prompt(self): return "\n" + "\n".join(self.instructions) - + @property def feedbacks_prompt(self): return "\n" + "\n".join(self.feedbacks) - + @property def codes_prompt(self): return "\n" + "\n".join(self.codes) - + @property def variables_prompt(self): return "\n" + "\n".join(self.variables) - def extend_memory(self, - other_feedbacks: list[str], - other_codes: list[str], - other_instructions: list[str], - other_variables: list[str], + def extend_memory(self, + other_feedbacks: list[str], + other_codes: list[str], + other_instructions: list[str], + other_variables: list[str], other_variable_names: list[str] ): self.feedbacks.extend(other_feedbacks) @@ -125,7 +125,8 @@ def get_sorted_patches_left_to_right_message_save(self, name): f"\nThe patches list has been sorted from left to right (horizontal). Now, the first patch in the list corresponds to the leftest position, while the last one corresponds to the rightest position") def get_sorted_patches_bottom_to_top_message_save(self, name): - self.feedbacks.append(f"\nThe patches list has been sorted from bottom to top (vertical). Now, the first patch in the list corresponds to the bottom/low/below position, while the last one corresponds to the top/up/above position.") + self.feedbacks.append( + f"\nThe patches list has been sorted from bottom to top (vertical). Now, the first patch in the list corresponds to the bottom/low/below position, while the last one corresponds to the top/up/above position.") def get_sorted_patches_front_to_back_message_save(self, name): self.feedbacks.append( diff --git a/hydra_vl4ai/agent/summarizer.py b/hydra_vl4ai/agent/summarizer.py index cbbbe3e..cc2783e 100644 --- a/hydra_vl4ai/agent/summarizer.py +++ b/hydra_vl4ai/agent/summarizer.py @@ -1,5 +1,5 @@ -from .smb.state_memory_bank import StateMemoryBank from .llm import llm +from .smb.state_memory_bank import StateMemoryBank from ..util.config import Config diff --git a/hydra_vl4ai/agent/webui.py b/hydra_vl4ai/agent/webui.py index 80b7143..176e6f2 100644 --- a/hydra_vl4ai/agent/webui.py +++ b/hydra_vl4ai/agent/webui.py @@ -1,16 +1,16 @@ +import gradio as gr import tensorneko_util as N import websockets from websockets import WebSocketClientProtocol -from .smb import StateMemoryBank -from ..util.console import logger from .hydra import HydraNoRL +from .smb import StateMemoryBank from ..util.config import Config - -import gradio as gr +from ..util.console import logger class HydraNoRLWeb(HydraNoRL): + def __init__(self): super().__init__() @@ -34,7 +34,8 @@ async def _gradio_call(self, image_path: str, query: str): messages = [gr.ChatMessage('user', [image_path, query])] yield messages, self._format_state_memory() - async with websockets.connect(f"ws://localhost:{Config.base_config['executor_port']}/ws/") as ws: # type: ignore + async with websockets.connect( + f"ws://localhost:{Config.base_config['executor_port']}/ws/") as ws: # type: ignore async for chunk in self._call_vqa(image_bytes, query, self.bank, ws): if chunk[0] == "start": continue @@ -82,7 +83,7 @@ async def _call_vqa( result, code = await self.reasoner.initial_run(image, query, websocket) t = timer.time(timer_msg := "Reasoner in Step 0") logger.debug(f"[Timer] {timer_msg}: {t:.4f} sec") - + if result.type == "error": yield "panic", "Reasoning", result return @@ -111,7 +112,8 @@ async def _call_vqa( t = timer.time(timer_msg := f"Controller in Step {current_step_index}") logger.debug(f"[Timer] {timer_msg}: {t:.4f} sec") - yield "stage-completed", "Planning", current_step_index, sorted(zip(instructions, probs), key=lambda x: x[1], reverse=True), instruction + yield "stage-completed", "Planning", current_step_index, sorted(zip(instructions, probs), + key=lambda x: x[1], reverse=True), instruction # ------------------ Reasoner ------------------ yield "start", "Reasoning" @@ -124,7 +126,7 @@ async def _call_vqa( if result.type == "error": yield "error", "Reasoning", result continue - + assert type(code) is str state_memory_bank.extend_memory( result.feedbacks, [code], [instruction], result.variables, result.variable_names @@ -148,6 +150,6 @@ async def _call_vqa( final_result = await self.summarizer.final_guess(query, guesses) t = timer.time(timer_msg := f"Summarizer in Step Final") logger.debug(f"[Timer] {timer_msg}: {t:.4f} sec") - + yield "final", final_result return diff --git a/evaluation/grounding_eval.py b/hydra_vl4ai/evaluation/grounding_eval.py similarity index 100% rename from evaluation/grounding_eval.py rename to hydra_vl4ai/evaluation/grounding_eval.py index fe58e0b..54500f4 100644 --- a/evaluation/grounding_eval.py +++ b/hydra_vl4ai/evaluation/grounding_eval.py @@ -5,6 +5,7 @@ from numpy import ndarray from torch import Tensor + def iou_2d(proposal: Union[Tensor, ndarray], target: Union[Tensor, ndarray]) -> Tensor: """ Calculate 2D IOU for M proposals with N targets. @@ -55,7 +56,6 @@ def process_grounding_result(x: str): return eval(x)["final_answer"][0][:4] - def batch_iou_2d(result): # we follow ViperGPT to filter out the None results remove_index = [] diff --git a/evaluation/vqa_eval.py b/hydra_vl4ai/evaluation/vqa_eval.py similarity index 100% rename from evaluation/vqa_eval.py rename to hydra_vl4ai/evaluation/vqa_eval.py diff --git a/hydra_vl4ai/execution/image_patch.py b/hydra_vl4ai/execution/image_patch.py index 8857bc8..46aa162 100644 --- a/hydra_vl4ai/execution/image_patch.py +++ b/hydra_vl4ai/execution/image_patch.py @@ -13,7 +13,7 @@ import tensorneko_util as N from .toolbox import forward -from ..agent.llm import chatgpt +from ..agent.llm import llm from ..agent.smb import StateMemoryBank from ..util.misc import get_hydra_root_folder, load_json from ..util.config import Config @@ -558,10 +558,10 @@ def llm_query(query, context=None, long_answer=True, state_memory_bank=None): prompt_ += f'Could you help me answer the question: {query}.' if not long_answer: - prompt_ += f'Please provide only a few-word answer. Be very concise, no ranges, no doubt.' + prompt_ += 'Please provide only a few-word answer. Be very concise, no ranges, no doubt.' try: - return_answer = asyncio.run(chatgpt(prompt_)) or "" - except: + return_answer = asyncio.run(llm(Config.base_config["llm_model"], prompt_)) or "" + except Exception: return_answer = 'not answer from gpt' # get global description. diff --git a/hydra_vl4ai/tool/llava.py b/hydra_vl4ai/tool/llava.py index ff82581..c846e6a 100644 --- a/hydra_vl4ai/tool/llava.py +++ b/hydra_vl4ai/tool/llava.py @@ -1,11 +1,11 @@ -import os.path +import re import torch +from huggingface_hub import snapshot_download from llava.conversation import conv_templates +from llava.mm_utils import process_images, tokenizer_image_token, get_model_name_from_path from llava.model.builder import load_pretrained_model from llava.utils import disable_torch_init -from llava.mm_utils import process_images, tokenizer_image_token, get_model_name_from_path -from huggingface_hub import snapshot_download from ._base import BaseModel, module_registry from ..util.misc import get_root_folder diff --git a/hydra_vl4ai/util/config.py b/hydra_vl4ai/util/config.py index 7f59f4e..ff366e0 100644 --- a/hydra_vl4ai/util/config.py +++ b/hydra_vl4ai/util/config.py @@ -2,13 +2,16 @@ from tensorneko_util.util import Singleton import tensorneko_util as N + @Singleton class Config: def __init__(self): self.model_config_path: str | None = None self._base_config_path: str | None = None + self._dqn_config_path: str | None = None self.base_config: dict[str, Any] = dict() + self.dqn_config: dict[str, Any] = dict() @property def base_config_path(self): @@ -19,3 +22,13 @@ def base_config_path(self, value): if value is not None: self._base_config_path = value self.base_config = N.read(value) + + @property + def dqn_config_path(self): + return self._dqn_config_path + + @dqn_config_path.setter + def dqn_config_path(self, value): + if value is not None: + self._dqn_config_path = value + self.dqn_config = N.read(value) diff --git a/inference_bash.sh b/inference_bash.sh new file mode 100755 index 0000000..197eca2 --- /dev/null +++ b/inference_bash.sh @@ -0,0 +1,9 @@ +#!/bin/bash + +# preprocessing +python main.py \ + --data_root '/home/mai/fke/fkee/coco2014' \ + --base_config './config/okvqa.yaml' \ + --model_config './config/model_config_1gpu.yaml' \ + --dqn_config './config/dqn_testing.yaml' +echo "preprocessing done" diff --git a/main.py b/main.py index f0b3967..634c6fa 100644 --- a/main.py +++ b/main.py @@ -1,35 +1,44 @@ import asyncio import json import os -import requests -import time -import tensorneko_util as N from pathlib import Path + +import tensorneko_util as N from dotenv import load_dotenv + load_dotenv() import argparse + parser = argparse.ArgumentParser() parser.add_argument("--data_root", type=str, required=True) parser.add_argument("--base_config", type=str, required=True) parser.add_argument("--model_config", type=str, required=True) parser.add_argument("--result_folder", type=str, default="./result") +parser.add_argument("--dqn_config", type=str) args = parser.parse_args() from hydra_vl4ai.util.config import Config + Config.base_config_path = args.base_config +if args.dqn_config is not None: + Config.dqn_config_path = args.dqn_config Config.model_config_path = args.model_config -from hydra_vl4ai.agent.hydra import HydraNoRL +from hydra_vl4ai.agent.hydra import HydraNoRL, HydraWithRL from hydra_vl4ai.util.console import logger, console from hydra_vl4ai.util.misc import wait_until_loaded import exp_datasets async def main(): - with console.status("[bold green]Connect to HYDRA executor...") as status: + with console.status("[bold green]Connect to HYDRA executor..."): wait_until_loaded(f"http://localhost:{Config.base_config['executor_port']}") - hydra = HydraNoRL() + + if args.dqn_config is None: + hydra = HydraNoRL() + else: + hydra = HydraWithRL() match Config.base_config["dataset"]: case "gqa": @@ -53,22 +62,22 @@ async def main(): dataset = exp_datasets.Refcoco(args.data_root) case _: raise ValueError("Invalid dataset") - + # output path Path(args.result_folder).mkdir(parents=True, exist_ok=True) save_path = Path(args.result_folder) / f"result_{Config.base_config['dataset']}.jsonl" - + # resume if the file exists completed = [] if os.path.exists(save_path): prev_results = N.io.read.json(str(save_path)) completed = [result["datum_id"] for result in prev_results] - + for i, (image_path, datum_id, query, ground_truth) in enumerate(dataset): if datum_id in completed: - logger.info(f"Skipping {i+1}/{len(dataset)}") + logger.info(f"Skipping {i + 1}/{len(dataset)}") continue - logger.info(f"Processing {i+1}/{len(dataset)}") + logger.info(f"Processing {i + 1}/{len(dataset)}") with open(image_path, "rb") as f: image_buffer = f.read() result = await hydra(image_buffer, query) @@ -83,5 +92,6 @@ async def main(): }) + "\n") f.flush() + if __name__ == "__main__": asyncio.run(main=main()) diff --git a/media/gradio_demo.mov b/media/gradio_demo.mov new file mode 100644 index 0000000..702f4a4 Binary files /dev/null and b/media/gradio_demo.mov differ diff --git a/module_repos/GLIP/pyproject.toml b/module_repos/GLIP/pyproject.toml new file mode 100644 index 0000000..8ca2555 --- /dev/null +++ b/module_repos/GLIP/pyproject.toml @@ -0,0 +1,3 @@ +[build-system] +requires = ["setuptools>=61"] +build-backend = "setuptools.build_meta" diff --git a/pixi.toml b/pixi.toml index 595c368..cf5dd8e 100644 --- a/pixi.toml +++ b/pixi.toml @@ -3,36 +3,36 @@ authors = ["ControlNet "] channels = ["pytorch", "nvidia/label/cuda-11.8.0", "anaconda", "conda-forge"] description = "Official implementation for HYDRA." name = "hydra_vl4ai" -platforms = ["linux-64"] +platforms = ["linux-64", "win-64"] version = "0.0.0" channel-priority = "disabled" [tasks] -install_glip = { cmd = "python setup.py clean --all build develop --user", cwd = "module_repos/GLIP", env = { CUDA_HOME = "$CONDA_PREFIX", AM_I_DOCKER = "False", BUILD_WITH_CUDA = "True" } } -install_sam = { cmd = "uv pip install -e .", cwd = "module_repos/Grounded-Segment-Anything/segment_anything", env = { CUDA_HOME = "$CONDA_PREFIX", AM_I_DOCKER = "False", BUILD_WITH_CUDA = "True" } } -install_groundingdino = { cmd = "python setup.py clean --all build develop --user", cwd = "module_repos/Grounded-Segment-Anything/GroundingDINO", env = { CUDA_HOME = "$CONDA_PREFIX", AM_I_DOCKER = "False", BUILD_WITH_CUDA = "True" } } -install_llava = { cmd = "uv pip install -e .", cwd = "module_repos/LLaVA", env = { CUDA_HOME = "$CONDA_PREFIX", AM_I_DOCKER = "False", BUILD_WITH_CUDA = "True" } } -install = { cmd = "uv pip install -e . && echo Depedencies installation finished!", depends-on = ["install_glip", "install_sam", "install_groundingdino", "install_llava"], env = { CUDA_HOME = "$CONDA_PREFIX", AM_I_DOCKER = "False", BUILD_WITH_CUDA = "True" } } -executor = "python -m hydra_vl4ai.executor --base_config config/gqa.yaml --model_config config/model_config_1gpu.yaml" -download_model = "python -m hydra_vl4ai.download_model --base_config config/gqa.yaml --model_config config/model_config_1gpu.yaml" +executor = "python -m hydra_vl4ai.executor --base_config config/okvqa.yaml --model_config config/model_config_1gpu.yaml" +download_model = "python -m hydra_vl4ai.download_model --base_config config/okvqa.yaml --model_config config/model_config_1gpu.yaml" [build-dependencies] setuptools = "*" cmake = "*" ninja = "*" +[pypi-options] +no-build-isolation = ["maskrcnn_benchmark", "segment_anything", "groundingdino", "llava", "hydra_vl4ai"] + [dependencies] python = { version = "3.11.*", channel = "anaconda" } pytorch = { version = "==2.1.2", channel = "pytorch" } torchvision = { version = "==0.16.2", channel = "pytorch" } torchaudio = { version = "==2.1.2", channel = "pytorch" } pytorch-cuda = { version = "11.8.*", channel = "pytorch" } -cuda = { version = "11.8.0", channel = "nvidia/label/cuda-11.8.0" } -cuda-libraries-dev = { version = "11.8.0", channel = "nvidia/label/cuda-11.8.0" } -cuda-version = "11.8" +cuda = { version = "==11.8.0", channel = "nvidia/label/cuda-11.8.0" } +cuda-libraries-dev = { version = "==11.8.0", channel = "nvidia/label/cuda-11.8.0" } +cuda-version = "==11.8" +markupsafe = ">2.0,<3.0" numpy = "<2.0" ipywidgets = ">=8.1.5,<9" ipykernel = ">=6.29.5,<7" +pandas = ">=2.2.3,<3" [pypi-dependencies] fastapi = "*" @@ -66,3 +66,9 @@ gdown = "==5.2.0" rich = "==13.7.1" ollama = "~=0.3.0" protobuf = "~=3.19.0" +gradio = ">=5.0,<5.12" +maskrcnn_benchmark = { path = "module_repos/GLIP" } +segment_anything = { path = "module_repos/Grounded-Segment-Anything/segment_anything" } +groundingdino = { path = "module_repos/Grounded-Segment-Anything/GroundingDINO" } +llava = { path = "module_repos/LLaVA" } +hydra_vl4ai = { path = ".", editable = true } diff --git a/requirements.txt b/requirements.txt index e601ea9..167e6d2 100644 --- a/requirements.txt +++ b/requirements.txt @@ -28,4 +28,4 @@ scipy==1.14.1 gdown==5.2.0 rich==13.7.1 ollama==0.3.* -gradio +gradio>=5.0,<5.12 diff --git a/train.py b/train.py new file mode 100644 index 0000000..9466a88 --- /dev/null +++ b/train.py @@ -0,0 +1,104 @@ +import asyncio +import json +import os +from pathlib import Path + +from dotenv import load_dotenv + +load_dotenv() + +import numpy as np + +import argparse + +parser = argparse.ArgumentParser() +parser.add_argument("--data_root", type=str, required=True) +parser.add_argument("--base_config", type=str, required=True) +parser.add_argument("--model_config", type=str, required=True) +parser.add_argument("--result_folder", type=str, default="./result") +parser.add_argument("--dqn_config", type=str, required=True) +args = parser.parse_args() + +from hydra_vl4ai.util.config import Config + +Config.base_config_path = args.base_config +Config.dqn_config_path = args.dqn_config +Config.model_config_path = args.model_config + +from hydra_vl4ai.agent.hydra import HydraWithRL +from hydra_vl4ai.util.console import logger, console +from hydra_vl4ai.util.misc import wait_until_loaded +import exp_datasets + + +async def main(): + with console.status("[bold green]Connect to HYDRA executor...") as status: + wait_until_loaded(f"http://localhost:{Config.base_config['executor_port']}") + hydra = HydraWithRL(training=True) + + match Config.base_config["dataset"]: + case "gqa": + dataset = exp_datasets.GQA( + args.data_root + ) + case "okvqa": + dataset = exp_datasets.OKVQA( + args.data_root + ) + case "aokvqa": + # TODO: Not tested yet + # dataset = exp_datasets.AOKVQA( + # f"{args.data_root}/aokvqa", + # "val", f"{args.data_root}/coco", version="v1p0" + # ) + raise NotImplementedError("AOKVQA is not implemented yet") + case "refcoco": + dataset = exp_datasets.Refcoco(args.data_root) + case "refcoco+": + dataset = exp_datasets.Refcoco(args.data_root) + case _: + raise ValueError("Invalid dataset") + + # output path + Path(args.result_folder).mkdir(parents=True, exist_ok=True) + save_path = Path(args.result_folder) / f"result_{Config.base_config['dataset']}.jsonl" + + cum_reward = 0 # TODO:modify + + for epoch_idx_ in range(Config.dqn_config["training_epoch"]): + for i, (image_path, datum_id, query, ground_truth) in enumerate(dataset): + + logger.info(f"Processing {i + 1}/{len(dataset)}") + with open(image_path, "rb") as f: + image_buffer = f.read() + result = await hydra.train_step(image_buffer, query, ground_truth) + logger.info(f"Query: {query} Answer: {result}") + + with open(save_path, "a") as f: + f.write(json.dumps({ + "datum_id": datum_id, + "query": query, + "ground_truth": ground_truth, + "result": result + }) + "\n") + f.flush() + + # training log info + if hydra.controller.obs_no % hydra.controller.train_log_interval == 0: + mean_reward = np.mean(hydra.controller.reward_window) + cum_reward = 0.99 * cum_reward + 0.01 * mean_reward + logger.info('---Current step:{}-----Mean Reward:{:.2f}----Cumulative Reward:{:.2f}'.format( + hydra.controller.obs_no, mean_reward, cum_reward)) + + if hydra.controller.save_model_obs_num > hydra.controller.save_interval \ + and hydra.controller.best_cum_reward < cum_reward: + # update best cumulated reward + hydra.controller.best_cum_reward = cum_reward + + # save model + hydra.controller.save() + hydra.controller.save_model_obs_num = 0 + + +if __name__ == "__main__": + asyncio.run(main=main())