Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion dlclive/modelzoo/pytorch_model_zoo_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,10 +43,12 @@ def _load_model_weights(model_name: str, super_animal: str = super_animal) -> Or
checkpoint: Path = download_super_animal_snapshot(dataset=super_animal, model_name=model_name)
return torch.load(checkpoint, map_location="cpu", weights_only=True)["model"]

# Skip downloading the detector weights for humanbody models, as they are not on huggingface
skip_detector_download = (detector_name is None) or (super_animal == "superanimal_humanbody")
export_dict = {
"config": model_cfg,
"pose": _load_model_weights(model_name),
"detector": _load_model_weights(detector_name) if detector_name is not None else None,
"detector": None if skip_detector_download else _load_model_weights(detector_name),
}
torch.save(export_dict, export_path)

Expand Down
22 changes: 16 additions & 6 deletions dlclive/modelzoo/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from ruamel.yaml import YAML

from dlclive.modelzoo.resolve_config import update_config
from dlclive.pose_estimation_pytorch.models.detectors.torchvision import SUPPORTED_TORCHVISION_DETECTORS

_MODELZOO_PATH = Path(__file__).parent

Expand Down Expand Up @@ -129,12 +130,21 @@ def load_super_animal_config(
model_config["method"] = "BU"
else:
model_config["method"] = "TD"
if super_animal != "superanimal_humanbody":
detector_cfg_path = get_super_animal_model_config_path(
model_name=detector_name
)
detector_cfg = read_config_as_dict(detector_cfg_path)
model_config["detector"] = detector_cfg
detector_cfg_path = get_super_animal_model_config_path(
model_name=detector_name
)
detector_cfg = read_config_as_dict(detector_cfg_path)
model_config["detector"] = detector_cfg
if super_animal == "superanimal_humanbody":
# Apply specific updates required to run the torchvision detector with pretrained weights
assert detector_name in SUPPORTED_TORCHVISION_DETECTORS
model_config["detector"]['model']= {
"type": "TorchvisionDetectorAdaptor",
"model": detector_name,
"weights": "COCO_V1",
"num_classes": None,
"box_score_thresh": 0.6,
}
return model_config


Expand Down
3 changes: 3 additions & 0 deletions dlclive/pose_estimation_pytorch/models/detectors/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,5 +12,8 @@
DETECTORS,
BaseDetector,
)
from dlclive.pose_estimation_pytorch.models.detectors.torchvision import (
TorchvisionDetectorAdaptor,
)
from dlclive.pose_estimation_pytorch.models.detectors.fasterRCNN import FasterRCNN
from dlclive.pose_estimation_pytorch.models.detectors.ssd import SSDLite
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,12 @@
import torch
import torchvision.models.detection as detection

from dlclive.pose_estimation_pytorch.models.detectors.base import BaseDetector
from dlclive.pose_estimation_pytorch.models.detectors.base import DETECTORS, BaseDetector

SUPPORTED_TORCHVISION_DETECTORS = ["fasterrcnn_mobilenet_v3_large_fpn"]


@DETECTORS.register_module
class TorchvisionDetectorAdaptor(BaseDetector):
"""An adaptor for torchvision detectors

Expand All @@ -26,8 +29,8 @@ class TorchvisionDetectorAdaptor(BaseDetector):
- fasterrcnn_mobilenet_v3_large_fpn
- fasterrcnn_resnet50_fpn_v2

This class should not be used out-of-the-box. Subclasses (such as FasterRCNN or
SSDLite) should be used instead.
This class can be used directly (e.g. with pre-trained COCO weights) or through its
subclasses (FasterRCNN or SSDLite) which adapt the model for DLC's 2-class detection.

The torchvision implementation does not allow to get both predictions and losses
with a single forward pass. Therefore, during evaluation only bounding box metrics
Expand Down
21 changes: 18 additions & 3 deletions dlclive/pose_estimation_pytorch/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,10 +268,24 @@ def load_model(self) -> None:
self.model = self.model.half()

self.detector = None
if self.dynamic is None and raw_data.get("detector") is not None:
detector_cfg = self.cfg.get("detector")
has_detector_weights = raw_data.get("detector") is not None
if detector_cfg is not None:
detector_model_cfg = detector_cfg["model"]
uses_pretrained = (
detector_model_cfg.get("pretrained", False)
or detector_model_cfg.get("weights") is not None
)
else:
uses_pretrained = False

if self.dynamic is None and (has_detector_weights or uses_pretrained):
self.detector = models.DETECTORS.build(self.cfg["detector"]["model"])
self.detector.to(self.device)
self.detector.load_state_dict(raw_data["detector"])

if has_detector_weights:
self.detector.load_state_dict(raw_data["detector"])

self.detector.eval()
if self.precision == "FP16":
self.detector = self.detector.half()
Expand All @@ -281,7 +295,8 @@ def load_model(self) -> None:
self.top_down_config.read_config(self.cfg)

detector_transforms = [v2.ToDtype(torch.float32, scale=True)]
if self.cfg["detector"]["data"]["inference"].get("normalize_images", False):
detector_data_cfg = detector_cfg.get("data", {}).get("inference", {})
if detector_data_cfg.get("normalize_images", False):
detector_transforms.append(v2.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]))
self.detector_transform = v2.Compose(detector_transforms)

Expand Down