Skip to content
This repository was archived by the owner on Nov 2, 2024. It is now read-only.

Commit 0776efa

Browse files
committed
perf(detector): 实现车辆检测,绘制边界框
1 parent 41e4ed3 commit 0776efa

File tree

1 file changed

+51
-38
lines changed

1 file changed

+51
-38
lines changed

py/car_detector.py

Lines changed: 51 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -2,79 +2,92 @@
22

33
"""
44
@date: 2020/3/2 上午8:07
5-
@file: detector.py
5+
@file: car_detector.py
66
@author: zj
77
@description: 车辆类别检测器
88
"""
99

10-
import os
1110
import copy
1211
import cv2
1312
import torch
1413
import torch.nn as nn
1514
from torchvision.models import alexnet
1615
import torchvision.transforms as transforms
1716
import selectivesearch
17+
from utils.util import parse_xml
1818

19-
from utils.util import parse_car_csv
2019

21-
if __name__ == '__main__':
22-
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
20+
def get_device():
21+
return torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
22+
23+
24+
def get_transform():
2325
# 数据转换
2426
transform = transforms.Compose([
2527
transforms.ToPILImage(),
2628
transforms.Resize((227, 227)),
29+
transforms.RandomHorizontalFlip(),
2730
transforms.ToTensor(),
28-
transforms.Normalize((0.5,), (0.5,))
31+
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
2932
])
33+
return transform
34+
35+
36+
def get_model(device=None):
3037
# 加载CNN模型
3138
model = alexnet()
3239
num_classes = 2
3340
num_features = model.classifier[6].in_features
3441
model.classifier[6] = nn.Linear(num_features, num_classes)
35-
model.load_state_dict(torch.load('./models/linear_svm_alexnet_car.pth'))
42+
model.load_state_dict(torch.load('./models/linear_svm_alexnet_car_4.pth'))
3643
model.eval()
37-
# print(model)
38-
model = model.to(device)
44+
3945
# 取消梯度追踪
4046
for param in model.parameters():
4147
param.requires_grad = False
48+
if device:
49+
model = model.to(device)
50+
51+
return model
52+
53+
54+
if __name__ == '__main__':
55+
device = get_device()
56+
transform = get_transform()
57+
model = get_model(device=device)
58+
4259
# 创建selectivesearch对象
4360
gs = selectivesearch.get_selective_search()
4461

45-
car_root_dir = './data/voc_car/'
46-
val_root_dir = os.path.join(car_root_dir, 'val')
47-
samples = parse_car_csv(val_root_dir)
48-
49-
for sample_name in samples:
50-
jpeg_path = os.path.join(val_root_dir, 'JPEGImages', sample_name + ".jpg")
51-
annotation_path = os.path.join(val_root_dir, 'Annotations', sample_name + ".xml")
62+
test_img_path = './data/voc_car/val/JPEGImages/000007.jpg'
63+
test_xml_path = './data/voc_car/val/Annotations/000007.xml'
5264

53-
img = cv2.imread(jpeg_path)
54-
dst = copy.deepcopy(img)
65+
img = cv2.imread(test_img_path)
66+
dst = copy.deepcopy(img)
5567

56-
# 候选区域建议
57-
selectivesearch.config(gs, img, strategy='f')
58-
rects = selectivesearch.get_rects(gs)
59-
print('候选区域建议数目: %d' % len(rects))
68+
bndboxs = parse_xml(test_xml_path)
69+
for bndbox in bndboxs:
70+
xmin, ymin, xmax, ymax = bndbox
71+
cv2.rectangle(dst, (xmin, ymin), (xmax, ymax), color=(0, 255, 0), thickness=2)
6072

61-
rects_transform = transform(rects)
62-
print(rects_transform.shape)
63-
exit(0)
73+
# 候选区域建议
74+
selectivesearch.config(gs, img, strategy='f')
75+
rects = selectivesearch.get_rects(gs)
76+
print('候选区域建议数目: %d' % len(rects))
6477

65-
for rect in rects:
66-
xmin, ymin, xmax, ymax = rect
67-
rect_img = img[ymin:ymax, xmin:xmax]
78+
for rect in rects:
79+
xmin, ymin, xmax, ymax = rect
80+
rect_img = img[ymin:ymax, xmin:xmax]
6881

69-
rect_transform = transform(rect_img).to(device)
70-
output = model(rect_transform.unsqueeze(0))[0]
82+
rect_transform = transform(rect_img).to(device)
83+
output = model(rect_transform.unsqueeze(0))[0]
7184

72-
if torch.argmax(output).item() == 1:
73-
"""
74-
预测为汽车
75-
"""
76-
cv2.rectangle(dst, (xmin, ymin), (xmax, ymax), color=(0, 0, 255), thickness=1)
77-
print(rect, output)
85+
if torch.argmax(output).item() == 1:
86+
"""
87+
预测为汽车
88+
"""
89+
cv2.rectangle(dst, (xmin, ymin), (xmax, ymax), color=(0, 0, 255), thickness=2)
90+
print(rect, output)
7891

79-
cv2.imshow('img', dst)
80-
cv2.waitKey(0)
92+
cv2.imshow('img', dst)
93+
cv2.waitKey(0)

0 commit comments

Comments
 (0)