Skip to content

Commit 9dee06a

Browse files
committed
修改函数参数顺序
1 parent 6615743 commit 9dee06a

16 files changed

Lines changed: 48 additions & 53 deletions

csrc/config/global_config.hpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
#pragma once
22

3-
// #include "infinicore/nn/quantization.hpp"
43
#include "infinicore/nn/rope.hpp"
54
#include "infinicore/ops.hpp"
65
#include "quant_config.hpp"

csrc/engine/rank_worker.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -175,7 +175,7 @@ void RankWorker::thread_loop() {
175175
infinicore::context::setDevice(rank_info_.device);
176176

177177
// Create model using factory (may be expensive)
178-
model_ = InfinilmModelFactory::createModel(rank_info_, pending_cache_config_ != nullptr ? pending_cache_config_.get() : nullptr, global_config_);
178+
model_ = InfinilmModelFactory::createModel(global_config_, rank_info_, pending_cache_config_ != nullptr ? pending_cache_config_.get() : nullptr);
179179
if (!model_) {
180180
throw std::runtime_error("Failed to create model");
181181
}

csrc/models/llama/llama.hpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818

1919
#include "../../config/global_config.hpp"
2020
#include "llama_attention.hpp"
21-
#include "llama_config.hpp"
2221
#include "llama_decoder_layer.hpp"
2322
#include "llama_for_causal_lm.hpp"
2423
#include "llama_mlp.hpp"

csrc/models/llama/llama_attention.cpp

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,12 @@
1717

1818
namespace infinilm::models::llama {
1919

20-
LlamaAttention::LlamaAttention(const infinicore::Device &device,
20+
LlamaAttention::LlamaAttention(std::shared_ptr<infinilm::config::global_config::GlobalConfig> global_config,
21+
const infinicore::Device &device,
2122
size_t layer_idx,
22-
engine::distributed::RankInfo rank_info,
23-
std::shared_ptr<infinilm::config::global_config::GlobalConfig> global_config)
24-
: layer_idx_(layer_idx),
23+
engine::distributed::RankInfo rank_info)
24+
: global_config_(global_config),
25+
layer_idx_(layer_idx),
2526
hidden_size_(global_config->get<size_t>("hidden_size")),
2627
num_attention_heads_(global_config->get<size_t>("num_attention_heads")),
2728
num_key_value_heads_(global_config->get<size_t>("num_key_value_heads")),
@@ -30,8 +31,7 @@ LlamaAttention::LlamaAttention(const infinicore::Device &device,
3031
use_bias_(global_config->get_or<bool>("attention_bias", true)),
3132
use_output_bias_(global_config->get_or<bool>("attention_output_bias", false)),
3233
max_position_embeddings_(global_config->get<size_t>("max_position_embeddings")),
33-
rank_info_(rank_info),
34-
global_config_(global_config) {
34+
rank_info_(rank_info) {
3535
const auto &dtype{global_config_->get_dtype()};
3636

3737
int tp_rank = rank_info.tp_rank;
@@ -54,8 +54,6 @@ LlamaAttention::LlamaAttention(const infinicore::Device &device,
5454
INFINILM_QKV_LINEAR_W8A8_INIT(qkv_proj, "q_proj", "k_proj", "v_proj", hidden_size_, head_dim_, global_config_->get<size_t>("num_attention_heads"), global_config_->get<size_t>("num_key_value_heads"), use_bias_,
5555
dtype, device, rank_info, quant_scheme);
5656

57-
// INFINICORE_NN_MODULE_INIT(o_proj, hidden_size_, hidden_size_, use_output_bias_,
58-
// dtype, device, tp_rank, tp_size, rank_info.comm, quant_scheme);
5957
INFINICORE_NN_MODULE_INIT(o_proj, global_config_->get<size_t>("num_attention_heads") * head_dim_, hidden_size_, use_output_bias_,
6058
dtype, device, tp_rank, tp_size, rank_info.comm, quant_scheme);
6159
break;

csrc/models/llama/llama_attention.hpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -37,10 +37,10 @@ class LlamaAttention : public infinicore::nn::Module {
3737
* @param layer_idx Layer index for cache access
3838
* @param dtype Optional data type for model parameters (defaults to F32)
3939
*/
40-
LlamaAttention(const infinicore::Device &device,
40+
LlamaAttention(std::shared_ptr<infinilm::config::global_config::GlobalConfig> global_config,
41+
const infinicore::Device &device,
4142
size_t layer_idx,
42-
engine::distributed::RankInfo rank_info = engine::distributed::RankInfo(),
43-
std::shared_ptr<infinilm::config::global_config::GlobalConfig> global_config = nullptr);
43+
engine::distributed::RankInfo rank_info = engine::distributed::RankInfo());
4444

4545
/**
4646
* @brief Forward pass: compute attention
@@ -102,6 +102,7 @@ class LlamaAttention : public infinicore::nn::Module {
102102
std::shared_ptr<infinicore::nn::RoPE> rotary_emb_;
103103

104104
private:
105+
std::shared_ptr<infinilm::config::global_config::GlobalConfig> global_config_;
105106
size_t layer_idx_; // Layer index for cache access
106107
size_t hidden_size_;
107108
size_t num_attention_heads_;
@@ -113,7 +114,6 @@ class LlamaAttention : public infinicore::nn::Module {
113114
size_t max_position_embeddings_; // For cache initialization (deprecated, kept for compatibility)
114115

115116
float scaling_;
116-
std::shared_ptr<infinilm::config::global_config::GlobalConfig> global_config_;
117117
};
118118

119119
} // namespace infinilm::models::llama

csrc/models/llama/llama_decoder_layer.cpp

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,10 @@
66

77
namespace infinilm::models::llama {
88

9-
LlamaDecoderLayer::LlamaDecoderLayer(const infinicore::Device &device,
9+
LlamaDecoderLayer::LlamaDecoderLayer(std::shared_ptr<infinilm::config::global_config::GlobalConfig> global_config,
10+
const infinicore::Device &device,
1011
size_t layer_idx,
11-
engine::distributed::RankInfo rank_info,
12-
std::shared_ptr<infinilm::config::global_config::GlobalConfig> global_config) : layer_idx_(layer_idx), rank_info_(rank_info), global_config_(global_config) {
12+
engine::distributed::RankInfo rank_info) : global_config_(global_config), layer_idx_(layer_idx), rank_info_(rank_info) {
1313
const auto &dtype{global_config_->get_dtype()};
1414

1515
// Initialize layer normalization layers
@@ -19,8 +19,8 @@ LlamaDecoderLayer::LlamaDecoderLayer(const infinicore::Device &device,
1919
dtype, device);
2020

2121
// Initialize attention and MLP modules
22-
INFINICORE_NN_MODULE_INIT(self_attn, device, layer_idx, rank_info_, global_config);
23-
INFINICORE_NN_MODULE_INIT(mlp, device, rank_info_, global_config);
22+
INFINICORE_NN_MODULE_INIT(self_attn, global_config, device, layer_idx, rank_info_);
23+
INFINICORE_NN_MODULE_INIT(mlp, global_config, device, rank_info_);
2424
}
2525

2626
infinicore::Tensor LlamaDecoderLayer::forward(const infinicore::Tensor &hidden_states,

csrc/models/llama/llama_decoder_layer.hpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,10 +33,10 @@ class LlamaDecoderLayer : public infinicore::nn::Module {
3333
* @param layer_idx Layer index for cache management and debugging
3434
* @param dtype Optional data type for model parameters (defaults to F32)
3535
*/
36-
LlamaDecoderLayer(const infinicore::Device &device,
36+
LlamaDecoderLayer(std::shared_ptr<infinilm::config::global_config::GlobalConfig> global_config,
37+
const infinicore::Device &device,
3738
size_t layer_idx,
38-
engine::distributed::RankInfo rank_info = engine::distributed::RankInfo(),
39-
std::shared_ptr<infinilm::config::global_config::GlobalConfig> global_config = nullptr);
39+
engine::distributed::RankInfo rank_info = engine::distributed::RankInfo());
4040

4141
/**
4242
* @brief Forward pass: process one decoder layer

csrc/models/llama/llama_for_causal_lm.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,17 +6,17 @@
66

77
namespace infinilm::models::llama {
88

9-
LlamaForCausalLM::LlamaForCausalLM(const infinicore::Device &device,
10-
engine::distributed::RankInfo rank_info,
11-
std::shared_ptr<infinilm::config::global_config::GlobalConfig> global_config) {
9+
LlamaForCausalLM::LlamaForCausalLM(std::shared_ptr<infinilm::config::global_config::GlobalConfig> global_config,
10+
const infinicore::Device &device,
11+
engine::distributed::RankInfo rank_info) {
1212

1313
// Initialize module's device_ member
1414
device_ = device;
1515

1616
const auto &dtype{global_config->get_dtype()};
1717

1818
// Initialize base model
19-
INFINICORE_NN_MODULE_INIT(model, device, rank_info, global_config);
19+
INFINICORE_NN_MODULE_INIT(model, global_config, device, rank_info);
2020

2121
// Initialize language modeling head
2222
// Note: If tie_word_embeddings is true, we would share weights with embed_tokens

csrc/models/llama/llama_for_causal_lm.hpp

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -28,9 +28,9 @@ class LlamaForCausalLM : public InfinilmModel {
2828
* @param config Model configuration
2929
* @param device Device to create tensors on
3030
*/
31-
LlamaForCausalLM(const infinicore::Device &device,
32-
engine::distributed::RankInfo rank_info = engine::distributed::RankInfo(),
33-
std::shared_ptr<infinilm::config::global_config::GlobalConfig> global_config = nullptr);
31+
LlamaForCausalLM(std::shared_ptr<infinilm::config::global_config::GlobalConfig> global_config,
32+
const infinicore::Device &device,
33+
engine::distributed::RankInfo rank_info = engine::distributed::RankInfo());
3434

3535
/**
3636
* @brief Forward pass: compute language modeling logits
@@ -43,7 +43,6 @@ class LlamaForCausalLM : public InfinilmModel {
4343
void reset_cache(const cache::CacheConfig *cache_config) override;
4444

4545
// Module information
46-
// const LlamaConfig &config() const { return model_->config(); }
4746
LlamaModel &model() { return *model_; }
4847
const LlamaModel &model() const { return *model_; }
4948

csrc/models/llama/llama_mlp.cpp

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,12 @@
55

66
namespace infinilm::models::llama {
77

8-
LlamaMLP::LlamaMLP(const infinicore::Device &device,
9-
engine::distributed::RankInfo rank_info,
10-
std::shared_ptr<infinilm::config::global_config::GlobalConfig> global_config)
11-
: hidden_size_(global_config->get<size_t>("hidden_size")),
8+
LlamaMLP::LlamaMLP(std::shared_ptr<infinilm::config::global_config::GlobalConfig> global_config,
9+
const infinicore::Device &device,
10+
engine::distributed::RankInfo rank_info)
11+
: global_config_(global_config), hidden_size_(global_config->get<size_t>("hidden_size")),
1212
intermediate_size_(global_config->get<size_t>("intermediate_size")),
13-
use_bias_(global_config->get_or<bool>("mlp_bias", false)), rank_info_(rank_info), global_config_(global_config) {
13+
use_bias_(global_config->get_or<bool>("mlp_bias", false)), rank_info_(rank_info) {
1414
const auto &dtype{global_config_->get_dtype()};
1515

1616
int tp_rank = rank_info.tp_rank;

0 commit comments

Comments
 (0)