feat: add mock VLM dataset and Gemma4 pretokenize support#1682
feat: add mock VLM dataset and Gemma4 pretokenize support#1682
Conversation
499b11c to
e7d5d7d
Compare
|
/ok to test e7d5d7d |
|
/ok to test 5255f6e |
|
/claude review |
| k: v[:, :ml] if isinstance(v, torch.Tensor) and v.dim() == 2 and v.shape[1] == seq_len else v | ||
| for k, v in result.items() | ||
| } | ||
| seq_len = ml |
There was a problem hiding this comment.
The truncation comprehension only slices 2D tensors whose second dimension matches seq_len. If a processor returns mm_token_type_ids as a 1D tensor (seq_len,) instead of (1, seq_len), it will pass through untruncated, causing a shape mismatch when it reaches pad_collate_fn.
Consider also handling 1D tensors:
| k: v[:, :ml] if isinstance(v, torch.Tensor) and v.dim() == 2 and v.shape[1] == seq_len else v | |
| for k, v in result.items() | |
| } | |
| seq_len = ml | |
| result = { | |
| k: ( | |
| v[:, :ml] | |
| if isinstance(v, torch.Tensor) and v.dim() == 2 and v.shape[1] == seq_len | |
| else v[:ml] | |
| if isinstance(v, torch.Tensor) and v.dim() == 1 and v.shape[0] == seq_len | |
| else v | |
| ) | |
| for k, v in result.items() | |
| } |
| @@ -0,0 +1,130 @@ | |||
| # Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. | |||
There was a problem hiding this comment.
The tests cover build_mock_vlm_dataset well, but there's no test coverage for the new truncate mode in PreTokenizedDatasetWrapper. This is the more complex and bug-prone part of the PR — it would be good to add at least a basic test that:
- Wraps a mock dataset with
PreTokenizedDatasetWrapper(ds, processor, max_length=N, truncate=True) - Verifies that
input_ids,labels, andattention_maskall have length exactlyN - Verifies that labels are not all
-100(i.e., label building before truncation worked correctly)
|
/ok to test cd3e17b |
Gemma4 cannot use flash_attention_2 (global layers have head_dim=512, FA2 max is 256 — huggingface/transformers#45202). SDPA has a known NaN bug with sliding window + padding (huggingface/transformers#32390) that was fixed for Gemma2 but not ported to Gemma3/4. Eager attention handles all padding correctly. Also add padding_side: right to the 4 original configs that were missing it (the mock configs already had it). Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> Signed-off-by: HuiyingLi <willwin.lee@gmail.com>
|
/ok to test 57dee89 |
| tp_size: 1 | ||
| cp_size: 1 | ||
| sequence_parallel: false | ||
| activation_checkpointing: true |
athitten
left a comment
There was a problem hiding this comment.
Some minor comments. Overall LGTM, thank you @HuiyingLi !
…nals The original gemma4_4b.yaml and gemma4_26b_a4b_moe.yaml configs do not use activation checkpointing, so the mock variants should not either. Also bump 4B mock nproc-per-node from 2 to 8 (needed without AC at seq_len=2048). Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> Signed-off-by: HuiyingLi <willwin.lee@gmail.com>
|
/ok to test 1fae9f9 |
|
/ok to test c8511b0 |
|
/claude review |
| import pytest | ||
| from PIL import Image | ||
|
|
||
| from nemo_automodel.components.datasets.vlm.mock import build_mock_vlm_dataset |
There was a problem hiding this comment.
The tests here thoroughly cover build_mock_vlm_dataset, but this PR also introduces non-trivial logic in:
PreTokenizedDatasetWrappertruncation mode (label-before-truncate ordering, the dict comprehension that selectively truncates tensors by shape)pad_collate_fnhandling ofmm_token_type_ids(2D→1D squeeze, padding, interaction with the autoregressive shift trim at L1230-1232)finetune.pyauto-enable ofpretokenize/truncatewhenmax_lengthis set
These are the most bug-prone parts of the change and would benefit from at least a few unit tests — e.g., a test that feeds a mock sample through PreTokenizedDatasetWrapper with truncate=True and asserts the output shapes are all max_length, and a test that verifies mm_token_type_ids survives pad_collate_fn with the correct shape.
There was a problem hiding this comment.
Code looks solid overall — the label-before-truncate design is clean and the collate_fn extensions are consistent with the existing patterns.
One concern: test coverage only covers the mock dataset builder, not the more complex changes (truncation in PreTokenizedDatasetWrapper, mm_token_type_ids/image_position_ids handling in pad_collate_fn, auto-enable logic in finetune.py). See inline comment for details.
…d_collate_fn mm_token_type_ids - Test that truncate=True produces exact max_length shapes for input_ids, attention_mask, and labels - Test that labels are not all -100 after truncation - Test that mm_token_type_ids (1D) is truncated correctly - Test that pad_collate_fn pads and trims mm_token_type_ids with autoregressive shift for both 1D and 2D inputs Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> Signed-off-by: HuiyingLi <willwin.lee@gmail.com>
|
/claude review |
|
/ok to test df5584b |
| wrapper = PreTokenizedDatasetWrapper(raw_ds, _StubProcessor(), max_length=ml, truncate=True) | ||
|
|
There was a problem hiding this comment.
Bug: this if guard means the test passes vacuously if mm_token_type_ids is ever accidentally dropped from the output. Since the stub processor always returns it, this key should always be present — assert that directly.
| wrapper = PreTokenizedDatasetWrapper(raw_ds, _StubProcessor(), max_length=ml, truncate=True) | |
| assert "mm_token_type_ids" in sample, "mm_token_type_ids missing from output" | |
| assert sample["mm_token_type_ids"].shape[0] <= ml |
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> Signed-off-by: HuiyingLi <willwin.lee@gmail.com>
|
/claude review |
|
/ok to test ba2e547 |
The wrapper's __getitem__ calls _preload_media, _conversation_has_media, build_labels_from_template, etc. which need a full processor. Mock these internals so the tests can run without a real HF processor/model. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> Signed-off-by: HuiyingLi <willwin.lee@gmail.com>
|
/ok to test f282996 |
|
/claude review |
| def _mock_extract_media(conversations): | ||
| return [Image.new("RGB", (4, 4))], [] | ||
|
|
||
|
|
||
| def _mock_build_labels(input_ids, conversations, processor): | ||
| # Mark last half of tokens as label tokens (not -100). |
There was a problem hiding this comment.
Bug: This test should fail. The mock _mock_build_labels sets non-(-100) labels only in the second half (labels[:, seq // 2:]), i.e. indices 32–63 when seq=64. Truncation to ml=32 keeps only indices 0–31, which are all -100. So the assertion (sample["labels"] != -100).any() would fail.
Fix the mock to place some real labels in the first portion of the sequence so they survive truncation:
| def _mock_extract_media(conversations): | |
| return [Image.new("RGB", (4, 4))], [] | |
| def _mock_build_labels(input_ids, conversations, processor): | |
| # Mark last half of tokens as label tokens (not -100). | |
| def _mock_build_labels(input_ids, conversations, processor): | |
| # Mark last three-quarters of tokens as label tokens (not -100). | |
| seq = input_ids.shape[1] | |
| labels = torch.full_like(input_ids, -100) | |
| labels[:, seq // 4:] = input_ids[:, seq // 4:] | |
| return labels |
|
Light review — one bug found, one minor note. Bug: Nit: PR description references |
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> Signed-off-by: HuiyingLi <willwin.lee@gmail.com>
|
/ok to test 29e1fb5 |
|
/claude review |
…ncached decode The Mamba mixer uses different CUDA kernels for cached (chunk_scan + selective_state_update) vs uncached (split_conv1d_scan_combined) paths. These are mathematically equivalent but not bit-identical in bf16, causing token divergence after the first step. Compare against generate(input_ids=...) instead, which uses the same cached kernel path. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> Signed-off-by: HuiyingLi <willwin.lee@gmail.com>
|
/ok to test 4c4282e |
Summary
build_mock_vlm_datasetfor VLM benchmarking/testing without real data downloads — generates random PIL images + dummy text in conversation formattruncatemode toPreTokenizedDatasetWrapper(labels built before truncation) for fixed-length sequence trainingimage_position_ids,mm_token_type_ids) toPreTokenizedDatasetWrapperandpad_collate_fnpretokenize/truncatewhenmax_lengthis set on dataset configpadding_side='left'. Changed to padding=right for training.Issue on gemma4 attn backend: [Gemma 4] Support per-layer FlashAttention: FA2 for sliding layers, SDPA for global layers huggingface/transformers#45201
Issue on gemma4 attn backend: Fix gemma4 has flash-attention incompatbile head-dim=512 huggingface/transformers#45202
Test plan
pytest tests/unit_tests/datasets/vlm/test_mock.py(11 tests pass)default_collate_fnandpad_collate_fnpaths🤖 Generated with Claude Code