diff --git a/src/libfastertransformer.cc b/src/libfastertransformer.cc index a870aa0..41729e7 100644 --- a/src/libfastertransformer.cc +++ b/src/libfastertransformer.cc @@ -271,22 +271,24 @@ std::shared_ptr ModelState::ModelFactory( LOG_MESSAGE(TRITONSERVER_LOG_ERROR, dt_message.c_str()); } } else if (model_type == "GPT-J") { + const int int8_mode = param_get_int(param, "int8_mode", 0); if (data_type == "fp16") { - ft_model = std::make_shared>(tp, pp, custom_ar, model_dir); + ft_model = std::make_shared>(tp, pp, custom_ar, model_dir, int8_mode); #ifdef ENABLE_BF16 } else if (data_type == "bf16") { - ft_model = std::make_shared>(tp, pp, custom_ar, model_dir); + ft_model = std::make_shared>(tp, pp, custom_ar, model_dir, int8_mode); #endif } else if (data_type == "fp32") { - ft_model = std::make_shared>(tp, pp, custom_ar, model_dir); + ft_model = std::make_shared>(tp, pp, custom_ar, model_dir, int8_mode); } else { LOG_MESSAGE(TRITONSERVER_LOG_ERROR, dt_message.c_str()); } } else if (model_type == "GPT-NeoX") { + const int int8_mode = param_get_int(param, "int8_mode", 0); if (data_type == "fp16") { - ft_model = std::make_shared>(tp, pp, custom_ar, model_dir); + ft_model = std::make_shared>(tp, pp, custom_ar, model_dir, int8_mode); } else { - ft_model = std::make_shared>(tp, pp, custom_ar, model_dir); + ft_model = std::make_shared>(tp, pp, custom_ar, model_dir, int8_mode); } } else if (model_type == "T5") { if (data_type == "fp16") {