Skip to content
This repository was archived by the owner on Nov 2, 2024. It is now read-only.

Commit 5de6a31

Browse files
committed
perf(svm): 添加列表add_negative_list,保存后续添加的负样本边界框,判断是否需要添加
1 parent 474d6f9 commit 5de6a31

File tree

1 file changed

+5
-5
lines changed

1 file changed

+5
-5
lines changed

py/linear_svm.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -95,13 +95,13 @@ def hinge_loss(outputs, labels):
9595

9696
def add_hard_negatives(hard_negative_list, negative_list, add_negative_list):
9797
for item in hard_negative_list:
98-
if add_negative_list is None:
98+
if len(add_negative_list) == 0:
9999
# 第一次添加负样本
100100
negative_list.append(item)
101-
add_negative_list.append(item['rect'])
102-
if item['rect'] not in add_negative_list:
101+
add_negative_list.append(list(item['rect']))
102+
if list(item['rect']) not in add_negative_list:
103103
negative_list.append(item)
104-
add_negative_list.append(item['rect'])
104+
add_negative_list.append(list(item['rect']))
105105

106106

107107
def get_hard_negatives(preds, cache_dicts):
@@ -199,7 +199,7 @@ def train_model(data_loaders, model, criterion, optimizer, lr_scheduler, num_epo
199199
# 获取训练数据集的负样本集
200200
negative_list = train_dataset.get_negatives()
201201
# 记录后续增加的负样本
202-
add_negative_list = data_loaders['add_negative']
202+
add_negative_list = data_loaders.get('add_negative', [])
203203

204204
running_corrects = 0
205205
# Iterate over data.

0 commit comments

Comments
 (0)