-
Notifications
You must be signed in to change notification settings - Fork 127
Add native FP8 model support. #448
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Add comprehensive FP8 quantized model support for models like Qwen3-FP8. This enables loading and running FP8 models with per-block scale factors. Changes: bumblebee.ex: - Add :preserve_source_types option to load_model/2 to keep FP8 types pytorch_params.ex: - Pass preserve_source_types through param loading pipeline - Modify ensure_type/3 to preserve FP8 types when option is set layers.ex: - Add fp8_aware_dense/3 layer that handles FP8 quantized weights - Implements block-wise dequantization using scale_inv parameter - Automatically falls back to identity scaling for non-FP8 models layers/transformer.ex: - Add :attention_dense option to blocks/2, block/2, multi_head_attention/4 - Allows custom dense function for Q, K, V, and output projections text/qwen3.ex: - Update decoder to use fp8_aware_dense for attention via attention_dense - Update gated_ffn to use fp8_aware_dense for FFN layers - Add scale_inv to params_mapping for all attention and FFN layers The implementation supports both: - Pre-dequantization: Convert FP8->F32 before loading - Native FP8: Load FP8 weights directly, apply scale_inv at runtime Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Update dependencies to use git versions of nx and safetensors which
have the new FP8 type representation: {:f8_e4m3fn, 8} instead of
{:f, 8, :e4m3fn}.
Changes:
- Update mix.exs to use git deps for nx, exla, torchx, and safetensors
- Update FP8 type detection pattern in pytorch_params.ex
- Add TODO comments noting deps should be switched back to hex when released
Tested with Qwen/Qwen3-4B-Instruct-2507-FP8 model - loads and generates
correctly with preserve_source_types: true.
Add a new section demonstrating how to load and use FP8 quantized Qwen3 models with preserve_source_types: true option. Updated introduction and summary to reflect the new capability.
e8e7f67 to
d6d5f62
Compare
|
To generate the fp8 tiny model Generate a tiny FP8 Qwen3 model for testing Bumblebee's FP8 support.
This creates a minimal model with:
- FP8 E4M3FN weights for linear layers
- Corresponding weight_scale_inv tensors (128x128 block scaling)
- Saved in safetensors format
Usage:
python generate_fp8_qwen3.py
# Then upload to HuggingFace: huggingface-cli upload roulis/tiny-fp8-qwen3 ./tiny-fp8-qwen3
"""
import torch
import json
import os
from safetensors.torch import save_file
# Tiny model config matching existing tiny-random-Qwen3ForCausalLM
CONFIG = {
"architectures": ["Qwen3ForCausalLM"],
"hidden_size": 32,
"intermediate_size": 64,
"num_attention_heads": 4,
"num_hidden_layers": 2,
"num_key_value_heads": 2,
"vocab_size": 1024,
"head_dim": 8, # hidden_size / num_attention_heads
"rms_norm_eps": 1e-6,
"rope_theta": 1000000.0,
"max_position_embeddings": 512,
"torch_dtype": "float8_e4m3fn",
"model_type": "qwen3",
"use_qk_norm": True,
"tie_word_embeddings": True,
"quantization_config": {
"quant_method": "fp8",
"weight_block_size": [128, 128]
}
}
BLOCK_SIZE = 128
def create_fp8_weight(shape, seed=42):
"""Create a random FP8 E4M3FN weight tensor."""
torch.manual_seed(seed)
# Create random values in valid FP8 E4M3FN range (-448 to 448)
weight_f32 = torch.randn(shape) * 0.1
weight_fp8 = weight_f32.to(torch.float8_e4m3fn)
return weight_fp8
def create_scale_inv(weight_shape):
"""Create scale_inv tensor for block-wise dequantization.
Shape: [ceil(out_features/128), ceil(in_features/128)]
For testing, use scale of 1.0 (identity) so dequantized = original.
"""
out_features, in_features = weight_shape
out_blocks = (out_features + BLOCK_SIZE - 1) // BLOCK_SIZE
in_blocks = (in_features + BLOCK_SIZE - 1) // BLOCK_SIZE
# Use 1.0 for identity scaling (easier to verify in tests)
return torch.ones(out_blocks, in_blocks, dtype=torch.float32)
def generate_model():
hidden_size = CONFIG["hidden_size"]
intermediate_size = CONFIG["intermediate_size"]
num_heads = CONFIG["num_attention_heads"]
num_kv_heads = CONFIG["num_key_value_heads"]
head_dim = CONFIG["head_dim"]
vocab_size = CONFIG["vocab_size"]
num_layers = CONFIG["num_hidden_layers"]
tensors = {}
seed = 0
# Embedding (not quantized)
tensors["model.embed_tokens.weight"] = torch.randn(vocab_size, hidden_size)
for layer_idx in range(num_layers):
prefix = f"model.layers.{layer_idx}"
# Self-attention projections (FP8 quantized)
q_size = num_heads * head_dim
kv_size = num_kv_heads * head_dim
# Q projection
tensors[f"{prefix}.self_attn.q_proj.weight"] = create_fp8_weight((q_size, hidden_size), seed)
seed += 1
tensors[f"{prefix}.self_attn.q_proj.weight_scale_inv"] = create_scale_inv((q_size, hidden_size))
# K projection
tensors[f"{prefix}.self_attn.k_proj.weight"] = create_fp8_weight((kv_size, hidden_size), seed)
seed += 1
tensors[f"{prefix}.self_attn.k_proj.weight_scale_inv"] = create_scale_inv((kv_size, hidden_size))
# V projection
tensors[f"{prefix}.self_attn.v_proj.weight"] = create_fp8_weight((kv_size, hidden_size), seed)
seed += 1
tensors[f"{prefix}.self_attn.v_proj.weight_scale_inv"] = create_scale_inv((kv_size, hidden_size))
# O projection
tensors[f"{prefix}.self_attn.o_proj.weight"] = create_fp8_weight((hidden_size, q_size), seed)
seed += 1
tensors[f"{prefix}.self_attn.o_proj.weight_scale_inv"] = create_scale_inv((hidden_size, q_size))
# QK norms (not quantized)
tensors[f"{prefix}.self_attn.q_norm.weight"] = torch.ones(head_dim)
tensors[f"{prefix}.self_attn.k_norm.weight"] = torch.ones(head_dim)
# MLP (FP8 quantized)
tensors[f"{prefix}.mlp.gate_proj.weight"] = create_fp8_weight((intermediate_size, hidden_size), seed)
seed += 1
tensors[f"{prefix}.mlp.gate_proj.weight_scale_inv"] = create_scale_inv((intermediate_size, hidden_size))
tensors[f"{prefix}.mlp.up_proj.weight"] = create_fp8_weight((intermediate_size, hidden_size), seed)
seed += 1
tensors[f"{prefix}.mlp.up_proj.weight_scale_inv"] = create_scale_inv((intermediate_size, hidden_size))
tensors[f"{prefix}.mlp.down_proj.weight"] = create_fp8_weight((hidden_size, intermediate_size), seed)
seed += 1
tensors[f"{prefix}.mlp.down_proj.weight_scale_inv"] = create_scale_inv((hidden_size, intermediate_size))
# Layer norms (not quantized)
tensors[f"{prefix}.input_layernorm.weight"] = torch.ones(hidden_size)
tensors[f"{prefix}.post_attention_layernorm.weight"] = torch.ones(hidden_size)
# Final norm (not quantized)
tensors["model.norm.weight"] = torch.ones(hidden_size)
# LM head (can be tied to embeddings, but we include it for completeness)
# Not quantized since it shares with embeddings
return tensors
def main():
output_dir = "tiny-fp8-qwen3"
os.makedirs(output_dir, exist_ok=True)
# Generate model tensors
tensors = generate_model()
# Save as safetensors
save_file(tensors, os.path.join(output_dir, "model.safetensors"))
# Save config
with open(os.path.join(output_dir, "config.json"), "w") as f:
json.dump(CONFIG, f, indent=2)
print(f"Model saved to {output_dir}/")
print(f"Total tensors: {len(tensors)}")
print("\nTo upload to HuggingFace:")
print(f" huggingface-cli upload roulis/tiny-fp8-qwen3 {output_dir}")
if __name__ == "__main__":
main() |
- Add fp8_aware_dense layer unit tests - Add FP8 Qwen3 model loading test using roulis/tiny-fp8-qwen3 - Include Python script to generate tiny FP8 test models
d6d5f62 to
6893058
Compare
| # Preserve FP8 E4M3FN types when preserve_source_types is enabled | ||
| {_expected, {:f8_e4m3fn, 8}, true} -> tensor |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We likely don't want to do this here, because Axon may cast and it can lead to inconsistent behaviour (see #311). Ideally we want to apply an Axon.MixedPrecision policy, but we cannot determine it upfront. Also Axon policies apply per layer, but in this case we may have a layer where each param has different type. I need to think about the best way to address it and the loading API we should have.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@seanmor5 do you have any thoughts on how to handle layers where parameters have different types, as part of Axon.MixedPrecision?
Summary
This PR adds support for loading and running FP8 (8-bit floating point) quantized models natively in Bumblebee. FP8 models use approximately half the memory
of BF16 models while maintaining good inference quality.
Changes
Core FP8 Support
preserve_source_typesoption toBumblebee.load_model/2to keep FP8 weights in their native formatdequantize_kernel/3function inBumblebee.Layersfor runtime FP8 → F32 conversion using scale_inv tensors{:f8_e4m3fn, 8}Qwen3 FP8 Integration
params_mappingfor FP8 weight scales (weight_scale_inv) in Qwen3 architectureDependencies
nx,exla,torchx, andsafetensorsfor FP8 type supportDocumentation
Usage
Loading an FP8 Model
Supported FP8 Models
Notes