@@ -87,56 +87,6 @@ def parse_data(img_path, xml_path, transform):
8787 return img , data_dict
8888
8989
90- def load_model (device ):
91- model_path = './models/checkpoint_yolo_v1.pth'
92- model = YOLO_v1 (S = 7 , B = 2 , C = 3 )
93- model .load_state_dict (torch .load (model_path ))
94- model .eval ()
95- for param in model .parameters ():
96- param .requires_grad = False
97- model = model .to (device )
98-
99- return model
100-
101-
102- def deform_bboxs (pred_bboxs , data_dict ):
103- """
104- :param pred_bboxs: [S*S, 4]
105- :return:
106- """
107- scale_h , scale_w = data_dict ['scale_size' ]
108- grid_w = scale_w / S
109- grid_h = scale_h / S
110-
111- bboxs = np .zeros (pred_bboxs .shape )
112- for i in range (S * S ):
113- row = int (i / S )
114- col = int (i % S )
115-
116- x_center , y_center , box_w , box_h = pred_bboxs [i ]
117- bboxs [i , 0 ] = (col + x_center ) * grid_w
118- bboxs [i , 1 ] = (row + y_center ) * grid_h
119- bboxs [i , 2 ] = box_w * scale_w
120- bboxs [i , 3 ] = box_h * scale_h
121- # (x_center, y_center, w, h) -> (xmin, ymin, xmax, ymax)
122- bboxs = util .bbox_center_to_corner (bboxs )
123-
124- ratio_h , ratio_w = data_dict ['ratio' ]
125- bboxs [:, 0 ] /= ratio_w
126- bboxs [:, 1 ] /= ratio_h
127- bboxs [:, 2 ] /= ratio_w
128- bboxs [:, 3 ] /= ratio_h
129-
130- # 最大最小值
131- h , w = data_dict ['src_size' ]
132- bboxs [:, 0 ] = np .maximum (bboxs [:, 0 ], 0 )
133- bboxs [:, 1 ] = np .maximum (bboxs [:, 1 ], 0 )
134- bboxs [:, 2 ] = np .minimum (bboxs [:, 2 ], w )
135- bboxs [:, 3 ] = np .minimum (bboxs [:, 3 ], h )
136-
137- return bboxs .astype (int )
138-
139-
14090def save_data (img_name , img , target_cates , target_bboxs , pred_cates , pred_probs , pred_bboxs ):
14191 """
14292 保存检测结果
@@ -174,7 +124,7 @@ def save_data(img_name, img, target_cates, target_bboxs, pred_cates, pred_probs,
174124if __name__ == '__main__' :
175125 # device = util.get_device()
176126 device = "cpu"
177- model = load_model (device )
127+ model = file . load_model (device , S , B , C )
178128
179129 transform = get_transform ()
180130 img_path_list , annotation_path_list = load_data ('./data/location_dataset' )
@@ -212,7 +162,7 @@ def save_data(img_name, img, target_cates, target_bboxs, pred_cates, pred_probs,
212162 pred_cate_bboxs [:, 3 ] = pred_bboxs [range (S * S ), pred_confidences_idxs * 4 + 3 ]
213163
214164 # 预测边界框的缩放,回到原始图像
215- pred_bboxs = deform_bboxs (pred_cate_bboxs , data_dict )
165+ pred_bboxs = util . deform_bboxs (pred_cate_bboxs , data_dict )
216166
217167 # 保存图像/标注边界框/预测边界框
218168 img_name = os .path .splitext (os .path .basename (img_path ))[0 ]
0 commit comments