Skip to content

MalisLoss: pass eroded gt_seg through data pipeline (skip per-step CC, fix crop topology)#211

Open
donglaiw wants to merge 1 commit into
masterfrom
feat/malis-gt-passthrough
Open

MalisLoss: pass eroded gt_seg through data pipeline (skip per-step CC, fix crop topology)#211
donglaiw wants to merge 1 commit into
masterfrom
feat/malis-gt-passthrough

Conversation

@donglaiw
Copy link
Copy Markdown
Collaborator

Summary

Passes the eroded GT segmentation from the data pipeline through to MalisLoss so it can:

  1. Skip its internal connected_components_affgraph(gt_affs, nhood) call every training step (~5–10 % of MALIS step cost; small).
  2. Preserve global instance IDs under malis_crop_size (the primary motivation; fixes a correctness artifact in the prior crop PR).

Opt-in via label_transform.emit_gt_seg: true on the YAML side; default off → bit-for-bit identical to current behaviour for configs without it.

Why

With malis_crop_size enabled (#crop PR), the current MALIS path runs CC on the cropped gt_affs. When a single GT instance spans the crop boundary so that two of its pieces appear inside the crop but are not connected within that window, CC labels them as distinct components. MALIS then injects spurious negative-constraint edges inside one true instance.

Passing the eroded gt_seg from upstream (where global instance IDs are preserved) and cropping it with the same origin fixes this.

Speedup story (measured + documented)

Production config: MedNeXt-L, batch 2, 128³ patch on L40S.

Configuration it/s sec/step h/epoch (5000 steps) Speedup vs original
BCE only (no MALIS) ~0.71 ~1.4 ~1.95 — (reference)
Full-volume MALIS (original) ~0.17 ~5.9 ~7.3 1.0×
MALIS + malis_crop_size: 64 ~0.78 ~1.3 ~1.78 ~4.6×
MALIS + crop + emit_gt_seg: true ~0.78+ ~1.3 ~1.78 ~4.6× plus a few %
  • Crop alone gives ~4.6× speedup vs the original full-volume MALIS (slurm 2505814 vs 2487040).
  • emit_gt_seg adds a small additional speedup AND is primarily a correctness fix for the cropped case.

See docs/source/notes/malis.rst (new) and the MalisLoss class docstring for the full table.

Implementation

Data-pipeline boundary

  • LabelTransformConfig.emit_gt_seg: bool = False (strict dataclass field).
  • CopyItemsd(keys="label", names="gt_seg") inserted in connectomics/data/augmentation/build.py immediately after SegErosionInstanced (both train and val), gated by label_cfg.emit_gt_seg. Makes "post-augment, post-erode" the canonical snapshot point.
  • MultiTaskLabelTransformd is untouched.

Loss orchestrator

  • LossMetadata.gt_seg_arg: Optional[str]. Set to "gt_seg" for MalisLoss; all other losses unaffected.
  • compute_standard_loss(..., gt_seg=None) plumbs the batch's gt_seg to any term whose metadata declares the arg, via the existing extra_loss_kwargs extension point.
  • compute_deep_supervision_loss(..., gt_seg=None) forces gt_seg = None for every head — DS lower heads work on downsampled targets that gt_seg can't match label-correctly. MalisLoss + DS falls back to the legacy CC-recompute path. Test-pinned.

MalisLoss

  • forward(pred, target, mask=None, gt_seg=None) accepts the optional kwarg.
  • _prepare_gt_seg normalizes shape (accepts [B, Z, Y, X] and [B, 1, Z, Y, X]), validates against pred spatial dims.
  • _apply_crop_if_configured now returns a 4-tuple and crops gt_seg at the same origin as pred/target/mask.
  • _compute_malis_weights(..., *, gt_seg=None) uses the supplied seg per sample when provided; falls back to connected_components_affgraph otherwise.

YAML (tutorials/neuron_nisb/base_banis+_malis.yaml)

default:
  data:
    label_transform:
      emit_gt_seg: true   # opt-in; pairs with MalisLoss
  ...
            malis_crop_size: 64

Tests

python -m pytest tests/unit/test_malis_loss.py tests/unit/test_data_factory.py tests/unit/test_loss_orchestrator.py -q76 passed, 1 skipped.

  • test_malis_loss.py — 6 new tests:
    • Metadata declares gt_seg_arg.
    • Uncropped CC equivalence (gt_seg supplied vs reconstructed match within rtol=1e-5).
    • gt_seg=None strict-equality preservation of the legacy path.
    • Cropped-instance-fragmentation bug fix: Path A spies on connected_components_affgraph and asserts the CC labels fragment; Path B asserts CC was NOT called and the supplied gt_seg retains a single instance label; losses differ.
    • Shape validation (ValueError on mismatch).
    • No grad flows through gt_seg.
  • test_data_factory.py — 2 new tests for the CopyItemsd insertion (train + val symmetry, plus a runtime end-to-end equivalence to a manual erosion).
  • test_loss_orchestrator.py — 2 new tests pinning the standard-loss gt_seg plumbing and the DS=legacy fallback.

Docs

  • New docs/source/notes/malis.rst covering MALIS performance and correctness knobs (linked from docs/source/index.rst "Get Started" toctree).
  • MalisLoss class docstring extended with a Performance section citing the measured speedup.

Backward compatibility

  • Default behaviour (no emit_gt_seg) is bit-for-bit identical to current MalisLoss.
  • Configs without MALIS are completely unaffected.
  • Deep-supervision configs are unaffected (orchestrator forces legacy CC fallback when DS is on).

CCC design history

Under .agent/features/malis_gt_passthrough/ in the worktree (gitignored). Plan rounds: 2 (plan_v0 NEEDS_CHANGES → plan_v1 APPROVE_WITH_MINOR_COMMENTS). Code rounds: 1 (code_v0 → review_v0 APPROVE_WITH_MINOR_COMMENTS). 3 minor review_v0 findings were applied as small follow-up commits or kept as observational (see review_v0 artifact).

🤖 Generated with Claude Code

…, fix crop topology)

Adds an opt-in `gt_seg` plumbing path that pipes the eroded GT
segmentation from the data pipeline through to MalisLoss, so MALIS
can skip its internal `connected_components_affgraph(gt_affs, nhood)`
call and use the supplied seg directly.

Primary motivation (correctness): with `malis_crop_size` enabled
(landed in commit 5b0451f), the current MALIS path runs CC on the
*cropped* gt_affs. When a single GT instance spans the crop boundary
so that two of its pieces appear inside the crop but are not
connected within that window, CC labels them as distinct components,
and MALIS then injects spurious negative-constraint edges inside one
true instance. Passing the eroded gt_seg from upstream (where global
instance IDs are preserved) and cropping it with the same origin
fixes this.

Secondary motivation (small speedup): removes the per-step CC inside
`_compute_malis_weights`. Roughly 5-10% of the remaining MALIS step
cost on top of cropping.

Pipeline boundary
- `LabelTransformConfig.emit_gt_seg: bool = False` (strict dataclass
  field, defaults to off).
- `MultiTaskLabelTransformd` is left untouched. Instead, a
  `CopyItemsd(keys="label", names="gt_seg")` MONAI transform is
  inserted in `data/augmentation/build.py` immediately after the
  existing `SegErosionInstanced` step (both train and val), gated by
  `label_cfg.emit_gt_seg`. This makes "post-augment, post-erode" the
  canonical snapshot point.
- Default behavior (no `emit_gt_seg`) is bit-for-bit identical to
  pre-PR `MalisLoss`.

Loss orchestrator
- New `LossMetadata.gt_seg_arg: Optional[str]`. Set to `"gt_seg"`
  for MalisLoss; all other losses unaffected.
- `compute_standard_loss(..., gt_seg=None)` plumbs the batch's
  `gt_seg` to any term whose metadata declares the arg, via the
  existing `extra_loss_kwargs` extension point.
- `compute_deep_supervision_loss(..., gt_seg=None)` forces
  `gt_seg=None` for every head (DS lower heads work on downsampled
  targets that gt_seg can't match label-correctly). MalisLoss + DS
  falls back to the legacy CC-recompute path. Pinned by a unit
  test.

MalisLoss
- `forward(pred, target, mask=None, gt_seg=None)` accepts the
  optional kwarg.
- `_prepare_gt_seg` normalizes shape to `[B, Z, Y, X]` (accepts
  both `[B, Z, Y, X]` and `[B, 1, Z, Y, X]`), validates against
  pred spatial dims.
- `_apply_crop_if_configured` now returns a 4-tuple and crops
  gt_seg at the same origin as pred/target/mask.
- `_compute_malis_weights(..., *, gt_seg=None)` uses the supplied
  seg per sample when provided; falls back to
  `connected_components_affgraph` otherwise.

Speedup story (measured + documented)

Production config: MedNeXt-L, batch 2, 128^3 patch on L40S.

| Config                              | it/s | hours/epoch (5000 steps) |
|-------------------------------------|------|--------------------------|
| BCE only (no MALIS)                 | ~0.71| ~1.95                    |
| Full-volume MALIS (original)        | ~0.17| ~7.3                     |
| MALIS + malis_crop_size=64          | ~0.78| ~1.78                    |
| MALIS + crop=64 + emit_gt_seg=true  | ~0.78+| ~1.78                   |

Crop alone gives ~4.6x speedup vs the original full-volume MALIS
(slurm 2505814 vs 2487040). emit_gt_seg adds a small additional
speedup on top, and is primarily a correctness fix for the cropped
case. See `docs/source/notes/malis.rst` and `MalisLoss` class
docstring.

YAML opt-in (`tutorials/neuron_nisb/base_banis+_malis.yaml`)
- `default.data.label_transform.emit_gt_seg: true` enables the
  passthrough path.
- `malis_crop_size: 64` uncommented to enable the crop in
  production.

Tests
- `tests/unit/test_malis_loss.py` — 6 new tests covering metadata,
  uncropped CC equivalence, `gt_seg=None` legacy preservation, the
  cropped-instance-fragmentation bug fix (Path A spy asserts CC was
  called and labels fragment; Path B asserts CC was NOT called and
  the supplied gt_seg retains a single instance label), shape
  validation, and that no grad flows through gt_seg.
- `tests/unit/test_data_factory.py` — 2 new tests for the
  CopyItemsd insertion (train + val symmetry, and a runtime
  end-to-end equivalence to a manual erosion).
- `tests/unit/test_loss_orchestrator.py` — 2 new tests pinning
  the standard-loss gt_seg plumbing and the DS=legacy fallback.
- All existing test_malis_loss cases unchanged.
- `python -m pytest tests/unit/test_malis_loss.py tests/unit/test_data_factory.py tests/unit/test_loss_orchestrator.py -q`
  -> 76 passed, 1 skipped.

Docs
- New `docs/source/notes/malis.rst` covering MALIS speedup and
  correctness knobs (linked from `docs/source/index.rst` Get
  Started toctree).
- `MalisLoss` class docstring extended with a Performance section
  citing the measured speedup.

CCC design history under `.agent/features/malis_gt_passthrough/`
(gitignored).

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant