Skip to content

fix: prevent F.linear from saving dequantized weights in MatMul4Bit/MatMul8bitLt to save ~13GB VRAM and prevent OOM errors#1935

Open
butterwecksolutions wants to merge 1 commit intobitsandbytes-foundation:mainfrom
butterwecksolutions:main
Open

fix: prevent F.linear from saving dequantized weights in MatMul4Bit/MatMul8bitLt to save ~13GB VRAM and prevent OOM errors#1935
butterwecksolutions wants to merge 1 commit intobitsandbytes-foundation:mainfrom
butterwecksolutions:main

Conversation

@butterwecksolutions
Copy link
Copy Markdown

@butterwecksolutions butterwecksolutions commented May 2, 2026

Problem

During QLoRA training, MatMul4Bit.forward dequantizes NF4 weights
to bf16, then passes them to F.linear(). F.linear internally
creates an autograd node that saves the bf16 weight (~0.5 GB per
decoder layer for 9B models) for its backward pass. These saved
weights accumulate across all layers during forward, consuming
~15 GB VRAM on a 32-layer model.

Root Cause

F.linear() is a standard PyTorch op that saves its inputs for
backward via save_for_backward. The saved fp16 weight is redundant:
MatMul4Bit.backward re-dequantizes the weight from the stored NF4
quantization state anyway (F.dequantize_4bit(B, ctx.state)).

Fix

Wrap the F.linear call in torch.no_grad() inside both
MatMul4Bit.forward and MatMul8bitLt.forward. This prevents the
intermediate autograd node from saving the dequantized weight.
Backward is unchanged — already re-dequantizes from NF4.

Before:

output = F.linear(A, dequantize_4bit(B, state).to(A.dtype).t(), bias)

After:
with torch.no_grad():
output = F.linear(A, dequantize_4bit(B, state).to(A.dtype).t(), bias)

Verification

Memory snapshot analysis (torch.cuda.memory._record_memory_history)
confirmed F.linear saved fp16 weights across all decoder layers.
After the fix, only the NF4-packed weight is stored (ctx.B = uint8,
~0.06 GB per layer).

Layer Before After
Saved weight 0.48 GB (bf16) 0.06 GB (NF4 uint8)
32 layers total 15.4 GB 1.9 GB

9B model training now runs on 12 GB GPUs (was 24 GB minimum).
Zero quality impact — backward re-dequantizes to identical values.

Impact: Unlocked trainable Models per GPU

A single line change in the autograd path enables the next tier of
model sizes on every GPU class:

GPU VRAM Text before Text after VL before VL after
RTX 3060 12 GB 4B (9B OOM) 9B
RTX 4060 12 GB 4B (9B OOM) 9B
RTX 3080 Ti 12 GB 4B (9B OOM) 9B
RTX 4070 12 GB 4B (9B OOM) 9B
RTX 5070 12 GB 4B (9B OOM) 9B
RTX 4060 Ti 16 GB 7B (9B OOM) 13B 9B
RTX 4080 16 GB 7B (9B OOM) 13B 9B
RTX 5080 16 GB 7B (9B OOM) 13B 9B
RTX 3090 24 GB 9B (13B OOM) 20B 9B (13B OOM) 13B
RTX 4090 24 GB 9B (13B OOM) 20B 9B (13B OOM) 13B
RTX 5090 32 GB 13B (20B OOM) 20B 13B (20B OOM) 20B
A10 24 GB 9B (13B OOM) 20B 9B (13B OOM) 13B
L40S 48 GB 20B (34B OOM) 34B 20B (34B OOM) 34B
A100 80 GB 34B (50B OOM) 50B 34B (40B OOM) 40B
H100 80 GB 34B (50B OOM) 50B 34B (40B OOM) 40B

@butterwecksolutions butterwecksolutions force-pushed the main branch 2 times, most recently from f23de9b to 007410d Compare May 2, 2026 17:19
MatMul4Bit and MatMul8bitLt dequantize NF4/INT8 weights to bf16,
then call F.linear which internally saves the bf16 weight for
backward. But both backprop functions re-dequantize from the stored
quantization state anyway. The saved bf16 weight (~0.5 GB per layer
for 9B models) accumulates across all layers during forward, causing
excessive VRAM usage during QLoRA training.

Fix: wrap F.linear in torch.no_grad() to prevent intermediate
autograd node creation. Backward path unchanged — already correct.

Result: ~0.5 GB per layer VRAM reduction with zero quality impact.
@butterwecksolutions butterwecksolutions changed the title important vram saving fix: prevent F.linear from saving dequantized weights in MatMul4Bit/MatMul8bitLt inportant fix: prevent F.linear from saving dequantized weights in MatMul4Bit/MatMul8bitLt May 2, 2026
@butterwecksolutions butterwecksolutions changed the title inportant fix: prevent F.linear from saving dequantized weights in MatMul4Bit/MatMul8bitLt fix: prevent F.linear from saving dequantized weights in MatMul4Bit/MatMul8bitLt May 2, 2026
@butterwecksolutions butterwecksolutions changed the title fix: prevent F.linear from saving dequantized weights in MatMul4Bit/MatMul8bitLt fix: prevent F.linear from saving dequantized weights in MatMul4Bit/MatMul8bitLt to save ~13GB VRAM May 3, 2026
@butterwecksolutions butterwecksolutions changed the title fix: prevent F.linear from saving dequantized weights in MatMul4Bit/MatMul8bitLt to save ~13GB VRAM fix: prevent F.linear from saving dequantized weights in MatMul4Bit/MatMul8bitLt to save ~13GB VRAM and prevent OOM errors May 3, 2026
@matthewdouglas
Copy link
Copy Markdown
Member

Hi @butterwecksolutions,

Thanks for the report and the PR. I am having trouble reproducing this issue. Can you share a minimal reproducer example that shows how you're measuring memory usage, and what you're doing exactly? It may also help to know versions of PyTorch, bitsandbytes, any other relevant libraries, etc.

The way I understand it is that an autograd.Function runs with gradient tracking disabled already, so I'm not sure I understand how this changes anything. Perhaps you're using this with torch.compile and it sees the full graph with forward+backward and decides to "optimize" by reusing the dequantized weights instead of recomputing? That's the best guess I've got here.

As a side-note, I'm currently actually in the middle of doing some work around the MatMul4Bit autograd function, in a couple phases, and might end up turning the whole thing into a PyTorch custom op with register_autograd as opposed to implementing autograd.Function. It will be useful to have a reproducer to see if that changes anything as well.

@butterwecksolutions
Copy link
Copy Markdown
Author

Hi @butterwecksolutions,

Thanks for the report and the PR. I am having trouble reproducing this issue. Can you share a minimal reproducer example that shows how you're measuring memory usage, and what you're doing exactly? It may also help to know versions of PyTorch, bitsandbytes, any other relevant libraries, etc.

The way I understand it is that an autograd.Function runs with gradient tracking disabled already, so I'm not sure I understand how this changes anything. Perhaps you're using this with torch.compile and it sees the full graph with forward+backward and decides to "optimize" by reusing the dequantized weights instead of recomputing? That's the best guess I've got here.

As a side-note, I'm currently actually in the middle of doing some work around the MatMul4Bit autograd function, in a couple phases, and might end up turning the whole thing into a PyTorch custom op with register_autograd as opposed to implementing autograd.Function. It will be useful to have a reproducer to see if that changes anything as well.

OK, this request is a result of a debug marathon of my training pipeline. i currently build a reproducer script with all my fixes and update this within the next few days. Please wait for my results until further investigation.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants