diff --git a/.vscode/settings.json b/.vscode/settings.json index 6f535da99..ea7bddf67 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -67,6 +67,33 @@ "unordered_set": "cpp", "future": "cpp", "cfenv": "cpp", - "typeindex": "cpp" + "typeindex": "cpp", + "__bit_reference": "cpp", + "__bits": "cpp", + "__config": "cpp", + "__debug": "cpp", + "__errc": "cpp", + "__hash_table": "cpp", + "__locale": "cpp", + "__mutex_base": "cpp", + "__node_handle": "cpp", + "__split_buffer": "cpp", + "__threading_support": "cpp", + "__tree": "cpp", + "__tuple": "cpp", + "__verbose_abort": "cpp", + "bit": "cpp", + "ios": "cpp", + "locale": "cpp", + "queue": "cpp", + "stack": "cpp", + "variant": "cpp", + "__nullptr": "cpp", + "__string": "cpp", + "compare": "cpp", + "concepts": "cpp", + "filesystem": "cpp", + "__memory": "cpp", + "version": "cpp" } -} \ No newline at end of file +} diff --git a/CMakeLists.txt b/CMakeLists.txt index 22c0f9c1f..b0879809b 100755 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -16,6 +16,21 @@ project(FasterTransformer LANGUAGES CXX CUDA) find_package(CUDA 10.2 REQUIRED) +include(FetchContent) + +FetchContent_Declare( + repo-cutlass + GIT_REPOSITORY https://github.com/NVIDIA/cutlass.git + GIT_TAG cc85b64cf676c45f98a17e3a47c0aafcf817f088 +) + +set(CUTLASS_ENABLE_HEADERS_ONLY ON CACHE BOOL "Enable only the header library") + +FetchContent_MakeAvailable(repo-cutlass) + +set(CUTLASS_HEADER_DIR ${PROJECT_SOURCE_DIR}/3rdparty/cutlass/include) +set(CUTLASS_EXTENSIONS_DIR ${PROJECT_SOURCE_DIR}/src/turbomind/cutlass_extensions/include) + if(${CUDA_VERSION_MAJOR} VERSION_GREATER_EQUAL "11") add_definitions("-DENABLE_BF16") message("CUDA_VERSION ${CUDA_VERSION_MAJOR}.${CUDA_VERSION_MINOR} is greater or equal than 11.0, enable -DENABLE_BF16 flag") @@ -346,6 +361,9 @@ add_library(transformer-shared SHARED $ $ $ + $ + $ + $ $ $ $ @@ -466,6 +484,7 @@ set_target_properties(transformer-shared PROPERTIES POSITION_INDEPENDENT_CODE ON set_target_properties(transformer-shared PROPERTIES CUDA_RESOLVE_DEVICE_SYMBOLS ON) set_target_properties(transformer-shared PROPERTIES LINKER_LANGUAGE CXX) target_link_libraries(transformer-shared PUBLIC -lcudart -lcublas -lcublasLt -lcurand) +target_link_libraries(transformer-shared PUBLIC stdc++fs) include(GNUInstallDirs) set(INSTALL_CONFIGDIR ${CMAKE_INSTALL_LIBDIR}/cmake/FasterTransformer) diff --git a/src/fastertransformer/kernels/decoder_masked_multihead_attention.h b/src/fastertransformer/kernels/decoder_masked_multihead_attention.h index 5a768184c..d86773f67 100644 --- a/src/fastertransformer/kernels/decoder_masked_multihead_attention.h +++ b/src/fastertransformer/kernels/decoder_masked_multihead_attention.h @@ -117,6 +117,9 @@ struct Multihead_attention_params_base { const float* qkv_scale_out = nullptr; const float* attention_out_scale = nullptr; int int8_mode = 0; + + float attention_k_scale = 0.f; + float attention_v_scale = 0.f; }; template @@ -135,6 +138,12 @@ struct Multihead_attention_params: public Multihead_attention_params_base { // required in case of masked attention with different length const int* length_per_sample = nullptr; + + T** k_cache_per_sample = nullptr; + T** v_cache_per_sample = nullptr; + size_t kv_cache_per_sample_offset = 0; + bool k_cache_interleaved = true; + int num_kv_heads = 0; }; template @@ -152,6 +161,12 @@ struct Multihead_attention_params: public Multihead_attention_params_ba // required in case of masked attention with different length const int* length_per_sample = nullptr; + + T** k_cache_per_sample = nullptr; + T** v_cache_per_sample = nullptr; + size_t kv_cache_per_sample_offset = 0; + bool k_cache_interleaved = true; + int num_kv_heads = 0; }; template diff --git a/src/fastertransformer/kernels/unfused_attention_kernels.cu b/src/fastertransformer/kernels/unfused_attention_kernels.cu index d0fb0a197..510865587 100644 --- a/src/fastertransformer/kernels/unfused_attention_kernels.cu +++ b/src/fastertransformer/kernels/unfused_attention_kernels.cu @@ -1556,6 +1556,42 @@ void invokeAddFusedQKVBiasTranspose(T* q_buf, } } +template +void invokeAddFusedQKVBiasTranspose(T* q_buf, + T* k_buf, + T* v_buf, + PrefixPromptBatchWeightsParam param, + T* QKV, + const T* qkv_bias, + const int* padding_offset, + const int* history_length, + const int batch_size, + const int seq_len, + const int token_num, + const int head_num, + const int kv_head_num, + const int size_per_head, + const int rotary_embedding_dim, + const int neox_rotary_style, + const float* scale, + const int int8_mode, + cudaStream_t stream) +{ + FT_CHECK(rotary_embedding_dim); + FT_CHECK_WITH_INFO(int8_mode != 2, "w8a8 not yet implemented with prefix prompt"); // TODO(mseznec) + // To implement rotary embeddings, each thread processes two QKV elems: + dim3 block((size_per_head / Vec_t::size + 31) / 32 * 32); + dim3 grid(token_num + batch_size * param.max_prefix_prompt_length, head_num); + size_t smem_size = neox_rotary_style ? 2 * rotary_embedding_dim * sizeof(T) : 0; + // NOTE: add offset for rotary embedding + if (param.max_prefix_prompt_length == 0) { + FUSED_QKV_BIAS_TRANSPOSE_LAUNCH(T, false); + } + else { + FUSED_QKV_BIAS_TRANSPOSE_LAUNCH(T, true); + } +} + #define INSTANTIATEADDFUSEDQKVBIASTRANSPOSE(T) \ template void invokeAddFusedQKVBiasTranspose(T* q_buf, \ T* k_buf, \ @@ -1573,6 +1609,25 @@ void invokeAddFusedQKVBiasTranspose(T* q_buf, const int neox_rotary_style, \ const float* scale, \ const int int8_mode, \ + cudaStream_t stream); \ + template void invokeAddFusedQKVBiasTranspose(T* q_buf, \ + T* k_buf, \ + T* v_buf, \ + PrefixPromptBatchWeightsParam param, \ + T* QKV, \ + const T* qkv_bias, \ + const int* padding_offset, \ + const int* history_length, \ + const int batch_size, \ + const int seq_len, \ + const int token_num, \ + const int head_num, \ + const int kv_head_num, \ + const int size_per_head, \ + const int rotary_embedding_dim, \ + const int neox_rotary_style, \ + const float* scale, \ + const int int8_mode, \ cudaStream_t stream) INSTANTIATEADDFUSEDQKVBIASTRANSPOSE(float); INSTANTIATEADDFUSEDQKVBIASTRANSPOSE(half); diff --git a/src/fastertransformer/kernels/unfused_attention_kernels.h b/src/fastertransformer/kernels/unfused_attention_kernels.h index 7ac7604d4..d47b16275 100644 --- a/src/fastertransformer/kernels/unfused_attention_kernels.h +++ b/src/fastertransformer/kernels/unfused_attention_kernels.h @@ -113,6 +113,27 @@ struct PrefixPromptBatchWeightsParam { const size_t prefix_prompt_layer_offset_per_seq = 0; }; +template +void invokeAddFusedQKVBiasTranspose(T* q_buf, + T* k_buf, + T* v_buf, + PrefixPromptBatchWeightsParam param, + T* QKV, + const T* qkv_bias, + const int* padding_offset, + const int* history_length, + const int batch_size, + const int seq_len, + const int token_num, + const int head_num, + const int kv_head_num, + const int size_per_head, + const int rotary_embedding_dim, + const int neox_rotary_style, + const float* scale, + const int int8_mode, + cudaStream_t stream); + template void invokeAddFusedQKVBiasTranspose(T* q_buf, T* k_buf, diff --git a/src/fastertransformer/layers/DynamicDecodeLayer.h b/src/fastertransformer/layers/DynamicDecodeLayer.h index 3b63cda92..3ed83ebff 100644 --- a/src/fastertransformer/layers/DynamicDecodeLayer.h +++ b/src/fastertransformer/layers/DynamicDecodeLayer.h @@ -56,6 +56,14 @@ class DynamicDecodeLayer: public BaseLayer { int* h_pinned_finished_sum_ = nullptr; public: + curandState_t* topk_curandstate_buf() + { + return static_cast*>(topk_decode_)->curandstate_buf(); + } + curandState_t* topp_curandstate_buf() + { + return static_cast*>(topp_decode_)->curandstate_buf(); + } DynamicDecodeLayer(size_t vocab_size, size_t vocab_size_padded, int end_id, diff --git a/src/fastertransformer/layers/sampling_layers/BaseSamplingLayer.h b/src/fastertransformer/layers/sampling_layers/BaseSamplingLayer.h index 3b02cd596..3489b8a1f 100644 --- a/src/fastertransformer/layers/sampling_layers/BaseSamplingLayer.h +++ b/src/fastertransformer/layers/sampling_layers/BaseSamplingLayer.h @@ -59,6 +59,11 @@ class BaseSamplingLayer: public DynamicDecodeBaseLayer { virtual void allocateBuffer(size_t batch_size, Tensor top_k, Tensor top_p); public: + curandState_t* curandstate_buf() + { + return curandstate_buf_; + } + BaseSamplingLayer(size_t max_batch_size, size_t vocab_size, size_t vocab_size_padded, diff --git a/src/fastertransformer/models/CMakeLists.txt b/src/fastertransformer/models/CMakeLists.txt index 248b4af3d..6ff04fd59 100644 --- a/src/fastertransformer/models/CMakeLists.txt +++ b/src/fastertransformer/models/CMakeLists.txt @@ -19,6 +19,7 @@ add_subdirectory(bert_fp8) endif() add_subdirectory(deberta) add_subdirectory(decoder) +add_subdirectory(llama) add_subdirectory(longformer) add_subdirectory(decoding) add_subdirectory(xlnet) diff --git a/src/fastertransformer/models/llama/Barrier.h b/src/fastertransformer/models/llama/Barrier.h new file mode 100644 index 000000000..12015b5a7 --- /dev/null +++ b/src/fastertransformer/models/llama/Barrier.h @@ -0,0 +1,37 @@ +// Copyright (c) OpenMMLab. All rights reserved. + +#pragma once + +#include "src/fastertransformer/utils/logger.h" +#include + +namespace fastertransformer { + +class Barrier { +public: + Barrier(unsigned count) + { + FT_LOG_INFO("Barrier(%d)", (int)count); + pthread_barrier_init(&barrier_, nullptr, count); + } + + Barrier(const Barrier&) = delete; + Barrier& operator=(const Barrier&) = delete; + Barrier(Barrier&&) noexcept = delete; + Barrier& operator=(Barrier&&) noexcept = delete; + + void wait() + { + pthread_barrier_wait(&barrier_); + } + + ~Barrier() + { + pthread_barrier_destroy(&barrier_); + } + +private: + pthread_barrier_t barrier_{}; +}; + +} // namespace fastertransformer diff --git a/src/fastertransformer/models/llama/CMakeLists.txt b/src/fastertransformer/models/llama/CMakeLists.txt new file mode 100644 index 000000000..74d61ecd2 --- /dev/null +++ b/src/fastertransformer/models/llama/CMakeLists.txt @@ -0,0 +1,43 @@ +# Copyright (c) OpenMMLab. All rights reserved. + +cmake_minimum_required(VERSION 3.8) + +add_subdirectory(fused_multi_head_attention) + +add_library(Llama STATIC + LlamaV2.cc + LlamaBatch.cc + LlamaCacheManager.cc + LlamaContextDecoder.cc + LlamaContextAttentionLayer.cc + LlamaDecoderSelfAttentionLayer.cc + LlamaDecoder.cc + LlamaWeight.cc + LlamaDecoderLayerWeight.cc + LlamaFfnLayer.cc + llama_kernels.cu + llama_decoder_kernels.cu + llama_utils.cu) +set_property(TARGET Llama PROPERTY POSITION_INDEPENDENT_CODE ON) +set_property(TARGET Llama PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON) +target_link_libraries(Llama PUBLIC -lcudart + cublasMMWrapper + DynamicDecodeLayer + activation_kernels + decoder_masked_multihead_attention + bert_preprocess_kernels + decoding_kernels + unfused_attention_kernels + custom_ar_kernels + custom_ar_comm + gpt_kernels + tensor + memory_utils + nccl_utils + cuda_utils + logger + stdc++fs + llama_fmha) + +add_executable(llama_gemm llama_gemm.cc) +target_link_libraries(llama_gemm PUBLIC -lcudart gpt_gemm_func memory_utils cuda_utils logger) diff --git a/src/fastertransformer/models/llama/LlamaBatch.cc b/src/fastertransformer/models/llama/LlamaBatch.cc new file mode 100644 index 000000000..4fce8f70d --- /dev/null +++ b/src/fastertransformer/models/llama/LlamaBatch.cc @@ -0,0 +1,1117 @@ +// Copyright (c) OpenMMLab. All rights reserved. + +#include "src/fastertransformer/models/llama/LlamaBatch.h" +#include "src/fastertransformer/kernels/decoding_kernels.h" +#include "src/fastertransformer/models/llama/LlamaNcclGuard.h" +#include "src/fastertransformer/models/llama/LlamaV2.h" +#include "src/fastertransformer/models/llama/Request.h" +#include "src/fastertransformer/models/llama/llama_utils.h" +#include "src/fastertransformer/utils/Tensor.h" +#include "src/fastertransformer/utils/logger.h" +#include +#include +#include +#include + +namespace fastertransformer { + +template +void LlamaBatch::verifyRequests(std::vector>& stop_reqs, + std::vector>& infer_reqs) +{ + std::unordered_map occurrence; + + auto count_occurrence = [&occurrence](const std::vector>& rs) { + for (const auto& r : rs) { + ++occurrence[r->id]; + } + }; + + auto invalidate = [](const char* type, std::shared_ptr& req, int ec) { + FT_LOG_WARNING("[verifyRequests] Skipping invalid %s request for id %ld, code = %d", type, (long)req->id, ec); + req->signal.set_value(ec); + req.reset(); + }; + + auto handle_conflict_or_invalid = [this, &occurrence, &invalidate](std::vector>& rs, + const char* type) { + for (auto& r : rs) { + if (r) { + int ec = 0; + + if (occurrence[r->id] != 1) { + ec = Request::kConflict; + } + else if (r->start_flag && r->stop_flag) { + ec = Request::kInvalid; + } + else if (!r->start_flag && !llama_->kv_cache_mgr_->contains(r->id)) { + ec = Request::kInvalid; + } + + if (ec) { + invalidate(type, r, ec); + } + } + } + }; + + auto drop_invalid = [](std::vector>& rs) { + int count = 0; + for (int i = 0; i < rs.size(); ++i) { + if (rs[i]) { + rs[count++] = std::move(rs[i]); + } + } + rs.resize(count); + }; + + count_occurrence(stop_reqs); + count_occurrence(infer_reqs); + + if (!stop_reqs.empty()) { + handle_conflict_or_invalid(stop_reqs, "stop"); + + // invalidate stop-only requests for inactive sequences + for (auto& r : stop_reqs) { + if (r && r->end_flag == false) { + int ec = Request::kInactive; + for (int i = 0; i < batch_size_; ++i) { + if (requests_[i] && requests_[i]->id == r->id) { + ec = 0; + break; + } + } + if (ec) { + invalidate("stop", r, ec); + } + } + } + + drop_invalid(stop_reqs); + } + + if (!infer_reqs.empty()) { + handle_conflict_or_invalid(infer_reqs, "infer"); + + // invalidate requests for busy sequences + for (auto& r : infer_reqs) { + if (r) { + for (int i = 0; i < batch_size_; ++i) { + if (requests_[i] && requests_[i]->id == r->id) { + invalidate("infer", r, Request::kBusy); + break; + } + } + } + } + + drop_invalid(infer_reqs); + } +} + +template +void LlamaBatch::handleStopRequests(const std::vector>& requests) +{ + for (const auto& r : requests) { + int ec = Request::kFail; + // find matching active sequence + for (int i = 0; i < batch_size_; ++i) { + // stop & optionally erase active sequence + if (requests_[i] && requests_[i]->id == r->id) { + ec = 0; + finishRequest(i, r->end_flag); + break; + } + } + // mismatch, try erase inactive sequence + if (ec && r->end_flag) { + ec = 0; + llama_->kv_cache_mgr_->erase(r->id); + } + // clear output buffers (prevent leaking conversations) if request is successful + if (ec == 0) { + auto& output_ids = r->outputs[rank_].at("output_ids"); + auto& sequence_length = r->outputs[rank_].at("sequence_length"); + check_cuda_error( + cudaMemsetAsync(output_ids.getPtr(), 0, sizeof(int) * output_ids.shape.at(2), stream_)); + check_cuda_error(cudaMemsetAsync(sequence_length.getPtr(), 0, sizeof(int), stream_)); + check_cuda_error(cudaStreamSynchronize(stream_)); + } + if (rank_ == 0) { + r->signal.set_value(ec); + } + } +} + +template +void LlamaBatch::allocateBuffer(size_t batch_size, size_t session_len) +{ + FT_LOG_DEBUG(__PRETTY_FUNCTION__); + const size_t batchxbeam = batch_size; + + const size_t hidden_units = llama_->hidden_units_; + const size_t vocab_size = llama_->vocab_size_; + + context_decoder_input_buf_ = + (T*)allocator_->reMalloc(context_decoder_input_buf_, sizeof(T) * max_context_token_num_ * hidden_units, false); + context_decoder_output_buf_ = + (T*)allocator_->reMalloc(context_decoder_output_buf_, sizeof(T) * max_context_token_num_ * hidden_units, false); + context_decoder_ids_buf_ = + (int*)allocator_->reMalloc(context_decoder_ids_buf_, sizeof(int) * max_context_token_num_, false); + + decoder_input_buf_ = (T*)allocator_->reMalloc(decoder_input_buf_, sizeof(T) * batchxbeam * hidden_units, false); + decoder_output_buf_ = (T*)allocator_->reMalloc(decoder_output_buf_, sizeof(T) * batchxbeam * hidden_units, false); + + input_ids_buf_ = (int*)allocator_->reMalloc(input_ids_buf_, sizeof(int) * batchxbeam * session_len, true); + input_length_buf_ = (int*)allocator_->reMalloc(input_length_buf_, sizeof(int) * batchxbeam); + history_length_buf_ = (int*)allocator_->reMalloc(history_length_buf_, sizeof(int) * batchxbeam); + context_length_buf_ = (int*)allocator_->reMalloc(context_length_buf_, sizeof(int) * batchxbeam); + + total_padding_count_ = (int*)allocator_->reMalloc(total_padding_count_, sizeof(int) * batchxbeam, false); + sequence_lengths_ = (int*)allocator_->reMalloc(sequence_lengths_, sizeof(int) * batchxbeam, false); + + k_cache_ptr_buf_ = (uint64_t*)allocator_->reMalloc(k_cache_ptr_buf_, sizeof(uint64_t) * batchxbeam); + v_cache_ptr_buf_ = (uint64_t*)allocator_->reMalloc(v_cache_ptr_buf_, sizeof(uint64_t) * batchxbeam); + + logits_buf_ = (float*)allocator_->reMalloc(logits_buf_, sizeof(float) * batchxbeam * vocab_size, false); + local_logits_buf_ = (float*)allocator_->reMalloc(local_logits_buf_, sizeof(float) * batchxbeam * vocab_size, false); + + token_ids_buf_ = (int*)allocator_->reMalloc(token_ids_buf_, sizeof(int) * batchxbeam * session_len * 2, true); + + end_ids_buf_ = (int*)allocator_->reMalloc(end_ids_buf_, sizeof(int) * batch_size, false); + finished_buf_ = (bool*)allocator_->reMalloc(finished_buf_, sizeof(bool) * batchxbeam, false); + seq_limit_len_ = (uint32_t*)allocator_->reMalloc(seq_limit_len_, sizeof(uint32_t) * batch_size, false); + + is_allocate_buffer_ = true; +} + +template +void LlamaBatch::allocatePersistantBuffer(size_t max_batch_size) +{ + output_ids_buf_ = (int*)allocator_->reMalloc(output_ids_buf_, sizeof(int) * max_batch_size * session_len_, true); + + stop_words_buf_ = + (int*)allocator_->reMalloc(stop_words_buf_, sizeof(int) * max_batch_size * kMaxStopBadWordsLen, true); + bad_words_buf_ = + (int*)allocator_->reMalloc(bad_words_buf_, sizeof(int) * max_batch_size * kMaxStopBadWordsLen, true); + + h_runtime_top_k_ = (int*)allocator_->reMalloc(h_runtime_top_k_, sizeof(int) * max_batch_size, true, true); + h_runtime_top_p_ = (float*)allocator_->reMalloc(h_runtime_top_p_, sizeof(float) * max_batch_size, true, true); + h_temperature_ = (float*)allocator_->reMalloc(h_temperature_, sizeof(float) * max_batch_size, true, true); + h_repetition_penalty_ = + (float*)allocator_->reMalloc(h_repetition_penalty_, sizeof(float) * max_batch_size, true, true); + h_random_seed_ = (uint64_t*)allocator_->reMalloc(h_random_seed_, sizeof(uint64_t) * max_batch_size, true, true); + + sampling_params_ = {{"stop_words_list", stop_words_buf_}, + {"bad_words_list", bad_words_buf_}, + {"runtime_top_k", h_runtime_top_k_}, + {"runtime_top_p", h_runtime_top_p_}, + {"temperature", h_temperature_}, + {"repetition_penalty", h_repetition_penalty_}, + {"random_seed", h_random_seed_}}; + + topk_curandstate_buf_ = allocator_->reMalloc(topk_curandstate_buf_, sizeof(curandState_t) * max_batch_size, true); + topp_curandstate_buf_ = allocator_->reMalloc(topp_curandstate_buf_, sizeof(curandState_t) * max_batch_size, true); + + { + NcclGuard barrier(llama_->tensor_para_, stream_, true); + h_input_ids_buf_ = + (int*)allocator_->reMalloc(h_input_ids_buf_, sizeof(int) * max_batch_size * session_len_, false, true); + h_input_length_buf_ = + (int*)allocator_->reMalloc(h_input_length_buf_, sizeof(int) * max_batch_size, false, true); + h_history_length_buf_ = + (int*)allocator_->reMalloc(h_history_length_buf_, sizeof(int) * max_batch_size, false, true); + h_context_length_buf_ = + (int*)allocator_->reMalloc(h_context_length_buf_, sizeof(int) * max_batch_size, false, true); + h_sequence_lengths_ = + (int*)allocator_->reMalloc(h_sequence_lengths_, sizeof(int) * max_batch_size, false, true); + h_k_cache_ptr_buf_ = + (uintptr_t*)allocator_->reMalloc(h_k_cache_ptr_buf_, sizeof(uintptr_t) * max_batch_size, true, true); + h_v_cache_ptr_buf_ = + (uintptr_t*)allocator_->reMalloc(h_v_cache_ptr_buf_, sizeof(uintptr_t) * max_batch_size, true, true); + h_finished_buf_ = (bool*)allocator_->reMalloc(h_finished_buf_, sizeof(bool) * max_batch_size, false, true); + h_seq_limit_len_ = + (uint32_t*)allocator_->reMalloc(h_seq_limit_len_, sizeof(uint32_t) * max_batch_size, false, true); + } + + is_allocate_persistant_buffer_ = true; +} + +template +void LlamaBatch::freeBuffer() +{ + FT_LOG_DEBUG(__PRETTY_FUNCTION__); + if (is_allocate_buffer_) { + allocator_->free((void**)&context_decoder_input_buf_); + allocator_->free((void**)&context_decoder_output_buf_); + allocator_->free((void**)&context_decoder_ids_buf_); + + allocator_->free((void**)&decoder_input_buf_); + allocator_->free((void**)&decoder_output_buf_); + + allocator_->free((void**)&input_ids_buf_); + allocator_->free((void**)&input_length_buf_); + allocator_->free((void**)&history_length_buf_); + allocator_->free((void**)&context_length_buf_); + + allocator_->free((void**)&total_padding_count_); + allocator_->free((void**)&sequence_lengths_); + + allocator_->free((void**)&k_cache_ptr_buf_); + allocator_->free((void**)&v_cache_ptr_buf_); + + allocator_->free((void**)&logits_buf_); + allocator_->free((void**)&local_logits_buf_); + + if (local_context_logits_buf_) { + allocator_->free((void**)&local_context_logits_buf_); + } + if (context_logits_buf_) { + allocator_->free((void**)&context_logits_buf_); + } + + allocator_->free((void**)&token_ids_buf_); + + allocator_->free((void**)&end_ids_buf_); + allocator_->free((void**)&finished_buf_); + allocator_->free((void**)&seq_limit_len_); + + is_allocate_buffer_ = false; + } + + if (is_allocate_persistant_buffer_) { + allocator_->free((void**)&h_input_ids_buf_, true); + allocator_->free((void**)&h_input_length_buf_, true); + allocator_->free((void**)&h_history_length_buf_, true); + allocator_->free((void**)&h_context_length_buf_, true); + allocator_->free((void**)&h_sequence_lengths_, true); + allocator_->free((void**)&h_k_cache_ptr_buf_, true); + allocator_->free((void**)&h_v_cache_ptr_buf_, true); + allocator_->free((void**)&h_seq_limit_len_, true); + allocator_->free((void**)&h_finished_buf_, true); + + allocator_->free((void**)&output_ids_buf_); + + is_allocate_persistant_buffer_ = false; + } +} + +template +LlamaBatch::LlamaBatch(int max_batch_size, int max_context_token_num, int session_len, LlamaV2* llama): + max_batch_size_(max_batch_size), + max_context_token_num_(max_context_token_num), + session_len_(session_len), + rank_(llama->tensor_para_.rank_), + debug_(llama->debug_), + llama_(llama), + data_type_(getTensorType()) +{ + stream_ = llama_->stream_; + allocator_ = llama_->allocator_; + cublas_wrapper_ = llama_->cublas_wrapper_; + + requests_.resize(max_batch_size); + request_seq_len_limit_.resize(max_batch_size); + cached_seq_.resize(max_batch_size); + + allocatePersistantBuffer(max_batch_size); +} + +template +void LlamaBatch::initializeSampling(int infer_request_count) +{ + TensorMap inputs; + for (const auto& param : sampling_params_) { + const Tensor* ptr{}; + for (int i = 0; i < batch_size_; ++i) { + if (requests_[i]->inputs[rank_].isExist(param.first)) { + ptr = &requests_[i]->inputs[rank_].at(param.first); + break; + } + } + if (ptr) { + const auto& ref = *ptr; + auto shape = ref.shape; + FT_CHECK(shape[0] == 1); + shape[0] = batch_size_; + const int size_in_bytes = ref.sizeBytes(); + check_cuda_error(cudaMemsetAsync(param.second, 0, size_in_bytes * batch_size_, stream_)); + for (int i = 0; i < batch_size_; ++i) { + if (requests_[i]->inputs[rank_].isExist(param.first)) { + auto& src = requests_[i]->inputs[rank_].at(param.first); + FT_CHECK(ref.shape == src.shape); + check_cuda_error(cudaMemcpyAsync((uint8_t*)param.second + size_in_bytes * i, + src.getPtr(), + size_in_bytes, + cudaMemcpyDefault, + stream_)); + } + } + inputs.insert({param.first, {ref.where, ref.type, shape, param.second}}); + if (debug_ && rank_ == 0) { + FT_LOG_INFO("[initializeSampling] %s", format({param.first, inputs.at(param.first)}).c_str()); + } + } + } + + inputs_ = std::move(inputs); + + llama_->dynamic_decode_layer_->setup(batch_size_, 1, &inputs_); + + for (int i = 0; i < batch_size_; ++i) { + // recover random states if not a new request or new request w/o "random_seed" + if (i < batch_size_ - infer_request_count || !requests_[i]->inputs[rank_].isExist("random_seed")) { + check_cuda_error(cudaMemcpyAsync(llama_->dynamic_decode_layer_->topk_curandstate_buf() + i, + (curandState_t*)topk_curandstate_buf_ + i, + sizeof(curandState_t), + cudaMemcpyDefault, + stream_)); + check_cuda_error(cudaMemcpyAsync(llama_->dynamic_decode_layer_->topp_curandstate_buf() + i, + (curandState_t*)topp_curandstate_buf_ + i, + sizeof(curandState_t), + cudaMemcpyDefault, + stream_)); + } + } + + handleOptArg(&inputs_, "end_id", end_ids_buf_, llama_->end_id_, batch_size_); + cudaStreamSynchronize(0); +} + +template +void LlamaBatch::initializeGeneration() +{ + max_context_len_ = *std::max_element(h_context_length_buf_, h_context_length_buf_ + batch_size_); + + check_cuda_error(cudaMemsetAsync(token_ids_buf_, 0, sizeof(int) * batch_size_ * session_len_ * 2, stream_)); + invokeTransposeAxis01(token_ids_buf_, output_ids_buf_, batch_size_, session_len_, 1, stream_); + sync_check_cuda_error(); + + // token_ids_buf_[s, b] + // ABCDe ABCDe e + // ABCDEFGHIJk ABCDEFGHIJk + // ABCDEFGHi -> ABCDEFGHi i + // ABCDEFGh ABCDEFGh h + // ABCd ABCd d + for (int i = 0; i < batch_size_; ++i) { + auto token_ids = token_ids_buf_ + i; + auto p_src = h_context_length_buf_[i] - 1; + auto p_dst = max_context_len_ - 1; + if (p_src != p_dst) { // dst and src of `cudaMemcpyAsync` must not overlap + check_cuda_error(cudaMemcpyAsync(token_ids + p_dst * batch_size_, + token_ids + p_src * batch_size_, + sizeof(int), + cudaMemcpyDefault, + stream_)); + } + } + + check_cuda_error(cudaMemcpyAsync( + context_length_buf_, h_context_length_buf_, sizeof(int) * batch_size_, cudaMemcpyDefault, stream_)); + check_cuda_error(cudaMemcpyAsync( + k_cache_ptr_buf_, h_k_cache_ptr_buf_, sizeof(uintptr_t) * batch_size_, cudaMemcpyDefault, stream_)); + check_cuda_error(cudaMemcpyAsync( + v_cache_ptr_buf_, h_v_cache_ptr_buf_, sizeof(uintptr_t) * batch_size_, cudaMemcpyDefault, stream_)); + + check_cuda_error( + cudaMemcpyAsync(sequence_lengths_, context_length_buf_, sizeof(int) * batch_size_, cudaMemcpyDefault, stream_)); + // `sequence_lengths_` will be increased by dynamic decode + // note that in decoder and in output "sequence length" has different semantic + // - in decoder it means length of sequence that has kv cache already computed + // - in output it means length of all tokens (the last generated token does not have k/v cache computed yet) + invokePlusScalar(sequence_lengths_, -1, batch_size_, stream_); + sync_check_cuda_error(); + + // total_padding_count_ + // decoding starts at max_context_len + check_cuda_error(cudaMemsetAsync(total_padding_count_, 0, sizeof(int) * batch_size_, stream_)); + invokeUpdatePaddingCount(total_padding_count_, // + context_length_buf_, + max_context_len_, + batch_size_, + 1, + stream_); + sync_check_cuda_error(); + + // seq_limit_len_, will be compared to `step` instead of `sequence_length`, so padding len should be accounted for + for (int i = 0; i < batch_size_; ++i) { + h_seq_limit_len_[i] = request_seq_len_limit_[i] + (max_context_len_ - h_context_length_buf_[i]); + // mask finished sequences + h_finished_buf_[i] = max_context_len_ >= h_seq_limit_len_[i]; + } + check_cuda_error( + cudaMemcpyAsync(seq_limit_len_, h_seq_limit_len_, sizeof(uint32_t) * batch_size_, cudaMemcpyDefault, stream_)); + check_cuda_error( + cudaMemcpyAsync(finished_buf_, h_finished_buf_, sizeof(bool) * batch_size_, cudaMemcpyDefault, stream_)); + + // ! range of step_ [1, 2 * session_len] + // consider a sequence with context_len == session_len and another sequence with context_len == 1 and + // request_output_len == session_len - 1 => step_ will loop in [session_len, 2 * session_len) + step_ = max_context_len_; + + if (rank_ == 0) { + FT_LOG_INFO("[initGen] batch_size = %d", (int)batch_size_); + FT_LOG_INFO("[initGen] max_context_len = %d", (int)max_context_len_); + + FT_LOG_INFO("[initGen] slot sequence_id context_len seq_limit_len finished"); + for (int i = 0; i < batch_size_; ++i) { + FT_LOG_INFO("[initGen] %4d %11ld %11d %13d %8d", + i, + (long)cached_seq_[i].id, + h_context_length_buf_[i], + (int)h_seq_limit_len_[i], + (int)h_finished_buf_[i]); + } + } +} + +template +bool LlamaBatch::generate() +{ + constexpr int kLogInterval = 10; + if (rank_ == 0 && (step_ - 1) % kLogInterval == 0) { + FT_LOG_INFO("------------------------- step = %d -------------------------", step_ - 1); + } + + const bool is_first_step = step_ == max_context_len_; + + std::vector prev; + if (debug_ && rank_ == 0 && is_first_step) { + prev.resize(batch_size_); + cudaMemcpyAsync(prev.data(), + token_ids_buf_ + (step_ - 1) * batch_size_, + sizeof(int) * batch_size_, + cudaMemcpyDefault, + stream_); + } + + // embeddingLookup(step_ - 1); + llama_->embeddingLookup(decoder_input_buf_, // + token_ids_buf_, + batch_size_, + step_ - 1); + + llama_->decoderForward(decoder_output_buf_, + k_cache_ptr_buf_, + v_cache_ptr_buf_, + decoder_input_buf_, + sequence_lengths_, + total_padding_count_, + finished_buf_, + step_, + 0, + session_len_, + batch_size_); + + llama_->postDecodeEmbedding(logits_buf_, // + local_logits_buf_, + decoder_output_buf_, + batch_size_); + + // stop-words & bad-words require the matched tokens to be contiguous, so item size > 1 is + // not supported yet. + bool should_stop{}; + llama_->dynamicDecode(token_ids_buf_, + finished_buf_, + sequence_lengths_, + &should_stop, + &inputs_, + &outputs_, + logits_buf_, + seq_limit_len_, + context_length_buf_, + end_ids_buf_, + step_, + 0, + max_context_len_, + session_len_ * 2, + batch_size_); + + if (debug_ && rank_ == 0) { + std::vector curr(batch_size_); + + cudaMemcpyAsync( + curr.data(), token_ids_buf_ + step_ * batch_size_, sizeof(int) * batch_size_, cudaMemcpyDefault, stream_); + cudaStreamSynchronize(stream_); + + if (is_first_step) { + std::stringstream sprev; + for (int k = 0; k < prev.size(); ++k) { + sprev << std::setw(6) << prev[k]; + } + FT_LOG_INFO("[ lookup ] step = %d, [%s]", step_ - 1, sprev.str().c_str()); + } + + std::stringstream scurr; + for (int k = 0; k < curr.size(); ++k) { + scurr << std::setw(6) << curr[k]; + } + FT_LOG_INFO("[generate] step = %d, [%s]", step_ - 1, scurr.str().c_str()); + } + + //////////////////////////////////////////////// + /// ! increase the step counter + ++step_; + + return !should_stop; +} + +template +void LlamaBatch::initialize(const std::vector>& infer_requests) +{ + FT_CHECK(batch_size_ + infer_requests.size() <= max_batch_size_); + + const int infer_request_count = infer_requests.size(); + + allocateBuffer(batch_size_ + infer_request_count, session_len_); + + // handle infer requests + std::vector tmp_input_length(infer_request_count); + std::vector tmp_cached_seq; + tmp_cached_seq.reserve(infer_request_count); + + int tmp_max_input_length = 0; + for (int i = 0; i < infer_request_count; ++i) { + auto& r = *infer_requests[i]; + + LlamaCacheManager::Sequence seq{}; + if (r.start_flag) { + seq = llama_->kv_cache_mgr_->create(r.id, stream_); + } + else { + seq = llama_->kv_cache_mgr_->fetch(r.id, stream_); + } + + const int step = r.inputs[rank_].getVal("step", -1); + if (step >= 0) { + if (step <= seq.token_ids.size()) { + seq.token_ids.resize(step); + seq.cache_len = std::min(seq.cache_len, (size_t)step); + } + else if (rank_ == 0) { + FT_LOG_WARNING("[initialize] Skipping invalid step (%d) setting for ID %ld", step, (long)seq.id); + } + } + + // input length with missing cache accounted for + int actual_input_len = r.inputs[rank_].getVal("input_lengths") + (seq.token_ids.size() - seq.cache_len); + + // insert `start_id` for empty sequences + if (seq.token_ids.empty() && actual_input_len == 0) { + seq.token_ids.push_back(llama_->start_id_); + seq.cache_len = 0; + actual_input_len = seq.token_ids.size() - seq.cache_len; + } + + tmp_input_length[i] = actual_input_len; + + tmp_max_input_length = std::max((int)tmp_max_input_length, actual_input_len); + tmp_cached_seq.push_back(std::move(seq)); + } + + FT_CHECK(tmp_max_input_length > 0); + const int max_input_length = tmp_max_input_length; + + // arrange requests in ascending order w.r.t actual input lengths, so that requests need context decoding will + // be together + { + std::vector idxs(tmp_input_length.size()); + std::iota(idxs.begin(), idxs.end(), 0); + std::sort(idxs.begin(), idxs.end(), [&](int i, int j) { return tmp_input_length[i] < tmp_input_length[j]; }); + for (int i = 0; i < idxs.size(); ++i) { + requests_[batch_size_ + i] = infer_requests[idxs[i]]; + cached_seq_[batch_size_ + i] = tmp_cached_seq[idxs[i]]; + } + } + + const int count = batch_size_ + infer_requests.size(); + + std::vector tmp_input_len(count); + + for (int i = batch_size_; i < count; ++i) { + const auto& seq = cached_seq_[i]; + + h_input_length_buf_[i] = requests_[i]->inputs[rank_].getVal("input_lengths"); + tmp_input_len[i] = h_input_length_buf_[i]; + // prepare output ids + // <--------> max_context_len + // aaaAAAA + // bbbbBBBBBB + // ccCCC + auto output_ids_ptr = output_ids_buf_ + i * session_len_; + + // clear the persistent buffer to prevent leaking previous conversation + check_cuda_error(cudaMemsetAsync(output_ids_ptr, 0, sizeof(int) * session_len_, stream_)); + + if (!seq.token_ids.empty()) { + check_cuda_error(cudaMemcpyAsync(output_ids_ptr, // + seq.token_ids.data(), + sizeof(int) * seq.token_ids.size(), + cudaMemcpyDefault, + stream_)); + output_ids_ptr += seq.token_ids.size(); + } + + if (h_input_length_buf_[i]) { + auto input_ids_ptr = requests_[i]->inputs[rank_].getPtr("input_ids"); + check_cuda_error(cudaMemcpyAsync(output_ids_ptr, // + input_ids_ptr, + sizeof(int) * h_input_length_buf_[i], + cudaMemcpyDefault, + stream_)); + } + + if (!requests_[i]->start_flag && !seq.random_state_.empty()) { + check_cuda_error(cudaMemcpyAsync((curandState_t*)topk_curandstate_buf_ + i, + seq.random_state_.data(), + sizeof(curandState_t), + cudaMemcpyDefault, + stream_)); + check_cuda_error(cudaMemcpyAsync((curandState_t*)topp_curandstate_buf_ + i, + seq.random_state_.data() + sizeof(curandState_t), + sizeof(curandState_t), + cudaMemcpyDefault, + stream_)); + } + } + + for (int i = batch_size_; i < count; ++i) { + const auto& seq = cached_seq_[i]; + const int missed = (int)seq.token_ids.size() - seq.cache_len; + auto input_ids_buf = input_ids_buf_ + i * session_len_; + FT_CHECK(missed >= 0); + if (missed > 0) { + check_cuda_error(cudaMemcpyAsync(input_ids_buf, // + seq.token_ids.data() + seq.cache_len, + sizeof(int) * missed, + cudaMemcpyDefault, + stream_)); + input_ids_buf += missed; + } + auto& input_ids = requests_[i]->inputs[rank_].at("input_ids"); + check_cuda_error(cudaMemcpyAsync(input_ids_buf, // + input_ids.getPtr(), + sizeof(int) * h_input_length_buf_[i], + cudaMemcpyDefault, + stream_)); + h_input_length_buf_[i] += missed; + h_history_length_buf_[i] = seq.cache_len; + h_context_length_buf_[i] = h_input_length_buf_[i] + h_history_length_buf_[i]; + + const int request_output_len = requests_[i]->inputs[rank_].getVal("request_output_len"); + request_seq_len_limit_[i] = h_context_length_buf_[i] + request_output_len; + // `length_criterion` sets finish flag when step >= seq_limit_len, however when step == seq_limit_len + // the actual sequence length is seq_limit_len + 1, hence seq_limit_len must truncated to session_len - 1 + if (request_seq_len_limit_[i] >= session_len_) { + request_seq_len_limit_[i] = session_len_ - 1; + if (rank_ == 0) { + const int trunc_output_len = request_seq_len_limit_[i] - h_context_length_buf_[i]; + FT_LOG_WARNING( + "[initialize] [%ld] total sequence length (%d + %d) exceeds session_len (%d), request_output_len is truncated to %d", + (long)seq.id, + h_context_length_buf_[i], + request_output_len, + (int)session_len_, + trunc_output_len); + } + } + + h_k_cache_ptr_buf_[i] = (uint64_t)seq.k_cache; + h_v_cache_ptr_buf_[i] = (uint64_t)seq.v_cache; + } + + const int max_context_len = *std::max_element(h_context_length_buf_ + batch_size_, h_context_length_buf_ + count); + + batch_size_ = count; + max_context_len_ = max_context_len; + step_ = max_context_len; + + check_cuda_error( + cudaMemcpyAsync(input_length_buf_, h_input_length_buf_, sizeof(int) * batch_size_, cudaMemcpyDefault, stream_)); + check_cuda_error(cudaMemcpyAsync( + history_length_buf_, h_history_length_buf_, sizeof(int) * batch_size_, cudaMemcpyDefault, stream_)); + check_cuda_error(cudaMemcpyAsync( + context_length_buf_, h_context_length_buf_, sizeof(int) * batch_size_, cudaMemcpyDefault, stream_)); + check_cuda_error(cudaMemcpyAsync( + k_cache_ptr_buf_, h_k_cache_ptr_buf_, sizeof(uintptr_t) * batch_size_, cudaMemcpyDefault, stream_)); + check_cuda_error(cudaMemcpyAsync( + v_cache_ptr_buf_, h_v_cache_ptr_buf_, sizeof(uintptr_t) * batch_size_, cudaMemcpyDefault, stream_)); + + if (llama_->tensor_para_.rank_ == 0) { + FT_LOG_INFO("[init] infer_request_count = %d", (int)infer_request_count); + FT_LOG_INFO("[init] batch_size = %d", (int)batch_size_); + FT_LOG_INFO("[init] session_len = %d", (int)session_len_); + FT_LOG_INFO("[init] max_input_length = %d", (int)max_input_length); + FT_LOG_INFO("[init] max_context_len = %d", (int)max_context_len); + FT_LOG_INFO( + "[init] slot sequence_id history_len input_len context_len tmp_input_len token_ids.size cache_len"); + for (int i = batch_size_ - infer_request_count; i < batch_size_; ++i) { + FT_LOG_INFO("[init] %4d %11ld %11d %9d %11d %13d %14d %9d", + i, + (int)cached_seq_[i].id, + h_history_length_buf_[i], + h_input_length_buf_[i], + h_context_length_buf_[i], + tmp_input_len[i], + (int)cached_seq_[i].token_ids.size(), + (int)cached_seq_[i].cache_len); + } + } +} + +template +void LlamaBatch::contextDecode() +{ + int base = -1; + for (int i = 0; i < batch_size_; ++i) { + if (h_input_length_buf_[i] > 1) { + base = i; + break; + } + } + if (base >= 0) { + check_cuda_error(cudaStreamSynchronize(stream_)); + const auto tick = std::chrono::high_resolution_clock::now(); + + const int context_decode_count = batch_size_ - base; + if (rank_ == 0) { + FT_LOG_INFO("[decodeContext] base = %d, count = %d", base, context_decode_count); + } + invokePlusScalar(input_length_buf_ + base, -1, context_decode_count, stream_); + invokePlusScalar(context_length_buf_ + base, -1, context_decode_count, stream_); + + auto get_input_len = [this](int index) { return h_input_length_buf_[index] - 1; }; + auto get_context_len = [this](int index) { return h_context_length_buf_[index] - 1; }; + + std::vector decode_indices{base}; + std::vector decode_lengths{get_input_len(base)}; + + auto token_num = get_input_len(base); + auto max_input_len = get_input_len(base); + auto max_context_len = get_context_len(base); + auto offset = base; + for (int i = offset + 1; i <= batch_size_; ++i) { + if (i == batch_size_ || token_num + h_context_length_buf_[i] > max_context_token_num_) { + const int context_decode_batch_size = i - offset; + if (rank_ == 0) { + FT_LOG_INFO( + "[decodeContext] offset = %d, batch_size = %d, token_num = %d, max_input_len = %d, max_context_len = %d", + base, + context_decode_batch_size, + token_num, + max_input_len, + max_context_len); + } + // construct context_decoder_ids w/o padding + // aaaa____ + // bb______ -> aaaabbcccccccc + // cccccccc + auto context_decoder_ids = context_decoder_ids_buf_; + for (int j = offset; j < i; ++j) { + check_cuda_error(cudaMemcpyAsync(context_decoder_ids, + input_ids_buf_ + j * session_len_, + sizeof(int) * get_input_len(j), + cudaMemcpyDefault, + stream_)); + context_decoder_ids += get_input_len(j); + } + llama_->contextDecode(nullptr, + k_cache_ptr_buf_ + offset, + v_cache_ptr_buf_ + offset, + context_decoder_input_buf_, + context_decoder_output_buf_, + context_decoder_ids_buf_, + input_length_buf_ + offset, + history_length_buf_ + offset, + context_length_buf_ + offset, + token_num, + max_input_len, + max_context_len, + session_len_, + context_decode_batch_size); + + // compute logits of inputs if requested + outputContextLogits(context_decoder_output_buf_, decode_indices, decode_lengths); + + if (i < batch_size_) { + // initialize next sub-batch + token_num = get_input_len(i); + max_input_len = get_input_len(i); + max_context_len = get_context_len(i); + offset = i; + + decode_indices = {i}; + decode_lengths = {get_input_len(i)}; + } + } + else { + // add to current sub-batch + token_num += get_input_len(i); + max_input_len = std::max(max_input_len, get_input_len(i)); + max_context_len = std::max(max_context_len, get_context_len(i)); + + decode_indices.push_back(i); + decode_lengths.push_back(get_input_len(i)); + } + } + + invokePlusScalar(context_length_buf_ + base, 1, context_decode_count, stream_); + invokePlusScalar(input_length_buf_ + base, 1, context_decode_count, stream_); + + for (int i = offset; i < batch_size_; ++i) { + h_input_length_buf_[i] = 0; + } + + check_cuda_error(cudaStreamSynchronize(stream_)); + const auto tock = std::chrono::high_resolution_clock::now(); + if (rank_ == 0) { + FT_LOG_INFO("[decodeContext] %.2f ms", std::chrono::duration(tock - tick).count()); + } + } + else if (rank_ == 0) { + FT_LOG_INFO("[decodeContext] Context decoding is not needed."); + } +} + +template +void LlamaBatch::outputContextLogits(T* context_decoder_output, + const std::vector& indices, + const std::vector& lengths) +{ + std::vector output_logits; + int num_token = 0; + { + bool is_return_logits = false; + for (int k = 0; k < indices.size(); ++k) { + auto& request = requests_[indices[k]]; + output_logits.push_back(request->outputs[rank_].getPtr("logits", nullptr)); + num_token += lengths[k]; + if (output_logits.back()) { + is_return_logits = true; + } + } + if (!is_return_logits) { + return; + } + } + + if (context_logits_buf_ == nullptr) { + NcclGuard guard(llama_->tensor_para_, stream_, true); + context_logits_buf_ = (float*)allocator_->malloc(sizeof(float) * llama_->vocab_size_ * max_context_token_num_); + const auto tp = llama_->tensor_para_.world_size_; + if (tp > 1) { + FT_CHECK(llama_->vocab_size_ % tp == 0); + const auto local_vocab_size = llama_->vocab_size_ / tp; + local_context_logits_buf_ = + (float*)allocator_->malloc(sizeof(float) * local_vocab_size * max_context_token_num_); + } + } + + llama_->postDecodeEmbedding(context_logits_buf_, local_context_logits_buf_, context_decoder_output, num_token); + + auto logits = context_logits_buf_; + + for (int k = 0; k < indices.size(); ++k) { + if (output_logits[k]) { + check_cuda_error(cudaMemcpyAsync(output_logits[k], + logits, + sizeof(float) * llama_->vocab_size_ * lengths[k], + cudaMemcpyDefault, + stream_)); + } + logits += llama_->vocab_size_ * lengths[k]; + } +} + +template +void LlamaBatch::finish() +{ + // secure info needed by `synchronize()` + check_cuda_error( + cudaMemcpyAsync(h_finished_buf_, finished_buf_, sizeof(bool) * batch_size_, cudaMemcpyDefault, stream_)); + check_cuda_error( + cudaMemcpyAsync(h_sequence_lengths_, sequence_lengths_, sizeof(int) * batch_size_, cudaMemcpyDefault, stream_)); + + setOutputTensors(step_); + + check_cuda_error(cudaStreamSynchronize(stream_)); + + for (int i = 0; i < batch_size_; ++i) { + FT_CHECK(requests_[i] != nullptr); + if (requests_[i]->stream_cb && rank_ == 0) { + requests_[i]->stream_cb(&requests_[i]->outputs[rank_].get()); + } + } + + if (debug_ && rank_ == 0) { + std::stringstream ss; + for (int i = 0; i < batch_size_; ++i) { + ss << (i ? ", " : "") << "(" << h_sequence_lengths_[i] << "," << h_finished_buf_[i] << ")"; + } + FT_LOG_INFO("[finish] [%s]", ss.str().c_str()); + } + + for (int i = 0; i < batch_size_; ++i) { + if (h_finished_buf_[i]) { + finishRequest(i, false); + ++finished_count_; + } + } +} + +template +void LlamaBatch::synchronize() +{ + // compact + int idx = 0; + for (int i = 0; i < batch_size_; ++i) { + if (requests_[i]) { + h_input_length_buf_[idx] = 0; + h_history_length_buf_[idx] = 0; + + h_context_length_buf_[idx] = h_sequence_lengths_[i] + 1; + h_sequence_lengths_[idx] = h_context_length_buf_[idx]; + + check_cuda_error(cudaMemcpyAsync((curandState_t*)topk_curandstate_buf_ + idx, + llama_->dynamic_decode_layer_->topk_curandstate_buf() + i, + sizeof(curandState_t), + cudaMemcpyDefault, + stream_)); + check_cuda_error(cudaMemcpyAsync((curandState_t*)topp_curandstate_buf_ + idx, + llama_->dynamic_decode_layer_->topp_curandstate_buf() + i, + sizeof(curandState_t), + cudaMemcpyDefault, + stream_)); + + if (i != idx) { + h_finished_buf_[idx] = h_finished_buf_[i]; + request_seq_len_limit_[idx] = request_seq_len_limit_[i]; + + h_k_cache_ptr_buf_[idx] = h_k_cache_ptr_buf_[i]; + h_v_cache_ptr_buf_[idx] = h_v_cache_ptr_buf_[i]; + + requests_[idx] = std::move(requests_[i]); + cached_seq_[idx] = std::move(cached_seq_[i]); + check_cuda_error(cudaMemcpyAsync(output_ids_buf_ + idx * session_len_, + output_ids_buf_ + i * session_len_, + sizeof(int) * h_context_length_buf_[idx], + cudaMemcpyDefault, + stream_)); + } + ++idx; + } + } + batch_size_ = idx; + + if (rank_ == 0) { + FT_LOG_INFO("[synchronize] batch_size = %d", (int)batch_size_); + } + + finished_count_ = 0; +} + +template +void LlamaBatch::setOutputTensors(int max_gen_step) +{ + // [s,b] -> [b,s] and skip padding in [context_len, max_context_len) + invokeGatherOutput(output_ids_buf_, + token_ids_buf_, + context_length_buf_, + max_context_len_, + max_gen_step, + session_len_, + batch_size_, + stream_); + sync_check_cuda_error(); + + /// TODO: fuse the loop into a single kernel + for (int i = 0; i < batch_size_; ++i) { + if (requests_[i]) { + auto& output_ids = requests_[i]->outputs[rank_].at("output_ids"); + auto& sequence_length = requests_[i]->outputs[rank_].at("sequence_length"); + check_cuda_error(cudaMemcpyAsync(output_ids.getPtr(), + output_ids_buf_ + i * session_len_, + sizeof(int) * output_ids.shape.at(2), + cudaMemcpyDefault, + stream_)); + check_cuda_error(cudaMemcpyAsync( + sequence_length.getPtr(), sequence_lengths_ + i, sizeof(int), cudaMemcpyDefault, stream_)); + if (max_gen_step > max_context_len_) { // +1 for newly generated token + invokePlusScalar(sequence_length.getPtr(), 1, 1, stream_); + } + } + } +} + +template +void LlamaBatch::finishRequest(int index, bool force_end) +{ + if (rank_ == 0) { + FT_LOG_INFO("[finishRequest] slot = %d, id = %lu", index, (long)requests_[index]->id); + } + + if (debug_ && rank_ == 0) { + std::vector tokens(h_sequence_lengths_[index] + 1); + cudaMemcpyAsync(tokens.data(), + output_ids_buf_ + index * session_len_, + sizeof(int) * tokens.size(), + cudaMemcpyDefault, + stream_); + cudaStreamSynchronize(stream_); + std::stringstream ss; + for (const auto& t : tokens) { + ss << " " << t; + } + FT_LOG_INFO("[finishRequest] slot %d, tokens [%s]", index, ss.str().c_str()); + } + + auto& output_ids_tensor = requests_[index]->outputs[rank_].at("output_ids"); + const auto output_ids_data = output_ids_tensor.getPtr(); + if (requests_[index]->end_flag || force_end) { + llama_->kv_cache_mgr_->erase(requests_[index]->id); + } + else { + // the last generated token is not processed by decoder thus dont have k/v cache + const int n_steps = step_ - max_context_len_; + const int cache_len = h_sequence_lengths_[index]; + const int output_len = n_steps > 0 ? cache_len + 1 : cache_len; + + auto& seq = cached_seq_[index]; + + seq.cache_len = cache_len; + + // update token IDs + seq.token_ids.resize(output_len); + check_cuda_error(cudaMemcpyAsync( + seq.token_ids.data(), output_ids_data, sizeof(int) * output_len, cudaMemcpyDefault, stream_)); + + // update random states + seq.random_state_.resize(sizeof(curandState_t) * 2); + check_cuda_error(cudaMemcpyAsync(seq.random_state_.data(), + llama_->dynamic_decode_layer_->topk_curandstate_buf() + index, + sizeof(curandState_t), + cudaMemcpyDefault, + stream_)); + check_cuda_error(cudaMemcpyAsync(seq.random_state_.data() + sizeof(curandState_t), + llama_->dynamic_decode_layer_->topp_curandstate_buf() + index, + sizeof(curandState_t), + cudaMemcpyDefault, + stream_)); + + check_cuda_error(cudaStreamSynchronize(stream_)); + + llama_->kv_cache_mgr_->update(cached_seq_[index], stream_); + } + + if (rank_ == 0) { + requests_[index]->signal.set_value(0); + } + + requests_[index] = nullptr; +} + +template class LlamaBatch; +template class LlamaBatch; + +} // namespace fastertransformer diff --git a/src/fastertransformer/models/llama/LlamaBatch.h b/src/fastertransformer/models/llama/LlamaBatch.h new file mode 100644 index 000000000..30f4f84ee --- /dev/null +++ b/src/fastertransformer/models/llama/LlamaBatch.h @@ -0,0 +1,159 @@ +// Copyright (c) OpenMMLab. All rights reserved. + +#pragma once + +#include "src/fastertransformer/models/llama/LlamaCacheManager.h" +#include "src/fastertransformer/models/llama/LlamaNcclGuard.h" +#include "src/fastertransformer/models/llama/Request.h" +#include "src/fastertransformer/utils/allocator.h" +#include "src/fastertransformer/utils/cublasMMWrapper.h" + +namespace fastertransformer { + +template +class LlamaV2; + +template +class LlamaBatch { +public: + int size() const noexcept + { + return batch_size_; + }; + + int maxSize() const noexcept + { + return max_batch_size_; + } + + int finishedCount() const noexcept + { + return finished_count_; + } + + void verifyRequests(std::vector>& stop_reqs, + std::vector>& infer_reqs); + void handleStopRequests(const std::vector>& requests); + + void allocateBuffer(size_t batch_size, size_t session_len); + void allocatePersistantBuffer(size_t max_batch_size); + void freeBuffer(); + + void initializeSampling(int infer_request_count); + + void initialize(const std::vector>& infer_requests); + void contextDecode(); + + void initializeGeneration(); + bool generate(); + + void finish(); + void finishRequest(int index, bool force_end); + + void synchronize(); + + void setOutputTensors(int max_gen_step); + + void + outputContextLogits(T* context_decoder_output, const std::vector& indices, const std::vector& lengths); + + explicit LlamaBatch(int max_batch_size, int max_context_token_num, int session_len, LlamaV2* llama); + + ~LlamaBatch() + { + freeBuffer(); + } + +private: + const int max_batch_size_; + const int max_context_token_num_; + const int session_len_; + const int rank_; + const bool debug_; + + LlamaV2* const llama_; + + // active requests + std::vector> requests_; + + T* context_decoder_input_buf_{}; // CTXDEC + T* context_decoder_output_buf_{}; // CTXDEC + int* context_decoder_ids_buf_{}; + + T* decoder_input_buf_{}; // CTXDEC, GENERATE + T* decoder_output_buf_{}; // CTXDEC, GENERATE + + int* input_ids_buf_{}; // input token ids + cache missed token ids, CTXDEC + int* input_length_buf_{}; // input + cache missed length, CTXDEC, GENERATE + int* history_length_buf_{}; // history length, CTXDEC + int* context_length_buf_{}; // history length + input_length, CTXDEC, GENERATE + + int* total_padding_count_{}; // GENERATE + int* sequence_lengths_{}; // current sequence length + + uint64_t* k_cache_ptr_buf_{}; + uint64_t* v_cache_ptr_buf_{}; + + float* logits_buf_{}; // combined logits + float* local_logits_buf_{}; // tensor parallel local logits + float* context_logits_buf_{}; + float* local_context_logits_buf_{}; + + // used by dynamic decoder + int* token_ids_buf_{}; // all token IDs in [S, B], indexed using `step` + int* output_ids_buf_{}; // output ids in [B, S] + int* end_ids_buf_{}; + bool* finished_buf_{}; + uint32_t* seq_limit_len_{}; + + // pinned buffers + int* h_input_ids_buf_{}; + int* h_input_length_buf_{}; + int* h_history_length_buf_{}; + int* h_context_length_buf_{}; + int* h_sequence_lengths_{}; + bool* h_finished_buf_{}; + uintptr_t* h_k_cache_ptr_buf_{}; + uintptr_t* h_v_cache_ptr_buf_{}; + uint32_t* h_seq_limit_len_{}; + + int* stop_words_buf_{}; // [batch_size, 2, kMaxStopWordsLen] + int* bad_words_buf_{}; + int* h_runtime_top_k_{}; + float* h_runtime_top_p_{}; + float* h_temperature_{}; + float* h_repetition_penalty_{}; + uint64_t* h_random_seed_{}; + + void* topk_curandstate_buf_{}; + void* topp_curandstate_buf_{}; + + // hard limits for persistent buffers + static constexpr int kMaxStopBadWordsLen = 32; + + using CachedSeq = LlamaCacheManager::Sequence; + + std::vector cached_seq_; + std::vector request_seq_len_limit_; + + const DataType data_type_{}; + + int batch_size_{}; + int max_context_len_{}; + int step_{}; + int finished_count_{}; + + bool is_allocate_persistant_buffer_ = false; + bool is_allocate_buffer_ = false; + + TensorMap inputs_; + TensorMap outputs_; + + std::unordered_map sampling_params_; + + cudaStream_t stream_{}; + cublasMMWrapper* cublas_wrapper_{}; + IAllocator* allocator_{}; +}; + +} // namespace fastertransformer diff --git a/src/fastertransformer/models/llama/LlamaCacheManager.cc b/src/fastertransformer/models/llama/LlamaCacheManager.cc new file mode 100644 index 000000000..2f51b3a01 --- /dev/null +++ b/src/fastertransformer/models/llama/LlamaCacheManager.cc @@ -0,0 +1,192 @@ +// Copyright (c) OpenMMLab. All rights reserved. + +#include "src/fastertransformer/models/llama/LlamaCacheManager.h" +#include "src/fastertransformer/utils/cuda_utils.h" +#include "src/fastertransformer/utils/logger.h" + +namespace fastertransformer { + +LlamaCacheManager::~LlamaCacheManager() +{ + for (auto& p : device_mem_) { + allocator_->free(&p, false); + } +} + +void* LlamaCacheManager::allocate(bool is_preallocte) +{ + if (rank_ == 0) { + FT_LOG_INFO("[LlamaCacheManager][allocate]"); + } + + void* mem_ptr{}; + + if (!device_free_.empty()) { + mem_ptr = device_free_.front(); + device_free_.pop(); + + if (rank_ == 0) { + FT_LOG_INFO("[LlamaCacheManager][allocate] free = %d", (int)device_free_.size()); + } + } + else if (entry_count_ < max_entry_count_) { + const auto alloc_count = std::min(chunk_size_, max_entry_count_ - entry_count_); + const size_t entry_byte_size = 2 * cache_byte_size_; // 2 for k,v + + if (rank_ == 0) { + FT_LOG_INFO("[LlamaCacheManager][allocate] malloc %d", (int)alloc_count); + } + const auto chunk_ptr = allocator_->malloc(alloc_count * entry_byte_size, false); + FT_CHECK(chunk_ptr); + device_mem_.push_back(chunk_ptr); + entry_count_ += alloc_count; + if (rank_ == 0) { + FT_LOG_INFO("[LlamaCacheManager][allocate] count = %d", entry_count_); + } + + for (int i = 0; i < alloc_count; ++i) { + device_free_.push((uint8_t*)chunk_ptr + entry_byte_size * i); + } + + if (!is_preallocte) { + mem_ptr = device_free_.front(); + device_free_.pop(); + } + + if (rank_ == 0) { + FT_LOG_INFO("[LlamaCacheManager][allocate] free = %d", (int)device_free_.size()); + } + } + else { + mem_ptr = evict(); + FT_CHECK_WITH_INFO(mem_ptr, "No enough cache entries."); + } + + return mem_ptr; +} + +auto LlamaCacheManager::create(uint64_t id, cudaStream_t stream) -> Sequence +{ + if (rank_ == 0) { + FT_LOG_INFO("[LlamaCacheManager][create] %ld", (long)id); + } + + for (const auto& e : device_cache_) { + if (e.id == id) { + if (rank_ == 0) { + FT_LOG_WARNING("[LlamaCacheManager][create] Removing conflicting id %ld", (long)id); + } + erase(id); + } + } + + const auto mem_ptr = (uint8_t*)allocate(false); + check_cuda_error(cudaMemsetAsync(mem_ptr, 0, cache_byte_size_ * 2, stream)); + + device_cache_.push_back({ + id, + max_seq_len_, + {}, + 0, + mem_ptr, + mem_ptr + cache_byte_size_, + {}, + static_cast(-1), + }); + + return device_cache_.back(); +} + +auto LlamaCacheManager::getEntryOrThrow(uint64_t id) -> std::vector::iterator +{ + auto pred = [&](const Sequence& s) { return s.id == id; }; + auto it = std::find_if(device_cache_.begin(), device_cache_.end(), pred); + if (it == device_cache_.end()) { + FT_LOG_ERROR("[LlamaCacheManager] %ld not found.\n", (long)id); + FT_CHECK(0); + } + return it; +} + +auto LlamaCacheManager::fetch(uint64_t id, cudaStream_t stream) -> Sequence +{ + if (rank_ == 0) { + FT_LOG_INFO("[LlamaCacheManager][fetch] %ld", (long)id); + } + + auto entry = getEntryOrThrow(id); + + if (entry->k_cache == nullptr) { + FT_CHECK(entry->cache_len == 0); + const auto mem_ptr = allocate(false); + check_cuda_error(cudaMemsetAsync(mem_ptr, 0, cache_byte_size_ * 2, stream)); + entry->k_cache = mem_ptr; + entry->v_cache = (uint8_t*)entry->k_cache + cache_byte_size_; + } + + entry->timestamp = static_cast(-1); + return *entry; +} + +void LlamaCacheManager::update(const Sequence& seq, cudaStream_t stream) +{ + if (rank_ == 0) { + FT_LOG_INFO("[LlamaCacheManager][update] %ld", (long)seq.id); + } + + auto entry = getEntryOrThrow(seq.id); + + entry->timestamp = ++timestamp_; + entry->token_ids = seq.token_ids; + entry->cache_len = seq.cache_len; + FT_CHECK(seq.k_cache == entry->k_cache && seq.v_cache == entry->v_cache); +} + +void LlamaCacheManager::erase(uint64_t id) +{ + if (rank_ == 0) { + FT_LOG_INFO("[LlamaCacheManager][erase] %ld", (long)id); + } + + auto entry = getEntryOrThrow(id); + + if (entry->k_cache) { + device_free_.push(entry->k_cache); + if (rank_ == 0) { + FT_LOG_INFO("[LlamaCacheManager][erase] free = %d", (int)device_free_.size()); + } + } + device_cache_.erase(entry); +} + +void* LlamaCacheManager::evict() +{ + FT_CHECK(!device_cache_.empty()); + auto it = std::min_element(device_cache_.begin(), device_cache_.end(), [](const auto& a, const auto& b) { + return a.timestamp < b.timestamp; + }); + + if (it->timestamp == static_cast(-1)) { + return nullptr; + } + + if (rank_ == 0) { + FT_LOG_INFO("[LlamaCacheManager][evict] %ld", (long)it->id); + } + + FT_CHECK(it->k_cache); + auto mem_ptr = it->k_cache; + it->k_cache = it->v_cache = nullptr; + it->cache_len = 0; + it->timestamp = static_cast(-1); + return mem_ptr; +} + +bool LlamaCacheManager::contains(uint64_t id) const noexcept +{ + auto pred = [&](const Sequence& s) { return s.id == id; }; + auto it = std::find_if(device_cache_.begin(), device_cache_.end(), pred); + return it != device_cache_.end(); +} + +} // namespace fastertransformer diff --git a/src/fastertransformer/models/llama/LlamaCacheManager.h b/src/fastertransformer/models/llama/LlamaCacheManager.h new file mode 100644 index 000000000..82000447c --- /dev/null +++ b/src/fastertransformer/models/llama/LlamaCacheManager.h @@ -0,0 +1,102 @@ +// Copyright (c) OpenMMLab. All rights reserved. + +#include "src/fastertransformer/utils/allocator.h" +#include "src/fastertransformer/utils/logger.h" +#include +#include +#include +#include +#include + +namespace fastertransformer { + +// k-cache layout [L, H, D/x, S[s:], x] +// v-cache layout [L, H, S[s:], D/x, x] + +class LlamaCacheManager { +public: + LlamaCacheManager(size_t layer_num, + size_t head_num, + size_t size_per_head, + size_t max_seq_len, + size_t elem_bits, + size_t max_entry_count, + size_t chunk_size, + int rank, + IAllocator* allocator): + layer_num_(layer_num), + head_num_(head_num), + size_per_head_(size_per_head), + max_seq_len_(max_seq_len), + elem_bits_(elem_bits), + cache_byte_size_(layer_num_ * head_num_ * max_seq_len_ * size_per_head_ * elem_bits_ / 8), + max_entry_count_(max_entry_count), + chunk_size_(chunk_size), + rank_(rank), + allocator_(allocator) + { + if (rank == 0) { + FT_LOG_INFO("[LlamaCacheManager] max_entry_count = %d", (int)max_entry_count_); + FT_LOG_INFO("[LlamaCacheManager] chunk_size = %d", (int)chunk_size_); + } + allocate(true); + } + + ~LlamaCacheManager(); + + struct Sequence { + // header + uint64_t id; + size_t max_seq_len; + + // payloads + std::vector token_ids; // all token ids + size_t cache_len; // cache_len == 0 -> cache miss + void* k_cache; + void* v_cache; + + std::vector random_state_; // states for RNGs + + // for LRU policy + uint64_t timestamp; + }; + + Sequence create(uint64_t id, cudaStream_t stream); + + Sequence fetch(uint64_t id, cudaStream_t stream); + + void update(const Sequence& seq, cudaStream_t stream); + + void erase(uint64_t id); + + bool contains(uint64_t id) const noexcept; + +private: + std::vector::iterator getEntryOrThrow(uint64_t id); + + void* allocate(bool is_preallocte); + + void* evict(); + +private: + const size_t layer_num_{}; + const size_t head_num_{}; + const size_t size_per_head_{}; + const size_t max_seq_len_{}; + const size_t elem_bits_{}; + const size_t cache_byte_size_{}; + const size_t max_entry_count_{}; + const size_t chunk_size_{}; + const int rank_{}; + IAllocator* allocator_{}; + + std::queue device_free_; + std::vector device_mem_; + int entry_count_{}; + + uint64_t timestamp_{}; + + std::vector device_cache_; +}; + +} // namespace fastertransformer diff --git a/src/fastertransformer/models/llama/LlamaContextAttentionLayer.cc b/src/fastertransformer/models/llama/LlamaContextAttentionLayer.cc new file mode 100644 index 000000000..f22dbe4b1 --- /dev/null +++ b/src/fastertransformer/models/llama/LlamaContextAttentionLayer.cc @@ -0,0 +1,415 @@ +/* + * Copyright (c) OpenMMLab. All rights reserved. + * Copyright (c) 2021-2023, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2021, NAVER Corp. Authored by CLOVA. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Modified from +// https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/layers/attention_layers/GptContextAttentionLayer.cc + +#include "src/fastertransformer/models/llama/LlamaContextAttentionLayer.h" +#include "src/fastertransformer/kernels/bert_preprocess_kernels.h" +#include "src/fastertransformer/kernels/unfused_attention_kernels.h" +#include "src/fastertransformer/models/llama/LlamaNcclGuard.h" +#include "src/fastertransformer/models/llama/llama_kernels.h" +#include "src/fastertransformer/models/llama/llama_utils.h" +#include "src/fastertransformer/utils/Tensor.h" +#include "src/fastertransformer/utils/cuda_utils.h" +#include "src/fastertransformer/utils/logger.h" + +namespace fastertransformer { + +template +void LlamaContextAttentionLayer::allocateBuffer(size_t batch_size, + size_t num_token, + size_t max_q_len, + size_t max_k_len) +{ + FT_LOG_DEBUG(__PRETTY_FUNCTION__); + + const int local_q_kv_head_num = local_head_num_ + 2 * local_kv_head_num_; + + // no padding + qkv_buf_ = (T*)allocator_->reMalloc(qkv_buf_, sizeof(T) * num_token * local_q_kv_head_num * size_per_head_, true); + + // padding is rebuilt for q/k/v_buf_2_ + // [qH + 2kvH, B, S, D] + q_buf_2_ = (T*)allocator_->reMalloc( + q_buf_2_, sizeof(T) * local_q_kv_head_num * batch_size * max_q_len * size_per_head_, true); + k_buf_2_ = q_buf_2_ + local_head_num_ * batch_size * max_q_len * size_per_head_; + v_buf_2_ = k_buf_2_ + local_kv_head_num_ * batch_size * max_q_len * size_per_head_; + + if (use_fmha_) { + FlashAttentionOp flash_attention(batch_size, local_head_num_, max_k_len, max_q_len, size_per_head_); + if (flash_attention.get_workspace_size() > 0) { + qk_buf_float_ = (float*)allocator_->reMalloc(qk_buf_float_, flash_attention.get_workspace_size(), true); + } + } + else { + // kv heads are repeated for unfused attention + k_cache_buf_ = (T*)allocator_->reMalloc( + k_cache_buf_, 2 * sizeof(T) * batch_size * local_head_num_ * max_k_len * size_per_head_, true); + v_cache_buf_ = k_cache_buf_ + batch_size * local_head_num_ * max_k_len * size_per_head_; + + qk_buf_ = + (T*)allocator_->reMalloc(qk_buf_, sizeof(T) * batch_size * local_head_num_ * max_q_len * max_k_len, true); + + // qkv_buf_2_ has padding + qkv_buf_2_ = (T*)allocator_->reMalloc( + qkv_buf_2_, sizeof(T) * batch_size * max_q_len * local_head_num_ * size_per_head_, true); + } + + // qkv_buf_3_ padding is removed + qkv_buf_3_ = (T*)allocator_->reMalloc(qkv_buf_3_, sizeof(T) * num_token * local_head_num_ * size_per_head_, true); + + is_allocate_buffer_ = true; +} + +template +void LlamaContextAttentionLayer::freeBuffer() +{ + if (is_allocate_buffer_) { + FT_LOG_DEBUG(__PRETTY_FUNCTION__); + + allocator_->free((void**)(&qkv_buf_)); + allocator_->free((void**)(&q_buf_2_)); + if (use_fmha_) { + allocator_->free((void**)&qk_buf_float_); + } + else { + allocator_->free((void**)(&k_cache_buf_)); + allocator_->free((void**)(&qk_buf_)); + allocator_->free((void**)(&qkv_buf_2_)); + } + allocator_->free((void**)(&qkv_buf_3_)); + + is_allocate_buffer_ = false; + } +} + +template +inline void LlamaContextAttentionLayer::forward(TensorMap* output_tensors, + const TensorMap* input_tensors, + const LlamaAttentionWeight* weights) +{ + FT_LOG_DEBUG(__PRETTY_FUNCTION__); + + /** + * input_tensors: + * \param input_query [token_num, hidden_dim] + * \param attention_mask [batch_size, 1, max_q_len, max_kv_len] + * \param padding_offset [token_num], int + * \param input_lengths [batch_size], int + * \param history_lengths [batch_size], int + * \param context_lengths [batch_size], int + * \param cu_seqlens [batch_size+1], int + * \param max_seq_len [1], int on cpu + * \param is_final_layer [1], bool on cpu + * \param layer_id [1], int on cpu + * + * output_tensors: + * \param hidden_features [token_num, hidden_dim] + * \param key_cache [batch_size], uint64 + * \param value_cache [batch_size], uint64 + */ + + ///////////////////////////////////////////// + /// parse inputs + const int batch_size = input_tensors->at("attention_mask").shape[0]; + const int max_q_len = input_tensors->at("attention_mask").shape[2]; + const int max_k_len = input_tensors->at("attention_mask").shape[3]; + const int layer_id = input_tensors->getVal("layer_id"); + + const int num_token = input_tensors->at("input_query").shape[0]; + + const int max_seq_len = input_tensors->at("max_seq_len").getVal(); + + T* attention_out = output_tensors->at("hidden_features").getPtr(); + T* attention_input = input_tensors->at("input_query").getPtr(); + T* attention_mask = input_tensors->at("attention_mask").getPtr(); + + const auto input_length = input_tensors->at("input_lengths").getPtr(); + const auto history_length = input_tensors->at("history_lengths").getPtr(); + const auto context_length = input_tensors->at("context_lengths").getPtr(); + int* cu_seqlens = input_tensors->at("cu_seqlens").getPtr(); + + const auto padding_offset = input_tensors->at("padding_offset").getPtr(); + + ///////////////////////////////////////////// + /// allocate buffers + allocateBuffer(batch_size, num_token, max_q_len, max_k_len); + + ////////////////////////////////////////////// + /// qkv gemm + // [token_num, hidden_dim] -> [token_num, 3, local_hidden_dim] + linear_.forward(qkv_buf_, attention_input, num_token, weights->qkv); + + ////////////////////////////////////////////// + /// transpose qkv & apply rotary embedding & rebuild padding + /// qkv [B, s, H + 2kvH, D] -> (q [B, H, s, D], k [B, kvH, s, D], v [B, kvH, s, D]) + invokeAddFusedQKVBiasTranspose(q_buf_2_, + k_buf_2_, + v_buf_2_, + PrefixPromptBatchWeightsParam{}, + qkv_buf_, + weights->qkv.bias, + padding_offset, // padding_offset, + history_length, // used for applying rotary embedding + batch_size, + max_q_len, // seq_len + num_token, // batch_size * seq_len + local_head_num_, + local_kv_head_num_, + size_per_head_, + rotary_embedding_dim_, + neox_rotary_style_, + nullptr, // query_weight.scale_out + 0, // int8 mode + stream_); + sync_check_cuda_error(); + + const size_t layer_offset = layer_id * local_kv_head_num_ * max_seq_len * size_per_head_; + + auto k_cache_ptrs = output_tensors->getPtr("key_cache"); + auto v_cache_ptrs = output_tensors->getPtr("value_cache"); + ////////////////////////////////////////////////////////// + /// insert the k/v computed from inputs into k/v cache + /// transpose kv -> kv cache + // put k/v_buf from shape [B, kvH, s, D] to + // k_buf_2 [B, kvH, s, D] -> key_cache [B, kvH, S[t:t+s], D/x, x] + // v_buf_2 [B, kvH, s, D] -> val_cache [B, kvH, S[t:t+s], D/x, x] + invokeExtendKVCache(k_cache_ptrs, + v_cache_ptrs, + layer_offset, + k_buf_2_, + v_buf_2_, + batch_size, + input_length, + max_q_len, + history_length, + max_seq_len, + size_per_head_, + local_kv_head_num_, + stream_, + quant_policy_, + weights->past_kv_scale.data()); + + sync_check_cuda_error(); + if (use_fmha_) { + fusedMultiHeadAttention(k_cache_ptrs, + v_cache_ptrs, + layer_offset, + attention_mask, + cu_seqlens, + batch_size, + max_q_len, + max_k_len, + max_seq_len); + } + else { + unfusedMultiHeadAttention(k_cache_ptrs, + v_cache_ptrs, + layer_offset, + attention_mask, + padding_offset, + context_length, + batch_size, + num_token, + max_q_len, + max_k_len, + max_seq_len, + quant_policy_, + weights->past_kv_scale.data()); + } + + ////////////////////////////////////////////// + /// output gemm -> + linear_.forward(attention_out, qkv_buf_3_, num_token, weights->output); + + if (tensor_para_.world_size_ > 1) { + NcclGuard nccl_guard(tensor_para_, stream_); + ftNcclAllReduceSum(attention_out, attention_out, num_token * hidden_units_, tensor_para_, stream_); + sync_check_cuda_error(); + } + + if (is_free_buffer_after_forward_ == true) { + freeBuffer(); + } + sync_check_cuda_error(); +} + +template +void LlamaContextAttentionLayer::fusedMultiHeadAttention(T** key_cache_ptrs, + T** val_cache_ptrs, + size_t cache_layer_offset, + T* attention_mask, + int* cu_seqlens, + int batch_size, + int max_q_len, + int max_k_len, + int max_seq_len) +{ + ////////////////////////////////////////////// + // flash attention + using AttentionOp = FlashAttentionOp; + using Layout = typename AttentionOp::AttentionLayout; + Layout layout_q{.stride_batch = int(local_head_num_ * max_q_len * size_per_head_), + .stride_seq = int(size_per_head_), + .stride_head = int(max_q_len * size_per_head_)}; + Layout layout_k{.stride_batch = int(local_head_num_ * max_seq_len * size_per_head_), + .stride_seq = int(size_per_head_), + .stride_head = int(max_seq_len * size_per_head_), + .batch_seqs_offset = int(cache_layer_offset), + .batch_seqs = key_cache_ptrs}; + Layout layout_v{.stride_batch = int(local_head_num_ * max_seq_len * size_per_head_), + .stride_seq = int(size_per_head_), + .stride_head = int(max_seq_len * size_per_head_), + .batch_seqs_offset = int(cache_layer_offset), + .batch_seqs = val_cache_ptrs}; + Layout layout_o{ + .stride_batch = int(local_head_num_ * max_q_len * size_per_head_), + .stride_seq = int(local_head_num_ * size_per_head_), + .stride_head = int(size_per_head_), + .use_seqlens = true, + }; + size_t group_size = size_t(local_head_num_ / local_kv_head_num_); + AttentionOp flash_attention(batch_size, local_head_num_, max_k_len, max_q_len, size_per_head_); + typename AttentionOp::Params attn_params{.attn_out = qkv_buf_3_, + .query = q_buf_2_, + .key = k_cache_buf_, + .val = v_cache_buf_, + .mask = attention_mask, + .out_accum = qk_buf_float_, + .cu_seqlens_q = cu_seqlens, + .cu_seqlens_k = nullptr, + .group_size = group_size, + .layout_q = layout_q, + .layout_k = layout_k, + .layout_v = layout_v, + .layout_o = layout_o}; + + // + flash_attention(attn_params, stream_); +} + +template +void LlamaContextAttentionLayer::unfusedMultiHeadAttention(T** key_cache_ptrs, + T** val_cache_ptrs, + size_t cache_layer_offset, + const T* attention_mask, + const int* padding_offset, + const int* context_length, + int batch_size, + int num_token, + int max_q_len, + int max_k_len, + int max_seq_len, + int quant, + const float* kv_scale) +{ + // key_cache [B, kvH, S[:t+s], D/x, x] -> [B, qH, t+s, D] + // val_cache [B, kvH, S[:t+s], D/x, x] -> [B, qH, t+s, D] + invokeTransposeKVCache(k_cache_buf_, + v_cache_buf_, + (const T**)key_cache_ptrs, + (const T**)val_cache_ptrs, + cache_layer_offset, + batch_size, + context_length, // history_len + input_len = context_len + max_k_len, + max_seq_len, + size_per_head_, + local_head_num_, + head_n_rep_, + stream_, + quant, + kv_scale); + sync_check_cuda_error(); + + const T qk_scale = static_cast(1.f / sqrtf(size_per_head_ * 1.f)); + + ////////////////////////////////////////////// + /// Q*K batch gemm + /// -> [B, H, s, t + s] + cublas_wrapper_->stridedBatchedGemm(CUBLAS_OP_T, + CUBLAS_OP_N, + max_k_len, // m + max_q_len, // n + size_per_head_, // k + k_cache_buf_, // A + size_per_head_, // lda + max_k_len * size_per_head_, // strideA + q_buf_2_, // B + size_per_head_, // ldb + max_q_len * size_per_head_, // strideB + qk_buf_, // C + max_k_len, // ldc + max_q_len * max_k_len, // strideC + batch_size * local_head_num_); // batchCount + + ////////////////////////////////////////////// + /// ! masked softmax (kernel asserts k_length <= 4096) + MaskedSoftmaxParam param{}; + param.attention_score = qk_buf_; + param.qk = qk_buf_; + param.attention_mask = attention_mask; + param.batch_size = batch_size; + param.q_length = max_q_len; + param.k_length = max_k_len; + param.num_heads = local_head_num_; + param.qk_scale = qk_scale; + param.linear_bias_slopes = nullptr; + invokeMaskedSoftmax(param, stream_); + sync_check_cuda_error(); + + ////////////////////////////////////////////// + /// softmax(QK)*V batch gemm + // -> [B, H, S, D] + cublas_wrapper_->stridedBatchedGemm(CUBLAS_OP_N, + CUBLAS_OP_N, + size_per_head_, // m + max_q_len, // n + max_k_len, // k + v_cache_buf_, // A + size_per_head_, // lda + max_k_len * size_per_head_, // strideA, + qk_buf_, // B + max_k_len, // ldb + max_k_len * max_q_len, // strideB + qkv_buf_2_, // C + size_per_head_, // ldc, + max_q_len * size_per_head_, // strideC + batch_size * local_head_num_); // batchCount + + ////////////////////////////////////////////// + /// transpose -> + invokeTransposeAttentionOutRemovePadding(qkv_buf_2_, + qkv_buf_3_, + num_token, + batch_size, + max_q_len, + local_head_num_, + size_per_head_, + padding_offset, + nullptr, + 0, + stream_); + sync_check_cuda_error(); +} + +template class LlamaContextAttentionLayer; +template class LlamaContextAttentionLayer; + +} // namespace fastertransformer diff --git a/src/fastertransformer/models/llama/LlamaContextAttentionLayer.h b/src/fastertransformer/models/llama/LlamaContextAttentionLayer.h new file mode 100644 index 000000000..161cda666 --- /dev/null +++ b/src/fastertransformer/models/llama/LlamaContextAttentionLayer.h @@ -0,0 +1,131 @@ +/* + * Copyright (c) OpenMMLab. All rights reserved. + * Copyright (c) 2021-2023, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2021, NAVER Corp. Authored by CLOVA. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Modified from +// https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/layers/attention_layers/GptContextAttentionLayer.h + +#pragma once + +#include "src/fastertransformer/models/llama/LlamaDenseWeight.h" +#include "src/fastertransformer/models/llama/LlamaLinear.h" +#include "src/fastertransformer/utils/Tensor.h" +#include "src/fastertransformer/utils/nccl_utils.h" + +namespace fastertransformer { + +template +class LlamaContextAttentionLayer { +public: + void freeBuffer(); + void allocateBuffer(size_t batch_size, size_t num_token, size_t max_q_len, size_t max_kv_len); + + LlamaContextAttentionLayer(size_t head_num, + size_t kv_head_num, + size_t size_per_head, + size_t rotary_embedding_dim, + bool neox_rotary_style, + NcclParam tensor_para, + cudaStream_t stream, + cublasMMWrapper* cublas_wrapper, + IAllocator* allocator, + bool is_free_buffer_after_forward, + bool use_fmha, + int quant_policy): + head_num_(head_num), + size_per_head_(size_per_head), + hidden_units_(head_num * size_per_head), + local_head_num_(head_num / tensor_para.world_size_), + local_kv_head_num_(kv_head_num / tensor_para.world_size_), + head_n_rep_(head_num / kv_head_num), + rotary_embedding_dim_(rotary_embedding_dim), + neox_rotary_style_(neox_rotary_style), + tensor_para_(tensor_para), + stream_(stream), + cublas_wrapper_(cublas_wrapper), + linear_(cublas_wrapper, stream), + allocator_(allocator), + is_free_buffer_after_forward_(is_free_buffer_after_forward), + use_fmha_(use_fmha), + quant_policy_(quant_policy) + { + FT_CHECK(head_num % kv_head_num == 0); + } + + void forward(TensorMap* output_tensors, const TensorMap* input_tensors, const LlamaAttentionWeight* weights); + + void fusedMultiHeadAttention(T** key_cache_ptrs, + T** val_cache_ptrs, + size_t cache_layer_offset, + T* attention_mask, + int* cu_seqlens, + int batch_size, + int max_q_len, + int max_k_len, + int max_seq_len); + + void unfusedMultiHeadAttention(T** key_cache_ptrs, + T** val_cache_ptrs, + size_t cache_layer_offset, + const T* attention_mask, + const int* padding_offset, + const int* context_length, + int batch_size, + int num_token, + int max_q_len, + int max_k_len, + int max_seq_len, + int quant_policy, + const float* kv_scale); + +private: + const size_t head_num_; + const size_t size_per_head_; + const size_t hidden_units_; + const size_t local_kv_head_num_; + const size_t local_head_num_; + const size_t head_n_rep_; + const size_t rotary_embedding_dim_; + const bool is_free_buffer_after_forward_; + + const bool neox_rotary_style_; + + const bool use_fmha_; + const int quant_policy_; + + NcclParam tensor_para_; + + cudaStream_t stream_; + IAllocator* allocator_; + cublasMMWrapper* cublas_wrapper_; + LlamaLinear linear_; + + T* qkv_buf_{}; + T* q_buf_2_{}; + T* k_buf_2_{}; + T* v_buf_2_{}; + T* k_cache_buf_{}; + T* v_cache_buf_{}; + T* qk_buf_{}; + float* qk_buf_float_{}; + T* qkv_buf_2_{}; + T* qkv_buf_3_{}; + + bool is_allocate_buffer_ = false; +}; + +} // namespace fastertransformer diff --git a/src/fastertransformer/models/llama/LlamaContextDecoder.cc b/src/fastertransformer/models/llama/LlamaContextDecoder.cc new file mode 100644 index 000000000..7af7e7098 --- /dev/null +++ b/src/fastertransformer/models/llama/LlamaContextDecoder.cc @@ -0,0 +1,288 @@ +/* + * Copyright (c) OpenMMLab. All rights reserved. + * Copyright (c) 2019-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Modified from +// https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/models/multi_gpu_gpt/ParallelGptContextDecoder.cc + +#include "src/fastertransformer/models/llama/LlamaContextDecoder.h" +#include "src/fastertransformer/kernels/bert_preprocess_kernels.h" +#include "src/fastertransformer/kernels/gpt_kernels.h" +#include "src/fastertransformer/models/llama/LlamaContextDecoder.h" +#include "src/fastertransformer/models/llama/llama_decoder_kernels.h" +#include "src/fastertransformer/models/llama/llama_kernels.h" +#include "src/fastertransformer/utils/Tensor.h" + +namespace fastertransformer { + +template +void LlamaContextDecoder::allocateBuffer() +{ + FT_CHECK(false); +} + +template +void LlamaContextDecoder::allocateBuffer(size_t batch_size, size_t num_token, size_t max_q_len, size_t max_kv_len) +{ + FT_LOG_DEBUG(__PRETTY_FUNCTION__); + + attention_mask_ = (T*)allocator_->reMalloc(attention_mask_, sizeof(T) * batch_size * max_q_len * max_kv_len, false); + padding_offset_ = (int*)allocator_->reMalloc(padding_offset_, sizeof(int) * batch_size * max_q_len, false); + cu_seqlens_ = (int*)allocator_->reMalloc(cu_seqlens_, sizeof(int) * (batch_size + 1), false); + + is_allocate_buffer_ = true; +} + +template +void LlamaContextDecoder::freeBuffer() +{ + FT_LOG_DEBUG(__PRETTY_FUNCTION__); + if (is_allocate_buffer_) { + allocator_->free((void**)&padding_offset_); + allocator_->free((void**)&cu_seqlens_); + allocator_->free((void**)&attention_mask_); + allocator_->free((void**)&h_pinned_token_num_ptr_, true); + is_allocate_buffer_ = false; + } +} + +template +void LlamaContextDecoder::initialize(size_t kv_head_num, bool use_fmha, int quant_policy) +{ + h_pinned_token_num_ptr_ = (size_t*)allocator_->reMalloc(h_pinned_token_num_ptr_, sizeof(size_t), true, true); + + context_attention_layer_ = new LlamaContextAttentionLayer(head_num_, + kv_head_num, + size_per_head_, + rotary_embedding_dim_, + false, // neox_rotary_style + tensor_para_, + stream_, + cublas_wrapper_, + allocator_, + is_free_buffer_after_forward_, + use_fmha, + quant_policy); + + silu_ffn_layer_ = new LlamaFfnLayer(head_num_, + size_per_head_, + inter_size_, + tensor_para_, + stream_, + cublas_wrapper_, + allocator_, + is_free_buffer_after_forward_); +} + +template +void LlamaContextDecoder::forwardSelfAttn(const Session& sess, + T* attn_io, + const std::unordered_map* input_tensors, + int layer, + bool is_final) +{ + // FT_LOG_ERROR(__PRETTY_FUNCTION__); + TensorMap self_attention_input_tensors{ + {"input_query", Tensor{MEMORY_GPU, data_type_, {sess.token_num, hidden_units_}, attn_io}}, + {"attention_mask", + {MEMORY_GPU, data_type_, {sess.batch_size, 1, sess.max_query_len, sess.max_key_len}, attention_mask_}}, + {"layer_id", Tensor{MEMORY_CPU, TYPE_INT32, {1}, &layer}}, + {"is_final_layer", Tensor{MEMORY_CPU, TYPE_BOOL, {1}, &is_final}}, + {"padding_offset", {MEMORY_GPU, TYPE_INT32, {sess.token_num}, padding_offset_}}, + {"cu_seqlens", {MEMORY_GPU, TYPE_INT32, {sess.batch_size + 1}, cu_seqlens_}}, + {"input_lengths", {MEMORY_GPU, TYPE_INT32, {sess.batch_size}, sess.input_length}}, + {"history_lengths", {MEMORY_GPU, TYPE_INT32, {sess.batch_size}, sess.history_length}}, + {"context_lengths", {MEMORY_GPU, TYPE_INT32, {sess.batch_size}, sess.context_length}}, + {"max_seq_len", input_tensors->at("max_seq_len")}}; + + auto& k_cache = *sess.k_cache; + auto& v_cache = *sess.v_cache; + + TensorMap self_attention_output_tensors{ + {"hidden_features", {MEMORY_GPU, data_type_, {sess.token_num, hidden_units_}, attn_io}}, + {"key_cache", k_cache}, + {"value_cache", v_cache}, + }; + + context_attention_layer_->forward(&self_attention_output_tensors, // + &self_attention_input_tensors, + &sess.weights->at(layer)->self_attn_weights); +} + +template +LlamaContextDecoder::LlamaContextDecoder(size_t head_num, + size_t kv_head_num, + size_t size_per_head, + size_t inter_size, + size_t num_layer, + size_t rotary_embedding_dim, + float rmsnorm_eps, + NcclParam tensor_para, + cudaStream_t stream, + cublasMMWrapper* cublas_wrapper, + IAllocator* allocator, + bool is_free_buffer_after_forward, + bool use_fmha, + int quant_policy): + BaseLayer(stream, cublas_wrapper, allocator, is_free_buffer_after_forward), + head_num_(head_num), + size_per_head_(size_per_head), + inter_size_(inter_size), + hidden_units_(head_num * size_per_head), + num_layer_(num_layer), + rotary_embedding_dim_(rotary_embedding_dim), + rmsnorm_eps_(rmsnorm_eps), + tensor_para_(tensor_para), + data_type_(getTensorType()) +{ + initialize(kv_head_num, use_fmha, quant_policy); +} + +template +LlamaContextDecoder::~LlamaContextDecoder() +{ + delete context_attention_layer_; + delete silu_ffn_layer_; + freeBuffer(); +} + +template +void LlamaContextDecoder::forward(std::vector* output_tensors, + const std::vector* input_tensors, + const std::vector*>* decoder_layer_weights) +{ + FT_CHECK(false); +} + +template +void LlamaContextDecoder::forward(std::unordered_map* output_tensors, + const std::unordered_map* input_tensors, + const std::vector*>* decoder_layer_weights) +{ + /** + * input tensors: + * \param decoder_input [num_token, hidden_units], float + * \param input_lengths [batch_size], int + * \param history_lengths [batch_size], int + * \param context_legnths [batch_size], int + * \param output_norm_weight [hidden_dims], float + * \param max_q_len [1], int on cpu + * \param max_kv_len [1], int on cpu + * \param max_seq_len [1], int on cpu + * + * output tensors: + * \param decoder_output [num_token, hidden_units], + * \param key_cache [num_layer, batch, local_head_num, size_per_head // x, max_seq_len, x] + * \param value_cache [num_layer, batch, local_head_num, max_seq_len, size_per_head] + * \param last_token_hidden_units [batch_size, hidden_units] + */ + + Session sess{}; + + sess.token_num = input_tensors->at("decoder_input").shape[0]; + sess.batch_size = input_tensors->at("input_lengths").shape[0]; + sess.max_query_len = input_tensors->at("max_q_len").getVal(); + sess.max_key_len = input_tensors->at("max_kv_len").getVal(); + sess.weights = decoder_layer_weights; + + sess.input_length = input_tensors->at("input_lengths").getPtr(); + sess.history_length = input_tensors->at("history_lengths").getPtr(); + sess.context_length = input_tensors->at("context_lengths").getPtr(); + + T* decoder_input_output = input_tensors->at("decoder_input").getPtr(); + T* decoder_output = output_tensors->at("decoder_output").getPtr(); + + sess.k_cache = &output_tensors->at("key_cache"); + sess.v_cache = &output_tensors->at("value_cache"); + + allocateBuffer(sess.batch_size, sess.token_num, sess.max_query_len, sess.max_key_len); + + size_t tmp_token_num{}; + invokeGetPaddingOffsetAndCuSeqLens(h_pinned_token_num_ptr_, + &tmp_token_num, // updated token num + padding_offset_, + cu_seqlens_, + input_tensors->at("input_lengths").getPtr(), + sess.batch_size, + sess.max_query_len, + stream_); + sync_check_cuda_error(); + FT_CHECK(tmp_token_num == sess.token_num); + + invokeCreateCausalMasks(attention_mask_, + sess.input_length, + sess.context_length, + sess.max_query_len, + sess.max_key_len, + sess.batch_size, + stream_); + sync_check_cuda_error(); + + ///////////////////////////////////////////// + /// RMSNorm + invokeRootMeanSquareNorm(decoder_output, + decoder_input_output, + decoder_layer_weights->at(0)->self_attn_norm_weights, + rmsnorm_eps_, + sess.token_num, + hidden_units_, + stream_); + sync_check_cuda_error(); + + for (size_t layer = 0; layer < num_layer_; ++layer) { + ///////////////////////////////////////////// + /// self-attention + forwardSelfAttn(sess, decoder_output, input_tensors, layer, false); + + invokeFusedAddBiasResidualRMSNorm(decoder_input_output, + decoder_output, + decoder_layer_weights->at(layer)->self_attn_weights.output.bias, + decoder_layer_weights->at(layer)->ffn_norm_weights, + rmsnorm_eps_, + sess.token_num, + hidden_units_, + stream_); + sync_check_cuda_error(); + + //////////////////////////////////////////// + /// feed-forward network + TensorMap ffn_inputs{{"ffn_input", {MEMORY_GPU, data_type_, {sess.token_num, hidden_units_}, decoder_output}}}; + TensorMap ffn_outputs{ + {"ffn_output", {MEMORY_GPU, data_type_, {sess.token_num, hidden_units_}, decoder_output}}}; + silu_ffn_layer_->forward(&ffn_outputs, &ffn_inputs, &decoder_layer_weights->at(layer)->ffn_weights); + + auto scale_weight = layer < num_layer_ - 1 ? decoder_layer_weights->at(layer + 1)->self_attn_norm_weights : + input_tensors->at("output_norm_weight").getPtr(); + invokeFusedAddBiasResidualRMSNorm(decoder_input_output, // + decoder_output, + decoder_layer_weights->at(layer)->ffn_weights.output.bias, + scale_weight, + rmsnorm_eps_, + sess.token_num, + hidden_units_, + stream_); + sync_check_cuda_error(); + } + + if (is_free_buffer_after_forward_) { + freeBuffer(); + } +} + +template class LlamaContextDecoder; +template class LlamaContextDecoder; + +} // namespace fastertransformer diff --git a/src/fastertransformer/models/llama/LlamaContextDecoder.h b/src/fastertransformer/models/llama/LlamaContextDecoder.h new file mode 100644 index 000000000..56535f36b --- /dev/null +++ b/src/fastertransformer/models/llama/LlamaContextDecoder.h @@ -0,0 +1,116 @@ +/* + * Copyright (c) OpenMMLab. All rights reserved. + * Copyright (c) 2019-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Modified from +// https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/models/multi_gpu_gpt/ParallelGptContextDecoder.h + +#pragma once + +// #include "src/fastertransformer/kernels/add_residual_kernels.h" +// #include "src/fastertransformer/kernels/layernorm_kernels.h" +#include "src/fastertransformer/layers/BaseLayer.h" +// #include "src/fastertransformer/layers/FfnLayer.h" +// #include "src/fastertransformer/layers/attention_layers/BaseAttentionLayer.h" +#include "src/fastertransformer/models/llama/LlamaContextAttentionLayer.h" +#include "src/fastertransformer/models/llama/LlamaDecoderLayerWeight.h" +#include "src/fastertransformer/models/llama/LlamaFfnLayer.h" +#include "src/fastertransformer/utils/Tensor.h" +#include "src/fastertransformer/utils/allocator.h" +#include "src/fastertransformer/utils/cublasMMWrapper.h" +#include "src/fastertransformer/utils/custom_ar_comm.h" +#include "src/fastertransformer/utils/nccl_utils.h" + +namespace fastertransformer { + +template +class LlamaContextDecoder: public BaseLayer { +protected: + void allocateBuffer() override; + void allocateBuffer(size_t batch_size, size_t num_token, size_t max_q_len, size_t max_kv_len); + void freeBuffer() override; + + void initialize(size_t kv_head_num, bool use_fmha, int quant_policy); + + size_t head_num_; + size_t size_per_head_; + size_t inter_size_; + size_t num_layer_; + size_t rotary_embedding_dim_; + size_t hidden_units_; + float rmsnorm_eps_; + + NcclParam tensor_para_; + + T* attention_mask_{}; + int* padding_offset_{}; + int* cu_seqlens_{}; // cu for cumulative + + size_t* h_pinned_token_num_ptr_{}; + + LlamaContextAttentionLayer* context_attention_layer_{}; + LlamaFfnLayer* silu_ffn_layer_{}; + + const DataType data_type_; + + struct Session { + size_t batch_size; + size_t token_num; + size_t max_query_len; + size_t max_key_len; + Tensor* k_cache; + Tensor* v_cache; + int* input_length{}; + int* history_length{}; + int* context_length{}; + + const std::vector*>* weights; + }; + + void forwardSelfAttn(const Session& sess, + T* attn_io, + const std::unordered_map* input_tensors, + int layer, + bool is_final); + +public: + LlamaContextDecoder(size_t head_num, + size_t kv_head_num, + size_t size_per_head, + size_t inter_size, + size_t num_layer, + size_t rotary_embedding_dim, + float rmsnorm_eps, + NcclParam tensor_para, + cudaStream_t stream, + cublasMMWrapper* cublas_wrapper, + IAllocator* allocator, + bool is_free_buffer_after_forward, + bool use_fmha, + int quant_policy); + + ~LlamaContextDecoder() override; + + virtual void forward(std::unordered_map* output_tensors, + const std::unordered_map* input_tensors, + const std::vector*>* decoder_layer_weights); + + virtual void forward(std::vector* output_tensors, + const std::vector* input_tensors, + const std::vector*>* decoder_layer_weights); +}; + +} // namespace fastertransformer diff --git a/src/fastertransformer/models/llama/LlamaDecoder.cc b/src/fastertransformer/models/llama/LlamaDecoder.cc new file mode 100644 index 000000000..5c2cf7e74 --- /dev/null +++ b/src/fastertransformer/models/llama/LlamaDecoder.cc @@ -0,0 +1,247 @@ +/* + * Copyright (c) OpenMMLab. All rights reserved. + * Copyright (c) 2019-2023, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2022, SK Telecom Authored by A. Dialog + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Modified from +// https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/models/multi_gpu_gpt/ParallelGptDecoder.cc + +#include "src/fastertransformer/models/llama/LlamaDecoder.h" +#include "src/fastertransformer/models/llama/llama_decoder_kernels.h" +#include "src/fastertransformer/models/llama/llama_kernels.h" +#include "src/fastertransformer/models/llama/llama_utils.h" + +namespace fastertransformer { + +template +LlamaDecoder::LlamaDecoder(size_t head_num, + size_t kv_head_num, + size_t size_per_head, + size_t inter_size, + size_t num_layer, + size_t rotary_embedding_dim, + float rmsnorm_eps, + NcclParam tensor_para, + cudaStream_t stream, + cublasMMWrapper* cublas_wrapper, + IAllocator* allocator, + bool is_free_buffer_after_forward, + int quant_policy): + BaseLayer(stream, cublas_wrapper, allocator, is_free_buffer_after_forward), + head_num_(head_num), + size_per_head_(size_per_head), + inter_size_(inter_size), + num_layer_(num_layer), + rotary_embedding_dim_(rotary_embedding_dim), + hidden_units_(head_num * size_per_head), + rmsnorm_eps_(rmsnorm_eps), + tensor_para_(tensor_para), + data_type_(getTensorType()) +{ + FT_LOG_DEBUG(__PRETTY_FUNCTION__); + initialize(kv_head_num, quant_policy); +} + +template +LlamaDecoder::~LlamaDecoder() +{ + FT_LOG_DEBUG(__PRETTY_FUNCTION__); + delete self_attention_layer_; + delete silu_ffn_layer_; +} + +template +void LlamaDecoder::initialize(size_t kv_head_num, int quant_policy) +{ + FT_LOG_DEBUG(__PRETTY_FUNCTION__); + + self_attention_layer_ = new LlamaDecoderSelfAttentionLayer(head_num_, + kv_head_num, + size_per_head_, + rotary_embedding_dim_, + false, // neox_rotary_style + tensor_para_, + stream_, + cublas_wrapper_, + allocator_, + is_free_buffer_after_forward_, + quant_policy); + + silu_ffn_layer_ = new LlamaFfnLayer(head_num_, + size_per_head_, + inter_size_, + tensor_para_, + stream_, + cublas_wrapper_, + allocator_, + is_free_buffer_after_forward_); +} + +template +void LlamaDecoder::allocateBuffer() +{ + FT_CHECK(false); +} + +template +void LlamaDecoder::allocateBuffer(size_t batch_size) +{ + FT_LOG_DEBUG(__PRETTY_FUNCTION__); + is_allocate_buffer_ = true; +} + +template +void LlamaDecoder::freeBuffer() +{ + FT_LOG_DEBUG(__PRETTY_FUNCTION__); + if (is_allocate_buffer_) { + is_allocate_buffer_ = false; + } +} + +template +void LlamaDecoder::forwardSelfAttn(const LlamaDecoder::Session& sess, + T* attn_io, + const std::unordered_map* input_tensors, + size_t layer) +{ + FT_LOG_DEBUG(__PRETTY_FUNCTION__); + TensorMap self_attention_input_tensors(*input_tensors); + self_attention_input_tensors.insert("input_query", + {MEMORY_GPU, data_type_, {sess.batch_size, hidden_units_}, attn_io}); + const int layer_id = layer; + self_attention_input_tensors.insert("layer_id", {MEMORY_CPU, TYPE_INT32, {1}, &layer_id}); + auto& k_cache = *sess.k_cache; + auto& v_cache = *sess.v_cache; + + TensorMap self_attention_output_tensors{ + {"attention_output", {MEMORY_GPU, data_type_, {sess.batch_size, hidden_units_}, attn_io}}, + {"key_cache", k_cache}, + {"value_cache", v_cache}, + }; + + self_attention_layer_->forward(&self_attention_output_tensors, // + &self_attention_input_tensors, + &sess.weights->at(layer)->self_attn_weights); +} + +template +void LlamaDecoder::forwardFfn(const LlamaDecoder::Session& sess, T* ffn_io, size_t layer) +{ + TensorMap ffn_inputs{{"ffn_input", {MEMORY_GPU, data_type_, {sess.batch_size, hidden_units_}, ffn_io}}}; + TensorMap ffn_outputs{{"ffn_output", {MEMORY_GPU, data_type_, {sess.batch_size, hidden_units_}, ffn_io}}}; + silu_ffn_layer_->forward(&ffn_outputs, &ffn_inputs, &sess.weights->at(layer)->ffn_weights); +} + +template +void LlamaDecoder::forward(std::vector* output_tensors, + const std::vector* input_tensors, + const std::vector*>* decoder_layer_weights) +{ + FT_CHECK(false); +} + +template +void LlamaDecoder::forward(std::unordered_map* output_tensors, + const std::unordered_map* input_tensors, + const std::vector*>* decoder_layer_weights) +{ + FT_LOG_DEBUG(__PRETTY_FUNCTION__); + /** + * input_tensors: + * \param decoder_input [batch_size, hidden_dims] + * \param sequence_lengths [batch_size] int + * \param output_norm_weight [hidden_dims] + * \param step [1] on cpu + * \param ite [1] on cpu + * \param finished [batch_size] bool + * \param total_padding_tokens [batch_size], int + * \param max_seq_len [1] on cpu + * \param masked_tokens [batch_size, memory_len] bool (optional), NOT USED YET + * + * output_tensors: + * \param decoder_output [batch_size, hidden_dimension] + * \param key_cache [batch_size] uint64_t + * \param value_cache [batch_size] uint64_t + */ + + // for the shape of key cache, refer to decoder_masked_multihead_attention_template.hpp + + Session sess{}; + sess.batch_size = input_tensors->at("decoder_input").shape[0]; + sess.weights = decoder_layer_weights; + + allocateBuffer(sess.batch_size); + + sess.ite = input_tensors->at("ite").getVal(); + sess.k_cache = &output_tensors->at("key_cache"); + sess.v_cache = &output_tensors->at("value_cache"); + + sess.max_memory_len = input_tensors->at("max_seq_len").getVal(); + + T* decoder_input = input_tensors->at("decoder_input").getPtr(); + T* decoder_output = output_tensors->at("decoder_output").getPtr(); + + //////////////////////////////////////////// + /// RMSNorm + invokeRootMeanSquareNorm(decoder_output, + decoder_input, + decoder_layer_weights->at(0)->self_attn_norm_weights, + rmsnorm_eps_, + sess.batch_size, + hidden_units_, + stream_); + sync_check_cuda_error(); + + for (size_t layer = 0; layer < num_layer_; ++layer) { + // output: self_attn_output_, k_cache, v_cache = self_attn(decoder_normed_input_) + forwardSelfAttn(sess, decoder_output, input_tensors, layer); + + invokeFusedAddBiasResidualRMSNorm(decoder_input, + decoder_output, + decoder_layer_weights->at(layer)->self_attn_weights.output.bias, + decoder_layer_weights->at(layer)->ffn_norm_weights, + rmsnorm_eps_, + sess.batch_size, + hidden_units_, + stream_); + sync_check_cuda_error(); + + // decoder_layer_output_ = ffn(decoder_normed_input_) + forwardFfn(sess, decoder_output, layer); + + auto scale_weight = layer < num_layer_ - 1 ? decoder_layer_weights->at(layer + 1)->self_attn_norm_weights : + input_tensors->at("output_norm_weight").getPtr(); + invokeFusedAddBiasResidualRMSNorm(decoder_input, // + decoder_output, + decoder_layer_weights->at(layer)->ffn_weights.output.bias, + scale_weight, + rmsnorm_eps_, + sess.batch_size, + hidden_units_, + stream_); + sync_check_cuda_error(); + } + + if (is_free_buffer_after_forward_) { + freeBuffer(); + } +} + +template class LlamaDecoder; +template class LlamaDecoder; + +} // namespace fastertransformer diff --git a/src/fastertransformer/models/llama/LlamaDecoder.h b/src/fastertransformer/models/llama/LlamaDecoder.h new file mode 100644 index 000000000..588e1c94c --- /dev/null +++ b/src/fastertransformer/models/llama/LlamaDecoder.h @@ -0,0 +1,97 @@ +/* + * Copyright (c) OpenMMLab. All rights reserved. + * Copyright (c) 2019-2023, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2022, SK Telecom Authored by A. Dialog + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Modified from +// https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/models/multi_gpu_gpt/ParallelGptDecoder.h + +#include "src/fastertransformer/layers/BaseLayer.h" +// #include "src/fastertransformer/layers/FfnLayer.h" +#include "src/fastertransformer/models/llama/LlamaDecoderLayerWeight.h" +#include "src/fastertransformer/models/llama/LlamaDecoderSelfAttentionLayer.h" +#include "src/fastertransformer/models/llama/LlamaFfnLayer.h" +#include "src/fastertransformer/utils/custom_ar_comm.h" +#include "src/fastertransformer/utils/nccl_utils.h" + +namespace fastertransformer { + +template +class LlamaDecoder: public BaseLayer { +protected: + void allocateBuffer() override; // deprecated + void allocateBuffer(size_t batch_size); + void freeBuffer() override; + void initialize(size_t kv_head_num, int quant_policy); + + size_t head_num_; + size_t size_per_head_; + size_t inter_size_; + size_t num_layer_; + size_t rotary_embedding_dim_; + size_t hidden_units_; + float rmsnorm_eps_; + + NcclParam tensor_para_; + + LlamaDecoderSelfAttentionLayer* self_attention_layer_{}; + LlamaFfnLayer* silu_ffn_layer_{}; + + const DataType data_type_; + + struct Session { + size_t batch_size; + int ite; + size_t max_memory_len; + Tensor* k_cache; + Tensor* v_cache; + const std::vector*>* weights; + }; + + void forwardSelfAttn(const Session& sess, + T* attn_io, + const std::unordered_map* input_tensors, + size_t layer); + + void forwardFfn(const LlamaDecoder::Session& sess, T* ffn_io, size_t layer); + +public: + LlamaDecoder(size_t head_num, + size_t kv_head_num, + size_t size_per_head, + size_t inter_size, + size_t num_layer, + size_t rotary_embedding_dim, + float rmsnorm_eps, + NcclParam tensor_para, + cudaStream_t stream, + cublasMMWrapper* cublas_wrapper, + IAllocator* allocator, + bool is_free_buffer_after_forward, + int quant_policy); + + ~LlamaDecoder() override; + + virtual void forward(std::unordered_map* output_tensors, + const std::unordered_map* input_tensors, + const std::vector*>* decoder_layer_weights); + + virtual void forward(std::vector* output_tensors, + const std::vector* input_tensors, + const std::vector*>* decoder_layer_weights); +}; + +} // namespace fastertransformer diff --git a/src/fastertransformer/models/llama/LlamaDecoderLayerWeight.cc b/src/fastertransformer/models/llama/LlamaDecoderLayerWeight.cc new file mode 100644 index 000000000..6515902af --- /dev/null +++ b/src/fastertransformer/models/llama/LlamaDecoderLayerWeight.cc @@ -0,0 +1,302 @@ +/* + * Copyright (c) OpenMMLab. All rights reserved. + * Copyright (c) 2019-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Modified from +// https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/models/multi_gpu_gpt/ParallelGptDecoderLayerWeight.cc + +#include "src/fastertransformer/models/llama/LlamaDecoderLayerWeight.h" +#include "src/fastertransformer/utils/logger.h" +#include "src/fastertransformer/utils/memory_utils.h" +// #include +#include + +namespace fastertransformer { + +template +LlamaDecoderLayerWeight::LlamaDecoderLayerWeight(size_t head_num, + size_t kv_head_num, + size_t size_per_head, + size_t inter_size, + WeightType weight_type, + bool attn_bias, + size_t tensor_para_size, + size_t tensor_para_rank): + head_num_(head_num), + kv_head_num_(kv_head_num), + size_per_head_(size_per_head), + hidden_units_(head_num * size_per_head), + inter_size_(inter_size), + weight_type_(weight_type), + attn_bias_(attn_bias), + tensor_para_size_(tensor_para_size), + tensor_para_rank_(tensor_para_rank) +{ + self_attn_weights.qkv.input_dims = hidden_units_; + self_attn_weights.qkv.output_dims = (head_num + 2 * kv_head_num) * size_per_head / tensor_para_size_; + self_attn_weights.qkv.type = weight_type; + + self_attn_weights.output.input_dims = hidden_units_ / tensor_para_size_; + self_attn_weights.output.output_dims = hidden_units_; + self_attn_weights.output.type = weight_type; + + ffn_weights.gating.input_dims = hidden_units_; + ffn_weights.gating.output_dims = inter_size_ / tensor_para_size_; + ffn_weights.gating.type = weight_type; + + ffn_weights.intermediate.input_dims = hidden_units_; + ffn_weights.intermediate.output_dims = inter_size_ / tensor_para_size_; + ffn_weights.intermediate.type = weight_type; + + ffn_weights.output.input_dims = inter_size_ / tensor_para_size_; + ffn_weights.output.output_dims = hidden_units_; + ffn_weights.output.type = weight_type; + mallocWeights(); +} + +template +void freeWeights(LlamaDenseWeight& weights) +{ + cudaFree(weights.kernel); + cudaFree(weights.bias); + cudaFree(weights.scales); + cudaFree(weights.zeros); + + weights.kernel = nullptr; + weights.bias = nullptr; + weights.scales = nullptr; + weights.zeros = nullptr; +} + +template +void mallocWeights(LlamaDenseWeight& weights, bool bias) +{ + if (bias) { + deviceMalloc((T**)&weights.bias, weights.output_dims); + } + const size_t bit_size = getBitSize(weights.type); + if (bit_size >= 16) { // fp16, fp32 + deviceMalloc((T**)&weights.kernel, weights.input_dims * weights.output_dims); + } + else { // int8, int4 + const int factor = sizeof(float) * 8 / bit_size; + FT_CHECK(weights.input_dims % factor == 0); + deviceMalloc((float**)&weights.kernel, weights.input_dims / factor * weights.output_dims); + deviceMalloc((T**)&weights.scales, weights.output_dims); + deviceMalloc((T**)&weights.zeros, weights.output_dims); + } +} + +template +void loadWeights(LlamaDenseWeight& w, + std::string prefix, + int rank, + FtCudaDataType model_file_type, + size_t tensor_para_size, + int slice_dim = 0, + std::vector slice_shape = {}) +{ + auto max_prefix = prefix + "." + std::to_string(tensor_para_size - 1); + const auto type = model_file_type; + + bool enable_slice = true; + // Disable slice if tensor param rank is 1 + if (tensor_para_size <= 1) { + enable_slice = false; + } + else { + // Disable slice if weight has already been sliced + if (std::experimental::filesystem::exists(max_prefix + ".weight") || std::experimental::filesystem::exists(max_prefix + ".qweight")) { + // if (std::filesystem::exists(max_prefix + ".weight") || std::filesystem::exists(max_prefix + ".qweight")) { + FT_LOG_DEBUG("TP weight exists. Disable runtime TP."); + enable_slice = false; + } + } + + size_t dim0 = w.input_dims; + size_t dim1 = w.output_dims; + if (enable_slice) { + // multiple tp size for slice stride + if (slice_dim == 0) { + dim0 = dim0 * tensor_para_size; + if (slice_shape.size() == 0) { + slice_shape = {dim0}; + } + } + else { + dim1 = dim1 * tensor_para_size; + if (slice_shape.size() == 0) { + slice_shape = {dim1}; + } + } + + prefix += "." + std::to_string(0); + } + else { + prefix += "." + std::to_string(rank); + } + + if (w.bias) { + std::vector bias_slices{}; + if (enable_slice) { + if (slice_dim == 1) { + size_t start = 0; + ConcateSlice slice0{.slices = {{0, 1}}}; + ConcateSlice slice1{.slices = {{}}}; + for (auto len : slice_shape) { + size_t stride = len / tensor_para_size; + slice1.slices.push_back({start + stride * rank, start + stride * (rank + 1)}); + start += len; + } + bias_slices = {slice0, slice1}; + } + } + loadWeightFromBin((T*)w.bias, {1, dim1}, prefix + ".bias", type, bias_slices); + } + const size_t bit_size = getBitSize(w.type); + if (bit_size >= 16) { // fp16, fp32 + std::vector weight_slices{}; + if (enable_slice) { + if (slice_dim == 1) { + size_t start = 0; + ConcateSlice slice0{.slices = {{0, dim0}}}; + ConcateSlice slice1{.slices = {{}}}; + for (auto len : slice_shape) { + size_t stride = len / tensor_para_size; + slice1.slices.push_back({start + stride * rank, start + stride * (rank + 1)}); + start += len; + } + weight_slices = {slice0, slice1}; + } + else { + size_t start = 0; + ConcateSlice slice0{.slices = {}}; + ConcateSlice slice1{.slices = {{0, dim1}}}; + for (auto len : slice_shape) { + size_t stride = len / tensor_para_size; + slice0.slices.push_back({start + stride * rank, start + stride * (rank + 1)}); + start += len; + } + weight_slices = {slice0, slice1}; + } + } + loadWeightFromBin((T*)w.kernel, {dim0, dim1}, prefix + ".weight", type, weight_slices); + } + else { // int8, int4 + const int factor = sizeof(float) * 8 / bit_size; + FT_CHECK(dim0 % factor == 0); + const auto f32_type = FtCudaDataType::FP32; + std::vector weight_slices{}; + std::vector bias_slices{}; + if (enable_slice) { + if (slice_dim == 1) { + size_t start = 0; + ConcateSlice slice0{.slices = {{0, dim0}}}; + ConcateSlice slice1{.slices = {{}}}; + for (auto len : slice_shape) { + size_t stride = len / tensor_para_size; + slice1.slices.push_back({start + stride * rank, start + stride * (rank + 1)}); + start += len; + } + weight_slices = {slice0, slice1}; + + ConcateSlice bias_slice0{.slices = {{0, 1}}}; + bias_slices = {bias_slice0, slice1}; + } + else { + size_t start = 0; + ConcateSlice slice0{.slices = {}}; + ConcateSlice slice1{.slices = {{0, dim1}}}; + for (auto len : slice_shape) { + size_t stride = len / factor / tensor_para_size; + slice0.slices.push_back({start + stride * rank, start + stride * (rank + 1)}); + start += len; + } + weight_slices = {slice0, slice1}; + } + } + loadWeightFromBin((float*)w.kernel, {dim0 / factor, dim1}, prefix + ".qweight", f32_type, weight_slices); + loadWeightFromBin((T*)w.scales, {1, dim1}, prefix + ".scales", type, bias_slices); + loadWeightFromBin((T*)w.zeros, {1, dim1}, prefix + ".zeros", type, bias_slices); + } +} + +template +void LlamaDecoderLayerWeight::mallocWeights() +{ + deviceMalloc((T**)&self_attn_norm_weights, hidden_units_); + deviceMalloc((T**)&ffn_norm_weights, hidden_units_); + + fastertransformer::mallocWeights(self_attn_weights.qkv, attn_bias_); + fastertransformer::mallocWeights(self_attn_weights.output, attn_bias_); + + fastertransformer::mallocWeights(ffn_weights.gating, false); + fastertransformer::mallocWeights(ffn_weights.intermediate, false); + fastertransformer::mallocWeights(ffn_weights.output, false); +} + +template +LlamaDecoderLayerWeight::~LlamaDecoderLayerWeight() +{ + cudaFree((void*)self_attn_norm_weights); + cudaFree((void*)ffn_norm_weights); + + freeWeights(self_attn_weights.qkv); + freeWeights(self_attn_weights.output); + freeWeights(ffn_weights.gating); + freeWeights(ffn_weights.intermediate); + freeWeights(ffn_weights.output); +} + +template +void LlamaDecoderLayerWeight::loadModel(std::string dir_path, FtCudaDataType model_file_type) +{ + const auto rank_spec = std::to_string(tensor_para_rank_); + const auto type = model_file_type; + + loadWeightFromBin( + (T*)self_attn_norm_weights, {hidden_units_}, dir_path + ".attention_norm.weight", model_file_type); + loadWeightFromBin((T*)ffn_norm_weights, {hidden_units_}, dir_path + ".ffn_norm.weight", model_file_type); + + loadWeights(self_attn_weights.qkv, + dir_path + ".attention.w_qkv", + tensor_para_rank_, + type, + tensor_para_size_, + 1, + {head_num_ * size_per_head_, kv_head_num_ * size_per_head_, kv_head_num_ * size_per_head_}); + loadWeights(self_attn_weights.output, dir_path + ".attention.wo", tensor_para_rank_, type, tensor_para_size_, 0); + loadWeights(ffn_weights.gating, dir_path + ".feed_forward.w1", tensor_para_rank_, type, tensor_para_size_, 1); + loadWeights(ffn_weights.intermediate, dir_path + ".feed_forward.w3", tensor_para_rank_, type, tensor_para_size_, 1); + loadWeights(ffn_weights.output, dir_path + ".feed_forward.w2", tensor_para_rank_, type, tensor_para_size_, 0); + + // load kv_cache quant scale + // if file not exist, get empty vector + std::string scale_path = dir_path + ".past_kv_scale." + rank_spec + ".weight"; + std::ifstream in(scale_path, std::ios::in); + if (in.is_open()) { + in.close(); + self_attn_weights.past_kv_scale = loadArrayFromBin({2}, scale_path); + } + else { + self_attn_weights.past_kv_scale = {}; + } +} + +template struct LlamaDecoderLayerWeight; +template struct LlamaDecoderLayerWeight; + +} // namespace fastertransformer diff --git a/src/fastertransformer/models/llama/LlamaDecoderLayerWeight.h b/src/fastertransformer/models/llama/LlamaDecoderLayerWeight.h new file mode 100644 index 000000000..8b8bb9c52 --- /dev/null +++ b/src/fastertransformer/models/llama/LlamaDecoderLayerWeight.h @@ -0,0 +1,66 @@ +/* + * Copyright (c) OpenMMLab. All rights reserved. + * Copyright (c) 2019-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Modified from +// https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/models/multi_gpu_gpt/ParallelGptDecoderLayerWeight.h + +#pragma once + +#include "src/fastertransformer/models/llama/LlamaDenseWeight.h" + +namespace fastertransformer { + +template +struct LlamaDecoderLayerWeight { +public: + LlamaDecoderLayerWeight() = delete; + LlamaDecoderLayerWeight(size_t head_num, + size_t kv_head_num, + size_t size_per_head, + size_t inter_size, + WeightType weight_type, + bool attn_bias, + size_t tensor_para_size, + size_t tensor_para_rank); + ~LlamaDecoderLayerWeight(); + LlamaDecoderLayerWeight(const LlamaDecoderLayerWeight& other) = delete; + LlamaDecoderLayerWeight& operator=(const LlamaDecoderLayerWeight& other) = delete; + + void loadModel(std::string dir_path, FtCudaDataType model_file_type); + + T* self_attn_norm_weights{}; + T* ffn_norm_weights{}; + LlamaAttentionWeight self_attn_weights{}; + LlamaFfnWeight ffn_weights{}; + +private: + size_t head_num_; + size_t kv_head_num_; + size_t size_per_head_; + size_t hidden_units_; + size_t inter_size_; + WeightType weight_type_; + size_t bit_size_; + bool attn_bias_; + size_t tensor_para_size_; + size_t tensor_para_rank_; + bool is_maintain_buffer_ = false; + + void mallocWeights(); +}; + +} // namespace fastertransformer diff --git a/src/fastertransformer/models/llama/LlamaDecoderSelfAttentionLayer.cc b/src/fastertransformer/models/llama/LlamaDecoderSelfAttentionLayer.cc new file mode 100644 index 000000000..a644d4669 --- /dev/null +++ b/src/fastertransformer/models/llama/LlamaDecoderSelfAttentionLayer.cc @@ -0,0 +1,303 @@ +/* + * Copyright (c) OpenMMLab. All rights reserved. + * Copyright (c) 2019-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Modified from +// https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/layers/attention_layers/DecoderSelfAttentionLayer.cc +#include "src/fastertransformer/models/llama/LlamaDecoderSelfAttentionLayer.h" +#include "src/fastertransformer/kernels/decoder_masked_multihead_attention.h" +#include "src/fastertransformer/models/llama/LlamaNcclGuard.h" +#include "src/fastertransformer/models/llama/llama_kernels.h" +#include "src/fastertransformer/models/llama/llama_utils.h" +#include "src/fastertransformer/utils/cuda_utils.h" +#include "src/fastertransformer/utils/logger.h" +#include "src/fastertransformer/utils/nvtx_utils.h" +#include +// #include + +namespace fastertransformer { + +template +struct SATypeConverter { + using Type = T; +}; + +template<> +struct SATypeConverter { + using Type = uint16_t; +}; + +template +static inline void fusedQKV_masked_attention_dispatch(const T* qkv_buf, + const T* qkv_bias, + const T* relative_attention_bias, + T* key_cache, + T* value_cache, + T** k_cache_per_sample, + T** v_cache_per_sample, + size_t kv_cache_per_sample_offset, + const int* cache_indir, + T* context_buf, + const bool* finished, + const int* sequence_lengths, + const int max_batch_size, + const int inference_batch_size, + const int beam_width, + const int head_num, + const int kv_head_num, + const int size_per_head, + const int rotary_embedding_dim, + const int memory_max_len, + const int* prefix_prompt_lengths, + const int max_prefix_prompt_length, + const int max_input_len, + const int* total_padding_tokens, + const int step, + const float q_scaling, + const int relative_attention_bias_stride, + const T* linear_bias_slopes, + const bool* masked_tokens, + const int* ia3_tasks, + const T* ia3_key_weights, + const T* ia3_value_weights, + const float* qkv_scale_out, + const float* attention_out_scale, + const int int8_mode, + const float* attention_kv_scale, + cudaStream_t stream) +{ + using DataType = typename SATypeConverter::Type; + // Prepare the parameters. + Masked_multihead_attention_params params; + memset(¶ms, 0, sizeof(params)); + // int hidden_units = head_num * size_per_head; + if (qkv_bias != nullptr) { + params.q_bias = reinterpret_cast(qkv_bias); + params.k_bias = reinterpret_cast(qkv_bias) + head_num * size_per_head; + params.v_bias = reinterpret_cast(qkv_bias) + (head_num + kv_head_num) * size_per_head; + } + else { + params.q_bias = nullptr; + params.k_bias = nullptr; + params.v_bias = nullptr; + } + + // Set the output buffer. + params.out = reinterpret_cast(context_buf); + + // Set the input buffers. + // [B, nH + kvH, D] + params.q = reinterpret_cast(qkv_buf); + params.k = reinterpret_cast(qkv_buf) + head_num * size_per_head; + params.v = reinterpret_cast(qkv_buf) + (head_num + kv_head_num) * size_per_head; + + params.stride = (head_num + 2 * kv_head_num) * size_per_head; + params.finished = const_cast(finished); + + FT_CHECK(k_cache_per_sample && v_cache_per_sample); + + params.k_cache = reinterpret_cast(key_cache); + params.v_cache = reinterpret_cast(value_cache); + params.k_cache_per_sample = reinterpret_cast(k_cache_per_sample); + params.v_cache_per_sample = reinterpret_cast(v_cache_per_sample); + params.kv_cache_per_sample_offset = kv_cache_per_sample_offset; + params.k_cache_interleaved = false; + params.cache_indir = cache_indir; + params.batch_size = inference_batch_size; + params.beam_width = beam_width; + params.memory_max_len = memory_max_len; + params.prefix_prompt_lengths = prefix_prompt_lengths; + params.max_prefix_prompt_length = max_prefix_prompt_length; + params.length_per_sample = sequence_lengths; // max_input_length + current output length + // timestep adding max_prefix_prompt_length for shared memory size calculation and rotary embedding computation + params.timestep = step + max_prefix_prompt_length - 1; + params.num_heads = head_num; + params.num_kv_heads = kv_head_num; + + params.hidden_size_per_head = size_per_head; + params.rotary_embedding_dim = rotary_embedding_dim; + // Note: keep norm factor (sqrt(K_dim)) when adopting megatron T5 structure (may adjust) + params.inv_sqrt_dh = 1.F / (sqrtf((float)params.hidden_size_per_head) * q_scaling); + + params.total_padding_tokens = total_padding_tokens; + if (relative_attention_bias != nullptr) { + params.relative_attention_bias = reinterpret_cast(relative_attention_bias); + } + params.relative_attention_bias_stride = relative_attention_bias_stride; + params.masked_tokens = masked_tokens; + + // The slope of linear position bias per head, e.g., ALiBi. + if (linear_bias_slopes != nullptr) { + params.linear_bias_slopes = reinterpret_cast(linear_bias_slopes); + } + params.max_input_length = max_input_len; + + params.ia3_tasks = ia3_tasks; + params.ia3_key_weights = reinterpret_cast(ia3_key_weights); + params.ia3_value_weights = reinterpret_cast(ia3_value_weights); + + params.int8_mode = int8_mode; + + if (int8_mode & QuantPolicy::kCacheKVInt8) { + params.attention_k_scale = attention_kv_scale[0]; + params.attention_v_scale = attention_kv_scale[1]; + } + + PUSH_RANGE("scaled dot-product fusion"); + masked_multihead_attention(params, stream); + POP_RANGE; +} + +template +void LlamaDecoderSelfAttentionLayer::allocateBuffer(size_t batch_size, int key_len, int max_memory_len) +{ + FT_LOG_DEBUG(__PRETTY_FUNCTION__); + + const size_t local_q_kv_head_num = local_head_num_ + 2 * local_kv_head_num_; + + qkv_buf_ = reinterpret_cast( + allocator_->reMalloc(qkv_buf_, sizeof(T) * batch_size * local_q_kv_head_num * size_per_head_, false)); + context_buf_ = + reinterpret_cast(allocator_->reMalloc(context_buf_, sizeof(T) * batch_size * local_hidden_units_, false)); + + is_allocate_buffer_ = true; +} + +template +void LlamaDecoderSelfAttentionLayer::freeBuffer() +{ + if (is_allocate_buffer_) { + allocator_->free((void**)(&qkv_buf_)); + allocator_->free((void**)(&context_buf_)); + // allocator_->free((void**)(&k_cache_buf_)); + // allocator_->free((void**)(&v_cache_buf_)); + is_allocate_buffer_ = false; + } +} + +template +void LlamaDecoderSelfAttentionLayer::forward(TensorMap* output_tensors, + const TensorMap* input_tensors, + const LlamaAttentionWeight* weights) +{ + /** + * input tensors: + * \param input_query [batch_size, hidden_units], + * \param sequence_lengths [batch_size] + * \param step [1] on cpu + * \param finished [batch_size] + * \param total_padding_tokens [batch_size] + * \param layer_id [1], int on cpu + * \param max_seq_len [1] on cpu + * \param masked_tokens [batch_size, memory_len], (optional), NOT USED YET + * \param cache_indirection [batch_size / beam_width, beam_width, memory_max_len] (optional) + * + * output tensors: + * \param attention_output [batch_size, hidden_units], + * \param key_cache [batch, local_head_num, memory_max_len, size_per_head] + * \param value_cache [batch, local_head_num, memory_max_len, size_per_head] + */ + + const T* input_query_data = input_tensors->getPtr("input_query"); + const int* sequence_lengths_data = input_tensors->getPtr("sequence_lengths"); + const int* total_padding_len = input_tensors->getPtr("total_padding_tokens"); + const bool* finished_data = input_tensors->getPtr("finished", nullptr); + const bool* masked_tokens_data = input_tensors->getPtr("masked_tokens", nullptr); + const int* cache_indir = input_tensors->getPtr("cache_indirection", nullptr); + + T* hidden_features_data = output_tensors->getPtr("attention_output"); + T** key_cache_ptrs = output_tensors->getPtr("key_cache"); + T** value_cache_ptrs = output_tensors->getPtr("value_cache"); + + const int layer_id = input_tensors->getVal("layer_id"); + + const int max_seq_len = input_tensors->getVal("max_seq_len"); + const int step = input_tensors->getVal("step"); + + const int step_1 = step - 1; + + const int batch_size = input_tensors->at("input_query").shape[0]; + const int beam_width = cache_indir != nullptr ? input_tensors->at("cache_indirection").shape[1] : 1; + + allocateBuffer(batch_size, step, max_seq_len); + + PUSH_RANGE("qkv_gemm"); + linear_.forward(qkv_buf_, input_query_data, batch_size, weights->qkv); + POP_RANGE; + + const auto kv_cache_layer_offset = layer_id * local_kv_head_num_ * max_seq_len * size_per_head_; + const int memory_len = max_seq_len; + + fusedQKV_masked_attention_dispatch( + qkv_buf_, + weights->qkv.bias, // query_weight.bias, + nullptr, // relative_attention_bias, + nullptr, + nullptr, + key_cache_ptrs, + value_cache_ptrs, + kv_cache_layer_offset, + cache_indir, + context_buf_, + finished_data, + sequence_lengths_data, // NOTE: current seq len including padding (fixed after meeting the finished id) + batch_size, + batch_size, + beam_width, + local_head_num_, + local_kv_head_num_, + size_per_head_, + rotary_embedding_dim_, + memory_len, + nullptr, // prefix_prompt_lengths + 0, // max_prefix_prompt_length + 0, // max_input_length, not used w/o linear_bias_slopes + input_tensors->getPtr("total_padding_tokens", nullptr), + step, + 1.f, // q_scaling + 0, // relative_attention_bias_stride + nullptr, // linear_bias_slopes + nullptr, // masked_tokens_data, + nullptr, // ia3_tasks + nullptr, // ia3_key_weights + nullptr, // ia3_value_weights + nullptr, // qkv_scale_out + nullptr, // attention_out_scale + quant_policy_, // int8_mode + weights->past_kv_scale.data(), // attention kv scale + stream_); + sync_check_cuda_error(); + + linear_.forward(hidden_features_data, context_buf_, batch_size, weights->output); + + if (tensor_para_.world_size_ > 1) { + NcclGuard nccl_guard(tensor_para_, stream_); + ftNcclAllReduceSum( + hidden_features_data, hidden_features_data, batch_size * hidden_units_, tensor_para_, stream_); + sync_check_cuda_error(); + } + + if (is_free_buffer_after_forward_) { + freeBuffer(); + } + + // LOG(WARNING); +} + +template class LlamaDecoderSelfAttentionLayer; +template class LlamaDecoderSelfAttentionLayer; + +} // namespace fastertransformer diff --git a/src/fastertransformer/models/llama/LlamaDecoderSelfAttentionLayer.h b/src/fastertransformer/models/llama/LlamaDecoderSelfAttentionLayer.h new file mode 100644 index 000000000..86b9afb49 --- /dev/null +++ b/src/fastertransformer/models/llama/LlamaDecoderSelfAttentionLayer.h @@ -0,0 +1,105 @@ +/* + * Copyright (c) OpenMMLab. All rights reserved. + * Copyright (c) 2019-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Modified from +// https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/layers/attention_layers/DecoderSelfAttentionLayer.h + +#pragma once + +#include "src/fastertransformer/models/llama/LlamaDenseWeight.h" +#include "src/fastertransformer/models/llama/LlamaLinear.h" +#include "src/fastertransformer/utils/Tensor.h" +#include "src/fastertransformer/utils/nccl_utils.h" + +namespace fastertransformer { + +template +class LlamaDecoderSelfAttentionLayer { +public: + void freeBuffer(); + void allocateBuffer(size_t batch_size, int key_len, int max_memory_len); + + LlamaDecoderSelfAttentionLayer(size_t head_num, + size_t kv_head_num, + size_t size_per_head, + size_t rotary_embedding_dim, + bool neox_rotary_style, + NcclParam tensor_para, + cudaStream_t stream, + cublasMMWrapper* cublas_wrapper, + IAllocator* allocator, + bool is_free_buffer_after_forward, + int quant_policy): + head_num_(head_num), + kv_head_num_(kv_head_num), + size_per_head_(size_per_head), + hidden_units_(head_num * size_per_head), + local_head_num_(head_num / tensor_para.world_size_), + local_kv_head_num_(kv_head_num_ / tensor_para.world_size_), + local_hidden_units_(hidden_units_ / tensor_para.world_size_), + rotary_embedding_dim_(rotary_embedding_dim), + neox_rotary_style_(neox_rotary_style), + tensor_para_(tensor_para), + stream_(stream), + linear_(cublas_wrapper, stream), + allocator_(allocator), + is_free_buffer_after_forward_(is_free_buffer_after_forward), + quant_policy_(quant_policy) + { + } + + ~LlamaDecoderSelfAttentionLayer() + { + freeBuffer(); + } + + void forward(TensorMap* output_tensors, const TensorMap* input_tensors, const LlamaAttentionWeight* weights); + +private: + const size_t head_num_; + const size_t kv_head_num_; + const size_t size_per_head_; + const size_t hidden_units_; + const size_t local_head_num_; + const size_t local_kv_head_num_; + const size_t local_hidden_units_; + const size_t rotary_embedding_dim_; + const bool is_free_buffer_after_forward_; + const int quant_policy_; + + const bool neox_rotary_style_; + + NcclParam tensor_para_; + + cudaStream_t stream_; + IAllocator* allocator_; + LlamaLinear linear_; + + T* qkv_buf_ = nullptr; + T* context_buf_ = nullptr; + // T* weight_buf_ = nullptr; + // T* k_cache_buf_{}; + // T* v_cache_buf_{}; + + // T* tmp_k_cache_buf_{}; + // T* tmp_v_cache_buf_{}; + // T* tmp_cache_buf_{}; + + bool is_allocate_buffer_{}; +}; + +} // namespace fastertransformer diff --git a/src/fastertransformer/models/llama/LlamaDenseWeight.h b/src/fastertransformer/models/llama/LlamaDenseWeight.h new file mode 100644 index 000000000..3d3776698 --- /dev/null +++ b/src/fastertransformer/models/llama/LlamaDenseWeight.h @@ -0,0 +1,79 @@ +/* + * Copyright (c) OpenMMLab. All rights reserved. + * Copyright (c) 2019-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Modified from https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/layers/DenseWeight.h + +#pragma once + +#include "src/fastertransformer/layers/FfnWeight.h" +#include "src/fastertransformer/layers/attention_layers/AttentionWeight.h" +#include "src/fastertransformer/utils/cuda_utils.h" + +namespace fastertransformer { + +enum class WeightType : int +{ + kFP32, + kFP16, + kFP8, // not supported yet + kINT8, + kINT4 +}; + +inline size_t getBitSize(WeightType type) +{ + switch (type) { + case WeightType::kFP32: + return 32; + case WeightType::kFP16: + return 16; + case WeightType::kFP8: + return 8; + case WeightType::kINT8: + return 8; + case WeightType::kINT4: + return 4; + } +} + +template +struct LlamaDenseWeight { + + size_t input_dims; + size_t output_dims; + void* kernel; + WeightType type; + T* bias; + T* scales; + T* zeros; +}; + +template +struct LlamaAttentionWeight { + LlamaDenseWeight qkv; + LlamaDenseWeight output; + std::vector past_kv_scale; +}; + +template +struct LlamaFfnWeight { + LlamaDenseWeight gating; + LlamaDenseWeight intermediate; + LlamaDenseWeight output; +}; + +} // namespace fastertransformer diff --git a/src/fastertransformer/models/llama/LlamaFfnLayer.cc b/src/fastertransformer/models/llama/LlamaFfnLayer.cc new file mode 100644 index 000000000..3ed15f4c4 --- /dev/null +++ b/src/fastertransformer/models/llama/LlamaFfnLayer.cc @@ -0,0 +1,113 @@ +/* + * Copyright (c) OpenMMLab. All rights reserved. + * Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Modified from https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/layers/FfnLayer.h + +#include "src/fastertransformer/models/llama/LlamaFfnLayer.h" +#include "src/fastertransformer/kernels/activation_kernels.h" +#include "src/fastertransformer/models/llama/LlamaNcclGuard.h" +#include "src/fastertransformer/utils/nvtx_utils.h" +// #include + +namespace fastertransformer { + +template +void LlamaFfnLayer::allocateBuffer(size_t token_num) +{ + inter_buf_ = (T*)allocator_->reMalloc(inter_buf_, sizeof(T) * token_num * inter_size_, false); + gating_buf_ = (T*)allocator_->reMalloc(gating_buf_, sizeof(T) * token_num * inter_size_, false); + is_allocate_buffer_ = true; +} + +template +void LlamaFfnLayer::freeBuffer() +{ + if (is_allocate_buffer_) { + allocator_->free((void**)&inter_buf_); + allocator_->free((void**)&gating_buf_); + is_allocate_buffer_ = false; + } +} + +template +void LlamaFfnLayer::activation(int num_token) +{ + invokeGenericActivation(gating_buf_, + (const T*)nullptr, // bias + inter_buf_, + (const T*)nullptr, // gated_bias + nullptr, // ia3_tasks + (const T*)nullptr, // ia3_weights + num_token, // m + inter_size_, // n + 0, // int8_mode + nullptr, // activation_in + nullptr, // activation_out + nullptr, // padding_offset + 0, // seq_len + stream_); + sync_check_cuda_error(); +} + +template +void LlamaFfnLayer::forward(TensorMap* output_tensors, + const TensorMap* input_tensors, + const LlamaFfnWeight* weights) +{ + /** + * input_tensors: + * \param ffn_input [token_num, hidden_dimension] + * + * output_tensors: + * \param ffn_output [token_num, hidden_dimension] + */ + + const size_t num_token = input_tensors->at("ffn_input").shape[0]; + // LOG(WARNING); + + allocateBuffer(num_token); + + const T* ffn_input_data = input_tensors->at("ffn_input").getPtr(); + T* ffn_output_data = output_tensors->at("ffn_output").getPtr(); + + PUSH_RANGE("ffn"); + // TODO: fuse the two GEMMs with activation + linear_.forward(gating_buf_, ffn_input_data, num_token, weights->gating); + + linear_.forward(inter_buf_, ffn_input_data, num_token, weights->intermediate); + + activation(num_token); + + linear_.forward(ffn_output_data, gating_buf_, num_token, weights->output); + POP_RANGE; + + if (tensor_para_.world_size_ > 1) { + NcclGuard nccl_guard(tensor_para_, stream_); + ftNcclAllReduceSum(ffn_output_data, ffn_output_data, num_token * hidden_units_, tensor_para_, stream_); + sync_check_cuda_error(); + } + + if (is_free_buffer_after_forward_) { + freeBuffer(); + } + // LOG(WARNING); +} + +template class LlamaFfnLayer; +template class LlamaFfnLayer; + +} // namespace fastertransformer diff --git a/src/fastertransformer/models/llama/LlamaFfnLayer.h b/src/fastertransformer/models/llama/LlamaFfnLayer.h new file mode 100644 index 000000000..4dcdb2589 --- /dev/null +++ b/src/fastertransformer/models/llama/LlamaFfnLayer.h @@ -0,0 +1,85 @@ +/* + * Copyright (c) OpenMMLab. All rights reserved. + * Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Modified from https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/layers/FfnLayer.cc + +#pragma once + +// #include "src/fastertransformer/layers/FfnLayer.h" +#include "src/fastertransformer/models/llama/LlamaDecoderLayerWeight.h" +#include "src/fastertransformer/models/llama/LlamaLinear.h" +#include "src/fastertransformer/utils/custom_ar_comm.h" +#include "src/fastertransformer/utils/nccl_utils.h" +#include + +namespace fastertransformer { + +template +class LlamaFfnLayer { +public: + LlamaFfnLayer(size_t head_num, + size_t size_per_head, + size_t inter_size, + NcclParam tensor_para, + cudaStream_t stream, + cublasMMWrapper* cublas_wrapper, + IAllocator* allocator, + bool is_free_buffer_after_forward): + head_num_(head_num), + size_per_head_(size_per_head), + inter_size_(inter_size / tensor_para.world_size_), + hidden_units_(head_num * size_per_head), + stream_(stream), + linear_(cublas_wrapper, stream), + allocator_(allocator), + tensor_para_(tensor_para), + is_free_buffer_after_forward_(is_free_buffer_after_forward) + { + } + + ~LlamaFfnLayer() + { + freeBuffer(); + } + + void forward(TensorMap* output_tensors, const TensorMap* input_tensors, const LlamaFfnWeight* weights); + +private: + void allocateBuffer(size_t token_num); + + void freeBuffer(); + + void activation(int num_token); + + size_t head_num_; + size_t size_per_head_; + size_t inter_size_; + size_t hidden_units_; + cudaStream_t stream_; + LlamaLinear linear_; + IAllocator* allocator_; + bool is_free_buffer_after_forward_; + + T* gating_buf_{}; + T* inter_buf_{}; + + NcclParam tensor_para_; + + bool is_allocate_buffer_{}; +}; + +} // namespace fastertransformer diff --git a/src/fastertransformer/models/llama/LlamaInstanceComm.h b/src/fastertransformer/models/llama/LlamaInstanceComm.h new file mode 100644 index 000000000..559abdadd --- /dev/null +++ b/src/fastertransformer/models/llama/LlamaInstanceComm.h @@ -0,0 +1,34 @@ +// Copyright (c) OpenMMLab. All rights reserved. + +#pragma once + +#include "src/fastertransformer/models/llama/Barrier.h" +#include "src/fastertransformer/utils/instance_comm.h" + +namespace fastertransformer { + +class LlamaInstanceComm: public AbstractInstanceComm { +public: + LlamaInstanceComm(int count): barrier_(count) {} + + void barrier() override + { + barrier_.wait(); + } + + void setSharedObject(void* p) override + { + ptr = p; + } + + void* getSharedObject() override + { + return ptr; + } + +private: + Barrier barrier_; + void* ptr{}; +}; + +} // namespace fastertransformer diff --git a/src/fastertransformer/models/llama/LlamaLinear.h b/src/fastertransformer/models/llama/LlamaLinear.h new file mode 100644 index 000000000..eb0e9c7b6 --- /dev/null +++ b/src/fastertransformer/models/llama/LlamaLinear.h @@ -0,0 +1,61 @@ +// Copyright (c) OpenMMLab. All rights reserved. + +#pragma once + +#include "src/fastertransformer/models/llama/LlamaDenseWeight.h" +#include "src/fastertransformer/models/llama/llama_kernels.h" +#include "src/fastertransformer/utils/cublasMMWrapper.h" +#include "src/fastertransformer/utils/cuda_utils.h" + +namespace fastertransformer { + +template +class LlamaLinear { +public: + LlamaLinear(cublasMMWrapper* cublas_wrapper, cudaStream_t stream): cublas_wrapper_(cublas_wrapper), stream_(stream) + { + } + + void forward(T* output_data, const T* input_data, int batch_size, const LlamaDenseWeight& weight) + { + switch (weight.type) { + case WeightType::kFP16: + case WeightType::kFP32: + forwardFp(output_data, input_data, batch_size, weight); + break; + case WeightType::kINT4: + forwardInt4(output_data, input_data, batch_size, weight); + break; + default: + FT_CHECK(0); + } + } + +private: + void forwardFp(T* output_data, const T* input_data, int batch_size, const LlamaDenseWeight& weight) + { + cublas_wrapper_->Gemm(CUBLAS_OP_N, + CUBLAS_OP_N, + weight.output_dims, + batch_size, + weight.input_dims, + (const T*)weight.kernel, + weight.output_dims, + input_data, + weight.input_dims, + output_data, + weight.output_dims); + sync_check_cuda_error(); + } + + void forwardInt4(T* output_data, const T* input_data, int batch_size, const LlamaDenseWeight& weight) + { + FT_CHECK_WITH_INFO(0, "Not implemented"); + } + +private: + cublasMMWrapper* cublas_wrapper_; + cudaStream_t stream_{}; +}; + +} // namespace fastertransformer diff --git a/src/fastertransformer/models/llama/LlamaNcclGuard.h b/src/fastertransformer/models/llama/LlamaNcclGuard.h new file mode 100644 index 000000000..5f9e14046 --- /dev/null +++ b/src/fastertransformer/models/llama/LlamaNcclGuard.h @@ -0,0 +1,92 @@ +// Copyright (c) OpenMMLab. All rights reserved. + +#pragma once + +#include "src/fastertransformer/utils/nccl_utils.h" +#include +#include +#include +#include +#include + +namespace fastertransformer { + +struct NcclGuard { + static constexpr int kMaxGroupCount = 32; + + static std::mutex& globalNcclMutex() + { + static std::mutex inst; + return inst; + } + + struct GroupState { + std::mutex mutex; + std::condition_variable cv; + int ref_count; + }; + + static GroupState& groupState(int group_id) + { + static std::array array{}; + FT_CHECK(group_id < kMaxGroupCount); + return array[group_id]; + } + + NcclGuard(NcclParam tensor_para, cudaStream_t stream, bool barrier = false): + tensor_para_(tensor_para), stream_(stream), barrier_(barrier) + { + if (is_active()) { + auto& group = groupState(tensor_para_.group_id_); + if (tensor_para_.rank_ == 0) { + /// TODO: use std::optional after switching to C++17 + global_nccl_lock_ = std::make_unique>(globalNcclMutex()); + { + std::lock_guard lock(group.mutex); + group.ref_count = tensor_para_.world_size_; + } + group.cv.notify_all(); + } + else { + std::unique_lock lock(group.mutex); + group.cv.wait(lock, [&] { return group.ref_count > 0; }); + } + } + } + + ~NcclGuard() + { + if (is_active()) { + ftNcclStreamSynchronize(tensor_para_, NcclParam{}, stream_); + + auto& group = groupState(tensor_para_.group_id_); + + int value = -1; + { + std::lock_guard lock(group.mutex); + value = --group.ref_count; + } + if (value == 0) { + group.cv.notify_all(); + } + else if (barrier_ || tensor_para_.rank_ == 0) { + std::unique_lock lock(group.mutex); + group.cv.wait(lock, [&] { return group.ref_count == 0; }); + } + + // rank 0 unlocks global NCCL mutex automatically + } + } + + bool is_active() + { + return barrier_ || (ftNcclGroupCount() > 1 && tensor_para_.world_size_ > 1); + } + + NcclParam tensor_para_; + cudaStream_t stream_; + bool barrier_; + std::unique_ptr> global_nccl_lock_; +}; + +} // namespace fastertransformer diff --git a/src/fastertransformer/models/llama/LlamaV2.cc b/src/fastertransformer/models/llama/LlamaV2.cc new file mode 100644 index 000000000..2ce57b452 --- /dev/null +++ b/src/fastertransformer/models/llama/LlamaV2.cc @@ -0,0 +1,603 @@ +/* + * Copyright (c) OpenMMLab. All rights reserved. + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2021, NAVER Corp. Authored by CLOVA. + * Copyright (c) 2022, SK Telecom Authored by A. Dialog + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Modified from +// https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/models/multi_gpu_gpt/ParallelGpt.cc + +#include "src/fastertransformer/models/llama/LlamaV2.h" +#include "src/fastertransformer/kernels/decoding_kernels.h" +#include "src/fastertransformer/kernels/gpt_kernels.h" +#include "src/fastertransformer/models/llama/LlamaBatch.h" +#include "src/fastertransformer/models/llama/LlamaNcclGuard.h" +#include "src/fastertransformer/models/llama/LlamaWeight.h" +#include "src/fastertransformer/models/llama/Request.h" +#include "src/fastertransformer/models/llama/llama_utils.h" +#include "src/fastertransformer/utils/Tensor.h" +#include "src/fastertransformer/utils/cuda_utils.h" +#include +#include +#include +#include + +namespace fastertransformer { + +template +LlamaV2::LlamaV2(size_t head_num, + size_t kv_head_num, + size_t size_per_head, + size_t inter_size, + size_t num_layer, + size_t vocab_size, + size_t rotary_embedding_dim, + float norm_eps, + int max_batch_size, + int max_context_token_num, + int session_len, + int step_length, + int start_id, + int end_id, + int cache_max_entry_count, + int cache_chunk_size, + int quant_policy, + bool use_context_fmha, + std::shared_ptr shared_state, + LlamaWeight* weights, + NcclParam tensor_para, + cudaStream_t stream, + cublasMMWrapper* cublas_wrapper, + IAllocator* allocator, + bool is_free_buffer_after_forward, + cudaDeviceProp* cuda_device_prop): + head_num_(head_num), + size_per_head_(size_per_head), + inter_size_(inter_size), + num_layer_(num_layer), + vocab_size_(vocab_size), + rotary_embedding_dim_(rotary_embedding_dim), + rmsnorm_eps_(norm_eps), + start_id_(start_id), + end_id_(end_id), + hidden_units_(head_num * size_per_head), + local_head_num_(head_num / tensor_para.world_size_), + weights_(weights), + tensor_para_(tensor_para), + stream_(stream), + cublas_wrapper_(cublas_wrapper), + allocator_(allocator), + is_free_buffer_after_forward_(is_free_buffer_after_forward), + cuda_device_prop_(cuda_device_prop), + debug_(isDebug()), + step_length_(step_length), + batch_(max_batch_size, max_context_token_num, session_len, this), + shared_state_(shared_state) + +{ + FT_LOG_DEBUG(__PRETTY_FUNCTION__); + FT_CHECK(vocab_size_ % tensor_para_.world_size_ == 0); + FT_LOG_INFO("NCCL group_id = %d", tensor_para_.group_id_); + + size_t elem_bits = 0; + if (quant_policy & QuantPolicy::kCacheKVInt8) { + elem_bits = sizeof(int8_t) * 8; + if (use_context_fmha) { + FT_LOG_ERROR("use_context_fmha not support int8"); + assert(0); + } + } + else { + elem_bits = sizeof(T) * 8; + } + + const size_t local_kv_head_num = kv_head_num / tensor_para.world_size_; + + kv_cache_mgr_ = std::make_unique(num_layer_, + local_kv_head_num, + size_per_head_, + session_len, + elem_bits, + cache_max_entry_count, + cache_chunk_size, + tensor_para.rank_, + allocator); + initialize(kv_head_num, use_context_fmha, quant_policy); + start(); +} + +template +LlamaV2::~LlamaV2() +{ + internal_thread_.join(); + + delete decoder_; + delete dynamic_decode_layer_; + delete context_decoder_; +} + +template +void LlamaV2::initialize(size_t kv_head_num, bool use_context_fmha, int quant_policy) +{ + FT_LOG_DEBUG(__PRETTY_FUNCTION__); + + context_decoder_ = new LlamaContextDecoder(head_num_, + kv_head_num, + size_per_head_, + inter_size_, + num_layer_, + rotary_embedding_dim_, + rmsnorm_eps_, + tensor_para_, + stream_, + cublas_wrapper_, + allocator_, + is_free_buffer_after_forward_, + use_context_fmha, + quant_policy); + + decoder_ = new LlamaDecoder(head_num_, + kv_head_num, + size_per_head_, + inter_size_, + num_layer_, + rotary_embedding_dim_, + rmsnorm_eps_, + tensor_para_, + stream_, + cublas_wrapper_, + allocator_, + is_free_buffer_after_forward_, + quant_policy); + + dynamic_decode_layer_ = new DynamicDecodeLayer(vocab_size_, + vocab_size_, // vocab_size_padded, + 0, // end_id, deprecated + stream_, + cublas_wrapper_, + allocator_, + is_free_buffer_after_forward_, + cuda_device_prop_); +} + +template +void LlamaV2::embeddingLookup(T* embeddings, const int* token_ids_buf, int batch_size, int step) +{ + FT_LOG_DEBUG(__PRETTY_FUNCTION__); + // ! This kernel can't be used in context decoding + invokeEmbeddingLookupPosEncodingPadCount(embeddings, + weights_->pre_decoder_embedding_table, + static_cast(nullptr), // position encoding + token_ids_buf, + static_cast(nullptr), // padding count, not used w/o pos-code + batch_size, + hidden_units_, + static_cast(1.), // scale + step, // step, used int index into output_ids_buf_ + batch_size, // token_num + 0, // ite + stream_); + sync_check_cuda_error(); +} + +template +void LlamaV2::contextDecode(T* deocder_output, + uintptr_t* k_cache_ptr, + uintptr_t* v_cache_ptr, + T* context_decoder_input_buf, + T* context_decoder_output_buf, + const int* input_ids, + const int* input_length, + const int* history_length, + const int* context_length, + size_t token_num, + size_t max_input_len, + size_t max_context_len, + size_t session_len, + size_t batch_size) +{ + FT_LOG_DEBUG(__PRETTY_FUNCTION__); + + if (tensor_para_.rank_ == 0) { + FT_LOG_INFO("context decoding start"); + } + + invokeInputIdsEmbeddingLookupPosEncoding(context_decoder_input_buf, + nullptr, // processed somewhere else + weights_->pre_decoder_embedding_table, + static_cast(nullptr), + pPromptTuningParam{}, + input_ids, + 0, // only used for position encoding + token_num, + token_num, + 1, + hidden_units_, + stream_); + sync_check_cuda_error(); + + const auto dtype = getTensorType(); + const auto bsz = batch_size; + + const int max_q_len = max_input_len; + const int max_kv_len = max_context_len; + const int max_seq_len = session_len; + + std::unordered_map decoder_input_tensors{ + {"decoder_input", {MEMORY_GPU, dtype, {token_num, hidden_units_}, context_decoder_input_buf}}, + {"output_norm_weight", {MEMORY_GPU, dtype, {hidden_units_}, weights_->output_norm_weight}}, + {"input_lengths", {MEMORY_GPU, TYPE_INT32, {bsz}, input_length}}, + {"history_lengths", {MEMORY_GPU, TYPE_INT32, {bsz}, history_length}}, + {"context_lengths", {MEMORY_GPU, TYPE_INT32, {bsz}, context_length}}, + {"max_q_len", {MEMORY_CPU, TYPE_INT32, {1}, &max_q_len}}, + {"max_kv_len", {MEMORY_CPU, TYPE_INT32, {1}, &max_kv_len}}, + {"max_seq_len", {MEMORY_CPU, TYPE_INT32, {1}, &max_seq_len}}, + }; + + std::unordered_map decoder_output_tensors{ + {"decoder_output", {MEMORY_GPU, dtype, {bsz, max_input_len, hidden_units_}, context_decoder_output_buf}}, + {"key_cache", {MEMORY_GPU, TYPE_UINT64, {bsz}, k_cache_ptr}}, + {"value_cache", {MEMORY_GPU, TYPE_UINT64, {bsz}, v_cache_ptr}}, + {"last_token_hidden_units", {MEMORY_GPU, dtype, {bsz, hidden_units_}, deocder_output}}}; + + context_decoder_->forward(&decoder_output_tensors, &decoder_input_tensors, &weights_->decoder_layer_weights); + + if (tensor_para_.rank_ == 0) { + FT_LOG_INFO("context decoding end"); + } +} + +template +void LlamaV2::decoderForward(T* decoder_output, + uintptr_t* k_cache_ptr, + uintptr_t* v_cache_ptr, + T* decoder_input, + const int* sequence_length, + const int* total_padding_count, + bool* finished, + int step, + int ite, + size_t session_len, + size_t batch_size) +{ + FT_LOG_DEBUG(__PRETTY_FUNCTION__); + + const int max_seq_len = session_len; + const auto dtype = getTensorType(); + + // max_input_length is not used w/o linear_bias_slopes + // sequence_lengths_ will be incremented in dynamic decode + std::unordered_map decoder_input_tensors{ + {"decoder_input", {MEMORY_GPU, dtype, {batch_size, hidden_units_}, decoder_input}}, + {"sequence_lengths", {MEMORY_GPU, TYPE_INT32, {batch_size}, sequence_length}}, + {"total_padding_tokens", {MEMORY_GPU, TYPE_INT32, {batch_size}, total_padding_count}}, + {"max_seq_len", {MEMORY_CPU, TYPE_INT32, {1}, &max_seq_len}}, + {"finished", {MEMORY_GPU, TYPE_BOOL, {batch_size}, finished}}, + {"output_norm_weight", {MEMORY_GPU, dtype, {hidden_units_}, weights_->output_norm_weight}}, + {"step", {MEMORY_CPU, TYPE_INT32, {1}, &step}}, + {"ite", {MEMORY_CPU, TYPE_INT32, {1}, &ite}}, + }; + + // LOG(ERROR) << key_cache_ << " " << value_cache_; + std::unordered_map decoder_output_tensors{ + {"decoder_output", {MEMORY_GPU, dtype, {batch_size, hidden_units_}, decoder_output}}, + {"key_cache", {MEMORY_GPU, TYPE_UINT64, {batch_size}, k_cache_ptr}}, + {"value_cache", {MEMORY_GPU, TYPE_UINT64, {batch_size}, v_cache_ptr}}, + }; + + decoder_->forward(&decoder_output_tensors, &decoder_input_tensors, &weights_->decoder_layer_weights); +} + +template +void LlamaV2::postDecodeEmbedding(float* logits, float* local_logits, const T* decoder_output, int batch_size) +{ + FT_LOG_DEBUG(__PRETTY_FUNCTION__); + cudaDataType_t data_type = getCudaDataType(); + float alpha = 1.f; + float beta = 0.f; + if (tensor_para_.world_size_ == 1) { + cublas_wrapper_->Gemm(CUBLAS_OP_T, + CUBLAS_OP_N, + vocab_size_, // n + batch_size, + hidden_units_, // k + &alpha, + weights_->post_decoder_embedding_kernel, + data_type, + hidden_units_, // k + decoder_output, + data_type, + hidden_units_, // k + &beta, + logits, + CUDA_R_32F, + vocab_size_, // n + CUDA_R_32F, + cublasGemmAlgo_t(-1)); + } + else { + FT_CHECK(vocab_size_ % tensor_para_.world_size_ == 0); + const size_t local_vocab_size = vocab_size_ / tensor_para_.world_size_; + cublas_wrapper_->Gemm(CUBLAS_OP_T, + CUBLAS_OP_N, + local_vocab_size, // n + batch_size, + hidden_units_, // k + &alpha, + weights_->post_decoder_embedding_kernel + + tensor_para_.rank_ * local_vocab_size * hidden_units_, + data_type, + hidden_units_, // k + decoder_output, + data_type, + hidden_units_, // k + &beta, + local_logits + tensor_para_.rank_ * batch_size * local_vocab_size, + CUDA_R_32F, + local_vocab_size, // n + CUDA_R_32F, + cublasGemmAlgo_t(-1)); + { + NcclGuard nccl_guard(tensor_para_, stream_); + ftNcclAllGather(local_logits, // send_buf + local_logits, // recv_buf + batch_size * local_vocab_size, // data_size + tensor_para_.rank_, + tensor_para_, + stream_); + } + invokeTransposeAxis01(logits, local_logits, tensor_para_.world_size_, batch_size, local_vocab_size, stream_); + sync_check_cuda_error(); + } +} + +template +void LlamaV2::dynamicDecode(int* token_ids, + bool* finished, + int* sequence_length, + bool* should_stop, + TensorMap* inputs, + TensorMap* outputs, + const float* logits, + const uint32_t* seq_limit_len, + const int* context_length, + const int* end_ids, + int step, + int ite, + size_t max_context_len, + size_t token_ids_len, + size_t batch_size) +{ + FT_LOG_DEBUG(__PRETTY_FUNCTION__); + int local_batch_size = (int)batch_size; + + std::unordered_map dynamic_decode_input_tensors{ + {"logits", {MEMORY_GPU, TYPE_FP32, {batch_size, (size_t)1, vocab_size_}, logits}}, + {"step", {MEMORY_CPU, TYPE_INT32, {1}, &step}}, + {"max_input_length", {MEMORY_CPU, TYPE_INT32, {1}, &max_context_len}}, + {"sequence_limit_length", {MEMORY_GPU, TYPE_UINT32, {batch_size}, seq_limit_len}}, + {"input_lengths", {MEMORY_GPU, TYPE_INT32, {batch_size, 1}, context_length}}, + {"ite", {MEMORY_CPU, TYPE_UINT32, {1}, &ite}}, + {"end_id", {MEMORY_GPU, TYPE_INT32, {batch_size}, end_ids}}, + {"local_batch_size", {MEMORY_CPU, TYPE_INT32, {1}, &local_batch_size}}, + }; + + const std::vector optional_inputs{"stop_words_list", + "bad_words_list", + "runtime_top_k", + "runtime_top_p", + "temperature", + "repetition_penalty", + "random_seed"}; + for (const auto& key : optional_inputs) { + if (inputs->isExist(key)) { + dynamic_decode_input_tensors.insert({key, inputs->at(key)}); + } + } + + std::unordered_map dynamic_decode_output_tensors{ + {"output_ids", {MEMORY_GPU, TYPE_INT32, {token_ids_len, batch_size, 1U}, token_ids}}, + {"finished", {MEMORY_GPU, TYPE_BOOL, {batch_size}, finished}}, + {"sequence_length", {MEMORY_GPU, TYPE_INT32, {batch_size}, sequence_length}}, + {"should_stop", {MEMORY_CPU, TYPE_BOOL, {1}, should_stop}}}; + + const std::vector optional_outputs{"cum_log_probs", "output_log_probs"}; + for (const auto& key : optional_outputs) { + if (outputs->isExist(key)) { + dynamic_decode_output_tensors.insert({key, outputs->at(key)}); + } + } + + dynamic_decode_layer_->forward(&dynamic_decode_output_tensors, &dynamic_decode_input_tensors); +} + +template +void LlamaV2::internalThreadEntry(int device_id) +{ + FT_LOG_INFO("[internalThreadEntry] %d", (int)tensor_para_.rank_); + check_cuda_error(cudaSetDevice(device_id)); + + auto& request_queue = shared_state_->request_queue; + auto& infer_requests = shared_state_->infer_requests; + auto& stop_requests = shared_state_->stop_requests; + + while (1) { + if (tensor_para_.rank_ == 0) { + const int free_slot_count = batch_.maxSize() - batch_.size() + batch_.finishedCount(); + const bool is_empty = free_slot_count == batch_.maxSize(); + + request_queue.dequeue(stop_requests, infer_requests, free_slot_count, is_empty); + + batch_.verifyRequests(stop_requests, infer_requests); + } + + // wait while rank-0 is dequeueing + shared_state_->barrier->wait(); + + bool modified = false; + + if (!(batch_.finishedCount() == 0 && stop_requests.empty() && infer_requests.empty())) { + batch_.handleStopRequests(stop_requests); + batch_.synchronize(); + modified = true; + } + + const int infer_request_count = infer_requests.size(); + + if (!infer_requests.empty()) { + batch_.initialize(infer_requests); // reinitialize when new requests come, possible buffer allocation + batch_.contextDecode(); + modified = true; + } + + // wait while shared stop/infer_requests is being used + shared_state_->barrier->wait(); + + if (batch_.size()) { + if (modified) { + batch_.initializeGeneration(); + batch_.initializeSampling(infer_request_count); + } + for (int i = 0; i < step_length_; ++i) { + if (!batch_.generate()) { + break; + } + } + batch_.finish(); + } + } + + FT_CHECK(0); +} + +template +void LlamaV2::start() +{ + int device_id = -1; + check_cuda_error(cudaGetDevice(&device_id)); + internal_thread_ = std::thread(&LlamaV2::internalThreadEntry, this, device_id); +} + +static inline Tensor slice(const Tensor& tensor, int index) +{ + auto shape = tensor.shape; + if (shape.at(0) == 1) { + return tensor; + } + shape[0] = 1; + const auto offset = std::accumulate(shape.begin(), shape.end(), (size_t)index, std::multiplies<>{}); + return tensor.slice(shape, offset); +} + +// ! implicit conversion from `unordered_map` to `TensorMap` drops 0-sized tensors +static inline TensorMap slice(const std::unordered_map& src, int index) +{ + TensorMap dst; + for (const auto& kv : src) { + dst.insert({kv.first, slice(kv.second, index)}); + } + return dst; +} + +template +void LlamaV2::forward(std::unordered_map* outputs, + const std::unordered_map* inputs, + Control control) +{ + if (debug_) { + if (tensor_para_.rank_ == 0) { + for (const auto& kv : *inputs) { + FT_LOG_INFO("[forward][rank=%d] INPUT: %s", (int)tensor_para_.rank_, format(kv).c_str()); + } + for (const auto& kv : *outputs) { + FT_LOG_INFO("[forward][rank=%d] OUTPUT: %s", (int)tensor_para_.rank_, format(kv).c_str()); + } + } + } + + const int batch_size = outputs->at("output_ids").shape[0]; + + const auto rank = tensor_para_.rank_; + + std::vector> requests(batch_size); + + // rank-0 allocates all requests for the batch + if (rank == 0) { + for (int i = 0; i < batch_size; ++i) { + requests[i] = std::make_shared(); + requests[i]->inputs.resize(tensor_para_.world_size_); + requests[i]->outputs.resize(tensor_para_.world_size_); + } + control.comm->setSharedObject(&requests); + } + + control.comm->barrier(); + + if (rank != 0) { + requests = *(std::vector>*)control.comm->getSharedObject(); + } + + for (int i = 0; i < batch_size; ++i) { + auto& r = requests[i]; + + r->inputs[rank] = slice(*inputs, i); + r->outputs[rank] = slice(*outputs, i); + + if (rank == 0) { + r->id = r->inputs[rank].getVal("CORRID", i); + r->start_flag = r->inputs[rank].getVal("START", 1); + r->end_flag = r->inputs[rank].getVal("END", 1); + r->stop_flag = r->inputs[rank].getVal("STOP", 0); + r->stream_cb = control.callback; + } + } + + control.comm->barrier(); + + // rank-0 now takes the ownership of `requests` + // rank-0 submits the tasks and wait for finish + std::vector error_codes; + bool has_error = 0; + if (rank == 0) { + FT_LOG_INFO("[forward] Enqueue requests"); + auto futures = shared_state_->request_queue.enqueue(std::move(requests)); + + FT_LOG_INFO("[forward] Wait for requests to complete ..."); + for (auto& f : futures) { + auto ec = f.get(); + error_codes.push_back(ec); + if (ec) { + has_error = true; + } + } + } + + // prevents request tensors being freed before the batch completes + control.comm->barrier(); + + if (rank == 0 && has_error) { + std::stringstream ss; + for (int i = 0; i < error_codes.size(); ++i) { + ss << (i ? "" : " ") << error_codes[i]; + } + throw std::runtime_error(ss.str()); + } +} + +template class LlamaV2; +template class LlamaV2; + +} // namespace fastertransformer diff --git a/src/fastertransformer/models/llama/LlamaV2.h b/src/fastertransformer/models/llama/LlamaV2.h new file mode 100644 index 000000000..dcbfe5ffc --- /dev/null +++ b/src/fastertransformer/models/llama/LlamaV2.h @@ -0,0 +1,192 @@ +/* + * Copyright (c) OpenMMLab. All rights reserved. + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2021, NAVER Corp. Authored by CLOVA. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Modified from +// https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/models/multi_gpu_gpt/ParallelGpt.h + +#pragma once + +#include "src/fastertransformer/layers/DynamicDecodeLayer.h" +#include "src/fastertransformer/models/llama/Barrier.h" +#include "src/fastertransformer/models/llama/LlamaBatch.h" +#include "src/fastertransformer/models/llama/LlamaContextDecoder.h" +#include "src/fastertransformer/models/llama/LlamaDecoder.h" +#include "src/fastertransformer/models/llama/LlamaWeight.h" +#include "src/fastertransformer/models/llama/Request.h" +#include "src/fastertransformer/utils/allocator.h" +#include "src/fastertransformer/utils/cublasMMWrapper.h" +#include "src/fastertransformer/utils/instance_comm.h" +#include "src/fastertransformer/utils/nccl_utils.h" +#include + +namespace fastertransformer { + +template +class LlamaV2 { +public: + struct SharedState { + std::vector> infer_requests; + std::vector> stop_requests; + RequestQueue request_queue; + std::shared_ptr barrier; + }; + + ~LlamaV2(); + + LlamaV2(size_t head_num, + size_t kv_head_num, + size_t size_per_head, + size_t inter_size, + size_t num_layer, + size_t vocab_size, + size_t rotary_embedding_dim, + float norm_eps, + int max_batch_size, + int max_context_token_num, + int session_len, + int step_length, + int start_id, + int end_id, + int cache_max_entry_count, + int cache_chunk_size, + int quant_policy, + bool use_context_fmha, + std::shared_ptr shared_state, + LlamaWeight* weights, + NcclParam tensor_para, + cudaStream_t stream, + cublasMMWrapper* cublas_wrapper, + IAllocator* allocator, + bool is_free_buffer_after_forward, + cudaDeviceProp* cuda_device_prop); + + struct Control { + AbstractInstanceComm* comm; + Request::Callback callback; + }; + + void forward(std::unordered_map* outputs, + const std::unordered_map* inputs, + Control control); + + void stop(const std::vector& seq_ids); + + size_t vocab_size() const noexcept + { + return vocab_size_; + } + +private: + friend class Batch; + + void internalThreadEntry(int device_id); + + void initialize(size_t kv_head_num, bool use_context_fmha, int quant_policy); + + void embeddingLookup(T* embeddings, const int* token_ids_buf, int batch_size, int step); + + void contextDecode(T* deocder_output, + uintptr_t* k_cache_ptr, + uintptr_t* v_cache_ptr, + T* context_decoder_input_buf, + T* context_decoder_output_buf, + const int* input_ids, + const int* input_length, + const int* history_length, + const int* context_length, + size_t token_num, + size_t max_input_len, + size_t max_context_len, + size_t session_len, + size_t batch_size); + + void decoderForward(T* decoder_output, + uintptr_t* k_cache_ptr, + uintptr_t* v_cache_ptr, + T* decoder_input, + const int* sequence_length, + const int* total_padding_count, + bool* finished, + int step, + int ite, + size_t session_len, + size_t batch_size); + + void postDecodeEmbedding(float* logits, float* local_logits, const T* decoder_output, int batch_size); + + void dynamicDecode(int* token_ids, + bool* finished, + int* sequence_length, + bool* should_stop, + TensorMap* inputs, + TensorMap* outputs, + const float* logits, + const uint32_t* seq_limit_len, + const int* context_length, + const int* end_ids, + int step, + int ite, + size_t max_context_len, + size_t token_ids_len, + size_t batch_size); + + void start(); + +private: + friend class LlamaBatch; + + const size_t head_num_; + const size_t size_per_head_; + const size_t inter_size_; + const size_t num_layer_; + const size_t vocab_size_; + const size_t rotary_embedding_dim_; + float rmsnorm_eps_ = 1e-6f; + + static constexpr bool neox_rotary_style_ = false; + + const int start_id_; + const int end_id_; + const size_t hidden_units_; + + const size_t local_head_num_; + NcclParam tensor_para_; + + cudaStream_t stream_; + cublasMMWrapper* cublas_wrapper_; + IAllocator* allocator_; + bool is_free_buffer_after_forward_; + cudaDeviceProp* cuda_device_prop_; + + const bool debug_{false}; + + std::unique_ptr kv_cache_mgr_; + + LlamaWeight* weights_{}; + LlamaDecoder* decoder_{}; + LlamaContextDecoder* context_decoder_{}; + DynamicDecodeLayer* dynamic_decode_layer_{}; + + const int step_length_; + LlamaBatch batch_; + std::shared_ptr shared_state_; + + std::thread internal_thread_; +}; + +} // namespace fastertransformer diff --git a/src/fastertransformer/models/llama/LlamaWeight.cc b/src/fastertransformer/models/llama/LlamaWeight.cc new file mode 100644 index 000000000..e693fe128 --- /dev/null +++ b/src/fastertransformer/models/llama/LlamaWeight.cc @@ -0,0 +1,132 @@ +/* + * Copyright (c) OpenMMLab. All rights reserved. + * Copyright (c) 2019-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Modified from +// https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/models/multi_gpu_gpt/ParallelGptWeight.cc + +#include "src/fastertransformer/models/llama/LlamaWeight.h" + +namespace fastertransformer { + +template +LlamaWeight::LlamaWeight(size_t head_num, + size_t kv_head_num, + size_t size_per_head, + size_t inter_size, + size_t vocab_size, + size_t num_layer, + WeightType weight_type, + bool attn_bias, + size_t tensor_para_size, + size_t tensor_para_rank, + int prefix_cache_len): + hidden_units_(head_num * size_per_head), + inter_size_(inter_size), + vocab_size_(vocab_size), + num_layer_(num_layer), + weight_type_(weight_type), + tensor_para_size_(tensor_para_size), + tensor_para_rank_(tensor_para_rank), + prefix_cache_len_(prefix_cache_len) +{ + decoder_layer_weights.reserve(num_layer_); + for (unsigned l = 0; l < num_layer_; ++l) { + decoder_layer_weights.push_back(new LlamaDecoderLayerWeight(head_num, + kv_head_num, + size_per_head, + inter_size_, + weight_type_, + attn_bias, + tensor_para_size_, + tensor_para_rank_)); + } + + mallocWeights(); +} + +template +LlamaWeight::~LlamaWeight() +{ + cudaFree((void*)pre_decoder_embedding_table); + cudaFree((void*)output_norm_weight); + cudaFree((void*)post_decoder_embedding_kernel); + + if (prefix_cache_key) { + cudaFree((void*)prefix_cache_key); + cudaFree((void*)prefix_cache_token); + } + + pre_decoder_embedding_table = nullptr; + post_decoder_embedding_kernel = nullptr; + + prefix_cache_token = nullptr; + prefix_cache_key = nullptr; + prefix_cache_value = nullptr; +} + +template +void LlamaWeight::mallocWeights() +{ + deviceMalloc((T**)&pre_decoder_embedding_table, vocab_size_ * hidden_units_); + deviceMalloc((T**)&output_norm_weight, hidden_units_); + deviceMalloc((T**)&post_decoder_embedding_kernel, hidden_units_ * vocab_size_); + + if (prefix_cache_len_) { + size_t cache_size = num_layer_ * prefix_cache_len_ * hidden_units_ / tensor_para_size_; + deviceMalloc((T**)&prefix_cache_key, cache_size * 2); + prefix_cache_value = prefix_cache_key + cache_size; + deviceMalloc((int**)&prefix_cache_token, prefix_cache_len_); + } +} + +template +void LlamaWeight::loadModel(std::string dir_path) +{ + FtCudaDataType model_file_type = FtCudaDataType::FP16; + dir_path += '/'; + + loadWeightFromBin((T*)pre_decoder_embedding_table, + {vocab_size_ * hidden_units_}, + dir_path + "tok_embeddings.weight", + model_file_type); + + loadWeightFromBin((T*)output_norm_weight, {hidden_units_}, dir_path + "norm.weight", model_file_type); + + loadWeightFromBin( + (T*)post_decoder_embedding_kernel, {hidden_units_ * vocab_size_}, dir_path + "output.weight", model_file_type); + + if (prefix_cache_len_) { + loadWeightFromBin((float*)prefix_cache_token, {prefix_cache_len_}, dir_path + "prefix_cache.token"); + loadWeightFromBin((T*)prefix_cache_key, + {num_layer_ * prefix_cache_len_, hidden_units_ / tensor_para_size_}, + dir_path + "prefix_cache." + std::to_string(tensor_para_rank_) + ".key", + model_file_type); + loadWeightFromBin((T*)prefix_cache_value, + {num_layer_ * prefix_cache_len_, hidden_units_ / tensor_para_size_}, + dir_path + "prefix_cache." + std::to_string(tensor_para_rank_) + ".value", + model_file_type); + } + + for (unsigned layer = 0; layer < num_layer_; ++layer) { + decoder_layer_weights[layer]->loadModel(dir_path + "layers." + std::to_string(layer), model_file_type); + } +} + +template struct LlamaWeight; +template struct LlamaWeight; + +} // namespace fastertransformer diff --git a/src/fastertransformer/models/llama/LlamaWeight.h b/src/fastertransformer/models/llama/LlamaWeight.h new file mode 100644 index 000000000..f5394aae5 --- /dev/null +++ b/src/fastertransformer/models/llama/LlamaWeight.h @@ -0,0 +1,72 @@ +/* + * Copyright (c) OpenMMLab. All rights reserved. + * Copyright (c) 2019-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Modified from +// https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/models/multi_gpu_gpt/ParallelGptWeight.h + +#pragma once + +#include "src/fastertransformer/models/llama/LlamaDecoderLayerWeight.h" +#include "src/fastertransformer/utils/memory_utils.h" + +namespace fastertransformer { + +template +struct LlamaWeight { + LlamaWeight() = default; + LlamaWeight(size_t head_num, + size_t kv_head_num, + size_t size_per_head, + size_t inter_size, + size_t vocab_size, + size_t num_layer, + WeightType weight_type, + bool attn_bias, + size_t tensor_para_size, + size_t tensor_para_rank, + int prefix_cache_len); + + ~LlamaWeight(); + + LlamaWeight(const LlamaWeight& other) = delete; + LlamaWeight& operator=(const LlamaWeight& other) = delete; + + void loadModel(std::string dir_path); + + std::vector*> decoder_layer_weights; + const T* pre_decoder_embedding_table{}; + const T* output_norm_weight{}; + const T* post_decoder_embedding_kernel{}; + + size_t prefix_cache_len_; + int* prefix_cache_token{}; + T* prefix_cache_key{}; + T* prefix_cache_value{}; + +private: + void mallocWeights(); + + size_t hidden_units_; + size_t inter_size_; + size_t vocab_size_; + size_t num_layer_; + WeightType weight_type_; + size_t tensor_para_size_; + size_t tensor_para_rank_; +}; + +} // namespace fastertransformer diff --git a/src/fastertransformer/models/llama/Request.h b/src/fastertransformer/models/llama/Request.h new file mode 100644 index 000000000..b24cf3c8e --- /dev/null +++ b/src/fastertransformer/models/llama/Request.h @@ -0,0 +1,91 @@ +// Copyright (c) OpenMMLab. All rights reserved. + +#pragma once + +#include "src/fastertransformer/utils/Tensor.h" +#include +#include +#include +#include +#include +#include + +namespace fastertransformer { + +struct Request { + uint64_t id; + bool start_flag; + bool end_flag; + bool stop_flag; + + // per rank inputs/outputs + std::vector inputs; + std::vector outputs; + + using Callback = std::function*)>; + Callback stream_cb; + + enum + { + kInvalid = 1, + kConflict = 2, + kBusy = 3, + kInactive = 4, + kFail = 5 + }; + std::promise signal; +}; + +class RequestQueue { +public: + std::vector> enqueue(std::vector> requests) + { + std::vector> futures; + futures.reserve(requests.size()); + { + std::lock_guard lock(mutex_); + for (auto& r : requests) { + futures.push_back(r->signal.get_future()); + if (r->stop_flag) { + stop_queue_.push(std::move(r)); + } + else { + infer_queue_.push(std::move(r)); + } + } + } + cv_.notify_one(); + return futures; + } + + void dequeue(std::vector>& stop_requests, + std::vector>& infer_requests, + unsigned max_infer_count, + bool blocking) + { + std::unique_lock lock(mutex_); + if (blocking) { + cv_.wait(lock, [this] { return !(stop_queue_.empty() && infer_queue_.empty()); }); + } + + stop_requests.clear(); + while (!stop_queue_.empty()) { + stop_requests.push_back(std::move(stop_queue_.front())); + stop_queue_.pop(); + } + + infer_requests.clear(); + while (!infer_queue_.empty() && infer_requests.size() < max_infer_count) { + infer_requests.push_back(std::move(infer_queue_.front())); + infer_queue_.pop(); + } + } + +private: + std::queue> stop_queue_; + std::queue> infer_queue_; + std::mutex mutex_; + std::condition_variable cv_; +}; + +} // namespace fastertransformer diff --git a/src/fastertransformer/models/llama/fused_multi_head_attention/CMakeLists.txt b/src/fastertransformer/models/llama/fused_multi_head_attention/CMakeLists.txt new file mode 100644 index 000000000..b423a20d7 --- /dev/null +++ b/src/fastertransformer/models/llama/fused_multi_head_attention/CMakeLists.txt @@ -0,0 +1,8 @@ + +cmake_minimum_required(VERSION 3.8) + +add_library(llama_fmha STATIC llama_flash_attention_kernel.cu) +target_include_directories(llama_fmha PRIVATE ${CUTLASS_DIR}/examples) +target_link_libraries(llama_fmha PRIVATE nvidia::cutlass::cutlass) +set_property(TARGET llama_fmha PROPERTY POSITION_INDEPENDENT_CODE ON) +set_property(TARGET llama_fmha PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON) diff --git a/src/fastertransformer/models/llama/fused_multi_head_attention/llama_flash_attention_kernel.cu b/src/fastertransformer/models/llama/fused_multi_head_attention/llama_flash_attention_kernel.cu new file mode 100644 index 000000000..4778f3407 --- /dev/null +++ b/src/fastertransformer/models/llama/fused_multi_head_attention/llama_flash_attention_kernel.cu @@ -0,0 +1,915 @@ +#include "src/fastertransformer/models/llama/llama_kernels.h" +#include "src/fastertransformer/utils/cuda_utils.h" + +#include "42_fused_multi_head_attention/kernel_forward.h" +#include "mma_accum_lambda_iterator.h" +#include "tile_smem_loader.h" +#include +#include +#include +#include +#include + +// modified from: +// https://github.com/NVIDIA/cutlass/blob/main/examples/41_fused_multi_head_attention/kernel_forward.h + +namespace fastertransformer { + +template< + // dtype of Q/K/V/M + typename Element_, + typename ArchTag, + int kQueriesPerBlock, + int kKeysPerBlock_, + int kSingleValueIteration_ = false> +struct LlamaAttentionKernel: + AttentionKernel { + using Base = AttentionKernel; + + using scalar_t = typename Base::scalar_t; + using accum_t = typename Base::accum_t; + using output_t = typename Base::output_t; + using output_accum_t = typename Base::output_accum_t; + using BaseParams = typename Base::Params; + static constexpr auto kSingleValueIteration = kSingleValueIteration_; + static constexpr bool kSupportsBias = true; + static constexpr auto kKeysPerBlock = kKeysPerBlock_; + static constexpr auto kNumThreads = Base::kNumThreads; + static constexpr auto kMinBlocksPerSm = Base::kMinBlocksPerSm; + static constexpr auto kNumWarpsPerBlock = Base::kNumWarpsPerBlock; + static constexpr auto kWarpSize = Base::kWarpSize; + static constexpr auto kKeepOutputInRF = Base::kKeepOutputInRF; + static constexpr bool kNeedsOutputAccumulatorBuffer = Base::kNeedsOutputAccumulatorBuffer; + static constexpr auto kAlignLSE = Base::kAlignLSE; + static constexpr auto kPreloadV = Base::kPreloadV; + static constexpr auto kAlignmentQ = Base::kAlignmentQ; + static constexpr auto kAlignmentK = Base::kAlignmentK; + static constexpr auto kAlignmentV = Base::kAlignmentV; + + struct Params: BaseParams { + scalar_t* attn_bias_ptr; + int32_t bias_strideM; + int32_t bias_strideH; + int32_t bias_strideB; + + bool q_use_seqlens = false; + bool o_use_seqlens = false; + + scalar_t** q_batch_seqs_ptr = nullptr; + scalar_t** k_batch_seqs_ptr = nullptr; + scalar_t** v_batch_seqs_ptr = nullptr; + output_t** o_batch_seqs_ptr = nullptr; + + int q_batch_seqs_offset = 0; + int k_batch_seqs_offset = 0; + int v_batch_seqs_offset = 0; + int o_batch_seqs_offset = 0; + + int32_t o_strideM_custom = 0; + + int32_t group_size = 1; + + float scale; + + CUTLASS_HOST_DEVICE int32_t o_strideM() const + { + if (o_strideM_custom == 0) + return BaseParams::head_dim_value; + else + return o_strideM_custom; + } + + template + CUTLASS_DEVICE void + update_batched_ptr(ptr_t& data_ptr, ptr_t* batch_seq_ptr, int batch_seq_offset, int batch_id, int strideB) + { + if (batch_seq_ptr != nullptr) + data_ptr = batch_seq_ptr[batch_id] + batch_seq_offset; + else + data_ptr += batch_id * strideB; + } + + CUTLASS_DEVICE bool advance_to_block() + { + + auto& query_ptr = BaseParams::query_ptr; + auto& key_ptr = BaseParams::key_ptr; + auto& value_ptr = BaseParams::value_ptr; + auto& cu_seqlens_q_ptr = BaseParams::cu_seqlens_q_ptr; + auto& cu_seqlens_k_ptr = BaseParams::cu_seqlens_k_ptr; + + auto& output_ptr = BaseParams::output_ptr; + auto& output_accum_ptr = BaseParams::output_accum_ptr; + auto& logsumexp_ptr = BaseParams::logsumexp_ptr; + + auto& head_dim = BaseParams::head_dim; + auto& head_dim_value = BaseParams::head_dim_value; + auto& num_queries = BaseParams::num_queries; + auto& num_keys = BaseParams::num_keys; + + auto& causal = BaseParams::causal; + + auto& q_strideM = BaseParams::q_strideM; + auto& k_strideM = BaseParams::k_strideM; + auto& v_strideM = BaseParams::v_strideM; + + // Everything below is only used in `advance_to_block` + // and shouldn't use registers + auto& q_strideH = BaseParams::q_strideH; + auto& k_strideH = BaseParams::k_strideH; + auto& v_strideH = BaseParams::v_strideH; + auto& o_strideH = BaseParams::o_strideH; + auto& q_strideB = BaseParams::q_strideB; + auto& k_strideB = BaseParams::k_strideB; + auto& v_strideB = BaseParams::v_strideB; + auto& o_strideB = BaseParams::o_strideB; + auto& num_batches = BaseParams::num_batches; + auto& num_heads = BaseParams::num_heads; + + auto batch_id = blockIdx.z; + auto head_id = blockIdx.y; + auto query_start = blockIdx.x * kQueriesPerBlock; + + auto lse_dim = ceil_div((int32_t)num_queries, kAlignLSE) * kAlignLSE; + + int64_t q_start, k_start; + + if (kSupportsBias && attn_bias_ptr != nullptr) { + attn_bias_ptr += (batch_id * bias_strideB) + (head_id * bias_strideH); + attn_bias_ptr = warp_uniform(attn_bias_ptr); + } + + // Advance to current batch - in case of different sequence lengths + int qq_start, qo_start; + if (cu_seqlens_q_ptr != nullptr) { + cu_seqlens_q_ptr += batch_id; + q_start = cu_seqlens_q_ptr[0]; + int64_t q_next_start = cu_seqlens_q_ptr[1]; + num_queries = q_next_start - q_start; + + if (query_start >= num_queries) { + return false; + } + if (!q_use_seqlens) { + update_batched_ptr(query_ptr, q_batch_seqs_ptr, q_batch_seqs_offset, batch_id, q_strideB); + qq_start = 0; + } + else { + qq_start = q_start; + } + if (!o_use_seqlens) { + update_batched_ptr(output_ptr, o_batch_seqs_ptr, o_batch_seqs_offset, batch_id, o_strideB); + qo_start = 0; + } + else { + qo_start = q_start; + } + } + else { + update_batched_ptr(query_ptr, q_batch_seqs_ptr, q_batch_seqs_offset, batch_id, q_strideB); + update_batched_ptr(output_ptr, o_batch_seqs_ptr, o_batch_seqs_offset, batch_id, o_strideB); + if (output_accum_ptr != nullptr) { + output_accum_ptr += batch_id * o_strideB; + } + q_start = 0; + qq_start = qo_start = q_start; + } + + if (cu_seqlens_k_ptr != nullptr) { + cu_seqlens_k_ptr += batch_id; + k_start = cu_seqlens_k_ptr[0]; + int64_t k_next_start = cu_seqlens_k_ptr[1]; + num_keys = k_next_start - k_start; + } + else { + update_batched_ptr(key_ptr, k_batch_seqs_ptr, k_batch_seqs_offset, batch_id, k_strideB); + update_batched_ptr(value_ptr, v_batch_seqs_ptr, v_batch_seqs_offset, batch_id, v_strideB); + k_start = 0; + } + + // Advance to the current batch / head / query_start + query_ptr += (qq_start + query_start) * q_strideM + head_id * q_strideH; + key_ptr += k_start * k_strideM + int64_t(head_id / group_size) * k_strideH; + value_ptr += k_start * v_strideM + int64_t(head_id / group_size) * v_strideH; + output_ptr += int64_t(qo_start + query_start) * o_strideM() + head_id * o_strideH; + + if (output_accum_ptr != nullptr) { + output_accum_ptr += int64_t(query_start) * o_strideM() + head_id * o_strideH; + } + else { + // Accumulate directly in the destination buffer (eg for f32) + output_accum_ptr = (accum_t*)output_ptr; + } + if (logsumexp_ptr != nullptr) { + // lse[batch_id, head_id, query_start] + logsumexp_ptr += batch_id * lse_dim * num_heads + head_id * lse_dim + query_start; + } + + num_queries -= query_start; + if (causal) { + num_keys = cutlass::fast_min(int32_t(query_start + kQueriesPerBlock), num_keys); + } + num_batches = 0; // no longer used after + + // Make sure the compiler knows these variables are the same on all + // the threads of the warp. + query_ptr = warp_uniform(query_ptr); + key_ptr = warp_uniform(key_ptr); + value_ptr = warp_uniform(value_ptr); + output_ptr = warp_uniform(output_ptr); + output_accum_ptr = warp_uniform(output_accum_ptr); + logsumexp_ptr = warp_uniform(logsumexp_ptr); + num_queries = warp_uniform(num_queries); + num_keys = warp_uniform(num_keys); + head_dim = warp_uniform(head_dim); + head_dim_value = warp_uniform(head_dim_value); + return true; + } + }; + + using MM0 = typename Base::MM0; + using MM1 = typename Base::MM1; + using BaseSharedStorageEpilogueAtEnd = typename Base::SharedStorageEpilogueAtEnd; + using BaseSharedStorageEpilogueInLoop = typename Base::SharedStorageEpilogueInLoop; + + // TODO: find a way to optimize non aligned bias + using BiasLoader = TileSmemLoader, + MM0::MmaCore::kThreads, + // input restriction: kv_len has to be a multiple of this value + 1>; // 1 per load. unless bias is aligned. + + using AccumLambdaIterator = + typename DefaultMmaAccumLambdaIterator::Iterator; + + struct SharedStorageEpilogueAtEnd: BaseSharedStorageEpilogueAtEnd { + struct SharedStorageAfterMM0 { + // Everything here might be overwritten during MM0 + union { + typename BiasLoader::SmemTile bias; + typename MM0::AccumulatorSharedStorage si; + }; + typename MM1::SharedStorageMM1 mm1; + }; + + union { + typename MM0::Mma::SharedStorage mm0; + SharedStorageAfterMM0 after_mm0; + typename MM1::DefaultEpilogue::SharedStorage epilogue; + }; + }; + + struct SharedStorageEpilogueInLoop: BaseSharedStorageEpilogueInLoop { + struct SharedStorageAfterMM0 { + // Everything here might be overwritten during MM0 + union { + typename BiasLoader::SmemTile bias; + typename MM0::AccumulatorSharedStorage si; + }; + typename MM1::SharedStorageMM1 mm1; + typename MM1::DefaultEpilogue::SharedStorage epilogue; + }; + + union { + typename MM0::Mma::SharedStorage mm0; + SharedStorageAfterMM0 after_mm0; + }; + }; + + using SharedStorage = typename cutlass::platform::conditional::type; + + static bool __host__ check_supported(Params const& p) + { + if (kSupportsBias) { + CHECK_ALIGNED_PTR(p.attn_bias_ptr, kAlignmentQ); + XFORMERS_CHECK(p.num_heads <= 1 || p.bias_strideH % kAlignmentQ == 0, + "attn_bias is not correctly aligned (strideH)"); + } + return Base::check_supported(p); + } + + static void CUTLASS_DEVICE attention_kernel(Params& p) + { + + // In this block, we will only ever: + // - read query[query_start:query_end, :] + // - write to output[query_start:query_end, :] + + extern __shared__ char smem_buffer[]; + SharedStorage& shared_storage = *((SharedStorage*)smem_buffer); + auto& m_prime = shared_storage.m_prime; + auto& s_prime = shared_storage.s_prime; + auto& si = shared_storage.after_mm0.si; + auto& mi = shared_storage.mi; + const uint32_t query_start = blockIdx.x * kQueriesPerBlock; + + static_assert(kQueriesPerBlock < kNumWarpsPerBlock * kWarpSize, ""); + if (thread_id() < kQueriesPerBlock) { + s_prime[thread_id()] = accum_t(0); + m_prime[thread_id()] = -cutlass::platform::numeric_limits::infinity(); + mi[thread_id()] = -cutlass::platform::numeric_limits::infinity(); + } + typename MM1::Mma::FragmentC accum_o; + accum_o.clear(); + + auto createOutputIter = [&](int col) -> typename MM1::OutputTileIterator { + using OutputTileIterator = typename MM1::OutputTileIterator; + return OutputTileIterator(typename OutputTileIterator::Params{(int32_t)p.o_strideM()}, + p.output_ptr, + typename OutputTileIterator::TensorCoord{p.num_queries, p.head_dim_value}, + thread_id(), + {0, col}); + }; + + auto createOutputAccumIter = [&](int col) -> typename MM1::OutputTileIteratorAccum { + using OutputTileIteratorAccum = typename MM1::OutputTileIteratorAccum; + return OutputTileIteratorAccum( + typename OutputTileIteratorAccum::Params{(int32_t)p.o_strideM()}, + p.output_accum_ptr, + typename OutputTileIteratorAccum::TensorCoord{p.num_queries, p.head_dim_value}, + thread_id(), + {0, col}); + }; + + // Iterate through keys + for (int32_t iter_key_start = 0; iter_key_start < p.num_keys; iter_key_start += kKeysPerBlock) { + int32_t problem_size_0_m = cutlass::fast_min((int32_t)kQueriesPerBlock, p.num_queries); + int32_t problem_size_0_n = cutlass::fast_min(int32_t(kKeysPerBlock), p.num_keys - iter_key_start); + int32_t const& problem_size_0_k = p.head_dim; + int32_t const& problem_size_1_n = p.head_dim_value; + int32_t const& problem_size_1_k = problem_size_0_n; + + auto prologueV = [&](int blockN) { + typename MM1::Mma::IteratorB iterator_V(typename MM1::IteratorB::Params{MM1::LayoutB(p.v_strideM)}, + p.value_ptr + iter_key_start * p.v_strideM, + {problem_size_1_k, problem_size_1_n}, + thread_id(), + cutlass::MatrixCoord{0, blockN * MM1::Mma::Shape::kN}); + MM1::Mma::prologue(shared_storage.after_mm0.mm1.mm, iterator_V, thread_id(), problem_size_1_k); + }; + + __syncthreads(); // Need to have shared memory initialized, and `m_prime` + // updated from end of prev iter + + // MATMUL: Q.K_t + // + // Computes the block-matrix product of: + // (a) query[query_start:query_end, :] + // with + // (b) key[iter_key_start:iter_key_start + kKeysPerBlock] + // and stores that into `shared_storage.si` + // + + // Compute threadblock location + cutlass::gemm::GemmCoord tb_tile_offset = {0, 0, 0}; + + cutlass::MatrixCoord tb_offset_A{tb_tile_offset.m() * MM0::Mma::Shape::kM, tb_tile_offset.k()}; + + cutlass::MatrixCoord tb_offset_B{tb_tile_offset.k(), tb_tile_offset.n() * MM0::Mma::Shape::kN}; + + // Construct iterators to A and B operands + typename MM0::IteratorA iterator_A( + typename MM0::IteratorA::Params(typename MM0::MmaCore::LayoutA(p.q_strideM)), + p.query_ptr, + {problem_size_0_m, problem_size_0_k}, + thread_id(), + tb_offset_A); + + typename MM0::IteratorB iterator_B( + typename MM0::IteratorB::Params(typename MM0::MmaCore::LayoutB(p.k_strideM)), + p.key_ptr + iter_key_start * p.k_strideM, + {problem_size_0_k, problem_size_0_n}, + thread_id(), + tb_offset_B); + + auto my_warp_id = warp_id(); + auto my_lane_id = lane_id(); + + // Construct thread-scoped matrix multiply + typename MM0::Mma mma(shared_storage.mm0, thread_id(), my_warp_id, my_lane_id); + + typename MM0::Mma::FragmentC accum; + + accum.clear(); + + auto gemm_k_iterations = (problem_size_0_k + MM0::Mma::Shape::kK - 1) / MM0::Mma::Shape::kK; + + // Compute threadblock-scoped matrix multiply-add + mma(gemm_k_iterations, accum, iterator_A, iterator_B, accum); + __syncthreads(); + + if (kPreloadV) { + prologueV(0); + } + + typename MM0::Mma::Operator::IteratorC::TensorCoord iteratorC_tile_offset = { + (tb_tile_offset.m() * MM0::Mma::WarpCount::kM) + (my_warp_id % MM0::Mma::WarpCount::kM), + (tb_tile_offset.n() * MM0::Mma::WarpCount::kN) + (my_warp_id / MM0::Mma::WarpCount::kM)}; + + // multiply by scaling factor + if (kSupportsBias) { + accum = cutlass::multiplies()(p.scale, accum); + } + + // apply attention bias if applicable + if (kSupportsBias && p.attn_bias_ptr != nullptr) { + // load bias tile Bij into shared memory + typename BiasLoader::GmemTileIterator bias_iter( + {cutlass::layout::RowMajor(p.bias_strideM)}, + // attn_bias_pointer points to matrix of size (n_queries, n_keys) + // for the relevant batch_id and head_id + p.attn_bias_ptr + query_start * p.bias_strideM + iter_key_start, + {problem_size_0_m, problem_size_0_n}, + thread_id()); + cutlass::TensorRef bias_tensor_ref( + shared_storage.after_mm0.bias.data(), cutlass::layout::RowMajor(MM0::ThreadblockShape::kN)); + typename BiasLoader::SmemTileIterator smem_tile_iter(bias_tensor_ref, thread_id()); + BiasLoader::load(bias_iter, smem_tile_iter); + + // Pij += Bij, Pij is in register fragment and Bij is in shared memory + auto lane_offset = AccumLambdaIterator::get_lane_offset(lane_id(), warp_id(), iteratorC_tile_offset); + AccumLambdaIterator::iterateRows( + lane_offset, + [&](int accum_m) {}, + [&](int accum_m, int accum_n, int idx) { + if (accum_m < problem_size_0_m && accum_n < problem_size_0_n) { + accum[idx] += (1.0f - bias_tensor_ref.at({accum_m, accum_n})) * -1e5f; + } + }, + [&](int accum_m) {}); + } + + DISPATCH_BOOL( + iter_key_start == 0, kIsFirst, ([&] { + DISPATCH_BOOL( + p.num_keys - iter_key_start >= kKeysPerBlock, kFullColumns, ([&] { + // Update `mi` from accum stored in registers + // Also updates `accum` with accum[i] <- + // exp(accum[i] * scale + // - mi) + MM0::ScalingCoefsUpdater::update( + accum_o, + accum, + mi, + m_prime, + s_prime, + lane_id(), + thread_id(), + warp_id(), + p.num_keys - iter_key_start, + iteratorC_tile_offset, + kSupportsBias ? 1.0f : p.scale); + })); + })); + + // Output results to shared-memory + int warp_idx_mn_0 = my_warp_id % (MM0::Mma::Base::WarpCount::kM * MM0::Mma::Base::WarpCount::kN); + auto output_tile_coords = cutlass::MatrixCoord{warp_idx_mn_0 % MM0::Mma::Base::WarpCount::kM, + warp_idx_mn_0 / MM0::Mma::Base::WarpCount::kM}; + + MM0::B2bGemm::accumToSmem(shared_storage.after_mm0.si, accum, my_lane_id, output_tile_coords); + + __syncthreads(); + + // + // MATMUL: Attn . V + // Run the matmul `attn @ V` for a block of attn and V. + // `attn` is read from shared memory (in `shared_storage_si`) + // `V` is read from global memory (with iterator_B) + // + + const int64_t nBlockN = + kSingleValueIteration ? 1 : ceil_div((int64_t)problem_size_1_n, int64_t(MM1::ThreadblockShape::kN)); + for (int blockN = 0; blockN < nBlockN; ++blockN) { + int gemm_k_iterations = (problem_size_1_k + MM1::Mma::Shape::kK - 1) / MM1::Mma::Shape::kK; + + // Compute threadblock-scoped matrix multiply-add and store it in accum + // (in registers) + if (!kPreloadV) { + __syncthreads(); // we share shmem between mma and epilogue + } + + typename MM1::Mma::IteratorB iterator_V(typename MM1::IteratorB::Params{MM1::LayoutB(p.v_strideM)}, + p.value_ptr + iter_key_start * p.v_strideM, + {problem_size_1_k, problem_size_1_n}, + thread_id(), + cutlass::MatrixCoord{0, blockN * MM1::Mma::Shape::kN}); + typename MM1::Mma mma_pv(shared_storage.after_mm0.mm1.mm, + shared_storage.after_mm0.si, + (int)thread_id(), + (int)warp_id(), + (int)lane_id(), + (int)problem_size_1_k); + mma_pv.set_prologue_done(kPreloadV); + if (!kKeepOutputInRF) { + accum_o.clear(); + } + mma_pv(gemm_k_iterations, accum_o, iterator_V, accum_o); + __syncthreads(); + + if (kPreloadV && !kSingleValueIteration && blockN + 1 < nBlockN) { + prologueV(blockN + 1); + } + + if (!kKeepOutputInRF) { + DISPATCH_BOOL( + iter_key_start == 0, kIsFirst, ([&] { + DISPATCH_BOOL( + (iter_key_start + kKeysPerBlock) >= p.num_keys, kIsLast, ([&] { + using DefaultEpilogue = typename MM1::DefaultEpilogue; + using DefaultOp = typename MM1::DefaultConfig::EpilogueOutputOp; + using ElementCompute = typename DefaultOp::ElementCompute; + using EpilogueOutputOp = + typename cutlass::epilogue::thread::MemoryEfficientAttentionNormalize< + typename cutlass::platform::conditional:: + type, + output_accum_t, + DefaultOp::kCount, + typename DefaultOp::ElementAccumulator, + ElementCompute, + kIsFirst, + kIsLast, + cutlass::Array>; + using Epilogue = typename cutlass::epilogue::threadblock::EpiloguePipelined< + typename DefaultEpilogue::Shape, + typename MM1::Mma::Operator, + DefaultEpilogue::kPartitionsK, + typename cutlass::platform::conditional< + kIsLast, + typename MM1::OutputTileIterator, + typename MM1::OutputTileIteratorAccum>::type, + typename DefaultEpilogue::AccumulatorFragmentIterator, + typename DefaultEpilogue::WarpTileIterator, + typename DefaultEpilogue::SharedLoadIterator, + EpilogueOutputOp, + typename DefaultEpilogue::Padding, + DefaultEpilogue::kFragmentsPerIteration, + true, // IterationsUnroll + typename MM1::OutputTileIteratorAccum // Read + // iterator + >; + + int col = blockN * MM1::Mma::Shape::kN; + auto source_iter = createOutputAccumIter(col); + auto dest_iter = + call_conditional::apply(createOutputIter, + createOutputAccumIter, + col); + EpilogueOutputOp rescale(s_prime, m_prime); + Epilogue epilogue( + shared_storage.epilogue_shared_storage(), thread_id(), warp_id(), lane_id()); + epilogue(rescale, dest_iter, accum_o, source_iter); + })); + })); + if (!kSingleValueIteration) { + __syncthreads(); + } + } + } + __syncthreads(); // we modify `m_prime` after + } + + if (kKeepOutputInRF) { + constexpr bool kIsFirst = true; + constexpr bool kIsLast = true; + using DefaultEpilogue = typename MM1::DefaultEpilogue; + using DefaultOp = typename MM1::DefaultConfig::EpilogueOutputOp; + using ElementCompute = typename DefaultOp::ElementCompute; + using EpilogueOutputOp = typename cutlass::epilogue::thread::MemoryEfficientAttentionNormalize< + output_t, // output + output_accum_t, // source + DefaultOp::kCount, + typename DefaultOp::ElementAccumulator, // accum + output_accum_t, // compute + kIsFirst, + kIsLast, + cutlass::Array>; + using Epilogue = typename cutlass::epilogue::threadblock::EpiloguePipelined< + typename DefaultEpilogue::Shape, + typename MM1::Mma::Operator, + DefaultEpilogue::kPartitionsK, + typename MM1::OutputTileIterator, // destination + typename DefaultEpilogue::AccumulatorFragmentIterator, + typename DefaultEpilogue::WarpTileIterator, + typename DefaultEpilogue::SharedLoadIterator, + EpilogueOutputOp, + typename DefaultEpilogue::Padding, + DefaultEpilogue::kFragmentsPerIteration, + true, // IterationsUnroll + typename MM1::OutputTileIteratorAccum // source tile + >; + auto dest_iter = createOutputIter(0); + EpilogueOutputOp rescale(s_prime, m_prime); + Epilogue epilogue(shared_storage.epilogue_shared_storage(), thread_id(), warp_id(), lane_id()); + epilogue(rescale, dest_iter, accum_o); + } + + // 7. Calculate logsumexp + // To make the backward easier, we pad logsumexp with `inf` + // this avoids a few bound checks, and is not more expensive during fwd + static_assert(kQueriesPerBlock < kNumWarpsPerBlock * kWarpSize, ""); + if (p.logsumexp_ptr && thread_id() < kQueriesPerBlock) { + auto lse_dim = ceil_div((int32_t)p.num_queries, kAlignLSE) * kAlignLSE; + if (thread_id() < p.num_queries) { + p.logsumexp_ptr[thread_id()] = + accum_t(mi[thread_id()]) + cutlass::fast_log(accum_t(s_prime[thread_id()])); + } + else if (thread_id() < lse_dim) { + p.logsumexp_ptr[thread_id()] = cutlass::platform::numeric_limits::infinity(); + } + } + } + + static CUTLASS_DEVICE int8_t lane_id() + { + return Base::lane_id(); + } + static CUTLASS_DEVICE int8_t warp_id() + { + return Base::warp_id(); + } + static CUTLASS_DEVICE int16_t thread_id() + { + return Base::thread_id(); + } +}; + +template +void invokeFlashAttention_impl(int batch_size, + int head_num, + int key_len, + int seq_len, + int size_per_head, + typename FlashAttentionOp::Params& attention_params, + cudaStream_t st) +{ + T* out_ptr = attention_params.attn_out; + T* query_ptr = attention_params.query; + T* key_ptr = attention_params.key; + T* value_ptr = attention_params.val; + T* mask_ptr = attention_params.mask; + float* output_accum_ptr = attention_params.out_accum; + auto* cu_seqlens_q_ptr = attention_params.cu_seqlens_q; + auto layout_q = attention_params.layout_q; + auto layout_k = attention_params.layout_k; + auto layout_v = attention_params.layout_v; + auto layout_o = attention_params.layout_o; + auto group_size = attention_params.group_size; + + using scalar_t = + typename std::conditional_t::type>::value, cutlass::half_t, T>; + + const float qk_scale = static_cast(1.f / sqrtf(size_per_head * 1.f)); + + constexpr bool kNeedsOutputAccumulatorBuffer = Attention::kNeedsOutputAccumulatorBuffer; + if (kNeedsOutputAccumulatorBuffer) { + assert(output_accum_ptr != nullptr); + } + + // fill param + typename Attention::Params params{}; + { + params.query_ptr = (scalar_t*)(query_ptr); + params.key_ptr = (scalar_t*)(key_ptr); + params.value_ptr = (scalar_t*)(value_ptr); + params.attn_bias_ptr = (scalar_t*)(mask_ptr); + params.cu_seqlens_q_ptr = cu_seqlens_q_ptr; + + params.output_ptr = (scalar_t*)(out_ptr); + params.output_accum_ptr = kNeedsOutputAccumulatorBuffer ? output_accum_ptr : nullptr; + params.logsumexp_ptr = nullptr; + + params.scale = qk_scale; + + params.head_dim = size_per_head; + params.head_dim_value = size_per_head; + params.num_queries = seq_len; + params.num_keys = key_len; + + params.bias_strideH = 0; + params.bias_strideM = key_len; + params.bias_strideB = seq_len * params.bias_strideM; + + params.q_strideH = layout_q.stride_head; + params.q_strideM = layout_q.stride_seq; + params.q_strideB = layout_q.stride_batch; + params.q_use_seqlens = layout_q.use_seqlens; + params.q_batch_seqs_ptr = (scalar_t**)(layout_q.batch_seqs); + params.q_batch_seqs_offset = layout_q.batch_seqs_offset; + + params.k_strideH = layout_k.stride_head; + params.k_strideM = layout_k.stride_seq; + params.k_strideB = layout_k.stride_batch; + params.k_batch_seqs_ptr = (scalar_t**)layout_k.batch_seqs; + params.k_batch_seqs_offset = layout_k.batch_seqs_offset; + + params.v_strideH = layout_v.stride_head; + params.v_strideM = layout_v.stride_seq; + params.v_strideB = layout_v.stride_batch; + params.v_batch_seqs_ptr = (scalar_t**)layout_v.batch_seqs; + params.v_batch_seqs_offset = layout_v.batch_seqs_offset; + + params.o_strideH = layout_o.stride_head; + params.o_strideM_custom = layout_o.stride_seq; + params.o_strideB = layout_o.stride_batch; + params.o_use_seqlens = layout_o.use_seqlens; + params.o_batch_seqs_ptr = (scalar_t**)layout_o.batch_seqs; + params.o_batch_seqs_offset = layout_o.batch_seqs_offset; + + params.num_batches = batch_size; + params.num_heads = head_num; + + params.group_size = int32_t(group_size); + } + + Attention::check_supported(params); + + // start kernel + auto block_grid = params.getBlocksGrid(); + auto thread_grid = params.getThreadsGrid(); + + int smem_bytes = sizeof(typename Attention::SharedStorage); + if (smem_bytes > 0xc000) { + cudaFuncSetAttribute( + attention_kernel_batched_impl, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_bytes); + } + + attention_kernel_batched_impl<<>>(params); +} + +#define CUTLASS_ARCH(sm) cutlass::arch::Sm##sm + +#define ATTENTION_KERNEL(scalar_t, sm, querys_per_block, keys_per_block, single_value) \ + LlamaAttentionKernel + +template +bool get_needs_accum_buffer() +{ + using scalar_t = + typename std::conditional_t::type>::value, cutlass::half_t, T>; + +#define GET_NEED_ACCUM_BUFFER(sm) \ + ATTENTION_KERNEL(scalar_t, sm, kQueriesPerBlock, kKeysPerBlock, false)::kNeedsOutputAccumulatorBuffer + + auto sm = getSMVersion(); + + switch (sm) { + case 75: + return GET_NEED_ACCUM_BUFFER(75); + default: + if (sm >= 80) { + return GET_NEED_ACCUM_BUFFER(80); + } + else { + return GET_NEED_ACCUM_BUFFER(70); + } + } +#undef GET_NEED_ACCUM_BUFFER +} + +template +void invoke_attention_impl(bool single_val_iteration, + int batch_size, + int head_num, + int key_len, + int seq_len, + int size_per_head, + typename FlashAttentionOp::Params& params, + cudaStream_t st) +{ + using scalar_t = + typename std::conditional_t::type>::value, cutlass::half_t, T>; + +#define INVOKE_ATTEN_IMPL(sm, single_value) \ + { \ + using AttentionKernel = ATTENTION_KERNEL(scalar_t, sm, kQueriesPerBlock, kKeysPerBlock, single_value); \ + invokeFlashAttention_impl( \ + batch_size, head_num, key_len, seq_len, size_per_head, params, st); \ + } + +#define INVOKE_ATTENN_IMPL_V2(sm) \ + { \ + if (single_val_iteration) \ + INVOKE_ATTEN_IMPL(sm, true) \ + else \ + INVOKE_ATTEN_IMPL(sm, false) \ + } + + auto sm = getSMVersion(); + switch (sm) { + case 75: + INVOKE_ATTENN_IMPL_V2(75); + break; + default: + if (sm >= 80) { + INVOKE_ATTENN_IMPL_V2(80); + } + else { + INVOKE_ATTENN_IMPL_V2(70); + } + } + +#undef INVOKE_ATTENN_IMPL_V2 +#undef INVOKE_ATTEN_IMPL +} + +template +class FlashAttentionOp::impl { + +private: + static constexpr int kQueriesPerBlock = 32; + static constexpr int kKeysPerBlock = 128; + using ArchTag = cutlass::arch::Sm80; + using scalar_t = + typename std::conditional_t::type>::value, cutlass::half_t, T>; + using SingleValueAttention = LlamaAttentionKernel; + using MultiValueAttention = LlamaAttentionKernel; + using AttentionLayout = typename FlashAttentionOp::AttentionLayout; + using Params = typename FlashAttentionOp::Params; + + int batch_size_; + int head_num_; + int key_len_; + int seq_len_; + int size_per_head_; + bool kSingleValueIteration; + +public: + impl(int batch_size, int head_num, int key_len, int seq_len, int size_per_head): + batch_size_(batch_size), + head_num_(head_num), + key_len_(key_len), + seq_len_(seq_len), + size_per_head_(size_per_head) + { + kSingleValueIteration = (size_per_head <= kKeysPerBlock); + } + + ~impl() {} + + int get_workspace_size() const + { + if (kSingleValueIteration) { + return 0; + } + else { + bool kNeedsOutputAccumulatorBuffer = get_needs_accum_buffer(); + if (kNeedsOutputAccumulatorBuffer) { + return batch_size_ * head_num_ * seq_len_ * size_per_head_ * sizeof(float); + } + else { + return 0; + } + } + } + + void operator()(Params& params, cudaStream_t st) const + { + invoke_attention_impl( + kSingleValueIteration, batch_size_, head_num_, key_len_, seq_len_, size_per_head_, params, st); + } +}; + +template +FlashAttentionOp::FlashAttentionOp(int batch_size, int head_num, int key_len, int seq_len, int size_per_head): + pimpl{std::make_unique::impl>(batch_size, head_num, key_len, seq_len, size_per_head)} +{ +} + +template +FlashAttentionOp::~FlashAttentionOp() +{ +} + +template +int FlashAttentionOp::get_workspace_size() const +{ + return pimpl->get_workspace_size(); +} + +template +void FlashAttentionOp::operator()(Params& params, cudaStream_t st) const +{ + pimpl->operator()(params, st); +} + +template class FlashAttentionOp; +template class FlashAttentionOp; + +} // namespace fastertransformer diff --git a/src/fastertransformer/models/llama/fused_multi_head_attention/mma_accum_lambda_iterator.h b/src/fastertransformer/models/llama/fused_multi_head_attention/mma_accum_lambda_iterator.h new file mode 100644 index 000000000..3ddeb1339 --- /dev/null +++ b/src/fastertransformer/models/llama/fused_multi_head_attention/mma_accum_lambda_iterator.h @@ -0,0 +1,309 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights + *reserved. SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, + *this list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holdvr nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE + *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + *POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include "cutlass/functional.h" +#include "cutlass/gemm/warp/mma_simt_tile_iterator.h" +#include "cutlass/gemm/warp/mma_tensor_op_tile_iterator_sm70.h" +#include "cutlass/gemm/warp/mma_tensor_op_tile_iterator_sm80.h" +#include "cutlass/matrix_shape.h" + +/* +TensorCores have different accumulator layouts. +This file provides a class to easily map the accumulator +i-th element with the corresponding matrix row/col. +*/ + +template +struct AccumLambdaIteratorSm80 { + static_assert(cutlass::platform::is_same::value, + "only RowMajor is supported"); + + using Policy = typename T::Policy; + using InstructionShape = typename T::InstructionShape; + using OpDelta = typename T::OpDelta; + using Shape = typename T::Shape; + static int const kElementsPerAccess = InstructionShape::kN / 4; + static int const kRowsPerTile = 8; + static int const kAccumulatorRows = InstructionShape::kM / kRowsPerTile; + + static cutlass::MatrixCoord CUTLASS_DEVICE get_lane_offset(int8_t lane_id, + int8_t warp_id, + typename T::TensorCoord const& tile_offset) + { + int quad = (lane_id >> 2); + int lane_in_quad = (lane_id & 3); + return cutlass::MatrixCoord(quad + tile_offset.row() * Shape::kRow, + lane_in_quad * kElementsPerAccess + tile_offset.column() * Shape::kColumn); + } + + template + CUTLASS_DEVICE static void iterateRows(cutlass::MatrixCoord& lane_offset, FA beginRow, FB op, FC endRow) + { + // See cutlass/gemm/warp/mma_tensor_op_tile_iterator.h + CUTLASS_PRAGMA_UNROLL + for (int mma_m = 0; mma_m < Policy::MmaIterations::kRow; ++mma_m) { + CUTLASS_PRAGMA_UNROLL + for (int row = 0; row < kAccumulatorRows; ++row) { + int accum_m = mma_m * InstructionShape::kM * OpDelta::kRow + row * kRowsPerTile + lane_offset.row(); + beginRow(accum_m); + + CUTLASS_PRAGMA_UNROLL + for (int mma_n = 0; mma_n < Policy::MmaIterations::kColumn; ++mma_n) { + int mma_accum_start = + kAccumulatorRows * kElementsPerAccess * (mma_n * Policy::MmaIterations::kRow + mma_m); + CUTLASS_PRAGMA_UNROLL + for (int col = 0; col < kElementsPerAccess; ++col) { + int accum_n = mma_n * InstructionShape::kN * OpDelta::kColumn + col + lane_offset.column(); + int idx = mma_accum_start + row * kElementsPerAccess + col; + op(accum_m, accum_n, idx); + } + } + + endRow(accum_m); + } + } + } + + template + CUTLASS_DEVICE static bool reduceSameRow(int lane_id, DT& myValue, F fn) + { + // In each warp, 4 threads will work on the same row + // - the ones with the same `quad` + auto otherV = __shfl_xor_sync(0xffffffff, myValue, 1); + myValue = fn(myValue, otherV); + otherV = __shfl_xor_sync(0xffffffff, myValue, 2); + myValue = fn(myValue, otherV); + int lane_in_quad = (lane_id & 3); + return lane_in_quad == 0; + } +}; + +template +struct AccumLambdaIteratorSm70 { + static_assert(cutlass::platform::is_same::value, + "only RowMajor is supported"); + + using Policy = typename T::Policy; + using InstructionShape = typename T::InstructionShape; + using OpDelta = typename T::OpDelta; + using Shape = typename T::Shape; + using Element = accum_t; + + static int const kElementsPerPartial = 4; + using EleShapePerPatial = typename cutlass::platform::conditional::value, + cutlass::MatrixShape<2, 2>, + cutlass::MatrixShape<1, 4>>::type; + static int const kElementsPerMma = 8; + static int const kAccumulatorPatials = 2; + using QuadShapePerPatialMma = cutlass::MatrixShape<4, 4>; + + static cutlass::MatrixCoord CUTLASS_DEVICE get_lane_offset(int8_t lane_id, + int8_t warp_id, + typename T::TensorCoord const& tile_offset) + { + int quad = (lane_id >> 2); + int lane_in_quad = (lane_id & 3); + int accum_m, accum_n; + + if (cutlass::platform::is_same::value) { + // (quad[2],quad[0])+lane_in_quad[0] + accum_m = (((quad & 0x4) >> 1) + (quad & 0x1)) * 8 + (lane_in_quad & 1); + // (quad[1])+lane_in_quad[1] + accum_n = ((quad >> 1) & 0x1) * kElementsPerPartial * kAccumulatorPatials + (lane_in_quad & 2); + } + else { + accum_m = (((quad & 0x4) >> 1) + (quad & 0x1)) * 8 + lane_in_quad; // (quad[2],quad[0]) + accum_n = ((quad >> 1) & 0x1) * kElementsPerPartial * kAccumulatorPatials; + } + return cutlass::MatrixCoord(accum_m + tile_offset.row() * Shape::kRow, + accum_n + tile_offset.column() * Shape::kColumn); + } + + template + CUTLASS_DEVICE static bool reduceSameRow(int lane_id, DT& myValue, F fn) + { + static_assert(cutlass::platform::is_same::value, "update to support non-float accum"); + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-fragment-mma-884-f16 + // T0 & T2 share same line within a quad + auto otherV = __shfl_xor_sync(0xffffffff, myValue, 1 << 1); + myValue = fn(myValue, otherV); + // quad 0 and quad 2 are on the same lines + otherV = __shfl_xor_sync(0xffffffff, myValue, 1 << 3); + myValue = fn(myValue, otherV); + return (lane_id & ((1 << 1) | (1 << 3))) == 0; + } + + template + CUTLASS_DEVICE static void iterateRows(cutlass::MatrixCoord& lane_offset, FA beginRow, FB op, FC endRow) + { + CUTLASS_PRAGMA_UNROLL + for (int tile_m = 0; tile_m < Policy::TileIterations::kRow; ++tile_m) { + CUTLASS_PRAGMA_UNROLL + for (int mma_m = 0; mma_m < Policy::MmaIterations::kRow; ++mma_m) { + CUTLASS_PRAGMA_UNROLL + for (int m = 0; m < EleShapePerPatial::kRow; ++m) { + int accum_m = tile_m * Policy::InterleavedTile::kRow + mma_m * QuadShapePerPatialMma::kRow + m * 2 + + lane_offset.row(); + beginRow(accum_m); + + CUTLASS_PRAGMA_UNROLL + for (int tile_n = 0; tile_n < Policy::TileIterations::kColumn; ++tile_n) { + CUTLASS_PRAGMA_UNROLL + for (int mma_n = 0; mma_n < Policy::MmaIterations::kColumn; ++mma_n) { + CUTLASS_PRAGMA_UNROLL + for (int p = 0; p < kAccumulatorPatials; ++p) { + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < EleShapePerPatial::kColumn; ++n) { + int mma_accum_start = (((tile_n * Policy::TileIterations::kRow + tile_m) + * Policy::MmaIterations::kColumn + + mma_n) + * Policy::MmaIterations::kRow + + mma_m) + * kElementsPerMma; + int accum_n = tile_n * Policy::InterleavedTile::kColumn + + mma_n * QuadShapePerPatialMma::kColumn + + p * Policy::InterleavedTile::kColumn / 2 + n + lane_offset.column(); + int idx = + mma_accum_start + p * kElementsPerPartial + m * EleShapePerPatial::kColumn + n; + op(accum_m, accum_n, idx); + } + } + } + } + endRow(accum_m); + } + } + } + } +}; + +template +struct AccumLambdaIteratorSimt { + using Policy = typename T::Policy; + using Iterations = typename T::Iterations; + using Element = typename T::Element; + using Delta = typename T::Delta; + using Shape = typename T::Shape; + static_assert(cutlass::platform::is_same::value, + "only RowMajor is supported"); + + template + CUTLASS_DEVICE static bool reduceSameRow(int lane_id, DT& myValue, F fn) + { + CUTLASS_PRAGMA_UNROLL + for (int bit = 1; bit < Policy::WarpShape::kColumn; bit *= 2) { + auto otherV = __shfl_xor_sync(0xffffffff, myValue, bit); + myValue = fn(myValue, otherV); + } + return (lane_id & (Policy::WarpShape::kColumn - 1)) == 0; + } + + template + CUTLASS_DEVICE static void iterateRows(cutlass::MatrixCoord& lane_offset, FA beginRow, FB op, FC endRow) + { + CUTLASS_PRAGMA_UNROLL + for (int mma_m = 0; mma_m < Iterations::kRow; ++mma_m) { + CUTLASS_PRAGMA_UNROLL + for (int m = 0; m < Policy::LaneMmaShape::kM; ++m) { + int accum_m = mma_m * Delta::kRow + m + lane_offset.row(); + beginRow(accum_m); + + CUTLASS_PRAGMA_UNROLL + for (int mma_n = 0; mma_n < Iterations::kColumn; ++mma_n) { + int accum_n = mma_n * Policy::WarpShape::kColumn * Policy::LaneMmaShape::kN + lane_offset.column(); + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < Policy::LaneMmaShape::kN; ++n) { + int idx = n + + Policy::LaneMmaShape::kN + * (mma_n + Iterations::kColumn * (m + mma_m * Policy::LaneMmaShape::kM)); + op(accum_m, accum_n + n, idx); + } + } + endRow(accum_m); + } + } + } + + static cutlass::MatrixCoord CUTLASS_DEVICE get_lane_offset(int8_t lane_id, + int8_t warp_id, + typename T::TensorCoord const& tile_offset) + { + static_assert( + cutlass::platform::is_same>::value, + ""); + typename Policy::LaneLayout lane_layout = Policy::get_lane_layout(); + + cutlass::MatrixCoord lane_offset = + lane_layout.inverse(lane_id) * cutlass::MatrixCoord(Policy::LaneMmaShape::kM, Policy::LaneMmaShape::kN); + return lane_offset + tile_offset * cutlass::MatrixCoord(Shape::kRow, Shape::kColumn); + } +}; + +template +struct DefaultMmaAccumLambdaIterator; + +// Simt +template +struct DefaultMmaAccumLambdaIterator< + cutlass::gemm::warp:: + MmaSimtTileIterator, + accum_t, + kWarpSize> { + using WarpIterator = typename cutlass::gemm::warp:: + MmaSimtTileIterator; + using Iterator = AccumLambdaIteratorSimt; +}; + +// TensorOp - Volta +template +struct DefaultMmaAccumLambdaIterator< + cutlass::gemm::warp:: + MmaVoltaTensorOpAccumulatorTileIterator>, + accum_t, + kWarpSize> { + using WarpIterator = typename cutlass::gemm::warp:: + MmaVoltaTensorOpAccumulatorTileIterator>; + using Iterator = AccumLambdaIteratorSm70; +}; + +// TensorOp - Sm75+ +template +struct DefaultMmaAccumLambdaIterator< + cutlass::gemm::warp::MmaTensorOpAccumulatorTileIterator, + accum_t, + kWarpSize> { + using WarpIterator = typename cutlass::gemm::warp:: + MmaTensorOpAccumulatorTileIterator; + using Iterator = AccumLambdaIteratorSm80; +}; diff --git a/src/fastertransformer/models/llama/fused_multi_head_attention/tile_smem_loader.h b/src/fastertransformer/models/llama/fused_multi_head_attention/tile_smem_loader.h new file mode 100644 index 000000000..070da1bb8 --- /dev/null +++ b/src/fastertransformer/models/llama/fused_multi_head_attention/tile_smem_loader.h @@ -0,0 +1,84 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights + *reserved. SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, + *this list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holdvr nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE + *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + *POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#include + +#include "cutlass/aligned_buffer.h" +#include "cutlass/array.h" +#include "cutlass/layout/matrix.h" +#include "cutlass/layout/pitch_linear.h" +#include "cutlass/numeric_types.h" +#include "cutlass/transform/pitch_linear_thread_map.h" +#include "cutlass/transform/threadblock/predicated_tile_iterator.h" +#include "cutlass/transform/threadblock/regular_tile_iterator.h" + +template // thread access width in elements +class TileSmemLoader { +public: + using SmemTile = cutlass::AlignedBuffer; + + using ThreadMap = cutlass::transform::PitchLinearStripminedThreadMap< + cutlass::layout::PitchLinearShape, // strided + Threads, // Threads + ElementsPerAccess>; // ElementsPerAccess + + using GmemTileIterator = + cutlass::transform::threadblock::PredicatedTileIterator; // ThreadMap + + using SmemTileIterator = cutlass::transform::threadblock::RegularTileIterator; // ThreadMap + + using Fragment = typename GmemTileIterator::Fragment; + + /// load a tile from global memory into shared memory + CUTLASS_DEVICE + static void load(GmemTileIterator tile_load_iter, SmemTileIterator tile_store_iter) + { + Fragment tb_frag; + tb_frag.clear(); + tile_load_iter.load(tb_frag); + tile_store_iter.store(tb_frag); + + __syncthreads(); + } +}; diff --git a/src/fastertransformer/models/llama/llama_decoder_kernels.cu b/src/fastertransformer/models/llama/llama_decoder_kernels.cu new file mode 100644 index 000000000..04a81593d --- /dev/null +++ b/src/fastertransformer/models/llama/llama_decoder_kernels.cu @@ -0,0 +1,165 @@ +// Copyright (c) OpenMMLab. All rights reserved. + +#include "src/fastertransformer/models/llama/llama_decoder_kernels.h" +#include "src/fastertransformer/utils/cuda_utils.h" +#include +#include +#include + +namespace cg = cooperative_groups; + +namespace fastertransformer { + +template +struct res_norm_ops_t { +}; + +template +struct res_norm_t { + res_norm_ops_t f; + __device__ uint4 addvec(const uint4& a, const uint4& b, const uint4& bias, float& accum) const + { + uint4 c; + c.x = f.cast(f.add(f.cast(a.x), f.cast(b.x), f.cast(bias.x), accum)); + c.y = f.cast(f.add(f.cast(a.y), f.cast(b.y), f.cast(bias.y), accum)); + c.z = f.cast(f.add(f.cast(a.z), f.cast(b.z), f.cast(bias.z), accum)); + c.w = f.cast(f.add(f.cast(a.w), f.cast(b.w), f.cast(bias.w), accum)); + return c; + } + __device__ uint4 normvec(const uint4& u, const uint4& s, float factor) const + { + uint4 v; + v.x = f.cast(f.norm(f.cast(u.x), f.cast(s.x), factor)); + v.y = f.cast(f.norm(f.cast(u.y), f.cast(s.y), factor)); + v.z = f.cast(f.norm(f.cast(u.z), f.cast(s.z), factor)); + v.w = f.cast(f.norm(f.cast(u.w), f.cast(s.w), factor)); + return v; + } +}; + +template<> +struct res_norm_ops_t { + __device__ float2 cast(const uint& x) const + { + return __half22float2(reinterpret_cast(x)); + } + __device__ uint cast(const float2& x) const + { + auto y = __float22half2_rn(x); + return reinterpret_cast(y); + } + __device__ float2 add(const float2& a, const float2& b, const float2& bias, float& accum) const + { + float2 c{a.x + b.x + bias.x, a.y + b.y + bias.y}; + accum += c.x * c.x + c.y * c.y; + return c; + } + __device__ float2 norm(const float2& a, const float2& s, float factor) const + { + return {a.x * s.x * factor, a.y * s.y * factor}; + } +}; + +template<> +struct res_norm_ops_t { + __device__ float cast(const uint& x) const + { + return reinterpret_cast(x); + } + __device__ uint cast(const float& x) const + { + return reinterpret_cast(x); + } + __device__ float add(const float& a, const float& b, const float& bias, float& accum) const + { + float c = a + b + bias; + accum += c * c; + return c; + } + __device__ float norm(const float& a, const float& s, float factor) const + { + return a * s * factor; + } +}; + +template +__device__ T blockReduceSum(const cg::thread_block& block, T value) +{ + __shared__ float partial[32]; + + auto tile = cg::tiled_partition<32>(block); + value = cg::reduce(tile, value, cg::plus{}); + + if (tile.thread_rank() == 0) { + partial[tile.meta_group_rank()] = value; + } + + block.sync(); + + value = tile.thread_rank() < tile.meta_group_size() ? partial[tile.thread_rank()] : T{}; + return cg::reduce(tile, value, cg::plus{}); +} + +template +__global__ void fusedAddBiasResidualNorm(T* __restrict__ r_data, + T* __restrict__ x_data, + const T* __restrict__ bias, + const T* __restrict__ scale, + float eps, + int batch_size, + int n_dims) +{ + auto block = cg::this_thread_block(); + auto grid = cg::this_grid(); + + constexpr int PACK_DIM = sizeof(uint4) / sizeof(T); + + const auto batch_idx = grid.block_rank(); + uint4* __restrict__ r_ptr = reinterpret_cast(r_data + batch_idx * n_dims); + uint4* __restrict__ x_ptr = reinterpret_cast(x_data + batch_idx * n_dims); + const uint4* __restrict__ b_ptr = reinterpret_cast(bias); + + res_norm_t ops; + + float thread_sum{}; + for (auto i = block.thread_rank(); i < n_dims / PACK_DIM; i += block.num_threads()) { + auto r = r_ptr[i]; + auto x = x_ptr[i]; + uint4 b = b_ptr ? b_ptr[i] : uint4{}; + r = ops.addvec(r, x, b, thread_sum); + r_ptr[i] = r; + } + + auto total_sum = blockReduceSum(block, thread_sum); + + float s_inv_mean = rsqrt(total_sum / n_dims + eps); + + const uint4* __restrict__ s_ptr = reinterpret_cast(scale); + for (uint i = block.thread_rank(); i < n_dims / PACK_DIM; i += block.num_threads()) { + auto r = r_ptr[i]; + auto s = s_ptr[i]; + auto o = ops.normvec(r, s, s_inv_mean); + x_ptr[i] = o; + } +} + +template +void invokeFusedAddBiasResidualRMSNorm( + T* residual, T* in_out, const T* bias, const T* scale, float eps, int batch_size, int n_dims, cudaStream_t stream) +{ + constexpr int PACK_DIM = sizeof(uint4) / sizeof(T); + FT_CHECK(n_dims % PACK_DIM == 0); + const int n_pack = n_dims / PACK_DIM; + const int n_iter = ((n_pack + 1023) / 1024); // iterations when block size == 1024 + int n_threads = (n_pack + n_iter - 1) / n_iter; // adjust block size to avoid tail effect + n_threads = (n_threads + 31) / 32 * 32; // round up to the nearest multiple of warp size + + fusedAddBiasResidualNorm<<>>( + residual, in_out, bias, scale, eps, batch_size, n_dims); +} + +template void +invokeFusedAddBiasResidualRMSNorm(float*, float*, const float*, const float*, float, int, int, cudaStream_t); +template void invokeFusedAddBiasResidualRMSNorm(half*, half*, const half*, const half*, float, int, int, cudaStream_t); + +} // namespace fastertransformer diff --git a/src/fastertransformer/models/llama/llama_decoder_kernels.h b/src/fastertransformer/models/llama/llama_decoder_kernels.h new file mode 100644 index 000000000..af299a9df --- /dev/null +++ b/src/fastertransformer/models/llama/llama_decoder_kernels.h @@ -0,0 +1,11 @@ +// Copyright (c) OpenMMLab. All rights reserved. + +#include + +namespace fastertransformer { + +template +void invokeFusedAddBiasResidualRMSNorm( + T* residual, T* in_out, const T* bias, const T* scale, float eps, int batch_size, int n_dims, cudaStream_t stream); + +} // namespace fastertransformer diff --git a/src/fastertransformer/models/llama/llama_gemm.cc b/src/fastertransformer/models/llama/llama_gemm.cc new file mode 100644 index 000000000..526398a7c --- /dev/null +++ b/src/fastertransformer/models/llama/llama_gemm.cc @@ -0,0 +1,149 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Copied from +// https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/models/multi_gpu_gpt/gpt_gemm.cc + +#include "src/fastertransformer/utils/gemm_test/gpt_gemm_func.h" +#include "src/fastertransformer/utils/memory_utils.h" + +namespace ft = fastertransformer; + +int main(int argc, char* argv[]) +{ + if (argc < 9 || argc > 11) { + FT_LOG_ERROR("./bin/llama_gemm batch_size \\ \n" + " beam_width \\ \n" + " max_input_len \\ \n" + " head_number \\ \n" + " size_per_head \\ \n" + " inter_size \\ \n" + " vocab_size \\ \n" + " data_type \\ \n" + " tensor_para_size \\\n" + " is_append (append new config into exist gemm_config.ini or not)"); + FT_LOG_ERROR("e.g. ./bin/llama_gemm 8 4 32 96 128 49152 51200 1 8 1"); + return 0; + } + + const int batch_size = atoi(argv[1]); + const int beam_width = atoi(argv[2]); + const int max_input_len = atoi(argv[3]); + const int head_num = atoi(argv[4]); + const int size_per_head = atoi(argv[5]); + const int inter_size = atoi(argv[6]); + const int vocab_size = atoi(argv[7]); + const ft::CublasDataType data_type = static_cast(atoi(argv[8])); // 0 FP32, 1 FP16, 2 BF 16 + const int tensor_para_size = argc < 10 ? 1 : atoi(argv[9]); + const bool is_append = argc < 11 ? false : (bool)(atoi(argv[10])); + + FT_LOG_INFO("Arguments:"); + FT_LOG_INFO(" batch_size: %d", batch_size); + FT_LOG_INFO(" beam_width: %d", beam_width); + FT_LOG_INFO(" max_input_len: %d", max_input_len); + FT_LOG_INFO(" head_num: %d", head_num); + FT_LOG_INFO(" size_per_head: %d", size_per_head); + FT_LOG_INFO(" inter_size: %d", inter_size); + FT_LOG_INFO(" vocab_size: %d", vocab_size); + FT_LOG_INFO(" data_type: %d", data_type); + FT_LOG_INFO(" tensor_para_size: %d", tensor_para_size); + FT_LOG_INFO(" is_append: %d", (int)is_append); + std::cout << std::endl; + + void* gemm_test_buf; + size_t buf_size_in_byte = ft::calGptGemmTestBufSizeInByte(batch_size, + beam_width, + max_input_len, + head_num, + size_per_head, + inter_size, + vocab_size, + tensor_para_size, + data_type); + size_t total, free; + ft::check_cuda_error(cudaMemGetInfo(&free, &total)); + if (free < buf_size_in_byte + 10 * 1024 * 1024) { + printf("[ERROR] There is no enough device memory for gemm test!\n" + " %ld Bytes is needed, but only %ld Bytes is free.\n", + buf_size_in_byte, + free); + gemm_test_buf = NULL; + return -1; + } + else { + ft::deviceMalloc(reinterpret_cast(&gemm_test_buf), buf_size_in_byte, false); + } + + if (data_type == ft::FLOAT_DATATYPE) { + ft::generate_gpt_gemm_config(batch_size, + beam_width, + max_input_len, + head_num, + size_per_head, + inter_size, + vocab_size, + tensor_para_size, + gemm_test_buf, + is_append); + } + else if (data_type == ft::HALF_DATATYPE) { + ft::generate_gpt_gemm_config(batch_size, + beam_width, + max_input_len, + head_num, + size_per_head, + inter_size, + vocab_size, + tensor_para_size, + gemm_test_buf, + is_append); + } +#ifdef ENABLE_BF16 + else if (data_type == ft::BFLOAT16_DATATYPE) { + ft::generate_gpt_gemm_config<__nv_bfloat16>(batch_size, + beam_width, + max_input_len, + head_num, + size_per_head, + inter_size, + vocab_size, + tensor_para_size, + gemm_test_buf, + is_append); + } +#endif +#ifdef ENABLE_FP8 + else if (data_type == ft::FP8_DATATYPE) { + ft::generate_gpt_gemm_config<__nv_fp8_e4m3>(batch_size, + beam_width, + max_input_len, + head_num, + size_per_head, + inter_size, + vocab_size, + tensor_para_size, + gemm_test_buf, + false); + } +#endif + else { + printf("[ERROR] data type only supports fp32(0), fp16(1), bf16(2), fp8(4). \n"); + return -1; + } + + ft::check_cuda_error(cudaFree(gemm_test_buf)); + return 0; +} diff --git a/src/fastertransformer/models/llama/llama_kernels.cu b/src/fastertransformer/models/llama/llama_kernels.cu new file mode 100644 index 000000000..5ddc08809 --- /dev/null +++ b/src/fastertransformer/models/llama/llama_kernels.cu @@ -0,0 +1,712 @@ +// Copyright (c) OpenMMLab. All rights reserved. + +#include "src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h" +#include "src/fastertransformer/kernels/reduce_kernel_utils.cuh" +#include "src/fastertransformer/models/llama/llama_kernels.h" +#include "src/fastertransformer/models/llama/llama_utils.h" +#include "src/fastertransformer/utils/cuda_type_utils.cuh" + +namespace fastertransformer { + +// fp16, bf16 +// n is divided by 2 for this impl +template +__global__ void rootMeanSquareNorm(T* out, const T* input, const T* scale, float eps, int m, int n) +{ + using T2 = typename TypeConverter::Type; + __shared__ float s_inv_mean; + float mean = 0.f; + + T2* out_ptr = (T2*)out; + const T2* input_ptr = (const T2*)input; + const T2* scale_ptr = (const T2*)scale; + + for (uint idx = threadIdx.x; idx < n; idx += blockDim.x) { + float2 tmp2 = cuda_cast(input_ptr[blockIdx.x * n + idx]); + mean += tmp2.x * tmp2.x; + mean += tmp2.y * tmp2.y; + } + + mean = blockReduceSum(mean); + if (threadIdx.x == 0) { + s_inv_mean = rsqrt(.5f * mean / (float)n + eps); + } + __syncthreads(); + + for (uint idx = threadIdx.x; idx < n; idx += blockDim.x) { + float2 tmp2 = cuda_cast(input_ptr[blockIdx.x * n + idx]); + float2 sca2 = cuda_cast(scale_ptr[idx]); + tmp2.x = tmp2.x * s_inv_mean * sca2.x; + tmp2.y = tmp2.y * s_inv_mean * sca2.y; + out_ptr[blockIdx.x * n + idx] = cuda_cast(tmp2); + } +} + +template<> +__global__ void rootMeanSquareNorm(float* out, const float* input, const float* scale, float eps, int m, int n) +{ + __shared__ float s_inv_mean; + float mean = 0.f; + + for (uint idx = threadIdx.x; idx < n; idx += blockDim.x) { + float tmp = input[blockIdx.x * n + idx]; + mean += tmp * tmp; + } + + mean = blockReduceSum(mean); + if (threadIdx.x == 0) { + s_inv_mean = rsqrt(mean / static_cast(n) + eps); + } + __syncthreads(); + + for (uint idx = threadIdx.x; idx < n; idx += blockDim.x) { + float tmp = input[blockIdx.x * n + idx]; + out[blockIdx.x * n + idx] = tmp * s_inv_mean * scale[idx]; + } +} + +template +void invokeRootMeanSquareNorm(T* out, const T* input, const T* scale, float eps, int m, int n, cudaStream_t stream) +{ + if (sizeof(T) == 2) { + FT_CHECK(n % 2 == 0); + n /= 2; + } + dim3 grid(m); + dim3 block(std::min(n, 1024)); + rootMeanSquareNorm<<>>(out, input, scale, eps, m, n); +} + +template void invokeRootMeanSquareNorm(float*, const float*, const float*, float, int, int, cudaStream_t); +template void invokeRootMeanSquareNorm(half*, const half*, const half*, float, int, int, cudaStream_t); + +// #ifdef ENABLE_BF16 + +// template void invokeRootMeanSquareNorm(__nv_bfloat16*, const __nv_bfloat16*, float, int, int, cudaStream_t); + +// #endif + +template +__device__ T saturate_cast(T0 x) +{ + return x; +} + +template<> +__device__ half saturate_cast(float x) +{ + return (x > 64512.f || x < -64512.f) ? (x > 0.f ? 64512.f : -64512.f) : x; +} + +template +__global__ void addResidual(T* out, const T* in, size_t n) +{ + auto idx = threadIdx.x + (size_t)blockIdx.x * blockDim.x; + if (idx < n) { + out[idx] = static_cast(static_cast(out[idx]) + static_cast(in[idx])); + } +} + +template +void invokeAddResidual(T* out, const T* in, int m, int n, cudaStream_t stream) +{ + auto total = static_cast(m) * n; + dim3 block(std::min(total, 1024UL)); + dim3 grid((total + block.x - 1) / block.x); + + addResidual<<>>(out, in, total); +} + +template void invokeAddResidual(float*, const float*, int, int, cudaStream_t); +template void invokeAddResidual(half*, const half*, int, int, cudaStream_t); + +// ids [seq_len, batch_size] +// input_ids [batch_size, max_input_len] +__global__ void +fixInputIds(int* ids, const int* input_ids, const int* input_lengths, int batch_size, int seq_len, int max_input_len) +{ + int seq_id = threadIdx.x; + int batch_id = blockIdx.x; + for (; seq_id < input_lengths[batch_id]; seq_id += blockDim.x) { + ids[seq_id * batch_size + batch_id] = input_ids[batch_id * max_input_len + seq_id]; + } +} + +void invokeFixInputIds(int* ids, + const int* input_ids, + const int* input_lengths, + int batch_size, + int seq_len, + int max_input_len, + cudaStream_t st) +{ + dim3 block(std::min(1024, max_input_len)); + dim3 grid(batch_size); + fixInputIds<<>>(ids, input_ids, input_lengths, batch_size, seq_len, max_input_len); +} + +template +__global__ void sliceCausalMask(T* mask, int seq_len, int key_len, int step) +{ + mask += (size_t)blockIdx.x * seq_len * key_len; + for (int i = threadIdx.x; i < seq_len * key_len; i += blockDim.x) { + int row = i / key_len; + int col = i % key_len; + if (col <= row + step) { + mask[i] = static_cast(1.f); + } + else { + mask[i] = static_cast(0.f); + } + } +} + +// [step: step+Q, :] of the K*K causal mask +template +void invokeSliceCausalMask(T* mask, int seq_len, int key_len, int step, int batch_size, cudaStream_t stream) +{ + FT_CHECK(step == key_len - seq_len); + sliceCausalMask<<>>(mask, seq_len, key_len, step); +} + +template void invokeSliceCausalMask(half*, int, int, int, int, cudaStream_t); +template void invokeSliceCausalMask(float*, int, int, int, int, cudaStream_t); + +// mask [bsz, max_q_len, max_k_len] + +template +__global__ void createCausalMasks(T* mask, const int* q_lens, const int* k_lens, int max_q_len, int max_k_len) +{ + const auto q_len = q_lens[blockIdx.x]; + const auto k_len = k_lens[blockIdx.x]; + mask += blockIdx.x * max_q_len * max_k_len; + for (int i = threadIdx.x; i < max_q_len * max_k_len; i += blockDim.x) { + const int q = i / max_k_len; // [0, max_q_len) + const int k = i % max_k_len; // [0, max_k_len) + bool is_valid = q < q_len && k < k_len && k <= q + (k_len - q_len); + mask[i] = static_cast(is_valid); + } +} + +template +void invokeCreateCausalMasks( + T* mask, const int* q_lens, const int* k_lens, int max_q_len, int max_k_len, int batch_size, cudaStream_t stream) +{ + createCausalMasks<<>>(mask, q_lens, k_lens, max_q_len, max_k_len); +} + +template void invokeCreateCausalMasks(float* mask, const int*, const int*, int, int, int, cudaStream_t); +template void invokeCreateCausalMasks(half* mask, const int*, const int*, int, int, int, cudaStream_t); + +template +__global__ void extend_key_cache(T** k_dst, + const size_t dst_offset, + const T* k_src, + const int head_num, + const int size_per_head, + const int* query_length, + const int* history_length, + const int max_q_len, + const int max_seq_len) +{ + const int batch_id = blockIdx.y; + const int head_id = blockIdx.z; + constexpr int X_ELEMS = (sizeof(T) == 4) ? 4 : 8; + + const int idx = blockIdx.x * blockDim.x + threadIdx.x; + int size_per_head_div_x = size_per_head / X_ELEMS; + + // x dim is now handled by uint4 type + const auto key_src = reinterpret_cast(k_src); + const auto key_dst = reinterpret_cast(k_dst[batch_id] + dst_offset); + + const auto seq_len = query_length[batch_id]; + const auto t_offset = history_length[batch_id]; + + const int k_head_size_id = idx % size_per_head_div_x; + const int k_seq_len_id = idx / size_per_head_div_x; + + if (k_seq_len_id < seq_len) { + // [B, H, s, D/x] -> [H, D/x, S[t:t+s]] + + const int64_t dst_idx = head_id * size_per_head_div_x * max_seq_len + // H + k_head_size_id * max_seq_len + // D/x + t_offset + k_seq_len_id; // s + offset + + const int64_t src_idx = batch_id * head_num * size_per_head_div_x * max_q_len + // B + head_id * size_per_head_div_x * max_q_len + // H + k_seq_len_id * size_per_head_div_x + // s + k_head_size_id; // D/x + + key_dst[dst_idx] = key_src[src_idx]; + } +} + +template +__global__ void extend_value_cache(T** v_dst, + const size_t dst_offset, + const T* v_src, + const int head_num, + const int size_per_head, + const int* query_length, + const int* history_length, + const int max_q_len, + const int max_seq_len) +{ + const int batch_id = blockIdx.y; + const int head_id = blockIdx.z; + constexpr int X_ELEMS = (sizeof(T) == 4) ? 4 : 8; + + const int idx = blockIdx.x * blockDim.x + threadIdx.x; + int size_per_head_div_x = size_per_head / X_ELEMS; + + // x dim is now handled by uint4 type + const auto val_src = reinterpret_cast(v_src); + const auto val_dst = reinterpret_cast(v_dst[batch_id] + dst_offset); + + const auto seq_len = query_length[batch_id]; + const auto t_offset = history_length[batch_id]; + + const int v_head_size_id = idx % size_per_head_div_x; + const int v_seq_len_id = idx / size_per_head_div_x; + + if (v_seq_len_id < seq_len) { + // [B, H, s, D/x] -> [H, S[t:t+s], D/x] + const int64_t dst_idx = head_id * size_per_head_div_x * max_seq_len + // H + (v_seq_len_id + t_offset) * size_per_head_div_x + // s + offset + v_head_size_id; // D/x + + const int64_t src_idx = batch_id * head_num * size_per_head_div_x * max_q_len + // B + head_id * size_per_head_div_x * max_q_len + // H + v_seq_len_id * size_per_head_div_x + // s + v_head_size_id; // D/x + + val_dst[dst_idx] = val_src[src_idx]; + } +} + +inline __device__ float2 float2div(float a, float2 b) +{ + float2 c; + c.x = b.x / a; + c.y = b.y / a; + return c; +} + +static inline __device__ half4 char4_scale_to_half4(char4 value, const float scale) +{ + half4 dst; + dst.x = __float2half(value.x * scale); + dst.y = __float2half(value.y * scale); + dst.z = __float2half(value.z * scale); + dst.w = __float2half(value.w * scale); + return dst; +} + +static inline __device__ uint32_t float4_to_char4(float x, float y, float z, float w) +{ + uint32_t dst; +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 720 + uint32_t a; + asm volatile("cvt.rni.sat.s32.f32 %0, %1;\n" : "=r"(a) : "f"(x)); + uint32_t b; + asm volatile("cvt.rni.sat.s32.f32 %0, %1;\n" : "=r"(b) : "f"(y)); + uint32_t c; + asm volatile("cvt.rni.sat.s32.f32 %0, %1;\n" : "=r"(c) : "f"(z)); + uint32_t d; + asm volatile("cvt.rni.sat.s32.f32 %0, %1;\n" : "=r"(d) : "f"(w)); + + asm volatile("cvt.pack.sat.s8.s32.b32 %0, %1, %2, 0;\n" : "=r"(dst) : "r"(d), "r"(c)); + asm volatile("cvt.pack.sat.s8.s32.b32 %0, %1, %2, %0;\n" : "+r"(dst) : "r"(b), "r"(a)); +#else + char4 tmp; + tmp.x = x; + tmp.y = y; + tmp.z = z; + tmp.w = w; + dst = reinterpret_cast(tmp); +#endif + return dst; +} + +template +__global__ void extend_value_cache_int8(int8_t** v_dst, + const size_t dst_offset, + const T* v_src, + const int head_num, + const int size_per_head, + const int* query_length, + const int* history_length, + const int max_q_len, + const int max_seq_len, + const float v_scale) +{ + const int batch_id = blockIdx.y; + const int head_id = blockIdx.z; + constexpr int X_ELEMS = (sizeof(T) == 4) ? 4 : 8; + + const int idx = blockIdx.x * blockDim.x + threadIdx.x; + int size_per_head_div_x = size_per_head / X_ELEMS; + + // x dim is now handled by uint4 type + const auto val_src = reinterpret_cast(v_src); + const auto val_dst = reinterpret_cast(v_dst[batch_id] + dst_offset); + + const auto seq_len = query_length[batch_id]; + const auto t_offset = history_length[batch_id]; + + const int v_head_size_id = idx % size_per_head_div_x; + const int v_seq_len_id = idx / size_per_head_div_x; + + if (v_seq_len_id < seq_len) { + // [B, H, s, D/x] -> [H, S[t:t+s], D/x] + const int64_t dst_idx = head_id * size_per_head_div_x * max_seq_len + // H + (v_seq_len_id + t_offset) * size_per_head_div_x + // s + offset + v_head_size_id; // D/x + + const int64_t src_idx = batch_id * head_num * size_per_head_div_x * max_q_len + // B + head_id * size_per_head_div_x * max_q_len + // H + v_seq_len_id * size_per_head_div_x + // s + v_head_size_id; // D/x + + // scale to int8 and write + const auto value = val_src[src_idx]; + auto to_ptr = reinterpret_cast(val_dst + dst_idx); + + float2 float2_0 = float2div(v_scale, mmha::half2_to_float2(value.x)); + float2 float2_1 = float2div(v_scale, mmha::half2_to_float2(value.y)); + to_ptr[0] = float4_to_char4(float2_0.x, float2_0.y, float2_1.x, float2_1.y); + + float2_0 = float2div(v_scale, mmha::half2_to_float2(value.z)); + float2_1 = float2div(v_scale, mmha::half2_to_float2(value.w)); + to_ptr[1] = float4_to_char4(float2_0.x, float2_0.y, float2_1.x, float2_1.y); + } +} + +template +void invokeExtendKVCache(T** k_dst, + T** v_dst, + size_t dst_offset, + const T* k_src, + const T* v_src, + int local_batch_size, + const int* query_length, + int max_q_len, + const int* history_length, + int max_seq_len, + int size_per_head, + int local_head_num, + cudaStream_t stream, + int quant, + const float* kv_scale) +{ + constexpr int block_sz = 128; + constexpr int x = (sizeof(T) == 4) ? 4 : 8; + + dim3 grid((max_q_len * size_per_head / x + block_sz - 1) / block_sz, local_batch_size, local_head_num); + + if (quant & QuantPolicy::kCacheKVInt8) { + extend_value_cache_int8<<>>(reinterpret_cast(k_dst), + dst_offset, + k_src, + local_head_num, + size_per_head, + query_length, + history_length, + max_q_len, + max_seq_len, + kv_scale[0]); + + extend_value_cache_int8<<>>(reinterpret_cast(v_dst), + dst_offset, + v_src, + local_head_num, + size_per_head, + query_length, + history_length, + max_q_len, + max_seq_len, + kv_scale[1]); + } + else { + extend_value_cache<<>>(k_dst, + dst_offset, + k_src, + local_head_num, + size_per_head, + query_length, + history_length, + max_q_len, + max_seq_len); + + extend_value_cache<<>>(v_dst, + dst_offset, + v_src, + local_head_num, + size_per_head, + query_length, + history_length, + max_q_len, + max_seq_len); + } +} + +template void invokeExtendKVCache(float**, + float**, + size_t, + const float*, + const float*, + int, + const int*, + int, + const int*, + int, + int, + int, + cudaStream_t stream, + int, + const float*); + +template void invokeExtendKVCache(half**, + half**, + size_t, + const half*, + const half*, + int, + const int*, + int, + const int*, + int, + int, + int, + cudaStream_t stream, + int, + const float*); + +template +__global__ void transpose_value_cache(T* v_dst, // + const T** v_src, + const size_t src_offset, + const int head_num, + const int head_n_rep, + const int size_per_head, + const int* seq_length, + const int max_kv_len, + const int max_seq_len) +{ + const int batch_id = blockIdx.y; + const int head_id = blockIdx.z; + constexpr int X_ELEMS = (sizeof(T) == 4) ? 4 : 8; + + const int idx = blockIdx.x * blockDim.x + threadIdx.x; + int size_per_head_div_x = size_per_head / X_ELEMS; + + // x dim is now handled by uint4 type + const auto val_src = reinterpret_cast(v_src[batch_id] + src_offset); + const auto val_dst = reinterpret_cast(v_dst); + + const auto seq_len = seq_length[batch_id]; + + const int v_head_size_id = idx % size_per_head_div_x; + const int v_seq_len_id = idx / size_per_head_div_x; + + if (v_seq_len_id < seq_len) { + // [B, H, s, D/x] <- [B, H, S[:s], D/x] + const int64_t src_idx = head_id / head_n_rep * size_per_head_div_x * max_seq_len + // H + v_seq_len_id * size_per_head_div_x + // s + v_head_size_id; // D/x + + const int64_t dst_idx = batch_id * head_num * size_per_head_div_x * max_kv_len + // B + head_id * size_per_head_div_x * max_kv_len + // H + v_seq_len_id * size_per_head_div_x + // s + v_head_size_id; // D/x + + val_dst[dst_idx] = val_src[src_idx]; + } +} + +template +__global__ void transpose_value_cache_int8(T* v_dst, // + const int8_t** v_src, + const size_t src_offset, + const int head_num, + const int head_n_rep, + const int size_per_head, + const int* seq_length, + const int max_kv_len, + const int max_seq_len, + const float v_scale) +{ + const int batch_id = blockIdx.y; + const int head_id = blockIdx.z; + constexpr int X_ELEMS = (sizeof(T) == 4) ? 4 : 8; + + const int idx = blockIdx.x * blockDim.x + threadIdx.x; + int size_per_head_div_x = size_per_head / X_ELEMS; + + // x dim is now handled by uint4 type + const auto val_src = reinterpret_cast(v_src[batch_id] + src_offset); + const auto val_dst = reinterpret_cast(v_dst); + + const auto seq_len = seq_length[batch_id]; + + const int v_head_size_id = idx % size_per_head_div_x; + const int v_seq_len_id = idx / size_per_head_div_x; + + if (v_seq_len_id < seq_len) { + // [B, H, s, D/x] <- [B, H, S[:s], D/x] + const int64_t src_idx = head_id / head_n_rep * size_per_head_div_x * max_seq_len + // H + v_seq_len_id * size_per_head_div_x + // s + v_head_size_id; // D/x + + const int64_t dst_idx = batch_id * head_num * size_per_head_div_x * max_kv_len + // B + head_id * size_per_head_div_x * max_kv_len + // H + v_seq_len_id * size_per_head_div_x + // s + v_head_size_id; // D/x + + // int8x8 -> fp16x8 + const auto from_ptr = reinterpret_cast(val_src + src_idx); + auto to_ptr = reinterpret_cast(val_dst + dst_idx); + + to_ptr[0] = char4_scale_to_half4(from_ptr[0], v_scale); + to_ptr[1] = char4_scale_to_half4(from_ptr[1], v_scale); + } +} + +template +void invokeTransposeKVCache(T* key_cache_trans, + T* val_cache_trans, + const T** key_cache, + const T** val_cache, + size_t src_offset, + int batch_size, + const int* key_length, + int max_kv_len, + int max_seq_len, + int size_per_head, + int head_num, + int head_n_rep, + cudaStream_t stream, + int quant, + const float* kv_scale) +{ + constexpr int block_sz = 128; + constexpr int x = (sizeof(T) == 4) ? 4 : 8; + + dim3 grid((max_kv_len * size_per_head / x + block_sz - 1) / block_sz, batch_size, head_num); + + if (quant & QuantPolicy::kCacheKVInt8) { + transpose_value_cache_int8<<>>(key_cache_trans, + reinterpret_cast(key_cache), + src_offset, + head_num, + head_n_rep, + size_per_head, + key_length, + max_kv_len, + max_seq_len, + kv_scale[0]); + + transpose_value_cache_int8<<>>(val_cache_trans, + reinterpret_cast(val_cache), + src_offset, + head_num, + head_n_rep, + size_per_head, + key_length, + max_kv_len, + max_seq_len, + kv_scale[1]); + } + else { + transpose_value_cache<<>>(key_cache_trans, + key_cache, + src_offset, + head_num, + head_n_rep, + size_per_head, + key_length, + max_kv_len, + max_seq_len); + + transpose_value_cache<<>>(val_cache_trans, + val_cache, + src_offset, + head_num, + head_n_rep, + size_per_head, + key_length, + max_kv_len, + max_seq_len); + } +} + +template void invokeTransposeKVCache(float*, + float*, + const float**, + const float**, + size_t, + int, + const int*, + int, + int, + int, + int, + int, + cudaStream_t stream, + int, + const float*); +template void invokeTransposeKVCache(half*, + half*, + const half**, + const half**, + size_t, + int, + const int*, + int, + int, + int, + int, + int, + cudaStream_t stream, + int, + const float*); + +__global__ void gatherOutput(int* output_ids, + const int* ids, + const int* context_length, + int max_context_len, + int max_gen_step, + int max_output_len, + int batch_size) +{ + const int batch_id = blockIdx.x; + const int context_len = context_length[batch_id]; + output_ids += batch_id * max_output_len; + for (int src_idx = threadIdx.x; src_idx < max_gen_step; src_idx += blockDim.x) { + // skip padding for src + if (context_len <= src_idx && src_idx < max_context_len) { + continue; + } + // skip padding for dst + const int dst_idx = src_idx < context_len ? src_idx : src_idx - (max_context_len - context_len); + output_ids[dst_idx] = ids[src_idx * batch_size + batch_id]; + } +} + +void invokeGatherOutput(int* output_ids, + const int* ids, + const int* context_length, + int max_context_len, + int max_gen_step, + int max_output_len, + int batch_size, + cudaStream_t stream) +{ + int block_size = 512; + int grid_size = batch_size; + gatherOutput<<>>( + output_ids, ids, context_length, max_context_len, max_gen_step, max_output_len, batch_size); +} + +} // namespace fastertransformer diff --git a/src/fastertransformer/models/llama/llama_kernels.h b/src/fastertransformer/models/llama/llama_kernels.h new file mode 100644 index 000000000..351ee724e --- /dev/null +++ b/src/fastertransformer/models/llama/llama_kernels.h @@ -0,0 +1,168 @@ +// Copyright (c) OpenMMLab. All rights reserved. + +#pragma once + +#include "src/fastertransformer/kernels/gpt_kernels.h" +#include "src/fastertransformer/utils/cuda_bf16_wrapper.h" +#include "src/fastertransformer/utils/cuda_utils.h" +#include +#include +#include +#include + +namespace fastertransformer { + +template +void invokeRootMeanSquareNorm(T* out, const T* input, const T* scale, float eps, int m, int n, cudaStream_t stream); + +template +void invokeAddResidual(T* out, const T* in, int m, int n, cudaStream_t stream); + +void invokeFixInputIds(int* ids, + const int* input_ids, + const int* input_lengths, + int batch_size, + int seq_len, + int max_input_len, + cudaStream_t st); + +template +void invokeSliceCausalMask(T* mask, int seq_len, int key_len, int step, int batch_size, cudaStream_t stream); + +template +void invokeCreateCausalMasks( + T* mask, const int* q_lens, const int* k_lens, int max_q_len, int max_k_len, int batch_size, cudaStream_t stream); + +template +void invokeExtendKVCache(T** k_dst, + T** v_dst, + size_t layer_offset, + const T* k_src, + const T* v_src, + int batch_size, + const int* query_length, + int max_q_len, + const int* history_length, + int max_seq_len, + int size_per_head, + int local_head_num, + cudaStream_t stream, + int quant, + const float* kv_scale); + +template +void invokeTransposeKVCache(T* key_cache_trans, + T* val_cache_trans, + const T** key_cache, + const T** val_cache, + size_t layer_offset, + int batch_size, + const int* key_length, + int max_kv_len, + int max_seq_len, + int size_per_head, + int head_num, + int head_n_rep, + cudaStream_t stream, + int quant_policy, + const float* kv_scale); + +void invokeGatherOutput(int* output_ids, + const int* ids, + const int* context_length, + int max_context_len, + int max_gen_step, + int max_output_len, + int batch_size, + cudaStream_t stream); + +void invokeMyCopyInt(int* dst, const int* src, size_t count, cudaStream_t st); + +template +class FlashAttentionOp { +public: + struct AttentionLayout { + int stride_batch; + int stride_seq; + int stride_head; + bool use_seqlens = false; + int batch_seqs_offset = 0; + T** batch_seqs = nullptr; + }; + + struct Params { + T* attn_out; + T* query; + T* key; + T* val; + T* mask; + float* out_accum = nullptr; + int* cu_seqlens_q = nullptr; + int* cu_seqlens_k = nullptr; + size_t group_size = 1; + AttentionLayout layout_q; + AttentionLayout layout_k; + AttentionLayout layout_v; + AttentionLayout layout_o; + }; + +public: + FlashAttentionOp(int batch_size, int head_num, int key_len, int seq_len, int size_per_head); + ~FlashAttentionOp(); + + int get_workspace_size() const; + + void operator()(Params& params, cudaStream_t st) const; + +private: + class impl; + std::unique_ptr pimpl; +}; + +template +inline void dump(const T* x, int size, cudaStream_t st, const char* msg, bool full = false) +{ + std::vector h_x(size); + cudaMemcpyAsync(h_x.data(), x, sizeof(T) * size, cudaMemcpyDefault, st); + cudaStreamSynchronize(st); + fprintf(stderr, "\n%s:\n", msg); + std::vector h_y(h_x.begin(), h_x.end()); + float asum = 0.f; + for (const auto& x : h_y) { + asum += std::fabs(x); + } + if (full) { + for (int i = 0; i < size; ++i) { + printf("%d %.8f\n", i, h_y[i]); + } + } + else { + for (int i = 0; i < 8; ++i) { + fprintf(stderr, "%.8f\n", h_y[i]); + } + for (int i = size - 8; i < size; ++i) { + fprintf(stderr, "%.8f\n", h_y[i]); + } + } + fprintf(stderr, "\nasum = %f\n", asum); + // getchar(); +} + +template +struct TempBuffer { + TempBuffer(size_t size) + { + deviceMalloc(&data, size, false); + } + T* data; +}; + +inline void dump_sequence_len(int* d_seq_len, int step, int tp_rank, cudaStream_t st) +{ + int h_seq_len = -1; + cudaMemcpyAsync(&h_seq_len, d_seq_len, sizeof(int), cudaMemcpyDefault, st); + cudaStreamSynchronize(st); + FT_LOG_ERROR("--------> rank = %d, step = %d, seq_len = %d <--------", tp_rank, step, h_seq_len); +} + +} // namespace fastertransformer diff --git a/src/fastertransformer/models/llama/llama_utils.cu b/src/fastertransformer/models/llama/llama_utils.cu new file mode 100644 index 000000000..42bb556f5 --- /dev/null +++ b/src/fastertransformer/models/llama/llama_utils.cu @@ -0,0 +1,160 @@ +// Copyright (c) OpenMMLab. All rights reserved. + +#include "src/fastertransformer/kernels/reduce_kernel_utils.cuh" +#include "src/fastertransformer/models/llama/llama_utils.h" +#include "src/fastertransformer/utils/cuda_utils.h" +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace fastertransformer { + +CmpMode compare_mode = kCmpNone; + +template +struct abs_diff_t { + using type = T; +}; + +template<> +struct abs_diff_t { + using type = float; +}; + +template +struct abs_diff: public thrust::unary_function, typename abs_diff_t::type> { + __host__ __device__ float operator()(thrust::tuple x) const + { + using R = typename abs_diff_t::type; + auto r = R(thrust::get<0>(x)) - R(thrust::get<1>(x)); + return r < R(0) ? -r : r; + } +}; + +template +void CheckNan(const T* ptr, size_t size, std::string key, cudaStream_t stream) +{ + std::vector h_data(size); + cudaMemcpyAsync(h_data.data(), ptr, sizeof(T) * size, cudaMemcpyDefault, stream); + + check_cuda_error(cudaStreamSynchronize(stream)); + + size_t nan_cnt = 0; + for (const auto& x : h_data) { + nan_cnt += std::isnan(static_cast(x)); + } + if (nan_cnt) { + std::cerr << key << ": NaN count " << nan_cnt << "\n"; + } +} + +template +void CmpRead(T* ptr, size_t size, std::string key, cudaStream_t stream) +{ + // wait for b + check_cuda_error(cudaStreamSynchronize(stream)); + // read a from file + thrust::host_vector h_a(size); + { + const auto filename = "tmp/" + key + ".cmp"; + std::ifstream ifs(filename, std::ios::binary); + if (!ifs.is_open()) { + std::cerr << key << ": failed to open " + filename << "\n"; + return; + } + ifs.seekg(0, ifs.end); + const auto actual_size_in_bytes = ifs.tellg(); + ifs.seekg(0, ifs.beg); + const auto expect_size_in_bytes = sizeof(T) * size; + if (actual_size_in_bytes != expect_size_in_bytes) { + std::cerr << key << ": file size in bytes mismatch, expect " << expect_size_in_bytes << ", got " + << actual_size_in_bytes << "\n"; + return; + } + ifs.read((char*)h_a.data(), sizeof(T) * h_a.size()); + } + // copy a to device + thrust::device_vector a = h_a; + // create abs(a - b) iterator + thrust::device_ptr dev_ptr(ptr); + auto zip_iter = thrust::make_zip_iterator(thrust::make_tuple(a.begin(), dev_ptr)); + auto transform_iter = thrust::make_transform_iterator(zip_iter, abs_diff{}); + // sum(abs(a - b)) + auto asum = thrust::reduce(thrust::device, transform_iter, transform_iter + size); + std::cerr << key << ": " << asum << " " << asum / size << "\n"; +} + +template +void CmpWrite(T* ptr, size_t size, std::string key, cudaStream_t stream) +{ + std::vector a(size); + // copy a to host + check_cuda_error(cudaMemcpyAsync(a.data(), ptr, sizeof(T) * size, cudaMemcpyDefault, stream)); + check_cuda_error(cudaStreamSynchronize(stream)); + // write to file + { + std::ofstream ofs("tmp/" + key + ".cmp", std::ios::binary); + ofs.write((char*)a.data(), sizeof(T) * a.size()); + } +} + +template +void Compare(T* ptr, size_t size, std::string key, CmpMode mode, cudaStream_t stream) +{ + // std::cerr << "Comparing " << key << "\n"; + if (mode == kCmpRead) { + CmpRead(ptr, size, key, stream); + } + else if (mode == kCmpWrite) { + CmpWrite(ptr, size, key, stream); + } + else { + // kCmpNone + } +} + +template void Compare(int* ptr, size_t size, std::string key, CmpMode mode, cudaStream_t stream); +template void Compare(float* ptr, size_t size, std::string key, CmpMode mode, cudaStream_t stream); +template void Compare(half* ptr, size_t size, std::string key, CmpMode mode, cudaStream_t stream); + +template void CheckNan(const float* ptr, size_t size, std::string key, cudaStream_t stream); +template void CheckNan(const half* ptr, size_t size, std::string key, cudaStream_t stream); + +std::string format(const std::pair& p) +{ + std::stringstream ss; + ss << p.first << " ["; + bool first = true; + for (const auto& x : p.second.shape) { + ss << (first ? "" : ", ") << x; + first = false; + } + ss << "]"; + return ss.str(); +} + +size_t curandStateGetSize() +{ + return sizeof(curandState_t); +} + +bool isDebug() +{ + static const bool is_debug = [] { + const auto level = std::getenv("FT_DEBUG_LEVEL"); + if (level && level == std::string("DEBUG")) { + return true; + } + return false; + }(); + return is_debug; +} + +} // namespace fastertransformer diff --git a/src/fastertransformer/models/llama/llama_utils.h b/src/fastertransformer/models/llama/llama_utils.h new file mode 100644 index 000000000..d472d6640 --- /dev/null +++ b/src/fastertransformer/models/llama/llama_utils.h @@ -0,0 +1,69 @@ +// Copyright (c) OpenMMLab. All rights reserved. + +#pragma once +#include "src/fastertransformer/utils/Tensor.h" +#include +#include +#include +#include + +namespace fastertransformer { + +enum QuantPolicy +{ + kNone = 0x00, + // reserve 0x01 and 0x02 for backward compatibility + kReserve1 = 0x01, + kReserve2 = 0x02, + // quantize cache kv + kCacheKVInt8 = 0x04, +}; + +enum CmpMode +{ + kCmpNone, + kCmpRead, + kCmpWrite, +}; + +extern CmpMode compare_mode; + +template +void Compare(T* ptr, size_t size, std::string key, CmpMode mode, cudaStream_t stream); + +template +void CheckNan(const T* ptr, size_t size, std::string key, cudaStream_t stream); + +namespace detail { + +template +std::string to_string(T x) +{ + return std::to_string(x); +} + +inline std::string to_string(std::string x) +{ + return x; +} + +} // namespace detail + +template +std::string Concat(std::string key, Args&&... args) +{ + std::vector args_str{detail::to_string((Args &&) args)...}; + for (const auto& s : args_str) { + key.append("_"); + key.append(s); + } + return key; +} + +std::string format(const std::pair& p); + +size_t curandStateGetSize(); + +bool isDebug(); + +} // namespace fastertransformer diff --git a/src/fastertransformer/triton_backend/CMakeLists.txt b/src/fastertransformer/triton_backend/CMakeLists.txt index 56cda1bde..037c36c36 100644 --- a/src/fastertransformer/triton_backend/CMakeLists.txt +++ b/src/fastertransformer/triton_backend/CMakeLists.txt @@ -27,3 +27,4 @@ if (ENABLE_FP8) endif() add_subdirectory(bert) add_subdirectory(deberta) +add_subdirectory(llama) diff --git a/src/fastertransformer/triton_backend/llama/CMakeLists.txt b/src/fastertransformer/triton_backend/llama/CMakeLists.txt new file mode 100644 index 000000000..8d493002b --- /dev/null +++ b/src/fastertransformer/triton_backend/llama/CMakeLists.txt @@ -0,0 +1,28 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Modified from https://github.com/NVIDIA/FasterTransformer/blob/main/src/turbomind/triton_backend/multi_gpu_gpt/CMakeLists.txt + +cmake_minimum_required(VERSION 3.8) + +set(llama_triton_backend_files + LlamaTritonModel.cc + LlamaTritonModelInstance.cc +) + +add_library(LlamaTritonBackend STATIC ${llama_triton_backend_files}) +set_property(TARGET LlamaTritonBackend PROPERTY POSITION_INDEPENDENT_CODE ON) +target_link_libraries(LlamaTritonBackend PRIVATE TransformerTritonBackend Llama tensor memory_utils -lcublasLt) +target_compile_features(LlamaTritonBackend PRIVATE cxx_std_14) diff --git a/src/fastertransformer/triton_backend/llama/LlamaTritonModel.cc b/src/fastertransformer/triton_backend/llama/LlamaTritonModel.cc new file mode 100644 index 000000000..265ef4a6e --- /dev/null +++ b/src/fastertransformer/triton_backend/llama/LlamaTritonModel.cc @@ -0,0 +1,384 @@ +/* + * Copyright (c) OpenMMLab. All rights reserved. + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Modified from +// https://github.com/NVIDIA/FasterTransformer/blob/main/src/c/triton_backend/multi_gpu_gpt/ParallelGptTritonModel.cc + +#include "src/fastertransformer/triton_backend/llama/LlamaTritonModel.h" +#include "3rdparty/INIReader.h" +#include "src/fastertransformer/models/llama/LlamaInstanceComm.h" +#include "src/fastertransformer/triton_backend/llama/LlamaTritonModelInstance.h" +#include "src/fastertransformer/triton_backend/transformer_triton_backend.hpp" +#include "src/fastertransformer/utils/allocator.h" +#include + +namespace ft = fastertransformer; + +std::shared_ptr AbstractTransformerModel::createLlamaModel(std::string inifile) +{ + INIReader reader = INIReader(inifile); + if (reader.ParseError() < 0) { + std::cout << "[ERROR] Can't load '" << inifile << "'\n"; + return nullptr; + } + + const std::string data_type = reader.Get("ft_instance_hyperparameter", "data_type"); + int tensor_para_size = reader.GetInteger("ft_instance_hyperparameter", "tensor_para_size"); + std::string model_dir = reader.Get("ft_instance_hyperparameter", "model_dir"); + + if (data_type == "half" || data_type == "fp16") { + return std::make_shared>( + reader.GetInteger("ft_instance_hyperparameter", "tensor_para_size"), + reader.GetInteger("ft_instance_hyperparameter", "pipeline_para_size"), + reader.GetInteger("ft_instance_hyperparameter", "enable_custom_all_reduce", 0), + model_dir); + } + else { + return std::make_shared>( + reader.GetInteger("ft_instance_hyperparameter", "tensor_para_size"), + reader.GetInteger("ft_instance_hyperparameter", "pipeline_para_size"), + reader.GetInteger("ft_instance_hyperparameter", "enable_custom_all_reduce", 0), + model_dir); + } +} + +template +void LlamaTritonModel::handleMissingParams() +{ + if (kv_head_num_ == 0) { + kv_head_num_ = head_num_; + FT_LOG_WARNING("[LlamaTritonModel] `kv_head_num` is not set, default to `head_num` (%d).", (int)kv_head_num_); + } + + if (!max_batch_size_) { + max_batch_size_ = 32; + FT_LOG_WARNING("[LlamaTritonModel] `max_batch_size` is not set, default to %d.", (int)max_batch_size_); + } + + if (!session_len_) { + session_len_ = 2160; + FT_LOG_WARNING("[LlamaTritonModel] `session_len` is not set, default to %d.", (int)session_len_); + } + + if (!max_context_token_num_) { + max_context_token_num_ = (int)std::sqrt(max_batch_size_); + FT_LOG_WARNING("[LlamaTritonModel] `max_context_token_num` is not set, default to %d.", + (int)max_context_token_num_); + } + + if (!step_length_) { + step_length_ = 1; + FT_LOG_WARNING("[LlamaTritonModel] `step_length` is not set, default to %d.", (int)step_length_); + } + + if (!cache_max_entry_count_) { + cache_max_entry_count_ = 32; + FT_LOG_WARNING("[LlamaTritonModel] `cache_max_entry_count` is not set, default to %d.", + (int)cache_max_entry_count_); + } + + if (!cache_chunk_size_) { + cache_chunk_size_ = cache_max_entry_count_; + FT_LOG_WARNING("[LlamaTritonModel] `cache_chunk_size` is not set, default to %d.", (int)cache_chunk_size_); + } +} + +template +LlamaTritonModel::LlamaTritonModel(size_t tensor_para_size, + size_t pipeline_para_size, + int enable_custom_all_reduce, + std::string model_dir): + tensor_para_size_(tensor_para_size), + pipeline_para_size_(pipeline_para_size), + shared_weights_(std::vector>>(ft::getDeviceCount())), + enable_custom_all_reduce_(enable_custom_all_reduce) +{ + model_dir_ = model_dir; + const std::string inifile{model_dir + "/config.ini"}; + INIReader reader = INIReader(inifile); + if (reader.ParseError() < 0) { + std::cout << "[ERROR] Can't load '" << inifile << "'\n"; + ft::FT_CHECK(false); + } + + model_name_ = reader.Get("llama", "model_name"); + head_num_ = reader.GetInteger("llama", "head_num"); + kv_head_num_ = reader.GetInteger("llama", "kv_head_num", 0); + size_per_head_ = reader.GetInteger("llama", "size_per_head"); + inter_size_ = reader.GetInteger("llama", "inter_size"); + num_layer_ = reader.GetInteger("llama", "num_layer"); + vocab_size_ = reader.GetInteger("llama", "vocab_size"); + rotary_embedding_dim_ = reader.GetInteger("llama", "rotary_embedding"); + norm_eps_ = reader.GetFloat("llama", "norm_eps"); + start_id_ = reader.GetInteger("llama", "start_id"); + end_id_ = reader.GetInteger("llama", "end_id"); + max_batch_size_ = reader.GetInteger("llama", "max_batch_size", 0); + max_context_token_num_ = reader.GetInteger("llama", "max_context_token_num", 0); + session_len_ = reader.GetInteger("llama", "session_len", 0); + step_length_ = reader.GetInteger("llama", "step_length", 0); + cache_max_entry_count_ = reader.GetInteger("llama", "cache_max_entry_count", 0); + use_context_fmha_ = reader.GetInteger("llama", "use_context_fmha", 1); + cache_chunk_size_ = reader.GetInteger("llama", "cache_chunk_size", 0); + prefix_cache_len_ = reader.GetInteger("llama", "prefix_cache_len", 0); + attn_bias_ = reader.GetInteger("llama", "attn_bias", 0); + quant_policy_ = reader.GetInteger("llama", "quant_policy", 0); + + handleMissingParams(); + + if (max_context_token_num_ <= max_batch_size_) { + max_context_token_num_ *= session_len_; + } + + shared_state_ = std::make_shared::SharedState>(); + shared_state_->barrier = std::make_shared(tensor_para_size); + + const auto device_count = ft::getDeviceCount(); + shared_instances_.resize(device_count); + shared_mutexes_.resize(device_count); + + const std::string weight_type_str = reader.Get("llama", "weight_type"); + if (weight_type_str == "fp16") { + weight_type_ = ft::WeightType::kFP16; + } + else if (weight_type_str == "fp32") { + weight_type_ = ft::WeightType::kFP32; + } + else if (weight_type_str == "int8") { + weight_type_ = ft::WeightType::kINT8; + } + else if (weight_type_str == "int4") { + weight_type_ = ft::WeightType::kINT4; + } + else { + std::cout << "[ERROR] Unsupported weight type: '" << weight_type_str << "'\n"; + ft::FT_CHECK(0); + } +} + +template +std::unique_ptr> LlamaTritonModel::createSharedModelInstance( + int device_id, + int rank, + std::pair, std::vector> nccl_params, + std::shared_ptr custom_all_reduce_comm) +{ + ft::check_cuda_error(cudaSetDevice(device_id)); + const int comms_rank = device_id % (tensor_para_size_ * pipeline_para_size_); + + std::unique_ptr> allocator( + new ft::Allocator(device_id)); + + /// TODO: this stream handle is leaked + cudaStream_t stream{}; + ft::check_cuda_error(cudaStreamCreate(&stream)); + + allocator->setStream(stream); + + cublasHandle_t cublas_handle; + cublasLtHandle_t cublaslt_handle; + + cublasCreate(&cublas_handle); + cublasLtCreate(&cublaslt_handle); + cublasSetStream(cublas_handle, stream); + + std::unique_ptr cublas_algo_map(new ft::cublasAlgoMap("gemm_config.in")); + std::unique_ptr cublas_wrapper_mutex(new std::mutex()); + std::unique_ptr cublas_wrapper(new ft::cublasMMWrapper( + cublas_handle, cublaslt_handle, stream, cublas_algo_map.get(), cublas_wrapper_mutex.get(), allocator.get())); + + std::unique_ptr cuda_device_prop_ptr(new cudaDeviceProp); + ft::check_cuda_error(cudaGetDeviceProperties(cuda_device_prop_ptr.get(), device_id)); + + if (std::is_same::value) { + cublas_wrapper->setGemmConfig(CUDA_R_16F, CUDA_R_16F, CUDA_R_16F, CUDA_R_32F); + } + else if (std::is_same::value) { + cublas_wrapper->setFP32GemmConfig(); + } + + ft::NcclParam tensor_para = nccl_params.first[comms_rank]; + ft::NcclParam pipeline_para = nccl_params.second[comms_rank]; + + ft::FT_CHECK(tensor_para.world_size_ == tensor_para_size_); + ft::FT_CHECK(pipeline_para.world_size_ = pipeline_para_size_); + + auto llama = std::make_unique>(head_num_, + kv_head_num_, + size_per_head_, + inter_size_, + num_layer_, + vocab_size_, + rotary_embedding_dim_, + norm_eps_, + max_batch_size_, + max_context_token_num_, + session_len_, + step_length_, + start_id_, + end_id_, + cache_max_entry_count_, + cache_chunk_size_, + quant_policy_, + use_context_fmha_, + shared_state_, + shared_weights_[device_id].get(), + tensor_para, + stream, + cublas_wrapper.get(), + allocator.get(), + false, // is_free_buffer_after_forward, + cuda_device_prop_ptr.get()); + + return std::make_unique>( + LlamaTritonSharedModelInstance{std::move(llama), + shared_weights_[device_id], + std::move(allocator), + std::move(cublas_algo_map), + std::move(cublas_wrapper_mutex), + std::move(cublas_wrapper), + std::move(cuda_device_prop_ptr), + session_len_}); +} + +template +std::unique_ptr +LlamaTritonModel::createModelInstance(int device_id, + int rank, + cudaStream_t stream, + std::pair, std::vector> nccl_params, + std::shared_ptr custom_all_reduce_comm) +{ + ft::check_cuda_error(cudaSetDevice(device_id)); + // const int comms_rank = device_id % (tensor_para_size_ * pipeline_para_size_); + + std::shared_ptr> instance; + { + std::lock_guard lock(shared_mutexes_[device_id]); + instance = shared_instances_[device_id].lock(); + if (!instance) { + instance = createSharedModelInstance(device_id, rank, nccl_params, custom_all_reduce_comm); + shared_instances_[device_id] = instance; + } + } + + std::unique_ptr> allocator( + new ft::Allocator(device_id)); + + allocator->setStream(stream); + + return std::make_unique>(instance, std::move(allocator)); +} + +template +void LlamaTritonModel::createSharedWeights(int device_id, int rank) +{ + ft::check_cuda_error(cudaSetDevice(device_id)); + const int tensor_para_rank = rank % tensor_para_size_; + const int pipeline_para_rank = rank / tensor_para_size_; + ft::FT_CHECK(pipeline_para_size_ == 1 && pipeline_para_rank == 0); + shared_weights_[device_id] = std::make_shared>(head_num_, + kv_head_num_, + size_per_head_, + inter_size_, + vocab_size_, + num_layer_, + weight_type_, + attn_bias_, + tensor_para_size_, + tensor_para_rank, + prefix_cache_len_); + shared_weights_[device_id]->loadModel(model_dir_); + return; +} + +template +std::string LlamaTritonModel::toString() +{ + std::stringstream ss; + ss << "Model: " + << "\nhead_num: " << head_num_ << "\nkv_head_num: " << kv_head_num_ << "\nsize_per_head: " << size_per_head_ + << "\ninter_size: " << inter_size_ << "\nnum_layer: " << num_layer_ << "\nvocab_size: " << vocab_size_ + << "\nattn_bias: " << attn_bias_ << "\nmax_batch_size: " << max_batch_size_ + << "\nmax_context_token_num: " << max_context_token_num_ << "\nsession_len: " << session_len_ + << "\nstep_length: " << step_length_ << "\ncache_max_entry_count: " << cache_max_entry_count_ + << "\ncache_chunk_size: " << cache_chunk_size_ << "\nuse_context_fmha: " << use_context_fmha_ + << "\nstart_id: " << start_id_ << "\ntensor_para_size: " << tensor_para_size_ + << "\npipeline_para_size: " << pipeline_para_size_ << "\nenable_custom_all_reduce: " << enable_custom_all_reduce_ + << "\nmodel_name: " << model_name_ << "\nprefix_cache_len: " << prefix_cache_len_ + << "\nmodel_dir: " << model_dir_ << "\nquant_policy: " << quant_policy_ << std::endl; + + return ss.str(); +} + +template +void LlamaTritonModel::createCustomComms( + std::vector>* custom_all_reduce_comms, int world_size) +{ + using commDataType = typename ft::CustomARCommTypeConverter::Type; + ft::initCustomAllReduceComm(custom_all_reduce_comms, enable_custom_all_reduce_, world_size); +} + +template +std::pair, std::vector> +LlamaTritonModel::createNcclParams(const int node_id, const int device_id_start, const bool multi_node) +{ + const auto device_count = ft::getDeviceCount(); + bool need_nccl_params = false; + // create nccl group when there are non-occupied devices + for (int i = 0; i < device_count; ++i) { + std::lock_guard lock(shared_mutexes_[i]); + if (shared_instances_[i].expired()) { + need_nccl_params = true; + break; + } + } + if (need_nccl_params) { + return AbstractTransformerModel::createNcclParams(node_id, device_id_start, multi_node); + } + else { + FT_LOG_INFO("Skipping NCCL param creation."); + + const int tensor_para_size = getTensorParaSize(); + const int pipeline_para_size = getPipelineParaSize(); + const int local_comm_size = multi_node ? device_count : tensor_para_size * pipeline_para_size; + + std::vector tensor_para_params(local_comm_size); + std::vector pipeline_para_params(local_comm_size); + return {std::move(tensor_para_params), std::move(pipeline_para_params)}; + } +} + +template +std::unique_ptr LlamaTritonModel::createInstanceComm(int size) +{ + return std::make_unique(size); +} + +template +int LlamaTritonModel::getTensorParaSize() +{ + return tensor_para_size_; +} + +template +int LlamaTritonModel::getPipelineParaSize() +{ + return pipeline_para_size_; +} + +template struct LlamaTritonModel; +template struct LlamaTritonModel; diff --git a/src/fastertransformer/triton_backend/llama/LlamaTritonModel.h b/src/fastertransformer/triton_backend/llama/LlamaTritonModel.h new file mode 100644 index 000000000..a78bf1be0 --- /dev/null +++ b/src/fastertransformer/triton_backend/llama/LlamaTritonModel.h @@ -0,0 +1,124 @@ +/* + * Copyright (c) OpenMMLab. All rights reserved. + * Copyright (c) 2021-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Modified from +// https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/triton_backend/multi_gpu_gpt/ParallelGptTritonModel.h + +#pragma once + +#include "src/fastertransformer/models/llama/LlamaV2.h" +#include "src/fastertransformer/triton_backend/llama/LlamaTritonModelInstance.h" +#include "src/fastertransformer/triton_backend/transformer_triton_backend.hpp" +#include "src/fastertransformer/utils/cuda_utils.h" +#include "src/fastertransformer/utils/custom_ar_comm.h" +#include "src/fastertransformer/utils/nccl_utils.h" +#include +#include + +namespace ft = fastertransformer; + +template +struct LlamaTritonSharedModelInstance; + +template +struct LlamaTritonModel: public AbstractTransformerModel { + LlamaTritonModel(size_t tensor_para_size, + size_t pipeline_para_size, + int enable_custom_all_reduce, + std::string model_dir); + + ~LlamaTritonModel() = default; + + std::unique_ptr + createModelInstance(int deviceId, + int rank, + cudaStream_t stream, + std::pair, std::vector> nccl_params, + std::shared_ptr custom_all_reduce_comm = nullptr) override; + + void createSharedWeights(int deviceId, int rank) override; + + void createCustomComms(std::vector>* custom_all_reduce_comms, + int world_size) override; + + std::pair, std::vector> + createNcclParams(const int node_id, const int device_id_start, const bool multi_node) override; + + std::unique_ptr createInstanceComm(int size) override; + + void handleMissingParams(); + + std::string toString() override; + int getTensorParaSize() override; + int getPipelineParaSize() override; + +private: + std::unique_ptr> + createSharedModelInstance(int deviceId, + int rank, + std::pair, std::vector> nccl_params, + std::shared_ptr custom_all_reduce_comm = nullptr); + + size_t head_num_; + size_t kv_head_num_; + size_t size_per_head_; + size_t inter_size_; + size_t num_layer_; + size_t vocab_size_; + size_t rotary_embedding_dim_; + float norm_eps_; + int max_batch_size_; + int max_context_token_num_; + int session_len_; + int step_length_; + int start_id_; + int end_id_; + int cache_max_entry_count_; + int cache_chunk_size_; + int use_context_fmha_; + size_t tensor_para_size_; + size_t pipeline_para_size_; + ft::WeightType weight_type_; + bool attn_bias_; + int quant_policy_; + + size_t prefix_cache_len_{}; + + // shared weights for each device + std::vector>> shared_weights_; + + std::shared_ptr::SharedState> shared_state_; + + // weak_ptr is used so that the instances get released when all strong references are gone + std::vector>> shared_instances_; + std::deque shared_mutexes_; // is locking really needed? + + // // residual type + // bool use_gptj_residual_ = true; + + // // number of tasks (for prefix-prompt, p/prompt-tuning) + // size_t num_tasks_ = 0; + // int prompt_learning_start_id_ = 0; + // ft::PromptLearningType prompt_learning_type_ = ft::PromptLearningType::no_prompt; + // std::map> prompt_learning_table_pair_ = {}; + + bool is_fp16_; + int enable_custom_all_reduce_ = 0; + + std::string model_name_; + std::string model_dir_; +}; diff --git a/src/fastertransformer/triton_backend/llama/LlamaTritonModelInstance.cc b/src/fastertransformer/triton_backend/llama/LlamaTritonModelInstance.cc new file mode 100644 index 000000000..35cacd3f0 --- /dev/null +++ b/src/fastertransformer/triton_backend/llama/LlamaTritonModelInstance.cc @@ -0,0 +1,298 @@ +/* + * Copyright (c) OpenMMLab. All rights reserved. + * Copyright (c) 2021-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Modified from +// https://github.com/NVIDIA/FasterTransformer/blob/main/src/turbomind/triton_backend/multi_gpu_gpt/ParallelGptTritonModel.h + +#include "src/fastertransformer/triton_backend/llama/LlamaTritonModelInstance.h" +#include "src/fastertransformer/triton_backend/transformer_triton_backend.hpp" +#include "src/fastertransformer/triton_backend/triton_utils.hpp" +#include "src/fastertransformer/utils/Tensor.h" +#include "src/fastertransformer/utils/cuda_utils.h" +#include +#include +#include +#include +#include +#include + +namespace ft = fastertransformer; + +template +void triton_stream_callback(std::unordered_map* output_tensors, void* ctx) +{ + LlamaTritonModelInstance* model = reinterpret_cast*>(ctx); + auto result = LlamaTritonModelInstance::convert_outputs(*output_tensors); + + model->stream_cb_(result, model->stream_ctx_); +} + +template +LlamaTritonModelInstance::LlamaTritonModelInstance( + std::shared_ptr> instance, + std::unique_ptr> allocator): + instance_(std::move(instance)), allocator_(std::move(allocator)) +{ +} + +template +std::unordered_map LlamaTritonModelInstance::convert_inputs( + std::shared_ptr> input_tensors) +{ + FT_LOG_DEBUG(__PRETTY_FUNCTION__); + + move_tensor_H2D(input_tensors->at("input_ids"), d_input_ids_, &allocator_); + move_tensor_H2D(input_tensors->at("input_lengths"), d_input_lengths_, &allocator_); + + const size_t request_batch_size = input_tensors->at("input_ids").shape[0]; + const size_t input_data_len = input_tensors->at("input_ids").shape[1]; + // freed in forward() + h_total_output_lengths_ = reinterpret_cast(malloc(request_batch_size * sizeof(uint32_t))); + + std::unordered_map ft_input_tensors = std::unordered_map{ + {"input_ids", as_GPU_tensor(input_tensors->at("input_ids"), d_input_ids_)}, + // {"input_lengths", as_GPU_tensor(input_tensors->at("input_lengths"), d_input_lengths_)}, + }; + + if (input_tensors->find("bad_words_list") != input_tensors->end()) { + move_tensor_H2D(input_tensors->at("bad_words_list"), d_input_bad_words_, &allocator_); + ft_input_tensors.insert( + {"bad_words_list", as_GPU_tensor(input_tensors->at("bad_words_list"), d_input_bad_words_)}); + } + + if (input_tensors->find("stop_words_list") != input_tensors->end()) { + move_tensor_H2D(input_tensors->at("stop_words_list"), d_input_stop_words_, &allocator_); + ft_input_tensors.insert( + {"stop_words_list", as_GPU_tensor(input_tensors->at("stop_words_list"), d_input_stop_words_)}); + } + + if (input_tensors->count("request_prompt_embedding") && input_tensors->count("request_prompt_lengths") + && input_tensors->count("request_prompt_type")) { + + move_tensor_H2D(input_tensors->at("request_prompt_lengths"), d_request_prompt_lengths_, &allocator_); + ft_input_tensors.insert( + {"request_prompt_lengths", + as_GPU_tensor(input_tensors->at("request_prompt_lengths"), d_request_prompt_lengths_)}); + + move_tensor_H2D(input_tensors->at("request_prompt_embedding"), d_request_prompt_embedding_, &allocator_); + ft_input_tensors.insert( + {"request_prompt_embedding", + as_GPU_tensor(input_tensors->at("request_prompt_embedding"), d_request_prompt_embedding_)}); + } + + if (input_tensors->find("top_p_decay") != input_tensors->end()) { + move_tensor_H2D(input_tensors->at("top_p_decay"), d_top_p_decay_, &allocator_); + ft_input_tensors.insert({"top_p_decay", as_GPU_tensor(input_tensors->at("top_p_decay"), d_top_p_decay_)}); + } + if (input_tensors->find("top_p_min") != input_tensors->end()) { + move_tensor_H2D(input_tensors->at("top_p_min"), d_top_p_min_, &allocator_); + ft_input_tensors.insert({"top_p_min", as_GPU_tensor(input_tensors->at("top_p_min"), d_top_p_min_)}); + } + if (input_tensors->find("top_p_reset_ids") != input_tensors->end()) { + move_tensor_H2D(input_tensors->at("top_p_reset_ids"), d_top_p_reset_ids_, &allocator_); + ft_input_tensors.insert( + {"top_p_reset_ids", as_GPU_tensor(input_tensors->at("top_p_reset_ids"), d_top_p_reset_ids_)}); + } + + for (auto t = input_tensors->begin(); t != input_tensors->end(); ++t) { + if (t->first.find("input_ids") == std::string::npos // && t->first.find("input_lengths") == std::string::npos + && t->first.find("output_seq_len") == std::string::npos + && t->first.find("prefix_soft_prompt_embedding") == std::string::npos + && t->first.find("prefix_soft_prompt_lengths") == std::string::npos) { + if (ft_input_tensors.count(t->first) == 0) { + ft_input_tensors.insert({t->first, t->second.convertTritonTensorToFt()}); + } + } + } + + return ft_input_tensors; +} + +template +std::shared_ptr> +LlamaTritonModelInstance::convert_outputs(const std::unordered_map& output_tensors) +{ + FT_LOG_DEBUG(__PRETTY_FUNCTION__); + std::unordered_map* outputs_mapping = + new std::unordered_map(); + + for (auto it = output_tensors.begin(); it != output_tensors.end(); it++) { + outputs_mapping->insert({it->first, triton::Tensor::convertFtTensorToTriton(it->second)}); + } + + return std::shared_ptr>(outputs_mapping); +} + +template +std::shared_ptr> +LlamaTritonModelInstance::forward(std::shared_ptr> input_tensors) +{ + ft::FT_CHECK(false); + return nullptr; +} + +template +std::shared_ptr> +LlamaTritonModelInstance::forward(std::shared_ptr> input_tensors) +{ + ft::FT_CHECK(false); + return nullptr; +} + +template +std::string format_vector(const std::vector& vec) +{ + std::stringstream ss; + ss << "["; + bool first = true; + for (const auto& x : vec) { + ss << (first ? "" : ", ") << x; + first = false; + } + ss << "]"; + return ss.str(); +} + +template +std::shared_ptr> +LlamaTritonModelInstance::forward(std::shared_ptr> input_tensors, + ft::AbstractInstanceComm* instance_comm) +{ + FT_LOG_DEBUG(__PRETTY_FUNCTION__); + // for (const auto& kv : *input_tensors) { + // FT_LOG_INFO("%s: %s", kv.first.c_str(), format_vector(kv.second.shape).c_str()); + // } + + FT_CHECK_WITH_INFO(input_tensors->at("input_ids").shape.size() == 2, + "input_tensors->at(\"input_ids\").shape.size() == 2"); + FT_CHECK_WITH_INFO(input_tensors->at("input_lengths").shape.size() == 1, + "input_tensors->at(\"input_lengths\").shape.size() == 1"); + + const uint32_t request_batch_size = input_tensors->at("input_ids").shape[0]; + const uint32_t max_request_output_len = (size_t)*std::max_element( + (int*)input_tensors->at("request_output_len").data, + (int*)input_tensors->at("request_output_len").data + input_tensors->at("request_output_len").shape[0]); + // const uint32_t total_output_len = max_request_output_len + input_tensors->at("input_ids").shape[1]; + const uint32_t beam_width = + input_tensors->count("beam_width") ? (size_t)(*(uint*)input_tensors->at("beam_width").data) : 1; + FT_CHECK_WITH_INFO(beam_width == 1, "Beam search is not implemented"); + + std::unordered_map ft_input_tensors = convert_inputs(input_tensors); + + const size_t max_input_len = input_tensors->at("input_ids").shape[1]; + const bool is_return_logits = + input_tensors->count("is_return_logits") && *(bool*)input_tensors->at("is_return_logits").data; + + const size_t vocab_size = instance_->llm->vocab_size(); + + allocateBuffer(request_batch_size, max_input_len, beam_width, instance_->session_len, is_return_logits); + + std::unordered_map output_tensors = std::unordered_map{ + {"output_ids", + ft::Tensor{ft::MEMORY_GPU, + ft::TYPE_UINT32, + std::vector{request_batch_size, beam_width, (size_t)instance_->session_len}, + d_output_ids_}}, + {"sequence_length", + ft::Tensor{ft::MEMORY_GPU, + ft::TYPE_UINT32, + std::vector{request_batch_size, beam_width}, + d_sequence_lengths_}}}; + + if (input_tensors->count("is_return_log_probs") && *((bool*)input_tensors->at("is_return_log_probs").data)) { + output_tensors.insert({"output_log_probs", + ft::Tensor{ft::MEMORY_GPU, + ft::TYPE_FP32, + std::vector{request_batch_size, beam_width, max_request_output_len}, + d_output_log_probs_}}); + output_tensors.insert({"cum_log_probs", + ft::Tensor{ft::MEMORY_GPU, + ft::TYPE_FP32, + std::vector{request_batch_size, beam_width}, + d_cum_log_probs_}}); + } + + if (is_return_logits) { + output_tensors.insert( + {"logits", + {ft::MEMORY_GPU, ft::TYPE_FP32, {request_batch_size, max_input_len, vocab_size}, d_output_logits_}}); + } + + try { + ft::Request::Callback callback; + + if (stream_cb_) { + callback = [this](std::unordered_map* outputs) { + triton_stream_callback(outputs, this); + }; + } + + ft::check_cuda_error(cudaStreamSynchronize(allocator_->returnStream())); + instance_->llm->forward(&output_tensors, &ft_input_tensors, {instance_comm, callback}); + // ! stream synced by the model before returning + } + catch (...) { + h_exception_ = std::current_exception(); + output_tensors.insert({"error_message", ft::Tensor{ft::MEMORY_CPU, ft::TYPE_BYTES, {1}, &h_exception_}}); + } + + if (h_total_output_lengths_ != nullptr) { + free(h_total_output_lengths_); + h_total_output_lengths_ = nullptr; + } + + return convert_outputs(output_tensors); +} + +template +LlamaTritonModelInstance::~LlamaTritonModelInstance() +{ + freeBuffer(); +} + +template +void LlamaTritonModelInstance::allocateBuffer(const size_t request_batch_size, + const size_t max_input_len, + const size_t beam_width, + const size_t session_len, + const bool is_return_logits) +{ + d_output_ids_ = + (int*)(allocator_->reMalloc(d_output_ids_, sizeof(int) * request_batch_size * beam_width * session_len, false)); + d_sequence_lengths_ = + (int*)(allocator_->reMalloc(d_sequence_lengths_, sizeof(int) * request_batch_size * beam_width, false)); + d_output_log_probs_ = (float*)(allocator_->reMalloc( + d_output_log_probs_, sizeof(float) * request_batch_size * beam_width * session_len, false)); + d_cum_log_probs_ = + (float*)(allocator_->reMalloc(d_cum_log_probs_, sizeof(float) * request_batch_size * beam_width, false)); + if (is_return_logits) { + d_output_logits_ = (float*)allocator_->reMalloc( + d_output_logits_, sizeof(float) * request_batch_size * max_input_len * instance_->llm->vocab_size(), false); + } +} + +template +void LlamaTritonModelInstance::freeBuffer() +{ + allocator_->free((void**)(&d_output_ids_)); + allocator_->free((void**)(&d_sequence_lengths_)); + allocator_->free((void**)(&d_output_log_probs_)); + allocator_->free((void**)(&d_cum_log_probs_)); +} + +template struct LlamaTritonModelInstance; +template struct LlamaTritonModelInstance; diff --git a/src/fastertransformer/triton_backend/llama/LlamaTritonModelInstance.h b/src/fastertransformer/triton_backend/llama/LlamaTritonModelInstance.h new file mode 100644 index 000000000..8ae8ab332 --- /dev/null +++ b/src/fastertransformer/triton_backend/llama/LlamaTritonModelInstance.h @@ -0,0 +1,94 @@ +/* + * Copyright (c) OpenMMLab. All rights reserved. + * Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Modified from +// https://github.com/NVIDIA/FasterTransformer/blob/main/src/turbomind/triton_backend/multi_gpu_gpt/ParallelGptTritonModel.h + +#pragma once + +#include "src/fastertransformer/models/llama/LlamaV2.h" +#include "src/fastertransformer/triton_backend/llama/LlamaTritonModel.h" +#include "src/fastertransformer/triton_backend/transformer_triton_backend.hpp" +#include + +namespace ft = fastertransformer; + +template +struct LlamaTritonSharedModelInstance { + std::unique_ptr> llm; + std::shared_ptr> llm_weight; + std::unique_ptr> allocator; + std::unique_ptr cublas_algo_map; + std::unique_ptr cublas_wrapper_mutex; + std::unique_ptr cublas_wrapper; + std::unique_ptr cuda_device_prop_ptr; + const int session_len; +}; + +template +struct LlamaTritonModelInstance: AbstractTransformerModelInstance { + + LlamaTritonModelInstance(std::shared_ptr> instance, + std::unique_ptr> allocator); + ~LlamaTritonModelInstance(); + + std::shared_ptr> + forward(std::shared_ptr> input_tensors) override; + + std::shared_ptr> + forward(std::shared_ptr> input_tensors) override; + + std::shared_ptr> + forward(std::shared_ptr> input_tensors, + ft::AbstractInstanceComm*) override; + + static std::shared_ptr> + convert_outputs(const std::unordered_map& output_tensors); + +private: + const std::shared_ptr> instance_; + const std::unique_ptr> allocator_; + + std::unordered_map + convert_inputs(std::shared_ptr> input_tensors); + + void allocateBuffer(const size_t request_batch_size, + const size_t max_input_len, + const size_t beam_width, + const size_t session_len, + const bool is_return_logits); + void freeBuffer(); + + int* d_input_ids_ = nullptr; + int* d_input_lengths_ = nullptr; + int* d_input_bad_words_ = nullptr; + int* d_input_stop_words_ = nullptr; + int* d_request_prompt_lengths_ = nullptr; + T* d_request_prompt_embedding_ = nullptr; + float* d_top_p_decay_ = nullptr; + float* d_top_p_min_ = nullptr; + int* d_top_p_reset_ids_ = nullptr; + + int* d_output_ids_ = nullptr; + int* d_sequence_lengths_ = nullptr; + float* d_output_log_probs_ = nullptr; + float* d_cum_log_probs_ = nullptr; + float* d_output_logits_ = nullptr; + + uint32_t* h_total_output_lengths_ = nullptr; + std::exception_ptr h_exception_ = nullptr; +}; diff --git a/src/fastertransformer/triton_backend/transformer_triton_backend.hpp b/src/fastertransformer/triton_backend/transformer_triton_backend.hpp index 47cf6750c..cc5bce919 100644 --- a/src/fastertransformer/triton_backend/transformer_triton_backend.hpp +++ b/src/fastertransformer/triton_backend/transformer_triton_backend.hpp @@ -23,6 +23,7 @@ #include "src/fastertransformer/utils/Tensor.h" #include "src/fastertransformer/utils/custom_ar_comm.h" +#include "src/fastertransformer/utils/instance_comm.h" #include "src/fastertransformer/utils/mpi_utils.h" #include "src/fastertransformer/utils/nccl_utils.h" @@ -270,6 +271,13 @@ struct AbstractTransformerModelInstance { virtual std::shared_ptr> forward(std::shared_ptr> input_tensors) = 0; + virtual std::shared_ptr> + forward(std::shared_ptr> input_tensors, + ft::AbstractInstanceComm*) + { + return forward(input_tensors); + } + void registerCallback(triton_stream_cb_t* cb, void* ctx) { stream_cb_ = cb; @@ -293,13 +301,19 @@ struct AbstractTransformerModel { static std::shared_ptr createGptNeoXModel(std::string inifile); static std::shared_ptr createT5Model(std::string model_dir); static std::shared_ptr createT5EncoderModel(std::string model_dir); + static std::shared_ptr createLlamaModel(std::string model_dir); - std::pair, std::vector> + virtual std::pair, std::vector> createNcclParams(const int node_id, const int device_id_start = 0, const bool multi_node = false); virtual void createCustomComms(std::vector>* custom_all_reduce_comms, int world_size) = 0; + virtual std::unique_ptr createInstanceComm(int size) + { + return nullptr; + } + virtual std::unique_ptr createModelInstance(int deviceId, int rank, diff --git a/src/fastertransformer/utils/Tensor.h b/src/fastertransformer/utils/Tensor.h index d3511c9a6..cd084e5b1 100644 --- a/src/fastertransformer/utils/Tensor.h +++ b/src/fastertransformer/utils/Tensor.h @@ -498,6 +498,21 @@ class TensorMap { return tensor_map_.end(); } + inline std::unordered_map& get() + { + return tensor_map_; + } + + inline std::unordered_map::const_iterator begin() const + { + return tensor_map_.begin(); + } + + inline std::unordered_map::const_iterator end() const + { + return tensor_map_.end(); + } + std::string toString(); static TensorMap fromNpyFolder(const std::string& base_folder); void saveNpy(const std::string& base_folder); diff --git a/src/fastertransformer/utils/instance_comm.h b/src/fastertransformer/utils/instance_comm.h new file mode 100644 index 000000000..28c328e25 --- /dev/null +++ b/src/fastertransformer/utils/instance_comm.h @@ -0,0 +1,16 @@ +#pragma once + +namespace fastertransformer { + +class AbstractInstanceComm { +public: + virtual ~AbstractInstanceComm() = default; + + virtual void barrier() = 0; + + virtual void setSharedObject(void*) = 0; + + virtual void* getSharedObject() = 0; +}; + +} // namespace fastertransformer diff --git a/src/fastertransformer/utils/memory_utils.cu b/src/fastertransformer/utils/memory_utils.cu index 134224a09..56f1a879c 100644 --- a/src/fastertransformer/utils/memory_utils.cu +++ b/src/fastertransformer/utils/memory_utils.cu @@ -297,53 +297,158 @@ template void cudaRandomUniform(__nv_fp8_e4m3* buffer, const size_t size); // loads data from binary file. If it succeeds, returns a non-empty vector. If loading fails or // the product of the elements in shape is 0, this function will return an empty vector. template -std::vector loadWeightFromBinHelper(std::vector shape, std::string filename) +std::vector +loadWeightFromBinHelper(std::vector shape, std::string filename, std::vector slices = {}) { if (shape.size() > 2) { printf("[ERROR] shape should have less than two dims \n"); return std::vector(); } + size_t dim0 = shape[0], dim1 = 1; if (shape.size() == 2) { dim1 = shape[1]; } - size_t size = dim0 * dim1; - if (size == 0) { - FT_LOG_WARNING("shape is zero, skip loading weight from file %s \n", filename.c_str()); - return std::vector(); - } - std::vector host_array(size); - std::ifstream in(filename, std::ios::in | std::ios::binary); - if (!in.is_open()) { - FT_LOG_WARNING("file %s cannot be opened, loading model fails! \n", filename.c_str()); - return std::vector(); + if (slices.size() == 0) { + size_t size = dim0 * dim1; + if (size == 0) { + FT_LOG_WARNING("shape is zero, skip loading weight from file %s \n", filename.c_str()); + return std::vector(); + } + + std::vector host_array(size); + std::ifstream in(filename, std::ios::in | std::ios::binary); + if (!in.is_open()) { + FT_LOG_WARNING("file %s cannot be opened, loading model fails! \n", filename.c_str()); + return std::vector(); + } + + size_t loaded_data_size = sizeof(T) * size; + in.seekg(0, in.end); + in.seekg(0, in.beg); + + FT_LOG_DEBUG("Read " + std::to_string(loaded_data_size) + " bytes from " + filename); + in.read((char*)host_array.data(), loaded_data_size); + + size_t in_get_size = in.gcount(); + if (in_get_size != loaded_data_size) { + FT_LOG_WARNING("file %s only has %ld, but request %ld, loading model fails! \n", + filename.c_str(), + in_get_size, + loaded_data_size); + return std::vector(); + } + in.close(); + // If we succeed, return an array with values. + return host_array; } + else { + // concate all slices on the same dims - size_t loaded_data_size = sizeof(T) * size; - in.seekg(0, in.end); - in.seekg(0, in.beg); + if (slices.size() != shape.size()) { + printf("[ERROR] slices should have same dims as shape \n"); + return std::vector(); + } - FT_LOG_DEBUG("Read " + std::to_string(loaded_data_size) + " bytes from " + filename); - in.read((char*)host_array.data(), loaded_data_size); + // get slices + ConcateSlice slice0{.slices = {{0, dim0}}}; + ConcateSlice slice1{.slices = {{0, dim1}}}; + if (slices.size() > 0 && slices[0].slices.size() > 0) { + slice0 = slices[0]; + } + if (shape.size() == 2 && slices[1].slices.size() > 0) { + slice1 = slices[1]; + } - size_t in_get_size = in.gcount(); - if (in_get_size != loaded_data_size) { - FT_LOG_WARNING("file %s only has %ld, but request %ld, loading model fails! \n", - filename.c_str(), - in_get_size, - loaded_data_size); - return std::vector(); + size_t w0 = 0; + for (auto& s : slice0.slices) { + if (s.second > dim0) { + s.second = dim0; + } + if (s.second < s.first) { + printf("[ERROR] slice0: end < start \n"); + return std::vector(); + } + w0 += s.second - s.first; + } + + size_t w1 = 0; + for (auto& s : slice1.slices) { + if (s.second > dim1) { + s.second = dim1; + } + if (s.second < s.first) { + printf("[ERROR] slice1: end < start \n"); + return std::vector(); + } + w1 += s.second - s.first; + } + + size_t size = w0 * w1; + size_t loaded_data_size = size * sizeof(T); + + FT_LOG_DEBUG("Read " + std::to_string(loaded_data_size) + " bytes from " + filename + " with slice."); + if (size == 0) { + FT_LOG_WARNING("shape is zero, skip loading weight from file %s \n", filename.c_str()); + return std::vector(); + } + + std::vector host_array(size); + std::ifstream in(filename, std::ios::in | std::ios::binary); + if (!in.is_open()) { + FT_LOG_WARNING("file %s cannot be opened, loading model fails! \n", filename.c_str()); + return std::vector(); + } + + char* host_ptr = (char*)host_array.data(); + if (slice1.slices.size() == 0 + || (slice1.slices.size() == 1 && slice1.slices[0].second - slice1.slices[0].first == dim1)) { + for (auto& s : slice0.slices) { + size_t read_size = (s.second - s.first) * dim1 * sizeof(T); + size_t pos = s.first * dim1; + in.seekg(pos * sizeof(T)); + in.read((char*)host_ptr, read_size); + host_ptr += read_size; + } + in.close(); + return host_array; + } + + { + for (auto& s0 : slice0.slices) { + // loop over outer slice + for (size_t line_id = s0.first; line_id < s0.second; ++line_id) { + // loop over lines + size_t pos0 = line_id * dim1; + for (auto& s1 : slice1.slices) { + // loop over inner slice + size_t pos = pos0 + s1.first; + size_t read_size = (s1.second - s1.first) * sizeof(T); + in.seekg(pos * sizeof(T)); + in.read(host_ptr, read_size); + host_ptr += read_size; + } + } + } + in.close(); + } + return host_array; } - in.close(); - // If we succeed, return an array with values. - return host_array; +} + +std::vector loadArrayFromBin(std::vector shape, std::string filename, std::vector slices) +{ + return loadWeightFromBinHelper(shape, filename, slices); } template -int loadWeightFromBinFunc(T* ptr, std::vector shape, std::string filename) +int loadWeightFromBinFunc(T* ptr, + std::vector shape, + std::string filename, + std::vector slices = std::vector()) { - std::vector host_array = loadWeightFromBinHelper(shape, filename); + std::vector host_array = loadWeightFromBinHelper(shape, filename, slices); if (host_array.empty()) { return 0; @@ -362,49 +467,84 @@ int loadWeightFromBinFunc(T* ptr, std::vector shape, std::string filenam return 0; } -template int loadWeightFromBinFunc(float* ptr, std::vector shape, std::string filename); -template int loadWeightFromBinFunc(half* ptr, std::vector shape, std::string filename); -template int loadWeightFromBinFunc(float* ptr, std::vector shape, std::string filename); -template int loadWeightFromBinFunc(half* ptr, std::vector shape, std::string filename); -template int loadWeightFromBinFunc(int8_t* ptr, std::vector shape, std::string filename); +template int loadWeightFromBinFunc(float* ptr, + std::vector shape, + std::string filename, + std::vector slices); +template int loadWeightFromBinFunc(half* ptr, + std::vector shape, + std::string filename, + std::vector slices); +template int loadWeightFromBinFunc(float* ptr, + std::vector shape, + std::string filename, + std::vector slices); +template int loadWeightFromBinFunc(half* ptr, + std::vector shape, + std::string filename, + std::vector slices); +template int loadWeightFromBinFunc(int8_t* ptr, + std::vector shape, + std::string filename, + std::vector slices); #ifdef ENABLE_BF16 -template int -loadWeightFromBinFunc<__nv_bfloat16, float>(__nv_bfloat16* ptr, std::vector shape, std::string filename); -template int -loadWeightFromBinFunc<__nv_bfloat16, half>(__nv_bfloat16* ptr, std::vector shape, std::string filename); -template int loadWeightFromBinFunc(float* ptr, std::vector shape, std::string filename); -template int loadWeightFromBinFunc(half* ptr, std::vector shape, std::string filename); -template int loadWeightFromBinFunc<__nv_bfloat16, __nv_bfloat16>(__nv_bfloat16* ptr, - std::vector shape, - std::string filename); +template int loadWeightFromBinFunc<__nv_bfloat16, float>(__nv_bfloat16* ptr, + std::vector shape, + std::string filename, + std::vector slices); +template int loadWeightFromBinFunc<__nv_bfloat16, half>(__nv_bfloat16* ptr, + std::vector shape, + std::string filename, + std::vector slices); +template int loadWeightFromBinFunc(float* ptr, + std::vector shape, + std::string filename, + std::vector slices); +template int loadWeightFromBinFunc(half* ptr, + std::vector shape, + std::string filename, + std::vector slices); +template int loadWeightFromBinFunc<__nv_bfloat16, __nv_bfloat16>(__nv_bfloat16* ptr, + std::vector shape, + std::string filename, + std::vector slices); #endif // ENABLE_BF16 -template int loadWeightFromBinFunc(int* ptr, std::vector shape, std::string filename); +template int loadWeightFromBinFunc(int* ptr, + std::vector shape, + std::string filename, + std::vector slices); #ifdef ENABLE_FP8 -template int -loadWeightFromBinFunc<__nv_fp8_e4m3, float>(__nv_fp8_e4m3* ptr, std::vector shape, std::string filename); +template int loadWeightFromBinFunc<__nv_fp8_e4m3, float>(__nv_fp8_e4m3* ptr, + std::vector shape, + std::string filename, + std::vector slices); #endif // ENABLE_FP8 template -int loadWeightFromBin(T* ptr, std::vector shape, std::string filename, FtCudaDataType model_file_type) +int loadWeightFromBin(T* ptr, + std::vector shape, + std::string filename, + FtCudaDataType model_file_type, + std::vector slices) { switch (model_file_type) { case FtCudaDataType::FP32: - loadWeightFromBinFunc(ptr, shape, filename); + loadWeightFromBinFunc(ptr, shape, filename, slices); break; case FtCudaDataType::FP16: - loadWeightFromBinFunc(ptr, shape, filename); + loadWeightFromBinFunc(ptr, shape, filename, slices); break; case FtCudaDataType::INT8: - loadWeightFromBinFunc(ptr, shape, filename); + loadWeightFromBinFunc(ptr, shape, filename, slices); break; #ifdef ENABLE_BF16 case FtCudaDataType::BF16: - loadWeightFromBinFunc(ptr, shape, filename); + loadWeightFromBinFunc(ptr, shape, filename, slices); break; #endif #ifdef ENABLE_FP8 case FtCudaDataType::FP8: - loadWeightFromBinFunc(ptr, shape, filename); + loadWeightFromBinFunc(ptr, shape, filename, slices); break; #endif default: @@ -415,28 +555,50 @@ int loadWeightFromBin(T* ptr, std::vector shape, std::string filename, F } template<> -int loadWeightFromBin(int* ptr, std::vector shape, std::string filename, FtCudaDataType model_file_type) +int loadWeightFromBin(int* ptr, + std::vector shape, + std::string filename, + FtCudaDataType model_file_type, + std::vector slices) { - loadWeightFromBinFunc(ptr, shape, filename); + loadWeightFromBinFunc(ptr, shape, filename, slices); return 0; } -template int -loadWeightFromBin(float* ptr, std::vector shape, std::string filename, FtCudaDataType model_file_type); -template int -loadWeightFromBin(half* ptr, std::vector shape, std::string filename, FtCudaDataType model_file_type); -template int -loadWeightFromBin(int8_t* ptr, std::vector shape, std::string filename, FtCudaDataType model_file_type); +template int loadWeightFromBin(float* ptr, + std::vector shape, + std::string filename, + FtCudaDataType model_file_type, + std::vector slices); +template int loadWeightFromBin(half* ptr, + std::vector shape, + std::string filename, + FtCudaDataType model_file_type, + std::vector slices); +template int loadWeightFromBin(int8_t* ptr, + std::vector shape, + std::string filename, + FtCudaDataType model_file_type, + std::vector slices); #ifdef ENABLE_BF16 -template int -loadWeightFromBin(__nv_bfloat16* ptr, std::vector shape, std::string filename, FtCudaDataType model_file_type); +template int loadWeightFromBin(__nv_bfloat16* ptr, + std::vector shape, + std::string filename, + FtCudaDataType model_file_type, + std::vector slices); #endif #ifdef ENABLE_FP8 -template int -loadWeightFromBin(__nv_fp8_e4m3* ptr, std::vector shape, std::string filename, FtCudaDataType model_file_type); +template int loadWeightFromBin(__nv_fp8_e4m3* ptr, + std::vector shape, + std::string filename, + FtCudaDataType model_file_type, + std::vector slices); #endif -template int -loadWeightFromBin(int* ptr, std::vector shape, std::string filename, FtCudaDataType model_file_type); +template int loadWeightFromBin(int* ptr, + std::vector shape, + std::string filename, + FtCudaDataType model_file_type, + std::vector slices); template int loadWeightFromBinAndQuantizeForWeightOnlyFunc(int8_t* ptr, diff --git a/src/fastertransformer/utils/memory_utils.h b/src/fastertransformer/utils/memory_utils.h index 9efd53f53..bfeadbe13 100644 --- a/src/fastertransformer/utils/memory_utils.h +++ b/src/fastertransformer/utils/memory_utils.h @@ -49,11 +49,20 @@ void cudaAutoCpy(T* tgt, const T* src, const size_t size, cudaStream_t stream = template void cudaRandomUniform(T* buffer, const size_t size); +struct ConcateSlice { + std::vector> slices; +}; + template -int loadWeightFromBin(T* ptr, - std::vector shape, - std::string filename, - FtCudaDataType model_file_type = FtCudaDataType::FP32); +int loadWeightFromBin(T* ptr, + std::vector shape, + std::string filename, + FtCudaDataType model_file_type = FtCudaDataType::FP32, + std::vector slices = std::vector()); + +std::vector loadArrayFromBin(std::vector shape, + std::string filename, + std::vector slices = std::vector()); template int loadWeightFromBinAndQuantizeForWeightOnly(int8_t* quantized_weight_ptr, diff --git a/src/fastertransformer/utils/nccl_utils.cc b/src/fastertransformer/utils/nccl_utils.cc index 29ac95979..c13c54bca 100644 --- a/src/fastertransformer/utils/nccl_utils.cc +++ b/src/fastertransformer/utils/nccl_utils.cc @@ -15,6 +15,7 @@ */ #include "src/fastertransformer/utils/nccl_utils.h" +#include namespace fastertransformer { @@ -417,6 +418,22 @@ void ftNcclInitialize(NcclParam& tensor_para, FT_LOG_DEBUG("%s stop", __PRETTY_FUNCTION__); } +static std::atomic& ncclGroupCount() +{ + static std::atomic value{}; + return value; +} + +int ftNcclNextGroupId() +{ + return ncclGroupCount()++; +} + +int ftNcclGroupCount() +{ + return ncclGroupCount(); +} + size_t getLocalBatchSize(const size_t batch_size, const size_t seq_len, const size_t pipeline_para_size) { size_t local_batch_size = batch_size; diff --git a/src/fastertransformer/utils/nccl_utils.h b/src/fastertransformer/utils/nccl_utils.h index cf5371487..eab2b602f 100644 --- a/src/fastertransformer/utils/nccl_utils.h +++ b/src/fastertransformer/utils/nccl_utils.h @@ -60,6 +60,7 @@ struct NcclUid { struct NcclParam { int rank_{0}; int world_size_{1}; + int group_id_{0}; #ifdef BUILD_MULTI_GPU ncclUniqueId nccl_uid_; ncclComm_t nccl_comm_ = nullptr; @@ -69,10 +70,15 @@ struct NcclParam { NcclParam(): rank_(0), world_size_(1), nccl_comm_(nullptr){}; NcclParam(int rank, int world_size): rank_(rank), world_size_(world_size){}; NcclParam(NcclParam const& param): - rank_(param.rank_), world_size_(param.world_size_), nccl_uid_(param.nccl_uid_), nccl_comm_(param.nccl_comm_){}; + rank_(param.rank_), + world_size_(param.world_size_), + group_id_(param.group_id_), + nccl_uid_(param.nccl_uid_), + nccl_comm_(param.nccl_comm_){}; std::string toString() { - return fmtstr("NcclParam[rank=%d, world_size=%d, nccl_comm=%p]", rank_, world_size_, nccl_comm_); + return fmtstr( + "NcclParam[rank=%d, world_size=%d, nccl_comm=%p, group_id=%d]", rank_, world_size_, nccl_comm_, group_id_); } #else NcclParam(): rank_(0), world_size_(1){}; @@ -111,6 +117,9 @@ void ftNcclGetUniqueId(NcclUid& uid); void ftNcclCommInitRank(NcclParam& param, const int rank, const int world_size, const NcclUid uid); void ftNcclParamDestroy(NcclParam& param); +int ftNcclNextGroupId(); +int ftNcclGroupCount(); + void ftNcclInitialize(NcclParam& tensor_para, NcclParam& pipeline_para, const int tensor_para_size,