Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/testing.yml
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ jobs:
shell: bash

- name: Run Model Benchmark Test
run: uv run dlc-live-test --nodisplay
run: uv run dlc-live-test

- name: Run DLC Live Unit Tests
run: uv run pytest
Expand Down
105 changes: 59 additions & 46 deletions dlclive/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,10 @@
import sys
import time
import warnings
from typing import TYPE_CHECKING
from pathlib import Path

import argparse
import os
import colorcet as cc
import cv2
import numpy as np
Expand All @@ -23,10 +25,12 @@

from dlclive import DLCLive
from dlclive import VERSION
from dlclive import __file__ as dlcfile
from dlclive.engine import Engine
from dlclive.utils import decode_fourcc

if TYPE_CHECKING:
import tensorflow # type: ignore


def download_benchmarking_data(
target_dir=".",
Expand All @@ -49,17 +53,20 @@ def download_benchmarking_data(
if os.path.exists(zip_path):
print(f"{zip_path} already exists. Skipping download.")
else:

def show_progress(count, block_size, total_size):
pbar.update(block_size)

print(f"Downloading the benchmarking data from {url} ...")
pbar = tqdm(unit="B", total=0, position=0, desc="Downloading")

filename, _ = urllib.request.urlretrieve(url, filename=zip_path, reporthook=show_progress)
filename, _ = urllib.request.urlretrieve(
url, filename=zip_path, reporthook=show_progress
)
pbar.close()

print(f"Extracting {zip_path} to {target_dir} ...")
with zipfile.ZipFile(zip_path, 'r') as zip_ref:
with zipfile.ZipFile(zip_path, "r") as zip_ref:
zip_ref.extractall(target_dir)


Expand All @@ -81,6 +88,7 @@ def benchmark_videos(
cmap="bmy",
save_poses=False,
save_video=False,
single_animal=False,
):
"""Analyze videos using DeepLabCut-live exported models.
Analyze multiple videos and/or multiple options for the size of the video
Expand Down Expand Up @@ -168,7 +176,7 @@ def benchmark_videos(
im_size_out = []

for i in range(len(resize)):
print(f"\nRun {i+1} / {len(resize)}\n")
print(f"\nRun {i + 1} / {len(resize)}\n")

this_inf_times, this_im_size, meta = benchmark(
model_path=model_path,
Expand All @@ -188,6 +196,7 @@ def benchmark_videos(
save_poses=save_poses,
save_video=save_video,
save_dir=output,
single_animal=single_animal,
)

inf_times.append(this_inf_times)
Expand Down Expand Up @@ -275,9 +284,7 @@ def get_system_info() -> dict:
}


def save_inf_times(
sys_info, inf_times, im_size, model=None, meta=None, output=None
):
def save_inf_times(sys_info, inf_times, im_size, model=None, meta=None, output=None):
"""Save inference time data collected using :function:`benchmark` with system information to a pickle file.
This is primarily used through :function:`benchmark_videos`

Expand Down Expand Up @@ -346,6 +353,7 @@ def save_inf_times(

return True


def benchmark(
model_path: str,
model_type: str,
Expand All @@ -357,8 +365,8 @@ def benchmark(
single_animal: bool = True,
cropping: list[int] | None = None,
dynamic: tuple[bool, float, int] = (False, 0.5, 10),
n_frames: int =1000,
print_rate: bool=False,
n_frames: int = 1000,
print_rate: bool = False,
precision: str = "FP32",
display: bool = True,
pcutoff: float = 0.5,
Expand Down Expand Up @@ -434,7 +442,10 @@ def benchmark(
if not cap.isOpened():
print(f"Error: Could not open video file {video_path}")
return
im_size = (int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)), int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)))
im_size = (
int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)),
int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)),
)

if pixels is not None:
resize = np.sqrt(pixels / (im_size[0] * im_size[1]))
Expand Down Expand Up @@ -492,9 +503,7 @@ def benchmark(

total_n_frames = cap.get(cv2.CAP_PROP_FRAME_COUNT)
n_frames = int(
n_frames
if (n_frames > 0) and n_frames < total_n_frames
else total_n_frames
n_frames if (n_frames > 0) and n_frames < total_n_frames else total_n_frames
)
iterator = range(n_frames) if print_rate or display else tqdm(range(n_frames))
for _ in iterator:
Expand All @@ -510,7 +519,7 @@ def benchmark(

start_time = time.perf_counter()
if frame_index == 0:
pose = dlc_live.init_inference(frame) # Loads model
pose = dlc_live.init_inference(frame) # Loads model
else:
pose = dlc_live.get_pose(frame)

Expand All @@ -519,7 +528,9 @@ def benchmark(
times.append(inf_time)

if print_rate:
print("Inference rate = {:.3f} FPS".format(1 / inf_time), end="\r", flush=True)
print(
"Inference rate = {:.3f} FPS".format(1 / inf_time), end="\r", flush=True
)

if save_video:
draw_pose_and_write(
Expand All @@ -531,19 +542,17 @@ def benchmark(
pcutoff=pcutoff,
display_radius=display_radius,
draw_keypoint_names=draw_keypoint_names,
vwriter=vwriter
vwriter=vwriter,
)

frame_index += 1

if print_rate:
print("Mean inference rate: {:.3f} FPS".format(np.mean(1 / np.array(times)[1:])))
print(
"Mean inference rate: {:.3f} FPS".format(np.mean(1 / np.array(times)[1:]))
)

metadata = _get_metadata(
video_path=video_path,
cap=cap,
dlc_live=dlc_live
)
metadata = _get_metadata(video_path=video_path, cap=cap, dlc_live=dlc_live)

cap.release()

Expand All @@ -558,19 +567,21 @@ def benchmark(
else:
individuals = []
n_individuals = len(individuals) or 1
save_poses_to_files(video_path, save_dir, n_individuals, bodyparts, poses, timestamp=timestamp)
save_poses_to_files(
video_path, save_dir, n_individuals, bodyparts, poses, timestamp=timestamp
)

return times, im_size, metadata


def setup_video_writer(
video_path:str,
save_dir:str,
timestamp:str,
num_keypoints:int,
cmap:str,
fps:float,
frame_size:tuple[int, int],
video_path: str,
save_dir: str,
timestamp: str,
num_keypoints: int,
cmap: str,
fps: float,
frame_size: tuple[int, int],
):
# Set colors and convert to RGB
cmap_colors = getattr(cc, cmap)
Expand All @@ -582,7 +593,9 @@ def setup_video_writer(
# Define output video path
video_path = Path(video_path)
video_name = video_path.stem # filename without extension
output_video_path = Path(save_dir) / f"{video_name}_DLCLIVE_LABELLED_{timestamp}.mp4"
output_video_path = (
Path(save_dir) / f"{video_name}_DLCLIVE_LABELLED_{timestamp}.mp4"
)

# Get video writer setup
fourcc = cv2.VideoWriter_fourcc(*"mp4v")
Expand All @@ -595,6 +608,7 @@ def setup_video_writer(

return colors, vwriter


def draw_pose_and_write(
frame: np.ndarray,
pose: np.ndarray,
Expand All @@ -611,7 +625,9 @@ def draw_pose_and_write(

if resize is not None and resize != 1.0:
# Resize the frame
frame = cv2.resize(frame, None, fx=resize, fy=resize, interpolation=cv2.INTER_LINEAR)
frame = cv2.resize(
frame, None, fx=resize, fy=resize, interpolation=cv2.INTER_LINEAR
)

# Scale pose coordinates
pose = pose.copy()
Expand Down Expand Up @@ -642,15 +658,10 @@ def draw_pose_and_write(
lineType=cv2.LINE_AA,
)


vwriter.write(image=frame)


def _get_metadata(
video_path: str,
cap: cv2.VideoCapture,
dlc_live: DLCLive
):
def _get_metadata(video_path: str, cap: cv2.VideoCapture, dlc_live: DLCLive):
try:
fourcc = decode_fourcc(cap.get(cv2.CAP_PROP_FOURCC))
except Exception:
Expand Down Expand Up @@ -687,7 +698,9 @@ def _get_metadata(
return meta


def save_poses_to_files(video_path, save_dir, n_individuals, bodyparts, poses, timestamp):
def save_poses_to_files(
video_path, save_dir, n_individuals, bodyparts, poses, timestamp
):
"""
Saves the detected keypoint poses from the video to CSV and HDF5 files.

Expand Down Expand Up @@ -725,14 +738,16 @@ def save_poses_to_files(video_path, save_dir, n_individuals, bodyparts, poses, t
else:
individuals = [f"individual_{i}" for i in range(n_individuals)]
pdindex = pd.MultiIndex.from_product(
[individuals, bodyparts, ["x", "y", "likelihood"]], names=["individuals", "bodyparts", "coords"]
[individuals, bodyparts, ["x", "y", "likelihood"]],
names=["individuals", "bodyparts", "coords"],
)

pose_df = pd.DataFrame(flattened_poses, columns=pdindex)

pose_df.to_hdf(h5_save_path, key="df_with_missing", mode="w")
pose_df.to_csv(csv_save_path, index=False)


def _create_poses_np_array(n_individuals: int, bodyparts: list, poses: list):
# Create numpy array with poses:
max_frame = max(p["frame"] for p in poses)
Expand All @@ -745,17 +760,15 @@ def _create_poses_np_array(n_individuals: int, bodyparts: list, poses: list):
if pose.ndim == 2:
pose = pose[np.newaxis, :, :]
padded_pose = np.full(pose_target_shape, np.nan)
slices = tuple(slice(0, min(pose.shape[i], pose_target_shape[i])) for i in range(3))
slices = tuple(
slice(0, min(pose.shape[i], pose_target_shape[i])) for i in range(3)
)
padded_pose[slices] = pose[slices]
poses_array[frame] = padded_pose

return poses_array


import argparse
import os


def main():
"""Provides a command line interface to benchmark_videos function."""
parser = argparse.ArgumentParser(
Expand Down
Loading