@@ -190,47 +190,45 @@ def train_model(data_loaders, model, criterion, optimizer, lr_scheduler, num_epo
190190 jpeg_images = train_dataset .get_jpeg_images ()
191191 transform = train_dataset .get_transform ()
192192
193- # 如果剩余的负样本集小于96个,那么结束hard negative mining
194- if len (remain_negative_list ) > batch_negative :
195- with torch .set_grad_enabled (False ):
196- remain_dataset = CustomHardNegativeMiningDataset (remain_negative_list , jpeg_images , transform = transform )
197- remain_data_loader = DataLoader (remain_dataset , batch_size = batch_total , num_workers = 8 , drop_last = True )
198-
199- # 获取训练数据集的负样本集
200- negative_list = train_dataset .get_negatives ()
201- # 记录后续增加的负样本
202- add_negative_list = data_loaders .get ('add_negative' , [])
203-
204- running_corrects = 0
205- # Iterate over data.
206- for inputs , labels , cache_dicts in remain_data_loader :
207- inputs = inputs .to (device )
208- labels = labels .to (device )
209-
210- # zero the parameter gradients
211- optimizer .zero_grad ()
193+ with torch .set_grad_enabled (False ):
194+ remain_dataset = CustomHardNegativeMiningDataset (remain_negative_list , jpeg_images , transform = transform )
195+ remain_data_loader = DataLoader (remain_dataset , batch_size = batch_total , num_workers = 8 , drop_last = True )
212196
213- outputs = model (inputs )
214- # print(outputs.shape)
215- _ , preds = torch .max (outputs , 1 )
197+ # 获取训练数据集的负样本集
198+ negative_list = train_dataset .get_negatives ()
199+ # 记录后续增加的负样本
200+ add_negative_list = data_loaders .get ('add_negative' , [])
216201
217- running_corrects += torch .sum (preds == labels .data )
202+ running_corrects = 0
203+ # Iterate over data.
204+ for inputs , labels , cache_dicts in remain_data_loader :
205+ inputs = inputs .to (device )
206+ labels = labels .to (device )
218207
219- hard_negative_list , easy_neagtive_list = get_hard_negatives ( preds . cpu (). numpy (), cache_dicts )
220- add_hard_negatives ( hard_negative_list , negative_list , add_negative_list )
208+ # zero the parameter gradients
209+ optimizer . zero_grad ( )
221210
222- remain_acc = running_corrects .double () / data_sizes [phase ]
223- print ('remiam negative size: {}, acc: {:.4f}' .format (len (remain_negative_list ), remain_acc ))
211+ outputs = model (inputs )
212+ # print(outputs.shape)
213+ _ , preds = torch .max (outputs , 1 )
214+
215+ running_corrects += torch .sum (preds == labels .data )
224216
225- # 训练完成后,重置负样本,进行hard negatives mining
226- train_dataset .set_negative_list (negative_list )
227- tmp_sampler = CustomBatchSampler (train_dataset .get_positive_num (), train_dataset .get_negative_num (),
228- batch_positive , batch_negative )
229- data_loaders ['train' ] = DataLoader (train_dataset , batch_size = batch_total , sampler = tmp_sampler ,
230- num_workers = 8 , drop_last = True )
231- data_loaders ['add_negative' ] = add_negative_list
232- # 重置数据集大小
233- data_sizes ['train' ] = len (tmp_sampler )
217+ hard_negative_list , easy_neagtive_list = get_hard_negatives (preds .cpu ().numpy (), cache_dicts )
218+ add_hard_negatives (hard_negative_list , negative_list , add_negative_list )
219+
220+ remain_acc = running_corrects .double () / len (remain_negative_list )
221+ print ('remiam negative size: {}, acc: {:.4f}' .format (len (remain_negative_list ), remain_acc ))
222+
223+ # 训练完成后,重置负样本,进行hard negatives mining
224+ train_dataset .set_negative_list (negative_list )
225+ tmp_sampler = CustomBatchSampler (train_dataset .get_positive_num (), train_dataset .get_negative_num (),
226+ batch_positive , batch_negative )
227+ data_loaders ['train' ] = DataLoader (train_dataset , batch_size = batch_total , sampler = tmp_sampler ,
228+ num_workers = 8 , drop_last = True )
229+ data_loaders ['add_negative' ] = add_negative_list
230+ # 重置数据集大小
231+ data_sizes ['train' ] = len (tmp_sampler )
234232
235233 # 每训练一轮就保存
236234 save_model (model , 'models/linear_svm_alexnet_car_%d.pth' % epoch )
0 commit comments