diff --git a/README.md b/README.md
index d07ce40..22cd0be 100644
--- a/README.md
+++ b/README.md
@@ -34,7 +34,7 @@
-**This is the code for the paper [HYDRA: A Hyper Agent for Dynamic Compositional Visual Reasoning](https://link.springer.com/chapter/10.1007/978-3-031-72661-3_8), accepted by ECCV 2024 \[[Project Page](https://hydra-vl4ai.github.io)\].**
+**This is the code for the paper [HYDRA: A Hyper Agent for Dynamic Compositional Visual Reasoning](https://link.springer.com/chapter/10.1007/978-3-031-72661-3_8), accepted by ECCV 2024 \[[Project Page](https://hydra-vl4ai.github.io)\]. We released the code that uses Reinforcement Learning (DQN) to fine-tune the LLM🔥🔥🔥**
## Release
@@ -50,7 +50,8 @@ We also notice the embedding model is updated by OpenAI as shown in this [link](
- [x] LLaMA3.1 (ollama) replacement.
- [x] Gradio Demo
- [x] GPT-4o Version.
-- [ ] HYDRA with RL
+- [x] HYDRA with RL(DQN).
+- [ ] HYDRA with Deepseek R1.
## Installation
@@ -71,7 +72,7 @@ git clone https://github.com/ControlNet/HYDRA
Option 1: Using [pixi](https://prefix.dev/) (recommended):
```Bash
-pixi run install
+pixi install
pixi shell
```
@@ -86,8 +87,8 @@ If you meet errors, please consider going through the `build_env.sh` file and in
Edit the file `.env` or setup in CLI to configure the environment variables.
```
-OPENAI_API_KEY=your-api-key
-OLLAMA_HOST=http://ollama.server:11434
+OPENAI_API_KEY=your-api-key # if you want to use OpenAI LLMs
+OLLAMA_HOST=http://ollama.server:11434 # if you want to use your OLLaMA server for llama or deepseek
# do not change this TORCH_HOME variable
TORCH_HOME=./pretrained_models
```
@@ -126,7 +127,9 @@ python demo_gradio.py \
--base_config \
--model_config
```
+https://github.com/user-attachments/assets/39a897ab-d457-49d2-8527-0d6fe3a3b922
+---
### Inference dataset
```Bash
@@ -150,6 +153,23 @@ For example,
python evaluate.py result/result_okvqa.jsonl okvqa
```
+## Training Controller with RL(DQN)
+
+```Bash
+python train.py \
+ --data_root \
+ --base_config \
+ --model_config \
+ --dqn_config
+```
+For example,
+```Bash
+python train.py \
+ --data_root ../coco2014 \
+ --base_config ./config/okvqa.yaml\
+ --model_config ./config/model_config_1gpu.yaml \
+ --dqn_config ./config/dqn_debug.yaml
+```
## Citation
```bibtex
diff --git a/config/dqn_debug.yaml b/config/dqn_debug.yaml
new file mode 100644
index 0000000..71b3998
--- /dev/null
+++ b/config/dqn_debug.yaml
@@ -0,0 +1,15 @@
+model_name: dataset_nameaokvqa_learn_starts300
+llm_embedding_dim: 1536 # small, 3072 for large
+mlp_hidden_dim: 512
+critic_layer_num: 4
+critic_lr: 0.0001
+train_log_interval: 1
+batch_size: 1
+update_times: 2
+save_interval: 1
+learn_starts: 2
+dqn_explore_epsilon: 0.2
+dqn_explore_epsilon_decay_rate: 0.02
+dqn_explore_epsilon_decay_interval: 100
+buffer_size: 100000
+training_epoch: 1
\ No newline at end of file
diff --git a/config/dqn_train_config.yaml b/config/dqn_train_config.yaml
new file mode 100644
index 0000000..af1b147
--- /dev/null
+++ b/config/dqn_train_config.yaml
@@ -0,0 +1,15 @@
+model_name: dqn_model_example
+llm_embedding_dim: 1536 # small, 3072 for large
+mlp_hidden_dim: 512
+critic_layer_num: 4
+critic_lr: 0.0001
+train_log_interval: 100
+batch_size: 128
+update_times: 4
+save_interval: 100
+learn_starts: 1000
+dqn_explore_epsilon: 0.2
+dqn_explore_epsilon_decay_rate: 0.02
+dqn_explore_epsilon_decay_interval: 100
+buffer_size: 100000
+training_epoch: 4
\ No newline at end of file
diff --git a/evaluate.py b/evaluate.py
index db41125..e02b68a 100644
--- a/evaluate.py
+++ b/evaluate.py
@@ -1,10 +1,10 @@
-import sys
-from evaluation.vqa_eval import GQAeval
-from evaluation.grounding_eval import batch_iou_2d
-from hydra_vl4ai.util.console import console, logger
-import tensorneko_util as N
import argparse
+import tensorneko_util as N
+
+from hydra_vl4ai.evaluation.grounding_eval import batch_iou_2d
+from hydra_vl4ai.evaluation.vqa_eval import GQAeval
+from hydra_vl4ai.util.console import console, logger
if __name__ == '__main__':
@@ -20,7 +20,8 @@
match args.dataset:
case "okvqa":
- score = evaluator.accuracy_one_set([each["result"] for each in result], [each["ground_truth"] for each in result])
+ score = evaluator.accuracy_one_set([each["result"] for each in result],
+ [each["ground_truth"] for each in result])
case "refcoco" | "refcoco+":
score = batch_iou_2d(result)
case _:
diff --git a/hydra_vl4ai/agent/controller.py b/hydra_vl4ai/agent/controller.py
index 24595c1..651645d 100644
--- a/hydra_vl4ai/agent/controller.py
+++ b/hydra_vl4ai/agent/controller.py
@@ -1,16 +1,161 @@
import abc
+import pickle
+import random
+from collections import deque
import numpy as np
+import torch
+
+from hydra_vl4ai.util.console import logger
+from .llm import llm_embedding
+from .rl_dqn import DQN_EmbeddingViaLLM, ReplayBuffer
+from .smb.state_memory_bank import StateMemoryBank
+from ..util.config import Config
+from ..util.misc import get_hydra_root_folder
class Controller(abc.ABC):
@abc.abstractmethod
- def __call__(self, instructions: list[str], probs: np.ndarray) -> str:
+ def __call__(self, *args, **kwargs) -> str:
pass
class ControllerLLM(Controller):
+ """
+ This is the function for not using RL controller
+ but directly use the LLM score to return optimal instruction
+ """
def __call__(self, instructions: list[str], probs: np.ndarray) -> str:
return instructions[np.argmax(probs)]
+
+
+class ControllerDQN(Controller):
+
+ def __init__(self,
+ embedding_prompt_base: str,
+ task_description_for_instruction: str,
+ instruction_example: str,
+ training: bool = False
+ ):
+ super().__init__()
+ self.instruction_example = instruction_example
+ self.embedding_prompt_base = embedding_prompt_base
+ self.model_name = Config.dqn_config["model_name"]
+ self.model_save_path = get_hydra_root_folder().parent / "ckpt" / self.model_name
+ self.task_description_for_instruction = task_description_for_instruction
+ self.training = training
+
+ self.rl_agent_model = DQN_EmbeddingViaLLM(
+ device=torch.device('cuda:0'),
+ llm_embedding_dim_concat=Config.dqn_config["llm_embedding_dim"],
+ mlp_hidden_dim=Config.dqn_config["mlp_hidden_dim"],
+ action_dim=Config.base_config["num_actions"] + 1,
+ critic_layer_num=Config.dqn_config["critic_layer_num"],
+ critic_lr=float(Config.dqn_config["critic_lr"])
+ )
+ # load model
+ self.model_full_path = self.model_save_path / "critic.pt"
+ self.buffer_path = self.model_save_path / "buffer.pickle"
+ self.model_save_path.mkdir(parents=True, exist_ok=True)
+
+ if self.model_full_path.exists():
+ self.rl_agent_model.load_model(str(self.model_full_path))
+ logger.info(f"Load Model Done from file: {str(self.model_full_path)}")
+ elif not self.training: # for inference, if no model, raise error
+ raise RuntimeError(f"Model is not found: {self.model_full_path}")
+
+ if self.training: # for training
+ self.rl_agent_model.train_mode()
+ self.train_log_interval = Config.dqn_config["train_log_interval"]
+ self.reward_window = deque(maxlen=self.train_log_interval)
+ self.obs_no = 0
+ self.batch_size = Config.dqn_config["batch_size"]
+ self.update_times = Config.dqn_config["update_times"]
+ self.save_interval = Config.dqn_config["save_interval"]
+ self.save_model_obs_num = 0 # accumulate
+ self.best_cum_reward = -100 # TODO:MODIFY
+ self.best_score = 0
+ self.learn_starts = Config.dqn_config["learn_starts"]
+ self.dqn_explore_epsilon = Config.dqn_config["dqn_explore_epsilon"]
+ self.dqn_explore_epsilon_decay_rate = Config.dqn_config["dqn_explore_epsilon_decay_rate"]
+ self.dqn_explore_epsilon_decay_interval = Config.dqn_config["dqn_explore_epsilon_decay_interval"]
+ self.dqn_explore_threshold = self.dqn_explore_epsilon - self.dqn_explore_epsilon_decay_rate \
+ * (self.obs_no / self.dqn_explore_epsilon_decay_interval)
+
+ # load buffer
+ if self.buffer_path.exists():
+ with open(self.buffer_path, "rb") as reward_buffer_container:
+ self.replay_buffer = pickle.load(reward_buffer_container)
+ reward_buffer_container.close()
+ else:
+ self.replay_buffer = ReplayBuffer(capacity=Config.dqn_config["buffer_size"])
+
+ else:
+ self.rl_agent_model.eval_mode()
+
+ def save(self):
+ self.rl_agent_model.save_model(self.model_full_path)
+ with open(self.buffer_path, "wb") as f:
+ pickle.dump(self.replay_buffer, f)
+
+ def load(self):
+ self.rl_agent_model.load_model(self.model_full_path)
+ if self.training and self.buffer_path.exists():
+ with open(self.buffer_path, "rb") as f:
+ self.replay_buffer = pickle.load(f)
+
+ async def __call__(self, query: str, current_step_index: int, instructions: list[str], probs: np.ndarray,
+ state_memory_bank: StateMemoryBank
+ ) -> tuple[str, np.ndarray, int]:
+ prompt = self.build_prompt(query, current_step_index, instructions, probs, state_memory_bank)
+
+ # get embedding from llm
+ response_emb = await llm_embedding(Config.base_config["embedding_model"], prompt)
+
+ affordance_value_array = self.rl_agent_model.get_action(obs=response_emb)
+
+ selected_idx = np.argmax(affordance_value_array)
+
+ # random exploration in the beginning.
+ if self.training:
+ # if it is in the beginning phase, do random exploration!
+ if self.obs_no <= self.learn_starts or np.random.random() <= self.dqn_explore_threshold:
+ selected_idx = random.choice(range(len(affordance_value_array)))
+
+ if selected_idx != len(instructions):
+ selected_instruction = instructions[selected_idx]
+ else:
+ selected_instruction = "REJECT"
+ return selected_instruction, response_emb, selected_idx
+
+ def build_prompt(self, query: str, current_step_index: int, instructions: list[str], probs: np.ndarray,
+ state_memory_bank: StateMemoryBank
+ ):
+ """Getting prompt based on template"""
+ # prompt-for-each-query
+ prompt = self.embedding_prompt_base.replace('[INSERT_QUERY_HERE]', query) # query insert
+ prompt = prompt.replace('[INSERT_CURRENT_STEP_NO]', str(current_step_index)) # step number insert
+
+ # prompt-for-query-type-about-the-dataset
+ prompt = prompt.replace('[INSERT_QUERY_TYPE_HERE]', self.task_description_for_instruction) # query type
+ prompt = prompt.replace('[EXAMPLE_HERE]', self.instruction_example) # query type demo/ exps
+
+ # previous instruction
+ prompt = prompt.replace('[NEED_TO_PROVIDE_PREVIOUS_INSTRUCTION]',
+ state_memory_bank.instructions_prompt) # previous code insert
+
+ # previous executed code
+ prompt = prompt.replace('[MORE_CODE_WAITING]', state_memory_bank.codes_prompt) # previous code insert
+ prompt = prompt.replace('[CURRENTLY_RESULT_WAITING]',
+ state_memory_bank.feedbacks_prompt) # result description insert
+
+ # variable details
+ prompt = prompt.replace('[VARIABLE_AND_DETAILS]', state_memory_bank.variables_prompt)
+
+ # current instructions/probs
+ prompt = prompt.replace('[CURRENT_OPTION]', str(instructions)) # instruction options
+ prompt = prompt.replace('[CURRENT_OPTION_PROBABILITY]', str(probs)) # probs of instruction options
+
+ return prompt
diff --git a/hydra_vl4ai/agent/hydra.py b/hydra_vl4ai/agent/hydra.py
index 2f92f0e..83377ce 100644
--- a/hydra_vl4ai/agent/hydra.py
+++ b/hydra_vl4ai/agent/hydra.py
@@ -1,15 +1,17 @@
import tensorneko_util as N
+import torchvision.ops.boxes as bops
import websockets
from websockets import WebSocketClientProtocol
+from .controller import ControllerLLM, ControllerDQN
+from .planner import Planner
from .reasoner import Reasoner
from .smb import StateMemoryBank
-from .controller import ControllerLLM
-from .planner import Planner
from .summarizer import Summarizer
+from ..evaluation.vqa_eval import GQAeval
+from ..util.config import Config
from ..util.console import logger
from ..util.misc import get_hydra_root_folder
-from ..util.config import Config
class Hydra:
@@ -125,7 +127,7 @@ async def _call_vqa(self, image: bytes, query: str, state_memory_bank: StateMemo
with N.util.Timer(verbose=False) as timer:
# initial perception
result, code = await self.reasoner.initial_run(image, query, websocket)
- t = timer.time(timer_msg := f"Reasoner in Step 0")
+ t = timer.time(timer_msg := "Reasoner in Step 0")
logger.debug(f"[Timer] {timer_msg}: {t:.4f} sec")
if result.type == "error":
return None
@@ -173,6 +175,358 @@ async def _call_vqa(self, image: bytes, query: str, state_memory_bank: StateMemo
case "final":
# ------------------ Summarizer ------------------
final_result = await self.summarizer.final_guess(query, guesses)
- t = timer.time(timer_msg := f"Summarizer in Step Final")
+ t = timer.time(timer_msg := "Summarizer in Step Final")
+ logger.debug(f"[Timer] {timer_msg}: {t:.4f} sec")
+ return final_result
+
+
+class HydraWithRL(Hydra):
+
+ def __init__(self, training: bool = False):
+ super().__init__()
+ # load config
+ self.dataset = Config.base_config["dataset"]
+ prompt_type = Config.base_config["prompt"]
+ self.max_iterations = Config.base_config["max_iterations"]
+ self._debug = Config.base_config["debug"]
+ self.task = Config.base_config["task"]
+ num_actions = Config.base_config["num_actions"]
+ reasoner_max_retry = Config.base_config["reasoner_max_retry"]
+
+ # load prompts
+ prompt_path = get_hydra_root_folder() / "agent" / "prompt" / prompt_type
+ if not prompt_path.exists():
+ raise NotImplementedError(
+ f"Prompt for {prompt_type} on {self.dataset} is not implemented in {prompt_path}.")
+
+ instruction_prompt_base = N.io.read.text(str(prompt_path / "instruction.prompt"))
+ task_description_for_instruction = N.io.read.text(str(prompt_path / "task_description_for_instruction.prompt"))
+ instruction_examples = N.io.read.text(str(prompt_path / "instruction_example.prompt"))
+
+ code_prompt_base = N.io.read.text(str(prompt_path / "code.prompt"))
+ task_description_for_code = N.io.read.text(str(prompt_path / "task_description_for_code.prompt"))
+ code_example = N.io.read.text(str(prompt_path / "code_example.prompt"))
+
+ self.planner = Planner(
+ instruction_prompt_base,
+ task_description_for_instruction,
+ instruction_examples,
+ num_actions,
+ Config.base_config["planner_max_retry"])
+
+ embedding_prompt_base = N.io.read.text(str(prompt_path / "embedding.prompt"))
+ self.controller = ControllerDQN(
+ embedding_prompt_base,
+ task_description_for_instruction,
+ instruction_examples,
+ training
+ )
+
+ self.reasoner = Reasoner(
+ code_prompt_base,
+ task_description_for_code,
+ code_example,
+ num_actions,
+ reasoner_max_retry
+ )
+
+ match self.task:
+ case "grounding":
+ self.summarizer = None
+ case "vqa":
+ summarize_prompt_base = N.io.read.text(str(prompt_path / "summarize.prompt"))
+ guess_answer_prompt_base = N.io.read.text(str(prompt_path / "guess_answer.prompt"))
+ self.summarizer = Summarizer(
+ summarize_prompt_base,
+ guess_answer_prompt_base,
+ task_description_for_instruction
+ )
+
+ self.evaluator = GQAeval()
+
+ async def __call__(self, image: bytes, query: str) -> str:
+ state_memory_bank = StateMemoryBank()
+ async with websockets.connect(f"ws://localhost:{Config.base_config['executor_port']}/ws/") as websocket:
+ match self.task:
+ case "grounding":
+ return await self._call_grounding(image, query, state_memory_bank, websocket)
+
+ case "vqa":
+ return await self._call_vqa(image, query, state_memory_bank, websocket)
+
+ async def _call_grounding(self, image: bytes, query: str, state_memory_bank: StateMemoryBank,
+ websocket: WebSocketClientProtocol
+ ) -> str:
+ with N.util.Timer(verbose=False) as timer:
+ for current_step_index in range(1, self.max_iterations + 1):
+ # ----------------- Planner -----------------
+ instructions, probs = await self.planner(query, current_step_index, state_memory_bank)
+ t = timer.time(timer_msg := f"Planner in Step {current_step_index}")
+ logger.debug(f"[Timer] {timer_msg}: {t:.4f} sec")
+
+ if not instructions:
+ break
+
+ # ----------------- Controller -----------------
+ instruction, response_emb, selected_idx = await self.controller(query, current_step_index, instructions,
+ probs, state_memory_bank) # TODO:MODIFY
+ t = timer.time(timer_msg := f"Controller in Step {current_step_index}")
+ logger.debug(f"[Timer] {timer_msg}: {t:.4f} sec")
+
+ if instruction == "REJECT": # TODO:MODIFY
+ continue
+ # ------------------ Reasoner ------------------
+ result, code = await self.reasoner(image, query, instruction, current_step_index, state_memory_bank,
+ websocket)
+ t = timer.time(timer_msg := f"Reasoner in Step {current_step_index}")
+ logger.debug(f"[Timer] {timer_msg}: {t:.4f} sec")
+
+ if result.type != "error":
+ assert type(code) is str
+ state_memory_bank.extend_memory(
+ result.feedbacks, [code], [instruction], result.variables, result.variable_names
+ )
+
+ match result.type:
+ case "error":
+ return None
+ case "final":
+ return result.final_result
+ case "continue":
+ continue
+
+ async def _call_vqa(self, image: bytes, query: str, state_memory_bank: StateMemoryBank,
+ websocket: WebSocketClientProtocol
+ ) -> str:
+ with N.util.Timer(verbose=False) as timer:
+ # initial perception
+ result, code = await self.reasoner.initial_run(image, query, websocket)
+ t = timer.time(timer_msg := "Reasoner in Step 0")
+ logger.debug(f"[Timer] {timer_msg}: {t:.4f} sec")
+ if result.type == "error":
+ return None
+ state_memory_bank.extend_memory(
+ result.feedbacks, [], [], result.variables, result.variable_names
+ )
+
+ guesses = []
+
+ for current_step_index in range(1, self.max_iterations + 1):
+ # ----------------- Planner -----------------
+ instructions, probs = await self.planner(query, current_step_index, state_memory_bank)
+ t = timer.time(timer_msg := f"Planner in Step {current_step_index}")
+ logger.debug(f"[Timer] {timer_msg}: {t:.4f} sec")
+
+ if not instructions:
+ break
+
+ # ----------------- Controller -----------------
+ instruction, response_emb, selected_idx = await self.controller(query, current_step_index, instructions,
+ probs, state_memory_bank) # TODO:MODIFY
+ t = timer.time(timer_msg := f"Controller in Step {current_step_index}")
+ logger.debug(f"[Timer] {timer_msg}: {t:.4f} sec")
+
+ if instruction == "REJECT": # TODO:MODIFY
+ continue
+ # ------------------ Reasoner ------------------
+ result, code = await self.reasoner(image, query, instruction, current_step_index, state_memory_bank,
+ websocket)
+ t = timer.time(timer_msg := f"Reasoner in Step {current_step_index}")
+ logger.debug(f"[Timer] {timer_msg}: {t:.4f} sec")
+
+ if result.type != "error":
+ assert type(code) is str
+ state_memory_bank.extend_memory(
+ result.feedbacks, [code], [instruction], result.variables, result.variable_names
+ )
+ # ------------------ Summarizer ------------------
+ guesses.append(await self.summarizer(query, state_memory_bank))
+ t = timer.time(timer_msg := f"Summarizer in Step {current_step_index}")
+ logger.debug(f"[Timer] {timer_msg}: {t:.4f} sec")
+
+ match result.type:
+ case "continue":
+ continue
+ case "error":
+ return None
+ case "final":
+ # ------------------ Summarizer ------------------
+ final_result = await self.summarizer.final_guess(query, guesses)
+ t = timer.time(timer_msg := "Summarizer in Step Final")
+ logger.debug(f"[Timer] {timer_msg}: {t:.4f} sec")
+ return final_result
+
+ async def train_step(self, image: bytes, query: str, ground_true=None) -> str:
+ state_memory_bank = StateMemoryBank()
+ async with websockets.connect(f"ws://localhost:{Config.base_config['executor_port']}/ws/") as websocket:
+ match self.task:
+ case "grounding":
+ return await self._train_step_grounding(image, query, state_memory_bank, websocket, ground_true)
+
+ case "vqa":
+ return await self._train_step_vqa(image, query, state_memory_bank, websocket, ground_true)
+
+ async def _train_step_grounding(self, image: bytes, query: str, state_memory_bank: StateMemoryBank,
+ websocket: WebSocketClientProtocol, ground_true=None
+ ) -> str:
+ with N.util.Timer(verbose=False) as timer:
+ sub_reward = 10
+ pre_obs_emb = None
+ for current_step_index in range(1, self.max_iterations + 1):
+ # ----------------- Planner -----------------
+ instructions, probs = await self.planner(query, current_step_index, state_memory_bank)
+ t = timer.time(timer_msg := f"Planner in Step {current_step_index}")
+ logger.debug(f"[Timer] {timer_msg}: {t:.4f} sec")
+
+ if not instructions:
+ break
+
+ # ----------------- Controller -----------------
+ instruction, response_emb, selected_idx = await self.controller(query, current_step_index, instructions,
+ probs, state_memory_bank)
+ t = timer.time(timer_msg := f"Controller in Step {current_step_index}")
+ logger.debug(f"[Timer] {timer_msg}: {t:.4f} sec")
+
+ if instruction == "REJECT":
+ continue
+ # ------------------ Reasoner ------------------
+ result, code = await self.reasoner(image, query, instruction, current_step_index, state_memory_bank,
+ websocket)
+ t = timer.time(timer_msg := f"Reasoner in Step {current_step_index}")
+ logger.debug(f"[Timer] {timer_msg}: {t:.4f} sec")
+
+ if result.type != "error":
+ assert type(code) is str
+ state_memory_bank.extend_memory(
+ result.feedbacks, [code], [instruction], result.variables, result.variable_names
+ )
+
+ else:
+ sub_reward -= 100
+
+ if result.type == "final":
+ sub_reward += float(bops.box_iou(result.final_result, ground_true)[0][0]) * 100
+
+ # Calculate reward
+ sub_reward -= current_step_index
+ self.controller.save_model_obs_num += 1
+
+ # update buffer and model
+ if current_step_index > 1:
+ # store transition
+ self.controller.replay_buffer.push(pre_obs_emb, selected_idx, sub_reward, response_emb, done=False)
+ self.controller.reward_window.append(sub_reward)
+ self.controller.obs_no += 1
+ # update model each step when buffer size bigger than batch_size.
+ if len(self.controller.replay_buffer) > self.controller.batch_size:
+ for i in range(self.controller.update_times):
+ self.controller.rl_agent_model.update(replay_buffer=self.controller.replay_buffer,
+ batch_size=self.controller.batch_size)
+
+ pre_obs_emb = response_emb # reserve current emb as previous emb
+ self.controller.dqn_explore_threshold = \
+ self.controller.dqn_explore_epsilon - self.controller.dqn_explore_epsilon_decay_rate \
+ * (self.controller.obs_no / self.controller.dqn_explore_epsilon_decay_interval)
+
+ match result.type:
+ case "error":
+ return None
+ case "final":
+ return result.final_result
+ case "continue":
+ continue
+
+ async def _train_step_vqa(self, image: bytes, query: str, state_memory_bank: StateMemoryBank,
+ websocket: WebSocketClientProtocol, ground_true=None
+ ) -> str:
+ with (N.util.Timer(verbose=False) as timer):
+ # initial perception
+ result, code = await self.reasoner.initial_run(image, query, websocket)
+ t = timer.time(timer_msg := "Reasoner in Step 0")
+ logger.debug(f"[Timer] {timer_msg}: {t:.4f} sec")
+ if result.type == "error":
+ return None
+ state_memory_bank.extend_memory(
+ result.feedbacks, [], [], result.variables, result.variable_names
+ )
+
+ guesses = []
+
+ # for training
+ sub_reward = 0
+ pre_obs_emb = None
+
+ for current_step_index in range(1, self.max_iterations + 1):
+ # ----------------- Planner -----------------
+ instructions, probs = await self.planner(query, current_step_index, state_memory_bank)
+ t = timer.time(timer_msg := f"Planner in Step {current_step_index}")
+ logger.debug(f"[Timer] {timer_msg}: {t:.4f} sec")
+
+ if not instructions:
+ break
+
+ # ----------------- Controller -----------------
+ instruction, response_emb, selected_idx = await self.controller(query, current_step_index, instructions,
+ probs, state_memory_bank)
+ t = timer.time(timer_msg := f"Controller in Step {current_step_index}")
+ logger.debug(f"[Timer] {timer_msg}: {t:.4f} sec")
+
+ if instruction == "REJECT":
+ continue
+ # ------------------ Reasoner ------------------
+ result, code = await self.reasoner(image, query, instruction, current_step_index, state_memory_bank,
+ websocket)
+ t = timer.time(timer_msg := f"Reasoner in Step {current_step_index}")
+ logger.debug(f"[Timer] {timer_msg}: {t:.4f} sec")
+
+ if result.type != "error":
+ assert type(code) is str
+ state_memory_bank.extend_memory(
+ result.feedbacks, [code], [instruction], result.variables, result.variable_names
+ )
+ # ------------------ Summarizer ------------------
+ guesses.append(await self.summarizer(query, state_memory_bank))
+ t = timer.time(timer_msg := f"Summarizer in Step {current_step_index}")
+ logger.debug(f"[Timer] {timer_msg}: {t:.4f} sec")
+
+ else:
+ sub_reward -= 100
+
+ if result.type == "final":
+ if 'okvqa' in self.dataset:
+ sub_reward += self.evaluator.accuracy_one_set([result.final_result], [ground_true])
+ else:
+ sub_reward += self.evaluator.accuracy_one_one([result.final_result], [ground_true])
+
+ # Calculate reward
+ sub_reward -= current_step_index
+ self.controller.save_model_obs_num += 1
+
+ # update buffer and model
+ if current_step_index > 1:
+ # store transition
+ self.controller.replay_buffer.push(pre_obs_emb, selected_idx, sub_reward, response_emb, done=False)
+ self.controller.reward_window.append(sub_reward)
+ self.controller.obs_no += 1
+ # update model each step when buffer size bigger than batch_size.
+ if len(self.controller.replay_buffer) > self.controller.batch_size:
+ for i in range(self.controller.update_times):
+ self.controller.rl_agent_model.update(replay_buffer=self.controller.replay_buffer,
+ batch_size=self.controller.batch_size)
+
+ pre_obs_emb = response_emb
+ self.controller.dqn_explore_threshold = \
+ self.controller.dqn_explore_epsilon - self.controller.dqn_explore_epsilon_decay_rate \
+ * (self.controller.obs_no / self.controller.dqn_explore_epsilon_decay_interval)
+
+ match result.type:
+ case "continue":
+ continue
+ case "error":
+ return None
+ case "final":
+ # ------------------ Summarizer ------------------
+ final_result = await self.summarizer.final_guess(query, guesses)
+ t = timer.time(timer_msg := "Summarizer in Step Final")
logger.debug(f"[Timer] {timer_msg}: {t:.4f} sec")
return final_result
diff --git a/hydra_vl4ai/agent/llm.py b/hydra_vl4ai/agent/llm.py
index 49dbfbf..2cfdae8 100644
--- a/hydra_vl4ai/agent/llm.py
+++ b/hydra_vl4ai/agent/llm.py
@@ -1,12 +1,13 @@
import asyncio
-from functools import wraps
import os
+import time
+from functools import wraps
+
import httpx
import numpy as np
-from openai import AsyncOpenAI
-from ollama import AsyncClient, Client
import openai
-import time
+from ollama import AsyncClient, Client
+from openai import AsyncOpenAI
from ..util.config import Config
from ..util.console import logger
@@ -40,21 +41,22 @@ async def wrapper(*args, **kwargs):
for _ in range(max_trial):
try:
return await func(*args, **kwargs)
- except openai.APITimeoutError as e:
+ except openai.APITimeoutError:
pass
- except openai.APIConnectionError as e:
+ except openai.APIConnectionError:
pass
- except openai.RateLimitError as e:
+ except openai.RateLimitError:
time.sleep(1)
pass
except openai.BadRequestError as e:
# maybe exceed the length, should raise directly
- raise
+ raise e
except openai.APIStatusError as e:
# server side problem, should raise directly
- raise
+ raise e
except Exception as e:
- raise
+ raise e
+
return wrapper
@@ -66,14 +68,15 @@ async def wrapper(*args, **kwargs):
for _ in range(max_trial):
try:
return await func(*args, **kwargs)
- except httpx.ConnectError as e:
+ except httpx.ConnectError:
pass
- except httpx.ConnectTimeout as e:
+ except httpx.ConnectTimeout:
pass
- except httpx.TimeoutException as e:
+ except httpx.TimeoutException:
pass
except Exception as e:
- raise
+ raise e
+
return wrapper
@@ -86,10 +89,10 @@ async def chatgpt(model_name: str, prompt: str):
@handle_openai_exceptions
-async def gpt3_embedding(prompt: str):
+async def gpt3_embedding(model_name: str, prompt: str):
async with _semaphore:
- response = (await openai_client.embeddings.create(input = [prompt],
- model=Config.base_config["embedding_model"])).data[0].embedding
+ response = (await openai_client.embeddings.create(input=[prompt],
+ model=model_name)).data[0].embedding
response = np.array(response)
return response
@@ -97,7 +100,7 @@ async def gpt3_embedding(prompt: str):
@handle_ollama_exceptions
async def ollama(model_name: str, prompt: str):
async with _semaphore:
- response = await ollama_client.chat(model=model_name,
+ response = await ollama_client.chat(model=model_name,
messages=[{"role": "user", "content": prompt}], stream=False, )
return response["message"]["content"]
@@ -109,3 +112,10 @@ async def llm(model_name: str, prompt: str):
return await ollama(model_name, prompt)
else:
raise ValueError(f"Model {model_name} is not supported.")
+
+
+async def llm_embedding(model_name: str, prompt: str):
+ if model_name in ("text-embedding-3-small", "text-embedding-3-large"):
+ return await gpt3_embedding(model_name, prompt)
+ else:
+ raise ValueError(f"Model {model_name} is not supported.")
diff --git a/hydra_vl4ai/agent/planner.py b/hydra_vl4ai/agent/planner.py
index 3351051..d3feb94 100644
--- a/hydra_vl4ai/agent/planner.py
+++ b/hydra_vl4ai/agent/planner.py
@@ -2,9 +2,9 @@
import numpy as np
-from ..util.config import Config
-from .smb.state_memory_bank import StateMemoryBank
from .llm import llm
+from .smb.state_memory_bank import StateMemoryBank
+from ..util.config import Config
from ..util.console import logger
diff --git a/hydra_vl4ai/agent/reasoner.py b/hydra_vl4ai/agent/reasoner.py
index 48e1c5f..6333844 100644
--- a/hydra_vl4ai/agent/reasoner.py
+++ b/hydra_vl4ai/agent/reasoner.py
@@ -1,13 +1,14 @@
import io
import json
+
import websockets
from PIL import Image
+from .llm import llm
+from .smb.state_memory_bank import StateMemoryBank
from ..util.config import Config
from ..util.console import logger
from ..util.message import ExecutionRequest, ExecutionResult
-from .llm import llm
-from .smb.state_memory_bank import StateMemoryBank
class Reasoner:
diff --git a/hydra_vl4ai/agent/rl_dqn.py b/hydra_vl4ai/agent/rl_dqn.py
new file mode 100644
index 0000000..58b2607
--- /dev/null
+++ b/hydra_vl4ai/agent/rl_dqn.py
@@ -0,0 +1,156 @@
+import random
+
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.optim as optim
+from torch.nn.utils import clip_grad_norm_
+
+
+# MLP head
+class MLPHead(nn.Module):
+
+ def __init__(
+ self,
+ input_dim,
+ mlp_hidden_dim,
+ output_dim,
+ layer_num
+ ) -> None:
+ super().__init__()
+
+ mlp_head = []
+ for idx in range(layer_num):
+ if idx == 0:
+ i_dim = input_dim
+ else:
+ i_dim = mlp_hidden_dim
+
+ if idx == layer_num - 1:
+ o_dim = output_dim
+ else:
+ o_dim = mlp_hidden_dim
+
+ mlp_head.append(nn.Linear(i_dim, o_dim))
+
+ # if idx != layer_num -1:
+ mlp_head.append(nn.Sigmoid())
+
+ self.mlp_head = nn.Sequential(*mlp_head)
+
+ def forward(self, x):
+ return self.mlp_head(x)
+
+
+class ReplayBuffer:
+
+ def __init__(self, capacity) -> None:
+ self.capacity = capacity
+ self.buffer = []
+ self.pos = 0
+
+ def push(self, state, action, reward, next_state, done):
+ if len(self.buffer) < self.capacity:
+ self.buffer.append(None)
+ self.buffer[self.pos] = (state, action, reward, next_state, done)
+ self.pos = int((self.pos + 1) % self.capacity) # as a ring buffer
+
+ def sample(self, batch_size):
+ batch = random.sample(self.buffer, batch_size)
+ state, action, reward, next_state, done = map(np.stack, zip(*batch)) # stack for each element
+
+ return state, action, reward, next_state, done
+
+ def __len__(self):
+ return len(self.buffer)
+
+
+class DQN_EmbeddingViaLLM:
+
+ def __init__(
+ self,
+ device,
+ llm_embedding_dim_concat,
+ mlp_hidden_dim,
+ action_dim,
+ critic_layer_num,
+ critic_lr,
+ ) -> None:
+ self.device = device
+
+ self.critic_head = MLPHead(input_dim=llm_embedding_dim_concat, mlp_hidden_dim=mlp_hidden_dim,
+ output_dim=action_dim, layer_num=critic_layer_num).to(device)
+ self.tar_critic_head = MLPHead(input_dim=llm_embedding_dim_concat, mlp_hidden_dim=mlp_hidden_dim,
+ output_dim=action_dim, layer_num=critic_layer_num).to(device)
+ for tar_param, param in zip(self.tar_critic_head.parameters(), self.critic_head.parameters()):
+ tar_param.data.copy_(param.data)
+
+ self.critic_optim = optim.Adam(self.critic_head.parameters(), lr=critic_lr)
+
+ def get_action(self, obs, batch_input=False):
+ if not batch_input:
+ obs = torch.FloatTensor(obs).to(self.device).unsqueeze(0)
+ else:
+ obs = torch.FloatTensor(obs).to(self.device)
+
+ qval = self.critic_head.forward(x=obs)
+
+ return qval.detach().cpu().numpy().flatten()
+ # if not batch_input: return qval.argmax().detach().cpu().numpy().flatten()
+ # else: return qval.argmax(dim=-1).detach().cpu().numpy().flatten()
+
+ def update(self, replay_buffer, batch_size, reward_scale=10., gamma=0.99, soft_tau=1e-2, is_clip_gradient=True,
+ clip_gradient_val=40
+ ):
+ obs, action, reward, next_obs, done = replay_buffer.sample(batch_size)
+
+ obs = torch.FloatTensor(obs).to(self.device) # obs.size = (batch_size, 1+seq_dim)
+ next_obs = torch.FloatTensor(next_obs).to(self.device)
+ action = torch.FloatTensor(action).to(self.device)
+ reward = torch.FloatTensor(reward).unsqueeze(1).to(
+ self.device) # reward is single value, unsqueeze() to add one dim to be [reward] at the sample dim;
+ reward = reward_scale * (reward - reward.mean(dim=0)) / (reward.std(
+ dim=0) + 1e-6) # normalize with batch mean and std; plus a small number to prevent numerical problem
+ done = torch.FloatTensor(np.float32(done)).unsqueeze(1).to(self.device)
+ # print(f'self.critic_head.forward(x=obs) {self.critic_head.forward(x=obs).shape}')
+ # print(f'action.long() {action.long().shape}')
+ # qval = self.critic_head.forward(x=obs).gather(1, action.long())
+ val_value_ = self.critic_head.forward(x=obs)
+ qval = val_value_.gather(1, action.long().reshape(val_value_.size(0), 1))
+
+ with torch.no_grad():
+ max_next_qval = self.tar_critic_head.forward(x=next_obs).max(dim=-1)[0].unsqueeze(-1)
+ tar_qval = reward + gamma * (1 - done) * max_next_qval
+
+ loss_func = nn.MSELoss()
+ qloss = loss_func(qval, tar_qval.detach())
+
+ self.critic_optim.zero_grad()
+ qloss.backward()
+ if is_clip_gradient: clip_grad_norm_(self.critic_head.parameters(), clip_gradient_val)
+ self.critic_optim.step()
+
+ for tar_param, param in zip(self.tar_critic_head.parameters(), self.critic_head.parameters()):
+ tar_param.data.copy_(param.data * soft_tau + tar_param.data * (1 - soft_tau))
+
+ return qloss.detach().cpu().item()
+
+ def save_model(self, path: str):
+ torch.save(self.critic_head.state_dict(), path)
+
+ def load_model(self, path: str):
+ self.critic_head.load_state_dict(torch.load(path))
+
+ for tar_param, param in zip(self.tar_critic_head.parameters(), self.critic_head.parameters()):
+ tar_param.data.copy_(param.data)
+
+ self.critic_head.eval()
+ self.tar_critic_head.eval()
+
+ def eval_mode(self):
+ self.critic_head.eval()
+ self.tar_critic_head.eval()
+
+ def train_mode(self):
+ self.critic_head.train()
+ self.tar_critic_head.train()
diff --git a/hydra_vl4ai/agent/smb/state_memory_bank.py b/hydra_vl4ai/agent/smb/state_memory_bank.py
index 103c4f6..d0e3661 100644
--- a/hydra_vl4ai/agent/smb/state_memory_bank.py
+++ b/hydra_vl4ai/agent/smb/state_memory_bank.py
@@ -25,24 +25,24 @@ def reset(self):
@property
def instructions_prompt(self):
return "\n" + "\n".join(self.instructions)
-
+
@property
def feedbacks_prompt(self):
return "\n" + "\n".join(self.feedbacks)
-
+
@property
def codes_prompt(self):
return "\n" + "\n".join(self.codes)
-
+
@property
def variables_prompt(self):
return "\n" + "\n".join(self.variables)
- def extend_memory(self,
- other_feedbacks: list[str],
- other_codes: list[str],
- other_instructions: list[str],
- other_variables: list[str],
+ def extend_memory(self,
+ other_feedbacks: list[str],
+ other_codes: list[str],
+ other_instructions: list[str],
+ other_variables: list[str],
other_variable_names: list[str]
):
self.feedbacks.extend(other_feedbacks)
@@ -125,7 +125,8 @@ def get_sorted_patches_left_to_right_message_save(self, name):
f"\nThe patches list has been sorted from left to right (horizontal). Now, the first patch in the list corresponds to the leftest position, while the last one corresponds to the rightest position")
def get_sorted_patches_bottom_to_top_message_save(self, name):
- self.feedbacks.append(f"\nThe patches list has been sorted from bottom to top (vertical). Now, the first patch in the list corresponds to the bottom/low/below position, while the last one corresponds to the top/up/above position.")
+ self.feedbacks.append(
+ f"\nThe patches list has been sorted from bottom to top (vertical). Now, the first patch in the list corresponds to the bottom/low/below position, while the last one corresponds to the top/up/above position.")
def get_sorted_patches_front_to_back_message_save(self, name):
self.feedbacks.append(
diff --git a/hydra_vl4ai/agent/summarizer.py b/hydra_vl4ai/agent/summarizer.py
index cbbbe3e..cc2783e 100644
--- a/hydra_vl4ai/agent/summarizer.py
+++ b/hydra_vl4ai/agent/summarizer.py
@@ -1,5 +1,5 @@
-from .smb.state_memory_bank import StateMemoryBank
from .llm import llm
+from .smb.state_memory_bank import StateMemoryBank
from ..util.config import Config
diff --git a/hydra_vl4ai/agent/webui.py b/hydra_vl4ai/agent/webui.py
index 80b7143..176e6f2 100644
--- a/hydra_vl4ai/agent/webui.py
+++ b/hydra_vl4ai/agent/webui.py
@@ -1,16 +1,16 @@
+import gradio as gr
import tensorneko_util as N
import websockets
from websockets import WebSocketClientProtocol
-from .smb import StateMemoryBank
-from ..util.console import logger
from .hydra import HydraNoRL
+from .smb import StateMemoryBank
from ..util.config import Config
-
-import gradio as gr
+from ..util.console import logger
class HydraNoRLWeb(HydraNoRL):
+
def __init__(self):
super().__init__()
@@ -34,7 +34,8 @@ async def _gradio_call(self, image_path: str, query: str):
messages = [gr.ChatMessage('user', [image_path, query])]
yield messages, self._format_state_memory()
- async with websockets.connect(f"ws://localhost:{Config.base_config['executor_port']}/ws/") as ws: # type: ignore
+ async with websockets.connect(
+ f"ws://localhost:{Config.base_config['executor_port']}/ws/") as ws: # type: ignore
async for chunk in self._call_vqa(image_bytes, query, self.bank, ws):
if chunk[0] == "start":
continue
@@ -82,7 +83,7 @@ async def _call_vqa(
result, code = await self.reasoner.initial_run(image, query, websocket)
t = timer.time(timer_msg := "Reasoner in Step 0")
logger.debug(f"[Timer] {timer_msg}: {t:.4f} sec")
-
+
if result.type == "error":
yield "panic", "Reasoning", result
return
@@ -111,7 +112,8 @@ async def _call_vqa(
t = timer.time(timer_msg := f"Controller in Step {current_step_index}")
logger.debug(f"[Timer] {timer_msg}: {t:.4f} sec")
- yield "stage-completed", "Planning", current_step_index, sorted(zip(instructions, probs), key=lambda x: x[1], reverse=True), instruction
+ yield "stage-completed", "Planning", current_step_index, sorted(zip(instructions, probs),
+ key=lambda x: x[1], reverse=True), instruction
# ------------------ Reasoner ------------------
yield "start", "Reasoning"
@@ -124,7 +126,7 @@ async def _call_vqa(
if result.type == "error":
yield "error", "Reasoning", result
continue
-
+
assert type(code) is str
state_memory_bank.extend_memory(
result.feedbacks, [code], [instruction], result.variables, result.variable_names
@@ -148,6 +150,6 @@ async def _call_vqa(
final_result = await self.summarizer.final_guess(query, guesses)
t = timer.time(timer_msg := f"Summarizer in Step Final")
logger.debug(f"[Timer] {timer_msg}: {t:.4f} sec")
-
+
yield "final", final_result
return
diff --git a/evaluation/grounding_eval.py b/hydra_vl4ai/evaluation/grounding_eval.py
similarity index 100%
rename from evaluation/grounding_eval.py
rename to hydra_vl4ai/evaluation/grounding_eval.py
index fe58e0b..54500f4 100644
--- a/evaluation/grounding_eval.py
+++ b/hydra_vl4ai/evaluation/grounding_eval.py
@@ -5,6 +5,7 @@
from numpy import ndarray
from torch import Tensor
+
def iou_2d(proposal: Union[Tensor, ndarray], target: Union[Tensor, ndarray]) -> Tensor:
"""
Calculate 2D IOU for M proposals with N targets.
@@ -55,7 +56,6 @@ def process_grounding_result(x: str):
return eval(x)["final_answer"][0][:4]
-
def batch_iou_2d(result):
# we follow ViperGPT to filter out the None results
remove_index = []
diff --git a/evaluation/vqa_eval.py b/hydra_vl4ai/evaluation/vqa_eval.py
similarity index 100%
rename from evaluation/vqa_eval.py
rename to hydra_vl4ai/evaluation/vqa_eval.py
diff --git a/hydra_vl4ai/execution/image_patch.py b/hydra_vl4ai/execution/image_patch.py
index 8857bc8..46aa162 100644
--- a/hydra_vl4ai/execution/image_patch.py
+++ b/hydra_vl4ai/execution/image_patch.py
@@ -13,7 +13,7 @@
import tensorneko_util as N
from .toolbox import forward
-from ..agent.llm import chatgpt
+from ..agent.llm import llm
from ..agent.smb import StateMemoryBank
from ..util.misc import get_hydra_root_folder, load_json
from ..util.config import Config
@@ -558,10 +558,10 @@ def llm_query(query, context=None, long_answer=True, state_memory_bank=None):
prompt_ += f'Could you help me answer the question: {query}.'
if not long_answer:
- prompt_ += f'Please provide only a few-word answer. Be very concise, no ranges, no doubt.'
+ prompt_ += 'Please provide only a few-word answer. Be very concise, no ranges, no doubt.'
try:
- return_answer = asyncio.run(chatgpt(prompt_)) or ""
- except:
+ return_answer = asyncio.run(llm(Config.base_config["llm_model"], prompt_)) or ""
+ except Exception:
return_answer = 'not answer from gpt'
# get global description.
diff --git a/hydra_vl4ai/tool/llava.py b/hydra_vl4ai/tool/llava.py
index ff82581..c846e6a 100644
--- a/hydra_vl4ai/tool/llava.py
+++ b/hydra_vl4ai/tool/llava.py
@@ -1,11 +1,11 @@
-import os.path
+import re
import torch
+from huggingface_hub import snapshot_download
from llava.conversation import conv_templates
+from llava.mm_utils import process_images, tokenizer_image_token, get_model_name_from_path
from llava.model.builder import load_pretrained_model
from llava.utils import disable_torch_init
-from llava.mm_utils import process_images, tokenizer_image_token, get_model_name_from_path
-from huggingface_hub import snapshot_download
from ._base import BaseModel, module_registry
from ..util.misc import get_root_folder
diff --git a/hydra_vl4ai/util/config.py b/hydra_vl4ai/util/config.py
index 7f59f4e..ff366e0 100644
--- a/hydra_vl4ai/util/config.py
+++ b/hydra_vl4ai/util/config.py
@@ -2,13 +2,16 @@
from tensorneko_util.util import Singleton
import tensorneko_util as N
+
@Singleton
class Config:
def __init__(self):
self.model_config_path: str | None = None
self._base_config_path: str | None = None
+ self._dqn_config_path: str | None = None
self.base_config: dict[str, Any] = dict()
+ self.dqn_config: dict[str, Any] = dict()
@property
def base_config_path(self):
@@ -19,3 +22,13 @@ def base_config_path(self, value):
if value is not None:
self._base_config_path = value
self.base_config = N.read(value)
+
+ @property
+ def dqn_config_path(self):
+ return self._dqn_config_path
+
+ @dqn_config_path.setter
+ def dqn_config_path(self, value):
+ if value is not None:
+ self._dqn_config_path = value
+ self.dqn_config = N.read(value)
diff --git a/inference_bash.sh b/inference_bash.sh
new file mode 100755
index 0000000..197eca2
--- /dev/null
+++ b/inference_bash.sh
@@ -0,0 +1,9 @@
+#!/bin/bash
+
+# preprocessing
+python main.py \
+ --data_root '/home/mai/fke/fkee/coco2014' \
+ --base_config './config/okvqa.yaml' \
+ --model_config './config/model_config_1gpu.yaml' \
+ --dqn_config './config/dqn_testing.yaml'
+echo "preprocessing done"
diff --git a/main.py b/main.py
index f0b3967..634c6fa 100644
--- a/main.py
+++ b/main.py
@@ -1,35 +1,44 @@
import asyncio
import json
import os
-import requests
-import time
-import tensorneko_util as N
from pathlib import Path
+
+import tensorneko_util as N
from dotenv import load_dotenv
+
load_dotenv()
import argparse
+
parser = argparse.ArgumentParser()
parser.add_argument("--data_root", type=str, required=True)
parser.add_argument("--base_config", type=str, required=True)
parser.add_argument("--model_config", type=str, required=True)
parser.add_argument("--result_folder", type=str, default="./result")
+parser.add_argument("--dqn_config", type=str)
args = parser.parse_args()
from hydra_vl4ai.util.config import Config
+
Config.base_config_path = args.base_config
+if args.dqn_config is not None:
+ Config.dqn_config_path = args.dqn_config
Config.model_config_path = args.model_config
-from hydra_vl4ai.agent.hydra import HydraNoRL
+from hydra_vl4ai.agent.hydra import HydraNoRL, HydraWithRL
from hydra_vl4ai.util.console import logger, console
from hydra_vl4ai.util.misc import wait_until_loaded
import exp_datasets
async def main():
- with console.status("[bold green]Connect to HYDRA executor...") as status:
+ with console.status("[bold green]Connect to HYDRA executor..."):
wait_until_loaded(f"http://localhost:{Config.base_config['executor_port']}")
- hydra = HydraNoRL()
+
+ if args.dqn_config is None:
+ hydra = HydraNoRL()
+ else:
+ hydra = HydraWithRL()
match Config.base_config["dataset"]:
case "gqa":
@@ -53,22 +62,22 @@ async def main():
dataset = exp_datasets.Refcoco(args.data_root)
case _:
raise ValueError("Invalid dataset")
-
+
# output path
Path(args.result_folder).mkdir(parents=True, exist_ok=True)
save_path = Path(args.result_folder) / f"result_{Config.base_config['dataset']}.jsonl"
-
+
# resume if the file exists
completed = []
if os.path.exists(save_path):
prev_results = N.io.read.json(str(save_path))
completed = [result["datum_id"] for result in prev_results]
-
+
for i, (image_path, datum_id, query, ground_truth) in enumerate(dataset):
if datum_id in completed:
- logger.info(f"Skipping {i+1}/{len(dataset)}")
+ logger.info(f"Skipping {i + 1}/{len(dataset)}")
continue
- logger.info(f"Processing {i+1}/{len(dataset)}")
+ logger.info(f"Processing {i + 1}/{len(dataset)}")
with open(image_path, "rb") as f:
image_buffer = f.read()
result = await hydra(image_buffer, query)
@@ -83,5 +92,6 @@ async def main():
}) + "\n")
f.flush()
+
if __name__ == "__main__":
asyncio.run(main=main())
diff --git a/media/gradio_demo.mov b/media/gradio_demo.mov
new file mode 100644
index 0000000..702f4a4
Binary files /dev/null and b/media/gradio_demo.mov differ
diff --git a/module_repos/GLIP/pyproject.toml b/module_repos/GLIP/pyproject.toml
new file mode 100644
index 0000000..8ca2555
--- /dev/null
+++ b/module_repos/GLIP/pyproject.toml
@@ -0,0 +1,3 @@
+[build-system]
+requires = ["setuptools>=61"]
+build-backend = "setuptools.build_meta"
diff --git a/pixi.toml b/pixi.toml
index 595c368..cf5dd8e 100644
--- a/pixi.toml
+++ b/pixi.toml
@@ -3,36 +3,36 @@ authors = ["ControlNet "]
channels = ["pytorch", "nvidia/label/cuda-11.8.0", "anaconda", "conda-forge"]
description = "Official implementation for HYDRA."
name = "hydra_vl4ai"
-platforms = ["linux-64"]
+platforms = ["linux-64", "win-64"]
version = "0.0.0"
channel-priority = "disabled"
[tasks]
-install_glip = { cmd = "python setup.py clean --all build develop --user", cwd = "module_repos/GLIP", env = { CUDA_HOME = "$CONDA_PREFIX", AM_I_DOCKER = "False", BUILD_WITH_CUDA = "True" } }
-install_sam = { cmd = "uv pip install -e .", cwd = "module_repos/Grounded-Segment-Anything/segment_anything", env = { CUDA_HOME = "$CONDA_PREFIX", AM_I_DOCKER = "False", BUILD_WITH_CUDA = "True" } }
-install_groundingdino = { cmd = "python setup.py clean --all build develop --user", cwd = "module_repos/Grounded-Segment-Anything/GroundingDINO", env = { CUDA_HOME = "$CONDA_PREFIX", AM_I_DOCKER = "False", BUILD_WITH_CUDA = "True" } }
-install_llava = { cmd = "uv pip install -e .", cwd = "module_repos/LLaVA", env = { CUDA_HOME = "$CONDA_PREFIX", AM_I_DOCKER = "False", BUILD_WITH_CUDA = "True" } }
-install = { cmd = "uv pip install -e . && echo Depedencies installation finished!", depends-on = ["install_glip", "install_sam", "install_groundingdino", "install_llava"], env = { CUDA_HOME = "$CONDA_PREFIX", AM_I_DOCKER = "False", BUILD_WITH_CUDA = "True" } }
-executor = "python -m hydra_vl4ai.executor --base_config config/gqa.yaml --model_config config/model_config_1gpu.yaml"
-download_model = "python -m hydra_vl4ai.download_model --base_config config/gqa.yaml --model_config config/model_config_1gpu.yaml"
+executor = "python -m hydra_vl4ai.executor --base_config config/okvqa.yaml --model_config config/model_config_1gpu.yaml"
+download_model = "python -m hydra_vl4ai.download_model --base_config config/okvqa.yaml --model_config config/model_config_1gpu.yaml"
[build-dependencies]
setuptools = "*"
cmake = "*"
ninja = "*"
+[pypi-options]
+no-build-isolation = ["maskrcnn_benchmark", "segment_anything", "groundingdino", "llava", "hydra_vl4ai"]
+
[dependencies]
python = { version = "3.11.*", channel = "anaconda" }
pytorch = { version = "==2.1.2", channel = "pytorch" }
torchvision = { version = "==0.16.2", channel = "pytorch" }
torchaudio = { version = "==2.1.2", channel = "pytorch" }
pytorch-cuda = { version = "11.8.*", channel = "pytorch" }
-cuda = { version = "11.8.0", channel = "nvidia/label/cuda-11.8.0" }
-cuda-libraries-dev = { version = "11.8.0", channel = "nvidia/label/cuda-11.8.0" }
-cuda-version = "11.8"
+cuda = { version = "==11.8.0", channel = "nvidia/label/cuda-11.8.0" }
+cuda-libraries-dev = { version = "==11.8.0", channel = "nvidia/label/cuda-11.8.0" }
+cuda-version = "==11.8"
+markupsafe = ">2.0,<3.0"
numpy = "<2.0"
ipywidgets = ">=8.1.5,<9"
ipykernel = ">=6.29.5,<7"
+pandas = ">=2.2.3,<3"
[pypi-dependencies]
fastapi = "*"
@@ -66,3 +66,9 @@ gdown = "==5.2.0"
rich = "==13.7.1"
ollama = "~=0.3.0"
protobuf = "~=3.19.0"
+gradio = ">=5.0,<5.12"
+maskrcnn_benchmark = { path = "module_repos/GLIP" }
+segment_anything = { path = "module_repos/Grounded-Segment-Anything/segment_anything" }
+groundingdino = { path = "module_repos/Grounded-Segment-Anything/GroundingDINO" }
+llava = { path = "module_repos/LLaVA" }
+hydra_vl4ai = { path = ".", editable = true }
diff --git a/requirements.txt b/requirements.txt
index e601ea9..167e6d2 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -28,4 +28,4 @@ scipy==1.14.1
gdown==5.2.0
rich==13.7.1
ollama==0.3.*
-gradio
+gradio>=5.0,<5.12
diff --git a/train.py b/train.py
new file mode 100644
index 0000000..9466a88
--- /dev/null
+++ b/train.py
@@ -0,0 +1,104 @@
+import asyncio
+import json
+import os
+from pathlib import Path
+
+from dotenv import load_dotenv
+
+load_dotenv()
+
+import numpy as np
+
+import argparse
+
+parser = argparse.ArgumentParser()
+parser.add_argument("--data_root", type=str, required=True)
+parser.add_argument("--base_config", type=str, required=True)
+parser.add_argument("--model_config", type=str, required=True)
+parser.add_argument("--result_folder", type=str, default="./result")
+parser.add_argument("--dqn_config", type=str, required=True)
+args = parser.parse_args()
+
+from hydra_vl4ai.util.config import Config
+
+Config.base_config_path = args.base_config
+Config.dqn_config_path = args.dqn_config
+Config.model_config_path = args.model_config
+
+from hydra_vl4ai.agent.hydra import HydraWithRL
+from hydra_vl4ai.util.console import logger, console
+from hydra_vl4ai.util.misc import wait_until_loaded
+import exp_datasets
+
+
+async def main():
+ with console.status("[bold green]Connect to HYDRA executor...") as status:
+ wait_until_loaded(f"http://localhost:{Config.base_config['executor_port']}")
+ hydra = HydraWithRL(training=True)
+
+ match Config.base_config["dataset"]:
+ case "gqa":
+ dataset = exp_datasets.GQA(
+ args.data_root
+ )
+ case "okvqa":
+ dataset = exp_datasets.OKVQA(
+ args.data_root
+ )
+ case "aokvqa":
+ # TODO: Not tested yet
+ # dataset = exp_datasets.AOKVQA(
+ # f"{args.data_root}/aokvqa",
+ # "val", f"{args.data_root}/coco", version="v1p0"
+ # )
+ raise NotImplementedError("AOKVQA is not implemented yet")
+ case "refcoco":
+ dataset = exp_datasets.Refcoco(args.data_root)
+ case "refcoco+":
+ dataset = exp_datasets.Refcoco(args.data_root)
+ case _:
+ raise ValueError("Invalid dataset")
+
+ # output path
+ Path(args.result_folder).mkdir(parents=True, exist_ok=True)
+ save_path = Path(args.result_folder) / f"result_{Config.base_config['dataset']}.jsonl"
+
+ cum_reward = 0 # TODO:modify
+
+ for epoch_idx_ in range(Config.dqn_config["training_epoch"]):
+ for i, (image_path, datum_id, query, ground_truth) in enumerate(dataset):
+
+ logger.info(f"Processing {i + 1}/{len(dataset)}")
+ with open(image_path, "rb") as f:
+ image_buffer = f.read()
+ result = await hydra.train_step(image_buffer, query, ground_truth)
+ logger.info(f"Query: {query} Answer: {result}")
+
+ with open(save_path, "a") as f:
+ f.write(json.dumps({
+ "datum_id": datum_id,
+ "query": query,
+ "ground_truth": ground_truth,
+ "result": result
+ }) + "\n")
+ f.flush()
+
+ # training log info
+ if hydra.controller.obs_no % hydra.controller.train_log_interval == 0:
+ mean_reward = np.mean(hydra.controller.reward_window)
+ cum_reward = 0.99 * cum_reward + 0.01 * mean_reward
+ logger.info('---Current step:{}-----Mean Reward:{:.2f}----Cumulative Reward:{:.2f}'.format(
+ hydra.controller.obs_no, mean_reward, cum_reward))
+
+ if hydra.controller.save_model_obs_num > hydra.controller.save_interval \
+ and hydra.controller.best_cum_reward < cum_reward:
+ # update best cumulated reward
+ hydra.controller.best_cum_reward = cum_reward
+
+ # save model
+ hydra.controller.save()
+ hydra.controller.save_model_obs_num = 0
+
+
+if __name__ == "__main__":
+ asyncio.run(main=main())