diff --git a/example/gpt2/main.cc b/example/gpt2/main.cc index d0d64a2..de24ebc 100644 --- a/example/gpt2/main.cc +++ b/example/gpt2/main.cc @@ -64,6 +64,7 @@ DEFINE_int32( DEFINE_uint32(tensor_parallel, 1, "Tensor Parallel world size"); DEFINE_bool(sequence_parallel, false, "Whether to enable Sequence Parallel"); DEFINE_uint32(pipeline_parallel, 1, "Pipeline Parallel world size, specified the number of PP stages."); +DEFINE_uint32(virtual_pipeline_parallel, 1, "Number of chunks in PP stage."); // precision DEFINE_string(dtype, "float32", "precision used in training (float32/bfloat16)"); @@ -356,7 +357,7 @@ int main(int argc, char *argv[]) { google::InitGoogleLogging(argv[0]); nn::parallel::global::InitAllEnv(FLAGS_nthread_per_process, FLAGS_tensor_parallel, FLAGS_sequence_parallel, - FLAGS_pipeline_parallel); + FLAGS_pipeline_parallel, FLAGS_virtual_pipeline_parallel); LOG(INFO) << nn::parallel::global::ProcessGroupOverview(); diff --git a/example/gpt2/net.cc b/example/gpt2/net.cc index df7fbea..5b22912 100644 --- a/example/gpt2/net.cc +++ b/example/gpt2/net.cc @@ -177,8 +177,9 @@ Block::Forward(const std::vector> &x) { GPT2::GPT2(const GPT2Config &config) : config_(config) { int pp_size = nn::parallel::global::GetPipelineParallelSize(); - auto [is_first_stage, is_last_stage, start_layer, end_layer] - = nn::parallel::PipelineParallel::GetStageInfo(config_.n_layer, pp_size, nn::parallel::pp_rank); + int vpp_size = nn::parallel::global::GetVirtualPipelineParallelSize(); + auto [is_first_stage, is_last_stage, layer_chunks] + = nn::parallel::PipelineParallel::GetStageInfo(config_.n_layer, pp_size, vpp_size); auto tp_world_size = nn::parallel::global::GetTensorParallelSize(); @@ -196,7 +197,10 @@ GPT2::GPT2(const GPT2Config &config) : config_(config) { { std::vector> h; - for (int64_t i = start_layer; i < end_layer; ++i) { h.push_back(std::make_shared(config_)); } + for (const auto &[start_layer, end_layer] : layer_chunks) { + for (int64_t i = start_layer; i < end_layer; ++i) { h.push_back(std::make_shared(config)); } + } + transformer[kHLayerName] = std::make_shared(std::move(h)); if (is_last_stage) { transformer[kLnFLayerName] = std::make_shared(std::vector{config_.n_embd}); @@ -204,6 +208,7 @@ GPT2::GPT2(const GPT2Config &config) : config_(config) { modules_[kTransformerLayerName] = std::make_shared(std::move(transformer)); } + if (is_last_stage) { // don't init this one, we will tie weights modules_[kLMHeadLayerName] = std::make_shared( @@ -224,27 +229,77 @@ GPT2::GPT2(const GPT2Config &config) : config_(config) { ->mutable_parameter(nn::parallel::VocabParallelEmbedding::kParamWeightName) = module(kLMHeadLayerName).parameter(nn::parallel::ColumnParallelLinear::kParamWeightName); } + + if (pp_size > 1) { + BuildChunks(); + } } } -std::vector> -GPT2::Forward(const std::vector> &x) { - int pp_rank = nn::parallel::pp_rank; +void GPT2::BuildChunks() { int pp_size = nn::parallel::global::GetPipelineParallelSize(); - bool is_first_stage = (pp_rank == 0); - bool is_last_stage = (pp_rank == pp_size - 1); + int vpp_size = nn::parallel::global::GetVirtualPipelineParallelSize(); + auto [is_first_stage, is_last_stage, layer_chunks] + = nn::parallel::PipelineParallel::GetStageInfo(config_.n_layer, pp_size, vpp_size); + + auto &transformer = modules_[kTransformerLayerName]; + + auto h_layers = std::dynamic_pointer_cast(transformer->mutable_module(kHLayerName)); + + int layer_offset = 0; + for (size_t chunk_idx = 0; chunk_idx < layer_chunks.size(); ++chunk_idx) { + GPT2Chunk chunk; + if (chunk_idx == 0 && is_first_stage) { + chunk.wte_ = transformer->mutable_module(kWTELayerName); + chunk.wpe_ = transformer->mutable_module(kWPELayerName); + } + + const auto &[start, end] = layer_chunks[chunk_idx]; + int chunk_layer_count = end - start; + + std::vector> chunk_blocks; + int current_index = 0; + for (auto it = h_layers->begin(); it != h_layers->end(); ++it, ++current_index) { + if (current_index >= layer_offset && current_index < layer_offset + chunk_layer_count) { + chunk_blocks.push_back(*it); + } + } + layer_offset += chunk_layer_count; + + chunk.blocks_ = std::make_shared(std::move(chunk_blocks)); + + if (chunk_idx == layer_chunks.size() - 1 && is_last_stage) { + chunk.norm_ = transformer->mutable_module(kLnFLayerName); + chunk.head_ = modules_[kLMHeadLayerName]; + } + + chunks_.push_back(std::move(chunk)); + } +} + +std::vector> +GPT2::ForwardChunk(int chunk_idx, const std::vector> &input) { + if (chunks_.size() <= chunk_idx) { + LOG(FATAL) << "chunk must be built!"; + } + auto &chunk = chunks_[chunk_idx]; // (B, T) - auto x1 = x[0]; + auto x1 = input[0]; const auto device = x1->GetDevice(); + int pp_size = nn::parallel::global::GetPipelineParallelSize(); + int vpp_size = nn::parallel::global::GetVirtualPipelineParallelSize(); + auto [is_first_stage, is_last_stage, layer_chunks] + = nn::parallel::PipelineParallel::GetStageInfo(config_.n_layer, pp_size, vpp_size); + const auto t = x1->Dims()[1] * (is_first_stage ? 1 : nn::parallel::global::GetSequenceParallelSize()); // full_seq_len CHECK_LE(t, config_.block_size) << "Cannot forward sequence of length " << t << ", block size is only " << config_.block_size; + // forward the GPT2 model itself - auto &transformer = modules_[kTransformerLayerName]; - if (is_first_stage) { + if (chunk.has_wte() && chunk.has_wpe()) { // (T_local) // NOTE(zbl): Slice pos sequence when SP is enabled auto tp_world_size = nn::parallel::global::GetTensorParallelSize(); @@ -260,33 +315,82 @@ GPT2::Forward(const std::vector> &x) { auto pos = nn::init::Arange(start, start + t_local, infini_train::DataType::kINT64, device); // (B, T) -> Embedding(V_local, C) -> (B, T, C) - auto tok_emb = transformer->mutable_module(kWTELayerName)->Forward({x1})[0]; + auto tok_emb = chunk.wte_->Forward({x1})[0]; + // (T) -> Embedding(T_max, C) -> (T, C) - auto pos_emb = transformer->mutable_module(kWPELayerName)->Forward({pos})[0]; + auto pos_emb = chunk.wpe_->Forward({pos})[0]; // (B, T, C) x1 = tok_emb + pos_emb; } // (B, T, C) -> transformer -> (B, T, C) - auto h_modules = transformer->mutable_module(kHLayerName); - CHECK_EQ(h_modules->type(), nn::ModuleList::kType) << "Failed to get ModuleList from transformer"; - auto h_layers = std::dynamic_pointer_cast(h_modules); // (bs, seq_len, n_embd) -> transformer -> (bs, seq_len, n_embd) - for (auto &h : *h_layers) { x1 = h->Forward({x1})[0]; } + for (auto &block : *chunk.blocks_) { x1 = block->Forward({x1})[0]; } - if (is_last_stage) { - // (B, T, C) -> Layernorm -> (B, T, C) - auto x3 = transformer->mutable_module(kLnFLayerName)->Forward({x1}); + // (B, T, C) -> Layernorm -> (B, T, C) + if (chunk.has_norm() && chunk.has_head()) { + auto x2 = chunk.norm_->Forward({x1}); // TODO(dcj): add inference-time mini-optimization // (B, T, C) -> Linear(C, V) -> (B, T, V) - auto logits = modules_[kLMHeadLayerName]->Forward(x3); - // (B, T, V_original) + auto logits = chunk.head_->Forward(x2); + return logits; } + return {x1}; } +std::vector> +GPT2::Forward(const std::vector> &x) { + // (B, T) + auto x1 = x[0]; + const auto device = x1->GetDevice(); + + const auto t = x1->Dims()[1]; // T + CHECK_LE(t, config_.block_size) << "Cannot forward sequence of length " << t << ", block size is only " + << config_.block_size; + // forward the GPT2 model itself + auto &transformer = modules_[kTransformerLayerName]; + + // (T_local) + // NOTE(zbl): Slice pos sequence when SP is enabled + auto tp_world_size = nn::parallel::global::GetTensorParallelSize(); + auto sequence_parallel_enabled = nn::parallel::global::GetSequenceParallelEnabled(); + int tp_rank = 0; + if (tp_world_size > 1) { + auto tp_group = nn::parallel::ProcessGroupFactory::Instance()->Get( + nn::parallel::GetTensorParallelProcessGroupName(device->rank().thread_rank())); + tp_rank = tp_group->GetGroupRank(device->rank().thread_rank()); + } + int64_t t_local = sequence_parallel_enabled ? (t / tp_world_size) : t; + int64_t start = sequence_parallel_enabled ? tp_rank * t_local : 0; + auto pos = nn::init::Arange(start, start + t_local, infini_train::DataType::kINT64, device); + + // (B, T) -> Embedding(V_local, C) -> (B, T, C) + auto tok_emb = transformer->mutable_module(kWTELayerName)->Forward({x1})[0]; + // (T) -> Embedding(T_max, C) -> (T, C) + auto pos_emb = transformer->mutable_module(kWPELayerName)->Forward({pos})[0]; + // (B, T, C) + x1 = tok_emb + pos_emb; + + // (B, T, C) -> transformer -> (B, T, C) + auto h_modules = transformer->mutable_module(kHLayerName); + CHECK_EQ(h_modules->type(), nn::ModuleList::kType) << "Failed to get ModuleList from transformer"; + auto h_layers = std::dynamic_pointer_cast(h_modules); + // (bs, seq_len, n_embd) -> transformer -> (bs, seq_len, n_embd) + for (auto &h : *h_layers) { x1 = h->Forward({x1})[0]; } + + // (B, T, C) -> Layernorm -> (B, T, C) + auto x3 = transformer->mutable_module(kLnFLayerName)->Forward({x1}); + + // TODO(dcj): add inference-time mini-optimization + // (B, T, C) -> Linear(C, V) -> (B, T, V) + auto logits = modules_[kLMHeadLayerName]->Forward(x3); + // (B, T, V_original) + return logits; +} + std::shared_ptr GPT2::FromPretrained(ModelType model_type) { // TODO(dcj): implement this later LOG(FATAL) << "Not implemented yet"; @@ -351,9 +455,16 @@ std::shared_ptr GPT2::FromLLMC(const std::string &filepath) { CHECK_EQ(n_embd % n_head, 0) << "n_embd must be divisible by n_head."; CHECK_EQ(n_head % tp_size, 0) << "n_head must be divisible by TP world size."; + // ========== pp_size:num_stages; vpp_size: num_chunks_per_stage ========== int pp_size = nn::parallel::global::GetPipelineParallelSize(); - auto [is_first_stage, is_last_stage, start_layer, end_layer] - = nn::parallel::PipelineParallel::GetStageInfo(n_layer, pp_size, nn::parallel::pp_rank); + int vpp_size = nn::parallel::global::GetVirtualPipelineParallelSize(); + auto [is_first_stage, is_last_stage, layer_chunks] + = nn::parallel::PipelineParallel::GetStageInfo(n_layer, pp_size, vpp_size); + // ========== layer to chunk ========== + std::unordered_map owned_layers; + for (const auto &[start, end] : layer_chunks) { + for (int i = start; i < end; ++i) { owned_layers[i] = true; } + } auto tp_rank = nn::parallel::tp_rank; // calculate xx_size_per_partition @@ -409,13 +520,14 @@ std::shared_ptr GPT2::FromLLMC(const std::string &filepath) { } // transformer.h.{i}.ln_1.weight + int local_layer_index = 0; for (int idx = 0; idx < n_layer; ++idx) { - bool owned = (idx >= start_layer && idx < end_layer); - if (owned) { + if (owned_layers.find(idx) != owned_layers.end()) { auto &tensor = state_dict[std::format("{}.{}.{}.{}.{}", GPT2::kTransformerLayerName, GPT2::kHLayerName, - std::to_string(idx - start_layer), Block::kLn1LayerName, + std::to_string(local_layer_index), Block::kLn1LayerName, nn::LayerNorm::kParamWeightName)]; ReadVectorAllFloat(ifs, static_cast(tensor->DataPtr()), n_embd); + ++local_layer_index; } else { size_t ln_1_w_bytes = n_embd * sizeof(float); ifs.seekg(ln_1_w_bytes, std::ios::cur); @@ -423,13 +535,14 @@ std::shared_ptr GPT2::FromLLMC(const std::string &filepath) { } // transformer.h.{i}.ln_1.bias + local_layer_index = 0; for (int idx = 0; idx < n_layer; ++idx) { - bool owned = (idx >= start_layer && idx < end_layer); - if (owned) { + if (owned_layers.find(idx) != owned_layers.end()) { auto &tensor = state_dict[std::format("{}.{}.{}.{}.{}", GPT2::kTransformerLayerName, GPT2::kHLayerName, - std::to_string(idx - start_layer), Block::kLn1LayerName, + std::to_string(local_layer_index), Block::kLn1LayerName, nn::LayerNorm::kParamBiasName)]; ReadVectorAllFloat(ifs, static_cast(tensor->DataPtr()), n_embd); + ++local_layer_index; } else { size_t ln_1_b_bytes = n_embd * sizeof(float); ifs.seekg(ln_1_b_bytes, std::ios::cur); @@ -437,11 +550,11 @@ std::shared_ptr GPT2::FromLLMC(const std::string &filepath) { } // transformer.h.{i}.attn.c_attn.weight (ColumnParallelLinear, but actually applies on "rows") + local_layer_index = 0; for (int idx = 0; idx < n_layer; ++idx) { - bool owned = (idx >= start_layer && idx < end_layer); - if (owned) { + if (owned_layers.find(idx) != owned_layers.end()) { auto &tensor = state_dict[std::format("{}.{}.{}.{}.{}.{}", GPT2::kTransformerLayerName, GPT2::kHLayerName, - std::to_string(idx - start_layer), Block::kAttnLayerName, + std::to_string(local_layer_index), Block::kAttnLayerName, CausalSelfAttention::kCAttnLayerName, nn::parallel::ColumnParallelLinear::kParamWeightName)]; // NOTE(zbl): In the .bin model file, Q/K/V is concated along last dim, @@ -471,6 +584,8 @@ std::shared_ptr GPT2::FromLLMC(const std::string &filepath) { /*dst=*/dst + (2 * local_C) * cols_all, /*rows=*/rows_all, /*cols=*/cols_all, /*row_start=*/2 * n_embd + tp_rank * local_C, /*row_cnt=*/local_C); + + ++local_layer_index; } else { size_t c_attn_w_bytes = qkv_out * n_embd * sizeof(float); ifs.seekg(c_attn_w_bytes, std::ios::cur); @@ -478,11 +593,11 @@ std::shared_ptr GPT2::FromLLMC(const std::string &filepath) { } // transformer.h.{i}.attn.c_attn.bias (ColumnParallelLinear) + local_layer_index = 0; for (int idx = 0; idx < n_layer; ++idx) { - bool owned = (idx >= start_layer && idx < end_layer); - if (owned) { + if (owned_layers.find(idx) != owned_layers.end()) { auto &tensor = state_dict[std::format("{}.{}.{}.{}.{}.{}", GPT2::kTransformerLayerName, GPT2::kHLayerName, - std::to_string(idx - start_layer), Block::kAttnLayerName, + std::to_string(local_layer_index), Block::kAttnLayerName, CausalSelfAttention::kCAttnLayerName, nn::parallel::ColumnParallelLinear::kParamBiasName)]; // NOTE(zbl): Same as c_attn.weight, the bias for Q/K/V is concated @@ -511,6 +626,8 @@ std::shared_ptr GPT2::FromLLMC(const std::string &filepath) { /*dst=*/dst + (2 * local_C), /*len=*/len_all, /*start=*/2 * n_embd + tp_rank * local_C, /*cnt=*/local_C); + + ++local_layer_index; } else { size_t c_attn_b_bytes = qkv_out * sizeof(float); ifs.seekg(c_attn_b_bytes, std::ios::cur); @@ -518,15 +635,16 @@ std::shared_ptr GPT2::FromLLMC(const std::string &filepath) { } // transformer.h.{i}.attn.c_proj.weight (RowParallelLinear, but actually applies on "columns") + local_layer_index = 0; for (int idx = 0; idx < n_layer; ++idx) { - bool owned = (idx >= start_layer && idx < end_layer); - if (owned) { + if (owned_layers.find(idx) != owned_layers.end()) { auto &tensor = state_dict[std::format("{}.{}.{}.{}.{}.{}", GPT2::kTransformerLayerName, GPT2::kHLayerName, - std::to_string(idx - start_layer), Block::kAttnLayerName, + std::to_string(local_layer_index), Block::kAttnLayerName, CausalSelfAttention::kCProjLayerName, nn::parallel::RowParallelLinear::kParamWeightName)]; ReadMatrixColShardFloat(ifs, static_cast(tensor->DataPtr()), n_embd, n_embd, tp_rank * in_pp, in_pp); + ++local_layer_index; } else { size_t c_proj_w_bytes = n_embd * n_embd * sizeof(float); ifs.seekg(c_proj_w_bytes, std::ios::cur); @@ -534,14 +652,15 @@ std::shared_ptr GPT2::FromLLMC(const std::string &filepath) { } // transformer.h.{i}.attn.c_proj.bias (RowParallelLinear, no shard on bias) + local_layer_index = 0; for (int idx = 0; idx < n_layer; ++idx) { - bool owned = (idx >= start_layer && idx < end_layer); - if (owned) { + if (owned_layers.find(idx) != owned_layers.end()) { auto &tensor = state_dict[std::format("{}.{}.{}.{}.{}.{}", GPT2::kTransformerLayerName, GPT2::kHLayerName, - std::to_string(idx - start_layer), Block::kAttnLayerName, + std::to_string(local_layer_index), Block::kAttnLayerName, CausalSelfAttention::kCProjLayerName, nn::parallel::RowParallelLinear::kParamBiasName)]; ReadVectorAllFloat(ifs, static_cast(tensor->DataPtr()), n_embd); + ++local_layer_index; } else { size_t c_proj_b_bytes = n_embd * sizeof(float); ifs.seekg(c_proj_b_bytes, std::ios::cur); @@ -549,13 +668,14 @@ std::shared_ptr GPT2::FromLLMC(const std::string &filepath) { } // transformer.h.{i}.ln_2.weight + local_layer_index = 0; for (int idx = 0; idx < n_layer; ++idx) { - bool owned = (idx >= start_layer && idx < end_layer); - if (owned) { + if (owned_layers.find(idx) != owned_layers.end()) { auto &tensor = state_dict[std::format("{}.{}.{}.{}.{}", GPT2::kTransformerLayerName, GPT2::kHLayerName, - std::to_string(idx - start_layer), Block::kLn2LayerName, + std::to_string(local_layer_index), Block::kLn2LayerName, nn::LayerNorm::kParamWeightName)]; ReadVectorAllFloat(ifs, static_cast(tensor->DataPtr()), n_embd); + ++local_layer_index; } else { size_t ln_2_w_bytes = n_embd * sizeof(float); ifs.seekg(ln_2_w_bytes, std::ios::cur); @@ -563,13 +683,14 @@ std::shared_ptr GPT2::FromLLMC(const std::string &filepath) { } // transformer.h.{i}.ln_2.bias + local_layer_index = 0; for (int idx = 0; idx < n_layer; ++idx) { - bool owned = (idx >= start_layer && idx < end_layer); - if (owned) { + if (owned_layers.find(idx) != owned_layers.end()) { auto &tensor = state_dict[std::format("{}.{}.{}.{}.{}", GPT2::kTransformerLayerName, GPT2::kHLayerName, - std::to_string(idx - start_layer), Block::kLn2LayerName, + std::to_string(local_layer_index), Block::kLn2LayerName, nn::LayerNorm::kParamBiasName)]; ReadVectorAllFloat(ifs, static_cast(tensor->DataPtr()), n_embd); + ++local_layer_index; } else { size_t ln_2_b_bytes = n_embd * sizeof(float); ifs.seekg(ln_2_b_bytes, std::ios::cur); @@ -577,13 +698,14 @@ std::shared_ptr GPT2::FromLLMC(const std::string &filepath) { } // transformer.h.{i}.mlp.c_fc.weight (ColumnParallelLinear, but actually applies on "rows") + local_layer_index = 0; for (int idx = 0; idx < n_layer; ++idx) { - bool owned = (idx >= start_layer && idx < end_layer); - if (owned) { + if (owned_layers.find(idx) != owned_layers.end()) { auto &tensor = state_dict[std::format( - "{}.{}.{}.{}.{}.{}", GPT2::kTransformerLayerName, GPT2::kHLayerName, std::to_string(idx - start_layer), + "{}.{}.{}.{}.{}.{}", GPT2::kTransformerLayerName, GPT2::kHLayerName, std::to_string(local_layer_index), Block::kMlpLayerName, MLP::kCFcLayerName, nn::parallel::ColumnParallelLinear::kParamWeightName)]; ReadMatrixRowShardFloat(ifs, static_cast(tensor->DataPtr()), fc_out, n_embd, fc_start, fc_pp); + ++local_layer_index; } else { size_t c_fc_w_bytes = fc_out * n_embd * sizeof(float); ifs.seekg(c_fc_w_bytes, std::ios::cur); @@ -591,13 +713,14 @@ std::shared_ptr GPT2::FromLLMC(const std::string &filepath) { } // transformer.h.{i}.mlp.c_fc.bias (ColumnParallelLinear) + local_layer_index = 0; for (int idx = 0; idx < n_layer; ++idx) { - bool owned = (idx >= start_layer && idx < end_layer); - if (owned) { + if (owned_layers.find(idx) != owned_layers.end()) { auto &tensor = state_dict[std::format( - "{}.{}.{}.{}.{}.{}", GPT2::kTransformerLayerName, GPT2::kHLayerName, std::to_string(idx - start_layer), + "{}.{}.{}.{}.{}.{}", GPT2::kTransformerLayerName, GPT2::kHLayerName, std::to_string(local_layer_index), Block::kMlpLayerName, MLP::kCFcLayerName, nn::parallel::ColumnParallelLinear::kParamBiasName)]; ReadVectorShardFloat(ifs, static_cast(tensor->DataPtr()), fc_out, fc_start, fc_pp); + ++local_layer_index; } else { size_t c_fc_b_bytes = fc_out * sizeof(float); ifs.seekg(c_fc_b_bytes, std::ios::cur); @@ -605,14 +728,15 @@ std::shared_ptr GPT2::FromLLMC(const std::string &filepath) { } // transformer.h.{i}.mlp.c_proj.weight (RowParallelLinear, but actually applies on "columns") + local_layer_index = 0; for (int idx = 0; idx < n_layer; ++idx) { - bool owned = (idx >= start_layer && idx < end_layer); - if (owned) { + if (owned_layers.find(idx) != owned_layers.end()) { auto &tensor = state_dict[std::format( - "{}.{}.{}.{}.{}.{}", GPT2::kTransformerLayerName, GPT2::kHLayerName, std::to_string(idx - start_layer), + "{}.{}.{}.{}.{}.{}", GPT2::kTransformerLayerName, GPT2::kHLayerName, std::to_string(local_layer_index), Block::kMlpLayerName, MLP::kCProjLayerName, nn::parallel::RowParallelLinear::kParamWeightName)]; ReadMatrixColShardFloat(ifs, static_cast(tensor->DataPtr()), n_embd, fc_out, tp_rank * in4_pp, in4_pp); + ++local_layer_index; } else { size_t c_proj_w_bytes = fc_out * n_embd * sizeof(float); ifs.seekg(c_proj_w_bytes, std::ios::cur); @@ -620,13 +744,14 @@ std::shared_ptr GPT2::FromLLMC(const std::string &filepath) { } // transformer.h.{i}.mlp.c_proj.bias (RowParallelLinear, no shard on bias) + local_layer_index = 0; for (int idx = 0; idx < n_layer; ++idx) { - bool owned = (idx >= start_layer && idx < end_layer); - if (owned) { + if (owned_layers.find(idx) != owned_layers.end()) { auto &tensor = state_dict[std::format( - "{}.{}.{}.{}.{}.{}", GPT2::kTransformerLayerName, GPT2::kHLayerName, std::to_string(idx - start_layer), + "{}.{}.{}.{}.{}.{}", GPT2::kTransformerLayerName, GPT2::kHLayerName, std::to_string(local_layer_index), Block::kMlpLayerName, MLP::kCProjLayerName, nn::parallel::RowParallelLinear::kParamBiasName)]; ReadVectorAllFloat(ifs, static_cast(tensor->DataPtr()), n_embd); + ++local_layer_index; } else { size_t c_proj_b_bytes = n_embd * sizeof(float); ifs.seekg(c_proj_b_bytes, std::ios::cur); diff --git a/example/gpt2/net.h b/example/gpt2/net.h index dba5cdf..704327b 100644 --- a/example/gpt2/net.h +++ b/example/gpt2/net.h @@ -11,6 +11,10 @@ #include "infini_train/include/nn/modules/module.h" #include "infini_train/include/tensor.h" +namespace infini_train::nn { +class ModuleList; +} + struct GPT2Config { int64_t block_size = 1024; int64_t vocab_size = 50304; @@ -71,6 +75,20 @@ class Block : public infini_train::nn::CloneableModule { Forward(const std::vector> &x) override; }; +class GPT2Chunk { +public: + bool has_wte() const { return wte_ != nullptr; } + bool has_wpe() const { return wpe_ != nullptr; } + bool has_norm() const { return norm_ != nullptr; } + bool has_head() const { return head_ != nullptr; } + + std::shared_ptr wte_ = nullptr; + std::shared_ptr wpe_ = nullptr; + std::shared_ptr blocks_ = nullptr; + std::shared_ptr norm_ = nullptr; + std::shared_ptr head_ = nullptr; +}; + class GPT2 : public infini_train::nn::CloneableModule { public: static constexpr char kWTELayerName[] = "wte"; @@ -92,9 +110,14 @@ class GPT2 : public infini_train::nn::CloneableModule { std::vector> Forward(const std::vector> &x) override; + void BuildChunks(); + std::vector> + ForwardChunk(int chunk_idx, const std::vector> &input) override; + static std::shared_ptr FromPretrained(ModelType model_type); static std::shared_ptr FromLLMC(const std::string &filepath); private: GPT2Config config_; + std::vector chunks_; }; diff --git a/example/llama3/main.cc b/example/llama3/main.cc index 09a6a17..0b3b60e 100644 --- a/example/llama3/main.cc +++ b/example/llama3/main.cc @@ -62,7 +62,8 @@ DEFINE_int32( "When set > 1, enables data parallelism with device=cuda on the specified number of visible CUDA devices."); DEFINE_uint32(tensor_parallel, 1, "Tensor Parallel world size"); DEFINE_bool(sequence_parallel, false, "Whether to enable Sequence Parallel"); -DEFINE_uint32(pipeline_parallel, 1, "Pipeline Parallel world size, , specified the number of PP stages."); +DEFINE_uint32(pipeline_parallel, 1, "Pipeline Parallel world size, specified the number of PP stages."); +DEFINE_uint32(virtual_pipeline_parallel, 1, "Number of chunks in PP stage."); // precision DEFINE_string(dtype, "float32", "precision used in training (float32/bfloat16)"); @@ -333,7 +334,7 @@ int main(int argc, char *argv[]) { google::InitGoogleLogging(argv[0]); nn::parallel::global::InitAllEnv(FLAGS_nthread_per_process, FLAGS_tensor_parallel, FLAGS_sequence_parallel, - FLAGS_pipeline_parallel); + FLAGS_pipeline_parallel, FLAGS_virtual_pipeline_parallel); LOG(INFO) << nn::parallel::global::ProcessGroupOverview(); diff --git a/example/llama3/net.cc b/example/llama3/net.cc index ad7ad4c..c0c232c 100644 --- a/example/llama3/net.cc +++ b/example/llama3/net.cc @@ -326,8 +326,9 @@ std::vector> Block::Forward(const std::vector> transformer; if (is_first_stage) { @@ -336,7 +337,9 @@ LLaMA3::LLaMA3(const LLaMA3Config &config) : config_(config) { } std::vector> h_local; - for (int64_t i = start_layer; i < end_layer; ++i) { h_local.push_back(std::make_shared(config)); } + for (const auto &[start_layer, end_layer] : layer_chunks) { + for (int64_t i = start_layer; i < end_layer; ++i) { h_local.push_back(std::make_shared(config)); } + } transformer[kHLayerName] = std::make_shared(std::move(h_local)); if (is_last_stage) { @@ -352,19 +355,112 @@ LLaMA3::LLaMA3(const LLaMA3Config &config) : config_(config) { /*sequence_parallel=*/nn::parallel::global::GetSequenceParallelEnabled()); } modules_[kTransformerLayerName] = std::make_shared(std::move(transformer)); + + if (pp_size > 1) { + BuildChunks(); + } } -std::vector> LLaMA3::Forward(const std::vector> &x) { - int pp_rank = nn::parallel::pp_rank; +void LLaMA3::BuildChunks() { int pp_size = nn::parallel::global::GetPipelineParallelSize(); - bool is_first_stage = (pp_rank == 0); - bool is_last_stage = (pp_rank == pp_size - 1); + int vpp_size = nn::parallel::global::GetVirtualPipelineParallelSize(); + auto [is_first_stage, is_last_stage, layer_chunks] + = nn::parallel::PipelineParallel::GetStageInfo(config_.n_layer, pp_size, vpp_size); - // (bs, seq_len) - auto x1 = x[0]; + auto &transformer = modules_[kTransformerLayerName]; + + auto h_layers = std::dynamic_pointer_cast(transformer->mutable_module(kHLayerName)); + + int layer_offset = 0; + for (size_t chunk_idx = 0; chunk_idx < layer_chunks.size(); ++chunk_idx) { + LLaMA3Chunk chunk; + if (chunk_idx == 0 && is_first_stage) { + chunk.embedding_ = transformer->mutable_module(kWTELayerName); + } + + const auto &[start, end] = layer_chunks[chunk_idx]; + int chunk_layer_count = end - start; + + std::vector> chunk_blocks; + int current_index = 0; + for (auto it = h_layers->begin(); it != h_layers->end(); ++it, ++current_index) { + if (current_index >= layer_offset && current_index < layer_offset + chunk_layer_count) { + chunk_blocks.push_back(*it); + } + } + layer_offset += chunk_layer_count; + + chunk.blocks_ = std::make_shared(std::move(chunk_blocks)); + + if (chunk_idx == layer_chunks.size() - 1 && is_last_stage) { + chunk.norm_ = transformer->mutable_module(kLnFLayerName); + chunk.head_ = modules_[kLMHeadLayerName]; + } + + chunks_.push_back(std::move(chunk)); + } +} + +std::vector> LLaMA3::ForwardChunk(int chunk_idx, + const std::vector> &input) { + // printf("[stage %d] LLaMA3::ForwardChunk entry\n", nn::parallel::pp_rank); + if (chunks_.size() <= chunk_idx) { + LOG(FATAL) << "chunk must be built!"; + } + auto &chunk = chunks_[chunk_idx]; + auto x1 = input[0]; const auto device = x1->GetDevice(); + + int pp_size = nn::parallel::global::GetPipelineParallelSize(); + int vpp_size = nn::parallel::global::GetVirtualPipelineParallelSize(); + auto [is_first_stage, is_last_stage, layer_chunks] + = nn::parallel::PipelineParallel::GetStageInfo(config_.n_layer, pp_size, vpp_size); + const auto t = x1->Dims()[1] * (is_first_stage ? 1 : nn::parallel::global::GetSequenceParallelSize()); // full_seq_len + + if (chunk.has_embedding()) { + x1 = chunk.embedding_->Forward({x1})[0]; + } + + // Init freqs_cis on device only once + // TODO(zbl): consider moving this to model construction + if (buffers_[kFreqsCisName] == nullptr) { + buffers_[kFreqsCisName] = PrecomputeFreqsCis(config_.n_embd / config_.n_head, config_.block_size * 2, + config_.rope_theta, config_.use_scaled_rope, device); + } + + // TODO(zbl): dynamic start_pos + int64_t start_pos = 0; + auto freqs_view = buffers_[kFreqsCisName]->Slice(0, start_pos, start_pos + t, 1); + + // TODO(lzm): add dtype support for nn::function::Ones later + std::shared_ptr ones = std::make_shared(nn::function::Ones({t, t})->To(x1->GetDevice())); + std::shared_ptr mask = nn::function::Triu(ones, 1)->View({1, 1, t, t}); + + std::shared_ptr start_pos_ptr = nullptr; + + for (auto &block : *chunk.blocks_) { x1 = block->Forward({x1, freqs_view, start_pos_ptr, mask})[0]; } + + if (chunk.has_norm() && chunk.has_head()) { + // (bs, seq_len, n_embd) -> RMSNorm -> (bs, seq_len, n_embd) + auto x2 = chunk.norm_->Forward({x1}); + + // TODO(zbl): add inference-time mini-optimization + // (bs, seq_len, n_embd) -> Linear(n_embd, vocab_size) -> (bs, seq_len, vocab_size) + auto logits = chunk.head_->Forward(x2); + + return logits; + } + + return {x1}; +} + +std::vector> LLaMA3::Forward(const std::vector> &x) { + // (bs, seq_len) + auto x1 = x[0]; + const auto device = x1->GetDevice(); + const auto t = x1->Dims()[1]; // full_seq_len CHECK_LE(t, config_.block_size) << "Cannot forward sequence of length " << t << ", block size is only " << config_.block_size; @@ -378,10 +474,8 @@ std::vector> LLaMA3::Forward(const std::vector Embedding(vocab_size, n_embd) -> (bs, seq_len, n_embd) - x1 = transformer->mutable_module(kWTELayerName)->Forward({x1})[0]; - } + // (bs, seq_len) -> Embedding(vocab_size, n_embd) -> (bs, seq_len, n_embd) + x1 = transformer->mutable_module(kWTELayerName)->Forward({x1})[0]; // TODO(zbl): dynamic start_pos int64_t start_pos = 0; @@ -399,19 +493,15 @@ std::vector> LLaMA3::Forward(const std::vector transformer -> (bs, seq_len, n_embd) for (auto &h : *h_layers) { x1 = h->Forward({x1, freqs_view, start_pos_ptr, mask})[0]; } - if (is_last_stage) { - // (bs, seq_len, n_embd) -> RMSNorm -> (bs, seq_len, n_embd) - auto x2 = transformer->mutable_module(kLnFLayerName)->Forward({x1}); - - // TODO(zbl): add inference-time mini-optimization - // (bs, seq_len, n_embd) -> Linear(n_embd, vocab_size) -> (bs, seq_len, vocab_size) - auto logits = modules_[kLMHeadLayerName]->Forward(x2); + // (bs, seq_len, n_embd) -> RMSNorm -> (bs, seq_len, n_embd) + auto x2 = transformer->mutable_module(kLnFLayerName)->Forward({x1}); - // (bs, seq_len, vocab_size) - return logits; - } + // TODO(zbl): add inference-time mini-optimization + // (bs, seq_len, n_embd) -> Linear(n_embd, vocab_size) -> (bs, seq_len, vocab_size) + auto logits = modules_[kLMHeadLayerName]->Forward(x2); - return {x1}; + // (bs, seq_len, vocab_size) + return logits; } std::shared_ptr LLaMA3::FromPretrained(ModelType model_type) { @@ -465,9 +555,17 @@ std::shared_ptr LLaMA3::FromLLMC(const std::string &filepath) { .use_scaled_rope = static_cast(use_scaled_rope), .norm_eps = norm_eps, .max_gen_batch_size = max_gen_bs}); + + // ========== pp_size:num_stages; vpp_size: num_chunks_per_stage ========== int pp_size = nn::parallel::global::GetPipelineParallelSize(); - auto [is_first_stage, is_last_stage, start_layer, end_layer] - = nn::parallel::PipelineParallel::GetStageInfo(n_layer, pp_size, nn::parallel::pp_rank); + int vpp_size = nn::parallel::global::GetVirtualPipelineParallelSize(); + auto [is_first_stage, is_last_stage, layer_chunks] + = nn::parallel::PipelineParallel::GetStageInfo(n_layer, pp_size, vpp_size); + // ========== layer to chunk ========== + std::unordered_map owned_layers; + for (const auto &[start, end] : layer_chunks) { + for (int i = start; i < end; ++i) { owned_layers[i] = true; } + } const int tp_size = nn::parallel::global::GetTensorParallelSize(); const int tp_rank = nn::parallel::tp_rank; @@ -493,6 +591,11 @@ std::shared_ptr LLaMA3::FromLLMC(const std::string &filepath) { LOG(INFO) << " max_gen_bs = " << max_gen_bs; LOG(INFO) << " version_major = " << version_major; LOG(INFO) << " version_minor = " << version_minor; + + LOG(INFO) << "Pipeline Parallel Chunks:"; + for (size_t i = 0; i < layer_chunks.size(); ++i) { + LOG(INFO) << " Chunk " << i << ": layers " << layer_chunks[i].first << " to " << layer_chunks[i].second; + } } const int64_t head_dim = static_cast(n_embd) / static_cast(n_head); @@ -548,14 +651,14 @@ std::shared_ptr LLaMA3::FromLLMC(const std::string &filepath) { } // transformer.h.{i}.ln_1.weight : Full version RMSNorm + int local_layer_index = 0; for (int i = 0; i < static_cast(n_layer); ++i) { - bool owned = (i >= start_layer && i < end_layer); - - if (owned) { + if (owned_layers.find(i) != owned_layers.end()) { auto &tensor = state_dict[std::format("{}.{}.{}.{}.{}", kTransformerLayerName, kHLayerName, - std::to_string(i - start_layer), Block::kLn1LayerName, + std::to_string(local_layer_index), Block::kLn1LayerName, RMSNorm::kParamWeightName)]; ReadVectorAllFloat(ifs, static_cast(tensor->DataPtr()), n_embd); + ++local_layer_index; } else { size_t ln_1_bytes = n_embd * sizeof(float); ifs.seekg(ln_1_bytes, std::ios::cur); @@ -564,11 +667,11 @@ std::shared_ptr LLaMA3::FromLLMC(const std::string &filepath) { // transformer.h.{i}.attn.c_attn.weight : ColumnParallelLinear, but actually applies on "rows" // W-qkv should be [Q(=n_embd) | K(=n_kv_head*head_dim) | V(=n_kv_head*head_dim)] × n_embd + local_layer_index = 0; for (int i = 0; i < static_cast(n_layer); ++i) { - bool owned = (i >= start_layer && i < end_layer); - if (owned) { + if (owned_layers.find(i) != owned_layers.end()) { auto &tensor = state_dict[std::format("{}.{}.{}.{}.{}.{}", kTransformerLayerName, kHLayerName, - std::to_string(i - start_layer), Block::kAttnLayerName, + std::to_string(local_layer_index), Block::kAttnLayerName, CausalSelfAttention::kCAttnLayerName, nn::parallel::ColumnParallelLinear::kParamWeightName)]; @@ -596,6 +699,7 @@ std::shared_ptr LLaMA3::FromLLMC(const std::string &filepath) { /*rows=*/attn_rows_all, /*cols=*/attn_cols, /*row_start=*/q_out_rows + kv_out_rows + tp_rank * kv_local_rows, /*row_cnt=*/kv_local_rows); + ++local_layer_index; } else { size_t qkv_bytes = static_cast(attn_rows_all) * attn_cols * sizeof(float); ifs.seekg(qkv_bytes, std::ios::cur); @@ -603,16 +707,17 @@ std::shared_ptr LLaMA3::FromLLMC(const std::string &filepath) { } // transformer.h.{i}.attn.c_proj.weight : RowParallelLinear, but actually applies on "columns" + local_layer_index = 0; for (int i = 0; i < static_cast(n_layer); ++i) { - bool owned = (i >= start_layer && i < end_layer); - if (owned) { + if (owned_layers.find(i) != owned_layers.end()) { auto &tensor = state_dict[std::format("{}.{}.{}.{}.{}.{}", kTransformerLayerName, kHLayerName, - std::to_string(i - start_layer), Block::kAttnLayerName, + std::to_string(local_layer_index), Block::kAttnLayerName, CausalSelfAttention::kCProjLayerName, nn::parallel::RowParallelLinear::kParamWeightName)]; ReadMatrixColShardFloat(ifs, static_cast(tensor->DataPtr()), /*rows=*/n_embd, /*cols=*/n_embd, /*col_start=*/tp_rank * in_pp, /*col_cnt=*/in_pp); + ++local_layer_index; } else { size_t c_proj_bytes = static_cast(n_embd) * n_embd * sizeof(float); ifs.seekg(c_proj_bytes, std::ios::cur); @@ -620,13 +725,15 @@ std::shared_ptr LLaMA3::FromLLMC(const std::string &filepath) { } // transformer.h.{i}.ln_2.weight : Full version RMSNorm + local_layer_index = 0; for (int i = 0; i < static_cast(n_layer); ++i) { - bool owned = (i >= start_layer && i < end_layer); - if (owned) { + + if (owned_layers.find(i) != owned_layers.end()) { auto &tensor = state_dict[std::format("{}.{}.{}.{}.{}", kTransformerLayerName, kHLayerName, - std::to_string(i - start_layer), Block::kLn2LayerName, + std::to_string(local_layer_index), Block::kLn2LayerName, RMSNorm::kParamWeightName)]; ReadVectorAllFloat(ifs, static_cast(tensor->DataPtr()), n_embd); + ++local_layer_index; } else { size_t ln_2_bytes = static_cast(n_embd) * sizeof(float); ifs.seekg(ln_2_bytes, std::ios::cur); @@ -634,15 +741,17 @@ std::shared_ptr LLaMA3::FromLLMC(const std::string &filepath) { } // transformer.h.{i}.mlp.c_fc.weight : ColumnParallelLinear, but actually applies on "rows" + local_layer_index = 0; for (int i = 0; i < static_cast(n_layer); ++i) { - bool owned = (i >= start_layer && i < end_layer); - if (owned) { + + if (owned_layers.find(i) != owned_layers.end()) { auto &tensor = state_dict[std::format( - "{}.{}.{}.{}.{}.{}", kTransformerLayerName, kHLayerName, std::to_string(i - start_layer), + "{}.{}.{}.{}.{}.{}", kTransformerLayerName, kHLayerName, std::to_string(local_layer_index), Block::kMlpLayerName, MLP::kCFcLayerName, nn::parallel::ColumnParallelLinear::kParamWeightName)]; ReadMatrixRowShardFloat(ifs, static_cast(tensor->DataPtr()), /*rows=*/fc_out, /*cols=*/n_embd, /*row_start=*/tp_rank * fc_pp, /*row_cnt=*/fc_pp); + ++local_layer_index; } else { size_t fc_bytes = static_cast(ffn_hidden) * n_embd * sizeof(float); ifs.seekg(fc_bytes, std::ios::cur); @@ -650,15 +759,17 @@ std::shared_ptr LLaMA3::FromLLMC(const std::string &filepath) { } // transformer.h.{i}.mlp.c_fc2.weight : ColumnParallelLinear, but actually applies on "rows" + local_layer_index = 0; for (int i = 0; i < static_cast(n_layer); ++i) { - bool owned = (i >= start_layer && i < end_layer); - if (owned) { + + if (owned_layers.find(i) != owned_layers.end()) { auto &tensor = state_dict[std::format( - "{}.{}.{}.{}.{}.{}", kTransformerLayerName, kHLayerName, std::to_string(i - start_layer), + "{}.{}.{}.{}.{}.{}", kTransformerLayerName, kHLayerName, std::to_string(local_layer_index), Block::kMlpLayerName, MLP::kCFc2LayerName, nn::parallel::ColumnParallelLinear::kParamWeightName)]; ReadMatrixRowShardFloat(ifs, static_cast(tensor->DataPtr()), /*rows=*/fc_out, /*cols=*/n_embd, /*row_start=*/tp_rank * fc_pp, /*row_cnt=*/fc_pp); + ++local_layer_index; } else { size_t fc2_bytes = static_cast(ffn_hidden) * n_embd * sizeof(float); ifs.seekg(fc2_bytes, std::ios::cur); @@ -666,15 +777,17 @@ std::shared_ptr LLaMA3::FromLLMC(const std::string &filepath) { } // transformer.h.{i}.mlp.c_proj.weight : RowParallelLinear, but actually applies on "columns" + local_layer_index = 0; for (int i = 0; i < static_cast(n_layer); ++i) { - bool owned = (i >= start_layer && i < end_layer); - if (owned) { + + if (owned_layers.find(i) != owned_layers.end()) { auto &tensor = state_dict[std::format( - "{}.{}.{}.{}.{}.{}", kTransformerLayerName, kHLayerName, std::to_string(i - start_layer), + "{}.{}.{}.{}.{}.{}", kTransformerLayerName, kHLayerName, std::to_string(local_layer_index), Block::kMlpLayerName, MLP::kCProjLayerName, nn::parallel::RowParallelLinear::kParamWeightName)]; ReadMatrixColShardFloat(ifs, static_cast(tensor->DataPtr()), /*rows=*/n_embd, /*cols=*/fc_out, /*col_start=*/tp_rank * in_fc_pp, /*col_cnt=*/in_fc_pp); + ++local_layer_index; } else { size_t c_proj_bytes = static_cast(n_embd) * ffn_hidden * sizeof(float); ifs.seekg(c_proj_bytes, std::ios::cur); diff --git a/example/llama3/net.h b/example/llama3/net.h index ec0199a..051e427 100644 --- a/example/llama3/net.h +++ b/example/llama3/net.h @@ -12,6 +12,10 @@ #include "infini_train/include/nn/modules/module.h" #include "infini_train/include/tensor.h" +namespace infini_train::nn { +class ModuleList; +} + struct LLaMA3Config { // ref: https://huggingface.co/meta-llama/Llama-3.2-1B // Model basic config @@ -108,6 +112,18 @@ class Block : public infini_train::nn::CloneableModule { Forward(const std::vector> &x) override; }; +class LLaMA3Chunk { +public: + bool has_embedding() const { return embedding_ != nullptr; } + bool has_norm() const { return norm_ != nullptr; } + bool has_head() const { return head_ != nullptr; } + + std::shared_ptr embedding_ = nullptr; + std::shared_ptr blocks_ = nullptr; + std::shared_ptr norm_ = nullptr; + std::shared_ptr head_ = nullptr; +}; + class LLaMA3 : public infini_train::nn::CloneableModule { public: static constexpr char kWTELayerName[] = "wte"; @@ -132,9 +148,14 @@ class LLaMA3 : public infini_train::nn::CloneableModule { std::vector> Forward(const std::vector> &x) override; + void BuildChunks(); + std::vector> + ForwardChunk(int chunk_idx, const std::vector> &input) override; + static std::shared_ptr FromPretrained(ModelType model_type); static std::shared_ptr FromLLMC(const std::string &filepath); private: LLaMA3Config config_; + std::vector chunks_; }; diff --git a/infini_train/include/nn/modules/module.h b/infini_train/include/nn/modules/module.h index 08c736e..ae4ca7b 100644 --- a/infini_train/include/nn/modules/module.h +++ b/infini_train/include/nn/modules/module.h @@ -54,6 +54,9 @@ class Module : public std::enable_shared_from_this { return 0.0f; }; + virtual std::vector> + ForwardChunk(int chunk_idx, const std::vector> &input_tensors); + virtual void To(const Device *device); virtual void To(DataType dtype); diff --git a/infini_train/include/nn/parallel/global.h b/infini_train/include/nn/parallel/global.h index f177b9d..480c128 100644 --- a/infini_train/include/nn/parallel/global.h +++ b/infini_train/include/nn/parallel/global.h @@ -27,7 +27,7 @@ class GlobalEnv { static GlobalEnv &Instance(); void Init(int threads_per_process, int tensor_parallel_size, bool sequence_parallel_enabled, - int pipeline_parallel_size); + int pipeline_parallel_size, int virtual_pipeline_parallel_size); int nnodes() const; @@ -51,6 +51,8 @@ class GlobalEnv { int pipeline_parallel_size() const; + int virtual_pipeline_parallel_size() const; + Layout layout() const; private: @@ -75,6 +77,7 @@ class GlobalEnv { int data_parallel_size_ = 1; int pipeline_parallel_size_ = 1; + int virtual_pipeline_parallel_size_ = 1; mutable std::mutex mutex_; bool initialized_ = false; @@ -83,9 +86,9 @@ class GlobalEnv { }; inline void InitAllEnv(int nthread_per_process, int tensor_parallel_size, bool sequence_parallel_enabled, - int pipeline_parallel_size) { + int pipeline_parallel_size, int virtual_pipeline_parallel) { GlobalEnv::Instance().Init(nthread_per_process, tensor_parallel_size, sequence_parallel_enabled, - pipeline_parallel_size); + pipeline_parallel_size, virtual_pipeline_parallel); } inline int GetNnodes() { return GlobalEnv::Instance().nnodes(); } @@ -100,6 +103,7 @@ inline int GetSequenceParallelSize() { return GlobalEnv::Instance().sequence_par inline bool GetSequenceParallelEnabled() { return GlobalEnv::Instance().sequence_parallel_enabled(); } inline int GetDataParallelSize() { return GlobalEnv::Instance().data_parallel_size(); } inline int GetPipelineParallelSize() { return GlobalEnv::Instance().pipeline_parallel_size(); } +inline int GetVirtualPipelineParallelSize() { return GlobalEnv::Instance().virtual_pipeline_parallel_size(); } // ========================= // Layout Helper Functions diff --git a/infini_train/include/nn/parallel/pp/pipeline_parallel.h b/infini_train/include/nn/parallel/pp/pipeline_parallel.h index c58f1da..78debb3 100644 --- a/infini_train/include/nn/parallel/pp/pipeline_parallel.h +++ b/infini_train/include/nn/parallel/pp/pipeline_parallel.h @@ -28,7 +28,8 @@ class PipelineParallel : public Module { const std::vector> &target, const std::shared_ptr &loss_fn, DataType dtype); - static std::tuple GetStageInfo(int total_layers, int pp_size, int pp_rank); + static std::tuple>> GetStageInfo(int total_layers, int pp_size, + int chunks_per_stage = 1); private: int num_stages_ = -1; diff --git a/infini_train/include/nn/parallel/pp/pipeline_schedule.h b/infini_train/include/nn/parallel/pp/pipeline_schedule.h index 4174716..33e4200 100644 --- a/infini_train/include/nn/parallel/pp/pipeline_schedule.h +++ b/infini_train/include/nn/parallel/pp/pipeline_schedule.h @@ -28,25 +28,34 @@ class PipelineSchedule { virtual float StepMicroBatches(const std::vector> &arg_mbs, const std::vector> &target_mbs, - const std::shared_ptr &loss_fn, DataType dtype) - = 0; + const std::shared_ptr &loss_fn, DataType dtype); - std::vector> ReceiveFromPrev(); - std::vector> SendToNext(const std::vector> &tensors); + std::vector> ReceiveFromPrev(int peer_rank); + std::vector> SendToNext(const std::vector> &tensors, int peer_rank); protected: int num_micro_batches_ = -1; std::shared_ptr stage_ = nullptr; }; -class ScheduleGPipe : public PipelineSchedule { +class PipelineParallelScheduler { public: - ScheduleGPipe(std::shared_ptr stage, int num_stages, int num_micro_batches) - : PipelineSchedule(std::move(stage), num_stages, num_micro_batches){}; - - float StepMicroBatches(const std::vector> &arg_mbs, - const std::vector> &target_mbs, - const std::shared_ptr &loss_fn, DataType dtype) override; + struct Task { + int step; + int microbatch_id; + int global_chunk_id; + int local_chunk_idx; + bool is_forward; + int stage_id; + bool is_first_chunk; + bool is_last_chunk; + }; + + static Task CreateTask(int step, int mb, int global_chunk, int num_stages, int total_chunks, bool is_forward); + + static std::vector GenerateGPipeSchedule(int n, int num_stages, int vpp_size); + + static std::vector GenerateInterleaved1F1BSchedule(int n, int num_stages, int vpp_size); }; class Schedule1F1B : public PipelineSchedule { diff --git a/infini_train/include/nn/parallel/pp/pipeline_stage.h b/infini_train/include/nn/parallel/pp/pipeline_stage.h index b7679d8..3011d7f 100644 --- a/infini_train/include/nn/parallel/pp/pipeline_stage.h +++ b/infini_train/include/nn/parallel/pp/pipeline_stage.h @@ -20,7 +20,8 @@ class PipelineStage { const std::vector> &recv_shape, std::shared_ptr optimizer, int device_id); - std::vector> ForwardOneChunk(const std::vector> &inputs); + std::vector> ForwardOneChunk(const std::vector> &inputs, + int chunk_idx = 0); bool IsFirstStage() const { return stage_index_ == 0; } bool IsLastStage() const { return stage_index_ == num_stages_ - 1; } diff --git a/infini_train/src/autograd/function.cc b/infini_train/src/autograd/function.cc index 48ad02a..4a94172 100644 --- a/infini_train/src/autograd/function.cc +++ b/infini_train/src/autograd/function.cc @@ -63,6 +63,10 @@ void Function::BackwardPartial(const std::shared_ptr &grad_output, int g const auto *device = grad_output->GetDevice(); device->SetDevice(); + // 添加日志:当前执行的反向函数 + // std::cout << "[Backward] Function: " << typeid(*this).name() << ", grad_output_idx: " << grad_output_idx + // << ", dependencies_reached: " << dependencies_reached_ << "/" << dependencies_number_ << std::endl; + // NOTE(dcj): The accumulate autograd function has no grad_outputs. // Temporarily resize the vector to hold one nullptr as a buffer. if (grad_outputs_.empty()) { diff --git a/infini_train/src/nn/modules/module.cc b/infini_train/src/nn/modules/module.cc index 3f757fa..8e1f070 100644 --- a/infini_train/src/nn/modules/module.cc +++ b/infini_train/src/nn/modules/module.cc @@ -110,6 +110,12 @@ std::vector> Module::Forward(const std::vector> Module::ForwardChunk(int chunk_idx, + const std::vector> &input_tensors) { + LOG(FATAL) << "ForwardChunk function not implemented for this module"; + return {}; +} + void Module::To(const Device *device) { CHECK_NOTNULL(device); if (device == device_) { diff --git a/infini_train/src/nn/parallel/global.cc b/infini_train/src/nn/parallel/global.cc index 4c00da4..39cd95d 100644 --- a/infini_train/src/nn/parallel/global.cc +++ b/infini_train/src/nn/parallel/global.cc @@ -90,7 +90,7 @@ GlobalEnv &GlobalEnv::Instance() { } void GlobalEnv::Init(int nthread_per_process, int tensor_parallel_size, bool sequence_parallel_enabled, - int pipeline_parallel_size) { + int pipeline_parallel_size, int virtual_pipeline_parallel_size) { std::lock_guard lock(mutex_); CHECK(!initialized_) << "Repeated initialization of GlobalEnv!"; @@ -106,6 +106,7 @@ void GlobalEnv::Init(int nthread_per_process, int tensor_parallel_size, bool seq tensor_parallel_size_ = tensor_parallel_size; sequence_parallel_enabled_ = sequence_parallel_enabled; pipeline_parallel_size_ = pipeline_parallel_size; + virtual_pipeline_parallel_size_ = virtual_pipeline_parallel_size; data_parallel_size_ = world_size_ / tensor_parallel_size_ / pipeline_parallel_size_; layout_.sizes[DP] = data_parallel_size_; @@ -171,6 +172,11 @@ int GlobalEnv::pipeline_parallel_size() const { return pipeline_parallel_size_; } +int GlobalEnv::virtual_pipeline_parallel_size() const { + CHECK(initialized_) << "GlobalEnv is not initialized!"; + return virtual_pipeline_parallel_size_; +} + Layout GlobalEnv::layout() const { CHECK(initialized_) << "GlobalEnv is not initialized!"; return layout_; diff --git a/infini_train/src/nn/parallel/pp/pipeline_parallel.cc b/infini_train/src/nn/parallel/pp/pipeline_parallel.cc index 6da3b1e..51df53b 100644 --- a/infini_train/src/nn/parallel/pp/pipeline_parallel.cc +++ b/infini_train/src/nn/parallel/pp/pipeline_parallel.cc @@ -24,7 +24,8 @@ void PipelineParallel::BuildPipelineStage(const std::shared_ptr &module, } void PipelineParallel::SetupSchedule(int num_micro_batches) { - schedule_ = std::make_shared(pipeline_stage_, num_stages_, num_micro_batches); + // schedule_ = std::make_shared(pipeline_stage_, num_stages_, num_micro_batches); + schedule_ = std::make_shared(pipeline_stage_, num_stages_, num_micro_batches); } float PipelineParallel::TrainStep(const std::vector> &input, @@ -39,22 +40,43 @@ float PipelineParallel::TrainStep(const std::vector> &in return schedule_->Step(stage_input, stage_target, loss_fn, dtype); } -std::tuple PipelineParallel::GetStageInfo(int total_layers, int pp_size, int pp_rank) { +std::tuple>> PipelineParallel::GetStageInfo(int total_layers, int pp_size, + int chunks_per_stage) { + int rank = pp_rank; bool is_first_stage = (pp_rank == 0); bool is_last_stage = (pp_rank == pp_size - 1); - int layers_per_stage = total_layers / pp_size; - int remainder = total_layers % pp_size; - int start_layer, end_layer; - if (pp_rank < remainder) { - start_layer = pp_rank * (layers_per_stage + 1); - end_layer = start_layer + layers_per_stage + 1; - } else { - start_layer = pp_rank * layers_per_stage + remainder; - end_layer = start_layer + layers_per_stage; + std::vector> layer_chunks; + + int layers_per_chunk = total_layers / (pp_size * chunks_per_stage); + int remainder = total_layers % (pp_size * chunks_per_stage); + + for (int chunk_idx = 0; chunk_idx < chunks_per_stage; ++chunk_idx) { + int global_chunk_idx = chunk_idx * pp_size + rank; + + if (global_chunk_idx * layers_per_chunk >= total_layers) { + break; + } + + int chunk_start = global_chunk_idx * layers_per_chunk; + int chunk_end = chunk_start + layers_per_chunk; + + if (global_chunk_idx < remainder) { + // Assign an additional layer to each of the first remainder chunks + chunk_start = global_chunk_idx * (layers_per_chunk + 1); + chunk_end = chunk_start + (layers_per_chunk + 1); + } else { + chunk_start = remainder * (layers_per_chunk + 1) + (global_chunk_idx - remainder) * layers_per_chunk; + chunk_end = chunk_start + layers_per_chunk; + } + + chunk_end = std::min(chunk_end, total_layers); + if (chunk_start < chunk_end) { + layer_chunks.push_back({chunk_start, chunk_end}); + } } - return {is_first_stage, is_last_stage, start_layer, end_layer}; + return {is_first_stage, is_last_stage, layer_chunks}; } PipelineParallel::PipelineParallel(const std::shared_ptr module, int num_stages, int num_micro_batches, diff --git a/infini_train/src/nn/parallel/pp/pipeline_schedule.cc b/infini_train/src/nn/parallel/pp/pipeline_schedule.cc index d570224..2286505 100644 --- a/infini_train/src/nn/parallel/pp/pipeline_schedule.cc +++ b/infini_train/src/nn/parallel/pp/pipeline_schedule.cc @@ -14,6 +14,7 @@ #include "infini_train/include/device.h" #include "infini_train/include/nn/init.h" #include "infini_train/include/nn/modules/module.h" +#include "infini_train/include/nn/parallel/global.h" #include "infini_train/include/nn/parallel/pp/pipeline_stage.h" #include "infini_train/include/nn/parallel/pp/send_recv.h" #include "infini_train/include/optimizer.h" @@ -21,30 +22,143 @@ namespace infini_train::nn::parallel { -float PipelineSchedule::Step(std::shared_ptr input, std::shared_ptr target, - const std::shared_ptr &loss_fn, DataType dtype) { - std::vector> micro_batches(num_micro_batches_); - std::vector> target_mbs(num_micro_batches_); - if (stage_->IsFirstStage()) { - micro_batches = input->Split(input->Dims()[0] / num_micro_batches_); +float Schedule1F1B::StepMicroBatches(const std::vector> µbatch_inputs, + const std::vector> µbatch_targets, + const std::shared_ptr &loss_fn, DataType dtype) { + const int n = num_micro_batches_; + if (n == 0) { + return 0.0f; } - if (stage_->IsLastStage()) { - target_mbs = target->Split(target->Dims()[0] / num_micro_batches_); + float total_loss = 0.0f; + const int num_stages = stage_->num_stages(); + const int stage_index = stage_->stage_index(); + + const int warmup_steps = num_stages; + const int cooldown_steps = num_stages; + const int total_steps = num_stages + n - 1; + + std::vector>> activations(n); + + int mb_forward_i; // forward micro_batch index + int mb_backward_i; // backward micro_batch index + printf("[stage %d] warmup_steps start\n", stage_index); + // warmup_steps + for (mb_forward_i = 0, mb_backward_i = 0; mb_forward_i < std::min(n, warmup_steps); + ++mb_forward_i, ++mb_backward_i) { + std::vector> inputs; + if (stage_->IsFirstStage()) { + inputs = {microbatch_inputs[mb_forward_i]}; + } else { + inputs = ReceiveFromPrev(stage_->prev_rank()); + } + + activations[mb_forward_i] = stage_->ForwardOneChunk(inputs); + + if (!stage_->IsLastStage()) { + SendToNext(activations[mb_forward_i], stage_->next_rank()); + } else { + auto target = microbatch_targets[mb_backward_i]; + auto output = activations[mb_backward_i][0]; + auto target_on_device = target->To(output->GetDevice()); + auto loss = loss_fn->Forward({output, std::make_shared(target_on_device)})[0]; + loss = loss / n; + + auto loss_cpu = loss->To(DeviceManager::Instance()->GetDefaultDevice()); + total_loss += static_cast(loss_cpu.DataPtr())[0]; + + printf("warmup_steps start Backward\n"); + loss->Backward(); + } } - const auto &optimizer = stage_->optimizer(); + if (!stage_->IsLastStage()) { + for (mb_backward_i = 0; mb_backward_i <= stage_index && mb_backward_i < n; ++mb_backward_i) { + auto out_tensor = activations[mb_backward_i][0]; - optimizer->ZeroGrad(); + auto gradient = std::make_shared(out_tensor->Dims(), out_tensor->Dtype(), out_tensor->GetDevice()); - float lossf = StepMicroBatches(micro_batches, target_mbs, loss_fn, dtype); + out_tensor->Backward(gradient); + } + } - optimizer->Step(); + printf("[stage %d] steady_steps start\n", stage_index); + // steady_steps + for (; mb_forward_i < n; ++mb_forward_i, ++mb_backward_i) { + // Forward + // printf("[stage %d] steady_steps mb_forward_i %d\n", stage_index_, mb_forward_i); + std::vector> inputs; + if (stage_->IsFirstStage()) { + inputs = {microbatch_inputs[mb_forward_i]}; + } else { + inputs = ReceiveFromPrev(stage_->prev_rank()); + } - return lossf; + activations[mb_forward_i] = stage_->ForwardOneChunk(inputs); + + printf("[stage %d] steady_steps 开始反向 mb_forward_i: %d mb_backward_i: %d\n", stage_index, mb_forward_i, + mb_backward_i); + // Backward + if (!stage_->IsLastStage()) { + SendToNext(activations[mb_forward_i], stage_->next_rank()); + + auto out_tensor = activations[mb_backward_i][0]; + + auto gradient = std::make_shared(out_tensor->Dims(), out_tensor->Dtype(), out_tensor->GetDevice()); + + out_tensor->Backward(gradient); + } else { + auto target = microbatch_targets[mb_backward_i]; + auto output = activations[mb_backward_i][0]; + auto target_on_device = target->To(output->GetDevice()); + auto loss = loss_fn->Forward({output, std::make_shared(target_on_device)})[0]; + loss = loss / n; + + auto loss_cpu = loss->To(DeviceManager::Instance()->GetDefaultDevice()); + total_loss += static_cast(loss_cpu.DataPtr())[0]; + + loss->Backward(); + } + } + + printf("[stage %d] cooldown_steps start\n", stage_index); + // cooldown_steps + if (!stage_->IsLastStage()) { + for (; mb_backward_i < n; ++mb_backward_i) { + auto out_tensor = activations[mb_backward_i][0]; + + auto gradient = std::make_shared(out_tensor->Dims(), out_tensor->Dtype(), out_tensor->GetDevice()); + + out_tensor->Backward(gradient); + } + } + + return total_loss; } -std::vector> PipelineSchedule::ReceiveFromPrev() { +void PrintScheduleTable(int n, int num_stages, int vpp_size) { + int total_global_chunks = num_stages * vpp_size; + int total_steps = n + total_global_chunks - 1; + + // auto schedule = PipelineParallelScheduler::GenerateGPipeSchedule(n, num_stages, vpp_size); + auto schedule = PipelineParallelScheduler::GenerateInterleaved1F1BSchedule(n, num_stages, vpp_size); + + printf("=== 1F1B Interleaved Schedule Table ===\n"); + printf("n=%d, stages=%d, vpp=%d, total_chunks=%d \n\n", n, num_stages, vpp_size, total_global_chunks); + + printf("Step | Type | Microbatch | Global Chunk | Local Chunk | Stage\n"); + printf("-----|-------|------------|--------------|-------------|-------\n"); + + for (const auto &task : schedule) { + int owning_stage = task.global_chunk_id % num_stages; + int local_chunk = task.global_chunk_id / num_stages; + + printf("%4d | %-6s| %-11d| %-13d| %-12d| %d\n", task.step, task.is_forward ? "Forward" : "Backward", + task.microbatch_id, task.global_chunk_id, local_chunk, owning_stage); + } +} + +std::vector> PipelineSchedule::ReceiveFromPrev(int peer_rank) { std::vector> recv_tensors; auto &shapes = stage_->recv_shape(); for (size_t i = 0; i < shapes.size(); ++i) { @@ -54,81 +168,247 @@ std::vector> PipelineSchedule::ReceiveFromPrev() { recv_tensors.push_back(tensor); } - return IRecv(recv_tensors, stage_->device(), stage_->stage_index(), stage_->prev_rank()); + return IRecv(recv_tensors, stage_->device(), stage_->stage_index(), peer_rank); } -std::vector> PipelineSchedule::SendToNext(const std::vector> &tensors) { - return ISend(tensors, stage_->device(), stage_->stage_index(), stage_->next_rank(), stage_->recv_shape()); +std::vector> PipelineSchedule::SendToNext(const std::vector> &tensors, + int peer_rank) { + return ISend(tensors, stage_->device(), stage_->stage_index(), peer_rank, stage_->recv_shape()); } -float ScheduleGPipe::StepMicroBatches(const std::vector> µbatch_inputs, - const std::vector> µbatch_targets, - const std::shared_ptr &loss_fn, DataType dtype) { - const auto n = num_micro_batches_; - if (n == 0) { - return 0.0f; - } +PipelineParallelScheduler::Task PipelineParallelScheduler::CreateTask(int step, int mb, int global_chunk, + int num_stages, int total_chunks, + bool is_forward) { + PipelineParallelScheduler::Task task; + task.step = step; + task.microbatch_id = mb; + task.global_chunk_id = global_chunk; + task.local_chunk_idx = global_chunk / num_stages; + task.is_forward = is_forward; + task.stage_id = global_chunk % num_stages; + task.is_last_chunk = (global_chunk == total_chunks - 1); + task.is_first_chunk = (global_chunk == 0); + return task; +} - std::vector>> outputs(n); +std::vector PipelineParallelScheduler::GenerateGPipeSchedule(int n, int num_stages, + int vpp_size) { + std::vector schedule; + int total_global_chunks = num_stages * vpp_size; + int total_steps = n + total_global_chunks - 1; // ======== Forward Pass ======== - for (int mb = 0; mb < n; ++mb) { - infini_train::AutocastGuard autocast_guard(stage_->device()->Type(), dtype); + for (int step = 0; step < total_steps; ++step) { + for (int mb = 0; mb < n; ++mb) { + int global_chunk_id = step - mb; + if (global_chunk_id >= 0 && global_chunk_id < total_global_chunks) { + auto is_forward = true; + auto task = CreateTask(step, mb, global_chunk_id, num_stages, total_global_chunks, is_forward); + schedule.push_back(task); + } + } + } - std::vector> inputs; - if (stage_->IsFirstStage()) { - inputs = {microbatch_inputs[mb]}; - } else { - inputs = ReceiveFromPrev(); + // ======== Backward Pass ======== + for (int step = 0; step < total_steps; ++step) { + for (int mb = 0; mb < n; ++mb) { + int global_chunk_id = (total_steps - 1 - step) - mb; + if (global_chunk_id >= 0 && global_chunk_id < total_global_chunks) { + auto is_forward = false; + auto task + = CreateTask(step + total_steps, mb, global_chunk_id, num_stages, total_global_chunks, is_forward); + schedule.push_back(task); + } } + } - outputs[mb] = stage_->ForwardOneChunk(inputs); + // sorted according to step、local_chunk_idx + std::sort(schedule.begin(), schedule.end(), [](const Task &a, const Task &b) { + if (a.step != b.step) { + return a.step < b.step; + } - if (!stage_->IsLastStage()) { - outputs[mb] = SendToNext(outputs[mb]); + return a.local_chunk_idx < b.local_chunk_idx; + }); + + return schedule; +} + +std::vector +PipelineParallelScheduler::GenerateInterleaved1F1BSchedule(int n, int num_stages, int vpp_size) { + std::vector schedule; + + if (n <= 0 || num_stages <= 0 || vpp_size <= 0) { + return schedule; + } + + int total_global_chunks = num_stages * vpp_size; + + int warmup_steps = total_global_chunks - 1; + int total_steps = 2 * warmup_steps + n; + + // std::cout << "Interleaved 1F1B Parameters:" << std::endl; + // std::cout << " n = " << n << ", num_stages = " << num_stages << ", vpp_size = " << vpp_size << std::endl; + // std::cout << " total_virtual_stages = " << total_global_chunks << std::endl; + // std::cout << " warmup_steps = " << warmup_steps << std::endl; + // std::cout << " total_steps = " << total_steps << std::endl; + + // ================ Warm-up ================ + for (int step = 0; step < warmup_steps; ++step) { + for (int mb = 0; mb < n; ++mb) { + int forward_global_chunk = step - mb; + if (forward_global_chunk >= 0 && forward_global_chunk < total_global_chunks) { + auto is_forward = true; + auto task = CreateTask(step, mb, forward_global_chunk, num_stages, total_global_chunks, is_forward); + schedule.push_back(task); + } } } - // ======== Backward Pass ======== - float total_loss = 0.0f; - if (!stage_->IsLastStage()) { + // ================ Steady ================ + for (int step = warmup_steps; step < warmup_steps + n; ++step) { + int stable_step = step - warmup_steps; // 稳定阶段内的步数 + for (int mb = 0; mb < n; ++mb) { - auto out_tensor = outputs[mb][0]; + // Forward + int forward_global_chunk = step - mb; + if (forward_global_chunk >= 0 && forward_global_chunk < total_global_chunks) { + auto is_forward = true; + auto task = CreateTask(step, mb, forward_global_chunk, num_stages, total_global_chunks, is_forward); + schedule.push_back(task); + } - auto dummy_gradient - = std::make_shared(out_tensor->Dims(), out_tensor->Dtype(), out_tensor->GetDevice()); + // Backward + int backward_global_chunk = (total_global_chunks - 1) - (stable_step - mb); - out_tensor->Backward(dummy_gradient); + if (backward_global_chunk >= 0 && backward_global_chunk < total_global_chunks) { + auto is_forward = false; + auto task = CreateTask(step, mb, backward_global_chunk, num_stages, total_global_chunks, is_forward); + schedule.push_back(task); + } } - } else { + } + + // ================ Cool-down ================ + for (int step = warmup_steps + n; step < total_steps; ++step) { for (int mb = 0; mb < n; ++mb) { - auto target = microbatch_targets[mb]; - auto output = outputs[mb][0]; + int backward_step = step - (warmup_steps); + int backward_global_chunk = (total_global_chunks - 1) - (backward_step - mb); + if (backward_global_chunk >= 0 && backward_global_chunk < total_global_chunks) { + auto is_forward = false; + auto task = CreateTask(step, mb, backward_global_chunk, num_stages, total_global_chunks, is_forward); + schedule.push_back(task); + } + } + } + // std::cout << "Cool-down阶段 OK" << std::endl; + + return schedule; +} + +float PipelineSchedule::StepMicroBatches(const std::vector> µbatch_inputs, + const std::vector> µbatch_targets, + const std::shared_ptr &loss_fn, DataType dtype) { + int n = num_micro_batches_; + int num_stages = stage_->num_stages(); + int stage_idx = stage_->stage_index(); + int vpp_size = global::GetVirtualPipelineParallelSize(); + + auto schedule = PipelineParallelScheduler::GenerateInterleaved1F1BSchedule(n, num_stages, vpp_size); + + // if (stage_idx == 0) { + // PrintScheduleTable(n, num_stages, vpp_size); + // } + float total_loss = 0.0f; + // printf("[stage %d] Schedule has %lu tasks\n", stage_idx, schedule.size()); - if (!target || !output) { - LOG(FATAL) << "Output or target is null at mb=" << mb; + std::vector>>> activations( + vpp_size, std::vector>>(n)); + + for (size_t i = 0; i < schedule.size(); ++i) { + const auto &task = schedule[i]; + if (task.stage_id != stage_idx) { + continue; + } + + int mb = task.microbatch_id; + if (task.is_forward) { + infini_train::AutocastGuard autocast_guard(stage_->device()->Type(), dtype); + + std::vector> inputs; + + if (task.is_first_chunk) { + inputs = {microbatch_inputs[mb]}; + } else { + if (stage_->IsFirstStage()) { + inputs = ReceiveFromPrev(num_stages - 1); + } else { + inputs = ReceiveFromPrev(stage_->prev_rank()); + } } - std::shared_ptr loss; - { - infini_train::AutocastGuard autocast_guard(stage_->device()->Type(), dtype); + activations[task.local_chunk_idx][mb] = stage_->ForwardOneChunk(inputs, task.local_chunk_idx); - auto target_on_device = target->To(output->GetDevice()); - loss = loss_fn->Forward({output, std::make_shared(target_on_device)})[0]; - if (!loss) { - LOG(FATAL) << "[ERROR] loss is nullptr at mb = " << mb; + if (!task.is_last_chunk) { + if (stage_->IsLastStage()) { + SendToNext(activations[task.local_chunk_idx][mb], 0); + } else { + SendToNext(activations[task.local_chunk_idx][mb], stage_->next_rank()); } - loss = loss / n; } + } else { + if (task.is_last_chunk) { + auto target = microbatch_targets[mb]; + std::shared_ptr loss; + { + infini_train::AutocastGuard autocast_guard(stage_->device()->Type(), dtype); - auto loss_cpu = loss->To(DeviceManager::Instance()->GetDefaultDevice()); - total_loss += static_cast(loss_cpu.DataPtr())[0]; + auto target_on_device = target->To(activations[task.local_chunk_idx][mb][0]->GetDevice()); + loss = loss_fn->Forward( + {activations[task.local_chunk_idx][mb][0], std::make_shared(target_on_device)})[0]; + loss = loss / n; + } + total_loss + += static_cast(loss->To(DeviceManager::Instance()->GetDefaultDevice()).DataPtr())[0]; - loss->Backward(); + // printf("[stage %d][chunk %d] ------------最后一个chunk反向 microbatch %d----------------------- \n", + // stage_idx, task.local_chunk_idx, mb); + loss->Backward(); + } else { + auto out_tensor = activations[task.local_chunk_idx][mb][0]; + + auto dummy_gradient + = std::make_shared(out_tensor->Dims(), out_tensor->Dtype(), out_tensor->GetDevice()); + + out_tensor->Backward(dummy_gradient); + } } } return total_loss; } +float PipelineSchedule::Step(std::shared_ptr input, std::shared_ptr target, + const std::shared_ptr &loss_fn, DataType dtype) { + std::vector> micro_batches(num_micro_batches_); + std::vector> target_mbs(num_micro_batches_); + if (stage_->IsFirstStage()) { + micro_batches = input->Split(input->Dims()[0] / num_micro_batches_); + } + + if (stage_->IsLastStage()) { + target_mbs = target->Split(target->Dims()[0] / num_micro_batches_); + } + + const auto &optimizer = stage_->optimizer(); + + optimizer->ZeroGrad(); + + float lossf = StepMicroBatches(micro_batches, target_mbs, loss_fn, dtype); + + optimizer->Step(); + + return lossf; +} + } // namespace infini_train::nn::parallel diff --git a/infini_train/src/nn/parallel/pp/pipeline_stage.cc b/infini_train/src/nn/parallel/pp/pipeline_stage.cc index bdb8cba..7fb1cfa 100644 --- a/infini_train/src/nn/parallel/pp/pipeline_stage.cc +++ b/infini_train/src/nn/parallel/pp/pipeline_stage.cc @@ -19,9 +19,9 @@ PipelineStage::PipelineStage(const std::shared_ptr &model, int stage_ind optimizer_(std::move(optimizer)), device_(DeviceManager::Instance()->GetAllAvailableDevices(DeviceType::kCUDA).at(device_id)) {} -std::vector> -PipelineStage::ForwardOneChunk(const std::vector> &inputs) { - return model_->Forward(inputs); +std::vector> PipelineStage::ForwardOneChunk(const std::vector> &inputs, + int chunk_idx) { + return model_->ForwardChunk(chunk_idx, inputs); } } // namespace infini_train::nn::parallel diff --git a/infini_train/src/nn/parallel/pp/send_recv.cc b/infini_train/src/nn/parallel/pp/send_recv.cc index 6a24a0e..65a6758 100644 --- a/infini_train/src/nn/parallel/pp/send_recv.cc +++ b/infini_train/src/nn/parallel/pp/send_recv.cc @@ -69,6 +69,7 @@ std::vector> ISend::Forward(const std::vector> ISend::Backward(const std::vector> &grad_outputs) { + // printf("[stage %d] ISend::Backward!!!!\n", cur_rank_); std::vector> recv_tensors; for (int shape_i = 0; shape_i < shapes_.size(); ++shape_i) { // FIXME(jym): The data type between stages is not float32, which will cause a crash @@ -100,6 +101,7 @@ void IRecv::SetupContext(const std::vector> &input_tenso } std::vector> IRecv::Backward(const std::vector> &grad_outputs) { + // printf("[stage %d] IRecv::Backward!!!! peer_rank_: %d\n", cur_rank_, peer_rank_); auto pp_group = ProcessGroupFactory::Instance()->Get(GetPipelineParallelProcessGroupName(cur_device_->rank().GlobalRank())); return pp_group->NcclSend(grad_outputs, peer_rank_);