Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion example/gpt2/main.cc
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ DEFINE_int32(
DEFINE_uint32(tensor_parallel, 1, "Tensor Parallel world size");
DEFINE_bool(sequence_parallel, false, "Whether to enable Sequence Parallel");
DEFINE_uint32(pipeline_parallel, 1, "Pipeline Parallel world size, specified the number of PP stages.");
DEFINE_uint32(virtual_pipeline_parallel, 1, "Number of chunks in PP stage.");

// precision
DEFINE_string(dtype, "float32", "precision used in training (float32/bfloat16)");
Expand Down Expand Up @@ -347,7 +348,7 @@ int main(int argc, char *argv[]) {
google::InitGoogleLogging(argv[0]);

nn::parallel::global::InitAllEnv(FLAGS_nthread_per_process, FLAGS_tensor_parallel, FLAGS_sequence_parallel,
FLAGS_pipeline_parallel);
FLAGS_pipeline_parallel, FLAGS_virtual_pipeline_parallel);

LOG(INFO) << nn::parallel::global::ProcessGroupOverview();

Expand Down
240 changes: 180 additions & 60 deletions example/gpt2/net.cc

Large diffs are not rendered by default.

23 changes: 23 additions & 0 deletions example/gpt2/net.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,10 @@
#include "infini_train/include/nn/modules/module.h"
#include "infini_train/include/tensor.h"

namespace infini_train::nn {
class ModuleList;
}

struct GPT2Config {
int64_t block_size = 1024;
int64_t vocab_size = 50304;
Expand Down Expand Up @@ -71,6 +75,20 @@ class Block : public infini_train::nn::CloneableModule<Block> {
Forward(const std::vector<std::shared_ptr<infini_train::Tensor>> &x) override;
};

class GPT2Chunk {
public:
bool has_wte() const { return wte_ != nullptr; }
bool has_wpe() const { return wpe_ != nullptr; }
bool has_norm() const { return norm_ != nullptr; }
bool has_head() const { return head_ != nullptr; }

std::shared_ptr<infini_train::nn::Module> wte_ = nullptr;
std::shared_ptr<infini_train::nn::Module> wpe_ = nullptr;
std::shared_ptr<infini_train::nn::ModuleList> blocks_ = nullptr;
std::shared_ptr<infini_train::nn::Module> norm_ = nullptr;
std::shared_ptr<infini_train::nn::Module> head_ = nullptr;
};

class GPT2 : public infini_train::nn::CloneableModule<GPT2> {
public:
static constexpr char kWTELayerName[] = "wte";
Expand All @@ -92,9 +110,14 @@ class GPT2 : public infini_train::nn::CloneableModule<GPT2> {
std::vector<std::shared_ptr<infini_train::Tensor>>
Forward(const std::vector<std::shared_ptr<infini_train::Tensor>> &x) override;

void BuildChunks();
std::vector<std::shared_ptr<infini_train::Tensor>>
ForwardChunk(int chunk_idx, const std::vector<std::shared_ptr<infini_train::Tensor>> &input) override;

static std::shared_ptr<GPT2> FromPretrained(ModelType model_type);
static std::shared_ptr<GPT2> FromLLMC(const std::string &filepath);

private:
GPT2Config config_;
std::vector<GPT2Chunk> chunks_;
};
5 changes: 3 additions & 2 deletions example/llama3/main.cc
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,8 @@ DEFINE_int32(
"When set > 1, enables data parallelism with device=cuda on the specified number of visible CUDA devices.");
DEFINE_uint32(tensor_parallel, 1, "Tensor Parallel world size");
DEFINE_bool(sequence_parallel, false, "Whether to enable Sequence Parallel");
DEFINE_uint32(pipeline_parallel, 1, "Pipeline Parallel world size, , specified the number of PP stages.");
DEFINE_uint32(pipeline_parallel, 1, "Pipeline Parallel world size, specified the number of PP stages.");
DEFINE_uint32(virtual_pipeline_parallel, 1, "Number of chunks in PP stage.");
// precision
DEFINE_string(dtype, "float32", "precision used in training (float32/bfloat16)");

Expand Down Expand Up @@ -325,7 +326,7 @@ int main(int argc, char *argv[]) {
google::InitGoogleLogging(argv[0]);

nn::parallel::global::InitAllEnv(FLAGS_nthread_per_process, FLAGS_tensor_parallel, FLAGS_sequence_parallel,
FLAGS_pipeline_parallel);
FLAGS_pipeline_parallel, FLAGS_virtual_pipeline_parallel);

LOG(INFO) << nn::parallel::global::ProcessGroupOverview();

Expand Down
199 changes: 153 additions & 46 deletions example/llama3/net.cc

Large diffs are not rendered by default.

21 changes: 21 additions & 0 deletions example/llama3/net.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,10 @@
#include "infini_train/include/nn/modules/module.h"
#include "infini_train/include/tensor.h"

namespace infini_train::nn {
class ModuleList;
}

struct LLaMA3Config {
// ref: https://huggingface.co/meta-llama/Llama-3.2-1B
// Model basic config
Expand Down Expand Up @@ -108,6 +112,18 @@ class Block : public infini_train::nn::CloneableModule<Block> {
Forward(const std::vector<std::shared_ptr<infini_train::Tensor>> &x) override;
};

class LLaMA3Chunk {
public:
bool has_embedding() const { return embedding_ != nullptr; }
bool has_norm() const { return norm_ != nullptr; }
bool has_head() const { return head_ != nullptr; }

std::shared_ptr<infini_train::nn::Module> embedding_ = nullptr;
std::shared_ptr<infini_train::nn::ModuleList> blocks_ = nullptr;
std::shared_ptr<infini_train::nn::Module> norm_ = nullptr;
std::shared_ptr<infini_train::nn::Module> head_ = nullptr;
};

class LLaMA3 : public infini_train::nn::CloneableModule<LLaMA3> {
public:
static constexpr char kWTELayerName[] = "wte";
Expand All @@ -132,9 +148,14 @@ class LLaMA3 : public infini_train::nn::CloneableModule<LLaMA3> {
std::vector<std::shared_ptr<infini_train::Tensor>>
Forward(const std::vector<std::shared_ptr<infini_train::Tensor>> &x) override;

void BuildChunks();
std::vector<std::shared_ptr<infini_train::Tensor>>
ForwardChunk(int chunk_idx, const std::vector<std::shared_ptr<infini_train::Tensor>> &input) override;

static std::shared_ptr<LLaMA3> FromPretrained(ModelType model_type);
static std::shared_ptr<LLaMA3> FromLLMC(const std::string &filepath);

private:
LLaMA3Config config_;
std::vector<LLaMA3Chunk> chunks_;
};
3 changes: 3 additions & 0 deletions infini_train/include/nn/modules/module.h
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,9 @@ class Module : public std::enable_shared_from_this<Module> {
return 0.0f;
};

virtual std::vector<std::shared_ptr<Tensor>>
ForwardChunk(int chunk_idx, const std::vector<std::shared_ptr<Tensor>> &input_tensors);

virtual void To(const Device *device);

virtual void To(DataType dtype);
Expand Down
10 changes: 7 additions & 3 deletions infini_train/include/nn/parallel/global.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ class GlobalEnv {
static GlobalEnv &Instance();

void Init(int threads_per_process, int tensor_parallel_size, bool sequence_parallel_enabled,
int pipeline_parallel_size);
int pipeline_parallel_size, int virtual_pipeline_parallel_size);

int world_size() const;

Expand All @@ -47,6 +47,8 @@ class GlobalEnv {

int pipeline_parallel_size() const;

int virtual_pipeline_parallel_size() const;

Layout layout() const;

private:
Expand All @@ -69,6 +71,7 @@ class GlobalEnv {
int data_parallel_size_ = 1;

int pipeline_parallel_size_ = 1;
int virtual_pipeline_parallel_size_ = 1;

mutable std::mutex mutex_;
bool initialized_ = false;
Expand All @@ -77,9 +80,9 @@ class GlobalEnv {
};

inline void InitAllEnv(int nthread_per_process, int tensor_parallel_size, bool sequence_parallel_enabled,
int pipeline_parallel_size) {
int pipeline_parallel_size, int virtual_pipeline_parallel) {
GlobalEnv::Instance().Init(nthread_per_process, tensor_parallel_size, sequence_parallel_enabled,
pipeline_parallel_size);
pipeline_parallel_size, virtual_pipeline_parallel);
}

inline int GetWorldSize() { return GlobalEnv::Instance().world_size(); }
Expand All @@ -92,6 +95,7 @@ inline int GetTensorParallelSize() { return GlobalEnv::Instance().tensor_paralle
inline bool GetSequenceParallelEnabled() { return GlobalEnv::Instance().sequence_parallel_enabled(); }
inline int GetDataParallelSize() { return GlobalEnv::Instance().data_parallel_size(); }
inline int GetPipelineParallelSize() { return GlobalEnv::Instance().pipeline_parallel_size(); }
inline int GetVirtualPipelineParallelSize() { return GlobalEnv::Instance().virtual_pipeline_parallel_size(); }

// Layout Helper Functions
inline int GetRankOf(int dp, int tp, int pp) { return GlobalEnv::Instance().layout().RankOf(dp, tp, pp); }
Expand Down
3 changes: 2 additions & 1 deletion infini_train/include/nn/parallel/pp/pipeline_parallel.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@ class PipelineParallel : public Module {
float TrainStep(const std::vector<std::shared_ptr<Tensor>> &input,
const std::vector<std::shared_ptr<Tensor>> &target, const std::shared_ptr<nn::Module> &loss_fn);

static std::tuple<bool, bool, int, int> GetStageInfo(int total_layers, int pp_size);
static std::tuple<bool, bool, std::vector<std::pair<int, int>>> GetStageInfo(int total_layers, int pp_size,
int chunks_per_stage = 1);

private:
int num_stages_ = -1;
Expand Down
29 changes: 19 additions & 10 deletions infini_train/include/nn/parallel/pp/pipeline_schedule.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,26 +26,35 @@ class PipelineSchedule {

virtual float StepMicroBatches(const std::vector<std::shared_ptr<Tensor>> &arg_mbs,
const std::vector<std::shared_ptr<Tensor>> &target_mbs,
const std::shared_ptr<nn::Module> &loss_fn)
= 0;
const std::shared_ptr<nn::Module> &loss_fn);

std::vector<std::shared_ptr<Tensor>> ReceiveFromPrev();
std::vector<std::shared_ptr<Tensor>> SendToNext(const std::vector<std::shared_ptr<Tensor>> &tensors);
std::vector<std::shared_ptr<Tensor>> ReceiveFromPrev(int peer_rank);
std::vector<std::shared_ptr<Tensor>> SendToNext(const std::vector<std::shared_ptr<Tensor>> &tensors, int peer_rank);

protected:
int num_micro_batches_ = -1;
int stage_index_ = -1;
std::shared_ptr<PipelineStage> stage_ = nullptr;
};

class ScheduleGPipe : public PipelineSchedule {
class PipelineParallelScheduler {
public:
ScheduleGPipe(std::shared_ptr<PipelineStage> stage, int num_stages, int num_micro_batches, int stage_index)
: PipelineSchedule(std::move(stage), num_stages, num_micro_batches, stage_index){};
struct Task {
int step;
int microbatch_id;
int global_chunk_id;
int local_chunk_idx;
bool is_forward;
int stage_id;
bool is_first_chunk;
bool is_last_chunk;
};

float StepMicroBatches(const std::vector<std::shared_ptr<Tensor>> &arg_mbs,
const std::vector<std::shared_ptr<Tensor>> &target_mbs,
const std::shared_ptr<nn::Module> &loss_fn) override;
static Task CreateTask(int step, int mb, int global_chunk, int num_stages, int total_chunks, bool is_forward);

static std::vector<Task> GenerateGPipeSchedule(int n, int num_stages, int vpp_size);

static std::vector<Task> GenerateInterleaved1F1BSchedule(int n, int num_stages, int vpp_size);
};

class Schedule1F1B : public PipelineSchedule {
Expand Down
3 changes: 2 additions & 1 deletion infini_train/include/nn/parallel/pp/pipeline_stage.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@ class PipelineStage {
PipelineStage(const std::shared_ptr<nn::Module> &model, int stage_index, int num_stages,
const std::vector<std::vector<int64_t>> &recv_shape, std::shared_ptr<Optimizer> optimizer);

std::vector<std::shared_ptr<Tensor>> ForwardOneChunk(const std::vector<std::shared_ptr<Tensor>> &inputs);
std::vector<std::shared_ptr<Tensor>> ForwardOneChunk(const std::vector<std::shared_ptr<Tensor>> &inputs,
int chunk_idx = 0);

bool IsFirstStage() const { return stage_index_ == 0; }
bool IsLastStage() const { return stage_index_ == num_stages_ - 1; }
Expand Down
4 changes: 4 additions & 0 deletions infini_train/src/autograd/function.cc
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,10 @@ void Function::BackwardPartial(const std::shared_ptr<Tensor> &grad_output, int g
const auto *device = grad_output->GetDevice();
device->SetDevice();

// 添加日志:当前执行的反向函数
// std::cout << "[Backward] Function: " << typeid(*this).name() << ", grad_output_idx: " << grad_output_idx
// << ", dependencies_reached: " << dependencies_reached_ << "/" << dependencies_number_ << std::endl;

// NOTE(dcj): The accumulate autograd function has no grad_outputs.
// Temporarily resize the vector to hold one nullptr as a buffer.
if (grad_outputs_.empty()) {
Expand Down
6 changes: 6 additions & 0 deletions infini_train/src/nn/modules/module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,12 @@ std::vector<std::shared_ptr<Tensor>> Module::Forward(const std::vector<std::shar
return {};
}

std::vector<std::shared_ptr<Tensor>> Module::ForwardChunk(int chunk_idx,
const std::vector<std::shared_ptr<Tensor>> &input_tensors) {
LOG(FATAL) << "ForwardChunk function not implemented for this module";
return {};
}

void Module::To(const Device *device) {
CHECK_NOTNULL(device);
if (device == device_) {
Expand Down
8 changes: 7 additions & 1 deletion infini_train/src/nn/parallel/global.cc
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ GlobalEnv &GlobalEnv::Instance() {
}

void GlobalEnv::Init(int nthread_per_process, int tensor_parallel_size, bool sequence_parallel_enabled,
int pipeline_parallel_size) {
int pipeline_parallel_size, int virtual_pipeline_parallel_size) {
std::lock_guard<std::mutex> lock(mutex_);

CHECK(!initialized_) << "Repeated initialization of GlobalEnv!";
Expand All @@ -100,6 +100,7 @@ void GlobalEnv::Init(int nthread_per_process, int tensor_parallel_size, bool seq
tensor_parallel_size_ = tensor_parallel_size;
sequence_parallel_enabled_ = sequence_parallel_enabled;
pipeline_parallel_size_ = pipeline_parallel_size;
virtual_pipeline_parallel_size_ = virtual_pipeline_parallel_size;
data_parallel_size_ = world_size_ / tensor_parallel_size_ / pipeline_parallel_size_;

layout_.sizes[DP] = data_parallel_size_;
Expand Down Expand Up @@ -156,6 +157,11 @@ int GlobalEnv::pipeline_parallel_size() const {
return pipeline_parallel_size_;
}

int GlobalEnv::virtual_pipeline_parallel_size() const {
CHECK(initialized_) << "GlobalEnv is not initialized!";
return virtual_pipeline_parallel_size_;
}

Layout GlobalEnv::layout() const {
CHECK(initialized_) << "GlobalEnv is not initialized!";
return layout_;
Expand Down
45 changes: 33 additions & 12 deletions infini_train/src/nn/parallel/pp/pipeline_parallel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@ void PipelineParallel::BuildPipelineStage(const std::shared_ptr<Module> &module,
}

void PipelineParallel::SetupSchedule(int num_micro_batches) {
schedule_ = std::make_shared<ScheduleGPipe>(pipeline_stage_, num_stages_, num_micro_batches, rank_);
// schedule_ = std::make_shared<Schedule1F1B>(pipeline_stage_, num_stages_, num_micro_batches, rank_);
schedule_ = std::make_shared<PipelineSchedule>(pipeline_stage_, num_stages_, num_micro_batches, rank_);
}

float PipelineParallel::TrainStep(const std::vector<std::shared_ptr<Tensor>> &input,
Expand All @@ -39,23 +40,43 @@ float PipelineParallel::TrainStep(const std::vector<std::shared_ptr<Tensor>> &in
return schedule_->Step(stage_input, stage_target, loss_fn);
}

std::tuple<bool, bool, int, int> PipelineParallel::GetStageInfo(int total_layers, int pp_size) {
std::tuple<bool, bool, std::vector<std::pair<int, int>>> PipelineParallel::GetStageInfo(int total_layers, int pp_size,
int chunks_per_stage) {
int rank = pp_rank;
bool is_first_stage = (pp_rank == 0);
bool is_last_stage = (pp_rank == pp_size - 1);

int layers_per_stage = total_layers / pp_size;
int remainder = total_layers % pp_size;
int start_layer, end_layer;
if (pp_rank < remainder) {
start_layer = pp_rank * (layers_per_stage + 1);
end_layer = start_layer + layers_per_stage + 1;
} else {
start_layer = pp_rank * layers_per_stage + remainder;
end_layer = start_layer + layers_per_stage;
std::vector<std::pair<int, int>> layer_chunks;

int layers_per_chunk = total_layers / (pp_size * chunks_per_stage);
int remainder = total_layers % (pp_size * chunks_per_stage);

for (int chunk_idx = 0; chunk_idx < chunks_per_stage; ++chunk_idx) {
int global_chunk_idx = chunk_idx * pp_size + rank;

if (global_chunk_idx * layers_per_chunk >= total_layers) {
break;
}

int chunk_start = global_chunk_idx * layers_per_chunk;
int chunk_end = chunk_start + layers_per_chunk;

if (global_chunk_idx < remainder) {
// Assign an additional layer to each of the first remainder chunks
chunk_start = global_chunk_idx * (layers_per_chunk + 1);
chunk_end = chunk_start + (layers_per_chunk + 1);
} else {
chunk_start = remainder * (layers_per_chunk + 1) + (global_chunk_idx - remainder) * layers_per_chunk;
chunk_end = chunk_start + layers_per_chunk;
}

chunk_end = std::min(chunk_end, total_layers);
if (chunk_start < chunk_end) {
layer_chunks.push_back({chunk_start, chunk_end});
}
}

return {is_first_stage, is_last_stage, start_layer, end_layer};
return {is_first_stage, is_last_stage, layer_chunks};
}

PipelineParallel::PipelineParallel(const std::shared_ptr<Module> module, int num_stages, int num_micro_batches,
Expand Down
Loading