Skip to content
This repository was archived by the owner on Nov 2, 2024. It is now read-only.

Commit 1195632

Browse files
committed
fix(svm): 统计剩余负样本集检测精度
1 parent 5de6a31 commit 1195632

File tree

1 file changed

+34
-36
lines changed

1 file changed

+34
-36
lines changed

py/linear_svm.py

Lines changed: 34 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -190,47 +190,45 @@ def train_model(data_loaders, model, criterion, optimizer, lr_scheduler, num_epo
190190
jpeg_images = train_dataset.get_jpeg_images()
191191
transform = train_dataset.get_transform()
192192

193-
# 如果剩余的负样本集小于96个,那么结束hard negative mining
194-
if len(remain_negative_list) > batch_negative:
195-
with torch.set_grad_enabled(False):
196-
remain_dataset = CustomHardNegativeMiningDataset(remain_negative_list, jpeg_images, transform=transform)
197-
remain_data_loader = DataLoader(remain_dataset, batch_size=batch_total, num_workers=8, drop_last=True)
198-
199-
# 获取训练数据集的负样本集
200-
negative_list = train_dataset.get_negatives()
201-
# 记录后续增加的负样本
202-
add_negative_list = data_loaders.get('add_negative', [])
203-
204-
running_corrects = 0
205-
# Iterate over data.
206-
for inputs, labels, cache_dicts in remain_data_loader:
207-
inputs = inputs.to(device)
208-
labels = labels.to(device)
209-
210-
# zero the parameter gradients
211-
optimizer.zero_grad()
193+
with torch.set_grad_enabled(False):
194+
remain_dataset = CustomHardNegativeMiningDataset(remain_negative_list, jpeg_images, transform=transform)
195+
remain_data_loader = DataLoader(remain_dataset, batch_size=batch_total, num_workers=8, drop_last=True)
212196

213-
outputs = model(inputs)
214-
# print(outputs.shape)
215-
_, preds = torch.max(outputs, 1)
197+
# 获取训练数据集的负样本集
198+
negative_list = train_dataset.get_negatives()
199+
# 记录后续增加的负样本
200+
add_negative_list = data_loaders.get('add_negative', [])
216201

217-
running_corrects += torch.sum(preds == labels.data)
202+
running_corrects = 0
203+
# Iterate over data.
204+
for inputs, labels, cache_dicts in remain_data_loader:
205+
inputs = inputs.to(device)
206+
labels = labels.to(device)
218207

219-
hard_negative_list, easy_neagtive_list = get_hard_negatives(preds.cpu().numpy(), cache_dicts)
220-
add_hard_negatives(hard_negative_list, negative_list, add_negative_list)
208+
# zero the parameter gradients
209+
optimizer.zero_grad()
221210

222-
remain_acc = running_corrects.double() / data_sizes[phase]
223-
print('remiam negative size: {}, acc: {:.4f}'.format(len(remain_negative_list), remain_acc))
211+
outputs = model(inputs)
212+
# print(outputs.shape)
213+
_, preds = torch.max(outputs, 1)
214+
215+
running_corrects += torch.sum(preds == labels.data)
224216

225-
# 训练完成后,重置负样本,进行hard negatives mining
226-
train_dataset.set_negative_list(negative_list)
227-
tmp_sampler = CustomBatchSampler(train_dataset.get_positive_num(), train_dataset.get_negative_num(),
228-
batch_positive, batch_negative)
229-
data_loaders['train'] = DataLoader(train_dataset, batch_size=batch_total, sampler=tmp_sampler,
230-
num_workers=8, drop_last=True)
231-
data_loaders['add_negative'] = add_negative_list
232-
# 重置数据集大小
233-
data_sizes['train'] = len(tmp_sampler)
217+
hard_negative_list, easy_neagtive_list = get_hard_negatives(preds.cpu().numpy(), cache_dicts)
218+
add_hard_negatives(hard_negative_list, negative_list, add_negative_list)
219+
220+
remain_acc = running_corrects.double() / len(remain_negative_list)
221+
print('remiam negative size: {}, acc: {:.4f}'.format(len(remain_negative_list), remain_acc))
222+
223+
# 训练完成后,重置负样本,进行hard negatives mining
224+
train_dataset.set_negative_list(negative_list)
225+
tmp_sampler = CustomBatchSampler(train_dataset.get_positive_num(), train_dataset.get_negative_num(),
226+
batch_positive, batch_negative)
227+
data_loaders['train'] = DataLoader(train_dataset, batch_size=batch_total, sampler=tmp_sampler,
228+
num_workers=8, drop_last=True)
229+
data_loaders['add_negative'] = add_negative_list
230+
# 重置数据集大小
231+
data_sizes['train'] = len(tmp_sampler)
234232

235233
# 每训练一轮就保存
236234
save_model(model, 'models/linear_svm_alexnet_car_%d.pth' % epoch)

0 commit comments

Comments
 (0)