1717
1818namespace infinilm ::models::llama {
1919
20- LlamaAttention::LlamaAttention (const infinicore::Device &device,
20+ LlamaAttention::LlamaAttention (std::shared_ptr<infinilm::config::global_config::GlobalConfig> global_config,
21+ const infinicore::Device &device,
2122 size_t layer_idx,
22- engine::distributed::RankInfo rank_info,
23- std::shared_ptr<infinilm::config::global_config::GlobalConfig> global_config)
24- : layer_idx_(layer_idx),
23+ engine::distributed::RankInfo rank_info)
24+ : global_config_( global_config),
25+ layer_idx_ (layer_idx),
2526 hidden_size_(global_config->get<size_t >(" hidden_size" )),
2627 num_attention_heads_(global_config->get<size_t >(" num_attention_heads" )),
2728 num_key_value_heads_(global_config->get<size_t >(" num_key_value_heads" )),
@@ -30,8 +31,7 @@ LlamaAttention::LlamaAttention(const infinicore::Device &device,
3031 use_bias_(global_config->get_or<bool >(" attention_bias" , true )),
3132 use_output_bias_(global_config->get_or<bool >(" attention_output_bias" , false )),
3233 max_position_embeddings_(global_config->get<size_t >(" max_position_embeddings" )),
33- rank_info_(rank_info),
34- global_config_(global_config) {
34+ rank_info_(rank_info) {
3535 const auto &dtype{global_config_->get_dtype ()};
3636
3737 int tp_rank = rank_info.tp_rank ;
@@ -54,8 +54,6 @@ LlamaAttention::LlamaAttention(const infinicore::Device &device,
5454 INFINILM_QKV_LINEAR_W8A8_INIT (qkv_proj, " q_proj" , " k_proj" , " v_proj" , hidden_size_, head_dim_, global_config_->get <size_t >(" num_attention_heads" ), global_config_->get <size_t >(" num_key_value_heads" ), use_bias_,
5555 dtype, device, rank_info, quant_scheme);
5656
57- // INFINICORE_NN_MODULE_INIT(o_proj, hidden_size_, hidden_size_, use_output_bias_,
58- // dtype, device, tp_rank, tp_size, rank_info.comm, quant_scheme);
5957 INFINICORE_NN_MODULE_INIT (o_proj, global_config_->get <size_t >(" num_attention_heads" ) * head_dim_, hidden_size_, use_output_bias_,
6058 dtype, device, tp_rank, tp_size, rank_info.comm , quant_scheme);
6159 break ;
0 commit comments