Skip to content
Open
2 changes: 1 addition & 1 deletion kt-kernel/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -495,7 +495,7 @@ if(NOT DEFINED CLANG_FORMAT_BIN)
)
endif()
if(NOT CLANG_FORMAT_BIN)
message(WARNING "clang-format not found. Please install clang-format (>=18) or pass -DCLANG_FORMAT_BIN=/full/path and reconfigure.")
message(WARNING "ONLY for developer: clang-format not found. Please install clang-format (>=18) or pass -DCLANG_FORMAT_BIN=/full/path and reconfigure.")
else()
execute_process(
COMMAND ${CLANG_FORMAT_BIN} --version
Expand Down
14 changes: 14 additions & 0 deletions kt-kernel/CMakePresets.json
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,20 @@
"KTRANSFORMERS_CPU_USE_AMX_AVX512": "ON",
"KTRANSFORMERS_USE_CUDA": "ON"
}
},
{
"name": "amd",
"displayName": "amd_platform",
"description": "for amd platform",
"cacheVariables": {
"KTRANSFORMERS_CPU_USE_AMX": "OFF",
"LLAMA_AVX512": "OFF",
"LLAMA_AVX2": "ON",
"KTRANSFORMERS_CPU_USE_AMX_AVX512": "OFF",
"KTRANSFORMERS_USE_CUDA": "ON",
"KTRANSFORMERS_CPU_MOE_AMD": "ON",
"KTRANSFORMERS_CPU_MOE_KERNEL": "ON"
}
}

]
Expand Down
4 changes: 2 additions & 2 deletions kt-kernel/operators/moe_kernel/la/kernel.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -340,7 +340,7 @@ struct GemmKernelInt8 {
static inline const int PACK_SIZE_M = 8;
static inline const int PACK_SIZE_K = 32;

static std::string name() { return "INT8"; }
static std::string name() { return "MOE_INT8"; }
static int recommended_nth(int n) { return (n + N_BLOCK - 1) / N_BLOCK; }
// type_: d for decode, p for prefill
static int recommended_nth_down(int n, char type_ = 'd') {
Expand Down Expand Up @@ -833,7 +833,7 @@ struct GemmKernelInt4 {
static inline const int PACK_SIZE_K = 32;
static inline const int PACK_SIZE_M = 8;

static std::string name() { return "INT4"; }
static std::string name() { return "MOE_INT4"; }
static int recommended_nth(int n) { return (n + N_BLOCK - 1) / N_BLOCK; }

static int recommended_nth_down(int n, char type_ = 'd') {
Expand Down
66 changes: 47 additions & 19 deletions kt-kernel/operators/moe_kernel/moe.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,9 @@ class MOE_KERNEL_TP
std::vector<std::shared_ptr<typename T::BufferB>> down_bb_;
std::vector<std::shared_ptr<typename T::BufferC>> down_bc_;

std::vector<void*> gate_up_owner_ptr_;
std::vector<void*> down_owner_ptr_;

inline void write_weights(std::filesystem::path prefix, std::string mat_class, char* bb, int expert_idx, size_t size,
size_t scale_size) {
// printf("expert %d, size %ld, scale size %ld\n", expert_idx, size, scale_size);
Expand Down Expand Up @@ -182,6 +185,7 @@ class MOE_KERNEL_TP
down_ba_.push_back(std::make_shared<typename T::BufferA>(config_.max_len, config_.intermediate_size, nullptr));
down_bc_.push_back(std::make_shared<typename T::BufferC>(config_.max_len, config_.hidden_size, nullptr));
void* gate_up_down_per_exp_ptr = std::aligned_alloc(64, gate_up_exp_size);
gate_up_owner_ptr_.push_back(gate_up_down_per_exp_ptr);

gate_bb_.push_back(std::make_shared<typename T::BufferB>(config_.intermediate_size, config_.hidden_size,
gate_up_down_per_exp_ptr, PACKED, 'u', PLAIN));
Expand All @@ -193,6 +197,7 @@ class MOE_KERNEL_TP

void* down_bb_ptr = std::aligned_alloc(
64, T::BufferB::required_size(config_.hidden_size, config_.intermediate_size, PACKED, 'd', PLAIN));
down_owner_ptr_.push_back(down_bb_ptr);
down_bb_.push_back(std::make_shared<typename T::BufferB>(config_.hidden_size, config_.intermediate_size,
down_bb_ptr, PACKED, 'd', PLAIN));
}
Expand Down Expand Up @@ -220,27 +225,41 @@ class MOE_KERNEL_TP

~MOE_KERNEL_TP() {
// printf(" Destroying KML_MOE_TP %lx\n", (intptr_t)(this));
for (void* ptr : gate_up_owner_ptr_) {
std::free(ptr);
}
for (void* ptr : down_owner_ptr_) {
std::free(ptr);
}
}

void load_weights() {
auto pool = config_.pool->get_subpool(tp_part_idx);
const uint64_t* physical_to_logical_map = (const uint64_t*)config_.physical_to_logical_map;
if (config_.gate_projs.size()) {
printf("load from safetensor");
pool->do_work_stealing_job(
config_.expert_num, nullptr,
[this, physical_to_logical_map](int expert_id) {
uint64_t logical_expert_id = expert_map(physical_to_logical_map, expert_id);
{
size_t scale_size = config_.intermediate_size * sizeof(float);
size_t size = T::BufferB::required_size(config_.intermediate_size, config_.hidden_size) - scale_size;
size_t whole_size_ =
T::BufferB::required_size(config_.intermediate_size, config_.hidden_size, PACKED, 'u', PLAIN);
size_t size = whole_size_ - scale_size;
void* dst_ = PLAIN ? gate_bb_[expert_id]->b_pack[0] : gate_bb_[expert_id]->b;

memcpy(gate_bb_[expert_id]->b, config_.gate_projs[tp_part_idx][logical_expert_id], size);
memcpy(dst_, config_.gate_projs[tp_part_idx][logical_expert_id], size);

if constexpr (T::BufferB::SCALE) {
memcpy(gate_bb_[expert_id]->d, config_.gate_scales[tp_part_idx][logical_expert_id], scale_size);
}

memcpy(up_bb_[expert_id]->b, config_.up_projs[tp_part_idx][logical_expert_id], size);
whole_size_ =
T::BufferB::required_size(config_.intermediate_size, config_.hidden_size, PACKED, 'u', PLAIN);
size = whole_size_ - scale_size;
dst_ = PLAIN ? up_bb_[expert_id]->b_pack[0] : up_bb_[expert_id]->b;
memcpy(dst_, config_.up_projs[tp_part_idx][logical_expert_id], size);

if constexpr (T::BufferB::SCALE) {
memcpy(up_bb_[expert_id]->d, config_.up_scales[tp_part_idx][logical_expert_id], scale_size);
Expand All @@ -249,9 +268,11 @@ class MOE_KERNEL_TP

{
size_t scale_size = config_.hidden_size * sizeof(float);
size_t size = T::BufferB::required_size(config_.hidden_size, config_.intermediate_size) - scale_size;

memcpy(down_bb_[expert_id]->b, config_.down_projs[tp_part_idx][logical_expert_id], size);
size_t whole_size_ =
T::BufferB::required_size(config_.hidden_size, config_.intermediate_size, PACKED, 'd', PLAIN);
size_t size = whole_size_ - scale_size;
void* dst_ = PLAIN ? down_bb_[expert_id]->b_pack[0] : down_bb_[expert_id]->b;
memcpy(dst_, config_.down_projs[tp_part_idx][logical_expert_id], size);

if constexpr (T::BufferB::SCALE) {
memcpy(down_bb_[expert_id]->d, config_.down_scales[tp_part_idx][logical_expert_id], scale_size);
Expand All @@ -270,19 +291,19 @@ class MOE_KERNEL_TP
uint8_t mat_split_idex = task_id % mat_split;
uint64_t logical_expert_id = expert_map(physical_to_logical_map, expert_idx);
if (mat_class == 0) { // the up matrix
size_t size = T::BufferB::required_size(config_.intermediate_size, config_.hidden_size);
size_t size = T::BufferB::required_size(config_.intermediate_size, config_.hidden_size, PACKED, 'u', PLAIN);
size_t scale_size = config_.intermediate_size * sizeof(float);
read_weights(prefix, "_up_", (char*)up_bb_[expert_idx]->b, logical_expert_id, size, scale_size, mat_split,
mat_split_idex);
read_weights(prefix, "_up_", (char*)up_bb_[expert_idx]->b_pack[0], logical_expert_id, size, scale_size,
mat_split, mat_split_idex);
} else if (mat_class == 1) {
size_t size = T::BufferB::required_size(config_.intermediate_size, config_.hidden_size);
size_t size = T::BufferB::required_size(config_.intermediate_size, config_.hidden_size, PACKED, 'u', PLAIN);
size_t scale_size = config_.intermediate_size * sizeof(float);
read_weights(prefix, "_gate_", (char*)gate_bb_[expert_idx]->b, logical_expert_id, size, scale_size,
read_weights(prefix, "_gate_", (char*)gate_bb_[expert_idx]->b_pack[0], logical_expert_id, size, scale_size,
mat_split, mat_split_idex);
} else {
size_t size = T::BufferB::required_size(config_.hidden_size, config_.intermediate_size);
size_t size = T::BufferB::required_size(config_.hidden_size, config_.intermediate_size, PACKED, 'd', PLAIN);
size_t scale_size = config_.hidden_size * sizeof(float);
read_weights(prefix, "_down_", (char*)down_bb_[expert_idx]->b, logical_expert_id, size, scale_size,
read_weights(prefix, "_down_", (char*)down_bb_[expert_idx]->b_pack[0], logical_expert_id, size, scale_size,
mat_split, mat_split_idex);
}
}
Expand Down Expand Up @@ -342,17 +363,20 @@ class MOE_KERNEL_TP
expert_idx = expert_map(physical_to_logical_map, expert_idx);
uint8_t mat_class = task_id % mat_type_all;
if (mat_class == 0) { // the up matrix
size_t size = T::BufferB::required_size(config_.intermediate_size, config_.hidden_size);
size_t size =
T::BufferB::required_size(config_.intermediate_size, config_.hidden_size, PACKED, 'u', PLAIN);
size_t scale_size = config_.intermediate_size * sizeof(float);
write_weights(prefix, "_up_", (char*)up_bb_[expert_idx]->b, expert_idx, size, scale_size);
write_weights(prefix, "_up_", (char*)up_bb_[expert_idx]->b_pack[0], expert_idx, size, scale_size);
} else if (mat_class == 1) {
size_t size = T::BufferB::required_size(config_.intermediate_size, config_.hidden_size);
size_t size =
T::BufferB::required_size(config_.intermediate_size, config_.hidden_size, PACKED, 'u', PLAIN);
size_t scale_size = config_.intermediate_size * sizeof(float);
write_weights(prefix, "_gate_", (char*)gate_bb_[expert_idx]->b, expert_idx, size, scale_size);
write_weights(prefix, "_gate_", (char*)gate_bb_[expert_idx]->b_pack[0], expert_idx, size, scale_size);
} else if (mat_class == 2) {
size_t size = T::BufferB::required_size(config_.hidden_size, config_.intermediate_size);
size_t size =
T::BufferB::required_size(config_.hidden_size, config_.intermediate_size, PACKED, 'd', PLAIN);
size_t scale_size = config_.hidden_size * sizeof(float);
write_weights(prefix, "_down_", (char*)down_bb_[expert_idx]->b, expert_idx, size, scale_size);
write_weights(prefix, "_down_", (char*)down_bb_[expert_idx]->b_pack[0], expert_idx, size, scale_size);
}
},
nullptr);
Expand Down Expand Up @@ -691,6 +715,10 @@ class TP_MOE<MOE_KERNEL_TP<K, T>> : public TP_MOE_Common<MOE_KERNEL_TP<K, T>> {
delete[] (ggml_bf16_t*)(tpc.up_proj);
delete[] (ggml_bf16_t*)(tpc.down_proj);
}
if (config.save) {
// free the bf16 weights after saving
tps.clear();
}

this->weights_loaded = true;
} else if (config.path != "") {
Expand Down
5 changes: 4 additions & 1 deletion kt-kernel/python/experts.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
# Import backend implementations
from .utils.amx import AMXMoEWrapper
from .utils.llamafile import LlamafileMoEWrapper
from .utils.moe_kernel import GeneralMoEWrapper


class KTMoEWrapper:
Expand Down Expand Up @@ -76,7 +77,7 @@ def __new__(
chunked_prefill_size: Maximum prefill chunk size
cpu_save: Whether to save weights to CPU memory
max_deferred_experts_per_token: Number of experts per token to defer. Defaults to 0.
method: Backend method ("AMXINT4", "AMXINT8", "LLAMAFILE")
method: Backend method ("AMXINT4", "AMXINT8", "LLAMAFILE", "MOE_INT4", "MOE_INT8")

Returns:
An instance of the appropriate backend implementation (e.g., AMXMoEWrapper)
Expand All @@ -86,6 +87,8 @@ def __new__(
backend_cls = AMXMoEWrapper
elif method == "LLAMAFILE":
backend_cls = LlamafileMoEWrapper
elif method in ["MOE_INT4", "MOE_INT8"]:
backend_cls = GeneralMoEWrapper
else:
raise NotImplementedError(f"Unsupported method: {method}")

Expand Down
Loading