2121from torch .utils .data import DataLoader , Dataset
2222from torch .optim .lr_scheduler import CosineAnnealingLR
2323
24- from utils import PromptPipeline , PPORLElement , PPORLBatch , PPORolloutStorage , Actor , Agent
24+
25+ #####################
26+ ## PyTorch Objects ##
27+ #####################
28+
29+ class PromptPipeline ():
30+ def __init__ (self , prompts : List [str ], max_prompt_length : int , tokenizer ):
31+ super ().__init__ ()
32+
33+ prompts = tokenizer (prompts ).input_ids
34+
35+ self .tokenizer = tokenizer
36+ self .prompts = [prompt [- max_prompt_length :] for prompt in prompts ]
37+ self .prompts = [{"input_ids" : prompt , "attention_mask" : [1 ] * len (prompt )} for prompt in self .prompts ]
38+
39+ def __getitem__ (self , ix : int ):
40+ return self .prompts [ix ]
41+
42+ def __len__ (self ) -> int :
43+ return len (self .prompts )
44+
45+ def create_loader (self , batch_size : int , shuffle = False ) -> DataLoader :
46+ collate_fn = DataCollatorWithPadding (self .tokenizer )
47+ return DataLoader (self , batch_size = batch_size , collate_fn = collate_fn , shuffle = shuffle )
48+
49+ @dataclass
50+ class PPORLElement :
51+ query_tensor : TensorType ["query_size" ]
52+ response_tensor : TensorType ["response_size" ]
53+ logprobs : TensorType ["response_size" , "vocab_size" ]
54+ values : TensorType ["response_size" ]
55+ rewards : TensorType ["response_size" ]
56+
57+
58+ @dataclass
59+ class PPORLBatch :
60+ query_tensors : TensorType ["batch_size" , "query_size" ]
61+ response_tensors : TensorType ["batch_size" , "response_size" ]
62+ logprobs : TensorType ["batch_size" , "response_size" , "vocab_size" ]
63+ values : TensorType ["batch_size" , "response_size" ]
64+ rewards : TensorType ["batch_size" , "response_size" ]
65+
66+
67+ class PPORolloutStorage ():
68+ def __init__ (self , pad_token_id ):
69+ super ().__init__ ()
70+ self .pad_token_id = pad_token_id
71+ self .history : Iterable [PPORLElement ] = [None ]
72+
73+ def push (self , exps : Iterable [PPORLElement ]):
74+ self .history += exps
75+
76+ def clear_history (self ):
77+ self .history = []
78+
79+ def __getitem__ (self , index : int ) -> PPORLElement :
80+ return self .history [index ]
81+
82+ def __len__ (self ) -> int :
83+ return len (self .history )
84+
85+ def create_loader (self , batch_size : int , shuffle : bool ) -> DataLoader :
86+ def collate_fn (elems : Iterable [PPORLElement ]):
87+ return PPORLBatch (
88+ pad_sequence (
89+ [elem .query_tensor .flip (0 ) for elem in elems ],
90+ padding_value = self .pad_token_id ,
91+ batch_first = True ,
92+ ).flip (1 ),
93+ pad_sequence (
94+ [elem .response_tensor for elem in elems ],
95+ padding_value = self .pad_token_id ,
96+ batch_first = True ,
97+ ),
98+ pad_sequence (
99+ [elem .logprobs for elem in elems ],
100+ padding_value = 0.0 ,
101+ batch_first = True ,
102+ ),
103+ pad_sequence (
104+ [elem .values for elem in elems ],
105+ padding_value = 0.0 ,
106+ batch_first = True
107+ ),
108+ pad_sequence (
109+ [elem .rewards for elem in elems ],
110+ padding_value = 0.0 ,
111+ batch_first = True ,
112+ ),
113+ )
114+
115+ return DataLoader (self , batch_size , shuffle = shuffle , collate_fn = collate_fn )
116+
117+ class Actor ():
118+
119+ def __init__ (
120+ self ,
121+ prompt_pipeline ,
122+ tokenizer ,
123+ chunk_size = 128 ):
124+
125+ self .prompt_pipeline = prompt_pipeline
126+ self .chunk_size = chunk_size
127+
128+ self .prompt_pipeline_loader = self .prompt_pipeline .create_loader (self .chunk_size , shuffle = True )
129+ self .prompt_pipeline_iterator = iter (self .prompt_pipeline_loader )
130+
131+ self .ref_model = Agent (config .model .model_path )
132+ self .ref_model_device = config .train .ref_model_device
133+ self .ref_model = self .ref_model .to (self .ref_model_device )
134+
135+ self .tokenizer = tokenizer
136+
137+
138+ def make_experience (self , model , num_rollouts = 128 ):
139+ model_device = next (model .parameters ()).device
140+
141+ ppo_rl_elements = []
142+ while len (ppo_rl_elements ) < num_rollouts :
143+ try :
144+ batch = next (self .prompt_pipeline_iterator )
145+ except StopIteration :
146+ self .pipeline_iterator = iter (self .prompt_pipeline_loader )
147+ batch = next (self .prompt_pipeline_iterator )
148+
149+ trajectories = generate (model , self .tokenizer , ** batch .to (model_device ))
150+
151+ query_tensors = batch .input_ids
152+ response_tensors = trajectories [:, query_tensors .shape [1 ] :]
153+
154+ all_tokens , attention_mask , position_ids = get_model_inputs (
155+ query_tensors .to (response_tensors .device ), response_tensors , self .tokenizer .pad_token_id )
156+ with torch .no_grad ():
157+ logits , values = model (
158+ all_tokens ,
159+ attention_mask = attention_mask ,
160+ position_ids = position_ids )
161+ ref_logits , _ = self .ref_model (
162+ all_tokens .to (self .ref_model_device ),
163+ attention_mask = attention_mask .to (self .ref_model_device ),
164+ position_ids = position_ids .to (self .ref_model_device ))
165+
166+ all_tokens = all_tokens .cpu ()
167+ logits = logits .cpu ()
168+ ref_logits = ref_logits .cpu ()
169+
170+ logprobs = logprobs_from_logits (logits [:, :- 1 , :], all_tokens [:, 1 :])
171+ ref_logprobs = logprobs_from_logits (ref_logits [:, :- 1 , :], all_tokens [:, 1 :])
172+
173+ n = trajectories .shape [0 ]
174+ values = values .cpu ()[:, :- 1 ]
175+ query_tensors = query_tensors .cpu ()
176+ response_tensors = response_tensors .cpu ()
177+
178+ start = query_tensors .shape [1 ] - 1
179+ ends = start + attention_mask [:, start :].sum (1 )
180+ all_values = [values [i , start : ends [i ]] for i in range (n )]
181+ all_logprobs = [logprobs [i , start : ends [i ]] for i in range (n )]
182+
183+ texts = self .tokenizer .batch_decode (trajectories , skip_special_tokens = True )
184+ scores = torch .tensor (reward_fn (texts ), device = 'cpu' , dtype = torch .float )
185+
186+ rewards = - config .method .kl_coef * (logprobs - ref_logprobs )
187+ all_rewards = [None ] * n
188+ for i in range (n ):
189+ rs = rewards [i ][start : ends [i ]]
190+ rs [- 1 ] = scores [i ]
191+ all_rewards [i ] = rs
192+
193+ new_ppo_rl_elements = [
194+ PPORLElement (
195+ query_tensor = query_tensors [i ],
196+ response_tensor = response_tensors [i ],
197+ logprobs = all_logprobs [i ],
198+ values = all_values [i ],
199+ rewards = all_rewards [i ],
200+ )
201+ for i in range (n )
202+ ]
203+
204+ ppo_rl_elements += new_ppo_rl_elements
205+
206+ return ppo_rl_elements , scores .mean ().item ()
207+
208+ class Agent (nn .Module ):
209+ def __init__ (self , model_path , num_layers_unfrozen = 0 ):
210+ super ().__init__ ()
211+
212+ self .base_model = transformers .AutoModelForCausalLM .from_pretrained (model_path , cache_dir = "./models" )
213+
214+ self .logit_head = self .base_model .get_output_embeddings ()
215+
216+ n_embd = self .base_model .lm_head .in_features
217+ self .value_head = nn .Sequential (
218+ nn .Linear (n_embd , n_embd * 2 ),
219+ nn .ReLU (),
220+ nn .Linear (n_embd * 2 , 1 ))
221+
222+ freeze_bottom_causal_layers (self .base_model , num_layers_unfrozen )
223+
224+
225+ def generate (self , input_ids , ** x ):
226+ return self .base_model .generate (input_ids , ** x )
227+
228+ def forward (self , input_ids , attention_mask , position_ids ):
229+
230+ transformer_outputs = self .base_model .transformer (input_ids = input_ids ,
231+ attention_mask = attention_mask ,
232+ position_ids = position_ids )
233+
234+ last_hidden_state = transformer_outputs .last_hidden_state
235+ lm_logits = self .logit_head (last_hidden_state )
236+ value = self .value_head (last_hidden_state ).squeeze (- 1 )
237+
238+ return lm_logits , value
239+
240+ #####################
241+ ## Util Functions ##
242+ #####################
25243
26244def generate (model , tokenizer , input_ids , attention_mask = None , ** kwargs ):
27245
@@ -63,20 +281,79 @@ def freeze_bottom_causal_layers(model: nn.Module, num_layers_unfrozen: int = 0):
63281 for layer in hidden_layers_to_freeze :
64282 layer .requires_grad_ (False )
65283
66- sentiment_fn = pipeline (
67- model = "lvwerra/distilbert-imdb" ,
68- top_k = 2 ,
69- batch_size = config .method .num_rollouts ,
70- device = config .train .reward_model_device ,
71- )
72-
73284def get_positive_score (scores ):
74285 return dict (map (lambda x : tuple (x .values ()), scores ))["POSITIVE" ]
75286
76287def reward_fn (samples : List [str ]) -> List [float ]:
77288 sentiments = list (map (get_positive_score , sentiment_fn (samples )))
78289 return sentiments
79290
291+ ############################
292+ ## Model and Data Configs ##
293+ ############################
294+
295+ config = {
296+ 'train' : {
297+ 'seed' : 2023 ,
298+ 'seq_length' : 1024 ,
299+ 'epochs' : 50 ,
300+ 'total_steps' : 5000 ,
301+ 'batch_size' : 64 ,
302+ 'eval_interval' : 100 ,
303+ 'model_device' :'cuda:0' ,
304+ 'ref_model_device' :'cpu' ,
305+ 'reward_model_device' :'cpu' },
306+ 'model' : {
307+ 'model_path' : 'lvwerra/gpt2-imdb' , #'edbeeching/gpt-neo-1.3B-imdb',
308+ 'tokenizer_path' : 'lvwerra/gpt2-imdb' , #'edbeeching/gpt-neo-1.3B-imdb',
309+ 'num_layers_unfrozen' : 1 },
310+ 'optimizer' : {
311+ 'name' : 'adamw' ,
312+ 'kwargs' : {'lr' : 0.0001 ,
313+ 'betas' : [0.9 , 0.95 ],
314+ 'eps' : 1e-08 ,
315+ 'weight_decay' : 1e-06 }},
316+ 'scheduler' : {
317+ 'name' : 'cosine_annealing' ,
318+ 'kwargs' : {
319+ 'T_max' : 10000 , 'eta_min' : 0.0001 }},
320+ 'method' : {
321+ 'use_whitening' : True ,
322+ 'prompt_size' : 10 ,
323+ 'num_rollouts' : 128 ,
324+ 'chunk_size' : 128 ,
325+ 'ppo_epochs' : 4 ,
326+ 'kl_coef' : 0.05 ,
327+ 'horizon' : 10000 ,
328+ 'gamma' : 1 ,
329+ 'lam' : 0.95 ,
330+ 'cliprange' : 0.2 ,
331+ 'cliprange_value' : 0.2 ,
332+ 'vf_coef' : 1 ,
333+ 'scale_reward' : False ,
334+ 'ref_mean' : None ,
335+ 'ref_std' : None ,
336+ 'cliprange_reward' : 10 ,
337+ 'gen_kwargs' : {
338+ 'max_new_tokens' : 60 ,
339+ 'top_k' : 0 ,
340+ 'top_p' : 1.0 ,
341+ 'do_sample' : True }}}
342+
343+ config = DictConfig (config )
344+
345+ random .seed (config .train .seed )
346+ np .random .seed (config .train .seed )
347+ torch .manual_seed (config .train .seed )
348+ torch .cuda .manual_seed (config .train .seed )
349+
350+ sentiment_fn = pipeline (
351+ model = "lvwerra/distilbert-imdb" ,
352+ top_k = 2 ,
353+ batch_size = config .method .num_rollouts ,
354+ device = config .train .reward_model_device ,
355+ )
356+
80357imdb = load_dataset ("imdb" , split = "train+test" )
81358
82359prompts = [" " .join (review .split ()[:config .method .prompt_size ]) for review in imdb ["text" ]]
@@ -108,6 +385,11 @@ def reward_fn(samples: List[str]) -> List[float]:
108385
109386reward_fn (generated_text )
110387
388+
389+ ###############
390+ ## Main loop ##
391+ ###############
392+
111393prompt_pipeline = PromptPipeline (prompts , config .train .seq_length , tokenizer )
112394
113395actor = Actor (prompt_pipeline , tokenizer , chunk_size = config .method .chunk_size )
@@ -141,6 +423,11 @@ def reward_fn(samples: List[str]) -> List[float]:
141423
142424 tbar .set_description (f"| score: { score :.3f} |" )
143425
426+
427+ ################
428+ ## Eval steps ##
429+ ################
430+
144431input_ids = tokenizer .batch_encode_plus (
145432 ["my feeling about the movie" , "this is" , "I can tell with certainty" ],
146433 return_tensors = 'pt' ,
@@ -164,3 +451,4 @@ def reward_fn(samples: List[str]) -> List[float]:
164451print (generated_text [1 ].replace ('\n ' , ' ' ) + '\n ' , rewards [1 ])
165452print (generated_text [2 ].replace ('\n ' , ' ' ) + '\n ' , rewards [2 ])
166453print ('all rewards mean:' ,np .mean (rewards ))
454+
0 commit comments