@@ -93,7 +93,15 @@ def hinge_loss(outputs, labels):
9393 return loss
9494
9595
96- def add_hard_negatives (preds , cache_dicts ):
96+ def add_hard_negatives (target_list , negative_list ):
97+ for item in target_list :
98+ if item not in negative_list :
99+ negative_list .append (item )
100+
101+ return negative_list
102+
103+
104+ def get_hard_negatives (preds , cache_dicts ):
97105 fp_mask = preds == 1
98106 tn_mask = preds == 0
99107
@@ -201,10 +209,8 @@ def train_model(data_loaders, model, criterion, optimizer, lr_scheduler, num_epo
201209 # print(outputs.shape)
202210 _ , preds = torch .max (outputs , 1 )
203211
204- hard_negative_list , easy_neagtive_list = add_hard_negatives (preds .cpu ().numpy (), cache_dicts )
205-
206- negative_list .extend (hard_negative_list )
207- res_negative_list .extend (easy_neagtive_list )
212+ hard_negative_list , easy_neagtive_list = get_hard_negatives (preds .cpu ().numpy (), cache_dicts )
213+ negative_list = add_hard_negatives (hard_negative_list , negative_list )
208214
209215 # 训练完成后,重置负样本,进行hard negatives mining
210216 train_dataset .set_negative_list (negative_list )
@@ -214,8 +220,6 @@ def train_model(data_loaders, model, criterion, optimizer, lr_scheduler, num_epo
214220 num_workers = 8 , drop_last = True )
215221 # 重置数据集大小
216222 data_sizes ['train' ] = len (tmp_sampler )
217- # 保存剩余的负样本集
218- data_loaders ['remain' ] = res_negative_list
219223
220224 # 每训练一轮就保存
221225 save_model (model , 'models/linear_svm_alexnet_car_%d.pth' % epoch )
0 commit comments