diff --git a/CMakeLists.txt b/CMakeLists.txt index e9c6c5c..833e298 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -110,8 +110,8 @@ if (EXISTS ${FT_DIR}) else() FetchContent_Declare( repo-ft - GIT_REPOSITORY https://github.com/NVIDIA/FasterTransformer.git - GIT_TAG main + GIT_REPOSITORY https://github.com/sfc-gh-zhwang/FasterTransformer + GIT_TAG e770ddf2bc66217034b6e9e3b0c3256ebf1c1b40 GIT_SHALLOW ON ) endif() diff --git a/src/libfastertransformer.cc b/src/libfastertransformer.cc index a870aa0..520e1da 100644 --- a/src/libfastertransformer.cc +++ b/src/libfastertransformer.cc @@ -49,10 +49,13 @@ // FT's libraries have dependency with triton's lib #include "src/fastertransformer/triton_backend/bert/BertTritonModel.h" +#include "src/fastertransformer/triton_backend/deberta/DebertaTritonModel.h" #include "src/fastertransformer/triton_backend/gptj/GptJTritonModel.h" #include "src/fastertransformer/triton_backend/gptj/GptJTritonModelInstance.h" #include "src/fastertransformer/triton_backend/gptneox/GptNeoXTritonModel.h" #include "src/fastertransformer/triton_backend/gptneox/GptNeoXTritonModelInstance.h" +#include "src/fastertransformer/triton_backend/llama/LlamaTritonModel.h" +#include "src/fastertransformer/triton_backend/llama/LlamaTritonModelInstance.h" #include "src/fastertransformer/triton_backend/multi_gpu_gpt/ParallelGptTritonModel.h" #include "src/fastertransformer/triton_backend/multi_gpu_gpt/ParallelGptTritonModelInstance.h" #include "src/fastertransformer/triton_backend/t5/T5TritonModel.h" @@ -62,6 +65,7 @@ #include "src/fastertransformer/triton_backend/transformer_triton_backend.hpp" #include "src/fastertransformer/utils/Tensor.h" #include "src/fastertransformer/utils/cuda_bf16_wrapper.h" +#include "src/fastertransformer/utils/instance_comm.h" #include "src/fastertransformer/utils/mpi_utils.h" #include "src/fastertransformer/utils/nccl_utils.h" @@ -329,6 +333,31 @@ std::shared_ptr ModelState::ModelFactory( tp, pp, custom_ar, model_dir, int8_mode, is_sparse, remove_padding); #endif } + } else if (model_type == "deberta") { + const int is_sparse = param_get_bool(param,"is_sparse", false); + const int remove_padding = param_get_bool(param,"is_remove_padding", false); + + if (data_type == "fp16") { + ft_model = std::make_shared>( + tp, pp, custom_ar, model_dir, is_sparse, remove_padding); + } else if (data_type == "fp32") { + ft_model = std::make_shared>( + tp, pp, custom_ar, model_dir, is_sparse, remove_padding); +#ifdef ENABLE_BF16 + } else if (data_type == "bf16") { + ft_model = std::make_shared>( + tp, pp, custom_ar, model_dir, is_sparse, remove_padding); +#endif + } + } else if (model_type == "Llama") { + const int is_sparse = param_get_bool(param,"is_sparse", false); + const int remove_padding = param_get_bool(param,"is_remove_padding", false); + + if (data_type == "fp16") { + ft_model = std::make_shared>(tp, pp, custom_ar, model_dir); + } else if (data_type == "fp32") { + ft_model = std::make_shared>(tp, pp, custom_ar, model_dir); + } } else { THROW_IF_BACKEND_MODEL_ERROR(TRITONSERVER_ErrorNew(TRITONSERVER_ERROR_UNSUPPORTED, ("Unknown model \"" + model_type + "\"").c_str())); @@ -696,6 +725,8 @@ class ModelInstanceState : public BackendModelInstance { std::vector> ft_model_instance_; + std::unique_ptr instance_comm_; + // inter-node broadcast buffer std::vector bcast_buffers; @@ -821,6 +852,8 @@ ModelInstanceState::ModelInstanceState( t.join(); } + instance_comm_ = shared_ft_model->createInstanceComm(tp_pp_size_); + LOG_MESSAGE( TRITONSERVER_LOG_INFO, (std::string("Model instance is created on GPU ") + model_instance_gpu_ids).c_str()); @@ -1246,6 +1279,7 @@ ThreadForward( std::unique_ptr* ft_model_instance, std::shared_ptr>* input_tensors, std::shared_ptr>* output_tensors, + ft::AbstractInstanceComm* instance_comm, std::exception_ptr* exception_ptr, const int device_id, const int use_stream_cb, stream_callback_ctx_t* context) @@ -1259,7 +1293,7 @@ ThreadForward( if (use_stream_cb) { (*ft_model_instance)->registerCallback(streaming_callback, (void*)context); } - *output_tensors = (*ft_model_instance)->forward(*input_tensors); + *output_tensors = (*ft_model_instance)->forward(*input_tensors, instance_comm); if (use_stream_cb) { (*ft_model_instance)->unRegisterCallback(); } @@ -1418,7 +1452,7 @@ ModelInstanceState::Execute( .c_str()); threads.push_back(std::thread( ThreadForward, &ft_model_instance_[instance_local_id], &input_tensors, - &output_tensors_list[instance_local_id], &exception_ptr[instance_local_id], gid, + &output_tensors_list[instance_local_id], instance_comm_.get(), &exception_ptr[instance_local_id], gid, is_decoupled_ && gid == model_instance_device_id_start_, context)); LOG_MESSAGE( TRITONSERVER_LOG_VERBOSE,