Skip to content
This repository was archived by the owner on Nov 2, 2024. It is now read-only.

Commit 59ee041

Browse files
committed
perf(svm): 计算剩余负样本集的检测精度
1 parent 97690ce commit 59ee041

File tree

1 file changed

+7
-1
lines changed

1 file changed

+7
-1
lines changed

py/linear_svm.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -196,7 +196,8 @@ def train_model(data_loaders, model, criterion, optimizer, lr_scheduler, num_epo
196196

197197
# 获取训练数据集的负样本集
198198
negative_list = train_dataset.get_negatives()
199-
res_negative_list = list()
199+
200+
running_corrects = 0
200201
# Iterate over data.
201202
for inputs, labels, cache_dicts in remain_data_loader:
202203
inputs = inputs.to(device)
@@ -209,9 +210,14 @@ def train_model(data_loaders, model, criterion, optimizer, lr_scheduler, num_epo
209210
# print(outputs.shape)
210211
_, preds = torch.max(outputs, 1)
211212

213+
running_corrects += torch.sum(preds == labels.data)
214+
212215
hard_negative_list, easy_neagtive_list = get_hard_negatives(preds.cpu().numpy(), cache_dicts)
213216
negative_list = add_hard_negatives(hard_negative_list, negative_list)
214217

218+
remain_acc = running_corrects.double() / data_sizes[phase]
219+
print('remain acc: {:.4f}'.format(remain_acc))
220+
215221
# 训练完成后,重置负样本,进行hard negatives mining
216222
train_dataset.set_negative_list(negative_list)
217223
tmp_sampler = CustomBatchSampler(train_dataset.get_positive_num(), train_dataset.get_negative_num(),

0 commit comments

Comments
 (0)