fix: disable gradient tracking for entropy logging in RLTrainer #662
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Description
Disabled gradient tracking for entropy in RLTrainer, as it was only being used for logging, adding a wasteful autograd graph on top of the forward/backward.
Type of Change
Testing
uv run pytestlocally.Checklist
Additional Notes
In
verifiers/rl/trainer:306,entropies = entropy_from_logits(logits)calls toverifiers/rl/trainer/utils.pywhere the following expression is called:chunk_entropy = -(torch.exp(logps) * logps).sum(-1)This made multiple, massive logit/entropy tensors within the computational graph
On my local hardware (2x RTX 5060 Ti 16GB),
Qwen2.5-0.5B-Instructwithseq_len=2048,micro_batch_size=8would yield an OOM Error, while with gradient tracking disabled on entropy I was using ~34% less peak memory:Test can be reproduced with the command below (didn't think it fit into any folders so apologies for the long command):
uv run python - <<'PY'
import torch
import torch.nn.functional as F
def selective_log_softmax(logits, targets):
logps = F.log_softmax(logits, dim=-1)
return logps.gather(dim=-1, index=targets.unsqueeze(-1)).squeeze(-1)
def entropy_from_logits(logits, chunk_size=128):
original_shape = logits.shape[:-1]
num_classes = logits.shape[-1]
flat_logits = logits.reshape(-1, num_classes)
entropies = []
for chunk in flat_logits.split(chunk_size, dim=0):
logps = F.log_softmax(chunk, dim=-1)
entropies.append(-(torch.exp(logps) * logps).sum(-1))
return torch.cat(entropies, dim=0).reshape(original_shape)
device = "cuda"
if not torch.cuda.is_available():
raise RuntimeError("CUDA required")
batch, seq, vocab = 4, 2048, 151_936
chunk_size = 128
dtype = torch.bfloat16
torch.randn(1, device=device)
torch.cuda.synchronize()
torch.cuda.reset_peak_memory_stats()
def run(detach):
torch.cuda.reset_peak_memory_stats()
logits = torch.randn(batch, seq, vocab, device=device, dtype=dtype, requires_grad=True)
targets = torch.randint(0, vocab, (batch, seq), device=device)
logprobs = selective_log_softmax(logits, targets)
entropy = entropy_from_logits(logits.detach() if detach else logits, chunk_size=chunk_size)
loss = -logprobs.sum()
loss.backward()
torch.cuda.synchronize()
peak = torch.cuda.max_memory_allocated() / 1e9
del logits, targets, logprobs, entropy, loss
torch.cuda.empty_cache()
return peak
peak_no_detach = run(False)
peak_with_detach = run(True)
print(f"Without detach: {peak_no_detach:.2f} GB")
print(f"With detach: {peak_with_detach:.2f} GB")
saved = peak_no_detach - peak_with_detach
print(f"Saved: {saved:.2f} GB ({100 * saved / peak_no_detach:.1f}%)")
PY
Note
Disables gradient tracking for entropy logging to cut unnecessary autograd memory.
RLTrainer.get_logprobs, callentropy_from_logits(logits.detach())instead of usinglogitsdirectly, preventing entropy tensors from entering the backward graphWritten by Cursor Bugbot for commit 1b3333c. This will update automatically on new commits. Configure here.