diff --git a/script/get-ml-model-gptj/meta.yaml b/script/get-ml-model-gptj/meta.yaml index f6db9adac..cbe7e6e28 100644 --- a/script/get-ml-model-gptj/meta.yaml +++ b/script/get-ml-model-gptj/meta.yaml @@ -16,6 +16,8 @@ input_mapping: new_env_keys: - MLC_ML_MODEL_* - GPTJ_CHECKPOINT_PATH +- MLC_NVIDIA_TP_SIZE +- MLC_NVIDIA_PP_SIZE prehook_deps: - enable_if_env: MLC_TMP_REQUIRE_DOWNLOAD: @@ -152,11 +154,13 @@ variations: pytorch,nvidia: default_variations: precision: fp8 + tp-size: tp-size.1 + pp-size: pp-size.1 deps: - env: MLC_GIT_CHECKOUT_PATH_ENV_NAME: MLC_TENSORRT_LLM_CHECKOUT_PATH extra_cache_tags: tensorrt-llm - tags: get,git,repo,_lfs,_repo.https://github.com/NVIDIA/TensorRT-LLM.git,_sha.0ab9d17a59c284d2de36889832fe9fc7c8697604 + tags: get,git,repo,_repo.https://github.com/NVIDIA/TensorRT-LLM.git,_sha.2ea17cdad28bed0f30e80eea5b1380726a7c6493,_submodules.3rdparty/NVTX;3rdparty/cutlass;3rdparty/cxxopts;3rdparty/json;3rdparty/pybind11;3rdparty/ucxx;3rdparty/xgrammar - names: - cuda tags: get,cuda @@ -253,6 +257,14 @@ variations: MLC_ML_MODEL_PRECISION: uint8 MLC_ML_MODEL_WEIGHT_DATA_TYPES: uint8 group: precision + tp-size.#: + env: + MLC_NVIDIA_TP_SIZE: '#' + group: tp-size + pp-size.#: + env: + MLC_NVIDIA_PP_SIZE: '#' + group: pp-size wget: add_deps_recursive: dae: