Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 48 additions & 1 deletion agentlightning/verl/daemon.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,6 +310,45 @@ def _resolve_image_path(self, path: str) -> str:
raise ValueError(f"Relative path '{path}' requires 'image_base_dir' to be set.")
return os.path.join(self.image_base_dir, path)

def _count_images_in_tokens(self, token_ids: List[int]) -> int:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can't review this PR. @totoluo

Copy link
Contributor

@totoluo totoluo Mar 3, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see the PR tried to match the how image is counted for image_grid_thw w.r.t get_rope_index. However, the current mechanism will still fail in a corner case of an image is truncated in the middle. Count index will increment by 1 and get_rope_index fail at the same place.
IMO, we should simply leverage the existing is_dropped_list and put dummy pos_ids and skip _compute_mrope_position_ids for those samples. They should get treated the same way as those exceeded length text prompts. So what we should do is:

      if self._use_mrope and is_drop_list[i]:
          # Don't call get_rope_index — it would crash on truncated images.
          # is_drop_mask will remove this sample in the trainer.
          position_ids_list.append(torch.zeros(4, seq_len, dtype=torch.long, device=device))
      else:
          pos_ids = self._compute_mrope_position_ids(...)
          position_ids_list.append(pos_ids)

There is no harm putting the current code in place, but it's not a fix for all. Thoughts @Mr-Neutr0n?

"""Count the number of complete image regions in a token ID sequence.

Image regions are identified by finding ``vision_start_token_id``
followed by ``image_token_id``, matching the detection logic used by
``get_rope_index`` in the Qwen2-VL / Qwen2.5-VL model implementation.
This is needed to reconcile ``image_grid_thw`` with truncated prompts
so that mRoPE position IDs are computed correctly.

Args:
token_ids: List of token IDs (possibly truncated).

Returns:
Number of image regions found in the token sequence, or ``-1`` if
the required special-token IDs could not be resolved (in which case
the caller should fall back to the original image count).
"""
# Resolve image_token_id from the processor (set during __init__)
image_token_id = getattr(self.processor, "image_token_id", None)
if image_token_id is None and hasattr(self.tokenizer, "convert_tokens_to_ids"):
image_token_id = self.tokenizer.convert_tokens_to_ids("<|image_pad|>")

# Resolve vision_start_token_id -- not stored on the processor, so we
# try the tokenizer first and fall back to the well-known default.
vision_start_token_id = None
if hasattr(self.tokenizer, "convert_tokens_to_ids"):
vision_start_token_id = self.tokenizer.convert_tokens_to_ids("<|vision_start|>")
if vision_start_token_id is None:
vision_start_token_id = 151652 # Qwen2-VL / Qwen2.5-VL default

if image_token_id is None:
return -1

count = 0
for i in range(len(token_ids) - 1):
if token_ids[i] == vision_start_token_id and token_ids[i + 1] == image_token_id:
count += 1
return count

def _get_image_grid_thw(self, image_urls: List[str]) -> Optional[torch.Tensor]:
"""Compute image_grid_thw from image URLs for M-RoPE computation.

Expand Down Expand Up @@ -907,9 +946,17 @@ def get_train_data_batch(
rollout_id_list.append(rollout_id)
turn_index_list.append(turn_index)

# Compute image_grid_thw for this triplet using image_urls from prompt
# Compute image_grid_thw for this triplet using image_urls from prompt.
# After prompt truncation, some image tokens may have been removed,
# so we must reconcile image_urls with the actual images remaining
# in the (possibly truncated) prompt to avoid shape mismatches in
# get_rope_index when computing mRoPE position IDs.
if self._use_mrope:
image_urls = trace.get("image_urls", [])
if image_urls:
n_images_in_tokens = self._count_images_in_tokens(prompt_ids)
if n_images_in_tokens >= 0 and n_images_in_tokens < len(image_urls):
image_urls = image_urls[:n_images_in_tokens]
image_grid_thw_list.append(self._get_image_grid_thw(image_urls))

elif self.trace_aggregator.get("level", "transition") == "trajectory":
Expand Down