fix: prevent F.linear from saving dequantized weights in MatMul4Bit/MatMul8bitLt to save ~13GB VRAM and prevent OOM errors#1935
Conversation
f23de9b to
007410d
Compare
MatMul4Bit and MatMul8bitLt dequantize NF4/INT8 weights to bf16, then call F.linear which internally saves the bf16 weight for backward. But both backprop functions re-dequantize from the stored quantization state anyway. The saved bf16 weight (~0.5 GB per layer for 9B models) accumulates across all layers during forward, causing excessive VRAM usage during QLoRA training. Fix: wrap F.linear in torch.no_grad() to prevent intermediate autograd node creation. Backward path unchanged — already correct. Result: ~0.5 GB per layer VRAM reduction with zero quality impact.
|
Thanks for the report and the PR. I am having trouble reproducing this issue. Can you share a minimal reproducer example that shows how you're measuring memory usage, and what you're doing exactly? It may also help to know versions of PyTorch, bitsandbytes, any other relevant libraries, etc. The way I understand it is that an As a side-note, I'm currently actually in the middle of doing some work around the |
OK, this request is a result of a debug marathon of my training pipeline. i currently build a reproducer script with all my fixes and update this within the next few days. Please wait for my results until further investigation. |
Problem
During QLoRA training,
MatMul4Bit.forwarddequantizes NF4 weightsto bf16, then passes them to
F.linear().F.linearinternallycreates an autograd node that saves the bf16 weight (~0.5 GB per
decoder layer for 9B models) for its backward pass. These saved
weights accumulate across all layers during forward, consuming
~15 GB VRAM on a 32-layer model.
Root Cause
F.linear()is a standard PyTorch op that saves its inputs forbackward via
save_for_backward. The saved fp16 weight is redundant:MatMul4Bit.backwardre-dequantizes the weight from the stored NF4quantization state anyway (
F.dequantize_4bit(B, ctx.state)).Fix
Wrap the
F.linearcall intorch.no_grad()inside bothMatMul4Bit.forwardandMatMul8bitLt.forward. This prevents theintermediate autograd node from saving the dequantized weight.
Backward is unchanged — already re-dequantizes from NF4.
Before:
output = F.linear(A, dequantize_4bit(B, state).to(A.dtype).t(), bias)
After:
with torch.no_grad():
output = F.linear(A, dequantize_4bit(B, state).to(A.dtype).t(), bias)
Verification
Memory snapshot analysis (torch.cuda.memory._record_memory_history)
confirmed F.linear saved fp16 weights across all decoder layers.
After the fix, only the NF4-packed weight is stored (ctx.B = uint8,
~0.06 GB per layer).
9B model training now runs on 12 GB GPUs (was 24 GB minimum).
Zero quality impact — backward re-dequantizes to identical values.
Impact: Unlocked trainable Models per GPU
A single line change in the autograd path enables the next tier of
model sizes on every GPU class: