Skip to content

Commit 121a130

Browse files
committed
feat(amx): add --amx toggle; prefer CPU 'extra' with GPU host+mmap when enabled
- CLI/server/bench: --amx (presence=enabled) -> mparams.amx_enable_mmap - Loader: with mmap + GPU host buft, prefer CPU 'extra' if supported (AMX repack), else fallback - llama-bench: add --amx flag to match CLI/server behavior
1 parent d9e0e7c commit 121a130

File tree

6 files changed

+149
-64
lines changed

6 files changed

+149
-64
lines changed

common/arg.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2887,6 +2887,14 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
28872887
params.use_mmap = false;
28882888
}
28892889
).set_env("LLAMA_ARG_NO_MMAP"));
2890+
add_opt(common_arg(
2891+
{"--amx"},
2892+
"enable AMX-aware CPU repack when mmap is on and a GPU host buffer would be used; prefers CPU \"extra\" buffer types (e.g., AMX) for weights on CPU.",
2893+
[](common_params & params) {
2894+
params.amx_enable_mmap = true;
2895+
}
2896+
));
2897+
28902898
add_opt(common_arg(
28912899
{"--numa"}, "TYPE",
28922900
"attempt optimizations that help on some NUMA systems\n"

common/common.cpp

Lines changed: 22 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1126,28 +1126,42 @@ struct llama_model_params common_model_params_to_llama(common_params & params) {
11261126
mparams.n_gpu_layers = params.n_gpu_layers;
11271127
}
11281128

1129-
mparams.main_gpu = params.main_gpu;
1130-
mparams.split_mode = params.split_mode;
1131-
mparams.tensor_split = params.tensor_split;
1132-
mparams.use_mmap = params.use_mmap;
1133-
mparams.use_mlock = params.use_mlock;
1134-
mparams.check_tensors = params.check_tensors;
1129+
mparams.main_gpu = params.main_gpu;
1130+
mparams.split_mode = params.split_mode;
1131+
1132+
// NOTE: common_params::tensor_split is a C-array (float [LLAMA_MAX_DEVICES])
1133+
// Upstream expects a pointer to the first element – do NOT use .data().
1134+
mparams.tensor_split = params.tensor_split;
1135+
1136+
mparams.use_mmap = params.use_mmap;
1137+
mparams.use_mlock = params.use_mlock;
1138+
mparams.check_tensors = params.check_tensors;
1139+
1140+
// Keep upstream policy: disable extra buffer types when --no-extra-bufts is set
11351141
mparams.use_extra_bufts = !params.no_extra_bufts;
11361142

1143+
// NEW: forward the AMX toggle from CLI into model params
1144+
mparams.amx_enable_mmap = params.amx_enable_mmap;
1145+
1146+
// Preserve upstream sentinel handling for KV overrides
11371147
if (params.kv_overrides.empty()) {
11381148
mparams.kv_overrides = NULL;
11391149
} else {
1140-
GGML_ASSERT(params.kv_overrides.back().key[0] == 0 && "KV overrides not terminated with empty key");
1150+
GGML_ASSERT(params.kv_overrides.back().key[0] == 0 &&
1151+
"KV overrides not terminated with empty key");
11411152
mparams.kv_overrides = params.kv_overrides.data();
11421153
}
11431154

1155+
// Preserve upstream sentinel handling for tensor buffer overrides
11441156
if (params.tensor_buft_overrides.empty()) {
11451157
mparams.tensor_buft_overrides = NULL;
11461158
} else {
1147-
GGML_ASSERT(params.tensor_buft_overrides.back().pattern == nullptr && "Tensor buffer overrides not terminated with empty pattern");
1159+
GGML_ASSERT(params.tensor_buft_overrides.back().pattern == nullptr &&
1160+
"Tensor buffer overrides not terminated with empty pattern");
11481161
mparams.tensor_buft_overrides = params.tensor_buft_overrides.data();
11491162
}
11501163

1164+
// Keep upstream progress callback wiring
11511165
mparams.progress_callback = params.load_progress_callback;
11521166
mparams.progress_callback_user_data = params.load_progress_callback_user_data;
11531167

common/common.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -392,6 +392,8 @@ struct common_params {
392392
bool check_tensors = false; // validate tensor data
393393
bool no_op_offload = false; // globally disable offload host tensor operations to device
394394
bool no_extra_bufts = false; // disable extra buffer types (used for weight repacking)
395+
bool amx_enable_mmap = false; // prefer CPU "extra" buffers when GPU host+mmap is chosen (enable AMX)
396+
395397

396398
bool single_turn = false; // single turn chat conversation
397399

include/llama.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -296,6 +296,7 @@ extern "C" {
296296
bool use_mlock; // force system to keep model in RAM
297297
bool check_tensors; // validate model tensor data
298298
bool use_extra_bufts; // use extra buffer types (used for weight repacking)
299+
bool amx_enable_mmap; // prefer CPU 'extra' buffers with GPU host+mmap (enable AMX repack on CPU)
299300
};
300301

301302
// NOTE: changing the default values of parameters marked as [EXPERIMENTAL] may cause crashes or incorrect results in certain configurations

src/llama-model.cpp

Lines changed: 47 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -2295,24 +2295,56 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
22952295
}
22962296
}
22972297

2298-
// avoid using a host buffer when using mmap
2299-
auto * buft_dev = ggml_backend_buft_get_device(buft);
2300-
if (ml.use_mmap && buft_dev && buft == ggml_backend_dev_host_buffer_type(buft_dev)) {
2301-
auto * cpu_dev = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU);
2302-
if (!cpu_dev) {
2303-
throw std::runtime_error("no CPU backend found");
2304-
}
2305-
buft = ggml_backend_dev_buffer_type(cpu_dev);
2298+
// avoid using a host buffer when using mmap
2299+
auto * buft_dev = ggml_backend_buft_get_device(buft);
2300+
if (ml.use_mmap && buft_dev && buft == ggml_backend_dev_host_buffer_type(buft_dev)) {
2301+
auto * cpu_dev = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU);
2302+
if (!cpu_dev) {
2303+
throw std::runtime_error("no CPU backend found");
2304+
}
2305+
2306+
// If enabled, prefer CPU "extra" (AMX) buffer types for weights on CPU; else use CPU default
2307+
ggml_backend_buffer_type_t cpu_default_buft = ggml_backend_dev_buffer_type(cpu_dev);
2308+
const bool prefer_cpu_extra = params.amx_enable_mmap;
2309+
2310+
if (!prefer_cpu_extra) {
2311+
buft = cpu_default_buft;
2312+
} else {
2313+
ggml_backend_buffer_type_t chosen = nullptr;
2314+
2315+
// Iterate available buffer types, skipping device-host buffer types
2316+
for (const auto & cur : *buft_list) {
2317+
ggml_backend_dev_t cur_dev = cur.first;
2318+
ggml_backend_buffer_type_t cur_buft = cur.second;
2319+
2320+
if (cur_dev && cur_buft == ggml_backend_dev_host_buffer_type(cur_dev)) {
2321+
continue;
23062322
}
23072323

2308-
if (buft != buft_list->front().second) {
2309-
n_moved_tensors++;
2310-
if (!first_moved_tensor) {
2311-
first_moved_tensor = t_meta;
2312-
first_moved_from_buft = buft_list->front().second;
2313-
first_moved_to_buft = buft;
2324+
// Prefer CPU "extra" (non-default) if supported for this tensor/op
2325+
if (cur_dev == cpu_dev && cur_buft != cpu_default_buft) {
2326+
if (weight_buft_supported(hparams, t_meta, op, cur_buft, cur_dev)) {
2327+
chosen = cur_buft;
2328+
break;
23142329
}
23152330
}
2331+
}
2332+
2333+
buft = chosen ? chosen : cpu_default_buft;
2334+
}
2335+
}
2336+
2337+
2338+
// (keep your existing moved-tensors accounting exactly as-is)
2339+
if (buft != buft_list->front().second) {
2340+
n_moved_tensors++;
2341+
if (!first_moved_tensor) {
2342+
first_moved_tensor = t_meta;
2343+
first_moved_from_buft = buft_list->front().second;
2344+
first_moved_to_buft = buft;
2345+
}
2346+
}
2347+
23162348

23172349
ggml_context * ctx = ctx_for_buft(buft);
23182350

@@ -19649,6 +19681,7 @@ llama_model_params llama_model_default_params() {
1964919681
/*.use_mlock =*/ false,
1965019682
/*.check_tensors =*/ false,
1965119683
/*.use_extra_bufts =*/ true,
19684+
/*.amx_enable_mmap =*/ false,
1965219685
};
1965319686

1965419687
return result;

0 commit comments

Comments
 (0)