Skip to content

feat: add mock VLM dataset and Gemma4 pretokenize support#1682

Open
HuiyingLi wants to merge 14 commits intomainfrom
huiyingl/vlm-mock-dataset
Open

feat: add mock VLM dataset and Gemma4 pretokenize support#1682
HuiyingLi wants to merge 14 commits intomainfrom
huiyingl/vlm-mock-dataset

Conversation

@HuiyingLi
Copy link
Copy Markdown
Contributor

@HuiyingLi HuiyingLi commented Apr 4, 2026

Summary

  • Add build_mock_vlm_dataset for VLM benchmarking/testing without real data downloads — generates random PIL images + dummy text in conversation format
  • Add truncate mode to PreTokenizedDatasetWrapper (labels built before truncation) for fixed-length sequence training
  • Add Gemma4 tensor support (image_position_ids, mm_token_type_ids) to PreTokenizedDatasetWrapper and pad_collate_fn
  • Auto-enable pretokenize/truncate when max_length is set on dataset config
  • Gemma4's tokenizer defaults to padding_side='left'. Changed to padding=right for training.
  • Changed attn to eager as WAR based on the following issues:
  1. Issue on gemma4 attn backend: [Gemma 4] Support per-layer FlashAttention: FA2 for sliding layers, SDPA for global layers huggingface/transformers#45201

  2. Issue on gemma4 attn backend: Fix gemma4 has flash-attention incompatbile head-dim=512 huggingface/transformers#45202

Test plan

  • Unit tests: pytest tests/unit_tests/datasets/vlm/test_mock.py (11 tests pass)
  • Gemma4 4B mock training: 26 steps, loss 17.08→3.97, no NaN
  • Gemma4 26B MoE mock training: 26 steps on 8 GPUs, loss 11.79→3.96, no NaN
  • Verified per-sample loss identical between default_collate_fn and pad_collate_fn paths

🤖 Generated with Claude Code

@copy-pr-bot
Copy link
Copy Markdown

copy-pr-bot bot commented Apr 4, 2026

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

@HuiyingLi HuiyingLi force-pushed the huiyingl/vlm-mock-dataset branch from 499b11c to e7d5d7d Compare April 4, 2026 22:48
@HuiyingLi
Copy link
Copy Markdown
Contributor Author

/ok to test e7d5d7d

@HuiyingLi
Copy link
Copy Markdown
Contributor Author

/ok to test 5255f6e

@HuiyingLi
Copy link
Copy Markdown
Contributor Author

/claude review

Comment on lines +976 to +979
k: v[:, :ml] if isinstance(v, torch.Tensor) and v.dim() == 2 and v.shape[1] == seq_len else v
for k, v in result.items()
}
seq_len = ml
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The truncation comprehension only slices 2D tensors whose second dimension matches seq_len. If a processor returns mm_token_type_ids as a 1D tensor (seq_len,) instead of (1, seq_len), it will pass through untruncated, causing a shape mismatch when it reaches pad_collate_fn.

Consider also handling 1D tensors:

Suggested change
k: v[:, :ml] if isinstance(v, torch.Tensor) and v.dim() == 2 and v.shape[1] == seq_len else v
for k, v in result.items()
}
seq_len = ml
result = {
k: (
v[:, :ml]
if isinstance(v, torch.Tensor) and v.dim() == 2 and v.shape[1] == seq_len
else v[:ml]
if isinstance(v, torch.Tensor) and v.dim() == 1 and v.shape[0] == seq_len
else v
)
for k, v in result.items()
}

@@ -0,0 +1,130 @@
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The tests cover build_mock_vlm_dataset well, but there's no test coverage for the new truncate mode in PreTokenizedDatasetWrapper. This is the more complex and bug-prone part of the PR — it would be good to add at least a basic test that:

  1. Wraps a mock dataset with PreTokenizedDatasetWrapper(ds, processor, max_length=N, truncate=True)
  2. Verifies that input_ids, labels, and attention_mask all have length exactly N
  3. Verifies that labels are not all -100 (i.e., label building before truncation worked correctly)

@HuiyingLi
Copy link
Copy Markdown
Contributor Author

/ok to test cd3e17b

Gemma4 cannot use flash_attention_2 (global layers have head_dim=512,
FA2 max is 256 — huggingface/transformers#45202). SDPA has a known
NaN bug with sliding window + padding (huggingface/transformers#32390)
that was fixed for Gemma2 but not ported to Gemma3/4. Eager attention
handles all padding correctly.

Also add padding_side: right to the 4 original configs that were
missing it (the mock configs already had it).

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Signed-off-by: HuiyingLi <willwin.lee@gmail.com>
@HuiyingLi
Copy link
Copy Markdown
Contributor Author

/ok to test 57dee89

tp_size: 1
cp_size: 1
sequence_parallel: false
activation_checkpointing: true
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same question as above

athitten
athitten previously approved these changes Apr 7, 2026
Copy link
Copy Markdown
Contributor

@athitten athitten left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some minor comments. Overall LGTM, thank you @HuiyingLi !

…nals

The original gemma4_4b.yaml and gemma4_26b_a4b_moe.yaml configs do not
use activation checkpointing, so the mock variants should not either.
Also bump 4B mock nproc-per-node from 2 to 8 (needed without AC at
seq_len=2048).

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Signed-off-by: HuiyingLi <willwin.lee@gmail.com>
@HuiyingLi
Copy link
Copy Markdown
Contributor Author

/ok to test 1fae9f9

@HuiyingLi
Copy link
Copy Markdown
Contributor Author

/ok to test c8511b0

@HuiyingLi
Copy link
Copy Markdown
Contributor Author

/claude review

import pytest
from PIL import Image

from nemo_automodel.components.datasets.vlm.mock import build_mock_vlm_dataset
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The tests here thoroughly cover build_mock_vlm_dataset, but this PR also introduces non-trivial logic in:

  • PreTokenizedDatasetWrapper truncation mode (label-before-truncate ordering, the dict comprehension that selectively truncates tensors by shape)
  • pad_collate_fn handling of mm_token_type_ids (2D→1D squeeze, padding, interaction with the autoregressive shift trim at L1230-1232)
  • finetune.py auto-enable of pretokenize/truncate when max_length is set

These are the most bug-prone parts of the change and would benefit from at least a few unit tests — e.g., a test that feeds a mock sample through PreTokenizedDatasetWrapper with truncate=True and asserts the output shapes are all max_length, and a test that verifies mm_token_type_ids survives pad_collate_fn with the correct shape.

Copy link
Copy Markdown
Contributor

@claude claude bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code looks solid overall — the label-before-truncate design is clean and the collate_fn extensions are consistent with the existing patterns.

One concern: test coverage only covers the mock dataset builder, not the more complex changes (truncation in PreTokenizedDatasetWrapper, mm_token_type_ids/image_position_ids handling in pad_collate_fn, auto-enable logic in finetune.py). See inline comment for details.

…d_collate_fn mm_token_type_ids

- Test that truncate=True produces exact max_length shapes for input_ids,
  attention_mask, and labels
- Test that labels are not all -100 after truncation
- Test that mm_token_type_ids (1D) is truncated correctly
- Test that pad_collate_fn pads and trims mm_token_type_ids with
  autoregressive shift for both 1D and 2D inputs

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Signed-off-by: HuiyingLi <willwin.lee@gmail.com>
@HuiyingLi
Copy link
Copy Markdown
Contributor Author

/claude review

@HuiyingLi
Copy link
Copy Markdown
Contributor Author

/ok to test df5584b

Comment on lines +203 to +204
wrapper = PreTokenizedDatasetWrapper(raw_ds, _StubProcessor(), max_length=ml, truncate=True)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bug: this if guard means the test passes vacuously if mm_token_type_ids is ever accidentally dropped from the output. Since the stub processor always returns it, this key should always be present — assert that directly.

Suggested change
wrapper = PreTokenizedDatasetWrapper(raw_ds, _StubProcessor(), max_length=ml, truncate=True)
assert "mm_token_type_ids" in sample, "mm_token_type_ids missing from output"
assert sample["mm_token_type_ids"].shape[0] <= ml

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Signed-off-by: HuiyingLi <willwin.lee@gmail.com>
@HuiyingLi
Copy link
Copy Markdown
Contributor Author

/claude review

@HuiyingLi
Copy link
Copy Markdown
Contributor Author

/ok to test ba2e547

The wrapper's __getitem__ calls _preload_media, _conversation_has_media,
build_labels_from_template, etc. which need a full processor. Mock these
internals so the tests can run without a real HF processor/model.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Signed-off-by: HuiyingLi <willwin.lee@gmail.com>
@HuiyingLi
Copy link
Copy Markdown
Contributor Author

/ok to test f282996

@HuiyingLi
Copy link
Copy Markdown
Contributor Author

/claude review

Comment on lines +185 to +190
def _mock_extract_media(conversations):
return [Image.new("RGB", (4, 4))], []


def _mock_build_labels(input_ids, conversations, processor):
# Mark last half of tokens as label tokens (not -100).
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bug: This test should fail. The mock _mock_build_labels sets non-(-100) labels only in the second half (labels[:, seq // 2:]), i.e. indices 32–63 when seq=64. Truncation to ml=32 keeps only indices 0–31, which are all -100. So the assertion (sample["labels"] != -100).any() would fail.

Fix the mock to place some real labels in the first portion of the sequence so they survive truncation:

Suggested change
def _mock_extract_media(conversations):
return [Image.new("RGB", (4, 4))], []
def _mock_build_labels(input_ids, conversations, processor):
# Mark last half of tokens as label tokens (not -100).
def _mock_build_labels(input_ids, conversations, processor):
# Mark last three-quarters of tokens as label tokens (not -100).
seq = input_ids.shape[1]
labels = torch.full_like(input_ids, -100)
labels[:, seq // 4:] = input_ids[:, seq // 4:]
return labels

@claude
Copy link
Copy Markdown
Contributor

claude bot commented Apr 11, 2026

Light review — one bug found, one minor note.

Bug: test_pretokenized_wrapper_truncate_labels_not_all_ignored should fail — see inline comment. The mock labels only populate indices 32+ but truncation keeps indices 0–31, so all labels are -100 post-truncation.

Nit: PR description references tests/unit_tests/datasets/vlm/test_mock.py but the actual file is test_mock_vlm.py.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Signed-off-by: HuiyingLi <willwin.lee@gmail.com>
@HuiyingLi
Copy link
Copy Markdown
Contributor Author

/ok to test 29e1fb5

@HuiyingLi
Copy link
Copy Markdown
Contributor Author

/claude review

claude[bot]
claude bot previously approved these changes Apr 11, 2026
Copy link
Copy Markdown
Contributor

@claude claude bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

…ncached decode

The Mamba mixer uses different CUDA kernels for cached (chunk_scan +
selective_state_update) vs uncached (split_conv1d_scan_combined) paths.
These are mathematically equivalent but not bit-identical in bf16,
causing token divergence after the first step. Compare against
generate(input_ids=...) instead, which uses the same cached kernel path.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Signed-off-by: HuiyingLi <willwin.lee@gmail.com>
@HuiyingLi
Copy link
Copy Markdown
Contributor Author

/ok to test 4c4282e

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants