Skip to content

Conversation

@shyampathak22
Copy link

@shyampathak22 shyampathak22 commented Dec 23, 2025

Description

Disabled gradient tracking for entropy in RLTrainer, as it was only being used for logging, adding a wasteful autograd graph on top of the forward/backward.

Type of Change

  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Documentation update
  • Test improvement

Testing

  • All existing tests pass when running uv run pytest locally.
  • New tests have been added to cover the changes

Checklist

  • My code follows the style guidelines of this project as outlined in AGENTS.md
  • I have performed a self-review of my own code
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • Any dependent changes have been merged and published

Additional Notes

In verifiers/rl/trainer:306, entropies = entropy_from_logits(logits) calls to verifiers/rl/trainer/utils.py where the following expression is called:

chunk_entropy = -(torch.exp(logps) * logps).sum(-1)

This made multiple, massive logit/entropy tensors within the computational graph

On my local hardware (2x RTX 5060 Ti 16GB), Qwen2.5-0.5B-Instruct with seq_len=2048, micro_batch_size=8 would yield an OOM Error, while with gradient tracking disabled on entropy I was using ~34% less peak memory:

image

Test can be reproduced with the command below (didn't think it fit into any folders so apologies for the long command):

uv run python - <<'PY'
import torch
import torch.nn.functional as F

def selective_log_softmax(logits, targets):
logps = F.log_softmax(logits, dim=-1)
return logps.gather(dim=-1, index=targets.unsqueeze(-1)).squeeze(-1)

def entropy_from_logits(logits, chunk_size=128):
original_shape = logits.shape[:-1]
num_classes = logits.shape[-1]
flat_logits = logits.reshape(-1, num_classes)
entropies = []
for chunk in flat_logits.split(chunk_size, dim=0):
logps = F.log_softmax(chunk, dim=-1)
entropies.append(-(torch.exp(logps) * logps).sum(-1))
return torch.cat(entropies, dim=0).reshape(original_shape)

device = "cuda"
if not torch.cuda.is_available():
raise RuntimeError("CUDA required")

batch, seq, vocab = 4, 2048, 151_936
chunk_size = 128
dtype = torch.bfloat16

torch.randn(1, device=device)
torch.cuda.synchronize()
torch.cuda.reset_peak_memory_stats()

def run(detach):
torch.cuda.reset_peak_memory_stats()
logits = torch.randn(batch, seq, vocab, device=device, dtype=dtype, requires_grad=True)
targets = torch.randint(0, vocab, (batch, seq), device=device)
logprobs = selective_log_softmax(logits, targets)
entropy = entropy_from_logits(logits.detach() if detach else logits, chunk_size=chunk_size)
loss = -logprobs.sum()
loss.backward()
torch.cuda.synchronize()
peak = torch.cuda.max_memory_allocated() / 1e9
del logits, targets, logprobs, entropy, loss
torch.cuda.empty_cache()
return peak

peak_no_detach = run(False)
peak_with_detach = run(True)

print(f"Without detach: {peak_no_detach:.2f} GB")
print(f"With detach: {peak_with_detach:.2f} GB")
saved = peak_no_detach - peak_with_detach
print(f"Saved: {saved:.2f} GB ({100 * saved / peak_no_detach:.1f}%)")
PY


Note

Disables gradient tracking for entropy logging to cut unnecessary autograd memory.

  • In RLTrainer.get_logprobs, call entropy_from_logits(logits.detach()) instead of using logits directly, preventing entropy tensors from entering the backward graph

Written by Cursor Bugbot for commit 1b3333c. This will update automatically on new commits. Configure here.

@CLAassistant
Copy link

CLAassistant commented Dec 23, 2025

CLA assistant check
All committers have signed the CLA.

@shyampathak22 shyampathak22 marked this pull request as ready for review December 23, 2025 09:59
@ADharaUTEXAS123007
Copy link

Doesn't completely solve the issue, still getting this error when you run gsm8k toml at 16th step.

torch.OutOfMemoryError: CUDA out of memory. Tried to allocate 14.52 GiB. GPU 0 has a total capacity of 95.00 GiB of which 16.98 GiB is free. Including non-PyTorch memory, this process has 78.02 GiB memory in use. Of the allocated memory 71.27 GiB is allocated by PyTorch, and 4.07 GiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation. See documentation for Memory Management (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants