Skip to content

Commit 76409b9

Browse files
committed
perf(train): 更新类别数和类别列表
1 parent 1e9610a commit 76409b9

File tree

4 files changed

+22
-10
lines changed

4 files changed

+22
-10
lines changed

py/lib/models/location_dataset.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -128,16 +128,20 @@ def __len__(self):
128128

129129

130130
if __name__ == '__main__':
131-
root_dir = '../../data/location_dataset/'
131+
# root_dir = '../../data/location_dataset/'
132+
root_dir = '../../data/VOC_dataset/'
132133
transform = transforms.Compose([
133134
transforms.ToPILImage(),
134135
transforms.Resize((448, 448)),
135136
transforms.ToTensor(),
136137
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
137138
])
138139

139-
cate_list = ['cucumber', 'eggplant', 'mushroom']
140-
data_set = LocationDataset(root_dir, cate_list, transform, 7, 2, 3)
140+
# cate_list = ['cucumber', 'eggplant', 'mushroom']
141+
# data_set = LocationDataset(root_dir, cate_list, transform, 7, 2, 3)
142+
cate_list = ['aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus', 'car', 'cat', 'chair', 'cow', 'diningtable',
143+
'dog', 'horse', 'motorbike', 'person', 'pottedplant', 'sheep', 'sofa', 'train', 'tvmonitor']
144+
data_set = LocationDataset(root_dir, cate_list, transform, 7, 2, 20)
141145
print(data_set)
142146
print(len(data_set))
143147

py/lib/models/multi_part_loss.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -329,15 +329,19 @@ def load_data(data_root_dir, cate_list, S=7, B=2, C=20):
329329
if __name__ == '__main__':
330330
S = 7
331331
B = 2
332-
C = 3
333-
cate_list = ['cucumber', 'eggplant', 'mushroom']
332+
# C = 3
333+
# cate_list = ['cucumber', 'eggplant', 'mushroom']
334+
C = 20
335+
cate_list = ['aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus', 'car', 'cat', 'chair', 'cow', 'diningtable',
336+
'dog', 'horse', 'motorbike', 'person', 'pottedplant', 'sheep', 'sofa', 'train', 'tvmonitor']
334337

335-
criterion = MultiPartLoss(448, 448, S=7, B=2, C=3)
338+
criterion = MultiPartLoss(448, 448, S=S, B=B, C=C)
336339
# preds = torch.arange(637).reshape(1, 7 * 7, 13) * 0.01
337340
# targets = torch.ones((1, 7 * 7, 13)) * 0.01
338341
# loss = criterion(preds, targets)
339342
# print(loss)
340-
data_loader = load_data('../../data/location_dataset', cate_list, S=S, B=B, C=C)
343+
# data_loader = load_data('../../data/location_dataset', cate_list, S=S, B=B, C=C)
344+
data_loader = load_data('../../data/VOC_dataset', cate_list, S=S, B=B, C=C)
341345
model = YOLO_v1(S=S, B=B, C=C)
342346

343347
for inputs, labels in data_loader:

py/lib/models/yolo_v1.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,8 @@ def forward(self, x):
8686
if __name__ == '__main__':
8787
# data = torch.randn((1, 3, 448, 448))
8888
data = torch.randn((1, 3, 224, 224))
89-
model = YOLO_v1(7, 2, 3)
89+
# model = YOLO_v1(7, 2, 3)
90+
model = YOLO_v1(7, 2, 20)
9091

9192
outputs = model(data)
9293
print(outputs.shape)

py/lib/train.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,12 @@
2121

2222
S = 7
2323
B = 2
24-
C = 3
2524

26-
cate_list = ['cucumber', 'eggplant', 'mushroom']
25+
# C = 3
26+
# cate_list = ['cucumber', 'eggplant', 'mushroom']
27+
C = 20
28+
cate_list = ['aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus', 'car', 'cat', 'chair', 'cow', 'diningtable',
29+
'dog', 'horse', 'motorbike', 'person', 'pottedplant', 'sheep', 'sofa', 'train', 'tvmonitor']
2730

2831

2932
def load_data(data_root_dir, cate_list, S=7, B=2, C=20):

0 commit comments

Comments
 (0)