Skip to content

Fix mRoPE position ID crash on Qwen2-VL prompt truncation#482

Open
Mr-Neutr0n wants to merge 2 commits into
microsoft:mainfrom
Mr-Neutr0n:fix/qwen-vl-mrope-truncation
Open

Fix mRoPE position ID crash on Qwen2-VL prompt truncation#482
Mr-Neutr0n wants to merge 2 commits into
microsoft:mainfrom
Mr-Neutr0n:fix/qwen-vl-mrope-truncation

Conversation

@Mr-Neutr0n
Copy link
Copy Markdown

Summary

Fixes #441

When training Qwen2.5-VL with agent-lightning + verl, the model crashes in get_rope_index with a shape mismatch:

position_ids[..., attention_mask == 1] = llm_positions

fails because llm_positions length differs from the attention mask true-count.

Root cause: In get_train_data_batch, prompt truncation (prompt_ids[:max_prompt_length]) changes the token count, potentially removing image placeholder tokens. However, image_grid_thw is computed from the original (untruncated) image_urls list. When get_rope_index processes the truncated sequence, it finds fewer <|vision_start|><|image_pad|> regions than image_grid_thw entries, causing the position ID length to diverge from the attention mask count.

Fix: After prompt truncation, count the remaining image regions in the truncated token sequence using the same vision_start_token_id + image_token_id pattern that get_rope_index uses, and slice image_urls to match before computing image_grid_thw.

  • Added _count_images_in_tokens() helper method to detect image regions in token sequences
  • Modified the transition-level mRoPE code path to reconcile image_urls with truncated prompts

Test plan

  • Verify Qwen2.5-VL training with prompts that exceed max_prompt_length and contain images no longer crashes in get_rope_index
  • Verify Qwen2.5-VL training with prompts shorter than max_prompt_length is unaffected (no truncation, all images retained)
  • Verify non-VL model training paths are unaffected (_use_mrope is False)

When training Qwen2.5-VL with agent-lightning + verl, prompt truncation
changes the token count but image_grid_thw is computed from the original
(untruncated) image_urls. This causes get_rope_index to fail with a
shape mismatch because it finds fewer image tokens in the truncated
input_ids than entries in image_grid_thw.

After prompt truncation, count remaining image regions in the truncated
token sequence and slice image_urls to match before computing
image_grid_thw, ensuring consistency between the token content and the
mRoPE spatial metadata.

Fixes microsoft#441
@Mr-Neutr0n Mr-Neutr0n force-pushed the fix/qwen-vl-mrope-truncation branch from bdd1c8d to ca0be5a Compare February 9, 2026 22:01
@Mr-Neutr0n
Copy link
Copy Markdown
Author

Friendly bump! Let me know if there's anything I should update or improve to help move this forward.

raise ValueError(f"Relative path '{path}' requires 'image_base_dir' to be set.")
return os.path.join(self.image_base_dir, path)

def _count_images_in_tokens(self, token_ids: List[int]) -> int:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can't review this PR. @totoluo

Copy link
Copy Markdown
Contributor

@totoluo totoluo Mar 3, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see the PR tried to match the how image is counted for image_grid_thw w.r.t get_rope_index. However, the current mechanism will still fail in a corner case of an image is truncated in the middle. Count index will increment by 1 and get_rope_index fail at the same place.
IMO, we should simply leverage the existing is_dropped_list and put dummy pos_ids and skip _compute_mrope_position_ids for those samples. They should get treated the same way as those exceeded length text prompts. So what we should do is:

      if self._use_mrope and is_drop_list[i]:
          # Don't call get_rope_index — it would crash on truncated images.
          # is_drop_mask will remove this sample in the trainer.
          position_ids_list.append(torch.zeros(4, seq_len, dtype=torch.long, device=device))
      else:
          pos_ids = self._compute_mrope_position_ids(...)
          position_ids_list.append(pos_ids)

There is no harm putting the current code in place, but it's not a fix for all. Thoughts @Mr-Neutr0n?

@totoluo (2026-03-03) noted that the current count-based fix still
fails for a corner case where an image is truncated in the middle of
the prompt — the count will increment by 1 and get_rope_index crashes
at the same place.

This change uses the existing is_drop_list to skip _compute_mrope_position_ids
for samples that will be dropped by is_drop_mask downstream, substituting
a zero placeholder. Those samples are removed by the trainer, so the
placeholder pos_ids are never consumed.

This is a strict superset of the previous fix.
Copilot AI review requested due to automatic review settings June 1, 2026 22:03
@Mr-Neutr0n
Copy link
Copy Markdown
Author

@totoluo — done. Pushed as bf8be4f on the same branch. Implemented your suggestion: when iterating to build position_ids_list, we now check is_drop_list[i] and append a zero torch.zeros(4, seq_len, dtype=torch.long, device=device) placeholder for dropped samples, skipping _compute_mrope_position_ids entirely. Those samples are filtered by is_drop_mask downstream in the trainer, so the placeholder is never used in the loss computation.

Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Note

Copilot was unable to run its full agentic suite in this review.

This PR hardens Qwen2-VL/Qwen2.5-VL mRoPE preprocessing against prompt truncation by reconciling image_urls with the images that remain in the truncated token stream and avoiding crashes when samples are dropped.

Changes:

  • Add _count_images_in_tokens() to infer how many complete image regions remain after prompt truncation.
  • Truncate image_urls to match the inferred image count before computing image_grid_thw.
  • Skip mRoPE position-id computation for dropped samples by inserting placeholder position_ids.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines +331 to +344
image_token_id = getattr(self.processor, "image_token_id", None)
if image_token_id is None and hasattr(self.tokenizer, "convert_tokens_to_ids"):
image_token_id = self.tokenizer.convert_tokens_to_ids("<|image_pad|>")

# Resolve vision_start_token_id -- not stored on the processor, so we
# try the tokenizer first and fall back to the well-known default.
vision_start_token_id = None
if hasattr(self.tokenizer, "convert_tokens_to_ids"):
vision_start_token_id = self.tokenizer.convert_tokens_to_ids("<|vision_start|>")
if vision_start_token_id is None:
vision_start_token_id = 151652 # Qwen2-VL / Qwen2.5-VL default

if image_token_id is None:
return -1
Comment on lines +340 to +341
if vision_start_token_id is None:
vision_start_token_id = 151652 # Qwen2-VL / Qwen2.5-VL default
Comment on lines 954 to 960
if self._use_mrope:
image_urls = trace.get("image_urls", [])
if image_urls:
n_images_in_tokens = self._count_images_in_tokens(prompt_ids)
if n_images_in_tokens >= 0 and n_images_in_tokens < len(image_urls):
image_urls = image_urls[:n_images_in_tokens]
image_grid_thw_list.append(self._get_image_grid_thw(image_urls))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

4 participants