Skip to content

Conversation

@yueshen2016
Copy link
Contributor

@yueshen2016 yueshen2016 commented Jan 13, 2026

What does this PR do?

Type of change: ?

Overview: ? This PR adds feature of VLM generation for megatron_generate

Usage

# Add a code snippet demonstrating how to use this

Testing

Before your PR is "Ready for review"

  • Make sure you read and follow Contributor guidelines and your commits are signed.
  • Is this change backward compatible?: Yes/No
  • Did you write any new necessary tests?: Yes/No
  • Did you add or update any necessary documentation?: Yes/No
  • Did you update Changelog?: Yes/No

Additional Information

Summary by CodeRabbit

  • New Features

    • Added Vision Language Model support to text generation pipeline, enabling simultaneous processing of image and text inputs during both generation and prefill operations.
  • Improvements

    • Enhanced data flow to properly route multimodal inputs (images and text tokens) through generation paths with automatic detection and handling of vision-enabled model architectures.

✏️ Tip: You can customize this high-level summary in your review settings.

@yueshen2016 yueshen2016 requested a review from a team as a code owner January 13, 2026 17:16
@yueshen2016 yueshen2016 requested a review from AAnoosheh January 13, 2026 17:16
@coderabbitai
Copy link
Contributor

coderabbitai bot commented Jan 13, 2026

📝 Walkthrough

Walkthrough

Introduces vision-input aware forwarding in megatron_generate.py and megatron_prefill.py. Detects vision language model (VLM) inputs via pixel_values, image_grid_thw, and image_sizes, constructing appropriate position IDs and attention masks. Implements dual-path handling that routes VLM inputs through specialized forward logic while maintaining existing text-only paths.

Changes

Cohort / File(s) Summary
Vision-aware VLM forwarding in generate/prefill
modelopt/torch/utils/plugins/megatron_generate.py, modelopt/torch/utils/plugins/megatron_prefill.py
Added vision input detection and dual-path forwarding logic. For VLMs: constructs vlm_position_ids and vlm_attention_mask, conditionally injects pixel_values/image_grid_thw/image_sizes into forward_args. Text-only path retained as fallback. Updated data flow to pass data_dict with tokens and optional vision inputs to forward step.

Sequence Diagram

sequenceDiagram
    participant DataIterator as Data Iterator
    participant ForwardStep as Forward Step Func
    participant VisionDetector as Vision Input Detector
    participant Model as Model
    
    DataIterator->>ForwardStep: data_dict (tokens + optional vision inputs)
    ForwardStep->>VisionDetector: Detect VLM inputs
    alt Vision Inputs Present
        VisionDetector->>ForwardStep: VLM detected
        ForwardStep->>ForwardStep: Construct vlm_position_ids<br/>Construct vlm_attention_mask<br/>Build forward_args with vision inputs
        ForwardStep->>Model: model(**forward_args)
    else Text-Only Path
        VisionDetector->>ForwardStep: No vision inputs
        ForwardStep->>Model: Original text-only call
    end
    Model-->>ForwardStep: Output
Loading

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~20 minutes

🚥 Pre-merge checks | ✅ 3
✅ Passed checks (3 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title 'Support megatron generate for vlm' directly addresses the main change: adding VLM support to megatron_generate functionality.
Docstring Coverage ✅ Passed Docstring coverage is 100.00% which is sufficient. The required threshold is 80.00%.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing touches
  • 📝 Generate docstrings

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
modelopt/torch/utils/plugins/megatron_generate.py (1)

339-339: Potential division by zero when only one token is generated.

If the first generated token is an EOS token, output_ids.shape[-1] equals 1, causing division by zero.

Proposed fix
-            "tps": time_remaining_outputs / (output_ids.shape[-1] - 1),
+            "tps": time_remaining_outputs / max(output_ids.shape[-1] - 1, 1),
🤖 Fix all issues with AI agents
In @modelopt/torch/utils/plugins/megatron_generate.py:
- Around line 218-243: The VLM branch currently omits passing inference_context
and always builds vlm_position_ids starting at 0, which disables/ breaks
KV-cache during decoding; fix by: when has_vision_inputs is true and an
inference_context is provided (and enable_kv_cache is true), include
inference_context in forward_args (forward_args["inference_context"] =
inference_context) and compute vlm_position_ids by adding the decode offset from
the context (e.g., base = getattr(inference_context, "position_offset",
getattr(inference_context, "curr_seq_len", 0)); vlm_position_ids =
torch.arange(base, base + seq_len, dtype=torch.long,
device=device).unsqueeze(0).expand(batch_size, -1)); alternatively, if you
prefer to disallow KV-cache for VLMs, explicitly set
forward_args["inference_context"] = None (or skip passing it) and ensure
enable_kv_cache is treated as disabled when has_vision_inputs is true.
- Around line 286-292: The vision inputs (pixel_values, image_grid_thw,
image_sizes) are being added to data_dict on every decode step; change the logic
so these keys are only added during the prefill/first generation step (e.g.,
when step == 0 or when an is_prefill flag is true). Locate the block building
data_dict (symbols: data_dict, tokens, pixel_values, image_grid_thw,
image_sizes) inside the generation loop/function and wrap the conditional
additions of pixel_values, image_grid_thw, and image_sizes so they execute only
for the initial prefill step.
📜 Review details

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 5e0d365 and 54c18a0.

📒 Files selected for processing (1)
  • modelopt/torch/utils/plugins/megatron_generate.py
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (4)
  • GitHub Check: wait-checks / wait
  • GitHub Check: wait-checks / wait
  • GitHub Check: code-quality
  • GitHub Check: build-docs
🔇 Additional comments (1)
modelopt/torch/utils/plugins/megatron_generate.py (1)

62-87: megatron_prefill accepts vision parameters but does not use them.

The function signature at lines 44-46 accepts pixel_values, image_grid_thw, and image_sizes, but the inner _forward_step_func doesn't handle vision inputs and these parameters aren't passed to data_iterator. If VLM support for prefill is intended, the same vision-aware logic from megatron_generate should be applied here.

Comment on lines +218 to +243
if has_vision_inputs:
# For VLM models:
# - position_ids: [batch, seq_len] (required for RoPE with multi-modal positions)
# - attention_mask: [batch, seq_len] (simple 1D boolean mask, not 4D causal)
vlm_position_ids = (
torch.arange(seq_len, dtype=torch.long, device=device)
.unsqueeze(0)
.expand(batch_size, -1)
)
vlm_attention_mask = torch.ones((batch_size, seq_len), dtype=torch.bool, device=device)

forward_args = {
"input_ids": data["tokens"],
"position_ids": vlm_position_ids,
"attention_mask": vlm_attention_mask,
"runtime_gather_output": True,
}
# Add vision inputs
if "pixel_values" in data and data["pixel_values"] is not None:
forward_args["pixel_values"] = data["pixel_values"]
if "image_grid_thw" in data and data["image_grid_thw"] is not None:
forward_args["image_grid_thw"] = data["image_grid_thw"]
if "image_sizes" in data and data["image_sizes"] is not None:
forward_args["image_sizes"] = data["image_sizes"]

output_tensor = model(**forward_args)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

VLM path does not support KV-cache decoding.

The VLM branch omits inference_context, so KV-cache is silently disabled for vision-language models even when enable_kv_cache=True. Additionally, vlm_position_ids always starts from 0, which would be incorrect during decode steps if KV-cache were used.

Consider either:

  1. Passing inference_context and computing the correct position offset during decode, or
  2. Explicitly disabling KV-cache when vision inputs are detected (similar to line 172-174 for sequence parallelism).
🤖 Prompt for AI Agents
In @modelopt/torch/utils/plugins/megatron_generate.py around lines 218 - 243,
The VLM branch currently omits passing inference_context and always builds
vlm_position_ids starting at 0, which disables/ breaks KV-cache during decoding;
fix by: when has_vision_inputs is true and an inference_context is provided (and
enable_kv_cache is true), include inference_context in forward_args
(forward_args["inference_context"] = inference_context) and compute
vlm_position_ids by adding the decode offset from the context (e.g., base =
getattr(inference_context, "position_offset", getattr(inference_context,
"curr_seq_len", 0)); vlm_position_ids = torch.arange(base, base + seq_len,
dtype=torch.long, device=device).unsqueeze(0).expand(batch_size, -1));
alternatively, if you prefer to disallow KV-cache for VLMs, explicitly set
forward_args["inference_context"] = None (or skip passing it) and ensure
enable_kv_cache is treated as disabled when has_vision_inputs is true.

Comment on lines +286 to +292
data_dict = {"tokens": tokens}
if pixel_values is not None:
data_dict["pixel_values"] = pixel_values
if image_grid_thw is not None:
data_dict["image_grid_thw"] = image_grid_thw
if image_sizes is not None:
data_dict["image_sizes"] = image_sizes
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Vision inputs are passed on every generation step instead of just prefill.

Vision inputs should only be processed during the prefill phase (step 0). Passing them on every decode step is wasteful and may cause unexpected behavior in some VLM architectures.

Proposed fix: Only include vision inputs on the first step
         data_dict = {"tokens": tokens}
-        if pixel_values is not None:
-            data_dict["pixel_values"] = pixel_values
-        if image_grid_thw is not None:
-            data_dict["image_grid_thw"] = image_grid_thw
-        if image_sizes is not None:
-            data_dict["image_sizes"] = image_sizes
+        # Vision inputs should only be processed during prefill (step 0)
+        if step == 0:
+            if pixel_values is not None:
+                data_dict["pixel_values"] = pixel_values
+            if image_grid_thw is not None:
+                data_dict["image_grid_thw"] = image_grid_thw
+            if image_sizes is not None:
+                data_dict["image_sizes"] = image_sizes
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
data_dict = {"tokens": tokens}
if pixel_values is not None:
data_dict["pixel_values"] = pixel_values
if image_grid_thw is not None:
data_dict["image_grid_thw"] = image_grid_thw
if image_sizes is not None:
data_dict["image_sizes"] = image_sizes
data_dict = {"tokens": tokens}
# Vision inputs should only be processed during prefill (step 0)
if step == 0:
if pixel_values is not None:
data_dict["pixel_values"] = pixel_values
if image_grid_thw is not None:
data_dict["image_grid_thw"] = image_grid_thw
if image_sizes is not None:
data_dict["image_sizes"] = image_sizes
🤖 Prompt for AI Agents
In @modelopt/torch/utils/plugins/megatron_generate.py around lines 286 - 292,
The vision inputs (pixel_values, image_grid_thw, image_sizes) are being added to
data_dict on every decode step; change the logic so these keys are only added
during the prefill/first generation step (e.g., when step == 0 or when an
is_prefill flag is true). Locate the block building data_dict (symbols:
data_dict, tokens, pixel_values, image_grid_thw, image_sizes) inside the
generation loop/function and wrap the conditional additions of pixel_values,
image_grid_thw, and image_sizes so they execute only for the initial prefill
step.

Comment on lines +211 to +215
# Check if this is a VLM model (has vision inputs)
has_vision_inputs = (
("pixel_values" in data and data["pixel_values"] is not None)
or ("image_grid_thw" in data and data["image_grid_thw"] is not None)
or ("image_sizes" in data and data["image_sizes"] is not None)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# Check if this is a VLM model (has vision inputs)
has_vision_inputs = (
("pixel_values" in data and data["pixel_values"] is not None)
or ("image_grid_thw" in data and data["image_grid_thw"] is not None)
or ("image_sizes" in data and data["image_sizes"] is not None)
# Check if this is a VLM model (has vision inputs)
_has_pixel_values = data.get("pixel_values") is not None
_has_image_grid_thw = data.get("image_grid_thw") is not None
_has_image_sizes = data.get("image_sizes") is not None
has_vision_inputs = _has pixel_values or _has_image_grid_thw or _has_image_sizes

Comment on lines +236 to +240
if "pixel_values" in data and data["pixel_values"] is not None:
forward_args["pixel_values"] = data["pixel_values"]
if "image_grid_thw" in data and data["image_grid_thw"] is not None:
forward_args["image_grid_thw"] = data["image_grid_thw"]
if "image_sizes" in data and data["image_sizes"] is not None:
Copy link
Contributor

@AAnoosheh AAnoosheh Jan 13, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if "pixel_values" in data and data["pixel_values"] is not None:
forward_args["pixel_values"] = data["pixel_values"]
if "image_grid_thw" in data and data["image_grid_thw"] is not None:
forward_args["image_grid_thw"] = data["image_grid_thw"]
if "image_sizes" in data and data["image_sizes"] is not None:
if _has_pixel_values:
forward_args["pixel_values"] = data["pixel_values"]
if _has_image_grid_thw:
forward_args["image_grid_thw"] = data["image_grid_thw"]
if _has_image_sizes:

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants