Skip to content

Commit a5cd892

Browse files
authored
Create main.py
1 parent d494be2 commit a5cd892

File tree

1 file changed

+166
-0
lines changed

1 file changed

+166
-0
lines changed

9_rlhf/wip/container/main.py

Lines changed: 166 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,166 @@
1+
import random
2+
import numpy as np
3+
from tqdm.notebook import tqdm
4+
from omegaconf import DictConfig
5+
from dataclasses import dataclass
6+
from typing import Optional, Tuple, Union
7+
from typing import Iterable, Sequence, List
8+
9+
from torchtyping import TensorType
10+
11+
import transformers
12+
from transformers import DataCollatorWithPadding
13+
from transformers import pipeline, AutoTokenizer
14+
15+
from datasets import load_dataset
16+
17+
import torch
18+
import torch.nn as nn
19+
import torch.nn.functional as F
20+
from torch.nn.utils.rnn import pad_sequence
21+
from torch.utils.data import DataLoader, Dataset
22+
from torch.optim.lr_scheduler import CosineAnnealingLR
23+
24+
from utils import PromptPipeline, PPORLElement, PPORLBatch, PPORolloutStorage, Actor, Agent
25+
26+
def generate(model, tokenizer, input_ids, attention_mask=None, **kwargs):
27+
28+
generate_kwargs = dict(
29+
config.method.gen_kwargs,
30+
eos_token_id=tokenizer.eos_token_id,
31+
pad_token_id=tokenizer.eos_token_id)
32+
33+
kwargs = dict(generate_kwargs, **kwargs)
34+
35+
with torch.no_grad():
36+
generated_results = model.generate(input_ids=input_ids, attention_mask=attention_mask, **kwargs)
37+
38+
return generated_results
39+
40+
41+
def get_model_inputs(query_tensors, response_tensors, pad_token_id):
42+
tokens = torch.cat((query_tensors, response_tensors), dim=1)[:, -config.train.seq_length :]
43+
attention_mask = (tokens.not_equal(pad_token_id).long().to(tokens.device))
44+
position_ids = attention_mask.cumsum(-1) - 1
45+
position_ids.masked_fill_(attention_mask.eq(0), 0)
46+
return tokens, attention_mask, position_ids
47+
48+
49+
def logprobs_from_logits(logits, labels):
50+
logprobs = F.log_softmax(logits, dim=-1)
51+
logprobs_labels = torch.gather(logprobs, dim=-1, index=labels.unsqueeze(-1))
52+
return logprobs_labels.squeeze(-1)
53+
54+
55+
def freeze_bottom_causal_layers(model: nn.Module, num_layers_unfrozen: int = 0):
56+
hidden_layers = model.transformer.h
57+
if num_layers_unfrozen == 0:
58+
hidden_layers_to_freeze = list(hidden_layers)
59+
elif num_layers_unfrozen > 0:
60+
hidden_layers_to_freeze = list(hidden_layers)[:-num_layers_unfrozen]
61+
else:
62+
hidden_layers_to_freeze = []
63+
for layer in hidden_layers_to_freeze:
64+
layer.requires_grad_(False)
65+
66+
sentiment_fn = pipeline(
67+
model = "lvwerra/distilbert-imdb",
68+
top_k=2,
69+
batch_size=config.method.num_rollouts,
70+
device=config.train.reward_model_device,
71+
)
72+
73+
def get_positive_score(scores):
74+
return dict(map(lambda x: tuple(x.values()), scores))["POSITIVE"]
75+
76+
def reward_fn(samples: List[str]) -> List[float]:
77+
sentiments = list(map(get_positive_score, sentiment_fn(samples)))
78+
return sentiments
79+
80+
imdb = load_dataset("imdb", split="train+test")
81+
82+
prompts = [" ".join(review.split()[:config.method.prompt_size]) for review in imdb["text"]]
83+
84+
tokenizer = AutoTokenizer.from_pretrained(config.model.tokenizer_path)
85+
tokenizer.pad_token = tokenizer.eos_token
86+
tokenizer.pad_token_ida = tokenizer.eos_token_id
87+
tokenizer.padding_side = "left"
88+
pad_token_id = 50256
89+
90+
max_prompt_length = (config.train.seq_length - config.method.gen_kwargs["max_new_tokens"])
91+
test_prompt_pipeline = PromptPipeline(prompts, max_prompt_length, tokenizer)
92+
93+
model = Agent(config.model.model_path, config.model.num_layers_unfrozen).to(config.train.model_device)
94+
95+
input_ids = tokenizer.batch_encode_plus(
96+
["my feeling about the movie", "this is", "I can tell with certainty"],
97+
return_tensors='pt',
98+
padding=True)['input_ids']
99+
100+
print (input_ids)
101+
102+
model_device = next(model.parameters()).device
103+
output_ids = generate(model, tokenizer, input_ids.to(model_device), max_new_tokens=config.method.gen_kwargs["max_new_tokens"])
104+
105+
generated_text = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
106+
107+
print (generated_text)
108+
109+
reward_fn(generated_text)
110+
111+
prompt_pipeline = PromptPipeline(prompts, config.train.seq_length, tokenizer)
112+
113+
actor = Actor(prompt_pipeline, tokenizer, chunk_size=config.method.chunk_size)
114+
115+
store = PPORolloutStorage(tokenizer.pad_token_id)
116+
117+
opt = torch.optim.Adam(model.parameters(), **config.optimizer.kwargs)
118+
scheduler = CosineAnnealingLR(opt, **config.scheduler.kwargs)
119+
120+
n_updates_per_batch = config.method.ppo_epochs
121+
total_steps = 400 # TODO: fix this
122+
123+
tbar = tqdm(initial=0, total=total_steps)
124+
125+
for _ in range(config.train.epochs):
126+
127+
store.clear_history()
128+
rollouts, score = actor.make_experience(model, config.method.num_rollouts)
129+
store.push(rollouts)
130+
train_dataloader = store.create_loader(config.train.batch_size, shuffle=True)
131+
132+
for batch in train_dataloader:
133+
for _ in range(n_updates_per_batch):
134+
135+
loss, reward = loss_fn(batch)
136+
loss.backward()
137+
opt.step()
138+
opt.zero_grad()
139+
scheduler.step()
140+
tbar.update()
141+
142+
tbar.set_description(f"| score: {score:.3f} |")
143+
144+
input_ids = tokenizer.batch_encode_plus(
145+
["my feeling about the movie", "this is", "I can tell with certainty"],
146+
return_tensors='pt',
147+
padding=True)['input_ids']
148+
input_ids
149+
150+
model_device = next(model.parameters()).device
151+
output_ids = generate(
152+
model,
153+
tokenizer,
154+
input_ids.to(model_device),
155+
# min_length=20,
156+
# max_new_tokens=100,
157+
# do_sample=True,
158+
# top_p=0.92,
159+
# top_k=0
160+
)
161+
generated_text = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
162+
rewards = reward_fn(generated_text)
163+
print(generated_text[0].replace('\n', ' ') + '\n', rewards[0])
164+
print(generated_text[1].replace('\n', ' ') + '\n', rewards[1])
165+
print(generated_text[2].replace('\n', ' ') + '\n', rewards[2])
166+
print('all rewards mean:',np.mean(rewards))

0 commit comments

Comments
 (0)