Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
46 commits
Select commit Hold shift + click to select a range
4debf6d
commit
sfc-gh-zhwang Aug 10, 2023
05303f9
commit
sfc-gh-zhwang Aug 10, 2023
8925434
commit
sfc-gh-zhwang Aug 11, 2023
110490a
commit
sfc-gh-zhwang Aug 11, 2023
c8f2976
commit
sfc-gh-zhwang Aug 11, 2023
255ca79
commit
sfc-gh-zhwang Aug 11, 2023
c804c16
commit
sfc-gh-zhwang Aug 11, 2023
fc21557
commit
sfc-gh-zhwang Aug 11, 2023
c74dfca
commit
sfc-gh-zhwang Aug 12, 2023
2a03595
commit
sfc-gh-zhwang Aug 12, 2023
91d18a4
commit
sfc-gh-zhwang Aug 13, 2023
4c6cd41
commit
sfc-gh-zhwang Aug 14, 2023
d42ee28
commit
sfc-gh-zhwang Aug 14, 2023
f8005eb
commit
sfc-gh-zhwang Aug 14, 2023
1bc93f0
commit
sfc-gh-zhwang Aug 14, 2023
0208d27
commit
sfc-gh-zhwang Aug 14, 2023
c9c9cee
commit
sfc-gh-zhwang Aug 14, 2023
1785956
commit
sfc-gh-zhwang Aug 14, 2023
d2bfc16
commit
sfc-gh-zhwang Aug 14, 2023
d6ae2e2
commit
sfc-gh-zhwang Aug 14, 2023
6e29c25
commit
sfc-gh-zhwang Aug 14, 2023
f84bb7f
commit
sfc-gh-zhwang Aug 14, 2023
ab960ad
commit
sfc-gh-zhwang Aug 15, 2023
714c8a6
commit
sfc-gh-zhwang Aug 15, 2023
30e610a
commit
sfc-gh-zhwang Aug 15, 2023
66b1b51
commit
sfc-gh-zhwang Aug 15, 2023
e568f16
commit
sfc-gh-zhwang Aug 15, 2023
02702e4
commit
sfc-gh-zhwang Aug 15, 2023
f3f9215
commit
sfc-gh-zhwang Aug 15, 2023
7c5f204
commit
sfc-gh-zhwang Aug 15, 2023
e0a69a5
commit
sfc-gh-zhwang Aug 15, 2023
4882622
commit
sfc-gh-zhwang Aug 15, 2023
11ea051
commit
sfc-gh-zhwang Aug 15, 2023
0a01134
commit
sfc-gh-zhwang Aug 15, 2023
b63fb3a
commit
sfc-gh-zhwang Aug 15, 2023
46dabdf
commit
sfc-gh-zhwang Aug 16, 2023
c3d1a9c
commit
sfc-gh-zhwang Aug 16, 2023
bc97cad
commit
sfc-gh-zhwang Aug 16, 2023
3b21384
commit
sfc-gh-zhwang Aug 16, 2023
2ddb1ed
commit
sfc-gh-zhwang Aug 16, 2023
c121271
commit
sfc-gh-zhwang Aug 16, 2023
faf12ff
commit
sfc-gh-zhwang Aug 17, 2023
a57e70a
commit
sfc-gh-zhwang Aug 17, 2023
1237419
commit
sfc-gh-zhwang Aug 17, 2023
e78f75b
commit
sfc-gh-zhwang Aug 17, 2023
379221c
commit
sfc-gh-zhwang Aug 19, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -110,8 +110,8 @@ if (EXISTS ${FT_DIR})
else()
FetchContent_Declare(
repo-ft
GIT_REPOSITORY https://github.com/neevaco/FasterTransformer.git
GIT_TAG main
GIT_REPOSITORY https://github.com/sfc-gh-zhwang/FasterTransformer
GIT_TAG e770ddf2bc66217034b6e9e3b0c3256ebf1c1b40
GIT_SHALLOW ON
)
endif()
Expand Down
21 changes: 19 additions & 2 deletions src/libfastertransformer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,8 @@
#include "src/fastertransformer/triton_backend/gptj/GptJTritonModelInstance.h"
#include "src/fastertransformer/triton_backend/gptneox/GptNeoXTritonModel.h"
#include "src/fastertransformer/triton_backend/gptneox/GptNeoXTritonModelInstance.h"
#include "src/fastertransformer/triton_backend/llama/LlamaTritonModel.h"
#include "src/fastertransformer/triton_backend/llama/LlamaTritonModelInstance.h"
#include "src/fastertransformer/triton_backend/multi_gpu_gpt/ParallelGptTritonModel.h"
#include "src/fastertransformer/triton_backend/multi_gpu_gpt/ParallelGptTritonModelInstance.h"
#include "src/fastertransformer/triton_backend/t5/T5TritonModel.h"
Expand All @@ -63,6 +65,7 @@
#include "src/fastertransformer/triton_backend/transformer_triton_backend.hpp"
#include "src/fastertransformer/utils/Tensor.h"
#include "src/fastertransformer/utils/cuda_bf16_wrapper.h"
#include "src/fastertransformer/utils/instance_comm.h"
#include "src/fastertransformer/utils/mpi_utils.h"
#include "src/fastertransformer/utils/nccl_utils.h"

Expand Down Expand Up @@ -346,6 +349,15 @@ std::shared_ptr<AbstractTransformerModel> ModelState::ModelFactory(
tp, pp, custom_ar, model_dir, is_sparse, remove_padding);
#endif
}
} else if (model_type == "Llama") {
const int is_sparse = param_get_bool(param,"is_sparse", false);
const int remove_padding = param_get_bool(param,"is_remove_padding", false);

if (data_type == "fp16") {
ft_model = std::make_shared<LlamaTritonModel<half>>(tp, pp, custom_ar, model_dir);
} else if (data_type == "fp32") {
ft_model = std::make_shared<LlamaTritonModel<float>>(tp, pp, custom_ar, model_dir);
}
} else {
THROW_IF_BACKEND_MODEL_ERROR(TRITONSERVER_ErrorNew(TRITONSERVER_ERROR_UNSUPPORTED,
("Unknown model \"" + model_type + "\"").c_str()));
Expand Down Expand Up @@ -713,6 +725,8 @@ class ModelInstanceState : public BackendModelInstance {
std::vector<std::unique_ptr<AbstractTransformerModelInstance>>
ft_model_instance_;

std::unique_ptr<ft::AbstractInstanceComm> instance_comm_;

// inter-node broadcast buffer
std::vector<char*> bcast_buffers;

Expand Down Expand Up @@ -838,6 +852,8 @@ ModelInstanceState::ModelInstanceState(
t.join();
}

instance_comm_ = shared_ft_model->createInstanceComm(tp_pp_size_);

LOG_MESSAGE(
TRITONSERVER_LOG_INFO,
(std::string("Model instance is created on GPU ") + model_instance_gpu_ids).c_str());
Expand Down Expand Up @@ -1263,6 +1279,7 @@ ThreadForward(
std::unique_ptr<AbstractTransformerModelInstance>* ft_model_instance,
std::shared_ptr<std::unordered_map<std::string, Tensor>>* input_tensors,
std::shared_ptr<std::unordered_map<std::string, Tensor>>* output_tensors,
ft::AbstractInstanceComm* instance_comm,
std::exception_ptr* exception_ptr,
const int device_id, const int use_stream_cb,
stream_callback_ctx_t* context)
Expand All @@ -1276,7 +1293,7 @@ ThreadForward(
if (use_stream_cb) {
(*ft_model_instance)->registerCallback(streaming_callback, (void*)context);
}
*output_tensors = (*ft_model_instance)->forward(*input_tensors);
*output_tensors = (*ft_model_instance)->forward(*input_tensors, instance_comm);
if (use_stream_cb) {
(*ft_model_instance)->unRegisterCallback();
}
Expand Down Expand Up @@ -1435,7 +1452,7 @@ ModelInstanceState::Execute(
.c_str());
threads.push_back(std::thread(
ThreadForward, &ft_model_instance_[instance_local_id], &input_tensors,
&output_tensors_list[instance_local_id], &exception_ptr[instance_local_id], gid,
&output_tensors_list[instance_local_id], instance_comm_.get(), &exception_ptr[instance_local_id], gid,
is_decoupled_ && gid == model_instance_device_id_start_, context));
LOG_MESSAGE(
TRITONSERVER_LOG_VERBOSE,
Expand Down