diff --git a/kt-kernel/CMakeLists.txt b/kt-kernel/CMakeLists.txt index 6429b7b6..82d56b7c 100644 --- a/kt-kernel/CMakeLists.txt +++ b/kt-kernel/CMakeLists.txt @@ -495,7 +495,7 @@ if(NOT DEFINED CLANG_FORMAT_BIN) ) endif() if(NOT CLANG_FORMAT_BIN) - message(WARNING "clang-format not found. Please install clang-format (>=18) or pass -DCLANG_FORMAT_BIN=/full/path and reconfigure.") + message(WARNING "ONLY for developer: clang-format not found. Please install clang-format (>=18) or pass -DCLANG_FORMAT_BIN=/full/path and reconfigure.") else() execute_process( COMMAND ${CLANG_FORMAT_BIN} --version diff --git a/kt-kernel/CMakePresets.json b/kt-kernel/CMakePresets.json index 1c3b00b9..c8db5fc8 100644 --- a/kt-kernel/CMakePresets.json +++ b/kt-kernel/CMakePresets.json @@ -39,6 +39,20 @@ "KTRANSFORMERS_CPU_USE_AMX_AVX512": "ON", "KTRANSFORMERS_USE_CUDA": "ON" } + }, + { + "name": "amd", + "displayName": "amd_platform", + "description": "for amd platform", + "cacheVariables": { + "KTRANSFORMERS_CPU_USE_AMX": "OFF", + "LLAMA_AVX512": "OFF", + "LLAMA_AVX2": "ON", + "KTRANSFORMERS_CPU_USE_AMX_AVX512": "OFF", + "KTRANSFORMERS_USE_CUDA": "ON", + "KTRANSFORMERS_CPU_MOE_AMD": "ON", + "KTRANSFORMERS_CPU_MOE_KERNEL": "ON" + } } ] diff --git a/kt-kernel/operators/moe_kernel/la/kernel.hpp b/kt-kernel/operators/moe_kernel/la/kernel.hpp index 34d55fc0..bc685e38 100644 --- a/kt-kernel/operators/moe_kernel/la/kernel.hpp +++ b/kt-kernel/operators/moe_kernel/la/kernel.hpp @@ -340,7 +340,7 @@ struct GemmKernelInt8 { static inline const int PACK_SIZE_M = 8; static inline const int PACK_SIZE_K = 32; - static std::string name() { return "INT8"; } + static std::string name() { return "MOE_INT8"; } static int recommended_nth(int n) { return (n + N_BLOCK - 1) / N_BLOCK; } // type_: d for decode, p for prefill static int recommended_nth_down(int n, char type_ = 'd') { @@ -833,7 +833,7 @@ struct GemmKernelInt4 { static inline const int PACK_SIZE_K = 32; static inline const int PACK_SIZE_M = 8; - static std::string name() { return "INT4"; } + static std::string name() { return "MOE_INT4"; } static int recommended_nth(int n) { return (n + N_BLOCK - 1) / N_BLOCK; } static int recommended_nth_down(int n, char type_ = 'd') { diff --git a/kt-kernel/operators/moe_kernel/moe.hpp b/kt-kernel/operators/moe_kernel/moe.hpp index c5d3acbc..2ff82853 100644 --- a/kt-kernel/operators/moe_kernel/moe.hpp +++ b/kt-kernel/operators/moe_kernel/moe.hpp @@ -57,6 +57,9 @@ class MOE_KERNEL_TP std::vector> down_bb_; std::vector> down_bc_; + std::vector gate_up_owner_ptr_; + std::vector down_owner_ptr_; + inline void write_weights(std::filesystem::path prefix, std::string mat_class, char* bb, int expert_idx, size_t size, size_t scale_size) { // printf("expert %d, size %ld, scale size %ld\n", expert_idx, size, scale_size); @@ -182,6 +185,7 @@ class MOE_KERNEL_TP down_ba_.push_back(std::make_shared(config_.max_len, config_.intermediate_size, nullptr)); down_bc_.push_back(std::make_shared(config_.max_len, config_.hidden_size, nullptr)); void* gate_up_down_per_exp_ptr = std::aligned_alloc(64, gate_up_exp_size); + gate_up_owner_ptr_.push_back(gate_up_down_per_exp_ptr); gate_bb_.push_back(std::make_shared(config_.intermediate_size, config_.hidden_size, gate_up_down_per_exp_ptr, PACKED, 'u', PLAIN)); @@ -193,6 +197,7 @@ class MOE_KERNEL_TP void* down_bb_ptr = std::aligned_alloc( 64, T::BufferB::required_size(config_.hidden_size, config_.intermediate_size, PACKED, 'd', PLAIN)); + down_owner_ptr_.push_back(down_bb_ptr); down_bb_.push_back(std::make_shared(config_.hidden_size, config_.intermediate_size, down_bb_ptr, PACKED, 'd', PLAIN)); } @@ -220,27 +225,41 @@ class MOE_KERNEL_TP ~MOE_KERNEL_TP() { // printf(" Destroying KML_MOE_TP %lx\n", (intptr_t)(this)); + for (void* ptr : gate_up_owner_ptr_) { + std::free(ptr); + } + for (void* ptr : down_owner_ptr_) { + std::free(ptr); + } } void load_weights() { auto pool = config_.pool->get_subpool(tp_part_idx); const uint64_t* physical_to_logical_map = (const uint64_t*)config_.physical_to_logical_map; if (config_.gate_projs.size()) { + printf("load from safetensor"); pool->do_work_stealing_job( config_.expert_num, nullptr, [this, physical_to_logical_map](int expert_id) { uint64_t logical_expert_id = expert_map(physical_to_logical_map, expert_id); { size_t scale_size = config_.intermediate_size * sizeof(float); - size_t size = T::BufferB::required_size(config_.intermediate_size, config_.hidden_size) - scale_size; + size_t whole_size_ = + T::BufferB::required_size(config_.intermediate_size, config_.hidden_size, PACKED, 'u', PLAIN); + size_t size = whole_size_ - scale_size; + void* dst_ = PLAIN ? gate_bb_[expert_id]->b_pack[0] : gate_bb_[expert_id]->b; - memcpy(gate_bb_[expert_id]->b, config_.gate_projs[tp_part_idx][logical_expert_id], size); + memcpy(dst_, config_.gate_projs[tp_part_idx][logical_expert_id], size); if constexpr (T::BufferB::SCALE) { memcpy(gate_bb_[expert_id]->d, config_.gate_scales[tp_part_idx][logical_expert_id], scale_size); } - memcpy(up_bb_[expert_id]->b, config_.up_projs[tp_part_idx][logical_expert_id], size); + whole_size_ = + T::BufferB::required_size(config_.intermediate_size, config_.hidden_size, PACKED, 'u', PLAIN); + size = whole_size_ - scale_size; + dst_ = PLAIN ? up_bb_[expert_id]->b_pack[0] : up_bb_[expert_id]->b; + memcpy(dst_, config_.up_projs[tp_part_idx][logical_expert_id], size); if constexpr (T::BufferB::SCALE) { memcpy(up_bb_[expert_id]->d, config_.up_scales[tp_part_idx][logical_expert_id], scale_size); @@ -249,9 +268,11 @@ class MOE_KERNEL_TP { size_t scale_size = config_.hidden_size * sizeof(float); - size_t size = T::BufferB::required_size(config_.hidden_size, config_.intermediate_size) - scale_size; - - memcpy(down_bb_[expert_id]->b, config_.down_projs[tp_part_idx][logical_expert_id], size); + size_t whole_size_ = + T::BufferB::required_size(config_.hidden_size, config_.intermediate_size, PACKED, 'd', PLAIN); + size_t size = whole_size_ - scale_size; + void* dst_ = PLAIN ? down_bb_[expert_id]->b_pack[0] : down_bb_[expert_id]->b; + memcpy(dst_, config_.down_projs[tp_part_idx][logical_expert_id], size); if constexpr (T::BufferB::SCALE) { memcpy(down_bb_[expert_id]->d, config_.down_scales[tp_part_idx][logical_expert_id], scale_size); @@ -270,19 +291,19 @@ class MOE_KERNEL_TP uint8_t mat_split_idex = task_id % mat_split; uint64_t logical_expert_id = expert_map(physical_to_logical_map, expert_idx); if (mat_class == 0) { // the up matrix - size_t size = T::BufferB::required_size(config_.intermediate_size, config_.hidden_size); + size_t size = T::BufferB::required_size(config_.intermediate_size, config_.hidden_size, PACKED, 'u', PLAIN); size_t scale_size = config_.intermediate_size * sizeof(float); - read_weights(prefix, "_up_", (char*)up_bb_[expert_idx]->b, logical_expert_id, size, scale_size, mat_split, - mat_split_idex); + read_weights(prefix, "_up_", (char*)up_bb_[expert_idx]->b_pack[0], logical_expert_id, size, scale_size, + mat_split, mat_split_idex); } else if (mat_class == 1) { - size_t size = T::BufferB::required_size(config_.intermediate_size, config_.hidden_size); + size_t size = T::BufferB::required_size(config_.intermediate_size, config_.hidden_size, PACKED, 'u', PLAIN); size_t scale_size = config_.intermediate_size * sizeof(float); - read_weights(prefix, "_gate_", (char*)gate_bb_[expert_idx]->b, logical_expert_id, size, scale_size, + read_weights(prefix, "_gate_", (char*)gate_bb_[expert_idx]->b_pack[0], logical_expert_id, size, scale_size, mat_split, mat_split_idex); } else { - size_t size = T::BufferB::required_size(config_.hidden_size, config_.intermediate_size); + size_t size = T::BufferB::required_size(config_.hidden_size, config_.intermediate_size, PACKED, 'd', PLAIN); size_t scale_size = config_.hidden_size * sizeof(float); - read_weights(prefix, "_down_", (char*)down_bb_[expert_idx]->b, logical_expert_id, size, scale_size, + read_weights(prefix, "_down_", (char*)down_bb_[expert_idx]->b_pack[0], logical_expert_id, size, scale_size, mat_split, mat_split_idex); } } @@ -342,17 +363,20 @@ class MOE_KERNEL_TP expert_idx = expert_map(physical_to_logical_map, expert_idx); uint8_t mat_class = task_id % mat_type_all; if (mat_class == 0) { // the up matrix - size_t size = T::BufferB::required_size(config_.intermediate_size, config_.hidden_size); + size_t size = + T::BufferB::required_size(config_.intermediate_size, config_.hidden_size, PACKED, 'u', PLAIN); size_t scale_size = config_.intermediate_size * sizeof(float); - write_weights(prefix, "_up_", (char*)up_bb_[expert_idx]->b, expert_idx, size, scale_size); + write_weights(prefix, "_up_", (char*)up_bb_[expert_idx]->b_pack[0], expert_idx, size, scale_size); } else if (mat_class == 1) { - size_t size = T::BufferB::required_size(config_.intermediate_size, config_.hidden_size); + size_t size = + T::BufferB::required_size(config_.intermediate_size, config_.hidden_size, PACKED, 'u', PLAIN); size_t scale_size = config_.intermediate_size * sizeof(float); - write_weights(prefix, "_gate_", (char*)gate_bb_[expert_idx]->b, expert_idx, size, scale_size); + write_weights(prefix, "_gate_", (char*)gate_bb_[expert_idx]->b_pack[0], expert_idx, size, scale_size); } else if (mat_class == 2) { - size_t size = T::BufferB::required_size(config_.hidden_size, config_.intermediate_size); + size_t size = + T::BufferB::required_size(config_.hidden_size, config_.intermediate_size, PACKED, 'd', PLAIN); size_t scale_size = config_.hidden_size * sizeof(float); - write_weights(prefix, "_down_", (char*)down_bb_[expert_idx]->b, expert_idx, size, scale_size); + write_weights(prefix, "_down_", (char*)down_bb_[expert_idx]->b_pack[0], expert_idx, size, scale_size); } }, nullptr); @@ -691,6 +715,10 @@ class TP_MOE> : public TP_MOE_Common> { delete[] (ggml_bf16_t*)(tpc.up_proj); delete[] (ggml_bf16_t*)(tpc.down_proj); } + if (config.save) { + // free the bf16 weights after saving + tps.clear(); + } this->weights_loaded = true; } else if (config.path != "") { diff --git a/kt-kernel/python/experts.py b/kt-kernel/python/experts.py index 55fb4915..78807eeb 100644 --- a/kt-kernel/python/experts.py +++ b/kt-kernel/python/experts.py @@ -19,6 +19,7 @@ # Import backend implementations from .utils.amx import AMXMoEWrapper from .utils.llamafile import LlamafileMoEWrapper +from .utils.moe_kernel import GeneralMoEWrapper class KTMoEWrapper: @@ -76,7 +77,7 @@ def __new__( chunked_prefill_size: Maximum prefill chunk size cpu_save: Whether to save weights to CPU memory max_deferred_experts_per_token: Number of experts per token to defer. Defaults to 0. - method: Backend method ("AMXINT4", "AMXINT8", "LLAMAFILE") + method: Backend method ("AMXINT4", "AMXINT8", "LLAMAFILE", "MOE_INT4", "MOE_INT8") Returns: An instance of the appropriate backend implementation (e.g., AMXMoEWrapper) @@ -86,6 +87,8 @@ def __new__( backend_cls = AMXMoEWrapper elif method == "LLAMAFILE": backend_cls = LlamafileMoEWrapper + elif method in ["MOE_INT4", "MOE_INT8"]: + backend_cls = GeneralMoEWrapper else: raise NotImplementedError(f"Unsupported method: {method}") diff --git a/kt-kernel/python/utils/moe_kernel.py b/kt-kernel/python/utils/moe_kernel.py new file mode 100644 index 00000000..a238b00f --- /dev/null +++ b/kt-kernel/python/utils/moe_kernel.py @@ -0,0 +1,315 @@ +import os +import torch +import ctypes + +# Use relative imports for package structure +from ..experts_base import BaseMoEWrapper +from .loader import SafeTensorLoader +from kt_kernel_ext.moe import MOEConfig + +try: + from kt_kernel_ext.moe import Int8_KERNEL_MOE + + _HAS_INT8_SUPPORT = True +except (ImportError, AttributeError): + Int8_KERNEL_MOE = None + _HAS_INT8_SUPPORT = False +try: + from kt_kernel_ext.moe import Int4_KERNEL_MOE + + _HAS_INT4_SUPPORT = True +except (ImportError, AttributeError): + Int4_KERNEL_MOE = None + _HAS_INT4_SUPPORT = False + +from typing import Optional + + +class GeneralMoEWrapper(BaseMoEWrapper): + """ + AMX-based MoE wrapper implementation. + Supports AMXINT4 and AMXINT8 quantization methods. + """ + + _safetensor_loader_instance = None # Singleton SafeTensorLoader + + def __init__( + self, + layer_idx: int, + num_experts: int, + num_experts_per_tok: int, + hidden_size: int, + moe_intermediate_size: int, + num_gpu_experts: int, + cpuinfer_threads: int, + threadpool_count: int, + weight_path: str, + chunked_prefill_size: int, + cpu_save: bool = False, + max_deferred_experts_per_token: Optional[int] = None, + method: str = "MOE_INT8", + ): + """ + Initialize general MoE Wrapper. + + Args: + layer_idx: Layer index + num_experts: Total number of experts + num_experts_per_tok: Number of experts per token (top-k) + hidden_size: Hidden dimension size + moe_intermediate_size: MoE intermediate size + num_gpu_experts: Number of experts to run on GPU + cpuinfer_threads: Number of CPU inference threads + threadpool_count: Number of NUMA subpools + weight_path: Path to weights (SafeTensor format) + chunked_prefill_size: Maximum prefill chunk size + cpu_save: Whether to save weights to CPU memory + max_deferred_experts_per_token: Number of experts per token to defer. Defaults to 0. + method: general quantization method ("MOE_INT4" or "MOE_INT8") + """ + if not _HAS_INT4_SUPPORT and method == "MOE_INT4": + raise RuntimeError( + "MoE_INT4 backend not available. kt_kernel_ext was not compiled with int4 support.\n" + "Please recompile with int4 enabled." + ) + if not _HAS_INT8_SUPPORT and method == "MOE_INT8": + raise RuntimeError( + "MoE_INT8 backend not available. kt_kernel_ext was not compiled with int8 support.\n" + "Please recompile with int8 enabled." + ) + + # Initialize base class + super().__init__( + layer_idx=layer_idx, + num_experts=num_experts, + num_experts_per_tok=num_experts_per_tok, + hidden_size=hidden_size, + moe_intermediate_size=moe_intermediate_size, + num_gpu_experts=num_gpu_experts, + cpuinfer_threads=cpuinfer_threads, + threadpool_count=threadpool_count, + weight_path=weight_path, + chunked_prefill_size=chunked_prefill_size, + cpu_save=cpu_save, + max_deferred_experts_per_token=max_deferred_experts_per_token, + method=method, + ) + + # AMX-specific: Check if we should load merged safetensor weights + self.load_merged_weight = False + import glob + + if glob.glob(os.path.join(weight_path, "*.safetensors")): + self.load_merged_weight = True + + # Initialize SafeTensor loader (singleton) + if self.load_merged_weight: + if GeneralMoEWrapper._safetensor_loader_instance is None: + GeneralMoEWrapper._safetensor_loader_instance = SafeTensorLoader(weight_path) + self.safetensor_loader = GeneralMoEWrapper._safetensor_loader_instance + + # AMX-specific weight storage + self.gate_weights = None + self.up_weights = None + self.down_weights = None + self.gate_scales = None + self.up_scales = None + self.down_scales = None + + def load_weights_from_tensors( + self, + gate_proj: torch.Tensor, + up_proj: torch.Tensor, + down_proj: torch.Tensor, + physical_to_logical_map_cpu: torch.Tensor, + ): + """ + Load and quantize weights from BF16/FP16 tensors (online quantization). + + Args: + gate_proj: Gate projection weights [num_experts, intermediate_size, hidden_size] + up_proj: Up projection weights [num_experts, intermediate_size, hidden_size] + down_proj: Down projection weights [num_experts, hidden_size, intermediate_size] + physical_to_logical_map_cpu: Mapping from physical to logical expert IDs + """ + # Store tensors as instance variables to keep them alive + self.gate_proj = gate_proj.contiguous() + self.up_proj = up_proj.contiguous() + self.down_proj = down_proj.contiguous() + + # Configure MoE with online quantization (cpu_save mode) + moe_config = MOEConfig( + self.num_experts, + self.num_experts_per_tok, + self.hidden_size, + self.moe_intermediate_size, + self.num_gpu_experts, + ) + moe_config.layer_idx = self.layer_idx + moe_config.pool = self.cpu_infer.backend_ + moe_config.max_len = self.chunked_prefill_size + + # Enable save mode for online quantization + moe_config.save = True + moe_config.load = False + + # Set weight pointers + moe_config.gate_proj = self.gate_proj.data_ptr() + moe_config.up_proj = self.up_proj.data_ptr() + moe_config.down_proj = self.down_proj.data_ptr() + + # Set output path for quantized weights + moe_config.path = self.weight_path + + # Create MoE module based on method + if self.method == "MOE_INT4": + self.moe = Int4_KERNEL_MOE(moe_config) + elif self.method == "MOE_INT8": + self.moe = Int8_KERNEL_MOE(moe_config) + else: + raise NotImplementedError(f"Unsupported MoE method: {self.method}") + + # Submit quantization and save task + self.cpu_infer.submit(self.moe.load_weights_task(physical_to_logical_map_cpu.data_ptr())) + self.cpu_infer.sync() + + def load_weights(self, physical_to_logical_map_cpu: torch.Tensor): + """ + Load weights for this layer and initialize the MoE module. + + Args: + physical_to_logical_map_cpu: Mapping from physical to logical expert IDs + """ + gate_ptr = 0 + up_ptr = 0 + down_ptr = 0 + + gate_ptrs = [] + up_ptrs = [] + down_ptrs = [] + + gate_scale_ptrs = [] + up_scale_ptrs = [] + down_scale_ptrs = [] + + if self.load_merged_weight: + base_key = f"blk.{self.layer_idx}" + w = self.safetensor_loader.load_experts(base_key) + + self.gate_weights = w["gate"] + self.up_weights = w["up"] + self.down_weights = w["down"] + self.gate_scales = w["gate_scale"] + self.up_scales = w["up_scale"] + self.down_scales = w["down_scale"] + + # Get pointers to weight arrays + gate_ptrs = [ + [ + ctypes.addressof(ctypes.cast(et.ctypes.data, ctypes.POINTER(ctypes.c_uint64)).contents) + for et in numa_array + ] + for numa_array in self.gate_weights + ] + + up_ptrs = [ + [ + ctypes.addressof(ctypes.cast(et.ctypes.data, ctypes.POINTER(ctypes.c_uint64)).contents) + for et in numa_array + ] + for numa_array in self.up_weights + ] + + down_ptrs = [ + [ + ctypes.addressof(ctypes.cast(et.ctypes.data, ctypes.POINTER(ctypes.c_uint64)).contents) + for et in numa_array + ] + for numa_array in self.down_weights + ] + + gate_scale_ptrs = [ + [ + ctypes.addressof(ctypes.cast(et.ctypes.data, ctypes.POINTER(ctypes.c_uint64)).contents) + for et in numa_array + ] + for numa_array in self.gate_scales + ] + + up_scale_ptrs = [ + [ + ctypes.addressof(ctypes.cast(et.ctypes.data, ctypes.POINTER(ctypes.c_uint64)).contents) + for et in numa_array + ] + for numa_array in self.up_scales + ] + + down_scale_ptrs = [ + [ + ctypes.addressof(ctypes.cast(et.ctypes.data, ctypes.POINTER(ctypes.c_uint64)).contents) + for et in numa_array + ] + for numa_array in self.down_scales + ] + + # Configure MoE + moe_config = MOEConfig( + self.num_experts, + self.num_experts_per_tok, + self.hidden_size, + self.moe_intermediate_size, + self.num_gpu_experts, + ) + moe_config.layer_idx = self.layer_idx + moe_config.pool = self.cpu_infer.backend_ + moe_config.max_len = self.chunked_prefill_size + + moe_config.gate_proj = gate_ptr + moe_config.up_proj = up_ptr + moe_config.down_proj = down_ptr + moe_config.gate_projs = gate_ptrs + moe_config.up_projs = up_ptrs + moe_config.down_projs = down_ptrs + moe_config.gate_scales = gate_scale_ptrs + moe_config.up_scales = up_scale_ptrs + moe_config.down_scales = down_scale_ptrs + + if self.cpu_save: + moe_config.save = True + moe_config.load = False + base_key = f"model.layers.{self.layer_idx}" + w = self.safetensor_loader.load_experts(base_key) + + self.gate_proj = torch.cat(w["gate_weight"], dim=0).contiguous() + self.up_proj = torch.cat(w["up_weight"], dim=0).contiguous() + self.down_proj = torch.cat(w["down_weight"], dim=0).contiguous() + + moe_config.gate_proj = self.gate_proj.data_ptr() + moe_config.up_proj = self.up_proj.data_ptr() + moe_config.down_proj = self.down_proj.data_ptr() + else: + moe_config.load = True + + if not self.load_merged_weight: + moe_config.path = self.weight_path + + # Create MoE module based on AMX method + if self.method == "MOE_INT4": + self.moe = Int4_KERNEL_MOE(moe_config) + elif self.method == "MOE_INT8": + self.moe = Int8_KERNEL_MOE(moe_config) + else: + raise NotImplementedError(f"Unsupported MoE method: {self.method}") + + # Load weights + self.cpu_infer.submit(self.moe.load_weights_task(physical_to_logical_map_cpu.data_ptr())) + self.cpu_infer.sync() + + # Clean up temporary weight storage if using merged weights + if self.load_merged_weight: + del self.gate_weights + del self.up_weights + del self.down_weights + del self.gate_scales + del self.up_scales + del self.down_scales diff --git a/kt-kernel/scripts/convert_cpu_weights.py b/kt-kernel/scripts/convert_cpu_weights.py index 92f3a442..a8444d95 100644 --- a/kt-kernel/scripts/convert_cpu_weights.py +++ b/kt-kernel/scripts/convert_cpu_weights.py @@ -606,6 +606,8 @@ def _load_layer_tensors_from_disk(self, layer_idx: int) -> Dict[str, torch.Tenso quant_to_amx_map = { "int4": "INT4", "int8": "INT8", + "moe_int4": "MOE_INT4", + "moe_int8": "MOE_INT8", } amx_method = quant_to_amx_map.get(self.quant_method, "INT4") @@ -613,6 +615,7 @@ def _load_layer_tensors_from_disk(self, layer_idx: int) -> Dict[str, torch.Tenso for numa_idx in range(self.threadpool_count): numa_folder = os.path.join(layer_path, f"_numa_{numa_idx}") if not os.path.exists(numa_folder): + print(f" Warning: NUMA folder not found: {numa_folder}, skipping...") continue # Iterate through all experts @@ -746,6 +749,8 @@ def _convert_layer_experts(self, layer_idx: int, expert_ids: List[int]) -> Dict[ quant_to_amx_map = { "int4": "AMXINT4", "int8": "AMXINT8", + "moe_int4": "MOE_INT4", + "moe_int8": "MOE_INT8", } amx_method = quant_to_amx_map.get(self.quant_method, "AMXINT4") @@ -817,7 +822,7 @@ def main(): parser.add_argument("--output", "-o", required=True, help="Output directory for converted safetensors") parser.add_argument( "--quant-method", - choices=["int4", "int8", "awq"], + choices=["int4", "int8", "awq", "moe_int4", "moe_int8"], default="int4", help="Quantization method for output (default: int4)", ) @@ -875,7 +880,7 @@ def main(): input_type=None, merge_to_safetensor=merge_to_safetensor, ) - elif quant_method in ["int4", "int8"] and args.input_type in ["fp8", "fp16", "bf16"]: + elif quant_method in ["int4", "int8", "moe_int4", "moe_int8"] and args.input_type in ["fp8", "fp16", "bf16"]: # Use OnlineQuantConverter for both INT4 and INT8 quantization converter = OnlineQuantConverter( args.input_path, diff --git a/kt-kernel/setup.py b/kt-kernel/setup.py index 3860f35c..8ed4abf9 100644 --- a/kt-kernel/setup.py +++ b/kt-kernel/setup.py @@ -28,6 +28,7 @@ CPUINFER_LTO_MODE=auto Forward to -DCPUINFER_LTO_MODE CPUINFER_NATIVE=ON (override LLAMA_NATIVE) + GPU backends (if ever added later, keep placeholders): CPUINFER_USE_CUDA=0/1 -DKTRANSFORMERS_USE_CUDA CPUINFER_USE_ROCM=0/1 -DKTRANSFORMERS_USE_ROCM @@ -51,6 +52,43 @@ from setuptools.command.build_ext import build_ext import shutil +# ------------------------- +# Env parsing helpers +# ------------------------- +def _env_get_bool(name: str, default: bool | None = None) -> bool | None: + v = os.environ.get(name) + if v is None: + return default + val = v.strip().lower() + if val in ("1", "on", "true", "yes", "y", "enable", "enabled"): + return True + if val in ("0", "off", "false", "no", "n", "disable", "disabled"): + return False + return default + + +def _cmake_onoff(flag: bool) -> str: + return "ON" if flag else "OFF" + + +def _forward_bool_env(cmake_args: list[str], env_name: str, cmake_flag: str) -> bool: + """If env exists, forward it to CMake as -D=ON/OFF and return True; else return False.""" + b = _env_get_bool(env_name, None) + if b is None: + return False + cmake_args.append(f"-D{cmake_flag}={_cmake_onoff(b)}") + print(f"-- Forward {env_name} -> -D{cmake_flag}={_cmake_onoff(b)}") + return True + + +def _forward_str_env(cmake_args: list[str], env_name: str, cmake_flag: str) -> bool: + v = os.environ.get(env_name) + if not v: + return False + cmake_args.append(f"-D{cmake_flag}={v}") + print(f"-- Forward {env_name} -> -D{cmake_flag}={v}") + return True + ################################################################################ # Helpers ################################################################################ @@ -204,7 +242,34 @@ def detect_cuda_toolkit() -> bool: return True return False - if os.environ.get("CPUINFER_USE_CUDA") is None: + # Locate nvcc executable (without forcing user to set -DCMAKE_CUDA_COMPILER) + def find_nvcc_path() -> str | None: + cuda_home = os.environ.get("CUDA_HOME") + if cuda_home: + cand = Path(cuda_home) / "bin" / "nvcc" + if cand.exists(): + return str(cand) + which_nvcc = shutil.which("nvcc") + if which_nvcc: + return which_nvcc + # Common fallbacks (ordered by preference) + for cand in [ + "/usr/local/cuda-12.6/bin/nvcc", + "/usr/local/cuda/bin/nvcc", + "/usr/bin/nvcc", + "/usr/lib/nvidia-cuda-toolkit/bin/nvcc", + ]: + if Path(cand).exists(): + return cand + return None + + # Note: We no longer set CMAKE_CUDA_ARCHITECTURES by default. + # If users want to specify CUDA archs, they can set env CPUINFER_CUDA_ARCHS + # (e.g. "89" or "86;89") or pass it via CMAKE_ARGS. + auto_moe_kernel_ = False + # Normalize CPUINFER_USE_CUDA: if unset, auto-detect; otherwise respect truthy/falsey values + cuda_env = _env_get_bool("CPUINFER_USE_CUDA", None) + if cuda_env is None: auto_cuda = detect_cuda_toolkit() os.environ["CPUINFER_USE_CUDA"] = "1" if auto_cuda else "0" print(f"-- CPUINFER_USE_CUDA not set; auto-detected CUDA toolkit: {'YES' if auto_cuda else 'NO'}") @@ -228,56 +293,86 @@ def detect_cuda_toolkit() -> bool: print(f"Detected CPU info: {d}") # Vendor / feature specific toggles - # Enable AMD MoE kernel on AMD by default unless user explicitly set CPUINFER_ENABLE_AMD - # temporarily disabled this opt, use llamafile backend for now - # if d.get("vendor") == "amd" and os.environ.get("CPUINFER_ENABLE_AMD") is None: - # cmake_args.append("-DKTRANSFORMERS_CPU_MOE_AMD=ON") - # print("-- Detected AMD CPU; enabling AMD MoE kernel (-DKTRANSFORMERS_CPU_MOE_AMD=ON)") - - # On ARM, enable KML by default if not explicitly toggled - if d.get("vendor") == "arm" and os.environ.get("CPUINFER_ENABLE_KML") is None: - cmake_args.append("-DKTRANSFORMERS_CPU_USE_KML=ON") - print("-- Detected ARM CPU; enabling KML (-DKTRANSFORMERS_CPU_USE_KML=ON)") - - # If AMX or AVX512 present, enable umbrella unless overridden; enable AMX specifically when present - if "AMX" in d["features"]: - if os.environ.get("CPUINFER_ENABLE_AMX") is None: + # AMD MoE: explicit env overrides; otherwise default ON on AMD CPU + if not _forward_bool_env(cmake_args, "CPUINFER_ENABLE_AMD", "KTRANSFORMERS_CPU_MOE_AMD"): + if d.get("vendor") == "amd": + auto_moe_kernel_ = True + cmake_args.append("-DKTRANSFORMERS_CPU_MOE_AMD=ON") + print("-- Detected AMD CPU; enabling AMD MoE kernel (-DKTRANSFORMERS_CPU_MOE_AMD=ON)") + + # KML: explicit env overrides; otherwise default ON on ARM + if not _forward_bool_env(cmake_args, "CPUINFER_ENABLE_KML", "KTRANSFORMERS_CPU_USE_KML"): + if d.get("vendor") == "arm": + auto_moe_kernel_ = True + cmake_args.append("-DKTRANSFORMERS_CPU_USE_KML=ON") + print("-- Detected ARM CPU; enabling KML (-DKTRANSFORMERS_CPU_USE_KML=ON)") + + # AMX: explicit env overrides; else enable if detected + if not _forward_bool_env(cmake_args, "CPUINFER_ENABLE_AMX", "KTRANSFORMERS_CPU_USE_AMX"): + if "AMX" in d["features"]: cmake_args.append("-DKTRANSFORMERS_CPU_USE_AMX=ON") print("-- AMX support detected; enabling (-DKTRANSFORMERS_CPU_USE_AMX=ON)") - if ("AMX" in d["features"] or "AVX512" in d["features"]) and os.environ.get( - "CPUINFER_ENABLE_AVX512" - ) is None: - cmake_args.append("-DKTRANSFORMERS_CPU_USE_AMX_AVX512=ON") - print("-- Enabling AMX/AVX512 umbrella (-DKTRANSFORMERS_CPU_USE_AMX_AVX512=ON)") + # AVX512 umbrella: explicit env overrides; else enable if AMX or AVX512 detected + if not _forward_bool_env(cmake_args, "CPUINFER_ENABLE_AVX512", "KTRANSFORMERS_CPU_USE_AMX_AVX512"): + if "AMX" in d["features"] or "AVX512" in d["features"]: + cmake_args.append("-DKTRANSFORMERS_CPU_USE_AMX_AVX512=ON") + print("-- Enabling AMX/AVX512 umbrella (-DKTRANSFORMERS_CPU_USE_AMX_AVX512=ON)") + + # Auto-enable MOE kernel only when env explicitly turns on AMD or KML backend + # (Do not enable purely on vendor auto-detection to avoid surprise behavior.) + amd_env = _env_get_bool("CPUINFER_ENABLE_AMD", None) + kml_env = _env_get_bool("CPUINFER_ENABLE_KML", None) + if amd_env or kml_env: + auto_moe_kernel_ = True + already_set = any("KTRANSFORMERS_CPU_MOE_KERNEL" in a for a in cmake_args) + if not already_set and auto_moe_kernel_: + cmake_args.append("-DKTRANSFORMERS_CPU_MOE_KERNEL=ON") + print("-- Auto-enabling MOE kernel (-DKTRANSFORMERS_CPU_MOE_KERNEL=ON) because CPUINFER_ENABLE_AMD or CPUINFER_ENABLE_KML is ON") # Friendly summary print( f"-- CPU detection: vendor={d.get('vendor')} arch={d.get('arch')} features={sorted(list(d.get('features', [])))}" ) - # Optional AMX / MLA toggles (explicit env overrides auto detection above) - if os.environ.get("CPUINFER_ENABLE_AMX"): - cmake_args.append(f"-DKTRANSFORMERS_CPU_USE_AMX={os.environ['CPUINFER_ENABLE_AMX']}") - if os.environ.get("CPUINFER_ENABLE_KML"): - cmake_args.append(f"-DKTRANSFORMERS_CPU_USE_KML={os.environ['CPUINFER_ENABLE_KML']}") - if os.environ.get("CPUINFER_ENABLE_MLA"): - cmake_args.append(f"-DKTRANSFORMERS_CPU_MLA={os.environ['CPUINFER_ENABLE_MLA']}") - - # LTO toggles if user added them in CMakeLists - if os.environ.get("CPUINFER_ENABLE_LTO"): - cmake_args.append(f"-DCPUINFER_ENABLE_LTO={os.environ['CPUINFER_ENABLE_LTO']}") - if os.environ.get("CPUINFER_LTO_JOBS"): - cmake_args.append(f"-DCPUINFER_LTO_JOBS={os.environ['CPUINFER_LTO_JOBS']}") - if os.environ.get("CPUINFER_LTO_MODE"): - cmake_args.append(f"-DCPUINFER_LTO_MODE={os.environ['CPUINFER_LTO_MODE']}") + # MLA toggle (string/boolean allowed) + if not _forward_bool_env(cmake_args, "CPUINFER_ENABLE_MLA", "KTRANSFORMERS_CPU_MLA"): + _forward_str_env(cmake_args, "CPUINFER_ENABLE_MLA", "KTRANSFORMERS_CPU_MLA") + + # LTO toggles + _forward_bool_env(cmake_args, "CPUINFER_ENABLE_LTO", "CPUINFER_ENABLE_LTO") + _forward_str_env(cmake_args, "CPUINFER_LTO_JOBS", "CPUINFER_LTO_JOBS") + _forward_str_env(cmake_args, "CPUINFER_LTO_MODE", "CPUINFER_LTO_MODE") # GPU backends (mutually exclusive expected) - if os.environ.get("CPUINFER_USE_CUDA") == "1": + if _env_get_bool("CPUINFER_USE_CUDA", False): cmake_args.append("-DKTRANSFORMERS_USE_CUDA=ON") print("-- Enabling CUDA backend (-DKTRANSFORMERS_USE_CUDA=ON)") - if os.environ.get("CPUINFER_USE_ROCM") == "1": + # Inject nvcc compiler path automatically unless user already specified one. + user_specified_compiler = any("CMAKE_CUDA_COMPILER" in a for a in cmake_args) + if not user_specified_compiler: + extra_env = os.environ.get("CMAKE_ARGS", "") + if "CMAKE_CUDA_COMPILER" in extra_env: + user_specified_compiler = True + if not user_specified_compiler: + nvcc_path = find_nvcc_path() + if nvcc_path: + cmake_args.append(f"-DCMAKE_CUDA_COMPILER={nvcc_path}") + print(f"-- Auto-detected nvcc: {nvcc_path} (adding -DCMAKE_CUDA_COMPILER)") + else: + print("-- Warning: nvcc not found via CUDA_HOME/PATH/common prefixes; CUDA configure may fail.") + # Optional host compiler for nvcc if user set CUDAHOSTCXX + if os.environ.get("CUDAHOSTCXX"): + hostcxx = os.environ["CUDAHOSTCXX"] + cmake_args.append(f"-DCMAKE_CUDA_HOST_COMPILER={hostcxx}") + print(f"-- Using CUDA host compiler from CUDAHOSTCXX: {hostcxx}") + # Respect user-provided architectures only (no default auto-detection). + archs_env = os.environ.get("CPUINFER_CUDA_ARCHS", "").strip() + if archs_env and not any("CMAKE_CUDA_ARCHITECTURES" in a for a in cmake_args): + cmake_args.append(f"-DCMAKE_CUDA_ARCHITECTURES={archs_env}") + print(f"-- Set CUDA architectures from CPUINFER_CUDA_ARCHS: {archs_env}") + if _env_get_bool("CPUINFER_USE_ROCM", False): cmake_args.append("-DKTRANSFORMERS_USE_ROCM=ON") - if os.environ.get("CPUINFER_USE_MUSA") == "1": + if _env_get_bool("CPUINFER_USE_MUSA", False): cmake_args.append("-DKTRANSFORMERS_USE_MUSA=ON") # Respect user extra CMAKE_ARGS (space separated) @@ -286,7 +381,7 @@ def detect_cuda_toolkit() -> bool: cmake_args += [a for a in extra.split() if a] # Force rebuild? (delete cache) - if os.environ.get("CPUINFER_FORCE_REBUILD") == "1": + if _env_get_bool("CPUINFER_FORCE_REBUILD", True): cache = build_temp / "CMakeCache.txt" if cache.exists(): cache.unlink()