[MaxEngine] Fix TypeError in prefill() during batched inference#3063
Open
jaisong123 wants to merge 1 commit intoAI-Hypercomputer:mainfrom
Open
[MaxEngine] Fix TypeError in prefill() during batched inference#3063jaisong123 wants to merge 1 commit intoAI-Hypercomputer:mainfrom
jaisong123 wants to merge 1 commit intoAI-Hypercomputer:mainfrom
Conversation
Currently, prefill() assumes true_length is scalar. This patch enables vector true_length (batch_size > 1) by using vectorized indexing instead of dynamic_slice.
|
Thanks for your pull request! It looks like this may be your first contribution to a Google open source project. Before we can look at your pull request, you'll need to sign a Contributor License Agreement (CLA). View this failed invocation of the CLA check for more information. For the most up to date status, view the checks section at the bottom of the pull request. |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
📝 Title
[MaxEngine] Fix TypeError in prefill() during batched inference (dynamic_slice scalar check)📋 Description
The Issue
Currently,
MaxEngine.prefill()assumes thattrue_lengthis always a scalar (0-D tensor). This works fine for single-example inference (batch_size=1).However, when running batched inference (e.g.,
batch_size=32for offline processing or high-throughput serving),true_lengthbecomes a 1-D array of shape(batch_size,).When this 1-D array is passed to
jax.lax.dynamic_sliceas a start index, JAX throws aTypeErrorbecausedynamic_slicerequires scalar start indices:The Fix
This PR updates the logit gathering logic in
_prefill_jitto handle both scalar and vectortrue_length:true_lengthhas >0 dimensions.dynamic_slicefor scalar inputs to maintain backward compatibility.Benefits
💻 Code Changes
File:
src/MaxText/maxengine.py🧪 Test Plan
Verified on TPU v5e-8 (slice) and v5litepod-32:
python3 decode.py ...(Single prompt) -> PASS (Usesdynamic_slicepath).python3 batch_decode.py ...(Batch size 32, mixed lengths) -> PASS (Uses vectorized path).TypeError.