Skip to content
This repository was archived by the owner on Nov 2, 2024. It is now read-only.

Commit 97690ce

Browse files
committed
perf(svm): 更新hard negative mining实现
1 parent 6c4821b commit 97690ce

File tree

1 file changed

+11
-7
lines changed

1 file changed

+11
-7
lines changed

py/linear_svm.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,15 @@ def hinge_loss(outputs, labels):
9393
return loss
9494

9595

96-
def add_hard_negatives(preds, cache_dicts):
96+
def add_hard_negatives(target_list, negative_list):
97+
for item in target_list:
98+
if item not in negative_list:
99+
negative_list.append(item)
100+
101+
return negative_list
102+
103+
104+
def get_hard_negatives(preds, cache_dicts):
97105
fp_mask = preds == 1
98106
tn_mask = preds == 0
99107

@@ -201,10 +209,8 @@ def train_model(data_loaders, model, criterion, optimizer, lr_scheduler, num_epo
201209
# print(outputs.shape)
202210
_, preds = torch.max(outputs, 1)
203211

204-
hard_negative_list, easy_neagtive_list = add_hard_negatives(preds.cpu().numpy(), cache_dicts)
205-
206-
negative_list.extend(hard_negative_list)
207-
res_negative_list.extend(easy_neagtive_list)
212+
hard_negative_list, easy_neagtive_list = get_hard_negatives(preds.cpu().numpy(), cache_dicts)
213+
negative_list = add_hard_negatives(hard_negative_list, negative_list)
208214

209215
# 训练完成后,重置负样本,进行hard negatives mining
210216
train_dataset.set_negative_list(negative_list)
@@ -214,8 +220,6 @@ def train_model(data_loaders, model, criterion, optimizer, lr_scheduler, num_epo
214220
num_workers=8, drop_last=True)
215221
# 重置数据集大小
216222
data_sizes['train'] = len(tmp_sampler)
217-
# 保存剩余的负样本集
218-
data_loaders['remain'] = res_negative_list
219223

220224
# 每训练一轮就保存
221225
save_model(model, 'models/linear_svm_alexnet_car_%d.pth' % epoch)

0 commit comments

Comments
 (0)