Skip to content

Fix GPT3 attention missing KV cache initialization and handling#3927

Merged
copybara-service[bot] merged 1 commit into
mainfrom
shralex_test_5
May 19, 2026
Merged

Fix GPT3 attention missing KV cache initialization and handling#3927
copybara-service[bot] merged 1 commit into
mainfrom
shralex_test_5

Conversation

@shralex
Copy link
Copy Markdown
Collaborator

@shralex shralex commented May 17, 2026

This pull request resolves an issue where Gpt3MultiHeadAttention called AttentionOp without passing cached_values, causing decoding to fail with an AssertionError: assert prefill_kv_cache.

FIXES: b/452778717

Tests

Updated GPT-3 tests. Verified that b/452778717 is fixed on a TPU VM.

Checklist

Before submitting this PR, please make sure (put X in square brackets):

  • I have performed a self-review of my code. For an optional AI review, add the gemini-review label.
  • I have necessary comments in my code, particularly in hard-to-understand areas.
  • I have run end-to-end tests tests and provided workload links above if applicable.
  • I have made or will make corresponding changes to the doc if needed, including adding new documentation pages to the relevant Table of Contents (toctree directive) as explained in our documentation.

@codecov
Copy link
Copy Markdown

codecov Bot commented May 17, 2026

Codecov Report

✅ All modified and coverable lines are covered by tests.

📢 Thoughts on this report? Let us know!

@shralex shralex force-pushed the shralex_test_5 branch 2 times, most recently from ddaccf2 to 12bc21b Compare May 17, 2026 15:58
Comment thread tests/unit/gpt3_test.py
self.rng = jax.random.PRNGKey(1234)

devices_array = maxtext_utils.create_device_mesh(self.cfg)
devices_array = maxtext_utils.create_device_mesh(self.cfg, devices=[jax.devices()[0]])
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why we only use one device for testing? No sharding involved?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is for the KV cache test. Running everything on a single device eliminates communication and ensures the test is fully deterministic.

Comment thread tests/unit/gpt3_test.py
enable_checkpointing=False,
model_name="gpt3-52k",
dtype="float32",
per_device_batch_size=1.0 / jax.device_count(),
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what is the purpose?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Its a way to set global batch size to 1 regardless of device count.

@copybara-service copybara-service Bot merged commit 4ebab2c into main May 19, 2026
35 checks passed
@copybara-service copybara-service Bot deleted the shralex_test_5 branch May 19, 2026 16:11
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants