-
Notifications
You must be signed in to change notification settings - Fork 237
Support megatron generate for vlm #773
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
📝 WalkthroughWalkthroughIntroduces vision-input aware forwarding in megatron_generate.py and megatron_prefill.py. Detects vision language model (VLM) inputs via pixel_values, image_grid_thw, and image_sizes, constructing appropriate position IDs and attention masks. Implements dual-path handling that routes VLM inputs through specialized forward logic while maintaining existing text-only paths. Changes
Sequence DiagramsequenceDiagram
participant DataIterator as Data Iterator
participant ForwardStep as Forward Step Func
participant VisionDetector as Vision Input Detector
participant Model as Model
DataIterator->>ForwardStep: data_dict (tokens + optional vision inputs)
ForwardStep->>VisionDetector: Detect VLM inputs
alt Vision Inputs Present
VisionDetector->>ForwardStep: VLM detected
ForwardStep->>ForwardStep: Construct vlm_position_ids<br/>Construct vlm_attention_mask<br/>Build forward_args with vision inputs
ForwardStep->>Model: model(**forward_args)
else Text-Only Path
VisionDetector->>ForwardStep: No vision inputs
ForwardStep->>Model: Original text-only call
end
Model-->>ForwardStep: Output
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~20 minutes 🚥 Pre-merge checks | ✅ 3✅ Passed checks (3 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing touches
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 2
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
modelopt/torch/utils/plugins/megatron_generate.py (1)
339-339: Potential division by zero when only one token is generated.If the first generated token is an EOS token,
output_ids.shape[-1]equals 1, causing division by zero.Proposed fix
- "tps": time_remaining_outputs / (output_ids.shape[-1] - 1), + "tps": time_remaining_outputs / max(output_ids.shape[-1] - 1, 1),
🤖 Fix all issues with AI agents
In @modelopt/torch/utils/plugins/megatron_generate.py:
- Around line 218-243: The VLM branch currently omits passing inference_context
and always builds vlm_position_ids starting at 0, which disables/ breaks
KV-cache during decoding; fix by: when has_vision_inputs is true and an
inference_context is provided (and enable_kv_cache is true), include
inference_context in forward_args (forward_args["inference_context"] =
inference_context) and compute vlm_position_ids by adding the decode offset from
the context (e.g., base = getattr(inference_context, "position_offset",
getattr(inference_context, "curr_seq_len", 0)); vlm_position_ids =
torch.arange(base, base + seq_len, dtype=torch.long,
device=device).unsqueeze(0).expand(batch_size, -1)); alternatively, if you
prefer to disallow KV-cache for VLMs, explicitly set
forward_args["inference_context"] = None (or skip passing it) and ensure
enable_kv_cache is treated as disabled when has_vision_inputs is true.
- Around line 286-292: The vision inputs (pixel_values, image_grid_thw,
image_sizes) are being added to data_dict on every decode step; change the logic
so these keys are only added during the prefill/first generation step (e.g.,
when step == 0 or when an is_prefill flag is true). Locate the block building
data_dict (symbols: data_dict, tokens, pixel_values, image_grid_thw,
image_sizes) inside the generation loop/function and wrap the conditional
additions of pixel_values, image_grid_thw, and image_sizes so they execute only
for the initial prefill step.
📜 Review details
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
modelopt/torch/utils/plugins/megatron_generate.py
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (4)
- GitHub Check: wait-checks / wait
- GitHub Check: wait-checks / wait
- GitHub Check: code-quality
- GitHub Check: build-docs
🔇 Additional comments (1)
modelopt/torch/utils/plugins/megatron_generate.py (1)
62-87:megatron_prefillaccepts vision parameters but does not use them.The function signature at lines 44-46 accepts
pixel_values,image_grid_thw, andimage_sizes, but the inner_forward_step_funcdoesn't handle vision inputs and these parameters aren't passed todata_iterator. If VLM support for prefill is intended, the same vision-aware logic frommegatron_generateshould be applied here.
| if has_vision_inputs: | ||
| # For VLM models: | ||
| # - position_ids: [batch, seq_len] (required for RoPE with multi-modal positions) | ||
| # - attention_mask: [batch, seq_len] (simple 1D boolean mask, not 4D causal) | ||
| vlm_position_ids = ( | ||
| torch.arange(seq_len, dtype=torch.long, device=device) | ||
| .unsqueeze(0) | ||
| .expand(batch_size, -1) | ||
| ) | ||
| vlm_attention_mask = torch.ones((batch_size, seq_len), dtype=torch.bool, device=device) | ||
|
|
||
| forward_args = { | ||
| "input_ids": data["tokens"], | ||
| "position_ids": vlm_position_ids, | ||
| "attention_mask": vlm_attention_mask, | ||
| "runtime_gather_output": True, | ||
| } | ||
| # Add vision inputs | ||
| if "pixel_values" in data and data["pixel_values"] is not None: | ||
| forward_args["pixel_values"] = data["pixel_values"] | ||
| if "image_grid_thw" in data and data["image_grid_thw"] is not None: | ||
| forward_args["image_grid_thw"] = data["image_grid_thw"] | ||
| if "image_sizes" in data and data["image_sizes"] is not None: | ||
| forward_args["image_sizes"] = data["image_sizes"] | ||
|
|
||
| output_tensor = model(**forward_args) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
VLM path does not support KV-cache decoding.
The VLM branch omits inference_context, so KV-cache is silently disabled for vision-language models even when enable_kv_cache=True. Additionally, vlm_position_ids always starts from 0, which would be incorrect during decode steps if KV-cache were used.
Consider either:
- Passing
inference_contextand computing the correct position offset during decode, or - Explicitly disabling KV-cache when vision inputs are detected (similar to line 172-174 for sequence parallelism).
🤖 Prompt for AI Agents
In @modelopt/torch/utils/plugins/megatron_generate.py around lines 218 - 243,
The VLM branch currently omits passing inference_context and always builds
vlm_position_ids starting at 0, which disables/ breaks KV-cache during decoding;
fix by: when has_vision_inputs is true and an inference_context is provided (and
enable_kv_cache is true), include inference_context in forward_args
(forward_args["inference_context"] = inference_context) and compute
vlm_position_ids by adding the decode offset from the context (e.g., base =
getattr(inference_context, "position_offset", getattr(inference_context,
"curr_seq_len", 0)); vlm_position_ids = torch.arange(base, base + seq_len,
dtype=torch.long, device=device).unsqueeze(0).expand(batch_size, -1));
alternatively, if you prefer to disallow KV-cache for VLMs, explicitly set
forward_args["inference_context"] = None (or skip passing it) and ensure
enable_kv_cache is treated as disabled when has_vision_inputs is true.
| data_dict = {"tokens": tokens} | ||
| if pixel_values is not None: | ||
| data_dict["pixel_values"] = pixel_values | ||
| if image_grid_thw is not None: | ||
| data_dict["image_grid_thw"] = image_grid_thw | ||
| if image_sizes is not None: | ||
| data_dict["image_sizes"] = image_sizes |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Vision inputs are passed on every generation step instead of just prefill.
Vision inputs should only be processed during the prefill phase (step 0). Passing them on every decode step is wasteful and may cause unexpected behavior in some VLM architectures.
Proposed fix: Only include vision inputs on the first step
data_dict = {"tokens": tokens}
- if pixel_values is not None:
- data_dict["pixel_values"] = pixel_values
- if image_grid_thw is not None:
- data_dict["image_grid_thw"] = image_grid_thw
- if image_sizes is not None:
- data_dict["image_sizes"] = image_sizes
+ # Vision inputs should only be processed during prefill (step 0)
+ if step == 0:
+ if pixel_values is not None:
+ data_dict["pixel_values"] = pixel_values
+ if image_grid_thw is not None:
+ data_dict["image_grid_thw"] = image_grid_thw
+ if image_sizes is not None:
+ data_dict["image_sizes"] = image_sizes📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| data_dict = {"tokens": tokens} | |
| if pixel_values is not None: | |
| data_dict["pixel_values"] = pixel_values | |
| if image_grid_thw is not None: | |
| data_dict["image_grid_thw"] = image_grid_thw | |
| if image_sizes is not None: | |
| data_dict["image_sizes"] = image_sizes | |
| data_dict = {"tokens": tokens} | |
| # Vision inputs should only be processed during prefill (step 0) | |
| if step == 0: | |
| if pixel_values is not None: | |
| data_dict["pixel_values"] = pixel_values | |
| if image_grid_thw is not None: | |
| data_dict["image_grid_thw"] = image_grid_thw | |
| if image_sizes is not None: | |
| data_dict["image_sizes"] = image_sizes |
🤖 Prompt for AI Agents
In @modelopt/torch/utils/plugins/megatron_generate.py around lines 286 - 292,
The vision inputs (pixel_values, image_grid_thw, image_sizes) are being added to
data_dict on every decode step; change the logic so these keys are only added
during the prefill/first generation step (e.g., when step == 0 or when an
is_prefill flag is true). Locate the block building data_dict (symbols:
data_dict, tokens, pixel_values, image_grid_thw, image_sizes) inside the
generation loop/function and wrap the conditional additions of pixel_values,
image_grid_thw, and image_sizes so they execute only for the initial prefill
step.
| # Check if this is a VLM model (has vision inputs) | ||
| has_vision_inputs = ( | ||
| ("pixel_values" in data and data["pixel_values"] is not None) | ||
| or ("image_grid_thw" in data and data["image_grid_thw"] is not None) | ||
| or ("image_sizes" in data and data["image_sizes"] is not None) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| # Check if this is a VLM model (has vision inputs) | |
| has_vision_inputs = ( | |
| ("pixel_values" in data and data["pixel_values"] is not None) | |
| or ("image_grid_thw" in data and data["image_grid_thw"] is not None) | |
| or ("image_sizes" in data and data["image_sizes"] is not None) | |
| # Check if this is a VLM model (has vision inputs) | |
| _has_pixel_values = data.get("pixel_values") is not None | |
| _has_image_grid_thw = data.get("image_grid_thw") is not None | |
| _has_image_sizes = data.get("image_sizes") is not None | |
| has_vision_inputs = _has pixel_values or _has_image_grid_thw or _has_image_sizes |
| if "pixel_values" in data and data["pixel_values"] is not None: | ||
| forward_args["pixel_values"] = data["pixel_values"] | ||
| if "image_grid_thw" in data and data["image_grid_thw"] is not None: | ||
| forward_args["image_grid_thw"] = data["image_grid_thw"] | ||
| if "image_sizes" in data and data["image_sizes"] is not None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| if "pixel_values" in data and data["pixel_values"] is not None: | |
| forward_args["pixel_values"] = data["pixel_values"] | |
| if "image_grid_thw" in data and data["image_grid_thw"] is not None: | |
| forward_args["image_grid_thw"] = data["image_grid_thw"] | |
| if "image_sizes" in data and data["image_sizes"] is not None: | |
| if _has_pixel_values: | |
| forward_args["pixel_values"] = data["pixel_values"] | |
| if _has_image_grid_thw: | |
| forward_args["image_grid_thw"] = data["image_grid_thw"] | |
| if _has_image_sizes: |
What does this PR do?
Type of change: ?
Overview: ? This PR adds feature of VLM generation for megatron_generate
Usage
# Add a code snippet demonstrating how to use thisTesting
Before your PR is "Ready for review"
Additional Information
Summary by CodeRabbit
New Features
Improvements
✏️ Tip: You can customize this high-level summary in your review settings.