Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
19 commits
Select commit Hold shift + click to select a range
e8106fe
Handle PyTorch weights-only checkpoint validation
lexiyutou May 21, 2026
08fd52f
add s3 inference ckpt
lexiyutou May 21, 2026
71c1e2c
add egl and osmesa for renderer
lexiyutou May 21, 2026
9249315
move to gpu device
lexiyutou May 21, 2026
c3a3e5f
update s3 inference ckpt
lexiyutou May 21, 2026
60dcea8
update pyproject'
lexiyutou May 21, 2026
9bee30d
fix comment for example checkpoint path in demo_tta.sh
xiu-cs May 21, 2026
39bde99
fix example checkpoint path in demo.sh
xiu-cs May 21, 2026
5a44971
update OpenGL platform preference to use EGL with surfaceless option
xiu-cs May 21, 2026
ffb58f4
update OpenGL platform preference to use EGL with surfaceless option
xiu-cs May 21, 2026
fd920ae
refactor: enhance progress reporting in animal detection and TTA process
xiu-cs May 21, 2026
91b7910
refactor: update deployment script to sync only tracked files and enh…
xiu-cs May 21, 2026
b8ccb0d
refactor: improve reporting for animal detection process in _collect_…
xiu-cs May 28, 2026
d0e0839
Merge branch 'main' into ti_dev
lexiyutou May 28, 2026
b753959
Resolve ti_dev merge follow-ups
xiu-cs May 28, 2026
ca27a04
fix(deploy): remove Detectron2 from Space requirements for fallback t…
xiu-cs May 28, 2026
4a52c35
fix(demo): update description to clarify fallback to full-image crop …
xiu-cs May 28, 2026
c40d922
fix(readme): update detector information and clarify fallback mechani…
xiu-cs May 28, 2026
e3bc151
feat(demo): add server name and port configuration for Gradio interface
xiu-cs May 29, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 11 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,11 @@ Options:
- `PRIMA_VENV=.venv ./scripts/clean_install_local.sh --skip-data` — skip the large `setup_demo_data` download if `data/` is already populated.
- `./scripts/clean_install_local.sh --wipe-data --force-data` — delete downloaded `data/` assets and redownload.
- `./scripts/clean_install_local.sh --no-editable` — only `requirements.txt` (no `pip install -e .`); use if editable install fails and you will install the training stack via conda as in the PyPI section above. You still need **Python 3.10+** for Gradio 5.1+. The smoke test sets `PYTHONPATH` to the repo root so `import prima` works without an editable install.
- **`requirements.txt` pins `deeplabcut==3.0.0rc14`** (SuperAnimal PyTorch API). On macOS, `clean_install_local.sh` installs a PyTables wheel first, then DLC 3.x. Full check: `./scripts/test_local_full.sh`.
- **macOS / DeepLabCut:** `requirements.txt` pins `deeplabcut==3.0.0rc14`
for the SuperAnimal PyTorch API. On macOS, `clean_install_local.sh` installs
it separately after a compatible PyTables wheel (`tables>=3.9.2,<3.11`) to
avoid Apple Silicon build issues. Validate the local setup with
`./scripts/test_local_full.sh`.

After `requirements.txt`, the script runs **`pip install --no-deps -e .`** so the `prima` package is registered without re-resolving `pyproject.toml` (which would pull **Detectron2** from git again). Install Detectron2 separately if needed: `pip install 'git+https://github.com/facebookresearch/detectron2.git'`.

Expand Down Expand Up @@ -159,7 +163,7 @@ The `s1ckpt_inference.ckpt` checkpoint is downloaded automatically if missing.
| | **Local** (`python app.py`) | **Hugging Face Space** |
|--|--|--|
| PRIMA device | GPU if available, else CPU | CPU only |
| Detectron2 | X-101-FPN | R50-FPN (lighter) |
| Detector | Detectron2 X-101-FPN | full-image crop fallback |
| Default TTA iterations | 30 | 0 (PRIMA-only by default) |
| Save `.obj` meshes | on | off |
| Preload checkpoint at startup | off | on |
Expand All @@ -186,8 +190,11 @@ Then from a clean checkout with LFS files present, redeploy the Space (same as `
./scripts/clean_redeploy_hf_space.sh
```

The script rsyncs the working tree (not `git archive`) so image files are materialized
before `git add` turns them into LFS blobs.
The script rsyncs only the Git-tracked files needed by the Space from the
working tree (not `git archive`) so image files are materialized before
`git add` turns them into LFS blobs.
During deployment, `detectron2` is removed from the Space `requirements.txt`;
the app uses its full-image crop fallback on the CPU Space.

---

Expand Down
144 changes: 114 additions & 30 deletions app.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,14 +22,17 @@
"""

import argparse
import concurrent.futures
import os
import queue
import sys
import tempfile
import time
import traceback
from dataclasses import dataclass
from functools import lru_cache
from types import SimpleNamespace
from typing import Any, Dict, List, Optional, Tuple
from typing import Any, Callable, Dict, List, Optional, Tuple
from pathlib import Path

# macOS: PyRender (OpenGL) and DeepLabCut/pyglet must run on the main thread.
Expand Down Expand Up @@ -65,6 +68,8 @@

# Output folder for rendered images/meshes and keypoints
DEFAULT_OUT_FOLDER = "demo_out_tta_gradio"
DEFAULT_SERVER_NAME = os.environ.get("PRIMA_GRADIO_HOST", "0.0.0.0")
DEFAULT_SERVER_PORT = int(os.environ.get("PRIMA_GRADIO_PORT", "7860"))

_D2_R50_CFG = "COCO-Detection/faster_rcnn_R_50_FPN_3x.yaml"
_D2_R50_URL = (
Expand Down Expand Up @@ -125,8 +130,8 @@ def resolve_detectron_device(self) -> str:
("demo_data/000000015956_horse.png", 1e-6, 30, 0.7, 0.1, False, True),
("demo_data/n02412080_12159.png", 1e-6, 30, 0.7, 0.1, False, True),
("demo_data/000000315905_zebra.jpg", 1e-6, 30, 0.7, 0.1, False, True),
("demo_data/beagle.jpg", 1e-6, 0, 0.7, 0.1, False, True),
("demo_data/shepherd_hati.jpg", 1e-6, 0, 0.7, 0.1, False, True),
("demo_data/beagle.jpg", 1e-6, 30, 0.7, 0.1, False, True),
("demo_data/shepherd_hati.jpg", 1e-6, 30, 0.7, 0.1, False, True),
),
description=(
"**Local demo** — full pipeline on your machine (GPU when available).\n\n"
Expand All @@ -145,20 +150,23 @@ def resolve_detectron_device(self) -> str:
detectron_config_yaml=_D2_R50_CFG,
detectron_weights_url=_D2_R50_URL,
detectron_device="cpu",
default_tta_iters=0,
default_tta_iters=30,
max_tta_iters=30,
default_save_mesh=False,
default_side_view=False,
preload_assets=True,
example_rows=(
("demo_data/beagle.jpg", 1e-6, 0, 0.7, 0.1, False, False),
("demo_data/000000015956_horse.png", 1e-6, 0, 0.7, 0.1, False, False),
("demo_data/000000315905_zebra.jpg", 1e-6, 0, 0.7, 0.1, False, False),
("demo_data/000000015956_horse.png", 1e-6, 30, 0.7, 0.1, False, False),
("demo_data/n02412080_12159.png", 1e-6, 30, 0.7, 0.1, False, False),
("demo_data/000000315905_zebra.jpg", 1e-6, 30, 0.7, 0.1, False, False),
("demo_data/beagle.jpg", 1e-6, 30, 0.7, 0.1, False, False),
("demo_data/shepherd_hati.jpg", 1e-6, 30, 0.7, 0.1, False, False),
),
description=(
"**Hugging Face Space (cpu-basic)** — lightweight demo: **CPU-only**, Detectron2 **R50-FPN**, "
"PRIMA inference. TTA is optional (0 by default; increases runtime). Mesh `.obj` export is off "
"by default to save time and disk."
"**Hugging Face Space (cpu-basic)** — lightweight demo: **CPU-only** PRIMA inference. "
"The Space build skips Detectron2 and uses a full-image crop fallback. TTA is optional "
"(30 iterations by default, matching the local demo; set to 0 to skip). Mesh `.obj` export "
"is off by default to save time and disk."
),
Comment thread
lexiyutou marked this conversation as resolved.
interface_title="PRIMA on Hugging Face — lightweight CPU demo",
)
Expand Down Expand Up @@ -364,6 +372,7 @@ def _collect_animal_results(
side_view: bool,
save_mesh: bool,
boxes: Optional[np.ndarray] = None,
progress_callback: Optional[Callable[[str], None]] = None,
) -> Tuple[List[np.ndarray], List[np.ndarray], List[np.ndarray], str | None, str | None]:
"""Run detection + PRIMA + SuperAnimal + TTA on a single RGB image.

Expand All @@ -384,15 +393,25 @@ def _collect_animal_results(
tta_optimize,
)

def report(message: str) -> None:
if progress_callback is not None:
progress_callback(message)

if int(tta_num_iters) > 0 and not SUPER_ANIMAL_ARGS.saved_2d_model_path:
report("Resolving SuperAnimal weights...")
SUPER_ANIMAL_ARGS.saved_2d_model_path = resolve_sa_weights_path("")

img_bgr = cv2.cvtColor(img_rgb, cv2.COLOR_RGB2BGR)
if boxes is None:
if detector is None:
report("Detectron2 unavailable; using full-image crop...")
else:
report("Detecting animals with Detectron2...")
boxes = _detect_animal_boxes(detector, img_bgr, det_thresh)
if boxes is None:
return [], [], [], None, None

report(f"Detected {len(boxes)} animal(s). Preparing crops...")
dataset = ViTDetDataset(model_cfg, img_bgr, boxes)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, num_workers=0)

Expand All @@ -404,9 +423,11 @@ def _collect_animal_results(

img_token = next(tempfile._get_candidate_names())

for batch in dataloader:
total_batches = len(dataloader)
for batch_idx, batch in enumerate(dataloader, start=1):
batch = recursive_to(batch, device)

report(f"Animal {batch_idx}/{total_batches}: running PRIMA...")
with torch.no_grad():
out_before = model(batch)

Expand All @@ -416,6 +437,7 @@ def _collect_animal_results(
img_fn = f"{img_token}"
from demo_tta import render_and_save # imported lazily to avoid circular issues

report(f"Animal {batch_idx}/{total_batches}: rendering before TTA...")
render_and_save(
renderer,
cam_crop_to_full_fn,
Expand All @@ -441,6 +463,7 @@ def _collect_animal_results(
before_mesh_paths.append(before_obj_path)

if int(tta_num_iters) <= 0:
report(f"Animal {batch_idx}/{total_batches}: rendering final output...")
render_and_save(
renderer,
cam_crop_to_full_fn,
Expand All @@ -467,6 +490,7 @@ def _collect_animal_results(
continue

# Prepare patch for SuperAnimal
report(f"Animal {batch_idx}/{total_batches}: running SuperAnimal keypoints...")
patch_rgb = denorm_patch_to_rgb(batch["img"][0])
with tempfile.TemporaryDirectory(prefix=f"dlc_{img_fn}_{animal_id}_") as tmp_dir:
bodyparts_xyc = run_superanimal_on_patch(patch_rgb, SUPER_ANIMAL_ARGS, tmp_dir)
Expand Down Expand Up @@ -497,6 +521,7 @@ def _collect_animal_results(
gt_kpts_norm = torch.from_numpy(kpts_norm[None]).to(device=device, dtype=batch["img"].dtype)

# Run TTA
report(f"Animal {batch_idx}/{total_batches}: running TTA ({int(tta_num_iters)} iterations)...")
out_after = tta_optimize(
model,
batch,
Expand All @@ -505,6 +530,7 @@ def _collect_animal_results(
lr=float(tta_lr),
)

report(f"Animal {batch_idx}/{total_batches}: rendering after TTA...")
render_and_save(
renderer,
cam_crop_to_full_fn,
Expand Down Expand Up @@ -532,6 +558,7 @@ def _collect_animal_results(
first_before_mesh = before_mesh_paths[0] if before_mesh_paths else None
first_after_mesh = after_mesh_paths[0] if after_mesh_paths else None

report("Collecting outputs...")
return before_imgs, after_imgs, kpt_imgs, first_before_mesh, first_after_mesh


Expand Down Expand Up @@ -634,25 +661,65 @@ def gradio_inference(
None,
None,
None,
f"Detected {len(boxes)} animal region(s). Running PRIMA (+ SuperAnimal/TTA if enabled)…",
)
before_imgs, after_imgs, kpt_imgs, mesh_before, mesh_after = _collect_animal_results(
runtime_cache["model"],
runtime_cache["model_cfg"],
runtime_cache["renderer"],
runtime_cache["cam_crop_to_full_fn"],
runtime_cache["device"],
runtime_cache["detector"],
out_folder,
img_rgb,
tta_lr=tta_lr,
tta_num_iters=tta_num_iters,
det_thresh=det_thresh,
kp_conf_thresh=kp_conf_thresh,
side_view=side_view,
save_mesh=save_mesh,
boxes=boxes,
f"Detected {len(boxes)} animal region(s). Running PRIMA (+ SuperAnimal/TTA if enabled)...",
)

def run_collect(progress_callback: Optional[Callable[[str], None]] = None):
return _collect_animal_results(
runtime_cache["model"],
runtime_cache["model_cfg"],
runtime_cache["renderer"],
runtime_cache["cam_crop_to_full_fn"],
runtime_cache["device"],
runtime_cache["detector"],
out_folder,
img_rgb,
tta_lr=tta_lr,
tta_num_iters=tta_num_iters,
det_thresh=det_thresh,
kp_conf_thresh=kp_conf_thresh,
side_view=side_view,
save_mesh=save_mesh,
boxes=boxes,
progress_callback=progress_callback,
)

if _should_use_gradio_queue(profile):
stage_updates: queue.Queue[str] = queue.Queue()

def report_stage(message: str) -> None:
stage_updates.put(message)

with concurrent.futures.ThreadPoolExecutor(max_workers=1) as pool:
fut = pool.submit(
run_collect,
report_stage,
)
t0 = time.monotonic()
latest_stage = "Starting inference..."
while True:
while True:
try:
latest_stage = stage_updates.get_nowait()
except queue.Empty:
break
else:
elapsed = int(time.monotonic() - t0)
yield None, None, None, f"{latest_stage}\nElapsed: {elapsed}s"
try:
before_imgs, after_imgs, kpt_imgs, mesh_before, mesh_after = fut.result(
timeout=1.0
)
break
except concurrent.futures.TimeoutError:
elapsed = int(time.monotonic() - t0)
yield None, None, None, (
f"{latest_stage}\n"
f"Elapsed: {elapsed}s\n"
"CPU inference can take several minutes."
)
else:
before_imgs, after_imgs, kpt_imgs, mesh_before, mesh_after = run_collect()
except Exception:
yield None, None, None, f"Inference failed:\n{traceback.format_exc()}"
return
Expand Down Expand Up @@ -744,6 +811,18 @@ def parse_args() -> argparse.Namespace:
default=DEFAULT_OUT_FOLDER,
help="Folder used to save rendered outputs and meshes",
)
parser.add_argument(
"--server_name",
type=str,
default=DEFAULT_SERVER_NAME,
help="Host/interface used by Gradio. Use 0.0.0.0 for Run:AI port-forward.",
)
parser.add_argument(
"--server_port",
type=int,
default=DEFAULT_SERVER_PORT,
help="Port used by Gradio.",
)
return parser.parse_args()


Expand All @@ -764,4 +843,9 @@ def parse_args() -> argparse.Namespace:
out_folder=args.out_folder,
runtime_cache=runtime_cache,
)
demo.launch(inbrowser=False)
demo.launch(
inbrowser=False,
ssr_mode=False,
server_name=args.server_name,
server_port=args.server_port,
)
2 changes: 1 addition & 1 deletion demo.sh
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
#
# If this local file is missing, it will be downloaded from the PRIMA Hugging Face repo.
# To use another local checkpoint instead, update this path.
# For example: checkpoint='data/PRIMAS3/checkpoints/s3ckpt.ckpt'
# For example: checkpoint='data/PRIMAS3/checkpoints/s3ckpt_inference.ckpt'
checkpoint='data/PRIMAS1/checkpoints/s1ckpt_inference.ckpt'

python demo.py \
Expand Down
2 changes: 1 addition & 1 deletion demo_tta.sh
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
#
# This standard path is auto-downloaded from the PRIMA Hugging Face repo if missing.
# To use another local checkpoint instead, update this path.
# For example: checkpoint='data/PRIMAS3/checkpoints/s3ckpt.ckpt'
# For example: checkpoint='data/PRIMAS3/checkpoints/s3ckpt_inference.ckpt'
checkpoint='data/PRIMAS1/checkpoints/s1ckpt_inference.ckpt'

python3 demo_tta.py \
Expand Down
7 changes: 7 additions & 0 deletions packages.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
libosmesa6
libgl1
libgl1-mesa-dri
libegl-mesa0
libegl1
libglx-mesa0
libgles2
6 changes: 4 additions & 2 deletions prima/utils/mesh_renderer.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,10 @@
from ctypes.util import find_library

if 'PYOPENGL_PLATFORM' not in os.environ and os.uname().sysname != 'Darwin':
# Prefer OSMesa; fall back to EGL where available.
os.environ['PYOPENGL_PLATFORM'] = 'osmesa' if find_library('OSMesa') else 'egl'
# Prefer EGL; PyOpenGL's OSMesa bindings can lack symbols required by pyrender.
os.environ['PYOPENGL_PLATFORM'] = 'egl' if find_library('EGL') else 'osmesa'
if os.environ['PYOPENGL_PLATFORM'] == 'egl':
os.environ.setdefault('EGL_PLATFORM', 'surfaceless')
import torch
from torchvision.utils import make_grid
import numpy as np
Expand Down
8 changes: 4 additions & 4 deletions prima/utils/renderer.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,10 @@
from ctypes.util import find_library

if 'PYOPENGL_PLATFORM' not in os.environ and os.uname().sysname != 'Darwin':
# Prefer OSMesa; fall back to EGL where available.
os.environ['PYOPENGL_PLATFORM'] = 'osmesa' if find_library('OSMesa') else 'egl'
# Prefer EGL; PyOpenGL's OSMesa bindings can lack symbols required by pyrender.
os.environ['PYOPENGL_PLATFORM'] = 'egl' if find_library('EGL') else 'osmesa'
if os.environ['PYOPENGL_PLATFORM'] == 'egl':
os.environ.setdefault('EGL_PLATFORM', 'surfaceless')
Comment thread
lexiyutou marked this conversation as resolved.
import torch
import numpy as np
import pyrender
Expand Down Expand Up @@ -438,5 +440,3 @@ def add_point_lighting(self, scene, cam_node, color=np.ones(3), intensity=1.0):
if scene.has_node(node):
continue
scene.add_node(node)


Loading
Loading