diff --git a/CMakeLists.txt b/CMakeLists.txt index e9c6c5c..7fe767f 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -110,8 +110,8 @@ if (EXISTS ${FT_DIR}) else() FetchContent_Declare( repo-ft - GIT_REPOSITORY https://github.com/NVIDIA/FasterTransformer.git - GIT_TAG main + GIT_REPOSITORY https://github.com/neevaco/FasterTransformer.git + GIT_TAG 7bb372317da21dc7a898cb0e6e0ce7c11b0b38ec GIT_SHALLOW ON ) endif() diff --git a/src/libfastertransformer.cc b/src/libfastertransformer.cc index a870aa0..f478c22 100644 --- a/src/libfastertransformer.cc +++ b/src/libfastertransformer.cc @@ -49,10 +49,12 @@ // FT's libraries have dependency with triton's lib #include "src/fastertransformer/triton_backend/bert/BertTritonModel.h" +#include "src/fastertransformer/triton_backend/deberta/DebertaTritonModel.h" #include "src/fastertransformer/triton_backend/gptj/GptJTritonModel.h" #include "src/fastertransformer/triton_backend/gptj/GptJTritonModelInstance.h" #include "src/fastertransformer/triton_backend/gptneox/GptNeoXTritonModel.h" #include "src/fastertransformer/triton_backend/gptneox/GptNeoXTritonModelInstance.h" +#include "src/fastertransformer/triton_backend/llama/LlamaTritonModel.h" #include "src/fastertransformer/triton_backend/multi_gpu_gpt/ParallelGptTritonModel.h" #include "src/fastertransformer/triton_backend/multi_gpu_gpt/ParallelGptTritonModelInstance.h" #include "src/fastertransformer/triton_backend/t5/T5TritonModel.h" @@ -327,6 +329,37 @@ std::shared_ptr ModelState::ModelFactory( } else if (data_type == "bf16") { ft_model = std::make_shared>( tp, pp, custom_ar, model_dir, int8_mode, is_sparse, remove_padding); +#endif + } + } else if (model_type == "llama") { + const int int8_mode = param_get_int(param, "int8_mode"); + + if (data_type == "fp16") { + ft_model = std::make_shared>( + tp, pp, custom_ar, model_dir, int8_mode); + } else if (data_type == "fp32") { + ft_model = std::make_shared>( + tp, pp, custom_ar, model_dir, int8_mode); +#ifdef ENABLE_BF16 + } else if (data_type == "bf16") { + ft_model = std::make_shared>( + tp, pp, custom_ar, model_dir, int8_mode); +#endif + } + } else if (model_type == "deberta") { + const int is_sparse = param_get_bool(param,"is_sparse", false); + const int remove_padding = param_get_bool(param,"is_remove_padding", false); + + if (data_type == "fp16") { + ft_model = std::make_shared>( + tp, pp, custom_ar, model_dir, is_sparse, remove_padding); + } else if (data_type == "fp32") { + ft_model = std::make_shared>( + tp, pp, custom_ar, model_dir, is_sparse, remove_padding); +#ifdef ENABLE_BF16 + } else if (data_type == "bf16") { + ft_model = std::make_shared>( + tp, pp, custom_ar, model_dir, is_sparse, remove_padding); #endif } } else {