diff --git a/.github/workflows/format.yml b/.github/workflows/format.yml new file mode 100644 index 00000000..95450e58 --- /dev/null +++ b/.github/workflows/format.yml @@ -0,0 +1,61 @@ +name: pre-commit (PR only on changed files) + +on: + pull_request: + types: [opened, synchronize, reopened] + +jobs: + detect_changes: + runs-on: ubuntu-latest + outputs: + changed: ${{ steps.changed_files.outputs.changed }} + + steps: + - name: Checkout full history + uses: actions/checkout@v4 + with: + fetch-depth: 0 + + - name: Detect changed files + id: changed_files + run: | + git fetch origin ${{ github.base_ref }} + CHANGED_FILES=$(git diff --name-only origin/${{ github.base_ref }}...HEAD) + + { + echo "changed<> "$GITHUB_OUTPUT" + + - name: Show changed files + run: | + echo "Changed files:" + echo "${{ steps.changed_files.outputs.changed }}" + + precommit: + needs: detect_changes + runs-on: ubuntu-latest + if: ${{ needs.detect_changes.outputs.changed != '' }} + + steps: + - name: Checkout PR branch + uses: actions/checkout@v4 + with: + fetch-depth: 0 + ref: ${{ github.head_ref }} + + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: "3.12" + + - name: Install pre-commit + run: pip install pre-commit + + - name: Run pre-commit (CI check-only stage) on changed files + env: + CHANGED_FILES: ${{ needs.detect_changes.outputs.changed }} + run: | + mapfile -t files <<< "$CHANGED_FILES" + pre-commit run --hook-stage manual --files "${files[@]}" --show-diff-on-failure diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 7e74fb5f..b69e80a1 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,28 +1,61 @@ +default_stages: [pre-commit] + repos: - repo: https://github.com/pre-commit/pre-commit-hooks rev: v6.0.0 hooks: + # These are safe to run in both local & CI (they don't require "fix vs check" split) - id: check-added-large-files + stages: [pre-commit, manual] - id: check-yaml + stages: [pre-commit, manual] - id: check-toml + stages: [pre-commit, manual] + - id: check-merge-conflict + stages: [pre-commit, manual] + + # These modify files. Run locally only (pre-commit stage). - id: end-of-file-fixer - - id: name-tests-test - args: [--pytest-test-first] + stages: [pre-commit] - id: trailing-whitespace - - id: check-merge-conflict + stages: [pre-commit] + - repo: https://github.com/tox-dev/pyproject-fmt rev: v2.15.2 hooks: - id: pyproject-fmt + stages: [pre-commit] # modifies -> local only + - repo: https://github.com/abravalheri/validate-pyproject rev: v0.25 hooks: - id: validate-pyproject + stages: [pre-commit, manual] + - repo: https://github.com/astral-sh/ruff-pre-commit rev: v0.15.0 hooks: - # Run the formatter. + # -------------------------- + # LOCAL AUTOFIX (developers) + # -------------------------- + - id: ruff-check + name: ruff-check (fix) + args: [--fix, --unsafe-fixes] + stages: [pre-commit] + - id: ruff-format - # Run the linter. + name: ruff-format (write) + stages: [pre-commit] + + # -------------------------- + # CI CHECK-ONLY (no writes) + # -------------------------- - id: ruff-check - args: [--fix,--unsafe-fixes] \ No newline at end of file + name: ruff-check (ci) + args: [--output-format=github] + stages: [manual] + + - id: ruff-format + name: ruff-format (ci) + args: [--check, --diff] + stages: [manual] diff --git a/MANIFEST.in b/MANIFEST.in index 6e2c2842..4d91463f 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -1,3 +1,3 @@ include dlclive/check_install/* include dlclive/modelzoo/model_configs/*.yaml -include dlclive/modelzoo/project_configs/*.yaml \ No newline at end of file +include dlclive/modelzoo/project_configs/*.yaml diff --git a/README.md b/README.md index 05d8b752..505e4c51 100644 --- a/README.md +++ b/README.md @@ -16,12 +16,12 @@ pipeline for real-time applications that has minimal (software) dependencies. Th is as easy to install as possible (in particular, on atypical systems like [ NVIDIA Jetson boards](https://developer.nvidia.com/buy-jetson)). -If you've used DeepLabCut-Live with TensorFlow models and want to try the PyTorch +If you've used DeepLabCut-Live with TensorFlow models and want to try the PyTorch version, take a look at [_Switching from TensorFlow to PyTorch_]( #Switching-from-TensorFlow-to-PyTorch) -**Performance of TensorFlow models:** If you would like to see estimates on how your -model should perform given different video sizes, neural network type, and hardware, +**Performance of TensorFlow models:** If you would like to see estimates on how your +model should perform given different video sizes, neural network type, and hardware, please see: [deeplabcut.github.io/DLC-inferencespeed-benchmark/ ](https://deeplabcut.github.io/DLC-inferencespeed-benchmark/). **We're working on getting these benchmarks for PyTorch architectures as well.** @@ -29,13 +29,13 @@ getting these benchmarks for PyTorch architectures as well.** If you have different hardware, please consider [submitting your results too]( https://github.com/DeepLabCut/DLC-inferencespeed-benchmark)! -**What this SDK provides:** This package provides a `DLCLive` class which enables pose +**What this SDK provides:** This package provides a `DLCLive` class which enables pose estimation online to provide feedback. This object loads and prepares a DeepLabCut network for inference, and will return the predicted pose for single images. -To perform processing on poses (such as predicting the future pose of an animal given -its current pose, or to trigger external hardware like send TTL pulses to a laser for -optogenetic stimulation), this object takes in a `Processor` object. Processor objects +To perform processing on poses (such as predicting the future pose of an animal given +its current pose, or to trigger external hardware like send TTL pulses to a laser for +optogenetic stimulation), this object takes in a `Processor` object. Processor objects must contain two methods: `process` and `save`. - The `process` method takes in a pose, performs some processing, and returns processed @@ -44,20 +44,20 @@ pose. For more details and examples, see documentation [here](dlclive/processor/README.md). -**🔥🔥🔥🔥🔥 Note :: alone, this object does not record video or capture images from a +**🔥🔥🔥🔥🔥 Note :: alone, this object does not record video or capture images from a camera. This must be done separately, i.e. see our [DeepLabCut-live GUI]( https://github.com/DeepLabCut/DeepLabCut-live-GUI).🔥🔥🔥🔥🔥** ### News! - **WIP 2025**: DeepLabCut-Live is implemented for models trained with the PyTorch engine! -- March 2022: DeepLabCut-Live! 1.0.2 supports poetry installation `poetry install +- March 2022: DeepLabCut-Live! 1.0.2 supports poetry installation `poetry install deeplabcut-live`, thanks to PR #60. -- March 2021: DeepLabCut-Live! [**version 1.0** is released](https://pypi.org/project/deeplabcut-live/), with support for +- March 2021: DeepLabCut-Live! [**version 1.0** is released](https://pypi.org/project/deeplabcut-live/), with support for tensorflow 1 and tensorflow 2! - Feb 2021: DeepLabCut-Live! was featured in **Nature Methods**: ["Real-time behavioral analysis"](https://www.nature.com/articles/s41592-021-01072-z) -- Jan 2021: full **eLife** paper is published: ["Real-time, low-latency closed-loop +- Jan 2021: full **eLife** paper is published: ["Real-time, low-latency closed-loop feedback using markerless posture tracking"](https://elifesciences.org/articles/61909) - Dec 2020: we talked to **RTS Suisse Radio** about DLC-Live!: ["Capture animal movements in real time"]( @@ -65,27 +65,27 @@ https://www.rts.ch/play/radio/cqfd/audio/capturer-les-mouvements-des-animaux-en- ### Installation -DeepLabCut-live can be installed from PyPI with PyTorch or Tensorflow directly: +DeepLabCut-live can be installed from PyPI with PyTorch or Tensorflow directly: ```bash # With PyTorch (recommended) pip install deeplabcut-live[pytorch] - + # Or with TensorFlow pip install deeplabcut-live[tf] - + # Or using uv uv pip install deeplabcut-live[pytorch] # or [tf] ``` -Note: On **Windows**, the `deeplabcut-live[pytorch]` extra will not install the required CUDA-enabled wheels for PyTorch by default. For GPU support, install CUDA-enabled PyTorch first, then install `deeplabcut-live[pytorch]`. +Note: On **Windows**, the `deeplabcut-live[pytorch]` extra will not install the required CUDA-enabled wheels for PyTorch by default. For GPU support, install CUDA-enabled PyTorch first, then install `deeplabcut-live[pytorch]`. Please see our instruction manual for more elaborate information on how to install on a [Windows or Linux machine]( docs/install_desktop.md) or on a [NVIDIA Jetson Development Board]( -docs/install_jetson.md). +docs/install_jetson.md). This code works with PyTorch, TensorFlow 1 or TensorFlow -2 models, but whatever engine you exported your model with, you must import with the -same version (i.e., export a PyTorch model, then install PyTorch, export with TF1.13, +2 models, but whatever engine you exported your model with, you must import with the +same version (i.e., export a PyTorch model, then install PyTorch, export with TF1.13, then use TF1.13 with DlC-Live; export with TF2.3, then use TF2.3 with DLC-live). You can test your installation by running: @@ -139,7 +139,7 @@ dlc_live.get_pose() - `index 0` = use dynamic cropping, bool - `index 1` = detection threshold, float - `index 2` = margin (in pixels) around identified points, int -- `resize` = float, optional; factor by which to resize image (resize=0.5 downsizes +- `resize` = float, optional; factor by which to resize image (resize=0.5 downsizes both width and height of image by half). Can be used to downsize large images for faster inference - `processor` = dlc pose processor object, optional @@ -148,43 +148,43 @@ dlc_live.get_pose() `DLCLive` **inputs:** -- `` = - - For TensorFlow models: path to the folder that has the `.pb` files that you +- `` = + - For TensorFlow models: path to the folder that has the `.pb` files that you acquire after running `deeplabcut.export_model` - - For PyTorch models: path to the `.pt` file that is generated after running + - For PyTorch models: path to the `.pt` file that is generated after running `deeplabcut.export_model` - `` = is a numpy array of each frame #### DLCLive - PyTorch Specific Guide -This guide is for users who trained a model with the PyTorch engine with +This guide is for users who trained a model with the PyTorch engine with `DeepLabCut 3.0`. Once you've trained your model in [DeepLabCut](https://github.com/DeepLabCut/DeepLabCut) -and you are happy with its performance, you can export the model to be used for live +and you are happy with its performance, you can export the model to be used for live inference with DLCLive! ### Switching from TensorFlow to PyTorch -This section is for users who **have already used DeepLabCut-Live** with +This section is for users who **have already used DeepLabCut-Live** with TensorFlow models (through DeepLabCut 1.X or 2.X) and want to switch to using the PyTorch Engine. Some quick notes: - You may need to adapt your code slightly when creating the DLCLive instance. -- Processors that were created for TensorFlow models will function the same way with -PyTorch models. As multi-animal models can be used with PyTorch, the shape of the `pose` +- Processors that were created for TensorFlow models will function the same way with +PyTorch models. As multi-animal models can be used with PyTorch, the shape of the `pose` array given to the processor may be `(num_individuals, num_keypoints, 3)`. Just call `DLCLive(..., single_animal=True)` and it will work. ### Benchmarking/Analyzing your exported DeepLabCut models -DeepLabCut-live offers some analysis tools that allow users to perform the following +DeepLabCut-live offers some analysis tools that allow users to perform the following operations on videos, from python or from the command line: #### Test inference speed across a range of image sizes -Downsizing images can be done by specifying the `resize` or `pixels` parameter. Using -the `pixels` parameter will resize images to the desired number of `pixels`, without +Downsizing images can be done by specifying the `resize` or `pixels` parameter. Using +the `pixels` parameter will resize images to the desired number of `pixels`, without changing the aspect ratio. Results will be saved (along with system info) to a pickle file if you specify an output directory. @@ -192,7 +192,7 @@ Inside a **python** shell or script, you can run: ```python dlclive.benchmark_videos( - "/path/to/exported/model", + "/path/to/exported/model", ["/path/to/video1", "/path/to/video2"], output="/path/to/output", resize=[1.0, 0.75, '0.5'], @@ -211,7 +211,7 @@ Inside a **python** shell or script, you can run: ```python dlclive.benchmark_videos( - "/path/to/exported/model", + "/path/to/exported/model", "/path/to/video", resize=0.5, display=True, @@ -229,7 +229,7 @@ dlc-live-benchmark /path/to/exported/model /path/to/video -r 0.5 --display --pcu #### Analyze and create a labeled video using the exported model and desired resize parameters. -This option functions similar to `deeplabcut.benchmark_videos` and +This option functions similar to `deeplabcut.benchmark_videos` and `deeplabcut.create_labeled_video` (note, this is slow and only for testing purposes). Inside a **python** shell or script, you can run: @@ -255,9 +255,9 @@ dlc-live-benchmark /path/to/exported/model /path/to/video -r 0.5 --pcutoff 0.5 - ## License: -This project is licensed under the GNU AGPLv3. Note that the software is provided "as -is", without warranty of any kind, express or implied. If you use the code or data, we -ask that you please cite us! This software is available for licensing via the EPFL +This project is licensed under the GNU AGPLv3. Note that the software is provided "as +is", without warranty of any kind, express or implied. If you use the code or data, we +ask that you please cite us! This software is available for licensing via the EPFL Technology Transfer Office (https://tto.epfl.ch/, info.tto@epfl.ch). ## Community Support, Developers, & Help: @@ -270,9 +270,9 @@ https://github.com/DeepLabCut/DeepLabCut/blob/master/CONTRIBUTING.md), which is at the main repository of DeepLabCut. - We are a community partner on the [![Image.sc forum](https://img.shields.io/badge/dynamic/json.svg?label=forum&url=https%3A%2F%2Fforum.image.sc%2Ftags%2Fdeeplabcut.json&query=%24.topic_list.tags.0.topic_count&colorB=brightgreen&&suffix=%20topics&logo=)](https://forum.image.sc/tags/deeplabcut). Please post help and support questions on the forum with the tag DeepLabCut. Check out their mission -statement [Scientific Community Image Forum: A discussion forum for scientific image +statement [Scientific Community Image Forum: A discussion forum for scientific image software](https://journals.plos.org/plosbiology/article?id=10.1371/journal.pbio.3000340). -- If you encounter a previously unreported bug/code issue, please post here (we +- If you encounter a previously unreported bug/code issue, please post here (we encourage you to search issues first): [github.com/DeepLabCut/DeepLabCut-live/issues]( https://github.com/DeepLabCut/DeepLabCut-live/issues) - For quick discussions here: [![Gitter]( @@ -281,7 +281,7 @@ https://gitter.im/DeepLabCut/community?utm_source=badge&utm_medium=badge&utm_cam ### Reference: -If you utilize our tool, please [cite Kane et al, eLife 2020](https://elifesciences.org/articles/61909). The preprint is +If you utilize our tool, please [cite Kane et al, eLife 2020](https://elifesciences.org/articles/61909). The preprint is available here: https://www.biorxiv.org/content/10.1101/2020.08.04.236422v2 ``` diff --git a/benchmarking/run_dlclive_benchmark.py b/benchmarking/run_dlclive_benchmark.py index 859843b0..35e9b14f 100644 --- a/benchmarking/run_dlclive_benchmark.py +++ b/benchmarking/run_dlclive_benchmark.py @@ -8,15 +8,14 @@ # Script for running the official benchmark from Kane et al, 2020. # Please share your results at https://github.com/DeepLabCut/DLC-inferencespeed-benchmark -import os, pathlib import glob +import os +import pathlib from dlclive import benchmark_videos, download_benchmarking_data from dlclive.engine import Engine -datafolder = os.path.join( - pathlib.Path(__file__).parent.absolute(), "Data-DLC-live-benchmark" -) +datafolder = os.path.join(pathlib.Path(__file__).parent.absolute(), "Data-DLC-live-benchmark") if not os.path.isdir(datafolder): # only download if data doesn't exist! # Downloading data.... this takes a while (see terminal) @@ -44,7 +43,7 @@ video_path=dog_video, output=out_dir, n_frames=n_frames, - pixels=pixels + pixels=pixels, ) for model_path in mouse_models: @@ -54,5 +53,5 @@ video_path=mouse_video, output=out_dir, n_frames=n_frames, - pixels=pixels + pixels=pixels, ) diff --git a/dlclive/__init__.py b/dlclive/__init__.py index c042b4f4..15d07319 100644 --- a/dlclive/__init__.py +++ b/dlclive/__init__.py @@ -8,10 +8,23 @@ # Check which backends are installed and get available backends # (Emits a warning if neither TensorFlow nor PyTorch is installed) from dlclive.utils import get_available_backends + _AVAILABLE_BACKENDS = get_available_backends() +from dlclive.benchmark import benchmark_videos, download_benchmarking_data from dlclive.display import Display from dlclive.dlclive import DLCLive +from dlclive.engine import Engine from dlclive.processor.processor import Processor from dlclive.version import VERSION, __version__ -from dlclive.benchmark import benchmark_videos, download_benchmarking_data \ No newline at end of file + +__all__ = [ + "DLCLive", + "Display", + "Processor", + "Engine", + "benchmark_videos", + "download_benchmarking_data", + "VERSION", + "__version__", +] diff --git a/dlclive/benchmark.py b/dlclive/benchmark.py index 8b4a0b45..fd6dbb3b 100644 --- a/dlclive/benchmark.py +++ b/dlclive/benchmark.py @@ -5,28 +5,34 @@ Licensed under GNU Lesser General Public License v3.0 """ +import argparse +import os +import pickle import platform import subprocess import sys import time import warnings from pathlib import Path +from typing import TYPE_CHECKING import colorcet as cc import cv2 import numpy as np -import pickle +import torch from PIL import ImageColor from pip._internal.operations import freeze -import torch from tqdm import tqdm -from dlclive import DLCLive -from dlclive import VERSION -from dlclive import __file__ as dlcfile from dlclive.engine import Engine from dlclive.utils import decode_fourcc +from .dlclive import DLCLive +from .version import VERSION + +if TYPE_CHECKING: + import tensorflow + def download_benchmarking_data( target_dir=".", @@ -49,6 +55,7 @@ def download_benchmarking_data( if os.path.exists(zip_path): print(f"{zip_path} already exists. Skipping download.") else: + def show_progress(count, block_size, total_size): pbar.update(block_size) @@ -59,7 +66,7 @@ def show_progress(count, block_size, total_size): pbar.close() print(f"Extracting {zip_path} to {target_dir} ...") - with zipfile.ZipFile(zip_path, 'r') as zip_ref: + with zipfile.ZipFile(zip_path, "r") as zip_ref: zip_ref.extractall(target_dir) @@ -106,15 +113,19 @@ def benchmark_videos( resize : int, optional resize factor. Can only use one of resize or pixels. If both are provided, will use pixels. by default None pixels : int, optional - downsize image to this number of pixels, maintaining aspect ratio. Can only use one of resize or pixels. If both are provided, will use pixels. by default None + downsize image to this number of pixels, maintaining aspect ratio. Can only use one of resize or pixels. + If both are provided, will use pixels. by default None cropping : list of int cropping parameters in pixel number: [x1, x2, y1, y2] dynamic: triple containing (state, detectiontreshold, margin) - If the state is true, then dynamic cropping will be performed. That means that if an object is detected (i.e. any body part > detectiontreshold), - then object boundaries are computed according to the smallest/largest x position and smallest/largest y position of all body parts. This window is - expanded by the margin and from then on only the posture within this crop is analyzed (until the object is lost, i.e. detectiontreshold), + then object boundaries are computed according to the + smallest/largest x position and smallest/largest y position of all body parts. + This window is expanded by the margin and from then on only + the posture within this crop is analyzed (until the object is lost, i.e. `, by default "bmy" + a string indicating the :package:`colorcet` colormap, `options here `, + by default "bmy" save_poses : bool, optional - flag to save poses to an hdf5 file. If True, operates similar to :function:`DeepLabCut.benchmark_videos`, by default False + flag to save poses to an hdf5 file. If True, operates similar to :function:`DeepLabCut.benchmark_videos`, + by default False save_video : bool, optional - flag to save a labeled video. If True, operates similar to :function:`DeepLabCut.create_labeled_video`, by default False + flag to save a labeled video. + If True, operates similar to :function:`DeepLabCut.create_labeled_video`, by default False Example ------- @@ -138,13 +152,17 @@ def benchmark_videos( dlclive.benchmark_videos('/my/exported/model', 'my_video.avi', n_frames=10000) dlclive.benchmark_videos('/my/exported/model', ['my_video1.avi', 'my_video2.avi'], n_frames=10000) - Return a vector of inference times, testing full size and resizing images to half the width and height for inference, for two videos - dlclive.benchmark_videos('/my/exported/model', ['my_video1.avi', 'my_video2.avi'], n_frames=10000, resize=[1.0, 0.5]) + Return a vector of inference times, testing full size and resizing images + to half the width and height for inference, for two videos + dlclive.benchmark_videos( + '/my/exported/model', ['my_video1.avi', 'my_video2.avi'], n_frames=10000, resize=[1.0, 0.5] + ) Display keypoints to check the accuracy of an exported model dlclive.benchmark_videos('/my/exported/model', 'my_video.avi', display=True) - Analyze a video (save poses to hdf5) and create a labeled video, similar to :function:`DeepLabCut.benchmark_videos` and :function:`create_labeled_video` + Analyze a video (save poses to hdf5) and create a labeled video, + similar to :function:`DeepLabCut.benchmark_videos` and :function:`create_labeled_video` dlclive.benchmark_videos('/my/exported/model', 'my_video.avi', save_poses=True, save_video=True) """ # convert video_paths to list @@ -168,7 +186,7 @@ def benchmark_videos( im_size_out = [] for i in range(len(resize)): - print(f"\nRun {i+1} / {len(resize)}\n") + print(f"\nRun {i + 1} / {len(resize)}\n") this_inf_times, this_im_size, meta = benchmark( model_path=model_path, @@ -243,11 +261,7 @@ def get_system_info() -> dict: git_hash = None dlc_basedir = os.path.dirname(os.path.dirname(__file__)) try: - git_hash = ( - subprocess.check_output(["git", "rev-parse", "HEAD"], cwd=dlc_basedir) - .decode("utf-8") - .strip() - ) + git_hash = subprocess.check_output(["git", "rev-parse", "HEAD"], cwd=dlc_basedir).decode("utf-8").strip() except subprocess.CalledProcessError: # Not installed from git repo, e.g., pypi pass @@ -275,9 +289,7 @@ def get_system_info() -> dict: } -def save_inf_times( - sys_info, inf_times, im_size, model=None, meta=None, output=None -): +def save_inf_times(sys_info, inf_times, im_size, model=None, meta=None, output=None): """Save inference time data collected using :function:`benchmark` with system information to a pickle file. This is primarily used through :function:`benchmark_videos` @@ -314,9 +326,7 @@ def save_inf_times( model_type = None fn_ind = 0 - base_name = ( - f"benchmark_{sys_info['host_name']}_{sys_info['device_type']}_{fn_ind}.pickle" - ) + base_name = f"benchmark_{sys_info['host_name']}_{sys_info['device_type']}_{fn_ind}.pickle" out_file = os.path.normpath(f"{output}/{base_name}") while os.path.isfile(out_file): fn_ind += 1 @@ -327,6 +337,7 @@ def save_inf_times( stats = zip( np.mean(inf_times, 1), np.std(inf_times, 1) * 1.0 / np.sqrt(np.shape(inf_times)[1]), + strict=False, ) data = { @@ -346,6 +357,7 @@ def save_inf_times( return True + def benchmark( model_path: str, model_type: str, @@ -357,8 +369,8 @@ def benchmark( single_animal: bool = True, cropping: list[int] | None = None, dynamic: tuple[bool, float, int] = (False, 0.5, 10), - n_frames: int =1000, - print_rate: bool=False, + n_frames: int = 1000, + print_rate: bool = False, precision: str = "FP32", display: bool = True, pcutoff: float = 0.5, @@ -370,7 +382,8 @@ def benchmark( draw_keypoint_names: bool = False, ): """ - Analyzes a video to track keypoints using a DeepLabCut model, and optionally saves the keypoint data and the labeled video. + Analyzes a video to track keypoints using a DeepLabCut model, + and optionally saves the keypoint data and the labeled video. Parameters ---------- @@ -387,7 +400,8 @@ def benchmark( device : str Pytorch only. Device to run the model on ('cpu' or 'cuda'). resize : float or None, optional - Resize dimensions for video frames. e.g. if resize = 0.5, the video will be processed in half the original size. If None, no resizing is applied. + Resize dimensions for video frames. e.g. if resize = 0.5, + the video will be processed in half the original size. If None, no resizing is applied. pixels : int, optional downsize image to this number of pixels, maintaining aspect ratio. Can only use one of resize or pixels. If both are provided, will use pixels. @@ -396,7 +410,14 @@ def benchmark( cropping : list of int or None, optional Cropping parameters [x1, x2, y1, y2] in pixels. If None, no cropping is applied. dynamic : tuple, optional, default=(False, 0.5, 10) (True/false), p cutoff, margin) - Parameters for dynamic cropping. If the state is true, then dynamic cropping will be performed. That means that if an object is detected (i.e. any body part > detectiontreshold), then object boundaries are computed according to the smallest/largest x position and smallest/largest y position of all body parts. This window is expanded by the margin and from then on only the posture within this crop is analyzed (until the object is lost, i.e. detectiontreshold), + then object boundaries are computed according to the smallest/largest x position and smallest/largest y + position of all body parts. This window is expanded by the margin and from then on only the posture within + this crop is analyzed (until the object is lost, i.e. 0) and n_frames < total_n_frames - else total_n_frames - ) + n_frames = int(n_frames if (n_frames > 0) and n_frames < total_n_frames else total_n_frames) iterator = range(n_frames) if print_rate or display else tqdm(range(n_frames)) for _ in iterator: ret, frame = cap.read() if not ret: warnings.warn( - ( - "Did not complete {:d} frames. " - "There probably were not enough frames in the video {}." - ).format(n_frames, video_path) + f"Did not complete {n_frames:d} frames." + " There probably were not enough frames in the video {video_path}.", + stacklevel=2, ) break start_time = time.perf_counter() if frame_index == 0: - pose = dlc_live.init_inference(frame) # Loads model + pose = dlc_live.init_inference(frame) # Loads model else: pose = dlc_live.get_pose(frame) @@ -519,7 +535,7 @@ def benchmark( times.append(inf_time) if print_rate: - print("Inference rate = {:.3f} FPS".format(1 / inf_time), end="\r", flush=True) + print(f"Inference rate = {1 / inf_time:.3f} FPS", end="\r", flush=True) if save_video: draw_pose_and_write( @@ -531,19 +547,15 @@ def benchmark( pcutoff=pcutoff, display_radius=display_radius, draw_keypoint_names=draw_keypoint_names, - vwriter=vwriter + vwriter=vwriter, ) frame_index += 1 if print_rate: - print("Mean inference rate: {:.3f} FPS".format(np.mean(1 / np.array(times)[1:]))) + print(f"Mean inference rate: {np.mean(1 / np.array(times)[1:]):.3f} FPS") - metadata = _get_metadata( - video_path=video_path, - cap=cap, - dlc_live=dlc_live - ) + metadata = _get_metadata(video_path=video_path, cap=cap, dlc_live=dlc_live) cap.release() @@ -564,20 +576,17 @@ def benchmark( def setup_video_writer( - video_path:str, - save_dir:str, - timestamp:str, - num_keypoints:int, - cmap:str, - fps:float, - frame_size:tuple[int, int], + video_path: str, + save_dir: str, + timestamp: str, + num_keypoints: int, + cmap: str, + fps: float, + frame_size: tuple[int, int], ): # Set colors and convert to RGB cmap_colors = getattr(cc, cmap) - colors = [ - ImageColor.getrgb(color) - for color in cmap_colors[:: int(len(cmap_colors) / num_keypoints)] - ] + colors = [ImageColor.getrgb(color) for color in cmap_colors[:: int(len(cmap_colors) / num_keypoints)]] # Define output video path video_path = Path(video_path) @@ -595,6 +604,7 @@ def setup_video_writer( return colors, vwriter + def draw_pose_and_write( frame: np.ndarray, pose: np.ndarray, @@ -642,15 +652,10 @@ def draw_pose_and_write( lineType=cv2.LINE_AA, ) - vwriter.write(image=frame) -def _get_metadata( - video_path: str, - cap: cv2.VideoCapture, - dlc_live: DLCLive -): +def _get_metadata(video_path: str, cap: cv2.VideoCapture, dlc_live: DLCLive): try: fourcc = decode_fourcc(cap.get(cv2.CAP_PROP_FOURCC)) except Exception: @@ -708,7 +713,7 @@ def save_poses_to_files(video_path, save_dir, n_individuals, bodyparts, poses, t ------- None """ - import pandas as pd + import pandas as pd # noqa E402 base_filename = Path(video_path).stem save_dir = Path(save_dir) @@ -719,9 +724,7 @@ def save_poses_to_files(video_path, save_dir, n_individuals, bodyparts, poses, t flattened_poses = poses_array.reshape(poses_array.shape[0], -1) if n_individuals == 1: - pdindex = pd.MultiIndex.from_product( - [bodyparts, ["x", "y", "likelihood"]], names=["bodyparts", "coords"] - ) + pdindex = pd.MultiIndex.from_product([bodyparts, ["x", "y", "likelihood"]], names=["bodyparts", "coords"]) else: individuals = [f"individual_{i}" for i in range(n_individuals)] pdindex = pd.MultiIndex.from_product( @@ -733,6 +736,7 @@ def save_poses_to_files(video_path, save_dir, n_individuals, bodyparts, poses, t pose_df.to_hdf(h5_save_path, key="df_with_missing", mode="w") pose_df.to_csv(csv_save_path, index=False) + def _create_poses_np_array(n_individuals: int, bodyparts: list, poses: list): # Create numpy array with poses: max_frame = max(p["frame"] for p in poses) @@ -752,21 +756,13 @@ def _create_poses_np_array(n_individuals: int, bodyparts: list, poses: list): return poses_array -import argparse -import os - - def main(): """Provides a command line interface to benchmark_videos function.""" - parser = argparse.ArgumentParser( - description="Analyze a video using a DeepLabCut model and visualize keypoints." - ) + parser = argparse.ArgumentParser(description="Analyze a video using a DeepLabCut model and visualize keypoints.") parser.add_argument("model_path", type=str, help="Path to the model.") parser.add_argument("video_path", type=str, help="Path to the video file.") parser.add_argument("model_type", type=str, help="Type of the model (e.g., 'DLC').") - parser.add_argument( - "device", type=str, help="Device to run the model on (e.g., 'cuda' or 'cpu')." - ) + parser.add_argument("device", type=str, help="Device to run the model on (e.g., 'cuda' or 'cpu').") parser.add_argument( "-p", "--precision", @@ -774,9 +770,7 @@ def main(): default="FP32", help="Model precision (e.g., 'FP32', 'FP16').", ) - parser.add_argument( - "-d", "--display", action="store_true", help="Display keypoints on the video." - ) + parser.add_argument("-d", "--display", action="store_true", help="Display keypoints on the video.") parser.add_argument( "-c", "--pcutoff", @@ -814,9 +808,7 @@ def main(): default=[False, 0.5, 10], help="Dynamic cropping [flag, pcutoff, margin].", ) - parser.add_argument( - "--save-poses", action="store_true", help="Save the keypoint poses to files." - ) + parser.add_argument("--save-poses", action="store_true", help="Save the keypoint poses to files.") parser.add_argument( "--save-video", action="store_true", @@ -833,9 +825,7 @@ def main(): action="store_true", help="Draw keypoint names on the video.", ) - parser.add_argument( - "--cmap", type=str, default="bmy", help="Colormap for keypoints visualization." - ) + parser.add_argument("--cmap", type=str, default="bmy", help="Colormap for keypoints visualization.") parser.add_argument( "--no-sys-info", action="store_false", diff --git a/dlclive/check_install/check_install.py b/dlclive/check_install/check_install.py index ae3e5694..8ed8bb3e 100755 --- a/dlclive/check_install/check_install.py +++ b/dlclive/check_install/check_install.py @@ -7,16 +7,16 @@ import argparse import shutil +import urllib import warnings from pathlib import Path from dlclibrary.dlcmodelzoo.modelzoo_download import download_huggingface_model import dlclive -from dlclive.utils import download_file from dlclive.benchmark import benchmark_videos from dlclive.engine import Engine -from dlclive.utils import get_available_backends +from dlclive.utils import download_file, get_available_backends MODEL_NAME = "superanimal_quadruped" SNAPSHOT_NAME = "snapshot-700000.pb" @@ -52,7 +52,7 @@ def main(): url_link = "https://raw.githubusercontent.com/DeepLabCut/DeepLabCut-live/master/check_install/dog_clip.avi" try: download_file(url_link, video_file) - except (urllib.error.URLError, IOError) as e: + except (OSError, urllib.error.URLError) as e: raise RuntimeError(f"Failed to download video file: {e}") from e else: print(f"Video file already exists at {video_file}, skipping download.") @@ -66,9 +66,7 @@ def main(): # assert these things exist so we can give informative error messages assert Path(video_file).exists(), f"Missing video file {video_file}" - assert Path( - model_dir / SNAPSHOT_NAME - ).exists(), f"Missing model file {model_dir / SNAPSHOT_NAME}" + assert Path(model_dir / SNAPSHOT_NAME).exists(), f"Missing model file {model_dir / SNAPSHOT_NAME}" # run benchmark videos print("\n Running inference...\n") @@ -78,7 +76,7 @@ def main(): video_path=video_file, display=display, resize=0.5, - pcutoff=0.25 + pcutoff=0.25, ) # deleting temporary files @@ -87,20 +85,20 @@ def main(): shutil.rmtree(tmp_dir) except PermissionError: warnings.warn( - f"Could not delete temporary directory {str(tmp_dir)} due to a permissions error, but otherwise dlc-live seems to be working fine!" + f"Could not delete temporary directory {str(tmp_dir)} due to a permissions error.", + stacklevel=2, ) print("\nDone!\n") if __name__ == "__main__": - # Get available backends (emits a warning if neither TensorFlow nor PyTorch is installed) available_backends: list[Engine] = get_available_backends() print(f"Available backends: {[b.value for b in available_backends]}") # TODO: JR add support for PyTorch in check_install.py (requires some exported pytorch model to be downloaded) - if not Engine.TENSORFLOW in available_backends: + if Engine.TENSORFLOW not in available_backends: raise NotImplementedError( "TensorFlow is not installed. Currently check_install.py only supports testing the TensorFlow installation." ) diff --git a/dlclive/core/config.py b/dlclive/core/config.py index 1305cf94..38024fcd 100644 --- a/dlclive/core/config.py +++ b/dlclive/core/config.py @@ -9,6 +9,7 @@ # Licensed under GNU Lesser General Public License v3.0 # """Helpers for configuration file IO""" + from pathlib import Path import ruamel.yaml @@ -22,7 +23,7 @@ def read_yaml(file_path: str | Path) -> dict: "was not found. Please check the path to the exported model directory" ) - with open(file_path, "r") as f: + with open(file_path) as f: cfg = ruamel.yaml.YAML(typ="safe", pure=True).load(f) return cfg diff --git a/dlclive/core/inferenceutils.py b/dlclive/core/inferenceutils.py index 7b76b271..8a67ff73 100644 --- a/dlclive/core/inferenceutils.py +++ b/dlclive/core/inferenceutils.py @@ -66,9 +66,7 @@ def __init__(self, j1, j2, affinity=1): self._length = sqrt((j1.pos[0] - j2.pos[0]) ** 2 + (j1.pos[1] - j2.pos[1]) ** 2) def __repr__(self): - return ( - f"Link {self.idx}, affinity={self.affinity:.2f}, length={self.length:.2f}" - ) + return f"Link {self.idx}, affinity={self.affinity:.2f}, length={self.length:.2f}" @property def confidence(self): @@ -350,9 +348,7 @@ def calibrate(self, train_data_file): pass n_bpts = len(df.columns.get_level_values("bodyparts").unique()) if n_bpts == 1: - warnings.warn( - "There is only one keypoint; skipping calibration...", stacklevel=2 - ) + warnings.warn("There is only one keypoint; skipping calibration...", stacklevel=2) return xy = df.to_numpy().reshape((-1, n_bpts, 2)) @@ -360,9 +356,7 @@ def calibrate(self, train_data_file): # Only keeps skeletons that are more than 90% complete xy = xy[frac_valid >= 0.9] if not xy.size: - warnings.warn( - "No complete poses were found. Skipping calibration...", stacklevel=2 - ) + warnings.warn("No complete poses were found. Skipping calibration...", stacklevel=2) return # TODO Normalize dists by longest length? @@ -383,9 +377,7 @@ def calibrate(self, train_data_file): stacklevel=2, ) - def calc_assembly_mahalanobis_dist( - self, assembly, return_proba=False, nan_policy="little" - ): + def calc_assembly_mahalanobis_dist(self, assembly, return_proba=False, nan_policy="little"): if self._kde is None: raise ValueError("Assembler should be calibrated first with training data.") @@ -439,9 +431,7 @@ def _flatten_detections(data_dict): ids = [np.ones(len(arr), dtype=int) * -1 for arr in confidence] else: ids = [arr.argmax(axis=1) for arr in ids] - for i, (coords, conf, id_) in enumerate( - zip(coordinates, confidence, ids, strict=False) - ): + for i, (coords, conf, id_) in enumerate(zip(coordinates, confidence, ids, strict=False)): if not np.any(coords): continue for xy, p, g in zip(coords, conf, id_, strict=False): @@ -466,9 +456,7 @@ def extract_best_links(self, joints_dict, costs, trees=None): aff[np.isnan(aff)] = 0 if trees: - vecs = np.vstack( - [[*det_s.pos, *det_t.pos] for det_s in dets_s for det_t in dets_t] - ) + vecs = np.vstack([[*det_s.pos, *det_t.pos] for det_s in dets_s for det_t in dets_t]) dists = [] for n, tree in enumerate(trees, start=1): d, _ = tree.query(vecs) @@ -477,15 +465,8 @@ def extract_best_links(self, joints_dict, costs, trees=None): aff *= w.reshape(aff.shape) if self.greedy: - conf = np.asarray( - [ - [det_s.confidence * det_t.confidence for det_t in dets_t] - for det_s in dets_s - ] - ) - rows, cols = np.where( - (conf >= self.pcutoff * self.pcutoff) & (aff >= self.min_affinity) - ) + conf = np.asarray([[det_s.confidence * det_t.confidence for det_t in dets_t] for det_s in dets_s]) + rows, cols = np.where((conf >= self.pcutoff * self.pcutoff) & (aff >= self.min_affinity)) candidates = sorted( zip(rows, cols, aff[rows, cols], lengths[rows, cols], strict=False), key=lambda x: x[2], @@ -501,18 +482,14 @@ def extract_best_links(self, joints_dict, costs, trees=None): if len(i_seen) == self.max_n_individuals: break else: # Optimal keypoint pairing - inds_s = sorted( - range(len(dets_s)), key=lambda x: dets_s[x].confidence, reverse=True - )[: self.max_n_individuals] - inds_t = sorted( - range(len(dets_t)), key=lambda x: dets_t[x].confidence, reverse=True - )[: self.max_n_individuals] - keep_s = [ - ind for ind in inds_s if dets_s[ind].confidence >= self.pcutoff + inds_s = sorted(range(len(dets_s)), key=lambda x: dets_s[x].confidence, reverse=True)[ + : self.max_n_individuals ] - keep_t = [ - ind for ind in inds_t if dets_t[ind].confidence >= self.pcutoff + inds_t = sorted(range(len(dets_t)), key=lambda x: dets_t[x].confidence, reverse=True)[ + : self.max_n_individuals ] + keep_s = [ind for ind in inds_s if dets_s[ind].confidence >= self.pcutoff] + keep_t = [ind for ind in inds_t if dets_t[ind].confidence >= self.pcutoff] aff = aff[np.ix_(keep_s, keep_t)] rows, cols = linear_sum_assignment(aff, maximize=True) for row, col in zip(rows, cols, strict=False): @@ -551,9 +528,7 @@ def push_to_stack(i): if new_ind in assembled: continue if safe_edge: - d_old = self.calc_assembly_mahalanobis_dist( - assembly, nan_policy=nan_policy - ) + d_old = self.calc_assembly_mahalanobis_dist(assembly, nan_policy=nan_policy) success = assembly.add_link(best, store_dict=True) if not success: assembly._dict = dict() @@ -606,9 +581,7 @@ def build_assemblies(self, links): continue assembly = Assembly(self.n_multibodyparts) assembly.add_link(link) - self._fill_assembly( - assembly, lookup, assembled, self.safe_edge, self.nan_policy - ) + self._fill_assembly(assembly, lookup, assembled, self.safe_edge, self.nan_policy) for assembly_link in assembly._links: i, j = assembly_link.idx lookup[i].pop(j) @@ -620,10 +593,7 @@ def build_assemblies(self, links): n_extra = len(assemblies) - self.max_n_individuals if n_extra > 0: if self.safe_edge: - ds_old = [ - self.calc_assembly_mahalanobis_dist(assembly) - for assembly in assemblies - ] + ds_old = [self.calc_assembly_mahalanobis_dist(assembly) for assembly in assemblies] while len(assemblies) > self.max_n_individuals: ds = [] for i, j in itertools.combinations(range(len(assemblies)), 2): @@ -755,10 +725,7 @@ def _assemble(self, data_dict, ind_frame): for _, group in groups: ass = Assembly(self.n_multibodyparts) for joint in sorted(group, key=lambda x: x.confidence, reverse=True): - if ( - joint.confidence >= self.pcutoff - and joint.label < self.n_multibodyparts - ): + if joint.confidence >= self.pcutoff and joint.label < self.n_multibodyparts: ass.add_joint(joint) if len(ass): assemblies.append(ass) @@ -787,11 +754,7 @@ def _assemble(self, data_dict, ind_frame): assembled.update(assembled_) # Remove invalid assemblies - discarded = set( - joint - for joint in joints - if joint.idx not in assembled and np.isfinite(joint.confidence) - ) + discarded = set(joint for joint in joints if joint.idx not in assembled and np.isfinite(joint.confidence)) for assembly in assemblies[::-1]: if 0 < assembly.n_links < self.min_n_links or not len(assembly): for link in assembly._links: @@ -799,9 +762,7 @@ def _assemble(self, data_dict, ind_frame): assemblies.remove(assembly) if 0 < self.max_overlap < 1: # Non-maximum pose suppression if self._kde is not None: - scores = [ - -self.calc_assembly_mahalanobis_dist(ass) for ass in assemblies - ] + scores = [-self.calc_assembly_mahalanobis_dist(ass) for ass in assemblies] else: scores = [ass._affinity for ass in assemblies] lst = list(zip(scores, assemblies, strict=False)) @@ -870,9 +831,7 @@ def wrapped(i): n_frames = len(self.metadata["imnames"]) with multiprocessing.Pool(n_processes) as p: with tqdm(total=n_frames) as pbar: - for i, (assemblies, unique) in p.imap_unordered( - wrapped, range(n_frames), chunksize=chunk_size - ): + for i, (assemblies, unique) in p.imap_unordered(wrapped, range(n_frames), chunksize=chunk_size): if assemblies: self.assemblies[i] = assemblies if unique is not None: @@ -891,9 +850,7 @@ def parse_metadata(data): params["joint_names"] = data["metadata"]["all_joints_names"] params["num_joints"] = len(params["joint_names"]) params["paf_graph"] = data["metadata"]["PAFgraph"] - params["paf"] = data["metadata"].get( - "PAFinds", np.arange(len(params["joint_names"])) - ) + params["paf"] = data["metadata"].get("PAFinds", np.arange(len(params["joint_names"]))) params["bpts"] = params["ibpts"] = range(params["num_joints"]) params["imnames"] = [fn for fn in list(data) if fn != "metadata"] return params @@ -983,11 +940,7 @@ def calc_object_keypoint_similarity( else: oks = [] xy_preds = [xy_pred] - combos = ( - pair - for l in range(len(symmetric_kpts)) - for pair in itertools.combinations(symmetric_kpts, l + 1) - ) + combos = (pair for l in range(len(symmetric_kpts)) for pair in itertools.combinations(symmetric_kpts, l + 1)) for pairs in combos: # Swap corresponding keypoints tmp = xy_pred.copy() @@ -1024,9 +977,7 @@ def match_assemblies( num_ground_truth = len(ground_truth) # Sort predictions by score - inds_pred = np.argsort( - [ins.affinity if ins.n_links else ins.confidence for ins in predictions] - )[::-1] + inds_pred = np.argsort([ins.affinity if ins.n_links else ins.confidence for ins in predictions])[::-1] predictions = np.asarray(predictions)[inds_pred] # indices of unmatched ground truth assemblies @@ -1133,9 +1084,7 @@ def find_outlier_assemblies(dict_of_assemblies, criterion="area", qs=(5, 95)): raise ValueError(f"Invalid criterion {criterion}.") if len(qs) != 2: - raise ValueError( - "Two percentiles (for lower and upper bounds) should be given." - ) + raise ValueError("Two percentiles (for lower and upper bounds) should be given.") tuples = [] for frame_ind, assemblies in dict_of_assemblies.items(): @@ -1239,9 +1188,7 @@ def evaluate_assembly_greedy( oks = np.asarray([match.oks for match in all_matched])[sorted_pred_indices] # Compute prediction and recall - p, r = _compute_precision_and_recall( - total_gt_assemblies, oks, oks_t, recall_thresholds - ) + p, r = _compute_precision_and_recall(total_gt_assemblies, oks, oks_t, recall_thresholds) precisions.append(p) recalls.append(r) @@ -1314,9 +1261,7 @@ def evaluate_assembly( precisions = [] recalls = [] for t in oks_thresholds: - p, r = _compute_precision_and_recall( - total_gt_assemblies, oks, t, recall_thresholds - ) + p, r = _compute_precision_and_recall(total_gt_assemblies, oks, t, recall_thresholds) precisions.append(p) recalls.append(r) diff --git a/dlclive/core/runner.py b/dlclive/core/runner.py index 878d5524..dda17281 100644 --- a/dlclive/core/runner.py +++ b/dlclive/core/runner.py @@ -9,6 +9,7 @@ # Licensed under GNU Lesser General Public License v3.0 # """Base runner for DeepLabCut-Live""" + import abc from pathlib import Path diff --git a/dlclive/display.py b/dlclive/display.py index 42abab41..d3f87b8e 100644 --- a/dlclive/display.py +++ b/dlclive/display.py @@ -35,9 +35,7 @@ class Display: def __init__(self, cmap="bmy", radius=3, pcutoff=0.5): if not _TKINTER_AVAILABLE: - raise ImportError( - "tkinter is not available. Display functionality requires tkinter. " - ) + raise ImportError("tkinter is not available. Display functionality requires tkinter. ") self.cmap = cmap self.colors = None self.radius = radius @@ -100,9 +98,7 @@ def display_frame(self, frame, pose=None): y0 = max(0, pose[i, j, 1] - self.radius) y1 = min(im_size[1], pose[i, j, 1] + self.radius) coords = [x0, y0, x1, y1] - draw.ellipse( - coords, fill=self.colors[j], outline=self.colors[j] - ) + draw.ellipse(coords, fill=self.colors[j], outline=self.colors[j]) except Exception as e: print(e) img_tk = ImageTk.PhotoImage(image=img, master=self.window) diff --git a/dlclive/engine.py b/dlclive/engine.py index eed0af13..487d61ad 100644 --- a/dlclive/engine.py +++ b/dlclive/engine.py @@ -1,6 +1,7 @@ from enum import Enum from pathlib import Path + class Engine(Enum): TENSORFLOW = "tensorflow" PYTORCH = "pytorch" @@ -30,4 +31,4 @@ def from_model_path(cls, model_path: str | Path) -> "Engine": if path.suffix == ".pt": return cls.PYTORCH - raise ValueError(f"Could not determine engine from model path: {model_path}") \ No newline at end of file + raise ValueError(f"Could not determine engine from model path: {model_path}") diff --git a/dlclive/factory.py b/dlclive/factory.py index df70029c..9cf6462a 100644 --- a/dlclive/factory.py +++ b/dlclive/factory.py @@ -1,4 +1,5 @@ """Factory to build runners for DeepLabCut-Live inference""" + from __future__ import annotations from pathlib import Path diff --git a/dlclive/live_inference.py b/dlclive/live_inference.py index 6db75977..7a4ffd14 100644 --- a/dlclive/live_inference.py +++ b/dlclive/live_inference.py @@ -55,11 +55,7 @@ def get_system_info() -> dict: git_hash = None dlc_basedir = os.path.dirname(os.path.dirname(__file__)) try: - git_hash = ( - subprocess.check_output(["git", "rev-parse", "HEAD"], cwd=dlc_basedir) - .decode("utf-8") - .strip() - ) + git_hash = subprocess.check_output(["git", "rev-parse", "HEAD"], cwd=dlc_basedir).decode("utf-8").strip() except subprocess.CalledProcessError: # Not installed from git repo, e.g., pypi pass @@ -109,7 +105,8 @@ def analyze_live_video( save_video=False, ): """ - Analyzes a video to track keypoints using a DeepLabCut model, and optionally saves the keypoint data and the labeled video. + Analyzes a video to track keypoints using a DeepLabCut model, + and optionally saves the keypoint data and the labeled video. Parameters ---------- @@ -132,11 +129,21 @@ def analyze_live_video( display_radius : int, optional, default=5 Radius of circles drawn for keypoints on video frames. resize : tuple of int (width, height) or None, optional - Resize dimensions for video frames. e.g. if resize = 0.5, the video will be processed in half the original size. If None, no resizing is applied. + Resize dimensions for video frames. e.g. if resize = 0.5, + the video will be processed in half the original size. + If None, no resizing is applied. cropping : list of int or None, optional Cropping parameters [x1, x2, y1, y2] in pixels. If None, no cropping is applied. dynamic : tuple, optional, default=(False, 0.5, 10) (True/false), p cutoff, margin) - Parameters for dynamic cropping. If the state is true, then dynamic cropping will be performed. That means that if an object is detected (i.e. any body part > detectiontreshold), then object boundaries are computed according to the smallest/largest x position and smallest/largest y position of all body parts. This window is expanded by the margin and from then on only the posture within this crop is analyzed (until the object is lost, i.e. detectiontreshold), + then object boundaries are computed according to the + smallest/largest x position and smallest/largest y position of all body parts. + This window is expanded by the margin and from then on only the posture within + this crop is analyzed (until the object is lost, i.e. Or """Download the model weights from huggingface and load them in torch state dict""" checkpoint: Path = download_super_animal_snapshot(dataset=super_animal, model_name=model_name) return torch.load(checkpoint, map_location="cpu", weights_only=True)["model"] - + export_dict = { "config": model_cfg, "pose": _load_model_weights(model_name), @@ -52,14 +52,14 @@ def _load_model_weights(model_name: str, super_animal: str = super_animal) -> Or if __name__ == "__main__": - """Example usage""" + """Example usage""" from utils import _MODELZOO_PATH - + model_name = "resnet_50" super_animal = "superanimal_quadruped" export_modelzoo_model( - export_path=_MODELZOO_PATH / 'exported_models' / f'exported_{super_animal}_{model_name}.pt', + export_path=_MODELZOO_PATH / "exported_models" / f"exported_{super_animal}_{model_name}.pt", super_animal=super_animal, model_name=model_name, ) diff --git a/dlclive/modelzoo/resolve_config.py b/dlclive/modelzoo/resolve_config.py index cf11f3b3..7eee5e7f 100644 --- a/dlclive/modelzoo/resolve_config.py +++ b/dlclive/modelzoo/resolve_config.py @@ -99,9 +99,7 @@ def get_updated_value(variable: str) -> int | list[int]: else: raise ValueError(f"Unknown operator for variable: {variable}") - raise ValueError( - f"Found {variable} in the configuration file, but cannot parse it." - ) + raise ValueError(f"Found {variable} in the configuration file, but cannot parse it.") updated_values = { "num_bodyparts": num_bodyparts, @@ -127,10 +125,7 @@ def get_updated_value(variable: str) -> int | list[int]: backbone_output_channels, **kwargs, ) - elif ( - isinstance(config[k], str) - and config[k].strip().split(" ")[0] in updated_values.keys() - ): + elif isinstance(config[k], str) and config[k].strip().split(" ")[0] in updated_values.keys(): config[k] = get_updated_value(config[k]) return config diff --git a/dlclive/modelzoo/utils.py b/dlclive/modelzoo/utils.py index f9bf2f71..548cd41e 100644 --- a/dlclive/modelzoo/utils.py +++ b/dlclive/modelzoo/utils.py @@ -8,8 +8,8 @@ import logging from pathlib import Path -from dlclibrary.dlcmodelzoo.modelzoo_download import download_huggingface_model from dlclibrary.dlcmodelzoo.modelzoo_download import _load_model_names as huggingface_model_paths +from dlclibrary.dlcmodelzoo.modelzoo_download import download_huggingface_model from ruamel.yaml import YAML from dlclive.modelzoo.resolve_config import update_config @@ -87,8 +87,7 @@ def add_metadata( config["metadata"] = { "project_path": project_config["project_path"], "pose_config_path": "", - "bodyparts": project_config.get("multianimalbodyparts") - or project_config["bodyparts"], + "bodyparts": project_config.get("multianimalbodyparts") or project_config["bodyparts"], "unique_bodyparts": project_config.get("uniquebodyparts", []), "individuals": project_config.get("individuals", ["animal"]), "with_identity": project_config.get("identity", False), @@ -130,9 +129,7 @@ def load_super_animal_config( else: model_config["method"] = "TD" if super_animal != "superanimal_humanbody": - detector_cfg_path = get_super_animal_model_config_path( - model_name=detector_name - ) + detector_cfg_path = get_super_animal_model_config_path(model_name=detector_name) detector_cfg = read_config_as_dict(detector_cfg_path) model_config["detector"] = detector_cfg return model_config @@ -161,17 +158,13 @@ def download_super_animal_snapshot(dataset: str, model_name: str) -> Path: return model_path try: - download_huggingface_model( - model_name, target_dir=str(snapshot_dir), rename_mapping=model_filename - ) + download_huggingface_model(model_name, target_dir=str(snapshot_dir), rename_mapping=model_filename) if not model_path.exists(): raise RuntimeError(f"Failed to download {model_name} to {model_path}") except Exception as e: - logging.error( - f"Failed to download superanimal snapshot {model_name} to {model_path}: {e}" - ) + logging.error(f"Failed to download superanimal snapshot {model_name} to {model_path}: {e}") raise e return model_path diff --git a/dlclive/pose_estimation_pytorch/data/__init__.py b/dlclive/pose_estimation_pytorch/data/__init__.py index 2fabb6ed..ad09d20c 100644 --- a/dlclive/pose_estimation_pytorch/data/__init__.py +++ b/dlclive/pose_estimation_pytorch/data/__init__.py @@ -4,6 +4,7 @@ Licensed under GNU Lesser General Public License v3.0 """ + from dlclive.pose_estimation_pytorch.data.image import ( top_down_crop, top_down_crop_torch, diff --git a/dlclive/pose_estimation_pytorch/data/image.py b/dlclive/pose_estimation_pytorch/data/image.py index 21187cc8..0c7880d3 100644 --- a/dlclive/pose_estimation_pytorch/data/image.py +++ b/dlclive/pose_estimation_pytorch/data/image.py @@ -101,9 +101,7 @@ def top_down_crop( out_w, out_h = output_size bbox = fix_bbox_aspect_ratio(bbox, margin, out_w, out_h) - x1, y1, x2, y2, pad_left, pad_top, pad_x, pad_y = crop_corners( - bbox, img_size, center_padding - ) + x1, y1, x2, y2, pad_left, pad_top, pad_x, pad_y = crop_corners(bbox, img_size, center_padding) w, h = x2 - x1, y2 - y1 crop_w, crop_h = w + pad_x, h + pad_y @@ -164,4 +162,4 @@ def forward(self, img: torch.Tensor) -> torch.Tensor: padding = (0, 0, pad_w, pad_h) # Warning: this method returns the batched image, regardless if its input was batched or not - return F.pad(img, padding, padding_mode="reflect") \ No newline at end of file + return F.pad(img, padding, padding_mode="reflect") diff --git a/dlclive/pose_estimation_pytorch/dynamic_cropping.py b/dlclive/pose_estimation_pytorch/dynamic_cropping.py index 45726348..16925c91 100644 --- a/dlclive/pose_estimation_pytorch/dynamic_cropping.py +++ b/dlclive/pose_estimation_pytorch/dynamic_cropping.py @@ -82,9 +82,7 @@ def crop(self, image: torch.Tensor) -> torch.Tensor: height. """ if len(image) != 1: - raise RuntimeError( - f"DynamicCropper can only be used with batch size 1 (found image shape: {image.shape})" - ) + raise RuntimeError(f"DynamicCropper can only be used with batch size 1 (found image shape: {image.shape})") if self._shape is None: self._shape = image.shape[3], image.shape[2] @@ -309,9 +307,7 @@ def crop(self, image: torch.Tensor) -> torch.Tensor: `crop` was previously called with an image of a different W or H. """ if len(image) != 1: - raise RuntimeError( - f"DynamicCropper can only be used with batch size 1 (found image shape: {image.shape})" - ) + raise RuntimeError(f"DynamicCropper can only be used with batch size 1 (found image shape: {image.shape})") if self._shape is None: self._shape = image.shape[3], image.shape[2] @@ -398,9 +394,7 @@ def update(self, pose: torch.Tensor) -> torch.Tensor: return pose - def _prepare_bounding_box( - self, x1: int, y1: int, x2: int, y2: int - ) -> tuple[int, int, int, int]: + def _prepare_bounding_box(self, x1: int, y1: int, x2: int, y2: int) -> tuple[int, int, int, int]: """Prepares the bounding box for cropping. Adds a margin around the bounding box, then transforms it into the target aspect @@ -497,12 +491,8 @@ def generate_patches(self) -> list[tuple[int, int, int, int]]: Returns: A list of patch coordinates as tuples (x0, y0, x1, y1). """ - patch_xs = self.split_array( - self._shape[0], self._patch_counts[0], self._patch_overlap - ) - patch_ys = self.split_array( - self._shape[1], self._patch_counts[1], self._patch_overlap - ) + patch_xs = self.split_array(self._shape[0], self._patch_counts[0], self._patch_overlap) + patch_ys = self.split_array(self._shape[1], self._patch_counts[1], self._patch_overlap) patches = [] for y0, y1 in patch_ys: diff --git a/dlclive/pose_estimation_pytorch/models/__init__.py b/dlclive/pose_estimation_pytorch/models/__init__.py index edd4e274..c54f4f75 100644 --- a/dlclive/pose_estimation_pytorch/models/__init__.py +++ b/dlclive/pose_estimation_pytorch/models/__init__.py @@ -5,5 +5,5 @@ Licensed under GNU Lesser General Public License v3.0 """ -from dlclive.pose_estimation_pytorch.models.model import PoseModel from dlclive.pose_estimation_pytorch.models.detectors import DETECTORS, BaseDetector +from dlclive.pose_estimation_pytorch.models.model import PoseModel diff --git a/dlclive/pose_estimation_pytorch/models/backbones/base.py b/dlclive/pose_estimation_pytorch/models/backbones/base.py index 4bc6b4f8..d4b30641 100644 --- a/dlclive/pose_estimation_pytorch/models/backbones/base.py +++ b/dlclive/pose_estimation_pytorch/models/backbones/base.py @@ -19,7 +19,7 @@ import torch.nn as nn from huggingface_hub import hf_hub_download -from dlclive.pose_estimation_pytorch.models.registry import build_from_cfg, Registry +from dlclive.pose_estimation_pytorch.models.registry import Registry, build_from_cfg BACKBONES = Registry("backbones", build_func=build_from_cfg) @@ -121,11 +121,7 @@ def download_weights(self, filename: str, force: bool = False) -> Path: logging.info(f"Downloading the pre-trained backbone to {model_path}") self.backbone_weight_folder.mkdir(exist_ok=True, parents=False) - output_path = Path( - hf_hub_download( - self.repo_id, filename, cache_dir=self.backbone_weight_folder - ) - ) + output_path = Path(hf_hub_download(self.repo_id, filename, cache_dir=self.backbone_weight_folder)) # resolve gets the actual path if the output path is a symlink output_path = output_path.resolve() diff --git a/dlclive/pose_estimation_pytorch/models/backbones/cspnext.py b/dlclive/pose_estimation_pytorch/models/backbones/cspnext.py index 681f2ba2..69a28b77 100644 --- a/dlclive/pose_estimation_pytorch/models/backbones/cspnext.py +++ b/dlclive/pose_estimation_pytorch/models/backbones/cspnext.py @@ -16,6 +16,7 @@ For more details about this architecture, see `RTMDet: An Empirical Study of Designing Real-Time Object Detectors`: https://arxiv.org/abs/1711.05101. """ + from dataclasses import dataclass import torch @@ -99,10 +100,7 @@ def __init__( ) -> None: super().__init__(stride=32, **kwargs) if arch not in self.ARCH: - raise ValueError( - f"Unknown `CSPNeXT` architecture: {arch}. Must be one of " - f"{self.ARCH.keys()}" - ) + raise ValueError(f"Unknown `CSPNeXT` architecture: {arch}. Must be one of {self.ARCH.keys()}") self.model_name = model_name self.layer_configs = self.ARCH[arch] diff --git a/dlclive/pose_estimation_pytorch/models/backbones/resnet.py b/dlclive/pose_estimation_pytorch/models/backbones/resnet.py index f6611596..d5dfb648 100644 --- a/dlclive/pose_estimation_pytorch/models/backbones/resnet.py +++ b/dlclive/pose_estimation_pytorch/models/backbones/resnet.py @@ -93,21 +93,11 @@ def __init__( self.interm_features = {} self.model.layer1[2].register_forward_hook(self._get_features("bank1")) self.model.layer2[2].register_forward_hook(self._get_features("bank2")) - self.conv_block1 = self._make_conv_block( - in_channels=512, out_channels=512, kernel_size=3, stride=2 - ) - self.conv_block2 = self._make_conv_block( - in_channels=512, out_channels=128, kernel_size=1, stride=1 - ) - self.conv_block3 = self._make_conv_block( - in_channels=256, out_channels=256, kernel_size=3, stride=2 - ) - self.conv_block4 = self._make_conv_block( - in_channels=256, out_channels=256, kernel_size=3, stride=2 - ) - self.conv_block5 = self._make_conv_block( - in_channels=256, out_channels=128, kernel_size=1, stride=1 - ) + self.conv_block1 = self._make_conv_block(in_channels=512, out_channels=512, kernel_size=3, stride=2) + self.conv_block2 = self._make_conv_block(in_channels=512, out_channels=128, kernel_size=1, stride=1) + self.conv_block3 = self._make_conv_block(in_channels=256, out_channels=256, kernel_size=3, stride=2) + self.conv_block4 = self._make_conv_block(in_channels=256, out_channels=256, kernel_size=3, stride=2) + self.conv_block5 = self._make_conv_block(in_channels=256, out_channels=128, kernel_size=1, stride=1) def _make_conv_block( self, @@ -118,9 +108,7 @@ def _make_conv_block( momentum: float = 0.001, # (1 - decay) ) -> torch.nn.Sequential: return nn.Sequential( - nn.Conv2d( - in_channels, out_channels, kernel_size=kernel_size, stride=stride - ), + nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride), nn.BatchNorm2d(out_channels, momentum=momentum), nn.ReLU(), ) diff --git a/dlclive/pose_estimation_pytorch/models/detectors/fasterRCNN.py b/dlclive/pose_estimation_pytorch/models/detectors/fasterRCNN.py index f250b9ac..f0b236c2 100644 --- a/dlclive/pose_estimation_pytorch/models/detectors/fasterRCNN.py +++ b/dlclive/pose_estimation_pytorch/models/detectors/fasterRCNN.py @@ -69,6 +69,4 @@ def __init__( # Modify the base predictor to output the correct number of classes num_classes = 2 in_features = self.model.roi_heads.box_predictor.cls_score.in_features - self.model.roi_heads.box_predictor = detection.faster_rcnn.FastRCNNPredictor( - in_features, num_classes - ) + self.model.roi_heads.box_predictor = detection.faster_rcnn.FastRCNNPredictor(in_features, num_classes) diff --git a/dlclive/pose_estimation_pytorch/models/detectors/torchvision.py b/dlclive/pose_estimation_pytorch/models/detectors/torchvision.py index 72dd54b8..b44b4c39 100644 --- a/dlclive/pose_estimation_pytorch/models/detectors/torchvision.py +++ b/dlclive/pose_estimation_pytorch/models/detectors/torchvision.py @@ -9,6 +9,7 @@ # Licensed under GNU Lesser General Public License v3.0 # """Module to adapt torchvision detectors for DeepLabCut""" + from __future__ import annotations import torch diff --git a/dlclive/pose_estimation_pytorch/models/heads/dekr.py b/dlclive/pose_estimation_pytorch/models/heads/dekr.py index 1ef1ec14..bc429db7 100644 --- a/dlclive/pose_estimation_pytorch/models/heads/dekr.py +++ b/dlclive/pose_estimation_pytorch/models/heads/dekr.py @@ -88,21 +88,13 @@ def __init__( super().__init__() self.bn_momentum = 0.1 self.inp_channels = channels[0] - self.num_joints_with_center = channels[ - 2 - ] # Should account for the center being a joint + self.num_joints_with_center = channels[2] # Should account for the center being a joint self.final_conv_kernel = final_conv_kernel - self.transition_heatmap = self._make_transition_for_head( - self.inp_channels, channels[1] - ) - self.head_heatmap = self._make_heatmap_head( - block, num_blocks, channels[1], dilation_rate - ) + self.transition_heatmap = self._make_transition_for_head(self.inp_channels, channels[1]) + self.head_heatmap = self._make_heatmap_head(block, num_blocks, channels[1], dilation_rate) - def _make_transition_for_head( - self, in_channels: int, out_channels: int - ) -> nn.Sequential: + def _make_transition_for_head(self, in_channels: int, out_channels: int) -> nn.Sequential: """Summary: Construct the transition layer for the head. @@ -141,9 +133,7 @@ def _make_heatmap_head( """ heatmap_head_layers = [] - feature_conv = self._make_layer( - block, num_channels, num_channels, num_blocks, dilation=dilation_rate - ) + feature_conv = self._make_layer(block, num_channels, num_channels, num_blocks, dilation=dilation_rate) heatmap_head_layers.append(feature_conv) heatmap_conv = nn.Conv2d( @@ -190,14 +180,10 @@ def _make_layer( stride=stride, bias=False, ), - nn.BatchNorm2d( - out_channels * block.expansion, momentum=self.bn_momentum - ), + nn.BatchNorm2d(out_channels * block.expansion, momentum=self.bn_momentum), ) - layers = [ - block(in_channels, out_channels, stride, downsample, dilation=dilation) - ] + layers = [block(in_channels, out_channels, stride, downsample, dilation=dilation)] in_channels = out_channels * block.expansion for _ in range(1, num_blocks): layers.append(block(in_channels, out_channels, dilation=dilation)) @@ -251,9 +237,7 @@ def __init__( self.dilation_rate = dilation_rate self.final_conv_kernel = final_conv_kernel - self.transition_offset = self._make_transition_for_head( - self.inp_channels, self.offset_channels - ) + self.transition_offset = self._make_transition_for_head(self.inp_channels, self.offset_channels) ( self.offset_feature_layers, self.offset_final_layer, @@ -306,24 +290,18 @@ def _make_layer( stride=stride, bias=False, ), - nn.BatchNorm2d( - out_channels * block.expansion, momentum=self.bn_momentum - ), + nn.BatchNorm2d(out_channels * block.expansion, momentum=self.bn_momentum), ) layers = [] - layers.append( - block(in_channels, out_channels, stride, downsample, dilation=dilation) - ) + layers.append(block(in_channels, out_channels, stride, downsample, dilation=dilation)) in_channels = out_channels * block.expansion for _ in range(1, num_blocks): layers.append(block(in_channels, out_channels, dilation=dilation)) return nn.Sequential(*layers) - def _make_transition_for_head( - self, in_channels: int, out_channels: int - ) -> nn.Sequential: + def _make_transition_for_head(self, in_channels: int, out_channels: int) -> nn.Sequential: """Summary: Create a transition layer for the head. @@ -404,9 +382,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: final_offset.append( self.offset_final_layer[j]( self.offset_feature_layers[j]( - offset_feature[ - :, j * self.offset_perkpt : (j + 1) * self.offset_perkpt - ] + offset_feature[:, j * self.offset_perkpt : (j + 1) * self.offset_perkpt] ) ) ) diff --git a/dlclive/pose_estimation_pytorch/models/heads/dlcrnet.py b/dlclive/pose_estimation_pytorch/models/heads/dlcrnet.py index 8cfbd57a..b513b6dd 100644 --- a/dlclive/pose_estimation_pytorch/models/heads/dlcrnet.py +++ b/dlclive/pose_estimation_pytorch/models/heads/dlcrnet.py @@ -41,9 +41,7 @@ def __init__( num_limbs = paf_config["channels"][-1] # Already has the 2x multiplier in_refined_channels = features_dim + num_keypoints + num_limbs if num_stages > 0: - heatmap_config["channels"][0] = paf_config["channels"][ - 0 - ] = in_refined_channels + heatmap_config["channels"][0] = paf_config["channels"][0] = in_refined_channels locref_config["channels"][0] = locref_config["channels"][-1] super().__init__(predictor, heatmap_config, locref_config) @@ -52,35 +50,21 @@ def __init__( self.paf_head = DeconvModule(**paf_config) - self.convt1 = self._make_layer_same_padding( - in_channels=in_channels, out_channels=num_keypoints - ) - self.convt2 = self._make_layer_same_padding( - in_channels=in_channels, out_channels=locref_config["channels"][-1] - ) - self.convt3 = self._make_layer_same_padding( - in_channels=in_channels, out_channels=num_limbs - ) - self.convt4 = self._make_layer_same_padding( - in_channels=in_channels, out_channels=features_dim - ) + self.convt1 = self._make_layer_same_padding(in_channels=in_channels, out_channels=num_keypoints) + self.convt2 = self._make_layer_same_padding(in_channels=in_channels, out_channels=locref_config["channels"][-1]) + self.convt3 = self._make_layer_same_padding(in_channels=in_channels, out_channels=num_limbs) + self.convt4 = self._make_layer_same_padding(in_channels=in_channels, out_channels=features_dim) self.hm_ref_layers = nn.ModuleList() self.paf_ref_layers = nn.ModuleList() for _ in range(num_stages): self.hm_ref_layers.append( - self._make_refinement_layer( - in_channels=in_refined_channels, out_channels=num_keypoints - ) + self._make_refinement_layer(in_channels=in_refined_channels, out_channels=num_keypoints) ) self.paf_ref_layers.append( - self._make_refinement_layer( - in_channels=in_refined_channels, out_channels=num_limbs - ) + self._make_refinement_layer(in_channels=in_refined_channels, out_channels=num_limbs) ) - def _make_layer_same_padding( - self, in_channels: int, out_channels: int - ) -> nn.ConvTranspose2d: + def _make_layer_same_padding(self, in_channels: int, out_channels: int) -> nn.ConvTranspose2d: # FIXME There is no consensual solution to emulate TF behavior in pytorch # see https://github.com/pytorch/pytorch/issues/3867 return nn.ConvTranspose2d( @@ -103,9 +87,7 @@ def _make_refinement_layer(self, in_channels: int, out_channels: int) -> nn.Conv Returns: refinement_layer: the refinement layer. """ - return nn.Conv2d( - in_channels, out_channels, kernel_size=3, stride=1, padding="same" - ) + return nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding="same") def forward(self, x: torch.Tensor) -> dict[str, torch.Tensor]: if self.num_stages > 0: @@ -116,7 +98,7 @@ def forward(self, x: torch.Tensor) -> dict[str, torch.Tensor]: stage_in = stage2_in stage_hm_out = stage1_hm_out for i, (hm_ref_layer, paf_ref_layer) in enumerate( - zip(self.hm_ref_layers, self.paf_ref_layers) + zip(self.hm_ref_layers, self.paf_ref_layers, strict=False) ): pre_stage_hm_out = stage_hm_out stage_hm_out = hm_ref_layer(stage_in) diff --git a/dlclive/pose_estimation_pytorch/models/heads/rtmcc_head.py b/dlclive/pose_estimation_pytorch/models/heads/rtmcc_head.py index 53c112db..8aa0ed28 100644 --- a/dlclive/pose_estimation_pytorch/models/heads/rtmcc_head.py +++ b/dlclive/pose_estimation_pytorch/models/heads/rtmcc_head.py @@ -13,14 +13,15 @@ Based on the official ``mmpose`` RTMCC head implementation. For more information, see . """ + from __future__ import annotations import torch import torch.nn as nn from dlclive.pose_estimation_pytorch.models.heads.base import ( - BaseHead, HEADS, + BaseHead, ) from dlclive.pose_estimation_pytorch.models.modules import ( GatedAttentionUnit, diff --git a/dlclive/pose_estimation_pytorch/models/heads/simple_head.py b/dlclive/pose_estimation_pytorch/models/heads/simple_head.py index 545d8543..6bc19b18 100644 --- a/dlclive/pose_estimation_pytorch/models/heads/simple_head.py +++ b/dlclive/pose_estimation_pytorch/models/heads/simple_head.py @@ -125,9 +125,7 @@ def __init__( head_stride = 1 self.deconv_layers = nn.Identity() if len(kernel_size) > 0: - self.deconv_layers = nn.Sequential( - *self._make_layers(in_channels, channels[1:], kernel_size, strides) - ) + self.deconv_layers = nn.Sequential(*self._make_layers(in_channels, channels[1:], kernel_size, strides)) for s in strides: head_stride *= s @@ -161,12 +159,10 @@ def _make_layers( the deconvolutional layers """ layers = [] - for out_channels, k, s in zip(out_channels, kernel_sizes, strides): - layers.append( - nn.ConvTranspose2d(in_channels, out_channels, kernel_size=k, stride=s) - ) + for out_chan, k, s in zip(out_channels, kernel_sizes, strides, strict=False): + layers.append(nn.ConvTranspose2d(in_channels, out_chan, kernel_size=k, stride=s)) layers.append(nn.ReLU()) - in_channels = out_channels + in_channels = out_chan return layers[:-1] def forward(self, x: torch.Tensor) -> torch.Tensor: diff --git a/dlclive/pose_estimation_pytorch/models/model.py b/dlclive/pose_estimation_pytorch/models/model.py index 1d8fce83..05c77a0c 100644 --- a/dlclive/pose_estimation_pytorch/models/model.py +++ b/dlclive/pose_estimation_pytorch/models/model.py @@ -48,10 +48,7 @@ def __init__( self.heads = nn.ModuleDict(heads) self.neck = neck - self._strides = { - name: _model_stride(self.backbone.stride, head.stride) - for name, head in heads.items() - } + self._strides = {name: _model_stride(self.backbone.stride, head.stride) for name, head in heads.items()} def forward(self, x: torch.Tensor) -> dict[str, dict[str, torch.Tensor]]: """ @@ -83,13 +80,10 @@ def get_predictions(self, outputs: dict[str, dict[str, torch.Tensor]]) -> dict: Returns: A dictionary containing the predictions of each head group """ - return { - name: head.predictor(self._strides[name], outputs[name]) - for name, head in self.heads.items() - } + return {name: head.predictor(self._strides[name], outputs[name]) for name, head in self.heads.items()} @staticmethod - def build(cfg: dict) -> "PoseModel": + def build(cfg: dict) -> PoseModel: """ Args: cfg: The configuration of the model to build. diff --git a/dlclive/pose_estimation_pytorch/models/modules/conv_block.py b/dlclive/pose_estimation_pytorch/models/modules/conv_block.py index f3fbb02c..f0070fc1 100644 --- a/dlclive/pose_estimation_pytorch/models/modules/conv_block.py +++ b/dlclive/pose_estimation_pytorch/models/modules/conv_block.py @@ -9,6 +9,7 @@ # Licensed under GNU Lesser General Public License v3.0 # """The code is based on DEKR: https://github.com/HRNet/DEKR/tree/main""" + from __future__ import annotations from abc import ABC, abstractmethod @@ -19,14 +20,14 @@ from dlclive.pose_estimation_pytorch.models.registry import Registry, build_from_cfg - BLOCKS = Registry("blocks", build_func=build_from_cfg) class BaseBlock(ABC, nn.Module): """Abstract Base class for defining custom blocks. - This class defines an abstract base class for creating custom blocks used in the HigherHRNet for Human Pose Estimation. + This class defines an abstract base class for creating custom blocks + used in the HigherHRNet for Human Pose Estimation. Attributes: bn_momentum: Batch normalization momentum. @@ -88,7 +89,7 @@ def __init__( downsample: nn.Module | None = None, dilation: int = 1, ): - super(BasicBlock, self).__init__() + super().__init__() self.conv1 = nn.Conv2d( in_channels, out_channels, @@ -167,7 +168,7 @@ def __init__( downsample: nn.Module | None = None, dilation: int = 1, ): - super(Bottleneck, self).__init__() + super().__init__() self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False) self.bn1 = nn.BatchNorm2d(out_channels, momentum=self.bn_momentum) self.conv2 = nn.Conv2d( @@ -180,12 +181,8 @@ def __init__( dilation=dilation, ) self.bn2 = nn.BatchNorm2d(out_channels, momentum=self.bn_momentum) - self.conv3 = nn.Conv2d( - out_channels, out_channels * self.expansion, kernel_size=1, bias=False - ) - self.bn3 = nn.BatchNorm2d( - out_channels * self.expansion, momentum=self.bn_momentum - ) + self.conv3 = nn.Conv2d(out_channels, out_channels * self.expansion, kernel_size=1, bias=False) + self.bn3 = nn.BatchNorm2d(out_channels * self.expansion, momentum=self.bn_momentum) self.relu = nn.ReLU(inplace=True) self.downsample = downsample self.stride = stride @@ -250,10 +247,8 @@ def __init__( dilation: int = 1, deformable_groups: int = 1, ): - super(AdaptBlock, self).__init__() - regular_matrix = torch.tensor( - [[-1, -1, -1, 0, 0, 0, 1, 1, 1], [-1, 0, 1, -1, 0, 1, -1, 0, 1]] - ) + super().__init__() + regular_matrix = torch.tensor([[-1, -1, -1, 0, 0, 0, 1, 1, 1], [-1, 0, 1, -1, 0, 1, -1, 0, 1]]) self.register_buffer("regular_matrix", regular_matrix.float()) self.downsample = downsample self.transform_matrix_conv = nn.Conv2d(in_channels, 4, 3, 1, 1, bias=True) @@ -284,9 +279,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: N, _, H, W = x.shape transform_matrix = self.transform_matrix_conv(x) - transform_matrix = transform_matrix.permute(0, 2, 3, 1).reshape( - (N * H * W, 2, 2) - ) + transform_matrix = transform_matrix.permute(0, 2, 3, 1).reshape((N * H * W, 2, 2)) offset = torch.matmul(transform_matrix, self.regular_matrix) offset = offset - self.regular_matrix offset = offset.transpose(1, 2).reshape((N, H, W, 18)).permute(0, 3, 1, 2) diff --git a/dlclive/pose_estimation_pytorch/models/modules/conv_module.py b/dlclive/pose_estimation_pytorch/models/modules/conv_module.py index 8f7241bb..76136543 100644 --- a/dlclive/pose_estimation_pytorch/models/modules/conv_module.py +++ b/dlclive/pose_estimation_pytorch/models/modules/conv_module.py @@ -9,14 +9,13 @@ # Licensed under GNU Lesser General Public License v3.0 # """The code is based on DEKR: https://github.com/HRNet/DEKR/tree/main""" + import logging -from typing import List import torch.nn as nn from dlclive.pose_estimation_pytorch.models.modules import BasicBlock - BN_MOMENTUM = 0.1 logger = logging.getLogger(__name__) @@ -46,10 +45,8 @@ def __init__( fuse_method: str, multi_scale_output: bool = True, ): - super(HighResolutionModule, self).__init__() - self._check_branches( - num_branches, block, num_blocks, num_inchannels, num_channels - ) + super().__init__() + self._check_branches(num_branches, block, num_blocks, num_inchannels, num_channels) self.num_inchannels = num_inchannels self.fuse_method = fuse_method @@ -57,9 +54,7 @@ def __init__( self.multi_scale_output = multi_scale_output - self.branches = self._make_branches( - num_branches, block, num_blocks, num_channels - ) + self.branches = self._make_branches(num_branches, block, num_blocks, num_channels) self.fuse_layers = self._make_fuse_layers() self.relu = nn.ReLU(True) @@ -72,23 +67,17 @@ def _check_branches( num_channels: int, ): if num_branches != len(num_blocks): - error_msg = "NUM_BRANCHES({}) <> NUM_BLOCKS({})".format( - num_branches, len(num_blocks) - ) + error_msg = f"NUM_BRANCHES({num_branches}) <> NUM_BLOCKS({len(num_blocks)})" logger.error(error_msg) raise ValueError(error_msg) if num_branches != len(num_channels): - error_msg = "NUM_BRANCHES({}) <> NUM_CHANNELS({})".format( - num_branches, len(num_channels) - ) + error_msg = f"NUM_BRANCHES({num_branches}) <> NUM_CHANNELS({len(num_channels)})" logger.error(error_msg) raise ValueError(error_msg) if num_branches != len(num_inchannels): - error_msg = "NUM_BRANCHES({}) <> NUM_INCHANNELS({})".format( - num_branches, len(num_inchannels) - ) + error_msg = f"NUM_BRANCHES({num_branches}) <> NUM_INCHANNELS({len(num_inchannels)})" logger.error(error_msg) raise ValueError(error_msg) @@ -101,11 +90,7 @@ def _make_one_branch( stride: int = 1, ) -> nn.Sequential: downsample = None - if ( - stride != 1 - or self.num_inchannels[branch_index] - != num_channels[branch_index] * block.expansion - ): + if stride != 1 or self.num_inchannels[branch_index] != num_channels[branch_index] * block.expansion: downsample = nn.Sequential( nn.Conv2d( self.num_inchannels[branch_index], @@ -114,9 +99,7 @@ def _make_one_branch( stride=stride, bias=False, ), - nn.BatchNorm2d( - num_channels[branch_index] * block.expansion, momentum=BN_MOMENTUM - ), + nn.BatchNorm2d(num_channels[branch_index] * block.expansion, momentum=BN_MOMENTUM), ) layers = [] @@ -129,16 +112,12 @@ def _make_one_branch( ) ) self.num_inchannels[branch_index] = num_channels[branch_index] * block.expansion - for i in range(1, num_blocks[branch_index]): - layers.append( - block(self.num_inchannels[branch_index], num_channels[branch_index]) - ) + for _i in range(1, num_blocks[branch_index]): + layers.append(block(self.num_inchannels[branch_index], num_channels[branch_index])) return nn.Sequential(*layers) - def _make_branches( - self, num_branches: int, block: BasicBlock, num_blocks: int, num_channels: int - ) -> nn.ModuleList: + def _make_branches(self, num_branches: int, block: BasicBlock, num_blocks: int, num_channels: int) -> nn.ModuleList: branches = [] for i in range(num_branches): @@ -215,7 +194,7 @@ def _make_fuse_layers(self) -> nn.ModuleList: def get_num_inchannels(self) -> int: return self.num_inchannels - def forward(self, x) -> List: + def forward(self, x) -> list: """Forward pass through the HighResolutionModule. Args: diff --git a/dlclive/pose_estimation_pytorch/models/modules/csp.py b/dlclive/pose_estimation_pytorch/models/modules/csp.py index 3099eebe..49548cf0 100644 --- a/dlclive/pose_estimation_pytorch/models/modules/csp.py +++ b/dlclive/pose_estimation_pytorch/models/modules/csp.py @@ -13,6 +13,7 @@ Based on the building blocks used for the ``mmdetection`` CSPNeXt implementation. For more information, see . """ + import torch import torch.nn as nn @@ -23,9 +24,7 @@ def build_activation(activation_fn: str, *args, **kwargs) -> nn.Module: elif activation_fn == "ReLU": return nn.ReLU(*args, **kwargs) - raise NotImplementedError( - f"Unknown `CSPNeXT` activation: {activation_fn}. Must be one of 'SiLU', 'ReLU'" - ) + raise NotImplementedError(f"Unknown `CSPNeXT` activation: {activation_fn}. Must be one of 'SiLU', 'ReLU'") def build_norm(norm: str, *args, **kwargs) -> nn.Module: @@ -34,9 +33,7 @@ def build_norm(norm: str, *args, **kwargs) -> nn.Module: elif norm == "BN": return nn.BatchNorm2d(*args, **kwargs) - raise NotImplementedError( - f"Unknown `CSPNeXT` norm_layer: {norm}. Must be one of 'SyncBN', 'BN'" - ) + raise NotImplementedError(f"Unknown `CSPNeXT` norm_layer: {norm}. Must be one of 'SyncBN', 'BN'") class SPPBottleneck(nn.Module): @@ -69,12 +66,7 @@ def __init__( activation_fn=activation_fn, ) - self.poolings = nn.ModuleList( - [ - nn.MaxPool2d(kernel_size=ks, stride=1, padding=ks // 2) - for ks in kernel_sizes - ] - ) + self.poolings = nn.ModuleList([nn.MaxPool2d(kernel_size=ks, stride=1, padding=ks // 2) for ks in kernel_sizes]) conv2_channels = mid_channels * (len(kernel_sizes) + 1) self.conv2 = CSPConvModule( conv2_channels, diff --git a/dlclive/pose_estimation_pytorch/models/modules/gated_attention_unit.py b/dlclive/pose_estimation_pytorch/models/modules/gated_attention_unit.py index f26aa200..c5c29531 100644 --- a/dlclive/pose_estimation_pytorch/models/modules/gated_attention_unit.py +++ b/dlclive/pose_estimation_pytorch/models/modules/gated_attention_unit.py @@ -13,14 +13,15 @@ Based on the building blocks used for the ``mmdetection`` CSPNeXt implementation. For more information, see . """ + from __future__ import annotations import math +import timm.layers as timm_layers import torch import torch.nn as nn import torch.nn.functional as F -import timm.layers as timm_layers from dlclive.pose_estimation_pytorch.models.modules.norm import ScaleNorm @@ -36,17 +37,13 @@ def rope(x, dim): for i in spatial_shape: total_len *= i - position = torch.reshape( - torch.arange(total_len, dtype=torch.int, device=x.device), spatial_shape - ) + position = torch.reshape(torch.arange(total_len, dtype=torch.int, device=x.device), spatial_shape) - for i in range(dim[-1] + 1, len(shape) - 1, 1): + for _i in range(dim[-1] + 1, len(shape) - 1, 1): position = torch.unsqueeze(position, dim=-1) half_size = shape[-1] // 2 - freq_seq = -torch.arange(half_size, dtype=torch.int, device=x.device) / float( - half_size - ) + freq_seq = -torch.arange(half_size, dtype=torch.int, device=x.device) / float(half_size) inv_freq = 10000**-freq_seq sinusoid = position[..., None] * inv_freq[None, None, :] @@ -94,7 +91,7 @@ def __init__( use_rel_bias=True, pos_enc=False, ): - super(GatedAttentionUnit, self).__init__() + super().__init__() self.s = s self.num_token = num_token self.use_rel_bias = use_rel_bias @@ -109,9 +106,7 @@ def __init__( self.e = int(in_token_dims * expansion_factor) if use_rel_bias: if attn_type == "self-attn": - self.w = nn.Parameter( - torch.rand([2 * num_token - 1], dtype=torch.float) - ) + self.w = nn.Parameter(torch.rand([2 * num_token - 1], dtype=torch.float)) else: self.a = nn.Parameter(torch.rand([1, s], dtype=torch.float)) self.b = nn.Parameter(torch.rand([1, s], dtype=torch.float)) diff --git a/dlclive/pose_estimation_pytorch/models/modules/norm.py b/dlclive/pose_estimation_pytorch/models/modules/norm.py index 9bf839b4..ecaa9454 100644 --- a/dlclive/pose_estimation_pytorch/models/modules/norm.py +++ b/dlclive/pose_estimation_pytorch/models/modules/norm.py @@ -9,6 +9,7 @@ # Licensed under GNU Lesser General Public License v3.0 # """Normalization layers""" + from __future__ import annotations import torch diff --git a/dlclive/pose_estimation_pytorch/models/necks/layers.py b/dlclive/pose_estimation_pytorch/models/necks/layers.py index 1dee7b10..7dbd1125 100644 --- a/dlclive/pose_estimation_pytorch/models/necks/layers.py +++ b/dlclive/pose_estimation_pytorch/models/necks/layers.py @@ -160,9 +160,7 @@ def __init__( self.scale = (dim // heads) ** -0.5 if scale_with_head else dim**-0.5 self.to_qkv = torch.nn.Linear(dim, dim * 3, bias=False) - self.to_out = torch.nn.Sequential( - torch.nn.Linear(dim, dim), torch.nn.Dropout(dropout) - ) + self.to_out = torch.nn.Sequential(torch.nn.Linear(dim, dim), torch.nn.Dropout(dropout)) self.num_keypoints = num_keypoints def forward(self, x: torch.Tensor, mask: torch.Tensor = None): @@ -175,7 +173,7 @@ def forward(self, x: torch.Tensor, mask: torch.Tensor = None): Returns: Output tensor. """ - b, n, _, h = *x.shape, self.heads + _b, _n, _, h = *x.shape, self.heads qkv = self.to_qkv(x).chunk(3, dim=-1) q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=h), qkv) @@ -259,16 +257,12 @@ def __init__( ), ) ), - Residual( - PreNorm(dim, FeedForward(dim, mlp_dim, dropout=dropout)) - ), + Residual(PreNorm(dim, FeedForward(dim, mlp_dim, dropout=dropout))), ] ) ) - def forward( - self, x: torch.Tensor, mask: torch.Tensor = None, pos: torch.Tensor = None - ): + def forward(self, x: torch.Tensor, mask: torch.Tensor = None, pos: torch.Tensor = None): """Forward pass through the TransformerLayer block. Args: diff --git a/dlclive/pose_estimation_pytorch/models/necks/transformer.py b/dlclive/pose_estimation_pytorch/models/necks/transformer.py index a6176943..970dac1f 100644 --- a/dlclive/pose_estimation_pytorch/models/necks/transformer.py +++ b/dlclive/pose_estimation_pytorch/models/necks/transformer.py @@ -8,7 +8,6 @@ # # Licensed under GNU Lesser General Public License v3.0 # -from typing import Tuple import torch from einops import rearrange, repeat @@ -28,7 +27,9 @@ class Transformer(BaseNeck): """Transformer Neck for pose estimation. title={TokenPose: Learning Keypoint Tokens for Human Pose Estimation}, - author={Yanjie Li and Shoukui Zhang and Zhicheng Wang and Sen Yang and Wankou Yang and Shu-Tao Xia and Erjin Zhou}, + author={ + Yanjie Li and Shoukui Zhang and Zhicheng Wang and Sen Yang + and Wankou Yang and Shu-Tao Xia and Erjin Zhou}, booktitle={IEEE/CVF International Conference on Computer Vision (ICCV)}, year={2021} @@ -79,25 +80,23 @@ class Transformer(BaseNeck): def __init__( self, *, - feature_size: Tuple[int, int], - patch_size: Tuple[int, int], + feature_size: tuple[int, int], + patch_size: tuple[int, int], num_keypoints: int, dim: int, depth: int, heads: int, mlp_dim: int = 3, apply_init: bool = False, - heatmap_size: Tuple[int, int] = (64, 64), + heatmap_size: tuple[int, int] = (64, 64), channels: int = 32, dropout: float = 0.0, emb_dropout: float = 0.0, - pos_embedding_type: str = "sine-full" + pos_embedding_type: str = "sine-full", ): super().__init__() - num_patches = (feature_size[0] // (patch_size[0])) * ( - feature_size[1] // (patch_size[1]) - ) + num_patches = (feature_size[0] // (patch_size[0])) * (feature_size[1] // (patch_size[1])) patch_dim = channels * patch_size[0] * patch_size[1] self.inplanes = 64 @@ -108,9 +107,7 @@ def __init__( self.pos_embedding_type = pos_embedding_type self.all_attn = self.pos_embedding_type == "sine-full" - self.keypoint_token = torch.nn.Parameter( - torch.zeros(1, self.num_keypoints, dim) - ) + self.keypoint_token = torch.nn.Parameter(torch.zeros(1, self.num_keypoints, dim)) h, w = ( feature_size[0] // (self.patch_size[0]), feature_size[1] // (self.patch_size[1]), @@ -156,9 +153,7 @@ def __init__( if apply_init: self.apply(self._init_weights) - def _make_position_embedding( - self, w: int, h: int, d_model: int, pe_type="learnable" - ): + def _make_position_embedding(self, w: int, h: int, d_model: int, pe_type="learnable"): """Create position embeddings for the transformer. Args: @@ -172,17 +167,11 @@ def _make_position_embedding( self.pe_h = h self.pe_w = w if pe_type != "learnable": - self.pos_embedding = torch.nn.Parameter( - make_sine_position_embedding(h, w, d_model), requires_grad=False - ) + self.pos_embedding = torch.nn.Parameter(make_sine_position_embedding(h, w, d_model), requires_grad=False) else: - self.pos_embedding = torch.nn.Parameter( - torch.zeros(1, self.num_patches + self.num_keypoints, d_model) - ) + self.pos_embedding = torch.nn.Parameter(torch.zeros(1, self.num_patches + self.num_keypoints, d_model)) - def _make_layer( - self, block: torch.nn.Module, planes: int, blocks: int, stride: int = 1 - ) -> torch.nn.Sequential: + def _make_layer(self, block: torch.nn.Module, planes: int, blocks: int, stride: int = 1) -> torch.nn.Sequential: """Create a layer of the transformer encoder. Args: @@ -210,7 +199,7 @@ def _make_layer( layers = [] layers.append(block(self.inplanes, planes, stride, downsample)) self.inplanes = planes * block.expansion - for i in range(1, blocks): + for _i in range(1, blocks): layers.append(block(self.inplanes, planes)) return torch.nn.Sequential(*layers) @@ -247,9 +236,7 @@ def forward(self, feature: torch.Tensor, mask=None) -> torch.Tensor: """ p = self.patch_size - x = rearrange( - feature, "b c (h p1) (w p2) -> b (h w) (p1 p2 c)", p1=p[0], p2=p[1] - ) + x = rearrange(feature, "b c (h p1) (w p2) -> b (h w) (p1 p2 c)", p1=p[0], p2=p[1]) x = self.patch_to_embedding(x) b, n, _ = x.shape diff --git a/dlclive/pose_estimation_pytorch/models/necks/utils.py b/dlclive/pose_estimation_pytorch/models/necks/utils.py index 028078b8..bbcf81a9 100644 --- a/dlclive/pose_estimation_pytorch/models/necks/utils.py +++ b/dlclive/pose_estimation_pytorch/models/necks/utils.py @@ -48,12 +48,8 @@ def make_sine_position_embedding( pos_x = x_embed[:, :, :, None] / dim_t pos_y = y_embed[:, :, :, None] / dim_t - pos_x = torch.stack( - (pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4 - ).flatten(3) - pos_y = torch.stack( - (pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4 - ).flatten(3) + pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3) + pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3) pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2) pos = pos.flatten(2).permute(0, 2, 1) diff --git a/dlclive/pose_estimation_pytorch/models/predictors/__init__.py b/dlclive/pose_estimation_pytorch/models/predictors/__init__.py index 0662ffac..b449b898 100644 --- a/dlclive/pose_estimation_pytorch/models/predictors/__init__.py +++ b/dlclive/pose_estimation_pytorch/models/predictors/__init__.py @@ -15,10 +15,10 @@ from dlclive.pose_estimation_pytorch.models.predictors.dekr_predictor import ( DEKRPredictor, ) +from dlclive.pose_estimation_pytorch.models.predictors.paf_predictor import ( + PartAffinityFieldPredictor, +) from dlclive.pose_estimation_pytorch.models.predictors.sim_cc import SimCCPredictor from dlclive.pose_estimation_pytorch.models.predictors.single_predictor import ( HeatmapPredictor, ) -from dlclive.pose_estimation_pytorch.models.predictors.paf_predictor import ( - PartAffinityFieldPredictor, -) diff --git a/dlclive/pose_estimation_pytorch/models/predictors/base.py b/dlclive/pose_estimation_pytorch/models/predictors/base.py index f8b8b987..8e45b766 100644 --- a/dlclive/pose_estimation_pytorch/models/predictors/base.py +++ b/dlclive/pose_estimation_pytorch/models/predictors/base.py @@ -45,9 +45,7 @@ def __init__(self): self.num_animals = None @abstractmethod - def forward( - self, stride: float, outputs: dict[str, torch.Tensor] - ) -> dict[str, torch.Tensor]: + def forward(self, stride: float, outputs: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: """Abstract method for the forward pass of the Predictor. Args: diff --git a/dlclive/pose_estimation_pytorch/models/predictors/dekr_predictor.py b/dlclive/pose_estimation_pytorch/models/predictors/dekr_predictor.py index 7f32967d..b243e99c 100644 --- a/dlclive/pose_estimation_pytorch/models/predictors/dekr_predictor.py +++ b/dlclive/pose_estimation_pytorch/models/predictors/dekr_predictor.py @@ -107,9 +107,7 @@ def __init__( self.nms_threshold = nms_threshold self.apply_pose_nms = apply_pose_nms - def forward( - self, stride: float, outputs: dict[str, torch.Tensor] - ) -> dict[str, torch.Tensor]: + def forward(self, stride: float, outputs: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: """Forward pass of DEKRPredictor. Args: @@ -147,30 +145,18 @@ def forward( poses = self._update_pose_with_heatmaps(poses, heatmaps[:, :-1]) if self.keypoint_score_type == "center": - score = ( - ctr_scores.unsqueeze(-1) - .expand(batch_size, -1, num_joints) - .unsqueeze(-1) - ) + score = ctr_scores.unsqueeze(-1).expand(batch_size, -1, num_joints).unsqueeze(-1) elif self.keypoint_score_type == "heatmap": score = self.get_heat_value(poses, heatmaps).unsqueeze(-1) elif self.keypoint_score_type == "combined": - center_score = ( - ctr_scores.unsqueeze(-1) - .expand(batch_size, -1, num_joints) - .unsqueeze(-1) - ) + center_score = ctr_scores.unsqueeze(-1).expand(batch_size, -1, num_joints).unsqueeze(-1) htmp_score = self.get_heat_value(poses, heatmaps).unsqueeze(-1) score = center_score * htmp_score else: raise ValueError(f"Unknown keypoint score type: {self.keypoint_score_type}") - poses[:, :, :, 0] = ( - poses[:, :, :, 0] * scale_factors[1] + 0.5 * scale_factors[1] - ) - poses[:, :, :, 1] = ( - poses[:, :, :, 1] * scale_factors[0] + 0.5 * scale_factors[0] - ) + poses[:, :, :, 0] = poses[:, :, :, 0] * scale_factors[1] + 0.5 * scale_factors[1] + poses[:, :, :, 1] = poses[:, :, :, 1] * scale_factors[0] + 0.5 * scale_factors[0] if self.clip_scores: score = torch.clip(score, min=0, max=1) @@ -181,9 +167,7 @@ def forward( return {"poses": poses_w_scores} - def get_locations( - self, height: int, width: int, device: torch.device - ) -> torch.Tensor: + def get_locations(self, height: int, width: int, device: torch.device) -> torch.Tensor: """Get locations for offsets. Args: @@ -245,11 +229,7 @@ def offset_to_pose(self, offsets: torch.Tensor) -> torch.Tensor: num_joints = int(num_offset / 2) reg_poses = self.get_reg_poses(offsets, num_joints) - reg_poses = ( - reg_poses.contiguous() - .view(batch_size, h * w, 2 * num_joints) - .permute(0, 2, 1) - ) + reg_poses = reg_poses.contiguous().view(batch_size, h * w, 2 * num_joints).permute(0, 2, 1) reg_poses = reg_poses.contiguous().view(batch_size, -1, h, w).contiguous() return reg_poses @@ -267,19 +247,15 @@ def max_pool(self, heatmap: torch.Tensor) -> torch.Tensor: # Assuming you have 'heatmap' tensor max_pooled_heatmap = predictor.max_pool(heatmap) """ - pool1 = torch.nn.MaxPool2d(3, 1, 1) + torch.nn.MaxPool2d(3, 1, 1) pool2 = torch.nn.MaxPool2d(5, 1, 2) - pool3 = torch.nn.MaxPool2d(7, 1, 3) - map_size = (heatmap.shape[1] + heatmap.shape[2]) / 2.0 - maxm = pool2( - heatmap - ) # Here I think pool 2 is a good match for default 17 pos_dist_tresh + torch.nn.MaxPool2d(7, 1, 3) + (heatmap.shape[1] + heatmap.shape[2]) / 2.0 + maxm = pool2(heatmap) # Here I think pool 2 is a good match for default 17 pos_dist_tresh return maxm - def get_top_values( - self, heatmap: torch.Tensor - ) -> tuple[torch.Tensor, torch.Tensor]: + def get_top_values(self, heatmap: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: """Get top values from the heatmap. Args: @@ -303,9 +279,7 @@ def get_top_values( return pos_ind, scores - def _update_pose_with_heatmaps( - self, _poses: torch.Tensor, kpt_heatmaps: torch.Tensor - ): + def _update_pose_with_heatmaps(self, _poses: torch.Tensor, kpt_heatmaps: torch.Tensor): """If a heatmap center is close enough from the regressed point, the final prediction is the center of this heatmap @@ -323,9 +297,7 @@ def _update_pose_with_heatmaps( x = ind % w y = (ind / w).long() - heats_ind = torch.stack( - (x, y), dim=3 - ) # (batch_size, num_keypoints, num_animals, 2) + heats_ind = torch.stack((x, y), dim=3) # (batch_size, num_keypoints, num_animals, 2) # Calculate differences between all pose-heat pairs # (batch_size, num_animals, num_keypoints, 1, 2) - (batch_size, 1, num_keypoints, num_animals, 2) @@ -333,29 +305,21 @@ def _update_pose_with_heatmaps( 1 ) # (batch_size, num_animals, num_keypoints, num_animals, 2) - pose_heat_dist = torch.norm( - pose_heat_diff, dim=-1 - ) # (batch_size, num_animals, num_keypoints, num_animals) + pose_heat_dist = torch.norm(pose_heat_diff, dim=-1) # (batch_size, num_animals, num_keypoints, num_animals) # Find closest heat point for each pose - keep_ind = torch.argmin( - pose_heat_dist, dim=-1 - ) # (batch_size, num_animals, num_keypoints) + keep_ind = torch.argmin(pose_heat_dist, dim=-1) # (batch_size, num_animals, num_keypoints) # Get minimum distances for filtering min_distances = torch.gather(pose_heat_dist, 3, keep_ind.unsqueeze(-1)).squeeze( -1 ) # (batch_size, num_animals, num_keypoints) - absorb_mask = ( - min_distances < self.max_absorb_distance - ) # (batch_size, num_animals, num_keypoints) + absorb_mask = min_distances < self.max_absorb_distance # (batch_size, num_animals, num_keypoints) # Create indices for gathering the correct heat points batch_indices = torch.arange(batch_size, device=poses.device).view(-1, 1, 1) - keypoint_indices = torch.arange(num_keypoints, device=poses.device).view( - 1, 1, -1 - ) + keypoint_indices = torch.arange(num_keypoints, device=poses.device).view(1, 1, -1) selected_heat_points = heats_ind[ batch_indices, keypoint_indices, keep_ind @@ -365,9 +329,7 @@ def _update_pose_with_heatmaps( return poses - def get_heat_value( - self, pose_coords: torch.Tensor, heatmaps: torch.Tensor - ) -> torch.Tensor: + def get_heat_value(self, pose_coords: torch.Tensor, heatmaps: torch.Tensor) -> torch.Tensor: """Get heat values for pose coordinates and heatmaps. Args: @@ -382,9 +344,7 @@ def get_heat_value( heat_values = predictor.get_heat_value(pose_coords, heatmaps) """ h, w = heatmaps.shape[2:] - heatmaps_nocenter = heatmaps[:, :-1].flatten( - 2, 3 - ) # (batch_size, num_joints, h*w) + heatmaps_nocenter = heatmaps[:, :-1].flatten(2, 3) # (batch_size, num_joints, h*w) # Predicted poses based on the offset can be outside the image x = torch.clamp(torch.floor(pose_coords[:, :, :, 0]), 0, w - 1).long() @@ -413,11 +373,7 @@ def pose_nms(self, poses: torch.Tensor) -> torch.Tensor: w = xy[..., 0].max(dim=-1)[0] - xy[..., 0].min(dim=-1)[0] h = xy[..., 1].max(dim=-1)[0] - xy[..., 1].min(dim=-1)[0] area = torch.clamp((w * w) + (h * h), min=1) - area = ( - area.unsqueeze(1) - .unsqueeze(3) - .expand(batch_size, num_people, num_people, num_joints) - ) + area = area.unsqueeze(1).unsqueeze(3).expand(batch_size, num_people, num_people, num_joints) # compute the difference between keypoints pose_diff = xy.unsqueeze(2) - xy.unsqueeze(1) @@ -432,13 +388,9 @@ def pose_nms(self, poses: torch.Tensor) -> torch.Tensor: nms_pose = pose_dist > self.nms_threshold # shape (b, num_people, num_people) # Upper triangular mask matrix to avoid double processing - triu_mask = torch.triu( - torch.ones(num_people, num_people, device=device), diagonal=1 - ).bool() + triu_mask = torch.triu(torch.ones(num_people, num_people, device=device), diagonal=1).bool() - suppress_pairs = nms_pose & triu_mask.unsqueeze( - 0 - ) # (batch_size, num_people, num_people) + suppress_pairs = nms_pose & triu_mask.unsqueeze(0) # (batch_size, num_people, num_people) # For each batch, determine which poses to suppress suppressed = suppress_pairs.any(dim=1) # (batch_size, num_people) @@ -447,9 +399,7 @@ def pose_nms(self, poses: torch.Tensor) -> torch.Tensor: # Indices for reordering batch_indices = torch.arange(batch_size, device=device).unsqueeze(1) - people_indices = ( - torch.arange(num_people, device=device).unsqueeze(0).expand(batch_size, -1) - ) + people_indices = torch.arange(num_people, device=device).unsqueeze(0).expand(batch_size, -1) # non-suppressed first, then suppressed sort_keys = kept.float() + (people_indices.float() + 1) / (num_people + 1) @@ -461,4 +411,4 @@ def pose_nms(self, poses: torch.Tensor) -> torch.Tensor: # Re-order predictions so the non-suppressed ones are up top poses = poses[batch_indices, sort_indices] - return poses \ No newline at end of file + return poses diff --git a/dlclive/pose_estimation_pytorch/models/predictors/identity_predictor.py b/dlclive/pose_estimation_pytorch/models/predictors/identity_predictor.py index a4837c57..29a3a872 100644 --- a/dlclive/pose_estimation_pytorch/models/predictors/identity_predictor.py +++ b/dlclive/pose_estimation_pytorch/models/predictors/identity_predictor.py @@ -9,6 +9,7 @@ # Licensed under GNU Lesser General Public License v3.0 # """Predictor to generate identity maps from head outputs""" + import torch import torch.nn as nn import torchvision.transforms.functional as F @@ -36,9 +37,7 @@ def __init__(self, apply_sigmoid: bool = True): self.apply_sigmoid = apply_sigmoid self.sigmoid = nn.Sigmoid() - def forward( - self, stride: float, outputs: dict[str, torch.Tensor] - ) -> dict[str, torch.Tensor]: + def forward(self, stride: float, outputs: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: """ Swaps the dimensions so the heatmap are (batch_size, h, w, num_individuals), optionally applies a sigmoid to the heatmaps, and rescales it to be the size diff --git a/dlclive/pose_estimation_pytorch/models/predictors/paf_predictor.py b/dlclive/pose_estimation_pytorch/models/predictors/paf_predictor.py index e172ad4c..1bc8c329 100644 --- a/dlclive/pose_estimation_pytorch/models/predictors/paf_predictor.py +++ b/dlclive/pose_estimation_pytorch/models/predictors/paf_predictor.py @@ -10,17 +10,18 @@ # from __future__ import annotations +from collections import defaultdict + import numpy as np import torch import torch.nn.functional as F from numpy.typing import NDArray -from collections import defaultdict +from dlclive.core import inferenceutils from dlclive.pose_estimation_pytorch.models.predictors.base import ( - BasePredictor, PREDICTORS, + BasePredictor, ) -from dlclive.core import inferenceutils Graph = list[tuple[int, int]] @@ -32,7 +33,8 @@ class PartAffinityFieldPredictor(BasePredictor): Args: num_animals: Number of animals in the project. num_multibodyparts: Number of animal's body parts (ignoring unique body parts). - num_uniquebodyparts: Number of unique body parts. # FIXME - should not be needed here if we separate the unique bodypart head + num_uniquebodyparts: Number of unique body parts. + # FIXME - should not be needed here if we separate the unique bodypart head graph: Part affinity field graph edges. edges_to_keep: List of indices in `graph` of the edges to keep. locref_stdev: Standard deviation for location refinement. @@ -109,9 +111,7 @@ def __init__( force_fusion=force_fusion, ) - def forward( - self, stride: float, outputs: dict[str, torch.Tensor] - ) -> dict[str, torch.Tensor]: + def forward(self, stride: float, outputs: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: """Forward pass of PartAffinityFieldPredictor. Gets predictions from model output. Args: @@ -141,21 +141,15 @@ def forward( # Filter predicted heatmaps with a 2D Gaussian kernel as in: # https://openaccess.thecvf.com/content_CVPR_2020/papers/Huang_The_Devil_Is_in_the_Details_Delving_Into_Unbiased_Data_CVPR_2020_paper.pdf - kernel = self.make_2d_gaussian_kernel( - sigma=self.sigma, size=self.nms_radius * 2 + 1 - )[None, None] + kernel = self.make_2d_gaussian_kernel(sigma=self.sigma, size=self.nms_radius * 2 + 1)[None, None] kernel = kernel.repeat(n_channels, 1, 1, 1).to(heatmaps.device) - heatmaps = F.conv2d( - heatmaps, kernel, stride=1, padding="same", groups=n_channels - ) + heatmaps = F.conv2d(heatmaps, kernel, stride=1, padding="same", groups=n_channels) peaks = self.find_local_peak_indices_maxpool_nms( heatmaps, self.nms_radius, threshold=0.01 ) # (n_peaks, 4) -> columns: (batch, part, height, width) if ~torch.any(peaks): - poses = -torch.ones( - (batch_size, self.num_animals, self.num_multibodyparts, 5) - ) + poses = -torch.ones((batch_size, self.num_animals, self.num_multibodyparts, 5)) results = dict(poses=poses) if self.return_preds: results["preds"] = ([dict(coordinates=[[]], costs=[])],) @@ -163,12 +157,8 @@ def forward( return results locrefs = locrefs.reshape(batch_size, n_channels, 2, height, width) - locrefs = ( - locrefs * self.locref_stdev - ) # (batch_size, num_joints, 2, height, width) - pafs = pafs.reshape( - batch_size, -1, 2, height, width - ) # (batch_size, num_edges, 2, height, width) + locrefs = locrefs * self.locref_stdev # (batch_size, num_joints, 2, height, width) + pafs = pafs.reshape(batch_size, -1, 2, height, width) # (batch_size, num_edges, 2, height, width) # Use only the minimal tree edges for efficiency graph = [self.graph[ind] for ind in self.edges_to_keep] @@ -209,9 +199,7 @@ def forward( return out @staticmethod - def find_local_peak_indices_maxpool_nms( - input_: torch.Tensor, radius: int, threshold: float - ) -> torch.Tensor: + def find_local_peak_indices_maxpool_nms(input_: torch.Tensor, radius: int, threshold: float) -> torch.Tensor: pooled = F.max_pool2d(input_, kernel_size=radius, stride=1, padding=radius // 2) maxima = input_ * torch.eq(input_, pooled).float() peak_indices = torch.nonzero(maxima >= threshold, as_tuple=False) @@ -304,12 +292,8 @@ def compute_edge_costs( batch_bodyparts = peak_bodyparts[batch_mask] # Masks of peaks that match each edge's source/dest bodypart for this batch - src_mask = batch_bodyparts.unsqueeze(0) == src_bodypart_id.unsqueeze( - 1 - ) # (n_edges, n_batch_peaks) - dst_mask = batch_bodyparts.unsqueeze(0) == dst_bodypart_id.unsqueeze( - 1 - ) # (n_edges, n_batch_peaks) + src_mask = batch_bodyparts.unsqueeze(0) == src_bodypart_id.unsqueeze(1) # (n_edges, n_batch_peaks) + dst_mask = batch_bodyparts.unsqueeze(0) == dst_bodypart_id.unsqueeze(1) # (n_edges, n_batch_peaks) # Valid src/dst peaks for each edge in this batch: (n_edges, n_batch_peaks, n_batch_peaks) valid_pairs = src_mask.unsqueeze(2) & dst_mask.unsqueeze(1) @@ -339,12 +323,8 @@ def compute_edge_costs( edge_idx = paf_limb_inds[edge_idx] # Map back to original PAF indices # Gather coordinates - src_coords = torch.stack( - [peak_rows[src_idx], peak_cols[src_idx]], dim=1 - ) # (found_pairs, 2) - dst_coords = torch.stack( - [peak_rows[dst_idx], peak_cols[dst_idx]], dim=1 - ) # (found_pairs, 2) + src_coords = torch.stack([peak_rows[src_idx], peak_cols[src_idx]], dim=1) # (found_pairs, 2) + dst_coords = torch.stack([peak_rows[dst_idx], peak_cols[dst_idx]], dim=1) # (found_pairs, 2) vecs_s = src_coords.float() # (found_pairs, 2) vecs_t = dst_coords.float() # (found_pairs, 2) @@ -372,9 +352,7 @@ def compute_edge_costs( ] # Integrate PAF along segment using trapezoidal rule - xy_reversed = torch.flip( - xy.float(), dims=[-1] - ) + xy_reversed = torch.flip(xy.float(), dims=[-1]) integ = torch.trapz(y, xy_reversed, dim=1) # (n_edges, 2) affinities = torch.norm(integ, dim=1) # (n_edges,) affinities = affinities / lengths @@ -401,16 +379,14 @@ def compute_edge_costs( # Run-length encode on (batch, limb) boundaries where (batch, limb) changes change = np.empty(batch_inds.size, dtype=bool) change[0] = True - change[1:] = (batch_inds[1:] != batch_inds[:-1]) | ( - edge_idx[1:] != edge_idx[:-1] - ) + change[1:] = (batch_inds[1:] != batch_inds[:-1]) | (edge_idx[1:] != edge_idx[:-1]) group_starts = np.flatnonzero(change) # Add sentinel end group_ends = np.r_[group_starts[1:], batch_inds.size] # Build an index dict of slices for group lookup by (batch, limb) batch_groups = defaultdict(list) # (batch)->list of (limb, start, end) - for st, en in zip(group_starts, group_ends): + for st, en in zip(group_starts, group_ends, strict=False): b = batch_inds[st] k = edge_idx[st] batch_groups[b].append((k, st, en)) @@ -499,9 +475,7 @@ def compute_peaks_and_costs( batch_size, n_channels = heatmaps.shape[:2] n_bodyparts = n_channels - n_id_channels # Refine peak positions to input-image pixels - pos = self.calc_peak_locations( - locrefs, peak_inds_in_batch, strides - ) # (n_peaks, 2) + pos = self.calc_peak_locations(locrefs, peak_inds_in_batch, strides) # (n_peaks, 2) # Compute per-limb affinity matrices via PAF line integral costs = self.compute_edge_costs( @@ -551,4 +525,4 @@ def set_paf_edges_to_keep(self, edge_indices: list[int]) -> None: edge_indices: The indices of edges in the graph to keep. """ self.edges_to_keep = edge_indices - self.assembler.paf_inds = edge_indices \ No newline at end of file + self.assembler.paf_inds = edge_indices diff --git a/dlclive/pose_estimation_pytorch/models/predictors/sim_cc.py b/dlclive/pose_estimation_pytorch/models/predictors/sim_cc.py index 022afdf4..9b4bd5e2 100644 --- a/dlclive/pose_estimation_pytorch/models/predictors/sim_cc.py +++ b/dlclive/pose_estimation_pytorch/models/predictors/sim_cc.py @@ -13,14 +13,15 @@ Based on the official ``mmpose`` SimCC codec and RTMCC head implementation. For more information, see . """ + from __future__ import annotations import numpy as np import torch from dlclive.pose_estimation_pytorch.models.predictors.base import ( - BasePredictor, PREDICTORS, + BasePredictor, ) @@ -58,9 +59,7 @@ def __init__( self.sigma = np.array(sigma) self.decode_beta = decode_beta - def forward( - self, stride: float, outputs: dict[str, torch.Tensor] - ) -> dict[str, torch.Tensor]: + def forward(self, stride: float, outputs: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: x, y = outputs["x"].detach(), outputs["y"].detach() if self.normalize_outputs: @@ -70,9 +69,7 @@ def forward( x = x * (self.sigma[0] * self.decode_beta) y = y * (self.sigma[1] * self.decode_beta) - keypoints, scores = get_simcc_maximum( - x.cpu().numpy(), y.cpu().numpy(), self.apply_softmax - ) + keypoints, scores = get_simcc_maximum(x.cpu().numpy(), y.cpu().numpy(), self.apply_softmax) if keypoints.ndim == 2: keypoints = keypoints[None, :] diff --git a/dlclive/pose_estimation_pytorch/models/predictors/single_predictor.py b/dlclive/pose_estimation_pytorch/models/predictors/single_predictor.py index deeed64a..b1a2f16b 100644 --- a/dlclive/pose_estimation_pytorch/models/predictors/single_predictor.py +++ b/dlclive/pose_estimation_pytorch/models/predictors/single_predictor.py @@ -10,13 +10,11 @@ # from __future__ import annotations -from typing import Tuple - import torch from dlclive.pose_estimation_pytorch.models.predictors.base import ( - BasePredictor, PREDICTORS, + BasePredictor, ) @@ -55,9 +53,7 @@ def __init__( self.location_refinement = location_refinement self.locref_std = locref_std - def forward( - self, stride: float, outputs: dict[str, torch.Tensor] - ) -> dict[str, torch.Tensor]: + def forward(self, stride: float, outputs: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: """Forward pass of SinglePredictor. Gets predictions from model output. Args: @@ -85,9 +81,7 @@ def forward( locrefs = None if self.location_refinement: locrefs = outputs["locref"] - locrefs = locrefs.permute(0, 2, 3, 1).reshape( - batch_size, height, width, num_joints, 2 - ) + locrefs = locrefs.permute(0, 2, 3, 1).reshape(batch_size, height, width, num_joints, 2) locrefs = locrefs * self.locref_std poses = self.get_pose_prediction(heatmaps, locrefs, scale_factors) @@ -97,9 +91,7 @@ def forward( return {"poses": poses} - def get_top_values( - self, heatmap: torch.Tensor - ) -> Tuple[torch.Tensor, torch.Tensor]: + def get_top_values(self, heatmap: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: """Get the top values from the heatmap. Args: @@ -119,9 +111,7 @@ def get_top_values( y, x = heatmap_top // nx, heatmap_top % nx return y, x - def get_pose_prediction( - self, heatmap: torch.Tensor, locref: torch.Tensor | None, scale_factors - ) -> torch.Tensor: + def get_pose_prediction(self, heatmap: torch.Tensor, locref: torch.Tensor | None, scale_factors) -> torch.Tensor: """Gets the pose prediction given the heatmaps and locref. Args: @@ -146,15 +136,11 @@ def get_pose_prediction( # Create batch and joint indices for indexing # batch_idx: [[0,0,0,...], [1,1,1,...], [2,2,2,...], ...] batch_idx = ( - torch.arange(batch_size, device=heatmap.device) - .unsqueeze(1) - .expand(-1, num_joints) + torch.arange(batch_size, device=heatmap.device).unsqueeze(1).expand(-1, num_joints) ) # (batch_size, num_joints) # joint_idx: [[0,1,2,...], [0,1,2,...], [0,1,2,...], ...] joint_idx = ( - torch.arange(num_joints, device=heatmap.device) - .unsqueeze(0) - .expand(batch_size, -1) + torch.arange(num_joints, device=heatmap.device).unsqueeze(0).expand(batch_size, -1) ) # (batch_size, num_joints) # Vectorized extraction of heatmap scores and locref offsets @@ -164,9 +150,7 @@ def get_pose_prediction( dz[:, 0, :, 2] = scores if locref is not None: - offsets = locref[ - batch_idx, y, x, joint_idx, : - ] # (batch_size, num_joints, 2) + offsets = locref[batch_idx, y, x, joint_idx, :] # (batch_size, num_joints, 2) dz[:, 0, :, :2] = offsets x, y = x.unsqueeze(1), y.unsqueeze(1) # x, y: (batch_size, 1, num_joints) @@ -174,8 +158,6 @@ def get_pose_prediction( x = x * scale_factors[1] + 0.5 * scale_factors[1] + dz[:, :, :, 0] y = y * scale_factors[0] + 0.5 * scale_factors[0] + dz[:, :, :, 1] - pose = torch.stack( - [x, y, dz[:, :, :, 2]], dim=-1 - ) # (batch_size, 1, num_joints, 3) + pose = torch.stack([x, y, dz[:, :, :, 2]], dim=-1) # (batch_size, 1, num_joints, 3) - return pose \ No newline at end of file + return pose diff --git a/dlclive/pose_estimation_pytorch/models/registry.py b/dlclive/pose_estimation_pytorch/models/registry.py index 45ed7354..baa79e2c 100644 --- a/dlclive/pose_estimation_pytorch/models/registry.py +++ b/dlclive/pose_estimation_pytorch/models/registry.py @@ -7,12 +7,10 @@ import inspect from functools import partial -from typing import Any, Dict, Optional +from typing import Any -def build_from_cfg( - cfg: Dict, registry: "Registry", default_args: Optional[Dict] = None -) -> Any: +def build_from_cfg(cfg: dict, registry: "Registry", default_args: dict | None = None) -> Any: """Builds a module from the configuration dictionary when it represents a class configuration, or call a function from the configuration dictionary when it represents a function configuration. @@ -57,7 +55,7 @@ def build_from_cfg( return obj_cls(**args) except Exception as e: # Normal TypeError does not print class name. - raise type(e)(f"{obj_cls.__name__}: {e}") + raise type(e)(f"{obj_cls.__name__}: {e}") from e class Registry: @@ -115,10 +113,7 @@ def __contains__(self, key): return self.get(key) is not None def __repr__(self): - format_str = ( - self.__class__.__name__ + f"(name={self._name}, " - f"items={self._module_dict})" - ) + format_str = self.__class__.__name__ + f"(name={self._name}, items={self._module_dict})" return format_str @staticmethod @@ -234,9 +229,7 @@ def _add_children(self, registry): """ assert isinstance(registry, Registry) assert registry.scope is not None - assert ( - registry.scope not in self.children - ), f"scope {registry.scope} exists in {self.name} registry" + assert registry.scope not in self.children, f"scope {registry.scope} exists in {self.name} registry" self.children[registry.scope] = registry def _register_module(self, module, module_name=None, force=False): @@ -261,9 +254,7 @@ def _register_module(self, module, module_name=None, force=False): >>> assert registry.get("Model") == Model """ if not inspect.isclass(module) and not inspect.isfunction(module): - raise TypeError( - "module must be a class or a function, " f"but got {type(module)}" - ) + raise TypeError(f"module must be a class or a function, but got {type(module)}") if module_name is None: module_name = module.__name__ @@ -271,7 +262,7 @@ def _register_module(self, module, module_name=None, force=False): module_name = [module_name] for name in module_name: if not force and name in self._module_dict: - raise KeyError(f"{name} is already registered " f"in {self.name}") + raise KeyError(f"{name} is already registered in {self.name}") self._module_dict[name] = module def deprecated_register_module(self, cls=None, force=False): diff --git a/dlclive/pose_estimation_pytorch/runner.py b/dlclive/pose_estimation_pytorch/runner.py index 2c59605f..0c6b1696 100644 --- a/dlclive/pose_estimation_pytorch/runner.py +++ b/dlclive/pose_estimation_pytorch/runner.py @@ -9,6 +9,7 @@ # Licensed under GNU Lesser General Public License v3.0 # """PyTorch and ONNX runners for DeepLabCut-Live""" + import copy from dataclasses import dataclass from pathlib import Path @@ -19,8 +20,8 @@ from torchvision.transforms import v2 import dlclive.pose_estimation_pytorch.data as data -import dlclive.pose_estimation_pytorch.models as models import dlclive.pose_estimation_pytorch.dynamic_cropping as dynamic_cropping +import dlclive.pose_estimation_pytorch.models as models from dlclive.core.runner import BaseRunner from dlclive.pose_estimation_pytorch.data.image import AutoPadToDivisor @@ -72,12 +73,8 @@ def update(self, pose: torch.Tensor, w: int, h: int) -> None: size = max(w, h) bboxes = torch.zeros((num_det, 4)) - bboxes[:, :2] = ( - torch.min(torch.nan_to_num(pose, size)[..., :2], dim=1)[0] - self.margin - ) - bboxes[:, 2:4] = ( - torch.max(torch.nan_to_num(pose, 0)[..., :2], dim=1)[0] + self.margin - ) + bboxes[:, :2] = torch.min(torch.nan_to_num(pose, size)[..., :2], dim=1)[0] - self.margin + bboxes[:, 2:4] = torch.max(torch.nan_to_num(pose, 0)[..., :2], dim=1)[0] + self.margin bboxes = torch.clip(bboxes, min=torch.zeros(4), max=torch.tensor([w, h, w, h])) self._detections = dict(boxes=bboxes, scores=torch.ones(num_det)) self._age += 1 @@ -174,7 +171,7 @@ def close(self) -> None: @torch.inference_mode() def get_pose(self, frame: np.ndarray) -> np.ndarray: c, h, w = frame.shape - tensor = torch.from_numpy(frame).permute(2, 0, 1) # CHW, still on CPU + tensor = torch.from_numpy(frame).permute(2, 0, 1) # CHW, still on CPU offsets_and_scales = None if self.detector is not None: @@ -191,7 +188,7 @@ def get_pose(self, frame: np.ndarray) -> np.ndarray: frame_batch, offsets_and_scales = self._prepare_top_down(tensor, detections) if len(frame_batch) == 0: - offsets_and_scales = [(0, 0), 1] + offsets_and_scales = [(0, 0), 1] tensor = frame_batch # still CHW, batched if self.dynamic is not None: @@ -290,14 +287,8 @@ def load_model(self) -> None: w, h = crop.get("width", 256), crop.get("height", 256) self.dynamic.top_down_crop_size = w, h - if ( - self.cfg["method"] == "td" - and self.detector is None - and self.dynamic is None - ): - raise ValueError( - "Top-down models must either use a detector or a TopDownDynamicCropper." - ) + if self.cfg["method"] == "td" and self.detector is None and self.dynamic is None: + raise ValueError("Top-down models must either use a detector or a TopDownDynamicCropper.") pose_transforms = [v2.ToDtype(torch.float32, scale=True)] auto_padding_cfg = self.cfg["data"]["inference"].get("auto_padding", None) @@ -320,9 +311,7 @@ def read_config(self) -> dict: raw_data = torch.load(self.path, map_location="cpu", weights_only=True) return raw_data["config"] - def _prepare_top_down( - self, frame: torch.Tensor, detections: dict[str, torch.Tensor] - ): + def _prepare_top_down(self, frame: torch.Tensor, detections: dict[str, torch.Tensor]): """Prepares a frame for top-down pose estimation.""" # Accept unbatched frame (C, H, W) or batched frame (1, C, H, W) if frame.dim() == 4: @@ -371,7 +360,7 @@ def _postprocess_top_down( return torch.zeros((0, bodyparts, coords)) poses = [] - for pose, (offset, scale) in zip(batch_pose, offsets_and_scales): + for pose, (offset, scale) in zip(batch_pose, offsets_and_scales, strict=False): poses.append( torch.cat( [ diff --git a/dlclive/pose_estimation_tensorflow/graph.py b/dlclive/pose_estimation_tensorflow/graph.py index 20c42031..80aa9d12 100644 --- a/dlclive/pose_estimation_tensorflow/graph.py +++ b/dlclive/pose_estimation_tensorflow/graph.py @@ -107,9 +107,7 @@ def get_input_tensor(graph): return input_tensor -def extract_graph( - graph, tf_config=None -) -> tuple[tf.Session, tf.Tensor, list[tf.Tensor]]: +def extract_graph(graph, tf_config=None) -> tuple[tf.Session, tf.Tensor, list[tf.Tensor]]: """ Initializes a tensorflow session with the specified graph and extracts the model's inputs and outputs diff --git a/dlclive/pose_estimation_tensorflow/pose.py b/dlclive/pose_estimation_tensorflow/pose.py index b7f9cd77..be19a8d8 100644 --- a/dlclive/pose_estimation_tensorflow/pose.py +++ b/dlclive/pose_estimation_tensorflow/pose.py @@ -69,9 +69,7 @@ def argmax_pose_predict(scmap, offmat, stride): num_joints = scmap.shape[2] pose = [] for joint_idx in range(num_joints): - maxloc = np.unravel_index( - np.argmax(scmap[:, :, joint_idx]), scmap[:, :, joint_idx].shape - ) + maxloc = np.unravel_index(np.argmax(scmap[:, :, joint_idx]), scmap[:, :, joint_idx].shape) offset = np.array(offmat[maxloc][joint_idx])[::-1] pos_f8 = np.array(maxloc).astype("float") * stride + 0.5 * stride + offset pose.append(np.hstack((pos_f8[::-1], [scmap[maxloc][joint_idx]]))) diff --git a/dlclive/pose_estimation_tensorflow/runner.py b/dlclive/pose_estimation_tensorflow/runner.py index fa05f8e0..620cdf69 100644 --- a/dlclive/pose_estimation_tensorflow/runner.py +++ b/dlclive/pose_estimation_tensorflow/runner.py @@ -9,6 +9,7 @@ # Licensed under GNU Lesser General Public License v3.0 # """TensorFlow runners for DeepLabCut-Live""" + import glob import os from pathlib import Path @@ -33,6 +34,12 @@ multi_pose_predict, ) +try: + # TensorFlow 1.x TensorRT integration + from tensorflow.contrib import tensorrt as trt # type: ignore[attr-defined] +except Exception: + trt = None + class TensorFlowRunner(BaseRunner): """TensorFlow runner for live pose estimation using DeepLabCut-Live.""" @@ -62,9 +69,7 @@ def close(self) -> None: def get_pose(self, frame: np.ndarray, **kwargs) -> np.ndarray: if self.model_type in ["base", "tensorrt"]: - pose_output = self.sess.run( - self.outputs, feed_dict={self.inputs: np.expand_dims(frame, axis=0)} - ) + pose_output = self.sess.run(self.outputs, feed_dict={self.inputs: np.expand_dims(frame, axis=0)}) elif self.model_type == "tflite": self.tflite_interpreter.set_tensor( @@ -79,14 +84,11 @@ def get_pose(self, frame: np.ndarray, **kwargs) -> np.ndarray: self.tflite_interpreter.get_tensor(self.outputs[1]["index"]), ] else: - pose_output = self.tflite_interpreter.get_tensor( - self.outputs[0]["index"] - ) + pose_output = self.tflite_interpreter.get_tensor(self.outputs[0]["index"]) else: raise DLCLiveError( - f"model_type={self.model_type} is not supported. model_type must be " - f"'base', 'tflite', or 'tensorrt'" + f"model_type={self.model_type} is not supported. model_type must be 'base', 'tflite', or 'tensorrt'" ) # check if using TFGPUinference flag @@ -95,9 +97,7 @@ def get_pose(self, frame: np.ndarray, **kwargs) -> np.ndarray: scmap, locref = extract_cnn_output(pose_output, self.cfg) num_outputs = self.cfg.get("num_outputs", 1) if num_outputs > 1: - pose = multi_pose_predict( - scmap, locref, self.cfg["stride"], num_outputs - ) + pose = multi_pose_predict(scmap, locref, self.cfg["stride"], num_outputs) else: pose = argmax_pose_predict(scmap, locref, self.cfg["stride"]) else: @@ -116,9 +116,7 @@ def init_inference(self, frame: np.ndarray, **kwargs) -> np.ndarray: if self.model_type == "base": graph_def = read_graph(model_file) graph = finalize_graph(graph_def) - self.sess, self.inputs, self.outputs = extract_graph( - graph, tf_config=self.tf_config - ) + self.sess, self.inputs, self.outputs = extract_graph(graph, tf_config=self.tf_config) elif self.model_type == "tflite": ### @@ -150,15 +148,13 @@ def init_inference(self, frame: np.ndarray, **kwargs) -> np.ndarray: try: tflite_model = converter.convert() - except Exception: + except Exception as e: raise DLCLiveError( - ( - "This model cannot be converted to tensorflow lite format. " - "To use tensorflow lite for live inference, " - "make sure to set TFGPUinference=False " - "when exporting the model from DeepLabCut" - ) - ) + "This model cannot be converted to tensorflow lite format. " + "To use tensorflow lite for live inference, " + "make sure to set TFGPUinference=False " + "when exporting the model from DeepLabCut" + ) from e self.tflite_interpreter = tf.lite.Interpreter(model_content=tflite_model) self.tflite_interpreter.allocate_tensors() @@ -166,6 +162,12 @@ def init_inference(self, frame: np.ndarray, **kwargs) -> np.ndarray: self.outputs = self.tflite_interpreter.get_output_details() elif self.model_type == "tensorrt": + if trt is None: + raise DLCLiveError( + "TensorRT integration requires tensorflow 1.x " + "and the tensorflow.contrib.tensorrt module," + " which is not available in your current environment." + ) graph_def = read_graph(model_file) graph = finalize_graph(graph_def) output_tensors = get_output_tensors(graph) @@ -188,14 +190,11 @@ def init_inference(self, frame: np.ndarray, **kwargs) -> np.ndarray: ) graph = finalize_graph(graph_def) - self.sess, self.inputs, self.outputs = extract_graph( - graph, tf_config=self.tf_config - ) + self.sess, self.inputs, self.outputs = extract_graph(graph, tf_config=self.tf_config) else: raise DLCLiveError( - f"model_type={self.model_type} is not supported. model_type must be " - "'base', 'tflite', or 'tensorrt'" + f"model_type={self.model_type} is not supported. model_type must be 'base', 'tflite', or 'tensorrt'" ) return self.get_pose(frame, **kwargs) diff --git a/dlclive/predictor/base.py b/dlclive/predictor/base.py index f8b8b987..8e45b766 100644 --- a/dlclive/predictor/base.py +++ b/dlclive/predictor/base.py @@ -45,9 +45,7 @@ def __init__(self): self.num_animals = None @abstractmethod - def forward( - self, stride: float, outputs: dict[str, torch.Tensor] - ) -> dict[str, torch.Tensor]: + def forward(self, stride: float, outputs: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: """Abstract method for the forward pass of the Predictor. Args: diff --git a/dlclive/predictor/single_predictor.py b/dlclive/predictor/single_predictor.py index c622cf94..d320701a 100644 --- a/dlclive/predictor/single_predictor.py +++ b/dlclive/predictor/single_predictor.py @@ -10,13 +10,11 @@ # from __future__ import annotations -from typing import Tuple - import torch from dlclive.pose_estimation_pytorch.models.predictors.base import ( - BasePredictor, PREDICTORS, + BasePredictor, ) @@ -55,9 +53,7 @@ def __init__( self.location_refinement = location_refinement self.locref_std = locref_std - def forward( - self, stride: float, outputs: dict[str, torch.Tensor] - ) -> dict[str, torch.Tensor]: + def forward(self, stride: float, outputs: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: """Forward pass of SinglePredictor. Gets predictions from model output. Args: @@ -85,9 +81,7 @@ def forward( locrefs = None if self.location_refinement: locrefs = outputs["locref"] - locrefs = locrefs.permute(0, 2, 3, 1).reshape( - batch_size, height, width, num_joints, 2 - ) + locrefs = locrefs.permute(0, 2, 3, 1).reshape(batch_size, height, width, num_joints, 2) locrefs = locrefs * self.locref_std poses = self.get_pose_prediction(heatmaps, locrefs, scale_factors) @@ -97,9 +91,7 @@ def forward( return {"poses": poses} - def get_top_values( - self, heatmap: torch.Tensor - ) -> Tuple[torch.Tensor, torch.Tensor]: + def get_top_values(self, heatmap: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: """Get the top values from the heatmap. Args: @@ -119,9 +111,7 @@ def get_top_values( y, x = heatmap_top // nx, heatmap_top % nx return y, x - def get_pose_prediction( - self, heatmap: torch.Tensor, locref: torch.Tensor | None, scale_factors - ) -> torch.Tensor: + def get_pose_prediction(self, heatmap: torch.Tensor, locref: torch.Tensor | None, scale_factors) -> torch.Tensor: """Gets the pose prediction given the heatmaps and locref. Args: diff --git a/dlclive/processor/__init__.py b/dlclive/processor/__init__.py index 657b4053..5d459267 100644 --- a/dlclive/processor/__init__.py +++ b/dlclive/processor/__init__.py @@ -4,4 +4,9 @@ Licensed under GNU Lesser General Public License v3.0 """ + from dlclive.processor.processor import Processor + +__all__ = [ + "Processor", +] diff --git a/dlclive/processor/kalmanfilter.py b/dlclive/processor/kalmanfilter.py index ff468052..31f9476b 100644 --- a/dlclive/processor/kalmanfilter.py +++ b/dlclive/processor/kalmanfilter.py @@ -19,13 +19,15 @@ def __init__( forward=0.002, fps=30, nderiv=2, - priors=[10, 10], + priors=None, initial_var=5, process_var=5, dlc_var=20, lik_thresh=0, **kwargs, ): + if priors is None: + priors = [10, 10] super().__init__(**kwargs) self.adapt = adapt @@ -121,11 +123,7 @@ def process(self, pose, **kwargs): liks = self._get_state_likelihood(pose) self._update(liks) - forward_time = ( - (time.time() - kwargs["frame_time"] + self.forward) - if self.adapt - else self.forward - ) + forward_time = (time.time() - kwargs["frame_time"] + self.forward) if self.adapt else self.forward future_pose = self._get_future_pose(forward_time) future_pose = np.hstack((future_pose, pose[:, 2].reshape(self.bp, 1))) diff --git a/dlclive/utils.py b/dlclive/utils.py index b6981959..8a6acdee 100644 --- a/dlclive/utils.py +++ b/dlclive/utils.py @@ -5,10 +5,10 @@ Licensed under GNU Lesser General Public License v3.0 """ -import warnings -from pathlib import Path import urllib.error import urllib.request +import warnings +from pathlib import Path import cv2 import numpy as np @@ -72,6 +72,7 @@ def img_to_rgb(frame: np.ndarray) -> np.ndarray: warnings.warn( f"Image has {frame.ndim} dimensions. Must be 2 or 3 dimensions to convert to RGB", DLCLiveWarning, + stacklevel=2, ) return frame @@ -140,9 +141,7 @@ def _img_as_ubyte_np(frame: np.ndarray) -> np.ndarray: return frame.astype(np.uint8) else: - raise TypeError( - "image of type {} could not be converted to ubyte".format(im_type) - ) + raise TypeError(f"image of type {im_type} could not be converted to ubyte") def decode_fourcc(cc): @@ -162,7 +161,7 @@ def decode_fourcc(cc): """ try: decoded = "".join([chr((int(cc) >> 8 * i) & 0xFF) for i in range(4)]) - except: + except Exception: decoded = "" return decoded @@ -171,77 +170,80 @@ def decode_fourcc(cc): def get_available_backends() -> list[Engine]: """ Check which backends (TensorFlow or PyTorch) are installed. - + Returns: - list[str]: List of installed backends. Possible values: ["tensorflow"], ["pytorch"], + list[str]: List of installed backends. Possible values: ["tensorflow"], ["pytorch"], or ["tensorflow", "pytorch"]. Returns an empty list if neither is installed. - + Warns: DLCLiveWarning: If neither TensorFlow nor PyTorch is installed. """ backends = [] - + try: import tensorflow + backends.append(Engine.TENSORFLOW) except (ImportError, ModuleNotFoundError): pass - + try: import torch + backends.append(Engine.PYTORCH) except (ImportError, ModuleNotFoundError): pass - + if not backends: warnings.warn( "Neither TensorFlow nor PyTorch is installed. One of these is required to use DLCLive!" "Install with: pip install deeplabcut-live[tf] or pip install deeplabcut-live[pytorch]", DLCLiveWarning, + stacklevel=2, ) - + return backends def download_file(url: str, filepath: str, chunk_size: int = 8192) -> None: """ Download a file from a URL with progress bar and error handling. - + Args: url: URL to download from filepath: Local path to save the file chunk_size: Size of chunks to read (default: 8192 bytes) - + Raises: urllib.error.URLError: If the download fails IOError: If the file cannot be written """ filepath = Path(filepath) - + # Check if file already exists if filepath.exists(): print(f"File already exists at {filepath}, skipping download.") return - + # Ensure parent directory exists filepath.parent.mkdir(parents=True, exist_ok=True) - + try: # Open the URL with urllib.request.urlopen(url) as response: # Get file size if available - total_size = int(response.headers.get('Content-Length', 0)) - + total_size = int(response.headers.get("Content-Length", 0)) + # Create progress bar if file size is known if total_size > 0: - pbar = tqdm(total=total_size, unit='B', unit_scale=True, desc="Downloading") + pbar = tqdm(total=total_size, unit="B", unit_scale=True, desc="Downloading") else: pbar = None print("Downloading...") - + # Download and write file downloaded = 0 - with open(filepath, 'wb') as f: + with open(filepath, "wb") as f: while True: chunk = response.read(chunk_size) if not chunk: @@ -250,19 +252,19 @@ def download_file(url: str, filepath: str, chunk_size: int = 8192) -> None: downloaded += len(chunk) if pbar: pbar.update(len(chunk)) - + if pbar: pbar.close() - + # Verify file was written if not filepath.exists() or filepath.stat().st_size == 0: - raise IOError(f"Downloaded file is empty or was not written to {filepath}") - + raise OSError(f"Downloaded file is empty or was not written to {filepath}") + print(f"Successfully downloaded to {filepath}") - + except urllib.error.HTTPError as e: - raise urllib.error.URLError(f"HTTP error {e.code}: {e.reason} when downloading from {url}") + raise urllib.error.URLError(f"HTTP error {e.code}: {e.reason} when downloading from {url}") from e except urllib.error.URLError as e: - raise urllib.error.URLError(f"Failed to download from {url}: {e.reason}") - except IOError as e: - raise IOError(f"Failed to write file to {filepath}: {e}") \ No newline at end of file + raise urllib.error.URLError(f"Failed to download from {url}: {e.reason}") from e + except OSError as e: + raise OSError(f"Failed to write file to {filepath}") from e diff --git a/dlclive/version.py b/dlclive/version.py index 0c4c1d89..26c13a69 100644 --- a/dlclive/version.py +++ b/dlclive/version.py @@ -6,6 +6,5 @@ Licensed under GNU Lesser General Public License v3.0 """ - __version__ = "1.1.1rc1" VERSION = __version__ diff --git a/docs/DLC Live Benchmark.md b/docs/DLC Live Benchmark.md index 583e9f25..03dbc184 100755 --- a/docs/DLC Live Benchmark.md +++ b/docs/DLC Live Benchmark.md @@ -25,7 +25,7 @@ | Linux | ONNX | ONNX | CPU | FP16 | Pigeon | 36s - ~1k | 30 | (480, 270) | Resize=0.25 | `ResNet50 + SSDLite detector` (td) | 161.32 ms ± 18.29ms | 161.3 ms ± 18.29ms | 6 ± 1 | | ** **CUDA: NVIDIA GeForce RTX 3050 (6GB)** -** **CPU: 13th Gen Intel Core i7-13620H × 16** +** **CPU: 13th Gen Intel Core i7-13620H × 16** ** **Linux: Ubuntu 24.04 LTS** ^ *Startup time at inference for a TensorRT engine takes between 30 and 50 seconds, diff --git a/docs/install_desktop.md b/docs/install_desktop.md index 7ef4238f..2c686d1f 100755 --- a/docs/install_desktop.md +++ b/docs/install_desktop.md @@ -25,22 +25,22 @@ Install from PyPI with PyTorch or TensorFlow: ```bash # With PyTorch (recommended) pip install deeplabcut-live[pytorch] - + # Or with TensorFlow pip install deeplabcut-live[tf] - + # Or using uv uv pip install deeplabcut-live[pytorch] # or [tf] ``` - -### Windows-users with GPU: -On **Windows**, the `deeplabcut-live[pytorch]` extra will not install the required CUDA-enabled wheels for PyTorch by default. Windows users with a CUDA GPU should install CUDA-enabled PyTorch first: + +### Windows-users with GPU: +On **Windows**, the `deeplabcut-live[pytorch]` extra will not install the required CUDA-enabled wheels for PyTorch by default. Windows users with a CUDA GPU should install CUDA-enabled PyTorch first: ```bash # First install PyTorch with CUDA support pip install torch torchvision --index-url https://download.pytorch.org/whl/cu124 - + # Then install DeepLabCut-live (it will use the existing GPU-enabled PyTorch) pip install deeplabcut-live[pytorch] ``` @@ -62,7 +62,7 @@ If you want to install from a local clone of the repository, follow these steps: # Using uv uv sync --extra pytorch --python 3.11 # or --extra tf for TensorFlow ``` - + **Option B - using pip in a conda environment** ``` conda create -n dlc-live python=3.11 @@ -75,7 +75,7 @@ If the above instructions do not work, or you want to use a specific Pytorch or ``` # Install Pytorch pip install torch torchvision --index-url https://download.pytorch.org/whl/cu124 - + # Or install TensorFlow instead pip install tensorflow ``` @@ -101,4 +101,4 @@ dlc-live-test If installed properly, this script will i) create a temporary folder ii) download the full_dog model from the [DeepLabCut Model Zoo]( http://www.mousemotorlab.org/dlc-modelzoo), iii) download a short video clip of -a dog, and iv) run inference while displaying keypoints. v) remove the temporary folder. \ No newline at end of file +a dog, and iv) run inference while displaying keypoints. v) remove the temporary folder. diff --git a/docs/install_jetson.md b/docs/install_jetson.md index 2db456e2..6506c783 100755 --- a/docs/install_jetson.md +++ b/docs/install_jetson.md @@ -49,9 +49,9 @@ pip install -U pip testresources setuptools #### Install DeepLabCut-live dependencies -First, install `python` dependencies to run `PyTorch` (from [NVIDIA instructions to +First, install `python` dependencies to run `PyTorch` (from [NVIDIA instructions to install PyTorch for Jetson Platform]( -https://docs.nvidia.com/deeplearning/frameworks/install-pytorch-jetson-platform/index.html)). +https://docs.nvidia.com/deeplearning/frameworks/install-pytorch-jetson-platform/index.html)). _This may take ~15-30 minutes._ ``` @@ -76,8 +76,8 @@ To install PyTorch >= 2.0 pip3 install --no-cache https://developer.download.nvidia.com/compute/redist/jp/v51/pytorch/ ``` -Currently, the only available PyTorch version that can be used is -`torch-2.0.0a0+8aa34602.nv23.03-cp38-cp38-linux_aarch64.whl`. +Currently, the only available PyTorch version that can be used is +`torch-2.0.0a0+8aa34602.nv23.03-cp38-cp38-linux_aarch64.whl`. Lastly, copy the opencv-python bindings into your virtual environment: @@ -88,7 +88,7 @@ cp -r /usr/lib/python3.12/dist-packages ~/dlc-live/lib/python3.12/dist-packages #### Install the DeepLabCut-live package -Finally, please install DeepLabCut-live from PyPi (_this will take 3-5 mins_), then +Finally, please install DeepLabCut-live from PyPi (_this will take 3-5 mins_), then test the installation: ``` diff --git a/example_processors/DogJumpLED/__init__.py b/example_processors/DogJumpLED/__init__.py index 7706cbc1..64bf2778 100644 --- a/example_processors/DogJumpLED/__init__.py +++ b/example_processors/DogJumpLED/__init__.py @@ -5,5 +5,11 @@ Licensed under GNU Lesser General Public License v3.0 """ -from .izzy_jump import IzzyJump, IzzyJumpKF -from .izzy_jump import IzzyJumpOffline, IzzyJumpKFOffline +from .izzy_jump import IzzyJump, IzzyJumpKF, IzzyJumpKFOffline, IzzyJumpOffline + +__all__ = [ + "IzzyJump", + "IzzyJumpKF", + "IzzyJumpKFOffline", + "IzzyJumpOffline", +] diff --git a/example_processors/DogJumpLED/izzy_jump.py b/example_processors/DogJumpLED/izzy_jump.py index f24356dc..8bbb80ff 100644 --- a/example_processors/DogJumpLED/izzy_jump.py +++ b/example_processors/DogJumpLED/izzy_jump.py @@ -5,17 +5,16 @@ Licensed under GNU Lesser General Public License v3.0 """ - -import serial -import struct import time + import numpy as np +import serial -from dlclive.processor import Processor, KalmanFilterPredictor +from dlclive.processor import KalmanFilterPredictor, Processor class IzzyJump(Processor): - def __init__(self, com="", lik_thresh=0.5, baudrate=int(9600), **kwargs): + def __init__(self, com="", lik_thresh=0.5, baudrate=9600, **kwargs): super().__init__() self.ser = serial.Serial(com, baudrate, timeout=0) @@ -69,11 +68,7 @@ def process(self, pose, **kwargs): l_elbow = pose[12, 1] if pose[12, 2] > self.lik_thresh else None r_elbow = pose[13, 1] if pose[13, 2] > self.lik_thresh else None elbows = [l_elbow, r_elbow] - this_elbow = ( - min([e for e in elbows if e is not None]) - if any([e is not None for e in elbows]) - else None - ) + this_elbow = min([e for e in elbows if e is not None]) if any([e is not None for e in elbows]) else None withers = pose[6, 1] if pose[6, 2] > self.lik_thresh else None @@ -107,17 +102,19 @@ def __init__( self, com="", lik_thresh=0.5, - baudrate=int(9600), + baudrate=9600, adapt=True, forward=0.003, fps=30, nderiv=2, - priors=[1, 1], + priors=None, initial_var=1, process_var=1, dlc_var=4, ): + if priors is None: + priors = [1, 1] super().__init__( adapt=adapt, forward=forward, diff --git a/example_processors/DogJumpLED/izzy_jump_offline.py b/example_processors/DogJumpLED/izzy_jump_offline.py index 9e574ab9..a93811c9 100644 --- a/example_processors/DogJumpLED/izzy_jump_offline.py +++ b/example_processors/DogJumpLED/izzy_jump_offline.py @@ -5,12 +5,9 @@ Licensed under GNU Lesser General Public License v3.0 """ - -import struct -import time import numpy as np -from dlclive.processor import Processor, KalmanFilterPredictor +from dlclive.processor import KalmanFilterPredictor, Processor class IzzyJumpOffline(Processor): @@ -53,11 +50,7 @@ def process(self, pose, **kwargs): l_elbow = pose[12, 1] if pose[12, 2] > self.lik_thresh else None r_elbow = pose[13, 1] if pose[13, 2] > self.lik_thresh else None elbows = [l_elbow, r_elbow] - this_elbow = ( - min([e for e in elbows if e is not None]) - if any([e is not None for e in elbows]) - else None - ) + this_elbow = min([e for e in elbows if e is not None]) if any([e is not None for e in elbows]) else None withers = pose[6, 1] if pose[6, 2] > self.lik_thresh else None @@ -94,12 +87,14 @@ def __init__( forward=0.003, fps=30, nderiv=2, - priors=[1, 1], + priors=None, initial_var=1, process_var=1, dlc_var=4, ): + if priors is None: + priors = [1, 1] super().__init__( adapt=adapt, forward=forward, diff --git a/example_processors/DogJumpLED/teensy_leds/teensy_leds.ino b/example_processors/DogJumpLED/teensy_leds/teensy_leds.ino index e4c1ebf5..049b0239 100644 --- a/example_processors/DogJumpLED/teensy_leds/teensy_leds.ino +++ b/example_processors/DogJumpLED/teensy_leds/teensy_leds.ino @@ -8,7 +8,7 @@ void blink() { Serial.flush(); noTone(IR); while (digitalRead(REC) == 0) {} - + } void setup() { @@ -24,26 +24,26 @@ void setup() { void loop() { unsigned int ser_avail = Serial.available(); - + while (ser_avail > 0) { - + unsigned int cmd = Serial.read(); if (cmd == 'L') { - + digitalWrite(LED, !digitalRead(LED)); - + } else if (cmd == 'R') { Serial.write(digitalRead(LED)); Serial.flush(); - + } else if (cmd == 'I') { tone(IR, 38000); - + } - + } - + } diff --git a/example_processors/MouseLickLED/__init__.py b/example_processors/MouseLickLED/__init__.py index 1f68284b..d262d83c 100644 --- a/example_processors/MouseLickLED/__init__.py +++ b/example_processors/MouseLickLED/__init__.py @@ -6,3 +6,7 @@ """ from .lick_led import MouseLickLED + +__all__ = [ + "MouseLickLED", +] diff --git a/example_processors/MouseLickLED/lick_led.py b/example_processors/MouseLickLED/lick_led.py index d7a1be77..51c70ac8 100644 --- a/example_processors/MouseLickLED/lick_led.py +++ b/example_processors/MouseLickLED/lick_led.py @@ -5,17 +5,16 @@ Licensed under GNU Lesser General Public License v3.0 """ - -import serial -import struct import time + import numpy as np +import serial from dlclive import Processor class MouseLickLED(Processor): - def __init__(self, com, lik_thresh=0.5, baudrate=int(9600)): + def __init__(self, com, lik_thresh=0.5, baudrate=9600): super().__init__() self.ser = serial.Serial(com, baudrate, timeout=0) @@ -75,9 +74,7 @@ def save(self, filename): in_time = np.array(self.in_time) frame_time = np.array(self.lick_frame_time) try: - np.savez( - filename, out_time=out_time, in_time=in_time, frame_time=frame_time - ) + np.savez(filename, out_time=out_time, in_time=in_time, frame_time=frame_time) save_code = True except Exception: save_code = False diff --git a/example_processors/MouseLickLED/teensy_leds/teensy_leds.ino b/example_processors/MouseLickLED/teensy_leds/teensy_leds.ino index e4c1ebf5..049b0239 100644 --- a/example_processors/MouseLickLED/teensy_leds/teensy_leds.ino +++ b/example_processors/MouseLickLED/teensy_leds/teensy_leds.ino @@ -8,7 +8,7 @@ void blink() { Serial.flush(); noTone(IR); while (digitalRead(REC) == 0) {} - + } void setup() { @@ -24,26 +24,26 @@ void setup() { void loop() { unsigned int ser_avail = Serial.available(); - + while (ser_avail > 0) { - + unsigned int cmd = Serial.read(); if (cmd == 'L') { - + digitalWrite(LED, !digitalRead(LED)); - + } else if (cmd == 'R') { Serial.write(digitalRead(LED)); Serial.flush(); - + } else if (cmd == 'I') { tone(IR, 38000); - + } - + } - + } diff --git a/example_processors/TeensyLaser/teensy_laser.py b/example_processors/TeensyLaser/teensy_laser.py index 07d6c9f2..795c9eef 100644 --- a/example_processors/TeensyLaser/teensy_laser.py +++ b/example_processors/TeensyLaser/teensy_laser.py @@ -5,26 +5,23 @@ Licensed under GNU Lesser General Public License v3.0 """ - -from dlclive.processor.processor import Processor -import serial -import struct import pickle +import struct import time +import serial + +from dlclive.processor.processor import Processor + class TeensyLaser(Processor): - def __init__( - self, com, baudrate=115200, pulse_freq=50, pulse_width=5, max_stim_dur=0 - ): + def __init__(self, com, baudrate=115200, pulse_freq=50, pulse_width=5, max_stim_dur=0): super().__init__() self.ser = serial.Serial(com, baudrate) self.pulse_freq = pulse_freq self.pulse_width = pulse_width - self.max_stim_dur = ( - max_stim_dur if (max_stim_dur >= 0) and (max_stim_dur < 65356) else 0 - ) + self.max_stim_dur = max_stim_dur if (max_stim_dur >= 0) and (max_stim_dur < 65356) else 0 self.stim_on = False self.stim_on_time = [] self.stim_off_time = [] @@ -35,14 +32,11 @@ def close_serial(self): def turn_stim_on(self): - # command to activate PWM signal to laser is the letter 'O' followed by three 16 bit integers -- pulse frequency, pulse width, and max stim duration + # command to activate PWM signal to laser is + # the letter 'O' followed by three 16 bit integers + # -- pulse frequency, pulse width, and max stim duration if not self.stim_on: - self.ser.write( - b"O" - + struct.pack( - "HHH", self.pulse_freq, self.pulse_width, self.max_stim_dur - ) - ) + self.ser.write(b"O" + struct.pack("HHH", self.pulse_freq, self.pulse_width, self.max_stim_dur)) self.stim_on = True self.stim_on_time.append(time.time()) diff --git a/example_processors/TeensyLaser/teensy_laser/teensy_laser.ino b/example_processors/TeensyLaser/teensy_laser/teensy_laser.ino index 76a470b3..6ce659c9 100644 --- a/example_processors/TeensyLaser/teensy_laser/teensy_laser.ino +++ b/example_processors/TeensyLaser/teensy_laser/teensy_laser.ino @@ -1,5 +1,5 @@ -/* - * Commands: +/* + * Commands: * O = opto on; command = O, frequency, width, duration * X = opto off * R = reboot @@ -14,7 +14,7 @@ unsigned int opto_start = 0, opto_dur = 0; unsigned int read_int16() { - union u_tag { + union u_tag { byte b[2]; unsigned int val; } par; @@ -35,13 +35,13 @@ void setup() { void loop() { unsigned int curr_time = millis(); - + while (Serial.available() > 0) { unsigned int cmd = Serial.read(); - + if(cmd == 'O') { - + opto_start = curr_time; opto_freq = read_int16(); opto_width = read_int16(); @@ -59,15 +59,15 @@ void loop() { Serial.print(opto_dur); Serial.print('\n'); Serial.flush(); - + } else if(cmd == 'X') { analogWrite(opto_pin, 0); - + } else if(cmd == 'R') { _reboot_Teensyduino_(); - + } } diff --git a/pyproject.toml b/pyproject.toml index 1762f256..e61b5aff 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,107 +1,98 @@ [build-system] -requires = ["setuptools>=61.0", "wheel"] build-backend = "setuptools.build_meta" +requires = [ "setuptools>=61", "wheel" ] [project] name = "deeplabcut-live" version = "1.1.1rc1" description = "Class to load exported DeepLabCut networks and perform pose estimation on single frames (from a camera feed)" readme = "README.md" -requires-python = ">=3.10,<3.13" +keywords = [ "deep-learning", "deeplabcut", "pose-estimation", "real-time" ] license = { text = "GNU Affero General Public License v3 or later (AGPLv3+)" } -authors = [ - { name = "A. & M. Mathis Labs", email = "admin@deeplabcut.org" } -] - -keywords = ["deeplabcut", "pose-estimation", "real-time", "deep-learning"] - +requires-python = ">=3.10,<3.13" classifiers = [ - "Programming Language :: Python :: 3", - "Programming Language :: Python :: 3.10", - "Programming Language :: Python :: 3.11", - "Programming Language :: Python :: 3.12", - "License :: OSI Approved :: GNU Affero General Public License v3 or later (AGPLv3+)", - "Operating System :: OS Independent", + "License :: OSI Approved :: GNU Affero General Public License v3 or later (AGPLv3+)", + "Operating System :: OS Independent", + "Programming Language :: Python :: 3 :: Only", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", ] - dependencies = [ - "numpy>=1.20,<2", - "ruamel.yaml>=0.17.20", - "colorcet>=3.0.0", - "einops>=0.6.1", - "Pillow>=8.0.0", - "py-cpuinfo>=5.0.0", - "tqdm>=4.62.3", - "pandas>=1.0.1,!=1.5.0", - "tables>=3.8", - "opencv-python-headless>=4.5", - "dlclibrary>=0.0.6", - "scipy>=1.9", - "pip", + "colorcet>=3", + "dlclibrary>=0.0.6", + "einops>=0.6.1", + "numpy>=1.20,<2", + "opencv-python-headless>=4.5", + "pandas>=1.0.1,!=1.5", + "pillow>=8", + "pip", + "py-cpuinfo>=5", + "ruamel-yaml>=0.17.20", + "scipy>=1.9", + "tables>=3.8", + "tqdm>=4.62.3", ] - +[[project.authors]] +name = "A. & M. Mathis Labs" +email = "admin@deeplabcut.org" [project.optional-dependencies] pytorch = [ - "timm>=1.0.7", - "torch>=2.0.0", - "torchvision>=0.15", + "timm>=1.0.7", + "torch>=2", + "torchvision>=0.15", ] - tf = [ - "tensorflow>=2.7.0,<2.12; platform_system == 'Linux' and python_version < '3.11'", - "tensorflow>=2.12.0; platform_system == 'Linux' and python_version >= '3.11'", - "tensorflow-macos>=2.7.0,<2.12; platform_system == 'Darwin' and python_version < '3.11'", - "tensorflow-macos>=2.12.0; platform_system == 'Darwin' and python_version >= '3.11'", - # Tensorflow is not supported on Windows with Python >= 3.11 - "tensorflow>=2.7,<=2.10; platform_system == 'Windows' and python_version < '3.11'", - "tensorflow-io-gcs-filesystem==0.27; platform_system == 'Windows' and python_version < '3.11'", - "tensorflow-io-gcs-filesystem; platform_system != 'Windows'", + # Tensorflow is not supported on Windows with Python >= 3.11 + "tensorflow>=2.7,<=2.10; platform_system=='Windows' and python_version<'3.11'", + "tensorflow>=2.7,<2.12; platform_system=='Linux' and python_version<'3.11'", + "tensorflow>=2.12; platform_system=='Linux' and python_version>='3.11'", + "tensorflow-io-gcs-filesystem; platform_system!='Windows'", + "tensorflow-io-gcs-filesystem==0.27; platform_system=='Windows' and python_version<'3.11'", + "tensorflow-macos>=2.7,<2.12; platform_system=='Darwin' and python_version<'3.11'", + "tensorflow-macos>=2.12; platform_system=='Darwin' and python_version>='3.11'", ] +[project.scripts] +dlc-live-test = "dlclive.check_install.check_install:main" +dlc-live-benchmark = "dlclive.benchmark:main" +[project.urls] +Homepage = "https://github.com/DeepLabCut/DeepLabCut-live" +Repository = "https://github.com/DeepLabCut/DeepLabCut-live" [dependency-groups] dev = [ - "pytest", - "pytest-cov", - "hypothesis", - "black", - "ruff", + "black", + "hypothesis", + "pytest", + "pytest-cov", + "ruff", ] # Keep only for backward compatibility with Poetry # (without this section, Poetry assumes the wrong root directory of the project) [tool.poetry] -packages = [ - { include = "dlclive" } -] - -[project.scripts] -dlc-live-test = "dlclive.check_install.check_install:main" -dlc-live-benchmark = "dlclive.benchmark:main" - -[project.urls] -Homepage = "https://github.com/DeepLabCut/DeepLabCut-live" -Repository = "https://github.com/DeepLabCut/DeepLabCut-live" +packages = [ { include = "dlclive" } ] [tool.setuptools] include-package-data = true - -[tool.setuptools.packages.find] -include = ["dlclive*"] - [tool.setuptools.package-data] dlclive = [ - "check_install/*", - "modelzoo/model_configs/*.yaml", - "modelzoo/project_configs/*.yaml", + "check_install/*", + "modelzoo/model_configs/*.yaml", + "modelzoo/project_configs/*.yaml", ] +[tool.setuptools.packages.find] +include = [ "dlclive*" ] [tool.ruff] -lint.select = ["E", "F", "B", "I", "UP"] -lint.ignore = ["E741"] target-version = "py310" -fix = true line-length = 120 - +fix = true +[tool.ruff.lint] +select = [ "E", "F", "B", "I", "UP" ] +ignore = [ "E741", "F403", "F401" ] +[tool.ruff.lint.per-file-ignores] +"dlclive/__init__.py" = [ "E402" ] [tool.ruff.lint.pydocstyle] convention = "google" @@ -109,4 +100,4 @@ convention = "google" max_supported_python = "3.12" generate_python_version_classifiers = true # Avoid collapsing tables to field.key = value format (less readable) -table_format = "long" \ No newline at end of file +table_format = "long" diff --git a/scripts/export.py b/scripts/export.py index 320ada0e..fd585a9a 100644 --- a/scripts/export.py +++ b/scripts/export.py @@ -1,4 +1,5 @@ """Exports DeepLabCut models for DeepLabCut-Live""" + import warnings from pathlib import Path @@ -14,8 +15,8 @@ def read_config_as_dict(config_path: str | Path) -> dict: Returns: The configuration file with pure Python classes """ - with open(config_path, "r") as f: - cfg = YAML(typ='safe', pure=True).load(f) + with open(config_path) as f: + cfg = YAML(typ="safe", pure=True).load(f) return cfg @@ -43,13 +44,14 @@ def export_dlc3_model( if model_cfg["method"].lower() == "td": warnings.warn( "The model is a top-down model but no detector snapshot was given." - "The configuration will be changed to run the model in bottom-up mode." + "The configuration will be changed to run the model in bottom-up mode.", + stacklevel=2, ) model_cfg["method"] = "bu" else: if model_cfg["method"].lower() == "bu": - raise ValueError(f"Cannot use a detector with a bottom-up model!") + raise ValueError("Cannot use a detector with a bottom-up model!") detector_weights = torch.load(detector_snapshot, **load_kwargs)["model"] torch.save( diff --git a/tests/test_config.py b/tests/test_config.py index 9751f281..c22b6c17 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -3,6 +3,7 @@ """ import pytest + from dlclive.core import config diff --git a/tests/test_display.py b/tests/test_display.py index aa6821a4..514397c4 100644 --- a/tests/test_display.py +++ b/tests/test_display.py @@ -39,9 +39,7 @@ def test_display_frame_creates_window_and_updates(headless_display_env): env.tk.update.assert_called_once_with() -def test_display_draws_only_points_above_cutoff_with_clamping( - headless_display_env, monkeypatch -): +def test_display_draws_only_points_above_cutoff_with_clamping(headless_display_env, monkeypatch): env = headless_display_env display_mod = env.mod disp = display_mod.Display(radius=3, pcutoff=0.5) diff --git a/tests/test_dlclive.py b/tests/test_dlclive.py index 6c152236..60b7bdf4 100644 --- a/tests/test_dlclive.py +++ b/tests/test_dlclive.py @@ -2,9 +2,11 @@ Tests for DLCLive core functionality - frame processing, cropping, etc. """ -import pytest -import numpy as np from unittest.mock import Mock, patch + +import numpy as np +import pytest + from dlclive import DLCLive from dlclive.exceptions import DLCLiveError @@ -40,7 +42,7 @@ def test_dlclive_initialization(self, mock_build_runner, mock_runner, tmp_path): assert dlc.path == model_path assert dlc.model_type == "pytorch" - assert dlc.is_initialized == False + assert not dlc.is_initialized assert dlc.cropping is None assert dlc.dynamic == (False, 0.5, 10) assert dlc.processor is None @@ -93,9 +95,7 @@ def test_dlclive_parameterization(self, mock_build_runner, mock_runner, tmp_path @patch("dlclive.factory.build_runner") @patch("dlclive.utils.img_to_rgb") - def test_process_frame_cropping( - self, mock_img_to_rgb, mock_build_runner, mock_runner, sample_frame, tmp_path - ): + def test_process_frame_cropping(self, mock_img_to_rgb, mock_build_runner, mock_runner, sample_frame, tmp_path): """Test frame processing with cropping""" model_path = tmp_path / "model.pt" model_path.write_text("test") @@ -134,9 +134,7 @@ def test_process_frame_resize( mock_resize.assert_called_once() @patch("dlclive.factory.build_runner") - def test_process_frame_dynamic_cropping( - self, mock_build_runner, mock_runner, sample_frame, tmp_path - ): + def test_process_frame_dynamic_cropping(self, mock_build_runner, mock_runner, sample_frame, tmp_path): """Test dynamic cropping functionality""" model_path = tmp_path / "model.pt" model_path.write_text("test") @@ -190,13 +188,11 @@ def test_close(self, mock_build_runner, mock_runner, tmp_path): dlc.is_initialized = True dlc.close() - assert dlc.is_initialized == False + assert not dlc.is_initialized mock_runner.close.assert_called_once() @patch("dlclive.factory.build_runner") - def test_post_process_pose_with_processor( - self, mock_build_runner, mock_runner, sample_frame, tmp_path - ): + def test_post_process_pose_with_processor(self, mock_build_runner, mock_runner, sample_frame, tmp_path): """Test pose post-processing with processor""" model_path = tmp_path / "model.pt" model_path.write_text("test") diff --git a/tests/test_engine.py b/tests/test_engine.py index b7a5b5a3..d5a6d2f9 100644 --- a/tests/test_engine.py +++ b/tests/test_engine.py @@ -3,6 +3,7 @@ """ import pytest + from dlclive.engine import Engine diff --git a/tests/test_exceptions.py b/tests/test_exceptions.py index eceb7f29..4e497a1d 100644 --- a/tests/test_exceptions.py +++ b/tests/test_exceptions.py @@ -3,6 +3,7 @@ """ import pytest + from dlclive.exceptions import DLCLiveError, DLCLiveWarning @@ -30,7 +31,7 @@ def test_dlclive_warning(self): with pytest.warns(DLCLiveWarning): import warnings - warnings.warn("Test warning", DLCLiveWarning) + warnings.warn("Test warning", DLCLiveWarning, stacklevel=2) def test_dlclive_warning_inheritance(self): """Test DLCLiveWarning is a Warning""" diff --git a/tests/test_factory.py b/tests/test_factory.py index cf68f3b1..b74a2497 100644 --- a/tests/test_factory.py +++ b/tests/test_factory.py @@ -2,8 +2,10 @@ Tests for the Factory module - runner building """ -import pytest from unittest.mock import Mock, patch + +import pytest + from dlclive import factory @@ -37,9 +39,7 @@ def test_build_runner_pytorch(self, mock_runner_class, tmp_path): mock_runner = Mock() mock_runner_class.return_value = mock_runner - runner = factory.build_runner( - "pytorch", model_path, device="cpu", precision="FP32" - ) + runner = factory.build_runner("pytorch", model_path, device="cpu", precision="FP32") mock_runner_class.assert_called_once() assert runner == mock_runner @@ -56,9 +56,7 @@ def test_build_runner_tensorflow(self, tmp_path): pytest.skip("TensorFlow runner module not available") # Patch the TensorFlowRunner class in the runner module - with patch.object( - runner, "TensorFlowRunner", autospec=True - ) as mock_runner_class: + with patch.object(runner, "TensorFlowRunner", autospec=True) as mock_runner_class: mock_runner = Mock() mock_runner_class.return_value = mock_runner @@ -70,9 +68,7 @@ def test_build_runner_tensorflow(self, tmp_path): importlib.reload(sys.modules["dlclive.factory"]) from dlclive import factory as factory_module - runner_instance = factory_module.build_runner( - "base", model_path, precision="FP32" - ) + runner_instance = factory_module.build_runner("base", model_path, precision="FP32") mock_runner_class.assert_called_once() assert runner_instance == mock_runner @@ -89,9 +85,7 @@ def test_build_runner_tensorflow_type_alias(self, tmp_path): pytest.skip("TensorFlow runner module not available") # Patch the TensorFlowRunner class in the runner module - with patch.object( - runner, "TensorFlowRunner", autospec=True - ) as mock_runner_class: + with patch.object(runner, "TensorFlowRunner", autospec=True) as mock_runner_class: mock_runner = Mock() mock_runner_class.return_value = mock_runner diff --git a/tests/test_modelzoo.py b/tests/test_modelzoo.py index c2a0d701..4411645b 100644 --- a/tests/test_modelzoo.py +++ b/tests/test_modelzoo.py @@ -9,9 +9,7 @@ from dlclive import modelzoo -@pytest.mark.parametrize( - "super_animal", ["superanimal_quadruped", "superanimal_topviewmouse"] -) +@pytest.mark.parametrize("super_animal", ["superanimal_quadruped", "superanimal_topviewmouse"]) @pytest.mark.parametrize("model_name", ["hrnet_w32"]) @pytest.mark.parametrize("detector_name", [None, "fasterrcnn_resnet50_fpn_v2"]) def test_get_config_model_paths(super_animal, model_name, detector_name): diff --git a/tests/test_processor.py b/tests/test_processor.py index 3189d196..340dc28f 100644 --- a/tests/test_processor.py +++ b/tests/test_processor.py @@ -3,6 +3,7 @@ """ import numpy as np + from dlclive.processor import Processor diff --git a/tests/test_utils.py b/tests/test_utils.py index e7ec4230..cd8c4721 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -2,9 +2,11 @@ Tests for utility functions - image processing and file operations """ -import pytest +from unittest.mock import MagicMock, patch + import numpy as np -from unittest.mock import patch, MagicMock +import pytest + from dlclive import utils