Skip to content

Commit d494be2

Browse files
authored
Create utils.py
1 parent 6cc34d8 commit d494be2

File tree

1 file changed

+234
-0
lines changed

1 file changed

+234
-0
lines changed

9_rlhf/wip/container/utils.py

Lines changed: 234 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,234 @@
1+
import random
2+
import numpy as np
3+
from tqdm.notebook import tqdm
4+
from omegaconf import DictConfig
5+
from dataclasses import dataclass
6+
from typing import Optional, Tuple, Union
7+
from typing import Iterable, Sequence, List
8+
9+
from torchtyping import TensorType
10+
11+
import transformers
12+
from transformers import DataCollatorWithPadding
13+
from transformers import pipeline, AutoTokenizer
14+
15+
from datasets import load_dataset
16+
17+
import torch
18+
import torch.nn as nn
19+
import torch.nn.functional as F
20+
from torch.nn.utils.rnn import pad_sequence
21+
from torch.utils.data import DataLoader, Dataset
22+
from torch.optim.lr_scheduler import CosineAnnealingLR
23+
24+
25+
class PromptPipeline():
26+
def __init__(self, prompts: List[str], max_prompt_length: int, tokenizer):
27+
super().__init__()
28+
29+
prompts = tokenizer(prompts).input_ids
30+
31+
self.tokenizer = tokenizer
32+
self.prompts = [prompt[-max_prompt_length:] for prompt in prompts]
33+
self.prompts = [{"input_ids": prompt, "attention_mask": [1] * len(prompt)} for prompt in self.prompts]
34+
35+
def __getitem__(self, ix: int):
36+
return self.prompts[ix]
37+
38+
def __len__(self) -> int:
39+
return len(self.prompts)
40+
41+
def create_loader(self, batch_size: int, shuffle=False) -> DataLoader:
42+
collate_fn = DataCollatorWithPadding(self.tokenizer)
43+
return DataLoader(self, batch_size=batch_size, collate_fn=collate_fn, shuffle=shuffle)
44+
45+
@dataclass
46+
class PPORLElement:
47+
query_tensor: TensorType["query_size"]
48+
response_tensor: TensorType["response_size"]
49+
logprobs: TensorType["response_size", "vocab_size"]
50+
values: TensorType["response_size"]
51+
rewards: TensorType["response_size"]
52+
53+
54+
@dataclass
55+
class PPORLBatch:
56+
query_tensors: TensorType["batch_size", "query_size"]
57+
response_tensors: TensorType["batch_size", "response_size"]
58+
logprobs: TensorType["batch_size", "response_size", "vocab_size"]
59+
values: TensorType["batch_size", "response_size"]
60+
rewards: TensorType["batch_size", "response_size"]
61+
62+
63+
class PPORolloutStorage():
64+
def __init__(self, pad_token_id):
65+
super().__init__()
66+
self.pad_token_id = pad_token_id
67+
self.history: Iterable[PPORLElement] = [None]
68+
69+
def push(self, exps: Iterable[PPORLElement]):
70+
self.history += exps
71+
72+
def clear_history(self):
73+
self.history = []
74+
75+
def __getitem__(self, index: int) -> PPORLElement:
76+
return self.history[index]
77+
78+
def __len__(self) -> int:
79+
return len(self.history)
80+
81+
def create_loader(self, batch_size: int, shuffle: bool) -> DataLoader:
82+
def collate_fn(elems: Iterable[PPORLElement]):
83+
return PPORLBatch(
84+
pad_sequence(
85+
[elem.query_tensor.flip(0) for elem in elems],
86+
padding_value=self.pad_token_id,
87+
batch_first=True,
88+
).flip(1),
89+
pad_sequence(
90+
[elem.response_tensor for elem in elems],
91+
padding_value=self.pad_token_id,
92+
batch_first=True,
93+
),
94+
pad_sequence(
95+
[elem.logprobs for elem in elems],
96+
padding_value=0.0,
97+
batch_first=True,
98+
),
99+
pad_sequence(
100+
[elem.values for elem in elems],
101+
padding_value=0.0,
102+
batch_first=True
103+
),
104+
pad_sequence(
105+
[elem.rewards for elem in elems],
106+
padding_value=0.0,
107+
batch_first=True,
108+
),
109+
)
110+
111+
return DataLoader(self, batch_size, shuffle=shuffle, collate_fn=collate_fn)
112+
113+
class Actor():
114+
115+
def __init__(
116+
self,
117+
prompt_pipeline,
118+
tokenizer,
119+
chunk_size = 128):
120+
121+
self.prompt_pipeline = prompt_pipeline
122+
self.chunk_size = chunk_size
123+
124+
self.prompt_pipeline_loader = self.prompt_pipeline.create_loader(self.chunk_size, shuffle=True)
125+
self.prompt_pipeline_iterator = iter(self.prompt_pipeline_loader)
126+
127+
self.ref_model = Agent(config.model.model_path)
128+
self.ref_model_device = config.train.ref_model_device
129+
self.ref_model = self.ref_model.to(self.ref_model_device)
130+
131+
self.tokenizer = tokenizer
132+
133+
134+
def make_experience(self, model, num_rollouts = 128):
135+
model_device = next(model.parameters()).device
136+
137+
ppo_rl_elements = []
138+
while len(ppo_rl_elements) < num_rollouts:
139+
try:
140+
batch = next(self.prompt_pipeline_iterator)
141+
except StopIteration:
142+
self.pipeline_iterator = iter(self.prompt_pipeline_loader)
143+
batch = next(self.prompt_pipeline_iterator)
144+
145+
trajectories = generate(model, self.tokenizer, **batch.to(model_device))
146+
147+
query_tensors = batch.input_ids
148+
response_tensors = trajectories[:, query_tensors.shape[1] :]
149+
150+
all_tokens, attention_mask, position_ids = get_model_inputs(
151+
query_tensors.to(response_tensors.device), response_tensors, self.tokenizer.pad_token_id)
152+
with torch.no_grad():
153+
logits, values = model(
154+
all_tokens,
155+
attention_mask=attention_mask,
156+
position_ids=position_ids)
157+
ref_logits, _ = self.ref_model(
158+
all_tokens.to(self.ref_model_device),
159+
attention_mask=attention_mask.to(self.ref_model_device),
160+
position_ids=position_ids.to(self.ref_model_device))
161+
162+
all_tokens = all_tokens.cpu()
163+
logits = logits.cpu()
164+
ref_logits = ref_logits.cpu()
165+
166+
logprobs = logprobs_from_logits(logits[:, :-1, :], all_tokens[:, 1:])
167+
ref_logprobs = logprobs_from_logits(ref_logits[:, :-1, :], all_tokens[:, 1:])
168+
169+
n = trajectories.shape[0]
170+
values = values.cpu()[:, :-1]
171+
query_tensors = query_tensors.cpu()
172+
response_tensors = response_tensors.cpu()
173+
174+
start = query_tensors.shape[1] - 1
175+
ends = start + attention_mask[:, start:].sum(1)
176+
all_values = [values[i, start : ends[i]] for i in range(n)]
177+
all_logprobs = [logprobs[i, start : ends[i]] for i in range(n)]
178+
179+
texts = self.tokenizer.batch_decode(trajectories, skip_special_tokens=True)
180+
scores = torch.tensor(reward_fn(texts), device='cpu', dtype=torch.float)
181+
182+
rewards = -config.method.kl_coef * (logprobs - ref_logprobs)
183+
all_rewards = [None] * n
184+
for i in range(n):
185+
rs = rewards[i][start : ends[i]]
186+
rs[-1] = scores[i]
187+
all_rewards[i] = rs
188+
189+
new_ppo_rl_elements = [
190+
PPORLElement(
191+
query_tensor=query_tensors[i],
192+
response_tensor=response_tensors[i],
193+
logprobs=all_logprobs[i],
194+
values=all_values[i],
195+
rewards=all_rewards[i],
196+
)
197+
for i in range(n)
198+
]
199+
200+
ppo_rl_elements += new_ppo_rl_elements
201+
202+
return ppo_rl_elements, scores.mean().item()
203+
204+
class Agent(nn.Module):
205+
def __init__(self, model_path, num_layers_unfrozen=0):
206+
super().__init__()
207+
208+
self.base_model = transformers.AutoModelForCausalLM.from_pretrained(model_path, cache_dir="./models")
209+
210+
self.logit_head = self.base_model.get_output_embeddings()
211+
212+
n_embd = self.base_model.lm_head.in_features
213+
self.value_head = nn.Sequential(
214+
nn.Linear(n_embd, n_embd*2),
215+
nn.ReLU(),
216+
nn.Linear(n_embd*2, 1))
217+
218+
freeze_bottom_causal_layers(self.base_model, num_layers_unfrozen)
219+
220+
221+
def generate(self, input_ids, **x):
222+
return self.base_model.generate(input_ids, **x)
223+
224+
def forward(self, input_ids, attention_mask, position_ids):
225+
226+
transformer_outputs = self.base_model.transformer(input_ids=input_ids,
227+
attention_mask=attention_mask,
228+
position_ids=position_ids)
229+
230+
last_hidden_state = transformer_outputs.last_hidden_state
231+
lm_logits = self.logit_head(last_hidden_state)
232+
value = self.value_head(last_hidden_state).squeeze(-1)
233+
234+
return lm_logits, value

0 commit comments

Comments
 (0)