Add Gemma 4 E2B / E4B (text) support to MaxText#3904
Conversation
|
🤖 Hi @gagika, I've received your request, and I'm working on it now! You can track my progress in the logs for more details. |
Codecov Report❌ Patch coverage is 📢 Thoughts on this report? Let us know! |
|
🤖 I'm sorry @gagika, but I was unable to process your request. Please see the logs for more details. |
aireenmei
left a comment
There was a problem hiding this comment.
Thanks for the rapid implementation! I wonder if you have test results from forward_pass_logit_checker? Also do you add some unit tests for comparison with torch on the new modules such as Gemma4SmallPLE, Gemma4SmallAttention, Gemma4SmallDecoderLayer? https://github.com/AI-Hypercomputer/maxtext/blob/main/tests/unit/gemma4_layers_test.py This was added recently.
There was a problem hiding this comment.
Thanks! I agree with @aireenmei that it would be good if we could reuse rope/attention, along with component-wise unit tests (potentially as follow-up). Some minor comments.
|
🤖 Hi @shuningjin, I've received your request, and I'm working on it now! You can track my progress in the logs for more details. |
|
🤖 I'm sorry @shuningjin, but I was unable to process your request. Please see the logs for more details. |
Added those unit test, PTL for forward logits test, yes, I have done for both models. https://paste.googleplex.com/5790253452492800 |
|
🤖 Hi @gagika, I've received your request, and I'm working on it now! You can track my progress in the logs for more details. |
|
🤖 I'm sorry @gagika, but I was unable to process your request. Please see the logs for more details. |
Description
Adds Gemma 4 small variants — E2B and E4B (text-only) — to MaxText.
These are the smaller members of the Gemma 4 family. They share the
broader Gemma 4 attention / norm structure but introduce two new features
that drive their parameter efficiency:
slice of an extra embedding tensor injected by a new
Gemma4SmallPLEblock. Controlled by
hidden_size_per_layer_input/vocab_size_per_layer_input.num_kv_shared_layersdecoder layers reuseK / V from the most recent non-shared layer of the same attention type
(sliding↔sliding, full↔full). E2B additionally widens the MLP on those
shared layers (
use_double_wide_mlp: true) to compensate for themissing parameters.
Both features carry per-layer state that is not expressible inside
nn.scan, so a newGEMMA4_SMALLDecoderBlockTypeis added with itsown non-scanned execution path (
Decoder._apply_gemma4_small_layers).The model validator enforces
scan_layers=Falsefor these variants.What's included
src/maxtext/models/gemma4_small.py(PLE + attentionwith optional KV sharing + decoder layer).
configs/models/gemma4-e2b.ymlandgemma4-e4b.yml.hf_model_configs.py,hf_shape.py,param_mapping.pyupdated to handle PLE params, KV-shared layers, andthe (optional) double-wide MLP.
calculate_gemma4_small_tflops_training_per_device.DecoderBlockType.GEMMA4_SMALL, four newAttentionfields inconfigs/types.py, base.yml defaults, andvalidation that rejects
scan_layers=true/use_multimodal=trueforE2B / E4B.
Out of scope
MaxText support for the gemma4-small vision encoder (clipped linears
in particular) is not in this PR.
use_multimodal=trueis rejected bythe validator with a clear error.
nn.scan; rejected by the validator.Tests
tests/unit/gemma4_small_test.py— attention-pattern dispatch,layer-type tuples, KV donor/shared-layer mapping for both variants.
tests/unit/flop_calculation_test.py::test_calculate_gemma4_small_tflops_*—closed-form TFLOP accounting matching the layer/donor structure.
tests/unit/configs_test.py— E2B / E4B yml configs are loaded bythe existing config-instantiation sweep.
tests/end_to_end/tpu/gemma4/e2b/{convert_gemma4,convert_gemma4_pt}.shtests/end_to_end/tpu/gemma4/e4b/{convert_gemma4,convert_gemma4_pt}.shto_maxtext, then runsforward_pass_logit_checkeragainst the HF model with--max_kl_div=0.03. This is the recommended smoke test aftertouching the model code, param map, or either YAML.
Checklist
Before submitting this PR, please make sure (put X in square brackets):
gemini-reviewlabel.