diff --git a/CMakeLists.txt b/CMakeLists.txt index 8880526e4..a164ef827 100755 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -399,6 +399,7 @@ add_library(transformer-shared SHARED $ $ $ + $ $ $ $ diff --git a/examples/cpp/CMakeLists.txt b/examples/cpp/CMakeLists.txt index da24d72c6..40bf46574 100644 --- a/examples/cpp/CMakeLists.txt +++ b/examples/cpp/CMakeLists.txt @@ -21,6 +21,7 @@ add_subdirectory(swin) add_subdirectory(swin_int8) add_subdirectory(vit) add_subdirectory(vit_int8) +add_subdirectory(llama) add_subdirectory(wenet) diff --git a/examples/cpp/llama/CMakeLists.txt b/examples/cpp/llama/CMakeLists.txt new file mode 100644 index 000000000..cdf9033dd --- /dev/null +++ b/examples/cpp/llama/CMakeLists.txt @@ -0,0 +1,22 @@ +# Copyright (c) 2019-2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +add_executable(llama_example llama_example.cc) +target_link_libraries(llama_example PUBLIC -lcublas -lcublasLt -lcudart + Llama nvtx_utils gpt_example_utils word_list mpi_utils nccl_utils) + +add_executable(llama_triton_example llama_triton_example.cc) +target_link_libraries(llama_triton_example PUBLIC -lcublas -lcublasLt -lcudart -lpthread + LlamaTritonBackend TransformerTritonBackend custom_ar_comm + gpt_example_utils word_list mpi_utils nccl_utils nvtx_utils) diff --git a/examples/cpp/llama/bad_words.csv b/examples/cpp/llama/bad_words.csv new file mode 100644 index 000000000..6a1126ebd --- /dev/null +++ b/examples/cpp/llama/bad_words.csv @@ -0,0 +1,2 @@ +7768,3908 +1,2 diff --git a/examples/cpp/llama/check_with_huggingface.py b/examples/cpp/llama/check_with_huggingface.py new file mode 100644 index 000000000..d1f356cc1 --- /dev/null +++ b/examples/cpp/llama/check_with_huggingface.py @@ -0,0 +1,16 @@ +import transformers + +from transformers import LlamaForCausalLM, LlamaTokenizer + +tokenizer = LlamaTokenizer.from_pretrained('/data/llama-7b-hf') + +prompt = "Hey, are you consciours? Can you talk to me?" +inputs = tokenizer(prompt, return_tensors='pt') +model = LlamaForCausalLM.from_pretrained("/data/llama-7b-hf") +hf_config = vars(model.config) +print(hf_config) +generated_ids = model.forward(inputs.input_ids, output_hidden_states=True) +print(generated_ids) + +tokens = [0,18637,29892,526,366,1136,455,2470,29973,1815,366,5193,304,592,29973,18637,29892,526,366,1136,455,2470,29973,1815,366,5193,304,592,29973,18637,29892,526,366,1136,455,2470,29973,1815,366,5193,304,592,29973,18637,29892,526,366] +print(tokenizer.decode(tokens)) diff --git a/examples/cpp/llama/huggingface_llama_convert.py b/examples/cpp/llama/huggingface_llama_convert.py new file mode 100644 index 000000000..d771c0b2c --- /dev/null +++ b/examples/cpp/llama/huggingface_llama_convert.py @@ -0,0 +1,233 @@ +# Copyright (c) 2021-2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import configparser +import numpy as np +from pathlib import Path + +import torch +import os +from transformers import LlamaForCausalLM, AutoConfig + +def get_weight_data_type(data_type): + if data_type == "fp32": + return np.float32 + elif data_type == "fp16": + return np.float16 + else: + assert False, f"Invalid weight data type {data_type}" + + +def split_and_convert_process(saved_dir, factor, key, val): + if key.find("input_layernorm.weight") != -1 or key.find("post_attention_layernorm.weight") != -1: + # shared weights, only need to convert the weights of rank 0 + saved_path = saved_dir + "/" + key + ".bin" + val.tofile(saved_path) + elif key.find("attention.dense.weight") != -1 or key.find("mlp.down_proj.weight") != -1: + split_vals = np.split(val, factor, axis=0) + for j in range(factor): + saved_path = saved_dir + "/" + key + ".%d.bin" % j + split_vals[j].tofile(saved_path) + elif key.find("mlp.gate_proj.weight") != -1 or key.find("mlp.up_proj.weight") != -1: + split_vals = np.split(val, factor, axis=-1) + for j in range(factor): + saved_path = saved_dir + "/" + key + ".%d.bin" % j + split_vals[j].tofile(saved_path) + elif key.find("attention.query_key_value.weight") != -1: + split_vals = np.split(val, factor, axis=-1) + for j in range(factor): + saved_path = saved_dir + "/" + key + ".%d.bin" % j + split_vals[j].tofile(saved_path) + else: + print("[ERROR] cannot find key '{}'".format(key)) + +def split_and_convert(args): + saved_dir = args.saved_dir + "/%d-gpu/" % args.infer_gpu_num + + if(os.path.exists(saved_dir) == False): + os.makedirs(saved_dir) + + t_gpu_num = args.trained_gpu_num + i_gpu_num = args.infer_gpu_num + assert(i_gpu_num % t_gpu_num == 0) + + factor = (int)(i_gpu_num / t_gpu_num) + # load position_embedding from rank 0 + # model = torch.load(ckpt_name) + print(f'load model from {args.in_file}') + # model = LlamaForCausalLM.from_pretrained(args.in_file, device_map='auto') + config = AutoConfig.from_pretrained(args.in_file) + # num_layers = 3 + # config.num_hidden_layers = num_layers + print(config) + state_dict = {} + for f in os.listdir(args.in_file): + if not f.endswith('.bin'): + continue + w = torch.load(os.path.join(args.in_file, f), map_location='cpu') + keys = list(w.keys()) + for k in keys: + if 'model.layers.' not in k: + continue + l = int(k.split('.')[2]) + if l < config.num_hidden_layers: + continue + del w[k] + state_dict.update(w) + + model = LlamaForCausalLM.from_pretrained(None, config=config, state_dict=state_dict) + hf_config = vars(model.config) + print(f"hf_config: {hf_config}") + + print("named parameters:") + for name, param in model.named_parameters(): + print(f"- {name}") + + hidden_size = hf_config["hidden_size"] + head_num = hf_config["num_attention_heads"] + kv_head_num = hf_config["num_key_value_heads"] + head_size = hidden_size // head_num + # num_layers = hf_config["num_hidden_layers"] + + + np_weight_data_type = get_weight_data_type(args.weight_data_type) + + try: + model_name = args.model_name + config = configparser.ConfigParser() + config['llama'] = {} + config['llama']['model_name'] = model_name + config['llama']["head_num"] = str(head_num) + config['llama']["kv_head_num"] = str(kv_head_num) + config['llama']["size_per_head"] = str(head_size) + config['llama']["inter_size"] = str(hf_config["intermediate_size"]) + config['llama']["num_layer"] = str(num_layers) + config['llama']["rotary_embedding"] = str(head_size) + config['llama']['layernorm_eps'] = str(hf_config["rms_norm_eps"]) + config['llama']["vocab_size"] = str(hf_config["vocab_size"]) + config['llama']["start_id"] = str(hf_config["bos_token_id"]) + config['llama']["end_id"] = str(hf_config["eos_token_id"]) + config['llama']["weight_data_type"] = args.weight_data_type + + with open((Path(saved_dir) / f"config.ini").as_posix(), 'w') as configfile: + config.write(configfile) + except Exception as e: + print(f"Fail to save the config in config.ini.") + print(e) + + param_to_weights = lambda param: param.detach().cpu().numpy().astype(np_weight_data_type) + + # layer-wise weights, example: + # - model.layers.0.self_attn.q_proj.weight + # - model.layers.0.self_attn.k_proj.weight + # - model.layers.0.self_attn.v_proj.weight + # - model.layers.0.self_attn.o_proj.weight + # - model.layers.0.mlp.gate_proj.weight + # - model.layers.0.mlp.down_proj.weight + # - model.layers.0.mlp.up_proj.weight + # - model.layers.0.input_layernorm.weight + # - model.layers.0.post_attention_layernorm.weight + for l in range(num_layers): + print(f"converting layer {l}") + # first merge QKV into a single weight + # concat direct to FT shape: [hidden_size, 3, head_num, head_size] + # copied from huggingface_gptj_ckpt_convert.py + # qkv_weights = np.stack([ + # param_to_weights(model.state_dict()[f'model.layers.{l}.self_attn.q_proj.weight']), + # param_to_weights(model.state_dict()[f'model.layers.{l}.self_attn.k_proj.weight']), + # param_to_weights(model.state_dict()[f'model.layers.{l}.self_attn.v_proj.weight']), + # ]) + # qkv_weights = np.transpose(qkv_weights, (2, 0, 1)) + q_proj = param_to_weights(model.state_dict()[f'model.layers.{l}.self_attn.q_proj.weight']) + k_proj = param_to_weights(model.state_dict()[f'model.layers.{l}.self_attn.k_proj.weight']) + v_proj = param_to_weights(model.state_dict()[f'model.layers.{l}.self_attn.v_proj.weight']) + q_proj = np.split(q_proj, factor, axis=0) + k_proj = np.split(k_proj, factor, axis=0) + v_proj = np.split(v_proj, factor, axis=0) + for j in range(factor): + qkv_weights = np.concatenate((q_proj[j], k_proj[j], v_proj[j]), axis=0) + print(qkv_weights.shape) + # qkv_weights = np.transpose(qkv_weights, (2, 0, 1)) + qkv_weights = np.transpose(qkv_weights) + qkv_weights_base_name = f'model.layers.{l}.attention.query_key_value.weight' + saved_path = saved_dir + "/" + qkv_weights_base_name + ".%d.bin" % j + qkv_weights.tofile(saved_path) + # qkv_weights = np.concatenate(( + # param_to_weights(model.state_dict()[f'model.layers.{l}.self_attn.q_proj.weight']), + # param_to_weights(model.state_dict()[f'model.layers.{l}.self_attn.k_proj.weight']), + # param_to_weights(model.state_dict()[f'model.layers.{l}.self_attn.v_proj.weight']), + # ), axis=0) + # print(qkv_weights.shape) + # # qkv_weights = np.transpose(qkv_weights, (2, 0, 1)) + # qkv_weights = np.transpose(qkv_weights) + # qkv_weights_base_name = f'model.layers.{l}.attention.query_key_value.weight' + # split_and_convert_process(saved_dir, factor, qkv_weights_base_name, qkv_weights) + + # attention dense + o_weight = param_to_weights(model.state_dict()[f'model.layers.{l}.self_attn.o_proj.weight']).T + o_weight_base_name = f'model.layers.{l}.attention.dense.weight' + split_and_convert_process(saved_dir, factor, o_weight_base_name, o_weight) + + # MLP + mlp_down_weight = param_to_weights(model.state_dict()[f'model.layers.{l}.mlp.down_proj.weight']).T + mlp_down_base_name = f'model.layers.{l}.mlp.down_proj.weight' + split_and_convert_process(saved_dir, factor, mlp_down_base_name, mlp_down_weight) + + mlp_gate_weight = param_to_weights(model.state_dict()[f'model.layers.{l}.mlp.gate_proj.weight']).T + mlp_gate_base_name = f'model.layers.{l}.mlp.gate_proj.weight' + split_and_convert_process(saved_dir, factor, mlp_gate_base_name, mlp_gate_weight) + + mlp_up_weight = param_to_weights(model.state_dict()[f'model.layers.{l}.mlp.up_proj.weight']).T + mlp_up_base_name = f'model.layers.{l}.mlp.up_proj.weight' + split_and_convert_process(saved_dir, factor, mlp_up_base_name, mlp_up_weight) + + # LayerNorm + input_ln_weight = param_to_weights(model.state_dict()[f'model.layers.{l}.input_layernorm.weight']) + input_ln_base_name = f'model.layers.{l}.input_layernorm.weight' + split_and_convert_process(saved_dir, factor, input_ln_base_name, input_ln_weight) + + post_attn_ln_weight = param_to_weights(model.state_dict()[f'model.layers.{l}.post_attention_layernorm.weight']) + post_attn_ln_base_name = f'model.layers.{l}.post_attention_layernorm.weight' + split_and_convert_process(saved_dir, factor, post_attn_ln_base_name, post_attn_ln_weight) + + print(f"done layer {l}") + + + # final common weights + for name, param in model.named_parameters(): + if name == 'model.embed_tokens.weight': + param.detach().cpu().numpy().astype(np_weight_data_type).tofile(saved_dir + "model.wte.weight.bin") + elif name == 'model.norm.weight': + param.detach().cpu().numpy().astype(np_weight_data_type).tofile(saved_dir + "model.final_layernorm.weight.bin") + elif name == 'lm_head.weight': + param.detach().cpu().numpy().astype(np_weight_data_type).tofile(saved_dir + "model.lm_head.weight.bin") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(formatter_class=argparse.RawTextHelpFormatter) + parser.add_argument('-saved_dir', '-o', type=str, help='file name of output file', required=True) + parser.add_argument('-in_file', '-i', type=str, help='file name of input checkpoint file', required=True) + parser.add_argument('-trained_gpu_num', '-t_g', type=int, help='How many gpus for inference', default=1) + parser.add_argument('-infer_gpu_num', '-i_g', type=int, help='How many gpus for inference', required=True) + parser.add_argument("-weight_data_type", type=str, default="fp32", choices=["fp32", "fp16", "bf16"]) + parser.add_argument('-model_name', '-m_n', type=str, help='model name', required=True) + + args = parser.parse_args() + print("\n=============== Argument ===============") + for key in vars(args): + print("{}: {}".format(key, vars(args)[key])) + print("========================================") + + split_and_convert(args) diff --git a/examples/cpp/llama/huggingface_llama_convert2.py b/examples/cpp/llama/huggingface_llama_convert2.py new file mode 100644 index 000000000..fe38238ff --- /dev/null +++ b/examples/cpp/llama/huggingface_llama_convert2.py @@ -0,0 +1,259 @@ +# Copyright (c) 2021-2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import configparser +import numpy as np +from pathlib import Path + +import torch +import os +from transformers import LlamaForCausalLM, AutoConfig +# using numpy extension: https://github.com/GreenWaves-Technologies/bfloat16 +# install the library with `pip install bfloat16` +from bfloat16 import bfloat16 + +def get_weight_data_type(data_type): + if data_type == "fp32": + return np.float32 + elif data_type == "fp16": + return np.float16 + elif data_type == "bf16": + return bfloat16 + else: + assert False, f"Invalid weight data type {data_type}" + + +def split_and_convert_process(saved_dir, factor, key, val): + if key.find("input_layernorm.weight") != -1 or key.find("post_attention_layernorm.weight") != -1: + # shared weights, only need to convert the weights of rank 0 + saved_path = saved_dir + "/" + key + ".bin" + val.tofile(saved_path) + elif key.find("attention.dense.weight") != -1 or key.find("mlp.down_proj.weight") != -1: + split_vals = np.split(val, factor, axis=0) + for j in range(factor): + saved_path = saved_dir + "/" + key + ".%d.bin" % j + split_vals[j].tofile(saved_path) + elif key.find("mlp.gate_proj.weight") != -1 or key.find("mlp.up_proj.weight") != -1: + split_vals = np.split(val, factor, axis=-1) + for j in range(factor): + saved_path = saved_dir + "/" + key + ".%d.bin" % j + split_vals[j].tofile(saved_path) + elif key.find("attention.query_key_value.weight") != -1: + split_vals = np.split(val, factor, axis=-1) + for j in range(factor): + saved_path = saved_dir + "/" + key + ".%d.bin" % j + split_vals[j].tofile(saved_path) + else: + print("[ERROR] cannot find key '{}'".format(key)) + +def split_and_convert(args): + saved_dir = args.saved_dir + "/%d-gpu/" % args.infer_gpu_num + + if(os.path.exists(saved_dir) == False): + os.makedirs(saved_dir) + + t_gpu_num = args.trained_gpu_num + i_gpu_num = args.infer_gpu_num + assert(i_gpu_num % t_gpu_num == 0) + + factor = (int)(i_gpu_num / t_gpu_num) + print(f'load model from {args.in_file}') + # model = LlamaForCausalLM.from_pretrained(args.in_file, device_map='auto') + config = AutoConfig.from_pretrained(args.in_file) + # num_layers = 3 + # config.num_hidden_layers = num_layers + + hf_config = vars(config) + print(f"hf_config: {hf_config}") + + hidden_size = hf_config["hidden_size"] + head_num = hf_config["num_attention_heads"] + kv_head_num = hf_config["num_key_value_heads"] + head_size = hidden_size // head_num + # num_layers = hf_config["num_hidden_layers"] + + + np_weight_data_type = get_weight_data_type(args.weight_data_type) + + try: + model_name = args.model_name + config = configparser.ConfigParser() + config['llama'] = {} + config['llama']['model_name'] = model_name + config['llama']["head_num"] = str(head_num) + config['llama']["kv_head_num"] = str(kv_head_num) + config['llama']["size_per_head"] = str(head_size) + config['llama']["inter_size"] = str(hf_config["intermediate_size"]) + config['llama']["num_layer"] = str(hf_config["num_hidden_layers"]) + config['llama']["rotary_embedding"] = str(head_size) + config['llama']['layernorm_eps'] = str(hf_config["rms_norm_eps"]) + config['llama']["vocab_size"] = str(hf_config["vocab_size"]) + config['llama']["start_id"] = str(hf_config["bos_token_id"]) + config['llama']["end_id"] = str(hf_config["eos_token_id"]) + config['llama']["weight_data_type"] = args.weight_data_type + + with open((Path(saved_dir) / f"config.ini").as_posix(), 'w') as configfile: + config.write(configfile) + except Exception as e: + print(f"Fail to save the config in config.ini.") + print(e) + + param_to_weights = lambda param: param.detach().cpu().float().numpy().astype(np_weight_data_type) + + def get_param(key, cache, loaded): + if key in cache: + return param_to_weights(cache[key]) + if key in loaded: + return param_to_weights(loaded[key]) + return None + + def clear_param(key, cache, loaded): + if key in cache: + del cache[key] + if key in loaded: + del loaded[key] + + def try_dump(key, cache, loaded, save_name, saved_dir, factor, transpose=True): + weight = get_param(key, cache, loaded) + if weight is None: + return + if transpose: + weight = weight.T + split_and_convert_process(saved_dir, factor, save_name, weight) + clear_param(key, state_dict, w) + # layer-wise weights, example: + # - model.layers.0.self_attn.q_proj.weight + # - model.layers.0.self_attn.k_proj.weight + # - model.layers.0.self_attn.v_proj.weight + # - model.layers.0.self_attn.o_proj.weight + # - model.layers.0.mlp.gate_proj.weight + # - model.layers.0.mlp.down_proj.weight + # - model.layers.0.mlp.up_proj.weight + # - model.layers.0.input_layernorm.weight + # - model.layers.0.post_attention_layernorm.weight + state_dict = {} + for f in os.listdir(args.in_file): + if not f.endswith('.bin'): + continue + f = os.path.join(args.in_file, f) + print(f'processing {f}') + w = torch.load(f, map_location='cpu') + for l in range(hf_config["num_hidden_layers"]): + # first merge QKV into a single weight + # concat direct to FT shape: [hidden_size, 3, head_num, head_size] + # copied from huggingface_gptj_ckpt_convert.py + # qkv_weights = np.stack([ + # param_to_weights(model.state_dict()[f'model.layers.{l}.self_attn.q_proj.weight']), + # param_to_weights(model.state_dict()[f'model.layers.{l}.self_attn.k_proj.weight']), + # param_to_weights(model.state_dict()[f'model.layers.{l}.self_attn.v_proj.weight']), + # ]) + # qkv_weights = np.transpose(qkv_weights, (2, 0, 1)) + q_key = f'model.layers.{l}.self_attn.q_proj.weight' + k_key = f'model.layers.{l}.self_attn.k_proj.weight' + v_key = f'model.layers.{l}.self_attn.v_proj.weight' + q_proj = get_param(q_key, state_dict, w) + k_proj = get_param(k_key, state_dict, w) + v_proj = get_param(v_key, state_dict, w) + + if q_proj is not None and k_proj is not None and v_proj is not None: + q_proj = np.split(q_proj, factor, axis=0) + k_proj = np.split(k_proj, factor, axis=0) + v_proj = np.split(v_proj, factor, axis=0) + for j in range(factor): + qkv_weights = np.concatenate((q_proj[j], k_proj[j], v_proj[j]), axis=0) + qkv_weights = np.transpose(qkv_weights) + qkv_weights_base_name = f'model.layers.{l}.attention.query_key_value.weight' + saved_path = saved_dir + "/" + qkv_weights_base_name + ".%d.bin" % j + qkv_weights.tofile(saved_path) + clear_param(q_key, state_dict, w) + clear_param(k_key, state_dict, w) + clear_param(v_key, state_dict, w) + + # attention dense + try_dump(key=f'model.layers.{l}.self_attn.o_proj.weight', + cache=state_dict, + loaded=w, + save_name=f'model.layers.{l}.attention.dense.weight', + saved_dir=saved_dir, + factor=factor) + + # MLP + try_dump(key=f'model.layers.{l}.mlp.down_proj.weight', + cache=state_dict, + loaded=w, + save_name=f'model.layers.{l}.mlp.down_proj.weight', + saved_dir=saved_dir, + factor=factor) + try_dump(key=f'model.layers.{l}.mlp.gate_proj.weight', + cache=state_dict, + loaded=w, + save_name=f'model.layers.{l}.mlp.gate_proj.weight', + saved_dir=saved_dir, + factor=factor) + try_dump(key=f'model.layers.{l}.mlp.up_proj.weight', + cache=state_dict, + loaded=w, + save_name=f'model.layers.{l}.mlp.up_proj.weight', + saved_dir=saved_dir, + factor=factor) + + # LayerNorm + try_dump(key=f'model.layers.{l}.input_layernorm.weight', + cache=state_dict, + loaded=w, + save_name=f'model.layers.{l}.input_layernorm.weight', + saved_dir=saved_dir, + factor=factor, + transpose=False) + try_dump(key=f'model.layers.{l}.post_attention_layernorm.weight', + cache=state_dict, + loaded=w, + save_name=f'model.layers.{l}.post_attention_layernorm.weight', + saved_dir=saved_dir, + factor=factor, + transpose=False) + to_del = [] + for name, param in w.items(): + if name == 'model.embed_tokens.weight': + param.detach().cpu().float().numpy().astype(np_weight_data_type).tofile(saved_dir + "model.wte.weight.bin") + elif name == 'model.norm.weight': + param.detach().cpu().float().numpy().astype(np_weight_data_type).tofile(saved_dir + "model.final_layernorm.weight.bin") + elif name == 'lm_head.weight': + param.detach().cpu().float().numpy().astype(np_weight_data_type).tofile(saved_dir + "model.lm_head.weight.bin") + else: + continue + to_del.append(name) + # for k in to_del: + # del w[k] + print(w.keys()) + state_dict.update(w) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(formatter_class=argparse.RawTextHelpFormatter) + parser.add_argument('-saved_dir', '-o', type=str, help='file name of output file', required=True) + parser.add_argument('-in_file', '-i', type=str, help='file name of input checkpoint file', required=True) + parser.add_argument('-trained_gpu_num', '-t_g', type=int, help='How many gpus for inference', default=1) + parser.add_argument('-infer_gpu_num', '-i_g', type=int, help='How many gpus for inference', required=True) + parser.add_argument("-weight_data_type", type=str, default="fp32", choices=["fp32", "fp16", "bf16"]) + parser.add_argument('-model_name', '-m_n', type=str, help='model name', required=True) + + args = parser.parse_args() + print("\n=============== Argument ===============") + for key in vars(args): + print("{}: {}".format(key, vars(args)[key])) + print("========================================") + + split_and_convert(args) diff --git a/examples/cpp/llama/llama_config.ini b/examples/cpp/llama/llama_config.ini new file mode 100644 index 000000000..ef789d35d --- /dev/null +++ b/examples/cpp/llama/llama_config.ini @@ -0,0 +1,34 @@ +[ft_instance_hyperparameter] +data_type=fp16 +enable_custom_all_reduce=0 + +tensor_para_size=1 +pipeline_para_size=1 + +model_name=llama_7b +model_dir=/notebooks/llama-2-70b-hf-ft-tp-1_llama_decoder/1/1-gpu/ + +[request] +beam_width=1 # beam width for beam search +top_k=1 ; k value for top k sampling +top_p=0.0 ; p value for top p sampling +temperature=1.0 ; Use for sampling +repetition_penalty=1.0 ; Use for sampling +presence_penalty=0.0 ; Only one of repetition_penalty and presence_penalty are allowed. +len_penalty=0.0 +beam_search_diversity_rate=0.0 +request_batch_size=8 # determine by the request +request_output_len=32 # determine by the request + +[llama_7b] +head_num = 64 +kv_head_num = 8 +size_per_head = 128 +inter_size = 28672 +num_layer = 3 +rotary_embedding = 128 +layernorm_eps = 1e-05 +vocab_size = 32000 +start_id = 1 +end_id = 2 +weight_data_type = fp16 diff --git a/examples/cpp/llama/llama_example.cc b/examples/cpp/llama/llama_example.cc new file mode 100644 index 000000000..84a0b54aa --- /dev/null +++ b/examples/cpp/llama/llama_example.cc @@ -0,0 +1,542 @@ +/* + * Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "3rdparty/INIReader.h" +#include "examples/cpp/multi_gpu_gpt/gpt_example_utils.h" +#include "src/fastertransformer/models/llama/Llama.h" +#include "src/fastertransformer/utils/mpi_utils.h" +#include "src/fastertransformer/utils/nccl_utils.h" +#include "src/fastertransformer/utils/nvtx_utils.h" +#include "src/fastertransformer/utils/word_list.h" + +#include +#include +#include +#include +#include +#include + +using namespace fastertransformer; + +template +void llama_example(const INIReader reader); + +int main(int argc, char* argv[]) +{ + fastertransformer::mpi::initialize(&argc, &argv); + srand(0); + + std::string ini_name; + if (argc == 2) { + ini_name = std::string(argv[1]); + } + else { + ini_name = "/notebooks/FasterTransformer/examples/cpp/llama/llama_config.ini"; + } + + INIReader reader = INIReader(ini_name); + if (reader.ParseError() < 0) { + std::cout << "[ERROR] Can't load '" << ini_name << "'\n"; + return -1; + } + const std::string data_type = reader.Get("ft_instance_hyperparameter", "data_type"); + + if (data_type == "fp32") { + llama_example(reader); + } + else if (data_type == "fp16") { + llama_example(reader); + } +#ifdef ENABLE_BF16 + else if (data_type == "bf16") { + llama_example<__nv_bfloat16>(reader); + } +#endif + else { + FT_LOG_ERROR("is_fp16 should be 0 (use float) or 1 (use half)."); + return -1; + } + mpi::finalize(); + return 0; +} + +template +void llama_example(const INIReader reader) +{ + const std::string model_name = reader.Get("ft_instance_hyperparameter", "model_name"); + std::string model_dir = std::string(reader.Get("ft_instance_hyperparameter", "model_dir")); + + int tensor_para_size = reader.GetInteger("ft_instance_hyperparameter", "tensor_para_size"); + int pipeline_para_size = reader.GetInteger("ft_instance_hyperparameter", "pipeline_para_size"); + int int8_mode = reader.GetInteger("ft_instance_hyperparameter", "int8_mode", 0); + + const size_t head_num = reader.GetInteger(model_name, "head_num"); + const size_t kv_head_num = reader.GetInteger(model_name, "kv_head_num"); + const size_t size_per_head = reader.GetInteger(model_name, "size_per_head"); + const size_t vocab_size = reader.GetInteger(model_name, "vocab_size"); + const size_t decoder_layers = reader.GetInteger(model_name, "num_layer"); + const size_t rotary_embedding_dim = reader.GetInteger(model_name, "rotary_embedding"); + const float layernorm_eps = reader.GetFloat(model_name, "layernorm_eps"); + const int start_id = reader.GetInteger(model_name, "start_id"); + const int end_id = reader.GetInteger(model_name, "end_id"); + + const size_t hidden_units = head_num * size_per_head; + const size_t inter_size = reader.GetInteger(model_name, "inter_size"); + + const size_t beam_width = reader.GetInteger("request", "beam_width"); + const uint top_k = (uint)reader.GetInteger("request", "top_k"); + const float top_p = reader.GetFloat("request", "top_p"); + const float temperature = reader.GetFloat("request", "temperature"); + const float repetition_penalty = reader.GetFloat("request", "repetition_penalty", 1.0f); + const float presence_penalty = reader.GetFloat("request", "presence_penalty", 0.0f); + const float len_penalty = reader.GetFloat("request", "len_penalty"); + const float beam_search_diversity_rate = reader.GetFloat("request", "beam_search_diversity_rate"); + const int min_length = reader.GetInteger("request", "min_length", 0); + const size_t request_batch_size = reader.GetInteger("request", "request_batch_size"); + // The length of tokens we hope this model to generate + const int request_output_len = reader.GetInteger("request", "request_output_len"); + + FT_CHECK(head_num % tensor_para_size == 0); + FT_CHECK(decoder_layers % pipeline_para_size == 0); + FT_CHECK_WITH_INFO( + repetition_penalty == 1.0f || presence_penalty == 0.0f, + fmtstr("Found ambiguous parameters repetition_penalty (%f) and presence_penalty (%f) " + "which are mutually exclusive. Please remove one of repetition_penalty or presence_penalty " + "or set to a default value.", + repetition_penalty, + presence_penalty)); + + // Prepare the parallelism parameters + int rank = mpi::getCommWorldRank(); + int world_size = mpi::getCommWorldSize(); + // world_size = 4; + if (rank == 0) { + printf("Total ranks: %d.\n", world_size); + } + int device, device_count; + check_cuda_error(cudaGetDeviceCount(&device_count)); + check_cuda_error(cudaSetDevice(rank % device_count)); + check_cuda_error(cudaGetDevice(&device)); + + struct cudaDeviceProp prop; + check_cuda_error(cudaGetDeviceProperties(&prop, device)); + printf("Device %s\n", prop.name); + + printf("P%d is running with GPU #%d.\n", rank, device); + if (tensor_para_size * pipeline_para_size != world_size) { + if (world_size % pipeline_para_size) { + printf("[ERROR] tensor_para_size * pipeline_para_size should equal to world_size \n"); + exit(-1); + } + tensor_para_size = world_size / pipeline_para_size; + printf("[INFO] Setting tensor_para_size to %d \n", tensor_para_size); + } + + const int layers_per_group = decoder_layers / pipeline_para_size; + if (layers_per_group * pipeline_para_size != (int)decoder_layers) { + printf("[ERROR] layers_per_group (%d) * pipeline_para_size (%d) should equal to decoder_layers (%ld) \n", + layers_per_group, + pipeline_para_size, + decoder_layers); + exit(-1); + } + + // assume gpu_num = k * n, + // tensor parallelism group size is n + // pipeline parallelism group size is k + NcclParam tensor_para; + NcclParam pipeline_para; + ftNcclInitialize(tensor_para, pipeline_para, tensor_para_size, pipeline_para_size); + + // Handle bad_words dictionary + std::vector bad_words; + read_word_list("/notebooks/FasterTransformer/examples/cpp/llama/bad_words.csv", bad_words); + + int* d_bad_words = nullptr; + deviceMalloc(&d_bad_words, bad_words.size(), false); + cudaH2Dcpy(d_bad_words, bad_words.data(), bad_words.size()); + + // Handle stop_words dictionary + std::vector stop_words; + read_word_list("/notebooks/FasterTransformer/examples/cpp/llama/stop_words.csv", stop_words); + + const size_t stop_words_len = stop_words.size() / 2; + // Tile with same dict for each element + std::vector tiled_stop_words; + for (int i = 0; i < request_batch_size; i++) { + tiled_stop_words.insert(tiled_stop_words.end(), stop_words.begin(), stop_words.end()); + } + + + int* d_stop_words = nullptr; + deviceMalloc(&d_stop_words, tiled_stop_words.size(), false); + cudaH2Dcpy(d_stop_words, tiled_stop_words.data(), tiled_stop_words.size()); + + // Read ids of request from file. + size_t max_input_len = -1; + std::vector v_start_lengths; + std::vector v_start_ids; + read_start_ids(request_batch_size, + &v_start_lengths, + &v_start_ids, + max_input_len, + end_id, + 1, + "/notebooks/FasterTransformer/examples/cpp/llama/start_ids.csv"); + + + int* d_input_ids; + int* d_input_lengths; + if (max_input_len == 0) { + // unconditional case, no input ids, so do nothing. + d_input_ids = nullptr; + d_input_lengths = nullptr; + } + else { + // conditional case. + deviceMalloc(&d_input_ids, request_batch_size * max_input_len, false); + deviceMalloc(&d_input_lengths, request_batch_size, false); + cudaH2Dcpy(d_input_ids, v_start_ids.data(), request_batch_size * max_input_len); + cudaH2Dcpy(d_input_lengths, v_start_lengths.data(), request_batch_size); + } + std::vector start_ids(request_batch_size, start_id); + std::vector end_ids(request_batch_size, end_id); + + // Prompt Learning Configurations + // NOTE: if you don't need prefix prompts, remember to set max_prefix_len to 0 and others to nullptr + int prompt_learning_start_id = reader.GetInteger(model_name, "prompt_learning_start_id", end_id + 1); + fastertransformer::PromptLearningType prompt_learning_type = + static_cast(reader.GetInteger(model_name, "prompt_learning_type", 0)); + + // NOTE: specify task names, take name id, prompt length in order to load those prompt learning tables. + // NOTE: Please make sure task ids are continuous and start from 0 + // for example: + // std::map> prefix_prompt_table_pair{{"no_prompt", {0, 0}}, + // {"prompt_1", {1, 1}}, + // {"prompt_2", {2, 2}}, + // {"prompt_3", {3, 3}}, + // {"prompt_4", {4, 4}}, + // {"prompt_5", {5, 5}}}; + + std::map> prefix_prompt_table_pair; + + // NOTE: get prompt table pairs from configuration files + const int num_tasks = reader.GetInteger(model_name, "num_tasks", 0); + for (int task_name_id = 0; task_name_id < num_tasks; task_name_id++) { + std::string config_task_name = model_name + "_task_" + std::to_string(task_name_id); + std::string task_name = reader.Get(config_task_name, "task_name"); + const int prompt_length = reader.GetInteger(config_task_name, "prompt_length", 0); + prefix_prompt_table_pair.insert({task_name, {task_name_id, prompt_length}}); + } + + // NOTE: task_name_ids for each sequence in one batch + // Each sequence can have different prompt learning task ids + std::vector prefix_prompt_task_ids(request_batch_size, 0); + + // Set different task ids + for (int i = 0; i < request_batch_size; i++) { + prefix_prompt_task_ids[i] = (num_tasks > 0) ? i % num_tasks : 0; + } + + const int total_output_len = max_input_len + request_output_len; + + cudaStream_t stream; + cublasHandle_t cublas_handle; + cublasLtHandle_t cublaslt_handle; + cudaStreamCreate(&stream); + cublasCreate(&cublas_handle); + cublasLtCreate(&cublaslt_handle); + cublasSetStream(cublas_handle, stream); + cublasAlgoMap* cublas_algo_map = new cublasAlgoMap("gemm_config.in"); + + Allocator allocator(getDevice()); + + std::mutex* cublas_wrapper_mutex = new std::mutex(); + cublasMMWrapper cublas_wrapper = + cublasMMWrapper(cublas_handle, cublaslt_handle, stream, cublas_algo_map, cublas_wrapper_mutex, &allocator); + if (std::is_same::value) { + cublas_wrapper.setGemmConfig(CUDA_R_16F, CUDA_R_16F, CUDA_R_16F, CUDA_R_32F); + } +#ifdef ENABLE_BF16 + else if (std::is_same::value) { + cublas_wrapper.setBF16GemmConfig(); + } +#endif + else if (std::is_same::value) { + cublas_wrapper.setFP32GemmConfig(); + } + + const bool use_gptj_residual = false; + printf("kv_head_num: %d\n", kv_head_num); + fastertransformer::LlamaWeight gpt_weights(head_num, + kv_head_num, + size_per_head, + inter_size, + vocab_size, + decoder_layers, + 0, // max_seq_len, deprecated + tensor_para.world_size_, + tensor_para.rank_, + pipeline_para.world_size_, + pipeline_para.rank_, + use_gptj_residual, + int8_mode, + prompt_learning_type, + prefix_prompt_table_pair); + + gpt_weights.loadModel(model_dir); + unsigned long long random_seed; + if (rank == 0) { + random_seed = (unsigned long long)(0); + } + if (world_size > 1) { + mpi::bcast(&random_seed, 1, mpi::MPI_TYPE_UNSIGNED_LONG_LONG, 0, mpi::COMM_WORLD); + } + + AttentionType attention_type = getAttentionType(size_per_head, + getSMVersion(), + true, // remove_padding + 0, // gpt supports any-seq-length fmha + true, // is_fuse + false, // with_relative_position_bias + true); // causal_mask + + Llama gpt = Llama(head_num, + kv_head_num, + size_per_head, + inter_size, + decoder_layers, + vocab_size, + rotary_embedding_dim, + layernorm_eps, + start_id, + end_id, + prompt_learning_start_id, + prompt_learning_type, + use_gptj_residual, + 0.0f, + top_k, + top_p, + random_seed, + temperature, + len_penalty, + repetition_penalty, + tensor_para, + pipeline_para, + stream, + &cublas_wrapper, + &allocator, + false, + &prop, + attention_type, + int8_mode, + nullptr, + 0, + 1.0f); + + int* d_output_ids; + int* d_sequence_lengths; + + + deviceMalloc(&d_output_ids, request_batch_size * beam_width * total_output_len, false); + deviceMalloc(&d_sequence_lengths, request_batch_size * beam_width, false); + + std::vector output_seq_len(request_batch_size, total_output_len); + std::unordered_map input_tensors = std::unordered_map{ + {"input_ids", + Tensor{MEMORY_GPU, TYPE_INT32, std::vector{request_batch_size, (size_t)max_input_len}, d_input_ids}}, + {"input_lengths", Tensor{MEMORY_GPU, TYPE_INT32, std::vector{request_batch_size}, d_input_lengths}}, + // NOTE: if you need prefix prompts, remember to add prefix_prompt_task_ids here + // {"prompt_learning_task_name_ids", Tensor{MEMORY_CPU, TYPE_INT32, std::vector{request_batch_size}, + // prefix_prompt_task_ids.data()}}, + {"output_seq_len", + Tensor{MEMORY_CPU, TYPE_UINT32, std::vector{request_batch_size}, output_seq_len.data()}}, + {"bad_words_list", Tensor{MEMORY_GPU, TYPE_INT32, {2, bad_words.size() / 2}, d_bad_words}}, + {"stop_words_list", Tensor{MEMORY_GPU, TYPE_INT32, {request_batch_size, 2, stop_words_len}, d_stop_words}}, + {"temperature", Tensor{MEMORY_CPU, TYPE_FP32, std::vector{1}, &temperature}}, + {"len_penalty", Tensor{MEMORY_CPU, TYPE_FP32, std::vector{1}, &len_penalty}}, + {"min_length", Tensor{MEMORY_CPU, TYPE_INT32, std::vector{1}, &min_length}}, + {"start_id", Tensor{MEMORY_CPU, TYPE_INT32, std::vector{request_batch_size}, start_ids.data()}}, + {"end_id", Tensor{MEMORY_CPU, TYPE_INT32, std::vector{request_batch_size}, end_ids.data()}}}; + + if (repetition_penalty != 1.0f) { + input_tensors.insert( + {"repetition_penalty", Tensor{MEMORY_CPU, TYPE_FP32, std::vector{1}, &repetition_penalty}}); + } + if (presence_penalty != 0.0f) { + input_tensors.insert( + {"presence_penalty", Tensor{MEMORY_CPU, TYPE_FP32, std::vector{1}, &presence_penalty}}); + } + + if (num_tasks > 0) { + // Prefix Prompt Task Name Ids here + input_tensors.insert( + {"prompt_learning_task_name_ids", + Tensor{MEMORY_CPU, TYPE_INT32, std::vector{request_batch_size}, prefix_prompt_task_ids.data()}}); + } + + if (top_k == 0 && top_p == 0.0f) { + FT_CHECK(beam_width > 1); + input_tensors.insert({"beam_search_diversity_rate", + Tensor{MEMORY_CPU, TYPE_FP32, std::vector{1}, &beam_search_diversity_rate}}); + } + else { + input_tensors.insert({"random_seed", Tensor{MEMORY_CPU, TYPE_UINT64, std::vector{1}, &random_seed}}); + if (top_p != 0.0f) { + input_tensors.insert({"runtime_top_p", Tensor{MEMORY_CPU, TYPE_FP32, std::vector{1}, &top_p}}); + } + if (top_k != 0) { + input_tensors.insert({"runtime_top_k", Tensor{MEMORY_CPU, TYPE_UINT32, std::vector{1}, &top_k}}); + } + } + + std::unordered_map output_tensors = std::unordered_map{ + {"output_ids", + Tensor{MEMORY_GPU, + TYPE_INT32, + std::vector{request_batch_size, beam_width, (size_t)total_output_len}, + d_output_ids}}, + {"sequence_length", + Tensor{MEMORY_GPU, TYPE_INT32, std::vector{request_batch_size, beam_width}, d_sequence_lengths}}, + {"output_log_probs", + Tensor{MEMORY_GPU, + TYPE_FP32, + std::vector{(size_t)request_output_len, request_batch_size, beam_width}, + nullptr}}}; + + print_mem_usage(); + + int ite = 1; + cudaDeviceSynchronize(); + mpi::barrier(); + + cudaProfilerStart(); + // warm up + ite = 1; + ft_nvtx::setScope("warmup_time"); + PUSH_RANGE("warmup time") + + for (int i = 0; i < ite; ++i) { + gpt.forward(&output_tensors, &input_tensors, &gpt_weights); + } + + cudaDeviceSynchronize(); + mpi::barrier(); + + POP_RANGE; + ft_nvtx::resetScope(); + + + if (rank == 0) { + + std::string fName = "out"; + auto outFile = std::ofstream(fName, std::ios::out); + if (!outFile.is_open()) { + printf("[WARNING] Cannot write results into output file %s \n", fName.c_str()); + } + else { + size_t outCount = total_output_len * request_batch_size * beam_width; + int* hBuf = new int[outCount]; + + size_t seqLCount = request_batch_size * beam_width; + int* seqlBuf = new int[seqLCount]; + + size_t inLCount = request_batch_size * beam_width; + int* inlBuf = new int[inLCount]; + + cudaD2Hcpy(hBuf, d_output_ids, outCount); + cudaD2Hcpy(seqlBuf, d_sequence_lengths, seqLCount); + cudaD2Hcpy(inlBuf, d_sequence_lengths, seqLCount); + printf("seqlBuf: %d\n", seqlBuf[0]); + + { + std::cout << "Writing " << outCount << " elements\n"; + int zeroCount = 0; + for (size_t i = 0; i < outCount; i++) { + if (hBuf[i] == int(0)) { + zeroCount++; + } + outFile << hBuf[i] << " "; + if ((i + 1) % (total_output_len) == 0) { + outFile << std::endl; + } + printf("%5d ", hBuf[i]); + // if (i < 10) { + // printf("%5d ", hBuf[i]); + // } + if ((i + 1) % (total_output_len) == 0) { + std::cout << std::endl; + } + } + std::cout << std::endl << "zeroCount = " << zeroCount << std::endl; + } + delete[] hBuf; + } + } + return; + // test time + struct timeval start, end; + mpi::barrier(); + cudaDeviceSynchronize(); + gettimeofday(&start, NULL); + + ft_nvtx::setScope("total_time"); + PUSH_RANGE("total time") + for (int i = 0; i < ite; ++i) { + gpt.forward(&output_tensors, &input_tensors, &gpt_weights); + } + cudaDeviceSynchronize(); + mpi::barrier(); + + POP_RANGE; + ft_nvtx::resetScope(); + gettimeofday(&end, NULL); + + cudaProfilerStop(); + + printf("[INFO] request_batch_size %ld beam_width %ld head_num %ld size_per_head %ld total_output_len %d" + " decoder_layers %ld vocab_size %ld FT-CPP-decoding-beamsearch-time %.2f ms\n", + request_batch_size, + beam_width, + head_num, + size_per_head, + total_output_len, + decoder_layers, + vocab_size, + ((end.tv_sec - start.tv_sec) * 1000 + (end.tv_usec - start.tv_usec) * 0.001) / ite); + + ftNcclParamDestroy(tensor_para); + ftNcclParamDestroy(pipeline_para); + + delete cublas_algo_map; + delete cublas_wrapper_mutex; + + cudaFree(d_bad_words); + cudaFree(d_stop_words); + if (d_input_ids != nullptr) { + cudaFree(d_input_ids); + } + if (d_input_lengths != nullptr) { + cudaFree(d_input_lengths); + } + if (d_output_ids != nullptr) { + deviceFree(d_output_ids); + } + if (d_sequence_lengths != nullptr) { + deviceFree(d_sequence_lengths); + } + return; +} diff --git a/examples/cpp/llama/llama_triton_example.cc b/examples/cpp/llama/llama_triton_example.cc new file mode 100644 index 000000000..3df2a2203 --- /dev/null +++ b/examples/cpp/llama/llama_triton_example.cc @@ -0,0 +1,457 @@ +/* + * Copyright (c) 2021-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "3rdparty/INIReader.h" +#include "examples/cpp/multi_gpu_gpt/gpt_example_utils.h" +#include "src/fastertransformer/triton_backend/llama/LlamaTritonModel.h" +#include "src/fastertransformer/triton_backend/llama/LlamaTritonModelInstance.h" +#include "src/fastertransformer/utils/custom_ar_comm.h" +#include "src/fastertransformer/utils/mpi_utils.h" +#include "src/fastertransformer/utils/nccl_utils.h" +#include "src/fastertransformer/utils/nvtx_utils.h" +#include "src/fastertransformer/utils/word_list.h" + +#include +#include + +namespace ft = fastertransformer; + +struct RequestParam { + int beam_width; + int request_output_len; + float beam_search_diversity_rate; + uint runtime_top_k; + float runtime_top_p; + float temperature; + float len_penalty; + float repetition_penalty; + float presence_penalty; + int min_length; + unsigned long long int random_seed; + int start_id; + int end_id; +}; + +std::vector>> +broadCastRequest(const std::vector& v_start_ids, + const std::vector& v_start_lengths, + const std::vector& v_bad_words, + const int node_id, + const int gpu_count, + const RequestParam param, + std::vector* pointer_record) +{ + // broadcast the request to all nodes, and copy "gpu_count" copies on different gpu + int size_1 = v_start_ids.size(); + int size_2 = v_start_lengths.size(); + int size_bad_words = v_bad_words.size(); + ft::mpi::bcast(&size_1, 1, ft::mpi::MPI_TYPE_INT, 0, ft::mpi::COMM_WORLD); + ft::mpi::bcast(&size_2, 1, ft::mpi::MPI_TYPE_INT, 0, ft::mpi::COMM_WORLD); + ft::mpi::bcast(&size_bad_words, 1, ft::mpi::MPI_TYPE_INT, 0, ft::mpi::COMM_WORLD); + + std::vector v_input_ids(size_1); + std::vector v_input_lengths(size_2); + std::vector v_input_bad_words(size_bad_words); + + if (node_id == 0) { + memcpy(v_input_ids.data(), v_start_ids.data(), size_1 * sizeof(int)); + memcpy(v_input_lengths.data(), v_start_lengths.data(), size_2 * sizeof(int)); + memcpy(v_input_bad_words.data(), v_bad_words.data(), size_bad_words * sizeof(int)); + } + ft::mpi::barrier(); + + int request_batch_size = size_2; + int max_input_len = size_1 / size_2; + + ft::mpi::bcast(v_input_ids.data(), size_1, ft::mpi::MPI_TYPE_INT, 0, ft::mpi::COMM_WORLD); + ft::mpi::bcast(v_input_lengths.data(), size_2, ft::mpi::MPI_TYPE_INT, 0, ft::mpi::COMM_WORLD); + ft::mpi::bcast(v_input_bad_words.data(), size_bad_words, ft::mpi::MPI_TYPE_INT, 0, ft::mpi::COMM_WORLD); + + std::vector>> request_list; + for (int device_id = 0; device_id < gpu_count; device_id++) { + ft::check_cuda_error(cudaSetDevice(device_id)); + + int* d_input_ids; + int* d_input_lengths; + int* d_input_bad_words; + + if (max_input_len == 0) { + // unconditional case, no input ids, so do nothing. + d_input_ids = nullptr; + d_input_lengths = nullptr; + max_input_len = 0; + } + else { + // conditional case. + ft::deviceMalloc(&d_input_ids, size_1, false); + ft::deviceMalloc(&d_input_lengths, size_2, false); + ft::cudaH2Dcpy(d_input_ids, v_input_ids.data(), size_1); + ft::cudaH2Dcpy(d_input_lengths, v_input_lengths.data(), size_2); + } + ft::deviceMalloc(&d_input_bad_words, size_bad_words, false); + ft::cudaH2Dcpy(d_input_bad_words, v_input_bad_words.data(), size_bad_words); + + uint32_t* request_output_len_ptr = (uint32_t*)malloc(request_batch_size * sizeof(uint32_t)); + for (int i = 0; i < request_batch_size; i++) { + request_output_len_ptr[i] = param.request_output_len; + } + + int* start_ids_ptr = (int*)malloc(request_batch_size * sizeof(int)); + int* end_ids_ptr = (int*)malloc(request_batch_size * sizeof(int)); + for (int i = 0; i < request_batch_size; i++) { + start_ids_ptr[i] = param.start_id; + end_ids_ptr[i] = param.end_id; + } + pointer_record->push_back(start_ids_ptr); + pointer_record->push_back(end_ids_ptr); + + request_list.push_back(std::shared_ptr>( + new std::unordered_map{ + {"input_ids", + triton::Tensor{triton::MEMORY_GPU, + triton::TYPE_INT32, + std::vector{(size_t)request_batch_size, (size_t)max_input_len}, + d_input_ids}}, + {"input_lengths", + triton::Tensor{triton::MEMORY_GPU, + triton::TYPE_INT32, + std::vector{(size_t)request_batch_size}, + d_input_lengths}}, + {"request_output_len", + triton::Tensor{triton::MEMORY_CPU, + triton::TYPE_INT32, + std::vector{(size_t)request_batch_size}, + request_output_len_ptr}}, + {"bad_words_list", + triton::Tensor{ + triton::MEMORY_GPU, triton::TYPE_INT32, {2, v_input_bad_words.size() / 2}, d_input_bad_words}}, + {"start_id", + triton::Tensor{triton::MEMORY_CPU, triton::TYPE_INT32, {(size_t)request_batch_size}, start_ids_ptr}}, + {"end_id", + triton::Tensor{triton::MEMORY_CPU, triton::TYPE_INT32, {(size_t)request_batch_size}, end_ids_ptr}}})); + + int* beam_width_ptr = new int(param.beam_width); + pointer_record->push_back(beam_width_ptr); + request_list[device_id]->insert( + {"beam_width", + triton::Tensor{triton::MEMORY_CPU, triton::TYPE_INT32, std::vector{1}, beam_width_ptr}}); + if (param.beam_width > 1) { + float* beam_search_diversity_rate_ptr = new float(param.beam_search_diversity_rate); + pointer_record->push_back(beam_search_diversity_rate_ptr); + request_list[device_id]->insert( + {"beam_search_diversity_rate", + triton::Tensor{ + triton::MEMORY_CPU, triton::TYPE_FP32, std::vector{1}, beam_search_diversity_rate_ptr}}); + } + else { + if (param.runtime_top_p != 0.0f) { + float* runtime_top_p_ptr = new float(param.runtime_top_p); + pointer_record->push_back(runtime_top_p_ptr); + request_list[device_id]->insert( + {"runtime_top_p", + triton::Tensor{triton::MEMORY_CPU, triton::TYPE_FP32, std::vector{1}, runtime_top_p_ptr}}); + } + if (param.runtime_top_k != 0) { + uint* runtime_top_k_ptr = new uint(param.runtime_top_k); + pointer_record->push_back(runtime_top_k_ptr); + request_list[device_id]->insert( + {"runtime_top_k", + triton::Tensor{ + triton::MEMORY_CPU, triton::TYPE_UINT32, std::vector{1}, runtime_top_k_ptr}}); + } + } + float* temperature_ptr = new float(param.temperature); + pointer_record->push_back(temperature_ptr); + request_list[device_id]->insert( + {"temperature", + triton::Tensor{triton::MEMORY_CPU, triton::TYPE_FP32, std::vector{1}, temperature_ptr}}); + float* len_penalty_ptr = new float(param.len_penalty); + pointer_record->push_back(len_penalty_ptr); + request_list[device_id]->insert( + {"len_penalty", + triton::Tensor{triton::MEMORY_CPU, triton::TYPE_FP32, std::vector{1}, len_penalty_ptr}}); + if (param.repetition_penalty != 1.0f) { + float* repetition_penalty_ptr = new float(param.repetition_penalty); + pointer_record->push_back(repetition_penalty_ptr); + request_list[device_id]->insert( + {"repetition_penalty", + triton::Tensor{ + triton::MEMORY_CPU, triton::TYPE_FP32, std::vector{1}, repetition_penalty_ptr}}); + } + if (param.presence_penalty != 0.0f) { + float* presence_penalty_ptr = new float(param.presence_penalty); + pointer_record->push_back(presence_penalty_ptr); + request_list[device_id]->insert( + {"presence_penalty", + triton::Tensor{triton::MEMORY_CPU, triton::TYPE_FP32, std::vector{1}, presence_penalty_ptr}}); + } + int* min_length_ptr = new int(param.min_length); + pointer_record->push_back(min_length_ptr); + request_list[device_id]->insert( + {"min_length", + triton::Tensor{triton::MEMORY_CPU, triton::TYPE_INT32, std::vector{1}, min_length_ptr}}); + unsigned long long int* random_seed_ptr = new unsigned long long int(param.random_seed); + pointer_record->push_back(random_seed_ptr); + request_list[device_id]->insert( + {"random_seed", + triton::Tensor{triton::MEMORY_CPU, triton::TYPE_UINT64, std::vector{1}, random_seed_ptr}}); + + pointer_record->push_back(d_input_ids); + pointer_record->push_back(d_input_lengths); + pointer_record->push_back(d_input_bad_words); + pointer_record->push_back(request_output_len_ptr); + } + + return request_list; +} + +std::vector>> +prepareRequest(std::string ini_name, const int node_id, const int gpu_count, std::vector* pointer_record) +{ + INIReader reader = INIReader(ini_name); + if (reader.ParseError() < 0) { + std::cout << "[ERROR] Can't load '" << ini_name << "'\n"; + ft::FT_CHECK(false); + } + + const size_t request_batch_size = reader.GetInteger("request", "request_batch_size"); + + const int start_id = reader.GetInteger("llama_7b", "start_id"); + const int end_id = reader.GetInteger("llama_7b", "end_id"); + + std::vector v_start_ids; + std::vector v_start_lengths; + + size_t max_input_len = 0; + ft::read_start_ids(request_batch_size, + &v_start_lengths, + &v_start_ids, + max_input_len, + end_id, + 1, + "../examples/cpp/llama/start_ids.csv"); + + std::vector v_bad_words; + ft::read_word_list("../examples/cpp/llama/bad_words.csv", v_bad_words); + + RequestParam param; + param.beam_width = reader.GetInteger("request", "beam_width"); + param.request_output_len = reader.GetInteger("request", "request_output_len"); + param.beam_search_diversity_rate = reader.GetFloat("request", "beam_search_diversity_rate"); + param.runtime_top_k = reader.GetInteger("request", "top_k"); + param.runtime_top_p = reader.GetFloat("request", "top_p"); + param.temperature = reader.GetFloat("request", "temperature"); + param.len_penalty = reader.GetFloat("request", "len_penalty"); + param.repetition_penalty = reader.GetFloat("request", "repetition_penalty", 1.0f); + param.presence_penalty = reader.GetFloat("request", "presence_penalty", 0.0f); + param.min_length = reader.GetInteger("request", "min_length", 0); + param.random_seed = (unsigned long long int)0; + param.start_id = start_id; + param.end_id = end_id; + + auto request_list = + broadCastRequest(v_start_ids, v_start_lengths, v_bad_words, node_id, gpu_count, param, pointer_record); + return request_list; +} + +int threadCreateModelInstances(std::shared_ptr model, + std::vector>* model_instances, + const int device_id, + const int rank, + std::pair, std::vector> nccl_params, + std::shared_ptr custom_all_reduce_comm = nullptr) +{ + printf("[INFO] rank = %d \n", rank); + ft::check_cuda_error(cudaSetDevice(device_id)); + cudaStream_t stream; + ft::check_cuda_error(cudaStreamCreate(&stream)); + model->createSharedWeights(device_id, rank); + auto model_instance = model->createModelInstance(device_id, rank, stream, nccl_params, custom_all_reduce_comm); + model_instances->at(device_id) = std::move(model_instance); + printf("model instance %d is created \n", device_id); + ft::print_mem_usage(); + return 0; +} + +int threadForward(std::unique_ptr* model_instance, + std::shared_ptr> request, + std::shared_ptr>* output_tensors, + const int device_id) +{ + ft::check_cuda_error(cudaSetDevice(device_id)); + *output_tensors = (*model_instance)->forward(request); + return 0; +} + +int main(int argc, char* argv[]) +{ + /* + Prepare the nccl ids, node id, device id and world size + by MPI or triton + */ + + // MPICHECK(MPI_Init(&argc, &argv)); + ft::mpi::initialize(&argc, &argv); + int node_id = ft::mpi::getCommWorldRank(); + int node_num = ft::mpi::getCommWorldSize(); + std::cout << "node_id: " << node_id << ", node_num: " << node_num << std::endl; + + // Note: Only supports that all nodes have same gpu count + const int gpu_count = ft::getDeviceCount(); + std::cout << "gpu_count: " << gpu_count << std::endl; + const int world_size = node_num * gpu_count; + std::string ini_name = argc >= 2 ? std::string(argv[1]) : "/notebooks/FasterTransformer/examples/cpp/llama/llama_config.ini"; + + // step 1: Create model + std::shared_ptr model = AbstractTransformerModel::createLlamaModel(ini_name); + int tensor_para_size = model->getTensorParaSize(); + int pipeline_para_size = model->getPipelineParaSize(); + FT_CHECK_WITH_INFO(world_size == (tensor_para_size * pipeline_para_size), + "World Size != Tensor Parallel Size * Pipeline Parallel Size !"); + + std::cout << model->toString(); + + // step 2: Initialize the NCCL + std::pair, std::vector> nccl_comms = model->createNcclParams(node_id); + cudaDeviceSynchronize(); + + // Optional Step: create custom all reduce comm + std::vector> custom_all_reduce_comms; + model->createCustomComms(&custom_all_reduce_comms, world_size); + + // step 3: Create model instances + std::vector> model_instances((size_t)gpu_count); + std::vector threads; + for (int device_id = 0; device_id < gpu_count; device_id++) { + const int rank = node_id * gpu_count + device_id; + threads.push_back(std::thread(threadCreateModelInstances, + model, + &model_instances, + device_id, + rank, + nccl_comms, + custom_all_reduce_comms[rank])); + } + for (auto& t : threads) { + t.join(); + } + + // step 4: prepare request + std::vector pointer_record; // Used to prevent the pointers are release after leaving functions + std::vector>> request_list = + prepareRequest(ini_name, node_id, gpu_count, &pointer_record); + printf("[INFO] request is created \n"); + + // step 5: Forward + std::vector>> output_tensors_lists( + (size_t)gpu_count); + for (int i = 0; i < 2; i++) { + threads.clear(); + for (int device_id = 0; device_id < gpu_count; device_id++) { + threads.push_back(std::thread(threadForward, + &model_instances[device_id], + request_list[device_id], + &output_tensors_lists[device_id], + device_id)); + } + for (auto& t : threads) { + t.join(); + } + } + printf("[INFO] forward is completed. \n"); + + const int* d_output_ids = (const int*)output_tensors_lists[0].get()->at("output_ids").data; + const int batch_size = output_tensors_lists[0].get()->at("output_ids").shape[0]; + const int beam_width = output_tensors_lists[0].get()->at("output_ids").shape[1]; + const int seq_len = output_tensors_lists[0].get()->at("output_ids").shape[2]; + const int* d_input_lengths = (const int*)output_tensors_lists[0].get()->at("input_lengths").data; + // step 6: check results + if (node_id == 0) { + + std::string fName = "out"; + auto outFile = std::ofstream(fName, std::ios::out); + if (!outFile.is_open()) { + printf("[WARNING] Cannot write results into output file %s \n", fName.c_str()); + } + else { + size_t outCount = batch_size * beam_width * seq_len; + int* hBuf = new int[outCount]; + int* iBuf = new int[batch_size]; + ft::cudaD2Hcpy(hBuf, d_output_ids, outCount); + ft::cudaD2Hcpy(iBuf, d_input_lengths, batch_size); + + + { + std::cout << "Writing " << outCount << " elements\n"; + int zeroCount = 0; + for (int i=0; i +#include +#include + +template +void groupedquery_attention_(const KERNEL_PARAMS_TYPE& params, const cudaStream_t& stream) +{ + switch (params.hidden_size_per_head) { + case 32: + mgqa_launch_kernel(params, stream); + break; + case 48: + mgqa_launch_kernel(params, stream); + break; + case 64: + mgqa_launch_kernel(params, stream); + break; + case 80: + mgqa_launch_kernel(params, stream); + break; + case 96: + mgqa_launch_kernel(params, stream); + break; + case 128: + mgqa_launch_kernel(params, stream); + break; + case 144: + mgqa_launch_kernel(params, stream); + break; + case 160: + mgqa_launch_kernel(params, stream); + break; + case 192: + mgqa_launch_kernel(params, stream); + break; + case 224: + mgqa_launch_kernel(params, stream); + break; + case 256: + mgqa_launch_kernel(params, stream); + break; + default: + assert(false); + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +void masked_groupedquery_attention(const Masked_groupedquery_attention_params& params, const cudaStream_t& stream) +{ + groupedquery_attention_>(params, stream); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +void masked_groupedquery_attention(const Masked_groupedquery_attention_params& params, const cudaStream_t& stream) +{ + groupedquery_attention_>(params, stream); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#ifdef ENABLE_BF16 +void masked_groupedquery_attention(const Masked_groupedquery_attention_params<__nv_bfloat16>& params, + const cudaStream_t& stream) +{ + groupedquery_attention_<__nv_bfloat16, Masked_groupedquery_attention_params<__nv_bfloat16>>(params, stream); +} +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#ifdef ENABLE_FP8 +void masked_groupedquery_attention(const Masked_groupedquery_attention_params<__nv_fp8_e4m3>& params, + const cudaStream_t& stream) +{ + groupedquery_attention_<__nv_fp8_e4m3, Masked_groupedquery_attention_params<__nv_fp8_e4m3>>(params, stream); +} +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention.h b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention.h new file mode 100644 index 000000000..b0968519f --- /dev/null +++ b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention.h @@ -0,0 +1,54 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "src/fastertransformer/kernels/decoder_masked_multihead_attention.h" +#include "src/fastertransformer/layers/attention_layers_fp8/AttentionFP8Weight.h" +#include "src/fastertransformer/utils/cuda_bf16_wrapper.h" +#include "src/fastertransformer/utils/cuda_fp8_utils.h" +#include +#include +#include +#include +#include + +template +struct GroupedQuery_attention_params: public Multihead_attention_params_base { + // allows to exist attention eary + bool* finished = nullptr; + int num_kv_heads = 0; + // required in case of masked attention with different length + const int* length_per_sample = nullptr; +}; + +template +using Masked_groupedquery_attention_params = GroupedQuery_attention_params; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +void masked_groupedquery_attention(const Masked_groupedquery_attention_params& params, const cudaStream_t& stream); +void masked_groupedquery_attention(const Masked_groupedquery_attention_params& params, const cudaStream_t& stream); +#ifdef ENABLE_BF16 +void masked_groupedquery_attention(const Masked_groupedquery_attention_params<__nv_bfloat16>& params, + const cudaStream_t& stream); +#endif +#ifdef ENABLE_FP8 +void masked_groupedquery_attention(const Masked_groupedquery_attention_params<__nv_fp8_e4m3>& params, + const cudaStream_t& stream); +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_groupedquery_attention_128.cu b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_groupedquery_attention_128.cu new file mode 100644 index 000000000..9f9f7ca3f --- /dev/null +++ b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_groupedquery_attention_128.cu @@ -0,0 +1,89 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "decoder_masked_groupedquery_attention_template.hpp" +#include "src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention.h" +#include "src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h" +#include "src/fastertransformer/utils/cuda_bf16_wrapper.h" +#include +#include +#include + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#define MGQA_LAUNCH_KERNEL( \ + T, Dh, Dh_MAX, THDS_PER_KEY, THDS_PER_VALUE, THDS_PER_BLOCK, HAS_BEAMS, stream) \ + size_t smem_sz = mmha::smem_size_in_bytes(params, THDS_PER_VALUE, THDS_PER_BLOCK); \ + dim3 grid(params.num_heads, params.batch_size); \ + mmha::masked_groupedquery_attention_kernel<<>>(params) + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// !!! Specialize the launcher for Cross attention +template +void mgqa_launch_kernel(const KERNEL_PARAMS_TYPE& params, const cudaStream_t& stream) +{ + constexpr int THREADS_PER_VALUE = threads_per_value_t::value; + constexpr bool DO_CROSS_ATTENTION = std::is_same>::value; + int tlength = (DO_CROSS_ATTENTION) ? params.memory_max_len : params.timestep; + // printf("tlength, CROSS_ATTENTION = %d, %d\n", tlength); + if (params.cache_indir == nullptr) { + if (tlength < 32) { + MGQA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 4, THREADS_PER_VALUE, 64, false, stream); + } + else if (tlength < 2048) { + MGQA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 2, THREADS_PER_VALUE, 128, false, stream); + } + else { + MGQA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 1, THREADS_PER_VALUE, 256, false, stream); + } + } + else { + if (tlength < 32) { + MGQA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 4, THREADS_PER_VALUE, 64, true, stream); + } + else if (tlength < 2048) { + MGQA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 2, THREADS_PER_VALUE, 128, true, stream); + } + else { + MGQA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 1, THREADS_PER_VALUE, 256, true, stream); + } + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template void mgqa_launch_kernel>( + const GroupedQuery_attention_params& params, const cudaStream_t& stream); +template void mgqa_launch_kernel>( + const GroupedQuery_attention_params& params, const cudaStream_t& stream); +#ifdef ENABLE_BF16 +template void mgqa_launch_kernel<__nv_bfloat16, 128, 128, GroupedQuery_attention_params<__nv_bfloat16>>( + const GroupedQuery_attention_params<__nv_bfloat16>& params, const cudaStream_t& stream); +#endif +#ifdef ENABLE_FP8 +template void mgqa_launch_kernel<__nv_fp8_e4m3, 128, 128, GroupedQuery_attention_params<__nv_fp8_e4m3>>( + const GroupedQuery_attention_params<__nv_fp8_e4m3>& params, const cudaStream_t& stream); +#endif + + +#undef MGQA_LAUNCH_KERNEL diff --git a/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_groupedquery_attention_144.cu b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_groupedquery_attention_144.cu new file mode 100644 index 000000000..6da6da083 --- /dev/null +++ b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_groupedquery_attention_144.cu @@ -0,0 +1,87 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "decoder_masked_groupedquery_attention_template.hpp" +#include "src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention.h" +#include "src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h" +#include "src/fastertransformer/utils/cuda_bf16_wrapper.h" +#include +#include +#include + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#define MGQA_LAUNCH_KERNEL( \ + T, Dh, Dh_MAX, THDS_PER_KEY, THDS_PER_VALUE, THDS_PER_BLOCK, HAS_BEAMS, stream) \ + size_t smem_sz = mmha::smem_size_in_bytes(params, THDS_PER_VALUE, THDS_PER_BLOCK); \ + dim3 grid(params.num_heads, params.batch_size); \ + mmha::masked_groupedquery_attention_kernel<<>>(params) + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// !!! Specialize the launcher for Cross attention +template +void mgqa_launch_kernel(const KERNEL_PARAMS_TYPE& params, const cudaStream_t& stream) +{ + constexpr int THREADS_PER_VALUE = Dh_MAX * sizeof(T) / 16; + constexpr bool DO_CROSS_ATTENTION = std::is_same>::value; + int tlength = (DO_CROSS_ATTENTION) ? params.memory_max_len : params.timestep; + if (params.cache_indir == nullptr) { + if (tlength < 32) { + MGQA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 4, THREADS_PER_VALUE, 64, false, stream); + } + else if (tlength < 2048) { + MGQA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 2, THREADS_PER_VALUE, 128, false, stream); + } + else { + MGQA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 1, THREADS_PER_VALUE, 256, false, stream); + } + } + else { + if (tlength < 32) { + MGQA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 4, THREADS_PER_VALUE, 64, true, stream); + } + else if (tlength < 2048) { + MGQA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 2, THREADS_PER_VALUE, 128, true, stream); + } + else { + MGQA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 1, THREADS_PER_VALUE, 256, true, stream); + } + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template void mgqa_launch_kernel>( + const GroupedQuery_attention_params& params, const cudaStream_t& stream); +template void mgqa_launch_kernel>( + const GroupedQuery_attention_params& params, const cudaStream_t& stream); +#ifdef ENABLE_BF16 +template void mgqa_launch_kernel<__nv_bfloat16, 144, 256, GroupedQuery_attention_params<__nv_bfloat16>>( + const GroupedQuery_attention_params<__nv_bfloat16>& params, const cudaStream_t& stream); +#endif +#ifdef ENABLE_FP8 +template void mgqa_launch_kernel<__nv_fp8_e4m3, 144, 256, GroupedQuery_attention_params<__nv_fp8_e4m3>>( + const GroupedQuery_attention_params<__nv_fp8_e4m3>& params, const cudaStream_t& stream); +#endif + +#undef MGQA_LAUNCH_KERNEL diff --git a/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_groupedquery_attention_160.cu b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_groupedquery_attention_160.cu new file mode 100644 index 000000000..bde08b41d --- /dev/null +++ b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_groupedquery_attention_160.cu @@ -0,0 +1,87 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "decoder_masked_groupedquery_attention_template.hpp" +#include "src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention.h" +#include "src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h" +#include "src/fastertransformer/utils/cuda_bf16_wrapper.h" +#include +#include +#include + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#define MGQA_LAUNCH_KERNEL( \ + T, Dh, Dh_MAX, THDS_PER_KEY, THDS_PER_VALUE, THDS_PER_BLOCK, HAS_BEAMS, stream) \ + size_t smem_sz = mmha::smem_size_in_bytes(params, THDS_PER_VALUE, THDS_PER_BLOCK); \ + dim3 grid(params.num_heads, params.batch_size); \ + mmha::masked_groupedquery_attention_kernel<<>>(params) + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// !!! Specialize the launcher for Cross attention +template +void mgqa_launch_kernel(const KERNEL_PARAMS_TYPE& params, const cudaStream_t& stream) +{ + constexpr int THREADS_PER_VALUE = threads_per_value_t::value; + constexpr bool DO_CROSS_ATTENTION = std::is_same>::value; + int tlength = (DO_CROSS_ATTENTION) ? params.memory_max_len : params.timestep; + if (params.cache_indir == nullptr) { + if (tlength < 32) { + MGQA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 4, THREADS_PER_VALUE, 64, false, stream); + } + else if (tlength < 2048) { + MGQA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 2, THREADS_PER_VALUE, 128, false, stream); + } + else { + MGQA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 1, THREADS_PER_VALUE, 256, false, stream); + } + } + else { + if (tlength < 32) { + MGQA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 4, THREADS_PER_VALUE, 64, true, stream); + } + else if (tlength < 2048) { + MGQA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 2, THREADS_PER_VALUE, 128, true, stream); + } + else { + MGQA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 1, THREADS_PER_VALUE, 256, true, stream); + } + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template void mgqa_launch_kernel>( + const GroupedQuery_attention_params& params, const cudaStream_t& stream); +template void mgqa_launch_kernel>( + const GroupedQuery_attention_params& params, const cudaStream_t& stream); +#ifdef ENABLE_BF16 +template void mgqa_launch_kernel<__nv_bfloat16, 160, 256, GroupedQuery_attention_params<__nv_bfloat16>>( + const GroupedQuery_attention_params<__nv_bfloat16>& params, const cudaStream_t& stream); +#endif +#ifdef ENABLE_FP8 +template void mgqa_launch_kernel<__nv_fp8_e4m3, 160, 256, GroupedQuery_attention_params<__nv_fp8_e4m3>>( + const GroupedQuery_attention_params<__nv_fp8_e4m3>& params, const cudaStream_t& stream); +#endif + +#undef MGQA_LAUNCH_KERNEL diff --git a/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_groupedquery_attention_192.cu b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_groupedquery_attention_192.cu new file mode 100644 index 000000000..7fa77808f --- /dev/null +++ b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_groupedquery_attention_192.cu @@ -0,0 +1,87 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "decoder_masked_groupedquery_attention_template.hpp" +#include "src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention.h" +#include "src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h" +#include "src/fastertransformer/utils/cuda_bf16_wrapper.h" +#include +#include +#include + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#define MGQA_LAUNCH_KERNEL( \ + T, Dh, Dh_MAX, THDS_PER_KEY, THDS_PER_VALUE, THDS_PER_BLOCK, HAS_BEAMS, stream) \ + size_t smem_sz = mmha::smem_size_in_bytes(params, THDS_PER_VALUE, THDS_PER_BLOCK); \ + dim3 grid(params.num_heads, params.batch_size); \ + mmha::masked_groupedquery_attention_kernel<<>>(params) + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// !!! Specialize the launcher for Cross attention +template +void mgqa_launch_kernel(const KERNEL_PARAMS_TYPE& params, const cudaStream_t& stream) +{ + constexpr int THREADS_PER_VALUE = threads_per_value_t::value; + constexpr bool DO_CROSS_ATTENTION = std::is_same>::value; + int tlength = (DO_CROSS_ATTENTION) ? params.memory_max_len : params.timestep; + if (params.cache_indir == nullptr) { + if (tlength < 32) { + MGQA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 4, THREADS_PER_VALUE, 64, false, stream); + } + else if (tlength < 2048) { + MGQA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 2, THREADS_PER_VALUE, 128, false, stream); + } + else { + MGQA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 1, THREADS_PER_VALUE, 256, false, stream); + } + } + else { + if (tlength < 32) { + MGQA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 4, THREADS_PER_VALUE, 64, true, stream); + } + else if (tlength < 2048) { + MGQA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 2, THREADS_PER_VALUE, 128, true, stream); + } + else { + MGQA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 1, THREADS_PER_VALUE, 256, true, stream); + } + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template void mgqa_launch_kernel>( + const GroupedQuery_attention_params& params, const cudaStream_t& stream); +template void mgqa_launch_kernel>( + const GroupedQuery_attention_params& params, const cudaStream_t& stream); +#ifdef ENABLE_BF16 +template void mgqa_launch_kernel<__nv_bfloat16, 192, 256, GroupedQuery_attention_params<__nv_bfloat16>>( + const GroupedQuery_attention_params<__nv_bfloat16>& params, const cudaStream_t& stream); +#endif +#ifdef ENABLE_FP8 +template void mgqa_launch_kernel<__nv_fp8_e4m3, 192, 256, GroupedQuery_attention_params<__nv_fp8_e4m3>>( + const GroupedQuery_attention_params<__nv_fp8_e4m3>& params, const cudaStream_t& stream); +#endif + +#undef MGQA_LAUNCH_KERNEL diff --git a/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_groupedquery_attention_224.cu b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_groupedquery_attention_224.cu new file mode 100644 index 000000000..8fdf2e1a5 --- /dev/null +++ b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_groupedquery_attention_224.cu @@ -0,0 +1,87 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "decoder_masked_groupedquery_attention_template.hpp" +#include "src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention.h" +#include "src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h" +#include "src/fastertransformer/utils/cuda_bf16_wrapper.h" +#include +#include +#include + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#define MGQA_LAUNCH_KERNEL( \ + T, Dh, Dh_MAX, THDS_PER_KEY, THDS_PER_VALUE, THDS_PER_BLOCK, HAS_BEAMS, stream) \ + size_t smem_sz = mmha::smem_size_in_bytes(params, THDS_PER_VALUE, THDS_PER_BLOCK); \ + dim3 grid(params.num_heads, params.batch_size); \ + mmha::masked_groupedquery_attention_kernel<<>>(params) + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// !!! Specialize the launcher for Cross attention +template +void mgqa_launch_kernel(const KERNEL_PARAMS_TYPE& params, const cudaStream_t& stream) +{ + constexpr int THREADS_PER_VALUE = threads_per_value_t::value; + constexpr bool DO_CROSS_ATTENTION = std::is_same>::value; + int tlength = (DO_CROSS_ATTENTION) ? params.memory_max_len : params.timestep; + if (params.cache_indir == nullptr) { + if (tlength < 32) { + MGQA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 4, THREADS_PER_VALUE, 64, false, stream); + } + else if (tlength < 2048) { + MGQA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 2, THREADS_PER_VALUE, 128, false, stream); + } + else { + MGQA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 1, THREADS_PER_VALUE, 256, false, stream); + } + } + else { + if (tlength < 32) { + MGQA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 4, THREADS_PER_VALUE, 64, true, stream); + } + else if (tlength < 2048) { + MGQA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 2, THREADS_PER_VALUE, 128, true, stream); + } + else { + MGQA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 1, THREADS_PER_VALUE, 256, true, stream); + } + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template void mgqa_launch_kernel>( + const GroupedQuery_attention_params& params, const cudaStream_t& stream); +template void mgqa_launch_kernel>( + const GroupedQuery_attention_params& params, const cudaStream_t& stream); +#ifdef ENABLE_BF16 +template void mgqa_launch_kernel<__nv_bfloat16, 224, 256, GroupedQuery_attention_params<__nv_bfloat16>>( + const GroupedQuery_attention_params<__nv_bfloat16>& params, const cudaStream_t& stream); +#endif +#ifdef ENABLE_FP8 +template void mgqa_launch_kernel<__nv_fp8_e4m3, 224, 256, GroupedQuery_attention_params<__nv_fp8_e4m3>>( + const GroupedQuery_attention_params<__nv_fp8_e4m3>& params, const cudaStream_t& stream); +#endif + +#undef MGQA_LAUNCH_KERNEL diff --git a/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_groupedquery_attention_256.cu b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_groupedquery_attention_256.cu new file mode 100644 index 000000000..359bd9214 --- /dev/null +++ b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_groupedquery_attention_256.cu @@ -0,0 +1,87 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "decoder_masked_groupedquery_attention_template.hpp" +#include "src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention.h" +#include "src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h" +#include "src/fastertransformer/utils/cuda_bf16_wrapper.h" +#include +#include +#include + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#define MGQA_LAUNCH_KERNEL( \ + T, Dh, Dh_MAX, THDS_PER_KEY, THDS_PER_VALUE, THDS_PER_BLOCK, HAS_BEAMS, stream) \ + size_t smem_sz = mmha::smem_size_in_bytes(params, THDS_PER_VALUE, THDS_PER_BLOCK); \ + dim3 grid(params.num_heads, params.batch_size); \ + mmha::masked_groupedquery_attention_kernel<<>>(params) + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// !!! Specialize the launcher for Cross attention +template +void mgqa_launch_kernel(const KERNEL_PARAMS_TYPE& params, const cudaStream_t& stream) +{ + constexpr int THREADS_PER_VALUE = threads_per_value_t::value; + constexpr bool DO_CROSS_ATTENTION = std::is_same>::value; + int tlength = (DO_CROSS_ATTENTION) ? params.memory_max_len : params.timestep; + if (params.cache_indir == nullptr) { + if (tlength < 32) { + MGQA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 4, THREADS_PER_VALUE, 64, false, stream); + } + else if (tlength < 2048) { + MGQA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 2, THREADS_PER_VALUE, 128, false, stream); + } + else { + MGQA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 1, THREADS_PER_VALUE, 256, false, stream); + } + } + else { + if (tlength < 32) { + MGQA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 4, THREADS_PER_VALUE, 64, true, stream); + } + else if (tlength < 2048) { + MGQA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 2, THREADS_PER_VALUE, 128, true, stream); + } + else { + MGQA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 1, THREADS_PER_VALUE, 256, true, stream); + } + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template void mgqa_launch_kernel>( + const GroupedQuery_attention_params& params, const cudaStream_t& stream); +template void mgqa_launch_kernel>( + const GroupedQuery_attention_params& params, const cudaStream_t& stream); +#ifdef ENABLE_BF16 +template void mgqa_launch_kernel<__nv_bfloat16, 256, 256, GroupedQuery_attention_params<__nv_bfloat16>>( + const GroupedQuery_attention_params<__nv_bfloat16>& params, const cudaStream_t& stream); +#endif +#ifdef ENABLE_FP8 +template void mgqa_launch_kernel<__nv_fp8_e4m3, 256, 256, GroupedQuery_attention_params<__nv_fp8_e4m3>>( + const GroupedQuery_attention_params<__nv_fp8_e4m3>& params, const cudaStream_t& stream); +#endif + +#undef MGQA_LAUNCH_KERNEL diff --git a/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_groupedquery_attention_32.cu b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_groupedquery_attention_32.cu new file mode 100644 index 000000000..827efd738 --- /dev/null +++ b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_groupedquery_attention_32.cu @@ -0,0 +1,88 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "decoder_masked_groupedquery_attention_template.hpp" +#include "src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention.h" +#include "src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h" +#include "src/fastertransformer/utils/cuda_bf16_wrapper.h" +#include +#include +#include + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#define MGQA_LAUNCH_KERNEL( \ + T, Dh, Dh_MAX, THDS_PER_KEY, THDS_PER_VALUE, THDS_PER_BLOCK, HAS_BEAMS, stream) \ + size_t smem_sz = mmha::smem_size_in_bytes(params, THDS_PER_VALUE, THDS_PER_BLOCK); \ + dim3 grid(params.num_heads, params.batch_size); \ + mmha::masked_groupedquery_attention_kernel<<>>(params) + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// !!! Specialize the launcher for Cross attention +template +void mgqa_launch_kernel(const KERNEL_PARAMS_TYPE& params, const cudaStream_t& stream) +{ + constexpr int THREADS_PER_VALUE = threads_per_value_t::value; + //constexpr bool DO_CROSS_ATTENTION = std::is_same>::value; + constexpr bool DO_CROSS_ATTENTION = false; + int tlength = (DO_CROSS_ATTENTION) ? params.memory_max_len : params.timestep; + if (params.cache_indir == nullptr) { + if (tlength < 32) { + MGQA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 4, THREADS_PER_VALUE, 64, false, stream); + } + else if (tlength < 2048) { + MGQA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 2, THREADS_PER_VALUE, 128, false, stream); + } + else { + MGQA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 1, THREADS_PER_VALUE, 256, false, stream); + } + } + else { + if (tlength < 32) { + MGQA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 4, THREADS_PER_VALUE, 64, true, stream); + } + else if (tlength < 2048) { + MGQA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 2, THREADS_PER_VALUE, 128, true, stream); + } + else { + MGQA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 1, THREADS_PER_VALUE, 256, true, stream); + } + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template void mgqa_launch_kernel>( + const GroupedQuery_attention_params& params, const cudaStream_t& stream); +template void mgqa_launch_kernel>( + const GroupedQuery_attention_params& params, const cudaStream_t& stream); +#ifdef ENABLE_BF16 +template void mgqa_launch_kernel<__nv_bfloat16, 32, 32, GroupedQuery_attention_params<__nv_bfloat16>>( + const GroupedQuery_attention_params<__nv_bfloat16>& params, const cudaStream_t& stream); +#endif +#ifdef ENABLE_FP8 +template void mgqa_launch_kernel<__nv_fp8_e4m3, 32, 32, GroupedQuery_attention_params<__nv_fp8_e4m3>>( + const GroupedQuery_attention_params<__nv_fp8_e4m3>& params, const cudaStream_t& stream); +#endif + +#undef MGQA_LAUNCH_KERNEL diff --git a/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_groupedquery_attention_48.cu b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_groupedquery_attention_48.cu new file mode 100644 index 000000000..cb7abfbcc --- /dev/null +++ b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_groupedquery_attention_48.cu @@ -0,0 +1,87 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "decoder_masked_groupedquery_attention_template.hpp" +#include "src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention.h" +#include "src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h" +#include "src/fastertransformer/utils/cuda_bf16_wrapper.h" +#include +#include +#include + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#define MGQA_LAUNCH_KERNEL( \ + T, Dh, Dh_MAX, THDS_PER_KEY, THDS_PER_VALUE, THDS_PER_BLOCK, HAS_BEAMS, stream) \ + size_t smem_sz = mmha::smem_size_in_bytes(params, THDS_PER_VALUE, THDS_PER_BLOCK); \ + dim3 grid(params.num_heads, params.batch_size); \ + mmha::masked_groupedquery_attention_kernel<<>>(params) + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// !!! Specialize the launcher for Cross attention +template +void mgqa_launch_kernel(const KERNEL_PARAMS_TYPE& params, const cudaStream_t& stream) +{ + constexpr int THREADS_PER_VALUE = threads_per_value_t::value; + constexpr bool DO_CROSS_ATTENTION = false; // std::is_same>::value; + int tlength = (DO_CROSS_ATTENTION) ? params.memory_max_len : params.timestep; + if (params.cache_indir == nullptr) { + if (tlength < 32) { + MGQA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 4, THREADS_PER_VALUE, 64, false, stream); + } + else if (tlength < 2048) { + MGQA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 2, THREADS_PER_VALUE, 128, false, stream); + } + else { + MGQA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 1, THREADS_PER_VALUE, 256, false, stream); + } + } + else { + if (tlength < 32) { + MGQA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 4, THREADS_PER_VALUE, 64, true, stream); + } + else if (tlength < 2048) { + MGQA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 2, THREADS_PER_VALUE, 128, true, stream); + } + else { + MGQA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 1, THREADS_PER_VALUE, 256, true, stream); + } + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template void mgqa_launch_kernel>( + const GroupedQuery_attention_params& params, const cudaStream_t& stream); +template void mgqa_launch_kernel>( + const GroupedQuery_attention_params& params, const cudaStream_t& stream); +#ifdef ENABLE_BF16 +template void mgqa_launch_kernel<__nv_bfloat16, 48, 64, GroupedQuery_attention_params<__nv_bfloat16>>( + const GroupedQuery_attention_params<__nv_bfloat16>& params, const cudaStream_t& stream); +#endif +#ifdef ENABLE_FP8 +template void mgqa_launch_kernel<__nv_fp8_e4m3, 48, 64, GroupedQuery_attention_params<__nv_fp8_e4m3>>( + const GroupedQuery_attention_params<__nv_fp8_e4m3>& params, const cudaStream_t& stream); +#endif + +#undef MGQA_LAUNCH_KERNEL diff --git a/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_groupedquery_attention_64.cu b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_groupedquery_attention_64.cu new file mode 100644 index 000000000..4f3105526 --- /dev/null +++ b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_groupedquery_attention_64.cu @@ -0,0 +1,87 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "decoder_masked_groupedquery_attention_template.hpp" +#include "src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention.h" +#include "src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h" +#include "src/fastertransformer/utils/cuda_bf16_wrapper.h" +#include +#include +#include + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#define MGQA_LAUNCH_KERNEL( \ + T, Dh, Dh_MAX, THDS_PER_KEY, THDS_PER_VALUE, THDS_PER_BLOCK, HAS_BEAMS, stream) \ + size_t smem_sz = mmha::smem_size_in_bytes(params, THDS_PER_VALUE, THDS_PER_BLOCK); \ + dim3 grid(params.num_heads, params.batch_size); \ + mmha::masked_groupedquery_attention_kernel<<>>(params) + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// !!! Specialize the launcher for Cross attention +template +void mgqa_launch_kernel(const KERNEL_PARAMS_TYPE& params, const cudaStream_t& stream) +{ + constexpr int THREADS_PER_VALUE = threads_per_value_t::value; + constexpr bool DO_CROSS_ATTENTION = false; // std::is_same>::value; + int tlength = (DO_CROSS_ATTENTION) ? params.memory_max_len : params.timestep; + if (params.cache_indir == nullptr) { + if (tlength < 32) { + MGQA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 4, THREADS_PER_VALUE, 64, false, stream); + } + else if (tlength < 2048) { + MGQA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 2, THREADS_PER_VALUE, 128, false, stream); + } + else { + MGQA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 1, THREADS_PER_VALUE, 256, false, stream); + } + } + else { + if (tlength < 32) { + MGQA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 4, THREADS_PER_VALUE, 64, true, stream); + } + else if (tlength < 2048) { + MGQA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 2, THREADS_PER_VALUE, 128, true, stream); + } + else { + MGQA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 1, THREADS_PER_VALUE, 256, true, stream); + } + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template void mgqa_launch_kernel>( + const GroupedQuery_attention_params& params, const cudaStream_t& stream); +template void mgqa_launch_kernel>( + const GroupedQuery_attention_params& params, const cudaStream_t& stream); +#ifdef ENABLE_BF16 +template void mgqa_launch_kernel<__nv_bfloat16, 64, 64, GroupedQuery_attention_params<__nv_bfloat16>>( + const GroupedQuery_attention_params<__nv_bfloat16>& params, const cudaStream_t& stream); +#endif +#ifdef ENABLE_FP8 +template void mgqa_launch_kernel<__nv_fp8_e4m3, 64, 64, GroupedQuery_attention_params<__nv_fp8_e4m3>>( + const GroupedQuery_attention_params<__nv_fp8_e4m3>& params, const cudaStream_t& stream); +#endif + +#undef MGQA_LAUNCH_KERNEL diff --git a/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_groupedquery_attention_80.cu b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_groupedquery_attention_80.cu new file mode 100644 index 000000000..81645f4fd --- /dev/null +++ b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_groupedquery_attention_80.cu @@ -0,0 +1,87 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "decoder_masked_groupedquery_attention_template.hpp" +#include "src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention.h" +#include "src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h" +#include "src/fastertransformer/utils/cuda_bf16_wrapper.h" +#include +#include +#include + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#define MGQA_LAUNCH_KERNEL( \ + T, Dh, Dh_MAX, THDS_PER_KEY, THDS_PER_VALUE, THDS_PER_BLOCK, HAS_BEAMS, stream) \ + size_t smem_sz = mmha::smem_size_in_bytes(params, THDS_PER_VALUE, THDS_PER_BLOCK); \ + dim3 grid(params.num_heads, params.batch_size); \ + mmha::masked_groupedquery_attention_kernel<<>>(params) + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// !!! Specialize the launcher for Cross attention +template +void mgqa_launch_kernel(const KERNEL_PARAMS_TYPE& params, const cudaStream_t& stream) +{ + constexpr int THREADS_PER_VALUE = threads_per_value_t::value; + constexpr bool DO_CROSS_ATTENTION = std::is_same>::value; + int tlength = (DO_CROSS_ATTENTION) ? params.memory_max_len : params.timestep; + if (params.cache_indir == nullptr) { + if (tlength < 32) { + MGQA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 4, THREADS_PER_VALUE, 64, false, stream); + } + else if (tlength < 2048) { + MGQA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 2, THREADS_PER_VALUE, 128, false, stream); + } + else { + MGQA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 1, THREADS_PER_VALUE, 256, false, stream); + } + } + else { + if (tlength < 32) { + MGQA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 4, THREADS_PER_VALUE, 64, true, stream); + } + else if (tlength < 2048) { + MGQA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 2, THREADS_PER_VALUE, 128, true, stream); + } + else { + MGQA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 1, THREADS_PER_VALUE, 256, true, stream); + } + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template void mgqa_launch_kernel>( + const GroupedQuery_attention_params& params, const cudaStream_t& stream); +template void mgqa_launch_kernel>( + const GroupedQuery_attention_params& params, const cudaStream_t& stream); +#ifdef ENABLE_BF16 +template void mgqa_launch_kernel<__nv_bfloat16, 80, 128, GroupedQuery_attention_params<__nv_bfloat16>>( + const GroupedQuery_attention_params<__nv_bfloat16>& params, const cudaStream_t& stream); +#endif +#ifdef ENABLE_FP8 +template void mgqa_launch_kernel<__nv_fp8_e4m3, 80, 128, GroupedQuery_attention_params<__nv_fp8_e4m3>>( + const GroupedQuery_attention_params<__nv_fp8_e4m3>& params, const cudaStream_t& stream); +#endif + +#undef MGQA_LAUNCH_KERNEL diff --git a/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_groupedquery_attention_96.cu b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_groupedquery_attention_96.cu new file mode 100644 index 000000000..c8a978952 --- /dev/null +++ b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_groupedquery_attention_96.cu @@ -0,0 +1,87 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "decoder_masked_groupedquery_attention_template.hpp" +#include "src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention.h" +#include "src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h" +#include "src/fastertransformer/utils/cuda_bf16_wrapper.h" +#include +#include +#include + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#define MGQA_LAUNCH_KERNEL( \ + T, Dh, Dh_MAX, THDS_PER_KEY, THDS_PER_VALUE, THDS_PER_BLOCK, HAS_BEAMS, stream) \ + size_t smem_sz = mmha::smem_size_in_bytes(params, THDS_PER_VALUE, THDS_PER_BLOCK); \ + dim3 grid(params.num_heads, params.batch_size); \ + mmha::masked_groupedquery_attention_kernel<<>>(params) + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// !!! Specialize the launcher for Cross attention +template +void mgqa_launch_kernel(const KERNEL_PARAMS_TYPE& params, const cudaStream_t& stream) +{ + constexpr int THREADS_PER_VALUE = threads_per_value_t::value; + constexpr bool DO_CROSS_ATTENTION = std::is_same>::value; + int tlength = (DO_CROSS_ATTENTION) ? params.memory_max_len : params.timestep; + if (params.cache_indir == nullptr) { + if (tlength < 32) { + MGQA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 4, THREADS_PER_VALUE, 64, false, stream); + } + else if (tlength < 2048) { + MGQA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 2, THREADS_PER_VALUE, 128, false, stream); + } + else { + MGQA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 1, THREADS_PER_VALUE, 256, false, stream); + } + } + else { + if (tlength < 32) { + MGQA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 4, THREADS_PER_VALUE, 64, true, stream); + } + else if (tlength < 2048) { + MGQA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 2, THREADS_PER_VALUE, 128, true, stream); + } + else { + MGQA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 1, THREADS_PER_VALUE, 256, true, stream); + } + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template void mgqa_launch_kernel>( + const GroupedQuery_attention_params& params, const cudaStream_t& stream); +template void mgqa_launch_kernel>( + const GroupedQuery_attention_params& params, const cudaStream_t& stream); +#ifdef ENABLE_BF16 +template void mgqa_launch_kernel<__nv_bfloat16, 96, 128, GroupedQuery_attention_params<__nv_bfloat16>>( + const GroupedQuery_attention_params<__nv_bfloat16>& params, const cudaStream_t& stream); +#endif +#ifdef ENABLE_FP8 +template void mgqa_launch_kernel<__nv_fp8_e4m3, 96, 128, GroupedQuery_attention_params<__nv_fp8_e4m3>>( + const GroupedQuery_attention_params<__nv_fp8_e4m3>& params, const cudaStream_t& stream); +#endif + +#undef MGQA_LAUNCH_KERNEL diff --git a/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_groupedquery_attention_template.hpp b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_groupedquery_attention_template.hpp new file mode 100644 index 000000000..581d566ca --- /dev/null +++ b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_groupedquery_attention_template.hpp @@ -0,0 +1,1878 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include "src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention.h" +#include "src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h" +#include "src/fastertransformer/utils/cuda_bf16_wrapper.h" +#include "src/fastertransformer/utils/cuda_fp8_utils.h" +#include "src/fastertransformer/utils/cuda_type_utils.cuh" +#include +#include +#include + +// #define MMHA_USE_HMMA_FOR_REDUCTION + +// Below are knobs to extend FP32 accumulation for higher FP16 accuracy + +// Does not seem to affect the accuracy that much +// #define MMHA_USE_FP32_ACUM_FOR_FMA + +// Seems to slightly improve the accuracy +#define MMHA_USE_FP32_ACUM_FOR_OUT + +#if 0 && defined(MMHA_USE_FP32_ACUM_FOR_OUT) + // Does not seem to improve the accuracy + //#define MMHA_USE_FP32_ACUM_FOR_LOGITS +#endif + +namespace mmha { + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// +// We use the following terminology to describe the different dimensions. +// +// B: Batch size (number of sequences), +// L: Sequence length, +// D: Hidden dimension, +// H: Number of heads, +// Dh: Hidden dimension per head - Dh = D / H. +// +// The different kernels assign a threadblock for B x H pair. The grid has size (1, B, H). We use +// 64, 128 and 256 threads per block. +// +// Each threadblock loads Dh values from Q and its associated bias. The kernels run a loop to +// compute Q * K^T where K is loaded from a cache buffer -- except for the current timestep. The +// cache buffer helps with memory accesses and contains keys with bias. +// +// The layout of the cache buffer for the keys is [B, H, Dh/x, L, x] where x == 8 for FP16 and +// x == 4 for FP32 where the fastest moving dimension (contiguous data) is the rightmost one. The +// values for x are chosen to create chunks of 16 bytes. +// +// The different kernels use 1, 2 or 4 threads per key (THREADS_PER_KEY). The size of the LDGs +// depends on the number of threads per key. Each thread sums Dh / THREADS_PER_KEY elements. At +// the end of each iteration of the Q * K^T loop, we perform a reduction between lanes using an +// HMMA instruction (Tensor Core). Each Q * K^T valuey is stored in shared memory in FP32. +// +// After that loop, a parallel softmax is computed across the different Q * K^T values stored in +// shared memory. +// +// The kernel ends with a loop over the values in V. We use THREADS_PER_VALUE to control how many +// timesteps are computed by loop iteration. As with the keys, the values are read from a cache +// except for the current timestep. The layout of the cache buffer for the values is much simpler +// as it is [B, H, L, Dh]. +// + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Qk_vec_m_ { +}; + +template<> +struct Qk_vec_m_ { + using Type = float; +}; +template<> +struct Qk_vec_m_ { + using Type = float2; +}; +template<> +struct Qk_vec_m_ { + using Type = float4; +}; +template<> +struct Qk_vec_m_ { + using Type = float4; +}; +template<> +struct Qk_vec_m_ { + using Type = uint32_t; +}; +template<> +struct Qk_vec_m_ { + using Type = uint32_t; +}; +template<> +struct Qk_vec_m_ { + using Type = uint2; +}; +template<> +struct Qk_vec_m_ { + using Type = uint4; +}; +#ifdef ENABLE_BF16 +template<> +struct Qk_vec_m_<__nv_bfloat16, 32> { + using Type = __nv_bfloat162; +}; +template<> +struct Qk_vec_m_<__nv_bfloat16, 64> { + using Type = __nv_bfloat162; +}; +template<> +struct Qk_vec_m_<__nv_bfloat16, 128> { + using Type = bf16_4_t; +}; +template<> +struct Qk_vec_m_<__nv_bfloat16, 256> { + using Type = bf16_8_t; +}; +#endif // ENABLE_BF16 + +#ifdef ENABLE_FP8 +template<> +struct Qk_vec_m_<__nv_fp8_e4m3, 32> { + using Type = fp8_4_t; +}; +template<> +struct Qk_vec_m_<__nv_fp8_e4m3, 64> { + using Type = fp8_4_t; +}; +template<> +struct Qk_vec_m_<__nv_fp8_e4m3, 128> { + using Type = fp8_4_t; +}; +template<> +struct Qk_vec_m_<__nv_fp8_e4m3, 256> { + using Type = fp8_4_t; +}; +#endif // ENABLE_FP8 + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Qk_vec_k_ { + using Type = typename Qk_vec_m_::Type; +}; +#ifdef ENABLE_FP8 +template<> +struct Qk_vec_k_<__nv_fp8_e4m3, 32> { + using Type = float4; +}; +template<> +struct Qk_vec_k_<__nv_fp8_e4m3, 64> { + using Type = float4; +}; +template<> +struct Qk_vec_k_<__nv_fp8_e4m3, 128> { + using Type = float4; +}; +template<> +struct Qk_vec_k_<__nv_fp8_e4m3, 256> { + using Type = float4; +}; +#endif // ENABLE_FP8 + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct K_vec_m_ { +}; + +template<> +struct K_vec_m_ { + using Type = float; +}; +template<> +struct K_vec_m_ { + using Type = float2; +}; +template<> +struct K_vec_m_ { + using Type = float4; +}; +template<> +struct K_vec_m_ { + using Type = uint32_t; +}; +template<> +struct K_vec_m_ { + using Type = uint2; +}; +template<> +struct K_vec_m_ { + using Type = uint4; +}; +#ifdef ENABLE_BF16 +template<> +struct K_vec_m_<__nv_bfloat16, 4> { + using Type = __nv_bfloat162; +}; +template<> +struct K_vec_m_<__nv_bfloat16, 2> { + using Type = bf16_4_t; +}; +template<> +struct K_vec_m_<__nv_bfloat16, 1> { + using Type = bf16_8_t; +}; +#endif // ENABLE_BF16 + +// NOTE: THREADS_PER_KEY * sizeof(K_vec_m_) = 128 bytes +#ifdef ENABLE_FP8 +template<> +struct K_vec_m_<__nv_fp8_e4m3, 4> { + using Type = fp8_4_t; +}; +template<> +struct K_vec_m_<__nv_fp8_e4m3, 2> { + using Type = fp8_4_t; +}; // Defined for compilation-purpose only, do not use +template<> +struct K_vec_m_<__nv_fp8_e4m3, 1> { + using Type = fp8_4_t; +}; // Defined for compilation-purpose only, do not use +#endif // ENABLE_FP8 + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct K_vec_k_ { + using Type = typename K_vec_m_::Type; +}; +#ifdef ENABLE_FP8 +template<> +struct K_vec_k_<__nv_fp8_e4m3, 4> { + using Type = float4; +}; +template<> +struct K_vec_k_<__nv_fp8_e4m3, 2> { + using Type = float4; +}; // Defined for compilation-purpose only, do not use +template<> +struct K_vec_k_<__nv_fp8_e4m3, 1> { + using Type = float4; +}; // Defined for compilation-purpose only, do not use +#endif // ENABLE_FP8 + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct V_vec_m_ { +}; + +template<> +struct V_vec_m_ { + using Type = float; +}; +template<> +struct V_vec_m_ { + using Type = float2; +}; +template<> +struct V_vec_m_ { + using Type = float4; +}; +template<> +struct V_vec_m_ { + using Type = uint32_t; +}; +template<> +struct V_vec_m_ { + using Type = uint2; +}; +template<> +struct V_vec_m_ { + using Type = uint4; +}; +#ifdef ENABLE_BF16 +template<> +struct V_vec_m_<__nv_bfloat16, 2> { + using Type = __nv_bfloat162; +}; +template<> +struct V_vec_m_<__nv_bfloat16, 4> { + using Type = bf16_4_t; +}; +template<> +struct V_vec_m_<__nv_bfloat16, 8> { + using Type = bf16_8_t; +}; +#endif // ENABLE_BF16 +#ifdef ENABLE_FP8 +template<> +struct V_vec_m_<__nv_fp8_e4m3, 4> { + using Type = fp8_4_t; +}; +template<> +struct V_vec_m_<__nv_fp8_e4m3, 8> { + using Type = fp8_4_t; +}; +template<> +struct V_vec_m_<__nv_fp8_e4m3, 16> { + using Type = fp8_4_t; +}; +#endif // ENABLE_FP8 + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct V_vec_k_ { + using Type = typename V_vec_m_::Type; +}; +#ifdef ENABLE_FP8 +template<> +struct V_vec_k_<__nv_fp8_e4m3, 4> { + using Type = float4; +}; +template<> +struct V_vec_k_<__nv_fp8_e4m3, 8> { + using Type = float4; +}; +template<> +struct V_vec_k_<__nv_fp8_e4m3, 16> { + using Type = float4; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#ifdef MMHA_USE_FP32_ACUM_FOR_FMA +template +struct Qk_vec_acum_fp32_ { +}; + +template<> +struct Qk_vec_acum_fp32_ { + using Type = float; +}; +template<> +struct Qk_vec_acum_fp32_ { + using Type = float2; +}; +template<> +struct Qk_vec_acum_fp32_ { + using Type = float4; +}; +// template<> struct Qk_vec_acum_fp32_ { using Type = float; }; +template<> +struct Qk_vec_acum_fp32_ { + using Type = float2; +}; +template<> +struct Qk_vec_acum_fp32_ { + using Type = Float4_; +}; +template<> +struct Qk_vec_acum_fp32_ { + using Type = Float8_; +}; +template<> +struct Qk_vec_acum_fp32_<__nv_bfloat16> { + using Type = float; +}; +template<> +struct Qk_vec_acum_fp32_<__nv_bfloat162> { + using Type = float2; +}; +template<> +struct Qk_vec_acum_fp32_ { + using Type = Float4_; +}; +template<> +struct Qk_vec_acum_fp32_ { + using Type = Float8_; +}; + +template<> +struct Qk_vec_acum_fp32_ { + using Type = Float8_; +}; +template<> +struct Qk_vec_acum_fp32_<__nv_bfloat16> { + using Type = float; +}; +template<> +struct Qk_vec_acum_fp32_<__nv_bfloat162> { + using Type = float2; +}; +template<> +struct Qk_vec_acum_fp32_ { + using Type = Float4_; +}; +template<> +struct Qk_vec_acum_fp32_ { + using Type = Float8_; +}; +#ifdef ENABLE_FP8 +// template<> +// struct Qk_vec_acum_fp32_ { +// using Type = float2; +// }; +template<> +struct Qk_vec_acum_fp32_ { + using Type = Float4_; +}; +// template<> +// struct Qk_vec_acum_fp32_ { +// using Type = Float4_; +// }; +#endif // ENABLE_FP8 + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct K_vec_acum_fp32_ { +}; + +template<> +struct K_vec_acum_fp32_ { + using Type = float; +}; +template<> +struct K_vec_acum_fp32_ { + using Type = float2; +}; +template<> +struct K_vec_acum_fp32_ { + using Type = float4; +}; +template<> +struct K_vec_acum_fp32_ { + using Type = float2; +}; +template<> +struct K_vec_acum_fp32_ { + using Type = Float4_; +}; +template<> +struct K_vec_acum_fp32_ { + using Type = Float8_; +}; +template<> +struct K_vec_acum_fp32_<__nv_bfloat16> { + using Type = float; +}; +template<> +struct K_vec_acum_fp32_<__nv_bfloat162> { + using Type = float2; +}; +template<> +struct K_vec_acum_fp32_ { + using Type = Float4_; +}; +template<> +struct K_vec_acum_fp32_ { + using Type = Float8_; +}; +#ifdef ENABLE_FP8 +// template<> +// struct K_vec_acum_fp32_ { +// using Type = float2; +// }; +template<> +struct K_vec_acum_fp32_ { + using Type = Float4_; +}; +// template<> +// struct K_vec_acum_fp32_ { +// using Type = Float4_; +// }; +#endif // ENABLE_FP8 +#endif // MMHA_USE_FP32_ACUM_FOR_FMA + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#ifdef MMHA_USE_FP32_ACUM_FOR_OUT +template +struct V_vec_acum_fp32_ { +}; + +template<> +struct V_vec_acum_fp32_ { + using Type = float; +}; +template<> +struct V_vec_acum_fp32_ { + using Type = float2; +}; +template<> +struct V_vec_acum_fp32_ { + using Type = float4; +}; +template<> +struct V_vec_acum_fp32_ { + using Type = float2; +}; +template<> +struct V_vec_acum_fp32_ { + using Type = Float4_; +}; +template<> +struct V_vec_acum_fp32_ { + using Type = Float8_; +}; +#ifdef ENABLE_BF16 +template<> +struct V_vec_acum_fp32_<__nv_bfloat162> { + using Type = float2; +}; +template<> +struct V_vec_acum_fp32_ { + using Type = Float4_; +}; +template<> +struct V_vec_acum_fp32_ { + using Type = Float8_; +}; +#endif // ENABLE_BF16 +#ifdef ENABLE_FP8 +// template<> +// struct V_vec_acum_fp32_ { +// using Type = float2; +// }; +template<> +struct V_vec_acum_fp32_ { + using Type = Float4_; +}; +// template<> +// struct V_vec_acum_fp32_ { +// using Type = Float4_; +// }; +#endif // ENABLE_FP8 +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +__inline__ __device__ Tout vec_conversion(const Tin& x) +{ + return x; +} +#ifdef ENABLE_FP8 +// fp8_t +template<> +__inline__ __device__ float vec_conversion(const __nv_fp8_e4m3& a) +{ + return float(a); +} +template<> +__inline__ __device__ __nv_fp8_e4m3 vec_conversion<__nv_fp8_e4m3, float>(const float& a) +{ + return __nv_fp8_e4m3(a); +} +// fp8_2_t +template<> +__inline__ __device__ float2 vec_conversion(const fp8_2_t& a) +{ + return float2(a); +} +template<> +__inline__ __device__ fp8_2_t vec_conversion(const float2& a) +{ + return fp8_2_t(a); +} +// fp8_4_t +template<> +__inline__ __device__ float4 vec_conversion(const fp8_4_t& a) +{ + return float4(a); +} +template<> +__inline__ __device__ fp8_4_t vec_conversion(const float4& a) +{ + return fp8_4_t(a); +} +#endif // ENABLE_FP8 + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline __device__ float qk_dot_(const K_vec (&q)[N], const K_vec (&k)[N]) +{ +#ifdef MMHA_USE_FP32_ACUM_FOR_FMA + using K_vec_acum = typename K_vec_acum_fp32_::Type; +#else + using K_vec_acum = K_vec; +#endif + // Compute the parallel products for Q*K^T (treat vector lanes separately). + K_vec_acum qk_vec = mul(q[0], k[0]); +#pragma unroll + for (int ii = 1; ii < N; ++ii) { + qk_vec = fma(q[ii], k[ii], qk_vec); + } + + // Finalize the reduction across lanes. + float qk = sum(qk_vec); +#pragma unroll + for (int mask = THREADS_PER_KEY / 2; mask >= 1; mask /= 2) { + qk += __shfl_xor_sync(uint32_t(-1), qk, mask); + } + return qk; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Qk_dot { + template + static inline __device__ float dot(const K_vec (&q)[N], const K_vec (&k)[N]) + { + return qk_dot_(q, k); + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ float4 hmma_fp32(const uint2& a, uint32_t b) +{ + float4 c; + float zero = 0.f; + asm volatile("mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 \n" + " {%0, %1, %2, %3}, \n" + " {%4, %5}, \n" + " {%6}, \n" + " {%7, %7, %7, %7}; \n" + + : "=f"(c.x), "=f"(c.y), "=f"(c.z), "=f"(c.w) + : "r"(a.x), "r"(a.y), "r"(b), "f"(zero)); + return c; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline __device__ float qk_hmma_dot_(const uint32_t (&q)[N], const uint32_t (&k)[N]) +{ +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 750 +#ifdef MMHA_USE_FP32_ACUM_FOR_FMA + using K_vec_acum = typename K_vec_acum_fp32_::Type; +#else + using K_vec_acum = uint32_t; +#endif + K_vec_acum qk_vec = mul(q[0], k[0]); +#pragma unroll + for (int ii = 1; ii < N; ++ii) { + qk_vec = fma(q[ii], k[ii], qk_vec); + } +#ifdef MMHA_USE_FP32_ACUM_FOR_FMA + uint32_t qk_vec_ = float2_to_half2(qk_vec); + return hmma_fp32(make_uint2(qk_vec_, 0u), 0x3c003c00u).x; +#else + return hmma_fp32(make_uint2(qk_vec, 0u), 0x3c003c00u).x; +#endif +#else + return 0.f; +#endif +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template<> +struct Qk_dot { + template + static inline __device__ float dot(const uint32_t (&q)[N], const uint32_t (&k)[N]) + { +#if __CUDA_ARCH__ >= 750 && defined(MMHA_USE_HMMA_FOR_REDUCTION) + return qk_hmma_dot_(q, k); +#else + return qk_dot_<4>(q, k); +#endif // defined MMHA_USE_HMMA_FOR_REDUCTION + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline __device__ float block_sum(float* red_smem, float sum) +{ + + // Decompose the thread index into warp / lane. + int warp = threadIdx.x / WARP_SIZE; + int lane = threadIdx.x % WARP_SIZE; + +// Compute the sum per warp. +#pragma unroll + for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) { + sum += __shfl_xor_sync(uint32_t(-1), sum, mask); + } + + // Warp leaders store the data to shared memory. + if (lane == 0) { + red_smem[warp] = sum; + } + + // Make sure the data is in shared memory. + __syncthreads(); + + // The warps compute the final sums. + if (lane < WARPS_PER_BLOCK) { + sum = red_smem[lane]; + } + +// Parallel reduction inside the warp. +#pragma unroll + for (int mask = WARPS_PER_BLOCK / 2; mask >= 1; mask /= 2) { + sum += __shfl_xor_sync(uint32_t(-1), sum, mask); + } + + // Broadcast to other threads. + return __shfl_sync(uint32_t(-1), sum, 0); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ void convert_from_float(float& dst, float src) +{ + dst = src; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ void convert_from_float(uint16_t& dst, float src) +{ + dst = float_to_half(src); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ void convert_from_float(uint32_t& dst, float2 src) +{ + dst = float2_to_half2(src); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// +#ifdef ENABLE_BF16 +inline __device__ void convert_from_float(__nv_bfloat16& dst, float src) +{ + dst = __float2bfloat16(src); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ void convert_from_float(__nv_bfloat162& dst, float2 src) +{ +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 + dst = __float22bfloat162_rn(src); +#else + dst = __floats2bfloat162_rn(src.x, src.y); +#endif +} +#endif // ENABLE_BF16 +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ void convert_from_float(uint2& dst, Float4_ src) +{ + dst.x = float2_to_half2(src.x); + dst.y = float2_to_half2(src.y); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ void convert_from_float(uint2& dst, float4 src) +{ + convert_from_float(dst, Float4_{make_float2(src.x, src.y), make_float2(src.z, src.w)}); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ void convert_from_float(uint4& dst, Float8_ src) +{ + dst.x = float2_to_half2(src.x); + dst.y = float2_to_half2(src.y); + dst.z = float2_to_half2(src.z); + dst.w = float2_to_half2(src.w); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#ifdef ENABLE_BF16 +inline __device__ void convert_from_float(bf16_4_t& dst, Float4_ src) +{ +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 + dst.x = __float22bfloat162_rn(src.x); + dst.y = __float22bfloat162_rn(src.y); +#else + dst.x = __floats2bfloat162_rn(src.x.x, src.x.y); + dst.y = __floats2bfloat162_rn(src.y.x, src.y.y); +#endif +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ void convert_from_float(bf16_4_t& dst, float4 src) +{ + convert_from_float(dst, Float4_{make_float2(src.x, src.y), make_float2(src.z, src.w)}); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ void convert_from_float(bf16_8_t& dst, Float8_ src) +{ +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 + dst.x = __float22bfloat162_rn(src.x); + dst.y = __float22bfloat162_rn(src.y); + dst.z = __float22bfloat162_rn(src.z); + dst.w = __float22bfloat162_rn(src.w); +#else + dst.x = __floats2bfloat162_rn(src.x.x, src.x.y); + dst.y = __floats2bfloat162_rn(src.y.x, src.y.y); + dst.z = __floats2bfloat162_rn(src.z.x, src.z.y); + dst.w = __floats2bfloat162_rn(src.w.x, src.w.y); +#endif +} +#endif // ENABLE_BF16 + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#ifdef ENABLE_FP8 +inline __device__ void convert_from_float(fp8_4_t& dst, float4 src) +{ + dst = fp8_4_t(src); +} +inline __device__ void convert_from_float(fp8_2_t& dst, float2 src) +{ + dst = fp8_2_t(src); +} +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ void convert_from_float(float2& dst, float2 src) +{ + dst = src; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ void convert_from_float(float4& dst, float4 src) +{ + dst = src; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ float convert_to_float(float4 u) +{ + return u.x; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ float convert_to_float(uint4 u) +{ + float2 tmp = half2_to_float2(u.x); + return tmp.x; +} + +#if defined(MMHA_USE_FP32_ACUM_FOR_LOGITS) + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ float cast_to_float(float u) +{ + return u; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ float2 cast_to_float(float2 u) +{ + return u; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ float4 cast_to_float(float4 u) +{ + return u; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ Float4_ cast_to_float(Float4_ u) +{ + return u; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ Float8_ cast_to_float(Float8_ u) +{ + return u; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ float2 cast_to_float(uint32_t u) +{ + return half2_to_float2(u); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ Float4_ cast_to_float(uint2 u) +{ + Float4_ tmp; + tmp.x = half2_to_float2(u.x); + tmp.y = half2_to_float2(u.y); + return tmp; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ Float8_ cast_to_float(uint4 u) +{ + Float8_ tmp; + tmp.x = half2_to_float2(u.x); + tmp.y = half2_to_float2(u.y); + tmp.z = half2_to_float2(u.z); + tmp.w = half2_to_float2(u.w); + return tmp; +} + +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ float float_from_int8(int8_t u) +{ + return u; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ float2 float_from_int8(int16_t u) +{ + union { + int16_t int16; + int8_t int8[2]; + }; + int16 = u; + return make_float2(int8[0], int8[1]); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ float4 float_from_int8(int32_t u) +{ + union { + int32_t int32; + int8_t int8[4]; + }; + int32 = u; + return make_float4(int8[0], int8[1], int8[2], int8[3]); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// clang-format off +inline __device__ Float8_ float_from_int8(int64_t u) +{ + union { + int64_t int64; + int16_t int16[4]; + }; + int64 = u; + return Float8_ {float_from_int8(int16[0]), + float_from_int8(int16[1]), + float_from_int8(int16[2]), + float_from_int8(int16[3])}; +} +// clang-format on + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ int8_t cast_to_int8(float val) +{ + union { + int8_t int8[2]; + int16_t int16; + }; + asm volatile("cvt.rni.sat.s8.f32 %0, %1;" : "=h"(int16) : "f"(val)); + return int8[0]; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ int32_t cast_to_int8(float4 val) +{ + union { + int8_t int8[4]; + int32_t int32; + }; + int8[0] = cast_to_int8(val.x); + int8[1] = cast_to_int8(val.y); + int8[2] = cast_to_int8(val.z); + int8[3] = cast_to_int8(val.w); + return int32; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ int64_t cast_to_int8(Float8_ val) +{ + union { + int8_t int8[8]; + int64_t int64; + }; + int8[0] = cast_to_int8(val.x.x); + int8[1] = cast_to_int8(val.x.y); + int8[2] = cast_to_int8(val.y.x); + int8[3] = cast_to_int8(val.y.y); + int8[4] = cast_to_int8(val.z.x); + int8[5] = cast_to_int8(val.z.y); + int8[6] = cast_to_int8(val.w.x); + int8[7] = cast_to_int8(val.w.y); + return int64; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline __device__ __host__ T div_up(T m, T n) +{ + return (m + n - 1) / n; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct kernel_type_t { + using Type = T; +}; + +#ifdef ENABLE_FP8 +template<> +struct kernel_type_t<__nv_fp8_e4m3> { + using Type = float; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline size_t smem_size_in_bytes(const GroupedQuery_attention_params& params, + int threads_per_value, + int threads_per_block) +{ + using Tk = typename kernel_type_t::Type; + // The amount of shared memory needed to store the Q*K^T values in float. + const int max_timesteps = min(params.timestep, params.memory_max_len); + size_t qk_sz = div_up(max_timesteps + 1, 4) * 16; + + // The extra memory needed if we are not using floats for the final logits. + size_t logits_sz = 0; +#ifndef MMHA_USE_FP32_ACUM_FOR_LOGITS + if (sizeof(Tk) != 4) { + // TDOD + logits_sz = div_up(max_timesteps + 1, 4) * 4 * sizeof(Tk); + } +#endif + + // The total size needed during softmax. + size_t softmax_sz = qk_sz + logits_sz; + + // The number of partial rows to reduce in the final reduction. + int rows_per_red = threads_per_block / threads_per_value; + // The amount of storage needed to finalize the outputs. + size_t red_sz = rows_per_red * params.hidden_size_per_head * sizeof(Tk) / 2; + + size_t transpose_rotary_size = 0; + if (params.rotary_embedding_dim > 0 && params.neox_rotary_style) { + transpose_rotary_size = 2 * params.rotary_embedding_dim * sizeof(Tk); + } + + // The max. + return max(max(softmax_sz, red_sz), transpose_rotary_size); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ constexpr uint32_t shfl_mask(int threads) +{ + return threads == 32 ? uint32_t(-1) : (1u << threads) - 1u; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template< + // The type of the inputs. Supported types: float and half. + typename T, + // The hidden dimension per head. + int Dh, + int Dh_MAX, + // The number of threads per key. + int THREADS_PER_KEY, + // The number of threads per value. + int THREADS_PER_VALUE, + // The number of threads in a threadblock. + int THREADS_PER_BLOCK, + bool HAS_BEAMS> +__global__ void masked_groupedquery_attention_kernel(GroupedQuery_attention_params params) +{ + using Tk = typename kernel_type_t::Type; +#ifdef ENABLE_FP8 + // FP8 MHA Scales + constexpr bool FP8_MHA_KERNEL = std::is_same::value; +#else + constexpr bool FP8_MHA_KERNEL = false; +#endif + // Make sure the hidden dimension per head is a multiple of the number of threads per key. + static_assert(Dh_MAX % THREADS_PER_KEY == 0, ""); + // Make sure the hidden dimension per head is a multiple of the number of threads per value. + static_assert(Dh_MAX % THREADS_PER_VALUE == 0, ""); + + // The size of a warp. + constexpr int WARP_SIZE = 32; + // The number of warps in a threadblock. + constexpr int WARPS_PER_BLOCK = THREADS_PER_BLOCK / WARP_SIZE; + + // Use smem_size_in_bytes (above) to determine the amount of shared memory. + extern __shared__ char smem_[]; + + // The shared memory for the Q*K^T values and partial logits in softmax. + float* qk_smem = reinterpret_cast(smem_); + + // The shared memory for the logits. For FP32, that's the same buffer as qk_smem. + char* logits_smem_ = smem_; +#ifndef MMHA_USE_FP32_ACUM_FOR_LOGITS + if (sizeof(Tk) != 4) { + // TODO - change to tlength + const int max_timesteps = min(params.timestep, params.memory_max_len); + logits_smem_ += div_up(max_timesteps + 1, 4) * 16; + } + Tk* logits_smem = reinterpret_cast(logits_smem_); +#else + float* logits_smem = reinterpret_cast(logits_smem_); +#endif + + // The shared memory to do the final reduction for the output values. Reuse qk_smem. + Tk* out_smem = reinterpret_cast(smem_); + + // The shared memory buffers for the block-wide reductions. One for max, one for sum. + __shared__ float red_smem[WARPS_PER_BLOCK * 2]; + + // A vector of Q or K elements for the current timestep. + using Qk_vec_k = typename Qk_vec_k_::Type; // with kernel-used precision + using Qk_vec_m = typename Qk_vec_m_::Type; // with memory-used precision + + // Use alignment for safely casting the shared buffers as Qk_vec_k. + // Shared memory to store Q inputs. + __shared__ __align__(sizeof(Qk_vec_k)) Tk q_smem[Dh_MAX]; + + // The number of elements per vector. + constexpr int QK_VEC_SIZE = sizeof(Qk_vec_m) / sizeof(T); + // Make sure the hidden size per head is a multiple of the vector size. + static_assert(Dh_MAX % QK_VEC_SIZE == 0, ""); + // We will use block wide reduction if needed + // static_assert(Dh_MAX / QK_VEC_SIZE <= WARP_SIZE, ""); + // The number of vectors per warp. + constexpr int QK_VECS_PER_WARP = Dh_MAX / QK_VEC_SIZE; + + // The layout of the cache is [B, H, Dh/x, L, x] with x == 4/8/16 for FP32/FP16/FP8. Since each thread + // owns x elements, we have to decompose the linear index into chunks of x values and the posi- + // tion of the thread in that chunk. + + // The number of elements in a chunk of 16B (that's the x in the above formula). + constexpr int QK_ELTS_IN_16B = 16 / sizeof(T); + // The number of K vectors in 16B. + constexpr int QK_VECS_IN_16B = 16 / sizeof(Qk_vec_m); + + // The batch/beam idx + const int bi = blockIdx.y; + if (params.finished != nullptr && params.finished[bi] == true) { + return; + } + // The beam idx + const int beami = bi % params.beam_width; + // The "beam-aware" batch idx + const int bbi = bi / params.beam_width; + const int head_n_rep = params.num_heads / params.num_kv_heads; + // const int head_n_rep = 1; + // The head. + const int hi = blockIdx.x; + const int kvhi = hi / head_n_rep; + // Combine the batch and the head indices. + const int bhi = bi * params.num_heads + hi; + const int bkvhi = bi * params.num_kv_heads + kvhi; + // Combine the "beam-aware" batch idx and the head indices. + const int bbhi = bbi * params.beam_width * params.num_heads + hi; + const int bbkvhi = bbi * params.beam_width * params.num_kv_heads + kvhi; + // The thread in the block. + const int tidx = threadIdx.x; + + constexpr bool handle_kv = true; + + // here. + + // While doing the product Q*K^T for the different keys we track the max. + float qk_max = -FLT_MAX; + + float qk = 0.0F; + + int qkv_base_offset = (params.stride == 0) ? bhi * Dh : bi * params.stride + hi * Dh; + + const size_t bi_seq_len_offset = bi * params.memory_max_len; + + int tlength = (params.length_per_sample == nullptr) ? + params.timestep : + params.length_per_sample[bi] + params.max_prefix_prompt_length; + const int first_step = max(0, tlength + 1 - params.memory_max_len); + const int tlength_circ = tlength % params.memory_max_len; + + // First QK_VECS_PER_WARP load Q and K + the bias values for the current timestep. + const bool is_masked = tidx >= QK_VECS_PER_WARP; + + // The offset in the Q and K buffer also accounts for the batch. + int qk_offset = qkv_base_offset + tidx * QK_VEC_SIZE; + // The offset in the bias buffer. + int qk_bias_offset = hi * Dh + tidx * QK_VEC_SIZE; + + const bool do_ia3 = handle_kv && params.ia3_tasks != nullptr; + const int ia3_task_id = do_ia3 ? params.ia3_tasks[bbi] : 0; + + // Trigger the loads from the Q and K buffers. + Qk_vec_k q; + zero(q); + if (!is_masked && (Dh == Dh_MAX || tidx * QK_VEC_SIZE < Dh)) { + if (params.int8_mode == 2) { + using Packed_Int8_t = typename packed_type::value>::type; + using Packed_Float_t = typename packed_type::value>::type; + const auto q_scaling = params.qkv_scale_out[0]; + const auto q_quant = + *reinterpret_cast(&reinterpret_cast(params.q)[qk_offset]); + + convert_from_float(q, mul(q_scaling, float_from_int8(q_quant))); + } + else { + q = vec_conversion(*reinterpret_cast(¶ms.q[qk_offset])); + } + } + + Qk_vec_k k; + zero(k); + if (params.int8_mode == 2) { + using Packed_Int8_t = typename packed_type::value>::type; + using Packed_Float_t = typename packed_type::value>::type; + const auto k_scaling = params.qkv_scale_out[1]; + const auto k_quant = + *reinterpret_cast(&reinterpret_cast(params.k)[qk_offset]); + + convert_from_float(k, mul(k_scaling, float_from_int8(k_quant))); + } + else { + k = !is_masked && (Dh == Dh_MAX || tidx * QK_VEC_SIZE < Dh) ? + vec_conversion(*reinterpret_cast(¶ms.k[qk_offset])) : + k; + } + + // Trigger the loads from the Q and K bias buffers. + Qk_vec_k q_bias; + zero(q_bias); + q_bias = + (!is_masked && Dh == Dh_MAX || tidx * QK_VEC_SIZE < Dh) && params.q_bias != nullptr ? + vec_conversion(*reinterpret_cast(¶ms.q_bias[qk_bias_offset])) : + q_bias; + + Qk_vec_k k_bias; + zero(k_bias); + if (handle_kv) { + k_bias = + !is_masked && (Dh == Dh_MAX || tidx * QK_VEC_SIZE < Dh) && params.k_bias != nullptr ? + vec_conversion(*reinterpret_cast(¶ms.k_bias[qk_bias_offset])) : + k_bias; + } + + // Computes the Q/K values with bias. + q = add(q, q_bias); + if (handle_kv) { + k = add(k, k_bias); + } + if (do_ia3 && !is_masked) { + k = mul( + k, + vec_conversion(*reinterpret_cast( + ¶ms.ia3_key_weights[(ia3_task_id * params.num_heads + hi) * Dh + tidx * QK_VEC_SIZE]))); + } + + // Padded len + const int padd_len = (params.total_padding_tokens == nullptr) ? 0 : params.total_padding_tokens[bi]; + if (params.rotary_embedding_dim > 0 && !params.neox_rotary_style) { + if (handle_kv) { + apply_rotary_embedding(q, k, tidx, params.rotary_embedding_dim, params.timestep - padd_len); + } + else { + apply_rotary_embedding(q, tidx, params.rotary_embedding_dim, params.timestep - padd_len); + } + } + else if (params.rotary_embedding_dim > 0 && params.neox_rotary_style) { + const bool do_rotary = !is_masked && QK_VEC_SIZE * tidx < params.rotary_embedding_dim; + + T* q_smem = reinterpret_cast(smem_); + T* k_smem = q_smem + params.rotary_embedding_dim; + + const int half_rotary_dim = params.rotary_embedding_dim / 2; + const int half_idx = (tidx * QK_VEC_SIZE) / half_rotary_dim; + const int intra_half_idx = (tidx * QK_VEC_SIZE) % half_rotary_dim; + const int smem_pitch = half_rotary_dim; // TODO: adjust for bank conflicts + + assert(half_rotary_dim % QK_VEC_SIZE == 0); + + if (do_rotary) { + *reinterpret_cast(q_smem + half_idx * smem_pitch + intra_half_idx) = q; + + if (handle_kv) { + *reinterpret_cast(k_smem + half_idx * smem_pitch + intra_half_idx) = k; + } + } + + __syncthreads(); + + const int transpose_idx = half_idx * (half_rotary_dim / 2) + intra_half_idx / 2; + constexpr int tidx_factor = (QK_VEC_SIZE > 1) ? QK_VEC_SIZE / 2 : 1; + if (do_rotary) { + mmha::vec_from_smem_transpose(q, q_smem, transpose_idx, smem_pitch); + + if (handle_kv) { + mmha::vec_from_smem_transpose(k, k_smem, transpose_idx, smem_pitch); + + mmha::apply_rotary_embedding( + q, k, transpose_idx / tidx_factor, params.rotary_embedding_dim, params.timestep - padd_len); + + mmha::write_smem_transpose(k, k_smem, transpose_idx, smem_pitch); + } + else { + mmha::apply_rotary_embedding( + q, transpose_idx / tidx_factor, params.rotary_embedding_dim, params.timestep); + } + mmha::write_smem_transpose(q, q_smem, transpose_idx, smem_pitch); + } + + __syncthreads(); + + if (do_rotary) { + q = *reinterpret_cast(q_smem + half_idx * smem_pitch + intra_half_idx); + if (handle_kv) { + k = *reinterpret_cast(k_smem + half_idx * smem_pitch + intra_half_idx); + } + } + + __syncthreads(); + } + + if (!is_masked) { + // Store the Q values to shared memory. + *reinterpret_cast(&q_smem[tidx * QK_VEC_SIZE]) = q; + + // Write the K values to the global memory cache. + // + // NOTE: The stores are uncoalesced as we have multiple chunks of 16B spread across the memory + // system. We designed it this way as it allows much better memory loads (and there are many + // more loads) + the stores are really "write and forget" since we won't need the ack before + // the end of the kernel. There's plenty of time for the transactions to complete. + + // The 16B chunk written by the thread. + int co = tidx / QK_VECS_IN_16B; + // The position of the thread in that 16B chunk. + int ci = tidx % QK_VECS_IN_16B * QK_VEC_SIZE; + + // Two chunks are separated by L * x elements. A thread write QK_VEC_SIZE elements. + // int offset = bhi * params.memory_max_len * Dh + co * params.memory_max_len * QK_ELTS_IN_16B + + // // params.timestep*QK_ELTS_IN_16B + + // tlength_circ * QK_ELTS_IN_16B + ci; + int offset = bkvhi * params.memory_max_len * Dh + co * params.memory_max_len * QK_ELTS_IN_16B + + // params.timestep*QK_ELTS_IN_16B + + tlength_circ * QK_ELTS_IN_16B + ci; + + if (handle_kv && bhi%head_n_rep==0) { + // Trigger the stores to global memory. + if (Dh == Dh_MAX || co < Dh / QK_ELTS_IN_16B) { + *reinterpret_cast(¶ms.k_cache[offset]) = vec_conversion(k); + } + } + + // Compute \sum_i Q[i] * K^T[i] for the current timestep. +#ifdef MMHA_USE_FP32_ACUM_FOR_FMA + using Qk_vec_acum = typename Qk_vec_acum_fp32_::Type; +#else + using Qk_vec_acum = Qk_vec_k; +#endif + qk = dot(q, k); + if (QK_VECS_PER_WARP <= WARP_SIZE) { +#pragma unroll + for (int mask = QK_VECS_PER_WARP / 2; mask >= 1; mask /= 2) { + qk += __shfl_xor_sync(shfl_mask(QK_VECS_PER_WARP), qk, mask); + } + } + } + + if (QK_VECS_PER_WARP > WARP_SIZE) { + constexpr int WARPS_PER_RED = (QK_VECS_PER_WARP + WARP_SIZE - 1) / WARP_SIZE; + qk = block_sum(&red_smem[WARPS_PER_RED], qk); + } + + // Store that value in shared memory. Keep the Q*K^T value in register for softmax. + if (tidx == 0) { + // Normalize qk. + qk *= params.inv_sqrt_dh; + if (params.relative_attention_bias != nullptr) { + qk = add(qk, + params.relative_attention_bias[hi * params.relative_attention_bias_stride + * params.relative_attention_bias_stride + + (tlength - padd_len) * params.relative_attention_bias_stride + + (tlength - padd_len)]); + } + // We don't need to apply the linear position bias here since qi - ki = 0 yields the position bias 0. + + qk_max = qk; + qk_smem[tlength - first_step] = qk; + // qk_smem[params.timestep] = qk; + } + + // Make sure the data is in shared memory. + __syncthreads(); + + // The type of queries and keys for the math in the Q*K^T product. + using K_vec_k = typename K_vec_k_::Type; + using K_vec_m = typename K_vec_m_::Type; + // The number of elements per vector. + constexpr int K_VEC_SIZE = sizeof(K_vec_m) / sizeof(T); + // Make sure the hidden size per head is a multiple of the vector size. + static_assert(Dh_MAX % K_VEC_SIZE == 0, ""); + // The number of elements per thread. + constexpr int K_ELTS_PER_THREAD = Dh_MAX / THREADS_PER_KEY; + // The number of vectors per thread. + constexpr int K_VECS_PER_THREAD = K_ELTS_PER_THREAD / K_VEC_SIZE; + + // The position the first key loaded by each thread from the cache buffer (for this B * H). + int ko = tidx / THREADS_PER_KEY; + // The position of the thread in the chunk of keys. + int ki = tidx % THREADS_PER_KEY * K_VEC_SIZE; + + static_assert(Dh_MAX == THREADS_PER_KEY * K_VEC_SIZE * K_VECS_PER_THREAD); + + // Load the Q values from shared memory. The values are reused during the loop on K. + K_vec_k q_vec[K_VECS_PER_THREAD]; +#pragma unroll + for (int ii = 0; ii < K_VECS_PER_THREAD; ++ii) { + q_vec[ii] = *reinterpret_cast(&q_smem[ki + ii * THREADS_PER_KEY * K_VEC_SIZE]); + } + + // The number of timesteps loaded per iteration. + constexpr int K_PER_ITER = THREADS_PER_BLOCK / THREADS_PER_KEY; + // The number of keys per warp. + constexpr int K_PER_WARP = WARP_SIZE / THREADS_PER_KEY; + + // The base pointer for the key in the cache buffer. + // T* k_cache = ¶ms.k_cache[bkvhi * params.memory_max_len * Dh + ki]; + // Base pointer for the beam's batch, before offsetting with indirection buffer + T* k_cache_batch = ¶ms.k_cache[bbkvhi * params.memory_max_len * Dh + ki]; + + // Pick a number of keys to make sure all the threads of a warp enter (due to shfl_sync). + // int ti_end = div_up(params.timestep, K_PER_WARP) * K_PER_WARP; + int ti_end = div_up(tlength - first_step, K_PER_WARP) * K_PER_WARP + first_step; + + // prefix prompt length if has + const int prefix_prompt_length = (params.prefix_prompt_lengths == nullptr) ? 0 : params.prefix_prompt_lengths[bi]; + + // Iterate over the keys/timesteps to compute the various (Q*K^T)_{ti} values. + const int* beam_indices = HAS_BEAMS ? ¶ms.cache_indir[bi_seq_len_offset] : nullptr; + + for (int ti = first_step + ko; ti < ti_end; ti += K_PER_ITER) { + const int ti_circ = ti % params.memory_max_len; + bool is_mask = (params.masked_tokens != nullptr) && params.masked_tokens[bi_seq_len_offset + ti]; + + // The keys loaded from the key cache. + K_vec_k k[K_VECS_PER_THREAD]; + K_vec_k k_vec_zero; + zero(k_vec_zero); +#pragma unroll + for (int ii = 0; ii < K_VECS_PER_THREAD; ++ii) { + int jj = ii * params.memory_max_len + ti_circ; + // if( ti < params.timestep ) { + const bool within_bounds = (Dh == Dh_MAX || jj * QK_ELTS_IN_16B < Dh * params.memory_max_len); + if (ti < tlength) { + if (!within_bounds) { + k[ii] = k_vec_zero; + } + else { + if (HAS_BEAMS) { + // const int beam_offset = beam_indices[ti_circ] * params.num_heads * params.memory_max_len * Dh; + const int beam_offset = beam_indices[ti_circ] * params.num_kv_heads * params.memory_max_len * Dh; + k[ii] = vec_conversion( + (*reinterpret_cast(&k_cache_batch[beam_offset + jj * QK_ELTS_IN_16B]))); + } + else { + k[ii] = vec_conversion( + (*reinterpret_cast(&k_cache_batch[jj * QK_ELTS_IN_16B]))); + } + } + } + } + + // Perform the dot product and normalize qk. + // + // WARNING: ALL THE THREADS OF A WARP MUST ENTER!!! + float qk = Qk_dot::dot(q_vec, k) * params.inv_sqrt_dh; + + // Store the product to shared memory. There's one qk value per timestep. Update the max. + // if( ti < params.timestep && tidx % THREADS_PER_KEY == 0 ) { + if (ti < tlength && tidx % THREADS_PER_KEY == 0) { + if (params.relative_attention_bias != nullptr) { + qk = add(qk, + params.relative_attention_bias[hi * params.relative_attention_bias_stride + * params.relative_attention_bias_stride + + tlength * params.relative_attention_bias_stride + ti]); + } + if (params.linear_bias_slopes != nullptr) { + // Apply the linear position bias: (ki - qi) * slope[hi]. + // The padding token locates between the input context and the generated tokens. + // We need to remove the number of padding tokens in the distance computation. + // ti : 0 1 2 3 4 5 6 7 8 9(tlength) + // token: i i i i p p p o o o where i=input, p=pad, o=output. + // e.g. ti = 2, dist = (9 - 3) - 2 = 4. + int max_context_length = params.max_prefix_prompt_length + params.max_input_length; + float dist = (ti < max_context_length ? ti + padd_len : ti) - tlength; + + qk += mul(params.linear_bias_slopes[hi], dist); + } + qk_max = is_mask ? qk_max : fmaxf(qk_max, qk); + qk_smem[ti - first_step] = qk; + } + } + +// Perform the final reduction to compute the max inside each warp. +// +// NOTE: In a group of THREADS_PER_KEY threads, the leader already has the max value for the +// group so it's not needed to run the reduction inside the group (again). +#pragma unroll + for (int mask = WARP_SIZE / 2; mask >= THREADS_PER_KEY; mask /= 2) { + qk_max = fmaxf(qk_max, __shfl_xor_sync(uint32_t(-1), qk_max, mask)); + } + + // Decompose the thread index into warp and lane. + const int warp = tidx / WARP_SIZE; + const int lane = tidx % WARP_SIZE; + + // The warp leader writes the max to shared memory. + if (lane == 0) { + red_smem[warp] = qk_max; + } + + // Make sure the products are in shared memory. + __syncthreads(); + + // The warps finalize the reduction. + qk_max = lane < WARPS_PER_BLOCK ? red_smem[lane] : -FLT_MAX; +#pragma unroll + for (int mask = WARPS_PER_BLOCK / 2; mask >= 1; mask /= 2) { + qk_max = fmaxf(qk_max, __shfl_xor_sync(uint32_t(-1), qk_max, mask)); + } + + // Broadcast to all the threads in the warp. + qk_max = __shfl_sync(uint32_t(-1), qk_max, 0); + + // Compute the logits and start the sum. + float sum = 0.f; + // for( int ti = tidx; ti <= params.timestep; ti += THREADS_PER_BLOCK ) { + for (int ti = first_step + tidx; ti <= tlength; ti += THREADS_PER_BLOCK) { + bool is_mask = (params.masked_tokens != nullptr) && params.masked_tokens[bi_seq_len_offset + ti]; +#ifdef FP8_MHA + float logit = 0.f; + if (FP8_MHA_KERNEL) { + logit = is_mask ? 0.f : + __expf((qk_smem[ti - first_step] - qk_max) * params.query_weight_output_scale[0] + * params.query_weight_output_scale[0]); + } + else { + logit = is_mask ? 0.f : __expf(qk_smem[ti - first_step] - qk_max); + } +#else + float logit = is_mask ? 0.f : __expf(qk_smem[ti - first_step] - qk_max); +#endif + sum += logit; + qk_smem[ti - first_step] = logit; + } + + // Compute the sum. + sum = block_sum(&red_smem[WARPS_PER_BLOCK], sum); + + // Normalize the logits. + float inv_sum = __fdividef(1.f, sum + 1.e-6f); + for (int ti = first_step + tidx; ti <= tlength; ti += THREADS_PER_BLOCK) { + float logit = qk_smem[ti - first_step] * inv_sum; + convert_from_float(logits_smem[ti - first_step], logit); + } + + // Put Values part below so we leverage __syncthreads + // from the previous step + + // The number of elements per vector. + constexpr int V_VEC_SIZE = Dh_MAX / THREADS_PER_VALUE; + // A vector of V elements for the current timestep. + using V_vec_k = typename V_vec_k_::Type; + using V_vec_m = typename V_vec_m_::Type; + + // The value computed by this thread. + int vo = tidx / THREADS_PER_VALUE; + // The hidden dimensions computed by this particular thread. + int vi = tidx % THREADS_PER_VALUE * V_VEC_SIZE; + // The base pointer for the value in the cache buffer. + // if (bkvhi == 63) { + // printf("%d %d %d %d %d\n", bkvhi, params.memory_max_len, Dh, vi, (bkvhi * params.memory_max_len * Dh + vi)); + // } + T* v_cache = ¶ms.v_cache[bkvhi * params.memory_max_len * Dh + vi]; + // Base pointer for the beam's batch, before offsetting with indirection buffer + T* v_cache_batch = ¶ms.v_cache[bbkvhi * params.memory_max_len * Dh + vi]; + + // The number of values processed per iteration of the loop. + constexpr int V_PER_ITER = THREADS_PER_BLOCK / THREADS_PER_VALUE; + + // One group of threads computes the product(s) for the current timestep. + V_vec_k v_bias; + zero(v_bias); + // if( vo == params.timestep % V_PER_ITER ) { + if (Dh == Dh_MAX || vi < Dh) { + if (vo == tlength % V_PER_ITER) { + // Trigger the loads from the V bias buffer. + if (params.v_bias != nullptr) { + v_bias = vec_conversion( + *reinterpret_cast(¶ms.v_bias[hi * Dh + vi])); + } + } + } + + // From previous, before values, step + // Also make sure the logits are in shared memory. + __syncthreads(); + + // Values continued +#ifdef MMHA_USE_FP32_ACUM_FOR_OUT + using V_vec_acum = typename V_vec_acum_fp32_::Type; +#else + using V_vec_acum = V_vec_k; +#endif + // The partial outputs computed by each thread. + V_vec_acum out; + zero(out); + + // Loop over the timesteps to compute the partial outputs. + // for( int ti = vo; ti < params.timestep; ti += V_PER_ITER ) { + if (Dh == Dh_MAX || vi < Dh) { + + // Separate the ti < memory_max_len and ti > memory_max_len + // to prevent ti % memory_len when ti < memory_len, and + // the compiler cannot optimize the codes automatically. + const int min_length = min(tlength, params.memory_max_len); + for (int ti = first_step + vo; ti < min_length; ti += V_PER_ITER) { + // Fetch offset based on cache_indir when beam sampling + const int beam_src = HAS_BEAMS ? params.cache_indir[bi_seq_len_offset + ti] : 0; + // const int beam_offset = HAS_BEAMS ? beam_src * params.num_heads * params.memory_max_len * Dh : 0; + const int beam_offset = HAS_BEAMS ? beam_src * params.num_kv_heads * params.memory_max_len * Dh : 0; + // Load the values from the cache. + V_vec_k v = vec_conversion( + *reinterpret_cast(&v_cache_batch[beam_offset + ti * Dh])); + // Load the logits from shared memory. +#if defined(MMHA_USE_FP32_ACUM_FOR_LOGITS) + float logit = logits_smem[ti - first_step]; + out = fma(logit, cast_to_float(v), out); +#else // MMHA_USE_FP32_ACUM_FOR_LOGITS +#ifdef FP8_MHA + Tk logit; + if (FP8_MHA_KERNEL) { + // NOTE: fake quantization + // logit = vec_conversion(vec_conversion(mul(1.0f / + // params.attention_qk_scale[0], logits_smem[ti]))); + logit = logits_smem[ti - first_step]; + } + else { + logit = logits_smem[ti - first_step]; + } + out = fma(logit, v, out); +#else // FP8_MHA + Tk logit = logits_smem[ti - first_step]; + out = fma(logit, v, out); +#endif // FP8_MHA +#endif // MMHA_USE_FP32_ACUM_FOR_LOGITS + } + for (int ti = first_step + vo; ti < tlength; ti += V_PER_ITER) { + if (ti < params.memory_max_len) { + // handled by previous loop + continue; + } + const int ti_circ = ti % params.memory_max_len; + + // Fetch offset based on cache_indir when beam sampling + const int beam_src = HAS_BEAMS ? params.cache_indir[bi_seq_len_offset + ti_circ] : 0; + const int beam_offset = HAS_BEAMS ? beam_src * params.num_kv_heads * params.memory_max_len * Dh : 0; + // Load the values from the cache. + V_vec_k v = vec_conversion( + *reinterpret_cast(&v_cache_batch[beam_offset + ti_circ * Dh])); + // Load the logits from shared memory. +#if defined(MMHA_USE_FP32_ACUM_FOR_LOGITS) + float logit = logits_smem[ti - first_step]; + out = fma(logit, cast_to_float(v), out); +#else // MMHA_USE_FP32_ACUM_FOR_LOGITS +#ifdef FP8_MHA + Tk logit; + if (FP8_MHA_KERNEL) { + // NOTE: fake quantization + // logit = vec_conversion(vec_conversion(mul(1.0f / + // params.attention_qk_scale[0], logits_smem[ti]))); + logit = logits_smem[ti - first_step]; + } + else { + logit = logits_smem[ti - first_step]; + } + out = fma(logit, v, out); +#else // FP8_MHA + Tk logit = logits_smem[ti - first_step]; + out = fma(logit, v, out); +#endif // FP8_MHA +#endif // MMHA_USE_FP32_ACUM_FOR_LOGITS + } + } + + // One group of threads computes the product(s) for the current timestep. + // if( vo == params.timestep % V_PER_ITER ) { + if (vo == tlength % V_PER_ITER && (Dh == Dh_MAX || vi < Dh)) { + + V_vec_k v; + // Trigger the loads from the V buffer. + const auto v_offset = qkv_base_offset + vi; + if (params.int8_mode == 2) { + using Packed_Int8_t = typename packed_type::value>::type; + using Packed_Float_t = typename packed_type::value>::type; + const auto v_scaling = params.qkv_scale_out[2]; + const auto v_quant = + *reinterpret_cast(&reinterpret_cast(params.v)[v_offset]); + + convert_from_float(v, mul(v_scaling, float_from_int8(v_quant))); + } + else { + v = vec_conversion(*reinterpret_cast(¶ms.v[v_offset])); + } + // Trigger the loads from the V bias buffer. + // V_vec v_bias = *reinterpret_cast(¶ms.v_bias[hi*Dh + vi]); + + // Compute the V values with bias. + v = add(v, v_bias); + + if (do_ia3) { + v = mul( + v, + *reinterpret_cast( + ¶ms.ia3_value_weights[(ia3_task_id * params.num_heads + hi) * Dh + vi])); + } + if (bhi % head_n_rep == 0) { + // Store the values with bias back to global memory in the cache for V. + //*reinterpret_cast(&v_cache[params.timestep*Dh]) = v; + *reinterpret_cast(&v_cache[tlength_circ * Dh]) = vec_conversion(v); + } + + // Initialize the output value with the current timestep. +#if defined(MMHA_USE_FP32_ACUM_FOR_LOGITS) + // out = fma(logits_smem[params.timestep], cast_to_float(v), out); + out = fma(logits_smem[tlength - first_step], cast_to_float(v), out); +#else // MMHA_USE_FP32_ACUM_FOR_LOGITS + // out = fma(logits_smem[params.timestep], v, out); +#ifdef FP8_MHA + Tk logit; + if (FP8_MHA_KERNEL) { + // NOTE: fake quantization + // logit = mul(1.0f / params.attention_qk_scale[0], logits_smem[tlength]); + logit = logits_smem[tlength - first_step]; + } + else { + logit = logits_smem[tlength - first_step]; + } + out = fma(logit, v, out); +#else // FP8_MHA + out = fma(logits_smem[tlength - first_step], v, out); +#endif // FP8_MHA +#endif // MMHA_USE_FP32_ACUM_FOR_LOGITS + } + + // Make sure we can start writing to shared memory. + __syncthreads(); + + // Run the final reduction amongst the different groups computing different partial outputs. + if (Dh == Dh_MAX || vi < Dh) { +#pragma unroll + for (int active_groups = V_PER_ITER; active_groups >= 2; active_groups /= 2) { + + // The midpoint in the number of active groups. + int midpoint = active_groups / 2; + + // The upper part of active threads store to shared memory. + if (vo >= midpoint && vo < active_groups && (Dh == Dh_MAX || vi < Dh)) { +#ifdef MMHA_USE_FP32_ACUM_FOR_OUT + convert_from_float(*reinterpret_cast(&out_smem[(vo - midpoint) * Dh + vi]), out); +#else + *reinterpret_cast(&out_smem[(vo - midpoint) * Dh + vi]) = out; +#endif + } + __syncthreads(); + + // The bottom warps update their values. + if (vo < midpoint && (Dh == Dh_MAX || vi < Dh)) { + out = add(*reinterpret_cast(&out_smem[vo * Dh + vi]), out); + } + __syncthreads(); + } + } + + // Output the final values. + if (vo == 0 && (Dh == Dh_MAX || vi < Dh)) { +#ifdef MMHA_USE_FP32_ACUM_FOR_OUT + if (FP8_MHA_KERNEL) { +#ifdef FP8_MHA + // float result_scale = params.attention_qk_scale[0] * params.query_weight_output_scale[0] * + // params.attention_output_weight_input_scale_inv[0]; + float result_scale = + params.query_weight_output_scale[0] * params.attention_output_weight_input_scale_inv[0]; + convert_from_float(*reinterpret_cast(¶ms.out[bhi * Dh + vi]), + mul(result_scale, out)); +#endif // FP8_MHA + } + else if (params.int8_mode == 2) { + using Packed_Int8_t = typename packed_type::value>::type; + out = mul(*params.attention_out_scale, out); + *reinterpret_cast(&(reinterpret_cast(params.out)[bhi * Dh + vi])) = + cast_to_int8(out); + } + else { + convert_from_float(*reinterpret_cast(¶ms.out[bhi * Dh + vi]), out); + } +#else // MMHA_USE_FP32_ACUM_FOR_OUT + // TODO: support int8_mode? + *reinterpret_cast(¶ms.out[bhi * Dh + vi]) = vec_conversion(out); +#endif // MMHA_USE_FP32_ACUM_FOR_OUT + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace mmha + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct threads_per_value_t { + static const int value = Dh_MAX * sizeof(T) / 16; +}; +#ifdef ENABLE_FP8 +template +struct threads_per_value_t<__nv_fp8_e4m3, Dh_MAX> { + static const int value = Dh_MAX * 4 / 16; // DEBUG: float v +}; +#endif + +template +void mgqa_launch_kernel(const KERNEL_PARAMS_TYPE& params, const cudaStream_t& stream); diff --git a/src/fastertransformer/kernels/unfused_attention_kernels.cu b/src/fastertransformer/kernels/unfused_attention_kernels.cu index d0fb0a197..4f0df238e 100644 --- a/src/fastertransformer/kernels/unfused_attention_kernels.cu +++ b/src/fastertransformer/kernels/unfused_attention_kernels.cu @@ -1698,6 +1698,34 @@ __global__ void transpose_4d_batch_major_k_cache( } } +template +__global__ void transpose_4d_batch_major_k_cache( + T* k_dst, const T* k_src, const int head_n_rep, const int kv_head_num, const int size_per_head, const int seq_len, const int max_seq_len) +{ + const int batch_id = blockIdx.y; + const int head_id = blockIdx.z; + constexpr int X_ELEMS = (sizeof(T) == 4) ? 4 : 8; + auto key_src = reinterpret_cast(k_src + batch_id * head_n_rep * kv_head_num * size_per_head * seq_len + + head_id * head_n_rep * size_per_head * seq_len); + auto key_dst = reinterpret_cast(k_dst + batch_id * kv_head_num * size_per_head * max_seq_len + + head_id * size_per_head * max_seq_len); + + const int out_idx = blockIdx.x * blockDim.x + threadIdx.x; + int size_per_head_div_x = size_per_head / X_ELEMS; + if (out_idx >= size_per_head_div_x * max_seq_len) { + return; + } + + int idx = out_idx; + const int k_seq_len_id = idx % max_seq_len; + idx = (idx - k_seq_len_id) / max_seq_len; + const int k_head_size_id = idx % size_per_head_div_x; + + if (k_seq_len_id < seq_len) { + key_dst[out_idx] = key_src[k_seq_len_id * size_per_head_div_x + k_head_size_id]; + } +} + template __global__ void transpose_4d_batch_major_v_cache( T* v_dst, const T* v_src, const int head_num, const int size_per_head, const int seq_len, const int max_seq_len) @@ -1724,6 +1752,32 @@ __global__ void transpose_4d_batch_major_v_cache( val_dst[idx] = val_src[idx]; } +template +__global__ void transpose_4d_batch_major_v_cache( + T* v_dst, const T* v_src, const int head_n_rep, const int kv_head_num, const int size_per_head, const int seq_len, const int max_seq_len) +{ + const int batch_id = blockIdx.y; + const int head_id = blockIdx.z; + + // 16 byte loads will handle "x" dimension + auto val_src = reinterpret_cast(v_src + batch_id * kv_head_num * head_n_rep * size_per_head * seq_len + + head_id * head_n_rep * size_per_head * seq_len); + auto val_dst = reinterpret_cast(v_dst + batch_id * kv_head_num * size_per_head * max_seq_len + + head_id * size_per_head * max_seq_len); + + // idx is over output dimension L * size_per_head / x for values + const int idx = blockIdx.x * blockDim.x + threadIdx.x; + + constexpr int X_ELEMS = (sizeof(T) == 4) ? 4 : 8; + const int size_per_head_div_x = size_per_head / X_ELEMS; + + if (idx >= size_per_head_div_x * seq_len) { + return; + } + + val_dst[idx] = val_src[idx]; +} + template void invokeTranspose4dBatchMajor(T* k_dst, T* v_dst, @@ -1749,6 +1803,33 @@ void invokeTranspose4dBatchMajor(T* k_dst, v_dst, v_src, local_head_num, size_per_head, seq_len, max_seq_len); } +template +void invokeTranspose4dBatchMajor(T* k_dst, + T* v_dst, + const T* k_src, + const T* v_src, + const int local_batch_size, + const int seq_len, + const int max_seq_len, + const int size_per_head, + const int local_head_num, + const int local_kv_head_num, + cudaStream_t stream) +{ + constexpr int block_sz = 128; + constexpr int x = (sizeof(T) == 4) ? 4 : 8; + int size = max_seq_len * size_per_head / x; + int head_n_rep = local_head_num / local_kv_head_num; + dim3 grid((size + block_sz - 1) / block_sz, local_batch_size, local_kv_head_num); + dim3 grid_v((seq_len * size_per_head / x + block_sz - 1) / block_sz, local_batch_size, local_kv_head_num); + + transpose_4d_batch_major_k_cache<<>>( + k_dst, k_src, head_n_rep, local_kv_head_num, size_per_head, seq_len, max_seq_len); + + transpose_4d_batch_major_v_cache<<>>( + v_dst, v_src, head_n_rep, local_kv_head_num, size_per_head, seq_len, max_seq_len); +} + #define INSTANTIATETRANSPOSE4DBATCHMAJOR(T) \ template void invokeTranspose4dBatchMajor(T* k_dst, \ T* v_dst, \ @@ -1759,6 +1840,17 @@ void invokeTranspose4dBatchMajor(T* k_dst, const int max_seq_len, \ const int size_per_head, \ const int local_head_num, \ + cudaStream_t stream); \ + template void invokeTranspose4dBatchMajor(T* k_dst, \ + T* v_dst, \ + const T* k_src, \ + const T* v_src, \ + const int local_batch_size, \ + const int seq_len, \ + const int max_seq_len, \ + const int size_per_head, \ + const int local_head_num, \ + const int local_kv_head_num, \ cudaStream_t stream) INSTANTIATETRANSPOSE4DBATCHMAJOR(float); INSTANTIATETRANSPOSE4DBATCHMAJOR(half); diff --git a/src/fastertransformer/kernels/unfused_attention_kernels.h b/src/fastertransformer/kernels/unfused_attention_kernels.h index 7ac7604d4..569c40f81 100644 --- a/src/fastertransformer/kernels/unfused_attention_kernels.h +++ b/src/fastertransformer/kernels/unfused_attention_kernels.h @@ -189,6 +189,19 @@ void invokeTranspose4dBatchMajor(T* k_dst, const int local_head_num, cudaStream_t stream); +template +void invokeTranspose4dBatchMajor(T* k_dst, + T* v_dst, + const T* k_src, + const T* v_src, + const int local_batch_size, + const int seq_len, + const int max_seq_len, + const int size_per_head, + const int local_head_num, + const int local_kv_head_num, + cudaStream_t stream); + template void invokeAddRelativeAttentionBias(T* qk_buf, const T* relative_attention_bias, diff --git a/src/fastertransformer/layers/attention_layers/CMakeLists.txt b/src/fastertransformer/layers/attention_layers/CMakeLists.txt index 628b3083a..60bbcffba 100644 --- a/src/fastertransformer/layers/attention_layers/CMakeLists.txt +++ b/src/fastertransformer/layers/attention_layers/CMakeLists.txt @@ -42,7 +42,7 @@ target_link_libraries(DecoderSelfAttentionLayer PUBLIC -lcublas -lcudart cublasM add_library(LlamaDecoderSelfAttentionLayer STATIC LlamaDecoderSelfAttentionLayer.cc) set_property(TARGET LlamaDecoderSelfAttentionLayer PROPERTY POSITION_INDEPENDENT_CODE ON) set_property(TARGET LlamaDecoderSelfAttentionLayer PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON) -target_link_libraries(LlamaDecoderSelfAttentionLayer PUBLIC -lcublas -lcudart cublasMMWrapper memory_utils decoder_masked_multihead_attention fpA_intB_gemm int8_gemm tensor nvtx_utils) +target_link_libraries(LlamaDecoderSelfAttentionLayer PUBLIC -lcublas -lcudart cublasMMWrapper memory_utils decoder_masked_groupedquery_attention fpA_intB_gemm int8_gemm tensor nvtx_utils) add_library(LlamaContextAttentionLayer STATIC LlamaContextAttentionLayer.cc) set_property(TARGET LlamaContextAttentionLayer PROPERTY POSITION_INDEPENDENT_CODE ON) diff --git a/src/fastertransformer/layers/attention_layers/LlamaContextAttentionLayer.cc b/src/fastertransformer/layers/attention_layers/LlamaContextAttentionLayer.cc index 1f3734bb6..57557eb43 100644 --- a/src/fastertransformer/layers/attention_layers/LlamaContextAttentionLayer.cc +++ b/src/fastertransformer/layers/attention_layers/LlamaContextAttentionLayer.cc @@ -352,6 +352,7 @@ void LlamaContextAttentionLayer::forward(TensorMap* output_ten max_seq_len, size_per_head_, local_head_num_, + local_kv_head_num_, stream_); // IDEA : after this, k_cache = (batch_size, num_heads, Dh/x, prefix_prompt_len + L, x) // k_cache = (batch_size, num_heads, prefix_prompt_len + L, Dh) diff --git a/src/fastertransformer/layers/attention_layers/LlamaDecoderSelfAttentionLayer.cc b/src/fastertransformer/layers/attention_layers/LlamaDecoderSelfAttentionLayer.cc index 5d12ff9a4..ed53c22d3 100644 --- a/src/fastertransformer/layers/attention_layers/LlamaDecoderSelfAttentionLayer.cc +++ b/src/fastertransformer/layers/attention_layers/LlamaDecoderSelfAttentionLayer.cc @@ -15,7 +15,7 @@ */ #include "src/fastertransformer/layers/attention_layers/LlamaDecoderSelfAttentionLayer.h" -#include "src/fastertransformer/kernels/decoder_masked_multihead_attention.h" +#include "src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention.h" #include "src/fastertransformer/utils/logger.h" #include "src/fastertransformer/utils/memory_utils.h" #include "src/fastertransformer/kernels/repeat_kv_kernels.h" @@ -47,6 +47,7 @@ void fusedQKV_masked_attention_dispatch(const T* qkv_buf, const int inference_batch_size, const int beam_width, const int head_num, + const int kv_head_num, const int size_per_head, const int rotary_embedding_dim, const bool neox_rotary_style, @@ -70,7 +71,7 @@ void fusedQKV_masked_attention_dispatch(const T* qkv_buf, { using DataType = typename SATypeConverter::Type; // Prepare the parameters. - Masked_multihead_attention_params params; + Masked_groupedquery_attention_params params; memset(¶ms, 0, sizeof(params)); int hidden_units = head_num * size_per_head; if (qkv_bias != nullptr) { @@ -112,6 +113,7 @@ void fusedQKV_masked_attention_dispatch(const T* qkv_buf, // timestep adding max_prefix_prompt_length for shared memory size calculation and rotary embedding computation params.timestep = step + max_prefix_prompt_length - 1; params.num_heads = head_num; + params.num_kv_heads = kv_head_num; params.hidden_size_per_head = size_per_head; params.rotary_embedding_dim = rotary_embedding_dim; params.neox_rotary_style = neox_rotary_style; @@ -142,7 +144,7 @@ void fusedQKV_masked_attention_dispatch(const T* qkv_buf, } PUSH_RANGE("scaled dot-product fusion"); - masked_multihead_attention(params, stream); + masked_groupedquery_attention(params, stream); POP_RANGE; } @@ -160,6 +162,7 @@ void fusedQKV_masked_attention_dispatch(const T* qkv_buf, const int inference_batch_size, \ const int beam_width, \ const int head_num, \ + const int kv_head_num, \ const int size_per_head, \ const int rotary_embedding_dim, \ const bool neox_rotary_style, \ @@ -629,6 +632,7 @@ void LlamaDecoderSelfAttentionLayer::forward(TensorMap* output_tens batch_size, beam_width, local_head_num_, + local_kv_head_num_, size_per_head_, rotary_embedding_dim_, neox_rotary_style_, diff --git a/src/fastertransformer/models/llama/Llama.cc b/src/fastertransformer/models/llama/Llama.cc index c139aa9f8..ebf39eaa4 100644 --- a/src/fastertransformer/models/llama/Llama.cc +++ b/src/fastertransformer/models/llama/Llama.cc @@ -104,7 +104,7 @@ void Llama::allocateBuffer( FT_LOG_DEBUG(__PRETTY_FUNCTION__); const size_t batchxbeam = batch_size * beam_width; const size_t self_cache_size = (num_layer_ / pipeline_para_.world_size_) * batchxbeam * max_cache_seq_len - * hidden_units_ / tensor_para_.world_size_; + * kv_head_num_ * size_per_head_ / tensor_para_.world_size_; if (vocab_size_ != vocab_size_padded_) { padded_embedding_kernel_ = @@ -596,13 +596,13 @@ void Llama::forward(std::unordered_map* output_ten const std::vector self_k_cache_shape = {num_layer_ / pipeline_para_.world_size_, batch_size * beam_width, - local_head_num_, + local_kv_head_num_, size_per_head_ / (16 / sizeof(T)), max_cache_seq_len, 16 / sizeof(T)}; const std::vector self_v_cache_shape = {num_layer_ / pipeline_para_.world_size_, batch_size * beam_width, - local_head_num_, + local_kv_head_num_, max_cache_seq_len, size_per_head_}; diff --git a/src/fastertransformer/models/llama/LlamaContextDecoder.cc b/src/fastertransformer/models/llama/LlamaContextDecoder.cc index f1c9382ca..c359266ea 100644 --- a/src/fastertransformer/models/llama/LlamaContextDecoder.cc +++ b/src/fastertransformer/models/llama/LlamaContextDecoder.cc @@ -461,7 +461,7 @@ void LlamaContextDecoder::forward(std::unordered_map* // element in batch_idx_to_compact_idx may reference the local batch // we're processing. We also need to discard references that aren't in // that particular local batch. - const size_t cache_stride_per_batch = hidden_units_ / tensor_para_.world_size_ * max_seq_len; + const size_t cache_stride_per_batch = kv_head_num_ * size_per_head_ / tensor_para_.world_size_ * max_seq_len; const size_t cache_layer_offset = (l - getFirstLayerParallelId()) * request_batch_size * cache_stride_per_batch; invokeUnCompactCaches(k_cache.getPtrWithOffset(cache_layer_offset),