MalisLoss: pass eroded gt_seg through data pipeline (skip per-step CC, fix crop topology)#211
Open
donglaiw wants to merge 1 commit into
Open
MalisLoss: pass eroded gt_seg through data pipeline (skip per-step CC, fix crop topology)#211donglaiw wants to merge 1 commit into
donglaiw wants to merge 1 commit into
Conversation
…, fix crop topology) Adds an opt-in `gt_seg` plumbing path that pipes the eroded GT segmentation from the data pipeline through to MalisLoss, so MALIS can skip its internal `connected_components_affgraph(gt_affs, nhood)` call and use the supplied seg directly. Primary motivation (correctness): with `malis_crop_size` enabled (landed in commit 5b0451f), the current MALIS path runs CC on the *cropped* gt_affs. When a single GT instance spans the crop boundary so that two of its pieces appear inside the crop but are not connected within that window, CC labels them as distinct components, and MALIS then injects spurious negative-constraint edges inside one true instance. Passing the eroded gt_seg from upstream (where global instance IDs are preserved) and cropping it with the same origin fixes this. Secondary motivation (small speedup): removes the per-step CC inside `_compute_malis_weights`. Roughly 5-10% of the remaining MALIS step cost on top of cropping. Pipeline boundary - `LabelTransformConfig.emit_gt_seg: bool = False` (strict dataclass field, defaults to off). - `MultiTaskLabelTransformd` is left untouched. Instead, a `CopyItemsd(keys="label", names="gt_seg")` MONAI transform is inserted in `data/augmentation/build.py` immediately after the existing `SegErosionInstanced` step (both train and val), gated by `label_cfg.emit_gt_seg`. This makes "post-augment, post-erode" the canonical snapshot point. - Default behavior (no `emit_gt_seg`) is bit-for-bit identical to pre-PR `MalisLoss`. Loss orchestrator - New `LossMetadata.gt_seg_arg: Optional[str]`. Set to `"gt_seg"` for MalisLoss; all other losses unaffected. - `compute_standard_loss(..., gt_seg=None)` plumbs the batch's `gt_seg` to any term whose metadata declares the arg, via the existing `extra_loss_kwargs` extension point. - `compute_deep_supervision_loss(..., gt_seg=None)` forces `gt_seg=None` for every head (DS lower heads work on downsampled targets that gt_seg can't match label-correctly). MalisLoss + DS falls back to the legacy CC-recompute path. Pinned by a unit test. MalisLoss - `forward(pred, target, mask=None, gt_seg=None)` accepts the optional kwarg. - `_prepare_gt_seg` normalizes shape to `[B, Z, Y, X]` (accepts both `[B, Z, Y, X]` and `[B, 1, Z, Y, X]`), validates against pred spatial dims. - `_apply_crop_if_configured` now returns a 4-tuple and crops gt_seg at the same origin as pred/target/mask. - `_compute_malis_weights(..., *, gt_seg=None)` uses the supplied seg per sample when provided; falls back to `connected_components_affgraph` otherwise. Speedup story (measured + documented) Production config: MedNeXt-L, batch 2, 128^3 patch on L40S. | Config | it/s | hours/epoch (5000 steps) | |-------------------------------------|------|--------------------------| | BCE only (no MALIS) | ~0.71| ~1.95 | | Full-volume MALIS (original) | ~0.17| ~7.3 | | MALIS + malis_crop_size=64 | ~0.78| ~1.78 | | MALIS + crop=64 + emit_gt_seg=true | ~0.78+| ~1.78 | Crop alone gives ~4.6x speedup vs the original full-volume MALIS (slurm 2505814 vs 2487040). emit_gt_seg adds a small additional speedup on top, and is primarily a correctness fix for the cropped case. See `docs/source/notes/malis.rst` and `MalisLoss` class docstring. YAML opt-in (`tutorials/neuron_nisb/base_banis+_malis.yaml`) - `default.data.label_transform.emit_gt_seg: true` enables the passthrough path. - `malis_crop_size: 64` uncommented to enable the crop in production. Tests - `tests/unit/test_malis_loss.py` — 6 new tests covering metadata, uncropped CC equivalence, `gt_seg=None` legacy preservation, the cropped-instance-fragmentation bug fix (Path A spy asserts CC was called and labels fragment; Path B asserts CC was NOT called and the supplied gt_seg retains a single instance label), shape validation, and that no grad flows through gt_seg. - `tests/unit/test_data_factory.py` — 2 new tests for the CopyItemsd insertion (train + val symmetry, and a runtime end-to-end equivalence to a manual erosion). - `tests/unit/test_loss_orchestrator.py` — 2 new tests pinning the standard-loss gt_seg plumbing and the DS=legacy fallback. - All existing test_malis_loss cases unchanged. - `python -m pytest tests/unit/test_malis_loss.py tests/unit/test_data_factory.py tests/unit/test_loss_orchestrator.py -q` -> 76 passed, 1 skipped. Docs - New `docs/source/notes/malis.rst` covering MALIS speedup and correctness knobs (linked from `docs/source/index.rst` Get Started toctree). - `MalisLoss` class docstring extended with a Performance section citing the measured speedup. CCC design history under `.agent/features/malis_gt_passthrough/` (gitignored). Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Passes the eroded GT segmentation from the data pipeline through to
MalisLossso it can:connected_components_affgraph(gt_affs, nhood)call every training step (~5–10 % of MALIS step cost; small).malis_crop_size(the primary motivation; fixes a correctness artifact in the prior crop PR).Opt-in via
label_transform.emit_gt_seg: trueon the YAML side; default off → bit-for-bit identical to current behaviour for configs without it.Why
With
malis_crop_sizeenabled (#crop PR), the current MALIS path runs CC on the croppedgt_affs. When a single GT instance spans the crop boundary so that two of its pieces appear inside the crop but are not connected within that window, CC labels them as distinct components. MALIS then injects spurious negative-constraint edges inside one true instance.Passing the eroded
gt_segfrom upstream (where global instance IDs are preserved) and cropping it with the same origin fixes this.Speedup story (measured + documented)
Production config: MedNeXt-L, batch 2, 128³ patch on L40S.
malis_crop_size: 64emit_gt_seg: trueemit_gt_segadds a small additional speedup AND is primarily a correctness fix for the cropped case.See
docs/source/notes/malis.rst(new) and theMalisLossclass docstring for the full table.Implementation
Data-pipeline boundary
LabelTransformConfig.emit_gt_seg: bool = False(strict dataclass field).CopyItemsd(keys="label", names="gt_seg")inserted inconnectomics/data/augmentation/build.pyimmediately afterSegErosionInstanced(both train and val), gated bylabel_cfg.emit_gt_seg. Makes "post-augment, post-erode" the canonical snapshot point.MultiTaskLabelTransformdis untouched.Loss orchestrator
LossMetadata.gt_seg_arg: Optional[str]. Set to"gt_seg"forMalisLoss; all other losses unaffected.compute_standard_loss(..., gt_seg=None)plumbs the batch'sgt_segto any term whose metadata declares the arg, via the existingextra_loss_kwargsextension point.compute_deep_supervision_loss(..., gt_seg=None)forcesgt_seg = Nonefor every head — DS lower heads work on downsampled targets thatgt_segcan't match label-correctly. MalisLoss + DS falls back to the legacy CC-recompute path. Test-pinned.MalisLoss
forward(pred, target, mask=None, gt_seg=None)accepts the optional kwarg._prepare_gt_segnormalizes shape (accepts[B, Z, Y, X]and[B, 1, Z, Y, X]), validates against pred spatial dims._apply_crop_if_configurednow returns a 4-tuple and cropsgt_segat the same origin as pred/target/mask._compute_malis_weights(..., *, gt_seg=None)uses the supplied seg per sample when provided; falls back toconnected_components_affgraphotherwise.YAML (
tutorials/neuron_nisb/base_banis+_malis.yaml)Tests
python -m pytest tests/unit/test_malis_loss.py tests/unit/test_data_factory.py tests/unit/test_loss_orchestrator.py -q→ 76 passed, 1 skipped.test_malis_loss.py— 6 new tests:gt_seg_arg.gt_segsupplied vs reconstructed match withinrtol=1e-5).gt_seg=Nonestrict-equality preservation of the legacy path.connected_components_affgraphand asserts the CC labels fragment; Path B asserts CC was NOT called and the supplied gt_seg retains a single instance label; losses differ.ValueErroron mismatch).gt_seg.test_data_factory.py— 2 new tests for theCopyItemsdinsertion (train + val symmetry, plus a runtime end-to-end equivalence to a manual erosion).test_loss_orchestrator.py— 2 new tests pinning the standard-lossgt_segplumbing and the DS=legacy fallback.Docs
docs/source/notes/malis.rstcovering MALIS performance and correctness knobs (linked fromdocs/source/index.rst"Get Started" toctree).MalisLossclass docstring extended with a Performance section citing the measured speedup.Backward compatibility
emit_gt_seg) is bit-for-bit identical to currentMalisLoss.CCC design history
Under
.agent/features/malis_gt_passthrough/in the worktree (gitignored). Plan rounds: 2 (plan_v0 NEEDS_CHANGES → plan_v1 APPROVE_WITH_MINOR_COMMENTS). Code rounds: 1 (code_v0 → review_v0 APPROVE_WITH_MINOR_COMMENTS). 3 minor review_v0 findings were applied as small follow-up commits or kept as observational (see review_v0 artifact).🤖 Generated with Claude Code