@@ -93,12 +93,15 @@ def hinge_loss(outputs, labels):
9393 return loss
9494
9595
96- def add_hard_negatives (target_list , negative_list ):
97- for item in target_list :
98- if item not in negative_list :
96+ def add_hard_negatives (hard_negative_list , negative_list , add_negative_list ):
97+ for item in hard_negative_list :
98+ if add_negative_list is None :
99+ # 第一次添加负样本
99100 negative_list .append (item )
100-
101- return negative_list
101+ add_negative_list .append (list (item ['rect' ]))
102+ if item ['rect' ] not in add_negative_list :
103+ negative_list .append (item )
104+ add_negative_list .append (list (item ['rect' ])
102105
103106
104107def get_hard_negatives (preds , cache_dicts ):
@@ -195,6 +198,8 @@ def train_model(data_loaders, model, criterion, optimizer, lr_scheduler, num_epo
195198
196199 # 获取训练数据集的负样本集
197200 negative_list = train_dataset .get_negatives ()
201+ # 记录后续增加的负样本
202+ add_negative_list = data_loaders ['add_negative' ]
198203
199204 running_corrects = 0
200205 # Iterate over data.
@@ -212,7 +217,7 @@ def train_model(data_loaders, model, criterion, optimizer, lr_scheduler, num_epo
212217 running_corrects += torch .sum (preds == labels .data )
213218
214219 hard_negative_list , easy_neagtive_list = get_hard_negatives (preds .cpu ().numpy (), cache_dicts )
215- negative_list = add_hard_negatives (hard_negative_list , negative_list )
220+ add_hard_negatives (hard_negative_list , negative_list , add_negative_list )
216221
217222 remain_acc = running_corrects .double () / data_sizes [phase ]
218223 print ('remiam negative size: {}, acc: {:.4f}' .format (len (remain_negative_list ), remain_acc ))
@@ -223,6 +228,7 @@ def train_model(data_loaders, model, criterion, optimizer, lr_scheduler, num_epo
223228 batch_positive , batch_negative )
224229 data_loaders ['train' ] = DataLoader (train_dataset , batch_size = batch_total , sampler = tmp_sampler ,
225230 num_workers = 8 , drop_last = True )
231+ data_loaders ['add_negative' ] = add_negative_list
226232 # 重置数据集大小
227233 data_sizes ['train' ] = len (tmp_sampler )
228234
0 commit comments