Skip to content

Commit c3a13fc

Browse files
authored
Update main.py
1 parent b5aac23 commit c3a13fc

File tree

1 file changed

+296
-8
lines changed

1 file changed

+296
-8
lines changed

9_rlhf/wip/container/main.py

Lines changed: 296 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,225 @@
2121
from torch.utils.data import DataLoader, Dataset
2222
from torch.optim.lr_scheduler import CosineAnnealingLR
2323

24-
from utils import PromptPipeline, PPORLElement, PPORLBatch, PPORolloutStorage, Actor, Agent
24+
25+
#####################
26+
## PyTorch Objects ##
27+
#####################
28+
29+
class PromptPipeline():
30+
def __init__(self, prompts: List[str], max_prompt_length: int, tokenizer):
31+
super().__init__()
32+
33+
prompts = tokenizer(prompts).input_ids
34+
35+
self.tokenizer = tokenizer
36+
self.prompts = [prompt[-max_prompt_length:] for prompt in prompts]
37+
self.prompts = [{"input_ids": prompt, "attention_mask": [1] * len(prompt)} for prompt in self.prompts]
38+
39+
def __getitem__(self, ix: int):
40+
return self.prompts[ix]
41+
42+
def __len__(self) -> int:
43+
return len(self.prompts)
44+
45+
def create_loader(self, batch_size: int, shuffle=False) -> DataLoader:
46+
collate_fn = DataCollatorWithPadding(self.tokenizer)
47+
return DataLoader(self, batch_size=batch_size, collate_fn=collate_fn, shuffle=shuffle)
48+
49+
@dataclass
50+
class PPORLElement:
51+
query_tensor: TensorType["query_size"]
52+
response_tensor: TensorType["response_size"]
53+
logprobs: TensorType["response_size", "vocab_size"]
54+
values: TensorType["response_size"]
55+
rewards: TensorType["response_size"]
56+
57+
58+
@dataclass
59+
class PPORLBatch:
60+
query_tensors: TensorType["batch_size", "query_size"]
61+
response_tensors: TensorType["batch_size", "response_size"]
62+
logprobs: TensorType["batch_size", "response_size", "vocab_size"]
63+
values: TensorType["batch_size", "response_size"]
64+
rewards: TensorType["batch_size", "response_size"]
65+
66+
67+
class PPORolloutStorage():
68+
def __init__(self, pad_token_id):
69+
super().__init__()
70+
self.pad_token_id = pad_token_id
71+
self.history: Iterable[PPORLElement] = [None]
72+
73+
def push(self, exps: Iterable[PPORLElement]):
74+
self.history += exps
75+
76+
def clear_history(self):
77+
self.history = []
78+
79+
def __getitem__(self, index: int) -> PPORLElement:
80+
return self.history[index]
81+
82+
def __len__(self) -> int:
83+
return len(self.history)
84+
85+
def create_loader(self, batch_size: int, shuffle: bool) -> DataLoader:
86+
def collate_fn(elems: Iterable[PPORLElement]):
87+
return PPORLBatch(
88+
pad_sequence(
89+
[elem.query_tensor.flip(0) for elem in elems],
90+
padding_value=self.pad_token_id,
91+
batch_first=True,
92+
).flip(1),
93+
pad_sequence(
94+
[elem.response_tensor for elem in elems],
95+
padding_value=self.pad_token_id,
96+
batch_first=True,
97+
),
98+
pad_sequence(
99+
[elem.logprobs for elem in elems],
100+
padding_value=0.0,
101+
batch_first=True,
102+
),
103+
pad_sequence(
104+
[elem.values for elem in elems],
105+
padding_value=0.0,
106+
batch_first=True
107+
),
108+
pad_sequence(
109+
[elem.rewards for elem in elems],
110+
padding_value=0.0,
111+
batch_first=True,
112+
),
113+
)
114+
115+
return DataLoader(self, batch_size, shuffle=shuffle, collate_fn=collate_fn)
116+
117+
class Actor():
118+
119+
def __init__(
120+
self,
121+
prompt_pipeline,
122+
tokenizer,
123+
chunk_size = 128):
124+
125+
self.prompt_pipeline = prompt_pipeline
126+
self.chunk_size = chunk_size
127+
128+
self.prompt_pipeline_loader = self.prompt_pipeline.create_loader(self.chunk_size, shuffle=True)
129+
self.prompt_pipeline_iterator = iter(self.prompt_pipeline_loader)
130+
131+
self.ref_model = Agent(config.model.model_path)
132+
self.ref_model_device = config.train.ref_model_device
133+
self.ref_model = self.ref_model.to(self.ref_model_device)
134+
135+
self.tokenizer = tokenizer
136+
137+
138+
def make_experience(self, model, num_rollouts = 128):
139+
model_device = next(model.parameters()).device
140+
141+
ppo_rl_elements = []
142+
while len(ppo_rl_elements) < num_rollouts:
143+
try:
144+
batch = next(self.prompt_pipeline_iterator)
145+
except StopIteration:
146+
self.pipeline_iterator = iter(self.prompt_pipeline_loader)
147+
batch = next(self.prompt_pipeline_iterator)
148+
149+
trajectories = generate(model, self.tokenizer, **batch.to(model_device))
150+
151+
query_tensors = batch.input_ids
152+
response_tensors = trajectories[:, query_tensors.shape[1] :]
153+
154+
all_tokens, attention_mask, position_ids = get_model_inputs(
155+
query_tensors.to(response_tensors.device), response_tensors, self.tokenizer.pad_token_id)
156+
with torch.no_grad():
157+
logits, values = model(
158+
all_tokens,
159+
attention_mask=attention_mask,
160+
position_ids=position_ids)
161+
ref_logits, _ = self.ref_model(
162+
all_tokens.to(self.ref_model_device),
163+
attention_mask=attention_mask.to(self.ref_model_device),
164+
position_ids=position_ids.to(self.ref_model_device))
165+
166+
all_tokens = all_tokens.cpu()
167+
logits = logits.cpu()
168+
ref_logits = ref_logits.cpu()
169+
170+
logprobs = logprobs_from_logits(logits[:, :-1, :], all_tokens[:, 1:])
171+
ref_logprobs = logprobs_from_logits(ref_logits[:, :-1, :], all_tokens[:, 1:])
172+
173+
n = trajectories.shape[0]
174+
values = values.cpu()[:, :-1]
175+
query_tensors = query_tensors.cpu()
176+
response_tensors = response_tensors.cpu()
177+
178+
start = query_tensors.shape[1] - 1
179+
ends = start + attention_mask[:, start:].sum(1)
180+
all_values = [values[i, start : ends[i]] for i in range(n)]
181+
all_logprobs = [logprobs[i, start : ends[i]] for i in range(n)]
182+
183+
texts = self.tokenizer.batch_decode(trajectories, skip_special_tokens=True)
184+
scores = torch.tensor(reward_fn(texts), device='cpu', dtype=torch.float)
185+
186+
rewards = -config.method.kl_coef * (logprobs - ref_logprobs)
187+
all_rewards = [None] * n
188+
for i in range(n):
189+
rs = rewards[i][start : ends[i]]
190+
rs[-1] = scores[i]
191+
all_rewards[i] = rs
192+
193+
new_ppo_rl_elements = [
194+
PPORLElement(
195+
query_tensor=query_tensors[i],
196+
response_tensor=response_tensors[i],
197+
logprobs=all_logprobs[i],
198+
values=all_values[i],
199+
rewards=all_rewards[i],
200+
)
201+
for i in range(n)
202+
]
203+
204+
ppo_rl_elements += new_ppo_rl_elements
205+
206+
return ppo_rl_elements, scores.mean().item()
207+
208+
class Agent(nn.Module):
209+
def __init__(self, model_path, num_layers_unfrozen=0):
210+
super().__init__()
211+
212+
self.base_model = transformers.AutoModelForCausalLM.from_pretrained(model_path, cache_dir="./models")
213+
214+
self.logit_head = self.base_model.get_output_embeddings()
215+
216+
n_embd = self.base_model.lm_head.in_features
217+
self.value_head = nn.Sequential(
218+
nn.Linear(n_embd, n_embd*2),
219+
nn.ReLU(),
220+
nn.Linear(n_embd*2, 1))
221+
222+
freeze_bottom_causal_layers(self.base_model, num_layers_unfrozen)
223+
224+
225+
def generate(self, input_ids, **x):
226+
return self.base_model.generate(input_ids, **x)
227+
228+
def forward(self, input_ids, attention_mask, position_ids):
229+
230+
transformer_outputs = self.base_model.transformer(input_ids=input_ids,
231+
attention_mask=attention_mask,
232+
position_ids=position_ids)
233+
234+
last_hidden_state = transformer_outputs.last_hidden_state
235+
lm_logits = self.logit_head(last_hidden_state)
236+
value = self.value_head(last_hidden_state).squeeze(-1)
237+
238+
return lm_logits, value
239+
240+
#####################
241+
## Util Functions ##
242+
#####################
25243

26244
def generate(model, tokenizer, input_ids, attention_mask=None, **kwargs):
27245

@@ -63,20 +281,79 @@ def freeze_bottom_causal_layers(model: nn.Module, num_layers_unfrozen: int = 0):
63281
for layer in hidden_layers_to_freeze:
64282
layer.requires_grad_(False)
65283

66-
sentiment_fn = pipeline(
67-
model = "lvwerra/distilbert-imdb",
68-
top_k=2,
69-
batch_size=config.method.num_rollouts,
70-
device=config.train.reward_model_device,
71-
)
72-
73284
def get_positive_score(scores):
74285
return dict(map(lambda x: tuple(x.values()), scores))["POSITIVE"]
75286

76287
def reward_fn(samples: List[str]) -> List[float]:
77288
sentiments = list(map(get_positive_score, sentiment_fn(samples)))
78289
return sentiments
79290

291+
############################
292+
## Model and Data Configs ##
293+
############################
294+
295+
config = {
296+
'train': {
297+
'seed': 2023,
298+
'seq_length': 1024,
299+
'epochs': 50,
300+
'total_steps': 5000,
301+
'batch_size': 64,
302+
'eval_interval': 100,
303+
'model_device':'cuda:0',
304+
'ref_model_device':'cpu',
305+
'reward_model_device':'cpu'},
306+
'model': {
307+
'model_path': 'lvwerra/gpt2-imdb', #'edbeeching/gpt-neo-1.3B-imdb',
308+
'tokenizer_path': 'lvwerra/gpt2-imdb', #'edbeeching/gpt-neo-1.3B-imdb',
309+
'num_layers_unfrozen': 1},
310+
'optimizer': {
311+
'name': 'adamw',
312+
'kwargs': {'lr': 0.0001,
313+
'betas': [0.9, 0.95],
314+
'eps': 1e-08,
315+
'weight_decay': 1e-06}},
316+
'scheduler': {
317+
'name': 'cosine_annealing',
318+
'kwargs': {
319+
'T_max': 10000, 'eta_min': 0.0001}},
320+
'method': {
321+
'use_whitening': True,
322+
'prompt_size': 10,
323+
'num_rollouts': 128,
324+
'chunk_size': 128,
325+
'ppo_epochs': 4,
326+
'kl_coef': 0.05,
327+
'horizon': 10000,
328+
'gamma': 1,
329+
'lam': 0.95,
330+
'cliprange': 0.2,
331+
'cliprange_value': 0.2,
332+
'vf_coef': 1,
333+
'scale_reward': False,
334+
'ref_mean': None,
335+
'ref_std': None,
336+
'cliprange_reward': 10,
337+
'gen_kwargs': {
338+
'max_new_tokens': 60,
339+
'top_k': 0,
340+
'top_p': 1.0,
341+
'do_sample': True}}}
342+
343+
config = DictConfig(config)
344+
345+
random.seed(config.train.seed)
346+
np.random.seed(config.train.seed)
347+
torch.manual_seed(config.train.seed)
348+
torch.cuda.manual_seed(config.train.seed)
349+
350+
sentiment_fn = pipeline(
351+
model = "lvwerra/distilbert-imdb",
352+
top_k=2,
353+
batch_size=config.method.num_rollouts,
354+
device=config.train.reward_model_device,
355+
)
356+
80357
imdb = load_dataset("imdb", split="train+test")
81358

82359
prompts = [" ".join(review.split()[:config.method.prompt_size]) for review in imdb["text"]]
@@ -108,6 +385,11 @@ def reward_fn(samples: List[str]) -> List[float]:
108385

109386
reward_fn(generated_text)
110387

388+
389+
###############
390+
## Main loop ##
391+
###############
392+
111393
prompt_pipeline = PromptPipeline(prompts, config.train.seq_length, tokenizer)
112394

113395
actor = Actor(prompt_pipeline, tokenizer, chunk_size=config.method.chunk_size)
@@ -141,6 +423,11 @@ def reward_fn(samples: List[str]) -> List[float]:
141423

142424
tbar.set_description(f"| score: {score:.3f} |")
143425

426+
427+
################
428+
## Eval steps ##
429+
################
430+
144431
input_ids = tokenizer.batch_encode_plus(
145432
["my feeling about the movie", "this is", "I can tell with certainty"],
146433
return_tensors='pt',
@@ -164,3 +451,4 @@ def reward_fn(samples: List[str]) -> List[float]:
164451
print(generated_text[1].replace('\n', ' ') + '\n', rewards[1])
165452
print(generated_text[2].replace('\n', ' ') + '\n', rewards[2])
166453
print('all rewards mean:',np.mean(rewards))
454+

0 commit comments

Comments
 (0)