Skip to content

[MaxEngine] Fix TypeError in prefill() during batched inference#3063

Open
jaisong123 wants to merge 1 commit intoAI-Hypercomputer:mainfrom
jaisong123:fix-batched-prefill
Open

[MaxEngine] Fix TypeError in prefill() during batched inference#3063
jaisong123 wants to merge 1 commit intoAI-Hypercomputer:mainfrom
jaisong123:fix-batched-prefill

Conversation

@jaisong123
Copy link

@jaisong123 jaisong123 commented Feb 2, 2026

📝 Title

[MaxEngine] Fix TypeError in prefill() during batched inference (dynamic_slice scalar check)

📋 Description

The Issue

Currently, MaxEngine.prefill() assumes that true_length is always a scalar (0-D tensor). This works fine for single-example inference (batch_size=1).

However, when running batched inference (e.g., batch_size=32 for offline processing or high-throughput serving), true_length becomes a 1-D array of shape (batch_size,).

When this 1-D array is passed to jax.lax.dynamic_slice as a start index, JAX throws a TypeError because dynamic_slice requires scalar start indices:

TypeError: start_indices arguments to dynamic_slice must be scalars, got indices (ShapedArray(int32[]), ShapedArray(int32[32]), ShapedArray(int32[]))

The Fix

This PR updates the logit gathering logic in _prefill_jit to handle both scalar and vector true_length:

  1. Checks if true_length has >0 dimensions.
  2. If it is a vector (batch mode), uses Vectorized Indexing (Advanced Indexing) to gather the correct token logits for each sequence in the batch.
  3. Falls back to the original dynamic_slice for scalar inputs to maintain backward compatibility.

Benefits

  • Enables Batch Decoding for variable-length padded sequences.
  • Unlocks high-throughput offline inference workloads on TPU pods (e.g. processing large datasets).

💻 Code Changes

File: src/MaxText/maxengine.py

    # ... inside _prefill_jit ...

    # [OLD CODE]
    # generated_tokens = jnp.zeros((1, 1), dtype=jnp.int32)
    # selected_logits = jax.lax.dynamic_slice(
    #     flat_logits,
    #     (0, true_length - 1, 0),
    #     (flat_logits.shape[0], 1, flat_logits.shape[2]),
    # )

    # [NEW CODE]
    # Support dynamic batch size for correct initialization
    batch_size = flat_logits.shape[0]
    generated_tokens = jnp.zeros((batch_size, 1), dtype=jnp.int32)

    if hasattr(true_length, 'ndim') and true_length.ndim == 1:
        # Vectorized gather for batched inputs
        # Gather the logit at (batch_idx, true_length[batch_idx]-1) for each item
        batch_indices = jnp.arange(batch_size)
        seq_indices = true_length - 1
        
        # [Batch, Vocab] -> [Batch, 1, Vocab]
        selected_logits = flat_logits[batch_indices, seq_indices, :]
        selected_logits = selected_logits[:, None, :] 
    else:
        # Legacy path for scalar true_length
        selected_logits = jax.lax.dynamic_slice(
            flat_logits,
            (0, true_length - 1, 0),
            (flat_logits.shape[0], 1, flat_logits.shape[2]),
        )

🧪 Test Plan

Verified on TPU v5e-8 (slice) and v5litepod-32:

  • Case 1 (Standard): python3 decode.py ... (Single prompt) -> PASS (Uses dynamic_slice path).
  • Case 2 (Batched): python3 batch_decode.py ... (Batch size 32, mixed lengths) -> PASS (Uses vectorized path).
    • Without fix: Crashes with TypeError.
    • With fix: Successfully extracts logits and generates tokens for all 32 items.

Currently, prefill() assumes true_length is scalar. This patch enables vector true_length (batch_size > 1) by using vectorized indexing instead of dynamic_slice.
@google-cla
Copy link

google-cla bot commented Feb 2, 2026

Thanks for your pull request! It looks like this may be your first contribution to a Google open source project. Before we can look at your pull request, you'll need to sign a Contributor License Agreement (CLA).

View this failed invocation of the CLA check for more information.

For the most up to date status, view the checks section at the bottom of the pull request.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant