From 3ba12a70250cfbea57484e97eab91137620a9a8a Mon Sep 17 00:00:00 2001 From: Kin Date: Sat, 10 Jan 2026 18:40:25 +0100 Subject: [PATCH 1/6] chore(visualization): refactor the open3d visualization, merge fn together. * reset fire class uage directly. * add save screenshot easily with multi-view. * sync view point through diff windows. * visual lidar center tf if set slc to True. --- README.md | 14 +- src/dataset.py | 2 +- src/trainer.py | 2 +- src/utils/mics.py | 72 ++++++ src/utils/o3d_view.py | 502 ++++++++++++++++++++++++-------------- tools/README.md | 28 ++- tools/visualization.py | 537 +++++++++++++++++++++++++++-------------- 7 files changed, 779 insertions(+), 378 deletions(-) diff --git a/README.md b/README.md index 3f3d59f..3603a46 100644 --- a/README.md +++ b/README.md @@ -301,7 +301,7 @@ pip install "evalai" evalai set-token # Step 3: Copy the command pop above and submit to leaderboard -evalai challenge 2010 phase 4018 submit --file av2_submit.zip --large --private +# evalai challenge 2010 phase 4018 submit --file av2_submit.zip --large --private evalai challenge 2210 phase 4396 submit --file av2_submit_v2.zip --large --private ``` @@ -318,21 +318,24 @@ python eval.py model=nsfp dataset_path=/home/kin/data/av2/h5py/demo/val # The output of above command will be like: Model: DeFlow, Checkpoint from: /home/kin/model_zoo/v2/seflow_best.ckpt We already write the flow_est into the dataset, please run following commend to visualize the flow. Copy and paste it to your terminal: -python tools/visualization.py --res_name 'seflow_best' --data_dir /home/kin/data/av2/preprocess_v2/sensor/vis +python tools/visualization.py vis --res_name 'seflow_best' --data_dir /home/kin/data/av2/preprocess_v2/sensor/vis Enjoy! ^v^ ------ # Then run the command in the terminal: -python tools/visualization.py --res_name 'seflow_best' --data_dir /home/kin/data/av2/preprocess_v2/sensor/vis +python tools/visualization.py vis --res_name 'seflow_best' --data_dir /home/kin/data/av2/preprocess_v2/sensor/vis ``` https://github.com/user-attachments/assets/f031d1a2-2d2f-4947-a01f-834ed1c146e6 For exporting easy comparsion with ground truth and other methods, we also provided multi-visulization open3d window: ```bash -python tools/visualization.py --mode mul --res_name "['flow', 'seflow_best']" --data_dir /home/kin/data/av2/preprocess_v2/sensor/vis +python tools/visualization.py vis --res_name "['flow', 'seflow_best']" --data_dir /home/kin/data/av2/preprocess_v2/sensor/vis ``` -Or another way to interact with [rerun](https://github.com/rerun-io/rerun) but please only vis scene by scene, not all at once. +**Tips**: To quickly create qualitative results for all methods, you can use multiple results comparison mode, select a good viewpoint and then save screenshots for all frames by pressing `P` key. You will found all methods' results are saved in the output folder (default is `logs/imgs`). Enjoy it! + + +_Rerun_: Another way to interact with [rerun](https://github.com/rerun-io/rerun) but please only vis scene by scene, not all at once. ```bash python tools/visualization_rerun.py --data_dir /home/kin/data/av2/h5py/demo/train --res_name "['flow', 'deflow']" @@ -340,7 +343,6 @@ python tools/visualization_rerun.py --data_dir /home/kin/data/av2/h5py/demo/trai https://github.com/user-attachments/assets/07e8d430-a867-42b7-900a-11755949de21 - ## Cite Us [*OpenSceneFlow*](https://github.com/KTH-RPL/OpenSceneFlow) is originally designed by [Qingwen Zhang](https://kin-zhang.github.io/) from DeFlow and SeFlow. diff --git a/src/dataset.py b/src/dataset.py index dd04995..9d431b9 100644 --- a/src/dataset.py +++ b/src/dataset.py @@ -347,7 +347,7 @@ def __getitem__(self, index_): data_dict[f'gmh{i+1}'] = past_gm data_dict[f'poseh{i+1}'] = past_pose - for data_key in self.vis_name + ['ego_motion', 'lidar_dt', + for data_key in self.vis_name + ['ego_motion', 'lidar_dt', 'lidar_center', # ground truth information: 'flow', 'flow_is_valid', 'flow_category_indices', 'flow_instance_id', 'dufo']: if data_key in f[key]: diff --git a/src/trainer.py b/src/trainer.py index 84b22c5..de064f1 100644 --- a/src/trainer.py +++ b/src/trainer.py @@ -243,7 +243,7 @@ def on_validation_epoch_end(self): with open(str(self.save_res_path)+'.pkl', 'wb') as f: pickle.dump((self.metrics.epe_3way, self.metrics.bucketed, self.metrics.epe_ssf), f) print(f"We already write the {self.res_name} into the dataset, please run following commend to visualize the flow. Copy and paste it to your terminal:") - print(f"python tools/visualization.py --res_name '{self.res_name}' --data_dir {self.dataset_path}") + print(f"python tools/visualization.py vis --res_name '{self.res_name}' --data_dir {self.dataset_path}") print(f"Enjoy! ^v^ ------ \n") self.metrics = OfficialMetrics() diff --git a/src/utils/mics.py b/src/utils/mics.py index 2ee5c65..8ba1a8e 100644 --- a/src/utils/mics.py +++ b/src/utils/mics.py @@ -172,6 +172,78 @@ def make_colorwheel(transitions: tuple=DEFAULT_TRANSITIONS) -> np.ndarray: return colorwheel +def error_to_color( + error_magnitude: np.ndarray, + max_error: float, + color_map: str = "jet" +) -> np.ndarray: + """ + Convert flow error to RGB color visualization. + Args: + color_map: Color map to use for visualization ("jet" recommended for error visualization) + + Returns: + RGB color representation of the error of shape (..., 3) + """ + if max_error > 0: + normalized_error = np.clip(error_magnitude / max_error, 0, 1) + else: + normalized_error = np.zeros_like(error_magnitude) + + # Create colormap + if color_map == "jet": + # Simple jet colormap implementation + colors = np.zeros((*normalized_error.shape, 3), dtype=np.uint8) + + # Blue to cyan to green to yellow to red + # Blue (low error) + idx = normalized_error < 0.25 + colors[idx, 2] = 255 + colors[idx, 0] = 0 + colors[idx, 1] = np.uint8(255 * normalized_error[idx] * 4) + + # Cyan to green + idx = (normalized_error >= 0.25) & (normalized_error < 0.5) + colors[idx, 1] = 255 + colors[idx, 0] = 0 + colors[idx, 2] = np.uint8(255 * (1 - (normalized_error[idx] - 0.25) * 4)) + + # Green to yellow + idx = (normalized_error >= 0.5) & (normalized_error < 0.75) + colors[idx, 1] = 255 + colors[idx, 2] = 0 + colors[idx, 0] = np.uint8(255 * (normalized_error[idx] - 0.5) * 4) + + # Yellow to red (high error) + idx = normalized_error >= 0.75 + colors[idx, 0] = 255 + colors[idx, 2] = 0 + colors[idx, 1] = np.uint8(255 * (1 - (normalized_error[idx] - 0.75) * 4)) + + elif color_map == "hot": + # Hot colormap: black -> red -> yellow -> white + colors = np.zeros((*normalized_error.shape, 3), dtype=np.uint8) + + # Black to red + idx = normalized_error < 0.33 + colors[idx, 0] = np.uint8(255 * normalized_error[idx] * 3) + + # Red to yellow + idx = (normalized_error >= 0.33) & (normalized_error < 0.67) + colors[idx, 0] = 255 + colors[idx, 1] = np.uint8(255 * (normalized_error[idx] - 0.33) * 3) + + # Yellow to white + idx = normalized_error >= 0.67 + colors[idx, 0] = 255 + colors[idx, 1] = 255 + colors[idx, 2] = np.uint8(255 * (normalized_error[idx] - 0.67) * 3) + + else: + raise ValueError(f"Unsupported color map: {color_map}. Use 'jet' or 'hot'.") + + return colors + def flow_to_rgb( flow: np.ndarray, flow_max_radius: Optional[float]=None, diff --git a/src/utils/o3d_view.py b/src/utils/o3d_view.py index e5de457..ea2316f 100644 --- a/src/utils/o3d_view.py +++ b/src/utils/o3d_view.py @@ -1,219 +1,349 @@ -''' -# @date: 2023-1-26 16:38 -# @author: Qingwen Zhang (https://kin-zhang.github.io/), Ajinkya Khoche (https://ajinkyakhoche.github.io/) -# Copyright (C) 2023-now, RPL, KTH Royal Institute of Technology -# @detail: -# 1. Play the data you want in open3d, and save the view control to json file. -# 2. Use the json file to view the data again. -# 3. Save the screen shot and view file for later check and animation. -# -# code gits: https://gist.github.com/Kin-Zhang/77e8aa77a998f1a4f7495357843f24ef -# -# CHANGELOG: -# 2024-08-23 21:41(Qingwen): remove totally on view setting from scratch but use open3d>=0.18.0 version for set_view from json text func. -# 2024-04-15 12:06(Qingwen): show a example json text. add hex_to_rgb, color_map_hex, color_map (for color points if needed) -# 2024-01-27 0:41(Qingwen): update MyVisualizer class, reference from kiss-icp: https://github.com/PRBonn/kiss-icp/blob/main/python/kiss_icp/tools/visualizer.py -# 2024-09-10 (Ajinkya): Add MyMultiVisualizer class to view multiple windows at once, allow forward and backward playback, create bev square for giving a sense of metric scale. -''' +""" +Open3D Visualizer for Scene Flow +================================ +@date: 2023-1-26 16:38 +@author: Qingwen Zhang (https://kin-zhang.github.io/), Ajinkya Khoche (https://ajinkyakhoche.github.io/) +Copyright (C) 2023-now, RPL, KTH Royal Institute of Technology + +# This file is part of OpenSceneFlow (https://github.com/KTH-RPL/OpenSceneFlow). +# If you find this repo helpful, please cite the respective publication as +# listed on the above website. + +Features: + - Single or multi-window visualization + - Viewpoint sync across windows (press S) + - Forward/backward playback + - Screenshot and viewpoint save + +CHANGELOG: +2026-01-10 (Qingwen): Unified single/multi visualizer, added viewpoint sync with S key +2024-09-10 (Ajinkya): Add multi-window support, forward/backward playback +2024-08-23 (Qingwen): Use open3d>=0.18.0 set_view_status API +""" import open3d as o3d -import os, time -from typing import List, Callable +import os +import time +from typing import List, Callable, Union from functools import partial import numpy as np -def hex_to_rgb(hex_color): + +def hex_to_rgb(hex_color: str) -> tuple: + """Convert hex color string to RGB tuple (0-1 range).""" hex_color = hex_color.lstrip("#") return tuple(int(hex_color[i:i + 2], 16) / 255.0 for i in (0, 2, 4)) -color_map_hex = ['#a6cee3', '#de2d26', '#1f78b4','#b2df8a','#33a02c','#fb9a99','#e31a1c','#fdbf6f','#ff7f00',\ - '#cab2d6','#6a3d9a','#ffff99','#b15928', '#8dd3c7','#ffffb3','#bebada','#fb8072','#80b1d3',\ - '#fdb462','#b3de69','#fccde5','#d9d9d9','#bc80bd','#ccebc5','#ffed6f'] - -color_map = [hex_to_rgb(color) for color in color_map_hex] +COLOR_MAP_HEX = [ + '#a6cee3', '#de2d26', '#1f78b4', '#b2df8a', '#33a02c', '#fb9a99', '#e31a1c', + '#fdbf6f', '#ff7f00', '#cab2d6', '#6a3d9a', '#ffff99', '#b15928', '#8dd3c7', + '#ffffb3', '#bebada', '#fb8072', '#80b1d3', '#fdb462', '#b3de69', '#fccde5', + '#d9d9d9', '#bc80bd', '#ccebc5', '#ffed6f' +] +color_map = [hex_to_rgb(c) for c in COLOR_MAP_HEX] + + +class O3DVisualizer: + """ + Unified Open3D visualizer supporting single or multiple windows. + + Args: + view_file: Path to JSON file with saved viewpoint + res_name: Single name or list of names for windows + save_folder: Folder to save screenshots + point_size: Point size for rendering + bg_color: Background color as RGB tuple (0-1 range) + screen_width: Screen width for multi-window layout + screen_height: Screen height for multi-window layout + + Usage: + # Single window + viz = O3DVisualizer(res_name="flow") + + # Multiple windows + viz = O3DVisualizer(res_name=["flow", "flow_est"]) + """ + + def __init__( + self, + view_file: str = None, + res_name: Union[str, List[str]] = "flow", + save_folder: str = "logs/imgs", + point_size: float = 3.0, + bg_color: tuple = (80/255, 90/255, 110/255), + screen_width: int = 1375, + screen_height: int = 2500, + ): + # Normalize res_name to list + self.res_names = [res_name] if isinstance(res_name, str) else list(res_name) + self.num_windows = len(self.res_names) -class MyVisualizer: - def __init__(self, view_file=None, window_title="Default", save_folder="logs/imgs"): - self.params = None - self.vis = o3d.visualization.VisualizerWithKeyCallback() - self.vis.create_window(window_name=window_title) self.view_file = view_file - + self.save_folder = save_folder + self.point_size = point_size + self.bg_color = np.asarray(bg_color) + + os.makedirs(self.save_folder, exist_ok=True) + + # State self.block_vis = True self.play_crun = False self.reset_bounding_box = True - self.save_img_folder = save_folder - os.makedirs(self.save_img_folder, exist_ok=True) + self.playback_direction = 1 # 1: forward, -1: backward + self.curr_index = -1 + self.tmp_value = None + self._should_save = False + self._should_sync = False + self._sync_source_idx = 0 + + # Create windows + self.vis: List[o3d.visualization.VisualizerWithKeyCallback] = [] + self._create_windows(screen_width, screen_height) + self._setup_render_options() + self._register_callbacks() + self._print_help() + + def _create_windows(self, screen_width: int, screen_height: int): + """Create visualizer windows.""" + if self.num_windows == 1: + v = o3d.visualization.VisualizerWithKeyCallback() + title = self._window_title(self.res_names[0]) + v.create_window(window_name=title) + self.vis.append(v) + else: + window_width = screen_width // 2 + window_height = screen_height // 4 + epsilon = 150 + positions = [ + (0, 0), + (screen_width - window_width + epsilon, 0), + (0, screen_height - window_height + epsilon), + (screen_width - window_width + epsilon, screen_height - window_height + epsilon), + ] + for i, name in enumerate(self.res_names): + v = o3d.visualization.VisualizerWithKeyCallback() + title = self._window_title(name) + pos = positions[i % len(positions)] + v.create_window(window_name=title, width=window_width, height=window_height, + left=pos[0], top=pos[1]) + self.vis.append(v) + + def _window_title(self, name: str) -> str: + label = "ground truth flow" if name == "flow" else name + return f"View {label} | SPACE: play/pause" + + def _setup_render_options(self): + """Configure render options for all windows.""" + for v in self.vis: + opt = v.get_render_option() + opt.background_color = self.bg_color + opt.point_size = self.point_size + + def _register_callbacks(self): + """Register keyboard callbacks for all windows.""" + callbacks = [ + (["Ā", "Q", "\x1b"], self._quit), + ([" "], self._start_stop), + (["D"], self._next_frame), + (["A"], self._prev_frame), + (["P"], self._save_screen), + (["E"], self._save_error_bar), + (["S"], self._sync_viewpoint), + ] + for keys, callback in callbacks: + for key in keys: + for idx, v in enumerate(self.vis): + v.register_key_callback(ord(str(key)), partial(callback, src_idx=idx)) + + def _print_help(self): + sync_hint = "[S] sync viewpoint across windows\n" if self.num_windows > 1 else "" print( - f"\n{window_title.capitalize()} initialized. Press:\n" - "\t[SPACE] to pause/start\n" - "\t[ESC/Q] to exit\n" - "\t [P] to save screen and viewpoint\n" - "\t [D] to step next\n" + f"\nVisualizer initialized ({self.num_windows} window(s)). Keys:\n" + f" [SPACE] play/pause [D] next frame [A] prev frame\n" + f" [P] save screenshot [E] save error bar\n" + f" {sync_hint}" + f" [ESC/Q] quit\n" ) - self._register_key_callback(["Ā", "Q", "\x1b"], self._quit) - self._register_key_callback(["P"], self._save_screen) - self._register_key_callback([" "], self._start_stop) - self._register_key_callback(["D"], self._next_frame) - - def show(self, assets: List): - self.vis.clear_geometries() - - for asset in assets: - self.vis.add_geometry(asset) - if self.view_file is not None: - self.vis.set_view_status(open(self.view_file).read()) - - self.vis.update_renderer() - self.vis.poll_events() - self.vis.run() - self.vis.destroy_window() - - def update(self, assets: List, clear: bool = True): - if clear: - self.vis.clear_geometries() - - for asset in assets: - self.vis.add_geometry(asset, reset_bounding_box=False) - self.vis.update_geometry(asset) + # ------------------------------------------------------------------------- + # Public API + # ------------------------------------------------------------------------- + + def update(self, assets: Union[List, List[List]], index: int = -1, value: float = None): + """ + Update visualizer with new assets. + + Args: + assets: For single window - list of geometries + For multi window - list of lists of geometries + index: Current frame index (for screenshot naming) + value: Optional value (e.g., max error for colorbar) + """ + self.curr_index = index + self.tmp_value = value + + # Normalize to list of lists + if self.num_windows == 1: + assets_list = [assets] if not self._is_nested_list(assets) else assets + else: + assets_list = assets + + # Clear and add geometries + for v in self.vis: + v.clear_geometries() + + for i, window_assets in enumerate(assets_list): + if i >= len(self.vis): + break + for asset in window_assets: + self.vis[i].add_geometry(asset, reset_bounding_box=False) + self.vis[i].update_geometry(asset) + + # Reset view on first frame if self.reset_bounding_box: - self.vis.reset_view_point(True) - if self.view_file is not None: - self.vis.set_view_status(open(self.view_file).read()) + for v in self.vis: + v.reset_view_point(True) + if self.view_file is not None: + v.set_view_status(open(self.view_file).read()) self.reset_bounding_box = False - - self.vis.update_renderer() + + # Render and wait + for v in self.vis: + v.update_renderer() + while self.block_vis: - self.vis.poll_events() + for v in self.vis: + v.poll_events() + if self._should_sync: + self._do_sync_viewpoint() + if self._should_save: + self._do_save_screen() if self.play_crun: break + self.block_vis = not self.block_vis - def _register_key_callback(self, keys: List, callback: Callable): - for key in keys: - self.vis.register_key_callback(ord(str(key)), partial(callback)) - - def _next_frame(self, vis): - self.block_vis = not self.block_vis - - def _start_stop(self, vis): - self.play_crun = not self.play_crun - - def _quit(self, vis): + def show(self, assets: List): + """Show assets and run visualization loop (blocking).""" + for v in self.vis: + v.clear_geometries() + for asset in assets: + for v in self.vis: + v.add_geometry(asset) + if self.view_file is not None: + v.set_view_status(open(self.view_file).read()) + for v in self.vis: + v.update_renderer() + v.poll_events() + self.vis[0].run() + for v in self.vis: + v.destroy_window() + + def _is_nested_list(self, obj) -> bool: + """Check if obj is a list of lists.""" + return isinstance(obj, list) and len(obj) > 0 and isinstance(obj[0], list) + + # ------------------------------------------------------------------------- + # Callbacks + # ------------------------------------------------------------------------- + + def _quit(self, vis, src_idx=0): print("Destroying Visualizer. Thanks for using ^v^.") - vis.destroy_window() + for v in self.vis: + v.destroy_window() os._exit(0) - def _save_screen(self, vis): - timestamp = time.strftime("%Y-%m-%d-%H-%M-%S", time.localtime()) - png_file = f"{self.save_img_folder}/ScreenShot_{timestamp}.png" - view_json_file = f"{self.save_img_folder}/ScreenView_{timestamp}.json" - with open(view_json_file, 'w') as f: - f.write(vis.get_view_status()) - vis.capture_screen_image(png_file) - print(f"ScreenShot saved to: {png_file}, Please check it.") - - -def create_bev_square(size=409.6, color=[68/255,114/255,196/255]): - # Create the vertices of the square - half_size = size / 2.0 - vertices = np.array([ - [-half_size, -half_size, 0], - [half_size, -half_size, 0], - [half_size, half_size, 0], - [-half_size, half_size, 0] - ]) - - # Define the square as a LineSet for visualization - lines = [[0, 1], [1, 2], [2, 3], [3, 0]] - colors = [color for _ in lines] - - line_set = o3d.geometry.LineSet( - points=o3d.utility.Vector3dVector(vertices), - lines=o3d.utility.Vector2iVector(lines) - ) - line_set.colors = o3d.utility.Vector3dVector(colors) - - return line_set - -class MyMultiVisualizer(MyVisualizer): - def __init__(self, view_file=None, flow_mode=['flow'], screen_width=2500, screen_height = 1375): - self.params = None - self.view_file = view_file - self.block_vis = True - self.play_crun = False - self.reset_bounding_box = True - self.playback_direction = 1 # 1:forward, -1:backward - - self.vis = [] - # self.o3d_vctrl = [] - - # Define width and height for each window - window_width = screen_width // 2 - window_height = screen_height // 2 - # Define positions for the four windows - epsilon = 150 - positions = [ - (0, 0), # Top-left - (screen_width - window_width + epsilon, 0), # Top-right - (0, screen_height - window_height + epsilon), # Bottom-left - (screen_width - window_width + epsilon, screen_height - window_height + epsilon) # Bottom-right - ] - - for i, mode in enumerate(flow_mode): - window_title = f"view {'ground truth flow' if mode == 'flow' else f'{mode} flow'}, `SPACE` start/stop" - v = o3d.visualization.VisualizerWithKeyCallback() - v.create_window(window_name=window_title, width=window_width, height=window_height, left=positions[i%len(positions)][0], top=positions[i%len(positions)][1]) - # self.o3d_vctrl.append(ViewControl(v.get_view_control(), view_file=view_file)) - self.vis.append(v) - - self._register_key_callback(["Ā", "Q", "\x1b"], self._quit) - self._register_key_callback([" "], self._start_stop) - self._register_key_callback(["D"], self._next_frame) - self._register_key_callback(["A"], self._prev_frame) - print( - f"\n{window_title.capitalize()} initialized. Press:\n" - "\t[SPACE] to pause/start\n" - "\t[ESC/Q] to exit\n" - "\t [P] to save screen and viewpoint\n" - "\t [D] to step next\n" - "\t [A] to step previous\n" - ) - - def update(self, assets_list: List, clear: bool = True): - if clear: - [v.clear_geometries() for v in self.vis] - - for i, assets in enumerate(assets_list): - [self.vis[i].add_geometry(asset, reset_bounding_box=False) for asset in assets] - self.vis[i].update_geometry(assets[-1]) - - if self.reset_bounding_box: - [v.reset_view_point(True) for v in self.vis] - if self.view_file is not None: - # [o.read_viewTfile(self.view_file) for o in self.o3d_vctrl] - [v.set_view_status(open(self.view_file).read()) for v in self.vis] - self.reset_bounding_box = False - - [v.update_renderer() for v in self.vis] - while self.block_vis: - [v.poll_events() for v in self.vis] - if self.play_crun: - break - self.block_vis = not self.block_vis + def _start_stop(self, vis, src_idx=0): + self.play_crun = not self.play_crun - def _register_key_callback(self, keys: List, callback: Callable): - for key in keys: - [v.register_key_callback(ord(str(key)), partial(callback)) for v in self.vis] - def _next_frame(self, vis): + def _next_frame(self, vis, src_idx=0): self.block_vis = not self.block_vis self.playback_direction = 1 - def _prev_frame(self, vis): + + def _prev_frame(self, vis, src_idx=0): self.block_vis = not self.block_vis self.playback_direction = -1 + def _save_screen(self, vis, src_idx=0): + self._should_save = True + # NOTE: sync viewpoint before saving + self._should_sync = True + return False + + def _sync_viewpoint(self, vis, src_idx=0): + """Sync all windows to the viewpoint of the source window.""" + self._should_sync = True + self._sync_source_idx = src_idx + return False + + def _do_sync_viewpoint(self): + """Actually perform viewpoint sync (called from main loop).""" + if self.num_windows <= 1: + self._should_sync = False + return + + source_view = self.vis[self._sync_source_idx].get_view_status() + for i, v in enumerate(self.vis): + if i != self._sync_source_idx: + v.set_view_status(source_view) + v.update_renderer() + print(f"Synced viewpoint from window {self._sync_source_idx} to all windows.") + self._should_sync = False + + def _do_save_screen(self): + """Save screenshots from all windows.""" + timestamp = time.strftime("%Y-%m-%d-%H-%M-%S", time.localtime()) + + for i, v in enumerate(self.vis): + v.poll_events() + v.update_renderer() + + name = self.res_names[i] if i < len(self.res_names) else f"window{i}" + prefix = f"{self.curr_index}_{name}" if self.curr_index != -1 else name + png_file = f"{self.save_folder}/{prefix}_{timestamp}.png" + v.capture_screen_image(png_file) + + if i == 0: + view_file = f"{self.save_folder}/{prefix}_{timestamp}.json" + with open(view_file, 'w') as f: + f.write(v.get_view_status()) + + print(f"Screenshots saved to {self.save_folder}/") + self._should_save = False + + def _save_error_bar(self, vis, src_idx=0): + """Save error colorbar as image.""" + if self.tmp_value is None: + print("No error value set, skipping error bar save.") + return + + import matplotlib.pyplot as plt + import matplotlib as mpl + + timestamp = time.strftime("%Y-%m-%d-%H-%M-%S", time.localtime()) + prefix = f"{self.curr_index}_error" if self.curr_index != -1 else "error" + png_file = f"{self.save_folder}/{prefix}_{timestamp}.png" + + fig, ax = plt.subplots(figsize=(10, 1)) + max_val = self.tmp_value * 100 + norm = mpl.colors.Normalize(vmin=0, vmax=max_val) + cb = mpl.colorbar.ColorbarBase(ax, cmap=plt.cm.hot, norm=norm, orientation='horizontal') + + ticks = np.linspace(0, max_val, 5) + cb.set_ticks(ticks) + cb.set_ticklabels([f"{t:.1f}" for t in ticks]) + cb.set_label('Error Magnitude (cm)') + + plt.savefig(png_file, bbox_inches='tight') + plt.close() + print(f"Error bar saved to: {png_file}") + + +# Backward compatibility aliases; FIXME: remove in near future +MyVisualizer = O3DVisualizer +MyMultiVisualizer = O3DVisualizer + if __name__ == "__main__": json_content = """{ @@ -236,12 +366,12 @@ def _prev_frame(self, vis): "version_minor" : 0 } """ - # write to json file view_json_file = "view.json" with open(view_json_file, 'w') as f: f.write(json_content) + sample_ply_data = o3d.data.PLYPointCloud() pcd = o3d.io.read_point_cloud(sample_ply_data.path) - - viz = MyVisualizer(view_json_file, window_title="Qingwen's View") + + viz = O3DVisualizer(view_json_file, res_name="Demo") viz.show([pcd]) \ No newline at end of file diff --git a/tools/README.md b/tools/README.md index a479138..9ccdb44 100644 --- a/tools/README.md +++ b/tools/README.md @@ -11,18 +11,36 @@ Here we introduce some tools to help you: run `tools/visualization.py` to view the scene flow dataset with ground truth flow. Note the color wheel in under world coordinate. ```bash -# view gt flow -python3 tools/visualization.py --data_dir /home/kin/data/av2/preprocess/sensor/mini --res_name flow +# Visualize flow with color coding +python tools/visualization.py vis --data_dir /path/to/data --res_name flow -# view est flow -python3 tools/visualization.py --data_dir /home/kin/data/av2/preprocess/sensor/mini --res_name deflow_best -python3 tools/visualization.py --data_dir /home/kin/data/av2/preprocess/sensor/mini --res_name seflow_best +# Compare multiple results side-by-side +python tools/visualization.py vis --data_dir /path/to/data --res_name "[flow, deflow, deltaflow, ssf]" + +# Show flow as vector lines +python tools/visualization.py vector --data_dir /path/to/data + +# Check flow with pc0, pc1, and flowed pc0 +python tools/visualization.py check --data_dir /path/to/data + +# Show error heatmap +python tools/visualization.py error --data_dir /path/to/data --res_name "[flow, deflow, deltaflow, ssf]" ``` Demo Effect (press `SPACE` to stop and start in the visualization window): https://github.com/user-attachments/assets/f031d1a2-2d2f-4947-a01f-834ed1c146e6 +**Tips**: To quickly create qualitative results for all methods, you can use multiple results comparison mode, select a good viewpoint and then save screenshots for all frames by pressing `P` key. You will found all methods' results are saved in the output folder (default is `logs/imgs`). + +## Quick Read .h5 Files + +You can quickly read all keys and shapes in a .h5 file by: + +```bash +python tools/read_h5.py --file_path /path/to/file.h5 +``` + ## Conversion run `tools/zero2ours.py` to convert the ZeroFlow pretrained model to our codebase. diff --git a/tools/visualization.py b/tools/visualization.py index 84c30e2..98b1ad6 100644 --- a/tools/visualization.py +++ b/tools/visualization.py @@ -1,204 +1,383 @@ """ -# Created: 2023-11-29 21:22 -# Copyright (C) 2023-now, RPL, KTH Royal Institute of Technology -# Author: Qingwen Zhang (https://kin-zhang.github.io/), Ajinkya Khoche (https://ajinkyakhoche.github.io/) -# -# This file is part of OpenSceneFlow (https://github.com/KTH-RPL/OpenSceneFlow). -# If you find this repo helpful, please cite the respective publication as -# listed on the above website. -# -# Description: view scene flow dataset after preprocess. +Scene Flow Visualization Tool +============================= +Copyright (C) 2023-now, RPL, KTH Royal Institute of Technology +Author: Qingwen Zhang (https://kin-zhang.github.io/), Ajinkya Khoche (https://ajinkyakhoche.github.io/) -# CHANGELOG: -# 2024-09-10 (Ajinkya): Add vis_multiple(), to visualize multiple flow modes at once. - -# Usage: (flow is ground truth flow, `other_name` is the estimated flow from the model) -* python tools/visualization.py --data_dir /home/kin/data/av2/h5py/demo/train --res_name 'flow' --mode vis -* python tools/visualization.py --data_dir /home/kin/data/av2/h5py/demo/train --res_name "['flow', 'deflow' , 'ssf']" --mode mul +Part of OpenSceneFlow (https://github.com/KTH-RPL/OpenSceneFlow). +Usage (Fire class-based): + # Visualize flow with color coding + python tools/visualization.py vis --data_dir /path/to/data --res_name flow + + # Compare multiple results side-by-side + python tools/visualization.py vis --data_dir /path/to/data --res_name "[flow, flow_est]" + + # Show flow as vector lines + python tools/visualization.py vector --data_dir /path/to/data + + # Check flow with pc0, pc1, and flowed pc0 + python tools/visualization.py check --data_dir /path/to/data + + # Show error heatmap + python tools/visualization.py error --data_dir /path/to/data --res_name "[raw, flow_est]" + +Keys: + [SPACE] play/pause [D] next frame [A] prev frame + [P] save screenshot [E] save error bar + [S] sync viewpoint across windows (multi-window mode) + [ESC/Q] quit """ import numpy as np -import fire, time +import fire +import time from tqdm import tqdm - import open3d as o3d -import os, sys -BASE_DIR = os.path.abspath(os.path.join( os.path.dirname( __file__ ), '..' )) +import os +import sys + +BASE_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), '..')) sys.path.append(BASE_DIR) -from src.utils.mics import HDF5Data, flow_to_rgb -from src.utils.o3d_view import MyVisualizer, MyMultiVisualizer, color_map, create_bev_square +from src.utils.mics import flow_to_rgb, error_to_color +from src.utils import npcal_pose0to1 +from src.utils.o3d_view import O3DVisualizer, color_map +from src.dataset import HDF5Dataset VIEW_FILE = f"{BASE_DIR}/assets/view/demo.json" +NO_COLOR = [1, 1, 1] -def check_flow( - data_dir: str ="/home/kin/data/av2/preprocess/sensor/mini", - res_name: str = "flow", # "flow", "flow_est" - start_id: int = 0, - point_size: float = 3.0, -): - dataset = HDF5Data(data_dir, vis_name=res_name, flow_view=True) - o3d_vis = MyVisualizer(view_file=VIEW_FILE, window_title=f"view {'ground truth flow' if res_name == 'flow' else f'{res_name} flow'}, `SPACE` start/stop") - - opt = o3d_vis.vis.get_render_option() - opt.background_color = np.asarray([80/255, 90/255, 110/255]) - opt.point_size = point_size - - for data_id in (pbar := tqdm(range(start_id, len(dataset)))): - data = dataset[data_id] - now_scene_id = data['scene_id'] - pbar.set_description(f"id: {data_id}, scene_id: {now_scene_id}, timestamp: {data['timestamp']}") - gm0 = data['gm0'] - pc0 = data['pc0'][~gm0] - - pcd = o3d.geometry.PointCloud() - pcd.points = o3d.utility.Vector3dVector(pc0[:, :3]) - pcd.paint_uniform_color([1.0, 0.0, 0.0]) # red: pc0 - - pc1 = data['pc1'] - pcd1 = o3d.geometry.PointCloud() - pcd1.points = o3d.utility.Vector3dVector(pc1[:, :3][~data['gm1']]) - pcd1.paint_uniform_color([0.0, 1.0, 0.0]) # green: pc1 - - pcd2 = o3d.geometry.PointCloud() - # pcd2.points = o3d.utility.Vector3dVector(pc0[:, :3] + pose_flow) # if you want to check pose_flow - pcd2.points = o3d.utility.Vector3dVector(pc0[:, :3] + data[res_name][~gm0]) - pcd2.paint_uniform_color([0.0, 0.0, 1.0]) # blue: pc0 + flow - o3d_vis.update([pcd, pcd1, pcd2, o3d.geometry.TriangleMesh.create_coordinate_frame(size=2)]) -def vis( - data_dir: str ="/home/kin/data/av2/h5py/demo/val", - res_name: str = "flow", # any res_name we write before in HDF5Data - start_id: int = 0, - point_size: float = 2.0, - mode: str = "vis", -): - if mode != "vis": - return - dataset = HDF5Data(data_dir, vis_name=res_name, flow_view=True) - o3d_vis = MyVisualizer(view_file=VIEW_FILE, window_title=f"view {'ground truth flow' if res_name == 'flow' else f'{res_name} flow'}, `SPACE` start/stop") +def _ensure_list(val): + """Ensure value is a list.""" + if val is None: + return [] + return val if isinstance(val, list) else [val] - opt = o3d_vis.vis.get_render_option() - # opt.background_color = np.asarray([216, 216, 216]) / 255.0 - opt.background_color = np.asarray([80/255, 90/255, 110/255]) - # opt.background_color = np.asarray([1, 1, 1]) - opt.point_size = point_size - for data_id in (pbar := tqdm(range(start_id, len(dataset)))): - data = dataset[data_id] - now_scene_id = data['scene_id'] - pbar.set_description(f"id: {data_id}, scene_id: {now_scene_id}, timestamp: {data['timestamp']}") +class SceneFlowVisualizer: + """ + Open3D-based Scene Flow Visualizer. + + Supports multiple visualization modes as class methods, + compatible with python-fire for CLI usage. + """ + + def __init__( + self, + data_dir: str = "/home/kin/data/av2/preprocess/sensor/mini", + res_name: str = "flow", + start_id: int = 0, + num_frames: int = 2, + rgm: bool = True, # remove ground mask + slc: bool = False, # show lidar centers + point_size: float = 3.0, + max_distance: float = 50.0, + bg_color: tuple = (80/255, 90/255, 110/255), + ): + """ + Initialize the visualizer. + + Args: + data_dir: Path to HDF5 dataset directory + res_name: Result name(s) to visualize (string or list) + start_id: Starting frame index + point_size: Point size for rendering + rgm: Remove ground mask if True + slc: Show LiDAR sensor centers if True + num_frames: Number of frames for history mode + max_distance: Maximum distance filter for points + bg_color: Background color as RGB tuple (0-1 range) + """ + self.data_dir = data_dir + self.res_names = _ensure_list(res_name) + self.start_id = start_id + self.point_size = point_size + self.rgm = rgm + self.num_frames = num_frames + self.max_distance = max_distance + self.bg_color = bg_color + self.show_lidar_centers = slc + + def _load_dataset(self, vis_name=None, n_frames=2): + """Load HDF5 dataset.""" + vis_name = vis_name or self.res_names + return HDF5Dataset(self.data_dir, vis_name=vis_name, n_frames=n_frames) + + def _create_visualizer(self, res_name=None): + """Create O3DVisualizer instance.""" + res_name = res_name or self.res_names + return O3DVisualizer( + view_file=VIEW_FILE, + res_name=res_name, + point_size=self.point_size, + bg_color=self.bg_color, + ) + + def _filter_ground_and_distance(self, pc, gm): + """Apply ground mask and distance filter.""" + if not self.rgm: + gm = np.zeros_like(gm) + distance = np.linalg.norm(pc[:, :3], axis=1) + return gm | (distance > self.max_distance) + + def _compute_pose_flow(self, pc0, pose0, pose1): + """Compute ego-motion flow.""" + ego_pose = npcal_pose0to1(pose0, pose1) + return pc0[:, :3] @ ego_pose[:3, :3].T + ego_pose[:3, 3] - pc0[:, :3] + + # ------------------------------------------------------------------------- + # Visualization Modes (Fire subcommands) + # ------------------------------------------------------------------------- + + def vis(self): + """ + Visualize scene flow with color-coded dynamic motion. + + Supports single or multiple result names for side-by-side comparison. + Colors represent flow direction (after ego-motion compensation). + """ + dataset = self._load_dataset() + o3d_vis = self._create_visualizer() + + data_id = self.start_id + pbar = tqdm(range(self.start_id, len(dataset))) + + while 0 <= data_id < len(dataset): + data = dataset[data_id] + pbar.set_description(f"id: {data_id}, scene: {data['scene_id']}, ts: {data['timestamp']}") + + pc0 = data['pc0'] + gm0 = self._filter_ground_and_distance(pc0, data['gm0']) + pose_flow = self._compute_pose_flow(pc0, data['pose0'], data['pose1']) + + if self.rgm: + pc0 = pc0[~gm0] + pose_flow = pose_flow[~gm0] + + pcd_list = [] + for single_res in self.res_names: + pcd = o3d.geometry.PointCloud() + + # Instance/cluster visualization + if single_res in ['dufo', 'cluster', 'dufocluster', 'flow_instance_id', + 'ground_mask', 'pc0_dynamic'] and single_res in data: + labels = data[single_res][~gm0] if self.rgm else data[single_res] + pcd = self._color_by_labels(pc0, labels) + + # Flow visualization + elif single_res in data: + pcd.points = o3d.utility.Vector3dVector(pc0[:, :3]) + flow = (data[single_res][~gm0] if self.rgm else data[single_res]) - pose_flow + flow_color = flow_to_rgb(flow) / 255.0 + is_dynamic = np.linalg.norm(flow, axis=1) > 0.08 + flow_color[~is_dynamic] = NO_COLOR + if not self.rgm: + flow_color[gm0] = NO_COLOR + pcd.colors = o3d.utility.Vector3dVector(flow_color) + + # Raw point cloud + elif single_res == 'raw': + pcd.points = o3d.utility.Vector3dVector(pc0[:, :3]) + + pcd_list.append([pcd, o3d.geometry.TriangleMesh.create_coordinate_frame(size=2)]) + + # show lidar centers + if self.show_lidar_centers and 'lidar_center' in data: + lidar_center = data['lidar_center'] + for lidar_num in range(lidar_center.shape[0]): + pcd_list[-1].append( + o3d.geometry.TriangleMesh.create_coordinate_frame(size=1).transform( + lidar_center[lidar_num] + ) + ) + + o3d_vis.update(pcd_list, index=data_id) + data_id += o3d_vis.playback_direction + pbar.update(o3d_vis.playback_direction) + + def check(self): + """ + Check flow by showing pc0 (red), pc1 (green), and pc0+flow (blue). + + Useful for verifying flow correctness. + """ + res_name = self.res_names[0] if self.res_names else "flow" + dataset = self._load_dataset(vis_name=res_name) + o3d_vis = self._create_visualizer(res_name=res_name) + + data_id = self.start_id + pbar = tqdm(range(self.start_id, len(dataset))) + + while 0 <= data_id < len(dataset): + data = dataset[data_id] + pbar.set_description(f"id: {data_id}, scene: {data['scene_id']}, ts: {data['timestamp']}") + + if res_name not in dataset[data_id]: + print(f"'{res_name}' not in dataset, skipping id {data_id}") + data_id += 1 + continue + + data = dataset[data_id] + pbar.set_description(f"id: {data_id}, scene: {data['scene_id']}, ts: {data['timestamp']}") + + pc0, pc1 = data['pc0'], data['pc1'] + if self.rgm: + pc0 = pc0[~data['gm0']] + pc1 = pc1[~data['gm1']] + + # Red: pc0 + pcd0 = o3d.geometry.PointCloud() + pcd0.points = o3d.utility.Vector3dVector(pc0[:, :3]) + pcd0.paint_uniform_color([1.0, 0.0, 0.0]) + + # Green: pc1 + pcd1 = o3d.geometry.PointCloud() + pcd1.points = o3d.utility.Vector3dVector(pc1[:, :3]) + pcd1.paint_uniform_color([0.0, 1.0, 0.0]) + + # Blue: pc0 + flow + res_flow = data[res_name][~data['gm0']] if self.rgm else data[res_name] + pcd2 = o3d.geometry.PointCloud() + pcd2.points = o3d.utility.Vector3dVector(pc0[:, :3] + res_flow) + pcd2.paint_uniform_color([0.0, 0.0, 1.0]) + + o3d_vis.update([pcd0, pcd1, pcd2, o3d.geometry.TriangleMesh.create_coordinate_frame(size=2)]) + data_id += o3d_vis.playback_direction + pbar.update(o3d_vis.playback_direction) - pc0 = data['pc0'] - gm0 = data['gm0'] - pose0 = data['pose0'] - pose1 = data['pose1'] - ego_pose = np.linalg.inv(pose1) @ pose0 + def vector(self): + """ + Visualize flow as red vector lines from source to target. + + Shows pc0 (green), pc1 (blue), and flow vectors (red lines). + """ + res_name = self.res_names[0] if self.res_names else "flow" + dataset = self._load_dataset(vis_name=res_name) + o3d_vis = O3DVisualizer( + view_file=VIEW_FILE, + res_name=res_name, + point_size=self.point_size, + bg_color=(1, 1, 1), # White background for vector mode + ) + + data_id = self.start_id + pbar = tqdm(range(self.start_id, len(dataset))) + + while 0 <= data_id < len(dataset): + data = dataset[data_id] + pbar.set_description(f"id: {data_id}, scene: {data['scene_id']}, ts: {data['timestamp']}") + + if res_name not in dataset[data_id]: + print(f"'{res_name}' not in dataset, skipping id {data_id}") + data_id += 1 + continue - pose_flow = pc0[:, :3] @ ego_pose[:3, :3].T + ego_pose[:3, 3] - pc0[:, :3] + pc0 = data['pc0'] + gm0 = self._filter_ground_and_distance(pc0, data['gm0']) + + ego_pose = np.linalg.inv(data['pose1']) @ data['pose0'] + pose_flow = pc0[:, :3] @ ego_pose[:3, :3].T + ego_pose[:3, 3] - pc0[:, :3] + flow = data[res_name] - pose_flow + + # Green: pc0 transformed + vis_pc = pc0[:, :3][~gm0] + pose_flow[~gm0] + pcd0 = o3d.geometry.PointCloud() + pcd0.points = o3d.utility.Vector3dVector(vis_pc) + pcd0.paint_uniform_color([0, 1, 0]) + + # Blue: pc1 + pcd1 = o3d.geometry.PointCloud() + pcd1.points = o3d.utility.Vector3dVector(data['pc1'][:, :3][~data['gm1']]) + pcd1.paint_uniform_color([0.0, 0.0, 1]) + + # Red: flow vectors + line_set = self._create_flow_lines(vis_pc, flow[~gm0], color=[1, 0, 0]) + + o3d_vis.update([pcd0, pcd1, line_set, o3d.geometry.TriangleMesh.create_coordinate_frame(size=2)], + index=data_id) + data_id += o3d_vis.playback_direction + pbar.update(o3d_vis.playback_direction) + + def error(self, max_error: float = 0.35): + """ + Visualize flow error as heatmap (hot colormap). + + Args: + max_error: Maximum error for color scaling (meters) + """ + dataset = self._load_dataset() + o3d_vis = self._create_visualizer() + o3d_vis.bg_color = np.asarray([216, 216, 216]) / 255.0 # Off-white + data_id = self.start_id + pbar = tqdm(range(0, len(dataset))) + + while 0 <= data_id < len(dataset): + data = dataset[data_id] + pbar.set_description(f"id: {data_id}, scene: {data['scene_id']}, ts: {data['timestamp']}") + + pc0 = data['pc0'] + gm0 = self._filter_ground_and_distance(pc0, data['gm0']) + + gt_flow = data["flow"][~gm0] if self.rgm else data["flow"] + if self.rgm: + pc0 = pc0[~gm0] + + pcd_list = [] + for single_res in self.res_names: + pcd = o3d.geometry.PointCloud() + pcd.points = o3d.utility.Vector3dVector(pc0[:, :3]) + + res_flow = None + if single_res in data: + res_flow = data[single_res][~gm0] if self.rgm else data[single_res] + elif single_res == 'raw': + res_flow = self._compute_pose_flow(pc0, data['pose0'], data['pose1']) + + if res_flow is not None: + error_mag = np.linalg.norm(gt_flow - res_flow, axis=-1) + error_mag[error_mag < 0.05] = 0 + error_color = error_to_color(error_mag, max_error=max_error, color_map="hot") / 255.0 + if not self.rgm: + error_color[gm0] = NO_COLOR + pcd.colors = o3d.utility.Vector3dVector(error_color) + pcd_list.append([pcd, o3d.geometry.TriangleMesh.create_coordinate_frame(size=2)]) + + o3d_vis.update(pcd_list, index=data_id, value=max_error) + data_id += o3d_vis.playback_direction + pbar.update(o3d_vis.playback_direction) + + # ------------------------------------------------------------------------- + # Helper Methods + # ------------------------------------------------------------------------- + + def _color_by_labels(self, pc, labels): + """Create point cloud colored by instance labels.""" pcd = o3d.geometry.PointCloud() - if res_name == 'raw': # no result, only show **raw point cloud** - pcd.points = o3d.utility.Vector3dVector(pc0[:, :3]) - pcd.paint_uniform_color([1.0, 1.0, 1.0]) - elif res_name in ['dufo', 'label']: - labels = data[res_name] + for label_i in np.unique(labels): pcd_i = o3d.geometry.PointCloud() - for label_i in np.unique(labels): - pcd_i.points = o3d.utility.Vector3dVector(pc0[labels == label_i][:, :3]) - if label_i <= 0: - pcd_i.paint_uniform_color([1.0, 1.0, 1.0]) - else: - pcd_i.paint_uniform_color(color_map[label_i % len(color_map)]) - pcd += pcd_i - elif res_name in data: - pcd.points = o3d.utility.Vector3dVector(pc0[:, :3]) - flow = data[res_name] - pose_flow # ego motion compensation here. - flow_color = flow_to_rgb(flow) / 255.0 - is_dynamic = np.linalg.norm(flow, axis=1) > 0.1 - flow_color[~is_dynamic] = [1, 1, 1] - flow_color[gm0] = [1, 1, 1] - pcd.colors = o3d.utility.Vector3dVector(flow_color) - o3d_vis.update([pcd, o3d.geometry.TriangleMesh.create_coordinate_frame(size=2)]) - - -def vis_multiple( - data_dir: str ="/home/kin/data/av2/h5py/demo/val", - res_name: list = ["flow"], - start_id: int = 0, - point_size: float = 3.0, - tone: str = 'dark', - mode: str = "mul", -): - if mode != "mul": - return - assert isinstance(res_name, list), "vis_multiple() needs a list as flow_mode" - dataset = HDF5Data(data_dir, vis_name=res_name, flow_view=True) - o3d_vis = MyMultiVisualizer(view_file=VIEW_FILE, flow_mode=res_name) - - for v in o3d_vis.vis: - opt = v.get_render_option() - if tone == 'bright': - background_color = np.asarray([216, 216, 216]) / 255.0 # offwhite - # background_color = np.asarray([1, 1, 1]) - pcd_color = [0.25, 0.25, 0.25] - elif tone == 'dark': - background_color = np.asarray([80/255, 90/255, 110/255]) # dark - pcd_color = [1., 1., 1.] - - opt.background_color = background_color - opt.point_size = point_size - - data_id = start_id - pbar = tqdm(range(0, len(dataset))) - - while data_id >= 0 and data_id < len(dataset): - data = dataset[data_id] - now_scene_id = data['scene_id'] - pbar.set_description(f"id: {data_id}, scene_id: {now_scene_id}, timestamp: {data['timestamp']}") - - pc0 = data['pc0'] - gm0 = data['gm0'] - pose0 = data['pose0'] - pose1 = data['pose1'] - ego_pose = np.linalg.inv(pose1) @ pose0 - - pose_flow = pc0[:, :3] @ ego_pose[:3, :3].T + ego_pose[:3, 3] - pc0[:, :3] - - pcd_list = [] - for mode in res_name: - pcd = o3d.geometry.PointCloud() - if mode in ['dufo', 'label']: - labels = data[mode] - pcd_i = o3d.geometry.PointCloud() - for label_i in np.unique(labels): - pcd_i.points = o3d.utility.Vector3dVector(pc0[labels == label_i][:, :3]) - if label_i <= 0: - pcd_i.paint_uniform_color([1.0, 1.0, 1.0]) - else: - pcd_i.paint_uniform_color(color_map[label_i % len(color_map)]) - pcd += pcd_i - elif mode in data: - pcd.points = o3d.utility.Vector3dVector(pc0[:, :3]) - flow = data[mode] - pose_flow # ego motion compensation here. - flow_color = flow_to_rgb(flow) / 255.0 - is_dynamic = np.linalg.norm(flow, axis=1) > 0.1 - flow_color[~is_dynamic] = pcd_color - flow_color[gm0] = pcd_color - pcd.colors = o3d.utility.Vector3dVector(flow_color) - pcd_list.append([pcd, create_bev_square(), - create_bev_square(size=204.8, color=[195/255,86/255,89/255]), - o3d.geometry.TriangleMesh.create_coordinate_frame(size=2)]) - o3d_vis.update(pcd_list) - - data_id += o3d_vis.playback_direction - pbar.update(o3d_vis.playback_direction) - + pcd_i.points = o3d.utility.Vector3dVector(pc[labels == label_i][:, :3]) + if label_i <= 0: + pcd_i.paint_uniform_color(NO_COLOR) + else: + pcd_i.paint_uniform_color(color_map[label_i % len(color_map)]) + pcd += pcd_i + return pcd + + def _create_flow_lines(self, source_pts, flow, color=[1, 0, 0]): + """Create line set for flow visualization.""" + line_set = o3d.geometry.LineSet() + target_pts = source_pts + flow + line_set_points = np.concatenate([source_pts, target_pts], axis=0) + lines = np.array([[i, i + len(source_pts)] for i in range(len(source_pts))]) + line_set.points = o3d.utility.Vector3dVector(line_set_points) + line_set.lines = o3d.utility.Vector2iVector(lines) + line_set.paint_uniform_color(color) + return line_set if __name__ == '__main__': start_time = time.time() - # fire.Fire(check_flow) - fire.Fire(vis) - fire.Fire(vis_multiple) + fire.Fire(SceneFlowVisualizer) print(f"Time used: {time.time() - start_time:.2f} s") \ No newline at end of file From cff7ce80937b13502e49d5f17d4661112389a701 Mon Sep 17 00:00:00 2001 From: Kin Date: Mon, 12 Jan 2026 12:31:39 +0100 Subject: [PATCH 2/6] fix(flow): add index_flow for 2hz gt view etc. --- src/dataset.py | 5 +++-- tools/visualization.py | 5 +++-- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/src/dataset.py b/src/dataset.py index 9d431b9..13cbc87 100644 --- a/src/dataset.py +++ b/src/dataset.py @@ -187,7 +187,7 @@ class HDF5Dataset(Dataset): def __init__(self, directory, \ transform=None, n_frames=2, ssl_label=None, \ eval = False, leaderboard_version=1, \ - vis_name=''): + vis_name='', index_flow=False): ''' Args: directory: the directory of the dataset, the folder should contain some .h5 file and index_total.pkl. @@ -199,6 +199,7 @@ def __init__(self, directory, \ * eval: if True, use the eval index (only used it for leaderboard evaluation) * leaderboard_version: 1st or 2nd, default is 1. If '2', we will use the index_eval_v2.pkl from assets/docs. * vis_name: the data of the visualization, default is ''. + * index_flow: if True, use the flow index for training or visualization. ''' super(HDF5Dataset, self).__init__() self.directory = directory @@ -247,7 +248,7 @@ def __init__(self, directory, \ # for some dataset that annotated HZ is different.... like truckscene and nuscene etc. self.train_index = None - if not eval and ssl_label is None and transform is not None: # transform indicates whether we are in training mode. + if (not eval and ssl_label is None and transform is not None) or index_flow: # transform indicates whether we are in training mode. # check if train seq all have gt. one_scene_id = list(self.scene_id_bounds.keys())[0] check_flow_exist = True diff --git a/tools/visualization.py b/tools/visualization.py index 98b1ad6..859259b 100644 --- a/tools/visualization.py +++ b/tools/visualization.py @@ -103,7 +103,7 @@ def __init__( def _load_dataset(self, vis_name=None, n_frames=2): """Load HDF5 dataset.""" vis_name = vis_name or self.res_names - return HDF5Dataset(self.data_dir, vis_name=vis_name, n_frames=n_frames) + return HDF5Dataset(self.data_dir, vis_name=vis_name, n_frames=n_frames, index_flow='flow' in vis_name) def _create_visualizer(self, res_name=None): """Create O3DVisualizer instance.""" @@ -289,7 +289,8 @@ def vector(self): # Blue: pc1 pcd1 = o3d.geometry.PointCloud() - pcd1.points = o3d.utility.Vector3dVector(data['pc1'][:, :3][~data['gm1']]) + gm1 = self._filter_ground_and_distance(data['pc1'], data['gm1']) + pcd1.points = o3d.utility.Vector3dVector(data['pc1'][:, :3][~gm1]) pcd1.paint_uniform_color([0.0, 0.0, 1]) # Red: flow vectors From fafe30e0a479297484590d08ca22d4fe1bc3427b Mon Sep 17 00:00:00 2001 From: Kin Date: Sat, 17 Jan 2026 15:04:53 +0100 Subject: [PATCH 3/6] hotfix: voteflow cuda lib skip compile if pre-install already. --- .../hough_transformation/cpp_im2ht/setup.py | 8 ++++- .../hough_transformation/im2ht.py | 35 +++++++++++-------- 2 files changed, 27 insertions(+), 16 deletions(-) diff --git a/src/models/basic/voteflow_plugin/hough_transformation/cpp_im2ht/setup.py b/src/models/basic/voteflow_plugin/hough_transformation/cpp_im2ht/setup.py index 6c2521f..d1e1559 100755 --- a/src/models/basic/voteflow_plugin/hough_transformation/cpp_im2ht/setup.py +++ b/src/models/basic/voteflow_plugin/hough_transformation/cpp_im2ht/setup.py @@ -1,6 +1,11 @@ from setuptools import setup from torch.utils.cpp_extension import BuildExtension, CUDAExtension +extra_compile_args = { + 'cxx': ['-DCCCL_IGNORE_DEPRECATED_CUDA_BELOW_12', '-DTHRUST_IGNORE_CUB_VERSION_CHECK'], + 'nvcc': ['-DCCCL_IGNORE_DEPRECATED_CUDA_BELOW_12', '-DTHRUST_IGNORE_CUB_VERSION_CHECK'], +} + setup( name='im2ht', # ext_modules=[ @@ -8,7 +13,8 @@ # extra_compile_args={'cxx': ['-g'], 'nvcc': ['-arch=sm_60']}), # ], ext_modules=[ - CUDAExtension(name='im2ht', sources=['im2ht.cpp', 'ht_cuda.cu']), + CUDAExtension(name='im2ht', sources=['im2ht.cpp', 'ht_cuda.cu'], + extra_compile_args=extra_compile_args), ], cmdclass={ 'build_ext': BuildExtension diff --git a/src/models/basic/voteflow_plugin/hough_transformation/im2ht.py b/src/models/basic/voteflow_plugin/hough_transformation/im2ht.py index e4d12e7..927e823 100755 --- a/src/models/basic/voteflow_plugin/hough_transformation/im2ht.py +++ b/src/models/basic/voteflow_plugin/hough_transformation/im2ht.py @@ -7,22 +7,27 @@ from torch.autograd.function import once_differentiable def load_cpp_ext(ext_name): - root_dir = os.path.join(os.path.split(__file__)[0]) - src_dir = os.path.join(root_dir, "cpp_im2ht") - tar_dir = os.path.join(src_dir, "build", ext_name) - os.makedirs(tar_dir, exist_ok=True) - srcs = glob(f"{src_dir}/*.cu") + glob(f"{src_dir}/*.cpp") + try: + import im2ht + ext = im2ht + except ImportError: + print(f"Compiling {ext_name} cpp/cuda extension...") + root_dir = os.path.join(os.path.split(__file__)[0]) + src_dir = os.path.join(root_dir, "cpp_im2ht") + tar_dir = os.path.join(src_dir, "build", ext_name) + os.makedirs(tar_dir, exist_ok=True) + srcs = glob(f"{src_dir}/*.cu") + glob(f"{src_dir}/*.cpp") - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - from torch.utils.cpp_extension import load - ext = load( - name=ext_name, - sources=srcs, - extra_cflags=["-O3"], - extra_cuda_cflags=[], - build_directory=tar_dir, - ) + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + from torch.utils.cpp_extension import load + ext = load( + name=ext_name, + sources=srcs, + extra_cflags=["-O3"], + extra_cuda_cflags=["-DTHRUST_IGNORE_CUB_VERSION_CHECK"], + build_directory=tar_dir, + ) return ext # defer calling load_cpp_ext to make CUDA_VISIBLE_DEVICES happy From a8b4856bf044934e35cb31238d2b48f2887e4cf6 Mon Sep 17 00:00:00 2001 From: Kin Date: Wed, 11 Mar 2026 17:36:04 +0100 Subject: [PATCH 4/6] !big changes on loss caculators. * Add teflowLoss into the codebase * update chamfer3D with CUDA stream-style batch busy compute. AI summary: - Added automatic collection of self-supervised loss function names in `src/lossfuncs/__init__.py`. - Improved documentation and structure of self-supervised loss functions in `src/lossfuncs/selfsupervise.py`. - Refactored loss calculation logic in `src/trainer.py` to support new self-supervised loss functions. - Introduced `ssl_loss_calculator` method for handling self-supervised losses. - Updated training step to differentiate between self-supervised and supervised loss calculations. - Enhanced error handling during training and validation steps to skip problematic batches. --- README.md | 19 +- assets/cuda/chamfer3D/__init__.py | 288 +++++++++++----- src/lossfuncs/__init__.py | 2 + src/lossfuncs/selfsupervise.py | 532 +++++++++++++++++++++--------- src/trainer.py | 215 +++++++----- 5 files changed, 731 insertions(+), 325 deletions(-) diff --git a/README.md b/README.md index f7ad2d4..983a28e 100644 --- a/README.md +++ b/README.md @@ -14,7 +14,7 @@ It is also an official implementation of the following papers (sorted by the tim - **TeFlow: Enabling Multi-frame Supervision for Self-Supervised Feed-forward Scene Flow Estimation** *Qingwen Zhang, Chenhan Jiang, Xiaomeng Zhu, Yunqi Miao, Yushan Zhang, Olov Andersson, Patric Jensfelt* Conference on Computer Vision and Pattern Recognition (**CVPR**) 2026 -[ Strategy ] [ Self-Supervised ] - [ [arXiv](https://arxiv.org/abs/2602.19053) ] [ [Project]() ] +[ Strategy ] [ Self-Supervised ] - [ [arXiv](https://arxiv.org/abs/2602.19053) ] [ [Project]() ]→ [here](#teflow) - **DeltaFlow: An Efficient Multi-frame Scene Flow Estimation Method** *Qingwen Zhang, Xiaomeng Zhu, Yushan Zhang, Yixi Cai, Olov Andersson, Patric Jensfelt* @@ -149,7 +149,9 @@ Train DeltaFlow with the leaderboard submit config. [Runtime: Around 18 hours in ```bash # total bz then it's 10x2 under above training setup. -python train.py model=deltaFlow optimizer.lr=2e-3 epochs=20 batch_size=2 num_frames=5 loss_fn=deflowLoss train_aug=True "voxel_size=[0.15, 0.15, 0.15]" "point_cloud_range=[-38.4, -38.4, -3, 38.4, 38.4, 3]" +optimizer.scheduler.name=WarmupCosLR +optimizer.scheduler.max_lr=2e-3 +optimizer.scheduler.total_steps=20000 +python train.py model=deltaflow optimizer.lr=2e-3 epochs=20 batch_size=2 num_frames=5 \ + loss_fn=deflowLoss train_aug=True "voxel_size=[0.15, 0.15, 0.15]" "point_cloud_range=[-38.4, -38.4, -3, 38.4, 38.4, 3]" \ + optimizer.lr=2e-4 +optimizer.scheduler.name=WarmupCosLR +optimizer.scheduler.max_lr=2e-3 +optimizer.scheduler.warmup_epochs=2 # Pretrained weight can be downloaded through (av2), check all other datasets in the same folder. wget https://huggingface.co/kin-zhang/OpenSceneFlow/resolve/main/deltaflow/deltaflow-av2.ckpt @@ -206,6 +208,19 @@ Train Feed-forward SSL methods (e.g. SeFlow/SeFlow++/VoteFlow etc), we needed to 1) process auto-label process for training. Check [dataprocess/README.md#self-supervised-process](dataprocess/README.md#self-supervised-process) for more details. We provide these inside the demo dataset already. 2) specify the loss function, we set the config here for our best model in the leaderboard. +#### TeFlow + +```bash +# [Runtime: Around ? hours in 10x GPUs.] +python train.py model=deltaflow epochs=21 batch_size=2 num_frames=5 train_aug=True \ + loss_fn=teflowLoss "voxel_size=[0.15, 0.15, 0.15]" "point_cloud_range=[-38.4, -38.4, -3, 38.4, 38.4, 3]" \ + +ssl_label=seflow_auto "+add_seloss={chamfer_dis: 1.0, static_flow_loss: 1.0, dynamic_chamfer_dis: 1.0, cluster_based_pc0pc1: 1.0}" \ + optimizer.name=Adam optimizer.lr=1e-4 +optimizer.scheduler.name=StepLR +optimizer.scheduler.step_size=9 +optimizer.scheduler.gamma=0.5 + +# Pretrained weight can be downloaded through: +wget https://huggingface.co/kin-zhang/OpenSceneFlow/resolve/main/teflow/teflow-av2.ckpt +``` + #### SeFlow ```bash diff --git a/assets/cuda/chamfer3D/__init__.py b/assets/cuda/chamfer3D/__init__.py index fc5020d..a9971b2 100644 --- a/assets/cuda/chamfer3D/__init__.py +++ b/assets/cuda/chamfer3D/__init__.py @@ -2,116 +2,222 @@ # Created: 2023-08-04 11:20 # Copyright (C) 2023-now, RPL, KTH Royal Institute of Technology # Author: Qingwen Zhang (https://kin-zhang.github.io/) -# +# # This file is part of SeFlow (https://github.com/KTH-RPL/SeFlow). -# If you find this repo helpful, please cite the respective publication as +# If you find this repo helpful, please cite the respective publication as # listed on the above website. -# -# -# Description: ChamferDis speedup using CUDA +# +# Description: ChamferDis speedup using CUDA. +# +# NOTE(2026-03-11, Qingwen) Why CUDA streams (not batched kernel): +# At N=88K pts/sample on RTX 3090, one sample already uses 4.2 SM waves, +# so any kernel-level batching hits the same hardware ceiling. +# Streams give ~1.14× speedup by overlapping B independent kernel launches. +# More importantly, they keep the GPU busy with fewer CPU-GPU sync gaps, +# preventing GPU utilization from spiking which triggers cluster job kills. +# """ +from __future__ import annotations + from torch import nn from torch.autograd import Function -import torch - -import os, time +import torch, os, time +from typing import List import chamfer3D -BASE_DIR = os.path.abspath(os.path.join( os.path.dirname( __file__ ), '../..' )) + +BASE_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), '../..')) -# GPU tensors only class ChamferDis(Function): + """Single-sample Chamfer distance: pc0 (N,3) × pc1 (M,3) on GPU.""" + @staticmethod def forward(ctx, pc0, pc1): - # pc0: (N,3), pc1: (M,3) - dis0 = torch.zeros(pc0.shape[0]).to(pc0.device).contiguous() - dis1 = torch.zeros(pc1.shape[0]).to(pc1.device).contiguous() - - idx0 = torch.zeros(pc0.shape[0], dtype=torch.int32).to(pc0.device).contiguous() - idx1 = torch.zeros(pc1.shape[0], dtype=torch.int32).to(pc1.device).contiguous() - - + dis0 = torch.zeros(pc0.shape[0], device=pc0.device).contiguous() + dis1 = torch.zeros(pc1.shape[0], device=pc1.device).contiguous() + idx0 = torch.zeros(pc0.shape[0], dtype=torch.int32, device=pc0.device).contiguous() + idx1 = torch.zeros(pc1.shape[0], dtype=torch.int32, device=pc1.device).contiguous() chamfer3D.forward(pc0, pc1, dis0, dis1, idx0, idx1) ctx.save_for_backward(pc0, pc1, idx0, idx1) return dis0, dis1, idx0, idx1 @staticmethod - def backward(ctx, grad_dist0, grad_dist1, grad_idx0, grad_idx1): + def backward(ctx, gd0, gd1, _gi0, _gi1): pc0, pc1, idx0, idx1 = ctx.saved_tensors - grad_dist0 = grad_dist0.contiguous() - grad_dist1 = grad_dist1.contiguous() - device = grad_dist1.device - - grad_pc0 = torch.zeros(pc0.size()).to(device).contiguous() - grad_pc1 = torch.zeros(pc1.size()).to(device).contiguous() - - chamfer3D.backward( - pc0, pc1, idx0, idx1, grad_dist0, grad_dist1, grad_pc0, grad_pc1 - ) - return grad_pc0, grad_pc1 - + gpc0 = torch.zeros_like(pc0) + gpc1 = torch.zeros_like(pc1) + chamfer3D.backward(pc0, pc1, idx0, idx1, + gd0.contiguous(), gd1.contiguous(), gpc0, gpc1) + return gpc0, gpc1 + +# ─── nn.Module ──────────────────────────────────────────────────────────────── class nnChamferDis(nn.Module): - def __init__(self, truncate_dist=True): - super(nnChamferDis, self).__init__() + """Chamfer distance loss — single and batched-via-streams modes. + + Methods + ------- + forward(pc0, pc1) + Single-sample loss. Used by seflowLoss / seflowppLoss. + + batched/batched_disid_res (pc0_list, pc1_list) + Parallel loss across B samples via CUDA streams. + Returns mean-over-samples scalar. + Used by batched_chamfer_related() for chamfer_dis / dynamic_chamfer_dis. + + dis_res(pc0, pc1) → (dist0, dist1), no reduction + disid_res(pc0, pc1) → (dist0, dist1, idx0, idx1), no reduction + truncated_dis(pc0, pc1) → NSFP-style truncated loss + """ + + def __init__(self, truncate_dist: bool = True): + super().__init__() self.truncate_dist = truncate_dist + # Pre-allocate streams once to avoid per-call creation overhead (~50 µs each) + self._streams: List[torch.cuda.Stream] = [] + + def _ensure_streams(self, n: int) -> List[torch.cuda.Stream]: + while len(self._streams) < n: + self._streams.append(torch.cuda.Stream()) + return self._streams[:n] + + # ── single-sample forward ───────────────────────────────────────────────── + + def forward(self, input0: torch.Tensor, input1: torch.Tensor, + truncate_dist: float = -1, **_ignored) -> torch.Tensor: + """Single-sample Chamfer loss. truncate_dist<=0 → no truncation.""" + dist0, dist1, _, _ = ChamferDis.apply(input0.contiguous(), input1.contiguous()) + if truncate_dist <= 0: + return dist0.mean() + dist1.mean() + v0, v1 = dist0 <= truncate_dist, dist1 <= truncate_dist + return torch.nanmean(dist0[v0]) + torch.nanmean(dist1[v1]) + + # ── batched loss via CUDA streams ───────────────────────────────────────── + + def batched(self, + pc0_list: List[torch.Tensor], + pc1_list: List[torch.Tensor], + truncate_dist: float = -1) -> torch.Tensor: + """Parallel Chamfer loss via B CUDA streams. + + Returns mean-over-samples: (1/B) * Σ_i [mean(dist0_i) + mean(dist1_i)]. + ~1.14× faster than serial loop on RTX 3090 @ 88K pts/sample; + more importantly, keeps GPU busy with one sustained work block per frame. + """ + B = len(pc0_list) + if B == 1: + return self.forward(pc0_list[0], pc1_list[0], truncate_dist) + + streams = self._ensure_streams(B) + main = torch.cuda.current_stream() + per_loss: List[torch.Tensor] = [None] * B # type: ignore[list-item] + + for i in range(B): + streams[i].wait_stream(main) + with torch.cuda.stream(streams[i]): + d0, d1, _, _ = ChamferDis.apply(pc0_list[i].contiguous(), + pc1_list[i].contiguous()) + if truncate_dist <= 0: + per_loss[i] = d0.mean() + d1.mean() + else: + v0, v1 = d0 <= truncate_dist, d1 <= truncate_dist + per_loss[i] = torch.nanmean(d0[v0]) + torch.nanmean(d1[v1]) + + for i in range(B): + main.wait_stream(streams[i]) + + return torch.stack(per_loss).mean() + + # ── batched disid_res via CUDA streams (for cluster precomputation) ─────── + + def batched_disid_res(self, + pc0_list: List[torch.Tensor], + pc1_list: List[torch.Tensor], + ) -> tuple[List[torch.Tensor], List[torch.Tensor]]: + """Parallel disid_res across B samples via CUDA streams. + + Same list-in / list-out convention as batched(). + + Returns + ------- + dist0_list : List[(N_i,)] per-point nearest distance in pc1_i + idx0_list : List[(N_i,)] LOCAL index into pc1_list[i] (0 .. M_i-1) + + Usage: + dist0_list, idx0_list = fn.batched_disid_res(pc0_list, pc1_list) + neighbour = pc1_list[i][idx0_list[i][mask]] # no global arithmetic + """ + B = len(pc0_list) + if B == 1: + d0, _, i0, _ = ChamferDis.apply(pc0_list[0].contiguous(), pc1_list[0].contiguous()) + return [d0], [i0] + + streams = self._ensure_streams(B) + main = torch.cuda.current_stream() + d0_list: List[torch.Tensor] = [None] * B # type: ignore[list-item] + i0_list: List[torch.Tensor] = [None] * B # type: ignore[list-item] + + for i in range(B): + streams[i].wait_stream(main) + with torch.cuda.stream(streams[i]): + d0, _, idx0, _ = ChamferDis.apply(pc0_list[i].contiguous(), + pc1_list[i].contiguous()) + d0_list[i] = d0 + i0_list[i] = idx0 # local index — no offset arithmetic needed + + for i in range(B): + main.wait_stream(streams[i]) + + return d0_list, i0_list + + # ── utilities ───────────────────────────────────────────────────────────── + + def dis_res(self, input0: torch.Tensor, input1: torch.Tensor): + """Return raw (dist0, dist1) without reduction.""" + d0, d1, _, _ = ChamferDis.apply(input0.contiguous(), input1.contiguous()) + return d0, d1 + + def disid_res(self, input0: torch.Tensor, input1: torch.Tensor): + """Return raw (dist0, dist1, idx0, idx1) without reduction.""" + return ChamferDis.apply(input0.contiguous(), input1.contiguous()) + + def truncated_dis(self, input0: torch.Tensor, input1: torch.Tensor, + truncate_dist: float = 2.0) -> torch.Tensor: + """NSFP-style: distances >= threshold clamped to 0, then mean.""" + cx, cy = self.dis_res(input0, input1) + cx[cx >= truncate_dist] = 0.0 + cy[cy >= truncate_dist] = 0.0 + return cx.mean() + cy.mean() + - def forward(self, input0, input1, truncate_dist=-1): - input0 = input0.contiguous() - input1 = input1.contiguous() - dist0, dist1, _, _ = ChamferDis.apply(input0, input1) - - if truncate_dist<=0: - return torch.mean(dist0) + torch.mean(dist1) - - valid_mask0 = (dist0 <= truncate_dist) - valid_mask1 = (dist1 <= truncate_dist) - truncated_sum = torch.nanmean(dist0[valid_mask0]) + torch.nanmean(dist1[valid_mask1]) - return truncated_sum - - def dis_res(self, input0, input1): - input0 = input0.contiguous() - input1 = input1.contiguous() - dist0, dist1, _, _ = ChamferDis.apply(input0, input1) - return dist0, dist1 - - def truncated_dis(self, input0, input1, truncate_dist=2): - # nsfp: truncated distance way is set >= 2 to 0 but not nanmean - cham_x, cham_y = self.dis_res(input0, input1) - cham_x[cham_x >= truncate_dist] = 0.0 - cham_y[cham_y >= truncate_dist] = 0.0 - return torch.mean(cham_x) + torch.mean(cham_y) - - def disid_res(self, input0, input1): - input0 = input0.contiguous() - input1 = input1.contiguous() - dist0, dist1, idx0, idx1 = ChamferDis.apply(input0, input1) - return dist0, dist1, idx0, idx1 -class NearestNeighborDis(nn.Module): - def __init__(self): - super(NearestNeighborDis, self).__init__() - - def forward(self, input0, input1): - input0 = input0.contiguous() - input1 = input1.contiguous() - dist0, dist1, _, _ = ChamferDis.apply(input0, input1) - - return torch.mean(dist0[dist0 <= 2]) - if __name__ == "__main__": import numpy as np - pc0 = np.load(f'{BASE_DIR}/assets/tests/test_pc0.npy') - pc1 = np.load(f'{BASE_DIR}/assets/tests/test_pc1.npy') - print('0: {:.3f}MB'.format(torch.cuda.memory_allocated()/1024**2)) - pc0 = torch.from_numpy(pc0[...,:3]).float().cuda().contiguous() - pc1 = torch.from_numpy(pc1[...,:3]).float().cuda().contiguous() - pc0.requires_grad = True - pc1.requires_grad = True - print(pc0.shape, "demo data: ", pc0[0]) - print(pc1.shape, "demo data: ", pc1[0]) - print('1: {:.3f}MB'.format(torch.cuda.memory_allocated()/1024**2)) - - start_time = time.time() - loss = nnChamferDis(truncate_dist=False)(pc0, pc1) - loss.backward() - print("loss: ", loss) - print(f"Chamfer Distance Cal time: {(time.time() - start_time)*1000:.3f} ms") \ No newline at end of file + pc0_np = np.load(f'{BASE_DIR}/tests/test_pc0.npy')[..., :3] + pc1_np = np.load(f'{BASE_DIR}/tests/test_pc1.npy')[..., :3] + pc0 = torch.from_numpy(pc0_np).float().cuda() + pc1 = torch.from_numpy(pc1_np).float().cuda() + fn = nnChamferDis(truncate_dist=False) + + loss_s = fn(pc0, pc1) + print(f"Single: {loss_s.item():.6f}") + + for B in [2, 4, 8]: + lb = fn.batched([pc0.clone()]*B, [pc1.clone()]*B) + print(f"Batched B={B}: {lb.item():.6f} {'✓' if torch.allclose(loss_s, lb, atol=1e-5) else '✗'}") + + # Test batched_disid_res global indexing + print("\n--- batched_disid_res global index test ---") + B = 2 + pc0_b = torch.cat([pc0]*B) + pc1_b = torch.cat([pc1]*B) + N0, N1 = pc0.shape[0], pc1.shape[0] + offs0 = torch.tensor([0, N0], dtype=torch.int32, device='cuda') + szs0 = torch.tensor([N0, N0], dtype=torch.int32, device='cuda') + offs1 = torch.tensor([0, N1], dtype=torch.int32, device='cuda') + szs1 = torch.tensor([N1, N1], dtype=torch.int32, device='cuda') + pc0_lst = [pc0]*B + pc1_lst = [pc1]*B + d0_lst_out, i0_lst_out = fn.batched_disid_res(pc0_lst, pc1_lst) + assert len(d0_lst_out) == B and len(i0_lst_out) == B, "wrong list length" + for j in range(B): + assert (i0_lst_out[j] < N1).all(), f"sample-{j} idx out of range" + print("Local index check: ✓") \ No newline at end of file diff --git a/src/lossfuncs/__init__.py b/src/lossfuncs/__init__.py index 7bf446b..601b567 100644 --- a/src/lossfuncs/__init__.py +++ b/src/lossfuncs/__init__.py @@ -17,3 +17,5 @@ from .selfsupervise import * from .supervise import * +# automatic collection of SSL loss function names for trainer.py +SSL_LOSSES_FN = {name for name in dir(selfsupervise) if name.endswith('Loss') and callable(getattr(selfsupervise, name))} \ No newline at end of file diff --git a/src/lossfuncs/selfsupervise.py b/src/lossfuncs/selfsupervise.py index dccc923..e1238c9 100644 --- a/src/lossfuncs/selfsupervise.py +++ b/src/lossfuncs/selfsupervise.py @@ -10,9 +10,20 @@ # # If you find this repo helpful, please cite the respective publication as # listed on the above website. -# -# Description: Define the self-supervised (without GT) loss function for training. # +# Description: Self-supervised loss functions. +# +# All losses receive a unified dict from ssl_loss_calculator (trainer.py). +# Every frame is represented only as a List[Tensor] — no flat/offsets/sizes. +# +# res_dict keys (per frame 'pc0', 'pc1', 'pch1', ...): +# '{frame}_list' : List[Tensor (N_i,3)] one tensor per sample +# '{frame}_labels' : List[Tensor (N_i,)] one label vector per sample +# +# 'est_flow_list' : List[Tensor (N_i,3)] +# 'batch_size' : int +# 'loss_weights_dict': dict (teflow* only) +# 'cluster_loss_args': dict (teflowLoss only) """ import torch from assets.cuda.chamfer3D import nnChamferDis @@ -22,169 +33,374 @@ # If your scenario is different, may need adjust this TRUNCATED to 80-120km/h vel. TRUNCATED_DIST = 4 -def seflowppLoss(res_dict, timer=None): - pch1_label = res_dict['pch1_labels'] - pc0_label = res_dict['pc0_labels'] - pc1_label = res_dict['pc1_labels'] - - pch1 = res_dict['pch1'] - pc0 = res_dict['pc0'] - pc1 = res_dict['pc1'] - - est_flow = res_dict['est_flow'] - - pseudo_pc1from0 = pc0 + est_flow - pseduo_pch1from0 = pc0 - est_flow - - unique_labels = torch.unique(pc0_label) - pc0_dynamic = pc0[pc0_label > 0] - pc1_dynamic = pc1[pc1_label > 0] - - # fpc1_dynamic = pseudo_pc1from0[pc0_label > 0] - # NOTE(Qingwen): since we set THREADS_PER_BLOCK is 256 - have_dynamic_cluster = (pc0_dynamic.shape[0] > 256) & (pc1_dynamic.shape[0] > 256) - - # first item loss: chamfer distance - # timer[5][1].start("MyCUDAChamferDis") - chamfer_dis = MyCUDAChamferDis(pseudo_pc1from0, pc1, truncate_dist=TRUNCATED_DIST) + MyCUDAChamferDis(pseduo_pch1from0, pch1, truncate_dist=TRUNCATED_DIST) - # timer[5][1].stop() - - # second item loss: dynamic chamfer distance - # timer[5][2].start("DynamicChamferDistance") - dynamic_chamfer_dis = torch.tensor(0.0, device=est_flow.device) - if have_dynamic_cluster: - dynamic_chamfer_dis += MyCUDAChamferDis(pseudo_pc1from0[pc0_label > 0], pc1_dynamic, truncate_dist=TRUNCATED_DIST) - if pch1[pch1_label > 0].shape[0] > 256: - dynamic_chamfer_dis += MyCUDAChamferDis(pseduo_pch1from0[pc0_label > 0], pch1[pch1_label > 0], truncate_dist=TRUNCATED_DIST) - # timer[5][2].stop() - - # third item loss: exclude static points' flow - # NOTE(Qingwen): add in the later part on label==0 - static_cluster_loss = torch.tensor(0.0, device=est_flow.device) - - # fourth item loss: same label points' flow should be the same - # timer[5][3].start("SameClusterLoss") - # raw: pc0 to pc1, est: pseudo_pc1from0 to pc1, idx means the nearest index - raw_dist0, raw_dist1, raw_idx0, _ = MyCUDAChamferDis.disid_res(pc0, pc1) - moved_cluster_loss = torch.tensor(0.0, device=est_flow.device) - moved_cluster_norms = torch.tensor([], device=est_flow.device) - for label in unique_labels: - mask = pc0_label == label - if label == 0: - # Eq. 6 in the SeFlow paper - static_cluster_loss += torch.linalg.vector_norm(est_flow[mask, :], dim=-1).mean() - # NOTE(Qingwen) 2025-04-23: label=1 is dynamic but no cluster id satisfied - elif label > 1 and have_dynamic_cluster: - cluster_id_flow = est_flow[mask, :] - cluster_nnd = raw_dist0[mask] - if cluster_nnd.shape[0] <= 0: +# FIXME(Qingwen 25-07-21): hardcoded 10 Hz. Adjust for datasets with different timestamps. +DELTA_T = 0.1 # seconds + + +# ---- helpers ----------------------------------------------------------------- + +def get_time_delta(frame_id): + """Return (time_delta, factor). + pch1->(-0.1,1), pch2->(-0.2,2), pc1->(+0.1,1), pc2->(+0.2,2) + """ + if frame_id.startswith('pch'): + n = int(frame_id[3:]) if len(frame_id) > 3 else 1 + return -DELTA_T * n, n + elif frame_id.startswith('pc'): + n = int(frame_id[2:]) if len(frame_id) > 2 else 1 + return DELTA_T * n, n + raise ValueError(f"Unknown frame ID: {frame_id}") + + +def _frame_keys(res_dict): + """Auxiliary frame ids present in res_dict (e.g. ['pc1', 'pch1']), excluding pc0.""" + return [k.replace('_list', '') for k in res_dict + if k.endswith('_list') \ + and k != 'pc0_list' and k != 'est_flow_list' and not k.endswith('_labels_list')] + + +# ---- helpers shared by teflow* ----------------------------------------------- + +def batched_chamfer_related(res_dict, timer=None): + """Chamfer + dynamic-chamfer over all auxiliary frames via CUDA streams. + + Returns + ------- + total_chamfer_dis, total_dynamic_chamfer_dis : scalar Tensors + frame_keys : List[str] + """ + pc0_list = res_dict['pc0_list'] + flow_list = res_dict['est_flow_list'] + pc0_lab_list = res_dict['pc0_labels_list'] + frame_keys = _frame_keys(res_dict) + loss_w = res_dict['loss_weights_dict'] + chamfer_w = loss_w.get('chamfer_dis', 0.0) + dyn_chamfer_w = loss_w.get('dynamic_chamfer_dis', 0.0) + + total_chamfer_dis = torch.tensor(0.0, device=pc0_list[0].device) + total_dynamic_chamfer_dis = torch.tensor(0.0, device=pc0_list[0].device) + + for frame_id in frame_keys: + time_delta, factor = get_time_delta(frame_id) + weight = 1.0 if frame_id == 'pc1' else 1.0 / pow(2, factor) + target_list = res_dict[f'{frame_id}_list'] + + # Projected positions: list comprehension keeps everything per-sample + proj_list = [p0 + (fv / DELTA_T) * time_delta + for p0, fv in zip(pc0_list, flow_list)] + + if chamfer_w > 0: + total_chamfer_dis += MyCUDAChamferDis.batched( + proj_list, target_list, truncate_dist=TRUNCATED_DIST * factor + ) * weight + + if dyn_chamfer_w <= 0: + continue + + tgt_lab_list = res_dict[f'{frame_id}_labels_list'] + proj_dyn, tgt_dyn = [], [] + for proj_i, p0_lab_i, tgt_i, tgt_lab_i in zip( + proj_list, pc0_lab_list, target_list, tgt_lab_list): + dp = proj_i[p0_lab_i > 0] + dt = tgt_i[tgt_lab_i > 0] + if dp.shape[0] > 256 and dt.shape[0] > 256: + proj_dyn.append(dp) + tgt_dyn.append(dt) + + if len(proj_dyn) == 1: + total_dynamic_chamfer_dis += MyCUDAChamferDis( + proj_dyn[0], tgt_dyn[0], truncate_dist=TRUNCATED_DIST * factor + ) * weight + elif len(proj_dyn) > 1: + total_dynamic_chamfer_dis += MyCUDAChamferDis.batched( + proj_dyn, tgt_dyn, truncate_dist=TRUNCATED_DIST * factor + ) * weight + + n = len(frame_keys) + if n > 0: + total_chamfer_dis /= n + total_dynamic_chamfer_dis /= n + + return total_chamfer_dis, total_dynamic_chamfer_dis, frame_keys + + +def multi_frames_clusterLoss( + pc0_list, pc0_lab_list, flow_list, + frame_keys, frames_dists, frames_indices, res_dict, args={} +): + """RANSAC-weighted cluster consistency loss across multiple temporal frames. + + frames_dists[frame_id] : List[(N_i,)] per-sample dist from batched_disid_res + frames_indices[frame_id] : List[(N_i,)] per-sample LOCAL idx into frame_list[i] + """ + TOP_K = int(args.get('top_k_candidates', 5)) + COS_THRESH = args.get('ransac_cos_threshold', 0.7071) + TIME_DECAY = args.get('time_decay_factor', 0.9) + NET_EST_W = args.get('network_estimate_weight', 1.0) + + all_cluster_flows, all_target_flows, all_avg_losses = [], [], [] + + for i, (p0, lab0, fv) in enumerate(zip(pc0_list, pc0_lab_list, flow_list)): + for label in torch.unique(lab0): + if label <= 1: continue - # Eq. 8 in the SeFlow paper - sorted_idxs = torch.argsort(cluster_nnd, descending=True) - nearby_label = pc1_label[raw_idx0[mask][sorted_idxs]] # nonzero means dynamic in label - non_zero_valid_indices = torch.nonzero(nearby_label > 0) - if non_zero_valid_indices.shape[0] <= 0: + cluster_mask = (lab0 == label) + cluster_flows = fv[cluster_mask] + + ext_flows, ext_dists, ext_tw = [], [], [] + for frame_id in frame_keys: + dist_c = frames_dists[frame_id][i][cluster_mask] + idx_c = frames_indices[frame_id][i][cluster_mask] + if dist_c.shape[0] <= TOP_K: + continue + topk_dists, topk_local = torch.topk(dist_c, k=TOP_K) + target_pts = res_dict[f'{frame_id}_list'][i][idx_c[topk_local]] + src_pts = p0[cluster_mask][topk_local] + time_delta, factor = get_time_delta(frame_id) + # Eq. 3 in the TeFlow paper, with time decay and directionality + flows = (target_pts - src_pts) / factor * (-1 if time_delta < 0 else 1) + ext_flows.append(flows) + ext_dists.append(topk_dists) + ext_tw.append(torch.full((TOP_K,), pow(TIME_DECAY, factor), device=p0.device)) + + if not ext_flows: continue - max_idx = sorted_idxs[non_zero_valid_indices.squeeze(1)[0]] - # Eq. 9 in the SeFlow paper - max_flow = pc1[raw_idx0[mask][max_idx]] - pc0[mask][max_idx] - - # Eq. 10 in the SeFlow paper - moved_cluster_norms = torch.cat((moved_cluster_norms, torch.linalg.vector_norm((cluster_id_flow - max_flow), dim=-1))) - - if moved_cluster_norms.shape[0] > 0: - moved_cluster_loss = moved_cluster_norms.mean() # Eq. 11 in the SeFlow paper - elif have_dynamic_cluster: - moved_cluster_loss = torch.mean(raw_dist0[raw_dist0 <= TRUNCATED_DIST]) + torch.mean(raw_dist1[raw_dist1 <= TRUNCATED_DIST]) - # timer[5][3].stop() - - res_loss = { - 'chamfer_dis': chamfer_dis / 2.0, - 'dynamic_chamfer_dis': dynamic_chamfer_dis / 2.0, - 'static_flow_loss': static_cluster_loss, + # Eq. 2 in the TeFlow paper + net_avg = cluster_flows.mean(dim=0) + net_mag = torch.linalg.norm(net_avg) + # Eq. 4 in the TeFlow paper + all_cands = torch.cat(ext_flows + [net_avg.unsqueeze(0)], dim=0) + all_d = torch.cat(ext_dists + [net_mag.unsqueeze(0)], dim=0) + all_tw = torch.cat(ext_tw, dim=0) + if all_cands.shape[0] < 2: + continue + + d_norm = (all_d - all_d.min()) / (all_d.max() - all_d.min() + 1e-6) + # Eq. 5 + cos_sim = torch.nn.functional.cosine_similarity( + all_cands[:, None, :], all_cands[None, :, :], dim=-1) + inlier = cos_sim > COS_THRESH + # Eq. 6 + weights = torch.cat([all_tw * (1 + d_norm[:-1]), + (NET_EST_W * (1 + d_norm[-1])).unsqueeze(0)]) + # Eq. 7 + scores = torch.matmul(inlier.float(), weights.unsqueeze(1)).squeeze() + best = torch.argmax(scores) + + # Eq. 8 + inlier_flows = all_cands[inlier[best]] + inlier_w = weights[inlier[best]] + denom = inlier_w.sum() + target_flow = (inlier_w.unsqueeze(1) * inlier_flows).sum(dim=0) / denom \ + if denom > 1e-6 else all_cands[best] + + all_cluster_flows.append(cluster_flows) + all_target_flows.append(target_flow.expand_as(cluster_flows)) + all_avg_losses.append( + torch.linalg.vector_norm(cluster_flows - target_flow, dim=-1).mean() + ) + + # FIXME(Qingwen): maybe afterward we can have weight here to specific different weight on point/cluster etc. + if not all_cluster_flows: + return torch.tensor(0.0, device=flow_list[0].device) + # Eq. 9 with two terms + # NOTE(Qingwen): Point-level term + loss = torch.nn.functional.mse_loss( + torch.cat(all_cluster_flows), torch.cat(all_target_flows) + ) + # NOTE(Qingwen): Cluster-level term + loss += torch.stack(all_avg_losses).mean() + return loss + + +# ---- shared cluster loop (seflow / seflowpp) ------------------- +# SeFlow Paper: https://arxiv.org/pdf/2407.01702 +def _seflow_cluster_loop(pc0_list, pc1_list, pc0_lab_list, pc1_lab_list, + flow_list, dist0_list, idx0_list): + """Per-sample seflow cluster loss (Eq. 6-11). + + dist0_list, idx0_list : output of batched_disid_res(pc0_list, pc1_list) + idx0_list[i] is LOCAL into pc1_list[i]. + Returns (static_cluster_loss, moved_cluster_loss, have_any_dynamic). + """ + dev = flow_list[0].device + static_loss = torch.tensor(0.0, device=dev) + cluster_norms = [] + fallback_dists = [] + have_any_dyn = False + + for p0, p1, lab0, lab1, fv, dist0, idx0 in zip( + pc0_list, pc1_list, pc0_lab_list, pc1_lab_list, + flow_list, dist0_list, idx0_list): + have_dyn = (lab0 > 0).sum() > 256 and (lab1 > 0).sum() > 256 + if have_dyn: + have_any_dyn = True + fallback_dists.append(dist0) + + for label in torch.unique(lab0): + mask = (lab0 == label) + if label == 0: + # Eq. 6 in the paper + static_loss += torch.linalg.vector_norm(fv[mask], dim=-1).mean() + elif label > 1 and have_dyn: + c_flow = fv[mask] + c_idx0 = idx0[mask] + # Eq. 8 in the paper + sorted_local = torch.argsort(dist0[mask], descending=True) + max_idx = torch.nonzero(lab1[c_idx0[sorted_local]] > 0).squeeze(1) + if max_idx.shape[0] == 0: + continue + best = sorted_local[max_idx[0]] + # Eq. 9 in the paper + max_flow = p1[c_idx0[best]] - p0[mask][best] + # Eq. 10 in the paper + cluster_norms.append(torch.linalg.vector_norm(c_flow - max_flow, dim=-1)) + + if cluster_norms: + # Eq. 11 + moved_loss = torch.cat(cluster_norms).mean() + elif have_any_dyn: + all_d = torch.cat(fallback_dists) + moved_loss = torch.mean(all_d[all_d <= TRUNCATED_DIST]) + else: + moved_loss = torch.tensor(0.0, device=dev) + + return static_loss, moved_loss + + +def teflowLoss(res_dict, timer=None): + """Temporal seflow: chamfer over all frames + static + RANSAC cluster loss.""" + pc0_list = res_dict['pc0_list'] + flow_list = res_dict['est_flow_list'] + pc0_lab_list = res_dict['pc0_labels_list'] + + chamfer_dis, dynamic_chamfer_dis, frame_keys = batched_chamfer_related(res_dict, timer) + + static_loss = torch.tensor(0.0, device=pc0_list[0].device) + for fv, lab in zip(flow_list, pc0_lab_list): + if (lab == 0).any(): + static_loss += torch.linalg.vector_norm(fv[lab == 0], dim=-1).mean() + static_loss /= max(len(pc0_list), 1) + + cluster_weight = res_dict['loss_weights_dict'].get('cluster_based_pc0pc1', 0.0) + if cluster_weight > 0: + frames_dists, frames_indices = {}, {} + for frame_id in frame_keys: + d_list, i_list = MyCUDAChamferDis.batched_disid_res( + pc0_list, res_dict[f'{frame_id}_list'], + ) + frames_dists[frame_id] = d_list + frames_indices[frame_id] = i_list + + moved_cluster_loss = multi_frames_clusterLoss( + pc0_list, pc0_lab_list, flow_list, + frame_keys, frames_dists, frames_indices, res_dict, + res_dict.get('cluster_loss_args', {}), + ) + else: + moved_cluster_loss = torch.tensor(0.0, device=pc0_list[0].device) + + return { + 'chamfer_dis': chamfer_dis, + 'dynamic_chamfer_dis': dynamic_chamfer_dis, + 'static_flow_loss': static_loss, 'cluster_based_pc0pc1': moved_cluster_loss, } - return res_loss -def seflowLoss(res_dict, timer=None): - pc0_label = res_dict['pc0_labels'] - pc1_label = res_dict['pc1_labels'] - - pc0 = res_dict['pc0'] - pc1 = res_dict['pc1'] - - est_flow = res_dict['est_flow'] - - pseudo_pc1from0 = pc0 + est_flow - - unique_labels = torch.unique(pc0_label) - pc0_dynamic = pc0[pc0_label > 0] - pc1_dynamic = pc1[pc1_label > 0] - # fpc1_dynamic = pseudo_pc1from0[pc0_label > 0] - # NOTE(Qingwen): since we set THREADS_PER_BLOCK is 256 - have_dynamic_cluster = (pc0_dynamic.shape[0] > 256) & (pc1_dynamic.shape[0] > 256) - - # first item loss: chamfer distance - # timer[5][1].start("MyCUDAChamferDis") - # raw: pc0 to pc1, est: pseudo_pc1from0 to pc1, idx means the nearest index - est_dist0, est_dist1, _, _ = MyCUDAChamferDis.disid_res(pseudo_pc1from0, pc1) - raw_dist0, raw_dist1, raw_idx0, _ = MyCUDAChamferDis.disid_res(pc0, pc1) - chamfer_dis = torch.mean(est_dist0[est_dist0 <= TRUNCATED_DIST]) + torch.mean(est_dist1[est_dist1 <= TRUNCATED_DIST]) - # timer[5][1].stop() - - # second item loss: dynamic chamfer distance - # timer[5][2].start("DynamicChamferDistance") - dynamic_chamfer_dis = torch.tensor(0.0, device=est_flow.device) - if have_dynamic_cluster: - dynamic_chamfer_dis += MyCUDAChamferDis(pseudo_pc1from0[pc0_label>0], pc1_dynamic, truncate_dist=TRUNCATED_DIST) - # timer[5][2].stop() - - # third item loss: exclude static points' flow - # NOTE(Qingwen): add in the later part on label==0 - static_cluster_loss = torch.tensor(0.0, device=est_flow.device) - - # fourth item loss: same label points' flow should be the same - # timer[5][3].start("SameClusterLoss") - moved_cluster_loss = torch.tensor(0.0, device=est_flow.device) - moved_cluster_norms = torch.tensor([], device=est_flow.device) - for label in unique_labels: - mask = pc0_label == label - if label == 0: - # Eq. 6 in the paper - static_cluster_loss += torch.linalg.vector_norm(est_flow[mask, :], dim=-1).mean() - # NOTE(Qingwen) 2025-04-23: label=1 is dynamic but no cluster id satisfied - elif label > 1 and have_dynamic_cluster: - cluster_id_flow = est_flow[mask, :] - cluster_nnd = raw_dist0[mask] - if cluster_nnd.shape[0] <= 0: - continue +# from paper: https://arxiv.org/abs/2503.00803 +def seflowppLoss(res_dict, timer=None): + """seflow++ loss: bidirectional (pc1 + pch1) chamfer + cluster, B samples.""" + pc0_list = res_dict['pc0_list'] + pc1_list = res_dict['pc1_list'] + pch1_list = res_dict['pch1_list'] + flow_list = res_dict['est_flow_list'] + pc0_lab_list = res_dict['pc0_labels_list'] + pc1_lab_list = res_dict['pc1_labels_list'] + pch1_lab_list = res_dict['pch1_labels_list'] + dev = pc0_list[0].device - # Eq. 8 in the paper - sorted_idxs = torch.argsort(cluster_nnd, descending=True) - nearby_label = pc1_label[raw_idx0[mask][sorted_idxs]] # nonzero means dynamic in label - non_zero_valid_indices = torch.nonzero(nearby_label > 0) - if non_zero_valid_indices.shape[0] <= 0: - continue - max_idx = sorted_idxs[non_zero_valid_indices.squeeze(1)[0]] - - # Eq. 9 in the paper - max_flow = pc1[raw_idx0[mask][max_idx]] - pc0[mask][max_idx] - - # Eq. 10 in the paper - moved_cluster_norms = torch.cat((moved_cluster_norms, torch.linalg.vector_norm((cluster_id_flow - max_flow), dim=-1))) - - if moved_cluster_norms.shape[0] > 0: - moved_cluster_loss = moved_cluster_norms.mean() # Eq. 11 in the paper - elif have_dynamic_cluster: - moved_cluster_loss = torch.mean(raw_dist0[raw_dist0 <= TRUNCATED_DIST]) + torch.mean(raw_dist1[raw_dist1 <= TRUNCATED_DIST]) - # timer[5][3].stop() - - res_loss = { - 'chamfer_dis': chamfer_dis, - 'dynamic_chamfer_dis': dynamic_chamfer_dis, - 'static_flow_loss': static_cluster_loss, + fwd_list = [p0 + fv for p0, fv in zip(pc0_list, flow_list)] + bwd_list = [p0 - fv for p0, fv in zip(pc0_list, flow_list)] + + # Chamfer: both temporal directions concurrently + chamfer_dis = MyCUDAChamferDis.batched(fwd_list, pc1_list, truncate_dist=TRUNCATED_DIST) + chamfer_dis += MyCUDAChamferDis.batched(bwd_list, pch1_list, truncate_dist=TRUNCATED_DIST) + + # Dynamic chamfer + dyn_fwd, dyn_pc1 = [], [] + dyn_bwd, dyn_pch1 = [], [] + for fwd_i, bwd_i, p1_i, ph1_i, lab0_i, lab1_i, labh1_i in zip( + fwd_list, bwd_list, pc1_list, pch1_list, + pc0_lab_list, pc1_lab_list, pch1_lab_list): + dyn_mask = lab0_i > 0 + if dyn_mask.sum() > 256: + dp1 = p1_i[lab1_i > 0] + dph = ph1_i[labh1_i > 0] + if dp1.shape[0] > 256: dyn_fwd.append(fwd_i[dyn_mask]); dyn_pc1.append(dp1) + if dph.shape[0] > 256: dyn_bwd.append(bwd_i[dyn_mask]); dyn_pch1.append(dph) + + dynamic_chamfer_dis = torch.tensor(0.0, device=dev) + if len(dyn_fwd) == 1: + dynamic_chamfer_dis += MyCUDAChamferDis(dyn_fwd[0], dyn_pc1[0], truncate_dist=TRUNCATED_DIST) + elif len(dyn_fwd) > 1: + dynamic_chamfer_dis += MyCUDAChamferDis.batched(dyn_fwd, dyn_pc1, truncate_dist=TRUNCATED_DIST) + if len(dyn_bwd) == 1: + dynamic_chamfer_dis += MyCUDAChamferDis(dyn_bwd[0], dyn_pch1[0], truncate_dist=TRUNCATED_DIST) + elif len(dyn_bwd) > 1: + dynamic_chamfer_dis += MyCUDAChamferDis.batched(dyn_bwd, dyn_pch1, truncate_dist=TRUNCATED_DIST) + + dist0_list, idx0_list = MyCUDAChamferDis.batched_disid_res(pc0_list, pc1_list) + static_loss, moved_cluster_loss = _seflow_cluster_loop( + pc0_list, pc1_list, pc0_lab_list, pc1_lab_list, + flow_list, dist0_list, idx0_list, + ) + + return { + 'chamfer_dis': chamfer_dis / 2.0, + 'dynamic_chamfer_dis': dynamic_chamfer_dis / 2.0, + 'static_flow_loss': static_loss, 'cluster_based_pc0pc1': moved_cluster_loss, } - return res_loss + +# from paper: https://arxiv.org/abs/2407.01702 +def seflowLoss(res_dict, timer=None): + """seflow loss: single future frame (pc1), batched over B samples.""" + pc0_list = res_dict['pc0_list'] + pc1_list = res_dict['pc1_list'] + flow_list = res_dict['est_flow_list'] + pc0_lab_list = res_dict['pc0_labels_list'] + pc1_lab_list = res_dict['pc1_labels_list'] + dev = pc0_list[0].device + + fwd_list = [p0 + fv for p0, fv in zip(pc0_list, flow_list)] + + chamfer_dis = MyCUDAChamferDis.batched(fwd_list, pc1_list, truncate_dist=TRUNCATED_DIST) + + # Dynamic chamfer + dyn_fwd, dyn_pc1 = [], [] + for fwd_i, p1_i, lab0_i, lab1_i in zip(fwd_list, pc1_list, pc0_lab_list, pc1_lab_list): + dp1 = p1_i[lab1_i > 0] + if (lab0_i > 0).sum() > 256 and dp1.shape[0] > 256: + dyn_fwd.append(fwd_i[lab0_i > 0]) + dyn_pc1.append(dp1) + + dynamic_chamfer_dis = torch.tensor(0.0, device=dev) + if len(dyn_fwd) == 1: + dynamic_chamfer_dis = MyCUDAChamferDis(dyn_fwd[0], dyn_pc1[0], truncate_dist=TRUNCATED_DIST) + elif len(dyn_fwd) > 1: + dynamic_chamfer_dis = MyCUDAChamferDis.batched(dyn_fwd, dyn_pc1, truncate_dist=TRUNCATED_DIST) + + dist0_list, idx0_list = MyCUDAChamferDis.batched_disid_res(pc0_list, pc1_list) + static_loss, moved_cluster_loss = _seflow_cluster_loop( + pc0_list, pc1_list, pc0_lab_list, pc1_lab_list, + flow_list, dist0_list, idx0_list, + ) + + return { + 'chamfer_dis': chamfer_dis, + 'dynamic_chamfer_dis': dynamic_chamfer_dis, + 'static_flow_loss': static_loss, + 'cluster_based_pc0pc1': moved_cluster_loss, + } \ No newline at end of file diff --git a/src/trainer.py b/src/trainer.py index de064f1..ffeb365 100644 --- a/src/trainer.py +++ b/src/trainer.py @@ -19,12 +19,13 @@ from lightning import LightningModule from hydra.utils import instantiate -from omegaconf import OmegaConf,open_dict +from omegaconf import OmegaConf, open_dict -import os, sys, time, h5py, pickle +import os, sys, time, h5py BASE_DIR = os.path.abspath(os.path.join( os.path.dirname( __file__ ), '..' )) sys.path.append(BASE_DIR) from src.utils import import_func +from src.lossfuncs import SSL_LOSSES_FN from src.utils.mics import weights_init, zip_res from src.utils.av2_eval import write_output_file from src.models.basic import cal_pose0to1, WarmupCosLR @@ -51,9 +52,12 @@ def __init__(self, cfg, eval=False): "save_res": False, "res_name": "default", "num_frames": 2, + + # lr scheduler, only active when warmup_epochs > 0 "optimizer": None, "dataset_path": None, "data_mode": None, + "cluster_loss_args": {}, } for key, default in default_self_values.items(): setattr(self, key, cfg.get(key, default)) @@ -92,74 +96,124 @@ def __init__(self, cfg, eval=False): self.save_res_path = Path(cfg.dataset_path).parent / "results" / cfg.output os.makedirs(self.save_res_path, exist_ok=True) print(f"We are in {cfg.data_mode}, results will be saved in: {self.save_res_path} with version: {self.leaderboard_version} format for online leaderboard.") - - # self.test_total_num = 0 if self.data_mode in ['val', 'valid', 'test']: print(cfg) + # self.test_total_num = 0 self.save_hyperparameters() + + def ssl_loss_calculator(self, batch, res_dict, if_log=True): + """Build dict2loss for ALL self-supervised losses (seflow, seflowpp, teflow*). - # FIXME(Qingwen 2025-08-20): update the loss_calculation fn alone to make all things pretty here.... - def training_step(self, batch, batch_idx): - self.model.timer[4].start("One Scan in model") - res_dict = self.model(batch) - self.model.timer[4].stop() + Each frame is represented only as a List[Tensor] and a List[labels]. + No flat tensors, no offsets, no sizes — chamfer calls use list APIs only. + """ + total_loss, bz_ = 0.0, len(batch["pose0"]) - self.model.timer[5].start("Loss") - # compute loss - total_loss = 0.0 + pc0_list = [res_dict['pc0_points_lst'][i] for i in range(bz_)] - if self.cfg_loss_name in ['seflowLoss', 'seflowppLoss']: - loss_items, weights = zip(*[(key, weight) for key, weight in self.add_seloss.items()]) - loss_logger = {'chamfer_dis': 0.0, 'dynamic_chamfer_dis': 0.0, 'static_flow_loss': 0.0, 'cluster_based_pc0pc1': 0.0} - else: - loss_items, weights = ['loss'], [1.0] - loss_logger = {'loss': 0.0} - - pc0_valid_idx = res_dict['pc0_valid_point_idxes'] # since padding - pc1_valid_idx = res_dict['pc1_valid_point_idxes'] # since padding - if 'pc0_points_lst' in res_dict and 'pc1_points_lst' in res_dict: - pc0_points_lst = res_dict['pc0_points_lst'] - pc1_points_lst = res_dict['pc1_points_lst'] + dict2loss = { + 'pc0_list': pc0_list, + 'est_flow_list': [res_dict['flow'][i] for i in range(bz_)], + 'pc0_labels_list': [batch['pc0_dynamic'][i][res_dict['pc0_valid_point_idxes'][i]] for i in range(bz_)], + 'batch_size': bz_, + } + + frame_keys = [key.replace('_points_lst', '') for key in res_dict.keys() + if key.startswith('pc') and key.endswith('_points_lst')] + frame_keys.remove('pc0') + + for frame_id in frame_keys: + points_list = [res_dict[f'{frame_id}_points_lst'][i] for i in range(bz_)] + labels_list = [batch[f'{frame_id}_dynamic'][i][res_dict[f'{frame_id}_valid_point_idxes'][i]] for i in range(bz_)] + dict2loss[f'{frame_id}_list'] = points_list + dict2loss[f'{frame_id}_labels_list'] = labels_list + + loss_items, weights = zip(*[(key, weight) for key, weight in self.add_seloss.items()]) + dict2loss['loss_weights_dict'] = self.add_seloss + + dict2loss['cluster_loss_args'] = self.cluster_loss_args + + res_loss = self.loss_fn(dict2loss) + + for i, loss_name in enumerate(loss_items): + if not torch.isnan(res_loss[loss_name]): + total_loss += weights[i] * res_loss[loss_name] - batch_sizes = len(batch["pose0"]) - pose_flows = res_dict['pose_flow'] - est_flow = res_dict['flow'] + if if_log: + self.log("trainer/loss", total_loss, sync_dist=True, batch_size=bz_, prog_bar=True) + for key in res_loss: + self.log(f"trainer/{key}", res_loss[key], sync_dist=True, batch_size=bz_) + + return total_loss + + def loss_calculator(self, batch, res_dict, if_log=True): + """ Calculate the loss based on the batch (gt/ssl-label) and res_dict (estimate flow).""" + def get_batch_data(batch, key, batch_id, batch_sizes, pc0_valid_from_pc2res, pose_flow_=None): + """NOTE(Qingwen): for gt need double check whether it exists in the batch and batch size is correct""" + if key not in batch or batch[key].shape[0] != batch_sizes: + return None + data = batch[key][batch_id][pc0_valid_from_pc2res] + if key == 'flow' and pose_flow_ is not None: + data = data - pose_flow_ + return data + def get_frame_keys(data_dict, suffix): + return [key for key in data_dict.keys() if key.endswith(suffix)] + def extract_frame_id(key, suffix): + """Extract frame identifier from key (e.g., 'pc0_points_lst' -> 'pc0')""" + return key.replace(suffix, '') + # Supervised-only path (deflowLoss, etc.) + # SSL losses are handled by ssl_loss_calculator. + total_loss, loss_logger = 0.0, {} + loss_items, weights = ['loss'], [1.0] + for key in loss_items: + loss_logger[key] = 0.0 + + batch_sizes, pose_flows, est_flow = len(batch["pose0"]), res_dict['pose_flow'], res_dict['flow'] for batch_id in range(batch_sizes): - pc0_valid_from_pc2res = pc0_valid_idx[batch_id] - pc1_valid_from_pc2res = pc1_valid_idx[batch_id] + # Get pc0 valid indices (main reference frame) + pc0_valid_from_pc2res = res_dict['pc0_valid_point_idxes'][batch_id] pose_flow_ = pose_flows[batch_id][pc0_valid_from_pc2res] dict2loss = {'est_flow': est_flow[batch_id], - 'gt_flow': None if 'flow' not in batch else batch['flow'][batch_id][pc0_valid_from_pc2res] - pose_flow_, - 'gt_classes': None if 'flow_category_indices' not in batch else batch['flow_category_indices'][batch_id][pc0_valid_from_pc2res], - 'gt_instance': None if 'flow_instance_id' not in batch else batch['flow_instance_id'][batch_id][pc0_valid_from_pc2res],} + 'gt_flow': get_batch_data(batch, 'flow', batch_id, batch_sizes, pc0_valid_from_pc2res, pose_flow_), + 'gt_classes': get_batch_data(batch, 'flow_category_indices', batch_id, batch_sizes, pc0_valid_from_pc2res), + 'gt_instance': get_batch_data(batch, 'flow_instance_id', batch_id, batch_sizes, pc0_valid_from_pc2res)} - if 'pc0_dynamic' in batch: - dict2loss['pc0_labels'] = batch['pc0_dynamic'][batch_id][pc0_valid_from_pc2res] - dict2loss['pc1_labels'] = batch['pc1_dynamic'][batch_id][pc1_valid_from_pc2res] - if 'pch1_dynamic' in batch and 'pch1_valid_point_idxes' in res_dict: - dict2loss['pch1_labels'] = batch['pch1_dynamic'][batch_id][res_dict['pch1_valid_point_idxes'][batch_id]] - - # different methods may don't have this in the res_dict - if 'pc0_points_lst' in res_dict and 'pc1_points_lst' in res_dict: - dict2loss['pc0'] = pc0_points_lst[batch_id] - dict2loss['pc1'] = pc1_points_lst[batch_id] - if 'pch1_points_lst' in res_dict: - dict2loss['pch1'] = res_dict['pch1_points_lst'][batch_id] + # Add all available point cloud frames + for points_key in get_frame_keys(res_dict, '_points_lst'): + frame_id = extract_frame_id(points_key, '_points_lst') + if points_key in res_dict: + dict2loss[frame_id] = res_dict[points_key][batch_id] res_loss = self.loss_fn(dict2loss) + for i, loss_name in enumerate(loss_items): + # if torch.isnan(res_loss[loss_name]): + # print(f"==> Loss: {loss_name} is nan, skip this batch.") + # continue total_loss += weights[i] * res_loss[loss_name] for key in res_loss: loss_logger[key] += res_loss[key] + if if_log: + self.log("trainer/loss", total_loss/batch_sizes, sync_dist=True, batch_size=self.batch_size, prog_bar=True) + return total_loss + + def training_step(self, batch, batch_idx): + total_loss = 0.0 + self.model.timer[5].start("Training Step") + self.model.timer[5][0].start("Forward") + res_dict = self.model(batch) + self.model.timer[5][0].stop() + self.model.timer[5][1].start("Compute Loss") - self.log("trainer/loss", total_loss/batch_sizes, sync_dist=True, batch_size=self.batch_size, prog_bar=True) - if self.add_seloss is not None and self.cfg_loss_name in ['seflowLoss', 'seflowppLoss']: - for key in loss_logger: - self.log(f"trainer/{key}", loss_logger[key]/batch_sizes, sync_dist=True, batch_size=self.batch_size) + if self.cfg_loss_name in SSL_LOSSES_FN: + total_loss = self.ssl_loss_calculator(batch, res_dict) + else: + total_loss = self.loss_calculator(batch, res_dict) + self.model.timer[5][1].stop() self.model.timer[5].stop() - + # NOTE (Qingwen): if you want to view the detail breakdown of time cost # self.model.timer.print(random_colors=False, bold=False) return total_loss @@ -206,6 +260,8 @@ def on_train_epoch_start(self): def on_train_epoch_end(self): self.log("pre_epoch_cost (mins)", (time.time()-self.time_start_train_epoch)/60.0, on_step=False, on_epoch=True, sync_dist=True) + # # NOTE (Qingwen): if you want to view the detail breakdown of time cost + # self.model.timer.print(random_colors=False, bold=False) def on_validation_epoch_end(self): self.model.timer.print(random_colors=False, bold=False) @@ -223,9 +279,9 @@ def on_validation_epoch_end(self): # wandb.log_artifact(output_file) return - if self.data_mode == 'val': + if self.data_mode in ['val', 'valid']: print(f"\nModel: {self.model.__class__.__name__}, Checkpoint from: {self.checkpoint}") - print(f"More details parameters and training status are in the checkpoint file.") + print(f"More details parameters and training status are in checkpoints file.") self.metrics.normalize() @@ -238,15 +294,12 @@ def on_validation_epoch_end(self): self.metrics.print() + self.metrics = OfficialMetrics() + if self.save_res: - # Save the dictionaries to a pickle file - with open(str(self.save_res_path)+'.pkl', 'wb') as f: - pickle.dump((self.metrics.epe_3way, self.metrics.bucketed, self.metrics.epe_ssf), f) - print(f"We already write the {self.res_name} into the dataset, please run following commend to visualize the flow. Copy and paste it to your terminal:") - print(f"python tools/visualization.py vis --res_name '{self.res_name}' --data_dir {self.dataset_path}") + print(f"We already write the flow_est into the dataset, please run following commend to visualize the flow. Copy and paste it to your terminal:") + print(f"python tools/visualization.py --res_name \"['{self.res_name}']\" --data_dir {self.dataset_path}") print(f"Enjoy! ^v^ ------ \n") - - self.metrics = OfficialMetrics() def eval_only_step_(self, batch, res_dict): eval_mask = batch['eval_mask'].squeeze() @@ -261,24 +314,35 @@ def eval_only_step_(self, batch, res_dict): # flow in the original pc0 coordinate pred_flow = pose_flow[~batch['gm0']].clone() + # debug: for ego-motion flow only + # res_dict['flow'] = torch.zeros_like(res_dict['flow']) pred_flow[valid_from_pc2res] = res_dict['flow'] + pose_flow[~batch['gm0']][valid_from_pc2res] final_flow[~batch['gm0']] = pred_flow else: final_flow[~batch['gm0']] = res_dict['flow'] + pose_flow[~batch['gm0']] - if self.data_mode == 'val': # since only val we have ground truth flow to eval + if self.data_mode in ['val', 'valid']: # since only val we have ground truth flow to eval gt_flow = batch["flow"] v1_dict = evaluate_leaderboard(final_flow[eval_mask], pose_flow[eval_mask], pc0[eval_mask], \ gt_flow[eval_mask], batch['flow_is_valid'][eval_mask], \ batch['flow_category_indices'][eval_mask]) v2_dict = evaluate_leaderboard_v2(final_flow[eval_mask], pose_flow[eval_mask], pc0[eval_mask], \ gt_flow[eval_mask], batch['flow_is_valid'][eval_mask], batch['flow_category_indices'][eval_mask]) - ssf_dict = evaluate_ssf(final_flow, pose_flow, pc0, \ - gt_flow, batch['flow_is_valid'], batch['flow_category_indices']) + ssf_dict = evaluate_ssf(final_flow[eval_mask], pose_flow[eval_mask], pc0[eval_mask], \ + gt_flow[eval_mask], batch['flow_is_valid'][eval_mask], batch['flow_category_indices'][eval_mask]) + self.metrics.step(v1_dict, v2_dict, ssf_dict) - + if self.save_res: + # write final_flow into the dataset. + key = str(batch['timestamp']) + scene_id = batch['scene_id'] + with h5py.File(os.path.join(self.dataset_path, f'{self.data_mode}/{scene_id}.h5'), 'r+') as f: + if self.res_name in f[key]: + del f[key][self.res_name] + f[key].create_dataset(self.res_name, data=final_flow.cpu().detach().numpy().astype(np.float32)) + # NOTE (Qingwen): Since val and test, we will force set batch_size = 1 - if self.save_res or self.data_mode == 'test': # test must save data to submit in the online leaderboard. + if self.save_res and self.data_mode == 'test': # test must save data to submit in the online leaderboard. save_pred_flow = final_flow[eval_mask, :3].cpu().detach().numpy() rigid_flow = pose_flow[eval_mask, :3].cpu().detach().numpy() is_dynamic = np.linalg.norm(save_pred_flow - rigid_flow, axis=1, ord=2) >= 0.05 @@ -302,19 +366,22 @@ def run_model_wo_ground_data(self, batch): # NOTE (Qingwen): Since val and test, we will force set batch_size = 1 batch = {key: batch[key][0] for key in batch if len(batch[key])>0} - res_dict = {key: res_dict[key][0] for key in res_dict if res_dict[key]!=None and len(res_dict[key])>0} + res_dict = {key: res_dict[key][0] for key in res_dict if (res_dict[key]!=None and len(res_dict[key])>0) } return batch, res_dict def validation_step(self, batch, batch_idx): - if self.data_mode in ['val', 'test']: - batch, res_dict = self.run_model_wo_ground_data(batch) - self.model.timer[13].start("Eval") - self.eval_only_step_(batch, res_dict) - self.model.timer[13].stop() - else: - res_dict = self.model(batch) - self.train_validation_step_(batch, res_dict) - + try: + if self.data_mode in ['val', 'valid'] or self.data_mode == 'test': + batch, res_dict = self.run_model_wo_ground_data(batch) + if batch['eval_flag']: + self.eval_only_step_(batch, res_dict) + else: + res_dict = self.model(batch) + self.train_validation_step_(batch, res_dict) + except Exception as e: + print(f"==> Exception occur during training/validation step: {e}. Skip this batch.") + print(f"Batch info: scene_id: {batch['scene_id']}, timestamp: {batch['timestamp']}, pc0 size: {batch['pc0']}") + def test_step(self, batch, batch_idx): batch, res_dict = self.run_model_wo_ground_data(batch) pc0 = batch['origin_pc0'] @@ -346,5 +413,5 @@ def on_test_epoch_end(self): self.model.timer.print(random_colors=False, bold=False) print(f"\n\nModel: {self.model.__class__.__name__}, Checkpoint from: {self.checkpoint}") print(f"We already write the flow_est into the dataset, please run following commend to visualize the flow. Copy and paste it to your terminal:") - print(f"python tools/visualization.py --res_name '{self.res_name}' --data_dir {self.dataset_path}") + print(f"python tools/visualization.py --res_name \"['{self.res_name}']\" --data_dir {self.dataset_path}") print(f"Enjoy! ^v^ ------ \n") From 52a88441940625fd006d4305f912bce3368c6ff6 Mon Sep 17 00:00:00 2001 From: Kin Date: Wed, 11 Mar 2026 17:55:15 +0100 Subject: [PATCH 5/6] docs(apptainer): update apptainer env for diff cluster env. update slurm and command for teflow --- README.md | 4 ++-- assets/README.md | 31 ++++++++++++++++++++++++++++--- assets/opensf.def | 14 ++++++++++++++ assets/slurm/1_train.sh | 34 ---------------------------------- assets/slurm/2_eval.sh | 20 -------------------- assets/slurm/train-teflow.sh | 25 +++++++++++++++++++++++++ 6 files changed, 69 insertions(+), 59 deletions(-) create mode 100644 assets/opensf.def delete mode 100644 assets/slurm/1_train.sh delete mode 100644 assets/slurm/2_eval.sh create mode 100644 assets/slurm/train-teflow.sh diff --git a/README.md b/README.md index 983a28e..4f2684e 100644 --- a/README.md +++ b/README.md @@ -212,10 +212,10 @@ Train Feed-forward SSL methods (e.g. SeFlow/SeFlow++/VoteFlow etc), we needed to ```bash # [Runtime: Around ? hours in 10x GPUs.] -python train.py model=deltaflow epochs=21 batch_size=2 num_frames=5 train_aug=True \ +python train.py model=deltaflow epochs=15 batch_size=2 num_frames=5 train_aug=True \ loss_fn=teflowLoss "voxel_size=[0.15, 0.15, 0.15]" "point_cloud_range=[-38.4, -38.4, -3, 38.4, 38.4, 3]" \ +ssl_label=seflow_auto "+add_seloss={chamfer_dis: 1.0, static_flow_loss: 1.0, dynamic_chamfer_dis: 1.0, cluster_based_pc0pc1: 1.0}" \ - optimizer.name=Adam optimizer.lr=1e-4 +optimizer.scheduler.name=StepLR +optimizer.scheduler.step_size=9 +optimizer.scheduler.gamma=0.5 + optimizer.name=Adam optimizer.lr=2e-3 +optimizer.scheduler.name=StepLR +optimizer.scheduler.step_size=9 +optimizer.scheduler.gamma=0.5 # Pretrained weight can be downloaded through: wget https://huggingface.co/kin-zhang/OpenSceneFlow/resolve/main/teflow/teflow-av2.ckpt diff --git a/assets/README.md b/assets/README.md index d49ebe8..f14b78c 100644 --- a/assets/README.md +++ b/assets/README.md @@ -51,7 +51,30 @@ Then follow [this stackoverflow answers](https://stackoverflow.com/questions/596 ```bash cd OpenSceneFlow && docker build -f Dockerfile -t zhangkin/opensf . ``` - + +### To Apptainer container + +If you want to build a **minimal** training env for Apptainer container, you can use the following command: +```bash +apptainer build opensf.sif assets/opensf.def +# zhangkin/opensf:full is created by Dockerfile +``` + +Then run as a Python env with: +```bash +PYTHON="apptainer run --nv --writable-tmpfs opensf.sif" +$PYTHON train.py +``` + + + + ## Installation We will use conda to manage the environment with mamba for faster package installation. @@ -77,10 +100,11 @@ Checking important packages in our environment now: ```bash mamba activate opensf python -c "import torch; print(torch.__version__); print(torch.cuda.is_available()); print(torch.version.cuda)" -python -c "import lightning.pytorch as pl; print(pl.__version__)" +python -c "import lightning.pytorch as pl; print('pl version:', pl.__version__)" +python -c "import spconv.pytorch as spconv; print('spconv import successfully')" python -c "from assets.cuda.mmcv import Voxelization, DynamicScatter;print('successfully import on our lite mmcv package')" python -c "from assets.cuda.chamfer3D import nnChamferDis;print('successfully import on our chamfer3D package')" -python -c "from av2.utils.io import read_feather; print('av2 package ok')" +python -c "from av2.utils.io import read_feather; print('av2 package ok') " ``` @@ -98,6 +122,7 @@ python -c "from av2.utils.io import read_feather; print('av2 package ok')" 2. In cluster have error: `pandas ImportError: /lib64/libstdc++.so.6: version 'GLIBCXX_3.4.29' not found` Solved by `export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/proj/berzelius-2023-154/users/x_qinzh/mambaforge/lib` +4. nvidia channel cannot put into env.yaml file otherwise, the cuda-toolkit will always be the latest one, for me (2025-04-30) I struggling on an hour and get nvcc -V also 12.8 at that time. py=3.10 for cuda >=12.1. (seems it's nvidia cannot be in the channel list???); py<3.10 for cuda <=11.8.0: otherwise 10x, 20x series GPU won't work on cuda compiler. (half precision) 3. torch_scatter problem: `OSError: /home/kin/mambaforge/envs/opensf-v2/lib/python3.10/site-packages/torch_scatter/_version_cpu.so: undefined symbol: _ZN5torch3jit17parseSchemaOrNameERKNSt7__cxx1112basic_stringIcSt11char_traitsIcESaIcEEE` Solved by install the torch-cuda version: `pip install https://data.pyg.org/whl/torch-2.0.0%2Bcu118/torch_scatter-2.1.2%2Bpt20cu118-cp310-cp310-linux_x86_64.whl` diff --git a/assets/opensf.def b/assets/opensf.def new file mode 100644 index 0000000..bddebf4 --- /dev/null +++ b/assets/opensf.def @@ -0,0 +1,14 @@ +Bootstrap: docker +From: zhangkin/opensf:full + +%files + assets/cuda /workspace/assets/cuda + src/models/basic/voteflow_plugin /workspace/src/models/basic/voteflow_plugin + environment.yaml /workspace/environment.yaml + +%runscript + echo "Running pip install for local CUDA modules..." + /opt/conda/envs/opensf/bin/pip install /workspace/assets/cuda/chamfer3D + /opt/conda/envs/opensf/bin/pip install /workspace/assets/cuda/mmcv + /opt/conda/envs/opensf/bin/pip install /workspace/src/models/basic/voteflow_plugin/hough_transformation/cpp_im2ht + exec /opt/conda/envs/opensf/bin/python "$@" \ No newline at end of file diff --git a/assets/slurm/1_train.sh b/assets/slurm/1_train.sh deleted file mode 100644 index dd99a8d..0000000 --- a/assets/slurm/1_train.sh +++ /dev/null @@ -1,34 +0,0 @@ -#!/bin/bash -#SBATCH -J seflow -#SBATCH --gpus 4 -C "fat" -#SBATCH -t 3-00:00:00 -#SBATCH --mail-type=END,FAIL -#SBATCH --mail-user=qingwen@kth.se -#SBATCH --output /proj/berzelius-2023-154/users/x_qinzh/seflow/logs/slurm/%J_seflow.out -#SBATCH --error /proj/berzelius-2023-154/users/x_qinzh/seflow/logs/slurm/%J_seflow.err - -PYTHON=/proj/berzelius-2023-154/users/x_qinzh/mambaforge/envs/opensf/bin/python -export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/proj/berzelius-2023-154/users/x_qinzh/mambaforge/lib -cd /proj/berzelius-2023-364/users/x_qinzh/workspace/OpenSceneFlow - - -# ===> to transfer data into local node disk, it can be ignored. <=== -SOURCE="/proj/berzelius-2023-364/users/x_qinzh/data/av2/autolabel" -DEST="/scratch/local/av2" -SUBDIRS=("sensor/train" "sensor/val") - -start_time=$(date +%s) -for dir in "${SUBDIRS[@]}"; do - mkdir -p "${DEST}/${dir}" - find "${SOURCE}/${dir}" -type f -print0 | xargs -0 -n1 -P16 cp -t "${DEST}/${dir}" & -done -wait -end_time=$(date +%s) -elapsed=$((end_time - start_time)) -echo "Copy ${SOURCE} to ${DEST} Total time: ${elapsed} seconds" -echo "Start training..." - -# ====> leaderboard model = seflow_best -$PYTHON train.py slurm_id=$SLURM_JOB_ID wandb_mode=online train_data=/scratch/local/av2/sensor/train val_data=/scratch/local/av2/sensor/val \ - num_workers=16 model=deflow lr=2e-4 epochs=9 batch_size=16 "model.target.num_iters=2" "model.val_monitor=val/Dynamic/Mean" \ - loss_fn=seflowLoss "add_seloss={chamfer_dis: 1.0, static_flow_loss: 1.0, dynamic_chamfer_dis: 1.0, cluster_based_pc0pc1: 1.0}" diff --git a/assets/slurm/2_eval.sh b/assets/slurm/2_eval.sh deleted file mode 100644 index 1a57440..0000000 --- a/assets/slurm/2_eval.sh +++ /dev/null @@ -1,20 +0,0 @@ -#!/bin/bash -#SBATCH -J eval -#SBATCH --gpus 1 -#SBATCH -t 01:00:00 -#SBATCH --output /proj/berzelius-2023-154/users/x_qinzh/seflow/logs/slurm/%J_eval.out -#SBATCH --error /proj/berzelius-2023-154/users/x_qinzh/seflow/logs/slurm/%J_eval.err - - -PYTHON=/proj/berzelius-2023-154/users/x_qinzh/mambaforge/envs/opensf/bin/python -export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/proj/berzelius-2023-154/users/x_qinzh/mambaforge/lib -cd /proj/berzelius-2023-364/users/x_qinzh/workspace/OpenSceneFlow - - -# ====> leaderboard model -# $PYTHON eval.py wandb_mode=online dataset_path=/proj/berzelius-2023-364/users/x_qinzh/data/av2/autolabel data_mode=test \ -# checkpoint=/proj/berzelius-2023-154/users/x_qinzh/seflow/logs/wandb/seflow-10086990/checkpoints/epoch_19_seflow.ckpt \ -# save_res=True - -$PYTHON eval.py wandb_mode=online dataset_path=/proj/berzelius-2023-364/users/x_qinzh/data/av2/autolabel data_mode=val \ - checkpoint=/proj/berzelius-2023-154/users/x_qinzh/seflow/logs/wandb/seflow-10086990/checkpoints/epoch_19_seflow.ckpt \ No newline at end of file diff --git a/assets/slurm/train-teflow.sh b/assets/slurm/train-teflow.sh new file mode 100644 index 0000000..be7888c --- /dev/null +++ b/assets/slurm/train-teflow.sh @@ -0,0 +1,25 @@ +#!/bin/bash +#SBATCH -J teflow +#SBATCH -A NAISS2026-3-96 -p alvis +#SBATCH -N 1 --gpus-per-node=T4:8 +#SBATCH -t 5-00:00:00 +#SBATCH --mail-type=END,FAIL +#SBATCH --mail-user=qingwen@kth.se +#SBATCH --output /cephyr/users/qingwenz/Alvis/workspace/OpenSceneFlow/logs/slurm/%J.out +#SBATCH --error /cephyr/users/qingwenz/Alvis/workspace/OpenSceneFlow/logs/slurm/%J.err + +cd /cephyr/users/qingwenz/Alvis/workspace/OpenSceneFlow +PYTHON="apptainer run --nv --writable-tmpfs /mimer/NOBACKUP/groups/kthrpl_patric/data/apptainer/opensf-full.sif" + +# sometimes diff gpus have different CUDA capability, and the compile package may not working +# PYTHON="/mimer/NOBACKUP/groups/kthrpl_patric/users/qingwenz/miniforge3/envs/opensf/bin/python" + +# ===> Need change it data path changed <=== +SOURCE="/mimer/NOBACKUP/groups/kthrpl_patric/data/h5py/av2" + +# ========================= TeFlow num_frame=5 ========================= +$PYTHON train.py slurm_id=$SLURM_JOB_ID wandb_mode=online wandb_project_name=teflow train_data=${SOURCE}/sensor/train val_data=${SOURCE}/sensor/val \ + model=deltaflow save_top_model=2 val_every=3 train_aug=True "voxel_size=[0.15, 0.15, 0.15]" "point_cloud_range=[-38.4, -38.4, -3, 38.4, 38.4, 3]" \ + num_workers=16 epochs=15 batch_size=2 num_frames=5 "+add_seloss={chamfer_dis: 1.0, static_flow_loss: 1.0, dynamic_chamfer_dis: 1.0, cluster_based_pc0pc1: 1.0}" \ + +ssl_label=seflow_auto loss_fn=teflowLoss optimizer.name=Adam optimizer.lr=2e-3 +optimizer.scheduler.name=StepLR +optimizer.scheduler.step_size=9 +optimizer.scheduler.gamma=0.5 + From 235b80129e825ad63707d76627f8f922a46b568d Mon Sep 17 00:00:00 2001 From: Kin Date: Thu, 12 Mar 2026 09:14:31 +0100 Subject: [PATCH 6/6] update train with rename jobid if it's self-supervised loss. --- assets/slurm/train-teflow.sh | 25 ------------------------- train.py | 6 +++--- 2 files changed, 3 insertions(+), 28 deletions(-) delete mode 100644 assets/slurm/train-teflow.sh diff --git a/assets/slurm/train-teflow.sh b/assets/slurm/train-teflow.sh deleted file mode 100644 index be7888c..0000000 --- a/assets/slurm/train-teflow.sh +++ /dev/null @@ -1,25 +0,0 @@ -#!/bin/bash -#SBATCH -J teflow -#SBATCH -A NAISS2026-3-96 -p alvis -#SBATCH -N 1 --gpus-per-node=T4:8 -#SBATCH -t 5-00:00:00 -#SBATCH --mail-type=END,FAIL -#SBATCH --mail-user=qingwen@kth.se -#SBATCH --output /cephyr/users/qingwenz/Alvis/workspace/OpenSceneFlow/logs/slurm/%J.out -#SBATCH --error /cephyr/users/qingwenz/Alvis/workspace/OpenSceneFlow/logs/slurm/%J.err - -cd /cephyr/users/qingwenz/Alvis/workspace/OpenSceneFlow -PYTHON="apptainer run --nv --writable-tmpfs /mimer/NOBACKUP/groups/kthrpl_patric/data/apptainer/opensf-full.sif" - -# sometimes diff gpus have different CUDA capability, and the compile package may not working -# PYTHON="/mimer/NOBACKUP/groups/kthrpl_patric/users/qingwenz/miniforge3/envs/opensf/bin/python" - -# ===> Need change it data path changed <=== -SOURCE="/mimer/NOBACKUP/groups/kthrpl_patric/data/h5py/av2" - -# ========================= TeFlow num_frame=5 ========================= -$PYTHON train.py slurm_id=$SLURM_JOB_ID wandb_mode=online wandb_project_name=teflow train_data=${SOURCE}/sensor/train val_data=${SOURCE}/sensor/val \ - model=deltaflow save_top_model=2 val_every=3 train_aug=True "voxel_size=[0.15, 0.15, 0.15]" "point_cloud_range=[-38.4, -38.4, -3, 38.4, 38.4, 3]" \ - num_workers=16 epochs=15 batch_size=2 num_frames=5 "+add_seloss={chamfer_dis: 1.0, static_flow_loss: 1.0, dynamic_chamfer_dis: 1.0, cluster_based_pc0pc1: 1.0}" \ - +ssl_label=seflow_auto loss_fn=teflowLoss optimizer.name=Adam optimizer.lr=2e-3 +optimizer.scheduler.name=StepLR +optimizer.scheduler.step_size=9 +optimizer.scheduler.gamma=0.5 - diff --git a/train.py b/train.py index b7d7eaf..f5ed7ca 100644 --- a/train.py +++ b/train.py @@ -28,9 +28,9 @@ from src.dataset import HDF5Dataset, collate_fn_pad, RandomHeight, RandomFlip, RandomJitter, ToTensor from torchvision import transforms from src.trainer import ModelWrapper - +from src.lossfuncs import SSL_LOSSES_FN def precheck_cfg_valid(cfg): - if cfg.loss_fn in ['seflowLoss', 'seflowppLoss'] and (cfg.add_seloss is None or cfg.ssl_label is None): + if cfg.loss_fn in SSL_LOSSES_FN and (cfg.add_seloss is None or cfg.ssl_label is None): raise ValueError("Please specify the self-supervised loss items and auto-label source for seflow-series loss.") grid_size = [(cfg.point_cloud_range[3] - cfg.point_cloud_range[0]) * (1/cfg.voxel_size[0]), @@ -83,7 +83,7 @@ def main(cfg): output_dir = HydraConfig.get().runtime.output_dir # overwrite logging folder name for SSL. - if cfg.loss_fn in ['seflowLoss', 'seflowppLoss']: + if cfg.loss_fn in SSL_LOSSES_FN: tmp_ = cfg.loss_fn.split('Loss')[0] + '-' + cfg.model.name cfg.output = cfg.output.replace(cfg.model.name, tmp_) output_dir = output_dir.replace(cfg.model.name, tmp_)