diff --git a/README.md b/README.md index f7ad2d4..4f2684e 100644 --- a/README.md +++ b/README.md @@ -14,7 +14,7 @@ It is also an official implementation of the following papers (sorted by the tim - **TeFlow: Enabling Multi-frame Supervision for Self-Supervised Feed-forward Scene Flow Estimation** *Qingwen Zhang, Chenhan Jiang, Xiaomeng Zhu, Yunqi Miao, Yushan Zhang, Olov Andersson, Patric Jensfelt* Conference on Computer Vision and Pattern Recognition (**CVPR**) 2026 -[ Strategy ] [ Self-Supervised ] - [ [arXiv](https://arxiv.org/abs/2602.19053) ] [ [Project]() ] +[ Strategy ] [ Self-Supervised ] - [ [arXiv](https://arxiv.org/abs/2602.19053) ] [ [Project]() ]→ [here](#teflow) - **DeltaFlow: An Efficient Multi-frame Scene Flow Estimation Method** *Qingwen Zhang, Xiaomeng Zhu, Yushan Zhang, Yixi Cai, Olov Andersson, Patric Jensfelt* @@ -149,7 +149,9 @@ Train DeltaFlow with the leaderboard submit config. [Runtime: Around 18 hours in ```bash # total bz then it's 10x2 under above training setup. -python train.py model=deltaFlow optimizer.lr=2e-3 epochs=20 batch_size=2 num_frames=5 loss_fn=deflowLoss train_aug=True "voxel_size=[0.15, 0.15, 0.15]" "point_cloud_range=[-38.4, -38.4, -3, 38.4, 38.4, 3]" +optimizer.scheduler.name=WarmupCosLR +optimizer.scheduler.max_lr=2e-3 +optimizer.scheduler.total_steps=20000 +python train.py model=deltaflow optimizer.lr=2e-3 epochs=20 batch_size=2 num_frames=5 \ + loss_fn=deflowLoss train_aug=True "voxel_size=[0.15, 0.15, 0.15]" "point_cloud_range=[-38.4, -38.4, -3, 38.4, 38.4, 3]" \ + optimizer.lr=2e-4 +optimizer.scheduler.name=WarmupCosLR +optimizer.scheduler.max_lr=2e-3 +optimizer.scheduler.warmup_epochs=2 # Pretrained weight can be downloaded through (av2), check all other datasets in the same folder. wget https://huggingface.co/kin-zhang/OpenSceneFlow/resolve/main/deltaflow/deltaflow-av2.ckpt @@ -206,6 +208,19 @@ Train Feed-forward SSL methods (e.g. SeFlow/SeFlow++/VoteFlow etc), we needed to 1) process auto-label process for training. Check [dataprocess/README.md#self-supervised-process](dataprocess/README.md#self-supervised-process) for more details. We provide these inside the demo dataset already. 2) specify the loss function, we set the config here for our best model in the leaderboard. +#### TeFlow + +```bash +# [Runtime: Around ? hours in 10x GPUs.] +python train.py model=deltaflow epochs=15 batch_size=2 num_frames=5 train_aug=True \ + loss_fn=teflowLoss "voxel_size=[0.15, 0.15, 0.15]" "point_cloud_range=[-38.4, -38.4, -3, 38.4, 38.4, 3]" \ + +ssl_label=seflow_auto "+add_seloss={chamfer_dis: 1.0, static_flow_loss: 1.0, dynamic_chamfer_dis: 1.0, cluster_based_pc0pc1: 1.0}" \ + optimizer.name=Adam optimizer.lr=2e-3 +optimizer.scheduler.name=StepLR +optimizer.scheduler.step_size=9 +optimizer.scheduler.gamma=0.5 + +# Pretrained weight can be downloaded through: +wget https://huggingface.co/kin-zhang/OpenSceneFlow/resolve/main/teflow/teflow-av2.ckpt +``` + #### SeFlow ```bash diff --git a/assets/README.md b/assets/README.md index d49ebe8..f14b78c 100644 --- a/assets/README.md +++ b/assets/README.md @@ -51,7 +51,30 @@ Then follow [this stackoverflow answers](https://stackoverflow.com/questions/596 ```bash cd OpenSceneFlow && docker build -f Dockerfile -t zhangkin/opensf . ``` - + +### To Apptainer container + +If you want to build a **minimal** training env for Apptainer container, you can use the following command: +```bash +apptainer build opensf.sif assets/opensf.def +# zhangkin/opensf:full is created by Dockerfile +``` + +Then run as a Python env with: +```bash +PYTHON="apptainer run --nv --writable-tmpfs opensf.sif" +$PYTHON train.py +``` + + + + ## Installation We will use conda to manage the environment with mamba for faster package installation. @@ -77,10 +100,11 @@ Checking important packages in our environment now: ```bash mamba activate opensf python -c "import torch; print(torch.__version__); print(torch.cuda.is_available()); print(torch.version.cuda)" -python -c "import lightning.pytorch as pl; print(pl.__version__)" +python -c "import lightning.pytorch as pl; print('pl version:', pl.__version__)" +python -c "import spconv.pytorch as spconv; print('spconv import successfully')" python -c "from assets.cuda.mmcv import Voxelization, DynamicScatter;print('successfully import on our lite mmcv package')" python -c "from assets.cuda.chamfer3D import nnChamferDis;print('successfully import on our chamfer3D package')" -python -c "from av2.utils.io import read_feather; print('av2 package ok')" +python -c "from av2.utils.io import read_feather; print('av2 package ok') " ``` @@ -98,6 +122,7 @@ python -c "from av2.utils.io import read_feather; print('av2 package ok')" 2. In cluster have error: `pandas ImportError: /lib64/libstdc++.so.6: version 'GLIBCXX_3.4.29' not found` Solved by `export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/proj/berzelius-2023-154/users/x_qinzh/mambaforge/lib` +4. nvidia channel cannot put into env.yaml file otherwise, the cuda-toolkit will always be the latest one, for me (2025-04-30) I struggling on an hour and get nvcc -V also 12.8 at that time. py=3.10 for cuda >=12.1. (seems it's nvidia cannot be in the channel list???); py<3.10 for cuda <=11.8.0: otherwise 10x, 20x series GPU won't work on cuda compiler. (half precision) 3. torch_scatter problem: `OSError: /home/kin/mambaforge/envs/opensf-v2/lib/python3.10/site-packages/torch_scatter/_version_cpu.so: undefined symbol: _ZN5torch3jit17parseSchemaOrNameERKNSt7__cxx1112basic_stringIcSt11char_traitsIcESaIcEEE` Solved by install the torch-cuda version: `pip install https://data.pyg.org/whl/torch-2.0.0%2Bcu118/torch_scatter-2.1.2%2Bpt20cu118-cp310-cp310-linux_x86_64.whl` diff --git a/assets/cuda/chamfer3D/__init__.py b/assets/cuda/chamfer3D/__init__.py index fc5020d..a9971b2 100644 --- a/assets/cuda/chamfer3D/__init__.py +++ b/assets/cuda/chamfer3D/__init__.py @@ -2,116 +2,222 @@ # Created: 2023-08-04 11:20 # Copyright (C) 2023-now, RPL, KTH Royal Institute of Technology # Author: Qingwen Zhang (https://kin-zhang.github.io/) -# +# # This file is part of SeFlow (https://github.com/KTH-RPL/SeFlow). -# If you find this repo helpful, please cite the respective publication as +# If you find this repo helpful, please cite the respective publication as # listed on the above website. -# -# -# Description: ChamferDis speedup using CUDA +# +# Description: ChamferDis speedup using CUDA. +# +# NOTE(2026-03-11, Qingwen) Why CUDA streams (not batched kernel): +# At N=88K pts/sample on RTX 3090, one sample already uses 4.2 SM waves, +# so any kernel-level batching hits the same hardware ceiling. +# Streams give ~1.14× speedup by overlapping B independent kernel launches. +# More importantly, they keep the GPU busy with fewer CPU-GPU sync gaps, +# preventing GPU utilization from spiking which triggers cluster job kills. +# """ +from __future__ import annotations + from torch import nn from torch.autograd import Function -import torch - -import os, time +import torch, os, time +from typing import List import chamfer3D -BASE_DIR = os.path.abspath(os.path.join( os.path.dirname( __file__ ), '../..' )) + +BASE_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), '../..')) -# GPU tensors only class ChamferDis(Function): + """Single-sample Chamfer distance: pc0 (N,3) × pc1 (M,3) on GPU.""" + @staticmethod def forward(ctx, pc0, pc1): - # pc0: (N,3), pc1: (M,3) - dis0 = torch.zeros(pc0.shape[0]).to(pc0.device).contiguous() - dis1 = torch.zeros(pc1.shape[0]).to(pc1.device).contiguous() - - idx0 = torch.zeros(pc0.shape[0], dtype=torch.int32).to(pc0.device).contiguous() - idx1 = torch.zeros(pc1.shape[0], dtype=torch.int32).to(pc1.device).contiguous() - - + dis0 = torch.zeros(pc0.shape[0], device=pc0.device).contiguous() + dis1 = torch.zeros(pc1.shape[0], device=pc1.device).contiguous() + idx0 = torch.zeros(pc0.shape[0], dtype=torch.int32, device=pc0.device).contiguous() + idx1 = torch.zeros(pc1.shape[0], dtype=torch.int32, device=pc1.device).contiguous() chamfer3D.forward(pc0, pc1, dis0, dis1, idx0, idx1) ctx.save_for_backward(pc0, pc1, idx0, idx1) return dis0, dis1, idx0, idx1 @staticmethod - def backward(ctx, grad_dist0, grad_dist1, grad_idx0, grad_idx1): + def backward(ctx, gd0, gd1, _gi0, _gi1): pc0, pc1, idx0, idx1 = ctx.saved_tensors - grad_dist0 = grad_dist0.contiguous() - grad_dist1 = grad_dist1.contiguous() - device = grad_dist1.device - - grad_pc0 = torch.zeros(pc0.size()).to(device).contiguous() - grad_pc1 = torch.zeros(pc1.size()).to(device).contiguous() - - chamfer3D.backward( - pc0, pc1, idx0, idx1, grad_dist0, grad_dist1, grad_pc0, grad_pc1 - ) - return grad_pc0, grad_pc1 - + gpc0 = torch.zeros_like(pc0) + gpc1 = torch.zeros_like(pc1) + chamfer3D.backward(pc0, pc1, idx0, idx1, + gd0.contiguous(), gd1.contiguous(), gpc0, gpc1) + return gpc0, gpc1 + +# ─── nn.Module ──────────────────────────────────────────────────────────────── class nnChamferDis(nn.Module): - def __init__(self, truncate_dist=True): - super(nnChamferDis, self).__init__() + """Chamfer distance loss — single and batched-via-streams modes. + + Methods + ------- + forward(pc0, pc1) + Single-sample loss. Used by seflowLoss / seflowppLoss. + + batched/batched_disid_res (pc0_list, pc1_list) + Parallel loss across B samples via CUDA streams. + Returns mean-over-samples scalar. + Used by batched_chamfer_related() for chamfer_dis / dynamic_chamfer_dis. + + dis_res(pc0, pc1) → (dist0, dist1), no reduction + disid_res(pc0, pc1) → (dist0, dist1, idx0, idx1), no reduction + truncated_dis(pc0, pc1) → NSFP-style truncated loss + """ + + def __init__(self, truncate_dist: bool = True): + super().__init__() self.truncate_dist = truncate_dist + # Pre-allocate streams once to avoid per-call creation overhead (~50 µs each) + self._streams: List[torch.cuda.Stream] = [] + + def _ensure_streams(self, n: int) -> List[torch.cuda.Stream]: + while len(self._streams) < n: + self._streams.append(torch.cuda.Stream()) + return self._streams[:n] + + # ── single-sample forward ───────────────────────────────────────────────── + + def forward(self, input0: torch.Tensor, input1: torch.Tensor, + truncate_dist: float = -1, **_ignored) -> torch.Tensor: + """Single-sample Chamfer loss. truncate_dist<=0 → no truncation.""" + dist0, dist1, _, _ = ChamferDis.apply(input0.contiguous(), input1.contiguous()) + if truncate_dist <= 0: + return dist0.mean() + dist1.mean() + v0, v1 = dist0 <= truncate_dist, dist1 <= truncate_dist + return torch.nanmean(dist0[v0]) + torch.nanmean(dist1[v1]) + + # ── batched loss via CUDA streams ───────────────────────────────────────── + + def batched(self, + pc0_list: List[torch.Tensor], + pc1_list: List[torch.Tensor], + truncate_dist: float = -1) -> torch.Tensor: + """Parallel Chamfer loss via B CUDA streams. + + Returns mean-over-samples: (1/B) * Σ_i [mean(dist0_i) + mean(dist1_i)]. + ~1.14× faster than serial loop on RTX 3090 @ 88K pts/sample; + more importantly, keeps GPU busy with one sustained work block per frame. + """ + B = len(pc0_list) + if B == 1: + return self.forward(pc0_list[0], pc1_list[0], truncate_dist) + + streams = self._ensure_streams(B) + main = torch.cuda.current_stream() + per_loss: List[torch.Tensor] = [None] * B # type: ignore[list-item] + + for i in range(B): + streams[i].wait_stream(main) + with torch.cuda.stream(streams[i]): + d0, d1, _, _ = ChamferDis.apply(pc0_list[i].contiguous(), + pc1_list[i].contiguous()) + if truncate_dist <= 0: + per_loss[i] = d0.mean() + d1.mean() + else: + v0, v1 = d0 <= truncate_dist, d1 <= truncate_dist + per_loss[i] = torch.nanmean(d0[v0]) + torch.nanmean(d1[v1]) + + for i in range(B): + main.wait_stream(streams[i]) + + return torch.stack(per_loss).mean() + + # ── batched disid_res via CUDA streams (for cluster precomputation) ─────── + + def batched_disid_res(self, + pc0_list: List[torch.Tensor], + pc1_list: List[torch.Tensor], + ) -> tuple[List[torch.Tensor], List[torch.Tensor]]: + """Parallel disid_res across B samples via CUDA streams. + + Same list-in / list-out convention as batched(). + + Returns + ------- + dist0_list : List[(N_i,)] per-point nearest distance in pc1_i + idx0_list : List[(N_i,)] LOCAL index into pc1_list[i] (0 .. M_i-1) + + Usage: + dist0_list, idx0_list = fn.batched_disid_res(pc0_list, pc1_list) + neighbour = pc1_list[i][idx0_list[i][mask]] # no global arithmetic + """ + B = len(pc0_list) + if B == 1: + d0, _, i0, _ = ChamferDis.apply(pc0_list[0].contiguous(), pc1_list[0].contiguous()) + return [d0], [i0] + + streams = self._ensure_streams(B) + main = torch.cuda.current_stream() + d0_list: List[torch.Tensor] = [None] * B # type: ignore[list-item] + i0_list: List[torch.Tensor] = [None] * B # type: ignore[list-item] + + for i in range(B): + streams[i].wait_stream(main) + with torch.cuda.stream(streams[i]): + d0, _, idx0, _ = ChamferDis.apply(pc0_list[i].contiguous(), + pc1_list[i].contiguous()) + d0_list[i] = d0 + i0_list[i] = idx0 # local index — no offset arithmetic needed + + for i in range(B): + main.wait_stream(streams[i]) + + return d0_list, i0_list + + # ── utilities ───────────────────────────────────────────────────────────── + + def dis_res(self, input0: torch.Tensor, input1: torch.Tensor): + """Return raw (dist0, dist1) without reduction.""" + d0, d1, _, _ = ChamferDis.apply(input0.contiguous(), input1.contiguous()) + return d0, d1 + + def disid_res(self, input0: torch.Tensor, input1: torch.Tensor): + """Return raw (dist0, dist1, idx0, idx1) without reduction.""" + return ChamferDis.apply(input0.contiguous(), input1.contiguous()) + + def truncated_dis(self, input0: torch.Tensor, input1: torch.Tensor, + truncate_dist: float = 2.0) -> torch.Tensor: + """NSFP-style: distances >= threshold clamped to 0, then mean.""" + cx, cy = self.dis_res(input0, input1) + cx[cx >= truncate_dist] = 0.0 + cy[cy >= truncate_dist] = 0.0 + return cx.mean() + cy.mean() + - def forward(self, input0, input1, truncate_dist=-1): - input0 = input0.contiguous() - input1 = input1.contiguous() - dist0, dist1, _, _ = ChamferDis.apply(input0, input1) - - if truncate_dist<=0: - return torch.mean(dist0) + torch.mean(dist1) - - valid_mask0 = (dist0 <= truncate_dist) - valid_mask1 = (dist1 <= truncate_dist) - truncated_sum = torch.nanmean(dist0[valid_mask0]) + torch.nanmean(dist1[valid_mask1]) - return truncated_sum - - def dis_res(self, input0, input1): - input0 = input0.contiguous() - input1 = input1.contiguous() - dist0, dist1, _, _ = ChamferDis.apply(input0, input1) - return dist0, dist1 - - def truncated_dis(self, input0, input1, truncate_dist=2): - # nsfp: truncated distance way is set >= 2 to 0 but not nanmean - cham_x, cham_y = self.dis_res(input0, input1) - cham_x[cham_x >= truncate_dist] = 0.0 - cham_y[cham_y >= truncate_dist] = 0.0 - return torch.mean(cham_x) + torch.mean(cham_y) - - def disid_res(self, input0, input1): - input0 = input0.contiguous() - input1 = input1.contiguous() - dist0, dist1, idx0, idx1 = ChamferDis.apply(input0, input1) - return dist0, dist1, idx0, idx1 -class NearestNeighborDis(nn.Module): - def __init__(self): - super(NearestNeighborDis, self).__init__() - - def forward(self, input0, input1): - input0 = input0.contiguous() - input1 = input1.contiguous() - dist0, dist1, _, _ = ChamferDis.apply(input0, input1) - - return torch.mean(dist0[dist0 <= 2]) - if __name__ == "__main__": import numpy as np - pc0 = np.load(f'{BASE_DIR}/assets/tests/test_pc0.npy') - pc1 = np.load(f'{BASE_DIR}/assets/tests/test_pc1.npy') - print('0: {:.3f}MB'.format(torch.cuda.memory_allocated()/1024**2)) - pc0 = torch.from_numpy(pc0[...,:3]).float().cuda().contiguous() - pc1 = torch.from_numpy(pc1[...,:3]).float().cuda().contiguous() - pc0.requires_grad = True - pc1.requires_grad = True - print(pc0.shape, "demo data: ", pc0[0]) - print(pc1.shape, "demo data: ", pc1[0]) - print('1: {:.3f}MB'.format(torch.cuda.memory_allocated()/1024**2)) - - start_time = time.time() - loss = nnChamferDis(truncate_dist=False)(pc0, pc1) - loss.backward() - print("loss: ", loss) - print(f"Chamfer Distance Cal time: {(time.time() - start_time)*1000:.3f} ms") \ No newline at end of file + pc0_np = np.load(f'{BASE_DIR}/tests/test_pc0.npy')[..., :3] + pc1_np = np.load(f'{BASE_DIR}/tests/test_pc1.npy')[..., :3] + pc0 = torch.from_numpy(pc0_np).float().cuda() + pc1 = torch.from_numpy(pc1_np).float().cuda() + fn = nnChamferDis(truncate_dist=False) + + loss_s = fn(pc0, pc1) + print(f"Single: {loss_s.item():.6f}") + + for B in [2, 4, 8]: + lb = fn.batched([pc0.clone()]*B, [pc1.clone()]*B) + print(f"Batched B={B}: {lb.item():.6f} {'✓' if torch.allclose(loss_s, lb, atol=1e-5) else '✗'}") + + # Test batched_disid_res global indexing + print("\n--- batched_disid_res global index test ---") + B = 2 + pc0_b = torch.cat([pc0]*B) + pc1_b = torch.cat([pc1]*B) + N0, N1 = pc0.shape[0], pc1.shape[0] + offs0 = torch.tensor([0, N0], dtype=torch.int32, device='cuda') + szs0 = torch.tensor([N0, N0], dtype=torch.int32, device='cuda') + offs1 = torch.tensor([0, N1], dtype=torch.int32, device='cuda') + szs1 = torch.tensor([N1, N1], dtype=torch.int32, device='cuda') + pc0_lst = [pc0]*B + pc1_lst = [pc1]*B + d0_lst_out, i0_lst_out = fn.batched_disid_res(pc0_lst, pc1_lst) + assert len(d0_lst_out) == B and len(i0_lst_out) == B, "wrong list length" + for j in range(B): + assert (i0_lst_out[j] < N1).all(), f"sample-{j} idx out of range" + print("Local index check: ✓") \ No newline at end of file diff --git a/assets/opensf.def b/assets/opensf.def new file mode 100644 index 0000000..bddebf4 --- /dev/null +++ b/assets/opensf.def @@ -0,0 +1,14 @@ +Bootstrap: docker +From: zhangkin/opensf:full + +%files + assets/cuda /workspace/assets/cuda + src/models/basic/voteflow_plugin /workspace/src/models/basic/voteflow_plugin + environment.yaml /workspace/environment.yaml + +%runscript + echo "Running pip install for local CUDA modules..." + /opt/conda/envs/opensf/bin/pip install /workspace/assets/cuda/chamfer3D + /opt/conda/envs/opensf/bin/pip install /workspace/assets/cuda/mmcv + /opt/conda/envs/opensf/bin/pip install /workspace/src/models/basic/voteflow_plugin/hough_transformation/cpp_im2ht + exec /opt/conda/envs/opensf/bin/python "$@" \ No newline at end of file diff --git a/assets/slurm/1_train.sh b/assets/slurm/1_train.sh deleted file mode 100644 index dd99a8d..0000000 --- a/assets/slurm/1_train.sh +++ /dev/null @@ -1,34 +0,0 @@ -#!/bin/bash -#SBATCH -J seflow -#SBATCH --gpus 4 -C "fat" -#SBATCH -t 3-00:00:00 -#SBATCH --mail-type=END,FAIL -#SBATCH --mail-user=qingwen@kth.se -#SBATCH --output /proj/berzelius-2023-154/users/x_qinzh/seflow/logs/slurm/%J_seflow.out -#SBATCH --error /proj/berzelius-2023-154/users/x_qinzh/seflow/logs/slurm/%J_seflow.err - -PYTHON=/proj/berzelius-2023-154/users/x_qinzh/mambaforge/envs/opensf/bin/python -export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/proj/berzelius-2023-154/users/x_qinzh/mambaforge/lib -cd /proj/berzelius-2023-364/users/x_qinzh/workspace/OpenSceneFlow - - -# ===> to transfer data into local node disk, it can be ignored. <=== -SOURCE="/proj/berzelius-2023-364/users/x_qinzh/data/av2/autolabel" -DEST="/scratch/local/av2" -SUBDIRS=("sensor/train" "sensor/val") - -start_time=$(date +%s) -for dir in "${SUBDIRS[@]}"; do - mkdir -p "${DEST}/${dir}" - find "${SOURCE}/${dir}" -type f -print0 | xargs -0 -n1 -P16 cp -t "${DEST}/${dir}" & -done -wait -end_time=$(date +%s) -elapsed=$((end_time - start_time)) -echo "Copy ${SOURCE} to ${DEST} Total time: ${elapsed} seconds" -echo "Start training..." - -# ====> leaderboard model = seflow_best -$PYTHON train.py slurm_id=$SLURM_JOB_ID wandb_mode=online train_data=/scratch/local/av2/sensor/train val_data=/scratch/local/av2/sensor/val \ - num_workers=16 model=deflow lr=2e-4 epochs=9 batch_size=16 "model.target.num_iters=2" "model.val_monitor=val/Dynamic/Mean" \ - loss_fn=seflowLoss "add_seloss={chamfer_dis: 1.0, static_flow_loss: 1.0, dynamic_chamfer_dis: 1.0, cluster_based_pc0pc1: 1.0}" diff --git a/assets/slurm/2_eval.sh b/assets/slurm/2_eval.sh deleted file mode 100644 index 1a57440..0000000 --- a/assets/slurm/2_eval.sh +++ /dev/null @@ -1,20 +0,0 @@ -#!/bin/bash -#SBATCH -J eval -#SBATCH --gpus 1 -#SBATCH -t 01:00:00 -#SBATCH --output /proj/berzelius-2023-154/users/x_qinzh/seflow/logs/slurm/%J_eval.out -#SBATCH --error /proj/berzelius-2023-154/users/x_qinzh/seflow/logs/slurm/%J_eval.err - - -PYTHON=/proj/berzelius-2023-154/users/x_qinzh/mambaforge/envs/opensf/bin/python -export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/proj/berzelius-2023-154/users/x_qinzh/mambaforge/lib -cd /proj/berzelius-2023-364/users/x_qinzh/workspace/OpenSceneFlow - - -# ====> leaderboard model -# $PYTHON eval.py wandb_mode=online dataset_path=/proj/berzelius-2023-364/users/x_qinzh/data/av2/autolabel data_mode=test \ -# checkpoint=/proj/berzelius-2023-154/users/x_qinzh/seflow/logs/wandb/seflow-10086990/checkpoints/epoch_19_seflow.ckpt \ -# save_res=True - -$PYTHON eval.py wandb_mode=online dataset_path=/proj/berzelius-2023-364/users/x_qinzh/data/av2/autolabel data_mode=val \ - checkpoint=/proj/berzelius-2023-154/users/x_qinzh/seflow/logs/wandb/seflow-10086990/checkpoints/epoch_19_seflow.ckpt \ No newline at end of file diff --git a/src/lossfuncs/__init__.py b/src/lossfuncs/__init__.py index 7bf446b..601b567 100644 --- a/src/lossfuncs/__init__.py +++ b/src/lossfuncs/__init__.py @@ -17,3 +17,5 @@ from .selfsupervise import * from .supervise import * +# automatic collection of SSL loss function names for trainer.py +SSL_LOSSES_FN = {name for name in dir(selfsupervise) if name.endswith('Loss') and callable(getattr(selfsupervise, name))} \ No newline at end of file diff --git a/src/lossfuncs/selfsupervise.py b/src/lossfuncs/selfsupervise.py index dccc923..e1238c9 100644 --- a/src/lossfuncs/selfsupervise.py +++ b/src/lossfuncs/selfsupervise.py @@ -10,9 +10,20 @@ # # If you find this repo helpful, please cite the respective publication as # listed on the above website. -# -# Description: Define the self-supervised (without GT) loss function for training. # +# Description: Self-supervised loss functions. +# +# All losses receive a unified dict from ssl_loss_calculator (trainer.py). +# Every frame is represented only as a List[Tensor] — no flat/offsets/sizes. +# +# res_dict keys (per frame 'pc0', 'pc1', 'pch1', ...): +# '{frame}_list' : List[Tensor (N_i,3)] one tensor per sample +# '{frame}_labels' : List[Tensor (N_i,)] one label vector per sample +# +# 'est_flow_list' : List[Tensor (N_i,3)] +# 'batch_size' : int +# 'loss_weights_dict': dict (teflow* only) +# 'cluster_loss_args': dict (teflowLoss only) """ import torch from assets.cuda.chamfer3D import nnChamferDis @@ -22,169 +33,374 @@ # If your scenario is different, may need adjust this TRUNCATED to 80-120km/h vel. TRUNCATED_DIST = 4 -def seflowppLoss(res_dict, timer=None): - pch1_label = res_dict['pch1_labels'] - pc0_label = res_dict['pc0_labels'] - pc1_label = res_dict['pc1_labels'] - - pch1 = res_dict['pch1'] - pc0 = res_dict['pc0'] - pc1 = res_dict['pc1'] - - est_flow = res_dict['est_flow'] - - pseudo_pc1from0 = pc0 + est_flow - pseduo_pch1from0 = pc0 - est_flow - - unique_labels = torch.unique(pc0_label) - pc0_dynamic = pc0[pc0_label > 0] - pc1_dynamic = pc1[pc1_label > 0] - - # fpc1_dynamic = pseudo_pc1from0[pc0_label > 0] - # NOTE(Qingwen): since we set THREADS_PER_BLOCK is 256 - have_dynamic_cluster = (pc0_dynamic.shape[0] > 256) & (pc1_dynamic.shape[0] > 256) - - # first item loss: chamfer distance - # timer[5][1].start("MyCUDAChamferDis") - chamfer_dis = MyCUDAChamferDis(pseudo_pc1from0, pc1, truncate_dist=TRUNCATED_DIST) + MyCUDAChamferDis(pseduo_pch1from0, pch1, truncate_dist=TRUNCATED_DIST) - # timer[5][1].stop() - - # second item loss: dynamic chamfer distance - # timer[5][2].start("DynamicChamferDistance") - dynamic_chamfer_dis = torch.tensor(0.0, device=est_flow.device) - if have_dynamic_cluster: - dynamic_chamfer_dis += MyCUDAChamferDis(pseudo_pc1from0[pc0_label > 0], pc1_dynamic, truncate_dist=TRUNCATED_DIST) - if pch1[pch1_label > 0].shape[0] > 256: - dynamic_chamfer_dis += MyCUDAChamferDis(pseduo_pch1from0[pc0_label > 0], pch1[pch1_label > 0], truncate_dist=TRUNCATED_DIST) - # timer[5][2].stop() - - # third item loss: exclude static points' flow - # NOTE(Qingwen): add in the later part on label==0 - static_cluster_loss = torch.tensor(0.0, device=est_flow.device) - - # fourth item loss: same label points' flow should be the same - # timer[5][3].start("SameClusterLoss") - # raw: pc0 to pc1, est: pseudo_pc1from0 to pc1, idx means the nearest index - raw_dist0, raw_dist1, raw_idx0, _ = MyCUDAChamferDis.disid_res(pc0, pc1) - moved_cluster_loss = torch.tensor(0.0, device=est_flow.device) - moved_cluster_norms = torch.tensor([], device=est_flow.device) - for label in unique_labels: - mask = pc0_label == label - if label == 0: - # Eq. 6 in the SeFlow paper - static_cluster_loss += torch.linalg.vector_norm(est_flow[mask, :], dim=-1).mean() - # NOTE(Qingwen) 2025-04-23: label=1 is dynamic but no cluster id satisfied - elif label > 1 and have_dynamic_cluster: - cluster_id_flow = est_flow[mask, :] - cluster_nnd = raw_dist0[mask] - if cluster_nnd.shape[0] <= 0: +# FIXME(Qingwen 25-07-21): hardcoded 10 Hz. Adjust for datasets with different timestamps. +DELTA_T = 0.1 # seconds + + +# ---- helpers ----------------------------------------------------------------- + +def get_time_delta(frame_id): + """Return (time_delta, factor). + pch1->(-0.1,1), pch2->(-0.2,2), pc1->(+0.1,1), pc2->(+0.2,2) + """ + if frame_id.startswith('pch'): + n = int(frame_id[3:]) if len(frame_id) > 3 else 1 + return -DELTA_T * n, n + elif frame_id.startswith('pc'): + n = int(frame_id[2:]) if len(frame_id) > 2 else 1 + return DELTA_T * n, n + raise ValueError(f"Unknown frame ID: {frame_id}") + + +def _frame_keys(res_dict): + """Auxiliary frame ids present in res_dict (e.g. ['pc1', 'pch1']), excluding pc0.""" + return [k.replace('_list', '') for k in res_dict + if k.endswith('_list') \ + and k != 'pc0_list' and k != 'est_flow_list' and not k.endswith('_labels_list')] + + +# ---- helpers shared by teflow* ----------------------------------------------- + +def batched_chamfer_related(res_dict, timer=None): + """Chamfer + dynamic-chamfer over all auxiliary frames via CUDA streams. + + Returns + ------- + total_chamfer_dis, total_dynamic_chamfer_dis : scalar Tensors + frame_keys : List[str] + """ + pc0_list = res_dict['pc0_list'] + flow_list = res_dict['est_flow_list'] + pc0_lab_list = res_dict['pc0_labels_list'] + frame_keys = _frame_keys(res_dict) + loss_w = res_dict['loss_weights_dict'] + chamfer_w = loss_w.get('chamfer_dis', 0.0) + dyn_chamfer_w = loss_w.get('dynamic_chamfer_dis', 0.0) + + total_chamfer_dis = torch.tensor(0.0, device=pc0_list[0].device) + total_dynamic_chamfer_dis = torch.tensor(0.0, device=pc0_list[0].device) + + for frame_id in frame_keys: + time_delta, factor = get_time_delta(frame_id) + weight = 1.0 if frame_id == 'pc1' else 1.0 / pow(2, factor) + target_list = res_dict[f'{frame_id}_list'] + + # Projected positions: list comprehension keeps everything per-sample + proj_list = [p0 + (fv / DELTA_T) * time_delta + for p0, fv in zip(pc0_list, flow_list)] + + if chamfer_w > 0: + total_chamfer_dis += MyCUDAChamferDis.batched( + proj_list, target_list, truncate_dist=TRUNCATED_DIST * factor + ) * weight + + if dyn_chamfer_w <= 0: + continue + + tgt_lab_list = res_dict[f'{frame_id}_labels_list'] + proj_dyn, tgt_dyn = [], [] + for proj_i, p0_lab_i, tgt_i, tgt_lab_i in zip( + proj_list, pc0_lab_list, target_list, tgt_lab_list): + dp = proj_i[p0_lab_i > 0] + dt = tgt_i[tgt_lab_i > 0] + if dp.shape[0] > 256 and dt.shape[0] > 256: + proj_dyn.append(dp) + tgt_dyn.append(dt) + + if len(proj_dyn) == 1: + total_dynamic_chamfer_dis += MyCUDAChamferDis( + proj_dyn[0], tgt_dyn[0], truncate_dist=TRUNCATED_DIST * factor + ) * weight + elif len(proj_dyn) > 1: + total_dynamic_chamfer_dis += MyCUDAChamferDis.batched( + proj_dyn, tgt_dyn, truncate_dist=TRUNCATED_DIST * factor + ) * weight + + n = len(frame_keys) + if n > 0: + total_chamfer_dis /= n + total_dynamic_chamfer_dis /= n + + return total_chamfer_dis, total_dynamic_chamfer_dis, frame_keys + + +def multi_frames_clusterLoss( + pc0_list, pc0_lab_list, flow_list, + frame_keys, frames_dists, frames_indices, res_dict, args={} +): + """RANSAC-weighted cluster consistency loss across multiple temporal frames. + + frames_dists[frame_id] : List[(N_i,)] per-sample dist from batched_disid_res + frames_indices[frame_id] : List[(N_i,)] per-sample LOCAL idx into frame_list[i] + """ + TOP_K = int(args.get('top_k_candidates', 5)) + COS_THRESH = args.get('ransac_cos_threshold', 0.7071) + TIME_DECAY = args.get('time_decay_factor', 0.9) + NET_EST_W = args.get('network_estimate_weight', 1.0) + + all_cluster_flows, all_target_flows, all_avg_losses = [], [], [] + + for i, (p0, lab0, fv) in enumerate(zip(pc0_list, pc0_lab_list, flow_list)): + for label in torch.unique(lab0): + if label <= 1: continue - # Eq. 8 in the SeFlow paper - sorted_idxs = torch.argsort(cluster_nnd, descending=True) - nearby_label = pc1_label[raw_idx0[mask][sorted_idxs]] # nonzero means dynamic in label - non_zero_valid_indices = torch.nonzero(nearby_label > 0) - if non_zero_valid_indices.shape[0] <= 0: + cluster_mask = (lab0 == label) + cluster_flows = fv[cluster_mask] + + ext_flows, ext_dists, ext_tw = [], [], [] + for frame_id in frame_keys: + dist_c = frames_dists[frame_id][i][cluster_mask] + idx_c = frames_indices[frame_id][i][cluster_mask] + if dist_c.shape[0] <= TOP_K: + continue + topk_dists, topk_local = torch.topk(dist_c, k=TOP_K) + target_pts = res_dict[f'{frame_id}_list'][i][idx_c[topk_local]] + src_pts = p0[cluster_mask][topk_local] + time_delta, factor = get_time_delta(frame_id) + # Eq. 3 in the TeFlow paper, with time decay and directionality + flows = (target_pts - src_pts) / factor * (-1 if time_delta < 0 else 1) + ext_flows.append(flows) + ext_dists.append(topk_dists) + ext_tw.append(torch.full((TOP_K,), pow(TIME_DECAY, factor), device=p0.device)) + + if not ext_flows: continue - max_idx = sorted_idxs[non_zero_valid_indices.squeeze(1)[0]] - # Eq. 9 in the SeFlow paper - max_flow = pc1[raw_idx0[mask][max_idx]] - pc0[mask][max_idx] - - # Eq. 10 in the SeFlow paper - moved_cluster_norms = torch.cat((moved_cluster_norms, torch.linalg.vector_norm((cluster_id_flow - max_flow), dim=-1))) - - if moved_cluster_norms.shape[0] > 0: - moved_cluster_loss = moved_cluster_norms.mean() # Eq. 11 in the SeFlow paper - elif have_dynamic_cluster: - moved_cluster_loss = torch.mean(raw_dist0[raw_dist0 <= TRUNCATED_DIST]) + torch.mean(raw_dist1[raw_dist1 <= TRUNCATED_DIST]) - # timer[5][3].stop() - - res_loss = { - 'chamfer_dis': chamfer_dis / 2.0, - 'dynamic_chamfer_dis': dynamic_chamfer_dis / 2.0, - 'static_flow_loss': static_cluster_loss, + # Eq. 2 in the TeFlow paper + net_avg = cluster_flows.mean(dim=0) + net_mag = torch.linalg.norm(net_avg) + # Eq. 4 in the TeFlow paper + all_cands = torch.cat(ext_flows + [net_avg.unsqueeze(0)], dim=0) + all_d = torch.cat(ext_dists + [net_mag.unsqueeze(0)], dim=0) + all_tw = torch.cat(ext_tw, dim=0) + if all_cands.shape[0] < 2: + continue + + d_norm = (all_d - all_d.min()) / (all_d.max() - all_d.min() + 1e-6) + # Eq. 5 + cos_sim = torch.nn.functional.cosine_similarity( + all_cands[:, None, :], all_cands[None, :, :], dim=-1) + inlier = cos_sim > COS_THRESH + # Eq. 6 + weights = torch.cat([all_tw * (1 + d_norm[:-1]), + (NET_EST_W * (1 + d_norm[-1])).unsqueeze(0)]) + # Eq. 7 + scores = torch.matmul(inlier.float(), weights.unsqueeze(1)).squeeze() + best = torch.argmax(scores) + + # Eq. 8 + inlier_flows = all_cands[inlier[best]] + inlier_w = weights[inlier[best]] + denom = inlier_w.sum() + target_flow = (inlier_w.unsqueeze(1) * inlier_flows).sum(dim=0) / denom \ + if denom > 1e-6 else all_cands[best] + + all_cluster_flows.append(cluster_flows) + all_target_flows.append(target_flow.expand_as(cluster_flows)) + all_avg_losses.append( + torch.linalg.vector_norm(cluster_flows - target_flow, dim=-1).mean() + ) + + # FIXME(Qingwen): maybe afterward we can have weight here to specific different weight on point/cluster etc. + if not all_cluster_flows: + return torch.tensor(0.0, device=flow_list[0].device) + # Eq. 9 with two terms + # NOTE(Qingwen): Point-level term + loss = torch.nn.functional.mse_loss( + torch.cat(all_cluster_flows), torch.cat(all_target_flows) + ) + # NOTE(Qingwen): Cluster-level term + loss += torch.stack(all_avg_losses).mean() + return loss + + +# ---- shared cluster loop (seflow / seflowpp) ------------------- +# SeFlow Paper: https://arxiv.org/pdf/2407.01702 +def _seflow_cluster_loop(pc0_list, pc1_list, pc0_lab_list, pc1_lab_list, + flow_list, dist0_list, idx0_list): + """Per-sample seflow cluster loss (Eq. 6-11). + + dist0_list, idx0_list : output of batched_disid_res(pc0_list, pc1_list) + idx0_list[i] is LOCAL into pc1_list[i]. + Returns (static_cluster_loss, moved_cluster_loss, have_any_dynamic). + """ + dev = flow_list[0].device + static_loss = torch.tensor(0.0, device=dev) + cluster_norms = [] + fallback_dists = [] + have_any_dyn = False + + for p0, p1, lab0, lab1, fv, dist0, idx0 in zip( + pc0_list, pc1_list, pc0_lab_list, pc1_lab_list, + flow_list, dist0_list, idx0_list): + have_dyn = (lab0 > 0).sum() > 256 and (lab1 > 0).sum() > 256 + if have_dyn: + have_any_dyn = True + fallback_dists.append(dist0) + + for label in torch.unique(lab0): + mask = (lab0 == label) + if label == 0: + # Eq. 6 in the paper + static_loss += torch.linalg.vector_norm(fv[mask], dim=-1).mean() + elif label > 1 and have_dyn: + c_flow = fv[mask] + c_idx0 = idx0[mask] + # Eq. 8 in the paper + sorted_local = torch.argsort(dist0[mask], descending=True) + max_idx = torch.nonzero(lab1[c_idx0[sorted_local]] > 0).squeeze(1) + if max_idx.shape[0] == 0: + continue + best = sorted_local[max_idx[0]] + # Eq. 9 in the paper + max_flow = p1[c_idx0[best]] - p0[mask][best] + # Eq. 10 in the paper + cluster_norms.append(torch.linalg.vector_norm(c_flow - max_flow, dim=-1)) + + if cluster_norms: + # Eq. 11 + moved_loss = torch.cat(cluster_norms).mean() + elif have_any_dyn: + all_d = torch.cat(fallback_dists) + moved_loss = torch.mean(all_d[all_d <= TRUNCATED_DIST]) + else: + moved_loss = torch.tensor(0.0, device=dev) + + return static_loss, moved_loss + + +def teflowLoss(res_dict, timer=None): + """Temporal seflow: chamfer over all frames + static + RANSAC cluster loss.""" + pc0_list = res_dict['pc0_list'] + flow_list = res_dict['est_flow_list'] + pc0_lab_list = res_dict['pc0_labels_list'] + + chamfer_dis, dynamic_chamfer_dis, frame_keys = batched_chamfer_related(res_dict, timer) + + static_loss = torch.tensor(0.0, device=pc0_list[0].device) + for fv, lab in zip(flow_list, pc0_lab_list): + if (lab == 0).any(): + static_loss += torch.linalg.vector_norm(fv[lab == 0], dim=-1).mean() + static_loss /= max(len(pc0_list), 1) + + cluster_weight = res_dict['loss_weights_dict'].get('cluster_based_pc0pc1', 0.0) + if cluster_weight > 0: + frames_dists, frames_indices = {}, {} + for frame_id in frame_keys: + d_list, i_list = MyCUDAChamferDis.batched_disid_res( + pc0_list, res_dict[f'{frame_id}_list'], + ) + frames_dists[frame_id] = d_list + frames_indices[frame_id] = i_list + + moved_cluster_loss = multi_frames_clusterLoss( + pc0_list, pc0_lab_list, flow_list, + frame_keys, frames_dists, frames_indices, res_dict, + res_dict.get('cluster_loss_args', {}), + ) + else: + moved_cluster_loss = torch.tensor(0.0, device=pc0_list[0].device) + + return { + 'chamfer_dis': chamfer_dis, + 'dynamic_chamfer_dis': dynamic_chamfer_dis, + 'static_flow_loss': static_loss, 'cluster_based_pc0pc1': moved_cluster_loss, } - return res_loss -def seflowLoss(res_dict, timer=None): - pc0_label = res_dict['pc0_labels'] - pc1_label = res_dict['pc1_labels'] - - pc0 = res_dict['pc0'] - pc1 = res_dict['pc1'] - - est_flow = res_dict['est_flow'] - - pseudo_pc1from0 = pc0 + est_flow - - unique_labels = torch.unique(pc0_label) - pc0_dynamic = pc0[pc0_label > 0] - pc1_dynamic = pc1[pc1_label > 0] - # fpc1_dynamic = pseudo_pc1from0[pc0_label > 0] - # NOTE(Qingwen): since we set THREADS_PER_BLOCK is 256 - have_dynamic_cluster = (pc0_dynamic.shape[0] > 256) & (pc1_dynamic.shape[0] > 256) - - # first item loss: chamfer distance - # timer[5][1].start("MyCUDAChamferDis") - # raw: pc0 to pc1, est: pseudo_pc1from0 to pc1, idx means the nearest index - est_dist0, est_dist1, _, _ = MyCUDAChamferDis.disid_res(pseudo_pc1from0, pc1) - raw_dist0, raw_dist1, raw_idx0, _ = MyCUDAChamferDis.disid_res(pc0, pc1) - chamfer_dis = torch.mean(est_dist0[est_dist0 <= TRUNCATED_DIST]) + torch.mean(est_dist1[est_dist1 <= TRUNCATED_DIST]) - # timer[5][1].stop() - - # second item loss: dynamic chamfer distance - # timer[5][2].start("DynamicChamferDistance") - dynamic_chamfer_dis = torch.tensor(0.0, device=est_flow.device) - if have_dynamic_cluster: - dynamic_chamfer_dis += MyCUDAChamferDis(pseudo_pc1from0[pc0_label>0], pc1_dynamic, truncate_dist=TRUNCATED_DIST) - # timer[5][2].stop() - - # third item loss: exclude static points' flow - # NOTE(Qingwen): add in the later part on label==0 - static_cluster_loss = torch.tensor(0.0, device=est_flow.device) - - # fourth item loss: same label points' flow should be the same - # timer[5][3].start("SameClusterLoss") - moved_cluster_loss = torch.tensor(0.0, device=est_flow.device) - moved_cluster_norms = torch.tensor([], device=est_flow.device) - for label in unique_labels: - mask = pc0_label == label - if label == 0: - # Eq. 6 in the paper - static_cluster_loss += torch.linalg.vector_norm(est_flow[mask, :], dim=-1).mean() - # NOTE(Qingwen) 2025-04-23: label=1 is dynamic but no cluster id satisfied - elif label > 1 and have_dynamic_cluster: - cluster_id_flow = est_flow[mask, :] - cluster_nnd = raw_dist0[mask] - if cluster_nnd.shape[0] <= 0: - continue +# from paper: https://arxiv.org/abs/2503.00803 +def seflowppLoss(res_dict, timer=None): + """seflow++ loss: bidirectional (pc1 + pch1) chamfer + cluster, B samples.""" + pc0_list = res_dict['pc0_list'] + pc1_list = res_dict['pc1_list'] + pch1_list = res_dict['pch1_list'] + flow_list = res_dict['est_flow_list'] + pc0_lab_list = res_dict['pc0_labels_list'] + pc1_lab_list = res_dict['pc1_labels_list'] + pch1_lab_list = res_dict['pch1_labels_list'] + dev = pc0_list[0].device - # Eq. 8 in the paper - sorted_idxs = torch.argsort(cluster_nnd, descending=True) - nearby_label = pc1_label[raw_idx0[mask][sorted_idxs]] # nonzero means dynamic in label - non_zero_valid_indices = torch.nonzero(nearby_label > 0) - if non_zero_valid_indices.shape[0] <= 0: - continue - max_idx = sorted_idxs[non_zero_valid_indices.squeeze(1)[0]] - - # Eq. 9 in the paper - max_flow = pc1[raw_idx0[mask][max_idx]] - pc0[mask][max_idx] - - # Eq. 10 in the paper - moved_cluster_norms = torch.cat((moved_cluster_norms, torch.linalg.vector_norm((cluster_id_flow - max_flow), dim=-1))) - - if moved_cluster_norms.shape[0] > 0: - moved_cluster_loss = moved_cluster_norms.mean() # Eq. 11 in the paper - elif have_dynamic_cluster: - moved_cluster_loss = torch.mean(raw_dist0[raw_dist0 <= TRUNCATED_DIST]) + torch.mean(raw_dist1[raw_dist1 <= TRUNCATED_DIST]) - # timer[5][3].stop() - - res_loss = { - 'chamfer_dis': chamfer_dis, - 'dynamic_chamfer_dis': dynamic_chamfer_dis, - 'static_flow_loss': static_cluster_loss, + fwd_list = [p0 + fv for p0, fv in zip(pc0_list, flow_list)] + bwd_list = [p0 - fv for p0, fv in zip(pc0_list, flow_list)] + + # Chamfer: both temporal directions concurrently + chamfer_dis = MyCUDAChamferDis.batched(fwd_list, pc1_list, truncate_dist=TRUNCATED_DIST) + chamfer_dis += MyCUDAChamferDis.batched(bwd_list, pch1_list, truncate_dist=TRUNCATED_DIST) + + # Dynamic chamfer + dyn_fwd, dyn_pc1 = [], [] + dyn_bwd, dyn_pch1 = [], [] + for fwd_i, bwd_i, p1_i, ph1_i, lab0_i, lab1_i, labh1_i in zip( + fwd_list, bwd_list, pc1_list, pch1_list, + pc0_lab_list, pc1_lab_list, pch1_lab_list): + dyn_mask = lab0_i > 0 + if dyn_mask.sum() > 256: + dp1 = p1_i[lab1_i > 0] + dph = ph1_i[labh1_i > 0] + if dp1.shape[0] > 256: dyn_fwd.append(fwd_i[dyn_mask]); dyn_pc1.append(dp1) + if dph.shape[0] > 256: dyn_bwd.append(bwd_i[dyn_mask]); dyn_pch1.append(dph) + + dynamic_chamfer_dis = torch.tensor(0.0, device=dev) + if len(dyn_fwd) == 1: + dynamic_chamfer_dis += MyCUDAChamferDis(dyn_fwd[0], dyn_pc1[0], truncate_dist=TRUNCATED_DIST) + elif len(dyn_fwd) > 1: + dynamic_chamfer_dis += MyCUDAChamferDis.batched(dyn_fwd, dyn_pc1, truncate_dist=TRUNCATED_DIST) + if len(dyn_bwd) == 1: + dynamic_chamfer_dis += MyCUDAChamferDis(dyn_bwd[0], dyn_pch1[0], truncate_dist=TRUNCATED_DIST) + elif len(dyn_bwd) > 1: + dynamic_chamfer_dis += MyCUDAChamferDis.batched(dyn_bwd, dyn_pch1, truncate_dist=TRUNCATED_DIST) + + dist0_list, idx0_list = MyCUDAChamferDis.batched_disid_res(pc0_list, pc1_list) + static_loss, moved_cluster_loss = _seflow_cluster_loop( + pc0_list, pc1_list, pc0_lab_list, pc1_lab_list, + flow_list, dist0_list, idx0_list, + ) + + return { + 'chamfer_dis': chamfer_dis / 2.0, + 'dynamic_chamfer_dis': dynamic_chamfer_dis / 2.0, + 'static_flow_loss': static_loss, 'cluster_based_pc0pc1': moved_cluster_loss, } - return res_loss + +# from paper: https://arxiv.org/abs/2407.01702 +def seflowLoss(res_dict, timer=None): + """seflow loss: single future frame (pc1), batched over B samples.""" + pc0_list = res_dict['pc0_list'] + pc1_list = res_dict['pc1_list'] + flow_list = res_dict['est_flow_list'] + pc0_lab_list = res_dict['pc0_labels_list'] + pc1_lab_list = res_dict['pc1_labels_list'] + dev = pc0_list[0].device + + fwd_list = [p0 + fv for p0, fv in zip(pc0_list, flow_list)] + + chamfer_dis = MyCUDAChamferDis.batched(fwd_list, pc1_list, truncate_dist=TRUNCATED_DIST) + + # Dynamic chamfer + dyn_fwd, dyn_pc1 = [], [] + for fwd_i, p1_i, lab0_i, lab1_i in zip(fwd_list, pc1_list, pc0_lab_list, pc1_lab_list): + dp1 = p1_i[lab1_i > 0] + if (lab0_i > 0).sum() > 256 and dp1.shape[0] > 256: + dyn_fwd.append(fwd_i[lab0_i > 0]) + dyn_pc1.append(dp1) + + dynamic_chamfer_dis = torch.tensor(0.0, device=dev) + if len(dyn_fwd) == 1: + dynamic_chamfer_dis = MyCUDAChamferDis(dyn_fwd[0], dyn_pc1[0], truncate_dist=TRUNCATED_DIST) + elif len(dyn_fwd) > 1: + dynamic_chamfer_dis = MyCUDAChamferDis.batched(dyn_fwd, dyn_pc1, truncate_dist=TRUNCATED_DIST) + + dist0_list, idx0_list = MyCUDAChamferDis.batched_disid_res(pc0_list, pc1_list) + static_loss, moved_cluster_loss = _seflow_cluster_loop( + pc0_list, pc1_list, pc0_lab_list, pc1_lab_list, + flow_list, dist0_list, idx0_list, + ) + + return { + 'chamfer_dis': chamfer_dis, + 'dynamic_chamfer_dis': dynamic_chamfer_dis, + 'static_flow_loss': static_loss, + 'cluster_based_pc0pc1': moved_cluster_loss, + } \ No newline at end of file diff --git a/src/models/basic/voteflow_plugin/hough_transformation/cpp_im2ht/setup.py b/src/models/basic/voteflow_plugin/hough_transformation/cpp_im2ht/setup.py index 6c2521f..d1e1559 100755 --- a/src/models/basic/voteflow_plugin/hough_transformation/cpp_im2ht/setup.py +++ b/src/models/basic/voteflow_plugin/hough_transformation/cpp_im2ht/setup.py @@ -1,6 +1,11 @@ from setuptools import setup from torch.utils.cpp_extension import BuildExtension, CUDAExtension +extra_compile_args = { + 'cxx': ['-DCCCL_IGNORE_DEPRECATED_CUDA_BELOW_12', '-DTHRUST_IGNORE_CUB_VERSION_CHECK'], + 'nvcc': ['-DCCCL_IGNORE_DEPRECATED_CUDA_BELOW_12', '-DTHRUST_IGNORE_CUB_VERSION_CHECK'], +} + setup( name='im2ht', # ext_modules=[ @@ -8,7 +13,8 @@ # extra_compile_args={'cxx': ['-g'], 'nvcc': ['-arch=sm_60']}), # ], ext_modules=[ - CUDAExtension(name='im2ht', sources=['im2ht.cpp', 'ht_cuda.cu']), + CUDAExtension(name='im2ht', sources=['im2ht.cpp', 'ht_cuda.cu'], + extra_compile_args=extra_compile_args), ], cmdclass={ 'build_ext': BuildExtension diff --git a/src/models/basic/voteflow_plugin/hough_transformation/im2ht.py b/src/models/basic/voteflow_plugin/hough_transformation/im2ht.py index e4d12e7..927e823 100755 --- a/src/models/basic/voteflow_plugin/hough_transformation/im2ht.py +++ b/src/models/basic/voteflow_plugin/hough_transformation/im2ht.py @@ -7,22 +7,27 @@ from torch.autograd.function import once_differentiable def load_cpp_ext(ext_name): - root_dir = os.path.join(os.path.split(__file__)[0]) - src_dir = os.path.join(root_dir, "cpp_im2ht") - tar_dir = os.path.join(src_dir, "build", ext_name) - os.makedirs(tar_dir, exist_ok=True) - srcs = glob(f"{src_dir}/*.cu") + glob(f"{src_dir}/*.cpp") + try: + import im2ht + ext = im2ht + except ImportError: + print(f"Compiling {ext_name} cpp/cuda extension...") + root_dir = os.path.join(os.path.split(__file__)[0]) + src_dir = os.path.join(root_dir, "cpp_im2ht") + tar_dir = os.path.join(src_dir, "build", ext_name) + os.makedirs(tar_dir, exist_ok=True) + srcs = glob(f"{src_dir}/*.cu") + glob(f"{src_dir}/*.cpp") - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - from torch.utils.cpp_extension import load - ext = load( - name=ext_name, - sources=srcs, - extra_cflags=["-O3"], - extra_cuda_cflags=[], - build_directory=tar_dir, - ) + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + from torch.utils.cpp_extension import load + ext = load( + name=ext_name, + sources=srcs, + extra_cflags=["-O3"], + extra_cuda_cflags=["-DTHRUST_IGNORE_CUB_VERSION_CHECK"], + build_directory=tar_dir, + ) return ext # defer calling load_cpp_ext to make CUDA_VISIBLE_DEVICES happy diff --git a/src/trainer.py b/src/trainer.py index de064f1..ffeb365 100644 --- a/src/trainer.py +++ b/src/trainer.py @@ -19,12 +19,13 @@ from lightning import LightningModule from hydra.utils import instantiate -from omegaconf import OmegaConf,open_dict +from omegaconf import OmegaConf, open_dict -import os, sys, time, h5py, pickle +import os, sys, time, h5py BASE_DIR = os.path.abspath(os.path.join( os.path.dirname( __file__ ), '..' )) sys.path.append(BASE_DIR) from src.utils import import_func +from src.lossfuncs import SSL_LOSSES_FN from src.utils.mics import weights_init, zip_res from src.utils.av2_eval import write_output_file from src.models.basic import cal_pose0to1, WarmupCosLR @@ -51,9 +52,12 @@ def __init__(self, cfg, eval=False): "save_res": False, "res_name": "default", "num_frames": 2, + + # lr scheduler, only active when warmup_epochs > 0 "optimizer": None, "dataset_path": None, "data_mode": None, + "cluster_loss_args": {}, } for key, default in default_self_values.items(): setattr(self, key, cfg.get(key, default)) @@ -92,74 +96,124 @@ def __init__(self, cfg, eval=False): self.save_res_path = Path(cfg.dataset_path).parent / "results" / cfg.output os.makedirs(self.save_res_path, exist_ok=True) print(f"We are in {cfg.data_mode}, results will be saved in: {self.save_res_path} with version: {self.leaderboard_version} format for online leaderboard.") - - # self.test_total_num = 0 if self.data_mode in ['val', 'valid', 'test']: print(cfg) + # self.test_total_num = 0 self.save_hyperparameters() + + def ssl_loss_calculator(self, batch, res_dict, if_log=True): + """Build dict2loss for ALL self-supervised losses (seflow, seflowpp, teflow*). - # FIXME(Qingwen 2025-08-20): update the loss_calculation fn alone to make all things pretty here.... - def training_step(self, batch, batch_idx): - self.model.timer[4].start("One Scan in model") - res_dict = self.model(batch) - self.model.timer[4].stop() + Each frame is represented only as a List[Tensor] and a List[labels]. + No flat tensors, no offsets, no sizes — chamfer calls use list APIs only. + """ + total_loss, bz_ = 0.0, len(batch["pose0"]) - self.model.timer[5].start("Loss") - # compute loss - total_loss = 0.0 + pc0_list = [res_dict['pc0_points_lst'][i] for i in range(bz_)] - if self.cfg_loss_name in ['seflowLoss', 'seflowppLoss']: - loss_items, weights = zip(*[(key, weight) for key, weight in self.add_seloss.items()]) - loss_logger = {'chamfer_dis': 0.0, 'dynamic_chamfer_dis': 0.0, 'static_flow_loss': 0.0, 'cluster_based_pc0pc1': 0.0} - else: - loss_items, weights = ['loss'], [1.0] - loss_logger = {'loss': 0.0} - - pc0_valid_idx = res_dict['pc0_valid_point_idxes'] # since padding - pc1_valid_idx = res_dict['pc1_valid_point_idxes'] # since padding - if 'pc0_points_lst' in res_dict and 'pc1_points_lst' in res_dict: - pc0_points_lst = res_dict['pc0_points_lst'] - pc1_points_lst = res_dict['pc1_points_lst'] + dict2loss = { + 'pc0_list': pc0_list, + 'est_flow_list': [res_dict['flow'][i] for i in range(bz_)], + 'pc0_labels_list': [batch['pc0_dynamic'][i][res_dict['pc0_valid_point_idxes'][i]] for i in range(bz_)], + 'batch_size': bz_, + } + + frame_keys = [key.replace('_points_lst', '') for key in res_dict.keys() + if key.startswith('pc') and key.endswith('_points_lst')] + frame_keys.remove('pc0') + + for frame_id in frame_keys: + points_list = [res_dict[f'{frame_id}_points_lst'][i] for i in range(bz_)] + labels_list = [batch[f'{frame_id}_dynamic'][i][res_dict[f'{frame_id}_valid_point_idxes'][i]] for i in range(bz_)] + dict2loss[f'{frame_id}_list'] = points_list + dict2loss[f'{frame_id}_labels_list'] = labels_list + + loss_items, weights = zip(*[(key, weight) for key, weight in self.add_seloss.items()]) + dict2loss['loss_weights_dict'] = self.add_seloss + + dict2loss['cluster_loss_args'] = self.cluster_loss_args + + res_loss = self.loss_fn(dict2loss) + + for i, loss_name in enumerate(loss_items): + if not torch.isnan(res_loss[loss_name]): + total_loss += weights[i] * res_loss[loss_name] - batch_sizes = len(batch["pose0"]) - pose_flows = res_dict['pose_flow'] - est_flow = res_dict['flow'] + if if_log: + self.log("trainer/loss", total_loss, sync_dist=True, batch_size=bz_, prog_bar=True) + for key in res_loss: + self.log(f"trainer/{key}", res_loss[key], sync_dist=True, batch_size=bz_) + + return total_loss + + def loss_calculator(self, batch, res_dict, if_log=True): + """ Calculate the loss based on the batch (gt/ssl-label) and res_dict (estimate flow).""" + def get_batch_data(batch, key, batch_id, batch_sizes, pc0_valid_from_pc2res, pose_flow_=None): + """NOTE(Qingwen): for gt need double check whether it exists in the batch and batch size is correct""" + if key not in batch or batch[key].shape[0] != batch_sizes: + return None + data = batch[key][batch_id][pc0_valid_from_pc2res] + if key == 'flow' and pose_flow_ is not None: + data = data - pose_flow_ + return data + def get_frame_keys(data_dict, suffix): + return [key for key in data_dict.keys() if key.endswith(suffix)] + def extract_frame_id(key, suffix): + """Extract frame identifier from key (e.g., 'pc0_points_lst' -> 'pc0')""" + return key.replace(suffix, '') + # Supervised-only path (deflowLoss, etc.) + # SSL losses are handled by ssl_loss_calculator. + total_loss, loss_logger = 0.0, {} + loss_items, weights = ['loss'], [1.0] + for key in loss_items: + loss_logger[key] = 0.0 + + batch_sizes, pose_flows, est_flow = len(batch["pose0"]), res_dict['pose_flow'], res_dict['flow'] for batch_id in range(batch_sizes): - pc0_valid_from_pc2res = pc0_valid_idx[batch_id] - pc1_valid_from_pc2res = pc1_valid_idx[batch_id] + # Get pc0 valid indices (main reference frame) + pc0_valid_from_pc2res = res_dict['pc0_valid_point_idxes'][batch_id] pose_flow_ = pose_flows[batch_id][pc0_valid_from_pc2res] dict2loss = {'est_flow': est_flow[batch_id], - 'gt_flow': None if 'flow' not in batch else batch['flow'][batch_id][pc0_valid_from_pc2res] - pose_flow_, - 'gt_classes': None if 'flow_category_indices' not in batch else batch['flow_category_indices'][batch_id][pc0_valid_from_pc2res], - 'gt_instance': None if 'flow_instance_id' not in batch else batch['flow_instance_id'][batch_id][pc0_valid_from_pc2res],} + 'gt_flow': get_batch_data(batch, 'flow', batch_id, batch_sizes, pc0_valid_from_pc2res, pose_flow_), + 'gt_classes': get_batch_data(batch, 'flow_category_indices', batch_id, batch_sizes, pc0_valid_from_pc2res), + 'gt_instance': get_batch_data(batch, 'flow_instance_id', batch_id, batch_sizes, pc0_valid_from_pc2res)} - if 'pc0_dynamic' in batch: - dict2loss['pc0_labels'] = batch['pc0_dynamic'][batch_id][pc0_valid_from_pc2res] - dict2loss['pc1_labels'] = batch['pc1_dynamic'][batch_id][pc1_valid_from_pc2res] - if 'pch1_dynamic' in batch and 'pch1_valid_point_idxes' in res_dict: - dict2loss['pch1_labels'] = batch['pch1_dynamic'][batch_id][res_dict['pch1_valid_point_idxes'][batch_id]] - - # different methods may don't have this in the res_dict - if 'pc0_points_lst' in res_dict and 'pc1_points_lst' in res_dict: - dict2loss['pc0'] = pc0_points_lst[batch_id] - dict2loss['pc1'] = pc1_points_lst[batch_id] - if 'pch1_points_lst' in res_dict: - dict2loss['pch1'] = res_dict['pch1_points_lst'][batch_id] + # Add all available point cloud frames + for points_key in get_frame_keys(res_dict, '_points_lst'): + frame_id = extract_frame_id(points_key, '_points_lst') + if points_key in res_dict: + dict2loss[frame_id] = res_dict[points_key][batch_id] res_loss = self.loss_fn(dict2loss) + for i, loss_name in enumerate(loss_items): + # if torch.isnan(res_loss[loss_name]): + # print(f"==> Loss: {loss_name} is nan, skip this batch.") + # continue total_loss += weights[i] * res_loss[loss_name] for key in res_loss: loss_logger[key] += res_loss[key] + if if_log: + self.log("trainer/loss", total_loss/batch_sizes, sync_dist=True, batch_size=self.batch_size, prog_bar=True) + return total_loss + + def training_step(self, batch, batch_idx): + total_loss = 0.0 + self.model.timer[5].start("Training Step") + self.model.timer[5][0].start("Forward") + res_dict = self.model(batch) + self.model.timer[5][0].stop() + self.model.timer[5][1].start("Compute Loss") - self.log("trainer/loss", total_loss/batch_sizes, sync_dist=True, batch_size=self.batch_size, prog_bar=True) - if self.add_seloss is not None and self.cfg_loss_name in ['seflowLoss', 'seflowppLoss']: - for key in loss_logger: - self.log(f"trainer/{key}", loss_logger[key]/batch_sizes, sync_dist=True, batch_size=self.batch_size) + if self.cfg_loss_name in SSL_LOSSES_FN: + total_loss = self.ssl_loss_calculator(batch, res_dict) + else: + total_loss = self.loss_calculator(batch, res_dict) + self.model.timer[5][1].stop() self.model.timer[5].stop() - + # NOTE (Qingwen): if you want to view the detail breakdown of time cost # self.model.timer.print(random_colors=False, bold=False) return total_loss @@ -206,6 +260,8 @@ def on_train_epoch_start(self): def on_train_epoch_end(self): self.log("pre_epoch_cost (mins)", (time.time()-self.time_start_train_epoch)/60.0, on_step=False, on_epoch=True, sync_dist=True) + # # NOTE (Qingwen): if you want to view the detail breakdown of time cost + # self.model.timer.print(random_colors=False, bold=False) def on_validation_epoch_end(self): self.model.timer.print(random_colors=False, bold=False) @@ -223,9 +279,9 @@ def on_validation_epoch_end(self): # wandb.log_artifact(output_file) return - if self.data_mode == 'val': + if self.data_mode in ['val', 'valid']: print(f"\nModel: {self.model.__class__.__name__}, Checkpoint from: {self.checkpoint}") - print(f"More details parameters and training status are in the checkpoint file.") + print(f"More details parameters and training status are in checkpoints file.") self.metrics.normalize() @@ -238,15 +294,12 @@ def on_validation_epoch_end(self): self.metrics.print() + self.metrics = OfficialMetrics() + if self.save_res: - # Save the dictionaries to a pickle file - with open(str(self.save_res_path)+'.pkl', 'wb') as f: - pickle.dump((self.metrics.epe_3way, self.metrics.bucketed, self.metrics.epe_ssf), f) - print(f"We already write the {self.res_name} into the dataset, please run following commend to visualize the flow. Copy and paste it to your terminal:") - print(f"python tools/visualization.py vis --res_name '{self.res_name}' --data_dir {self.dataset_path}") + print(f"We already write the flow_est into the dataset, please run following commend to visualize the flow. Copy and paste it to your terminal:") + print(f"python tools/visualization.py --res_name \"['{self.res_name}']\" --data_dir {self.dataset_path}") print(f"Enjoy! ^v^ ------ \n") - - self.metrics = OfficialMetrics() def eval_only_step_(self, batch, res_dict): eval_mask = batch['eval_mask'].squeeze() @@ -261,24 +314,35 @@ def eval_only_step_(self, batch, res_dict): # flow in the original pc0 coordinate pred_flow = pose_flow[~batch['gm0']].clone() + # debug: for ego-motion flow only + # res_dict['flow'] = torch.zeros_like(res_dict['flow']) pred_flow[valid_from_pc2res] = res_dict['flow'] + pose_flow[~batch['gm0']][valid_from_pc2res] final_flow[~batch['gm0']] = pred_flow else: final_flow[~batch['gm0']] = res_dict['flow'] + pose_flow[~batch['gm0']] - if self.data_mode == 'val': # since only val we have ground truth flow to eval + if self.data_mode in ['val', 'valid']: # since only val we have ground truth flow to eval gt_flow = batch["flow"] v1_dict = evaluate_leaderboard(final_flow[eval_mask], pose_flow[eval_mask], pc0[eval_mask], \ gt_flow[eval_mask], batch['flow_is_valid'][eval_mask], \ batch['flow_category_indices'][eval_mask]) v2_dict = evaluate_leaderboard_v2(final_flow[eval_mask], pose_flow[eval_mask], pc0[eval_mask], \ gt_flow[eval_mask], batch['flow_is_valid'][eval_mask], batch['flow_category_indices'][eval_mask]) - ssf_dict = evaluate_ssf(final_flow, pose_flow, pc0, \ - gt_flow, batch['flow_is_valid'], batch['flow_category_indices']) + ssf_dict = evaluate_ssf(final_flow[eval_mask], pose_flow[eval_mask], pc0[eval_mask], \ + gt_flow[eval_mask], batch['flow_is_valid'][eval_mask], batch['flow_category_indices'][eval_mask]) + self.metrics.step(v1_dict, v2_dict, ssf_dict) - + if self.save_res: + # write final_flow into the dataset. + key = str(batch['timestamp']) + scene_id = batch['scene_id'] + with h5py.File(os.path.join(self.dataset_path, f'{self.data_mode}/{scene_id}.h5'), 'r+') as f: + if self.res_name in f[key]: + del f[key][self.res_name] + f[key].create_dataset(self.res_name, data=final_flow.cpu().detach().numpy().astype(np.float32)) + # NOTE (Qingwen): Since val and test, we will force set batch_size = 1 - if self.save_res or self.data_mode == 'test': # test must save data to submit in the online leaderboard. + if self.save_res and self.data_mode == 'test': # test must save data to submit in the online leaderboard. save_pred_flow = final_flow[eval_mask, :3].cpu().detach().numpy() rigid_flow = pose_flow[eval_mask, :3].cpu().detach().numpy() is_dynamic = np.linalg.norm(save_pred_flow - rigid_flow, axis=1, ord=2) >= 0.05 @@ -302,19 +366,22 @@ def run_model_wo_ground_data(self, batch): # NOTE (Qingwen): Since val and test, we will force set batch_size = 1 batch = {key: batch[key][0] for key in batch if len(batch[key])>0} - res_dict = {key: res_dict[key][0] for key in res_dict if res_dict[key]!=None and len(res_dict[key])>0} + res_dict = {key: res_dict[key][0] for key in res_dict if (res_dict[key]!=None and len(res_dict[key])>0) } return batch, res_dict def validation_step(self, batch, batch_idx): - if self.data_mode in ['val', 'test']: - batch, res_dict = self.run_model_wo_ground_data(batch) - self.model.timer[13].start("Eval") - self.eval_only_step_(batch, res_dict) - self.model.timer[13].stop() - else: - res_dict = self.model(batch) - self.train_validation_step_(batch, res_dict) - + try: + if self.data_mode in ['val', 'valid'] or self.data_mode == 'test': + batch, res_dict = self.run_model_wo_ground_data(batch) + if batch['eval_flag']: + self.eval_only_step_(batch, res_dict) + else: + res_dict = self.model(batch) + self.train_validation_step_(batch, res_dict) + except Exception as e: + print(f"==> Exception occur during training/validation step: {e}. Skip this batch.") + print(f"Batch info: scene_id: {batch['scene_id']}, timestamp: {batch['timestamp']}, pc0 size: {batch['pc0']}") + def test_step(self, batch, batch_idx): batch, res_dict = self.run_model_wo_ground_data(batch) pc0 = batch['origin_pc0'] @@ -346,5 +413,5 @@ def on_test_epoch_end(self): self.model.timer.print(random_colors=False, bold=False) print(f"\n\nModel: {self.model.__class__.__name__}, Checkpoint from: {self.checkpoint}") print(f"We already write the flow_est into the dataset, please run following commend to visualize the flow. Copy and paste it to your terminal:") - print(f"python tools/visualization.py --res_name '{self.res_name}' --data_dir {self.dataset_path}") + print(f"python tools/visualization.py --res_name \"['{self.res_name}']\" --data_dir {self.dataset_path}") print(f"Enjoy! ^v^ ------ \n") diff --git a/train.py b/train.py index b7d7eaf..f5ed7ca 100644 --- a/train.py +++ b/train.py @@ -28,9 +28,9 @@ from src.dataset import HDF5Dataset, collate_fn_pad, RandomHeight, RandomFlip, RandomJitter, ToTensor from torchvision import transforms from src.trainer import ModelWrapper - +from src.lossfuncs import SSL_LOSSES_FN def precheck_cfg_valid(cfg): - if cfg.loss_fn in ['seflowLoss', 'seflowppLoss'] and (cfg.add_seloss is None or cfg.ssl_label is None): + if cfg.loss_fn in SSL_LOSSES_FN and (cfg.add_seloss is None or cfg.ssl_label is None): raise ValueError("Please specify the self-supervised loss items and auto-label source for seflow-series loss.") grid_size = [(cfg.point_cloud_range[3] - cfg.point_cloud_range[0]) * (1/cfg.voxel_size[0]), @@ -83,7 +83,7 @@ def main(cfg): output_dir = HydraConfig.get().runtime.output_dir # overwrite logging folder name for SSL. - if cfg.loss_fn in ['seflowLoss', 'seflowppLoss']: + if cfg.loss_fn in SSL_LOSSES_FN: tmp_ = cfg.loss_fn.split('Loss')[0] + '-' + cfg.model.name cfg.output = cfg.output.replace(cfg.model.name, tmp_) output_dir = output_dir.replace(cfg.model.name, tmp_)