Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 25 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
<a href="https://www.python.org/"><img src="https://img.shields.io/pypi/pyversions/hydra-vl4ai?style=flat-square"></a>
</div>

**This is the code for the paper [HYDRA: A Hyper Agent for Dynamic Compositional Visual Reasoning](https://link.springer.com/chapter/10.1007/978-3-031-72661-3_8), accepted by ECCV 2024 \[[Project Page](https://hydra-vl4ai.github.io)\].**
**This is the code for the paper [HYDRA: A Hyper Agent for Dynamic Compositional Visual Reasoning](https://link.springer.com/chapter/10.1007/978-3-031-72661-3_8), accepted by ECCV 2024 \[[Project Page](https://hydra-vl4ai.github.io)\]. We released the code that <span style="color: #FF6347; font-weight: bold;">uses Reinforcement Learning (DQN) to fine-tune the LLM</span>🔥🔥🔥**

## Release

Expand All @@ -50,7 +50,8 @@ We also notice the embedding model is updated by OpenAI as shown in this [link](
- [x] LLaMA3.1 (ollama) replacement.
- [x] Gradio Demo
- [x] GPT-4o Version.
- [ ] HYDRA with RL
- [x] HYDRA with RL(DQN).
- [ ] HYDRA with Deepseek R1.


## Installation
Expand All @@ -71,7 +72,7 @@ git clone https://github.com/ControlNet/HYDRA

Option 1: Using [pixi](https://prefix.dev/) (recommended):
```Bash
pixi run install
pixi install
pixi shell
```

Expand All @@ -86,8 +87,8 @@ If you meet errors, please consider going through the `build_env.sh` file and in
Edit the file `.env` or setup in CLI to configure the environment variables.

```
OPENAI_API_KEY=your-api-key
OLLAMA_HOST=http://ollama.server:11434
OPENAI_API_KEY=your-api-key # if you want to use OpenAI LLMs
OLLAMA_HOST=http://ollama.server:11434 # if you want to use your OLLaMA server for llama or deepseek
# do not change this TORCH_HOME variable
TORCH_HOME=./pretrained_models
```
Expand Down Expand Up @@ -126,7 +127,9 @@ python demo_gradio.py \
--base_config <YOUR-CONFIG-DIR> \
--model_config <MODEL-PATH>
```
https://github.com/user-attachments/assets/39a897ab-d457-49d2-8527-0d6fe3a3b922

---
### Inference dataset

```Bash
Expand All @@ -150,6 +153,23 @@ For example,
python evaluate.py result/result_okvqa.jsonl okvqa
```

## Training Controller with RL(DQN)

```Bash
python train.py \
--data_root <IMAGE_PATH> \
--base_config <YOUR-CONFIG-DIR>\
--model_config <MODEL-PATH> \
--dqn_config <YOUR-DQN-CONFIG-DIR>
```
For example,
```Bash
python train.py \
--data_root ../coco2014 \
--base_config ./config/okvqa.yaml\
--model_config ./config/model_config_1gpu.yaml \
--dqn_config ./config/dqn_debug.yaml
```

## Citation
```bibtex
Expand Down
15 changes: 15 additions & 0 deletions config/dqn_debug.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
model_name: dataset_nameaokvqa_learn_starts300
llm_embedding_dim: 1536 # small, 3072 for large
mlp_hidden_dim: 512
critic_layer_num: 4
critic_lr: 0.0001
train_log_interval: 1
batch_size: 1
update_times: 2
save_interval: 1
learn_starts: 2
dqn_explore_epsilon: 0.2
dqn_explore_epsilon_decay_rate: 0.02
dqn_explore_epsilon_decay_interval: 100
buffer_size: 100000
training_epoch: 1
15 changes: 15 additions & 0 deletions config/dqn_train_config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
model_name: dqn_model_example
llm_embedding_dim: 1536 # small, 3072 for large
mlp_hidden_dim: 512
critic_layer_num: 4
critic_lr: 0.0001
train_log_interval: 100
batch_size: 128
update_times: 4
save_interval: 100
learn_starts: 1000
dqn_explore_epsilon: 0.2
dqn_explore_epsilon_decay_rate: 0.02
dqn_explore_epsilon_decay_interval: 100
buffer_size: 100000
training_epoch: 4
13 changes: 7 additions & 6 deletions evaluate.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import sys
from evaluation.vqa_eval import GQAeval
from evaluation.grounding_eval import batch_iou_2d
from hydra_vl4ai.util.console import console, logger
import tensorneko_util as N
import argparse

import tensorneko_util as N

from hydra_vl4ai.evaluation.grounding_eval import batch_iou_2d
from hydra_vl4ai.evaluation.vqa_eval import GQAeval
from hydra_vl4ai.util.console import console, logger

if __name__ == '__main__':

Expand All @@ -20,7 +20,8 @@

match args.dataset:
case "okvqa":
score = evaluator.accuracy_one_set([each["result"] for each in result], [each["ground_truth"] for each in result])
score = evaluator.accuracy_one_set([each["result"] for each in result],
[each["ground_truth"] for each in result])
case "refcoco" | "refcoco+":
score = batch_iou_2d(result)
case _:
Expand Down
147 changes: 146 additions & 1 deletion hydra_vl4ai/agent/controller.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,161 @@
import abc
import pickle
import random
from collections import deque

import numpy as np
import torch

from hydra_vl4ai.util.console import logger
from .llm import llm_embedding
from .rl_dqn import DQN_EmbeddingViaLLM, ReplayBuffer
from .smb.state_memory_bank import StateMemoryBank
from ..util.config import Config
from ..util.misc import get_hydra_root_folder


class Controller(abc.ABC):

@abc.abstractmethod
def __call__(self, instructions: list[str], probs: np.ndarray) -> str:
def __call__(self, *args, **kwargs) -> str:
pass


class ControllerLLM(Controller):
"""
This is the function for not using RL controller
but directly use the LLM score to return optimal instruction
"""

def __call__(self, instructions: list[str], probs: np.ndarray) -> str:
return instructions[np.argmax(probs)]


class ControllerDQN(Controller):

def __init__(self,
embedding_prompt_base: str,
task_description_for_instruction: str,
instruction_example: str,
training: bool = False
):
super().__init__()
self.instruction_example = instruction_example
self.embedding_prompt_base = embedding_prompt_base
self.model_name = Config.dqn_config["model_name"]
self.model_save_path = get_hydra_root_folder().parent / "ckpt" / self.model_name
self.task_description_for_instruction = task_description_for_instruction
self.training = training

self.rl_agent_model = DQN_EmbeddingViaLLM(
device=torch.device('cuda:0'),
llm_embedding_dim_concat=Config.dqn_config["llm_embedding_dim"],
mlp_hidden_dim=Config.dqn_config["mlp_hidden_dim"],
action_dim=Config.base_config["num_actions"] + 1,
critic_layer_num=Config.dqn_config["critic_layer_num"],
critic_lr=float(Config.dqn_config["critic_lr"])
)
# load model
self.model_full_path = self.model_save_path / "critic.pt"
self.buffer_path = self.model_save_path / "buffer.pickle"
self.model_save_path.mkdir(parents=True, exist_ok=True)

if self.model_full_path.exists():
self.rl_agent_model.load_model(str(self.model_full_path))
logger.info(f"Load Model Done from file: {str(self.model_full_path)}")
elif not self.training: # for inference, if no model, raise error
raise RuntimeError(f"Model is not found: {self.model_full_path}")

if self.training: # for training
self.rl_agent_model.train_mode()
self.train_log_interval = Config.dqn_config["train_log_interval"]
self.reward_window = deque(maxlen=self.train_log_interval)
self.obs_no = 0
self.batch_size = Config.dqn_config["batch_size"]
self.update_times = Config.dqn_config["update_times"]
self.save_interval = Config.dqn_config["save_interval"]
self.save_model_obs_num = 0 # accumulate
self.best_cum_reward = -100 # TODO:MODIFY
self.best_score = 0
self.learn_starts = Config.dqn_config["learn_starts"]
self.dqn_explore_epsilon = Config.dqn_config["dqn_explore_epsilon"]
self.dqn_explore_epsilon_decay_rate = Config.dqn_config["dqn_explore_epsilon_decay_rate"]
self.dqn_explore_epsilon_decay_interval = Config.dqn_config["dqn_explore_epsilon_decay_interval"]
self.dqn_explore_threshold = self.dqn_explore_epsilon - self.dqn_explore_epsilon_decay_rate \
* (self.obs_no / self.dqn_explore_epsilon_decay_interval)

# load buffer
if self.buffer_path.exists():
with open(self.buffer_path, "rb") as reward_buffer_container:
self.replay_buffer = pickle.load(reward_buffer_container)
reward_buffer_container.close()
else:
self.replay_buffer = ReplayBuffer(capacity=Config.dqn_config["buffer_size"])

else:
self.rl_agent_model.eval_mode()

def save(self):
self.rl_agent_model.save_model(self.model_full_path)
with open(self.buffer_path, "wb") as f:
pickle.dump(self.replay_buffer, f)

def load(self):
self.rl_agent_model.load_model(self.model_full_path)
if self.training and self.buffer_path.exists():
with open(self.buffer_path, "rb") as f:
self.replay_buffer = pickle.load(f)

async def __call__(self, query: str, current_step_index: int, instructions: list[str], probs: np.ndarray,
state_memory_bank: StateMemoryBank
) -> tuple[str, np.ndarray, int]:
prompt = self.build_prompt(query, current_step_index, instructions, probs, state_memory_bank)

# get embedding from llm
response_emb = await llm_embedding(Config.base_config["embedding_model"], prompt)

affordance_value_array = self.rl_agent_model.get_action(obs=response_emb)

selected_idx = np.argmax(affordance_value_array)

# random exploration in the beginning.
if self.training:
# if it is in the beginning phase, do random exploration!
if self.obs_no <= self.learn_starts or np.random.random() <= self.dqn_explore_threshold:
selected_idx = random.choice(range(len(affordance_value_array)))

if selected_idx != len(instructions):
selected_instruction = instructions[selected_idx]
else:
selected_instruction = "REJECT"
return selected_instruction, response_emb, selected_idx

def build_prompt(self, query: str, current_step_index: int, instructions: list[str], probs: np.ndarray,
state_memory_bank: StateMemoryBank
):
"""Getting prompt based on template"""
# prompt-for-each-query
prompt = self.embedding_prompt_base.replace('[INSERT_QUERY_HERE]', query) # query insert
prompt = prompt.replace('[INSERT_CURRENT_STEP_NO]', str(current_step_index)) # step number insert

# prompt-for-query-type-about-the-dataset
prompt = prompt.replace('[INSERT_QUERY_TYPE_HERE]', self.task_description_for_instruction) # query type
prompt = prompt.replace('[EXAMPLE_HERE]', self.instruction_example) # query type demo/ exps

# previous instruction
prompt = prompt.replace('[NEED_TO_PROVIDE_PREVIOUS_INSTRUCTION]',
state_memory_bank.instructions_prompt) # previous code insert

# previous executed code
prompt = prompt.replace('[MORE_CODE_WAITING]', state_memory_bank.codes_prompt) # previous code insert
prompt = prompt.replace('[CURRENTLY_RESULT_WAITING]',
state_memory_bank.feedbacks_prompt) # result description insert

# variable details
prompt = prompt.replace('[VARIABLE_AND_DETAILS]', state_memory_bank.variables_prompt)

# current instructions/probs
prompt = prompt.replace('[CURRENT_OPTION]', str(instructions)) # instruction options
prompt = prompt.replace('[CURRENT_OPTION_PROBABILITY]', str(probs)) # probs of instruction options

return prompt
Loading