99
1010import copy
1111import cv2
12+ import numpy as np
1213import torch
1314import torch .nn as nn
1415from torchvision .models import alexnet
1516import torchvision .transforms as transforms
1617import selectivesearch
17- from utils .util import parse_xml
18+
19+ import utils .util as util
1820
1921
2022def get_device ():
@@ -39,7 +41,6 @@ def get_model(device=None):
3941 num_classes = 2
4042 num_features = model .classifier [6 ].in_features
4143 model .classifier [6 ] = nn .Linear (num_features , num_classes )
42- # model.load_state_dict(torch.load('./models/linear_svm_alexnet_car_4.pth'))
4344 model .load_state_dict (torch .load ('./models/best_linear_svm_alexnet_car.pth' ))
4445 model .eval ()
4546
@@ -68,6 +69,45 @@ def draw_box_with_text(img, rect_list, score_list):
6869 cv2 .putText (img , "{:.3f}" .format (score ), (xmin , ymin ), cv2 .FONT_HERSHEY_SIMPLEX , 0.5 , (255 , 255 , 255 ), 1 )
6970
7071
72+ def nms (rect_list , score_list ):
73+ """
74+ 非最大抑制
75+ """
76+ nms_rects = list ()
77+ nms_scores = list ()
78+
79+ rect_array = np .array (rect_list )
80+ score_array = np .array (score_list )
81+
82+ # 一次排序后即可
83+ # 按分类概率从大到小排序
84+ idxs = np .argsort (score_array )[::- 1 ]
85+ rect_array = rect_array [idxs ]
86+ score_array = score_array [idxs ]
87+
88+ thresh = 0.3
89+ while len (score_array ) > 0 :
90+ # 添加分类概率最大的边界框
91+ nms_rects .append (rect_array [0 ])
92+ nms_scores .append (score_array [0 ])
93+ rect_array = rect_array [1 :]
94+ score_array = score_array [1 :]
95+
96+ length = len (score_array )
97+ if length <= 0 :
98+ break
99+
100+ # 计算IoU
101+ iou_scores = util .iou (np .array (nms_rects [len (nms_rects ) - 1 ]), rect_array )
102+ # print(iou_scores)
103+ # 去除重叠率大于等于thresh的边界框
104+ idxs = np .where (iou_scores < thresh )[0 ]
105+ rect_array = rect_array [idxs ]
106+ score_array = score_array [idxs ]
107+
108+ return nms_rects , nms_scores
109+
110+
71111if __name__ == '__main__' :
72112 device = get_device ()
73113 transform = get_transform ()
@@ -76,13 +116,15 @@ def draw_box_with_text(img, rect_list, score_list):
76116 # 创建selectivesearch对象
77117 gs = selectivesearch .get_selective_search ()
78118
79- test_img_path = './data/voc_car/val/JPEGImages/000007.jpg'
80- test_xml_path = './data/voc_car/val/Annotations/000007.xml'
119+ # test_img_path = '../imgs/000007.jpg'
120+ # test_xml_path = '../imgs/000007.xml'
121+ test_img_path = '../imgs/000012.jpg'
122+ test_xml_path = '../imgs/000012.xml'
81123
82124 img = cv2 .imread (test_img_path )
83125 dst = copy .deepcopy (img )
84126
85- bndboxs = parse_xml (test_xml_path )
127+ bndboxs = util . parse_xml (test_xml_path )
86128 for bndbox in bndboxs :
87129 xmin , ymin , xmax , ymax = bndbox
88130 cv2 .rectangle (dst , (xmin , ymin ), (xmax , ymax ), color = (0 , 255 , 0 ), thickness = 1 )
@@ -94,6 +136,8 @@ def draw_box_with_text(img, rect_list, score_list):
94136
95137 # softmax = torch.softmax()
96138
139+ svm_thresh = 0.60
140+
97141 # 保存正样本边界框以及
98142 score_list = list ()
99143 positive_list = list ()
@@ -110,12 +154,16 @@ def draw_box_with_text(img, rect_list, score_list):
110154 """
111155 probs = torch .softmax (output , dim = 0 ).cpu ().numpy ()
112156
113- score_list .append (probs [1 ])
114- positive_list .append (rect )
115- # cv2.rectangle(dst, (xmin, ymin), (xmax, ymax), color=(0, 0, 255), thickness=2)
116- print (rect , output , probs )
157+ if probs [1 ] >= svm_thresh :
158+ score_list .append (probs [1 ])
159+ positive_list .append (rect )
160+ # cv2.rectangle(dst, (xmin, ymin), (xmax, ymax), color=(0, 0, 255), thickness=2)
161+ print (rect , output , probs )
117162
118- draw_box_with_text (dst , positive_list , score_list )
163+ nms_rects , nms_scores = nms (positive_list , score_list )
164+ print (nms_rects )
165+ print (nms_scores )
166+ draw_box_with_text (dst , nms_rects , nms_scores )
119167
120168 cv2 .imshow ('img' , dst )
121169 cv2 .waitKey (0 )
0 commit comments