-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathexport_web_data.py
More file actions
157 lines (133 loc) · 5.9 KB
/
export_web_data.py
File metadata and controls
157 lines (133 loc) · 5.9 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
"""
Export mesh data to binary files for the Three.js web viewer.
Default output (no args):
karate_output/web/faces.bin - int32 face indices
karate_output/web/frame_XXXXXX.bin - float32 vertices per frame
karate_output/web/manifest.json - metadata for the viewer
Usage:
# Default (karate pipeline output):
python export_web_data.py
# Custom paths (for uploaded videos processed by studio_server):
python export_web_data.py --mesh_dir uploads/abc123/output/mesh_data \
--video_path uploads/abc123/video.mp4 \
--output_dir uploads/abc123/output/web
"""
import argparse
import os, sys, json
import numpy as np
from tqdm import tqdm
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
PROJECT_ROOT = os.path.dirname(os.path.abspath(__file__))
DEFAULT_MESH_DIR = os.path.join(PROJECT_ROOT, "karate_output", "mesh_data")
DEFAULT_WEB_DIR = os.path.join(PROJECT_ROOT, "karate_output", "web")
DEFAULT_CHECKPOINT = os.path.join(PROJECT_ROOT, "checkpoints", "sam-3d-body-dinov3", "model.ckpt")
DEFAULT_MHR = os.path.join(PROJECT_ROOT, "checkpoints", "sam-3d-body-dinov3", "assets", "mhr_model.pt")
# Try multiple video paths for the default case
DEFAULT_VIDEO_CANDIDATES = [
os.path.join(PROJECT_ROOT, "karate_heian_sandan.mp4"),
os.path.join(PROJECT_ROOT, "input_video", "karate_video.mp4"),
]
def find_default_video():
for p in DEFAULT_VIDEO_CANDIDATES:
if os.path.isfile(p):
return p
return DEFAULT_VIDEO_CANDIDATES[0]
def main():
parser = argparse.ArgumentParser(description="Export mesh data to binary files for the Three.js web viewer")
parser.add_argument("--mesh_dir", default=DEFAULT_MESH_DIR,
help="Directory containing .npz mesh data files")
parser.add_argument("--video_path", default=None,
help="Path to source video (for resolution/fps metadata)")
parser.add_argument("--output_dir", default=DEFAULT_WEB_DIR,
help="Output directory for web data (.bin + manifest.json)")
parser.add_argument("--checkpoint_path", default=DEFAULT_CHECKPOINT,
help="Path to SAM 3D Body checkpoint (for face topology)")
parser.add_argument("--mhr_path", default=DEFAULT_MHR,
help="Path to MHR model")
args = parser.parse_args()
mesh_dir = args.mesh_dir
web_dir = args.output_dir
video_path = args.video_path or find_default_video()
os.makedirs(web_dir, exist_ok=True)
# 1. Export faces (only need to do this once)
faces_path = os.path.join(web_dir, "faces.bin")
if os.path.exists(faces_path):
print("[1/3] faces.bin already exists, loading...")
faces_data = np.fromfile(faces_path, dtype=np.int32)
n_faces = len(faces_data) // 3
print(f" Loaded {faces_path} ({n_faces} triangles)")
else:
print("[1/3] Loading model to export face topology...")
from sam_3d_body import load_sam_3d_body, SAM3DBodyEstimator
model, cfg = load_sam_3d_body(
args.checkpoint_path,
device="cpu",
mhr_path=args.mhr_path,
)
estimator = SAM3DBodyEstimator(sam_3d_body_model=model, model_cfg=cfg)
faces = estimator.faces # (F, 3) int array
faces.astype(np.int32).tofile(faces_path)
n_faces = faces.shape[0]
print(f" Saved {faces_path} ({n_faces} triangles)")
# 2. Export per-frame vertex positions
print("[2/3] Exporting per-frame vertex data...")
if not os.path.isdir(mesh_dir):
print(f" [!] Mesh directory not found: {mesh_dir}")
sys.exit(1)
files = sorted([f for f in os.listdir(mesh_dir) if f.endswith(".npz")])
if not files:
print(f" [!] No .npz files found in {mesh_dir}")
sys.exit(1)
manifest_frames = []
n_verts_detected = 18439 # default
for fn in tqdm(files, desc="Exporting"):
d = np.load(os.path.join(mesh_dir, fn), allow_pickle=True)
key = fn.replace(".npz", "")
entry = {
"file": key + ".bin",
"frame_idx": int(d["frame_idx"]) if "frame_idx" in d else 0,
"timestamp": float(d["timestamp"]) if "timestamp" in d else 0.0,
}
if "pose_label" in d:
lbl = d["pose_label"]
entry["label"] = str(lbl) if lbl is not None else ""
if "vertices" in d:
verts = d["vertices"].astype(np.float32)
verts.tofile(os.path.join(web_dir, key + ".bin"))
entry["has_mesh"] = True
entry["n_verts"] = int(verts.shape[0])
n_verts_detected = int(verts.shape[0])
else:
entry["has_mesh"] = False
manifest_frames.append(entry)
# 3. Write manifest
print("[3/3] Writing manifest.json...")
vid_w, vid_h, vid_fps = 1280, 720, 30.0
if os.path.isfile(video_path):
import cv2
cap = cv2.VideoCapture(video_path)
vid_w = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) or 1280
vid_h = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) or 720
vid_fps = cap.get(cv2.CAP_PROP_FPS) or 30.0
cap.release()
else:
print(f" [WARN] Video not found: {video_path} — using defaults")
manifest = {
"n_faces": n_faces,
"n_verts": n_verts_detected,
"video_width": vid_w,
"video_height": vid_h,
"video_fps": vid_fps,
"frames": manifest_frames,
}
with open(os.path.join(web_dir, "manifest.json"), "w") as f:
json.dump(manifest, f)
print(f" Manifest: {len(manifest_frames)} frames")
# Also copy the transcript for the labeler (if available)
transcript_path = os.path.join(PROJECT_ROOT, "karate_transcript.json")
if os.path.exists(transcript_path):
import shutil
shutil.copy(transcript_path, os.path.join(web_dir, "transcript.json"))
print(f"\n[DONE] Web data exported to {web_dir}/")
if __name__ == "__main__":
main()