Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 9 additions & 2 deletions main/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
import cv2
from tqdm import tqdm
import json
from typing import Literal, Union
from mmdet.apis import init_detector, inference_detector
from utils.inference_utils import process_mmdet_results, non_max_suppression

Expand All @@ -36,6 +35,7 @@ def parse_args():
parser.add_argument('--multi_person', action="store_true")
parser.add_argument('--iou_thr', type=float, default=0.5)
parser.add_argument('--bbox_thr', type=int, default=50)
parser.add_argument('--min_score', type=float, default=0.3)
args = parser.parse_args()
return args

Expand Down Expand Up @@ -82,7 +82,14 @@ def main():

## mmdet inference
mmdet_results = inference_detector(model, img_path)
mmdet_box = process_mmdet_results(mmdet_results, cat_id=0, multi_person=True)
person_dets = [d for d in mmdet_results[0] if (len(d) >= 5 and d[4] >= args.min_score)]
if not multi_person and person_dets:
# Sort by score
person_dets_sorted = sorted(person_dets, key=lambda d: -d[4])
mmdet_box = [person_dets_sorted]
else:
mmdet_results = [np.asarray(person_dets)] + list(mmdet_results[1:])
mmdet_box = process_mmdet_results(mmdet_results, cat_id=0, multi_person=True)

# save original image if no bbox
if len(mmdet_box[0])<1:
Expand Down