@@ -182,39 +182,40 @@ def train_model(data_loaders, model, criterion, optimizer, lr_scheduler, num_epo
182182 print ('remian_negative_list: %d' % (len (remain_negative_list )))
183183 # 如果剩余的负样本集小于96个,那么结束hard negative mining
184184 if len (remain_negative_list ) > batch_negative :
185- remain_dataset = CustomHardNegativeMiningDataset (remain_negative_list , jpeg_images , transform = transform )
186- remain_data_loader = DataLoader (remain_dataset , batch_size = batch_total , num_workers = 8 , drop_last = True )
185+ with torch .set_grad_enabled (False ):
186+ remain_dataset = CustomHardNegativeMiningDataset (remain_negative_list , jpeg_images , transform = transform )
187+ remain_data_loader = DataLoader (remain_dataset , batch_size = batch_total , num_workers = 8 , drop_last = True )
187188
188- # 获取训练数据集的负样本集
189- negative_list = train_dataset .get_negatives ()
190- res_negative_list = list ()
191- # Iterate over data.
192- for inputs , labels , cache_dicts in remain_data_loader :
193- inputs = inputs .to (device )
194- labels = labels .to (device )
189+ # 获取训练数据集的负样本集
190+ negative_list = train_dataset .get_negatives ()
191+ res_negative_list = list ()
192+ # Iterate over data.
193+ for inputs , labels , cache_dicts in remain_data_loader :
194+ inputs = inputs .to (device )
195+ labels = labels .to (device )
195196
196- # zero the parameter gradients
197- optimizer .zero_grad ()
197+ # zero the parameter gradients
198+ optimizer .zero_grad ()
199+
200+ outputs = model (inputs )
201+ # print(outputs.shape)
202+ _ , preds = torch .max (outputs , 1 )
198203
199- outputs = model (inputs )
200- # print(outputs.shape)
201- _ , preds = torch .max (outputs , 1 )
202-
203- hard_negative_list , easy_neagtive_list = add_hard_negatives (preds .cpu ().numpy (), cache_dicts )
204-
205- negative_list .extend (hard_negative_list )
206- res_negative_list .extend (easy_neagtive_list )
207-
208- # 训练完成后,重置负样本,进行hard negatives mining
209- train_dataset .set_negative_list (negative_list )
210- tmp_sampler = CustomBatchSampler (train_dataset .get_positive_num (), train_dataset .get_negative_num (),
211- batch_positive , batch_negative )
212- data_loaders ['train' ] = DataLoader (train_dataset , batch_size = batch_total , sampler = tmp_sampler ,
213- num_workers = 8 , drop_last = True )
214- # 重置数据集大小
215- data_sizes ['train' ] = len (tmp_sampler )
216- # 保存剩余的负样本集
217- data_loaders ['remain' ] = res_negative_list
204+ hard_negative_list , easy_neagtive_list = add_hard_negatives (preds .cpu ().numpy (), cache_dicts )
205+
206+ negative_list .extend (hard_negative_list )
207+ res_negative_list .extend (easy_neagtive_list )
208+
209+ # 训练完成后,重置负样本,进行hard negatives mining
210+ train_dataset .set_negative_list (negative_list )
211+ tmp_sampler = CustomBatchSampler (train_dataset .get_positive_num (), train_dataset .get_negative_num (),
212+ batch_positive , batch_negative )
213+ data_loaders ['train' ] = DataLoader (train_dataset , batch_size = batch_total , sampler = tmp_sampler ,
214+ num_workers = 8 , drop_last = True )
215+ # 重置数据集大小
216+ data_sizes ['train' ] = len (tmp_sampler )
217+ # 保存剩余的负样本集
218+ data_loaders ['remain' ] = res_negative_list
218219
219220 # 每训练一轮就保存
220221 save_model (model , 'models/linear_svm_alexnet_car_%d.pth' % epoch )
0 commit comments