diff --git a/.vscode/settings.json b/.vscode/settings.json index 6f535da99..655836ef1 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -67,6 +67,23 @@ "unordered_set": "cpp", "future": "cpp", "cfenv": "cpp", - "typeindex": "cpp" + "typeindex": "cpp", + "__bit_reference": "cpp", + "__config": "cpp", + "__debug": "cpp", + "__errc": "cpp", + "__hash_table": "cpp", + "__locale": "cpp", + "__mutex_base": "cpp", + "__node_handle": "cpp", + "__split_buffer": "cpp", + "__threading_support": "cpp", + "__tree": "cpp", + "__verbose_abort": "cpp", + "charconv": "cpp", + "ios": "cpp", + "locale": "cpp", + "variant": "cpp", + "__memory": "cpp" } -} \ No newline at end of file +} diff --git a/CMakeLists.txt b/CMakeLists.txt index a164ef827..2ed27a8a2 100755 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -322,6 +322,7 @@ add_library(transformer-shared SHARED $ $ $ + $ $ $ $ diff --git a/examples/cpp/CMakeLists.txt b/examples/cpp/CMakeLists.txt index da24d72c6..efacc9c7d 100644 --- a/examples/cpp/CMakeLists.txt +++ b/examples/cpp/CMakeLists.txt @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +add_subdirectory(bart) add_subdirectory(bert) add_subdirectory(bert_int8) add_subdirectory(decoding) diff --git a/examples/cpp/bart/CMakeLists.txt b/examples/cpp/bart/CMakeLists.txt new file mode 100644 index 000000000..5cceacb32 --- /dev/null +++ b/examples/cpp/bart/CMakeLists.txt @@ -0,0 +1,18 @@ +# Copyright (c) 2019-2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +add_executable(bart_triton_example bart_triton_example.cc) +target_link_libraries(bart_triton_example PUBLIC -lcublas -lcublasLt -lcudart -lpthread + BartTritonBackend TransformerTritonBackend custom_ar_comm + gpt_example_utils word_list mpi_utils nccl_utils nvtx_utils) diff --git a/examples/cpp/bart/bad_words.csv b/examples/cpp/bart/bad_words.csv new file mode 100644 index 000000000..6a1126ebd --- /dev/null +++ b/examples/cpp/bart/bad_words.csv @@ -0,0 +1,2 @@ +7768,3908 +1,2 diff --git a/examples/cpp/bart/bart_triton_example.cc b/examples/cpp/bart/bart_triton_example.cc new file mode 100644 index 000000000..448167273 --- /dev/null +++ b/examples/cpp/bart/bart_triton_example.cc @@ -0,0 +1,467 @@ +/* + * Copyright (c) 2021-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "3rdparty/INIReader.h" +#include "examples/cpp/multi_gpu_gpt/gpt_example_utils.h" +#include "src/fastertransformer/triton_backend/bart/BartTritonModel.h" +#include "src/fastertransformer/triton_backend/bart/BartTritonModelInstance.h" +#include "src/fastertransformer/utils/custom_ar_comm.h" +#include "src/fastertransformer/utils/mpi_utils.h" +#include "src/fastertransformer/utils/nccl_utils.h" +#include "src/fastertransformer/utils/nvtx_utils.h" +#include "src/fastertransformer/utils/word_list.h" + +#include +#include + +namespace ft = fastertransformer; + +struct RequestParam { + int beam_width; + int request_output_len; + float beam_search_diversity_rate; + uint runtime_top_k; + float runtime_top_p; + float temperature; + float len_penalty; + float repetition_penalty; + float presence_penalty; + int min_length; + unsigned long long int random_seed; + int start_id; + int end_id; +}; + +std::vector>> +broadCastRequest(const std::vector& v_start_ids, + const std::vector& v_start_lengths, + const std::vector& v_bad_words, + const int node_id, + const int gpu_count, + const RequestParam param, + std::vector* pointer_record) +{ + // broadcast the request to all nodes, and copy "gpu_count" copies on different gpu + int size_1 = v_start_ids.size(); + int size_2 = v_start_lengths.size(); + int size_bad_words = v_bad_words.size(); + ft::mpi::bcast(&size_1, 1, ft::mpi::MPI_TYPE_INT, 0, ft::mpi::COMM_WORLD); + ft::mpi::bcast(&size_2, 1, ft::mpi::MPI_TYPE_INT, 0, ft::mpi::COMM_WORLD); + ft::mpi::bcast(&size_bad_words, 1, ft::mpi::MPI_TYPE_INT, 0, ft::mpi::COMM_WORLD); + + std::vector v_input_ids(size_1); + std::vector v_input_lengths(size_2); + std::vector v_input_bad_words(size_bad_words); + + if (node_id == 0) { + memcpy(v_input_ids.data(), v_start_ids.data(), size_1 * sizeof(int)); + memcpy(v_input_lengths.data(), v_start_lengths.data(), size_2 * sizeof(int)); + memcpy(v_input_bad_words.data(), v_bad_words.data(), size_bad_words * sizeof(int)); + } + ft::mpi::barrier(); + + int request_batch_size = 2; + int max_input_len = size_1 / size_2; + + ft::mpi::bcast(v_input_ids.data(), size_1, ft::mpi::MPI_TYPE_INT, 0, ft::mpi::COMM_WORLD); + ft::mpi::bcast(v_input_lengths.data(), size_2, ft::mpi::MPI_TYPE_INT, 0, ft::mpi::COMM_WORLD); + ft::mpi::bcast(v_input_bad_words.data(), size_bad_words, ft::mpi::MPI_TYPE_INT, 0, ft::mpi::COMM_WORLD); + + std::vector>> request_list; + for (int device_id = 0; device_id < gpu_count; device_id++) { + ft::check_cuda_error(cudaSetDevice(device_id)); + + int* d_input_ids; + int* d_input_lengths; + int* d_input_bad_words; + + if (max_input_len == 0) { + // unconditional case, no input ids, so do nothing. + d_input_ids = nullptr; + d_input_lengths = nullptr; + max_input_len = 0; + } + else { + // conditional case. + ft::deviceMalloc(&d_input_ids, size_1, false); + ft::deviceMalloc(&d_input_lengths, size_2, false); + ft::cudaH2Dcpy(d_input_ids, v_input_ids.data(), size_1); + ft::cudaH2Dcpy(d_input_lengths, v_input_lengths.data(), size_2); + } + ft::deviceMalloc(&d_input_bad_words, size_bad_words, false); + ft::cudaH2Dcpy(d_input_bad_words, v_input_bad_words.data(), size_bad_words); + + uint32_t* request_output_len_ptr = (uint32_t*)malloc(request_batch_size * sizeof(uint32_t)); + for (int i = 0; i < request_batch_size; i++) { + request_output_len_ptr[i] = param.request_output_len; + } + + int* start_ids_ptr = (int*)malloc(request_batch_size * sizeof(int)); + int* end_ids_ptr = (int*)malloc(request_batch_size * sizeof(int)); + for (int i = 0; i < request_batch_size; i++) { + start_ids_ptr[i] = param.start_id; + end_ids_ptr[i] = param.end_id; + } + pointer_record->push_back(start_ids_ptr); + pointer_record->push_back(end_ids_ptr); + + request_list.push_back(std::shared_ptr>( + new std::unordered_map{ + {"input_ids", + triton::Tensor{triton::MEMORY_GPU, + triton::TYPE_INT32, + std::vector{(size_t)request_batch_size, (size_t)max_input_len}, + d_input_ids}}, + {"sequence_length", + triton::Tensor{triton::MEMORY_GPU, + triton::TYPE_INT32, + std::vector{(size_t)request_batch_size}, + d_input_lengths}}, + {"max_output_len", + triton::Tensor{triton::MEMORY_CPU, + triton::TYPE_INT32, + std::vector{(size_t)request_batch_size}, + request_output_len_ptr}}, + {"bad_words_list", + triton::Tensor{ + triton::MEMORY_GPU, triton::TYPE_INT32, {2, v_input_bad_words.size() / 2}, d_input_bad_words}}, + // {"start_id", + // triton::Tensor{triton::MEMORY_CPU, triton::TYPE_INT32, {(size_t)request_batch_size}, start_ids_ptr}}, + // {"end_id", + // triton::Tensor{triton::MEMORY_CPU, triton::TYPE_INT32, {(size_t)request_batch_size}, end_ids_ptr}} + })); + + int* beam_width_ptr = new int(param.beam_width); + pointer_record->push_back(beam_width_ptr); + request_list[device_id]->insert( + {"beam_width", + triton::Tensor{triton::MEMORY_CPU, triton::TYPE_INT32, std::vector{1}, beam_width_ptr}}); + if (param.beam_width > 1) { + float* beam_search_diversity_rate_ptr = new float(param.beam_search_diversity_rate); + pointer_record->push_back(beam_search_diversity_rate_ptr); + request_list[device_id]->insert( + {"beam_search_diversity_rate", + triton::Tensor{ + triton::MEMORY_CPU, triton::TYPE_FP32, std::vector{1}, beam_search_diversity_rate_ptr}}); + } + else { + if (param.runtime_top_p != 0.0f) { + float* runtime_top_p_ptr = new float(param.runtime_top_p); + pointer_record->push_back(runtime_top_p_ptr); + request_list[device_id]->insert( + {"runtime_top_p", + triton::Tensor{triton::MEMORY_CPU, triton::TYPE_FP32, std::vector{1}, runtime_top_p_ptr}}); + } + if (param.runtime_top_k != 0) { + uint* runtime_top_k_ptr = new uint(param.runtime_top_k); + pointer_record->push_back(runtime_top_k_ptr); + request_list[device_id]->insert( + {"runtime_top_k", + triton::Tensor{ + triton::MEMORY_CPU, triton::TYPE_UINT32, std::vector{1}, runtime_top_k_ptr}}); + } + } + float* temperature_ptr = new float(param.temperature); + pointer_record->push_back(temperature_ptr); + request_list[device_id]->insert( + {"temperature", + triton::Tensor{triton::MEMORY_CPU, triton::TYPE_FP32, std::vector{1}, temperature_ptr}}); + float* len_penalty_ptr = new float(param.len_penalty); + pointer_record->push_back(len_penalty_ptr); + request_list[device_id]->insert( + {"len_penalty", + triton::Tensor{triton::MEMORY_CPU, triton::TYPE_FP32, std::vector{1}, len_penalty_ptr}}); + if (param.repetition_penalty != 1.0f) { + float* repetition_penalty_ptr = new float(param.repetition_penalty); + pointer_record->push_back(repetition_penalty_ptr); + request_list[device_id]->insert( + {"repetition_penalty", + triton::Tensor{ + triton::MEMORY_CPU, triton::TYPE_FP32, std::vector{1}, repetition_penalty_ptr}}); + } + if (param.presence_penalty != 0.0f) { + float* presence_penalty_ptr = new float(param.presence_penalty); + pointer_record->push_back(presence_penalty_ptr); + request_list[device_id]->insert( + {"presence_penalty", + triton::Tensor{triton::MEMORY_CPU, triton::TYPE_FP32, std::vector{1}, presence_penalty_ptr}}); + } + int* min_length_ptr = new int(param.min_length); + pointer_record->push_back(min_length_ptr); + request_list[device_id]->insert( + {"min_length", + triton::Tensor{triton::MEMORY_CPU, triton::TYPE_INT32, std::vector{1}, min_length_ptr}}); + unsigned long long int* random_seed_ptr = new unsigned long long int(param.random_seed); + pointer_record->push_back(random_seed_ptr); + request_list[device_id]->insert( + {"random_seed", + triton::Tensor{triton::MEMORY_CPU, triton::TYPE_UINT64, std::vector{1}, random_seed_ptr}}); + + pointer_record->push_back(d_input_ids); + pointer_record->push_back(d_input_lengths); + pointer_record->push_back(d_input_bad_words); + pointer_record->push_back(request_output_len_ptr); + } + + return request_list; +} + +std::vector>> +prepareRequest(std::string ini_name, const int node_id, const int gpu_count, std::vector* pointer_record) +{ + INIReader reader = INIReader(ini_name); + if (reader.ParseError() < 0) { + std::cout << "[ERROR] Can't load '" << ini_name << "'\n"; + ft::FT_CHECK(false); + } + + const size_t request_batch_size = reader.GetInteger("request", "request_batch_size"); + + const int start_id = reader.GetInteger("decoder", "start_id"); + const int end_id = reader.GetInteger("decoder", "end_id"); + + std::vector v_start_ids; + std::vector v_start_lengths; + + size_t max_input_len = 0; + ft::read_start_ids(request_batch_size, + &v_start_lengths, + &v_start_ids, + max_input_len, + end_id, + 1, + "../examples/cpp/bart/start_ids.csv"); + printf("v_start_ids size: %d v_start_lengths size: %d\n", v_start_ids.size(), v_start_lengths.size()); + + std::vector v_bad_words; + ft::read_word_list("../examples/cpp/bart/bad_words.csv", v_bad_words); + + RequestParam param; + param.beam_width = reader.GetInteger("request", "beam_width"); + param.request_output_len = reader.GetInteger("request", "request_output_len"); + param.beam_search_diversity_rate = reader.GetFloat("request", "beam_search_diversity_rate"); + param.runtime_top_k = reader.GetInteger("request", "top_k"); + param.runtime_top_p = reader.GetFloat("request", "top_p"); + param.temperature = reader.GetFloat("request", "temperature"); + param.len_penalty = reader.GetFloat("request", "len_penalty"); + param.repetition_penalty = reader.GetFloat("request", "repetition_penalty", 1.0f); + param.presence_penalty = reader.GetFloat("request", "presence_penalty", 0.0f); + param.min_length = reader.GetInteger("request", "min_length", 0); + param.random_seed = (unsigned long long int)0; + param.start_id = start_id; + param.end_id = end_id; + + auto request_list = + broadCastRequest(v_start_ids, v_start_lengths, v_bad_words, node_id, gpu_count, param, pointer_record); + return request_list; +} + +int threadCreateModelInstances(std::shared_ptr model, + std::vector>* model_instances, + const int device_id, + const int rank, + std::pair, std::vector> nccl_params, + std::shared_ptr custom_all_reduce_comm = nullptr) +{ + printf("[INFO] rank = %d \n", rank); + ft::check_cuda_error(cudaSetDevice(device_id)); + cudaStream_t stream; + ft::check_cuda_error(cudaStreamCreate(&stream)); + model->createSharedWeights(device_id, rank); + auto model_instance = model->createModelInstance(device_id, rank, stream, nccl_params, custom_all_reduce_comm); + model_instances->at(device_id) = std::move(model_instance); + printf("model instance %d is created \n", device_id); + ft::print_mem_usage(); + return 0; +} + +int threadForward(std::unique_ptr* model_instance, + std::shared_ptr> request, + std::shared_ptr>* output_tensors, + const int device_id) +{ + ft::check_cuda_error(cudaSetDevice(device_id)); + *output_tensors = (*model_instance)->forward(request); + return 0; +} + +int main(int argc, char* argv[]) +{ + /* + Prepare the nccl ids, node id, device id and world size + by MPI or triton + */ + + // MPICHECK(MPI_Init(&argc, &argv)); + ft::mpi::initialize(&argc, &argv); + int node_id = ft::mpi::getCommWorldRank(); + int node_num = ft::mpi::getCommWorldSize(); + std::cout << "node_id: " << node_id << ", node_num: " << node_num << std::endl; + + // Note: Only supports that all nodes have same gpu count + const int gpu_count = ft::getDeviceCount(); + std::cout << "gpu_count: " << gpu_count << std::endl; + const int world_size = node_num * gpu_count; + std::string ini_name = argc >= 2 ? std::string(argv[1]) : "/notebooks/tmp/FasterTransformer/examples/cpp/bart/"; + + // step 1: Create model + std::shared_ptr model = AbstractTransformerModel::createBartModel("/notebooks/bart-ft/1-gpu"); + int tensor_para_size = model->getTensorParaSize(); + int pipeline_para_size = model->getPipelineParaSize(); + FT_CHECK_WITH_INFO(world_size == (tensor_para_size * pipeline_para_size), + "World Size != Tensor Parallel Size * Pipeline Parallel Size !"); + + std::cout << model->toString(); + + // step 2: Initialize the NCCL + std::pair, std::vector> nccl_comms = model->createNcclParams(node_id); + cudaDeviceSynchronize(); + + // Optional Step: create custom all reduce comm + std::vector> custom_all_reduce_comms; + model->createCustomComms(&custom_all_reduce_comms, world_size); + + // step 3: Create model instances + std::vector> model_instances((size_t)gpu_count); + std::vector threads; + for (int device_id = 0; device_id < gpu_count; device_id++) { + const int rank = node_id * gpu_count + device_id; + threads.push_back(std::thread(threadCreateModelInstances, + model, + &model_instances, + device_id, + rank, + nccl_comms, + custom_all_reduce_comms[rank])); + } + for (auto& t : threads) { + t.join(); + } + + // step 4: prepare request + std::vector pointer_record; // Used to prevent the pointers are release after leaving functions + std::vector>> request_list = + prepareRequest(ini_name + "/config.ini", node_id, gpu_count, &pointer_record); + printf("[INFO] request is created \n"); + + // step 5: Forward + std::vector>> output_tensors_lists( + (size_t)gpu_count); + printf("[INFO] gpu_count: %d %d %d %d\n", gpu_count, model_instances.size(), request_list.size(), output_tensors_lists.size()); + for (int i = 0; i < 1; i++) { + threads.clear(); + for (int device_id = 0; device_id < gpu_count; device_id++) { + threads.push_back(std::thread(threadForward, + &model_instances[device_id], + request_list[device_id], + &output_tensors_lists[device_id], + device_id)); + } + for (auto& t : threads) { + t.join(); + } + } + printf("[INFO] forward is completed. \n"); + + for (const auto& pair : *output_tensors_lists[0]) { + std::cout << "Key: " << pair.first << std::endl; + } + + const int* d_output_ids = (const int*)output_tensors_lists[0].get()->at("output_ids").data; + const int batch_size = output_tensors_lists[0].get()->at("output_ids").shape[0]; + const int beam_width = output_tensors_lists[0].get()->at("output_ids").shape[1]; + const int seq_len = output_tensors_lists[0].get()->at("output_ids").shape[2]; + // const int* d_input_lengths = (const int*)output_tensors_lists[0].get()->at("sequence_length").data; + // step 6: check results + if (node_id == 0) { + + std::string fName = "out"; + auto outFile = std::ofstream(fName, std::ios::out); + if (!outFile.is_open()) { + printf("[WARNING] Cannot write results into output file %s \n", fName.c_str()); + } + else { + size_t outCount = batch_size * beam_width * seq_len; + int* hBuf = new int[outCount]; + // int* iBuf = new int[batch_size]; + ft::cudaD2Hcpy(hBuf, d_output_ids, outCount); + // ft::cudaD2Hcpy(iBuf, d_input_lengths, batch_size); + + + { + std::cout << "Writing " << outCount << " elements\n"; + int zeroCount = 0; + // for (int i=0; ipush_back(tmp_start_lengths[i]); } } + for (int i : *v_start_lengths) { + printf("v_start_lengths %d\n", i); + } return batch_size; } diff --git a/examples/pytorch/bart/translate_example.py b/examples/pytorch/bart/translate_example.py index 3d32f9907..db2db77d5 100644 --- a/examples/pytorch/bart/translate_example.py +++ b/examples/pytorch/bart/translate_example.py @@ -213,7 +213,7 @@ def translate(args_dict): config.decoder_start_token_id, config.eos_token_id, config.vocab_size, tensor_para_size=tensor_para_size, pipeline_para_size=pipeline_para_size, bart_with_bias=bart_with_bias, mbart=is_mbart, - position_embedding_type=position_embedding_type, + position_embedding_type=position_embedding_type, activation_type=activation_type, layernorm_type=layernorm_type) ft_bart = FTBart(ft_encoder, ft_decoding) @@ -375,4 +375,4 @@ def translate(args_dict): args = parser.parse_args() log_format = "%(asctime)s %(name)s [%(levelname)s] %(message)s" logging.basicConfig(level=logging.DEBUG if args.verbose else logging.INFO, format=log_format) - translate(vars(args)) \ No newline at end of file + translate(vars(args)) diff --git a/examples/pytorch/bart/utils/huggingface_bart_ckpt_convert.py b/examples/pytorch/bart/utils/huggingface_bart_ckpt_convert.py new file mode 100644 index 000000000..2f6dc2c1d --- /dev/null +++ b/examples/pytorch/bart/utils/huggingface_bart_ckpt_convert.py @@ -0,0 +1,307 @@ +# Copyright (c) 2021-2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import configparser +import multiprocessing +from datetime import datetime +import logging +from pathlib import Path + +from transformers import BartForConditionalGeneration + +import numpy as np +import torch # pytype: disable=import-error + +LOGGER = logging.getLogger(__name__) + + +def get_weight_data_type(data_type): + if data_type == "fp32": + return np.float32 + elif data_type == "fp16": + return np.float16 + else: + assert False, f"Invalid weight data type {data_type}" + + +def fuse_decoder_qkv(model, factor, saved_dir, np_weight_data_type): + model_dict = {} + for name, param in model.named_parameters(): + if name.find("self_attn") == -1 or name.find("decoder.layers") == -1: + continue + if name.find(".q_proj.") != -1 or name.find(".k_proj.") != -1 or name.find(".v_proj.") != -1: + model_dict[name] = param + + for i in range(model.config.decoder_layers): + shape = model_dict[f"model.decoder.layers.{i}.self_attn.q_proj.weight"].T.shape + qkv = torch.cat([model_dict[f"model.decoder.layers.{i}.self_attn.q_proj.weight"].T, + model_dict[f"model.decoder.layers.{i}.self_attn.k_proj.weight"].T, + model_dict[f"model.decoder.layers.{i}.self_attn.v_proj.weight"].T], dim=-1) + + qkv = qkv.reshape([shape[0], 3, shape[1]]) + qkv = qkv.cpu().detach().numpy().astype(np_weight_data_type) + + split_vals = np.split(qkv, factor, axis=-1) + for j in range(factor): + saved_path = saved_dir / f"decoder.{i}.layer.SelfAttention.qkv.weight.{j}.bin" + split_vals[j].tofile(saved_path.as_posix()) + + for i in range(model.config.decoder_layers): + shape = model_dict[f"model.decoder.layers.{i}.self_attn.q_proj.bias"].shape + qkv = torch.cat([model_dict[f"model.decoder.layers.{i}.self_attn.q_proj.bias"], + model_dict[f"model.decoder.layers.{i}.self_attn.k_proj.bias"], + model_dict[f"model.decoder.layers.{i}.self_attn.v_proj.bias"]], dim=-1) + qkv = qkv.cpu().detach().numpy().astype(np_weight_data_type) + + split_vals = np.split(qkv, factor, axis=-1) + for j in range(factor): + saved_path = saved_dir / f"decoder.{i}.layer.SelfAttention.qkv.bias.{j}.bin" + split_vals[j].tofile(saved_path.as_posix()) + + +def get_encoder_or_decoder(key): + return "encoder" if key.find("encoder") != -1 else "decoder" + + +def get_fc(key): + return "fc1" if key.find("fc1.") != -1 else "fc2" + + +def split_and_convert_process(key, val, factor, saved_dir): + if val.ndim == 2: + val = val.transpose(1, 0) + + if key.find(".embed_positions.weight") != -1: + prefix = get_encoder_or_decoder(key) + saved_path = saved_dir / f"{prefix}.embed_positions.weight.bin" + val[:, 2:].T.tofile(saved_path.as_posix()) + elif key.find(".embed_tokens.weight") != -1: + prefix = get_encoder_or_decoder(key) + saved_path = saved_dir / f"{prefix}.embed_tokens.weight.bin" + val.T.tofile(saved_path.as_posix()) + elif key.find(".layernorm_embedding.weight") != -1: + prefix = get_encoder_or_decoder(key) + saved_path = saved_dir / f"{prefix}.final_layer_norm.weight.bin" + val.tofile(saved_path.as_posix()) + elif key.find(".layernorm_embedding.bias") != -1: + prefix = get_encoder_or_decoder(key) + saved_path = saved_dir / f"{prefix}.final_layer_norm.bias.bin" + val.tofile(saved_path.as_posix()) + elif ( + key.find("self_attn.k_proj.weight") != -1 + or key.find("self_attn.v_proj.weight") != -1 + or key.find("self_attn.q_proj.weight") != -1 + ): + split_vals = np.split(val, factor, axis=0) + prefix = get_encoder_or_decoder(key) + if prefix == "decoder": + # will be handled in fuse_decoder_qkv instead + return + layer = int(key.split('layers.')[1].split('.self_attn')[0]) + qkv = key.split('self_attn.')[1][:1] + for j in range(factor): + saved_path = saved_dir / f"{prefix}.{layer}.layer.SelfAttention.{qkv}.weight.{j:d}.bin" + split_vals[j].tofile(saved_path.as_posix()) + elif ( + key.find("self_attn.k_proj.bias") != -1 + or key.find("self_attn.v_proj.bias") != -1 + or key.find("self_attn.q_proj.bias") != -1 + ): + split_vals = np.split(val, factor, axis=0) + prefix = get_encoder_or_decoder(key) + if prefix == "decoder": + # will be handled in fuse_decoder_qkv instead + return + layer = int(key.split('layers.')[1].split('.self_attn')[0]) + qkv = key.split('self_attn.')[1][:1] + for j in range(factor): + saved_path = saved_dir / f"{prefix}.{layer}.layer.SelfAttention.{qkv}.bias.{j:d}.bin" + split_vals[j].tofile(saved_path.as_posix()) + elif key.find("self_attn.out_proj.weight") != -1: + split_vals = np.split(val, factor, axis=0) + prefix = get_encoder_or_decoder(key) + layer = int(key.split('layers.')[1].split('.self_attn')[0]) + for j in range(factor): + saved_path = saved_dir / f"{prefix}.{layer}.layer.SelfAttention.out_proj.weight.{j:d}.bin" + split_vals[j].tofile(saved_path.as_posix()) + elif key.find("self_attn.out_proj.bias") != -1: + split_vals = np.split(val, factor, axis=0) + prefix = get_encoder_or_decoder(key) + layer = int(key.split('layers.')[1].split('.self_attn')[0]) + for j in range(factor): + saved_path = saved_dir / f"{prefix}.{layer}.layer.SelfAttention.out_proj.bias.{j:d}.bin" + split_vals[j].tofile(saved_path.as_posix()) + elif key.find("self_attn_layer_norm.weight") != -1: + prefix = get_encoder_or_decoder(key) + layer = int(key.split('layers.')[1].split('.self_attn')[0]) + saved_path = saved_dir / f"{prefix}.{layer}.layer.SelfAttention.attn_layer_norm.weight.bin" + val.tofile(saved_path.as_posix()) + elif key.find("self_attn_layer_norm.bias") != -1: + prefix = get_encoder_or_decoder(key) + layer = int(key.split('layers.')[1].split('.self_attn')[0]) + saved_path = saved_dir / f"{prefix}.{layer}.layer.SelfAttention.attn_layer_norm.bias.bin" + val.tofile(saved_path.as_posix()) + elif ( + key.find("encoder_attn.k_proj.weight") != -1 + or key.find("encoder_attn.v_proj.weight") != -1 + or key.find("encoder_attn.q_proj.weight") != -1 + ): + split_vals = np.split(val, factor, axis=0) + layer = int(key.split('layers.')[1].split('.encoder_attn')[0]) + qkv = key.split('encoder_attn.')[1][:1] + for j in range(factor): + saved_path = saved_dir / f"decoder.{layer}.layer.CrossAttention.{qkv}.weight.{j:d}.bin" + split_vals[j].tofile(saved_path.as_posix()) + elif ( + key.find("encoder_attn.k_proj.bias") != -1 + or key.find("encoder_attn.v_proj.bias") != -1 + or key.find("encoder_attn.q_proj.bias") != -1 + ): + split_vals = np.split(val, factor, axis=0) + layer = int(key.split('layers.')[1].split('.encoder_attn')[0]) + qkv = key.split('encoder_attn.')[1][:1] + for j in range(factor): + saved_path = saved_dir / f"decoder.{layer}.layer.CrossAttention.{qkv}.bias.{j:d}.bin" + split_vals[j].tofile(saved_path.as_posix()) + elif key.find("encoder_attn.out_proj.weight") != -1: + split_vals = np.split(val, factor, axis=0) + layer = int(key.split('layers.')[1].split('.encoder_attn')[0]) + for j in range(factor): + saved_path = saved_dir / f"decoder.{layer}.layer.CrossAttention.out_proj.weight.{j:d}.bin" + split_vals[j].tofile(saved_path.as_posix()) + elif key.find("encoder_attn.out_proj.bias") != -1: + split_vals = np.split(val, factor, axis=0) + layer = int(key.split('layers.')[1].split('.encoder_attn')[0]) + for j in range(factor): + saved_path = saved_dir / f"decoder.{layer}.layer.CrossAttention.out_proj.bias.{j:d}.bin" + split_vals[j].tofile(saved_path.as_posix()) + elif key.find("encoder_attn_layer_norm.weight") != -1: + layer = int(key.split('layers.')[1].split('.encoder_attn')[0]) + saved_path = saved_dir / f"decoder.{layer}.layer.CrossAttention.attn_layer_norm.weight.bin" + val.tofile(saved_path.as_posix()) + elif key.find("encoder_attn_layer_norm.bias") != -1: + layer = int(key.split('layers.')[1].split('.encoder_attn')[0]) + saved_path = saved_dir / f"decoder.{layer}.layer.CrossAttention.attn_layer_norm.bias.bin" + val.tofile(saved_path.as_posix()) + elif key.find("fc1.weight") != -1 or key.find("fc2.weight") != -1: + prefix = get_encoder_or_decoder(key) + split_vals = np.split(val, factor, axis=0) + fc = get_fc(key) + layer = int(key.split('layers.')[1].split(f'.{fc}.')[0]) + for j in range(factor): + saved_path = saved_dir / f"{prefix}.{layer}.layer.SelfAttention.{fc}.weight.{j:d}.bin" + split_vals[j].tofile(saved_path.as_posix()) + elif key.find("fc1.bias") != -1 or key.find("fc2.bias") != -1: + prefix = get_encoder_or_decoder(key) + fc = get_fc(key) + layer = int(key.split('layers.')[1].split(f'.{fc}.')[0]) + split_vals = np.split(val, factor, axis=0) + for j in range(factor): + saved_path = saved_dir / f"{prefix}.{layer}.layer.SelfAttention.{fc}.bias.{j:d}.bin" + split_vals[j].tofile(saved_path.as_posix()) + elif key.find("final_layer_norm.weight") != -1: + prefix = get_encoder_or_decoder(key) + layer = int(key.split('layers.')[1].split('.final_layer_norm.')[0]) + saved_path = saved_dir / f"{prefix}.{layer}.layer.SelfAttention.final_layer_norm.weight.bin" + val.tofile(saved_path.as_posix()) + elif key.find("final_layer_norm.bias") != -1: + prefix = get_encoder_or_decoder(key) + layer = int(key.split('layers.')[1].split('.final_layer_norm.')[0]) + saved_path = saved_dir / f"{prefix}.{layer}.layer.SelfAttention.final_layer_norm.bias.bin" + val.tofile(saved_path.as_posix()) + elif key.find("lm_head.weight") != -1: + saved_path = saved_dir / "decoder.lm_head.weight.bin" + val.T.tofile(saved_path.as_posix()) + elif key.find("final_logits_bias") != -1: + saved_path = saved_dir / "decoder.final_logits_bias.bin" + val.tofile(saved_path.as_posix()) + elif key.find("encoder.embed_tokens.weight") != -1 or \ + key.find("decoder.embed_tokens.weight") != -1: + LOGGER.warning(f"Not save {key}, using shared.weight directly.") + else: + LOGGER.warning(f"Not save '{key}' with shape {val.shape}") + + +def convert_checkpoint(args): + saved_dir = Path(args.saved_dir) / f"{args.inference_tensor_para_size:d}-gpu" + saved_dir.mkdir(parents=True, exist_ok=True) + + bart_model = BartForConditionalGeneration.from_pretrained(args.in_file) + hf_config = vars(bart_model.config) + config = configparser.ConfigParser() + + config["encoder"] = {} + config["encoder"]["model_name"] = "bart" + config["encoder"]["num_heads"] = str(hf_config["encoder_attention_heads"]) + config["encoder"]["d_kv"] = str(hf_config["d_model"] // hf_config["encoder_attention_heads"]) + config["encoder"]["d_model"] = str(hf_config["d_model"]) + config["encoder"]["d_ff"] = str(hf_config["encoder_ffn_dim"]) + config["encoder"]["num_layers"] = str(hf_config["encoder_layers"]) + config["encoder"]["vocab_size"] = str(hf_config["vocab_size"]) + config["encoder"]["max_pos_seq_len"] = str(hf_config["max_position_embeddings"]) + config["encoder"]["feed_forward_proj"] = str(hf_config["activation_function"]) + config["encoder"]["weight_data_type"] = args.weight_data_type + + config["decoder"] = {} + config["decoder"]["num_heads"] = str(hf_config["decoder_attention_heads"]) + config["decoder"]["d_kv"] = str(hf_config["d_model"] // hf_config["decoder_attention_heads"]) + config["decoder"]["d_model"] = str(hf_config["d_model"]) + config["decoder"]["d_ff"] = str(hf_config["decoder_ffn_dim"]) + config["decoder"]["num_layers"] = str(hf_config["decoder_layers"]) + config["decoder"]["vocab_size"] = str(hf_config["vocab_size"]) + config["decoder"]["max_pos_seq_len"] = str(hf_config["max_position_embeddings"]) + config["decoder"]["decoder_start_token_id"] = str(hf_config["decoder_start_token_id"]) + config["decoder"]["eos_token_id"] = str(hf_config["eos_token_id"]) + config["decoder"]["weight_data_type"] = args.weight_data_type + + with open((saved_dir / "config.ini").as_posix(), 'w') as configfile: + config.write(configfile) + np_weight_data_type = get_weight_data_type(args.weight_data_type) + + i_gpu_num = args.inference_tensor_para_size + pool = multiprocessing.Pool(args.processes) + pool.starmap_async(split_and_convert_process, + [(name, param.cpu().detach().numpy().astype(np_weight_data_type), i_gpu_num, saved_dir) + for name, param in bart_model.state_dict().items()]) + + pool.close() + pool.join() + + fuse_decoder_qkv(bart_model, i_gpu_num, saved_dir, np_weight_data_type) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(formatter_class=argparse.RawTextHelpFormatter) + parser.add_argument("-saved_dir", "-o", type=str, help="file name of output file", required=True) + parser.add_argument("-in_file", "-i", type=str, help="file name of input checkpoint file", required=True) + parser.add_argument("-inference_tensor_para_size", "-i_g", type=int, help="How many gpus for inference", + required=True) + parser.add_argument("-processes", "-p", type=int, help="How many processes to spawn for conversion (default: 4)", + default=4) + parser.add_argument("-weight_data_type", type=str, default="fp32", choices=["fp32", "fp16"]) + parser.add_argument("--verbose", action="store_true", help="Provide verbose messages") + args = parser.parse_args() + log_format = "%(asctime)s %(name)s [%(levelname)s] %(message)s" + logging.basicConfig(level=logging.DEBUG if args.verbose else logging.INFO, format=log_format) + LOGGER.info("\n=============== Argument ===============") + for key in vars(args): + LOGGER.info(f"{key}: {vars(args)[key]}") + LOGGER.info("========================================") + + start_time = datetime.now() + convert_checkpoint(args) + stop_time = datetime.now() + run_time = (stop_time - start_time) + LOGGER.info("Spend {} (h:m:s) to convert the model".format(run_time)) diff --git a/src/fastertransformer/layers/DynamicDecodeLayer.h b/src/fastertransformer/layers/DynamicDecodeLayer.h index 3b63cda92..774300731 100644 --- a/src/fastertransformer/layers/DynamicDecodeLayer.h +++ b/src/fastertransformer/layers/DynamicDecodeLayer.h @@ -26,6 +26,17 @@ namespace fastertransformer { +// fallback to fp32 dynamic decoder when bf16 specified +template +struct fallBackType { + using Type = float; +}; + +template<> +struct fallBackType { + using Type = half; +}; + template class DynamicDecodeLayer: public BaseLayer { protected: diff --git a/src/fastertransformer/models/bart/BartDecoder.cc b/src/fastertransformer/models/bart/BartDecoder.cc index 2c7180549..d8c928e86 100644 --- a/src/fastertransformer/models/bart/BartDecoder.cc +++ b/src/fastertransformer/models/bart/BartDecoder.cc @@ -546,7 +546,24 @@ void BartDecoder::forward(std::vector* outp stream_); } sync_check_cuda_error(); - +// { +// { +// T* buf; +// int st = local_batch_size * d_model_; +// buf = new T[st]; +// cudaMemcpy(buf, decoder_output, sizeof(T) * st, cudaMemcpyDeviceToHost); +// auto step_ptr = input_tensors->at(4).data; +// int step = ((int*)step_ptr)[0]; +// if (step == 1) { +// printf("decoder_output at layer %d step %d\n", l, step); +// for (int i=0; i<50; i++) { +// printf("%f ", double(buf[i])); +// } +// printf("buf last: %f\n", double(buf[st-1])); +// printf("\n"); +// } +// } +// } if (isLastLayerParallelId(l) == true && pipeline_para_.rank_ != pipeline_para_.world_size_ - 1 && pipeline_para_.world_size_ > 1) { // ftNcclSend(decoder_output, local_batch_size * d_model_, pipeline_para_.rank_ + 1, diff --git a/src/fastertransformer/models/bart/BartDecoderLayerWeight.cc b/src/fastertransformer/models/bart/BartDecoderLayerWeight.cc index 3b17c7317..e77b81b17 100644 --- a/src/fastertransformer/models/bart/BartDecoderLayerWeight.cc +++ b/src/fastertransformer/models/bart/BartDecoderLayerWeight.cc @@ -274,8 +274,118 @@ void BartDecoderLayerWeight::loadModel(std::string dir_path, FtCudaDataType m { FT_LOG_DEBUG("BartDecoderLayerWeight " + std::string(__func__) + " start"); - FT_LOG_DEBUG( - "Currently only support checkpoint loading from PyTorch interface outside FT. Direct checkpoint .bin loading support TBD"); + const auto tp_rank = std::to_string(tensor_para_rank_); + + loadWeightFromBin(weights_ptr[0], + {weights_size[0]}, + dir_path + "layer.SelfAttention.final_layer_norm.weight.bin", + model_file_type); + loadWeightFromBin(weights_ptr[1], + {weights_size[1]}, + dir_path + "layer.SelfAttention.qkv.weight." + tp_rank + ".bin", + model_file_type); + loadWeightFromBin(weights_ptr[2], + {weights_size[2]}, + dir_path + "layer.SelfAttention.out_proj.weight." + tp_rank + ".bin", + model_file_type); + loadWeightFromBin(weights_ptr[3], + {weights_size[3]}, + dir_path + "layer.SelfAttention.attn_layer_norm.weight.bin", + model_file_type); + loadWeightFromBin(weights_ptr[4], + {weights_size[4]}, + dir_path + "layer.CrossAttention.q.weight." + tp_rank + ".bin", + model_file_type); + loadWeightFromBin(weights_ptr[5], + {weights_size[5]}, + dir_path + "layer.CrossAttention.k.weight." + tp_rank + ".bin", + model_file_type); + loadWeightFromBin(weights_ptr[6], + {weights_size[6]}, + dir_path + "layer.CrossAttention.v.weight." + tp_rank + ".bin", + model_file_type); + loadWeightFromBin(weights_ptr[7], + {weights_size[7]}, + dir_path + "layer.CrossAttention.out_proj.weight." + tp_rank + ".bin", + model_file_type); + loadWeightFromBin(weights_ptr[8], + {weights_size[8]}, + dir_path + "layer.CrossAttention.attn_layer_norm.weight.bin", + model_file_type); + + loadWeightFromBin(weights_ptr[9], + {weights_size[9]}, + dir_path + "layer.SelfAttention.fc1.weight." + tp_rank + ".bin", + model_file_type); + loadWeightFromBin(weights_ptr[10], + {weights_size[10]}, + dir_path + "layer.SelfAttention.fc2.weight." + tp_rank + ".bin", + model_file_type); + + if (bart_with_bias_) { + /* + layernorm_weights.beta = weights_ptr[11]; + self_attention_weights.query_weight.bias = weights_ptr[12]; + self_attention_weights.attention_output_weight.bias = weights_ptr[13]; + self_attn_layernorm_weights.beta = weights_ptr[14]; + + cross_attention_weights.query_weight.bias = weights_ptr[15]; + cross_attention_weights.key_weight.bias = weights_ptr[16]; + cross_attention_weights.value_weight.bias = weights_ptr[17]; + cross_attention_weights.attention_output_weight.bias = weights_ptr[18]; + cross_attn_layernorm_weights.beta = weights_ptr[19]; + + ffn_weights.intermediate_weight.bias = weights_ptr[20]; + ffn_weights.output_weight.bias = weights_ptr[21]; + */ + loadWeightFromBin(weights_ptr[11], + {weights_size[11]}, + dir_path + "layer.SelfAttention.final_layer_norm.bias.bin", + model_file_type); + loadWeightFromBin(weights_ptr[12], + {weights_size[12]}, + dir_path + "layer.SelfAttention.qkv.bias." + tp_rank + ".bin", + model_file_type); + loadWeightFromBin(weights_ptr[13], + {weights_size[13]}, + dir_path + "layer.SelfAttention.out_proj.bias." + tp_rank + ".bin", + model_file_type); + loadWeightFromBin(weights_ptr[14], + {weights_size[14]}, + dir_path + "layer.SelfAttention.attn_layer_norm.bias.bin", + model_file_type); + + loadWeightFromBin(weights_ptr[15], + {weights_size[15]}, + dir_path + "layer.CrossAttention.q.bias." + tp_rank + ".bin", + model_file_type); + loadWeightFromBin(weights_ptr[16], + {weights_size[16]}, + dir_path + "layer.CrossAttention.k.bias." + tp_rank + ".bin", + model_file_type); + loadWeightFromBin(weights_ptr[17], + {weights_size[17]}, + dir_path + "layer.CrossAttention.v.bias." + tp_rank + ".bin", + model_file_type); + loadWeightFromBin(weights_ptr[18], + {weights_size[18]}, + dir_path + "layer.CrossAttention.out_proj.bias." + tp_rank + ".bin", + model_file_type); + loadWeightFromBin(weights_ptr[19], + {weights_size[19]}, + dir_path + "layer.CrossAttention.attn_layer_norm.bias.bin", + model_file_type); + + loadWeightFromBin(weights_ptr[20], + {weights_size[20]}, + dir_path + "layer.SelfAttention.fc1.bias." + tp_rank + ".bin", + model_file_type); + loadWeightFromBin(weights_ptr[21], + {weights_size[21]}, + dir_path + "layer.SelfAttention.fc2.bias." + tp_rank + ".bin", + model_file_type); + + } FT_LOG_DEBUG("BartDecoderLayerWeight " + std::string(__func__) + " end"); } diff --git a/src/fastertransformer/models/bart/BartDecoding.cc b/src/fastertransformer/models/bart/BartDecoding.cc index a0e2d876e..08457c14f 100644 --- a/src/fastertransformer/models/bart/BartDecoding.cc +++ b/src/fastertransformer/models/bart/BartDecoding.cc @@ -321,6 +321,17 @@ BartDecoding::~BartDecoding() freeBuffer(); } +template +void BartDecoding::registerCallback(callback_sig* fn, void* ctx) +{ +} + +template +void BartDecoding::unRegisterCallback() +{ +} + + template void BartDecoding::forward(TensorMap* output_tensors, TensorMap* input_tensors, @@ -371,6 +382,7 @@ void BartDecoding::forward(TensorMap* output_tensors, dynamic_decode_layer_->setup(batch_size, beam_width, &input_map); handleOptArg(&input_map, "start_id", start_ids_buf_, start_id_, batch_size); handleOptArg(&input_map, "end_id", end_ids_buf_, end_id_, batch_size); + printf("start_id_ end_id_ %d %d\n", start_id_, end_id_); } FT_CHECK_WITH_INFO(input_tensors->at("encoder_output").shape[2] == d_model_, @@ -421,6 +433,17 @@ void BartDecoding::forward(TensorMap* output_tensors, max_input_length - 1, stream_); sync_check_cuda_error(); + { + int* buf; + int st = batch_size * (max_seq_len+1); + buf = new int[st]; + cudaMemcpy(buf, output_ids_buf_, sizeof(int) * st, cudaMemcpyDeviceToHost); + printf("output_ids_buf_ batch_size: %d\n", batch_size); + for (int i=0; iabsolute_or_relative_position_embedding, @@ -501,6 +524,17 @@ void BartDecoding::forward(TensorMap* output_tensors, sync_check_cuda_error(); } + if (step == max_input_length) { + T* buf; + int st = batch_size * d_model_; + buf = new T[st]; + cudaMemcpy(buf, decoder_input_buf_, sizeof(T) * st, cudaMemcpyDeviceToHost); + printf("decoder_input_buf_: %d\n", batch_size); + for (int i=0; i::forward(TensorMap* output_tensors, {"local_batch_size", Tensor{MEMORY_CPU, TYPE_INT32, {1}, &tmp_local_batch_size}}, {"is_initialize_random_table", Tensor{MEMORY_CPU, TYPE_BOOL, {1}, &is_initialize_random_table}}}); +// { +// T* buf; +// int st = batch_size * beam_width * vocab_size_padded_; +// buf = new T[st]; +// cudaMemcpy(buf, logits_buf_, sizeof(T) * st, cudaMemcpyDeviceToHost); +// printf("logits_buf_\n"); +// for (int i=0; i<50; i++) { +// printf("%f ", double(buf[i])); +// } +// printf("buf last: %f\n", double(buf[st-1])); +// printf("\n"); +// } if (cache_indirections_[src_indir_idx] != nullptr) { dynamic_decode_input_tensors.insert( "src_cache_indirection", @@ -780,8 +826,30 @@ void BartDecoding::forward(TensorMap* output_tensors, } dynamic_decode_output_tensors.insert(*t); } + // { + // int* buf; + // int st = batch_size * (max_seq_len+1); + // buf = new int[st]; + // cudaMemcpy(buf, output_ids_buf_, sizeof(int) * st, cudaMemcpyDeviceToHost); + // printf("start_ids_buf_ before forward: %d\n", batch_size); + // for (int i=0; iforward(&dynamic_decode_output_tensors, &dynamic_decode_input_tensors); + { + int* buf; + int st = batch_size * (max_seq_len+1); + buf = new int[st]; + cudaMemcpy(buf, output_ids_buf_, sizeof(int) * st, cudaMemcpyDeviceToHost); + printf("output_ids_buf_ after forward: %d\n", batch_size); + for (int i=0; i::forward(TensorMap* output_tensors, } } + { + int* buf; + int st = batch_size * (max_seq_len+1); + buf = new int[st]; + cudaMemcpy(buf, output_ids_buf_, sizeof(int) * st, cudaMemcpyDeviceToHost); + printf("output_ids_buf_ after finalize: %d\n", batch_size); + for (int i=0; i 1) { ftNcclGroupStart(); if (pipeline_para_.rank_ == pipeline_para_.world_size_ - 1) { @@ -976,6 +1057,17 @@ void BartDecoding::forward(TensorMap* output_tensors, // throw errors when detected ftNcclStreamSynchronize(tensor_para_, pipeline_para_, stream_); + { + int* buf; + int st = 32; + buf = new int[st]; + cudaMemcpy(buf, output_tensors->at("output_ids").data, sizeof(int) * st, cudaMemcpyDeviceToHost); + printf("output_ids after finalize: %s %d\n", output_tensors->at("output_ids").toString().c_str(), batch_size); + for (int i=0; i -struct fallBackType { - using Type = float; -}; - -template<> -struct fallBackType { - using Type = half; -}; - template class BartDecoding: public BaseLayer { private: @@ -128,6 +117,8 @@ class BartDecoding: public BaseLayer { const bool using_beam_hyps = true; BeamHypotheses beam_hyps_; + using callback_sig = void(TensorMap*, void*); + public: BartDecoding(size_t max_batch_size, size_t max_seq_len, @@ -170,6 +161,9 @@ class BartDecoding: public BaseLayer { void forward(TensorMap* output_tensors, TensorMap* input_tensors, const BartDecodingWeight* Decoding_weights); void setStream(cudaStream_t stream) override; + + void registerCallback(callback_sig* fn, void* ctx); + void unRegisterCallback(); }; } // namespace fastertransformer diff --git a/src/fastertransformer/models/bart/BartDecodingWeight.cc b/src/fastertransformer/models/bart/BartDecodingWeight.cc index 7789eb0cf..f990d49c0 100644 --- a/src/fastertransformer/models/bart/BartDecodingWeight.cc +++ b/src/fastertransformer/models/bart/BartDecodingWeight.cc @@ -256,8 +256,28 @@ void BartDecodingWeight::loadModel(std::string dir_path) { FT_LOG_DEBUG("BartDecodingWeight " + std::string(__func__) + " start"); - FT_LOG_DEBUG( - "Currently only support checkpoint loading from PyTorch interface outside FT. Direct checkpoint .bin loading support TBD"); + FtCudaDataType model_file_type = getModelFileType(dir_path + "/config.ini", "decoder"); + FT_CHECK(is_maintain_buffer_ == true); + + loadWeightFromBin(weights_ptr[0], {(size_t)weights_size[0]}, dir_path + "/decoder.embed_positions.weight.bin", model_file_type); + loadWeightFromBin(weights_ptr[1], {(size_t)weights_size[1]}, dir_path + "/decoder.embed_tokens.weight.bin", model_file_type); + loadWeightFromBin(weights_ptr[2], {(size_t)weights_size[2]}, dir_path + "/decoder.lm_head.weight.bin", model_file_type); + loadWeightFromBin( + weights_ptr[3], {(size_t)weights_size[3]}, dir_path + "/decoder.final_layer_norm.weight.bin", model_file_type); + if (bart_with_bias) { + loadWeightFromBin(weights_ptr[4], + {(size_t)weights_size[4]}, + dir_path + "/decoder.final_layer_norm.bias.bin", + model_file_type); + loadWeightFromBin(weights_ptr[5], {(size_t)weights_size[5]}, dir_path + "/decoder.final_logits_bias.bin", model_file_type); + } + + for (int l = 0; l < num_layer_; l++) { + if (isValidLayerParallelId(l)) { + decoder_layer_weights[l]->loadModel(dir_path + "/decoder." + std::to_string(l) + ".", + model_file_type); + } + } FT_LOG_DEBUG("BartDecodingWeight " + std::string(__func__) + " end"); } diff --git a/src/fastertransformer/models/bart/BartEncoder.cc b/src/fastertransformer/models/bart/BartEncoder.cc index bc55b9e45..0f6487b6a 100644 --- a/src/fastertransformer/models/bart/BartEncoder.cc +++ b/src/fastertransformer/models/bart/BartEncoder.cc @@ -368,6 +368,7 @@ void BartEncoder::forward(TensorMap* output_tensors, FT_CHECK(input_tensors->at("input_ids").shape.size() == 2); } std::string input_tensor_name = use_inputs_embeds ? "inputs_embeds" : "input_ids"; + printf("input_tensor_name: %s\n", input_tensor_name.c_str()); const size_t request_batch_size = input_tensors->at(input_tensor_name).shape[0]; const size_t request_seq_len = input_tensors->at(input_tensor_name).shape[1]; const bool return_attentions = output_tensors->at("output_attentions", {}).size(); @@ -413,8 +414,9 @@ void BartEncoder::forward(TensorMap* output_tensors, size_t d_model_offset = id_offset * request_seq_len * d_model_; const int* sequence_lengths = input_tensors->at("sequence_length").getPtr() + id_offset; - + printf("use_inputs_embeds: %d\n", use_inputs_embeds); if (position_embedding_type == PositionEmbeddingType::absolute) { + printf("invokeInputIdsEmbeddingLookupPosEncoding\n"); invokeInputIdsEmbeddingLookupPosEncoding( bart_encoder_emb_buf_, nullptr, @@ -453,6 +455,30 @@ void BartEncoder::forward(TensorMap* output_tensors, sync_check_cuda_error(); +{ + T* buf; + int batch_size = 1; + int seq_len = 11; + int st = batch_size * seq_len * d_model_; + printf("st: %d %d %d %d\n",batch_size, seq_len, d_model_, st); + buf = new T[st]; + cudaMemcpy(buf, bart_encoder_emb_buf_, sizeof(T) * st, cudaMemcpyDeviceToHost); + printf("bart_encoder_emb_buf_\n"); + for (int i=0; i < seq_len; i++) { + for (int j=0; j 10) { + break; + } + } + printf("\n"); + } + for (int i=0; i<50; i++) { + printf("%f ", double(buf[i])); + } + printf("buf last: %f\n", double(buf[st-1])); + printf("\n"); +} size_t h_token_num; T* bart_encoder_input_ptr; T* bart_encoder_output_ptr; diff --git a/src/fastertransformer/models/bart/BartEncoderLayerWeight.cc b/src/fastertransformer/models/bart/BartEncoderLayerWeight.cc index 579e8aec3..7f8f42b3c 100644 --- a/src/fastertransformer/models/bart/BartEncoderLayerWeight.cc +++ b/src/fastertransformer/models/bart/BartEncoderLayerWeight.cc @@ -38,6 +38,7 @@ BartEncoderLayerWeight::BartEncoderLayerWeight(const size_t head_num, bart_with_bias_(bart_with_bias), use_gated_activation_(use_gated_activation) { + printf("BartEncoderLayerWeight\n"); real_weights_num_ = (8 + (use_gated_activation_ ? 1 : 0)) * (bart_with_bias_ ? 2 : 1); // 8: Q, K, V, O, LayerNorm1, FC1, FC2, LayerNorm2 FT_LOG_DEBUG("BartEncoderLayerWeight " + std::string(__func__) + " start"); @@ -293,7 +294,77 @@ void BartEncoderLayerWeight::loadModel(std::string dir_path, FtCudaDataType m { FT_LOG_DEBUG("BartEncoderLayerWeight " + std::string(__func__) + " start"); - FT_LOG_DEBUG("Megatron BART support TBD"); + const auto tp_rank = std::to_string(tensor_para_rank_); + loadWeightFromBin(weights_ptr_[0], + {weights_size_[0]}, + dir_path + "layer.SelfAttention.q.weight." + tp_rank + ".bin", + model_file_type); + loadWeightFromBin(weights_ptr_[1], + {weights_size_[1]}, + dir_path + "layer.SelfAttention.k.weight." + tp_rank + ".bin", + model_file_type); + loadWeightFromBin(weights_ptr_[2], + {weights_size_[2]}, + dir_path + "layer.SelfAttention.v.weight." + tp_rank + ".bin", + model_file_type); + loadWeightFromBin(weights_ptr_[3], + {weights_size_[3]}, + dir_path + "layer.SelfAttention.out_proj.weight." + tp_rank + ".bin", + model_file_type); + loadWeightFromBin(weights_ptr_[4], + {weights_size_[4]}, + dir_path + "layer.SelfAttention.attn_layer_norm.weight.bin", + model_file_type); + + loadWeightFromBin(weights_ptr_[5], + {weights_size_[5]}, + dir_path + "layer.SelfAttention.fc1.weight." + tp_rank + ".bin", + model_file_type); + loadWeightFromBin(weights_ptr_[6], + {weights_size_[6]}, + dir_path + "layer.SelfAttention.fc2.weight." + tp_rank + ".bin", + model_file_type); + loadWeightFromBin(weights_ptr_[7], + {weights_size_[7]}, + dir_path + "layer.SelfAttention.final_layer_norm.weight.bin", + model_file_type); + + if (bart_with_bias_) { + loadWeightFromBin(weights_ptr_[8], + {weights_size_[8]}, + dir_path + "layer.SelfAttention.q.bias." + tp_rank + ".bin", + model_file_type); + loadWeightFromBin(weights_ptr_[9], + {weights_size_[9]}, + dir_path + "layer.SelfAttention.k.bias." + tp_rank + ".bin", + model_file_type); + loadWeightFromBin(weights_ptr_[10], + {weights_size_[10]}, + dir_path + "layer.SelfAttention.v.bias." + tp_rank + ".bin", + model_file_type); + loadWeightFromBin(weights_ptr_[11], + {weights_size_[11]}, + dir_path + "layer.SelfAttention.out_proj.bias." + tp_rank + ".bin", + model_file_type); + loadWeightFromBin(weights_ptr_[12], + {weights_size_[12]}, + dir_path + "layer.SelfAttention.attn_layer_norm.bias.bin", + model_file_type); + + loadWeightFromBin(weights_ptr_[13], + {weights_size_[13]}, + dir_path + "layer.SelfAttention.fc1.bias." + tp_rank + ".bin", + model_file_type); + loadWeightFromBin(weights_ptr_[14], + {weights_size_[14]}, + dir_path + "layer.SelfAttention.fc2.bias." + tp_rank + ".bin", + model_file_type); + loadWeightFromBin(weights_ptr_[15], + {weights_size_[15]}, + dir_path + "layer.SelfAttention.final_layer_norm.bias.bin", + model_file_type); + } + FT_LOG_DEBUG("BartEncoderLayerWeight " + std::string(__func__) + " end"); } diff --git a/src/fastertransformer/models/bart/BartEncoderWeight.cc b/src/fastertransformer/models/bart/BartEncoderWeight.cc index 47028260a..275f21435 100644 --- a/src/fastertransformer/models/bart/BartEncoderWeight.cc +++ b/src/fastertransformer/models/bart/BartEncoderWeight.cc @@ -62,6 +62,7 @@ BartEncoderWeight::BartEncoderWeight(const size_t head_num, setWeightPtr(); bart_encoder_layer_weights.clear(); bart_encoder_layer_weights.reserve(num_layer_); + printf("bart_encoder_layer_weights.reserve(num_layer_);\n"); for (int l = 0; l < num_layer_; l++) { if (isValidLayerParallelId(l)) { bart_encoder_layer_weights.push_back(new BartEncoderLayerWeight(head_num_, @@ -79,6 +80,7 @@ BartEncoderWeight::BartEncoderWeight(const size_t head_num, } } FT_LOG_DEBUG("BartEncoderWeight " + std::string(__func__) + " end"); + printf("BartEncoderWeight Done\n"); } template @@ -154,6 +156,7 @@ BartEncoderWeight::BartEncoderWeight(const BartEncoderWeight& other): position_embedding_type(other.position_embedding_type), real_weights_num_(other.real_weights_num_) { + printf("Copy BartEncoderWeight\n"); FT_LOG_DEBUG("BartEncoderWeight " + std::string(__func__) + " start"); initialize(); mallocWeights(); @@ -249,7 +252,41 @@ void BartEncoderWeight::loadModel(std::string dir_path) { FT_LOG_DEBUG("BartEncoderWeight " + std::string(__func__) + " start"); - FT_LOG_DEBUG("Megatron BART support TBD"); + FtCudaDataType model_file_type = getModelFileType(dir_path + "/config.ini", "encoder"); + FT_CHECK(is_maintain_buffer == true); + + loadWeightFromBin(weights_ptr[0], {(size_t)weights_size[0]}, dir_path + "/encoder.embed_positions.weight.bin", model_file_type); + loadWeightFromBin(weights_ptr[1], {(size_t)weights_size[1]}, dir_path + "/encoder.embed_tokens.weight.bin", model_file_type); +{ + T* buf; + int batch_size = 1; + int seq_len = 11; + int st = weights_size[1]; + printf("weights_size: %d \n",weights_size[1]); + buf = new T[st]; + cudaMemcpy(buf, weights_ptr[1], sizeof(T) * st, cudaMemcpyDeviceToHost); + printf("weights_ptr[0]\n"); + for (int i=0; i<50; i++) { + printf("%f ", double(buf[i])); + } + printf("buf last: %f\n", double(buf[st-1])); + printf("\n"); +} + loadWeightFromBin( + weights_ptr[2], {(size_t)weights_size[2]}, dir_path + "/encoder.final_layer_norm.weight.bin", model_file_type); + if (bart_with_bias) { + loadWeightFromBin(weights_ptr[3], + {(size_t)weights_size[3]}, + dir_path + "/encoder.final_layer_norm.bias.bin", + model_file_type); + } + + for (int l = 0; l < num_layer_; l++) { + if (isValidLayerParallelId(l)) { + bart_encoder_layer_weights[l]->loadModel(dir_path + "/encoder." + std::to_string(l) + ".", + model_file_type); + } + } FT_LOG_DEBUG("BartEncoderWeight " + std::string(__func__) + " end"); } diff --git a/src/fastertransformer/models/t5/T5Decoding.h b/src/fastertransformer/models/t5/T5Decoding.h index 67f04d480..cf74652a9 100644 --- a/src/fastertransformer/models/t5/T5Decoding.h +++ b/src/fastertransformer/models/t5/T5Decoding.h @@ -27,17 +27,6 @@ namespace fastertransformer { -// fallback to fp32 dynamic decoder when bf16 specified -template -struct fallBackType { - using Type = float; -}; - -template<> -struct fallBackType { - using Type = half; -}; - template class T5Decoding: public BaseLayer { private: diff --git a/src/fastertransformer/triton_backend/CMakeLists.txt b/src/fastertransformer/triton_backend/CMakeLists.txt index 037c36c36..c27c1bb13 100644 --- a/src/fastertransformer/triton_backend/CMakeLists.txt +++ b/src/fastertransformer/triton_backend/CMakeLists.txt @@ -19,6 +19,7 @@ target_link_libraries(TransformerTritonBackend PRIVATE nccl_utils mpi_utils) add_subdirectory(gptj) add_subdirectory(gptneox) +add_subdirectory(bart) add_subdirectory(t5) add_subdirectory(t5-encoder) add_subdirectory(multi_gpu_gpt) diff --git a/src/fastertransformer/triton_backend/bart/BartTritonModel.cc b/src/fastertransformer/triton_backend/bart/BartTritonModel.cc new file mode 100644 index 000000000..9d82e4d38 --- /dev/null +++ b/src/fastertransformer/triton_backend/bart/BartTritonModel.cc @@ -0,0 +1,357 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/fastertransformer/triton_backend/bart/BartTritonModel.h" +#include "src/fastertransformer/triton_backend/bart/BartTritonModelInstance.h" +#include "src/fastertransformer/triton_backend/transformer_triton_backend.hpp" +#include "src/fastertransformer/utils/allocator.h" + +namespace ft = fastertransformer; + +std::shared_ptr AbstractTransformerModel::createBartModel(std::string model_dir) +{ + INIReader reader = INIReader(model_dir + "/config.ini"); + if (reader.ParseError() < 0) { + std::cout << "[ERROR] Can't load '" << model_dir << "/config.ini" + << "'\n"; + return nullptr; + } + + const std::string data_type = "fp32"; //reader.Get("ft_instance_hyperparameter", "data_type"); + if (data_type == "fp16") { + // return std::make_shared>(reader, model_dir); + return std::make_shared>(1, 1, 0, model_dir, 0); + } +#ifdef ENABLE_BF16 + else if (data_type == "bf16") { + return std::make_shared>(1, 1, 0, model_dir, 0); + } +#endif + else if (data_type == "fp32") { + return std::make_shared>(1, 1, 0, model_dir, 0); + } + else { + FT_LOG_ERROR("Unsupported data type " + data_type); + exit(-1); + } +} + +template +BartTritonModel::BartTritonModel(INIReader reader, std::string model_dir): model_dir_(model_dir) +{ + // encoder + encoder_head_num_ = reader.GetInteger("encoder", "num_heads"); + encoder_size_per_head_ = reader.GetInteger("encoder", "d_kv"); + encoder_d_model_ = reader.GetInteger("encoder", "d_model"); + encoder_inter_size_ = reader.GetInteger("encoder", "d_ff"); + encoder_num_layer_ = reader.GetInteger("encoder", "num_layers"); + encoder_vocab_size_ = reader.GetInteger("encoder", "vocab_size"); + encoder_max_pos_seq_len_ = reader.GetInteger("encoder", "max_pos_seq_len"); + + // decoding + decoding_head_num_ = reader.GetInteger("decoder", "num_heads"); + decoding_size_per_head_ = reader.GetInteger("decoder", "d_kv"); + decoding_d_model_ = reader.GetInteger("decoder", "d_model"); + decoding_inter_size_ = reader.GetInteger("decoder", "d_ff"); + decoding_num_layer_ = reader.GetInteger("decoder", "num_layers"); + decoding_vocab_size_ = reader.GetInteger("decoder", "vocab_size"); + decoding_max_pos_seq_len_ = reader.GetInteger("decoder", "max_pos_seq_len"); + + start_id_ = reader.GetInteger("decoder", "decoder_start_token_id"); + end_id_ = reader.GetInteger("decoder", "eos_token_id"); + tensor_para_size_ = reader.GetInteger("ft_instance_hyperparameter", "tensor_para_size"); + pipeline_para_size_ = reader.GetInteger("ft_instance_hyperparameter", "pipeline_para_size"); + enable_custom_all_reduce_ = reader.GetInteger("ft_instance_hyperparameter", "enable_custom_all_reduce", 0); + max_distance_ = 128; // use default value of huggingface here +} + +template +BartTritonModel::BartTritonModel(size_t tensor_para_size, + size_t pipeline_para_size, + int enable_custom_all_reduce, + std::string model_dir, + int int8_mode): + tensor_para_size_(tensor_para_size), + pipeline_para_size_(pipeline_para_size), + encoder_shared_weights_(std::vector>>(ft::getDeviceCount())), + decoding_shared_weights_(std::vector>>(ft::getDeviceCount())), + enable_custom_all_reduce_(enable_custom_all_reduce), + model_dir_(model_dir), + int8_mode_(int8_mode) +{ + INIReader reader = INIReader(model_dir + "/config.ini"); + if (reader.ParseError() < 0) { + std::cout << "[ERROR] Can't load '" << model_dir << "/config.ini" + << "'\n"; + ft::FT_CHECK(false); + } + + ft::FT_CHECK(int8_mode_ == 0); + + model_name_ = reader.Get("encoder", "model_name"); + // encoder + encoder_head_num_ = reader.GetInteger("encoder", "num_heads"); + encoder_size_per_head_ = reader.GetInteger("encoder", "d_kv"); + encoder_d_model_ = reader.GetInteger("encoder", "d_model"); + encoder_inter_size_ = reader.GetInteger("encoder", "d_ff"); + encoder_num_layer_ = reader.GetInteger("encoder", "num_layers"); + encoder_vocab_size_ = reader.GetInteger("encoder", "vocab_size"); + encoder_max_pos_seq_len_ = + reader.GetInteger("encoder", "max_pos_seq_len"); + + // decoding + decoding_head_num_ = reader.GetInteger("decoder", "num_heads"); + decoding_size_per_head_ = reader.GetInteger("decoder", "d_kv"); + decoding_d_model_ = reader.GetInteger("decoder", "d_model"); + decoding_inter_size_ = reader.GetInteger("decoder", "d_ff"); + decoding_num_layer_ = reader.GetInteger("decoder", "num_layers"); + decoding_vocab_size_ = reader.GetInteger("decoder", "vocab_size"); + decoding_max_pos_seq_len_ = + reader.GetInteger("decoder", "max_pos_seq_len"); + + start_id_ = reader.GetInteger("decoder", "decoder_start_token_id"); + end_id_ = reader.GetInteger("decoder", "eos_token_id"); + + // common settings + activation_type_ = ft::getActivationType(reader.Get("encoder", "feed_forward_proj")); + + max_distance_ = 128; // use default value of huggingface here +} + +template +std::unique_ptr +BartTritonModel::createModelInstance(int device_id, + int rank, + cudaStream_t stream, + std::pair, std::vector> nccl_params, + std::shared_ptr custom_all_reduce_comm) +{ + printf("createModelInstance\n"); + ft::check_cuda_error(cudaSetDevice(device_id)); + const int comms_rank = device_id % (tensor_para_size_ * pipeline_para_size_); + + std::unique_ptr> allocator( + new ft::Allocator(device_id)); + + allocator->setStream(stream); + + cublasHandle_t cublas_handle; + cublasLtHandle_t cublaslt_handle; + + cublasCreate(&cublas_handle); + cublasLtCreate(&cublaslt_handle); + cublasSetStream(cublas_handle, stream); + + std::unique_ptr cublas_algo_map(new ft::cublasAlgoMap("gemm_config.in")); + std::unique_ptr cublas_wrapper_mutex(new std::mutex()); + std::unique_ptr cublas_wrapper(new ft::cublasMMWrapper( + cublas_handle, cublaslt_handle, stream, cublas_algo_map.get(), cublas_wrapper_mutex.get(), allocator.get())); + + std::unique_ptr cuda_device_prop_ptr(new cudaDeviceProp); + ft::check_cuda_error(cudaGetDeviceProperties(cuda_device_prop_ptr.get(), device_id)); + + if (std::is_same::value) { + cublas_wrapper->setGemmConfig(CUDA_R_16F, CUDA_R_16F, CUDA_R_16F, CUDA_R_32F); + } +#ifdef ENABLE_BF16 + else if (std::is_same::value) { + cublas_wrapper->setBF16GemmConfig(); + } +#endif + else if (std::is_same::value) { + cublas_wrapper->setFP32GemmConfig(); + } + + const int sm_ = ft::getSMVersion(); + + // TODO(bhsueh) not support fused mha + // NOTE: fmha doesn't support bart-style relative position bias + ft::AttentionType attention_type = + ft::getAttentionType(encoder_size_per_head_, sm_, true, encoder_max_pos_seq_len_, false); + + ft::NcclParam tensor_para_ = nccl_params.first[comms_rank]; + ft::NcclParam pipeline_para_ = nccl_params.second[comms_rank]; + + auto encoder = std::make_unique>(ft::BartEncoder(0, + 0, + encoder_head_num_, + encoder_size_per_head_, + encoder_inter_size_, + encoder_d_model_, + encoder_num_layer_, + encoder_max_pos_seq_len_, + max_distance_, + sm_, + q_scaling_, + stream, + cublas_wrapper.get(), + allocator.get(), + false, + attention_type, + false, + activation_type_, + layernorm_type_, + tensor_para_, + pipeline_para_, + custom_all_reduce_comm, + enable_custom_all_reduce_)); + + auto decoding = std::make_unique>(ft::BartDecoding(0, + 0, + 0, + 0, + decoding_head_num_, + decoding_size_per_head_, + decoding_inter_size_, + decoding_d_model_, + decoding_num_layer_, + decoding_vocab_size_, + decoding_max_pos_seq_len_, + max_distance_, + q_scaling_, + start_id_, + end_id_, + 0.0f, // beam_search_diversity_rate_, + 1, // top_k_, + 0.0f, // top_p_, + 1.0f, // temperature_, + 0.0f, // len_penalty_, + 1.0f, // repetition_penalty_, + stream, + cublas_wrapper.get(), + allocator.get(), + false, + cuda_device_prop_ptr.get(), + tensor_para_, + pipeline_para_, + activation_type_, + layernorm_type_, + tie_word_embeddings_, + custom_all_reduce_comm, + enable_custom_all_reduce_)); + + return std::unique_ptr>(new BartTritonModelInstance(std::move(encoder), + std::move(decoding), + encoder_shared_weights_[device_id], + decoding_shared_weights_[device_id], + std::move(allocator), + std::move(cublas_algo_map), + std::move(cublas_wrapper_mutex), + std::move(cublas_wrapper), + std::move(cuda_device_prop_ptr))); +} + +template +void BartTritonModel::createSharedWeights(int device_id, int rank) +{ + printf("createSharedWeights\n"); + ft::check_cuda_error(cudaSetDevice(device_id)); + const int tensor_para_rank = rank % tensor_para_size_; + const int pipeline_para_rank = rank / tensor_para_size_; + + printf("BartEncoderWeight %d %d\n", encoder_shared_weights_.size(), device_id); + encoder_shared_weights_[device_id] = + std::make_shared>(encoder_head_num_, + encoder_size_per_head_, + encoder_d_model_, + encoder_inter_size_, + encoder_vocab_size_, + encoder_num_layer_, + encoder_max_pos_seq_len_, + tensor_para_size_, + tensor_para_rank, + pipeline_para_size_, + pipeline_para_rank, + bart_with_bias_, + mbart_para_, + use_gated_activation_, + position_embedding_type_); + + printf("BartDecodingWeight\n"); + decoding_shared_weights_[device_id] = + std::make_shared>(decoding_head_num_, + decoding_size_per_head_, + decoding_d_model_, + decoding_inter_size_, + decoding_vocab_size_, + decoding_num_layer_, + encoder_d_model_, + decoding_max_pos_seq_len_, + tensor_para_size_, + tensor_para_rank, + pipeline_para_size_, + pipeline_para_rank, + bart_with_bias_, + mbart_para_, + use_gated_activation_, + position_embedding_type_); + + printf("load model\n"); + encoder_shared_weights_[device_id]->loadModel(model_dir_); + decoding_shared_weights_[device_id]->loadModel(model_dir_); +} + +template +std::string BartTritonModel::toString() +{ + std::stringstream ss; + std::string position_embedding_type_string = + position_embedding_type_ == ft::PositionEmbeddingType::relative ? "relative" : "absolute"; + + ss << "\nModel: " + << "\n encoder_head_num_: " << encoder_head_num_ << "\n encoder_size_per_head_: " << encoder_size_per_head_ + << "\n encoder_d_model_: " << encoder_d_model_ << "\n encoder_inter_size_: " << encoder_inter_size_ + << "\n encoder_num_layer_: " << encoder_num_layer_ << "\n encoder_vocab_size_: " << encoder_vocab_size_ + << "\n encoder_max_pos_seq_len_: " << encoder_max_pos_seq_len_ + << "\n decoding_head_num_: " << decoding_head_num_ + << "\n decoding_size_per_head_: " << decoding_size_per_head_ + << "\n decoding_d_model_: " << decoding_d_model_ << "\n decoding_inter_size_: " << decoding_inter_size_ + << "\n decoding_num_layer_: " << decoding_num_layer_ << "\n decoding_vocab_size_: " << decoding_vocab_size_ + << "\n decoding_max_pos_seq_len_: " << decoding_max_pos_seq_len_ + << "\n bart_with_bias_: " << bart_with_bias_ + << "\n use_gated_activation_: " << use_gated_activation_ + << "\n position_embedding_type_: " << position_embedding_type_string << "\n start_id_: " << start_id_ + << "\n end_id_: " << end_id_ << "\n model_name_: " << model_name_ << "\n model_dir_: " << model_dir_ + << std::endl; + + return ss.str(); +} + +template +void BartTritonModel::createCustomComms(std::vector>* custom_all_reduce_comms, + int world_size) +{ + using commDataType = typename ft::CustomARCommTypeConverter::Type; + ft::initCustomAllReduceComm(custom_all_reduce_comms, enable_custom_all_reduce_, world_size); +} + +template +int BartTritonModel::getTensorParaSize() +{ + return tensor_para_size_; +} + +template +int BartTritonModel::getPipelineParaSize() +{ + return pipeline_para_size_; +} + +template struct BartTritonModel; +template struct BartTritonModel; +#ifdef ENABLE_BF16 +template struct BartTritonModel<__nv_bfloat16>; +#endif diff --git a/src/fastertransformer/triton_backend/bart/BartTritonModel.h b/src/fastertransformer/triton_backend/bart/BartTritonModel.h new file mode 100644 index 000000000..47ab7f08f --- /dev/null +++ b/src/fastertransformer/triton_backend/bart/BartTritonModel.h @@ -0,0 +1,108 @@ +/* + * Copyright (c) 2021-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "3rdparty/INIReader.h" +#include "src/fastertransformer/models/bart/BartDecoding.h" +#include "src/fastertransformer/models/bart/BartEncoder.h" +#include "src/fastertransformer/triton_backend/transformer_triton_backend.hpp" +#include "src/fastertransformer/utils/cuda_utils.h" +#include "src/fastertransformer/utils/custom_ar_comm.h" +#include "src/fastertransformer/utils/nccl_utils.h" +#include + +namespace ft = fastertransformer; + +template +struct BartTritonModel: public AbstractTransformerModel { + BartTritonModel(INIReader reader, std::string model_dir); + + BartTritonModel(size_t tensor_para_size, + size_t pipeline_para_size, + int enable_custom_all_reduce, + std::string model_dir, + int int8_mode); + + ~BartTritonModel() = default; + + virtual std::unique_ptr + createModelInstance(int deviceId, + int rank, + cudaStream_t stream, + std::pair, std::vector> nccl_params, + std::shared_ptr custom_all_reduce_comm = nullptr); + + virtual void createSharedWeights(int deviceId, int rank) override; + + virtual void createCustomComms(std::vector>* custom_all_reduce_comms, + int world_size) override; + + virtual std::string toString() override; + virtual int getTensorParaSize() override; + virtual int getPipelineParaSize() override; + +private: + // encoder + size_t encoder_head_num_; + size_t encoder_size_per_head_; + size_t encoder_d_model_; + size_t encoder_inter_size_; + size_t encoder_num_layer_; + size_t encoder_vocab_size_; + size_t encoder_max_pos_seq_len_; + + // decoding + size_t decoding_head_num_; + size_t decoding_size_per_head_; + size_t decoding_d_model_; + size_t decoding_inter_size_; + size_t decoding_num_layer_; + size_t decoding_vocab_size_; + size_t decoding_max_pos_seq_len_; + + float q_scaling_ = 1.f; + + size_t max_distance_; + int start_id_; + int end_id_; + + bool tie_word_embeddings_ = false; + + size_t tensor_para_size_; + size_t pipeline_para_size_; + + // shared weights for each device + std::vector>> encoder_shared_weights_; + std::vector>> decoding_shared_weights_; + + // bart structure difference + bool bart_with_bias_ = true; + // TODO(zhwang): support mbart. + bool mbart_para_ = false; + bool use_gated_activation_ = false; + ft::PositionEmbeddingType position_embedding_type_ = ft::PositionEmbeddingType::absolute; + ft::ActivationType activation_type_; + ft::LayerNormType layernorm_type_ = ft::LayerNormType::post_layernorm; + + bool is_fp16_; + int int8_mode_; + + int enable_custom_all_reduce_ = 0; + + std::string model_name_; + std::string model_dir_; +}; diff --git a/src/fastertransformer/triton_backend/bart/BartTritonModelInstance.cc b/src/fastertransformer/triton_backend/bart/BartTritonModelInstance.cc new file mode 100644 index 000000000..2c8add9b6 --- /dev/null +++ b/src/fastertransformer/triton_backend/bart/BartTritonModelInstance.cc @@ -0,0 +1,273 @@ +/* + * Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/fastertransformer/triton_backend/bart/BartTritonModelInstance.h" +#include "src/fastertransformer/triton_backend/transformer_triton_backend.hpp" +#include "src/fastertransformer/triton_backend/triton_utils.hpp" +#include "src/fastertransformer/utils/Tensor.h" +#include + +namespace ft = fastertransformer; + +template +void triton_stream_callback(ft::TensorMap* output_tensors, void* ctx) +{ + auto* const model = reinterpret_cast*>(ctx); + auto const result = BartTritonModelInstance::convert_outputs(*output_tensors); + + model->stream_cb_(result, model->stream_ctx_); +} + +template +BartTritonModelInstance::BartTritonModelInstance(std::unique_ptr> bart_encoder, + std::unique_ptr> bart_decoding, + std::shared_ptr> bart_encoder_weight, + std::shared_ptr> bart_decoding_weight, + std::unique_ptr> allocator, + std::unique_ptr cublas_algo_map, + std::unique_ptr cublas_wrapper_mutex, + std::unique_ptr cublas_wrapper, + std::unique_ptr cuda_device_prop_ptr): + bart_encoder_(std::move(bart_encoder)), + bart_decoding_(std::move(bart_decoding)), + bart_encoder_weight_(bart_encoder_weight), + bart_decoding_weight_(bart_decoding_weight), + allocator_(std::move(allocator)), + cublas_algo_map_(std::move(cublas_algo_map)), + cublas_wrapper_mutex_(std::move(cublas_wrapper_mutex)), + cublas_wrapper_(std::move(cublas_wrapper)), + cuda_device_prop_ptr_(std::move(cuda_device_prop_ptr)) +{ +} + +template +ft::TensorMap +BartTritonModelInstance::convert_inputs(std::shared_ptr> input_tensors) +{ + move_tensor_H2D(input_tensors->at("input_ids"), d_input_ids_, &allocator_); + move_tensor_H2D(input_tensors->at("sequence_length"), d_input_lengths_, &allocator_); + + ft::TensorMap ft_input_tensors( + {{"input_ids", as_GPU_tensor(input_tensors->at("input_ids"), d_input_ids_)}, + {"sequence_length", as_GPU_tensor(input_tensors->at("sequence_length"), d_input_lengths_)}}); + + return ft_input_tensors; +} + +template +std::shared_ptr> +BartTritonModelInstance::convert_outputs(ft::TensorMap& output_tensors) +{ + std::unordered_map* outputs_mapping = + new std::unordered_map(); + + for (auto it = output_tensors.begin(); it != output_tensors.end(); it++) { + outputs_mapping->insert({it->first, triton::Tensor::convertFtTensorToTriton(it->second)}); + } + + return std::shared_ptr>(outputs_mapping); +} + +template +std::shared_ptr> +BartTritonModelInstance::forward(std::shared_ptr> input_tensors) +{ + printf("BartTritonModelInstance::forward\n"); + for (const auto& pair : *input_tensors) { + std::cout << "Key: " << pair.first << std::endl; + input_tensors->at(pair.first); + } + + printf("input_tensors input_ids\n"); + printf("done\n"); + const size_t request_batch_size = input_tensors->at("input_ids").shape[0]; + const size_t mem_max_seq_len = input_tensors->at("input_ids").shape[1]; + const size_t max_output_len = *((uint*)input_tensors->at("max_output_len").data); + const size_t beam_width = + input_tensors->count("beam_width") ? (size_t)(*(uint*)input_tensors->at("beam_width").data) : 1; + + printf("allocateBuffer\n"); + allocateBuffer(request_batch_size, beam_width, max_output_len, mem_max_seq_len); + + ft::TensorMap encoder_input_tensors(convert_inputs(input_tensors)); + printf("encoder_input_tensors\n"); + ft::TensorMap encoder_output_tensors( + {{"output_hidden_state", + ft::Tensor{ft::MEMORY_GPU, + ft::getTensorType(), + std::vector{request_batch_size, mem_max_seq_len, bart_encoder_->getDModel()}, + d_encoder_outputs_}}}); + + ft::TensorMap decoding_input_tensors({{"encoder_output", encoder_output_tensors.at("output_hidden_state")}, + {"encoder_sequence_length", encoder_input_tensors.at("sequence_length")}}); + + if (input_tensors->find("top_p_decay") != input_tensors->end()) { + move_tensor_H2D(input_tensors->at("top_p_decay"), d_top_p_decay_, &allocator_); + decoding_input_tensors.insert({"top_p_decay", as_GPU_tensor(input_tensors->at("top_p_decay"), d_top_p_decay_)}); + } + if (input_tensors->find("top_p_min") != input_tensors->end()) { + move_tensor_H2D(input_tensors->at("top_p_min"), d_top_p_min_, &allocator_); + decoding_input_tensors.insert({"top_p_min", as_GPU_tensor(input_tensors->at("top_p_min"), d_top_p_min_)}); + } + if (input_tensors->find("top_p_reset_ids") != input_tensors->end()) { + move_tensor_H2D(input_tensors->at("top_p_reset_ids"), d_top_p_reset_ids_, &allocator_); + decoding_input_tensors.insert( + {"top_p_reset_ids", as_GPU_tensor(input_tensors->at("top_p_reset_ids"), d_top_p_reset_ids_)}); + } + + std::set keys_on_gpu = {"input_ids", + "sequence_length", + "bad_words_list", + "stop_words_list", + "top_p_decay", + "top_p_min", + "top_p_reset_ids"}; + for (auto& t : *input_tensors) { + if (keys_on_gpu.count(t.first) == 0) { + decoding_input_tensors.insert({t.first, t.second.convertTritonTensorToFt()}); + } + } + + if (input_tensors->find("bad_words_list") != input_tensors->end()) { + move_tensor_H2D(input_tensors->at("bad_words_list"), d_input_bad_words_, &allocator_); + decoding_input_tensors.insert( + {"bad_words_list", as_GPU_tensor(input_tensors->at("bad_words_list"), d_input_bad_words_)}); + } + + if (input_tensors->find("stop_words_list") != input_tensors->end()) { + move_tensor_H2D(input_tensors->at("stop_words_list"), d_input_stop_words_, &allocator_); + decoding_input_tensors.insert( + {"stop_words_list", as_GPU_tensor(input_tensors->at("stop_words_list"), d_input_stop_words_)}); + } + + ft::TensorMap decoding_output_tensors( + {{"output_ids", + ft::Tensor{ft::MEMORY_GPU, + ft::TYPE_INT32, + std::vector{request_batch_size, beam_width, max_output_len}, + d_output_ids_}}, + {"sequence_length", + ft::Tensor{ft::MEMORY_GPU, + ft::TYPE_INT32, + std::vector{request_batch_size, beam_width}, + d_sequence_lengths_}}}); + if (input_tensors->count("is_return_log_probs") > 0 + && input_tensors->at("is_return_log_probs").convertTritonTensorToFt().getVal()) { + decoding_output_tensors.insert({"output_log_probs", + ft::Tensor{ft::MEMORY_GPU, + ft::TYPE_FP32, + std::vector{request_batch_size, beam_width, max_output_len}, + d_output_log_probs_}}); + decoding_output_tensors.insert({"cum_log_probs", + ft::Tensor{ft::MEMORY_GPU, + ft::TYPE_FP32, + std::vector{request_batch_size, beam_width}, + d_cum_log_probs_}}); + } + + try { + if (stream_cb_ != nullptr) { + bart_decoding_->registerCallback(triton_stream_callback, this); + } + + bart_encoder_->forward(&encoder_output_tensors, &encoder_input_tensors, bart_encoder_weight_.get()); + + +{ + T* buf; + int st = request_batch_size * mem_max_seq_len * bart_encoder_->getDModel(); + buf = new T[st]; + cudaMemcpy(buf, d_encoder_outputs_, sizeof(T) * st, cudaMemcpyDeviceToHost); + printf("cudaMemcpy\n"); + for (int i=0; i<10; i++) { + printf("%f ", double(buf[i])); + if (i % 500 == 10 ) { + printf("\n"); + } + } + printf("\n"); +} + + bart_decoding_->forward(&decoding_output_tensors, &decoding_input_tensors, bart_decoding_weight_.get()); + +{ + int* buf; + int st = request_batch_size * max_output_len; + buf = new int[st]; + cudaMemcpy(buf, d_output_ids_, sizeof(int) * st, cudaMemcpyDeviceToHost); + printf("cudaMemcpy d_output_ids_\n"); + for (int i=0; i<10; i++) { + printf("%d ", (buf[i])); + if (i % 500 == 10 ) { + printf("\n"); + } + } + printf("\n"); +} + if (stream_cb_ != nullptr) { + bart_decoding_->unRegisterCallback(); + } + } + catch (...) { + h_exception_ = std::current_exception(); + decoding_output_tensors.insert( + {"error_message", ft::Tensor{ft::MEMORY_CPU, ft::TYPE_BYTES, {1}, &h_exception_}}); + } + + return convert_outputs(decoding_output_tensors); +} + +template +BartTritonModelInstance::~BartTritonModelInstance() +{ + freeBuffer(); +} + +template +void BartTritonModelInstance::allocateBuffer(const size_t request_batch_size, + const size_t beam_width, + const size_t max_output_len, + const size_t mem_max_seq_len) +{ + d_output_ids_ = (int*)(allocator_->reMalloc( + d_output_ids_, sizeof(int) * request_batch_size * beam_width * max_output_len, false)); + d_encoder_outputs_ = (T*)(allocator_->reMalloc( + d_encoder_outputs_, sizeof(T) * request_batch_size * mem_max_seq_len * bart_encoder_->getDModel(), false)); + d_sequence_lengths_ = + (int*)(allocator_->reMalloc(d_sequence_lengths_, sizeof(int) * request_batch_size * beam_width, false)); + d_output_log_probs_ = (float*)(allocator_->reMalloc( + d_output_log_probs_, sizeof(float) * request_batch_size * beam_width * max_output_len, false)); + d_cum_log_probs_ = (float*)(allocator_->reMalloc( + d_cum_log_probs_, sizeof(float) * request_batch_size * beam_width * max_output_len, false)); + d_within_range_ = (bool*)(allocator_->reMalloc(d_within_range_, sizeof(bool))); +} + +template +void BartTritonModelInstance::freeBuffer() +{ + allocator_->free((void**)(&d_encoder_outputs_)); + allocator_->free((void**)(&d_output_ids_)); + allocator_->free((void**)(&d_sequence_lengths_)); + allocator_->free((void**)(&d_output_log_probs_)); + allocator_->free((void**)(&d_cum_log_probs_)); + allocator_->free((void**)(&d_within_range_)); +} + +template struct BartTritonModelInstance; +template struct BartTritonModelInstance; +#ifdef ENABLE_BF16 +template struct BartTritonModelInstance<__nv_bfloat16>; +#endif diff --git a/src/fastertransformer/triton_backend/bart/BartTritonModelInstance.h b/src/fastertransformer/triton_backend/bart/BartTritonModelInstance.h new file mode 100644 index 000000000..8c14901e4 --- /dev/null +++ b/src/fastertransformer/triton_backend/bart/BartTritonModelInstance.h @@ -0,0 +1,94 @@ +/* + * Copyright (c) 2021-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "src/fastertransformer/models/bart/BartDecoding.h" +#include "src/fastertransformer/models/bart/BartEncoder.h" +#include "src/fastertransformer/triton_backend/bart/BartTritonModel.h" +#include "src/fastertransformer/triton_backend/transformer_triton_backend.hpp" +#include + +namespace ft = fastertransformer; + +template +struct BartTritonModelInstance: AbstractTransformerModelInstance { + + BartTritonModelInstance(std::unique_ptr> bart_encoder, + std::unique_ptr> bart_decoding, + std::shared_ptr> bart_encoder_weight, + std::shared_ptr> bart_decoding_weight, + std::unique_ptr> allocator, + std::unique_ptr cublas_algo_map, + std::unique_ptr cublas_wrapper_mutex, + std::unique_ptr cublas_wrapper, + std::unique_ptr cuda_device_prop_ptr); + ~BartTritonModelInstance(); + + std::shared_ptr> + forward(std::shared_ptr> input_tensors) override + { + ft::FT_CHECK(false); + return nullptr; + }; + + std::shared_ptr> + forward(std::shared_ptr> input_tensors) override; + + static std::shared_ptr> + convert_outputs(ft::TensorMap& output_tensors); + +private: + const std::unique_ptr> bart_encoder_; + const std::shared_ptr> bart_encoder_weight_; + const std::unique_ptr> bart_decoding_; + const std::shared_ptr> bart_decoding_weight_; + const std::unique_ptr> allocator_; + const std::unique_ptr cublas_algo_map_; + const std::unique_ptr cublas_wrapper_mutex_; + const std::unique_ptr cublas_wrapper_; + const std::unique_ptr cuda_device_prop_ptr_; + + ft::TensorMap convert_inputs(std::shared_ptr> input_tensors); + + void allocateBuffer(const size_t request_batch_size, + const size_t beam_width, + const size_t max_output_len, + const size_t mem_max_seq_len); + void freeBuffer(); + + int* d_input_ids_ = nullptr; + int* d_input_lengths_ = nullptr; + int* d_input_bad_words_ = nullptr; + int* d_input_stop_words_ = nullptr; + int* d_input_ia3_tasks_ = nullptr; + int* d_request_prompt_lengths_ = nullptr; + T* d_request_prompt_embedding_ = nullptr; + float* d_top_p_decay_ = nullptr; + float* d_top_p_min_ = nullptr; + int* d_top_p_reset_ids_ = nullptr; + + T* d_encoder_outputs_ = nullptr; + int* d_output_ids_ = nullptr; + int* d_sequence_lengths_ = nullptr; + float* d_output_log_probs_ = nullptr; + float* d_cum_log_probs_ = nullptr; + bool* d_within_range_ = nullptr; + + int h_total_output_len_; + + std::exception_ptr h_exception_ = nullptr; +}; diff --git a/src/fastertransformer/triton_backend/bart/CMakeLists.txt b/src/fastertransformer/triton_backend/bart/CMakeLists.txt new file mode 100644 index 000000000..f37028e7f --- /dev/null +++ b/src/fastertransformer/triton_backend/bart/CMakeLists.txt @@ -0,0 +1,25 @@ +# Copyright (c) 2021-2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +cmake_minimum_required(VERSION 3.8) + +set(bart_triton_backend_files + BartTritonModel.cc + BartTritonModelInstance.cc +) + +add_library(BartTritonBackend STATIC ${bart_triton_backend_files}) +set_property(TARGET BartTritonBackend PROPERTY POSITION_INDEPENDENT_CODE ON) +target_link_libraries(BartTritonBackend PRIVATE TransformerTritonBackend BartEncoder BartDecoding -lcublasLt) +target_compile_features(BartTritonBackend PRIVATE cxx_std_14) diff --git a/src/fastertransformer/triton_backend/transformer_triton_backend.hpp b/src/fastertransformer/triton_backend/transformer_triton_backend.hpp index edffabfd7..3b9ef2d08 100644 --- a/src/fastertransformer/triton_backend/transformer_triton_backend.hpp +++ b/src/fastertransformer/triton_backend/transformer_triton_backend.hpp @@ -294,6 +294,7 @@ struct AbstractTransformerModel { static std::shared_ptr createT5Model(std::string model_dir); static std::shared_ptr createT5EncoderModel(std::string model_dir); static std::shared_ptr createLlamaModel(std::string model_dir); + static std::shared_ptr createBartModel(std::string model_dir); std::pair, std::vector> createNcclParams(const int node_id, const int device_id_start = 0, const bool multi_node = false);